add argument 'transforms' to datamanager

pull/201/head
kaiyangzhou 2019-07-03 12:05:55 +01:00
parent c1dd1b6f24
commit 709018d948
3 changed files with 53 additions and 57 deletions

View File

@ -33,16 +33,8 @@ def init_parser():
help='sampler for trainloader')
parser.add_argument('--combineall', action='store_true',
help='combine all data in a dataset (train+query+gallery) for training')
# ************************************************************
# Data augmentation
# ************************************************************
parser.add_argument('--random-erase', action='store_true',
help='use random erasing for data augmentation')
parser.add_argument('--color-jitter', action='store_true',
help='randomly change the brightness, contrast and saturation')
parser.add_argument('--color-aug', action='store_true',
help='randomly alter the intensities of RGB channels')
parser.add_argument('--transforms', type=str, default='random_flip', nargs='+',
help='transformations applied to model training')
# ************************************************************
# Video datasets
@ -203,9 +195,7 @@ def imagedata_kwargs(parsed_args):
'targets': parsed_args.targets,
'height': parsed_args.height,
'width': parsed_args.width,
'random_erase': parsed_args.random_erase,
'color_jitter': parsed_args.color_jitter,
'color_aug': parsed_args.color_aug,
'transforms': parsed_args.transforms,
'use_cpu': parsed_args.use_cpu,
'split_id': parsed_args.split_id,
'combineall': parsed_args.combineall,
@ -227,9 +217,7 @@ def videodata_kwargs(parsed_args):
'targets': parsed_args.targets,
'height': parsed_args.height,
'width': parsed_args.width,
'random_erase': parsed_args.random_erase,
'color_jitter': parsed_args.color_jitter,
'color_aug': parsed_args.color_aug,
'transforms': parsed_args.transforms,
'use_cpu': parsed_args.use_cpu,
'split_id': parsed_args.split_id,
'combineall': parsed_args.combineall,

View File

