From 2295cf56c2554363f26c84662e4678cce6b3bdc5 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sun, 10 Mar 2019 14:23:16 -0700 Subject: [PATCH] Add some Nvidia performance enhancements (prefetch loader, fast collate), and refactor some of training and model fact/transforms --- data/__init__.py | 4 + dataset.py => data/dataset.py | 0 data/random_erasing.py | 131 +++++++++++++++++++++++++ data/transforms.py | 53 +++++++++++ data/utils.py | 65 +++++++++++++ models/__init__.py | 2 +- models/model_factory.py | 3 - models/random_erasing.py | 61 ------------ models/transforms.py | 80 ---------------- optim/__init__.py | 2 + train.py | 174 +++++++++++++++++----------------- 11 files changed, 345 insertions(+), 230 deletions(-) create mode 100644 data/__init__.py rename dataset.py => data/dataset.py (100%) create mode 100644 data/random_erasing.py create mode 100644 data/transforms.py create mode 100644 data/utils.py delete mode 100644 models/random_erasing.py delete mode 100644 models/transforms.py create mode 100644 optim/__init__.py diff --git a/data/__init__.py b/data/__init__.py new file mode 100644 index 00000000..07868e03 --- /dev/null +++ b/data/__init__.py @@ -0,0 +1,4 @@ +from data.dataset import Dataset +from data.transforms import transforms_imagenet_eval, transforms_imagenet_train +from data.utils import fast_collate, PrefetchLoader +from data.random_erasing import RandomErasingTorch, RandomErasingNumpy \ No newline at end of file diff --git a/dataset.py b/data/dataset.py similarity index 100% rename from dataset.py rename to data/dataset.py diff --git a/data/random_erasing.py b/data/random_erasing.py new file mode 100644 index 00000000..867eb578 --- /dev/null +++ b/data/random_erasing.py @@ -0,0 +1,131 @@ +from __future__ import absolute_import + +#from torchvision.transforms import * + +from PIL import Image +import random +import math +import numpy as np +import torch + + +class RandomErasingNumpy: + """ Randomly selects a rectangle region in an image and erases its pixels. + 'Random Erasing Data Augmentation' by Zhong et al. + See https://arxiv.org/pdf/1708.04896.pdf + + This 'Numpy' variant of RandomErasing is intended to be applied on a per + image basis after transforming the image to uint8 numpy array in + range 0-255 prior to tensor conversion and normalization + Args: + probability: The probability that the Random Erasing operation will be performed. + sl: Minimum proportion of erased area against input image. + sh: Maximum proportion of erased area against input image. + r1: Minimum aspect ratio of erased area. + mean: Erasing value. + """ + + def __init__( + self, + probability=0.5, sl=0.02, sh=1/3, min_aspect=0.3, + per_pixel=False, rand_color=False, + pl=0, ph=255, mean=[255 * 0.485, 255 * 0.456, 255 * 0.406], + out_type=np.uint8): + self.probability = probability + if not per_pixel and not rand_color: + self.mean = np.array(mean).round().astype(out_type) + else: + self.mean = None + self.sl = sl + self.sh = sh + self.min_aspect = min_aspect + self.pl = pl + self.ph = ph + self.per_pixel = per_pixel # per pixel random, bounded by [pl, ph] + self.rand_color = rand_color # per block random, bounded by [pl, ph] + self.out_type = out_type + + def __call__(self, img): + if random.random() > self.probability: + return img + + chan, img_h, img_w = img.shape + area = img_h * img_w + for attempt in range(100): + target_area = random.uniform(self.sl, self.sh) * area + aspect_ratio = random.uniform(self.min_aspect, 1 / self.min_aspect) + + h = int(round(math.sqrt(target_area * aspect_ratio))) + w = int(round(math.sqrt(target_area / aspect_ratio))) + if self.rand_color: + c = np.random.randint(self.pl, self.ph + 1, (chan,), self.out_type) + elif not self.per_pixel: + c = self.mean[:chan] + if w < img_w and h < img_h: + top = random.randint(0, img_h - h) + left = random.randint(0, img_w - w) + if self.per_pixel: + img[:, top:top + h, left:left + w] = np.random.randint( + self.pl, self.ph + 1, (chan, h, w), self.out_type) + else: + img[:, top:top + h, left:left + w] = c + return img + + return img + + +class RandomErasingTorch: + """ Randomly selects a rectangle region in an image and erases its pixels. + 'Random Erasing Data Augmentation' by Zhong et al. + See https://arxiv.org/pdf/1708.04896.pdf + + This 'Torch' variant of RandomErasing is intended to be applied to a full batch + tensor after it has been normalized by dataset mean and std. + Args: + probability: The probability that the Random Erasing operation will be performed. + sl: Minimum proportion of erased area against input image. + sh: Maximum proportion of erased area against input image. + r1: Minimum aspect ratio of erased area. + """ + + def __init__( + self, + probability=0.5, sl=0.02, sh=1/3, min_aspect=0.3, + per_pixel=False, rand_color=False, + device='cuda'): + self.probability = probability + self.sl = sl + self.sh = sh + self.min_aspect = min_aspect + self.per_pixel = per_pixel # per pixel random, bounded by [pl, ph] + self.rand_color = rand_color # per block random, bounded by [pl, ph] + self.device = device + + def __call__(self, batch): + batch_size, chan, img_h, img_w = batch.size() + area = img_h * img_w + for i in range(batch_size): + if random.random() > self.probability: + continue + img = batch[i] + for attempt in range(100): + target_area = random.uniform(self.sl, self.sh) * area + aspect_ratio = random.uniform(self.min_aspect, 1 / self.min_aspect) + + h = int(round(math.sqrt(target_area * aspect_ratio))) + w = int(round(math.sqrt(target_area / aspect_ratio))) + if self.rand_color: + c = torch.empty(chan, dtype=batch.dtype, device=self.device).normal_() + elif not self.per_pixel: + c = torch.zeros(chan, dtype=batch.dtype, device=self.device) + if w < img_w and h < img_h: + top = random.randint(0, img_h - h) + left = random.randint(0, img_w - w) + if self.per_pixel: + img[:, top:top + h, left:left + w] = torch.empty( + (chan, h, w), dtype=batch.dtype, device=self.device).normal_() + else: + img[:, top:top + h, left:left + w] = c + break + + return batch diff --git a/data/transforms.py b/data/transforms.py new file mode 100644 index 00000000..80491ef6 --- /dev/null +++ b/data/transforms.py @@ -0,0 +1,53 @@ +import torch +from torchvision import transforms +from PIL import Image +import math +import numpy as np +from data.random_erasing import RandomErasingNumpy + +DEFAULT_CROP_PCT = 0.875 + +IMAGENET_DPN_MEAN = [124 / 255, 117 / 255, 104 / 255] +IMAGENET_DPN_STD = [1 / (.0167 * 255)] * 3 +IMAGENET_INCEPTION_MEAN = [0.5, 0.5, 0.5] +IMAGENET_INCEPTION_STD = [0.5, 0.5, 0.5] +IMAGENET_DEFAULT_MEAN = [0.485, 0.456, 0.406] +IMAGENET_DEFAULT_STD = [0.229, 0.224, 0.225] + + +class AsNumpy: + + def __call__(self, pil_img): + np_img = np.array(pil_img, dtype=np.uint8) + if np_img.ndim < 3: + np_img = np.expand_dims(np_img, axis=-1) + np_img = np.rollaxis(np_img, 2) # HWC to CHW + return np_img + + +def transforms_imagenet_train( + img_size=224, + scale=(0.1, 1.0), + color_jitter=(0.4, 0.4, 0.4), + random_erasing=0.4): + + tfl = [ + transforms.RandomResizedCrop(img_size, scale=scale), + transforms.RandomHorizontalFlip(), + transforms.ColorJitter(*color_jitter), + AsNumpy(), + ] + #if random_erasing > 0.: + # tfl.append(RandomErasingNumpy(random_erasing, per_pixel=True)) + return transforms.Compose(tfl) + + +def transforms_imagenet_eval(img_size=224, crop_pct=None): + crop_pct = crop_pct or DEFAULT_CROP_PCT + scale_size = int(math.floor(img_size / crop_pct)) + + return transforms.Compose([ + transforms.Resize(scale_size, Image.BICUBIC), + transforms.CenterCrop(img_size), + AsNumpy(), + ]) diff --git a/data/utils.py b/data/utils.py new file mode 100644 index 00000000..f4afa60e --- /dev/null +++ b/data/utils.py @@ -0,0 +1,65 @@ +import torch +from data.random_erasing import RandomErasingTorch + + +def fast_collate(batch): + targets = torch.tensor([b[1] for b in batch], dtype=torch.int64) + batch_size = len(targets) + tensor = torch.zeros((batch_size, *batch[0][0].shape), dtype=torch.uint8) + for i in range(batch_size): + tensor[i] += torch.from_numpy(batch[i][0]) + + return tensor, targets + + +class PrefetchLoader: + + def __init__(self, + loader, + fp16=False, + random_erasing=True, + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]): + self.loader = loader + self.fp16 = fp16 + self.random_erasing = random_erasing + self.mean = torch.tensor([x * 255 for x in mean]).cuda().view(1, 3, 1, 1) + self.std = torch.tensor([x * 255 for x in std]).cuda().view(1, 3, 1, 1) + if random_erasing: + self.random_erasing = RandomErasingTorch(per_pixel=True) + else: + self.random_erasing = None + + if self.fp16: + self.mean = self.mean.half() + self.std = self.std.half() + + def __iter__(self): + stream = torch.cuda.Stream() + first = True + + for next_input, next_target in self.loader: + with torch.cuda.stream(stream): + next_input = next_input.cuda(non_blocking=True) + next_target = next_target.cuda(non_blocking=True) + if self.fp16: + next_input = next_input.half() + else: + next_input = next_input.float() + next_input = next_input.sub_(self.mean).div_(self.std) + if self.random_erasing is not None: + next_input = self.random_erasing(next_input) + + if not first: + yield input, target + else: + first = False + + torch.cuda.current_stream().wait_stream(stream) + input = next_input + target = next_target + + yield input, target + + def __len__(self): + return len(self.loader) diff --git a/models/__init__.py b/models/__init__.py index 13423620..b24eceb5 100644 --- a/models/__init__.py +++ b/models/__init__.py @@ -1,2 +1,2 @@ from .model_factory import create_model -from .transforms import transforms_imagenet_eval, transforms_imagenet_train + diff --git a/models/model_factory.py b/models/model_factory.py index 37eeac97..813ff43a 100644 --- a/models/model_factory.py +++ b/models/model_factory.py @@ -1,7 +1,4 @@ import torch -from torchvision import transforms -from PIL import Image -import math import os from .inception_v4 import inception_v4 diff --git a/models/random_erasing.py b/models/random_erasing.py deleted file mode 100644 index f544525a..00000000 --- a/models/random_erasing.py +++ /dev/null @@ -1,61 +0,0 @@ -from __future__ import absolute_import - -from torchvision.transforms import * - -from PIL import Image -import random -import math -import numpy as np -import torch - - -class RandomErasing: - """ Randomly selects a rectangle region in an image and erases its pixels. - 'Random Erasing Data Augmentation' by Zhong et al. - See https://arxiv.org/pdf/1708.04896.pdf - Args: - probability: The probability that the Random Erasing operation will be performed. - sl: Minimum proportion of erased area against input image. - sh: Maximum proportion of erased area against input image. - r1: Minimum aspect ratio of erased area. - mean: Erasing value. - """ - - def __init__( - self, - probability=0.5, sl=0.02, sh=1/3, min_aspect=0.3, - per_pixel=False, random=False, - pl=0, ph=1., mean=[0.485, 0.456, 0.406]): - self.probability = probability - self.mean = torch.tensor(mean) - self.sl = sl - self.sh = sh - self.min_aspect = min_aspect - self.pl = pl - self.ph = ph - self.per_pixel = per_pixel # per pixel random, bounded by [pl, ph] - self.random = random # per block random, bounded by [pl, ph] - - def __call__(self, img): - if random.random() > self.probability: - return img - - chan, img_h, img_w = img.size() - area = img_h * img_w - for attempt in range(100): - target_area = random.uniform(self.sl, self.sh) * area - aspect_ratio = random.uniform(self.min_aspect, 1 / self.min_aspect) - - h = int(round(math.sqrt(target_area * aspect_ratio))) - w = int(round(math.sqrt(target_area / aspect_ratio))) - c = torch.empty((chan)).uniform_(self.pl, self.ph) if self.random else self.mean[:chan] - if w < img_w and h < img_h: - top = random.randint(0, img_h - h) - left = random.randint(0, img_w - w) - if self.per_pixel: - img[:, top:top + h, left:left + w] = torch.empty((chan, h, w)).uniform_(self.pl, self.ph) - else: - img[:, top:top + h, left:left + w] = c - return img - - return img diff --git a/models/transforms.py b/models/transforms.py deleted file mode 100644 index 94768b48..00000000 --- a/models/transforms.py +++ /dev/null @@ -1,80 +0,0 @@ -import torch -from torchvision import transforms -from PIL import Image -import math -from models.random_erasing import RandomErasing - -DEFAULT_CROP_PCT = 0.875 - -IMAGENET_DPN_MEAN = [124 / 255, 117 / 255, 104 / 255] -IMAGENET_DPN_STD = [1 / (.0167 * 255)] * 3 -IMAGENET_DEFAULT_MEAN = [0.485, 0.456, 0.406] -IMAGENET_DEFAULT_STD = [0.229, 0.224, 0.225] - - -class LeNormalize(object): - """Normalize to -1..1 in Google Inception style - """ - def __call__(self, tensor): - for t in tensor: - t.sub_(0.5).mul_(2.0) - return tensor - - -def transforms_imagenet_train( - model_name, - img_size=224, - scale=(0.1, 1.0), - color_jitter=(0.4, 0.4, 0.4), - random_erasing=0.4): - if 'dpn' in model_name: - normalize = transforms.Normalize( - mean=IMAGENET_DPN_MEAN, - std=IMAGENET_DPN_STD) - elif 'inception' in model_name: - normalize = LeNormalize() - else: - normalize = transforms.Normalize( - mean=IMAGENET_DEFAULT_MEAN, - std=IMAGENET_DEFAULT_STD) - - tfl = [ - transforms.RandomResizedCrop(img_size, scale=scale), - transforms.RandomHorizontalFlip(), - transforms.ColorJitter(*color_jitter), - transforms.ToTensor()] - if random_erasing > 0.: - tfl.append(RandomErasing(random_erasing, per_pixel=True)) - return transforms.Compose(tfl + [normalize]) - - -def transforms_imagenet_eval(model_name, img_size=224, crop_pct=None): - crop_pct = crop_pct or DEFAULT_CROP_PCT - if 'dpn' in model_name: - if crop_pct is None: - # Use default 87.5% crop for model's native img_size - # but use 100% crop for larger than native as it - # improves test time results across all models. - if img_size == 224: - scale_size = int(math.floor(img_size / DEFAULT_CROP_PCT)) - else: - scale_size = img_size - else: - scale_size = int(math.floor(img_size / crop_pct)) - normalize = transforms.Normalize( - mean=IMAGENET_DPN_MEAN, - std=IMAGENET_DPN_STD) - elif 'inception' in model_name: - scale_size = int(math.floor(img_size / crop_pct)) - normalize = LeNormalize() - else: - scale_size = int(math.floor(img_size / crop_pct)) - normalize = transforms.Normalize( - mean=IMAGENET_DEFAULT_MEAN, - std=IMAGENET_DEFAULT_STD) - - return transforms.Compose([ - transforms.Resize(scale_size, Image.BICUBIC), - transforms.CenterCrop(img_size), - transforms.ToTensor(), - normalize]) diff --git a/optim/__init__.py b/optim/__init__.py new file mode 100644 index 00000000..d8995736 --- /dev/null +++ b/optim/__init__.py @@ -0,0 +1,2 @@ +from optim.adabound import AdaBound +from optim.nadam import Nadam \ No newline at end of file diff --git a/train.py b/train.py index df7d91df..90cabead 100644 --- a/train.py +++ b/train.py @@ -3,10 +3,10 @@ import time from collections import OrderedDict from datetime import datetime -from dataset import Dataset -from models import model_factory, transforms_imagenet_eval, transforms_imagenet_train +from data import * +from models import model_factory from utils import * -from optim import nadam, adabound +from optim import Nadam, AdaBound import scheduler import torch @@ -95,24 +95,32 @@ def main(): dataset_train = Dataset( os.path.join(args.data, 'train'), - transform=transforms_imagenet_train(args.model)) + transform=transforms_imagenet_train()) loader_train = data.DataLoader( dataset_train, batch_size=batch_size, shuffle=True, - num_workers=args.workers + num_workers=args.workers, + collate_fn=fast_collate ) + loader_train = PrefetchLoader( + loader_train, random_erasing=True, + ) dataset_eval = Dataset( os.path.join(args.data, 'validation'), - transform=transforms_imagenet_eval(args.model)) + transform=transforms_imagenet_eval()) loader_eval = data.DataLoader( dataset_eval, batch_size=4 * args.batch_size, shuffle=False, - num_workers=args.workers + num_workers=args.workers, + collate_fn=fast_collate, + ) + loader_eval = PrefetchLoader( + loader_eval, random_erasing=False, ) model = model_factory.create_model( @@ -156,66 +164,11 @@ def main(): train_loss_fn = validate_loss_fn = torch.nn.CrossEntropyLoss().cuda() - if args.opt.lower() == 'sgd': - optimizer = optim.SGD( - model.parameters(), lr=args.lr, - momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True) - elif args.opt.lower() == 'adam': - optimizer = optim.Adam( - model.parameters(), lr=args.lr, weight_decay=args.weight_decay, eps=args.opt_eps) - elif args.opt.lower() == 'nadam': - optimizer = nadam.Nadam( - model.parameters(), lr=args.lr, weight_decay=args.weight_decay, eps=args.opt_eps) - elif args.opt.lower() == 'adabound': - optimizer = adabound.AdaBound( - model.parameters(), lr=args.lr / 1000, weight_decay=args.weight_decay, eps=args.opt_eps, - final_lr=args.lr) - elif args.opt.lower() == 'adadelta': - optimizer = optim.Adadelta( - model.parameters(), lr=args.lr, weight_decay=args.weight_decay, eps=args.opt_eps) - elif args.opt.lower() == 'rmsprop': - optimizer = optim.RMSprop( - model.parameters(), lr=args.lr, alpha=0.9, eps=args.opt_eps, - momentum=args.momentum, weight_decay=args.weight_decay) - else: - assert False and "Invalid optimizer" - exit(1) + optimizer = create_optimizer(args, model.parameters()) + if optimizer_state is not None: + optimizer.load_state_dict(optimizer_state) - #if optimizer_state is not None: - # optimizer.load_state_dict(optimizer_state) - - num_epochs = args.epochs - if args.sched == 'cosine': - lr_scheduler = scheduler.CosineLRScheduler( - optimizer, - t_initial=args.epochs, - t_mul=1.0, - lr_min=1e-5, - decay_rate=args.decay_rate, - warmup_lr_init=1e-4, - warmup_t=3, - cycle_limit=1, - t_in_epochs=True, - ) - num_epochs = lr_scheduler.get_cycle_length() + 10 - elif args.sched == 'tanh': - lr_scheduler = scheduler.TanhLRScheduler( - optimizer, - t_initial=args.epochs, - t_mul=1.0, - lr_min=1e-5, - warmup_lr_init=.001, - warmup_t=3, - cycle_limit=1, - t_in_epochs=True, - ) - num_epochs = lr_scheduler.get_cycle_length() + 10 - else: - lr_scheduler = scheduler.StepLRScheduler( - optimizer, - decay_t=args.decay_epochs, - decay_rate=args.decay_rate, - ) + lr_scheduler, num_epochs = create_scheduler(args, optimizer) print(num_epochs) saver = CheckpointSaver(checkpoint_dir=output_dir) @@ -244,7 +197,6 @@ def main(): 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'args': args, - 'gp': args.gp, }, epoch=epoch + 1, metric=eval_metrics['eval_loss']) @@ -271,12 +223,6 @@ def train_epoch( last_batch = batch_idx == last_idx data_time_m.update(time.time() - end) - input = input.cuda() - if isinstance(target, (tuple, list)): - target = [t.cuda() for t in target] - else: - target = target.cuda() - output = model(input) loss = loss_fn(output, target) @@ -286,6 +232,7 @@ def train_epoch( loss.backward() optimizer.step() + torch.cuda.synchronize() num_updates += 1 batch_time_m.update(time.time() - end) @@ -316,7 +263,7 @@ def train_epoch( padding=0, normalize=True) - if saver is not None and last_batch or batch_idx % args.recovery_interval == 0: + if saver is not None and last_batch or (batch_idx + 1) % args.recovery_interval == 0: save_epoch = epoch + 1 if last_batch else epoch saver.save_recovery({ 'epoch': save_epoch, @@ -324,7 +271,6 @@ def train_epoch( 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'args': args, - 'gp': args.gp, }, epoch=save_epoch, batch_idx=batch_idx) @@ -351,12 +297,6 @@ def validate(model, loader, loss_fn, args): for batch_idx, (input, target) in enumerate(loader): last_batch = batch_idx == last_idx - input = input.cuda() - if isinstance(target, list): - target = target[0].cuda() - else: - target = target.cuda() - output = model(input) if isinstance(output, (tuple, list)): output = output[0] @@ -367,12 +307,12 @@ def validate(model, loader, loss_fn, args): output = output.unfold(0, reduce_factor, reduce_factor).mean(dim=2) target = target[0:target.size(0):reduce_factor] - # calc loss loss = loss_fn(output, target) - losses_m.update(loss.item(), input.size(0)) - - # metrics prec1, prec5 = accuracy(output, target, topk=(1, 5)) + + torch.cuda.synchronize() + + losses_m.update(loss.item(), input.size(0)) prec1_m.update(prec1.item(), output.size(0)) prec5_m.update(prec5.item(), output.size(0)) @@ -393,5 +333,69 @@ def validate(model, loader, loss_fn, args): return metrics +def create_optimizer(args, parameters): + if args.opt.lower() == 'sgd': + optimizer = optim.SGD( + parameters, lr=args.lr, + momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True) + elif args.opt.lower() == 'adam': + optimizer = optim.Adam( + parameters, lr=args.lr, weight_decay=args.weight_decay, eps=args.opt_eps) + elif args.opt.lower() == 'nadam': + optimizer = Nadam( + parameters, lr=args.lr, weight_decay=args.weight_decay, eps=args.opt_eps) + elif args.opt.lower() == 'adabound': + optimizer = AdaBound( + parameters, lr=args.lr / 100, weight_decay=args.weight_decay, eps=args.opt_eps, + final_lr=args.lr) + elif args.opt.lower() == 'adadelta': + optimizer = optim.Adadelta( + parameters, lr=args.lr, weight_decay=args.weight_decay, eps=args.opt_eps) + elif args.opt.lower() == 'rmsprop': + optimizer = optim.RMSprop( + parameters, lr=args.lr, alpha=0.9, eps=args.opt_eps, + momentum=args.momentum, weight_decay=args.weight_decay) + else: + assert False and "Invalid optimizer" + raise ValueError + return optimizer + + +def create_scheduler(args, optimizer): + num_epochs = args.epochs + if args.sched == 'cosine': + lr_scheduler = scheduler.CosineLRScheduler( + optimizer, + t_initial=num_epochs, + t_mul=1.0, + lr_min=1e-5, + decay_rate=args.decay_rate, + warmup_lr_init=1e-4, + warmup_t=0, + cycle_limit=1, + t_in_epochs=True, + ) + num_epochs = lr_scheduler.get_cycle_length() + 10 + elif args.sched == 'tanh': + lr_scheduler = scheduler.TanhLRScheduler( + optimizer, + t_initial=num_epochs, + t_mul=1.0, + lr_min=1e-5, + warmup_lr_init=.001, + warmup_t=3, + cycle_limit=1, + t_in_epochs=True, + ) + num_epochs = lr_scheduler.get_cycle_length() + 10 + else: + lr_scheduler = scheduler.StepLRScheduler( + optimizer, + decay_t=args.decay_epochs, + decay_rate=args.decay_rate, + ) + return lr_scheduler, num_epochs + + if __name__ == '__main__': main()