Random erasing crash fix and args pass through
parent
9c3859fb9c
commit
c328b155e9
|
@ -18,25 +18,25 @@ class PrefetchLoader:
|
|||
|
||||
def __init__(self,
|
||||
loader,
|
||||
random_erasing=0.,
|
||||
rand_erase_prob=0.,
|
||||
rand_erase_pp=False,
|
||||
mean=IMAGENET_DEFAULT_MEAN,
|
||||
std=IMAGENET_DEFAULT_STD):
|
||||
self.loader = loader
|
||||
self.random_erasing = random_erasing
|
||||
self.stream = torch.cuda.Stream()
|
||||
self.mean = torch.tensor([x * 255 for x in mean]).cuda().view(1, 3, 1, 1)
|
||||
self.std = torch.tensor([x * 255 for x in std]).cuda().view(1, 3, 1, 1)
|
||||
if random_erasing:
|
||||
if rand_erase_prob:
|
||||
self.random_erasing = RandomErasingTorch(
|
||||
probability=random_erasing, per_pixel=False)
|
||||
probability=rand_erase_prob, per_pixel=rand_erase_pp)
|
||||
else:
|
||||
self.random_erasing = None
|
||||
|
||||
def __iter__(self):
|
||||
stream = torch.cuda.Stream()
|
||||
first = True
|
||||
|
||||
for next_input, next_target in self.loader:
|
||||
with torch.cuda.stream(stream):
|
||||
with torch.cuda.stream(self.stream):
|
||||
next_input = next_input.cuda(non_blocking=True)
|
||||
next_target = next_target.cuda(non_blocking=True)
|
||||
next_input = next_input.float().sub_(self.mean).div_(self.std)
|
||||
|
@ -48,7 +48,7 @@ class PrefetchLoader:
|
|||
else:
|
||||
first = False
|
||||
|
||||
torch.cuda.current_stream().wait_stream(stream)
|
||||
torch.cuda.current_stream().wait_stream(self.stream)
|
||||
input = next_input
|
||||
target = next_target
|
||||
|
||||
|
@ -68,7 +68,8 @@ def create_loader(
|
|||
batch_size,
|
||||
is_training=False,
|
||||
use_prefetcher=True,
|
||||
random_erasing=0.,
|
||||
rand_erase_prob=0.,
|
||||
rand_erase_pp=False,
|
||||
mean=IMAGENET_DEFAULT_MEAN,
|
||||
std=IMAGENET_DEFAULT_STD,
|
||||
num_workers=1,
|
||||
|
@ -110,7 +111,8 @@ def create_loader(
|
|||
if use_prefetcher:
|
||||
loader = PrefetchLoader(
|
||||
loader,
|
||||
random_erasing=random_erasing if is_training else 0.,
|
||||
rand_erase_prob=rand_erase_prob if is_training else 0.,
|
||||
rand_erase_pp=rand_erase_pp,
|
||||
mean=mean,
|
||||
std=std)
|
||||
|
||||
|
|
|
@ -110,7 +110,7 @@ class RandomErasingTorch:
|
|||
h = int(round(math.sqrt(target_area * aspect_ratio)))
|
||||
w = int(round(math.sqrt(target_area / aspect_ratio)))
|
||||
if self.rand_color:
|
||||
c = torch.empty((chan, 1, 1), dtype=batch.dtype).cuda().normal_()
|
||||
c = torch.empty((chan, 1, 1), dtype=batch.dtype).normal_().cuda()
|
||||
elif not self.per_pixel:
|
||||
c = torch.zeros((chan, 1, 1), dtype=batch.dtype).cuda()
|
||||
if w < img_w and h < img_h:
|
||||
|
@ -118,7 +118,7 @@ class RandomErasingTorch:
|
|||
left = random.randint(0, img_w - w)
|
||||
if self.per_pixel:
|
||||
img[:, top:top + h, left:left + w] = torch.empty(
|
||||
(chan, h, w), dtype=batch.dtype).cuda().normal_()
|
||||
(chan, h, w), dtype=batch.dtype).normal_().cuda()
|
||||
else:
|
||||
img[:, top:top + h, left:left + w] = c
|
||||
break
|
||||
|
|
7
train.py
7
train.py
|
@ -61,6 +61,10 @@ parser.add_argument('--sched', default='step', type=str, metavar='SCHEDULER',
|
|||
help='LR scheduler (default: "step"')
|
||||
parser.add_argument('--drop', type=float, default=0.0, metavar='DROP',
|
||||
help='Dropout rate (default: 0.1)')
|
||||
parser.add_argument('--reprob', type=float, default=0.4, metavar='PCT',
|
||||
help='Random erase prob (default: 0.4)')
|
||||
parser.add_argument('--repp', action='store_true', default=False,
|
||||
help='Random erase per-pixel (default: False)')
|
||||
parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
|
||||
help='learning rate (default: 0.01)')
|
||||
parser.add_argument('--warmup-lr', type=float, default=0.0001, metavar='LR',
|
||||
|
@ -196,7 +200,8 @@ def main():
|
|||
batch_size=args.batch_size,
|
||||
is_training=True,
|
||||
use_prefetcher=True,
|
||||
random_erasing=0.3,
|
||||
rand_erase_prob=args.reprob,
|
||||
rand_erase_pp=args.repp,
|
||||
mean=data_mean,
|
||||
std=data_std,
|
||||
num_workers=args.workers,
|
||||
|
|
Loading…
Reference in New Issue