mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add per model crop pct, interpolation defaults, tie it all together
* create one resolve fn to pull together model defaults + cmd line args * update attribution comments in some models * test update train/validation/inference scripts
This commit is contained in:
parent
c328b155e9
commit
0562b91c38
@ -23,20 +23,20 @@ class PrefetchLoader:
|
|||||||
mean=IMAGENET_DEFAULT_MEAN,
|
mean=IMAGENET_DEFAULT_MEAN,
|
||||||
std=IMAGENET_DEFAULT_STD):
|
std=IMAGENET_DEFAULT_STD):
|
||||||
self.loader = loader
|
self.loader = loader
|
||||||
self.stream = torch.cuda.Stream()
|
|
||||||
self.mean = torch.tensor([x * 255 for x in mean]).cuda().view(1, 3, 1, 1)
|
self.mean = torch.tensor([x * 255 for x in mean]).cuda().view(1, 3, 1, 1)
|
||||||
self.std = torch.tensor([x * 255 for x in std]).cuda().view(1, 3, 1, 1)
|
self.std = torch.tensor([x * 255 for x in std]).cuda().view(1, 3, 1, 1)
|
||||||
if rand_erase_prob:
|
if rand_erase_prob > 0.:
|
||||||
self.random_erasing = RandomErasingTorch(
|
self.random_erasing = RandomErasingTorch(
|
||||||
probability=rand_erase_prob, per_pixel=rand_erase_pp)
|
probability=rand_erase_prob, per_pixel=rand_erase_pp)
|
||||||
else:
|
else:
|
||||||
self.random_erasing = None
|
self.random_erasing = None
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
|
stream = torch.cuda.Stream()
|
||||||
first = True
|
first = True
|
||||||
|
|
||||||
for next_input, next_target in self.loader:
|
for next_input, next_target in self.loader:
|
||||||
with torch.cuda.stream(self.stream):
|
with torch.cuda.stream(stream):
|
||||||
next_input = next_input.cuda(non_blocking=True)
|
next_input = next_input.cuda(non_blocking=True)
|
||||||
next_target = next_target.cuda(non_blocking=True)
|
next_target = next_target.cuda(non_blocking=True)
|
||||||
next_input = next_input.float().sub_(self.mean).div_(self.std)
|
next_input = next_input.float().sub_(self.mean).div_(self.std)
|
||||||
@ -48,7 +48,7 @@ class PrefetchLoader:
|
|||||||
else:
|
else:
|
||||||
first = False
|
first = False
|
||||||
|
|
||||||
torch.cuda.current_stream().wait_stream(self.stream)
|
torch.cuda.current_stream().wait_stream(stream)
|
||||||
input = next_input
|
input = next_input
|
||||||
target = next_target
|
target = next_target
|
||||||
|
|
||||||
@ -64,28 +64,35 @@ class PrefetchLoader:
|
|||||||
|
|
||||||
def create_loader(
|
def create_loader(
|
||||||
dataset,
|
dataset,
|
||||||
img_size,
|
input_size,
|
||||||
batch_size,
|
batch_size,
|
||||||
is_training=False,
|
is_training=False,
|
||||||
use_prefetcher=True,
|
use_prefetcher=True,
|
||||||
rand_erase_prob=0.,
|
rand_erase_prob=0.,
|
||||||
rand_erase_pp=False,
|
rand_erase_pp=False,
|
||||||
|
interpolation='bilinear',
|
||||||
mean=IMAGENET_DEFAULT_MEAN,
|
mean=IMAGENET_DEFAULT_MEAN,
|
||||||
std=IMAGENET_DEFAULT_STD,
|
std=IMAGENET_DEFAULT_STD,
|
||||||
num_workers=1,
|
num_workers=1,
|
||||||
distributed=False,
|
distributed=False,
|
||||||
crop_pct=None,
|
crop_pct=None,
|
||||||
):
|
):
|
||||||
|
if isinstance(input_size, tuple):
|
||||||
|
img_size = input_size[-2:]
|
||||||
|
else:
|
||||||
|
img_size = input_size
|
||||||
|
|
||||||
if is_training:
|
if is_training:
|
||||||
transform = transforms_imagenet_train(
|
transform = transforms_imagenet_train(
|
||||||
img_size,
|
img_size,
|
||||||
|
interpolation=interpolation,
|
||||||
use_prefetcher=use_prefetcher,
|
use_prefetcher=use_prefetcher,
|
||||||
mean=mean,
|
mean=mean,
|
||||||
std=std)
|
std=std)
|
||||||
else:
|
else:
|
||||||
transform = transforms_imagenet_eval(
|
transform = transforms_imagenet_eval(
|
||||||
img_size,
|
img_size,
|
||||||
|
interpolation=interpolation,
|
||||||
use_prefetcher=use_prefetcher,
|
use_prefetcher=use_prefetcher,
|
||||||
mean=mean,
|
mean=mean,
|
||||||
std=std,
|
std=std,
|
||||||
|
@ -15,28 +15,66 @@ IMAGENET_DPN_MEAN = (124 / 255, 117 / 255, 104 / 255)
|
|||||||
IMAGENET_DPN_STD = tuple([1 / (.0167 * 255)] * 3)
|
IMAGENET_DPN_STD = tuple([1 / (.0167 * 255)] * 3)
|
||||||
|
|
||||||
|
|
||||||
def get_mean_and_std(model, args, num_chan=3):
|
def resolve_data_config(model, args, default_cfg={}, verbose=True):
|
||||||
if hasattr(model, 'default_cfg'):
|
new_config = {}
|
||||||
mean = model.default_cfg['mean']
|
default_cfg = default_cfg
|
||||||
std = model.default_cfg['std']
|
if not default_cfg and hasattr(model, 'default_cfg'):
|
||||||
|
default_cfg = model.default_cfg
|
||||||
|
|
||||||
|
# Resolve input/image size
|
||||||
|
# FIXME grayscale/chans arg to use different # channels?
|
||||||
|
in_chans = 3
|
||||||
|
input_size = (in_chans, 224, 224)
|
||||||
|
if args.img_size is not None:
|
||||||
|
# FIXME support passing img_size as tuple, non-square
|
||||||
|
assert isinstance(args.img_size, int)
|
||||||
|
input_size = (in_chans, args.img_size, args.img_size)
|
||||||
|
elif 'input_size' in default_cfg:
|
||||||
|
input_size = default_cfg['input_size']
|
||||||
|
new_config['input_size'] = input_size
|
||||||
|
|
||||||
|
# resolve interpolation method
|
||||||
|
new_config['interpolation'] = 'bilinear'
|
||||||
|
if args.interpolation:
|
||||||
|
new_config['interpolation'] = args.interpolation
|
||||||
|
elif 'interpolation' in default_cfg:
|
||||||
|
new_config['interpolation'] = default_cfg['interpolation']
|
||||||
|
|
||||||
|
# resolve dataset + model mean for normalization
|
||||||
|
new_config['mean'] = get_mean_by_model(args.model)
|
||||||
|
if args.mean is not None:
|
||||||
|
mean = tuple(args.mean)
|
||||||
|
if len(mean) == 1:
|
||||||
|
mean = tuple(list(mean) * in_chans)
|
||||||
|
else:
|
||||||
|
assert len(mean) == in_chans
|
||||||
|
new_config['mean'] = mean
|
||||||
|
elif 'mean' in default_cfg:
|
||||||
|
new_config['mean'] = default_cfg['mean']
|
||||||
|
|
||||||
|
# resolve dataset + model std deviation for normalization
|
||||||
|
new_config['std'] = get_std_by_model(args.model)
|
||||||
|
if args.std is not None:
|
||||||
|
std = tuple(args.std)
|
||||||
|
if len(std) == 1:
|
||||||
|
std = tuple(list(std) * in_chans)
|
||||||
|
else:
|
||||||
|
assert len(std) == in_chans
|
||||||
|
new_config['std'] = std
|
||||||
else:
|
else:
|
||||||
if args.mean is not None:
|
new_config['std'] = default_cfg['std']
|
||||||
mean = tuple(args.mean)
|
|
||||||
if len(mean) == 1:
|
# resolve default crop percentage
|
||||||
mean = tuple(list(mean) * num_chan)
|
new_config['crop_pct'] = DEFAULT_CROP_PCT
|
||||||
else:
|
if 'crop_pct' in default_cfg:
|
||||||
assert len(mean) == num_chan
|
new_config['crop_pct'] = default_cfg['crop_pct']
|
||||||
else:
|
|
||||||
mean = get_mean_by_model(args.model)
|
if verbose:
|
||||||
if args.std is not None:
|
print('Data processing configuration for current model + dataset:')
|
||||||
std = tuple(args.std)
|
for n, v in new_config.items():
|
||||||
if len(std) == 1:
|
print('\t%s: %s' % (n, str(v)))
|
||||||
std = tuple(list(std) * num_chan)
|
|
||||||
else:
|
return new_config
|
||||||
assert len(std) == num_chan
|
|
||||||
else:
|
|
||||||
std = get_std_by_model(args.model)
|
|
||||||
return mean, std
|
|
||||||
|
|
||||||
|
|
||||||
def get_mean_by_name(name):
|
def get_mean_by_name(name):
|
||||||
@ -104,6 +142,7 @@ def transforms_imagenet_train(
|
|||||||
img_size=224,
|
img_size=224,
|
||||||
scale=(0.1, 1.0),
|
scale=(0.1, 1.0),
|
||||||
color_jitter=(0.4, 0.4, 0.4),
|
color_jitter=(0.4, 0.4, 0.4),
|
||||||
|
interpolation='bilinear',
|
||||||
random_erasing=0.4,
|
random_erasing=0.4,
|
||||||
use_prefetcher=False,
|
use_prefetcher=False,
|
||||||
mean=IMAGENET_DEFAULT_MEAN,
|
mean=IMAGENET_DEFAULT_MEAN,
|
||||||
@ -112,7 +151,8 @@ def transforms_imagenet_train(
|
|||||||
|
|
||||||
tfl = [
|
tfl = [
|
||||||
transforms.RandomResizedCrop(
|
transforms.RandomResizedCrop(
|
||||||
img_size, scale=scale, interpolation=Image.BICUBIC),
|
img_size, scale=scale,
|
||||||
|
interpolation=Image.BILINEAR if interpolation == 'bilinear' else Image.BICUBIC),
|
||||||
transforms.RandomHorizontalFlip(),
|
transforms.RandomHorizontalFlip(),
|
||||||
transforms.ColorJitter(*color_jitter),
|
transforms.ColorJitter(*color_jitter),
|
||||||
]
|
]
|
||||||
@ -135,14 +175,24 @@ def transforms_imagenet_train(
|
|||||||
def transforms_imagenet_eval(
|
def transforms_imagenet_eval(
|
||||||
img_size=224,
|
img_size=224,
|
||||||
crop_pct=None,
|
crop_pct=None,
|
||||||
|
interpolation='bilinear',
|
||||||
use_prefetcher=False,
|
use_prefetcher=False,
|
||||||
mean=IMAGENET_DEFAULT_MEAN,
|
mean=IMAGENET_DEFAULT_MEAN,
|
||||||
std=IMAGENET_DEFAULT_STD):
|
std=IMAGENET_DEFAULT_STD):
|
||||||
crop_pct = crop_pct or DEFAULT_CROP_PCT
|
crop_pct = crop_pct or DEFAULT_CROP_PCT
|
||||||
scale_size = int(math.floor(img_size / crop_pct))
|
|
||||||
|
if isinstance(img_size, tuple):
|
||||||
|
assert len(img_size) == 2
|
||||||
|
if img_size[0] == img_size[1]:
|
||||||
|
# fall-back to older behaviour so Resize scales to shortest edge if target is square
|
||||||
|
scale_size = int(math.floor(img_size[0] / crop_pct))
|
||||||
|
else:
|
||||||
|
scale_size = tuple([int(x[0] / crop_pct) for x in img_size])
|
||||||
|
else:
|
||||||
|
scale_size = int(math.floor(img_size / crop_pct))
|
||||||
|
|
||||||
tfl = [
|
tfl = [
|
||||||
transforms.Resize(scale_size, Image.BICUBIC),
|
transforms.Resize(scale_size, Image.BILINEAR if interpolation == 'bilinear' else Image.BICUBIC),
|
||||||
transforms.CenterCrop(img_size),
|
transforms.CenterCrop(img_size),
|
||||||
]
|
]
|
||||||
if use_prefetcher:
|
if use_prefetcher:
|
||||||
|
26
inference.py
26
inference.py
@ -12,7 +12,7 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from models import create_model, apply_test_time_pool
|
from models import create_model, apply_test_time_pool
|
||||||
from data import Dataset, create_loader, get_mean_and_std
|
from data import Dataset, create_loader, resolve_data_config
|
||||||
from utils import AverageMeter
|
from utils import AverageMeter
|
||||||
|
|
||||||
torch.backends.cudnn.benchmark = True
|
torch.backends.cudnn.benchmark = True
|
||||||
@ -30,6 +30,12 @@ parser.add_argument('-b', '--batch-size', default=256, type=int,
|
|||||||
metavar='N', help='mini-batch size (default: 256)')
|
metavar='N', help='mini-batch size (default: 256)')
|
||||||
parser.add_argument('--img-size', default=224, type=int,
|
parser.add_argument('--img-size', default=224, type=int,
|
||||||
metavar='N', help='Input image dimension')
|
metavar='N', help='Input image dimension')
|
||||||
|
parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
|
||||||
|
help='Override mean pixel value of dataset')
|
||||||
|
parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
|
||||||
|
help='Override std deviation of of dataset')
|
||||||
|
parser.add_argument('--interpolation', default='', type=str, metavar='NAME',
|
||||||
|
help='Image resize interpolation type (overrides model)')
|
||||||
parser.add_argument('--num-classes', type=int, default=1000,
|
parser.add_argument('--num-classes', type=int, default=1000,
|
||||||
help='Number classes in dataset')
|
help='Number classes in dataset')
|
||||||
parser.add_argument('--print-freq', '-p', default=10, type=int,
|
parser.add_argument('--print-freq', '-p', default=10, type=int,
|
||||||
@ -40,8 +46,8 @@ parser.add_argument('--pretrained', dest='pretrained', action='store_true',
|
|||||||
help='use pre-trained model')
|
help='use pre-trained model')
|
||||||
parser.add_argument('--num-gpu', type=int, default=1,
|
parser.add_argument('--num-gpu', type=int, default=1,
|
||||||
help='Number of GPUS to use')
|
help='Number of GPUS to use')
|
||||||
parser.add_argument('--no-test-pool', dest='test_time_pool', action='store_false',
|
parser.add_argument('--no-test-pool', dest='no_test_pool', action='store_true',
|
||||||
help='use pre-trained model')
|
help='disable test time pool')
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
@ -58,8 +64,8 @@ def main():
|
|||||||
print('Model %s created, param count: %d' %
|
print('Model %s created, param count: %d' %
|
||||||
(args.model, sum([m.numel() for m in model.parameters()])))
|
(args.model, sum([m.numel() for m in model.parameters()])))
|
||||||
|
|
||||||
data_mean, data_std = get_mean_and_std(model, args)
|
config = resolve_data_config(model, args)
|
||||||
model, test_time_pool = apply_test_time_pool(model, args)
|
model, test_time_pool = apply_test_time_pool(model, config, args)
|
||||||
|
|
||||||
if args.num_gpu > 1:
|
if args.num_gpu > 1:
|
||||||
model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda()
|
model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda()
|
||||||
@ -68,12 +74,14 @@ def main():
|
|||||||
|
|
||||||
loader = create_loader(
|
loader = create_loader(
|
||||||
Dataset(args.data),
|
Dataset(args.data),
|
||||||
img_size=args.img_size,
|
input_size=config['input_size'],
|
||||||
batch_size=args.batch_size,
|
batch_size=args.batch_size,
|
||||||
use_prefetcher=True,
|
use_prefetcher=True,
|
||||||
mean=data_mean,
|
interpolation=config['interpolation'],
|
||||||
std=data_std,
|
mean=config['mean'],
|
||||||
num_workers=args.workers)
|
std=config['std'],
|
||||||
|
num_workers=args.workers,
|
||||||
|
crop_pct=1.0 if test_time_pool else config['crop_pct'])
|
||||||
|
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
"""Pytorch Densenet implementation tweaks
|
"""Pytorch Densenet implementation w/ tweaks
|
||||||
This file is a copy of https://github.com/pytorch/vision 'densenet.py' (BSD-3-Clause) with
|
This file is a copy of https://github.com/pytorch/vision 'densenet.py' (BSD-3-Clause) with
|
||||||
fixed kwargs passthrough and addition of dynamic global avg/max pool.
|
fixed kwargs passthrough and addition of dynamic global avg/max pool.
|
||||||
"""
|
"""
|
||||||
@ -18,6 +18,7 @@ __all__ = ['DenseNet', 'densenet121', 'densenet169', 'densenet201', 'densenet161
|
|||||||
def _cfg(url=''):
|
def _cfg(url=''):
|
||||||
return {
|
return {
|
||||||
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 244), 'pool_size': (7, 7),
|
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 244), 'pool_size': (7, 7),
|
||||||
|
'crop_pct': 0.875, 'interpolation': 'bilinear',
|
||||||
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
||||||
'first_conv': 'features.conv0', 'classifier': 'classifier',
|
'first_conv': 'features.conv0', 'classifier': 'classifier',
|
||||||
}
|
}
|
||||||
|
@ -25,6 +25,7 @@ __all__ = ['DPN', 'dpn68', 'dpn92', 'dpn98', 'dpn131', 'dpn107']
|
|||||||
def _cfg(url=''):
|
def _cfg(url=''):
|
||||||
return {
|
return {
|
||||||
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
||||||
|
'crop_pct': 0.875, 'interpolation': 'bicubic',
|
||||||
'mean': IMAGENET_DPN_MEAN, 'std': IMAGENET_DPN_STD,
|
'mean': IMAGENET_DPN_MEAN, 'std': IMAGENET_DPN_STD,
|
||||||
'first_conv': 'features.conv1_1.conv', 'classifier': 'classifier',
|
'first_conv': 'features.conv1_1.conv', 'classifier': 'classifier',
|
||||||
}
|
}
|
||||||
|
@ -26,7 +26,6 @@ def load_checkpoint(model, checkpoint_path):
|
|||||||
|
|
||||||
|
|
||||||
def resume_checkpoint(model, checkpoint_path, start_epoch=None):
|
def resume_checkpoint(model, checkpoint_path, start_epoch=None):
|
||||||
start_epoch = 0 if start_epoch is None else start_epoch
|
|
||||||
optimizer_state = None
|
optimizer_state = None
|
||||||
if os.path.isfile(checkpoint_path):
|
if os.path.isfile(checkpoint_path):
|
||||||
print("=> loading checkpoint '{}'".format(checkpoint_path))
|
print("=> loading checkpoint '{}'".format(checkpoint_path))
|
||||||
@ -46,6 +45,7 @@ def resume_checkpoint(model, checkpoint_path, start_epoch=None):
|
|||||||
start_epoch = checkpoint['epoch'] if start_epoch is None else start_epoch
|
start_epoch = checkpoint['epoch'] if start_epoch is None else start_epoch
|
||||||
else:
|
else:
|
||||||
model.load_state_dict(checkpoint)
|
model.load_state_dict(checkpoint)
|
||||||
|
start_epoch = 0 if start_epoch is None else start_epoch
|
||||||
return optimizer_state, start_epoch
|
return optimizer_state, start_epoch
|
||||||
else:
|
else:
|
||||||
print("=> No checkpoint found at '{}'".format(checkpoint_path))
|
print("=> No checkpoint found at '{}'".format(checkpoint_path))
|
||||||
|
@ -14,6 +14,7 @@ default_cfgs = {
|
|||||||
'inception_resnet_v2': {
|
'inception_resnet_v2': {
|
||||||
'url': 'http://data.lip6.fr/cadene/pretrainedmodels/inceptionresnetv2-520b38e4.pth',
|
'url': 'http://data.lip6.fr/cadene/pretrainedmodels/inceptionresnetv2-520b38e4.pth',
|
||||||
'num_classes': 1001, 'input_size': (3, 299, 299), 'pool_size': (8, 8),
|
'num_classes': 1001, 'input_size': (3, 299, 299), 'pool_size': (8, 8),
|
||||||
|
'crop_pct': 0.8975, 'interpolation': 'bicubic',
|
||||||
'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
|
'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
|
||||||
'first_conv': 'conv2d_1a.conv', 'classifier': 'last_linear',
|
'first_conv': 'conv2d_1a.conv', 'classifier': 'last_linear',
|
||||||
}
|
}
|
||||||
|
@ -14,6 +14,7 @@ default_cfgs = {
|
|||||||
'inception_v4': {
|
'inception_v4': {
|
||||||
'url': 'http://webia.lip6.fr/~cadene/Downloads/inceptionv4-97ef9c30.pth',
|
'url': 'http://webia.lip6.fr/~cadene/Downloads/inceptionv4-97ef9c30.pth',
|
||||||
'num_classes': 1001, 'input_size': (3, 299, 299), 'pool_size': (8, 8),
|
'num_classes': 1001, 'input_size': (3, 299, 299), 'pool_size': (8, 8),
|
||||||
|
'crop_pct': 0.875, 'interpolation': 'bicubic',
|
||||||
'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
|
'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
|
||||||
'first_conv': 'features.0.conv', 'classifier': 'classif',
|
'first_conv': 'features.0.conv', 'classifier': 'classif',
|
||||||
}
|
}
|
||||||
|
@ -1,3 +1,10 @@
|
|||||||
|
"""
|
||||||
|
pnasnet5large implementation grabbed from Cadene's pretrained models
|
||||||
|
Additional credit to https://github.com/creafz
|
||||||
|
|
||||||
|
https://github.com/Cadene/pretrained-models.pytorch/blob/master/pretrainedmodels/models/pnasnet.py
|
||||||
|
|
||||||
|
"""
|
||||||
from __future__ import print_function, division, absolute_import
|
from __future__ import print_function, division, absolute_import
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
|
||||||
@ -13,9 +20,10 @@ default_cfgs = {
|
|||||||
'url': 'http://data.lip6.fr/cadene/pretrainedmodels/pnasnet5large-bf079911.pth',
|
'url': 'http://data.lip6.fr/cadene/pretrainedmodels/pnasnet5large-bf079911.pth',
|
||||||
'input_size': (3, 331, 331),
|
'input_size': (3, 331, 331),
|
||||||
'pool_size': (11, 11),
|
'pool_size': (11, 11),
|
||||||
|
'crop_pct': 0.875,
|
||||||
|
'interpolation': 'bicubic',
|
||||||
'mean': (0.5, 0.5, 0.5),
|
'mean': (0.5, 0.5, 0.5),
|
||||||
'std': (0.5, 0.5, 0.5),
|
'std': (0.5, 0.5, 0.5),
|
||||||
'crop_pct': 0.8975,
|
|
||||||
'num_classes': 1001,
|
'num_classes': 1001,
|
||||||
'first_conv': 'conv_0.conv',
|
'first_conv': 'conv_0.conv',
|
||||||
'classifier': 'last_linear',
|
'classifier': 'last_linear',
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
"""Pytorch ResNet implementation tweaks
|
"""Pytorch ResNet implementation w/ tweaks
|
||||||
This file is a copy of https://github.com/pytorch/vision 'resnet.py' (BSD-3-Clause) with
|
This file is a copy of https://github.com/pytorch/vision 'resnet.py' (BSD-3-Clause) with
|
||||||
additional dropout and dynamic global avg/max pool.
|
additional dropout and dynamic global avg/max pool.
|
||||||
|
|
||||||
|
ResNext additions added by Ross Wightman
|
||||||
"""
|
"""
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -18,7 +20,8 @@ def _cfg(url=''):
|
|||||||
return {
|
return {
|
||||||
'url': url,
|
'url': url,
|
||||||
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
||||||
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'crop_pct': 0.875,
|
'crop_pct': 0.875, 'interpolation': 'bilinear',
|
||||||
|
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
||||||
'first_conv': 'conv1', 'classifier': 'fc',
|
'first_conv': 'conv1', 'classifier': 'fc',
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -271,7 +274,7 @@ def resnet152(num_classes=1000, in_chans=3, pretrained=False, **kwargs):
|
|||||||
def resnext50_32x4d(num_classes=1000, in_chans=3, pretrained=False, **kwargs):
|
def resnext50_32x4d(num_classes=1000, in_chans=3, pretrained=False, **kwargs):
|
||||||
"""Constructs a ResNeXt50-32x4d model.
|
"""Constructs a ResNeXt50-32x4d model.
|
||||||
"""
|
"""
|
||||||
default_cfg = default_cfgs['resnext50_32x4d2']
|
default_cfg = default_cfgs['resnext50_32x4d']
|
||||||
model = ResNet(
|
model = ResNet(
|
||||||
Bottleneck, [3, 4, 6, 3], cardinality=32, base_width=4,
|
Bottleneck, [3, 4, 6, 3], cardinality=32, base_width=4,
|
||||||
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||||
|
@ -1,4 +1,10 @@
|
|||||||
"""
|
"""
|
||||||
|
SEResNet implementation from Cadene's pretrained models
|
||||||
|
https://github.com/Cadene/pretrained-models.pytorch/blob/master/pretrainedmodels/models/senet.py
|
||||||
|
Additional credit to https://github.com/creafz
|
||||||
|
|
||||||
|
Original model: https://github.com/hujie-frank/SENet
|
||||||
|
|
||||||
ResNet code gently borrowed from
|
ResNet code gently borrowed from
|
||||||
https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
|
https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
|
||||||
"""
|
"""
|
||||||
@ -20,7 +26,8 @@ __all__ = ['SENet', 'senet154', 'seresnet50', 'seresnet101', 'seresnet152',
|
|||||||
def _cfg(url=''):
|
def _cfg(url=''):
|
||||||
return {
|
return {
|
||||||
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 244), 'pool_size': (7, 7),
|
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 244), 'pool_size': (7, 7),
|
||||||
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'crop_pct': 0.875,
|
'crop_pct': 0.875, 'interpolation': 'bilinear',
|
||||||
|
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
||||||
'first_conv': 'layer0.conv1', 'classifier': 'last_linear',
|
'first_conv': 'layer0.conv1', 'classifier': 'last_linear',
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -26,11 +26,13 @@ class TestTimePoolHead(nn.Module):
|
|||||||
return x.view(x.size(0), -1)
|
return x.view(x.size(0), -1)
|
||||||
|
|
||||||
|
|
||||||
def apply_test_time_pool(model, args):
|
def apply_test_time_pool(model, config, args):
|
||||||
test_time_pool = False
|
test_time_pool = False
|
||||||
if args.img_size > model.default_cfg['input_size'][-1] and not args.no_test_pool:
|
if not args.no_test_pool and \
|
||||||
print('Target input size (%d) > pretrained default (%d), using test time pooling' %
|
config['input_size'][-1] > model.default_cfg['input_size'][-1] and \
|
||||||
(args.img_size, model.default_cfg['input_size'][-1]))
|
config['input_size'][-2] > model.default_cfg['input_size'][-2]:
|
||||||
|
print('Target input size (%s) > pretrained default (%s), using test time pooling' %
|
||||||
|
(str(config['input_size'][-2:]), str(model.default_cfg['input_size'][-2:])))
|
||||||
model = TestTimePoolHead(model, original_pool=model.default_cfg['pool_size'])
|
model = TestTimePoolHead(model, original_pool=model.default_cfg['pool_size'])
|
||||||
test_time_pool = True
|
test_time_pool = True
|
||||||
return model, test_time_pool
|
return model, test_time_pool
|
||||||
|
@ -37,10 +37,11 @@ default_cfgs = {
|
|||||||
'xception': {
|
'xception': {
|
||||||
'url': 'http://data.lip6.fr/cadene/pretrainedmodels/xception-43020ad28.pth',
|
'url': 'http://data.lip6.fr/cadene/pretrainedmodels/xception-43020ad28.pth',
|
||||||
'input_size': (3, 299, 299),
|
'input_size': (3, 299, 299),
|
||||||
|
'crop_pct': 0.8975,
|
||||||
|
'interpolation': 'bicubic',
|
||||||
'mean': (0.5, 0.5, 0.5),
|
'mean': (0.5, 0.5, 0.5),
|
||||||
'std': (0.5, 0.5, 0.5),
|
'std': (0.5, 0.5, 0.5),
|
||||||
'num_classes': 1000,
|
'num_classes': 1000,
|
||||||
'crop_pct': 0.8975,
|
|
||||||
'first_conv': 'conv1',
|
'first_conv': 'conv1',
|
||||||
'classifier': 'fc'
|
'classifier': 'fc'
|
||||||
# The resize parameter of the validation transform should be 333, and make sure to center crop at 299x299
|
# The resize parameter of the validation transform should be 333, and make sure to center crop at 299x299
|
||||||
|
24
train.py
24
train.py
@ -43,6 +43,12 @@ parser.add_argument('--pretrained', action='store_true', default=False,
|
|||||||
help='Start with pretrained version of specified network (if avail)')
|
help='Start with pretrained version of specified network (if avail)')
|
||||||
parser.add_argument('--img-size', type=int, default=224, metavar='N',
|
parser.add_argument('--img-size', type=int, default=224, metavar='N',
|
||||||
help='Image patch size (default: 224)')
|
help='Image patch size (default: 224)')
|
||||||
|
parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
|
||||||
|
help='Override mean pixel value of dataset')
|
||||||
|
parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
|
||||||
|
help='Override std deviation of of dataset')
|
||||||
|
parser.add_argument('--interpolation', default='', type=str, metavar='NAME',
|
||||||
|
help='Image resize interpolation type (overrides model)')
|
||||||
parser.add_argument('-b', '--batch-size', type=int, default=32, metavar='N',
|
parser.add_argument('-b', '--batch-size', type=int, default=32, metavar='N',
|
||||||
help='input batch size for training (default: 32)')
|
help='input batch size for training (default: 32)')
|
||||||
parser.add_argument('-s', '--initial-batch-size', type=int, default=0, metavar='N',
|
parser.add_argument('-s', '--initial-batch-size', type=int, default=0, metavar='N',
|
||||||
@ -150,13 +156,13 @@ def main():
|
|||||||
global_pool=args.gp,
|
global_pool=args.gp,
|
||||||
checkpoint_path=args.initial_checkpoint)
|
checkpoint_path=args.initial_checkpoint)
|
||||||
|
|
||||||
data_mean, data_std = get_mean_and_std(model, args)
|
data_config = resolve_data_config(model, args, verbose=args.local_rank == 0)
|
||||||
|
|
||||||
# optionally resume from a checkpoint
|
# optionally resume from a checkpoint
|
||||||
start_epoch = 0
|
start_epoch = 0
|
||||||
optimizer_state = None
|
optimizer_state = None
|
||||||
if args.resume:
|
if args.resume:
|
||||||
start_epoch, optimizer_state = resume_checkpoint(model, args.resume, args.start_epoch)
|
optimizer_state, start_epoch = resume_checkpoint(model, args.resume, args.start_epoch)
|
||||||
|
|
||||||
if args.num_gpu > 1:
|
if args.num_gpu > 1:
|
||||||
if args.amp:
|
if args.amp:
|
||||||
@ -196,14 +202,15 @@ def main():
|
|||||||
|
|
||||||
loader_train = create_loader(
|
loader_train = create_loader(
|
||||||
dataset_train,
|
dataset_train,
|
||||||
img_size=args.img_size,
|
input_size=data_config['input_size'],
|
||||||
batch_size=args.batch_size,
|
batch_size=args.batch_size,
|
||||||
is_training=True,
|
is_training=True,
|
||||||
use_prefetcher=True,
|
use_prefetcher=True,
|
||||||
rand_erase_prob=args.reprob,
|
rand_erase_prob=args.reprob,
|
||||||
rand_erase_pp=args.repp,
|
rand_erase_pp=args.repp,
|
||||||
mean=data_mean,
|
interpolation=data_config['interpolation'],
|
||||||
std=data_std,
|
mean=data_config['mean'],
|
||||||
|
std=data_config['std'],
|
||||||
num_workers=args.workers,
|
num_workers=args.workers,
|
||||||
distributed=args.distributed,
|
distributed=args.distributed,
|
||||||
)
|
)
|
||||||
@ -216,12 +223,13 @@ def main():
|
|||||||
|
|
||||||
loader_eval = create_loader(
|
loader_eval = create_loader(
|
||||||
dataset_eval,
|
dataset_eval,
|
||||||
img_size=args.img_size,
|
input_size=data_config['input_size'],
|
||||||
batch_size=4 * args.batch_size,
|
batch_size=4 * args.batch_size,
|
||||||
is_training=False,
|
is_training=False,
|
||||||
use_prefetcher=True,
|
use_prefetcher=True,
|
||||||
mean=data_mean,
|
interpolation=data_config['interpolation'],
|
||||||
std=data_std,
|
mean=data_config['mean'],
|
||||||
|
std=data_config['std'],
|
||||||
num_workers=args.workers,
|
num_workers=args.workers,
|
||||||
distributed=args.distributed,
|
distributed=args.distributed,
|
||||||
)
|
)
|
||||||
|
30
validate.py
30
validate.py
@ -10,7 +10,7 @@ import torch.nn as nn
|
|||||||
import torch.nn.parallel
|
import torch.nn.parallel
|
||||||
|
|
||||||
from models import create_model, apply_test_time_pool
|
from models import create_model, apply_test_time_pool
|
||||||
from data import Dataset, create_loader, get_mean_and_std
|
from data import Dataset, create_loader, resolve_data_config
|
||||||
from utils import accuracy, AverageMeter
|
from utils import accuracy, AverageMeter
|
||||||
|
|
||||||
torch.backends.cudnn.benchmark = True
|
torch.backends.cudnn.benchmark = True
|
||||||
@ -24,8 +24,14 @@ parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',
|
|||||||
help='number of data loading workers (default: 2)')
|
help='number of data loading workers (default: 2)')
|
||||||
parser.add_argument('-b', '--batch-size', default=256, type=int,
|
parser.add_argument('-b', '--batch-size', default=256, type=int,
|
||||||
metavar='N', help='mini-batch size (default: 256)')
|
metavar='N', help='mini-batch size (default: 256)')
|
||||||
parser.add_argument('--img-size', default=224, type=int,
|
parser.add_argument('--img-size', default=None, type=int,
|
||||||
metavar='N', help='Input image dimension')
|
metavar='N', help='Input image dimension, uses model default if empty')
|
||||||
|
parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
|
||||||
|
help='Override mean pixel value of dataset')
|
||||||
|
parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
|
||||||
|
help='Override std deviation of of dataset')
|
||||||
|
parser.add_argument('--interpolation', default='', type=str, metavar='NAME',
|
||||||
|
help='Image resize interpolation type (overrides model)')
|
||||||
parser.add_argument('--num-classes', type=int, default=1000,
|
parser.add_argument('--num-classes', type=int, default=1000,
|
||||||
help='Number classes in dataset')
|
help='Number classes in dataset')
|
||||||
parser.add_argument('--print-freq', '-p', default=10, type=int,
|
parser.add_argument('--print-freq', '-p', default=10, type=int,
|
||||||
@ -37,7 +43,7 @@ parser.add_argument('--pretrained', dest='pretrained', action='store_true',
|
|||||||
parser.add_argument('--num-gpu', type=int, default=1,
|
parser.add_argument('--num-gpu', type=int, default=1,
|
||||||
help='Number of GPUS to use')
|
help='Number of GPUS to use')
|
||||||
parser.add_argument('--no-test-pool', dest='no_test_pool', action='store_true',
|
parser.add_argument('--no-test-pool', dest='no_test_pool', action='store_true',
|
||||||
help='disable test time pool for DPN models')
|
help='disable test time pool')
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
@ -54,9 +60,8 @@ def main():
|
|||||||
print('Model %s created, param count: %d' %
|
print('Model %s created, param count: %d' %
|
||||||
(args.model, sum([m.numel() for m in model.parameters()])))
|
(args.model, sum([m.numel() for m in model.parameters()])))
|
||||||
|
|
||||||
data_mean, data_std = get_mean_and_std(model, args)
|
data_config = resolve_data_config(model, args)
|
||||||
|
model, test_time_pool = apply_test_time_pool(model, data_config, args)
|
||||||
model, test_time_pool = apply_test_time_pool(model, args)
|
|
||||||
|
|
||||||
if args.num_gpu > 1:
|
if args.num_gpu > 1:
|
||||||
model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda()
|
model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda()
|
||||||
@ -68,13 +73,14 @@ def main():
|
|||||||
|
|
||||||
loader = create_loader(
|
loader = create_loader(
|
||||||
Dataset(args.data),
|
Dataset(args.data),
|
||||||
img_size=args.img_size,
|
input_size=data_config['input_size'],
|
||||||
batch_size=args.batch_size,
|
batch_size=args.batch_size,
|
||||||
use_prefetcher=False,
|
use_prefetcher=True,
|
||||||
mean=data_mean,
|
interpolation=data_config['interpolation'],
|
||||||
std=data_std,
|
mean=data_config['mean'],
|
||||||
|
std=data_config['std'],
|
||||||
num_workers=args.workers,
|
num_workers=args.workers,
|
||||||
crop_pct=1.0 if test_time_pool else None)
|
crop_pct=1.0 if test_time_pool else data_config['crop_pct'])
|
||||||
|
|
||||||
batch_time = AverageMeter()
|
batch_time = AverageMeter()
|
||||||
losses = AverageMeter()
|
losses = AverageMeter()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user