add k_tfm

pull/405/head
KaiyangZhou 2020-06-23 15:21:11 +01:00
parent d18b3ad99a
commit a5ee40744c
3 changed files with 30 additions and 2 deletions

View File

@ -23,10 +23,11 @@ def get_default_config():
cfg.data.width = 128 # image width
cfg.data.combineall = False # combine train, query and gallery for training
cfg.data.transforms = ['random_flip'] # data augmentation
cfg.data.k_tfm = 1 # number of times to apply augmentation to an image independently
cfg.data.norm_mean = [0.485, 0.456, 0.406] # default is imagenet mean
cfg.data.norm_std = [0.229, 0.224, 0.225] # default is imagenet std
cfg.data.save_dir = 'log' # path to save log
cfg.data.load_train_targets = False
cfg.data.load_train_targets = False # load training set from target dataset
# specific datasets
cfg.market1501 = CN()
@ -114,6 +115,7 @@ def imagedata_kwargs(cfg):
'height': cfg.data.height,
'width': cfg.data.width,
'transforms': cfg.data.transforms,
'k_tfm': cfg.data.k_tfm,
'norm_mean': cfg.data.norm_mean,
'norm_std': cfg.data.norm_std,
'use_gpu': cfg.use_gpu,

View File

@ -98,6 +98,10 @@ class ImageDataManager(DataManager):
width (int, optional): target image width. Default is 128.
transforms (str or list of str, optional): transformations applied to model training.
Default is 'random_flip'.
k_tfm (int): number of times to apply augmentation to an image
independently. If k_tfm > 1, the transform function will be
applied k_tfm times to an image. This variable will only be
useful for training and is currently valid for image datasets only.
norm_mean (list or None, optional): data mean. Default is None (use imagenet mean).
norm_std (list or None, optional): data std. Default is None (use imagenet std).
use_gpu (bool, optional): use gpu. Default is True.
@ -150,6 +154,7 @@ class ImageDataManager(DataManager):
height=256,
width=128,
transforms='random_flip',
k_tfm=1,
norm_mean=None,
norm_std=None,
use_gpu=True,
@ -184,6 +189,7 @@ class ImageDataManager(DataManager):
trainset_ = init_image_dataset(
name,
transform=self.transform_tr,
k_tfm=k_tfm,
mode='train',
combineall=combineall,
root=root,
@ -225,6 +231,7 @@ class ImageDataManager(DataManager):
trainset_t_ = init_image_dataset(
name,
transform=self.transform_tr,
k_tfm=k_tfm,
mode='train',
combineall=False, # only use the training data
root=root,

View File

@ -19,6 +19,10 @@ class Dataset(object):
query (list): contains tuples of (img_path(s), pid, camid).
gallery (list): contains tuples of (img_path(s), pid, camid).
transform: transform function.
k_tfm (int): number of times to apply augmentation to an image
independently. If k_tfm > 1, the transform function will be
applied k_tfm times to an image. This variable will only be
useful for training and is currently valid for image datasets only.
mode (str): 'train', 'query' or 'gallery'.
combineall (bool): combines train, query and gallery in a
dataset for training.
@ -33,6 +37,7 @@ class Dataset(object):
query,
gallery,
transform=None,
k_tfm=1,
mode='train',
combineall=False,
verbose=True,
@ -42,6 +47,7 @@ class Dataset(object):
self.query = query
self.gallery = gallery
self.transform = transform
self.k_tfm = k_tfm
self.mode = mode
self.combineall = combineall
self.verbose = verbose
@ -245,6 +251,19 @@ class Dataset(object):
return msg
def _transform_image(self, tfm, k_tfm, img0):
"""Transform a raw image (img0) k_tfm times."""
img_list = []
for k in range(k_tfm):
img_list.append(tfm(img0))
img = img_list
if len(img) == 1:
img = img[0]
return img
class ImageDataset(Dataset):
"""A base class representing ImageDataset.
@ -264,7 +283,7 @@ class ImageDataset(Dataset):
img_path, pid, camid = self.data[index]
img = read_image(img_path)
if self.transform is not None:
img = self.transform(img)
img = self._transform_image(self.transform, self.k_tfm, img)
item = {'img': img, 'pid': pid, 'camid': camid, 'impath': img_path}
return item