mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Merge branch 'master' of github.com:rwightman/pytorch-models
This commit is contained in:
commit
63e677d03b
@ -20,7 +20,7 @@ def get_model_meanstd(model_name):
|
||||
model_name = model_name.lower()
|
||||
if 'dpn' in model_name:
|
||||
return IMAGENET_DPN_MEAN, IMAGENET_DPN_STD
|
||||
elif 'ception' in model_name:
|
||||
elif 'ception' in model_name or 'nasnet' in model_name:
|
||||
return IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
||||
else:
|
||||
return IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
@ -30,7 +30,7 @@ def get_model_mean(model_name):
|
||||
model_name = model_name.lower()
|
||||
if 'dpn' in model_name:
|
||||
return IMAGENET_DPN_STD
|
||||
elif 'ception' in model_name:
|
||||
elif 'ception' in model_name or 'nasnet' in model_name:
|
||||
return IMAGENET_INCEPTION_MEAN
|
||||
else:
|
||||
return IMAGENET_DEFAULT_MEAN
|
||||
@ -40,7 +40,7 @@ def get_model_std(model_name):
|
||||
model_name = model_name.lower()
|
||||
if 'dpn' in model_name:
|
||||
return IMAGENET_DEFAULT_STD
|
||||
elif 'ception' in model_name:
|
||||
elif 'ception' in model_name or 'nasnet' in model_name:
|
||||
return IMAGENET_INCEPTION_STD
|
||||
else:
|
||||
return IMAGENET_DEFAULT_STD
|
||||
|
@ -12,6 +12,7 @@ from .senet import seresnet18, seresnet34, seresnet50, seresnet101, seresnet152,
|
||||
seresnext26_32x4d, seresnext50_32x4d, seresnext101_32x4d
|
||||
#from .resnext import resnext50, resnext101, resnext152
|
||||
from .xception import xception
|
||||
from .pnasnet import pnasnet5large
|
||||
|
||||
model_config_dict = {
|
||||
'resnet18': {
|
||||
@ -48,6 +49,8 @@ model_config_dict = {
|
||||
'model_name': 'inception_resnet_v2', 'num_classes': 1000, 'input_size': 299, 'normalizer': 'le'},
|
||||
'xception': {
|
||||
'model_name': 'xception', 'num_classes': 1000, 'input_size': 299, 'normalizer': 'le'},
|
||||
'pnasnet5large': {
|
||||
'model_name': 'pnasnet5large', 'num_classes': 1000, 'input_size': 331, 'normalizer': 'le'}
|
||||
}
|
||||
|
||||
|
||||
@ -118,6 +121,8 @@ def create_model(
|
||||
model = resnext152_32x4d(num_classes=num_classes, pretrained=pretrained, **kwargs)
|
||||
elif model_name == 'xception':
|
||||
model = xception(num_classes=num_classes, pretrained=pretrained)
|
||||
elif model_name == 'pnasnet5large':
|
||||
model = pnasnet5large(num_classes=num_classes, pretrained=pretrained)
|
||||
else:
|
||||
assert False and "Invalid model"
|
||||
|
||||
|
@ -5,7 +5,6 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.model_zoo as model_zoo
|
||||
|
||||
|
||||
pretrained_settings = {
|
||||
'pnasnet5large': {
|
||||
'imagenet': {
|
||||
@ -292,6 +291,8 @@ class PNASNet5Large(nn.Module):
|
||||
def __init__(self, num_classes=1001):
|
||||
super(PNASNet5Large, self).__init__()
|
||||
self.num_classes = num_classes
|
||||
self.num_features = 4320
|
||||
|
||||
self.conv_0 = nn.Sequential(OrderedDict([
|
||||
('conv', nn.Conv2d(3, 96, kernel_size=3, stride=2, bias=False)),
|
||||
('bn', nn.BatchNorm2d(96, eps=0.001))
|
||||
@ -335,9 +336,20 @@ class PNASNet5Large(nn.Module):
|
||||
self.relu = nn.ReLU()
|
||||
self.avg_pool = nn.AvgPool2d(11, stride=1, padding=0)
|
||||
self.dropout = nn.Dropout(0.5)
|
||||
self.last_linear = nn.Linear(4320, num_classes)
|
||||
self.last_linear = nn.Linear(self.num_features, num_classes)
|
||||
|
||||
def features(self, x):
|
||||
def get_classifier(self):
|
||||
return self.last_linear
|
||||
|
||||
def reset_classifier(self, num_classes):
|
||||
self.num_classes = num_classes
|
||||
del self.last_linear
|
||||
if num_classes:
|
||||
self.last_linear = nn.Linear(self.num_features, num_classes)
|
||||
else:
|
||||
self.last_linear = None
|
||||
|
||||
def forward_features(self, x, pool=True):
|
||||
x_conv_0 = self.conv_0(x)
|
||||
x_stem_0 = self.cell_stem_0(x_conv_0)
|
||||
x_stem_1 = self.cell_stem_1(x_conv_0, x_stem_0)
|
||||
@ -353,19 +365,16 @@ class PNASNet5Large(nn.Module):
|
||||
x_cell_9 = self.cell_9(x_cell_7, x_cell_8)
|
||||
x_cell_10 = self.cell_10(x_cell_8, x_cell_9)
|
||||
x_cell_11 = self.cell_11(x_cell_9, x_cell_10)
|
||||
return x_cell_11
|
||||
|
||||
def logits(self, features):
|
||||
x = self.relu(features)
|
||||
x = self.avg_pool(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
x = self.dropout(x)
|
||||
x = self.last_linear(x)
|
||||
x = self.relu(x_cell_11)
|
||||
if pool:
|
||||
x = self.avg_pool(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
return x
|
||||
|
||||
def forward(self, input):
|
||||
x = self.features(input)
|
||||
x = self.logits(x)
|
||||
x = self.forward_features(input)
|
||||
x = self.dropout(x)
|
||||
x = self.last_linear(x)
|
||||
return x
|
||||
|
||||
|
||||
@ -375,7 +384,7 @@ def pnasnet5large(num_classes=1001, pretrained='imagenet'):
|
||||
<https://arxiv.org/abs/1712.00559>`_ paper.
|
||||
"""
|
||||
if pretrained:
|
||||
settings = pretrained_settings['pnasnet5large'][pretrained]
|
||||
settings = pretrained_settings['pnasnet5large']['imagenet']
|
||||
assert num_classes == settings[
|
||||
'num_classes'], 'num_classes should be {}, but is {}'.format(
|
||||
settings['num_classes'], num_classes)
|
||||
@ -384,18 +393,12 @@ def pnasnet5large(num_classes=1001, pretrained='imagenet'):
|
||||
model = PNASNet5Large(num_classes=1001)
|
||||
model.load_state_dict(model_zoo.load_url(settings['url']))
|
||||
|
||||
if pretrained == 'imagenet':
|
||||
new_last_linear = nn.Linear(model.last_linear.in_features, 1000)
|
||||
new_last_linear.weight.data = model.last_linear.weight.data[1:]
|
||||
new_last_linear.bias.data = model.last_linear.bias.data[1:]
|
||||
model.last_linear = new_last_linear
|
||||
#if pretrained == 'imagenet':
|
||||
new_last_linear = nn.Linear(model.last_linear.in_features, 1000)
|
||||
new_last_linear.weight.data = model.last_linear.weight.data[1:]
|
||||
new_last_linear.bias.data = model.last_linear.bias.data[1:]
|
||||
model.last_linear = new_last_linear
|
||||
|
||||
model.input_space = settings['input_space']
|
||||
model.input_size = settings['input_size']
|
||||
model.input_range = settings['input_range']
|
||||
|
||||
model.mean = settings['mean']
|
||||
model.std = settings['std']
|
||||
else:
|
||||
model = PNASNet5Large(num_classes=num_classes)
|
||||
return model
|
||||
|
@ -127,6 +127,7 @@ class Xception(nn.Module):
|
||||
"""
|
||||
super(Xception, self).__init__()
|
||||
self.num_classes = num_classes
|
||||
self.num_features = 2048
|
||||
|
||||
self.conv1 = nn.Conv2d(3, 32, 3, 2, 0, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(32)
|
||||
@ -156,10 +157,10 @@ class Xception(nn.Module):
|
||||
self.bn3 = nn.BatchNorm2d(1536)
|
||||
|
||||
# do relu here
|
||||
self.conv4 = SeparableConv2d(1536, 2048, 3, 1, 1)
|
||||
self.bn4 = nn.BatchNorm2d(2048)
|
||||
self.conv4 = SeparableConv2d(1536, self.num_features, 3, 1, 1)
|
||||
self.bn4 = nn.BatchNorm2d(self.num_features)
|
||||
|
||||
self.fc = nn.Linear(2048, num_classes)
|
||||
self.fc = nn.Linear(self.num_features, num_classes)
|
||||
|
||||
# #------- init weights --------
|
||||
for m in self.modules():
|
||||
@ -169,7 +170,18 @@ class Xception(nn.Module):
|
||||
m.weight.data.fill_(1)
|
||||
m.bias.data.zero_()
|
||||
|
||||
def forward_features(self, input):
|
||||
def get_classifier(self):
|
||||
return self.fc
|
||||
|
||||
def reset_classifier(self, num_classes):
|
||||
self.num_classes = num_classes
|
||||
del self.fc
|
||||
if num_classes:
|
||||
self.fc = nn.Linear(self.num_features, num_classes)
|
||||
else:
|
||||
self.fc = None
|
||||
|
||||
def forward_features(self, input, pool=True):
|
||||
x = self.conv1(input)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
@ -197,19 +209,16 @@ class Xception(nn.Module):
|
||||
|
||||
x = self.conv4(x)
|
||||
x = self.bn4(x)
|
||||
return x
|
||||
x = self.relu(x)
|
||||
|
||||
def logits(self, features):
|
||||
x = self.relu(features)
|
||||
|
||||
x = F.adaptive_avg_pool2d(x, (1, 1))
|
||||
x = x.view(x.size(0), -1)
|
||||
x = self.last_linear(x)
|
||||
if pool:
|
||||
x = F.adaptive_avg_pool2d(x, (1, 1))
|
||||
x = x.view(x.size(0), -1)
|
||||
return x
|
||||
|
||||
def forward(self, input):
|
||||
x = self.forward_features(input)
|
||||
x = self.logits(x)
|
||||
x = self.fc(x)
|
||||
return x
|
||||
|
||||
|
||||
@ -223,13 +232,4 @@ def xception(num_classes=1000, pretrained=False):
|
||||
model = Xception(num_classes=num_classes)
|
||||
model.load_state_dict(model_zoo.load_url(config['url']))
|
||||
|
||||
model.input_space = config['input_space']
|
||||
model.input_size = config['input_size']
|
||||
model.input_range = config['input_range']
|
||||
model.mean = config['mean']
|
||||
model.std = config['std']
|
||||
|
||||
# TODO: ugly
|
||||
model.last_linear = model.fc
|
||||
del model.fc
|
||||
return model
|
||||
|
24
train.py
24
train.py
@ -93,6 +93,8 @@ parser.add_argument('--amp', action='store_true', default=False,
|
||||
help='use NVIDIA amp for mixed precision training')
|
||||
parser.add_argument('--output', default='', type=str, metavar='PATH',
|
||||
help='path to output folder (default: none, current dir)')
|
||||
parser.add_argument('--eval-metric', default='prec1', type=str, metavar='EVAL_METRIC',
|
||||
help='Best metric (default: "prec1"')
|
||||
parser.add_argument("--local_rank", default=0, type=int)
|
||||
|
||||
|
||||
@ -238,10 +240,13 @@ def main():
|
||||
if args.local_rank == 0:
|
||||
print('Scheduled epochs: ', num_epochs)
|
||||
|
||||
eval_metric = args.eval_metric
|
||||
saver = None
|
||||
if output_dir:
|
||||
saver = CheckpointSaver(checkpoint_dir=output_dir)
|
||||
best_loss = None
|
||||
decreasing = True if eval_metric == 'loss' else False
|
||||
saver = CheckpointSaver(checkpoint_dir=output_dir, decreasing=decreasing)
|
||||
best_metric = None
|
||||
best_epoch = None
|
||||
try:
|
||||
for epoch in range(start_epoch, num_epochs):
|
||||
if args.distributed:
|
||||
@ -255,15 +260,15 @@ def main():
|
||||
model, loader_eval, validate_loss_fn, args)
|
||||
|
||||
if lr_scheduler is not None:
|
||||
lr_scheduler.step(epoch, eval_metrics['eval_loss'])
|
||||
lr_scheduler.step(epoch, eval_metrics[eval_metric])
|
||||
|
||||
update_summary(
|
||||
epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'),
|
||||
write_header=best_loss is None)
|
||||
write_header=best_metric is None)
|
||||
|
||||
if saver is not None:
|
||||
# save proper checkpoint with eval metric
|
||||
best_loss = saver.save_checkpoint({
|
||||
best_metric, best_epoch = saver.save_checkpoint({
|
||||
'epoch': epoch + 1,
|
||||
'arch': args.model,
|
||||
'state_dict': model.state_dict(),
|
||||
@ -271,11 +276,12 @@ def main():
|
||||
'args': args,
|
||||
},
|
||||
epoch=epoch + 1,
|
||||
metric=eval_metrics['eval_loss'])
|
||||
metric=eval_metrics[eval_metric])
|
||||
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
print('*** Best loss: {0} (epoch {1})'.format(best_loss[1], best_loss[0]))
|
||||
if best_metric is not None:
|
||||
print('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch))
|
||||
|
||||
|
||||
def train_epoch(
|
||||
@ -363,7 +369,7 @@ def train_epoch(
|
||||
|
||||
end = time.time()
|
||||
|
||||
return OrderedDict([('train_loss', losses_m.avg)])
|
||||
return OrderedDict([('loss', losses_m.avg)])
|
||||
|
||||
|
||||
def validate(model, loader, loss_fn, args):
|
||||
@ -418,7 +424,7 @@ def validate(model, loader, loss_fn, args):
|
||||
batch_time=batch_time_m, loss=losses_m,
|
||||
top1=prec1_m, top5=prec5_m))
|
||||
|
||||
metrics = OrderedDict([('eval_loss', losses_m.avg), ('eval_prec1', prec1_m.avg)])
|
||||
metrics = OrderedDict([('loss', losses_m.avg), ('prec1', prec1_m.avg), ('prec5', prec5_m.avg)])
|
||||
|
||||
return metrics
|
||||
|
||||
|
50
utils.py
50
utils.py
@ -6,6 +6,7 @@ import os
|
||||
import shutil
|
||||
import glob
|
||||
import csv
|
||||
import operator
|
||||
from collections import OrderedDict
|
||||
|
||||
|
||||
@ -16,24 +17,32 @@ class CheckpointSaver:
|
||||
recovery_prefix='recovery',
|
||||
checkpoint_dir='',
|
||||
recovery_dir='',
|
||||
decreasing=False,
|
||||
verbose=True,
|
||||
max_history=10):
|
||||
|
||||
self.checkpoint_files = []
|
||||
# state
|
||||
self.checkpoint_files = [] # (filename, metric) tuples in order of decreasing betterness
|
||||
self.best_epoch = None
|
||||
self.best_metric = None
|
||||
self.worst_metric = None
|
||||
self.max_history = max_history
|
||||
assert self.max_history >= 1
|
||||
self.curr_recovery_file = ''
|
||||
self.last_recovery_file = ''
|
||||
|
||||
# config
|
||||
self.checkpoint_dir = checkpoint_dir
|
||||
self.recovery_dir = recovery_dir
|
||||
self.save_prefix = checkpoint_prefix
|
||||
self.recovery_prefix = recovery_prefix
|
||||
self.extension = '.pth.tar'
|
||||
self.decreasing = decreasing # a lower metric is better if True
|
||||
self.cmp = operator.lt if decreasing else operator.gt # True if lhs better than rhs
|
||||
self.verbose = verbose
|
||||
self.max_history = max_history
|
||||
assert self.max_history >= 1
|
||||
|
||||
def save_checkpoint(self, state, epoch, metric=None):
|
||||
worst_metric = self.checkpoint_files[-1] if self.checkpoint_files else None
|
||||
if len(self.checkpoint_files) < self.max_history or metric < worst_metric[1]:
|
||||
worst_file = self.checkpoint_files[-1] if self.checkpoint_files else None
|
||||
if len(self.checkpoint_files) < self.max_history or self.cmp(metric, worst_file[1]):
|
||||
if len(self.checkpoint_files) >= self.max_history:
|
||||
self._cleanup_checkpoints(1)
|
||||
|
||||
@ -43,16 +52,21 @@ class CheckpointSaver:
|
||||
state['metric'] = metric
|
||||
torch.save(state, save_path)
|
||||
self.checkpoint_files.append((save_path, metric))
|
||||
self.checkpoint_files = sorted(self.checkpoint_files, key=lambda x: x[1])
|
||||
self.checkpoint_files = sorted(
|
||||
self.checkpoint_files, key=lambda x: x[1],
|
||||
reverse=not self.decreasing) # sort in descending order if a lower metric is not better
|
||||
|
||||
print("Current checkpoints:")
|
||||
for c in self.checkpoint_files:
|
||||
print(c)
|
||||
if self.verbose:
|
||||
print("Current checkpoints:")
|
||||
for c in self.checkpoint_files:
|
||||
print(c)
|
||||
|
||||
if metric is not None and (self.best_metric is None or metric < self.best_metric[1]):
|
||||
self.best_metric = (epoch, metric)
|
||||
if metric is not None and (self.best_metric is None or self.cmp(metric, self.best_metric)):
|
||||
self.best_epoch = epoch
|
||||
self.best_metric = metric
|
||||
shutil.copyfile(save_path, os.path.join(self.checkpoint_dir, 'model_best' + self.extension))
|
||||
return None, None if self.best_metric is None else self.best_metric
|
||||
|
||||
return (None, None) if self.best_metric is None else (self.best_metric, self.best_epoch)
|
||||
|
||||
def _cleanup_checkpoints(self, trim=0):
|
||||
trim = min(len(self.checkpoint_files), trim)
|
||||
@ -62,7 +76,8 @@ class CheckpointSaver:
|
||||
to_delete = self.checkpoint_files[delete_index:]
|
||||
for d in to_delete:
|
||||
try:
|
||||
print('Cleaning checkpoint: ', d)
|
||||
if self.verbose:
|
||||
print('Cleaning checkpoint: ', d)
|
||||
os.remove(d[0])
|
||||
except Exception as e:
|
||||
print('Exception (%s) while deleting checkpoint' % str(e))
|
||||
@ -74,7 +89,8 @@ class CheckpointSaver:
|
||||
torch.save(state, save_path)
|
||||
if os.path.exists(self.last_recovery_file):
|
||||
try:
|
||||
print('Cleaning recovery', self.last_recovery_file)
|
||||
if self.verbose:
|
||||
print('Cleaning recovery', self.last_recovery_file)
|
||||
os.remove(self.last_recovery_file)
|
||||
except Exception as e:
|
||||
print("Exception (%s) while removing %s" % (str(e), self.last_recovery_file))
|
||||
@ -143,8 +159,8 @@ def get_outdir(path, *paths, inc=False):
|
||||
|
||||
def update_summary(epoch, train_metrics, eval_metrics, filename, write_header=False):
|
||||
rowd = OrderedDict(epoch=epoch)
|
||||
rowd.update(train_metrics)
|
||||
rowd.update(eval_metrics)
|
||||
rowd.update([('train_' + k, v) for k, v in train_metrics.items()])
|
||||
rowd.update([('eval_' + k, v) for k, v in eval_metrics.items()])
|
||||
with open(filename, mode='a') as cf:
|
||||
dw = csv.DictWriter(cf, fieldnames=rowd.keys())
|
||||
if write_header: # first iteration (epoch == 1 can't be used)
|
||||
|
Loading…
x
Reference in New Issue
Block a user