Random erasing crash fix and args pass through
parent
9c3859fb9c
commit
c328b155e9
|
@ -18,25 +18,25 @@ class PrefetchLoader:
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
loader,
|
loader,
|
||||||
random_erasing=0.,
|
rand_erase_prob=0.,
|
||||||
|
rand_erase_pp=False,
|
||||||
mean=IMAGENET_DEFAULT_MEAN,
|
mean=IMAGENET_DEFAULT_MEAN,
|
||||||
std=IMAGENET_DEFAULT_STD):
|
std=IMAGENET_DEFAULT_STD):
|
||||||
self.loader = loader
|
self.loader = loader
|
||||||
self.random_erasing = random_erasing
|
self.stream = torch.cuda.Stream()
|
||||||
self.mean = torch.tensor([x * 255 for x in mean]).cuda().view(1, 3, 1, 1)
|
self.mean = torch.tensor([x * 255 for x in mean]).cuda().view(1, 3, 1, 1)
|
||||||
self.std = torch.tensor([x * 255 for x in std]).cuda().view(1, 3, 1, 1)
|
self.std = torch.tensor([x * 255 for x in std]).cuda().view(1, 3, 1, 1)
|
||||||
if random_erasing:
|
if rand_erase_prob:
|
||||||
self.random_erasing = RandomErasingTorch(
|
self.random_erasing = RandomErasingTorch(
|
||||||
probability=random_erasing, per_pixel=False)
|
probability=rand_erase_prob, per_pixel=rand_erase_pp)
|
||||||
else:
|
else:
|
||||||
self.random_erasing = None
|
self.random_erasing = None
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
stream = torch.cuda.Stream()
|
|
||||||
first = True
|
first = True
|
||||||
|
|
||||||
for next_input, next_target in self.loader:
|
for next_input, next_target in self.loader:
|
||||||
with torch.cuda.stream(stream):
|
with torch.cuda.stream(self.stream):
|
||||||
next_input = next_input.cuda(non_blocking=True)
|
next_input = next_input.cuda(non_blocking=True)
|
||||||
next_target = next_target.cuda(non_blocking=True)
|
next_target = next_target.cuda(non_blocking=True)
|
||||||
next_input = next_input.float().sub_(self.mean).div_(self.std)
|
next_input = next_input.float().sub_(self.mean).div_(self.std)
|
||||||
|
@ -48,7 +48,7 @@ class PrefetchLoader:
|
||||||
else:
|
else:
|
||||||
first = False
|
first = False
|
||||||
|
|
||||||
torch.cuda.current_stream().wait_stream(stream)
|
torch.cuda.current_stream().wait_stream(self.stream)
|
||||||
input = next_input
|
input = next_input
|
||||||
target = next_target
|
target = next_target
|
||||||
|
|
||||||
|
@ -68,7 +68,8 @@ def create_loader(
|
||||||
batch_size,
|
batch_size,
|
||||||
is_training=False,
|
is_training=False,
|
||||||
use_prefetcher=True,
|
use_prefetcher=True,
|
||||||
random_erasing=0.,
|
rand_erase_prob=0.,
|
||||||
|
rand_erase_pp=False,
|
||||||
mean=IMAGENET_DEFAULT_MEAN,
|
mean=IMAGENET_DEFAULT_MEAN,
|
||||||
std=IMAGENET_DEFAULT_STD,
|
std=IMAGENET_DEFAULT_STD,
|
||||||
num_workers=1,
|
num_workers=1,
|
||||||
|
@ -110,7 +111,8 @@ def create_loader(
|
||||||
if use_prefetcher:
|
if use_prefetcher:
|
||||||
loader = PrefetchLoader(
|
loader = PrefetchLoader(
|
||||||
loader,
|
loader,
|
||||||
random_erasing=random_erasing if is_training else 0.,
|
rand_erase_prob=rand_erase_prob if is_training else 0.,
|
||||||
|
rand_erase_pp=rand_erase_pp,
|
||||||
mean=mean,
|
mean=mean,
|
||||||
std=std)
|
std=std)
|
||||||
|
|
||||||
|
|
|
@ -110,7 +110,7 @@ class RandomErasingTorch:
|
||||||
h = int(round(math.sqrt(target_area * aspect_ratio)))
|
h = int(round(math.sqrt(target_area * aspect_ratio)))
|
||||||
w = int(round(math.sqrt(target_area / aspect_ratio)))
|
w = int(round(math.sqrt(target_area / aspect_ratio)))
|
||||||
if self.rand_color:
|
if self.rand_color:
|
||||||
c = torch.empty((chan, 1, 1), dtype=batch.dtype).cuda().normal_()
|
c = torch.empty((chan, 1, 1), dtype=batch.dtype).normal_().cuda()
|
||||||
elif not self.per_pixel:
|
elif not self.per_pixel:
|
||||||
c = torch.zeros((chan, 1, 1), dtype=batch.dtype).cuda()
|
c = torch.zeros((chan, 1, 1), dtype=batch.dtype).cuda()
|
||||||
if w < img_w and h < img_h:
|
if w < img_w and h < img_h:
|
||||||
|
@ -118,7 +118,7 @@ class RandomErasingTorch:
|
||||||
left = random.randint(0, img_w - w)
|
left = random.randint(0, img_w - w)
|
||||||
if self.per_pixel:
|
if self.per_pixel:
|
||||||
img[:, top:top + h, left:left + w] = torch.empty(
|
img[:, top:top + h, left:left + w] = torch.empty(
|
||||||
(chan, h, w), dtype=batch.dtype).cuda().normal_()
|
(chan, h, w), dtype=batch.dtype).normal_().cuda()
|
||||||
else:
|
else:
|
||||||
img[:, top:top + h, left:left + w] = c
|
img[:, top:top + h, left:left + w] = c
|
||||||
break
|
break
|
||||||
|
|
7
train.py
7
train.py
|
@ -61,6 +61,10 @@ parser.add_argument('--sched', default='step', type=str, metavar='SCHEDULER',
|
||||||
help='LR scheduler (default: "step"')
|
help='LR scheduler (default: "step"')
|
||||||
parser.add_argument('--drop', type=float, default=0.0, metavar='DROP',
|
parser.add_argument('--drop', type=float, default=0.0, metavar='DROP',
|
||||||
help='Dropout rate (default: 0.1)')
|
help='Dropout rate (default: 0.1)')
|
||||||
|
parser.add_argument('--reprob', type=float, default=0.4, metavar='PCT',
|
||||||
|
help='Random erase prob (default: 0.4)')
|
||||||
|
parser.add_argument('--repp', action='store_true', default=False,
|
||||||
|
help='Random erase per-pixel (default: False)')
|
||||||
parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
|
parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
|
||||||
help='learning rate (default: 0.01)')
|
help='learning rate (default: 0.01)')
|
||||||
parser.add_argument('--warmup-lr', type=float, default=0.0001, metavar='LR',
|
parser.add_argument('--warmup-lr', type=float, default=0.0001, metavar='LR',
|
||||||
|
@ -196,7 +200,8 @@ def main():
|
||||||
batch_size=args.batch_size,
|
batch_size=args.batch_size,
|
||||||
is_training=True,
|
is_training=True,
|
||||||
use_prefetcher=True,
|
use_prefetcher=True,
|
||||||
random_erasing=0.3,
|
rand_erase_prob=args.reprob,
|
||||||
|
rand_erase_pp=args.repp,
|
||||||
mean=data_mean,
|
mean=data_mean,
|
||||||
std=data_std,
|
std=data_std,
|
||||||
num_workers=args.workers,
|
num_workers=args.workers,
|
||||||
|
|
Loading…
Reference in New Issue