add argument 'transforms' to datamanager
parent
c1dd1b6f24
commit
709018d948
scripts
torchreid/data
|
@ -33,16 +33,8 @@ def init_parser():
|
|||
help='sampler for trainloader')
|
||||
parser.add_argument('--combineall', action='store_true',
|
||||
help='combine all data in a dataset (train+query+gallery) for training')
|
||||
|
||||
# ************************************************************
|
||||
# Data augmentation
|
||||
# ************************************************************
|
||||
parser.add_argument('--random-erase', action='store_true',
|
||||
help='use random erasing for data augmentation')
|
||||
parser.add_argument('--color-jitter', action='store_true',
|
||||
help='randomly change the brightness, contrast and saturation')
|
||||
parser.add_argument('--color-aug', action='store_true',
|
||||
help='randomly alter the intensities of RGB channels')
|
||||
parser.add_argument('--transforms', type=str, default='random_flip', nargs='+',
|
||||
help='transformations applied to model training')
|
||||
|
||||
# ************************************************************
|
||||
# Video datasets
|
||||
|
@ -203,9 +195,7 @@ def imagedata_kwargs(parsed_args):
|
|||
'targets': parsed_args.targets,
|
||||
'height': parsed_args.height,
|
||||
'width': parsed_args.width,
|
||||
'random_erase': parsed_args.random_erase,
|
||||
'color_jitter': parsed_args.color_jitter,
|
||||
'color_aug': parsed_args.color_aug,
|
||||
'transforms': parsed_args.transforms,
|
||||
'use_cpu': parsed_args.use_cpu,
|
||||
'split_id': parsed_args.split_id,
|
||||
'combineall': parsed_args.combineall,
|
||||
|
@ -227,9 +217,7 @@ def videodata_kwargs(parsed_args):
|
|||
'targets': parsed_args.targets,
|
||||
'height': parsed_args.height,
|
||||
'width': parsed_args.width,
|
||||
'random_erase': parsed_args.random_erase,
|
||||
'color_jitter': parsed_args.color_jitter,
|
||||
'color_aug': parsed_args.color_aug,
|
||||
'transforms': parsed_args.transforms,
|
||||
'use_cpu': parsed_args.use_cpu,
|
||||
'split_id': parsed_args.split_id,
|
||||
'combineall': parsed_args.combineall,
|
||||
|
|
|
@ -10,7 +10,7 @@ from torchreid.data.datasets import init_image_dataset, init_video_dataset
|
|||
|
||||
|
||||
class DataManager(object):
|
||||
"""Base data manager.
|
||||
r"""Base data manager.
|
||||
|
||||
Args:
|
||||
sources (str or list): source dataset(s).
|
||||
|
@ -18,14 +18,13 @@ class DataManager(object):
|
|||
it equals to ``sources``.
|
||||
height (int, optional): target image height. Default is 256.
|
||||
width (int, optional): target image width. Default is 128.
|
||||
random_erase (bool, optional): use random erasing. Default is False.
|
||||
color_jitter (bool, optional): use color jittering. Default is False.
|
||||
color_aug (bool, optional): use color augmentation. Default is False.
|
||||
transforms (str or list of str, optional): transformations applied to model training.
|
||||
Default is 'random_flip'.
|
||||
use_cpu (bool, optional): use cpu. Default is False.
|
||||
"""
|
||||
|
||||
def __init__(self, sources=None, targets=None, height=256, width=128, random_erase=False,
|
||||
color_jitter=False, color_aug=False, use_cpu=False):
|
||||
def __init__(self, sources=None, targets=None, height=256, width=128, transforms='random_flip',
|
||||
use_cpu=False):
|
||||
self.sources = sources
|
||||
self.targets = targets
|
||||
|
||||
|
@ -42,10 +41,7 @@ class DataManager(object):
|
|||
self.targets = [self.targets]
|
||||
|
||||
self.transform_tr, self.transform_te = build_transforms(
|
||||
height, width,
|
||||
random_erase=random_erase,
|
||||
color_jitter=color_jitter,
|
||||
color_aug=color_aug
|
||||
height, width, transforms
|
||||
)
|
||||
|
||||
self.use_gpu = (torch.cuda.is_available() and not use_cpu)
|
||||
|
@ -75,7 +71,7 @@ class DataManager(object):
|
|||
|
||||
|
||||
class ImageDataManager(DataManager):
|
||||
"""Image data manager.
|
||||
r"""Image data manager.
|
||||
|
||||
Args:
|
||||
root (str): root path to datasets.
|
||||
|
@ -84,9 +80,8 @@ class ImageDataManager(DataManager):
|
|||
it equals to ``sources``.
|
||||
height (int, optional): target image height. Default is 256.
|
||||
width (int, optional): target image width. Default is 128.
|
||||
random_erase (bool, optional): use random erasing. Default is False.
|
||||
color_jitter (bool, optional): use color jittering. Default is False.
|
||||
color_aug (bool, optional): use color augmentation. Default is False.
|
||||
transforms (str or list of str, optional): transformations applied to model training.
|
||||
Default is 'random_flip'.
|
||||
use_cpu (bool, optional): use cpu. Default is False.
|
||||
split_id (int, optional): split id (*0-based*). Default is 0.
|
||||
combineall (bool, optional): combine train, query and gallery in a dataset for
|
||||
|
@ -114,14 +109,13 @@ class ImageDataManager(DataManager):
|
|||
)
|
||||
"""
|
||||
|
||||
def __init__(self, root='', sources=None, targets=None, height=256, width=128, random_erase=False,
|
||||
color_jitter=False, color_aug=False, use_cpu=False, split_id=0, combineall=False,
|
||||
def __init__(self, root='', sources=None, targets=None, height=256, width=128, transforms='random_flip',
|
||||
use_cpu=False, split_id=0, combineall=False,
|
||||
batch_size=32, workers=4, num_instances=4, train_sampler='',
|
||||
cuhk03_labeled=False, cuhk03_classic_split=False, market1501_500k=False):
|
||||
|
||||
super(ImageDataManager, self).__init__(sources=sources, targets=targets, height=height, width=width,
|
||||
random_erase=random_erase, color_jitter=color_jitter,
|
||||
color_aug=color_aug, use_cpu=use_cpu)
|
||||
transforms=transforms, use_cpu=use_cpu)
|
||||
|
||||
print('=> Loading train (source) dataset')
|
||||
trainset = []
|
||||
|
@ -223,7 +217,7 @@ class ImageDataManager(DataManager):
|
|||
|
||||
|
||||
class VideoDataManager(DataManager):
|
||||
"""Video data manager.
|
||||
r"""Video data manager.
|
||||
|
||||
Args:
|
||||
root (str): root path to datasets.
|
||||
|
@ -232,9 +226,8 @@ class VideoDataManager(DataManager):
|
|||
it equals to ``sources``.
|
||||
height (int, optional): target image height. Default is 256.
|
||||
width (int, optional): target image width. Default is 128.
|
||||
random_erase (bool, optional): use random erasing. Default is False.
|
||||
color_jitter (bool, optional): use color jittering. Default is False.
|
||||
color_aug (bool, optional): use color augmentation. Default is False.
|
||||
transforms (str or list of str, optional): transformations applied to model training.
|
||||
Default is 'random_flip'.
|
||||
use_cpu (bool, optional): use cpu. Default is False.
|
||||
split_id (int, optional): split id (*0-based*). Default is 0.
|
||||
combineall (bool, optional): combine train, query and gallery in a dataset for
|
||||
|
@ -263,14 +256,13 @@ class VideoDataManager(DataManager):
|
|||
)
|
||||
"""
|
||||
|
||||
def __init__(self, root='', sources=None, targets=None, height=256, width=128, random_erase=False,
|
||||
color_jitter=False, color_aug=False, use_cpu=False, split_id=0, combineall=False,
|
||||
def __init__(self, root='', sources=None, targets=None, height=256, width=128, transforms='random_flip',
|
||||
use_cpu=False, split_id=0, combineall=False,
|
||||
batch_size=3, workers=4, num_instances=4, train_sampler=None,
|
||||
seq_len=15, sample_method='evenly'):
|
||||
|
||||
super(VideoDataManager, self).__init__(sources=sources, targets=targets, height=height, width=width,
|
||||
random_erase=random_erase, color_jitter=color_jitter,
|
||||
color_aug=color_aug, use_cpu=use_cpu)
|
||||
transforms=transforms, se_cpu=use_cpu)
|
||||
|
||||
print('=> Loading train (source) dataset')
|
||||
trainset = []
|
||||
|
|
|
@ -15,7 +15,7 @@ class Random2DTranslation(object):
|
|||
"""Randomly translates the input image with a probability.
|
||||
|
||||
Specifically, given a predefined shape (height, width), the input is first
|
||||
resized with a factor of 1.25, leading to (height*1.25, width*1.25), then
|
||||
resized with a factor of 1.125, leading to (height*1.125, width*1.125), then
|
||||
a random crop is performed. Such operation is done with a probability.
|
||||
|
||||
Args:
|
||||
|
@ -131,38 +131,54 @@ class ColorAugmentation(object):
|
|||
return tensor
|
||||
|
||||
|
||||
def build_transforms(height, width, random_erase=False, color_jitter=False,
|
||||
color_aug=False, norm_mean=[0.485, 0.456, 0.406],
|
||||
def build_transforms(height, width, transforms='random_flip', norm_mean=[0.485, 0.456, 0.406],
|
||||
norm_std=[0.229, 0.224, 0.225], **kwargs):
|
||||
"""Builds train and test transform functions
|
||||
"""Builds train and test transform functions.
|
||||
|
||||
Args:
|
||||
height (int): target image height.
|
||||
width (int): target image width.
|
||||
random_erase (bool, optional): use random erasing. Default is False.
|
||||
color_jitter (bool, optional): use color jittering. Default is False.
|
||||
color_aug (bool, optional): use color augmentation. Default is False.
|
||||
transforms (str or list of str, optional): transformations applied to model training.
|
||||
Default is 'random_flip'.
|
||||
norm_mean (list): normalization mean values. Default is ImageNet means.
|
||||
norm_std (list): normalization standard deviation values. Default is
|
||||
ImageNet standard deviation values.
|
||||
"""
|
||||
if isinstance(transforms, str):
|
||||
transforms = [transforms]
|
||||
|
||||
if not isinstance(transforms, list):
|
||||
raise ValueError('transforms must be a list of strings, but found to be {}'.format(type(transforms)))
|
||||
|
||||
transforms = [t.lower() for t in transforms]
|
||||
|
||||
normalize = Normalize(mean=norm_mean, std=norm_std)
|
||||
|
||||
# build train transformations
|
||||
print('Building train transforms ...')
|
||||
transform_tr = []
|
||||
transform_tr += [Random2DTranslation(height, width)]
|
||||
transform_tr += [RandomHorizontalFlip()]
|
||||
if color_jitter:
|
||||
if 'random_flip' in transforms:
|
||||
print('+ random flip')
|
||||
transform_tr += [RandomHorizontalFlip()]
|
||||
if 'random_crop' in transforms:
|
||||
print('+ random crop (enlarge to {}x{} and ' \
|
||||
'crop {}x{})'.format(int(round(height*1.125)), int(round(width*1.125)), height, width))
|
||||
transform_tr += [Random2DTranslation(height, width)]
|
||||
if 'color_jitter' in transforms:
|
||||
print('+ color jitter')
|
||||
transform_tr += [ColorJitter(brightness=0.2, contrast=0.15, saturation=0, hue=0)]
|
||||
print('+ to torch tensor of range [0, 1]')
|
||||
transform_tr += [ToTensor()]
|
||||
if color_aug:
|
||||
transform_tr += [ColorAugmentation()]
|
||||
print('+ normalization (mean={}, std={})'.format(norm_mean, norm_std))
|
||||
transform_tr += [normalize]
|
||||
if random_erase:
|
||||
if 'random_erase' in transforms:
|
||||
print('+ random erase')
|
||||
transform_tr += [RandomErasing()]
|
||||
transform_tr = Compose(transform_tr)
|
||||
|
||||
# build test transformations
|
||||
print('Building test transforms ...')
|
||||
print('+ resize to {}x{}'.format(height, width))
|
||||
print('+ to torch tensor of range [0, 1]')
|
||||
print('+ normalization (mean={}, std={})'.format(norm_mean, norm_std))
|
||||
transform_te = Compose([
|
||||
Resize((height, width)),
|
||||
ToTensor(),
|
||||
|
|
Loading…
Reference in New Issue