mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Uniform pretrained model handling.
* All models have 'default_cfgs' dict * load/resume/pretrained helpers factored out * pretrained load operates on state_dict based on default_cfg * test all models in validate * schedule, optim factor factored out * test time pool wrapper applied based on default_cfg
This commit is contained in:
parent
63e677d03b
commit
9c3859fb9c
@ -1,4 +1,4 @@
|
||||
from data.dataset import Dataset
|
||||
from data.transforms import transforms_imagenet_eval, transforms_imagenet_train, get_model_meanstd
|
||||
from data.utils import create_loader
|
||||
from data.transforms import *
|
||||
from data.loader import create_loader
|
||||
from data.random_erasing import RandomErasingTorch, RandomErasingNumpy
|
@ -94,6 +94,9 @@ def create_loader(
|
||||
|
||||
sampler = None
|
||||
if distributed:
|
||||
# FIXME note, doing this for validation isn't technically correct
|
||||
# There currently is no fixed order distributed sampler that corrects
|
||||
# for padded entries
|
||||
sampler = tdata.distributed.DistributedSampler(dataset)
|
||||
|
||||
loader = tdata.DataLoader(
|
@ -1,8 +1,5 @@
|
||||
from __future__ import absolute_import
|
||||
|
||||
#from torchvision.transforms import *
|
||||
|
||||
from PIL import Image
|
||||
import random
|
||||
import math
|
||||
import numpy as np
|
||||
|
@ -7,26 +7,57 @@ from data.random_erasing import RandomErasingNumpy
|
||||
|
||||
DEFAULT_CROP_PCT = 0.875
|
||||
|
||||
IMAGENET_DPN_MEAN = [124 / 255, 117 / 255, 104 / 255]
|
||||
IMAGENET_DPN_STD = [1 / (.0167 * 255)] * 3
|
||||
IMAGENET_INCEPTION_MEAN = [0.5, 0.5, 0.5]
|
||||
IMAGENET_INCEPTION_STD = [0.5, 0.5, 0.5]
|
||||
IMAGENET_DEFAULT_MEAN = [0.485, 0.456, 0.406]
|
||||
IMAGENET_DEFAULT_STD = [0.229, 0.224, 0.225]
|
||||
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
|
||||
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
|
||||
IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5)
|
||||
IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5)
|
||||
IMAGENET_DPN_MEAN = (124 / 255, 117 / 255, 104 / 255)
|
||||
IMAGENET_DPN_STD = tuple([1 / (.0167 * 255)] * 3)
|
||||
|
||||
|
||||
# FIXME replace these mean/std fn with model factory based values from config dict
|
||||
def get_model_meanstd(model_name):
|
||||
model_name = model_name.lower()
|
||||
if 'dpn' in model_name:
|
||||
return IMAGENET_DPN_MEAN, IMAGENET_DPN_STD
|
||||
elif 'ception' in model_name or 'nasnet' in model_name:
|
||||
return IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
||||
def get_mean_and_std(model, args, num_chan=3):
|
||||
if hasattr(model, 'default_cfg'):
|
||||
mean = model.default_cfg['mean']
|
||||
std = model.default_cfg['std']
|
||||
else:
|
||||
return IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
if args.mean is not None:
|
||||
mean = tuple(args.mean)
|
||||
if len(mean) == 1:
|
||||
mean = tuple(list(mean) * num_chan)
|
||||
else:
|
||||
assert len(mean) == num_chan
|
||||
else:
|
||||
mean = get_mean_by_model(args.model)
|
||||
if args.std is not None:
|
||||
std = tuple(args.std)
|
||||
if len(std) == 1:
|
||||
std = tuple(list(std) * num_chan)
|
||||
else:
|
||||
assert len(std) == num_chan
|
||||
else:
|
||||
std = get_std_by_model(args.model)
|
||||
return mean, std
|
||||
|
||||
|
||||
def get_model_mean(model_name):
|
||||
def get_mean_by_name(name):
|
||||
if name == 'dpn':
|
||||
return IMAGENET_DPN_MEAN
|
||||
elif name == 'inception' or name == 'le':
|
||||
return IMAGENET_INCEPTION_MEAN
|
||||
else:
|
||||
return IMAGENET_DEFAULT_MEAN
|
||||
|
||||
|
||||
def get_std_by_name(name):
|
||||
if name == 'dpn':
|
||||
return IMAGENET_DPN_STD
|
||||
elif name == 'inception' or name == 'le':
|
||||
return IMAGENET_INCEPTION_STD
|
||||
else:
|
||||
return IMAGENET_DEFAULT_STD
|
||||
|
||||
|
||||
def get_mean_by_model(model_name):
|
||||
model_name = model_name.lower()
|
||||
if 'dpn' in model_name:
|
||||
return IMAGENET_DPN_STD
|
||||
@ -36,7 +67,7 @@ def get_model_mean(model_name):
|
||||
return IMAGENET_DEFAULT_MEAN
|
||||
|
||||
|
||||
def get_model_std(model_name):
|
||||
def get_std_by_model(model_name):
|
||||
model_name = model_name.lower()
|
||||
if 'dpn' in model_name:
|
||||
return IMAGENET_DEFAULT_STD
|
||||
@ -93,8 +124,8 @@ def transforms_imagenet_train(
|
||||
tfl += [
|
||||
ToTensor(),
|
||||
transforms.Normalize(
|
||||
mean=torch.tensor(mean) * 255,
|
||||
std=torch.tensor(std) * 255)
|
||||
mean=torch.tensor(mean),
|
||||
std=torch.tensor(std))
|
||||
]
|
||||
if random_erasing > 0.:
|
||||
tfl.append(RandomErasingNumpy(random_erasing, per_pixel=True))
|
||||
@ -124,11 +155,5 @@ def transforms_imagenet_eval(
|
||||
mean=torch.tensor(mean),
|
||||
std=torch.tensor(std))
|
||||
]
|
||||
# tfl += [
|
||||
# ToTensor(),
|
||||
# transforms.Normalize(
|
||||
# mean=torch.tensor(mean) * 255,
|
||||
# std=torch.tensor(std) * 255)
|
||||
# ]
|
||||
|
||||
return transforms.Compose(tfl)
|
||||
|
21
inference.py
21
inference.py
@ -11,10 +11,11 @@ import argparse
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from models import create_model, load_checkpoint, TestTimePoolHead
|
||||
from data import Dataset, create_loader, get_model_meanstd
|
||||
from models import create_model, apply_test_time_pool
|
||||
from data import Dataset, create_loader, get_mean_and_std
|
||||
from utils import AverageMeter
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
parser = argparse.ArgumentParser(description='PyTorch ImageNet Inference')
|
||||
parser.add_argument('data', metavar='DIR',
|
||||
@ -29,6 +30,8 @@ parser.add_argument('-b', '--batch-size', default=256, type=int,
|
||||
metavar='N', help='mini-batch size (default: 256)')
|
||||
parser.add_argument('--img-size', default=224, type=int,
|
||||
metavar='N', help='Input image dimension')
|
||||
parser.add_argument('--num-classes', type=int, default=1000,
|
||||
help='Number classes in dataset')
|
||||
parser.add_argument('--print-freq', '-p', default=10, type=int,
|
||||
metavar='N', help='print frequency (default: 10)')
|
||||
parser.add_argument('--checkpoint', default='', type=str, metavar='PATH',
|
||||
@ -45,26 +48,24 @@ def main():
|
||||
args = parser.parse_args()
|
||||
|
||||
# create model
|
||||
num_classes = 1000
|
||||
model = create_model(
|
||||
args.model,
|
||||
num_classes=num_classes,
|
||||
pretrained=args.pretrained)
|
||||
num_classes=args.num_classes,
|
||||
in_chans=3,
|
||||
pretrained=args.pretrained,
|
||||
checkpoint_path=args.checkpoint)
|
||||
|
||||
print('Model %s created, param count: %d' %
|
||||
(args.model, sum([m.numel() for m in model.parameters()])))
|
||||
|
||||
# load a checkpoint
|
||||
if not args.pretrained:
|
||||
if not load_checkpoint(model, args.checkpoint):
|
||||
exit(1)
|
||||
data_mean, data_std = get_mean_and_std(model, args)
|
||||
model, test_time_pool = apply_test_time_pool(model, args)
|
||||
|
||||
if args.num_gpu > 1:
|
||||
model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda()
|
||||
else:
|
||||
model = model.cuda()
|
||||
|
||||
data_mean, data_std = get_model_meanstd(args.model)
|
||||
loader = create_loader(
|
||||
Dataset(args.data),
|
||||
img_size=args.img_size,
|
||||
|
@ -1,3 +1,4 @@
|
||||
from .model_factory import create_model, load_checkpoint
|
||||
from .test_time_pool import TestTimePoolHead
|
||||
from models.model_factory import create_model
|
||||
from models.helpers import load_checkpoint, resume_checkpoint
|
||||
from models.test_time_pool import TestTimePoolHead, apply_test_time_pool
|
||||
|
||||
|
@ -5,19 +5,29 @@ fixed kwargs passthrough and addition of dynamic global avg/max pool.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.model_zoo as model_zoo
|
||||
from collections import OrderedDict
|
||||
from .adaptive_avgmax_pool import *
|
||||
|
||||
from models.helpers import load_pretrained
|
||||
from models.adaptive_avgmax_pool import *
|
||||
from data.transforms import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
import re
|
||||
|
||||
__all__ = ['DenseNet', 'densenet121', 'densenet169', 'densenet201', 'densenet161']
|
||||
|
||||
|
||||
model_urls = {
|
||||
'densenet121': 'https://download.pytorch.org/models/densenet121-241335ed.pth',
|
||||
'densenet169': 'https://download.pytorch.org/models/densenet169-6f0f7f60.pth',
|
||||
'densenet201': 'https://download.pytorch.org/models/densenet201-4c113574.pth',
|
||||
'densenet161': 'https://download.pytorch.org/models/densenet161-17b70270.pth',
|
||||
def _cfg(url=''):
|
||||
return {
|
||||
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 244), 'pool_size': (7, 7),
|
||||
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
||||
'first_conv': 'features.conv0', 'classifier': 'classifier',
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = {
|
||||
'densenet121': _cfg(url='https://download.pytorch.org/models/densenet121-241335ed.pth'),
|
||||
'densenet169': _cfg(url='https://download.pytorch.org/models/densenet169-6f0f7f60.pth'),
|
||||
'densenet201': _cfg(url='https://download.pytorch.org/models/densenet201-4c113574.pth'),
|
||||
'densenet161': _cfg(url='https://download.pytorch.org/models/densenet161-17b70270.pth'),
|
||||
}
|
||||
|
||||
|
||||
@ -34,59 +44,56 @@ def _filter_pretrained(state_dict):
|
||||
return state_dict
|
||||
|
||||
|
||||
def densenet121(pretrained=False, **kwargs):
|
||||
def densenet121(num_classes=1000, in_chans=3, pretrained=False, **kwargs):
|
||||
r"""Densenet-121 model from
|
||||
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
"""
|
||||
model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16), **kwargs)
|
||||
default_cfg = default_cfgs['densenet121']
|
||||
model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16),
|
||||
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
state_dict = model_zoo.load_url(model_urls['densenet121'])
|
||||
model.load_state_dict(_filter_pretrained(state_dict))
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans, filter_fn=_filter_pretrained)
|
||||
return model
|
||||
|
||||
|
||||
def densenet169(pretrained=False, **kwargs):
|
||||
def densenet169(num_classes=1000, in_chans=3, pretrained=False, **kwargs):
|
||||
r"""Densenet-169 model from
|
||||
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
"""
|
||||
model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 32, 32), **kwargs)
|
||||
default_cfg = default_cfgs['densenet169']
|
||||
model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 32, 32),
|
||||
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
state_dict = model_zoo.load_url(model_urls['densenet169'])
|
||||
model.load_state_dict(_filter_pretrained(state_dict))
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans, filter_fn=_filter_pretrained)
|
||||
return model
|
||||
|
||||
|
||||
def densenet201(pretrained=False, **kwargs):
|
||||
def densenet201(num_classes=1000, in_chans=3, pretrained=False, **kwargs):
|
||||
r"""Densenet-201 model from
|
||||
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
"""
|
||||
model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 48, 32), **kwargs)
|
||||
default_cfg = default_cfgs['densenet201']
|
||||
model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 48, 32),
|
||||
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
state_dict = model_zoo.load_url(model_urls['densenet201'])
|
||||
model.load_state_dict(_filter_pretrained(state_dict))
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans, filter_fn=_filter_pretrained)
|
||||
return model
|
||||
|
||||
|
||||
def densenet161(pretrained=False, **kwargs):
|
||||
def densenet161(num_classes=1000, in_chans=3, pretrained=False, **kwargs):
|
||||
r"""Densenet-201 model from
|
||||
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
"""
|
||||
model = DenseNet(num_init_features=96, growth_rate=48, block_config=(6, 12, 36, 24), **kwargs)
|
||||
print(num_classes, in_chans, pretrained)
|
||||
default_cfg = default_cfgs['densenet161']
|
||||
model = DenseNet(num_init_features=96, growth_rate=48, block_config=(6, 12, 36, 24),
|
||||
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
state_dict = model_zoo.load_url(model_urls['densenet161'])
|
||||
model.load_state_dict(_filter_pretrained(state_dict))
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans, filter_fn=_filter_pretrained)
|
||||
return model
|
||||
|
||||
|
||||
@ -142,14 +149,15 @@ class DenseNet(nn.Module):
|
||||
num_classes (int) - number of classification classes
|
||||
"""
|
||||
def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16),
|
||||
num_init_features=64, bn_size=4, drop_rate=0, num_classes=1000, global_pool='avg'):
|
||||
num_init_features=64, bn_size=4, drop_rate=0,
|
||||
num_classes=1000, in_chans=3, global_pool='avg'):
|
||||
self.global_pool = global_pool
|
||||
self.num_classes = num_classes
|
||||
super(DenseNet, self).__init__()
|
||||
|
||||
# First convolution
|
||||
self.features = nn.Sequential(OrderedDict([
|
||||
('conv0', nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)),
|
||||
('conv0', nn.Conv2d(in_chans, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)),
|
||||
('norm0', nn.BatchNorm2d(num_init_features)),
|
||||
('relu0', nn.ReLU(inplace=True)),
|
||||
('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)),
|
||||
@ -172,7 +180,7 @@ class DenseNet(nn.Module):
|
||||
self.features.add_module('norm5', nn.BatchNorm2d(num_features))
|
||||
|
||||
# Linear layer
|
||||
self.classifier = torch.nn.Linear(num_features, num_classes)
|
||||
self.classifier = nn.Linear(num_features, num_classes)
|
||||
|
||||
self.num_features = num_features
|
||||
|
||||
@ -184,7 +192,7 @@ class DenseNet(nn.Module):
|
||||
self.num_classes = num_classes
|
||||
del self.classifier
|
||||
if num_classes:
|
||||
self.classifier = torch.nn.Linear(self.num_features, num_classes)
|
||||
self.classifier = nn.Linear(self.num_features, num_classes)
|
||||
else:
|
||||
self.classifier = None
|
||||
|
||||
|
@ -13,94 +13,108 @@ import os
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.model_zoo as model_zoo
|
||||
from collections import OrderedDict
|
||||
|
||||
from .adaptive_avgmax_pool import select_adaptive_pool2d
|
||||
from models.helpers import load_pretrained
|
||||
from models.adaptive_avgmax_pool import select_adaptive_pool2d
|
||||
from data.transforms import IMAGENET_DPN_MEAN, IMAGENET_DPN_STD
|
||||
|
||||
__all__ = ['DPN', 'dpn68', 'dpn92', 'dpn98', 'dpn131', 'dpn107']
|
||||
|
||||
|
||||
model_urls = {
|
||||
def _cfg(url=''):
|
||||
return {
|
||||
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
||||
'mean': IMAGENET_DPN_MEAN, 'std': IMAGENET_DPN_STD,
|
||||
'first_conv': 'features.conv1_1.conv', 'classifier': 'classifier',
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = {
|
||||
'dpn68':
|
||||
'http://data.lip6.fr/cadene/pretrainedmodels/dpn68-66bebafa7.pth',
|
||||
_cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/dpn68-66bebafa7.pth'),
|
||||
'dpn68b_extra':
|
||||
'http://data.lip6.fr/cadene/pretrainedmodels/'
|
||||
'dpn68b_extra-84854c156.pth',
|
||||
'dpn92': '',
|
||||
_cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/dpn68b_extra-84854c156.pth'),
|
||||
'dpn92_extra':
|
||||
'http://data.lip6.fr/cadene/pretrainedmodels/'
|
||||
'dpn92_extra-b040e4a9b.pth',
|
||||
_cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/dpn92_extra-b040e4a9b.pth'),
|
||||
'dpn98':
|
||||
'http://data.lip6.fr/cadene/pretrainedmodels/dpn98-5b90dec4d.pth',
|
||||
_cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/dpn98-5b90dec4d.pth'),
|
||||
'dpn131':
|
||||
'http://data.lip6.fr/cadene/pretrainedmodels/dpn131-71dfe43e0.pth',
|
||||
_cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/dpn131-71dfe43e0.pth'),
|
||||
'dpn107_extra':
|
||||
'http://data.lip6.fr/cadene/pretrainedmodels/'
|
||||
'dpn107_extra-1ac7121e2.pth'
|
||||
_cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/dpn107_extra-1ac7121e2.pth')
|
||||
}
|
||||
|
||||
|
||||
def dpn68(num_classes=1000, pretrained=False):
|
||||
def dpn68(num_classes=1000, in_chans=3, pretrained=False):
|
||||
default_cfg = default_cfgs['dpn68']
|
||||
model = DPN(
|
||||
small=True, num_init_features=10, k_r=128, groups=32,
|
||||
k_sec=(3, 4, 12, 3), inc_sec=(16, 32, 32, 64),
|
||||
num_classes=num_classes)
|
||||
num_classes=num_classes, in_chans=in_chans)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
model.load_state_dict(model_zoo.load_url(model_urls['dpn68']))
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
|
||||
|
||||
def dpn68b(num_classes=1000, pretrained=False):
|
||||
def dpn68b(num_classes=1000, in_chans=3, pretrained=False):
|
||||
default_cfg = default_cfgs['dpn68b_extra']
|
||||
model = DPN(
|
||||
small=True, num_init_features=10, k_r=128, groups=32,
|
||||
b=True, k_sec=(3, 4, 12, 3), inc_sec=(16, 32, 32, 64),
|
||||
num_classes=num_classes)
|
||||
num_classes=num_classes, in_chans=in_chans)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
model.load_state_dict(model_zoo.load_url(model_urls['dpn68b_extra']))
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
|
||||
|
||||
def dpn92(num_classes=1000, pretrained=False, extra=True):
|
||||
def dpn92(num_classes=1000, in_chans=3, pretrained=False):
|
||||
default_cfg = default_cfgs['dpn92_extra']
|
||||
model = DPN(
|
||||
num_init_features=64, k_r=96, groups=32,
|
||||
k_sec=(3, 4, 20, 3), inc_sec=(16, 32, 24, 128),
|
||||
num_classes=num_classes)
|
||||
num_classes=num_classes, in_chans=in_chans)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
if extra:
|
||||
model.load_state_dict(model_zoo.load_url(model_urls['dpn92_extra']))
|
||||
else:
|
||||
model.load_state_dict(model_zoo.load_url(model_urls['dpn92']))
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
|
||||
|
||||
def dpn98(num_classes=1000, pretrained=False):
|
||||
def dpn98(num_classes=1000, in_chans=3, pretrained=False):
|
||||
default_cfg = default_cfgs['dpn98']
|
||||
model = DPN(
|
||||
num_init_features=96, k_r=160, groups=40,
|
||||
k_sec=(3, 6, 20, 3), inc_sec=(16, 32, 32, 128),
|
||||
num_classes=num_classes)
|
||||
num_classes=num_classes, in_chans=in_chans)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
model.load_state_dict(model_zoo.load_url(model_urls['dpn98']))
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
|
||||
|
||||
def dpn131(num_classes=1000, pretrained=False):
|
||||
def dpn131(num_classes=1000, in_chans=3, pretrained=False):
|
||||
default_cfg = default_cfgs['dpn131']
|
||||
model = DPN(
|
||||
num_init_features=128, k_r=160, groups=40,
|
||||
k_sec=(4, 8, 28, 3), inc_sec=(16, 32, 32, 128),
|
||||
num_classes=num_classes)
|
||||
num_classes=num_classes, in_chans=in_chans)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
model.load_state_dict(model_zoo.load_url(model_urls['dpn131']))
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
|
||||
|
||||
def dpn107(num_classes=1000, pretrained=False):
|
||||
def dpn107(num_classes=1000, in_chans=3, pretrained=False):
|
||||
default_cfg = default_cfgs['dpn107_extra']
|
||||
model = DPN(
|
||||
num_init_features=128, k_r=200, groups=50,
|
||||
k_sec=(4, 8, 20, 3), inc_sec=(20, 64, 64, 128),
|
||||
num_classes=num_classes)
|
||||
num_classes=num_classes, in_chans=in_chans)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
model.load_state_dict(model_zoo.load_url(model_urls['dpn107_extra']))
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
|
||||
|
||||
@ -128,11 +142,11 @@ class BnActConv2d(nn.Module):
|
||||
|
||||
|
||||
class InputBlock(nn.Module):
|
||||
def __init__(self, num_init_features, kernel_size=7,
|
||||
def __init__(self, num_init_features, kernel_size=7, in_chans=3,
|
||||
padding=3, activation_fn=nn.ReLU(inplace=True)):
|
||||
super(InputBlock, self).__init__()
|
||||
self.conv = nn.Conv2d(
|
||||
3, num_init_features, kernel_size=kernel_size, stride=2, padding=padding, bias=False)
|
||||
in_chans, num_init_features, kernel_size=kernel_size, stride=2, padding=padding, bias=False)
|
||||
self.bn = nn.BatchNorm2d(num_init_features, eps=0.001)
|
||||
self.act = activation_fn
|
||||
self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
@ -212,7 +226,7 @@ class DualPathBlock(nn.Module):
|
||||
class DPN(nn.Module):
|
||||
def __init__(self, small=False, num_init_features=64, k_r=96, groups=32,
|
||||
b=False, k_sec=(3, 4, 20, 3), inc_sec=(16, 32, 24, 128),
|
||||
num_classes=1000, fc_act=nn.ELU(inplace=True)):
|
||||
num_classes=1000, in_chans=3, fc_act=nn.ELU(inplace=True)):
|
||||
super(DPN, self).__init__()
|
||||
self.num_classes = num_classes
|
||||
self.b = b
|
||||
@ -222,9 +236,11 @@ class DPN(nn.Module):
|
||||
|
||||
# conv1
|
||||
if small:
|
||||
blocks['conv1_1'] = InputBlock(num_init_features, kernel_size=3, padding=1)
|
||||
blocks['conv1_1'] = InputBlock(
|
||||
num_init_features, in_chans=in_chans, kernel_size=3, padding=1)
|
||||
else:
|
||||
blocks['conv1_1'] = InputBlock(num_init_features, kernel_size=7, padding=3)
|
||||
blocks['conv1_1'] = InputBlock(
|
||||
num_init_features, in_chans=in_chans, kernel_size=7, padding=3)
|
||||
|
||||
# conv2
|
||||
bw = 64 * bw_factor
|
||||
|
89
models/helpers.py
Normal file
89
models/helpers.py
Normal file
@ -0,0 +1,89 @@
|
||||
import torch
|
||||
import torch.utils.model_zoo as model_zoo
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
|
||||
|
||||
def load_checkpoint(model, checkpoint_path):
|
||||
if checkpoint_path and os.path.isfile(checkpoint_path):
|
||||
print("=> Loading checkpoint '{}'".format(checkpoint_path))
|
||||
checkpoint = torch.load(checkpoint_path)
|
||||
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
|
||||
new_state_dict = OrderedDict()
|
||||
for k, v in checkpoint['state_dict'].items():
|
||||
if k.startswith('module'):
|
||||
name = k[7:] # remove `module.`
|
||||
else:
|
||||
name = k
|
||||
new_state_dict[name] = v
|
||||
model.load_state_dict(new_state_dict)
|
||||
else:
|
||||
model.load_state_dict(checkpoint)
|
||||
print("=> Loaded checkpoint '{}'".format(checkpoint_path))
|
||||
else:
|
||||
print("=> Error: No checkpoint found at '{}'".format(checkpoint_path))
|
||||
raise FileNotFoundError()
|
||||
|
||||
|
||||
def resume_checkpoint(model, checkpoint_path, start_epoch=None):
|
||||
start_epoch = 0 if start_epoch is None else start_epoch
|
||||
optimizer_state = None
|
||||
if os.path.isfile(checkpoint_path):
|
||||
print("=> loading checkpoint '{}'".format(checkpoint_path))
|
||||
checkpoint = torch.load(checkpoint_path)
|
||||
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
|
||||
new_state_dict = OrderedDict()
|
||||
for k, v in checkpoint['state_dict'].items():
|
||||
if k.startswith('module'):
|
||||
name = k[7:] # remove `module.`
|
||||
else:
|
||||
name = k
|
||||
new_state_dict[name] = v
|
||||
model.load_state_dict(new_state_dict)
|
||||
if 'optimizer' in checkpoint:
|
||||
optimizer_state = checkpoint['optimizer']
|
||||
print("=> loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch']))
|
||||
start_epoch = checkpoint['epoch'] if start_epoch is None else start_epoch
|
||||
else:
|
||||
model.load_state_dict(checkpoint)
|
||||
return optimizer_state, start_epoch
|
||||
else:
|
||||
print("=> No checkpoint found at '{}'".format(checkpoint_path))
|
||||
raise FileNotFoundError()
|
||||
|
||||
|
||||
def load_pretrained(model, default_cfg, num_classes=1000, in_chans=3, filter_fn=None):
|
||||
state_dict = model_zoo.load_url(default_cfg['url'])
|
||||
|
||||
if in_chans == 1:
|
||||
conv1_name = default_cfg['first_conv']
|
||||
print('Converting first conv (%s) from 3 to 1 channel' % conv1_name)
|
||||
conv1_weight = state_dict[conv1_name + '.weight']
|
||||
state_dict[conv1_name + '.weight'] = conv1_weight.sum(dim=1, keepdim=True)
|
||||
elif in_chans != 3:
|
||||
assert False, "Invalid in_chans for pretrained weights"
|
||||
|
||||
strict = True
|
||||
classifier_name = default_cfg['classifier']
|
||||
if num_classes == 1000 and default_cfg['num_classes'] == 1001:
|
||||
# special case for imagenet trained models with extra background class in pretrained weights
|
||||
classifier_weight = state_dict[classifier_name + '.weight']
|
||||
state_dict[classifier_name + '.weight'] = classifier_weight[1:]
|
||||
classifier_bias = state_dict[classifier_name + '.bias']
|
||||
state_dict[classifier_name + '.bias'] = classifier_bias[1:]
|
||||
elif num_classes != default_cfg['num_classes']:
|
||||
# completely discard fully connected for all other differences between pretrained and created model
|
||||
del state_dict[classifier_name + '.weight']
|
||||
del state_dict[classifier_name + '.bias']
|
||||
strict = False
|
||||
|
||||
if filter_fn is not None:
|
||||
state_dict = filter_fn(state_dict)
|
||||
|
||||
model.load_state_dict(state_dict, strict=strict)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
@ -5,12 +5,18 @@ based upon Google's Tensorflow implementation and pretrained weights (Apache 2.0
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.model_zoo as model_zoo
|
||||
import numpy as np
|
||||
from .adaptive_avgmax_pool import *
|
||||
from models.helpers import load_pretrained
|
||||
from models.adaptive_avgmax_pool import *
|
||||
from data.transforms import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
||||
|
||||
model_urls = {
|
||||
'imagenet': 'http://data.lip6.fr/cadene/pretrainedmodels/inceptionresnetv2-520b38e4.pth'
|
||||
|
||||
default_cfgs = {
|
||||
'inception_resnet_v2': {
|
||||
'url': 'http://data.lip6.fr/cadene/pretrainedmodels/inceptionresnetv2-520b38e4.pth',
|
||||
'num_classes': 1001, 'input_size': (3, 299, 299), 'pool_size': (8, 8),
|
||||
'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
|
||||
'first_conv': 'conv2d_1a.conv', 'classifier': 'last_linear',
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -204,12 +210,14 @@ class Block8(nn.Module):
|
||||
|
||||
|
||||
class InceptionResnetV2(nn.Module):
|
||||
def __init__(self, num_classes=1001, drop_rate=0., global_pool='avg'):
|
||||
def __init__(self, num_classes=1001, in_chans=3, drop_rate=0., global_pool='avg'):
|
||||
super(InceptionResnetV2, self).__init__()
|
||||
self.drop_rate = drop_rate
|
||||
self.global_pool = global_pool
|
||||
self.num_classes = num_classes
|
||||
self.conv2d_1a = BasicConv2d(3, 32, kernel_size=3, stride=2)
|
||||
self.num_features = 1536
|
||||
|
||||
self.conv2d_1a = BasicConv2d(in_chans, 32, kernel_size=3, stride=2)
|
||||
self.conv2d_2a = BasicConv2d(32, 32, kernel_size=3, stride=1)
|
||||
self.conv2d_2b = BasicConv2d(32, 64, kernel_size=3, stride=1, padding=1)
|
||||
self.maxpool_3a = nn.MaxPool2d(3, stride=2)
|
||||
@ -265,29 +273,21 @@ class InceptionResnetV2(nn.Module):
|
||||
Block8(scale=0.20)
|
||||
)
|
||||
self.block8 = Block8(noReLU=True)
|
||||
self.conv2d_7b = BasicConv2d(2080, 1536, kernel_size=1, stride=1)
|
||||
self.num_features = 1536
|
||||
self.last_linear = nn.Linear(1536, num_classes)
|
||||
self.conv2d_7b = BasicConv2d(2080, self.num_features, kernel_size=1, stride=1)
|
||||
self.last_linear = nn.Linear(self.num_features, num_classes)
|
||||
|
||||
def get_classifier(self):
|
||||
return self.classif
|
||||
return self.last_linear
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool='avg'):
|
||||
self.global_pool = global_pool
|
||||
self.num_classes = num_classes
|
||||
del self.classif
|
||||
del self.last_linear
|
||||
if num_classes:
|
||||
self.last_linear = torch.nn.Linear(1536, num_classes)
|
||||
self.last_linear = torch.nn.Linear(self.num_features, num_classes)
|
||||
else:
|
||||
self.last_linear = None
|
||||
|
||||
def trim_classifier(self, trim=1):
|
||||
self.num_classes -= trim
|
||||
new_last_linear = nn.Linear(1536, self.num_classes)
|
||||
new_last_linear.weight.data = self.last_linear.weight.data[trim:]
|
||||
new_last_linear.bias.data = self.last_linear.bias.data[trim:]
|
||||
self.last_linear = new_last_linear
|
||||
|
||||
def forward_features(self, x, pool=True):
|
||||
x = self.conv2d_1a(x)
|
||||
x = self.conv2d_2a(x)
|
||||
@ -318,19 +318,15 @@ class InceptionResnetV2(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
def inception_resnet_v2(pretrained=False, num_classes=1000, **kwargs):
|
||||
def inception_resnet_v2(num_classes=1000, in_chans=3, pretrained=False, **kwargs):
|
||||
r"""InceptionResnetV2 model architecture from the
|
||||
`"InceptionV4, Inception-ResNet..." <https://arxiv.org/abs/1602.07261>`_ paper.
|
||||
|
||||
Args:
|
||||
pretrained ('string'): If True, returns a model pre-trained on ImageNet
|
||||
"""
|
||||
extra_class = 1 if pretrained else 0
|
||||
model = InceptionResnetV2(num_classes=num_classes + extra_class, **kwargs)
|
||||
default_cfg = default_cfgs['inception_resnet_v2']
|
||||
model = InceptionResnetV2(num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
print('Loading pretrained from %s' % model_urls['imagenet'])
|
||||
model.load_state_dict(model_zoo.load_url(model_urls['imagenet']))
|
||||
model.trim_classifier()
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
|
||||
return model
|
||||
|
||||
|
@ -5,11 +5,18 @@ based upon Google's Tensorflow implementation and pretrained weights (Apache 2.0
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.model_zoo as model_zoo
|
||||
from .adaptive_avgmax_pool import *
|
||||
from models.helpers import load_pretrained
|
||||
from models.adaptive_avgmax_pool import *
|
||||
from data.transforms import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
||||
|
||||
model_urls = {
|
||||
'imagenet': 'http://webia.lip6.fr/~cadene/Downloads/inceptionv4-97ef9c30.pth'
|
||||
|
||||
default_cfgs = {
|
||||
'inception_v4': {
|
||||
'url': 'http://webia.lip6.fr/~cadene/Downloads/inceptionv4-97ef9c30.pth',
|
||||
'num_classes': 1001, 'input_size': (3, 299, 299), 'pool_size': (8, 8),
|
||||
'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
|
||||
'first_conv': 'features.0.conv', 'classifier': 'classif',
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -230,13 +237,15 @@ class Inception_C(nn.Module):
|
||||
|
||||
|
||||
class InceptionV4(nn.Module):
|
||||
def __init__(self, num_classes=1001, drop_rate=0., global_pool='avg'):
|
||||
def __init__(self, num_classes=1001, in_chans=3, drop_rate=0., global_pool='avg'):
|
||||
super(InceptionV4, self).__init__()
|
||||
self.drop_rate = drop_rate
|
||||
self.global_pool = global_pool
|
||||
self.num_classes = num_classes
|
||||
self.num_features = 1536
|
||||
|
||||
self.features = nn.Sequential(
|
||||
BasicConv2d(3, 32, kernel_size=3, stride=2),
|
||||
BasicConv2d(in_chans, 32, kernel_size=3, stride=2),
|
||||
BasicConv2d(32, 32, kernel_size=3, stride=1),
|
||||
BasicConv2d(32, 64, kernel_size=3, stride=1, padding=1),
|
||||
Mixed_3a(),
|
||||
@ -259,7 +268,7 @@ class InceptionV4(nn.Module):
|
||||
Inception_C(),
|
||||
Inception_C(),
|
||||
)
|
||||
self.classif = nn.Linear(1536, num_classes)
|
||||
self.classif = nn.Linear(self.num_features, num_classes)
|
||||
|
||||
def get_classifier(self):
|
||||
return self.classif
|
||||
@ -267,12 +276,12 @@ class InceptionV4(nn.Module):
|
||||
def reset_classifier(self, num_classes, global_pool='avg'):
|
||||
self.global_pool = global_pool
|
||||
self.num_classes = num_classes
|
||||
self.classif = nn.Linear(1536, num_classes)
|
||||
self.classif = nn.Linear(self.num_features, num_classes)
|
||||
|
||||
def forward_features(self, x, pool=True):
|
||||
x = self.features(x)
|
||||
if pool:
|
||||
x = select_adaptive_pool2d(x, self.global_pool, count_include_pad=False)
|
||||
x = select_adaptive_pool2d(x, self.global_pool)
|
||||
x = x.view(x.size(0), -1)
|
||||
return x
|
||||
|
||||
@ -284,10 +293,12 @@ class InceptionV4(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
def inception_v4(pretrained=False, num_classes=1001, **kwargs):
|
||||
model = InceptionV4(num_classes=num_classes, **kwargs)
|
||||
def inception_v4(num_classes=1000, in_chans=3, pretrained=False, **kwargs):
|
||||
default_cfg = default_cfgs['inception_v4']
|
||||
model = InceptionV4(num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
model.load_state_dict(model_zoo.load_url(model_urls['imagenet']))
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
|
||||
|
||||
|
@ -1,155 +1,34 @@
|
||||
import torch
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
|
||||
from .inception_v4 import inception_v4
|
||||
from .inception_resnet_v2 import inception_resnet_v2
|
||||
from .densenet import densenet161, densenet121, densenet169, densenet201
|
||||
from .resnet import resnet18, resnet34, resnet50, resnet101, resnet152, \
|
||||
from models.inception_v4 import inception_v4
|
||||
from models.inception_resnet_v2 import inception_resnet_v2
|
||||
from models.densenet import densenet161, densenet121, densenet169, densenet201
|
||||
from models.resnet import resnet18, resnet34, resnet50, resnet101, resnet152, \
|
||||
resnext50_32x4d, resnext101_32x4d, resnext101_64x4d, resnext152_32x4d
|
||||
from .dpn import dpn68, dpn68b, dpn92, dpn98, dpn131, dpn107
|
||||
from .senet import seresnet18, seresnet34, seresnet50, seresnet101, seresnet152, \
|
||||
from models.dpn import dpn68, dpn68b, dpn92, dpn98, dpn131, dpn107
|
||||
from models.senet import seresnet18, seresnet34, seresnet50, seresnet101, seresnet152, \
|
||||
seresnext26_32x4d, seresnext50_32x4d, seresnext101_32x4d
|
||||
#from .resnext import resnext50, resnext101, resnext152
|
||||
from .xception import xception
|
||||
from .pnasnet import pnasnet5large
|
||||
from models.xception import xception
|
||||
from models.pnasnet import pnasnet5large
|
||||
|
||||
model_config_dict = {
|
||||
'resnet18': {
|
||||
'model_name': 'resnet18', 'num_classes': 1000, 'input_size': 224, 'normalizer': 'tv'},
|
||||
'resnet34': {
|
||||
'model_name': 'resnet34', 'num_classes': 1000, 'input_size': 224, 'normalizer': 'tv'},
|
||||
'resnet50': {
|
||||
'model_name': 'resnet50', 'num_classes': 1000, 'input_size': 224, 'normalizer': 'tv'},
|
||||
'resnet101': {
|
||||
'model_name': 'resnet101', 'num_classes': 1000, 'input_size': 224, 'normalizer': 'tv'},
|
||||
'resnet152': {
|
||||
'model_name': 'resnet152', 'num_classes': 1000, 'input_size': 224, 'normalizer': 'tv'},
|
||||
'densenet121': {
|
||||
'model_name': 'densenet121', 'num_classes': 1000, 'input_size': 224, 'normalizer': 'tv'},
|
||||
'densenet169': {
|
||||
'model_name': 'densenet169', 'num_classes': 1000, 'input_size': 224, 'normalizer': 'tv'},
|
||||
'densenet201': {
|
||||
'model_name': 'densenet201', 'num_classes': 1000, 'input_size': 224, 'normalizer': 'tv'},
|
||||
'densenet161': {
|
||||
'model_name': 'densenet161', 'num_classes': 1000, 'input_size': 224, 'normalizer': 'tv'},
|
||||
'dpn107': {
|
||||
'model_name': 'dpn107', 'num_classes': 1000, 'input_size': 299, 'normalizer': 'dpn'},
|
||||
'dpn92_extra': {
|
||||
'model_name': 'dpn92', 'num_classes': 1000, 'input_size': 299, 'normalizer': 'dpn'},
|
||||
'dpn92': {
|
||||
'model_name': 'dpn92', 'num_classes': 1000, 'input_size': 299, 'normalizer': 'dpn'},
|
||||
'dpn68': {
|
||||
'model_name': 'dpn68', 'num_classes': 1000, 'input_size': 299, 'normalizer': 'dpn'},
|
||||
'dpn68b': {
|
||||
'model_name': 'dpn68b', 'num_classes': 1000, 'input_size': 299, 'normalizer': 'dpn'},
|
||||
'dpn68b_extra': {
|
||||
'model_name': 'dpn68b', 'num_classes': 1000, 'input_size': 299, 'normalizer': 'dpn'},
|
||||
'inception_resnet_v2': {
|
||||
'model_name': 'inception_resnet_v2', 'num_classes': 1000, 'input_size': 299, 'normalizer': 'le'},
|
||||
'xception': {
|
||||
'model_name': 'xception', 'num_classes': 1000, 'input_size': 299, 'normalizer': 'le'},
|
||||
'pnasnet5large': {
|
||||
'model_name': 'pnasnet5large', 'num_classes': 1000, 'input_size': 331, 'normalizer': 'le'}
|
||||
}
|
||||
from models.helpers import load_checkpoint
|
||||
|
||||
|
||||
def create_model(
|
||||
model_name='resnet50',
|
||||
pretrained=None,
|
||||
num_classes=1000,
|
||||
in_chans=3,
|
||||
checkpoint_path='',
|
||||
**kwargs):
|
||||
|
||||
if model_name == 'dpn68':
|
||||
model = dpn68(num_classes=num_classes, pretrained=pretrained)
|
||||
elif model_name == 'dpn68b':
|
||||
model = dpn68b(num_classes=num_classes, pretrained=pretrained)
|
||||
elif model_name == 'dpn92':
|
||||
model = dpn92(num_classes=num_classes, pretrained=pretrained)
|
||||
elif model_name == 'dpn98':
|
||||
model = dpn98(num_classes=num_classes, pretrained=pretrained)
|
||||
elif model_name == 'dpn131':
|
||||
model = dpn131(num_classes=num_classes, pretrained=pretrained)
|
||||
elif model_name == 'dpn107':
|
||||
model = dpn107(num_classes=num_classes, pretrained=pretrained)
|
||||
elif model_name == 'resnet18':
|
||||
model = resnet18(num_classes=num_classes, pretrained=pretrained, **kwargs)
|
||||
elif model_name == 'resnet34':
|
||||
model = resnet34(num_classes=num_classes, pretrained=pretrained, **kwargs)
|
||||
elif model_name == 'resnet50':
|
||||
model = resnet50(num_classes=num_classes, pretrained=pretrained, **kwargs)
|
||||
elif model_name == 'resnet101':
|
||||
model = resnet101(num_classes=num_classes, pretrained=pretrained, **kwargs)
|
||||
elif model_name == 'resnet152':
|
||||
model = resnet152(num_classes=num_classes, pretrained=pretrained, **kwargs)
|
||||
elif model_name == 'densenet121':
|
||||
model = densenet121(num_classes=num_classes, pretrained=pretrained, **kwargs)
|
||||
elif model_name == 'densenet161':
|
||||
model = densenet161(num_classes=num_classes, pretrained=pretrained, **kwargs)
|
||||
elif model_name == 'densenet169':
|
||||
model = densenet169(num_classes=num_classes, pretrained=pretrained, **kwargs)
|
||||
elif model_name == 'densenet201':
|
||||
model = densenet201(num_classes=num_classes, pretrained=pretrained, **kwargs)
|
||||
elif model_name == 'inception_resnet_v2':
|
||||
model = inception_resnet_v2(num_classes=num_classes, pretrained=pretrained, **kwargs)
|
||||
elif model_name == 'inception_v4':
|
||||
model = inception_v4(num_classes=num_classes, pretrained=pretrained, **kwargs)
|
||||
elif model_name == 'seresnet18':
|
||||
model = seresnet18(num_classes=num_classes, pretrained=pretrained, **kwargs)
|
||||
elif model_name == 'seresnet34':
|
||||
model = seresnet34(num_classes=num_classes, pretrained=pretrained, **kwargs)
|
||||
elif model_name == 'seresnet50':
|
||||
model = seresnet50(num_classes=num_classes, pretrained=pretrained, **kwargs)
|
||||
elif model_name == 'seresnet101':
|
||||
model = seresnet101(num_classes=num_classes, pretrained=pretrained, **kwargs)
|
||||
elif model_name == 'seresnet152':
|
||||
model = seresnet152(num_classes=num_classes, pretrained=pretrained, **kwargs)
|
||||
elif model_name == 'seresnext26_32x4d':
|
||||
model = seresnext26_32x4d(num_classes=num_classes, pretrained=pretrained, **kwargs)
|
||||
elif model_name == 'seresnext50_32x4d':
|
||||
model = seresnext50_32x4d(num_classes=num_classes, pretrained=pretrained, **kwargs)
|
||||
elif model_name == 'seresnext101_32x4d':
|
||||
model = seresnext101_32x4d(num_classes=num_classes, pretrained=pretrained, **kwargs)
|
||||
elif model_name == 'resnext50_32x4d':
|
||||
model = resnext50_32x4d(num_classes=num_classes, pretrained=pretrained, **kwargs)
|
||||
elif model_name == 'resnext101_32x4d':
|
||||
model = resnext101_32x4d(num_classes=num_classes, pretrained=pretrained, **kwargs)
|
||||
elif model_name == 'resnext101_64x4d':
|
||||
model = resnext101_32x4d(num_classes=num_classes, pretrained=pretrained, **kwargs)
|
||||
elif model_name == 'resnext152_32x4d':
|
||||
model = resnext152_32x4d(num_classes=num_classes, pretrained=pretrained, **kwargs)
|
||||
elif model_name == 'xception':
|
||||
model = xception(num_classes=num_classes, pretrained=pretrained)
|
||||
elif model_name == 'pnasnet5large':
|
||||
model = pnasnet5large(num_classes=num_classes, pretrained=pretrained)
|
||||
margs = dict(num_classes=num_classes, in_chans=in_chans, pretrained=pretrained)
|
||||
|
||||
if model_name in globals():
|
||||
create_fn = globals()[model_name]
|
||||
model = create_fn(**margs, **kwargs)
|
||||
else:
|
||||
assert False and "Invalid model"
|
||||
raise RuntimeError('Unknown model (%s)' % model_name)
|
||||
|
||||
if checkpoint_path and not pretrained:
|
||||
print(checkpoint_path)
|
||||
load_checkpoint(model, checkpoint_path)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def load_checkpoint(model, checkpoint_path):
|
||||
if checkpoint_path and os.path.isfile(checkpoint_path):
|
||||
print("=> Loading checkpoint '{}'".format(checkpoint_path))
|
||||
checkpoint = torch.load(checkpoint_path)
|
||||
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
|
||||
new_state_dict = OrderedDict()
|
||||
for k, v in checkpoint['state_dict'].items():
|
||||
if k.startswith('module'):
|
||||
name = k[7:] # remove `module.`
|
||||
else:
|
||||
name = k
|
||||
new_state_dict[name] = v
|
||||
model.load_state_dict(new_state_dict)
|
||||
else:
|
||||
model.load_state_dict(checkpoint)
|
||||
print("=> Loaded checkpoint '{}'".format(checkpoint_path))
|
||||
return True
|
||||
else:
|
||||
print("=> Error: No checkpoint found at '{}'".format(checkpoint_path))
|
||||
return False
|
||||
|
@ -3,29 +3,23 @@ from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.model_zoo as model_zoo
|
||||
import torch.nn.functional as F
|
||||
|
||||
pretrained_settings = {
|
||||
from models.helpers import load_pretrained
|
||||
from models.adaptive_avgmax_pool import SelectAdaptivePool2d
|
||||
|
||||
default_cfgs = {
|
||||
'pnasnet5large': {
|
||||
'imagenet': {
|
||||
'url': 'http://data.lip6.fr/cadene/pretrainedmodels/pnasnet5large-bf079911.pth',
|
||||
'input_space': 'RGB',
|
||||
'input_size': [3, 331, 331],
|
||||
'input_range': [0, 1],
|
||||
'mean': [0.5, 0.5, 0.5],
|
||||
'std': [0.5, 0.5, 0.5],
|
||||
'num_classes': 1000
|
||||
},
|
||||
'imagenet+background': {
|
||||
'url': 'http://data.lip6.fr/cadene/pretrainedmodels/pnasnet5large-bf079911.pth',
|
||||
'input_space': 'RGB',
|
||||
'input_size': [3, 331, 331],
|
||||
'input_range': [0, 1],
|
||||
'mean': [0.5, 0.5, 0.5],
|
||||
'std': [0.5, 0.5, 0.5],
|
||||
'num_classes': 1001
|
||||
}
|
||||
}
|
||||
'url': 'http://data.lip6.fr/cadene/pretrainedmodels/pnasnet5large-bf079911.pth',
|
||||
'input_size': (3, 331, 331),
|
||||
'pool_size': (11, 11),
|
||||
'mean': (0.5, 0.5, 0.5),
|
||||
'std': (0.5, 0.5, 0.5),
|
||||
'crop_pct': 0.8975,
|
||||
'num_classes': 1001,
|
||||
'first_conv': 'conv_0.conv',
|
||||
'classifier': 'last_linear',
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@ -288,13 +282,14 @@ class Cell(CellBase):
|
||||
|
||||
|
||||
class PNASNet5Large(nn.Module):
|
||||
def __init__(self, num_classes=1001):
|
||||
def __init__(self, num_classes=1001, in_chans=3, drop_rate=0.5, global_pool='avg'):
|
||||
super(PNASNet5Large, self).__init__()
|
||||
self.num_classes = num_classes
|
||||
self.num_features = 4320
|
||||
self.drop_rate = drop_rate
|
||||
|
||||
self.conv_0 = nn.Sequential(OrderedDict([
|
||||
('conv', nn.Conv2d(3, 96, kernel_size=3, stride=2, bias=False)),
|
||||
('conv', nn.Conv2d(in_chans, 96, kernel_size=3, stride=2, bias=False)),
|
||||
('bn', nn.BatchNorm2d(96, eps=0.001))
|
||||
]))
|
||||
self.cell_stem_0 = CellStem0(in_channels_left=96, out_channels_left=54,
|
||||
@ -334,18 +329,18 @@ class PNASNet5Large(nn.Module):
|
||||
self.cell_11 = Cell(in_channels_left=4320, out_channels_left=864,
|
||||
in_channels_right=4320, out_channels_right=864)
|
||||
self.relu = nn.ReLU()
|
||||
self.avg_pool = nn.AvgPool2d(11, stride=1, padding=0)
|
||||
self.dropout = nn.Dropout(0.5)
|
||||
self.last_linear = nn.Linear(self.num_features, num_classes)
|
||||
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
||||
self.last_linear = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes)
|
||||
|
||||
def get_classifier(self):
|
||||
return self.last_linear
|
||||
|
||||
def reset_classifier(self, num_classes):
|
||||
def reset_classifier(self, num_classes, global_pool='avg'):
|
||||
self.num_classes = num_classes
|
||||
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
||||
del self.last_linear
|
||||
if num_classes:
|
||||
self.last_linear = nn.Linear(self.num_features, num_classes)
|
||||
self.last_linear = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes)
|
||||
else:
|
||||
self.last_linear = None
|
||||
|
||||
@ -367,38 +362,27 @@ class PNASNet5Large(nn.Module):
|
||||
x_cell_11 = self.cell_11(x_cell_9, x_cell_10)
|
||||
x = self.relu(x_cell_11)
|
||||
if pool:
|
||||
x = self.avg_pool(x)
|
||||
x = self.global_pool(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
return x
|
||||
|
||||
def forward(self, input):
|
||||
x = self.forward_features(input)
|
||||
x = self.dropout(x)
|
||||
if self.drop_rate > 0:
|
||||
x = F.dropout(x, self.drop_rate, training=self.training)
|
||||
x = self.last_linear(x)
|
||||
return x
|
||||
|
||||
|
||||
def pnasnet5large(num_classes=1001, pretrained='imagenet'):
|
||||
def pnasnet5large(num_classes=1000, in_chans=3, pretrained='imagenet'):
|
||||
r"""PNASNet-5 model architecture from the
|
||||
`"Progressive Neural Architecture Search"
|
||||
<https://arxiv.org/abs/1712.00559>`_ paper.
|
||||
"""
|
||||
default_cfg = default_cfgs['pnasnet5large']
|
||||
model = PNASNet5Large(num_classes=1000, in_chans=in_chans)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
settings = pretrained_settings['pnasnet5large']['imagenet']
|
||||
assert num_classes == settings[
|
||||
'num_classes'], 'num_classes should be {}, but is {}'.format(
|
||||
settings['num_classes'], num_classes)
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
|
||||
# both 'imagenet'&'imagenet+background' are loaded from same parameters
|
||||
model = PNASNet5Large(num_classes=1001)
|
||||
model.load_state_dict(model_zoo.load_url(settings['url']))
|
||||
|
||||
#if pretrained == 'imagenet':
|
||||
new_last_linear = nn.Linear(model.last_linear.in_features, 1000)
|
||||
new_last_linear.weight.data = model.last_linear.weight.data[1:]
|
||||
new_last_linear.bias.data = model.last_linear.bias.data[1:]
|
||||
model.last_linear = new_last_linear
|
||||
|
||||
else:
|
||||
model = PNASNet5Large(num_classes=num_classes)
|
||||
return model
|
||||
|
145
models/resnet.py
145
models/resnet.py
@ -6,17 +6,33 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import math
|
||||
import torch.utils.model_zoo as model_zoo
|
||||
from .adaptive_avgmax_pool import SelectAdaptivePool2d
|
||||
from models.helpers import load_pretrained
|
||||
from models.adaptive_avgmax_pool import SelectAdaptivePool2d
|
||||
from data.transforms import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
|
||||
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152']
|
||||
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152',
|
||||
'resnext50_32x4d', 'resnext101_32x4d', 'resnext101_64x4d', 'resnext152_32x4d']
|
||||
|
||||
model_urls = {
|
||||
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
|
||||
'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
|
||||
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
|
||||
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
|
||||
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
|
||||
|
||||
def _cfg(url=''):
|
||||
return {
|
||||
'url': url,
|
||||
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
||||
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'crop_pct': 0.875,
|
||||
'first_conv': 'conv1', 'classifier': 'fc',
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = {
|
||||
'resnet18': _cfg(url='https://download.pytorch.org/models/resnet18-5c106cde.pth'),
|
||||
'resnet34': _cfg(url='https://download.pytorch.org/models/resnet34-333f7ec4.pth'),
|
||||
'resnet50': _cfg(url='https://download.pytorch.org/models/resnet50-19c8e357.pth'),
|
||||
'resnet101': _cfg(url='https://download.pytorch.org/models/resnet101-5d3b4d8f.pth'),
|
||||
'resnet152': _cfg(url='https://download.pytorch.org/models/resnet152-b121ed2d.pth'),
|
||||
'resnext50_32x4d': _cfg(url=''),
|
||||
'resnext101_32x4d': _cfg(url=''),
|
||||
'resnext101_64x4d': _cfg(url=''),
|
||||
'resnext152_32x4d': _cfg(url=''),
|
||||
}
|
||||
|
||||
|
||||
@ -116,7 +132,7 @@ class Bottleneck(nn.Module):
|
||||
|
||||
class ResNet(nn.Module):
|
||||
|
||||
def __init__(self, block, layers, num_classes=1000,
|
||||
def __init__(self, block, layers, num_classes=1000, in_chans=3,
|
||||
cardinality=1, base_width=64,
|
||||
drop_rate=0.0, block_drop_rate=0.0,
|
||||
global_pool='avg'):
|
||||
@ -127,7 +143,7 @@ class ResNet(nn.Module):
|
||||
self.drop_rate = drop_rate
|
||||
self.expansion = block.expansion
|
||||
super(ResNet, self).__init__()
|
||||
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
|
||||
self.conv1 = nn.Conv2d(in_chans, 64, kernel_size=7, stride=2, padding=3, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(64)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
@ -197,109 +213,108 @@ class ResNet(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
def resnet18(pretrained=False, **kwargs):
|
||||
def resnet18(num_classes=1000, in_chans=3, pretrained=False, **kwargs):
|
||||
"""Constructs a ResNet-18 model.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
"""
|
||||
model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
|
||||
default_cfg = default_cfgs['resnet18']
|
||||
model = ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
|
||||
|
||||
def resnet34(pretrained=False, **kwargs):
|
||||
def resnet34(num_classes=1000, in_chans=3, pretrained=False, **kwargs):
|
||||
"""Constructs a ResNet-34 model.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
"""
|
||||
model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
|
||||
default_cfg = default_cfgs['resnet34']
|
||||
model = ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
model.load_state_dict(model_zoo.load_url(model_urls['resnet34']))
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
|
||||
|
||||
def resnet50(pretrained=False, **kwargs):
|
||||
def resnet50(num_classes=1000, in_chans=3, pretrained=False, **kwargs):
|
||||
"""Constructs a ResNet-50 model.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
"""
|
||||
model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
|
||||
default_cfg = default_cfgs['resnet50']
|
||||
model = ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
|
||||
|
||||
def resnet101(pretrained=False, **kwargs):
|
||||
def resnet101(num_classes=1000, in_chans=3, pretrained=False, **kwargs):
|
||||
"""Constructs a ResNet-101 model.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
"""
|
||||
model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
|
||||
default_cfg = default_cfgs['resnet101']
|
||||
model = ResNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
model.load_state_dict(model_zoo.load_url(model_urls['resnet101']))
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
|
||||
|
||||
def resnet152(pretrained=False, **kwargs):
|
||||
def resnet152(num_classes=1000, in_chans=3, pretrained=False, **kwargs):
|
||||
"""Constructs a ResNet-152 model.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
||||
"""
|
||||
model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
|
||||
default_cfg = default_cfgs['resnet152']
|
||||
model = ResNet(Bottleneck, [3, 8, 36, 3], num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
model.load_state_dict(model_zoo.load_url(model_urls['resnet152']))
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
|
||||
|
||||
def resnext50_32x4d(cardinality=32, base_width=4, pretrained=False, **kwargs):
|
||||
def resnext50_32x4d(num_classes=1000, in_chans=3, pretrained=False, **kwargs):
|
||||
"""Constructs a ResNeXt50-32x4d model.
|
||||
|
||||
Args:
|
||||
cardinality (int): Cardinality of the aggregated transform
|
||||
base_width (int): Base width of the grouped convolution
|
||||
"""
|
||||
default_cfg = default_cfgs['resnext50_32x4d2']
|
||||
model = ResNet(
|
||||
Bottleneck, [3, 4, 6, 3], cardinality=cardinality, base_width=base_width, **kwargs)
|
||||
Bottleneck, [3, 4, 6, 3], cardinality=32, base_width=4,
|
||||
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
|
||||
|
||||
def resnext101_32x4d(cardinality=32, base_width=4, pretrained=False, **kwargs):
|
||||
def resnext101_32x4d(num_classes=1000, in_chans=3, pretrained=False, **kwargs):
|
||||
"""Constructs a ResNeXt-101 model.
|
||||
|
||||
Args:
|
||||
cardinality (int): Cardinality of the aggregated transform
|
||||
base_width (int): Base width of the grouped convolution
|
||||
"""
|
||||
default_cfg = default_cfgs['resnext101_32x4d']
|
||||
model = ResNet(
|
||||
Bottleneck, [3, 4, 23, 3], cardinality=cardinality, base_width=base_width, **kwargs)
|
||||
Bottleneck, [3, 4, 23, 3], cardinality=32, base_width=4,
|
||||
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
|
||||
|
||||
def resnext101_64x4d(cardinality=64, base_width=4, pretrained=False, **kwargs):
|
||||
def resnext101_64x4d(num_classes=1000, in_chans=3, pretrained=False, **kwargs):
|
||||
"""Constructs a ResNeXt101-64x4d model.
|
||||
|
||||
Args:
|
||||
cardinality (int): Cardinality of the aggregated transform
|
||||
base_width (int): Base width of the grouped convolution
|
||||
"""
|
||||
default_cfg = default_cfgs['resnext101_32x4d']
|
||||
model = ResNet(
|
||||
Bottleneck, [3, 4, 23, 3], cardinality=cardinality, base_width=base_width, **kwargs)
|
||||
Bottleneck, [3, 4, 23, 3], cardinality=64, base_width=4,
|
||||
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
|
||||
|
||||
def resnext152_32x4d(cardinality=32, base_width=4, pretrained=False, **kwargs):
|
||||
def resnext152_32x4d(num_classes=1000, in_chans=3, pretrained=False, **kwargs):
|
||||
"""Constructs a ResNeXt152-32x4d model.
|
||||
|
||||
Args:
|
||||
cardinality (int): Cardinality of the aggregated transform
|
||||
base_width (int): Base width of the grouped convolution
|
||||
"""
|
||||
default_cfg = default_cfgs['resnext152_32x4d']
|
||||
model = ResNet(
|
||||
Bottleneck, [3, 8, 36, 3], cardinality=cardinality, base_width=base_width, **kwargs)
|
||||
Bottleneck, [3, 8, 36, 3], cardinality=32, base_width=4,
|
||||
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
|
137
models/senet.py
137
models/senet.py
@ -8,21 +8,40 @@ import math
|
||||
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.utils import model_zoo
|
||||
|
||||
from models.helpers import load_pretrained
|
||||
from models.adaptive_avgmax_pool import SelectAdaptivePool2d
|
||||
from data.transforms import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
|
||||
__all__ = ['SENet', 'senet154', 'seresnet50', 'seresnet101', 'seresnet152',
|
||||
'seresnext50_32x4d', 'seresnext101_32x4d']
|
||||
|
||||
model_urls = {
|
||||
'senet154': 'http://data.lip6.fr/cadene/pretrainedmodels/senet154-c7b49a05.pth',
|
||||
'seresnet18': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnet50-ce0d4300.pth',
|
||||
'seresnet34': 'https://www.dropbox.com/s/q31ccy22aq0fju7/seresnet34-a4004e63.pth?dl=1',
|
||||
'seresnet50': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnet50-ce0d4300.pth',
|
||||
'seresnet101': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnet101-7e38fcc6.pth',
|
||||
'seresnet152': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnet152-d17c99b7.pth',
|
||||
'seresnext50_32x4d': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnext50_32x4d-a260b3a4.pth',
|
||||
'seresnext101_32x4d': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnext101_32x4d-3b2fe3d8.pth',
|
||||
|
||||
def _cfg(url=''):
|
||||
return {
|
||||
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 244), 'pool_size': (7, 7),
|
||||
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'crop_pct': 0.875,
|
||||
'first_conv': 'layer0.conv1', 'classifier': 'last_linear',
|
||||
}
|
||||
|
||||
|
||||
default_cfgs = {
|
||||
'senet154':
|
||||
_cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/senet154-c7b49a05.pth'),
|
||||
'seresnet18':
|
||||
_cfg(url=''),
|
||||
'seresnet34':
|
||||
_cfg(url='https://www.dropbox.com/s/q31ccy22aq0fju7/seresnet34-a4004e63.pth?dl=1'),
|
||||
'seresnet50':
|
||||
_cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/se_resnet50-ce0d4300.pth'),
|
||||
'seresnet101':
|
||||
_cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/se_resnet101-7e38fcc6.pth'),
|
||||
'seresnet152':
|
||||
_cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/se_resnet152-d17c99b7.pth'),
|
||||
'seresnext50_32x4d':
|
||||
_cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/se_resnext50_32x4d-a260b3a4.pth'),
|
||||
'seresnext101_32x4d':
|
||||
_cfg(url='http://data.lip6.fr/cadene/pretrainedmodels/se_resnext101_32x4d-3b2fe3d8.pth'),
|
||||
}
|
||||
|
||||
|
||||
@ -197,7 +216,7 @@ class SEResNetBlock(nn.Module):
|
||||
class SENet(nn.Module):
|
||||
|
||||
def __init__(self, block, layers, groups, reduction, drop_rate=0.2,
|
||||
inchans=3, inplanes=128, input_3x3=True, downsample_kernel_size=3,
|
||||
in_chans=3, inplanes=128, input_3x3=True, downsample_kernel_size=3,
|
||||
downsample_padding=1, num_classes=1000, global_pool='avg'):
|
||||
"""
|
||||
Parameters
|
||||
@ -247,7 +266,7 @@ class SENet(nn.Module):
|
||||
self.num_classes = num_classes
|
||||
if input_3x3:
|
||||
layer0_modules = [
|
||||
('conv1', nn.Conv2d(inchans, 64, 3, stride=2, padding=1, bias=False)),
|
||||
('conv1', nn.Conv2d(in_chans, 64, 3, stride=2, padding=1, bias=False)),
|
||||
('bn1', nn.BatchNorm2d(64)),
|
||||
('relu1', nn.ReLU(inplace=True)),
|
||||
('conv2', nn.Conv2d(64, 64, 3, stride=1, padding=1, bias=False)),
|
||||
@ -260,7 +279,7 @@ class SENet(nn.Module):
|
||||
else:
|
||||
layer0_modules = [
|
||||
('conv1', nn.Conv2d(
|
||||
inchans, inplanes, kernel_size=7, stride=2, padding=3, bias=False)),
|
||||
in_chans, inplanes, kernel_size=7, stride=2, padding=3, bias=False)),
|
||||
('bn1', nn.BatchNorm2d(inplanes)),
|
||||
('relu1', nn.ReLU(inplace=True)),
|
||||
]
|
||||
@ -368,99 +387,107 @@ class SENet(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
def _load_pretrained(model, url, inchans=3):
|
||||
state_dict = model_zoo.load_url(url)
|
||||
if inchans == 1:
|
||||
conv1_weight = state_dict['conv1.weight']
|
||||
state_dict['conv1.weight'] = conv1_weight.sum(dim=1, keepdim=True)
|
||||
elif inchans != 3:
|
||||
assert False, "Invalid inchans for pretrained weights"
|
||||
model.load_state_dict(state_dict)
|
||||
|
||||
|
||||
def senet154(num_classes=1000, inchans=3, pretrained='imagenet', **kwargs):
|
||||
model = SENet(SEBottleneck, [3, 8, 36, 3], groups=64, reduction=16,
|
||||
num_classes=num_classes, **kwargs)
|
||||
if pretrained:
|
||||
_load_pretrained(model, model_urls['senet154'], inchans)
|
||||
return model
|
||||
|
||||
|
||||
def seresnet18(num_classes=1000, inchans=3, pretrained='imagenet', **kwargs):
|
||||
def seresnet18(num_classes=1000, in_chans=3, pretrained=False, **kwargs):
|
||||
default_cfg = default_cfgs['seresnet18']
|
||||
model = SENet(SEResNetBlock, [2, 2, 2, 2], groups=1, reduction=16,
|
||||
inplanes=64, input_3x3=False,
|
||||
downsample_kernel_size=1, downsample_padding=0,
|
||||
num_classes=num_classes, **kwargs)
|
||||
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
_load_pretrained(model, model_urls['seresnet18'], inchans)
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
|
||||
|
||||
def seresnet34(num_classes=1000, inchans=3, pretrained='imagenet', **kwargs):
|
||||
def seresnet34(num_classes=1000, in_chans=3, pretrained=False, **kwargs):
|
||||
default_cfg = default_cfgs['seresnet34']
|
||||
model = SENet(SEResNetBlock, [3, 4, 6, 3], groups=1, reduction=16,
|
||||
inplanes=64, input_3x3=False,
|
||||
downsample_kernel_size=1, downsample_padding=0,
|
||||
num_classes=num_classes, **kwargs)
|
||||
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
_load_pretrained(model, model_urls['seresnet34'], inchans)
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
|
||||
|
||||
def seresnet50(num_classes=1000, inchans=3, pretrained='imagenet', **kwargs):
|
||||
def seresnet50(num_classes=1000, in_chans=3, pretrained=False, **kwargs):
|
||||
default_cfg = default_cfgs['seresnet50']
|
||||
model = SENet(SEResNetBottleneck, [3, 4, 6, 3], groups=1, reduction=16,
|
||||
inplanes=64, input_3x3=False,
|
||||
downsample_kernel_size=1, downsample_padding=0,
|
||||
num_classes=num_classes, **kwargs)
|
||||
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
_load_pretrained(model, model_urls['seresnet50'], inchans)
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
|
||||
|
||||
def seresnet101(num_classes=1000, inchans=3, pretrained='imagenet', **kwargs):
|
||||
def seresnet101(num_classes=1000, in_chans=3, pretrained=False, **kwargs):
|
||||
default_cfg = default_cfgs['seresnet101']
|
||||
model = SENet(SEResNetBottleneck, [3, 4, 23, 3], groups=1, reduction=16,
|
||||
inplanes=64, input_3x3=False,
|
||||
downsample_kernel_size=1, downsample_padding=0,
|
||||
num_classes=num_classes, **kwargs)
|
||||
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
_load_pretrained(model, model_urls['seresnet101'], inchans)
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
|
||||
|
||||
def seresnet152(num_classes=1000, inchans=3, pretrained='imagenet', **kwargs):
|
||||
def seresnet152(num_classes=1000, in_chans=3, pretrained=False, **kwargs):
|
||||
default_cfg = default_cfgs['seresnet152']
|
||||
model = SENet(SEResNetBottleneck, [3, 8, 36, 3], groups=1, reduction=16,
|
||||
inplanes=64, input_3x3=False,
|
||||
downsample_kernel_size=1, downsample_padding=0,
|
||||
num_classes=num_classes, **kwargs)
|
||||
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
_load_pretrained(model, model_urls['seresnet152'], inchans)
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
|
||||
|
||||
def seresnext26_32x4d(num_classes=1000, inchans=3, pretrained='imagenet', **kwargs):
|
||||
def senet154(num_classes=1000, in_chans=3, pretrained=False, **kwargs):
|
||||
default_cfg = default_cfgs['senet154']
|
||||
model = SENet(SEBottleneck, [3, 8, 36, 3], groups=64, reduction=16,
|
||||
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
|
||||
|
||||
def seresnext26_32x4d(num_classes=1000, in_chans=3, pretrained=False, **kwargs):
|
||||
default_cfg = default_cfgs['seresnext26_32x4d']
|
||||
model = SENet(SEResNeXtBottleneck, [2, 2, 2, 2], groups=32, reduction=16,
|
||||
inplanes=64, input_3x3=False,
|
||||
downsample_kernel_size=1, downsample_padding=0,
|
||||
num_classes=num_classes, **kwargs)
|
||||
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
_load_pretrained(model, model_urls['se_resnext26_32x4d'], inchans)
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
|
||||
|
||||
def seresnext50_32x4d(num_classes=1000, inchans=3, pretrained='imagenet', **kwargs):
|
||||
def seresnext50_32x4d(num_classes=1000, in_chans=3, pretrained=False, **kwargs):
|
||||
default_cfg = default_cfgs['seresnext50_32x4d']
|
||||
model = SENet(SEResNeXtBottleneck, [3, 4, 6, 3], groups=32, reduction=16,
|
||||
inplanes=64, input_3x3=False,
|
||||
downsample_kernel_size=1, downsample_padding=0,
|
||||
num_classes=num_classes, **kwargs)
|
||||
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
_load_pretrained(model, model_urls['seresnext50_32x4d'], inchans)
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
|
||||
|
||||
def seresnext101_32x4d(num_classes=1000, inchans=3, pretrained='imagenet', **kwargs):
|
||||
def seresnext101_32x4d(num_classes=1000, in_chans=3, pretrained=False, **kwargs):
|
||||
default_cfg = default_cfgs['seresnext101_32x4d']
|
||||
model = SENet(SEResNeXtBottleneck, [3, 4, 23, 3], groups=32, reduction=16,
|
||||
inplanes=64, input_3x3=False,
|
||||
downsample_kernel_size=1, downsample_padding=0,
|
||||
num_classes=num_classes, **kwargs)
|
||||
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
_load_pretrained(model, model_urls['seresnext101_32x4d'], inchans)
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
|
@ -25,3 +25,12 @@ class TestTimePoolHead(nn.Module):
|
||||
x = adaptive_avgmax_pool2d(x, 1)
|
||||
return x.view(x.size(0), -1)
|
||||
|
||||
|
||||
def apply_test_time_pool(model, args):
|
||||
test_time_pool = False
|
||||
if args.img_size > model.default_cfg['input_size'][-1] and not args.no_test_pool:
|
||||
print('Target input size (%d) > pretrained default (%d), using test time pooling' %
|
||||
(args.img_size, model.default_cfg['input_size'][-1]))
|
||||
model = TestTimePoolHead(model, original_pool=model.default_cfg['pool_size'])
|
||||
test_time_pool = True
|
||||
return model, test_time_pool
|
||||
|
@ -26,24 +26,24 @@ import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.model_zoo as model_zoo
|
||||
from torch.nn import init
|
||||
|
||||
from models.helpers import load_pretrained
|
||||
from models.adaptive_avgmax_pool import select_adaptive_pool2d
|
||||
|
||||
|
||||
__all__ = ['xception']
|
||||
|
||||
pretrained_config = {
|
||||
default_cfgs = {
|
||||
'xception': {
|
||||
'imagenet': {
|
||||
'url': 'http://data.lip6.fr/cadene/pretrainedmodels/xception-43020ad28.pth',
|
||||
'input_space': 'RGB',
|
||||
'input_size': [3, 299, 299],
|
||||
'input_range': [0, 1],
|
||||
'mean': [0.5, 0.5, 0.5],
|
||||
'std': [0.5, 0.5, 0.5],
|
||||
'num_classes': 1000,
|
||||
'scale': 0.8975
|
||||
'url': 'http://data.lip6.fr/cadene/pretrainedmodels/xception-43020ad28.pth',
|
||||
'input_size': (3, 299, 299),
|
||||
'mean': (0.5, 0.5, 0.5),
|
||||
'std': (0.5, 0.5, 0.5),
|
||||
'num_classes': 1000,
|
||||
'crop_pct': 0.8975,
|
||||
'first_conv': 'conv1',
|
||||
'classifier': 'fc'
|
||||
# The resize parameter of the validation transform should be 333, and make sure to center crop at 299x299
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -120,16 +120,18 @@ class Xception(nn.Module):
|
||||
https://arxiv.org/pdf/1610.02357.pdf
|
||||
"""
|
||||
|
||||
def __init__(self, num_classes=1000):
|
||||
def __init__(self, num_classes=1000, in_chans=3, drop_rate=0., global_pool='avg'):
|
||||
""" Constructor
|
||||
Args:
|
||||
num_classes: number of classes
|
||||
"""
|
||||
super(Xception, self).__init__()
|
||||
self.drop_rate = drop_rate
|
||||
self.global_pool = global_pool
|
||||
self.num_classes = num_classes
|
||||
self.num_features = 2048
|
||||
|
||||
self.conv1 = nn.Conv2d(3, 32, 3, 2, 0, bias=False)
|
||||
self.conv1 = nn.Conv2d(in_chans, 32, 3, 2, 0, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(32)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
@ -173,8 +175,9 @@ class Xception(nn.Module):
|
||||
def get_classifier(self):
|
||||
return self.fc
|
||||
|
||||
def reset_classifier(self, num_classes):
|
||||
def reset_classifier(self, num_classes, global_pool='avg'):
|
||||
self.num_classes = num_classes
|
||||
self.global_pool = global_pool
|
||||
del self.fc
|
||||
if num_classes:
|
||||
self.fc = nn.Linear(self.num_features, num_classes)
|
||||
@ -212,24 +215,23 @@ class Xception(nn.Module):
|
||||
x = self.relu(x)
|
||||
|
||||
if pool:
|
||||
x = F.adaptive_avg_pool2d(x, (1, 1))
|
||||
x = select_adaptive_pool2d(x, pool_type=self.global_pool)
|
||||
x = x.view(x.size(0), -1)
|
||||
return x
|
||||
|
||||
def forward(self, input):
|
||||
x = self.forward_features(input)
|
||||
if self.drop_rate:
|
||||
F.dropout(x, self.drop_rate, training=self.training)
|
||||
x = self.fc(x)
|
||||
return x
|
||||
|
||||
|
||||
def xception(num_classes=1000, pretrained=False):
|
||||
model = Xception(num_classes=num_classes)
|
||||
def xception(num_classes=1000, in_chans=3, pretrained=False, **kwargs):
|
||||
default_cfg = default_cfgs['xception']
|
||||
model = Xception(num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
config = pretrained_config['xception']['imagenet']
|
||||
assert num_classes == config['num_classes'], \
|
||||
"num_classes should be {}, but is {}".format(config['num_classes'], num_classes)
|
||||
|
||||
model = Xception(num_classes=num_classes)
|
||||
model.load_state_dict(model_zoo.load_url(config['url']))
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
|
||||
return model
|
||||
|
@ -1,2 +1,3 @@
|
||||
from optim.adabound import AdaBound
|
||||
from optim.nadam import Nadam
|
||||
from optim.nadam import Nadam
|
||||
from optim.optim_factory import create_optimizer
|
30
optim/optim_factory.py
Normal file
30
optim/optim_factory.py
Normal file
@ -0,0 +1,30 @@
|
||||
from torch import optim as optim
|
||||
from optim import Nadam, AdaBound
|
||||
|
||||
|
||||
def create_optimizer(args, parameters):
|
||||
if args.opt.lower() == 'sgd':
|
||||
optimizer = optim.SGD(
|
||||
parameters, lr=args.lr,
|
||||
momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True)
|
||||
elif args.opt.lower() == 'adam':
|
||||
optimizer = optim.Adam(
|
||||
parameters, lr=args.lr, weight_decay=args.weight_decay, eps=args.opt_eps)
|
||||
elif args.opt.lower() == 'nadam':
|
||||
optimizer = Nadam(
|
||||
parameters, lr=args.lr, weight_decay=args.weight_decay, eps=args.opt_eps)
|
||||
elif args.opt.lower() == 'adabound':
|
||||
optimizer = AdaBound(
|
||||
parameters, lr=args.lr / 100, weight_decay=args.weight_decay, eps=args.opt_eps,
|
||||
final_lr=args.lr)
|
||||
elif args.opt.lower() == 'adadelta':
|
||||
optimizer = optim.Adadelta(
|
||||
parameters, lr=args.lr, weight_decay=args.weight_decay, eps=args.opt_eps)
|
||||
elif args.opt.lower() == 'rmsprop':
|
||||
optimizer = optim.RMSprop(
|
||||
parameters, lr=args.lr, alpha=0.9, eps=args.opt_eps,
|
||||
momentum=args.momentum, weight_decay=args.weight_decay)
|
||||
else:
|
||||
assert False and "Invalid optimizer"
|
||||
raise ValueError
|
||||
return optimizer
|
@ -1,4 +1,5 @@
|
||||
from .cosine_lr import CosineLRScheduler
|
||||
from .plateau_lr import PlateauLRScheduler
|
||||
from .step_lr import StepLRScheduler
|
||||
from .tanh_lr import TanhLRScheduler
|
||||
from scheduler.cosine_lr import CosineLRScheduler
|
||||
from scheduler.plateau_lr import PlateauLRScheduler
|
||||
from scheduler.step_lr import StepLRScheduler
|
||||
from scheduler.tanh_lr import TanhLRScheduler
|
||||
from scheduler.scheduler_factory import create_scheduler
|
43
scheduler/scheduler_factory.py
Normal file
43
scheduler/scheduler_factory.py
Normal file
@ -0,0 +1,43 @@
|
||||
from scheduler.cosine_lr import CosineLRScheduler
|
||||
from scheduler.plateau_lr import PlateauLRScheduler
|
||||
from scheduler.tanh_lr import TanhLRScheduler
|
||||
from scheduler.step_lr import StepLRScheduler
|
||||
|
||||
|
||||
def create_scheduler(args, optimizer):
|
||||
num_epochs = args.epochs
|
||||
#FIXME expose cycle parms of the scheduler config to arguments
|
||||
if args.sched == 'cosine':
|
||||
lr_scheduler = CosineLRScheduler(
|
||||
optimizer,
|
||||
t_initial=num_epochs,
|
||||
t_mul=1.0,
|
||||
lr_min=1e-5,
|
||||
decay_rate=args.decay_rate,
|
||||
warmup_lr_init=args.warmup_lr,
|
||||
warmup_t=args.warmup_epochs,
|
||||
cycle_limit=1,
|
||||
t_in_epochs=True,
|
||||
)
|
||||
num_epochs = lr_scheduler.get_cycle_length() + 10
|
||||
elif args.sched == 'tanh':
|
||||
lr_scheduler = TanhLRScheduler(
|
||||
optimizer,
|
||||
t_initial=num_epochs,
|
||||
t_mul=1.0,
|
||||
lr_min=1e-5,
|
||||
warmup_lr_init=args.warmup_lr,
|
||||
warmup_t=args.warmup_epochs,
|
||||
cycle_limit=1,
|
||||
t_in_epochs=True,
|
||||
)
|
||||
num_epochs = lr_scheduler.get_cycle_length() + 10
|
||||
else:
|
||||
lr_scheduler = StepLRScheduler(
|
||||
optimizer,
|
||||
decay_t=args.decay_epochs,
|
||||
decay_rate=args.decay_rate,
|
||||
warmup_lr_init=args.warmup_lr,
|
||||
warmup_t=args.warmup_epochs,
|
||||
)
|
||||
return lr_scheduler, num_epochs
|
207
train.py
207
train.py
@ -1,7 +1,6 @@
|
||||
|
||||
import argparse
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
from datetime import datetime
|
||||
|
||||
try:
|
||||
@ -12,17 +11,14 @@ except ImportError:
|
||||
has_apex = False
|
||||
|
||||
from data import *
|
||||
from models import model_factory
|
||||
from models import create_model, resume_checkpoint
|
||||
from utils import *
|
||||
from optim import Nadam, AdaBound
|
||||
from loss import LabelSmoothingCrossEntropy
|
||||
import scheduler
|
||||
from optim import create_optimizer
|
||||
from scheduler import create_scheduler
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
import torch.utils.data as data
|
||||
import torch.distributed as dist
|
||||
import torchvision.utils
|
||||
|
||||
@ -33,6 +29,8 @@ parser.add_argument('data', metavar='DIR',
|
||||
help='path to dataset')
|
||||
parser.add_argument('--model', default='resnet101', type=str, metavar='MODEL',
|
||||
help='Name of model to train (default: "countception"')
|
||||
parser.add_argument('--num-classes', type=int, default=1000, metavar='N',
|
||||
help='number of label classes (default: 1000)')
|
||||
parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER',
|
||||
help='Optimizer (default: "sgd"')
|
||||
parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON',
|
||||
@ -120,10 +118,13 @@ def main():
|
||||
r = torch.distributed.get_rank()
|
||||
|
||||
if args.distributed:
|
||||
print('Training in distributed mode with %d processes, 1 GPU per process. Process %d.'
|
||||
% (args.world_size, r))
|
||||
print('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
|
||||
% (r, args.world_size))
|
||||
else:
|
||||
print('Training with a single process with %d GPUs.' % args.num_gpu)
|
||||
print('Training with a single process on %d GPUs.' % args.num_gpu)
|
||||
|
||||
# FIXME seed handling for multi-process distributed?
|
||||
torch.manual_seed(args.seed)
|
||||
|
||||
output_dir = ''
|
||||
if args.local_rank == 0:
|
||||
@ -137,80 +138,21 @@ def main():
|
||||
str(args.img_size)])
|
||||
output_dir = get_outdir(output_base, 'train', exp_name)
|
||||
|
||||
batch_size = args.batch_size
|
||||
torch.manual_seed(args.seed)
|
||||
|
||||
data_mean, data_std = get_model_meanstd(args.model)
|
||||
|
||||
dataset_train = Dataset(os.path.join(args.data, 'train'))
|
||||
|
||||
loader_train = create_loader(
|
||||
dataset_train,
|
||||
img_size=args.img_size,
|
||||
batch_size=batch_size,
|
||||
is_training=True,
|
||||
use_prefetcher=True,
|
||||
random_erasing=0.3,
|
||||
mean=data_mean,
|
||||
std=data_std,
|
||||
num_workers=args.workers,
|
||||
distributed=args.distributed,
|
||||
)
|
||||
|
||||
dataset_eval = Dataset(os.path.join(args.data, 'validation'))
|
||||
|
||||
loader_eval = create_loader(
|
||||
dataset_eval,
|
||||
img_size=args.img_size,
|
||||
batch_size=4 * args.batch_size,
|
||||
is_training=False,
|
||||
use_prefetcher=True,
|
||||
mean=data_mean,
|
||||
std=data_std,
|
||||
num_workers=args.workers,
|
||||
distributed=args.distributed,
|
||||
)
|
||||
|
||||
model = model_factory.create_model(
|
||||
model = create_model(
|
||||
args.model,
|
||||
pretrained=args.pretrained,
|
||||
num_classes=1000,
|
||||
num_classes=args.num_classes,
|
||||
drop_rate=args.drop,
|
||||
global_pool=args.gp,
|
||||
checkpoint_path=args.initial_checkpoint)
|
||||
|
||||
data_mean, data_std = get_mean_and_std(model, args)
|
||||
|
||||
# optionally resume from a checkpoint
|
||||
start_epoch = 0 if args.start_epoch is None else args.start_epoch
|
||||
start_epoch = 0
|
||||
optimizer_state = None
|
||||
if args.resume:
|
||||
if os.path.isfile(args.resume):
|
||||
print("=> loading checkpoint '{}'".format(args.resume))
|
||||
checkpoint = torch.load(args.resume)
|
||||
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
|
||||
new_state_dict = OrderedDict()
|
||||
for k, v in checkpoint['state_dict'].items():
|
||||
if k.startswith('module'):
|
||||
name = k[7:] # remove `module.`
|
||||
else:
|
||||
name = k
|
||||
new_state_dict[name] = v
|
||||
model.load_state_dict(new_state_dict)
|
||||
if 'optimizer' in checkpoint:
|
||||
optimizer_state = checkpoint['optimizer']
|
||||
print("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch']))
|
||||
start_epoch = checkpoint['epoch'] if args.start_epoch is None else args.start_epoch
|
||||
else:
|
||||
model.load_state_dict(checkpoint)
|
||||
else:
|
||||
print("=> no checkpoint found at '{}'".format(args.resume))
|
||||
return False
|
||||
|
||||
if args.smoothing:
|
||||
train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing).cuda()
|
||||
validate_loss_fn = nn.CrossEntropyLoss().cuda()
|
||||
else:
|
||||
train_loss_fn = nn.CrossEntropyLoss().cuda()
|
||||
validate_loss_fn = train_loss_fn
|
||||
start_epoch, optimizer_state = resume_checkpoint(model, args.resume, args.start_epoch)
|
||||
|
||||
if args.num_gpu > 1:
|
||||
if args.amp:
|
||||
@ -237,9 +179,55 @@ def main():
|
||||
model = DDP(model, delay_allreduce=True)
|
||||
|
||||
lr_scheduler, num_epochs = create_scheduler(args, optimizer)
|
||||
if start_epoch > 0:
|
||||
lr_scheduler.step(start_epoch)
|
||||
if args.local_rank == 0:
|
||||
print('Scheduled epochs: ', num_epochs)
|
||||
|
||||
train_dir = os.path.join(args.data, 'train')
|
||||
if not os.path.exists(train_dir):
|
||||
print('Error: training folder does not exist at: %s' % train_dir)
|
||||
exit(1)
|
||||
dataset_train = Dataset(train_dir)
|
||||
|
||||
loader_train = create_loader(
|
||||
dataset_train,
|
||||
img_size=args.img_size,
|
||||
batch_size=args.batch_size,
|
||||
is_training=True,
|
||||
use_prefetcher=True,
|
||||
random_erasing=0.3,
|
||||
mean=data_mean,
|
||||
std=data_std,
|
||||
num_workers=args.workers,
|
||||
distributed=args.distributed,
|
||||
)
|
||||
|
||||
eval_dir = os.path.join(args.data, 'validation')
|
||||
if not os.path.isdir(eval_dir):
|
||||
print('Error: validation folder does not exist at: %s' % eval_dir)
|
||||
exit(1)
|
||||
dataset_eval = Dataset(eval_dir)
|
||||
|
||||
loader_eval = create_loader(
|
||||
dataset_eval,
|
||||
img_size=args.img_size,
|
||||
batch_size=4 * args.batch_size,
|
||||
is_training=False,
|
||||
use_prefetcher=True,
|
||||
mean=data_mean,
|
||||
std=data_std,
|
||||
num_workers=args.workers,
|
||||
distributed=args.distributed,
|
||||
)
|
||||
|
||||
if args.smoothing:
|
||||
train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing).cuda()
|
||||
validate_loss_fn = nn.CrossEntropyLoss().cuda()
|
||||
else:
|
||||
train_loss_fn = nn.CrossEntropyLoss().cuda()
|
||||
validate_loss_fn = train_loss_fn
|
||||
|
||||
eval_metric = args.eval_metric
|
||||
saver = None
|
||||
if output_dir:
|
||||
@ -429,76 +417,9 @@ def validate(model, loader, loss_fn, args):
|
||||
return metrics
|
||||
|
||||
|
||||
def create_optimizer(args, parameters):
|
||||
if args.opt.lower() == 'sgd':
|
||||
optimizer = optim.SGD(
|
||||
parameters, lr=args.lr,
|
||||
momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True)
|
||||
elif args.opt.lower() == 'adam':
|
||||
optimizer = optim.Adam(
|
||||
parameters, lr=args.lr, weight_decay=args.weight_decay, eps=args.opt_eps)
|
||||
elif args.opt.lower() == 'nadam':
|
||||
optimizer = Nadam(
|
||||
parameters, lr=args.lr, weight_decay=args.weight_decay, eps=args.opt_eps)
|
||||
elif args.opt.lower() == 'adabound':
|
||||
optimizer = AdaBound(
|
||||
parameters, lr=args.lr / 100, weight_decay=args.weight_decay, eps=args.opt_eps,
|
||||
final_lr=args.lr)
|
||||
elif args.opt.lower() == 'adadelta':
|
||||
optimizer = optim.Adadelta(
|
||||
parameters, lr=args.lr, weight_decay=args.weight_decay, eps=args.opt_eps)
|
||||
elif args.opt.lower() == 'rmsprop':
|
||||
optimizer = optim.RMSprop(
|
||||
parameters, lr=args.lr, alpha=0.9, eps=args.opt_eps,
|
||||
momentum=args.momentum, weight_decay=args.weight_decay)
|
||||
else:
|
||||
assert False and "Invalid optimizer"
|
||||
raise ValueError
|
||||
return optimizer
|
||||
|
||||
|
||||
def create_scheduler(args, optimizer):
|
||||
num_epochs = args.epochs
|
||||
#FIXME expose cycle parms of the scheduler config to arguments
|
||||
if args.sched == 'cosine':
|
||||
lr_scheduler = scheduler.CosineLRScheduler(
|
||||
optimizer,
|
||||
t_initial=num_epochs,
|
||||
t_mul=1.0,
|
||||
lr_min=1e-5,
|
||||
decay_rate=args.decay_rate,
|
||||
warmup_lr_init=args.warmup_lr,
|
||||
warmup_t=args.warmup_epochs,
|
||||
cycle_limit=1,
|
||||
t_in_epochs=True,
|
||||
)
|
||||
num_epochs = lr_scheduler.get_cycle_length() + 10
|
||||
elif args.sched == 'tanh':
|
||||
lr_scheduler = scheduler.TanhLRScheduler(
|
||||
optimizer,
|
||||
t_initial=num_epochs,
|
||||
t_mul=1.0,
|
||||
lr_min=1e-5,
|
||||
warmup_lr_init=args.warmup_lr,
|
||||
warmup_t=args.warmup_epochs,
|
||||
cycle_limit=1,
|
||||
t_in_epochs=True,
|
||||
)
|
||||
num_epochs = lr_scheduler.get_cycle_length() + 10
|
||||
else:
|
||||
lr_scheduler = scheduler.StepLRScheduler(
|
||||
optimizer,
|
||||
decay_t=args.decay_epochs,
|
||||
decay_rate=args.decay_rate,
|
||||
warmup_lr_init=args.warmup_lr,
|
||||
warmup_t=args.warmup_epochs,
|
||||
)
|
||||
return lr_scheduler, num_epochs
|
||||
|
||||
|
||||
def reduce_tensor(tensor, n):
|
||||
rt = tensor.clone()
|
||||
dist.all_reduce(rt, op=dist.reduce_op.SUM)
|
||||
dist.all_reduce(rt, op=dist.ReduceOp.SUM)
|
||||
rt /= n
|
||||
return rt
|
||||
|
||||
|
74
validate.py
74
validate.py
@ -6,13 +6,14 @@ import argparse
|
||||
import os
|
||||
import time
|
||||
import torch
|
||||
import torch.backends.cudnn as cudnn
|
||||
import torch.nn as nn
|
||||
import torch.nn.parallel
|
||||
|
||||
from models import create_model, load_checkpoint, TestTimePoolHead
|
||||
from data import Dataset, create_loader, get_model_meanstd
|
||||
from models import create_model, apply_test_time_pool
|
||||
from data import Dataset, create_loader, get_mean_and_std
|
||||
from utils import accuracy, AverageMeter
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
parser = argparse.ArgumentParser(description='PyTorch ImageNet Validation')
|
||||
parser.add_argument('data', metavar='DIR',
|
||||
@ -25,6 +26,8 @@ parser.add_argument('-b', '--batch-size', default=256, type=int,
|
||||
metavar='N', help='mini-batch size (default: 256)')
|
||||
parser.add_argument('--img-size', default=224, type=int,
|
||||
metavar='N', help='Input image dimension')
|
||||
parser.add_argument('--num-classes', type=int, default=1000,
|
||||
help='Number classes in dataset')
|
||||
parser.add_argument('--print-freq', '-p', default=10, type=int,
|
||||
metavar='N', help='print frequency (default: 10)')
|
||||
parser.add_argument('--checkpoint', default='', type=str, metavar='PATH',
|
||||
@ -41,25 +44,19 @@ def main():
|
||||
args = parser.parse_args()
|
||||
|
||||
# create model
|
||||
num_classes = 1000
|
||||
model = create_model(
|
||||
args.model,
|
||||
num_classes=num_classes,
|
||||
pretrained=args.pretrained)
|
||||
num_classes=args.num_classes,
|
||||
in_chans=3,
|
||||
pretrained=args.pretrained,
|
||||
checkpoint_path=args.checkpoint)
|
||||
|
||||
print('Model %s created, param count: %d' %
|
||||
(args.model, sum([m.numel() for m in model.parameters()])))
|
||||
|
||||
# load a checkpoint
|
||||
if not args.pretrained:
|
||||
if not load_checkpoint(model, args.checkpoint):
|
||||
exit(1)
|
||||
data_mean, data_std = get_mean_and_std(model, args)
|
||||
|
||||
test_time_pool = False
|
||||
# FIXME make this work for networks with default img size != 224 and default pool k != 7
|
||||
if args.img_size > 224 and not args.no_test_pool:
|
||||
model = TestTimePoolHead(model)
|
||||
test_time_pool = True
|
||||
model, test_time_pool = apply_test_time_pool(model, args)
|
||||
|
||||
if args.num_gpu > 1:
|
||||
model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda()
|
||||
@ -69,14 +66,11 @@ def main():
|
||||
# define loss function (criterion) and optimizer
|
||||
criterion = nn.CrossEntropyLoss().cuda()
|
||||
|
||||
cudnn.benchmark = True
|
||||
|
||||
data_mean, data_std = get_model_meanstd(args.model)
|
||||
loader = create_loader(
|
||||
Dataset(args.data),
|
||||
img_size=args.img_size,
|
||||
batch_size=args.batch_size,
|
||||
use_prefetcher=True,
|
||||
use_prefetcher=False,
|
||||
mean=data_mean,
|
||||
std=data_std,
|
||||
num_workers=args.workers,
|
||||
@ -111,51 +105,17 @@ def main():
|
||||
|
||||
if i % args.print_freq == 0:
|
||||
print('Test: [{0}/{1}]\t'
|
||||
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
|
||||
'Time {batch_time.val:.3f} ({batch_time.avg:.3f}, {rate_avg:.3f}/s) \t'
|
||||
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
|
||||
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
|
||||
'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
|
||||
i, len(loader), batch_time=batch_time, loss=losses,
|
||||
top1=top1, top5=top5))
|
||||
i, len(loader), batch_time=batch_time,
|
||||
rate_avg=input.size(0) / batch_time.avg,
|
||||
loss=losses, top1=top1, top5=top5))
|
||||
|
||||
print(' * Prec@1 {top1.avg:.3f} ({top1a:.3f}) Prec@5 {top5.avg:.3f} ({top5a:.3f})'.format(
|
||||
top1=top1, top1a=100-top1.avg, top5=top5, top5a=100.-top5.avg))
|
||||
|
||||
|
||||
class AverageMeter(object):
|
||||
"""Computes and stores the average and current value"""
|
||||
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.val = 0
|
||||
self.avg = 0
|
||||
self.sum = 0
|
||||
self.count = 0
|
||||
|
||||
def update(self, val, n=1):
|
||||
self.val = val
|
||||
self.sum += val * n
|
||||
self.count += n
|
||||
self.avg = self.sum / self.count
|
||||
|
||||
|
||||
def accuracy(output, target, topk=(1,)):
|
||||
"""Computes the precision@k for the specified values of k"""
|
||||
maxk = max(topk)
|
||||
batch_size = target.size(0)
|
||||
|
||||
_, pred = output.topk(maxk, 1, True, True)
|
||||
pred = pred.t()
|
||||
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
||||
|
||||
res = []
|
||||
for k in topk:
|
||||
correct_k = correct[:k].view(-1).float().sum(0)
|
||||
res.append(correct_k.mul_(100.0 / batch_size))
|
||||
return res
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user