fast-reid/fastreid/data/transforms/build.py

60 lines
1.5 KiB
Python
Raw Normal View History

2020-02-10 07:38:56 +08:00
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
import torchvision.transforms as T
from .transforms import *
def build_transforms(cfg, is_train=True):
res = []
if is_train:
2020-02-10 22:13:04 +08:00
size_train = cfg.INPUT.SIZE_TRAIN
# horizontal filp
2020-02-10 22:13:04 +08:00
do_flip = cfg.INPUT.DO_FLIP
flip_prob = cfg.INPUT.FLIP_PROB
2020-02-10 22:13:04 +08:00
# padding
do_pad = cfg.INPUT.DO_PAD
padding = cfg.INPUT.PADDING
padding_mode = cfg.INPUT.PADDING_MODE
# augmix augmentation
do_augmix = cfg.INPUT.DO_AUGMIX
# color jitter
do_cj = cfg.INPUT.DO_CJ
2020-02-10 22:13:04 +08:00
# random erasing
do_rea = cfg.INPUT.REA.ENABLED
rea_prob = cfg.INPUT.REA.PROB
rea_mean = cfg.INPUT.REA.MEAN
# random patch
do_rpt = cfg.INPUT.RPT.ENABLED
rpt_prob = cfg.INPUT.RPT.PROB
res.append(T.Resize(size_train, interpolation=3))
2020-02-10 22:13:04 +08:00
if do_flip:
res.append(T.RandomHorizontalFlip(p=flip_prob))
if do_pad:
res.extend([T.Pad(padding, padding_mode=padding_mode),
T.RandomCrop(size_train)])
if do_cj:
res.append(ColorJitter())
if do_augmix:
res.append(AugMix())
if do_rea:
res.append(RandomErasing(probability=rea_prob, mean=rea_mean))
if do_rpt:
res.append(RandomPatch(prob_happen=rpt_prob))
2020-02-10 07:38:56 +08:00
else:
2020-02-10 22:13:04 +08:00
size_test = cfg.INPUT.SIZE_TEST
res.append(T.Resize(size_test, interpolation=3))
res.append(ToTensor())
2020-02-10 07:38:56 +08:00
return T.Compose(res)