2020-02-10 07:38:56 +08:00
|
|
|
# encoding: utf-8
|
|
|
|
"""
|
|
|
|
@author: liaoxingyu
|
|
|
|
@contact: sherlockliao01@gmail.com
|
|
|
|
"""
|
|
|
|
|
|
|
|
import torchvision.transforms as T
|
|
|
|
|
|
|
|
from .transforms import *
|
|
|
|
|
|
|
|
|
|
|
|
def build_transforms(cfg, is_train=True):
|
|
|
|
res = []
|
|
|
|
|
|
|
|
if is_train:
|
2020-02-10 22:13:04 +08:00
|
|
|
size_train = cfg.INPUT.SIZE_TRAIN
|
2020-04-05 23:54:26 +08:00
|
|
|
|
|
|
|
# horizontal filp
|
2020-02-10 22:13:04 +08:00
|
|
|
do_flip = cfg.INPUT.DO_FLIP
|
|
|
|
flip_prob = cfg.INPUT.FLIP_PROB
|
2020-04-05 23:54:26 +08:00
|
|
|
|
2020-02-10 22:13:04 +08:00
|
|
|
# padding
|
|
|
|
do_pad = cfg.INPUT.DO_PAD
|
|
|
|
padding = cfg.INPUT.PADDING
|
|
|
|
padding_mode = cfg.INPUT.PADDING_MODE
|
2020-04-05 23:54:26 +08:00
|
|
|
|
|
|
|
# augmix augmentation
|
|
|
|
do_augmix = cfg.INPUT.DO_AUGMIX
|
|
|
|
|
|
|
|
# color jitter
|
|
|
|
do_cj = cfg.INPUT.DO_CJ
|
|
|
|
|
2020-02-10 22:13:04 +08:00
|
|
|
# random erasing
|
2020-04-05 23:54:26 +08:00
|
|
|
do_rea = cfg.INPUT.REA.ENABLED
|
|
|
|
rea_prob = cfg.INPUT.REA.PROB
|
|
|
|
rea_mean = cfg.INPUT.REA.MEAN
|
|
|
|
# random patch
|
|
|
|
do_rpt = cfg.INPUT.RPT.ENABLED
|
|
|
|
rpt_prob = cfg.INPUT.RPT.PROB
|
|
|
|
|
2020-02-18 21:01:23 +08:00
|
|
|
res.append(T.Resize(size_train, interpolation=3))
|
2020-02-10 22:13:04 +08:00
|
|
|
if do_flip:
|
|
|
|
res.append(T.RandomHorizontalFlip(p=flip_prob))
|
|
|
|
if do_pad:
|
|
|
|
res.extend([T.Pad(padding, padding_mode=padding_mode),
|
|
|
|
T.RandomCrop(size_train)])
|
2020-04-05 23:54:26 +08:00
|
|
|
if do_cj:
|
|
|
|
res.append(ColorJitter())
|
|
|
|
if do_augmix:
|
|
|
|
res.append(AugMix())
|
|
|
|
if do_rea:
|
|
|
|
res.append(RandomErasing(probability=rea_prob, mean=rea_mean))
|
|
|
|
if do_rpt:
|
|
|
|
res.append(RandomPatch(prob_happen=rpt_prob))
|
2020-02-10 07:38:56 +08:00
|
|
|
else:
|
2020-02-10 22:13:04 +08:00
|
|
|
size_test = cfg.INPUT.SIZE_TEST
|
2020-02-18 21:01:23 +08:00
|
|
|
res.append(T.Resize(size_test, interpolation=3))
|
|
|
|
res.append(ToTensor())
|
2020-02-10 07:38:56 +08:00
|
|
|
return T.Compose(res)
|