@ -10,7 +10,7 @@ from torchreid.data.datasets import init_image_dataset, init_video_dataset
class DataManager(object):
"""Base data manager.
r"""Base data manager.
Args:
sources (str or list): source dataset(s).
@ -18,14 +18,13 @@ class DataManager(object):
it equals to ``sources``.
height (int, optional): target image height. Default is 256.
width (int, optional): target image width. Default is 128.
random_erase (bool, optional): use random erasing. Default is False.
color_jitter (bool, optional): use color jittering. Default is False.
color_aug (bool, optional): use color augmentation. Default is False.
transforms (str or list of str, optional): transformations applied to model training.
Default is 'random_flip'.
use_cpu (bool, optional): use cpu. Default is False.
"""
def __init__(self, sources=None, targets=None, height=256, width=128, random_erase=False,
color_jitter=False, color_aug=False, use_cpu=False):
def __init__(self, sources=None, targets=None, height=256, width=128, transforms='random_flip',
use_cpu=False):
self.sources = sources
self.targets = targets
@ -42,10 +41,7 @@ class DataManager(object):
self.targets = [self.targets]
self.transform_tr, self.transform_te = build_transforms(
height, width,
random_erase=random_erase,
color_jitter=color_jitter,
color_aug=color_aug
height, width, transforms
)
self.use_gpu = (torch.cuda.is_available() and not use_cpu)
@ -75,7 +71,7 @@ class DataManager(object):
class ImageDataManager(DataManager):
"""Image data manager.
r"""Image data manager.
Args:
root (str): root path to datasets.
@ -84,9 +80,8 @@ class ImageDataManager(DataManager):
it equals to ``sources``.
height (int, optional): target image height. Default is 256.
width (int, optional): target image width. Default is 128.
random_erase (bool, optional): use random erasing. Default is False.
color_jitter (bool, optional): use color jittering. Default is False.
color_aug (bool, optional): use color augmentation. Default is False.
transforms (str or list of str, optional): transformations applied to model training.
Default is 'random_flip'.
use_cpu (bool, optional): use cpu. Default is False.
split_id (int, optional): split id (*0-based*). Default is 0.
combineall (bool, optional): combine train, query and gallery in a dataset for
@ -114,14 +109,13 @@ class ImageDataManager(DataManager):
)
"""
def __init__(self, root='', sources=None, targets=None, height=256, width=128, random_erase=False,
color_jitter=False, color_aug=False, use_cpu=False, split_id=0, combineall=False,
def __init__(self, root='', sources=None, targets=None, height=256, width=128, transforms='random_flip',
use_cpu=False, split_id=0, combineall=False,
batch_size=32, workers=4, num_instances=4, train_sampler='',
cuhk03_labeled=False, cuhk03_classic_split=False, market1501_500k=False):
super(ImageDataManager, self).__init__(sources=sources, targets=targets, height=height, width=width,
random_erase=random_erase, color_jitter=color_jitter,
color_aug=color_aug, use_cpu=use_cpu)
transforms=transforms, use_cpu=use_cpu)
print('=> Loading train (source) dataset')
trainset = []
@ -223,7 +217,7 @@ class ImageDataManager(DataManager):
class VideoDataManager(DataManager):
"""Video data manager.
r"""Video data manager.
Args:
root (str): root path to datasets.
@ -232,9 +226,8 @@ class VideoDataManager(DataManager):
it equals to ``sources``.
height (int, optional): target image height. Default is 256.
width (int, optional): target image width. Default is 128.
random_erase (bool, optional): use random erasing. Default is False.
color_jitter (bool, optional): use color jittering. Default is False.
color_aug (bool, optional): use color augmentation. Default is False.
transforms (str or list of str, optional): transformations applied to model training.
Default is 'random_flip'.
use_cpu (bool, optional): use cpu. Default is False.
split_id (int, optional): split id (*0-based*). Default is 0.
combineall (bool, optional): combine train, query and gallery in a dataset for
@ -263,14 +256,13 @@ class VideoDataManager(DataManager):
)
"""
def __init__(self, root='', sources=None, targets=None, height=256, width=128, random_erase=False,
color_jitter=False, color_aug=False, use_cpu=False, split_id=0, combineall=False,
def __init__(self, root='', sources=None, targets=None, height=256, width=128, transforms='random_flip',
use_cpu=False, split_id=0, combineall=False,
batch_size=3, workers=4, num_instances=4, train_sampler=None,
seq_len=15, sample_method='evenly'):
super(VideoDataManager, self).__init__(sources=sources, targets=targets, height=height, width=width,
random_erase=random_erase, color_jitter=color_jitter,
color_aug=color_aug, use_cpu=use_cpu)
transforms=transforms, se_cpu=use_cpu)
print('=> Loading train (source) dataset')
trainset = []

View File

@ -15,7 +15,7 @@ class Random2DTranslation(object):
"""Randomly translates the input image with a probability.
Specifically, given a predefined shape (height, width), the input is first
resized with a factor of 1.25, leading to (height*1.25, width*1.25), then
resized with a factor of 1.125, leading to (height*1.125, width*1.125), then
a random crop is performed. Such operation is done with a probability.
Args:
@ -131,38 +131,54 @@ class ColorAugmentation(object):
return tensor
def build_transforms(height, width, random_erase=False, color_jitter=False,
color_aug=False, norm_mean=[0.485, 0.456, 0.406],
def build_transforms(height, width, transforms='random_flip', norm_mean=[0.485, 0.456, 0.406],
norm_std=[0.229, 0.224, 0.225], **kwargs):
"""Builds train and test transform functions
"""Builds train and test transform functions.
Args:
height (int): target image height.
width (int): target image width.
random_erase (bool, optional): use random erasing. Default is False.
color_jitter (bool, optional): use color jittering. Default is False.
color_aug (bool, optional): use color augmentation. Default is False.
transforms (str or list of str, optional): transformations applied to model training.
Default is 'random_flip'.
norm_mean (list): normalization mean values. Default is ImageNet means.
norm_std (list): normalization standard deviation values. Default is
ImageNet standard deviation values.
"""
if isinstance(transforms, str):
transforms = [transforms]
if not isinstance(transforms, list):
raise ValueError('transforms must be a list of strings, but found to be {}'.format(type(transforms)))
transforms = [t.lower() for t in transforms]
normalize = Normalize(mean=norm_mean, std=norm_std)
# build train transformations
print('Building train transforms ...')
transform_tr = []
transform_tr += [Random2DTranslation(height, width)]
transform_tr += [RandomHorizontalFlip()]
if color_jitter:
if 'random_flip' in transforms:
print('+ random flip')
transform_tr += [RandomHorizontalFlip()]
if 'random_crop' in transforms:
print('+ random crop (enlarge to {}x{} and ' \
'crop {}x{})'.format(int(round(height*1.125)), int(round(width*1.125)), height, width))
transform_tr += [Random2DTranslation(height, width)]
if 'color_jitter' in transforms:
print('+ color jitter')
transform_tr += [ColorJitter(brightness=0.2, contrast=0.15, saturation=0, hue=0)]
print('+ to torch tensor of range [0, 1]')
transform_tr += [ToTensor()]
if color_aug:
transform_tr += [ColorAugmentation()]
print('+ normalization (mean={}, std={})'.format(norm_mean, norm_std))
transform_tr += [normalize]
if random_erase:
if 'random_erase' in transforms:
print('+ random erase')
transform_tr += [RandomErasing()]
transform_tr = Compose(transform_tr)
# build test transformations
print('Building test transforms ...')
print('+ resize to {}x{}'.format(height, width))
print('+ to torch tensor of range [0, 1]')
print('+ normalization (mean={}, std={})'.format(norm_mean, norm_std))
transform_te = Compose([
Resize((height, width)),
ToTensor(),