mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Fixup validate/inference script args, fix senet init for better test accuracy
This commit is contained in:
parent
b1a5a71151
commit
31055466fc
31
inference.py
31
inference.py
@ -10,10 +10,9 @@ import time
|
|||||||
import argparse
|
import argparse
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.autograd as autograd
|
|
||||||
import torch.utils.data as data
|
import torch.utils.data as data
|
||||||
|
|
||||||
import model_factory
|
from models import create_model, transforms_imagenet_eval
|
||||||
from dataset import Dataset
|
from dataset import Dataset
|
||||||
|
|
||||||
|
|
||||||
@ -32,12 +31,12 @@ parser.add_argument('--img-size', default=224, type=int,
|
|||||||
metavar='N', help='Input image dimension')
|
metavar='N', help='Input image dimension')
|
||||||
parser.add_argument('--print-freq', '-p', default=10, type=int,
|
parser.add_argument('--print-freq', '-p', default=10, type=int,
|
||||||
metavar='N', help='print frequency (default: 10)')
|
metavar='N', help='print frequency (default: 10)')
|
||||||
parser.add_argument('--restore-checkpoint', default='', type=str, metavar='PATH',
|
parser.add_argument('--checkpoint', default='', type=str, metavar='PATH',
|
||||||
help='path to latest checkpoint (default: none)')
|
help='path to latest checkpoint (default: none)')
|
||||||
parser.add_argument('--pretrained', dest='pretrained', action='store_true',
|
parser.add_argument('--pretrained', dest='pretrained', action='store_true',
|
||||||
help='use pre-trained model')
|
help='use pre-trained model')
|
||||||
parser.add_argument('--multi-gpu', dest='multi_gpu', action='store_true',
|
parser.add_argument('--num-gpu', type=int, default=1,
|
||||||
help='use multiple-gpus')
|
help='Number of GPUS to use')
|
||||||
parser.add_argument('--no-test-pool', dest='test_time_pool', action='store_false',
|
parser.add_argument('--no-test-pool', dest='test_time_pool', action='store_false',
|
||||||
help='use pre-trained model')
|
help='use pre-trained model')
|
||||||
|
|
||||||
@ -47,37 +46,33 @@ def main():
|
|||||||
|
|
||||||
# create model
|
# create model
|
||||||
num_classes = 1000
|
num_classes = 1000
|
||||||
model = model_factory.create_model(
|
model = create_model(
|
||||||
args.model,
|
args.model,
|
||||||
num_classes=num_classes,
|
num_classes=num_classes,
|
||||||
pretrained=args.pretrained,
|
pretrained=args.pretrained,
|
||||||
test_time_pool=args.test_time_pool)
|
test_time_pool=args.test_time_pool)
|
||||||
|
|
||||||
# resume from a checkpoint
|
# resume from a checkpoint
|
||||||
if args.restore_checkpoint and os.path.isfile(args.restore_checkpoint):
|
if args.checkpoint and os.path.isfile(args.checkpoint):
|
||||||
print("=> loading checkpoint '{}'".format(args.restore_checkpoint))
|
print("=> loading checkpoint '{}'".format(args.checkpoint))
|
||||||
checkpoint = torch.load(args.restore_checkpoint)
|
checkpoint = torch.load(args.checkpoint)
|
||||||
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
|
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
|
||||||
model.load_state_dict(checkpoint['state_dict'])
|
model.load_state_dict(checkpoint['state_dict'])
|
||||||
else:
|
else:
|
||||||
model.load_state_dict(checkpoint)
|
model.load_state_dict(checkpoint)
|
||||||
print("=> loaded checkpoint '{}'".format(args.restore_checkpoint))
|
print("=> loaded checkpoint '{}'".format(args.checkpoint))
|
||||||
elif not args.pretrained:
|
elif not args.pretrained:
|
||||||
print("=> no checkpoint found at '{}'".format(args.restore_checkpoint))
|
print("=> no checkpoint found at '{}'".format(args.checkpoint))
|
||||||
exit(1)
|
exit(1)
|
||||||
|
|
||||||
if args.multi_gpu:
|
if args.num_gpu > 1:
|
||||||
model = torch.nn.DataParallel(model).cuda()
|
model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda()
|
||||||
else:
|
else:
|
||||||
model = model.cuda()
|
model = model.cuda()
|
||||||
|
|
||||||
transforms = model_factory.get_transforms_eval(
|
|
||||||
args.model,
|
|
||||||
args.img_size)
|
|
||||||
|
|
||||||
dataset = Dataset(
|
dataset = Dataset(
|
||||||
args.data,
|
args.data,
|
||||||
transforms)
|
transforms_imagenet_eval(args.model, args.img_size))
|
||||||
|
|
||||||
loader = data.DataLoader(
|
loader = data.DataLoader(
|
||||||
dataset,
|
dataset,
|
||||||
|
@ -105,13 +105,9 @@ pretrained_config = {
|
|||||||
|
|
||||||
|
|
||||||
def _weight_init(m, n='', ll=''):
|
def _weight_init(m, n='', ll=''):
|
||||||
print(m, n, ll)
|
|
||||||
if isinstance(m, nn.Conv2d):
|
if isinstance(m, nn.Conv2d):
|
||||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||||
elif isinstance(m, nn.BatchNorm2d):
|
elif isinstance(m, nn.BatchNorm2d):
|
||||||
if ll and n == ll:
|
|
||||||
nn.init.constant_(m.weight, 0.)
|
|
||||||
else:
|
|
||||||
nn.init.constant_(m.weight, 1.)
|
nn.init.constant_(m.weight, 1.)
|
||||||
nn.init.constant_(m.bias, 0.)
|
nn.init.constant_(m.bias, 0.)
|
||||||
|
|
||||||
@ -128,9 +124,6 @@ class SEModule(nn.Module):
|
|||||||
channels // reduction, channels, kernel_size=1, padding=0)
|
channels // reduction, channels, kernel_size=1, padding=0)
|
||||||
self.sigmoid = nn.Sigmoid()
|
self.sigmoid = nn.Sigmoid()
|
||||||
|
|
||||||
for m in self.modules():
|
|
||||||
_weight_init(m)
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
module_input = x
|
module_input = x
|
||||||
x = self.avg_pool(x)
|
x = self.avg_pool(x)
|
||||||
@ -191,9 +184,6 @@ class SEBottleneck(Bottleneck):
|
|||||||
self.downsample = downsample
|
self.downsample = downsample
|
||||||
self.stride = stride
|
self.stride = stride
|
||||||
|
|
||||||
for n, m in self.named_modules():
|
|
||||||
_weight_init(m, n, ll='bn3')
|
|
||||||
|
|
||||||
|
|
||||||
class SEResNetBottleneck(Bottleneck):
|
class SEResNetBottleneck(Bottleneck):
|
||||||
"""
|
"""
|
||||||
@ -219,9 +209,6 @@ class SEResNetBottleneck(Bottleneck):
|
|||||||
self.downsample = downsample
|
self.downsample = downsample
|
||||||
self.stride = stride
|
self.stride = stride
|
||||||
|
|
||||||
for n, m in self.named_modules():
|
|
||||||
_weight_init(m, n, ll='bn3')
|
|
||||||
|
|
||||||
|
|
||||||
class SEResNeXtBottleneck(Bottleneck):
|
class SEResNeXtBottleneck(Bottleneck):
|
||||||
"""
|
"""
|
||||||
@ -246,9 +233,6 @@ class SEResNeXtBottleneck(Bottleneck):
|
|||||||
self.downsample = downsample
|
self.downsample = downsample
|
||||||
self.stride = stride
|
self.stride = stride
|
||||||
|
|
||||||
for n, m in self.named_modules():
|
|
||||||
_weight_init(m, n, ll='bn3')
|
|
||||||
|
|
||||||
|
|
||||||
class SEResNetBlock(nn.Module):
|
class SEResNetBlock(nn.Module):
|
||||||
expansion = 1
|
expansion = 1
|
||||||
@ -266,9 +250,6 @@ class SEResNetBlock(nn.Module):
|
|||||||
self.downsample = downsample
|
self.downsample = downsample
|
||||||
self.stride = stride
|
self.stride = stride
|
||||||
|
|
||||||
for n, m in self.named_modules():
|
|
||||||
_weight_init(m, n, ll='bn2')
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
residual = x
|
residual = x
|
||||||
|
|
||||||
@ -405,10 +386,7 @@ class SENet(nn.Module):
|
|||||||
self.dropout = nn.Dropout(dropout_p) if dropout_p is not None else None
|
self.dropout = nn.Dropout(dropout_p) if dropout_p is not None else None
|
||||||
self.last_linear = nn.Linear(512 * block.expansion, num_classes)
|
self.last_linear = nn.Linear(512 * block.expansion, num_classes)
|
||||||
|
|
||||||
for n, m in self.named_children():
|
for m in self.modules():
|
||||||
if n == 'layer0':
|
|
||||||
m.apply(_weight_init)
|
|
||||||
else:
|
|
||||||
_weight_init(m)
|
_weight_init(m)
|
||||||
|
|
||||||
def _make_layer(self, block, planes, blocks, groups, reduction, stride=1,
|
def _make_layer(self, block, planes, blocks, groups, reduction, stride=1,
|
||||||
|
@ -21,7 +21,7 @@ class LeNormalize(object):
|
|||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
def transforms_imagenet_train(model_name, img_size=224, scale=(0.1, 1.0), color_jitter=(0.333, 0.333, 0.333)):
|
def transforms_imagenet_train(model_name, img_size=224, scale=(0.1, 1.0), color_jitter=(0.4, 0.4, 0.4)):
|
||||||
if 'dpn' in model_name:
|
if 'dpn' in model_name:
|
||||||
normalize = transforms.Normalize(
|
normalize = transforms.Normalize(
|
||||||
mean=IMAGENET_DPN_MEAN,
|
mean=IMAGENET_DPN_MEAN,
|
||||||
|
4
train.py
4
train.py
@ -180,8 +180,8 @@ def main():
|
|||||||
assert False and "Invalid optimizer"
|
assert False and "Invalid optimizer"
|
||||||
exit(1)
|
exit(1)
|
||||||
|
|
||||||
if optimizer_state is not None:
|
#if optimizer_state is not None:
|
||||||
optimizer.load_state_dict(optimizer_state)
|
# optimizer.load_state_dict(optimizer_state)
|
||||||
|
|
||||||
if args.sched == 'cosine':
|
if args.sched == 'cosine':
|
||||||
lr_scheduler = scheduler.CosineLRScheduler(
|
lr_scheduler = scheduler.CosineLRScheduler(
|
||||||
|
32
validate.py
32
validate.py
@ -12,7 +12,7 @@ import torch.nn.parallel
|
|||||||
import torch.utils.data as data
|
import torch.utils.data as data
|
||||||
|
|
||||||
|
|
||||||
from models import model_factory
|
from models import create_model, transforms_imagenet_eval
|
||||||
from dataset import Dataset
|
from dataset import Dataset
|
||||||
|
|
||||||
|
|
||||||
@ -29,12 +29,12 @@ parser.add_argument('--img-size', default=224, type=int,
|
|||||||
metavar='N', help='Input image dimension')
|
metavar='N', help='Input image dimension')
|
||||||
parser.add_argument('--print-freq', '-p', default=10, type=int,
|
parser.add_argument('--print-freq', '-p', default=10, type=int,
|
||||||
metavar='N', help='print frequency (default: 10)')
|
metavar='N', help='print frequency (default: 10)')
|
||||||
parser.add_argument('--restore-checkpoint', default='', type=str, metavar='PATH',
|
parser.add_argument('--checkpoint', default='', type=str, metavar='PATH',
|
||||||
help='path to latest checkpoint (default: none)')
|
help='path to latest checkpoint (default: none)')
|
||||||
parser.add_argument('--pretrained', dest='pretrained', action='store_true',
|
parser.add_argument('--pretrained', dest='pretrained', action='store_true',
|
||||||
help='use pre-trained model')
|
help='use pre-trained model')
|
||||||
parser.add_argument('--multi-gpu', dest='multi_gpu', action='store_true',
|
parser.add_argument('--num-gpu', type=int, default=1,
|
||||||
help='use multiple-gpus')
|
help='Number of GPUS to use')
|
||||||
parser.add_argument('--no-test-pool', dest='no_test_pool', action='store_true',
|
parser.add_argument('--no-test-pool', dest='no_test_pool', action='store_true',
|
||||||
help='disable test time pool for DPN models')
|
help='disable test time pool for DPN models')
|
||||||
|
|
||||||
@ -48,7 +48,7 @@ def main():
|
|||||||
|
|
||||||
# create model
|
# create model
|
||||||
num_classes = 1000
|
num_classes = 1000
|
||||||
model = model_factory.create_model(
|
model = create_model(
|
||||||
args.model,
|
args.model,
|
||||||
num_classes=num_classes,
|
num_classes=num_classes,
|
||||||
pretrained=args.pretrained,
|
pretrained=args.pretrained,
|
||||||
@ -57,23 +57,21 @@ def main():
|
|||||||
print('Model %s created, param count: %d' %
|
print('Model %s created, param count: %d' %
|
||||||
(args.model, sum([m.numel() for m in model.parameters()])))
|
(args.model, sum([m.numel() for m in model.parameters()])))
|
||||||
|
|
||||||
print(model)
|
|
||||||
|
|
||||||
# optionally resume from a checkpoint
|
# optionally resume from a checkpoint
|
||||||
if args.restore_checkpoint and os.path.isfile(args.restore_checkpoint):
|
if args.checkpoint and os.path.isfile(args.checkpoint):
|
||||||
print("=> loading checkpoint '{}'".format(args.restore_checkpoint))
|
print("=> loading checkpoint '{}'".format(args.checkpoint))
|
||||||
checkpoint = torch.load(args.restore_checkpoint)
|
checkpoint = torch.load(args.checkpoint)
|
||||||
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
|
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
|
||||||
model.load_state_dict(checkpoint['state_dict'])
|
model.load_state_dict(checkpoint['state_dict'])
|
||||||
else:
|
else:
|
||||||
model.load_state_dict(checkpoint)
|
model.load_state_dict(checkpoint)
|
||||||
print("=> loaded checkpoint '{}'".format(args.restore_checkpoint))
|
print("=> loaded checkpoint '{}'".format(args.checkpoint))
|
||||||
elif not args.pretrained:
|
elif not args.pretrained:
|
||||||
print("=> no checkpoint found at '{}'".format(args.restore_checkpoint))
|
print("=> no checkpoint found at '{}'".format(args.checkpoint))
|
||||||
exit(1)
|
exit(1)
|
||||||
|
|
||||||
if args.multi_gpu:
|
if args.num_gpu > 1:
|
||||||
model = torch.nn.DataParallel(model).cuda()
|
model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda()
|
||||||
else:
|
else:
|
||||||
model = model.cuda()
|
model = model.cuda()
|
||||||
|
|
||||||
@ -82,13 +80,9 @@ def main():
|
|||||||
|
|
||||||
cudnn.benchmark = True
|
cudnn.benchmark = True
|
||||||
|
|
||||||
transforms = model_factory.get_transforms_eval(
|
|
||||||
args.model,
|
|
||||||
args.img_size)
|
|
||||||
|
|
||||||
dataset = Dataset(
|
dataset = Dataset(
|
||||||
args.data,
|
args.data,
|
||||||
transforms)
|
transforms_imagenet_eval(args.model, args.img_size))
|
||||||
|
|
||||||
loader = data.DataLoader(
|
loader = data.DataLoader(
|
||||||
dataset,
|
dataset,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user