From 709018d9481774f442cfa05cddf031e6ba7fcdcd Mon Sep 17 00:00:00 2001 From: kaiyangzhou Date: Wed, 3 Jul 2019 12:05:55 +0100 Subject: [PATCH] add argument 'transforms' to datamanager --- scripts/default_parser.py | 20 +++------------ torchreid/data/datamanager.py | 44 ++++++++++++++------------------- torchreid/data/transforms.py | 46 +++++++++++++++++++++++------------ 3 files changed, 53 insertions(+), 57 deletions(-) diff --git a/scripts/default_parser.py b/scripts/default_parser.py index d98a7c1..a54b096 100644 --- a/scripts/default_parser.py +++ b/scripts/default_parser.py @@ -33,16 +33,8 @@ def init_parser(): help='sampler for trainloader') parser.add_argument('--combineall', action='store_true', help='combine all data in a dataset (train+query+gallery) for training') - - # ************************************************************ - # Data augmentation - # ************************************************************ - parser.add_argument('--random-erase', action='store_true', - help='use random erasing for data augmentation') - parser.add_argument('--color-jitter', action='store_true', - help='randomly change the brightness, contrast and saturation') - parser.add_argument('--color-aug', action='store_true', - help='randomly alter the intensities of RGB channels') + parser.add_argument('--transforms', type=str, default='random_flip', nargs='+', + help='transformations applied to model training') # ************************************************************ # Video datasets @@ -203,9 +195,7 @@ def imagedata_kwargs(parsed_args): 'targets': parsed_args.targets, 'height': parsed_args.height, 'width': parsed_args.width, - 'random_erase': parsed_args.random_erase, - 'color_jitter': parsed_args.color_jitter, - 'color_aug': parsed_args.color_aug, + 'transforms': parsed_args.transforms, 'use_cpu': parsed_args.use_cpu, 'split_id': parsed_args.split_id, 'combineall': parsed_args.combineall, @@ -227,9 +217,7 @@ def videodata_kwargs(parsed_args): 'targets': parsed_args.targets, 'height': parsed_args.height, 'width': parsed_args.width, - 'random_erase': parsed_args.random_erase, - 'color_jitter': parsed_args.color_jitter, - 'color_aug': parsed_args.color_aug, + 'transforms': parsed_args.transforms, 'use_cpu': parsed_args.use_cpu, 'split_id': parsed_args.split_id, 'combineall': parsed_args.combineall, diff --git a/torchreid/data/datamanager.py b/torchreid/data/datamanager.py index 001d2b4..f353254 100644 --- a/torchreid/data/datamanager.py +++ b/torchreid/data/datamanager.py @@ -10,7 +10,7 @@ from torchreid.data.datasets import init_image_dataset, init_video_dataset class DataManager(object): - """Base data manager. + r"""Base data manager. Args: sources (str or list): source dataset(s). @@ -18,14 +18,13 @@ class DataManager(object): it equals to ``sources``. height (int, optional): target image height. Default is 256. width (int, optional): target image width. Default is 128. - random_erase (bool, optional): use random erasing. Default is False. - color_jitter (bool, optional): use color jittering. Default is False. - color_aug (bool, optional): use color augmentation. Default is False. + transforms (str or list of str, optional): transformations applied to model training. + Default is 'random_flip'. use_cpu (bool, optional): use cpu. Default is False. """ - def __init__(self, sources=None, targets=None, height=256, width=128, random_erase=False, - color_jitter=False, color_aug=False, use_cpu=False): + def __init__(self, sources=None, targets=None, height=256, width=128, transforms='random_flip', + use_cpu=False): self.sources = sources self.targets = targets @@ -42,10 +41,7 @@ class DataManager(object): self.targets = [self.targets] self.transform_tr, self.transform_te = build_transforms( - height, width, - random_erase=random_erase, - color_jitter=color_jitter, - color_aug=color_aug + height, width, transforms ) self.use_gpu = (torch.cuda.is_available() and not use_cpu) @@ -75,7 +71,7 @@ class DataManager(object): class ImageDataManager(DataManager): - """Image data manager. + r"""Image data manager. Args: root (str): root path to datasets. @@ -84,9 +80,8 @@ class ImageDataManager(DataManager): it equals to ``sources``. height (int, optional): target image height. Default is 256. width (int, optional): target image width. Default is 128. - random_erase (bool, optional): use random erasing. Default is False. - color_jitter (bool, optional): use color jittering. Default is False. - color_aug (bool, optional): use color augmentation. Default is False. + transforms (str or list of str, optional): transformations applied to model training. + Default is 'random_flip'. use_cpu (bool, optional): use cpu. Default is False. split_id (int, optional): split id (*0-based*). Default is 0. combineall (bool, optional): combine train, query and gallery in a dataset for @@ -114,14 +109,13 @@ class ImageDataManager(DataManager): ) """ - def __init__(self, root='', sources=None, targets=None, height=256, width=128, random_erase=False, - color_jitter=False, color_aug=False, use_cpu=False, split_id=0, combineall=False, + def __init__(self, root='', sources=None, targets=None, height=256, width=128, transforms='random_flip', + use_cpu=False, split_id=0, combineall=False, batch_size=32, workers=4, num_instances=4, train_sampler='', cuhk03_labeled=False, cuhk03_classic_split=False, market1501_500k=False): super(ImageDataManager, self).__init__(sources=sources, targets=targets, height=height, width=width, - random_erase=random_erase, color_jitter=color_jitter, - color_aug=color_aug, use_cpu=use_cpu) + transforms=transforms, use_cpu=use_cpu) print('=> Loading train (source) dataset') trainset = [] @@ -223,7 +217,7 @@ class ImageDataManager(DataManager): class VideoDataManager(DataManager): - """Video data manager. + r"""Video data manager. Args: root (str): root path to datasets. @@ -232,9 +226,8 @@ class VideoDataManager(DataManager): it equals to ``sources``. height (int, optional): target image height. Default is 256. width (int, optional): target image width. Default is 128. - random_erase (bool, optional): use random erasing. Default is False. - color_jitter (bool, optional): use color jittering. Default is False. - color_aug (bool, optional): use color augmentation. Default is False. + transforms (str or list of str, optional): transformations applied to model training. + Default is 'random_flip'. use_cpu (bool, optional): use cpu. Default is False. split_id (int, optional): split id (*0-based*). Default is 0. combineall (bool, optional): combine train, query and gallery in a dataset for @@ -263,14 +256,13 @@ class VideoDataManager(DataManager): ) """ - def __init__(self, root='', sources=None, targets=None, height=256, width=128, random_erase=False, - color_jitter=False, color_aug=False, use_cpu=False, split_id=0, combineall=False, + def __init__(self, root='', sources=None, targets=None, height=256, width=128, transforms='random_flip', + use_cpu=False, split_id=0, combineall=False, batch_size=3, workers=4, num_instances=4, train_sampler=None, seq_len=15, sample_method='evenly'): super(VideoDataManager, self).__init__(sources=sources, targets=targets, height=height, width=width, - random_erase=random_erase, color_jitter=color_jitter, - color_aug=color_aug, use_cpu=use_cpu) + transforms=transforms, se_cpu=use_cpu) print('=> Loading train (source) dataset') trainset = [] diff --git a/torchreid/data/transforms.py b/torchreid/data/transforms.py index d4724f2..afacdd6 100644 --- a/torchreid/data/transforms.py +++ b/torchreid/data/transforms.py @@ -15,7 +15,7 @@ class Random2DTranslation(object): """Randomly translates the input image with a probability. Specifically, given a predefined shape (height, width), the input is first - resized with a factor of 1.25, leading to (height*1.25, width*1.25), then + resized with a factor of 1.125, leading to (height*1.125, width*1.125), then a random crop is performed. Such operation is done with a probability. Args: @@ -131,38 +131,54 @@ class ColorAugmentation(object): return tensor -def build_transforms(height, width, random_erase=False, color_jitter=False, - color_aug=False, norm_mean=[0.485, 0.456, 0.406], +def build_transforms(height, width, transforms='random_flip', norm_mean=[0.485, 0.456, 0.406], norm_std=[0.229, 0.224, 0.225], **kwargs): - """Builds train and test transform functions + """Builds train and test transform functions. Args: height (int): target image height. width (int): target image width. - random_erase (bool, optional): use random erasing. Default is False. - color_jitter (bool, optional): use color jittering. Default is False. - color_aug (bool, optional): use color augmentation. Default is False. + transforms (str or list of str, optional): transformations applied to model training. + Default is 'random_flip'. norm_mean (list): normalization mean values. Default is ImageNet means. norm_std (list): normalization standard deviation values. Default is ImageNet standard deviation values. """ + if isinstance(transforms, str): + transforms = [transforms] + + if not isinstance(transforms, list): + raise ValueError('transforms must be a list of strings, but found to be {}'.format(type(transforms))) + + transforms = [t.lower() for t in transforms] + normalize = Normalize(mean=norm_mean, std=norm_std) - # build train transformations + print('Building train transforms ...') transform_tr = [] - transform_tr += [Random2DTranslation(height, width)] - transform_tr += [RandomHorizontalFlip()] - if color_jitter: + if 'random_flip' in transforms: + print('+ random flip') + transform_tr += [RandomHorizontalFlip()] + if 'random_crop' in transforms: + print('+ random crop (enlarge to {}x{} and ' \ + 'crop {}x{})'.format(int(round(height*1.125)), int(round(width*1.125)), height, width)) + transform_tr += [Random2DTranslation(height, width)] + if 'color_jitter' in transforms: + print('+ color jitter') transform_tr += [ColorJitter(brightness=0.2, contrast=0.15, saturation=0, hue=0)] + print('+ to torch tensor of range [0, 1]') transform_tr += [ToTensor()] - if color_aug: - transform_tr += [ColorAugmentation()] + print('+ normalization (mean={}, std={})'.format(norm_mean, norm_std)) transform_tr += [normalize] - if random_erase: + if 'random_erase' in transforms: + print('+ random erase') transform_tr += [RandomErasing()] transform_tr = Compose(transform_tr) - # build test transformations + print('Building test transforms ...') + print('+ resize to {}x{}'.format(height, width)) + print('+ to torch tensor of range [0, 1]') + print('+ normalization (mean={}, std={})'.format(norm_mean, norm_std)) transform_te = Compose([ Resize((height, width)), ToTensor(),