wrap transforms into build_transforms()

pull/119/head
KaiyangZhou 2018-11-01 22:29:19 +00:00
parent e302af017d
commit c388c24880
5 changed files with 61 additions and 55 deletions

View File

@ -2,6 +2,7 @@ from __future__ import absolute_import
from __future__ import division
from torchvision.transforms import *
import torch
from PIL import Image
import random
@ -13,8 +14,8 @@ class Random2DTranslation(object):
With a probability, first increase image size to (1 + 1/8), and then perform random crop.
Args:
- height (int): target height.
- width (int): target width.
- height (int): target image height.
- width (int): target image width.
- p (float): probability of performing this transformation. Default: 0.5.
"""
def __init__(self, height, width, p=0.5, interpolation=Image.BILINEAR):
@ -38,4 +39,49 @@ class Random2DTranslation(object):
x1 = int(round(random.uniform(0, x_maxrange)))
y1 = int(round(random.uniform(0, y_maxrange)))
croped_img = resized_img.crop((x1, y1, x1 + self.width, y1 + self.height))
return croped_img
return croped_img
def build_transforms(height, width, is_train, **kwargs):
"""Build transforms
Args:
- height (int): target image height.
- width (int): target image width.
- is_train (bool): train or test phase.
"""
# use imagenet mean and std as default
imagenet_mean = [0.485, 0.456, 0.406]
imagenet_std = [0.229, 0.224, 0.225]
normalize = Normalize(mean=imagenet_mean, std=imagenet_std)
transforms = []
if is_train:
# build TRAIN transforms
transforms += [Random2DTranslation(height, width)]
transforms += [RandomHorizontalFlip()]
transforms += [ToTensor()]
transforms += [normalize]
else:
# build TEST transforms
if 'five_crop' in kwargs and kwargs['five_crop']:
transforms += [Resize((int(height * 1.125), int(width * 1.125)))]
transforms += [FiveCrop((height, width))]
transforms += [Lambda(lambda crops: torch.stack([normalize(ToTensor()(crop)) for crop in crops]))]
elif 'ten_crop' in kwargs and kwargs['ten_crop']:
transforms += [Resize((int(height * 1.125), int(width * 1.125)))]
transforms += [TenCrop((height, width))]
transforms += [Lambda(lambda crops: torch.stack([normalize(ToTensor()(crop)) for crop in crops]))]
else:
transforms += [Resize((height, width))]
transforms += [ToTensor()]
transforms += [normalize]
transforms = Compose(transforms)
return transforms

View File

@ -17,7 +17,7 @@ from torch.optim import lr_scheduler
from torchreid import data_manager
from torchreid.dataset_loader import ImageDataset
from torchreid import transforms as T
from torchreid.transforms import build_transforms
from torchreid import models
from torchreid.losses import CrossEntropyLoss, DeepSupervision
from torchreid.utils.iotools import save_checkpoint, check_isfile
@ -135,18 +135,8 @@ def main():
cuhk03_labeled=args.cuhk03_labeled, cuhk03_classic_split=args.cuhk03_classic_split,
)
transform_train = T.Compose([
T.Random2DTranslation(args.height, args.width),
T.RandomHorizontalFlip(),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
transform_test = T.Compose([
T.Resize((args.height, args.width)),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
transform_train = build_transforms(args.height, args.width, is_train=True)
transform_test = build_transforms(args.height, args.width, is_train=False)
pin_memory = True if use_gpu else False

View File

@ -17,7 +17,7 @@ from torch.optim import lr_scheduler
from torchreid import data_manager
from torchreid.dataset_loader import ImageDataset
from torchreid import transforms as T
from torchreid.transforms import build_transforms
from torchreid import models
from torchreid.losses import CrossEntropyLoss, TripletLoss, DeepSupervision
from torchreid.utils.iotools import save_checkpoint, check_isfile
@ -140,18 +140,8 @@ def main():
cuhk03_labeled=args.cuhk03_labeled, cuhk03_classic_split=args.cuhk03_classic_split,
)
transform_train = T.Compose([
T.Random2DTranslation(args.height, args.width),
T.RandomHorizontalFlip(),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
transform_test = T.Compose([
T.Resize((args.height, args.width)),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
transform_train = build_transforms(args.height, args.width, is_train=True)
transform_test = build_transforms(args.height, args.width, is_train=False)
pin_memory = True if use_gpu else False

View File

@ -17,7 +17,7 @@ from torch.optim import lr_scheduler
from torchreid import data_manager
from torchreid.dataset_loader import ImageDataset, VideoDataset
from torchreid import transforms as T
from torchreid.transforms import build_transforms
from torchreid import models
from torchreid.losses import CrossEntropyLoss
from torchreid.utils.iotools import save_checkpoint, check_isfile
@ -126,18 +126,8 @@ def main():
print("Initializing dataset {}".format(args.dataset))
dataset = data_manager.init_vidreid_dataset(root=args.root, name=args.dataset)
transform_train = T.Compose([
T.Random2DTranslation(args.height, args.width),
T.RandomHorizontalFlip(),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
transform_test = T.Compose([
T.Resize((args.height, args.width)),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
transform_train = build_transforms(args.height, args.width, is_train=True)
transform_test = build_transforms(args.height, args.width, is_train=False)
pin_memory = True if use_gpu else False

View File

@ -17,7 +17,7 @@ from torch.optim import lr_scheduler
from torchreid import data_manager
from torchreid.dataset_loader import ImageDataset, VideoDataset
from torchreid import transforms as T
from torchreid.transforms import build_transforms
from torchreid import models
from torchreid.losses import CrossEntropyLoss, TripletLoss, DeepSupervision
from torchreid.utils.iotools import save_checkpoint, check_isfile
@ -131,18 +131,8 @@ def main():
print("Initializing dataset {}".format(args.dataset))
dataset = data_manager.init_vidreid_dataset(root=args.root, name=args.dataset)
transform_train = T.Compose([
T.Random2DTranslation(args.height, args.width),
T.RandomHorizontalFlip(),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
transform_test = T.Compose([
T.Resize((args.height, args.width)),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
transform_train = build_transforms(args.height, args.width, is_train=True)
transform_test = build_transforms(args.height, args.width, is_train=False)
pin_memory = True if use_gpu else False