Fixup validate/inference script args, fix senet init for better test accuracy

This commit is contained in:
Ross Wightman 2019-02-22 14:07:50 -08:00
parent b1a5a71151
commit 31055466fc
5 changed files with 32 additions and 65 deletions

View File

@ -10,10 +10,9 @@ import time
import argparse import argparse
import numpy as np import numpy as np
import torch import torch
import torch.autograd as autograd
import torch.utils.data as data import torch.utils.data as data
import model_factory from models import create_model, transforms_imagenet_eval
from dataset import Dataset from dataset import Dataset
@ -32,12 +31,12 @@ parser.add_argument('--img-size', default=224, type=int,
metavar='N', help='Input image dimension') metavar='N', help='Input image dimension')
parser.add_argument('--print-freq', '-p', default=10, type=int, parser.add_argument('--print-freq', '-p', default=10, type=int,
metavar='N', help='print frequency (default: 10)') metavar='N', help='print frequency (default: 10)')
parser.add_argument('--restore-checkpoint', default='', type=str, metavar='PATH', parser.add_argument('--checkpoint', default='', type=str, metavar='PATH',
help='path to latest checkpoint (default: none)') help='path to latest checkpoint (default: none)')
parser.add_argument('--pretrained', dest='pretrained', action='store_true', parser.add_argument('--pretrained', dest='pretrained', action='store_true',
help='use pre-trained model') help='use pre-trained model')
parser.add_argument('--multi-gpu', dest='multi_gpu', action='store_true', parser.add_argument('--num-gpu', type=int, default=1,
help='use multiple-gpus') help='Number of GPUS to use')
parser.add_argument('--no-test-pool', dest='test_time_pool', action='store_false', parser.add_argument('--no-test-pool', dest='test_time_pool', action='store_false',
help='use pre-trained model') help='use pre-trained model')
@ -47,37 +46,33 @@ def main():
# create model # create model
num_classes = 1000 num_classes = 1000
model = model_factory.create_model( model = create_model(
args.model, args.model,
num_classes=num_classes, num_classes=num_classes,
pretrained=args.pretrained, pretrained=args.pretrained,
test_time_pool=args.test_time_pool) test_time_pool=args.test_time_pool)
# resume from a checkpoint # resume from a checkpoint
if args.restore_checkpoint and os.path.isfile(args.restore_checkpoint): if args.checkpoint and os.path.isfile(args.checkpoint):
print("=> loading checkpoint '{}'".format(args.restore_checkpoint)) print("=> loading checkpoint '{}'".format(args.checkpoint))
checkpoint = torch.load(args.restore_checkpoint) checkpoint = torch.load(args.checkpoint)
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
model.load_state_dict(checkpoint['state_dict']) model.load_state_dict(checkpoint['state_dict'])
else: else:
model.load_state_dict(checkpoint) model.load_state_dict(checkpoint)
print("=> loaded checkpoint '{}'".format(args.restore_checkpoint)) print("=> loaded checkpoint '{}'".format(args.checkpoint))
elif not args.pretrained: elif not args.pretrained:
print("=> no checkpoint found at '{}'".format(args.restore_checkpoint)) print("=> no checkpoint found at '{}'".format(args.checkpoint))
exit(1) exit(1)
if args.multi_gpu: if args.num_gpu > 1:
model = torch.nn.DataParallel(model).cuda() model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda()
else: else:
model = model.cuda() model = model.cuda()
transforms = model_factory.get_transforms_eval(
args.model,
args.img_size)
dataset = Dataset( dataset = Dataset(
args.data, args.data,
transforms) transforms_imagenet_eval(args.model, args.img_size))
loader = data.DataLoader( loader = data.DataLoader(
dataset, dataset,

View File

@ -105,13 +105,9 @@ pretrained_config = {
def _weight_init(m, n='', ll=''): def _weight_init(m, n='', ll=''):
print(m, n, ll)
if isinstance(m, nn.Conv2d): if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.BatchNorm2d): elif isinstance(m, nn.BatchNorm2d):
if ll and n == ll:
nn.init.constant_(m.weight, 0.)
else:
nn.init.constant_(m.weight, 1.) nn.init.constant_(m.weight, 1.)
nn.init.constant_(m.bias, 0.) nn.init.constant_(m.bias, 0.)
@ -128,9 +124,6 @@ class SEModule(nn.Module):
channels // reduction, channels, kernel_size=1, padding=0) channels // reduction, channels, kernel_size=1, padding=0)
self.sigmoid = nn.Sigmoid() self.sigmoid = nn.Sigmoid()
for m in self.modules():
_weight_init(m)
def forward(self, x): def forward(self, x):
module_input = x module_input = x
x = self.avg_pool(x) x = self.avg_pool(x)
@ -191,9 +184,6 @@ class SEBottleneck(Bottleneck):
self.downsample = downsample self.downsample = downsample
self.stride = stride self.stride = stride
for n, m in self.named_modules():
_weight_init(m, n, ll='bn3')
class SEResNetBottleneck(Bottleneck): class SEResNetBottleneck(Bottleneck):
""" """
@ -219,9 +209,6 @@ class SEResNetBottleneck(Bottleneck):
self.downsample = downsample self.downsample = downsample
self.stride = stride self.stride = stride
for n, m in self.named_modules():
_weight_init(m, n, ll='bn3')
class SEResNeXtBottleneck(Bottleneck): class SEResNeXtBottleneck(Bottleneck):
""" """
@ -246,9 +233,6 @@ class SEResNeXtBottleneck(Bottleneck):
self.downsample = downsample self.downsample = downsample
self.stride = stride self.stride = stride
for n, m in self.named_modules():
_weight_init(m, n, ll='bn3')
class SEResNetBlock(nn.Module): class SEResNetBlock(nn.Module):
expansion = 1 expansion = 1
@ -266,9 +250,6 @@ class SEResNetBlock(nn.Module):
self.downsample = downsample self.downsample = downsample
self.stride = stride self.stride = stride
for n, m in self.named_modules():
_weight_init(m, n, ll='bn2')
def forward(self, x): def forward(self, x):
residual = x residual = x
@ -405,10 +386,7 @@ class SENet(nn.Module):
self.dropout = nn.Dropout(dropout_p) if dropout_p is not None else None self.dropout = nn.Dropout(dropout_p) if dropout_p is not None else None
self.last_linear = nn.Linear(512 * block.expansion, num_classes) self.last_linear = nn.Linear(512 * block.expansion, num_classes)
for n, m in self.named_children(): for m in self.modules():
if n == 'layer0':
m.apply(_weight_init)
else:
_weight_init(m) _weight_init(m)
def _make_layer(self, block, planes, blocks, groups, reduction, stride=1, def _make_layer(self, block, planes, blocks, groups, reduction, stride=1,

View File

@ -21,7 +21,7 @@ class LeNormalize(object):
return tensor return tensor
def transforms_imagenet_train(model_name, img_size=224, scale=(0.1, 1.0), color_jitter=(0.333, 0.333, 0.333)): def transforms_imagenet_train(model_name, img_size=224, scale=(0.1, 1.0), color_jitter=(0.4, 0.4, 0.4)):
if 'dpn' in model_name: if 'dpn' in model_name:
normalize = transforms.Normalize( normalize = transforms.Normalize(
mean=IMAGENET_DPN_MEAN, mean=IMAGENET_DPN_MEAN,

View File

@ -180,8 +180,8 @@ def main():
assert False and "Invalid optimizer" assert False and "Invalid optimizer"
exit(1) exit(1)
if optimizer_state is not None: #if optimizer_state is not None:
optimizer.load_state_dict(optimizer_state) # optimizer.load_state_dict(optimizer_state)
if args.sched == 'cosine': if args.sched == 'cosine':
lr_scheduler = scheduler.CosineLRScheduler( lr_scheduler = scheduler.CosineLRScheduler(

View File

@ -12,7 +12,7 @@ import torch.nn.parallel
import torch.utils.data as data import torch.utils.data as data
from models import model_factory from models import create_model, transforms_imagenet_eval
from dataset import Dataset from dataset import Dataset
@ -29,12 +29,12 @@ parser.add_argument('--img-size', default=224, type=int,
metavar='N', help='Input image dimension') metavar='N', help='Input image dimension')
parser.add_argument('--print-freq', '-p', default=10, type=int, parser.add_argument('--print-freq', '-p', default=10, type=int,
metavar='N', help='print frequency (default: 10)') metavar='N', help='print frequency (default: 10)')
parser.add_argument('--restore-checkpoint', default='', type=str, metavar='PATH', parser.add_argument('--checkpoint', default='', type=str, metavar='PATH',
help='path to latest checkpoint (default: none)') help='path to latest checkpoint (default: none)')
parser.add_argument('--pretrained', dest='pretrained', action='store_true', parser.add_argument('--pretrained', dest='pretrained', action='store_true',
help='use pre-trained model') help='use pre-trained model')
parser.add_argument('--multi-gpu', dest='multi_gpu', action='store_true', parser.add_argument('--num-gpu', type=int, default=1,
help='use multiple-gpus') help='Number of GPUS to use')
parser.add_argument('--no-test-pool', dest='no_test_pool', action='store_true', parser.add_argument('--no-test-pool', dest='no_test_pool', action='store_true',
help='disable test time pool for DPN models') help='disable test time pool for DPN models')
@ -48,7 +48,7 @@ def main():
# create model # create model
num_classes = 1000 num_classes = 1000
model = model_factory.create_model( model = create_model(
args.model, args.model,
num_classes=num_classes, num_classes=num_classes,
pretrained=args.pretrained, pretrained=args.pretrained,
@ -57,23 +57,21 @@ def main():
print('Model %s created, param count: %d' % print('Model %s created, param count: %d' %
(args.model, sum([m.numel() for m in model.parameters()]))) (args.model, sum([m.numel() for m in model.parameters()])))
print(model)
# optionally resume from a checkpoint # optionally resume from a checkpoint
if args.restore_checkpoint and os.path.isfile(args.restore_checkpoint): if args.checkpoint and os.path.isfile(args.checkpoint):
print("=> loading checkpoint '{}'".format(args.restore_checkpoint)) print("=> loading checkpoint '{}'".format(args.checkpoint))
checkpoint = torch.load(args.restore_checkpoint) checkpoint = torch.load(args.checkpoint)
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
model.load_state_dict(checkpoint['state_dict']) model.load_state_dict(checkpoint['state_dict'])
else: else:
model.load_state_dict(checkpoint) model.load_state_dict(checkpoint)
print("=> loaded checkpoint '{}'".format(args.restore_checkpoint)) print("=> loaded checkpoint '{}'".format(args.checkpoint))
elif not args.pretrained: elif not args.pretrained:
print("=> no checkpoint found at '{}'".format(args.restore_checkpoint)) print("=> no checkpoint found at '{}'".format(args.checkpoint))
exit(1) exit(1)
if args.multi_gpu: if args.num_gpu > 1:
model = torch.nn.DataParallel(model).cuda() model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda()
else: else:
model = model.cuda() model = model.cuda()
@ -82,13 +80,9 @@ def main():
cudnn.benchmark = True cudnn.benchmark = True
transforms = model_factory.get_transforms_eval(
args.model,
args.img_size)
dataset = Dataset( dataset = Dataset(
args.data, args.data,
transforms) transforms_imagenet_eval(args.model, args.img_size))
loader = data.DataLoader( loader = data.DataLoader(
dataset, dataset,