32 lines
848 B
Python
32 lines
848 B
Python
# encoding: utf-8
|
|
"""
|
|
@author: liaoxingyu
|
|
@contact: liaoxingyu2@jd.com
|
|
"""
|
|
|
|
import torchvision.transforms as T
|
|
|
|
from .transforms import RandomErasing
|
|
|
|
|
|
def build_transforms(cfg, is_train=True):
|
|
normalize_transform = T.Normalize(mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD)
|
|
if is_train:
|
|
transform = T.Compose([
|
|
T.Resize(cfg.INPUT.SIZE_TRAIN),
|
|
T.RandomHorizontalFlip(p=cfg.INPUT.PROB),
|
|
T.Pad(cfg.INPUT.PADDING),
|
|
T.RandomCrop(cfg.INPUT.SIZE_TRAIN),
|
|
T.ToTensor(),
|
|
normalize_transform,
|
|
RandomErasing(probability=cfg.INPUT.RE_PROB, mean=cfg.INPUT.PIXEL_MEAN)
|
|
])
|
|
else:
|
|
transform = T.Compose([
|
|
T.Resize(cfg.INPUT.SIZE_TEST),
|
|
T.ToTensor(),
|
|
normalize_transform
|
|
])
|
|
|
|
return transform
|