wrap transforms into build_transforms()
parent
e302af017d
commit
c388c24880
|
@ -2,6 +2,7 @@ from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
|
|
||||||
from torchvision.transforms import *
|
from torchvision.transforms import *
|
||||||
|
import torch
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import random
|
import random
|
||||||
|
@ -13,8 +14,8 @@ class Random2DTranslation(object):
|
||||||
With a probability, first increase image size to (1 + 1/8), and then perform random crop.
|
With a probability, first increase image size to (1 + 1/8), and then perform random crop.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
- height (int): target height.
|
- height (int): target image height.
|
||||||
- width (int): target width.
|
- width (int): target image width.
|
||||||
- p (float): probability of performing this transformation. Default: 0.5.
|
- p (float): probability of performing this transformation. Default: 0.5.
|
||||||
"""
|
"""
|
||||||
def __init__(self, height, width, p=0.5, interpolation=Image.BILINEAR):
|
def __init__(self, height, width, p=0.5, interpolation=Image.BILINEAR):
|
||||||
|
@ -39,3 +40,48 @@ class Random2DTranslation(object):
|
||||||
y1 = int(round(random.uniform(0, y_maxrange)))
|
y1 = int(round(random.uniform(0, y_maxrange)))
|
||||||
croped_img = resized_img.crop((x1, y1, x1 + self.width, y1 + self.height))
|
croped_img = resized_img.crop((x1, y1, x1 + self.width, y1 + self.height))
|
||||||
return croped_img
|
return croped_img
|
||||||
|
|
||||||
|
|
||||||
|
def build_transforms(height, width, is_train, **kwargs):
|
||||||
|
"""Build transforms
|
||||||
|
|
||||||
|
Args:
|
||||||
|
- height (int): target image height.
|
||||||
|
- width (int): target image width.
|
||||||
|
- is_train (bool): train or test phase.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# use imagenet mean and std as default
|
||||||
|
imagenet_mean = [0.485, 0.456, 0.406]
|
||||||
|
imagenet_std = [0.229, 0.224, 0.225]
|
||||||
|
normalize = Normalize(mean=imagenet_mean, std=imagenet_std)
|
||||||
|
|
||||||
|
transforms = []
|
||||||
|
|
||||||
|
if is_train:
|
||||||
|
# build TRAIN transforms
|
||||||
|
transforms += [Random2DTranslation(height, width)]
|
||||||
|
transforms += [RandomHorizontalFlip()]
|
||||||
|
transforms += [ToTensor()]
|
||||||
|
transforms += [normalize]
|
||||||
|
|
||||||
|
else:
|
||||||
|
# build TEST transforms
|
||||||
|
if 'five_crop' in kwargs and kwargs['five_crop']:
|
||||||
|
transforms += [Resize((int(height * 1.125), int(width * 1.125)))]
|
||||||
|
transforms += [FiveCrop((height, width))]
|
||||||
|
transforms += [Lambda(lambda crops: torch.stack([normalize(ToTensor()(crop)) for crop in crops]))]
|
||||||
|
|
||||||
|
elif 'ten_crop' in kwargs and kwargs['ten_crop']:
|
||||||
|
transforms += [Resize((int(height * 1.125), int(width * 1.125)))]
|
||||||
|
transforms += [TenCrop((height, width))]
|
||||||
|
transforms += [Lambda(lambda crops: torch.stack([normalize(ToTensor()(crop)) for crop in crops]))]
|
||||||
|
|
||||||
|
else:
|
||||||
|
transforms += [Resize((height, width))]
|
||||||
|
transforms += [ToTensor()]
|
||||||
|
transforms += [normalize]
|
||||||
|
|
||||||
|
transforms = Compose(transforms)
|
||||||
|
|
||||||
|
return transforms
|
|
@ -17,7 +17,7 @@ from torch.optim import lr_scheduler
|
||||||
|
|
||||||
from torchreid import data_manager
|
from torchreid import data_manager
|
||||||
from torchreid.dataset_loader import ImageDataset
|
from torchreid.dataset_loader import ImageDataset
|
||||||
from torchreid import transforms as T
|
from torchreid.transforms import build_transforms
|
||||||
from torchreid import models
|
from torchreid import models
|
||||||
from torchreid.losses import CrossEntropyLoss, DeepSupervision
|
from torchreid.losses import CrossEntropyLoss, DeepSupervision
|
||||||
from torchreid.utils.iotools import save_checkpoint, check_isfile
|
from torchreid.utils.iotools import save_checkpoint, check_isfile
|
||||||
|
@ -135,18 +135,8 @@ def main():
|
||||||
cuhk03_labeled=args.cuhk03_labeled, cuhk03_classic_split=args.cuhk03_classic_split,
|
cuhk03_labeled=args.cuhk03_labeled, cuhk03_classic_split=args.cuhk03_classic_split,
|
||||||
)
|
)
|
||||||
|
|
||||||
transform_train = T.Compose([
|
transform_train = build_transforms(args.height, args.width, is_train=True)
|
||||||
T.Random2DTranslation(args.height, args.width),
|
transform_test = build_transforms(args.height, args.width, is_train=False)
|
||||||
T.RandomHorizontalFlip(),
|
|
||||||
T.ToTensor(),
|
|
||||||
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
|
||||||
])
|
|
||||||
|
|
||||||
transform_test = T.Compose([
|
|
||||||
T.Resize((args.height, args.width)),
|
|
||||||
T.ToTensor(),
|
|
||||||
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
|
||||||
])
|
|
||||||
|
|
||||||
pin_memory = True if use_gpu else False
|
pin_memory = True if use_gpu else False
|
||||||
|
|
||||||
|
|
|
@ -17,7 +17,7 @@ from torch.optim import lr_scheduler
|
||||||
|
|
||||||
from torchreid import data_manager
|
from torchreid import data_manager
|
||||||
from torchreid.dataset_loader import ImageDataset
|
from torchreid.dataset_loader import ImageDataset
|
||||||
from torchreid import transforms as T
|
from torchreid.transforms import build_transforms
|
||||||
from torchreid import models
|
from torchreid import models
|
||||||
from torchreid.losses import CrossEntropyLoss, TripletLoss, DeepSupervision
|
from torchreid.losses import CrossEntropyLoss, TripletLoss, DeepSupervision
|
||||||
from torchreid.utils.iotools import save_checkpoint, check_isfile
|
from torchreid.utils.iotools import save_checkpoint, check_isfile
|
||||||
|
@ -140,18 +140,8 @@ def main():
|
||||||
cuhk03_labeled=args.cuhk03_labeled, cuhk03_classic_split=args.cuhk03_classic_split,
|
cuhk03_labeled=args.cuhk03_labeled, cuhk03_classic_split=args.cuhk03_classic_split,
|
||||||
)
|
)
|
||||||
|
|
||||||
transform_train = T.Compose([
|
transform_train = build_transforms(args.height, args.width, is_train=True)
|
||||||
T.Random2DTranslation(args.height, args.width),
|
transform_test = build_transforms(args.height, args.width, is_train=False)
|
||||||
T.RandomHorizontalFlip(),
|
|
||||||
T.ToTensor(),
|
|
||||||
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
|
||||||
])
|
|
||||||
|
|
||||||
transform_test = T.Compose([
|
|
||||||
T.Resize((args.height, args.width)),
|
|
||||||
T.ToTensor(),
|
|
||||||
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
|
||||||
])
|
|
||||||
|
|
||||||
pin_memory = True if use_gpu else False
|
pin_memory = True if use_gpu else False
|
||||||
|
|
||||||
|
|
|
@ -17,7 +17,7 @@ from torch.optim import lr_scheduler
|
||||||
|
|
||||||
from torchreid import data_manager
|
from torchreid import data_manager
|
||||||
from torchreid.dataset_loader import ImageDataset, VideoDataset
|
from torchreid.dataset_loader import ImageDataset, VideoDataset
|
||||||
from torchreid import transforms as T
|
from torchreid.transforms import build_transforms
|
||||||
from torchreid import models
|
from torchreid import models
|
||||||
from torchreid.losses import CrossEntropyLoss
|
from torchreid.losses import CrossEntropyLoss
|
||||||
from torchreid.utils.iotools import save_checkpoint, check_isfile
|
from torchreid.utils.iotools import save_checkpoint, check_isfile
|
||||||
|
@ -126,18 +126,8 @@ def main():
|
||||||
print("Initializing dataset {}".format(args.dataset))
|
print("Initializing dataset {}".format(args.dataset))
|
||||||
dataset = data_manager.init_vidreid_dataset(root=args.root, name=args.dataset)
|
dataset = data_manager.init_vidreid_dataset(root=args.root, name=args.dataset)
|
||||||
|
|
||||||
transform_train = T.Compose([
|
transform_train = build_transforms(args.height, args.width, is_train=True)
|
||||||
T.Random2DTranslation(args.height, args.width),
|
transform_test = build_transforms(args.height, args.width, is_train=False)
|
||||||
T.RandomHorizontalFlip(),
|
|
||||||
T.ToTensor(),
|
|
||||||
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
|
||||||
])
|
|
||||||
|
|
||||||
transform_test = T.Compose([
|
|
||||||
T.Resize((args.height, args.width)),
|
|
||||||
T.ToTensor(),
|
|
||||||
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
|
||||||
])
|
|
||||||
|
|
||||||
pin_memory = True if use_gpu else False
|
pin_memory = True if use_gpu else False
|
||||||
|
|
||||||
|
|
|
@ -17,7 +17,7 @@ from torch.optim import lr_scheduler
|
||||||
|
|
||||||
from torchreid import data_manager
|
from torchreid import data_manager
|
||||||
from torchreid.dataset_loader import ImageDataset, VideoDataset
|
from torchreid.dataset_loader import ImageDataset, VideoDataset
|
||||||
from torchreid import transforms as T
|
from torchreid.transforms import build_transforms
|
||||||
from torchreid import models
|
from torchreid import models
|
||||||
from torchreid.losses import CrossEntropyLoss, TripletLoss, DeepSupervision
|
from torchreid.losses import CrossEntropyLoss, TripletLoss, DeepSupervision
|
||||||
from torchreid.utils.iotools import save_checkpoint, check_isfile
|
from torchreid.utils.iotools import save_checkpoint, check_isfile
|
||||||
|
@ -131,18 +131,8 @@ def main():
|
||||||
print("Initializing dataset {}".format(args.dataset))
|
print("Initializing dataset {}".format(args.dataset))
|
||||||
dataset = data_manager.init_vidreid_dataset(root=args.root, name=args.dataset)
|
dataset = data_manager.init_vidreid_dataset(root=args.root, name=args.dataset)
|
||||||
|
|
||||||
transform_train = T.Compose([
|
transform_train = build_transforms(args.height, args.width, is_train=True)
|
||||||
T.Random2DTranslation(args.height, args.width),
|
transform_test = build_transforms(args.height, args.width, is_train=False)
|
||||||
T.RandomHorizontalFlip(),
|
|
||||||
T.ToTensor(),
|
|
||||||
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
|
||||||
])
|
|
||||||
|
|
||||||
transform_test = T.Compose([
|
|
||||||
T.Resize((args.height, args.width)),
|
|
||||||
T.ToTensor(),
|
|
||||||
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
|
||||||
])
|
|
||||||
|
|
||||||
pin_memory = True if use_gpu else False
|
pin_memory = True if use_gpu else False
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue