add k_tfm
parent
d18b3ad99a
commit
a5ee40744c
|
@ -23,10 +23,11 @@ def get_default_config():
|
|||
cfg.data.width = 128 # image width
|
||||
cfg.data.combineall = False # combine train, query and gallery for training
|
||||
cfg.data.transforms = ['random_flip'] # data augmentation
|
||||
cfg.data.k_tfm = 1 # number of times to apply augmentation to an image independently
|
||||
cfg.data.norm_mean = [0.485, 0.456, 0.406] # default is imagenet mean
|
||||
cfg.data.norm_std = [0.229, 0.224, 0.225] # default is imagenet std
|
||||
cfg.data.save_dir = 'log' # path to save log
|
||||
cfg.data.load_train_targets = False
|
||||
cfg.data.load_train_targets = False # load training set from target dataset
|
||||
|
||||
# specific datasets
|
||||
cfg.market1501 = CN()
|
||||
|
@ -114,6 +115,7 @@ def imagedata_kwargs(cfg):
|
|||
'height': cfg.data.height,
|
||||
'width': cfg.data.width,
|
||||
'transforms': cfg.data.transforms,
|
||||
'k_tfm': cfg.data.k_tfm,
|
||||
'norm_mean': cfg.data.norm_mean,
|
||||
'norm_std': cfg.data.norm_std,
|
||||
'use_gpu': cfg.use_gpu,
|
||||
|
|
|
@ -98,6 +98,10 @@ class ImageDataManager(DataManager):
|
|||
width (int, optional): target image width. Default is 128.
|
||||
transforms (str or list of str, optional): transformations applied to model training.
|
||||
Default is 'random_flip'.
|
||||
k_tfm (int): number of times to apply augmentation to an image
|
||||
independently. If k_tfm > 1, the transform function will be
|
||||
applied k_tfm times to an image. This variable will only be
|
||||
useful for training and is currently valid for image datasets only.
|
||||
norm_mean (list or None, optional): data mean. Default is None (use imagenet mean).
|
||||
norm_std (list or None, optional): data std. Default is None (use imagenet std).
|
||||
use_gpu (bool, optional): use gpu. Default is True.
|
||||
|
@ -150,6 +154,7 @@ class ImageDataManager(DataManager):
|
|||
height=256,
|
||||
width=128,
|
||||
transforms='random_flip',
|
||||
k_tfm=1,
|
||||
norm_mean=None,
|
||||
norm_std=None,
|
||||
use_gpu=True,
|
||||
|
@ -184,6 +189,7 @@ class ImageDataManager(DataManager):
|
|||
trainset_ = init_image_dataset(
|
||||
name,
|
||||
transform=self.transform_tr,
|
||||
k_tfm=k_tfm,
|
||||
mode='train',
|
||||
combineall=combineall,
|
||||
root=root,
|
||||
|
@ -225,6 +231,7 @@ class ImageDataManager(DataManager):
|
|||
trainset_t_ = init_image_dataset(
|
||||
name,
|
||||
transform=self.transform_tr,
|
||||
k_tfm=k_tfm,
|
||||
mode='train',
|
||||
combineall=False, # only use the training data
|
||||
root=root,
|
||||
|
|
|
@ -19,6 +19,10 @@ class Dataset(object):
|
|||
query (list): contains tuples of (img_path(s), pid, camid).
|
||||
gallery (list): contains tuples of (img_path(s), pid, camid).
|
||||
transform: transform function.
|
||||
k_tfm (int): number of times to apply augmentation to an image
|
||||
independently. If k_tfm > 1, the transform function will be
|
||||
applied k_tfm times to an image. This variable will only be
|
||||
useful for training and is currently valid for image datasets only.
|
||||
mode (str): 'train', 'query' or 'gallery'.
|
||||
combineall (bool): combines train, query and gallery in a
|
||||
dataset for training.
|
||||
|
@ -33,6 +37,7 @@ class Dataset(object):
|
|||
query,
|
||||
gallery,
|
||||
transform=None,
|
||||
k_tfm=1,
|
||||
mode='train',
|
||||
combineall=False,
|
||||
verbose=True,
|
||||
|
@ -42,6 +47,7 @@ class Dataset(object):
|
|||
self.query = query
|
||||
self.gallery = gallery
|
||||
self.transform = transform
|
||||
self.k_tfm = k_tfm
|
||||
self.mode = mode
|
||||
self.combineall = combineall
|
||||
self.verbose = verbose
|
||||
|
@ -245,6 +251,19 @@ class Dataset(object):
|
|||
|
||||
return msg
|
||||
|
||||
def _transform_image(self, tfm, k_tfm, img0):
|
||||
"""Transform a raw image (img0) k_tfm times."""
|
||||
img_list = []
|
||||
|
||||
for k in range(k_tfm):
|
||||
img_list.append(tfm(img0))
|
||||
|
||||
img = img_list
|
||||
if len(img) == 1:
|
||||
img = img[0]
|
||||
|
||||
return img
|
||||
|
||||
|
||||
class ImageDataset(Dataset):
|
||||
"""A base class representing ImageDataset.
|
||||
|
@ -264,7 +283,7 @@ class ImageDataset(Dataset):
|
|||
img_path, pid, camid = self.data[index]
|
||||
img = read_image(img_path)
|
||||
if self.transform is not None:
|
||||
img = self.transform(img)
|
||||
img = self._transform_image(self.transform, self.k_tfm, img)
|
||||
item = {'img': img, 'pid': pid, 'camid': camid, 'impath': img_path}
|
||||
return item
|
||||
|
||||
|
|
Loading…
Reference in New Issue