diff --git a/data/__init__.py b/data/__init__.py index e5289973..f1e1c182 100644 --- a/data/__init__.py +++ b/data/__init__.py @@ -1,4 +1,4 @@ from data.dataset import Dataset -from data.transforms import transforms_imagenet_eval, transforms_imagenet_train, get_model_meanstd -from data.utils import create_loader +from data.transforms import * +from data.loader import create_loader from data.random_erasing import RandomErasingTorch, RandomErasingNumpy \ No newline at end of file diff --git a/data/utils.py b/data/loader.py similarity index 94% rename from data/utils.py rename to data/loader.py index 964f4812..23ae0c7c 100644 --- a/data/utils.py +++ b/data/loader.py @@ -94,6 +94,9 @@ def create_loader( sampler = None if distributed: + # FIXME note, doing this for validation isn't technically correct + # There currently is no fixed order distributed sampler that corrects + # for padded entries sampler = tdata.distributed.DistributedSampler(dataset) loader = tdata.DataLoader( diff --git a/data/random_erasing.py b/data/random_erasing.py index 478b6253..81253311 100644 --- a/data/random_erasing.py +++ b/data/random_erasing.py @@ -1,8 +1,5 @@ from __future__ import absolute_import -#from torchvision.transforms import * - -from PIL import Image import random import math import numpy as np diff --git a/data/transforms.py b/data/transforms.py index 90419ae0..92d37ee6 100644 --- a/data/transforms.py +++ b/data/transforms.py @@ -7,26 +7,57 @@ from data.random_erasing import RandomErasingNumpy DEFAULT_CROP_PCT = 0.875 -IMAGENET_DPN_MEAN = [124 / 255, 117 / 255, 104 / 255] -IMAGENET_DPN_STD = [1 / (.0167 * 255)] * 3 -IMAGENET_INCEPTION_MEAN = [0.5, 0.5, 0.5] -IMAGENET_INCEPTION_STD = [0.5, 0.5, 0.5] -IMAGENET_DEFAULT_MEAN = [0.485, 0.456, 0.406] -IMAGENET_DEFAULT_STD = [0.229, 0.224, 0.225] +IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) +IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) +IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5) +IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5) +IMAGENET_DPN_MEAN = (124 / 255, 117 / 255, 104 / 255) +IMAGENET_DPN_STD = tuple([1 / (.0167 * 255)] * 3) -# FIXME replace these mean/std fn with model factory based values from config dict -def get_model_meanstd(model_name): - model_name = model_name.lower() - if 'dpn' in model_name: - return IMAGENET_DPN_MEAN, IMAGENET_DPN_STD - elif 'ception' in model_name or 'nasnet' in model_name: - return IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD +def get_mean_and_std(model, args, num_chan=3): + if hasattr(model, 'default_cfg'): + mean = model.default_cfg['mean'] + std = model.default_cfg['std'] else: - return IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD + if args.mean is not None: + mean = tuple(args.mean) + if len(mean) == 1: + mean = tuple(list(mean) * num_chan) + else: + assert len(mean) == num_chan + else: + mean = get_mean_by_model(args.model) + if args.std is not None: + std = tuple(args.std) + if len(std) == 1: + std = tuple(list(std) * num_chan) + else: + assert len(std) == num_chan + else: + std = get_std_by_model(args.model) + return mean, std -def get_model_mean(model_name): +def get_mean_by_name(name): + if name == 'dpn': + return IMAGENET_DPN_MEAN + elif name == 'inception' or name == 'le': + return IMAGENET_INCEPTION_MEAN + else: + return IMAGENET_DEFAULT_MEAN + + +def get_std_by_name(name): + if name == 'dpn': + return IMAGENET_DPN_STD + elif name == 'inception' or name == 'le': + return IMAGENET_INCEPTION_STD + else: + return IMAGENET_DEFAULT_STD + + +def get_mean_by_model(model_name): model_name = model_name.lower() if 'dpn' in model_name: return IMAGENET_DPN_STD @@ -36,7 +67,7 @@ def get_model_mean(model_name): return IMAGENET_DEFAULT_MEAN -def get_model_std(model_name): +def get_std_by_model(model_name): model_name = model_name.lower() if 'dpn' in model_name: return IMAGENET_DEFAULT_STD @@ -93,8 +124,8 @@ def transforms_imagenet_train( tfl += [ ToTensor(), transforms.Normalize( - mean=torch.tensor(mean) * 255, - std=torch.tensor(std) * 255) + mean=torch.tensor(mean), + std=torch.tensor(std)) ] if random_erasing > 0.: tfl.append(RandomErasingNumpy(random_erasing, per_pixel=True)) @@ -124,11 +155,5 @@ def transforms_imagenet_eval( mean=torch.tensor(mean), std=torch.tensor(std)) ] - # tfl += [ - # ToTensor(), - # transforms.Normalize( - # mean=torch.tensor(mean) * 255, - # std=torch.tensor(std) * 255) - # ] return transforms.Compose(tfl) diff --git a/inference.py b/inference.py index 8b696090..a7c0c851 100644 --- a/inference.py +++ b/inference.py @@ -11,10 +11,11 @@ import argparse import numpy as np import torch -from models import create_model, load_checkpoint, TestTimePoolHead -from data import Dataset, create_loader, get_model_meanstd +from models import create_model, apply_test_time_pool +from data import Dataset, create_loader, get_mean_and_std from utils import AverageMeter +torch.backends.cudnn.benchmark = True parser = argparse.ArgumentParser(description='PyTorch ImageNet Inference') parser.add_argument('data', metavar='DIR', @@ -29,6 +30,8 @@ parser.add_argument('-b', '--batch-size', default=256, type=int, metavar='N', help='mini-batch size (default: 256)') parser.add_argument('--img-size', default=224, type=int, metavar='N', help='Input image dimension') +parser.add_argument('--num-classes', type=int, default=1000, + help='Number classes in dataset') parser.add_argument('--print-freq', '-p', default=10, type=int, metavar='N', help='print frequency (default: 10)') parser.add_argument('--checkpoint', default='', type=str, metavar='PATH', @@ -45,26 +48,24 @@ def main(): args = parser.parse_args() # create model - num_classes = 1000 model = create_model( args.model, - num_classes=num_classes, - pretrained=args.pretrained) + num_classes=args.num_classes, + in_chans=3, + pretrained=args.pretrained, + checkpoint_path=args.checkpoint) print('Model %s created, param count: %d' % (args.model, sum([m.numel() for m in model.parameters()]))) - # load a checkpoint - if not args.pretrained: - if not load_checkpoint(model, args.checkpoint): - exit(1) + data_mean, data_std = get_mean_and_std(model, args) + model, test_time_pool = apply_test_time_pool(model, args) if args.num_gpu > 1: model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda() else: model = model.cuda() - data_mean, data_std = get_model_meanstd(args.model) loader = create_loader( Dataset(args.data), img_size=args.img_size, diff --git a/models/__init__.py b/models/__init__.py index e975d08d..2cd97ca6 100644 --- a/models/__init__.py +++ b/models/__init__.py @@ -1,3 +1,4 @@ -from .model_factory import create_model, load_checkpoint -from .test_time_pool import TestTimePoolHead +from models.model_factory import create_model +from models.helpers import load_checkpoint, resume_checkpoint +from models.test_time_pool import TestTimePoolHead, apply_test_time_pool diff --git a/models/densenet.py b/models/densenet.py index 46f8f7b9..c37659d0 100644 --- a/models/densenet.py +++ b/models/densenet.py @@ -5,19 +5,29 @@ fixed kwargs passthrough and addition of dynamic global avg/max pool. import torch import torch.nn as nn import torch.nn.functional as F -import torch.utils.model_zoo as model_zoo from collections import OrderedDict -from .adaptive_avgmax_pool import * + +from models.helpers import load_pretrained +from models.adaptive_avgmax_pool import * +from data.transforms import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD import re __all__ = ['DenseNet', 'densenet121', 'densenet169', 'densenet201', 'densenet161'] -model_urls = { - 'densenet121': 'https://download.pytorch.org/models/densenet121-241335ed.pth', - 'densenet169': 'https://download.pytorch.org/models/densenet169-6f0f7f60.pth', - 'densenet201': 'https://download.pytorch.org/models/densenet201-4c113574.pth', - 'densenet161': 'https://download.pytorch.org/models/densenet161-17b70270.pth', +def _cfg(url=''): + return { + 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 244), 'pool_size': (7, 7), + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'features.conv0', 'classifier': 'classifier', + } + + +default_cfgs = { + 'densenet121': _cfg(url='https://download.pytorch.org/models/densenet121-241335ed.pth'), + 'densenet169': _cfg(url='https://download.pytorch.org/models/densenet169-6f0f7f60.pth'), + 'densenet201': _cfg(url='https://download.pytorch.org/models/densenet201-4c113574.pth'), + 'densenet161': _cfg(url='https://download.pytorch.org/models/densenet161-17b70270.pth'), } @@ -34,59 +44,56 @@ def _filter_pretrained(state_dict): return state_dict -def densenet121(pretrained=False, **kwargs): +def densenet121(num_classes=1000, in_chans=3, pretrained=False, **kwargs): r"""Densenet-121 model from `"Densely Connected Convolutional Networks" ` - - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet """ - model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16), **kwargs) + default_cfg = default_cfgs['densenet121'] + model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16), + num_classes=num_classes, in_chans=in_chans, **kwargs) + model.default_cfg = default_cfg if pretrained: - state_dict = model_zoo.load_url(model_urls['densenet121']) - model.load_state_dict(_filter_pretrained(state_dict)) + load_pretrained(model, default_cfg, num_classes, in_chans, filter_fn=_filter_pretrained) return model -def densenet169(pretrained=False, **kwargs): +def densenet169(num_classes=1000, in_chans=3, pretrained=False, **kwargs): r"""Densenet-169 model from `"Densely Connected Convolutional Networks" ` - - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet """ - model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 32, 32), **kwargs) + default_cfg = default_cfgs['densenet169'] + model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 32, 32), + num_classes=num_classes, in_chans=in_chans, **kwargs) + model.default_cfg = default_cfg if pretrained: - state_dict = model_zoo.load_url(model_urls['densenet169']) - model.load_state_dict(_filter_pretrained(state_dict)) + load_pretrained(model, default_cfg, num_classes, in_chans, filter_fn=_filter_pretrained) return model -def densenet201(pretrained=False, **kwargs): +def densenet201(num_classes=1000, in_chans=3, pretrained=False, **kwargs): r"""Densenet-201 model from `"Densely Connected Convolutional Networks" ` - - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet """ - model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 48, 32), **kwargs) + default_cfg = default_cfgs['densenet201'] + model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 48, 32), + num_classes=num_classes, in_chans=in_chans, **kwargs) + model.default_cfg = default_cfg if pretrained: - state_dict = model_zoo.load_url(model_urls['densenet201']) - model.load_state_dict(_filter_pretrained(state_dict)) + load_pretrained(model, default_cfg, num_classes, in_chans, filter_fn=_filter_pretrained) return model -def densenet161(pretrained=False, **kwargs): +def densenet161(num_classes=1000, in_chans=3, pretrained=False, **kwargs): r"""Densenet-201 model from `"Densely Connected Convolutional Networks" ` - - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet """ - model = DenseNet(num_init_features=96, growth_rate=48, block_config=(6, 12, 36, 24), **kwargs) + print(num_classes, in_chans, pretrained) + default_cfg = default_cfgs['densenet161'] + model = DenseNet(num_init_features=96, growth_rate=48, block_config=(6, 12, 36, 24), + num_classes=num_classes, in_chans=in_chans, **kwargs) + model.default_cfg = default_cfg if pretrained: - state_dict = model_zoo.load_url(model_urls['densenet161']) - model.load_state_dict(_filter_pretrained(state_dict)) + load_pretrained(model, default_cfg, num_classes, in_chans, filter_fn=_filter_pretrained) return model @@ -142,14 +149,15 @@ class DenseNet(nn.Module): num_classes (int) - number of classification classes """ def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16), - num_init_features=64, bn_size=4, drop_rate=0, num_classes=1000, global_pool='avg'): + num_init_features=64, bn_size=4, drop_rate=0, + num_classes=1000, in_chans=3, global_pool='avg'): self.global_pool = global_pool self.num_classes = num_classes super(DenseNet, self).__init__() # First convolution self.features = nn.Sequential(OrderedDict([ - ('conv0', nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)), + ('conv0', nn.Conv2d(in_chans, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)), ('norm0', nn.BatchNorm2d(num_init_features)), ('relu0', nn.ReLU(inplace=True)), ('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)), @@ -172,7 +180,7 @@ class DenseNet(nn.Module): self.features.add_module('norm5', nn.BatchNorm2d(num_features)) # Linear layer - self.classifier = torch.nn.Linear(num_features, num_classes) + self.classifier = nn.Linear(num_features, num_classes) self.num_features = num_features @@ -184,7 +192,7 @@ class DenseNet(nn.Module): self.num_classes = num_classes del self.classifier if num_classes: - self.classifier = torch.nn.Linear(self.num_features, num_classes) + self.classifier = nn.Linear(self.num_features, num_classes) else: self.classifier = None diff --git a/models/dpn.py b/models/dpn.py index ec8fe9d2..c1ab8c6b 100644 --- a/models/dpn.py +++ b/models/dpn.py @@ -13,94 +13,108 @@ import os import torch import torch.nn as nn import torch.nn.functional as F -import torch.utils.model_zoo as model_zoo from collections import OrderedDict -from .adaptive_avgmax_pool import select_adaptive_pool2d +from models.helpers import load_pretrained +from models.adaptive_avgmax_pool import select_adaptive_pool2d +from data.transforms import IMAGENET_DPN_MEAN, IMAGENET_DPN_STD __all__ = ['DPN', 'dpn68', 'dpn92', 'dpn98', 'dpn131', 'dpn107'] -model_urls = { +def _cfg(url=''): + return { + 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'mean': IMAGENET_DPN_MEAN, 'std': IMAGENET_DPN_STD, + 'first_conv': 'features.conv1_1.conv', 'classifier': 'classifier', + } + + +default_cfgs = { 'dpn68': - 'http://data.lip6.fr/cadene/pretrainedmodels/dpn68-66bebafa7.pth', + _cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/dpn68-66bebafa7.pth'), 'dpn68b_extra': - 'http://data.lip6.fr/cadene/pretrainedmodels/' - 'dpn68b_extra-84854c156.pth', - 'dpn92': '', + _cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/dpn68b_extra-84854c156.pth'), 'dpn92_extra': - 'http://data.lip6.fr/cadene/pretrainedmodels/' - 'dpn92_extra-b040e4a9b.pth', + _cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/dpn92_extra-b040e4a9b.pth'), 'dpn98': - 'http://data.lip6.fr/cadene/pretrainedmodels/dpn98-5b90dec4d.pth', + _cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/dpn98-5b90dec4d.pth'), 'dpn131': - 'http://data.lip6.fr/cadene/pretrainedmodels/dpn131-71dfe43e0.pth', + _cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/dpn131-71dfe43e0.pth'), 'dpn107_extra': - 'http://data.lip6.fr/cadene/pretrainedmodels/' - 'dpn107_extra-1ac7121e2.pth' + _cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/dpn107_extra-1ac7121e2.pth') } -def dpn68(num_classes=1000, pretrained=False): +def dpn68(num_classes=1000, in_chans=3, pretrained=False): + default_cfg = default_cfgs['dpn68'] model = DPN( small=True, num_init_features=10, k_r=128, groups=32, k_sec=(3, 4, 12, 3), inc_sec=(16, 32, 32, 64), - num_classes=num_classes) + num_classes=num_classes, in_chans=in_chans) + model.default_cfg = default_cfg if pretrained: - model.load_state_dict(model_zoo.load_url(model_urls['dpn68'])) + load_pretrained(model, default_cfg, num_classes, in_chans) return model -def dpn68b(num_classes=1000, pretrained=False): +def dpn68b(num_classes=1000, in_chans=3, pretrained=False): + default_cfg = default_cfgs['dpn68b_extra'] model = DPN( small=True, num_init_features=10, k_r=128, groups=32, b=True, k_sec=(3, 4, 12, 3), inc_sec=(16, 32, 32, 64), - num_classes=num_classes) + num_classes=num_classes, in_chans=in_chans) + model.default_cfg = default_cfg if pretrained: - model.load_state_dict(model_zoo.load_url(model_urls['dpn68b_extra'])) + load_pretrained(model, default_cfg, num_classes, in_chans) return model -def dpn92(num_classes=1000, pretrained=False, extra=True): +def dpn92(num_classes=1000, in_chans=3, pretrained=False): + default_cfg = default_cfgs['dpn92_extra'] model = DPN( num_init_features=64, k_r=96, groups=32, k_sec=(3, 4, 20, 3), inc_sec=(16, 32, 24, 128), - num_classes=num_classes) + num_classes=num_classes, in_chans=in_chans) + model.default_cfg = default_cfg if pretrained: - if extra: - model.load_state_dict(model_zoo.load_url(model_urls['dpn92_extra'])) - else: - model.load_state_dict(model_zoo.load_url(model_urls['dpn92'])) + load_pretrained(model, default_cfg, num_classes, in_chans) return model -def dpn98(num_classes=1000, pretrained=False): +def dpn98(num_classes=1000, in_chans=3, pretrained=False): + default_cfg = default_cfgs['dpn98'] model = DPN( num_init_features=96, k_r=160, groups=40, k_sec=(3, 6, 20, 3), inc_sec=(16, 32, 32, 128), - num_classes=num_classes) + num_classes=num_classes, in_chans=in_chans) + model.default_cfg = default_cfg if pretrained: - model.load_state_dict(model_zoo.load_url(model_urls['dpn98'])) + load_pretrained(model, default_cfg, num_classes, in_chans) return model -def dpn131(num_classes=1000, pretrained=False): +def dpn131(num_classes=1000, in_chans=3, pretrained=False): + default_cfg = default_cfgs['dpn131'] model = DPN( num_init_features=128, k_r=160, groups=40, k_sec=(4, 8, 28, 3), inc_sec=(16, 32, 32, 128), - num_classes=num_classes) + num_classes=num_classes, in_chans=in_chans) + model.default_cfg = default_cfg if pretrained: - model.load_state_dict(model_zoo.load_url(model_urls['dpn131'])) + load_pretrained(model, default_cfg, num_classes, in_chans) return model -def dpn107(num_classes=1000, pretrained=False): +def dpn107(num_classes=1000, in_chans=3, pretrained=False): + default_cfg = default_cfgs['dpn107_extra'] model = DPN( num_init_features=128, k_r=200, groups=50, k_sec=(4, 8, 20, 3), inc_sec=(20, 64, 64, 128), - num_classes=num_classes) + num_classes=num_classes, in_chans=in_chans) + model.default_cfg = default_cfg if pretrained: - model.load_state_dict(model_zoo.load_url(model_urls['dpn107_extra'])) + load_pretrained(model, default_cfg, num_classes, in_chans) return model @@ -128,11 +142,11 @@ class BnActConv2d(nn.Module): class InputBlock(nn.Module): - def __init__(self, num_init_features, kernel_size=7, + def __init__(self, num_init_features, kernel_size=7, in_chans=3, padding=3, activation_fn=nn.ReLU(inplace=True)): super(InputBlock, self).__init__() self.conv = nn.Conv2d( - 3, num_init_features, kernel_size=kernel_size, stride=2, padding=padding, bias=False) + in_chans, num_init_features, kernel_size=kernel_size, stride=2, padding=padding, bias=False) self.bn = nn.BatchNorm2d(num_init_features, eps=0.001) self.act = activation_fn self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) @@ -212,7 +226,7 @@ class DualPathBlock(nn.Module): class DPN(nn.Module): def __init__(self, small=False, num_init_features=64, k_r=96, groups=32, b=False, k_sec=(3, 4, 20, 3), inc_sec=(16, 32, 24, 128), - num_classes=1000, fc_act=nn.ELU(inplace=True)): + num_classes=1000, in_chans=3, fc_act=nn.ELU(inplace=True)): super(DPN, self).__init__() self.num_classes = num_classes self.b = b @@ -222,9 +236,11 @@ class DPN(nn.Module): # conv1 if small: - blocks['conv1_1'] = InputBlock(num_init_features, kernel_size=3, padding=1) + blocks['conv1_1'] = InputBlock( + num_init_features, in_chans=in_chans, kernel_size=3, padding=1) else: - blocks['conv1_1'] = InputBlock(num_init_features, kernel_size=7, padding=3) + blocks['conv1_1'] = InputBlock( + num_init_features, in_chans=in_chans, kernel_size=7, padding=3) # conv2 bw = 64 * bw_factor diff --git a/models/helpers.py b/models/helpers.py new file mode 100644 index 00000000..7bb98dd3 --- /dev/null +++ b/models/helpers.py @@ -0,0 +1,89 @@ +import torch +import torch.utils.model_zoo as model_zoo +import os +from collections import OrderedDict + + +def load_checkpoint(model, checkpoint_path): + if checkpoint_path and os.path.isfile(checkpoint_path): + print("=> Loading checkpoint '{}'".format(checkpoint_path)) + checkpoint = torch.load(checkpoint_path) + if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: + new_state_dict = OrderedDict() + for k, v in checkpoint['state_dict'].items(): + if k.startswith('module'): + name = k[7:] # remove `module.` + else: + name = k + new_state_dict[name] = v + model.load_state_dict(new_state_dict) + else: + model.load_state_dict(checkpoint) + print("=> Loaded checkpoint '{}'".format(checkpoint_path)) + else: + print("=> Error: No checkpoint found at '{}'".format(checkpoint_path)) + raise FileNotFoundError() + + +def resume_checkpoint(model, checkpoint_path, start_epoch=None): + start_epoch = 0 if start_epoch is None else start_epoch + optimizer_state = None + if os.path.isfile(checkpoint_path): + print("=> loading checkpoint '{}'".format(checkpoint_path)) + checkpoint = torch.load(checkpoint_path) + if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: + new_state_dict = OrderedDict() + for k, v in checkpoint['state_dict'].items(): + if k.startswith('module'): + name = k[7:] # remove `module.` + else: + name = k + new_state_dict[name] = v + model.load_state_dict(new_state_dict) + if 'optimizer' in checkpoint: + optimizer_state = checkpoint['optimizer'] + print("=> loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch'])) + start_epoch = checkpoint['epoch'] if start_epoch is None else start_epoch + else: + model.load_state_dict(checkpoint) + return optimizer_state, start_epoch + else: + print("=> No checkpoint found at '{}'".format(checkpoint_path)) + raise FileNotFoundError() + + +def load_pretrained(model, default_cfg, num_classes=1000, in_chans=3, filter_fn=None): + state_dict = model_zoo.load_url(default_cfg['url']) + + if in_chans == 1: + conv1_name = default_cfg['first_conv'] + print('Converting first conv (%s) from 3 to 1 channel' % conv1_name) + conv1_weight = state_dict[conv1_name + '.weight'] + state_dict[conv1_name + '.weight'] = conv1_weight.sum(dim=1, keepdim=True) + elif in_chans != 3: + assert False, "Invalid in_chans for pretrained weights" + + strict = True + classifier_name = default_cfg['classifier'] + if num_classes == 1000 and default_cfg['num_classes'] == 1001: + # special case for imagenet trained models with extra background class in pretrained weights + classifier_weight = state_dict[classifier_name + '.weight'] + state_dict[classifier_name + '.weight'] = classifier_weight[1:] + classifier_bias = state_dict[classifier_name + '.bias'] + state_dict[classifier_name + '.bias'] = classifier_bias[1:] + elif num_classes != default_cfg['num_classes']: + # completely discard fully connected for all other differences between pretrained and created model + del state_dict[classifier_name + '.weight'] + del state_dict[classifier_name + '.bias'] + strict = False + + if filter_fn is not None: + state_dict = filter_fn(state_dict) + + model.load_state_dict(state_dict, strict=strict) + + + + + + diff --git a/models/inception_resnet_v2.py b/models/inception_resnet_v2.py index fabd9731..2b4f2a6a 100644 --- a/models/inception_resnet_v2.py +++ b/models/inception_resnet_v2.py @@ -5,12 +5,18 @@ based upon Google's Tensorflow implementation and pretrained weights (Apache 2.0 import torch import torch.nn as nn import torch.nn.functional as F -import torch.utils.model_zoo as model_zoo -import numpy as np -from .adaptive_avgmax_pool import * +from models.helpers import load_pretrained +from models.adaptive_avgmax_pool import * +from data.transforms import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD -model_urls = { - 'imagenet': 'http://data.lip6.fr/cadene/pretrainedmodels/inceptionresnetv2-520b38e4.pth' + +default_cfgs = { + 'inception_resnet_v2': { + 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/inceptionresnetv2-520b38e4.pth', + 'num_classes': 1001, 'input_size': (3, 299, 299), 'pool_size': (8, 8), + 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD, + 'first_conv': 'conv2d_1a.conv', 'classifier': 'last_linear', + } } @@ -204,12 +210,14 @@ class Block8(nn.Module): class InceptionResnetV2(nn.Module): - def __init__(self, num_classes=1001, drop_rate=0., global_pool='avg'): + def __init__(self, num_classes=1001, in_chans=3, drop_rate=0., global_pool='avg'): super(InceptionResnetV2, self).__init__() self.drop_rate = drop_rate self.global_pool = global_pool self.num_classes = num_classes - self.conv2d_1a = BasicConv2d(3, 32, kernel_size=3, stride=2) + self.num_features = 1536 + + self.conv2d_1a = BasicConv2d(in_chans, 32, kernel_size=3, stride=2) self.conv2d_2a = BasicConv2d(32, 32, kernel_size=3, stride=1) self.conv2d_2b = BasicConv2d(32, 64, kernel_size=3, stride=1, padding=1) self.maxpool_3a = nn.MaxPool2d(3, stride=2) @@ -265,29 +273,21 @@ class InceptionResnetV2(nn.Module): Block8(scale=0.20) ) self.block8 = Block8(noReLU=True) - self.conv2d_7b = BasicConv2d(2080, 1536, kernel_size=1, stride=1) - self.num_features = 1536 - self.last_linear = nn.Linear(1536, num_classes) + self.conv2d_7b = BasicConv2d(2080, self.num_features, kernel_size=1, stride=1) + self.last_linear = nn.Linear(self.num_features, num_classes) def get_classifier(self): - return self.classif + return self.last_linear def reset_classifier(self, num_classes, global_pool='avg'): self.global_pool = global_pool self.num_classes = num_classes - del self.classif + del self.last_linear if num_classes: - self.last_linear = torch.nn.Linear(1536, num_classes) + self.last_linear = torch.nn.Linear(self.num_features, num_classes) else: self.last_linear = None - def trim_classifier(self, trim=1): - self.num_classes -= trim - new_last_linear = nn.Linear(1536, self.num_classes) - new_last_linear.weight.data = self.last_linear.weight.data[trim:] - new_last_linear.bias.data = self.last_linear.bias.data[trim:] - self.last_linear = new_last_linear - def forward_features(self, x, pool=True): x = self.conv2d_1a(x) x = self.conv2d_2a(x) @@ -318,19 +318,15 @@ class InceptionResnetV2(nn.Module): return x -def inception_resnet_v2(pretrained=False, num_classes=1000, **kwargs): +def inception_resnet_v2(num_classes=1000, in_chans=3, pretrained=False, **kwargs): r"""InceptionResnetV2 model architecture from the `"InceptionV4, Inception-ResNet..." `_ paper. - - Args: - pretrained ('string'): If True, returns a model pre-trained on ImageNet """ - extra_class = 1 if pretrained else 0 - model = InceptionResnetV2(num_classes=num_classes + extra_class, **kwargs) + default_cfg = default_cfgs['inception_resnet_v2'] + model = InceptionResnetV2(num_classes=num_classes, in_chans=in_chans, **kwargs) + model.default_cfg = default_cfg if pretrained: - print('Loading pretrained from %s' % model_urls['imagenet']) - model.load_state_dict(model_zoo.load_url(model_urls['imagenet'])) - model.trim_classifier() + load_pretrained(model, default_cfg, num_classes, in_chans) return model diff --git a/models/inception_v4.py b/models/inception_v4.py index 3de774df..67cf74d0 100644 --- a/models/inception_v4.py +++ b/models/inception_v4.py @@ -5,11 +5,18 @@ based upon Google's Tensorflow implementation and pretrained weights (Apache 2.0 import torch import torch.nn as nn import torch.nn.functional as F -import torch.utils.model_zoo as model_zoo -from .adaptive_avgmax_pool import * +from models.helpers import load_pretrained +from models.adaptive_avgmax_pool import * +from data.transforms import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD -model_urls = { - 'imagenet': 'http://webia.lip6.fr/~cadene/Downloads/inceptionv4-97ef9c30.pth' + +default_cfgs = { + 'inception_v4': { + 'url': 'http://webia.lip6.fr/~cadene/Downloads/inceptionv4-97ef9c30.pth', + 'num_classes': 1001, 'input_size': (3, 299, 299), 'pool_size': (8, 8), + 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD, + 'first_conv': 'features.0.conv', 'classifier': 'classif', + } } @@ -230,13 +237,15 @@ class Inception_C(nn.Module): class InceptionV4(nn.Module): - def __init__(self, num_classes=1001, drop_rate=0., global_pool='avg'): + def __init__(self, num_classes=1001, in_chans=3, drop_rate=0., global_pool='avg'): super(InceptionV4, self).__init__() self.drop_rate = drop_rate self.global_pool = global_pool self.num_classes = num_classes + self.num_features = 1536 + self.features = nn.Sequential( - BasicConv2d(3, 32, kernel_size=3, stride=2), + BasicConv2d(in_chans, 32, kernel_size=3, stride=2), BasicConv2d(32, 32, kernel_size=3, stride=1), BasicConv2d(32, 64, kernel_size=3, stride=1, padding=1), Mixed_3a(), @@ -259,7 +268,7 @@ class InceptionV4(nn.Module): Inception_C(), Inception_C(), ) - self.classif = nn.Linear(1536, num_classes) + self.classif = nn.Linear(self.num_features, num_classes) def get_classifier(self): return self.classif @@ -267,12 +276,12 @@ class InceptionV4(nn.Module): def reset_classifier(self, num_classes, global_pool='avg'): self.global_pool = global_pool self.num_classes = num_classes - self.classif = nn.Linear(1536, num_classes) + self.classif = nn.Linear(self.num_features, num_classes) def forward_features(self, x, pool=True): x = self.features(x) if pool: - x = select_adaptive_pool2d(x, self.global_pool, count_include_pad=False) + x = select_adaptive_pool2d(x, self.global_pool) x = x.view(x.size(0), -1) return x @@ -284,10 +293,12 @@ class InceptionV4(nn.Module): return x -def inception_v4(pretrained=False, num_classes=1001, **kwargs): - model = InceptionV4(num_classes=num_classes, **kwargs) +def inception_v4(num_classes=1000, in_chans=3, pretrained=False, **kwargs): + default_cfg = default_cfgs['inception_v4'] + model = InceptionV4(num_classes=num_classes, in_chans=in_chans, **kwargs) + model.default_cfg = default_cfg if pretrained: - model.load_state_dict(model_zoo.load_url(model_urls['imagenet'])) + load_pretrained(model, default_cfg, num_classes, in_chans) return model diff --git a/models/model_factory.py b/models/model_factory.py index 06940496..66d67692 100644 --- a/models/model_factory.py +++ b/models/model_factory.py @@ -1,155 +1,34 @@ -import torch -import os -from collections import OrderedDict - -from .inception_v4 import inception_v4 -from .inception_resnet_v2 import inception_resnet_v2 -from .densenet import densenet161, densenet121, densenet169, densenet201 -from .resnet import resnet18, resnet34, resnet50, resnet101, resnet152, \ +from models.inception_v4 import inception_v4 +from models.inception_resnet_v2 import inception_resnet_v2 +from models.densenet import densenet161, densenet121, densenet169, densenet201 +from models.resnet import resnet18, resnet34, resnet50, resnet101, resnet152, \ resnext50_32x4d, resnext101_32x4d, resnext101_64x4d, resnext152_32x4d -from .dpn import dpn68, dpn68b, dpn92, dpn98, dpn131, dpn107 -from .senet import seresnet18, seresnet34, seresnet50, seresnet101, seresnet152, \ +from models.dpn import dpn68, dpn68b, dpn92, dpn98, dpn131, dpn107 +from models.senet import seresnet18, seresnet34, seresnet50, seresnet101, seresnet152, \ seresnext26_32x4d, seresnext50_32x4d, seresnext101_32x4d -#from .resnext import resnext50, resnext101, resnext152 -from .xception import xception -from .pnasnet import pnasnet5large +from models.xception import xception +from models.pnasnet import pnasnet5large -model_config_dict = { - 'resnet18': { - 'model_name': 'resnet18', 'num_classes': 1000, 'input_size': 224, 'normalizer': 'tv'}, - 'resnet34': { - 'model_name': 'resnet34', 'num_classes': 1000, 'input_size': 224, 'normalizer': 'tv'}, - 'resnet50': { - 'model_name': 'resnet50', 'num_classes': 1000, 'input_size': 224, 'normalizer': 'tv'}, - 'resnet101': { - 'model_name': 'resnet101', 'num_classes': 1000, 'input_size': 224, 'normalizer': 'tv'}, - 'resnet152': { - 'model_name': 'resnet152', 'num_classes': 1000, 'input_size': 224, 'normalizer': 'tv'}, - 'densenet121': { - 'model_name': 'densenet121', 'num_classes': 1000, 'input_size': 224, 'normalizer': 'tv'}, - 'densenet169': { - 'model_name': 'densenet169', 'num_classes': 1000, 'input_size': 224, 'normalizer': 'tv'}, - 'densenet201': { - 'model_name': 'densenet201', 'num_classes': 1000, 'input_size': 224, 'normalizer': 'tv'}, - 'densenet161': { - 'model_name': 'densenet161', 'num_classes': 1000, 'input_size': 224, 'normalizer': 'tv'}, - 'dpn107': { - 'model_name': 'dpn107', 'num_classes': 1000, 'input_size': 299, 'normalizer': 'dpn'}, - 'dpn92_extra': { - 'model_name': 'dpn92', 'num_classes': 1000, 'input_size': 299, 'normalizer': 'dpn'}, - 'dpn92': { - 'model_name': 'dpn92', 'num_classes': 1000, 'input_size': 299, 'normalizer': 'dpn'}, - 'dpn68': { - 'model_name': 'dpn68', 'num_classes': 1000, 'input_size': 299, 'normalizer': 'dpn'}, - 'dpn68b': { - 'model_name': 'dpn68b', 'num_classes': 1000, 'input_size': 299, 'normalizer': 'dpn'}, - 'dpn68b_extra': { - 'model_name': 'dpn68b', 'num_classes': 1000, 'input_size': 299, 'normalizer': 'dpn'}, - 'inception_resnet_v2': { - 'model_name': 'inception_resnet_v2', 'num_classes': 1000, 'input_size': 299, 'normalizer': 'le'}, - 'xception': { - 'model_name': 'xception', 'num_classes': 1000, 'input_size': 299, 'normalizer': 'le'}, - 'pnasnet5large': { - 'model_name': 'pnasnet5large', 'num_classes': 1000, 'input_size': 331, 'normalizer': 'le'} -} +from models.helpers import load_checkpoint def create_model( model_name='resnet50', pretrained=None, num_classes=1000, + in_chans=3, checkpoint_path='', **kwargs): - if model_name == 'dpn68': - model = dpn68(num_classes=num_classes, pretrained=pretrained) - elif model_name == 'dpn68b': - model = dpn68b(num_classes=num_classes, pretrained=pretrained) - elif model_name == 'dpn92': - model = dpn92(num_classes=num_classes, pretrained=pretrained) - elif model_name == 'dpn98': - model = dpn98(num_classes=num_classes, pretrained=pretrained) - elif model_name == 'dpn131': - model = dpn131(num_classes=num_classes, pretrained=pretrained) - elif model_name == 'dpn107': - model = dpn107(num_classes=num_classes, pretrained=pretrained) - elif model_name == 'resnet18': - model = resnet18(num_classes=num_classes, pretrained=pretrained, **kwargs) - elif model_name == 'resnet34': - model = resnet34(num_classes=num_classes, pretrained=pretrained, **kwargs) - elif model_name == 'resnet50': - model = resnet50(num_classes=num_classes, pretrained=pretrained, **kwargs) - elif model_name == 'resnet101': - model = resnet101(num_classes=num_classes, pretrained=pretrained, **kwargs) - elif model_name == 'resnet152': - model = resnet152(num_classes=num_classes, pretrained=pretrained, **kwargs) - elif model_name == 'densenet121': - model = densenet121(num_classes=num_classes, pretrained=pretrained, **kwargs) - elif model_name == 'densenet161': - model = densenet161(num_classes=num_classes, pretrained=pretrained, **kwargs) - elif model_name == 'densenet169': - model = densenet169(num_classes=num_classes, pretrained=pretrained, **kwargs) - elif model_name == 'densenet201': - model = densenet201(num_classes=num_classes, pretrained=pretrained, **kwargs) - elif model_name == 'inception_resnet_v2': - model = inception_resnet_v2(num_classes=num_classes, pretrained=pretrained, **kwargs) - elif model_name == 'inception_v4': - model = inception_v4(num_classes=num_classes, pretrained=pretrained, **kwargs) - elif model_name == 'seresnet18': - model = seresnet18(num_classes=num_classes, pretrained=pretrained, **kwargs) - elif model_name == 'seresnet34': - model = seresnet34(num_classes=num_classes, pretrained=pretrained, **kwargs) - elif model_name == 'seresnet50': - model = seresnet50(num_classes=num_classes, pretrained=pretrained, **kwargs) - elif model_name == 'seresnet101': - model = seresnet101(num_classes=num_classes, pretrained=pretrained, **kwargs) - elif model_name == 'seresnet152': - model = seresnet152(num_classes=num_classes, pretrained=pretrained, **kwargs) - elif model_name == 'seresnext26_32x4d': - model = seresnext26_32x4d(num_classes=num_classes, pretrained=pretrained, **kwargs) - elif model_name == 'seresnext50_32x4d': - model = seresnext50_32x4d(num_classes=num_classes, pretrained=pretrained, **kwargs) - elif model_name == 'seresnext101_32x4d': - model = seresnext101_32x4d(num_classes=num_classes, pretrained=pretrained, **kwargs) - elif model_name == 'resnext50_32x4d': - model = resnext50_32x4d(num_classes=num_classes, pretrained=pretrained, **kwargs) - elif model_name == 'resnext101_32x4d': - model = resnext101_32x4d(num_classes=num_classes, pretrained=pretrained, **kwargs) - elif model_name == 'resnext101_64x4d': - model = resnext101_32x4d(num_classes=num_classes, pretrained=pretrained, **kwargs) - elif model_name == 'resnext152_32x4d': - model = resnext152_32x4d(num_classes=num_classes, pretrained=pretrained, **kwargs) - elif model_name == 'xception': - model = xception(num_classes=num_classes, pretrained=pretrained) - elif model_name == 'pnasnet5large': - model = pnasnet5large(num_classes=num_classes, pretrained=pretrained) + margs = dict(num_classes=num_classes, in_chans=in_chans, pretrained=pretrained) + + if model_name in globals(): + create_fn = globals()[model_name] + model = create_fn(**margs, **kwargs) else: - assert False and "Invalid model" + raise RuntimeError('Unknown model (%s)' % model_name) if checkpoint_path and not pretrained: - print(checkpoint_path) load_checkpoint(model, checkpoint_path) return model - - -def load_checkpoint(model, checkpoint_path): - if checkpoint_path and os.path.isfile(checkpoint_path): - print("=> Loading checkpoint '{}'".format(checkpoint_path)) - checkpoint = torch.load(checkpoint_path) - if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: - new_state_dict = OrderedDict() - for k, v in checkpoint['state_dict'].items(): - if k.startswith('module'): - name = k[7:] # remove `module.` - else: - name = k - new_state_dict[name] = v - model.load_state_dict(new_state_dict) - else: - model.load_state_dict(checkpoint) - print("=> Loaded checkpoint '{}'".format(checkpoint_path)) - return True - else: - print("=> Error: No checkpoint found at '{}'".format(checkpoint_path)) - return False diff --git a/models/pnasnet.py b/models/pnasnet.py index 6aebb772..61542af5 100644 --- a/models/pnasnet.py +++ b/models/pnasnet.py @@ -3,29 +3,23 @@ from collections import OrderedDict import torch import torch.nn as nn -import torch.utils.model_zoo as model_zoo +import torch.nn.functional as F -pretrained_settings = { +from models.helpers import load_pretrained +from models.adaptive_avgmax_pool import SelectAdaptivePool2d + +default_cfgs = { 'pnasnet5large': { - 'imagenet': { - 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/pnasnet5large-bf079911.pth', - 'input_space': 'RGB', - 'input_size': [3, 331, 331], - 'input_range': [0, 1], - 'mean': [0.5, 0.5, 0.5], - 'std': [0.5, 0.5, 0.5], - 'num_classes': 1000 - }, - 'imagenet+background': { - 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/pnasnet5large-bf079911.pth', - 'input_space': 'RGB', - 'input_size': [3, 331, 331], - 'input_range': [0, 1], - 'mean': [0.5, 0.5, 0.5], - 'std': [0.5, 0.5, 0.5], - 'num_classes': 1001 - } - } + 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/pnasnet5large-bf079911.pth', + 'input_size': (3, 331, 331), + 'pool_size': (11, 11), + 'mean': (0.5, 0.5, 0.5), + 'std': (0.5, 0.5, 0.5), + 'crop_pct': 0.8975, + 'num_classes': 1001, + 'first_conv': 'conv_0.conv', + 'classifier': 'last_linear', + }, } @@ -288,13 +282,14 @@ class Cell(CellBase): class PNASNet5Large(nn.Module): - def __init__(self, num_classes=1001): + def __init__(self, num_classes=1001, in_chans=3, drop_rate=0.5, global_pool='avg'): super(PNASNet5Large, self).__init__() self.num_classes = num_classes self.num_features = 4320 + self.drop_rate = drop_rate self.conv_0 = nn.Sequential(OrderedDict([ - ('conv', nn.Conv2d(3, 96, kernel_size=3, stride=2, bias=False)), + ('conv', nn.Conv2d(in_chans, 96, kernel_size=3, stride=2, bias=False)), ('bn', nn.BatchNorm2d(96, eps=0.001)) ])) self.cell_stem_0 = CellStem0(in_channels_left=96, out_channels_left=54, @@ -334,18 +329,18 @@ class PNASNet5Large(nn.Module): self.cell_11 = Cell(in_channels_left=4320, out_channels_left=864, in_channels_right=4320, out_channels_right=864) self.relu = nn.ReLU() - self.avg_pool = nn.AvgPool2d(11, stride=1, padding=0) - self.dropout = nn.Dropout(0.5) - self.last_linear = nn.Linear(self.num_features, num_classes) + self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) + self.last_linear = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes) def get_classifier(self): return self.last_linear - def reset_classifier(self, num_classes): + def reset_classifier(self, num_classes, global_pool='avg'): self.num_classes = num_classes + self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) del self.last_linear if num_classes: - self.last_linear = nn.Linear(self.num_features, num_classes) + self.last_linear = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes) else: self.last_linear = None @@ -367,38 +362,27 @@ class PNASNet5Large(nn.Module): x_cell_11 = self.cell_11(x_cell_9, x_cell_10) x = self.relu(x_cell_11) if pool: - x = self.avg_pool(x) + x = self.global_pool(x) x = x.view(x.size(0), -1) return x def forward(self, input): x = self.forward_features(input) - x = self.dropout(x) + if self.drop_rate > 0: + x = F.dropout(x, self.drop_rate, training=self.training) x = self.last_linear(x) return x -def pnasnet5large(num_classes=1001, pretrained='imagenet'): +def pnasnet5large(num_classes=1000, in_chans=3, pretrained='imagenet'): r"""PNASNet-5 model architecture from the `"Progressive Neural Architecture Search" `_ paper. """ + default_cfg = default_cfgs['pnasnet5large'] + model = PNASNet5Large(num_classes=1000, in_chans=in_chans) + model.default_cfg = default_cfg if pretrained: - settings = pretrained_settings['pnasnet5large']['imagenet'] - assert num_classes == settings[ - 'num_classes'], 'num_classes should be {}, but is {}'.format( - settings['num_classes'], num_classes) + load_pretrained(model, default_cfg, num_classes, in_chans) - # both 'imagenet'&'imagenet+background' are loaded from same parameters - model = PNASNet5Large(num_classes=1001) - model.load_state_dict(model_zoo.load_url(settings['url'])) - - #if pretrained == 'imagenet': - new_last_linear = nn.Linear(model.last_linear.in_features, 1000) - new_last_linear.weight.data = model.last_linear.weight.data[1:] - new_last_linear.bias.data = model.last_linear.bias.data[1:] - model.last_linear = new_last_linear - - else: - model = PNASNet5Large(num_classes=num_classes) return model diff --git a/models/resnet.py b/models/resnet.py index 08d836e7..3a7e63cc 100644 --- a/models/resnet.py +++ b/models/resnet.py @@ -6,17 +6,33 @@ import torch import torch.nn as nn import torch.nn.functional as F import math -import torch.utils.model_zoo as model_zoo -from .adaptive_avgmax_pool import SelectAdaptivePool2d +from models.helpers import load_pretrained +from models.adaptive_avgmax_pool import SelectAdaptivePool2d +from data.transforms import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152'] +__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', + 'resnext50_32x4d', 'resnext101_32x4d', 'resnext101_64x4d', 'resnext152_32x4d'] -model_urls = { - 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', - 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', - 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', - 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', - 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', + +def _cfg(url=''): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'crop_pct': 0.875, + 'first_conv': 'conv1', 'classifier': 'fc', + } + + +default_cfgs = { + 'resnet18': _cfg(url='https://download.pytorch.org/models/resnet18-5c106cde.pth'), + 'resnet34': _cfg(url='https://download.pytorch.org/models/resnet34-333f7ec4.pth'), + 'resnet50': _cfg(url='https://download.pytorch.org/models/resnet50-19c8e357.pth'), + 'resnet101': _cfg(url='https://download.pytorch.org/models/resnet101-5d3b4d8f.pth'), + 'resnet152': _cfg(url='https://download.pytorch.org/models/resnet152-b121ed2d.pth'), + 'resnext50_32x4d': _cfg(url=''), + 'resnext101_32x4d': _cfg(url=''), + 'resnext101_64x4d': _cfg(url=''), + 'resnext152_32x4d': _cfg(url=''), } @@ -116,7 +132,7 @@ class Bottleneck(nn.Module): class ResNet(nn.Module): - def __init__(self, block, layers, num_classes=1000, + def __init__(self, block, layers, num_classes=1000, in_chans=3, cardinality=1, base_width=64, drop_rate=0.0, block_drop_rate=0.0, global_pool='avg'): @@ -127,7 +143,7 @@ class ResNet(nn.Module): self.drop_rate = drop_rate self.expansion = block.expansion super(ResNet, self).__init__() - self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) + self.conv1 = nn.Conv2d(in_chans, 64, kernel_size=7, stride=2, padding=3, bias=False) self.bn1 = nn.BatchNorm2d(64) self.relu = nn.ReLU(inplace=True) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) @@ -197,109 +213,108 @@ class ResNet(nn.Module): return x -def resnet18(pretrained=False, **kwargs): +def resnet18(num_classes=1000, in_chans=3, pretrained=False, **kwargs): """Constructs a ResNet-18 model. - - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet """ - model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) + default_cfg = default_cfgs['resnet18'] + model = ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes, in_chans=in_chans, **kwargs) + model.default_cfg = default_cfg if pretrained: - model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) + load_pretrained(model, default_cfg, num_classes, in_chans) return model -def resnet34(pretrained=False, **kwargs): +def resnet34(num_classes=1000, in_chans=3, pretrained=False, **kwargs): """Constructs a ResNet-34 model. - - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet """ - model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) + default_cfg = default_cfgs['resnet34'] + model = ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans, **kwargs) + model.default_cfg = default_cfg if pretrained: - model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) + load_pretrained(model, default_cfg, num_classes, in_chans) return model -def resnet50(pretrained=False, **kwargs): +def resnet50(num_classes=1000, in_chans=3, pretrained=False, **kwargs): """Constructs a ResNet-50 model. - - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet """ - model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) + default_cfg = default_cfgs['resnet50'] + model = ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans, **kwargs) + model.default_cfg = default_cfg if pretrained: - model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) + load_pretrained(model, default_cfg, num_classes, in_chans) return model -def resnet101(pretrained=False, **kwargs): +def resnet101(num_classes=1000, in_chans=3, pretrained=False, **kwargs): """Constructs a ResNet-101 model. - - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet """ - model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) + default_cfg = default_cfgs['resnet101'] + model = ResNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes, in_chans=in_chans, **kwargs) + model.default_cfg = default_cfg if pretrained: - model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) + load_pretrained(model, default_cfg, num_classes, in_chans) return model -def resnet152(pretrained=False, **kwargs): +def resnet152(num_classes=1000, in_chans=3, pretrained=False, **kwargs): """Constructs a ResNet-152 model. - - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet """ - model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) + default_cfg = default_cfgs['resnet152'] + model = ResNet(Bottleneck, [3, 8, 36, 3], num_classes=num_classes, in_chans=in_chans, **kwargs) + model.default_cfg = default_cfg if pretrained: - model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) + load_pretrained(model, default_cfg, num_classes, in_chans) return model -def resnext50_32x4d(cardinality=32, base_width=4, pretrained=False, **kwargs): +def resnext50_32x4d(num_classes=1000, in_chans=3, pretrained=False, **kwargs): """Constructs a ResNeXt50-32x4d model. - - Args: - cardinality (int): Cardinality of the aggregated transform - base_width (int): Base width of the grouped convolution """ + default_cfg = default_cfgs['resnext50_32x4d2'] model = ResNet( - Bottleneck, [3, 4, 6, 3], cardinality=cardinality, base_width=base_width, **kwargs) + Bottleneck, [3, 4, 6, 3], cardinality=32, base_width=4, + num_classes=num_classes, in_chans=in_chans, **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans) return model -def resnext101_32x4d(cardinality=32, base_width=4, pretrained=False, **kwargs): +def resnext101_32x4d(num_classes=1000, in_chans=3, pretrained=False, **kwargs): """Constructs a ResNeXt-101 model. - - Args: - cardinality (int): Cardinality of the aggregated transform - base_width (int): Base width of the grouped convolution """ + default_cfg = default_cfgs['resnext101_32x4d'] model = ResNet( - Bottleneck, [3, 4, 23, 3], cardinality=cardinality, base_width=base_width, **kwargs) + Bottleneck, [3, 4, 23, 3], cardinality=32, base_width=4, + num_classes=num_classes, in_chans=in_chans, **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans) return model -def resnext101_64x4d(cardinality=64, base_width=4, pretrained=False, **kwargs): +def resnext101_64x4d(num_classes=1000, in_chans=3, pretrained=False, **kwargs): """Constructs a ResNeXt101-64x4d model. - - Args: - cardinality (int): Cardinality of the aggregated transform - base_width (int): Base width of the grouped convolution """ + default_cfg = default_cfgs['resnext101_32x4d'] model = ResNet( - Bottleneck, [3, 4, 23, 3], cardinality=cardinality, base_width=base_width, **kwargs) + Bottleneck, [3, 4, 23, 3], cardinality=64, base_width=4, + num_classes=num_classes, in_chans=in_chans, **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans) return model -def resnext152_32x4d(cardinality=32, base_width=4, pretrained=False, **kwargs): +def resnext152_32x4d(num_classes=1000, in_chans=3, pretrained=False, **kwargs): """Constructs a ResNeXt152-32x4d model. - - Args: - cardinality (int): Cardinality of the aggregated transform - base_width (int): Base width of the grouped convolution """ + default_cfg = default_cfgs['resnext152_32x4d'] model = ResNet( - Bottleneck, [3, 8, 36, 3], cardinality=cardinality, base_width=base_width, **kwargs) + Bottleneck, [3, 8, 36, 3], cardinality=32, base_width=4, + num_classes=num_classes, in_chans=in_chans, **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans) return model diff --git a/models/senet.py b/models/senet.py index bacec15f..03df6432 100644 --- a/models/senet.py +++ b/models/senet.py @@ -8,21 +8,40 @@ import math import torch.nn as nn import torch.nn.functional as F -from torch.utils import model_zoo + +from models.helpers import load_pretrained from models.adaptive_avgmax_pool import SelectAdaptivePool2d +from data.transforms import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD __all__ = ['SENet', 'senet154', 'seresnet50', 'seresnet101', 'seresnet152', 'seresnext50_32x4d', 'seresnext101_32x4d'] -model_urls = { - 'senet154': 'http://data.lip6.fr/cadene/pretrainedmodels/senet154-c7b49a05.pth', - 'seresnet18': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnet50-ce0d4300.pth', - 'seresnet34': 'https://www.dropbox.com/s/q31ccy22aq0fju7/seresnet34-a4004e63.pth?dl=1', - 'seresnet50': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnet50-ce0d4300.pth', - 'seresnet101': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnet101-7e38fcc6.pth', - 'seresnet152': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnet152-d17c99b7.pth', - 'seresnext50_32x4d': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnext50_32x4d-a260b3a4.pth', - 'seresnext101_32x4d': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnext101_32x4d-3b2fe3d8.pth', + +def _cfg(url=''): + return { + 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 244), 'pool_size': (7, 7), + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'crop_pct': 0.875, + 'first_conv': 'layer0.conv1', 'classifier': 'last_linear', + } + + +default_cfgs = { + 'senet154': + _cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/senet154-c7b49a05.pth'), + 'seresnet18': + _cfg(url=''), + 'seresnet34': + _cfg(url='https://www.dropbox.com/s/q31ccy22aq0fju7/seresnet34-a4004e63.pth?dl=1'), + 'seresnet50': + _cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/se_resnet50-ce0d4300.pth'), + 'seresnet101': + _cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/se_resnet101-7e38fcc6.pth'), + 'seresnet152': + _cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/se_resnet152-d17c99b7.pth'), + 'seresnext50_32x4d': + _cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/se_resnext50_32x4d-a260b3a4.pth'), + 'seresnext101_32x4d': + _cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/se_resnext101_32x4d-3b2fe3d8.pth'), } @@ -197,7 +216,7 @@ class SEResNetBlock(nn.Module): class SENet(nn.Module): def __init__(self, block, layers, groups, reduction, drop_rate=0.2, - inchans=3, inplanes=128, input_3x3=True, downsample_kernel_size=3, + in_chans=3, inplanes=128, input_3x3=True, downsample_kernel_size=3, downsample_padding=1, num_classes=1000, global_pool='avg'): """ Parameters @@ -247,7 +266,7 @@ class SENet(nn.Module): self.num_classes = num_classes if input_3x3: layer0_modules = [ - ('conv1', nn.Conv2d(inchans, 64, 3, stride=2, padding=1, bias=False)), + ('conv1', nn.Conv2d(in_chans, 64, 3, stride=2, padding=1, bias=False)), ('bn1', nn.BatchNorm2d(64)), ('relu1', nn.ReLU(inplace=True)), ('conv2', nn.Conv2d(64, 64, 3, stride=1, padding=1, bias=False)), @@ -260,7 +279,7 @@ class SENet(nn.Module): else: layer0_modules = [ ('conv1', nn.Conv2d( - inchans, inplanes, kernel_size=7, stride=2, padding=3, bias=False)), + in_chans, inplanes, kernel_size=7, stride=2, padding=3, bias=False)), ('bn1', nn.BatchNorm2d(inplanes)), ('relu1', nn.ReLU(inplace=True)), ] @@ -368,99 +387,107 @@ class SENet(nn.Module): return x -def _load_pretrained(model, url, inchans=3): - state_dict = model_zoo.load_url(url) - if inchans == 1: - conv1_weight = state_dict['conv1.weight'] - state_dict['conv1.weight'] = conv1_weight.sum(dim=1, keepdim=True) - elif inchans != 3: - assert False, "Invalid inchans for pretrained weights" - model.load_state_dict(state_dict) - - -def senet154(num_classes=1000, inchans=3, pretrained='imagenet', **kwargs): - model = SENet(SEBottleneck, [3, 8, 36, 3], groups=64, reduction=16, - num_classes=num_classes, **kwargs) - if pretrained: - _load_pretrained(model, model_urls['senet154'], inchans) - return model - - -def seresnet18(num_classes=1000, inchans=3, pretrained='imagenet', **kwargs): +def seresnet18(num_classes=1000, in_chans=3, pretrained=False, **kwargs): + default_cfg = default_cfgs['seresnet18'] model = SENet(SEResNetBlock, [2, 2, 2, 2], groups=1, reduction=16, inplanes=64, input_3x3=False, downsample_kernel_size=1, downsample_padding=0, - num_classes=num_classes, **kwargs) + num_classes=num_classes, in_chans=in_chans, **kwargs) + model.default_cfg = default_cfg if pretrained: - _load_pretrained(model, model_urls['seresnet18'], inchans) + load_pretrained(model, default_cfg, num_classes, in_chans) return model -def seresnet34(num_classes=1000, inchans=3, pretrained='imagenet', **kwargs): +def seresnet34(num_classes=1000, in_chans=3, pretrained=False, **kwargs): + default_cfg = default_cfgs['seresnet34'] model = SENet(SEResNetBlock, [3, 4, 6, 3], groups=1, reduction=16, inplanes=64, input_3x3=False, downsample_kernel_size=1, downsample_padding=0, - num_classes=num_classes, **kwargs) + num_classes=num_classes, in_chans=in_chans, **kwargs) + model.default_cfg = default_cfg if pretrained: - _load_pretrained(model, model_urls['seresnet34'], inchans) + load_pretrained(model, default_cfg, num_classes, in_chans) return model -def seresnet50(num_classes=1000, inchans=3, pretrained='imagenet', **kwargs): +def seresnet50(num_classes=1000, in_chans=3, pretrained=False, **kwargs): + default_cfg = default_cfgs['seresnet50'] model = SENet(SEResNetBottleneck, [3, 4, 6, 3], groups=1, reduction=16, inplanes=64, input_3x3=False, downsample_kernel_size=1, downsample_padding=0, - num_classes=num_classes, **kwargs) + num_classes=num_classes, in_chans=in_chans, **kwargs) + model.default_cfg = default_cfg if pretrained: - _load_pretrained(model, model_urls['seresnet50'], inchans) + load_pretrained(model, default_cfg, num_classes, in_chans) return model -def seresnet101(num_classes=1000, inchans=3, pretrained='imagenet', **kwargs): +def seresnet101(num_classes=1000, in_chans=3, pretrained=False, **kwargs): + default_cfg = default_cfgs['seresnet101'] model = SENet(SEResNetBottleneck, [3, 4, 23, 3], groups=1, reduction=16, inplanes=64, input_3x3=False, downsample_kernel_size=1, downsample_padding=0, - num_classes=num_classes, **kwargs) + num_classes=num_classes, in_chans=in_chans, **kwargs) + model.default_cfg = default_cfg if pretrained: - _load_pretrained(model, model_urls['seresnet101'], inchans) + load_pretrained(model, default_cfg, num_classes, in_chans) return model -def seresnet152(num_classes=1000, inchans=3, pretrained='imagenet', **kwargs): +def seresnet152(num_classes=1000, in_chans=3, pretrained=False, **kwargs): + default_cfg = default_cfgs['seresnet152'] model = SENet(SEResNetBottleneck, [3, 8, 36, 3], groups=1, reduction=16, inplanes=64, input_3x3=False, downsample_kernel_size=1, downsample_padding=0, - num_classes=num_classes, **kwargs) + num_classes=num_classes, in_chans=in_chans, **kwargs) + model.default_cfg = default_cfg if pretrained: - _load_pretrained(model, model_urls['seresnet152'], inchans) + load_pretrained(model, default_cfg, num_classes, in_chans) return model -def seresnext26_32x4d(num_classes=1000, inchans=3, pretrained='imagenet', **kwargs): +def senet154(num_classes=1000, in_chans=3, pretrained=False, **kwargs): + default_cfg = default_cfgs['senet154'] + model = SENet(SEBottleneck, [3, 8, 36, 3], groups=64, reduction=16, + num_classes=num_classes, in_chans=in_chans, **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans) + return model + + +def seresnext26_32x4d(num_classes=1000, in_chans=3, pretrained=False, **kwargs): + default_cfg = default_cfgs['seresnext26_32x4d'] model = SENet(SEResNeXtBottleneck, [2, 2, 2, 2], groups=32, reduction=16, inplanes=64, input_3x3=False, downsample_kernel_size=1, downsample_padding=0, - num_classes=num_classes, **kwargs) + num_classes=num_classes, in_chans=in_chans, **kwargs) + model.default_cfg = default_cfg if pretrained: - _load_pretrained(model, model_urls['se_resnext26_32x4d'], inchans) + load_pretrained(model, default_cfg, num_classes, in_chans) return model -def seresnext50_32x4d(num_classes=1000, inchans=3, pretrained='imagenet', **kwargs): +def seresnext50_32x4d(num_classes=1000, in_chans=3, pretrained=False, **kwargs): + default_cfg = default_cfgs['seresnext50_32x4d'] model = SENet(SEResNeXtBottleneck, [3, 4, 6, 3], groups=32, reduction=16, inplanes=64, input_3x3=False, downsample_kernel_size=1, downsample_padding=0, - num_classes=num_classes, **kwargs) + num_classes=num_classes, in_chans=in_chans, **kwargs) + model.default_cfg = default_cfg if pretrained: - _load_pretrained(model, model_urls['seresnext50_32x4d'], inchans) + load_pretrained(model, default_cfg, num_classes, in_chans) return model -def seresnext101_32x4d(num_classes=1000, inchans=3, pretrained='imagenet', **kwargs): +def seresnext101_32x4d(num_classes=1000, in_chans=3, pretrained=False, **kwargs): + default_cfg = default_cfgs['seresnext101_32x4d'] model = SENet(SEResNeXtBottleneck, [3, 4, 23, 3], groups=32, reduction=16, inplanes=64, input_3x3=False, downsample_kernel_size=1, downsample_padding=0, - num_classes=num_classes, **kwargs) + num_classes=num_classes, in_chans=in_chans, **kwargs) + model.default_cfg = default_cfg if pretrained: - _load_pretrained(model, model_urls['seresnext101_32x4d'], inchans) + load_pretrained(model, default_cfg, num_classes, in_chans) return model diff --git a/models/test_time_pool.py b/models/test_time_pool.py index 269f15f8..ee95b81e 100644 --- a/models/test_time_pool.py +++ b/models/test_time_pool.py @@ -25,3 +25,12 @@ class TestTimePoolHead(nn.Module): x = adaptive_avgmax_pool2d(x, 1) return x.view(x.size(0), -1) + +def apply_test_time_pool(model, args): + test_time_pool = False + if args.img_size > model.default_cfg['input_size'][-1] and not args.no_test_pool: + print('Target input size (%d) > pretrained default (%d), using test time pooling' % + (args.img_size, model.default_cfg['input_size'][-1])) + model = TestTimePoolHead(model, original_pool=model.default_cfg['pool_size']) + test_time_pool = True + return model, test_time_pool diff --git a/models/xception.py b/models/xception.py index 97b3947d..4a981979 100644 --- a/models/xception.py +++ b/models/xception.py @@ -26,24 +26,24 @@ import math import torch import torch.nn as nn import torch.nn.functional as F -import torch.utils.model_zoo as model_zoo -from torch.nn import init + +from models.helpers import load_pretrained +from models.adaptive_avgmax_pool import select_adaptive_pool2d + __all__ = ['xception'] -pretrained_config = { +default_cfgs = { 'xception': { - 'imagenet': { - 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/xception-43020ad28.pth', - 'input_space': 'RGB', - 'input_size': [3, 299, 299], - 'input_range': [0, 1], - 'mean': [0.5, 0.5, 0.5], - 'std': [0.5, 0.5, 0.5], - 'num_classes': 1000, - 'scale': 0.8975 + 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/xception-43020ad28.pth', + 'input_size': (3, 299, 299), + 'mean': (0.5, 0.5, 0.5), + 'std': (0.5, 0.5, 0.5), + 'num_classes': 1000, + 'crop_pct': 0.8975, + 'first_conv': 'conv1', + 'classifier': 'fc' # The resize parameter of the validation transform should be 333, and make sure to center crop at 299x299 - } } } @@ -120,16 +120,18 @@ class Xception(nn.Module): https://arxiv.org/pdf/1610.02357.pdf """ - def __init__(self, num_classes=1000): + def __init__(self, num_classes=1000, in_chans=3, drop_rate=0., global_pool='avg'): """ Constructor Args: num_classes: number of classes """ super(Xception, self).__init__() + self.drop_rate = drop_rate + self.global_pool = global_pool self.num_classes = num_classes self.num_features = 2048 - self.conv1 = nn.Conv2d(3, 32, 3, 2, 0, bias=False) + self.conv1 = nn.Conv2d(in_chans, 32, 3, 2, 0, bias=False) self.bn1 = nn.BatchNorm2d(32) self.relu = nn.ReLU(inplace=True) @@ -173,8 +175,9 @@ class Xception(nn.Module): def get_classifier(self): return self.fc - def reset_classifier(self, num_classes): + def reset_classifier(self, num_classes, global_pool='avg'): self.num_classes = num_classes + self.global_pool = global_pool del self.fc if num_classes: self.fc = nn.Linear(self.num_features, num_classes) @@ -212,24 +215,23 @@ class Xception(nn.Module): x = self.relu(x) if pool: - x = F.adaptive_avg_pool2d(x, (1, 1)) + x = select_adaptive_pool2d(x, pool_type=self.global_pool) x = x.view(x.size(0), -1) return x def forward(self, input): x = self.forward_features(input) + if self.drop_rate: + F.dropout(x, self.drop_rate, training=self.training) x = self.fc(x) return x -def xception(num_classes=1000, pretrained=False): - model = Xception(num_classes=num_classes) +def xception(num_classes=1000, in_chans=3, pretrained=False, **kwargs): + default_cfg = default_cfgs['xception'] + model = Xception(num_classes=num_classes, in_chans=in_chans, **kwargs) + model.default_cfg = default_cfg if pretrained: - config = pretrained_config['xception']['imagenet'] - assert num_classes == config['num_classes'], \ - "num_classes should be {}, but is {}".format(config['num_classes'], num_classes) - - model = Xception(num_classes=num_classes) - model.load_state_dict(model_zoo.load_url(config['url'])) + load_pretrained(model, default_cfg, num_classes, in_chans) return model diff --git a/optim/__init__.py b/optim/__init__.py index d8995736..62ab43ba 100644 --- a/optim/__init__.py +++ b/optim/__init__.py @@ -1,2 +1,3 @@ from optim.adabound import AdaBound -from optim.nadam import Nadam \ No newline at end of file +from optim.nadam import Nadam +from optim.optim_factory import create_optimizer \ No newline at end of file diff --git a/optim/optim_factory.py b/optim/optim_factory.py new file mode 100644 index 00000000..d207dbcf --- /dev/null +++ b/optim/optim_factory.py @@ -0,0 +1,30 @@ +from torch import optim as optim +from optim import Nadam, AdaBound + + +def create_optimizer(args, parameters): + if args.opt.lower() == 'sgd': + optimizer = optim.SGD( + parameters, lr=args.lr, + momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True) + elif args.opt.lower() == 'adam': + optimizer = optim.Adam( + parameters, lr=args.lr, weight_decay=args.weight_decay, eps=args.opt_eps) + elif args.opt.lower() == 'nadam': + optimizer = Nadam( + parameters, lr=args.lr, weight_decay=args.weight_decay, eps=args.opt_eps) + elif args.opt.lower() == 'adabound': + optimizer = AdaBound( + parameters, lr=args.lr / 100, weight_decay=args.weight_decay, eps=args.opt_eps, + final_lr=args.lr) + elif args.opt.lower() == 'adadelta': + optimizer = optim.Adadelta( + parameters, lr=args.lr, weight_decay=args.weight_decay, eps=args.opt_eps) + elif args.opt.lower() == 'rmsprop': + optimizer = optim.RMSprop( + parameters, lr=args.lr, alpha=0.9, eps=args.opt_eps, + momentum=args.momentum, weight_decay=args.weight_decay) + else: + assert False and "Invalid optimizer" + raise ValueError + return optimizer diff --git a/scheduler/__init__.py b/scheduler/__init__.py index 8242163f..4d84a052 100644 --- a/scheduler/__init__.py +++ b/scheduler/__init__.py @@ -1,4 +1,5 @@ -from .cosine_lr import CosineLRScheduler -from .plateau_lr import PlateauLRScheduler -from .step_lr import StepLRScheduler -from .tanh_lr import TanhLRScheduler \ No newline at end of file +from scheduler.cosine_lr import CosineLRScheduler +from scheduler.plateau_lr import PlateauLRScheduler +from scheduler.step_lr import StepLRScheduler +from scheduler.tanh_lr import TanhLRScheduler +from scheduler.scheduler_factory import create_scheduler \ No newline at end of file diff --git a/scheduler/scheduler_factory.py b/scheduler/scheduler_factory.py new file mode 100644 index 00000000..55c4927d --- /dev/null +++ b/scheduler/scheduler_factory.py @@ -0,0 +1,43 @@ +from scheduler.cosine_lr import CosineLRScheduler +from scheduler.plateau_lr import PlateauLRScheduler +from scheduler.tanh_lr import TanhLRScheduler +from scheduler.step_lr import StepLRScheduler + + +def create_scheduler(args, optimizer): + num_epochs = args.epochs + #FIXME expose cycle parms of the scheduler config to arguments + if args.sched == 'cosine': + lr_scheduler = CosineLRScheduler( + optimizer, + t_initial=num_epochs, + t_mul=1.0, + lr_min=1e-5, + decay_rate=args.decay_rate, + warmup_lr_init=args.warmup_lr, + warmup_t=args.warmup_epochs, + cycle_limit=1, + t_in_epochs=True, + ) + num_epochs = lr_scheduler.get_cycle_length() + 10 + elif args.sched == 'tanh': + lr_scheduler = TanhLRScheduler( + optimizer, + t_initial=num_epochs, + t_mul=1.0, + lr_min=1e-5, + warmup_lr_init=args.warmup_lr, + warmup_t=args.warmup_epochs, + cycle_limit=1, + t_in_epochs=True, + ) + num_epochs = lr_scheduler.get_cycle_length() + 10 + else: + lr_scheduler = StepLRScheduler( + optimizer, + decay_t=args.decay_epochs, + decay_rate=args.decay_rate, + warmup_lr_init=args.warmup_lr, + warmup_t=args.warmup_epochs, + ) + return lr_scheduler, num_epochs diff --git a/train.py b/train.py index 5494b3ff..d2931ca9 100644 --- a/train.py +++ b/train.py @@ -1,7 +1,6 @@ import argparse import time -from collections import OrderedDict from datetime import datetime try: @@ -12,17 +11,14 @@ except ImportError: has_apex = False from data import * -from models import model_factory +from models import create_model, resume_checkpoint from utils import * -from optim import Nadam, AdaBound from loss import LabelSmoothingCrossEntropy -import scheduler +from optim import create_optimizer +from scheduler import create_scheduler import torch import torch.nn as nn -import torch.nn.functional as F -import torch.optim as optim -import torch.utils.data as data import torch.distributed as dist import torchvision.utils @@ -33,6 +29,8 @@ parser.add_argument('data', metavar='DIR', help='path to dataset') parser.add_argument('--model', default='resnet101', type=str, metavar='MODEL', help='Name of model to train (default: "countception"') +parser.add_argument('--num-classes', type=int, default=1000, metavar='N', + help='number of label classes (default: 1000)') parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER', help='Optimizer (default: "sgd"') parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON', @@ -120,10 +118,13 @@ def main(): r = torch.distributed.get_rank() if args.distributed: - print('Training in distributed mode with %d processes, 1 GPU per process. Process %d.' - % (args.world_size, r)) + print('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.' + % (r, args.world_size)) else: - print('Training with a single process with %d GPUs.' % args.num_gpu) + print('Training with a single process on %d GPUs.' % args.num_gpu) + + # FIXME seed handling for multi-process distributed? + torch.manual_seed(args.seed) output_dir = '' if args.local_rank == 0: @@ -137,80 +138,21 @@ def main(): str(args.img_size)]) output_dir = get_outdir(output_base, 'train', exp_name) - batch_size = args.batch_size - torch.manual_seed(args.seed) - - data_mean, data_std = get_model_meanstd(args.model) - - dataset_train = Dataset(os.path.join(args.data, 'train')) - - loader_train = create_loader( - dataset_train, - img_size=args.img_size, - batch_size=batch_size, - is_training=True, - use_prefetcher=True, - random_erasing=0.3, - mean=data_mean, - std=data_std, - num_workers=args.workers, - distributed=args.distributed, - ) - - dataset_eval = Dataset(os.path.join(args.data, 'validation')) - - loader_eval = create_loader( - dataset_eval, - img_size=args.img_size, - batch_size=4 * args.batch_size, - is_training=False, - use_prefetcher=True, - mean=data_mean, - std=data_std, - num_workers=args.workers, - distributed=args.distributed, - ) - - model = model_factory.create_model( + model = create_model( args.model, pretrained=args.pretrained, - num_classes=1000, + num_classes=args.num_classes, drop_rate=args.drop, global_pool=args.gp, checkpoint_path=args.initial_checkpoint) + data_mean, data_std = get_mean_and_std(model, args) + # optionally resume from a checkpoint - start_epoch = 0 if args.start_epoch is None else args.start_epoch + start_epoch = 0 optimizer_state = None if args.resume: - if os.path.isfile(args.resume): - print("=> loading checkpoint '{}'".format(args.resume)) - checkpoint = torch.load(args.resume) - if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: - new_state_dict = OrderedDict() - for k, v in checkpoint['state_dict'].items(): - if k.startswith('module'): - name = k[7:] # remove `module.` - else: - name = k - new_state_dict[name] = v - model.load_state_dict(new_state_dict) - if 'optimizer' in checkpoint: - optimizer_state = checkpoint['optimizer'] - print("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch'])) - start_epoch = checkpoint['epoch'] if args.start_epoch is None else args.start_epoch - else: - model.load_state_dict(checkpoint) - else: - print("=> no checkpoint found at '{}'".format(args.resume)) - return False - - if args.smoothing: - train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing).cuda() - validate_loss_fn = nn.CrossEntropyLoss().cuda() - else: - train_loss_fn = nn.CrossEntropyLoss().cuda() - validate_loss_fn = train_loss_fn + start_epoch, optimizer_state = resume_checkpoint(model, args.resume, args.start_epoch) if args.num_gpu > 1: if args.amp: @@ -237,9 +179,55 @@ def main(): model = DDP(model, delay_allreduce=True) lr_scheduler, num_epochs = create_scheduler(args, optimizer) + if start_epoch > 0: + lr_scheduler.step(start_epoch) if args.local_rank == 0: print('Scheduled epochs: ', num_epochs) + train_dir = os.path.join(args.data, 'train') + if not os.path.exists(train_dir): + print('Error: training folder does not exist at: %s' % train_dir) + exit(1) + dataset_train = Dataset(train_dir) + + loader_train = create_loader( + dataset_train, + img_size=args.img_size, + batch_size=args.batch_size, + is_training=True, + use_prefetcher=True, + random_erasing=0.3, + mean=data_mean, + std=data_std, + num_workers=args.workers, + distributed=args.distributed, + ) + + eval_dir = os.path.join(args.data, 'validation') + if not os.path.isdir(eval_dir): + print('Error: validation folder does not exist at: %s' % eval_dir) + exit(1) + dataset_eval = Dataset(eval_dir) + + loader_eval = create_loader( + dataset_eval, + img_size=args.img_size, + batch_size=4 * args.batch_size, + is_training=False, + use_prefetcher=True, + mean=data_mean, + std=data_std, + num_workers=args.workers, + distributed=args.distributed, + ) + + if args.smoothing: + train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing).cuda() + validate_loss_fn = nn.CrossEntropyLoss().cuda() + else: + train_loss_fn = nn.CrossEntropyLoss().cuda() + validate_loss_fn = train_loss_fn + eval_metric = args.eval_metric saver = None if output_dir: @@ -429,76 +417,9 @@ def validate(model, loader, loss_fn, args): return metrics -def create_optimizer(args, parameters): - if args.opt.lower() == 'sgd': - optimizer = optim.SGD( - parameters, lr=args.lr, - momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True) - elif args.opt.lower() == 'adam': - optimizer = optim.Adam( - parameters, lr=args.lr, weight_decay=args.weight_decay, eps=args.opt_eps) - elif args.opt.lower() == 'nadam': - optimizer = Nadam( - parameters, lr=args.lr, weight_decay=args.weight_decay, eps=args.opt_eps) - elif args.opt.lower() == 'adabound': - optimizer = AdaBound( - parameters, lr=args.lr / 100, weight_decay=args.weight_decay, eps=args.opt_eps, - final_lr=args.lr) - elif args.opt.lower() == 'adadelta': - optimizer = optim.Adadelta( - parameters, lr=args.lr, weight_decay=args.weight_decay, eps=args.opt_eps) - elif args.opt.lower() == 'rmsprop': - optimizer = optim.RMSprop( - parameters, lr=args.lr, alpha=0.9, eps=args.opt_eps, - momentum=args.momentum, weight_decay=args.weight_decay) - else: - assert False and "Invalid optimizer" - raise ValueError - return optimizer - - -def create_scheduler(args, optimizer): - num_epochs = args.epochs - #FIXME expose cycle parms of the scheduler config to arguments - if args.sched == 'cosine': - lr_scheduler = scheduler.CosineLRScheduler( - optimizer, - t_initial=num_epochs, - t_mul=1.0, - lr_min=1e-5, - decay_rate=args.decay_rate, - warmup_lr_init=args.warmup_lr, - warmup_t=args.warmup_epochs, - cycle_limit=1, - t_in_epochs=True, - ) - num_epochs = lr_scheduler.get_cycle_length() + 10 - elif args.sched == 'tanh': - lr_scheduler = scheduler.TanhLRScheduler( - optimizer, - t_initial=num_epochs, - t_mul=1.0, - lr_min=1e-5, - warmup_lr_init=args.warmup_lr, - warmup_t=args.warmup_epochs, - cycle_limit=1, - t_in_epochs=True, - ) - num_epochs = lr_scheduler.get_cycle_length() + 10 - else: - lr_scheduler = scheduler.StepLRScheduler( - optimizer, - decay_t=args.decay_epochs, - decay_rate=args.decay_rate, - warmup_lr_init=args.warmup_lr, - warmup_t=args.warmup_epochs, - ) - return lr_scheduler, num_epochs - - def reduce_tensor(tensor, n): rt = tensor.clone() - dist.all_reduce(rt, op=dist.reduce_op.SUM) + dist.all_reduce(rt, op=dist.ReduceOp.SUM) rt /= n return rt diff --git a/validate.py b/validate.py index 1e82a1fc..07aac71c 100644 --- a/validate.py +++ b/validate.py @@ -6,13 +6,14 @@ import argparse import os import time import torch -import torch.backends.cudnn as cudnn import torch.nn as nn import torch.nn.parallel -from models import create_model, load_checkpoint, TestTimePoolHead -from data import Dataset, create_loader, get_model_meanstd +from models import create_model, apply_test_time_pool +from data import Dataset, create_loader, get_mean_and_std +from utils import accuracy, AverageMeter +torch.backends.cudnn.benchmark = True parser = argparse.ArgumentParser(description='PyTorch ImageNet Validation') parser.add_argument('data', metavar='DIR', @@ -25,6 +26,8 @@ parser.add_argument('-b', '--batch-size', default=256, type=int, metavar='N', help='mini-batch size (default: 256)') parser.add_argument('--img-size', default=224, type=int, metavar='N', help='Input image dimension') +parser.add_argument('--num-classes', type=int, default=1000, + help='Number classes in dataset') parser.add_argument('--print-freq', '-p', default=10, type=int, metavar='N', help='print frequency (default: 10)') parser.add_argument('--checkpoint', default='', type=str, metavar='PATH', @@ -41,25 +44,19 @@ def main(): args = parser.parse_args() # create model - num_classes = 1000 model = create_model( args.model, - num_classes=num_classes, - pretrained=args.pretrained) + num_classes=args.num_classes, + in_chans=3, + pretrained=args.pretrained, + checkpoint_path=args.checkpoint) print('Model %s created, param count: %d' % (args.model, sum([m.numel() for m in model.parameters()]))) - # load a checkpoint - if not args.pretrained: - if not load_checkpoint(model, args.checkpoint): - exit(1) + data_mean, data_std = get_mean_and_std(model, args) - test_time_pool = False - # FIXME make this work for networks with default img size != 224 and default pool k != 7 - if args.img_size > 224 and not args.no_test_pool: - model = TestTimePoolHead(model) - test_time_pool = True + model, test_time_pool = apply_test_time_pool(model, args) if args.num_gpu > 1: model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda() @@ -69,14 +66,11 @@ def main(): # define loss function (criterion) and optimizer criterion = nn.CrossEntropyLoss().cuda() - cudnn.benchmark = True - - data_mean, data_std = get_model_meanstd(args.model) loader = create_loader( Dataset(args.data), img_size=args.img_size, batch_size=args.batch_size, - use_prefetcher=True, + use_prefetcher=False, mean=data_mean, std=data_std, num_workers=args.workers, @@ -111,51 +105,17 @@ def main(): if i % args.print_freq == 0: print('Test: [{0}/{1}]\t' - 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' + 'Time {batch_time.val:.3f} ({batch_time.avg:.3f}, {rate_avg:.3f}/s) \t' 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( - i, len(loader), batch_time=batch_time, loss=losses, - top1=top1, top5=top5)) + i, len(loader), batch_time=batch_time, + rate_avg=input.size(0) / batch_time.avg, + loss=losses, top1=top1, top5=top5)) print(' * Prec@1 {top1.avg:.3f} ({top1a:.3f}) Prec@5 {top5.avg:.3f} ({top5a:.3f})'.format( top1=top1, top1a=100-top1.avg, top5=top5, top5a=100.-top5.avg)) -class AverageMeter(object): - """Computes and stores the average and current value""" - - def __init__(self): - self.reset() - - def reset(self): - self.val = 0 - self.avg = 0 - self.sum = 0 - self.count = 0 - - def update(self, val, n=1): - self.val = val - self.sum += val * n - self.count += n - self.avg = self.sum / self.count - - -def accuracy(output, target, topk=(1,)): - """Computes the precision@k for the specified values of k""" - maxk = max(topk) - batch_size = target.size(0) - - _, pred = output.topk(maxk, 1, True, True) - pred = pred.t() - correct = pred.eq(target.view(1, -1).expand_as(pred)) - - res = [] - for k in topk: - correct_k = correct[:k].view(-1).float().sum(0) - res.append(correct_k.mul_(100.0 / batch_size)) - return res - - if __name__ == '__main__': main()