mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Fixup validate/inference script args, fix senet init for better test accuracy
This commit is contained in:
parent
b1a5a71151
commit
31055466fc
31
inference.py
31
inference.py
@ -10,10 +10,9 @@ import time
|
||||
import argparse
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.autograd as autograd
|
||||
import torch.utils.data as data
|
||||
|
||||
import model_factory
|
||||
from models import create_model, transforms_imagenet_eval
|
||||
from dataset import Dataset
|
||||
|
||||
|
||||
@ -32,12 +31,12 @@ parser.add_argument('--img-size', default=224, type=int,
|
||||
metavar='N', help='Input image dimension')
|
||||
parser.add_argument('--print-freq', '-p', default=10, type=int,
|
||||
metavar='N', help='print frequency (default: 10)')
|
||||
parser.add_argument('--restore-checkpoint', default='', type=str, metavar='PATH',
|
||||
parser.add_argument('--checkpoint', default='', type=str, metavar='PATH',
|
||||
help='path to latest checkpoint (default: none)')
|
||||
parser.add_argument('--pretrained', dest='pretrained', action='store_true',
|
||||
help='use pre-trained model')
|
||||
parser.add_argument('--multi-gpu', dest='multi_gpu', action='store_true',
|
||||
help='use multiple-gpus')
|
||||
parser.add_argument('--num-gpu', type=int, default=1,
|
||||
help='Number of GPUS to use')
|
||||
parser.add_argument('--no-test-pool', dest='test_time_pool', action='store_false',
|
||||
help='use pre-trained model')
|
||||
|
||||
@ -47,37 +46,33 @@ def main():
|
||||
|
||||
# create model
|
||||
num_classes = 1000
|
||||
model = model_factory.create_model(
|
||||
model = create_model(
|
||||
args.model,
|
||||
num_classes=num_classes,
|
||||
pretrained=args.pretrained,
|
||||
test_time_pool=args.test_time_pool)
|
||||
|
||||
# resume from a checkpoint
|
||||
if args.restore_checkpoint and os.path.isfile(args.restore_checkpoint):
|
||||
print("=> loading checkpoint '{}'".format(args.restore_checkpoint))
|
||||
checkpoint = torch.load(args.restore_checkpoint)
|
||||
if args.checkpoint and os.path.isfile(args.checkpoint):
|
||||
print("=> loading checkpoint '{}'".format(args.checkpoint))
|
||||
checkpoint = torch.load(args.checkpoint)
|
||||
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
|
||||
model.load_state_dict(checkpoint['state_dict'])
|
||||
else:
|
||||
model.load_state_dict(checkpoint)
|
||||
print("=> loaded checkpoint '{}'".format(args.restore_checkpoint))
|
||||
print("=> loaded checkpoint '{}'".format(args.checkpoint))
|
||||
elif not args.pretrained:
|
||||
print("=> no checkpoint found at '{}'".format(args.restore_checkpoint))
|
||||
print("=> no checkpoint found at '{}'".format(args.checkpoint))
|
||||
exit(1)
|
||||
|
||||
if args.multi_gpu:
|
||||
model = torch.nn.DataParallel(model).cuda()
|
||||
if args.num_gpu > 1:
|
||||
model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda()
|
||||
else:
|
||||
model = model.cuda()
|
||||
|
||||
transforms = model_factory.get_transforms_eval(
|
||||
args.model,
|
||||
args.img_size)
|
||||
|
||||
dataset = Dataset(
|
||||
args.data,
|
||||
transforms)
|
||||
transforms_imagenet_eval(args.model, args.img_size))
|
||||
|
||||
loader = data.DataLoader(
|
||||
dataset,
|
||||
|
@ -105,14 +105,10 @@ pretrained_config = {
|
||||
|
||||
|
||||
def _weight_init(m, n='', ll=''):
|
||||
print(m, n, ll)
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
if ll and n == ll:
|
||||
nn.init.constant_(m.weight, 0.)
|
||||
else:
|
||||
nn.init.constant_(m.weight, 1.)
|
||||
nn.init.constant_(m.weight, 1.)
|
||||
nn.init.constant_(m.bias, 0.)
|
||||
|
||||
|
||||
@ -128,9 +124,6 @@ class SEModule(nn.Module):
|
||||
channels // reduction, channels, kernel_size=1, padding=0)
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
|
||||
for m in self.modules():
|
||||
_weight_init(m)
|
||||
|
||||
def forward(self, x):
|
||||
module_input = x
|
||||
x = self.avg_pool(x)
|
||||
@ -191,9 +184,6 @@ class SEBottleneck(Bottleneck):
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
for n, m in self.named_modules():
|
||||
_weight_init(m, n, ll='bn3')
|
||||
|
||||
|
||||
class SEResNetBottleneck(Bottleneck):
|
||||
"""
|
||||
@ -219,9 +209,6 @@ class SEResNetBottleneck(Bottleneck):
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
for n, m in self.named_modules():
|
||||
_weight_init(m, n, ll='bn3')
|
||||
|
||||
|
||||
class SEResNeXtBottleneck(Bottleneck):
|
||||
"""
|
||||
@ -246,9 +233,6 @@ class SEResNeXtBottleneck(Bottleneck):
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
for n, m in self.named_modules():
|
||||
_weight_init(m, n, ll='bn3')
|
||||
|
||||
|
||||
class SEResNetBlock(nn.Module):
|
||||
expansion = 1
|
||||
@ -266,9 +250,6 @@ class SEResNetBlock(nn.Module):
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
for n, m in self.named_modules():
|
||||
_weight_init(m, n, ll='bn2')
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
|
||||
@ -405,11 +386,8 @@ class SENet(nn.Module):
|
||||
self.dropout = nn.Dropout(dropout_p) if dropout_p is not None else None
|
||||
self.last_linear = nn.Linear(512 * block.expansion, num_classes)
|
||||
|
||||
for n, m in self.named_children():
|
||||
if n == 'layer0':
|
||||
m.apply(_weight_init)
|
||||
else:
|
||||
_weight_init(m)
|
||||
for m in self.modules():
|
||||
_weight_init(m)
|
||||
|
||||
def _make_layer(self, block, planes, blocks, groups, reduction, stride=1,
|
||||
downsample_kernel_size=1, downsample_padding=0):
|
||||
|
@ -21,7 +21,7 @@ class LeNormalize(object):
|
||||
return tensor
|
||||
|
||||
|
||||
def transforms_imagenet_train(model_name, img_size=224, scale=(0.1, 1.0), color_jitter=(0.333, 0.333, 0.333)):
|
||||
def transforms_imagenet_train(model_name, img_size=224, scale=(0.1, 1.0), color_jitter=(0.4, 0.4, 0.4)):
|
||||
if 'dpn' in model_name:
|
||||
normalize = transforms.Normalize(
|
||||
mean=IMAGENET_DPN_MEAN,
|
||||
|
4
train.py
4
train.py
@ -180,8 +180,8 @@ def main():
|
||||
assert False and "Invalid optimizer"
|
||||
exit(1)
|
||||
|
||||
if optimizer_state is not None:
|
||||
optimizer.load_state_dict(optimizer_state)
|
||||
#if optimizer_state is not None:
|
||||
# optimizer.load_state_dict(optimizer_state)
|
||||
|
||||
if args.sched == 'cosine':
|
||||
lr_scheduler = scheduler.CosineLRScheduler(
|
||||
|
32
validate.py
32
validate.py
@ -12,7 +12,7 @@ import torch.nn.parallel
|
||||
import torch.utils.data as data
|
||||
|
||||
|
||||
from models import model_factory
|
||||
from models import create_model, transforms_imagenet_eval
|
||||
from dataset import Dataset
|
||||
|
||||
|
||||
@ -29,12 +29,12 @@ parser.add_argument('--img-size', default=224, type=int,
|
||||
metavar='N', help='Input image dimension')
|
||||
parser.add_argument('--print-freq', '-p', default=10, type=int,
|
||||
metavar='N', help='print frequency (default: 10)')
|
||||
parser.add_argument('--restore-checkpoint', default='', type=str, metavar='PATH',
|
||||
parser.add_argument('--checkpoint', default='', type=str, metavar='PATH',
|
||||
help='path to latest checkpoint (default: none)')
|
||||
parser.add_argument('--pretrained', dest='pretrained', action='store_true',
|
||||
help='use pre-trained model')
|
||||
parser.add_argument('--multi-gpu', dest='multi_gpu', action='store_true',
|
||||
help='use multiple-gpus')
|
||||
parser.add_argument('--num-gpu', type=int, default=1,
|
||||
help='Number of GPUS to use')
|
||||
parser.add_argument('--no-test-pool', dest='no_test_pool', action='store_true',
|
||||
help='disable test time pool for DPN models')
|
||||
|
||||
@ -48,7 +48,7 @@ def main():
|
||||
|
||||
# create model
|
||||
num_classes = 1000
|
||||
model = model_factory.create_model(
|
||||
model = create_model(
|
||||
args.model,
|
||||
num_classes=num_classes,
|
||||
pretrained=args.pretrained,
|
||||
@ -57,23 +57,21 @@ def main():
|
||||
print('Model %s created, param count: %d' %
|
||||
(args.model, sum([m.numel() for m in model.parameters()])))
|
||||
|
||||
print(model)
|
||||
|
||||
# optionally resume from a checkpoint
|
||||
if args.restore_checkpoint and os.path.isfile(args.restore_checkpoint):
|
||||
print("=> loading checkpoint '{}'".format(args.restore_checkpoint))
|
||||
checkpoint = torch.load(args.restore_checkpoint)
|
||||
if args.checkpoint and os.path.isfile(args.checkpoint):
|
||||
print("=> loading checkpoint '{}'".format(args.checkpoint))
|
||||
checkpoint = torch.load(args.checkpoint)
|
||||
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
|
||||
model.load_state_dict(checkpoint['state_dict'])
|
||||
else:
|
||||
model.load_state_dict(checkpoint)
|
||||
print("=> loaded checkpoint '{}'".format(args.restore_checkpoint))
|
||||
print("=> loaded checkpoint '{}'".format(args.checkpoint))
|
||||
elif not args.pretrained:
|
||||
print("=> no checkpoint found at '{}'".format(args.restore_checkpoint))
|
||||
print("=> no checkpoint found at '{}'".format(args.checkpoint))
|
||||
exit(1)
|
||||
|
||||
if args.multi_gpu:
|
||||
model = torch.nn.DataParallel(model).cuda()
|
||||
if args.num_gpu > 1:
|
||||
model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda()
|
||||
else:
|
||||
model = model.cuda()
|
||||
|
||||
@ -82,13 +80,9 @@ def main():
|
||||
|
||||
cudnn.benchmark = True
|
||||
|
||||
transforms = model_factory.get_transforms_eval(
|
||||
args.model,
|
||||
args.img_size)
|
||||
|
||||
dataset = Dataset(
|
||||
args.data,
|
||||
transforms)
|
||||
transforms_imagenet_eval(args.model, args.img_size))
|
||||
|
||||
loader = data.DataLoader(
|
||||
dataset,
|
||||
|
Loading…
x
Reference in New Issue
Block a user