Merge pull request #16 from rwightman/misc-epoch
Weights, arguments, epoch counting, and morepull/19/head
commit
8a05b8d555
|
@ -67,6 +67,7 @@ I've leveraged the training scripts in this repository to train a few of the mod
|
|||
|Model | Prec@1 (Err) | Prec@5 (Err) | Param # | Image Scaling |
|
||||
|---|---|---|---|---|
|
||||
| resnext50_32x4d | 78.512 (21.488) | 94.042 (5.958) | 25M | bicubic |
|
||||
| resnet50 | 78.470 (21.530) | 94.266 (5.734) | 25.6M | bicubic |
|
||||
| seresnext26_32x4d | 77.104 (22.896) | 93.316 (6.684) | 16.8M | bicubic |
|
||||
| efficientnet_b0 | 76.912 (23.088) | 93.210 (6.790) | 5.29M | bicubic |
|
||||
| mobilenetv3_100 | 75.634 (24.366) | 92.708 (7.292) | 5.5M | bicubic |
|
||||
|
@ -86,7 +87,7 @@ I've leveraged the training scripts in this repository to train a few of the mod
|
|||
#### @ 260x260
|
||||
|Model | Prec@1 (Err) | Prec@5 (Err) | Param # | Image Scaling |
|
||||
|---|---|---|---|---|
|
||||
| efficientnet_b2 | 79.668 (20.332) | 94.634 (5.366) | 9.11M | bicubic |
|
||||
| efficientnet_b2 | 79.760 (20.240) | 94.714 (5.286) | 9.11M | bicubic |
|
||||
|
||||
### Ported Weights
|
||||
|
||||
|
|
|
@ -70,7 +70,7 @@ def main():
|
|||
logging.info('Model %s created, param count: %d' %
|
||||
(args.model, sum([m.numel() for m in model.parameters()])))
|
||||
|
||||
config = resolve_data_config(model, args)
|
||||
config = resolve_data_config(vars(args), model=model)
|
||||
model, test_time_pool = apply_test_time_pool(model, config, args)
|
||||
|
||||
if args.num_gpu > 1:
|
||||
|
|
|
@ -2,35 +2,43 @@ import logging
|
|||
from .constants import *
|
||||
|
||||
|
||||
def resolve_data_config(model, args, default_cfg={}, verbose=True):
|
||||
def resolve_data_config(args, default_cfg={}, model=None, verbose=True):
|
||||
new_config = {}
|
||||
default_cfg = default_cfg
|
||||
if not default_cfg and hasattr(model, 'default_cfg'):
|
||||
if not default_cfg and model is not None and hasattr(model, 'default_cfg'):
|
||||
default_cfg = model.default_cfg
|
||||
|
||||
# Resolve input/image size
|
||||
# FIXME grayscale/chans arg to use different # channels?
|
||||
in_chans = 3
|
||||
if 'chans' in args and args['chans'] is not None:
|
||||
in_chans = args['chans']
|
||||
|
||||
input_size = (in_chans, 224, 224)
|
||||
if args.img_size is not None:
|
||||
# FIXME support passing img_size as tuple, non-square
|
||||
assert isinstance(args.img_size, int)
|
||||
input_size = (in_chans, args.img_size, args.img_size)
|
||||
if 'input_size' in args and args['input_size'] is not None:
|
||||
assert isinstance(args['input_size'], (tuple, list))
|
||||
assert len(args['input_size']) == 3
|
||||
input_size = tuple(args['input_size'])
|
||||
in_chans = input_size[0] # input_size overrides in_chans
|
||||
elif 'img_size' in args and args['img_size'] is not None:
|
||||
assert isinstance(args['img_size'], int)
|
||||
input_size = (in_chans, args['img_size'], args['img_size'])
|
||||
elif 'input_size' in default_cfg:
|
||||
input_size = default_cfg['input_size']
|
||||
new_config['input_size'] = input_size
|
||||
|
||||
# resolve interpolation method
|
||||
new_config['interpolation'] = 'bilinear'
|
||||
if args.interpolation:
|
||||
new_config['interpolation'] = args.interpolation
|
||||
new_config['interpolation'] = 'bicubic'
|
||||
if 'interpolation' in args and args['interpolation']:
|
||||
new_config['interpolation'] = args['interpolation']
|
||||
elif 'interpolation' in default_cfg:
|
||||
new_config['interpolation'] = default_cfg['interpolation']
|
||||
|
||||
# resolve dataset + model mean for normalization
|
||||
new_config['mean'] = get_mean_by_model(args.model)
|
||||
if args.mean is not None:
|
||||
mean = tuple(args.mean)
|
||||
new_config['mean'] = IMAGENET_DEFAULT_MEAN
|
||||
if 'model' in args:
|
||||
new_config['mean'] = get_mean_by_model(args['model'])
|
||||
if 'mean' in args and args['mean'] is not None:
|
||||
mean = tuple(args['mean'])
|
||||
if len(mean) == 1:
|
||||
mean = tuple(list(mean) * in_chans)
|
||||
else:
|
||||
|
@ -40,9 +48,11 @@ def resolve_data_config(model, args, default_cfg={}, verbose=True):
|
|||
new_config['mean'] = default_cfg['mean']
|
||||
|
||||
# resolve dataset + model std deviation for normalization
|
||||
new_config['std'] = get_std_by_model(args.model)
|
||||
if args.std is not None:
|
||||
std = tuple(args.std)
|
||||
new_config['std'] = IMAGENET_DEFAULT_STD
|
||||
if 'model' in args:
|
||||
new_config['std'] = get_std_by_model(args['model'])
|
||||
if 'std' in args and args['std'] is not None:
|
||||
std = tuple(args['std'])
|
||||
if len(std) == 1:
|
||||
std = tuple(list(std) * in_chans)
|
||||
else:
|
||||
|
@ -53,7 +63,9 @@ def resolve_data_config(model, args, default_cfg={}, verbose=True):
|
|||
|
||||
# resolve default crop percentage
|
||||
new_config['crop_pct'] = DEFAULT_CROP_PCT
|
||||
if 'crop_pct' in default_cfg:
|
||||
if 'crop_pct' in args and args['crop_pct'] is not None:
|
||||
new_config['crop_pct'] = args['crop_pct']
|
||||
elif 'crop_pct' in default_cfg:
|
||||
new_config['crop_pct'] = default_cfg['crop_pct']
|
||||
|
||||
if verbose:
|
||||
|
@ -64,29 +76,11 @@ def resolve_data_config(model, args, default_cfg={}, verbose=True):
|
|||
return new_config
|
||||
|
||||
|
||||
def get_mean_by_name(name):
|
||||
if name == 'dpn':
|
||||
return IMAGENET_DPN_MEAN
|
||||
elif name == 'inception' or name == 'le':
|
||||
return IMAGENET_INCEPTION_MEAN
|
||||
else:
|
||||
return IMAGENET_DEFAULT_MEAN
|
||||
|
||||
|
||||
def get_std_by_name(name):
|
||||
if name == 'dpn':
|
||||
return IMAGENET_DPN_STD
|
||||
elif name == 'inception' or name == 'le':
|
||||
return IMAGENET_INCEPTION_STD
|
||||
else:
|
||||
return IMAGENET_DEFAULT_STD
|
||||
|
||||
|
||||
def get_mean_by_model(model_name):
|
||||
model_name = model_name.lower()
|
||||
if 'dpn' in model_name:
|
||||
return IMAGENET_DPN_STD
|
||||
elif 'ception' in model_name or 'nasnet' in model_name:
|
||||
elif 'ception' in model_name or ('nasnet' in model_name and 'mnasnet' not in model_name):
|
||||
return IMAGENET_INCEPTION_MEAN
|
||||
else:
|
||||
return IMAGENET_DEFAULT_MEAN
|
||||
|
@ -96,7 +90,7 @@ def get_std_by_model(model_name):
|
|||
model_name = model_name.lower()
|
||||
if 'dpn' in model_name:
|
||||
return IMAGENET_DEFAULT_STD
|
||||
elif 'ception' in model_name or 'nasnet' in model_name:
|
||||
elif 'ception' in model_name or ('nasnet' in model_name and 'mnasnet' not in model_name):
|
||||
return IMAGENET_INCEPTION_STD
|
||||
else:
|
||||
return IMAGENET_DEFAULT_STD
|
||||
|
|
|
@ -86,6 +86,7 @@ def create_loader(
|
|||
use_prefetcher=True,
|
||||
rand_erase_prob=0.,
|
||||
rand_erase_mode='const',
|
||||
color_jitter=0.4,
|
||||
interpolation='bilinear',
|
||||
mean=IMAGENET_DEFAULT_MEAN,
|
||||
std=IMAGENET_DEFAULT_STD,
|
||||
|
@ -107,6 +108,7 @@ def create_loader(
|
|||
if is_training:
|
||||
transform = transforms_imagenet_train(
|
||||
img_size,
|
||||
color_jitter=color_jitter,
|
||||
interpolation=interpolation,
|
||||
use_prefetcher=use_prefetcher,
|
||||
mean=mean,
|
||||
|
|
|
@ -6,12 +6,13 @@ import torch
|
|||
def _get_pixels(per_pixel, rand_color, patch_size, dtype=torch.float32, device='cuda'):
|
||||
# NOTE I've seen CUDA illegal memory access errors being caused by the normal_()
|
||||
# paths, flip the order so normal is run on CPU if this becomes a problem
|
||||
# ie torch.empty(patch_size, dtype=dtype).normal_().to(device=device)
|
||||
# Issue has been fixed in master https://github.com/pytorch/pytorch/issues/19508
|
||||
# will revert back to doing normal_() on GPU when it's in next release
|
||||
if per_pixel:
|
||||
return torch.empty(
|
||||
patch_size, dtype=dtype, device=device).normal_()
|
||||
patch_size, dtype=dtype).normal_().to(device=device)
|
||||
elif rand_color:
|
||||
return torch.empty((patch_size[0], 1, 1), dtype=dtype, device=device).normal_()
|
||||
return torch.empty((patch_size[0], 1, 1), dtype=dtype).normal_().to(device=device)
|
||||
else:
|
||||
return torch.zeros((patch_size[0], 1, 1), dtype=dtype, device=device)
|
||||
|
||||
|
|
|
@ -156,7 +156,7 @@ class RandomResizedCropAndInterpolation(object):
|
|||
def transforms_imagenet_train(
|
||||
img_size=224,
|
||||
scale=(0.08, 1.0),
|
||||
color_jitter=(0.4, 0.4, 0.4),
|
||||
color_jitter=0.4,
|
||||
interpolation='random',
|
||||
random_erasing=0.4,
|
||||
random_erasing_mode='const',
|
||||
|
@ -164,6 +164,13 @@ def transforms_imagenet_train(
|
|||
mean=IMAGENET_DEFAULT_MEAN,
|
||||
std=IMAGENET_DEFAULT_STD
|
||||
):
|
||||
if isinstance(color_jitter, (list, tuple)):
|
||||
# color jitter should be a 3-tuple/list if spec brightness/contrast/saturation
|
||||
# or 4 if also augmenting hue
|
||||
assert len(color_jitter) in (3, 4)
|
||||
else:
|
||||
# if it's a scalar, duplicate for brightness, contrast, and saturation, no hue
|
||||
color_jitter = (float(color_jitter),) * 3
|
||||
|
||||
tfl = [
|
||||
RandomResizedCropAndInterpolation(
|
||||
|
|
|
@ -35,9 +35,9 @@ def _cfg(url=''):
|
|||
default_cfgs = {
|
||||
'dpn68': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-dpn-pretrained/releases/download/v0.1/dpn68-66bebafa7.pth'),
|
||||
'dpn68b_extra': _cfg(
|
||||
'dpn68b': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-dpn-pretrained/releases/download/v0.1/dpn68b_extra-84854c156.pth'),
|
||||
'dpn92_extra': _cfg(
|
||||
'dpn92': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-dpn-pretrained/releases/download/v0.1/dpn92_extra-b040e4a9b.pth'),
|
||||
'dpn98': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-dpn-pretrained/releases/download/v0.1/dpn98-5b90dec4d.pth'),
|
||||
|
|
|
@ -84,7 +84,7 @@ default_cfgs = {
|
|||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b1-533bc792.pth',
|
||||
input_size=(3, 240, 240), pool_size=(8, 8), interpolation='bicubic', crop_pct=0.882),
|
||||
'efficientnet_b2': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b2-d4105846.pth',
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b2-cf78dc4d.pth',
|
||||
input_size=(3, 260, 260), pool_size=(9, 9), interpolation='bicubic', crop_pct=0.890),
|
||||
'efficientnet_b3': _cfg(
|
||||
url='', input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904),
|
||||
|
|
|
@ -50,11 +50,9 @@ default_cfgs = {
|
|||
'gluon_resnext50_32x4d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnext50_32x4d-e6a097c1.pth'),
|
||||
'gluon_resnext101_32x4d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnext101_32x4d-b253c8c4.pth'),
|
||||
'gluon_resnext101_64x4d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnext101_64x4d-f9a8e184.pth'),
|
||||
'gluon_resnext152_32x4d': _cfg(url=''),
|
||||
'gluon_seresnext50_32x4d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_seresnext50_32x4d-90cf2d6e.pth'),
|
||||
'gluon_seresnext101_32x4d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_seresnext101_32x4d-cf52900d.pth'),
|
||||
'gluon_seresnext101_64x4d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_seresnext101_64x4d-f9926f93.pth'),
|
||||
'gluon_seresnext152_32x4d': _cfg(url=''),
|
||||
'gluon_senet154': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_senet154-70a1a3c0.pth'),
|
||||
}
|
||||
|
||||
|
@ -617,20 +615,6 @@ def gluon_resnext101_64x4d(pretrained=False, num_classes=1000, in_chans=3, **kwa
|
|||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def gluon_resnext152_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
"""Constructs a ResNeXt152-32x4d model.
|
||||
"""
|
||||
default_cfg = default_cfgs['gluon_resnext152_32x4d']
|
||||
model = GluonResNet(
|
||||
BottleneckGl, [3, 8, 36, 3], cardinality=32, base_width=4,
|
||||
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def gluon_seresnext50_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
"""Constructs a SEResNeXt50-32x4d model.
|
||||
|
@ -673,20 +657,6 @@ def gluon_seresnext101_64x4d(pretrained=False, num_classes=1000, in_chans=3, **k
|
|||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def gluon_seresnext152_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
"""Constructs a SEResNeXt152-32x4d model.
|
||||
"""
|
||||
default_cfg = default_cfgs['gluon_seresnext152_32x4d']
|
||||
model = GluonResNet(
|
||||
BottleneckGl, [3, 8, 36, 3], cardinality=32, base_width=4, use_se=True,
|
||||
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
#if pretrained:
|
||||
# load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def gluon_senet154(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
"""Constructs an SENet-154 model.
|
||||
|
|
|
@ -28,8 +28,9 @@ def load_checkpoint(model, checkpoint_path, use_ema=False):
|
|||
raise FileNotFoundError()
|
||||
|
||||
|
||||
def resume_checkpoint(model, checkpoint_path, start_epoch=None):
|
||||
def resume_checkpoint(model, checkpoint_path):
|
||||
optimizer_state = None
|
||||
resume_epoch = None
|
||||
if os.path.isfile(checkpoint_path):
|
||||
checkpoint = torch.load(checkpoint_path)
|
||||
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
|
||||
|
@ -40,13 +41,15 @@ def resume_checkpoint(model, checkpoint_path, start_epoch=None):
|
|||
model.load_state_dict(new_state_dict)
|
||||
if 'optimizer' in checkpoint:
|
||||
optimizer_state = checkpoint['optimizer']
|
||||
start_epoch = checkpoint['epoch'] if start_epoch is None else start_epoch
|
||||
if 'epoch' in checkpoint:
|
||||
resume_epoch = checkpoint['epoch']
|
||||
if 'version' in checkpoint and checkpoint['version'] > 1:
|
||||
resume_epoch += 1 # start at the next epoch, old checkpoints incremented before save
|
||||
logging.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch']))
|
||||
else:
|
||||
model.load_state_dict(checkpoint)
|
||||
start_epoch = 0 if start_epoch is None else start_epoch
|
||||
logging.info("Loaded checkpoint '{}'".format(checkpoint_path))
|
||||
return optimizer_state, start_epoch
|
||||
return optimizer_state, resume_epoch
|
||||
else:
|
||||
logging.error("No checkpoint found at '{}'".format(checkpoint_path))
|
||||
raise FileNotFoundError()
|
||||
|
|
|
@ -5,22 +5,36 @@ from collections import defaultdict
|
|||
|
||||
__all__ = ['list_models', 'is_model', 'model_entrypoint', 'list_modules', 'is_model_in_modules']
|
||||
|
||||
_module_to_models = defaultdict(set)
|
||||
_model_to_module = {}
|
||||
_model_entrypoints = {}
|
||||
_module_to_models = defaultdict(set) # dict of sets to check membership of model in module
|
||||
_model_to_module = {} # mapping of model names to module names
|
||||
_model_entrypoints = {} # mapping of model names to entrypoint fns
|
||||
_model_has_pretrained = set() # set of model names that have pretrained weight url present
|
||||
|
||||
|
||||
def register_model(fn):
|
||||
# lookup containing module
|
||||
mod = sys.modules[fn.__module__]
|
||||
module_name_split = fn.__module__.split('.')
|
||||
module_name = module_name_split[-1] if len(module_name_split) else ''
|
||||
|
||||
# add model to __all__ in module
|
||||
model_name = fn.__name__
|
||||
if hasattr(mod, '__all__'):
|
||||
mod.__all__.append(fn.__name__)
|
||||
mod.__all__.append(model_name)
|
||||
else:
|
||||
mod.__all__ = [fn.__name__]
|
||||
_model_entrypoints[fn.__name__] = fn
|
||||
_model_to_module[fn.__name__] = module_name
|
||||
_module_to_models[module_name].add(fn.__name__)
|
||||
mod.__all__ = [model_name]
|
||||
|
||||
# add entries to registry dict/sets
|
||||
_model_entrypoints[model_name] = fn
|
||||
_model_to_module[model_name] = module_name
|
||||
_module_to_models[module_name].add(model_name)
|
||||
has_pretrained = False # check if model has a pretrained url to allow filtering on this
|
||||
if hasattr(mod, 'default_cfgs') and model_name in mod.default_cfgs:
|
||||
# this will catch all models that have entrypoint matching cfg key, but miss any aliasing
|
||||
# entrypoints or non-matching combos
|
||||
has_pretrained = 'url' in mod.default_cfgs[model_name] and 'http' in mod.default_cfgs[model_name]['url']
|
||||
if has_pretrained:
|
||||
_model_has_pretrained.add(model_name)
|
||||
return fn
|
||||
|
||||
|
||||
|
@ -28,7 +42,7 @@ def _natural_key(string_):
|
|||
return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
|
||||
|
||||
|
||||
def list_models(filter='', module=''):
|
||||
def list_models(filter='', module='', pretrained=False):
|
||||
""" Return list of available model names, sorted alphabetically
|
||||
|
||||
Args:
|
||||
|
@ -45,6 +59,8 @@ def list_models(filter='', module=''):
|
|||
models = _model_entrypoints.keys()
|
||||
if filter:
|
||||
models = fnmatch.filter(models, filter)
|
||||
if pretrained:
|
||||
models = _model_has_pretrained.intersection(models)
|
||||
return list(sorted(models, key=_natural_key))
|
||||
|
||||
|
||||
|
|
|
@ -33,14 +33,22 @@ default_cfgs = {
|
|||
'resnet18': _cfg(url='https://download.pytorch.org/models/resnet18-5c106cde.pth'),
|
||||
'resnet34': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet34-43635321.pth'),
|
||||
'resnet50': _cfg(url='https://download.pytorch.org/models/resnet50-19c8e357.pth'),
|
||||
'resnet50': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/rw_resnet50-86acaeed.pth',
|
||||
interpolation='bicubic'),
|
||||
'resnet101': _cfg(url='https://download.pytorch.org/models/resnet101-5d3b4d8f.pth'),
|
||||
'resnet152': _cfg(url='https://download.pytorch.org/models/resnet152-b121ed2d.pth'),
|
||||
'resnext50_32x4d': _cfg(url='https://www.dropbox.com/s/yxci33lfew51p6a/resnext50_32x4d-068914d1.pth?dl=1',
|
||||
interpolation='bicubic'),
|
||||
'tv_resnet34': _cfg(url='https://download.pytorch.org/models/resnet34-333f7ec4.pth'),
|
||||
'tv_resnet50': _cfg(url='https://download.pytorch.org/models/resnet50-19c8e357.pth'),
|
||||
'wide_resnet50_2': _cfg(url='https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth'),
|
||||
'wide_resnet101_2': _cfg(url='https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth'),
|
||||
'resnext50_32x4d': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnext50_32x4d-068914d1.pth',
|
||||
interpolation='bicubic'),
|
||||
'resnext101_32x4d': _cfg(url=''),
|
||||
'resnext101_32x8d': _cfg(url='https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth'),
|
||||
'resnext101_64x4d': _cfg(url=''),
|
||||
'resnext152_32x4d': _cfg(url=''),
|
||||
'tv_resnext50_32x4d': _cfg(url='https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth'),
|
||||
'ig_resnext101_32x8d': _cfg(url='https://download.pytorch.org/models/ig_resnext101_32x8-c38310e5.pth'),
|
||||
'ig_resnext101_32x16d': _cfg(url='https://download.pytorch.org/models/ig_resnext101_32x16-c6f796b0.pth'),
|
||||
'ig_resnext101_32x32d': _cfg(url='https://download.pytorch.org/models/ig_resnext101_32x32-e4b90b00.pth'),
|
||||
|
@ -285,6 +293,61 @@ def resnet152(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
|||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def tv_resnet34(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
"""Constructs a ResNet-34 model with original Torchvision weights.
|
||||
"""
|
||||
model = ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||
model.default_cfg = default_cfgs['tv_resnet34']
|
||||
if pretrained:
|
||||
load_pretrained(model, model.default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def tv_resnet50(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
"""Constructs a ResNet-50 model with original Torchvision weights.
|
||||
"""
|
||||
model = ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||
model.default_cfg = default_cfgs['tv_resnet50']
|
||||
if pretrained:
|
||||
load_pretrained(model, model.default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def wide_resnet50_2(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
"""Constructs a Wide ResNet-50-2 model.
|
||||
The model is the same as ResNet except for the bottleneck number of channels
|
||||
which is twice larger in every block. The number of channels in outer 1x1
|
||||
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
|
||||
channels, and in Wide ResNet-50-2 has 2048-1024-2048.
|
||||
"""
|
||||
model = ResNet(
|
||||
Bottleneck, [3, 4, 6, 3], base_width=128,
|
||||
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||
model.default_cfg = default_cfgs['wide_resnet50_2']
|
||||
if pretrained:
|
||||
load_pretrained(model, model.default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def wide_resnet101_2(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
"""Constructs a Wide ResNet-101-2 model.
|
||||
The model is the same as ResNet except for the bottleneck number of channels
|
||||
which is twice larger in every block. The number of channels in outer 1x1
|
||||
convolutions is the same.
|
||||
"""
|
||||
model = ResNet(
|
||||
Bottleneck, [3, 4, 23, 3], base_width=128,
|
||||
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||
model.default_cfg = default_cfgs['wide_resnet101_2']
|
||||
if pretrained:
|
||||
load_pretrained(model, model.default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def resnext50_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
"""Constructs a ResNeXt50-32x4d model.
|
||||
|
@ -301,7 +364,7 @@ def resnext50_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
|||
|
||||
@register_model
|
||||
def resnext101_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
"""Constructs a ResNeXt-101 model.
|
||||
"""Constructs a ResNeXt-101 32x4d model.
|
||||
"""
|
||||
default_cfg = default_cfgs['resnext101_32x4d']
|
||||
model = ResNet(
|
||||
|
@ -313,6 +376,20 @@ def resnext101_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
|||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def resnext101_32x8d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
"""Constructs a ResNeXt-101 32x8d model.
|
||||
"""
|
||||
default_cfg = default_cfgs['resnext101_32x8d']
|
||||
model = ResNet(
|
||||
Bottleneck, [3, 4, 23, 3], cardinality=32, base_width=8,
|
||||
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def resnext101_64x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
"""Constructs a ResNeXt101-64x4d model.
|
||||
|
@ -328,12 +405,12 @@ def resnext101_64x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
|||
|
||||
|
||||
@register_model
|
||||
def resnext152_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
"""Constructs a ResNeXt152-32x4d model.
|
||||
def tv_resnext50_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
"""Constructs a ResNeXt50-32x4d model with original Torchvision weights.
|
||||
"""
|
||||
default_cfg = default_cfgs['resnext152_32x4d']
|
||||
default_cfg = default_cfgs['tv_resnext50_32x4d']
|
||||
model = ResNet(
|
||||
Bottleneck, [3, 8, 36, 3], cardinality=32, base_width=4,
|
||||
Bottleneck, [3, 4, 6, 3], cardinality=32, base_width=4,
|
||||
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
|
|
|
@ -56,7 +56,7 @@ class Scheduler:
|
|||
|
||||
def step(self, epoch: int, metric: float = None) -> None:
|
||||
self.metric = metric
|
||||
values = self.get_epoch_values(epoch + 1) # +1 to calculate for next epoch
|
||||
values = self.get_epoch_values(epoch)
|
||||
if values is not None:
|
||||
self.update_groups(values)
|
||||
|
||||
|
|
|
@ -83,7 +83,8 @@ class CheckpointSaver:
|
|||
'arch': args.model,
|
||||
'state_dict': get_state_dict(model),
|
||||
'optimizer': optimizer.state_dict(),
|
||||
'args': args
|
||||
'args': args,
|
||||
'version': 2, # version < 2 increments epoch before save
|
||||
}
|
||||
if model_ema is not None:
|
||||
save_state['state_dict_ema'] = get_state_dict(model_ema)
|
||||
|
|
|
@ -1 +1 @@
|
|||
__version__ = '0.1.2'
|
||||
__version__ = '0.1.4'
|
||||
|
|
87
train.py
87
train.py
|
@ -27,22 +27,21 @@ import torchvision.utils
|
|||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
parser = argparse.ArgumentParser(description='Training')
|
||||
# Dataset / Model parameters
|
||||
parser.add_argument('data', metavar='DIR',
|
||||
help='path to dataset')
|
||||
parser.add_argument('--model', default='resnet101', type=str, metavar='MODEL',
|
||||
help='Name of model to train (default: "countception"')
|
||||
parser.add_argument('--num-classes', type=int, default=1000, metavar='N',
|
||||
help='number of label classes (default: 1000)')
|
||||
parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER',
|
||||
help='Optimizer (default: "sgd"')
|
||||
parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON',
|
||||
help='Optimizer Epsilon (default: 1e-8)')
|
||||
parser.add_argument('--gp', default='avg', type=str, metavar='POOL',
|
||||
help='Type of global pool, "avg", "max", "avgmax", "avgmaxc" (default: "avg")')
|
||||
parser.add_argument('--tta', type=int, default=0, metavar='N',
|
||||
help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)')
|
||||
parser.add_argument('--pretrained', action='store_true', default=False,
|
||||
help='Start with pretrained version of specified network (if avail)')
|
||||
parser.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH',
|
||||
help='Initialize model from this checkpoint (default: none)')
|
||||
parser.add_argument('--resume', default='', type=str, metavar='PATH',
|
||||
help='Resume full model and optimizer state from checkpoint (default: none)')
|
||||
parser.add_argument('--num-classes', type=int, default=1000, metavar='N',
|
||||
help='number of label classes (default: 1000)')
|
||||
parser.add_argument('--gp', default='avg', type=str, metavar='POOL',
|
||||
help='Type of global pool, "avg", "max", "avgmax", "avgmaxc" (default: "avg")')
|
||||
parser.add_argument('--img-size', type=int, default=None, metavar='N',
|
||||
help='Image patch size (default: None => model default)')
|
||||
parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
|
||||
|
@ -53,8 +52,24 @@ parser.add_argument('--interpolation', default='', type=str, metavar='NAME',
|
|||
help='Image resize interpolation type (overrides model)')
|
||||
parser.add_argument('-b', '--batch-size', type=int, default=32, metavar='N',
|
||||
help='input batch size for training (default: 32)')
|
||||
parser.add_argument('-s', '--initial-batch-size', type=int, default=0, metavar='N',
|
||||
help='initial input batch size for training (default: 0)')
|
||||
parser.add_argument('--drop', type=float, default=0.0, metavar='DROP',
|
||||
help='Dropout rate (default: 0.)')
|
||||
# Optimizer parameters
|
||||
parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER',
|
||||
help='Optimizer (default: "sgd"')
|
||||
parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON',
|
||||
help='Optimizer Epsilon (default: 1e-8)')
|
||||
parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
|
||||
help='SGD momentum (default: 0.9)')
|
||||
parser.add_argument('--weight-decay', type=float, default=0.0001,
|
||||
help='weight decay (default: 0.0001)')
|
||||
# Learning rate schedule parameters
|
||||
parser.add_argument('--sched', default='step', type=str, metavar='SCHEDULER',
|
||||
help='LR scheduler (default: "step"')
|
||||
parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
|
||||
help='learning rate (default: 0.01)')
|
||||
parser.add_argument('--warmup-lr', type=float, default=0.0001, metavar='LR',
|
||||
help='warmup learning rate (default: 0.0001)')
|
||||
parser.add_argument('--epochs', type=int, default=200, metavar='N',
|
||||
help='number of epochs to train (default: 2)')
|
||||
parser.add_argument('--start-epoch', default=None, type=int, metavar='N',
|
||||
|
@ -65,40 +80,34 @@ parser.add_argument('--warmup-epochs', type=int, default=3, metavar='N',
|
|||
help='epochs to warmup LR, if scheduler supports')
|
||||
parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',
|
||||
help='LR decay rate (default: 0.1)')
|
||||
parser.add_argument('--sched', default='step', type=str, metavar='SCHEDULER',
|
||||
help='LR scheduler (default: "step"')
|
||||
parser.add_argument('--drop', type=float, default=0.0, metavar='DROP',
|
||||
help='Dropout rate (default: 0.)')
|
||||
# Augmentation parameters
|
||||
parser.add_argument('--color_jitter', type=float, default=0.4, metavar='PCT',
|
||||
help='Color jitter factor (default: 0.4)')
|
||||
parser.add_argument('--reprob', type=float, default=0., metavar='PCT',
|
||||
help='Random erase prob (default: 0.)')
|
||||
parser.add_argument('--remode', type=str, default='const',
|
||||
help='Random erase mode (default: "const")')
|
||||
parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
|
||||
help='learning rate (default: 0.01)')
|
||||
parser.add_argument('--warmup-lr', type=float, default=0.0001, metavar='LR',
|
||||
help='warmup learning rate (default: 0.0001)')
|
||||
parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
|
||||
help='SGD momentum (default: 0.9)')
|
||||
parser.add_argument('--weight-decay', type=float, default=0.0001,
|
||||
help='weight decay (default: 0.0001)')
|
||||
parser.add_argument('--mixup', type=float, default=0.0,
|
||||
help='mixup alpha, mixup enabled if > 0. (default: 0.)')
|
||||
parser.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N',
|
||||
help='turn off mixup after this epoch, disabled if 0 (default: 0)')
|
||||
parser.add_argument('--smoothing', type=float, default=0.1,
|
||||
help='label smoothing (default: 0.1)')
|
||||
# Batch norm parameters (only works with gen_efficientnet based models currently)
|
||||
parser.add_argument('--bn-tf', action='store_true', default=False,
|
||||
help='Use Tensorflow BatchNorm defaults for models that support it (default: False)')
|
||||
parser.add_argument('--bn-momentum', type=float, default=None,
|
||||
help='BatchNorm momentum override (if not None)')
|
||||
parser.add_argument('--bn-eps', type=float, default=None,
|
||||
help='BatchNorm epsilon override (if not None)')
|
||||
# Model Exponential Moving Average
|
||||
parser.add_argument('--model-ema', action='store_true', default=False,
|
||||
help='Enable tracking moving average of model weights')
|
||||
parser.add_argument('--model-ema-force-cpu', action='store_true', default=False,
|
||||
help='Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.')
|
||||
parser.add_argument('--model-ema-decay', type=float, default=0.9998,
|
||||
help='decay factor for model weights moving average (default: 0.9998)')
|
||||
# Misc
|
||||
parser.add_argument('--seed', type=int, default=42, metavar='S',
|
||||
help='random seed (default: 42)')
|
||||
parser.add_argument('--log-interval', type=int, default=50, metavar='N',
|
||||
|
@ -109,10 +118,6 @@ parser.add_argument('-j', '--workers', type=int, default=4, metavar='N',
|
|||
help='how many training processes to use (default: 1)')
|
||||
parser.add_argument('--num-gpu', type=int, default=1,
|
||||
help='Number of GPUS to use')
|
||||
parser.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH',
|
||||
help='path to init checkpoint (default: none)')
|
||||
parser.add_argument('--resume', default='', type=str, metavar='PATH',
|
||||
help='path to latest checkpoint (default: none)')
|
||||
parser.add_argument('--save-images', action='store_true', default=False,
|
||||
help='save images of input bathes every log interval for debugging')
|
||||
parser.add_argument('--amp', action='store_true', default=False,
|
||||
|
@ -125,6 +130,8 @@ parser.add_argument('--output', default='', type=str, metavar='PATH',
|
|||
help='path to output folder (default: none, current dir)')
|
||||
parser.add_argument('--eval-metric', default='prec1', type=str, metavar='EVAL_METRIC',
|
||||
help='Best metric (default: "prec1"')
|
||||
parser.add_argument('--tta', type=int, default=0, metavar='N',
|
||||
help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)')
|
||||
parser.add_argument("--local_rank", default=0, type=int)
|
||||
|
||||
|
||||
|
@ -174,13 +181,13 @@ def main():
|
|||
logging.info('Model %s created, param count: %d' %
|
||||
(args.model, sum([m.numel() for m in model.parameters()])))
|
||||
|
||||
data_config = resolve_data_config(model, args, verbose=args.local_rank == 0)
|
||||
data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0)
|
||||
|
||||
# optionally resume from a checkpoint
|
||||
start_epoch = 0
|
||||
optimizer_state = None
|
||||
resume_epoch = None
|
||||
if args.resume:
|
||||
optimizer_state, start_epoch = resume_checkpoint(model, args.resume, args.start_epoch)
|
||||
optimizer_state, resume_epoch = resume_checkpoint(model, args.resume)
|
||||
|
||||
if args.num_gpu > 1:
|
||||
if args.amp:
|
||||
|
@ -232,8 +239,15 @@ def main():
|
|||
# NOTE: EMA model does not need to be wrapped by DDP
|
||||
|
||||
lr_scheduler, num_epochs = create_scheduler(args, optimizer)
|
||||
start_epoch = 0
|
||||
if args.start_epoch is not None:
|
||||
# a specified start_epoch will always override the resume epoch
|
||||
start_epoch = args.start_epoch
|
||||
elif resume_epoch is not None:
|
||||
start_epoch = resume_epoch
|
||||
if start_epoch > 0:
|
||||
lr_scheduler.step(start_epoch)
|
||||
|
||||
if args.local_rank == 0:
|
||||
logging.info('Scheduled epochs: {}'.format(num_epochs))
|
||||
|
||||
|
@ -255,6 +269,7 @@ def main():
|
|||
use_prefetcher=args.prefetcher,
|
||||
rand_erase_prob=args.reprob,
|
||||
rand_erase_mode=args.remode,
|
||||
color_jitter=args.color_jitter,
|
||||
interpolation='random', # FIXME cleanly resolve this? data_config['interpolation'],
|
||||
mean=data_config['mean'],
|
||||
std=data_config['std'],
|
||||
|
@ -327,7 +342,8 @@ def main():
|
|||
eval_metrics = ema_eval_metrics
|
||||
|
||||
if lr_scheduler is not None:
|
||||
lr_scheduler.step(epoch, eval_metrics[eval_metric])
|
||||
# step LR for next epoch
|
||||
lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])
|
||||
|
||||
update_summary(
|
||||
epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'),
|
||||
|
@ -338,9 +354,7 @@ def main():
|
|||
save_metric = eval_metrics[eval_metric]
|
||||
best_metric, best_epoch = saver.save_checkpoint(
|
||||
model, optimizer, args,
|
||||
epoch=epoch + 1,
|
||||
model_ema=model_ema,
|
||||
metric=save_metric)
|
||||
epoch=epoch, model_ema=model_ema, metric=save_metric)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
|
@ -433,9 +447,8 @@ def train_epoch(
|
|||
|
||||
if saver is not None and args.recovery_interval and (
|
||||
last_batch or (batch_idx + 1) % args.recovery_interval == 0):
|
||||
save_epoch = epoch + 1 if last_batch else epoch
|
||||
saver.save_recovery(
|
||||
model, optimizer, args, save_epoch, model_ema=model_ema, batch_idx=batch_idx)
|
||||
model, optimizer, args, epoch, model_ema=model_ema, batch_idx=batch_idx)
|
||||
|
||||
if lr_scheduler is not None:
|
||||
lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg)
|
||||
|
|
|
@ -71,7 +71,7 @@ def validate(args):
|
|||
param_count = sum([m.numel() for m in model.parameters()])
|
||||
logging.info('Model %s created, param count: %d' % (args.model, param_count))
|
||||
|
||||
data_config = resolve_data_config(model, args)
|
||||
data_config = resolve_data_config(vars(args), model=model)
|
||||
model, test_time_pool = apply_test_time_pool(model, data_config, args)
|
||||
|
||||
if args.num_gpu > 1:
|
||||
|
|
Loading…
Reference in New Issue