Some transform/data/loader refactoring, hopefully didn't break things
* factor out data related constants to own file * move data related config helpers to own file * add a variant of RandomResizeCrop that randomizes interpolation method * remove old Numpy version of RandomErasing * cleanup torch version of RandomErasing and use it in either GPU loader batch mode or single image cpu Transformpull/2/head
parent
e3377b0409
commit
76539d905e
|
@ -1,4 +1,5 @@
|
|||
from data.constants import *
|
||||
from data.config import resolve_data_config
|
||||
from data.dataset import Dataset
|
||||
from data.transforms import *
|
||||
from data.loader import create_loader
|
||||
from data.random_erasing import RandomErasingTorch, RandomErasingNumpy
|
|
@ -0,0 +1,101 @@
|
|||
from data.constants import *
|
||||
|
||||
|
||||
def resolve_data_config(model, args, default_cfg={}, verbose=True):
|
||||
new_config = {}
|
||||
default_cfg = default_cfg
|
||||
if not default_cfg and hasattr(model, 'default_cfg'):
|
||||
default_cfg = model.default_cfg
|
||||
|
||||
# Resolve input/image size
|
||||
# FIXME grayscale/chans arg to use different # channels?
|
||||
in_chans = 3
|
||||
input_size = (in_chans, 224, 224)
|
||||
if args.img_size is not None:
|
||||
# FIXME support passing img_size as tuple, non-square
|
||||
assert isinstance(args.img_size, int)
|
||||
input_size = (in_chans, args.img_size, args.img_size)
|
||||
elif 'input_size' in default_cfg:
|
||||
input_size = default_cfg['input_size']
|
||||
new_config['input_size'] = input_size
|
||||
|
||||
# resolve interpolation method
|
||||
new_config['interpolation'] = 'bilinear'
|
||||
if args.interpolation:
|
||||
new_config['interpolation'] = args.interpolation
|
||||
elif 'interpolation' in default_cfg:
|
||||
new_config['interpolation'] = default_cfg['interpolation']
|
||||
|
||||
# resolve dataset + model mean for normalization
|
||||
new_config['mean'] = get_mean_by_model(args.model)
|
||||
if args.mean is not None:
|
||||
mean = tuple(args.mean)
|
||||
if len(mean) == 1:
|
||||
mean = tuple(list(mean) * in_chans)
|
||||
else:
|
||||
assert len(mean) == in_chans
|
||||
new_config['mean'] = mean
|
||||
elif 'mean' in default_cfg:
|
||||
new_config['mean'] = default_cfg['mean']
|
||||
|
||||
# resolve dataset + model std deviation for normalization
|
||||
new_config['std'] = get_std_by_model(args.model)
|
||||
if args.std is not None:
|
||||
std = tuple(args.std)
|
||||
if len(std) == 1:
|
||||
std = tuple(list(std) * in_chans)
|
||||
else:
|
||||
assert len(std) == in_chans
|
||||
new_config['std'] = std
|
||||
elif 'std' in default_cfg:
|
||||
new_config['std'] = default_cfg['std']
|
||||
|
||||
# resolve default crop percentage
|
||||
new_config['crop_pct'] = DEFAULT_CROP_PCT
|
||||
if 'crop_pct' in default_cfg:
|
||||
new_config['crop_pct'] = default_cfg['crop_pct']
|
||||
|
||||
if verbose:
|
||||
print('Data processing configuration for current model + dataset:')
|
||||
for n, v in new_config.items():
|
||||
print('\t%s: %s' % (n, str(v)))
|
||||
|
||||
return new_config
|
||||
|
||||
|
||||
def get_mean_by_name(name):
|
||||
if name == 'dpn':
|
||||
return IMAGENET_DPN_MEAN
|
||||
elif name == 'inception' or name == 'le':
|
||||
return IMAGENET_INCEPTION_MEAN
|
||||
else:
|
||||
return IMAGENET_DEFAULT_MEAN
|
||||
|
||||
|
||||
def get_std_by_name(name):
|
||||
if name == 'dpn':
|
||||
return IMAGENET_DPN_STD
|
||||
elif name == 'inception' or name == 'le':
|
||||
return IMAGENET_INCEPTION_STD
|
||||
else:
|
||||
return IMAGENET_DEFAULT_STD
|
||||
|
||||
|
||||
def get_mean_by_model(model_name):
|
||||
model_name = model_name.lower()
|
||||
if 'dpn' in model_name:
|
||||
return IMAGENET_DPN_STD
|
||||
elif 'ception' in model_name or 'nasnet' in model_name:
|
||||
return IMAGENET_INCEPTION_MEAN
|
||||
else:
|
||||
return IMAGENET_DEFAULT_MEAN
|
||||
|
||||
|
||||
def get_std_by_model(model_name):
|
||||
model_name = model_name.lower()
|
||||
if 'dpn' in model_name:
|
||||
return IMAGENET_DEFAULT_STD
|
||||
elif 'ception' in model_name or 'nasnet' in model_name:
|
||||
return IMAGENET_INCEPTION_STD
|
||||
else:
|
||||
return IMAGENET_DEFAULT_STD
|
|
@ -0,0 +1,7 @@
|
|||
DEFAULT_CROP_PCT = 0.875
|
||||
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
|
||||
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
|
||||
IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5)
|
||||
IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5)
|
||||
IMAGENET_DPN_MEAN = (124 / 255, 117 / 255, 104 / 255)
|
||||
IMAGENET_DPN_STD = tuple([1 / (.0167 * 255)] * 3)
|
|
@ -1,6 +1,4 @@
|
|||
import torch
|
||||
import torch.utils.data
|
||||
from data.random_erasing import RandomErasingTorch
|
||||
from data.transforms import *
|
||||
from data.distributed_sampler import OrderedDistributedSampler
|
||||
|
||||
|
@ -27,7 +25,7 @@ class PrefetchLoader:
|
|||
self.mean = torch.tensor([x * 255 for x in mean]).cuda().view(1, 3, 1, 1)
|
||||
self.std = torch.tensor([x * 255 for x in std]).cuda().view(1, 3, 1, 1)
|
||||
if rand_erase_prob > 0.:
|
||||
self.random_erasing = RandomErasingTorch(
|
||||
self.random_erasing = RandomErasing(
|
||||
probability=rand_erase_prob, per_pixel=rand_erase_pp)
|
||||
else:
|
||||
self.random_erasing = None
|
||||
|
|
|
@ -2,125 +2,68 @@ from __future__ import absolute_import
|
|||
|
||||
import random
|
||||
import math
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
class RandomErasingNumpy:
|
||||
def _get_patch(per_pixel, rand_color, patch_size, dtype=torch.float32, device='cuda'):
|
||||
if per_pixel:
|
||||
return torch.empty(
|
||||
patch_size, dtype=dtype, device=device).normal_()
|
||||
elif rand_color:
|
||||
return torch.empty((patch_size[0], 1, 1), dtype=dtype, device=device).normal_()
|
||||
else:
|
||||
return torch.zeros((patch_size[0], 1, 1), dtype=dtype, device=device)
|
||||
|
||||
|
||||
class RandomErasing:
|
||||
""" Randomly selects a rectangle region in an image and erases its pixels.
|
||||
'Random Erasing Data Augmentation' by Zhong et al.
|
||||
See https://arxiv.org/pdf/1708.04896.pdf
|
||||
|
||||
This 'Numpy' variant of RandomErasing is intended to be applied on a per
|
||||
image basis after transforming the image to uint8 numpy array in
|
||||
range 0-255 prior to tensor conversion and normalization
|
||||
This variant of RandomErasing is intended to be applied to either a batch
|
||||
or single image tensor after it has been normalized by dataset mean and std.
|
||||
Args:
|
||||
probability: The probability that the Random Erasing operation will be performed.
|
||||
sl: Minimum proportion of erased area against input image.
|
||||
sh: Maximum proportion of erased area against input image.
|
||||
r1: Minimum aspect ratio of erased area.
|
||||
mean: Erasing value.
|
||||
min_aspect: Minimum aspect ratio of erased area.
|
||||
per_pixel: random value for each pixel in the erase region, precedence over rand_color
|
||||
rand_color: random color for whole erase region, 0 if neither this or per_pixel set
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
probability=0.5, sl=0.02, sh=1/3, min_aspect=0.3,
|
||||
per_pixel=False, rand_color=False,
|
||||
pl=0, ph=255, mean=[255 * 0.485, 255 * 0.456, 255 * 0.406],
|
||||
out_type=np.uint8):
|
||||
per_pixel=False, rand_color=False, device='cuda'):
|
||||
self.probability = probability
|
||||
if not per_pixel and not rand_color:
|
||||
self.mean = np.array(mean).round().astype(out_type)
|
||||
else:
|
||||
self.mean = None
|
||||
self.sl = sl
|
||||
self.sh = sh
|
||||
self.min_aspect = min_aspect
|
||||
self.pl = pl
|
||||
self.ph = ph
|
||||
self.per_pixel = per_pixel # per pixel random, bounded by [pl, ph]
|
||||
self.rand_color = rand_color # per block random, bounded by [pl, ph]
|
||||
self.out_type = out_type
|
||||
self.device = device
|
||||
|
||||
def __call__(self, img):
|
||||
def _erase(self, img, chan, img_h, img_w, dtype):
|
||||
if random.random() > self.probability:
|
||||
return img
|
||||
|
||||
chan, img_h, img_w = img.shape
|
||||
return
|
||||
area = img_h * img_w
|
||||
for attempt in range(100):
|
||||
target_area = random.uniform(self.sl, self.sh) * area
|
||||
aspect_ratio = random.uniform(self.min_aspect, 1 / self.min_aspect)
|
||||
|
||||
h = int(round(math.sqrt(target_area * aspect_ratio)))
|
||||
w = int(round(math.sqrt(target_area / aspect_ratio)))
|
||||
if self.rand_color:
|
||||
c = np.random.randint(self.pl, self.ph + 1, (chan,), self.out_type)
|
||||
elif not self.per_pixel:
|
||||
c = self.mean[:chan]
|
||||
if w < img_w and h < img_h:
|
||||
top = random.randint(0, img_h - h)
|
||||
left = random.randint(0, img_w - w)
|
||||
if self.per_pixel:
|
||||
img[:, top:top + h, left:left + w] = np.random.randint(
|
||||
self.pl, self.ph + 1, (chan, h, w), self.out_type)
|
||||
else:
|
||||
img[:, top:top + h, left:left + w] = c
|
||||
return img
|
||||
img[:, top:top + h, left:left + w] = _get_patch(
|
||||
self.per_pixel, self.rand_color, (chan, h, w), dtype=dtype, device=self.device)
|
||||
break
|
||||
|
||||
return img
|
||||
|
||||
|
||||
class RandomErasingTorch:
|
||||
""" Randomly selects a rectangle region in an image and erases its pixels.
|
||||
'Random Erasing Data Augmentation' by Zhong et al.
|
||||
See https://arxiv.org/pdf/1708.04896.pdf
|
||||
|
||||
This 'Torch' variant of RandomErasing is intended to be applied to a full batch
|
||||
tensor after it has been normalized by dataset mean and std.
|
||||
Args:
|
||||
probability: The probability that the Random Erasing operation will be performed.
|
||||
sl: Minimum proportion of erased area against input image.
|
||||
sh: Maximum proportion of erased area against input image.
|
||||
r1: Minimum aspect ratio of erased area.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
probability=0.5, sl=0.02, sh=1/3, min_aspect=0.3,
|
||||
per_pixel=False, rand_color=False):
|
||||
self.probability = probability
|
||||
self.sl = sl
|
||||
self.sh = sh
|
||||
self.min_aspect = min_aspect
|
||||
self.per_pixel = per_pixel # per pixel random, bounded by [pl, ph]
|
||||
self.rand_color = rand_color # per block random, bounded by [pl, ph]
|
||||
|
||||
def __call__(self, batch):
|
||||
batch_size, chan, img_h, img_w = batch.size()
|
||||
area = img_h * img_w
|
||||
for i in range(batch_size):
|
||||
if random.random() > self.probability:
|
||||
continue
|
||||
img = batch[i]
|
||||
for attempt in range(100):
|
||||
target_area = random.uniform(self.sl, self.sh) * area
|
||||
aspect_ratio = random.uniform(self.min_aspect, 1 / self.min_aspect)
|
||||
|
||||
h = int(round(math.sqrt(target_area * aspect_ratio)))
|
||||
w = int(round(math.sqrt(target_area / aspect_ratio)))
|
||||
if self.rand_color:
|
||||
c = torch.empty((chan, 1, 1), dtype=batch.dtype).normal_().cuda()
|
||||
elif not self.per_pixel:
|
||||
c = torch.zeros((chan, 1, 1), dtype=batch.dtype).cuda()
|
||||
if w < img_w and h < img_h:
|
||||
top = random.randint(0, img_h - h)
|
||||
left = random.randint(0, img_w - w)
|
||||
if self.per_pixel:
|
||||
img[:, top:top + h, left:left + w] = torch.empty(
|
||||
(chan, h, w), dtype=batch.dtype).normal_().cuda()
|
||||
else:
|
||||
img[:, top:top + h, left:left + w] = c
|
||||
break
|
||||
|
||||
return batch
|
||||
def __call__(self, input):
|
||||
if len(input.size()) == 3:
|
||||
self._erase(input, *input.size(), input.dtype)
|
||||
else:
|
||||
batch_size, chan, img_h, img_w = input.size()
|
||||
for i in range(batch_size):
|
||||
self._erase(input[i], chan, img_h, img_w, input.dtype)
|
||||
return input
|
||||
|
|
|
@ -1,118 +1,14 @@
|
|||
import torch
|
||||
from torchvision import transforms
|
||||
import torchvision.transforms.functional as F
|
||||
from PIL import Image
|
||||
import warnings
|
||||
import math
|
||||
import random
|
||||
import numpy as np
|
||||
from data.random_erasing import RandomErasingNumpy
|
||||
|
||||
DEFAULT_CROP_PCT = 0.875
|
||||
|
||||
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
|
||||
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
|
||||
IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5)
|
||||
IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5)
|
||||
IMAGENET_DPN_MEAN = (124 / 255, 117 / 255, 104 / 255)
|
||||
IMAGENET_DPN_STD = tuple([1 / (.0167 * 255)] * 3)
|
||||
|
||||
|
||||
def resolve_data_config(model, args, default_cfg={}, verbose=True):
|
||||
new_config = {}
|
||||
default_cfg = default_cfg
|
||||
if not default_cfg and hasattr(model, 'default_cfg'):
|
||||
default_cfg = model.default_cfg
|
||||
|
||||
# Resolve input/image size
|
||||
# FIXME grayscale/chans arg to use different # channels?
|
||||
in_chans = 3
|
||||
input_size = (in_chans, 224, 224)
|
||||
if args.img_size is not None:
|
||||
# FIXME support passing img_size as tuple, non-square
|
||||
assert isinstance(args.img_size, int)
|
||||
input_size = (in_chans, args.img_size, args.img_size)
|
||||
elif 'input_size' in default_cfg:
|
||||
input_size = default_cfg['input_size']
|
||||
new_config['input_size'] = input_size
|
||||
|
||||
# resolve interpolation method
|
||||
new_config['interpolation'] = 'bilinear'
|
||||
if args.interpolation:
|
||||
new_config['interpolation'] = args.interpolation
|
||||
elif 'interpolation' in default_cfg:
|
||||
new_config['interpolation'] = default_cfg['interpolation']
|
||||
|
||||
# resolve dataset + model mean for normalization
|
||||
new_config['mean'] = get_mean_by_model(args.model)
|
||||
if args.mean is not None:
|
||||
mean = tuple(args.mean)
|
||||
if len(mean) == 1:
|
||||
mean = tuple(list(mean) * in_chans)
|
||||
else:
|
||||
assert len(mean) == in_chans
|
||||
new_config['mean'] = mean
|
||||
elif 'mean' in default_cfg:
|
||||
new_config['mean'] = default_cfg['mean']
|
||||
|
||||
# resolve dataset + model std deviation for normalization
|
||||
new_config['std'] = get_std_by_model(args.model)
|
||||
if args.std is not None:
|
||||
std = tuple(args.std)
|
||||
if len(std) == 1:
|
||||
std = tuple(list(std) * in_chans)
|
||||
else:
|
||||
assert len(std) == in_chans
|
||||
new_config['std'] = std
|
||||
elif 'std' in default_cfg:
|
||||
new_config['std'] = default_cfg['std']
|
||||
|
||||
# resolve default crop percentage
|
||||
new_config['crop_pct'] = DEFAULT_CROP_PCT
|
||||
if 'crop_pct' in default_cfg:
|
||||
new_config['crop_pct'] = default_cfg['crop_pct']
|
||||
|
||||
if verbose:
|
||||
print('Data processing configuration for current model + dataset:')
|
||||
for n, v in new_config.items():
|
||||
print('\t%s: %s' % (n, str(v)))
|
||||
|
||||
return new_config
|
||||
|
||||
|
||||
def get_mean_by_name(name):
|
||||
if name == 'dpn':
|
||||
return IMAGENET_DPN_MEAN
|
||||
elif name == 'inception' or name == 'le':
|
||||
return IMAGENET_INCEPTION_MEAN
|
||||
else:
|
||||
return IMAGENET_DEFAULT_MEAN
|
||||
|
||||
|
||||
def get_std_by_name(name):
|
||||
if name == 'dpn':
|
||||
return IMAGENET_DPN_STD
|
||||
elif name == 'inception' or name == 'le':
|
||||
return IMAGENET_INCEPTION_STD
|
||||
else:
|
||||
return IMAGENET_DEFAULT_STD
|
||||
|
||||
|
||||
def get_mean_by_model(model_name):
|
||||
model_name = model_name.lower()
|
||||
if 'dpn' in model_name:
|
||||
return IMAGENET_DPN_STD
|
||||
elif 'ception' in model_name or 'nasnet' in model_name:
|
||||
return IMAGENET_INCEPTION_MEAN
|
||||
else:
|
||||
return IMAGENET_DEFAULT_MEAN
|
||||
|
||||
|
||||
def get_std_by_model(model_name):
|
||||
model_name = model_name.lower()
|
||||
if 'dpn' in model_name:
|
||||
return IMAGENET_DEFAULT_STD
|
||||
elif 'ception' in model_name or 'nasnet' in model_name:
|
||||
return IMAGENET_INCEPTION_STD
|
||||
else:
|
||||
return IMAGENET_DEFAULT_STD
|
||||
from data import DEFAULT_CROP_PCT, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from data.random_erasing import RandomErasing
|
||||
|
||||
|
||||
class ToNumpy:
|
||||
|
@ -138,6 +34,16 @@ class ToTensor:
|
|||
return torch.from_numpy(np_img).to(dtype=self.dtype)
|
||||
|
||||
|
||||
_pil_interpolation_to_str = {
|
||||
Image.NEAREST: 'PIL.Image.NEAREST',
|
||||
Image.BILINEAR: 'PIL.Image.BILINEAR',
|
||||
Image.BICUBIC: 'PIL.Image.BICUBIC',
|
||||
Image.LANCZOS: 'PIL.Image.LANCZOS',
|
||||
Image.HAMMING: 'PIL.Image.HAMMING',
|
||||
Image.BOX: 'PIL.Image.BOX',
|
||||
}
|
||||
|
||||
|
||||
def _pil_interp(method):
|
||||
if method == 'bicubic':
|
||||
return Image.BICUBIC
|
||||
|
@ -150,21 +56,118 @@ def _pil_interp(method):
|
|||
return Image.BILINEAR
|
||||
|
||||
|
||||
RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC)
|
||||
|
||||
|
||||
class RandomResizedCropAndInterpolation(object):
|
||||
"""Crop the given PIL Image to random size and aspect ratio with random interpolation.
|
||||
|
||||
A crop of random size (default: of 0.08 to 1.0) of the original size and a random
|
||||
aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop
|
||||
is finally resized to given size.
|
||||
This is popularly used to train the Inception networks.
|
||||
|
||||
Args:
|
||||
size: expected output size of each edge
|
||||
scale: range of size of the origin size cropped
|
||||
ratio: range of aspect ratio of the origin aspect ratio cropped
|
||||
interpolation: Default: PIL.Image.BILINEAR
|
||||
"""
|
||||
|
||||
def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.),
|
||||
interpolation='bilinear'):
|
||||
if isinstance(size, tuple):
|
||||
self.size = size
|
||||
else:
|
||||
self.size = (size, size)
|
||||
if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
|
||||
warnings.warn("range should be of kind (min, max)")
|
||||
|
||||
if interpolation == 'random':
|
||||
self.interpolation = RANDOM_INTERPOLATION
|
||||
else:
|
||||
self.interpolation = _pil_interp(interpolation)
|
||||
self.scale = scale
|
||||
self.ratio = ratio
|
||||
|
||||
@staticmethod
|
||||
def get_params(img, scale, ratio):
|
||||
"""Get parameters for ``crop`` for a random sized crop.
|
||||
|
||||
Args:
|
||||
img (PIL Image): Image to be cropped.
|
||||
scale (tuple): range of size of the origin size cropped
|
||||
ratio (tuple): range of aspect ratio of the origin aspect ratio cropped
|
||||
|
||||
Returns:
|
||||
tuple: params (i, j, h, w) to be passed to ``crop`` for a random
|
||||
sized crop.
|
||||
"""
|
||||
area = img.size[0] * img.size[1]
|
||||
|
||||
for attempt in range(10):
|
||||
target_area = random.uniform(*scale) * area
|
||||
aspect_ratio = random.uniform(*ratio)
|
||||
|
||||
w = int(round(math.sqrt(target_area * aspect_ratio)))
|
||||
h = int(round(math.sqrt(target_area / aspect_ratio)))
|
||||
|
||||
if random.random() < 0.5 and min(ratio) <= (h / w) <= max(ratio):
|
||||
w, h = h, w
|
||||
|
||||
if w <= img.size[0] and h <= img.size[1]:
|
||||
i = random.randint(0, img.size[1] - h)
|
||||
j = random.randint(0, img.size[0] - w)
|
||||
return i, j, h, w
|
||||
|
||||
# Fallback
|
||||
w = min(img.size[0], img.size[1])
|
||||
i = (img.size[1] - w) // 2
|
||||
j = (img.size[0] - w) // 2
|
||||
return i, j, w, w
|
||||
|
||||
def __call__(self, img):
|
||||
"""
|
||||
Args:
|
||||
img (PIL Image): Image to be cropped and resized.
|
||||
|
||||
Returns:
|
||||
PIL Image: Randomly cropped and resized image.
|
||||
"""
|
||||
i, j, h, w = self.get_params(img, self.scale, self.ratio)
|
||||
if isinstance(self.interpolation, (tuple, list)):
|
||||
interpolation = random.choice(self.interpolation)
|
||||
else:
|
||||
interpolation = self.interpolation
|
||||
return F.resized_crop(img, i, j, h, w, self.size, interpolation)
|
||||
|
||||
def __repr__(self):
|
||||
if isinstance(self.interpolation, (tuple, list)):
|
||||
interpolate_str = ' '.join([_pil_interpolation_to_str[x] for x in self.interpolation])
|
||||
else:
|
||||
interpolate_str = _pil_interpolation_to_str[self.interpolation]
|
||||
format_string = self.__class__.__name__ + '(size={0}'.format(self.size)
|
||||
format_string += ', scale={0}'.format(tuple(round(s, 4) for s in self.scale))
|
||||
format_string += ', ratio={0}'.format(tuple(round(r, 4) for r in self.ratio))
|
||||
format_string += ', interpolation={0})'.format(interpolate_str)
|
||||
return format_string
|
||||
|
||||
|
||||
def transforms_imagenet_train(
|
||||
img_size=224,
|
||||
scale=(0.1, 1.0),
|
||||
scale=(0.08, 1.0),
|
||||
color_jitter=(0.4, 0.4, 0.4),
|
||||
interpolation='bilinear',
|
||||
interpolation='random',
|
||||
random_erasing=0.4,
|
||||
random_erasing_pp=True,
|
||||
use_prefetcher=False,
|
||||
mean=IMAGENET_DEFAULT_MEAN,
|
||||
std=IMAGENET_DEFAULT_STD
|
||||
):
|
||||
|
||||
tfl = [
|
||||
transforms.RandomResizedCrop(
|
||||
img_size, scale=scale,
|
||||
interpolation=_pil_interp(interpolation)),
|
||||
RandomResizedCropAndInterpolation(
|
||||
img_size, scale=scale, interpolation=interpolation),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.ColorJitter(*color_jitter),
|
||||
]
|
||||
|
@ -174,13 +177,13 @@ def transforms_imagenet_train(
|
|||
tfl += [ToNumpy()]
|
||||
else:
|
||||
tfl += [
|
||||
ToTensor(),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(
|
||||
mean=torch.tensor(mean),
|
||||
std=torch.tensor(std))
|
||||
]
|
||||
if random_erasing > 0.:
|
||||
tfl.append(RandomErasingNumpy(random_erasing, per_pixel=True))
|
||||
tfl.append(RandomErasing(random_erasing, per_pixel=random_erasing_pp, device='cpu'))
|
||||
return transforms.Compose(tfl)
|
||||
|
||||
|
||||
|
|
|
@ -9,7 +9,7 @@ from collections import OrderedDict
|
|||
|
||||
from models.helpers import load_pretrained
|
||||
from models.adaptive_avgmax_pool import *
|
||||
from data.transforms import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
import re
|
||||
|
||||
__all__ = ['DenseNet', 'densenet121', 'densenet169', 'densenet201', 'densenet161']
|
||||
|
|
|
@ -17,7 +17,7 @@ from collections import OrderedDict
|
|||
|
||||
from models.helpers import load_pretrained
|
||||
from models.adaptive_avgmax_pool import select_adaptive_pool2d
|
||||
from data.transforms import IMAGENET_DPN_MEAN, IMAGENET_DPN_STD
|
||||
from data import IMAGENET_DPN_MEAN, IMAGENET_DPN_STD
|
||||
|
||||
__all__ = ['DPN', 'dpn68', 'dpn92', 'dpn98', 'dpn131', 'dpn107']
|
||||
|
||||
|
|
|
@ -24,7 +24,7 @@ import torch.nn.functional as F
|
|||
from models.helpers import load_pretrained
|
||||
from models.adaptive_avgmax_pool import SelectAdaptivePool2d
|
||||
from models.conv2d_same import sconv2d
|
||||
from data.transforms import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
|
||||
__all__ = ['GenMobileNet', 'mnasnet_050', 'mnasnet_075', 'mnasnet_100', 'mnasnet_140',
|
||||
'semnasnet_050', 'semnasnet_075', 'semnasnet_100', 'semnasnet_140', 'mnasnet_small',
|
||||
|
|
|
@ -7,8 +7,7 @@ import torch.nn as nn
|
|||
import torch.nn.functional as F
|
||||
from models.helpers import load_pretrained
|
||||
from models.adaptive_avgmax_pool import *
|
||||
from data.transforms import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
||||
|
||||
from data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
||||
|
||||
default_cfgs = {
|
||||
'inception_resnet_v2': {
|
||||
|
|
|
@ -7,8 +7,7 @@ import torch.nn as nn
|
|||
import torch.nn.functional as F
|
||||
from models.helpers import load_pretrained
|
||||
from models.adaptive_avgmax_pool import *
|
||||
from data.transforms import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
||||
|
||||
from data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
||||
|
||||
default_cfgs = {
|
||||
'inception_v4': {
|
||||
|
|
|
@ -10,7 +10,7 @@ import torch.nn.functional as F
|
|||
import math
|
||||
from models.helpers import load_pretrained
|
||||
from models.adaptive_avgmax_pool import SelectAdaptivePool2d
|
||||
from data.transforms import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
|
||||
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152',
|
||||
'resnext50_32x4d', 'resnext101_32x4d', 'resnext101_64x4d', 'resnext152_32x4d']
|
||||
|
|
|
@ -17,7 +17,7 @@ import torch.nn.functional as F
|
|||
|
||||
from models.helpers import load_pretrained
|
||||
from models.adaptive_avgmax_pool import SelectAdaptivePool2d
|
||||
from data.transforms import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
|
||||
__all__ = ['SENet', 'senet154', 'seresnet50', 'seresnet101', 'seresnet152',
|
||||
'seresnext50_32x4d', 'seresnext101_32x4d']
|
||||
|
|
4
train.py
4
train.py
|
@ -10,7 +10,7 @@ try:
|
|||
except ImportError:
|
||||
has_apex = False
|
||||
|
||||
from data import *
|
||||
from data import Dataset, create_loader, resolve_data_config
|
||||
from models import create_model, resume_checkpoint
|
||||
from utils import *
|
||||
from loss import LabelSmoothingCrossEntropy, SparseLabelCrossEntropy
|
||||
|
@ -224,7 +224,7 @@ def main():
|
|||
use_prefetcher=True,
|
||||
rand_erase_prob=args.reprob,
|
||||
rand_erase_pp=args.repp,
|
||||
interpolation=data_config['interpolation'],
|
||||
interpolation='random', # FIXME cleanly resolve this? data_config['interpolation'],
|
||||
mean=data_config['mean'],
|
||||
std=data_config['std'],
|
||||
num_workers=args.workers,
|
||||
|
|
Loading…
Reference in New Issue