diff --git a/inference.py b/inference.py index 4ca5919d..0afb529e 100644 --- a/inference.py +++ b/inference.py @@ -10,10 +10,9 @@ import time import argparse import numpy as np import torch -import torch.autograd as autograd import torch.utils.data as data -import model_factory +from models import create_model, transforms_imagenet_eval from dataset import Dataset @@ -32,12 +31,12 @@ parser.add_argument('--img-size', default=224, type=int, metavar='N', help='Input image dimension') parser.add_argument('--print-freq', '-p', default=10, type=int, metavar='N', help='print frequency (default: 10)') -parser.add_argument('--restore-checkpoint', default='', type=str, metavar='PATH', +parser.add_argument('--checkpoint', default='', type=str, metavar='PATH', help='path to latest checkpoint (default: none)') parser.add_argument('--pretrained', dest='pretrained', action='store_true', help='use pre-trained model') -parser.add_argument('--multi-gpu', dest='multi_gpu', action='store_true', - help='use multiple-gpus') +parser.add_argument('--num-gpu', type=int, default=1, + help='Number of GPUS to use') parser.add_argument('--no-test-pool', dest='test_time_pool', action='store_false', help='use pre-trained model') @@ -47,37 +46,33 @@ def main(): # create model num_classes = 1000 - model = model_factory.create_model( + model = create_model( args.model, num_classes=num_classes, pretrained=args.pretrained, test_time_pool=args.test_time_pool) # resume from a checkpoint - if args.restore_checkpoint and os.path.isfile(args.restore_checkpoint): - print("=> loading checkpoint '{}'".format(args.restore_checkpoint)) - checkpoint = torch.load(args.restore_checkpoint) + if args.checkpoint and os.path.isfile(args.checkpoint): + print("=> loading checkpoint '{}'".format(args.checkpoint)) + checkpoint = torch.load(args.checkpoint) if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: model.load_state_dict(checkpoint['state_dict']) else: model.load_state_dict(checkpoint) - print("=> loaded checkpoint '{}'".format(args.restore_checkpoint)) + print("=> loaded checkpoint '{}'".format(args.checkpoint)) elif not args.pretrained: - print("=> no checkpoint found at '{}'".format(args.restore_checkpoint)) + print("=> no checkpoint found at '{}'".format(args.checkpoint)) exit(1) - if args.multi_gpu: - model = torch.nn.DataParallel(model).cuda() + if args.num_gpu > 1: + model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda() else: model = model.cuda() - transforms = model_factory.get_transforms_eval( - args.model, - args.img_size) - dataset = Dataset( args.data, - transforms) + transforms_imagenet_eval(args.model, args.img_size)) loader = data.DataLoader( dataset, diff --git a/models/senet.py b/models/senet.py index e0907ebf..d16ccf62 100644 --- a/models/senet.py +++ b/models/senet.py @@ -105,14 +105,10 @@ pretrained_config = { def _weight_init(m, n='', ll=''): - print(m, n, ll) if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') elif isinstance(m, nn.BatchNorm2d): - if ll and n == ll: - nn.init.constant_(m.weight, 0.) - else: - nn.init.constant_(m.weight, 1.) + nn.init.constant_(m.weight, 1.) nn.init.constant_(m.bias, 0.) @@ -128,9 +124,6 @@ class SEModule(nn.Module): channels // reduction, channels, kernel_size=1, padding=0) self.sigmoid = nn.Sigmoid() - for m in self.modules(): - _weight_init(m) - def forward(self, x): module_input = x x = self.avg_pool(x) @@ -191,9 +184,6 @@ class SEBottleneck(Bottleneck): self.downsample = downsample self.stride = stride - for n, m in self.named_modules(): - _weight_init(m, n, ll='bn3') - class SEResNetBottleneck(Bottleneck): """ @@ -219,9 +209,6 @@ class SEResNetBottleneck(Bottleneck): self.downsample = downsample self.stride = stride - for n, m in self.named_modules(): - _weight_init(m, n, ll='bn3') - class SEResNeXtBottleneck(Bottleneck): """ @@ -246,9 +233,6 @@ class SEResNeXtBottleneck(Bottleneck): self.downsample = downsample self.stride = stride - for n, m in self.named_modules(): - _weight_init(m, n, ll='bn3') - class SEResNetBlock(nn.Module): expansion = 1 @@ -266,9 +250,6 @@ class SEResNetBlock(nn.Module): self.downsample = downsample self.stride = stride - for n, m in self.named_modules(): - _weight_init(m, n, ll='bn2') - def forward(self, x): residual = x @@ -405,11 +386,8 @@ class SENet(nn.Module): self.dropout = nn.Dropout(dropout_p) if dropout_p is not None else None self.last_linear = nn.Linear(512 * block.expansion, num_classes) - for n, m in self.named_children(): - if n == 'layer0': - m.apply(_weight_init) - else: - _weight_init(m) + for m in self.modules(): + _weight_init(m) def _make_layer(self, block, planes, blocks, groups, reduction, stride=1, downsample_kernel_size=1, downsample_padding=0): diff --git a/models/transforms.py b/models/transforms.py index 6d54e891..49aaca57 100644 --- a/models/transforms.py +++ b/models/transforms.py @@ -21,7 +21,7 @@ class LeNormalize(object): return tensor -def transforms_imagenet_train(model_name, img_size=224, scale=(0.1, 1.0), color_jitter=(0.333, 0.333, 0.333)): +def transforms_imagenet_train(model_name, img_size=224, scale=(0.1, 1.0), color_jitter=(0.4, 0.4, 0.4)): if 'dpn' in model_name: normalize = transforms.Normalize( mean=IMAGENET_DPN_MEAN, diff --git a/train.py b/train.py index 82508e7e..43aeeaf5 100644 --- a/train.py +++ b/train.py @@ -180,8 +180,8 @@ def main(): assert False and "Invalid optimizer" exit(1) - if optimizer_state is not None: - optimizer.load_state_dict(optimizer_state) + #if optimizer_state is not None: + # optimizer.load_state_dict(optimizer_state) if args.sched == 'cosine': lr_scheduler = scheduler.CosineLRScheduler( diff --git a/validate.py b/validate.py index 0d4ce999..e08e1d95 100644 --- a/validate.py +++ b/validate.py @@ -12,7 +12,7 @@ import torch.nn.parallel import torch.utils.data as data -from models import model_factory +from models import create_model, transforms_imagenet_eval from dataset import Dataset @@ -29,12 +29,12 @@ parser.add_argument('--img-size', default=224, type=int, metavar='N', help='Input image dimension') parser.add_argument('--print-freq', '-p', default=10, type=int, metavar='N', help='print frequency (default: 10)') -parser.add_argument('--restore-checkpoint', default='', type=str, metavar='PATH', +parser.add_argument('--checkpoint', default='', type=str, metavar='PATH', help='path to latest checkpoint (default: none)') parser.add_argument('--pretrained', dest='pretrained', action='store_true', help='use pre-trained model') -parser.add_argument('--multi-gpu', dest='multi_gpu', action='store_true', - help='use multiple-gpus') +parser.add_argument('--num-gpu', type=int, default=1, + help='Number of GPUS to use') parser.add_argument('--no-test-pool', dest='no_test_pool', action='store_true', help='disable test time pool for DPN models') @@ -48,7 +48,7 @@ def main(): # create model num_classes = 1000 - model = model_factory.create_model( + model = create_model( args.model, num_classes=num_classes, pretrained=args.pretrained, @@ -57,23 +57,21 @@ def main(): print('Model %s created, param count: %d' % (args.model, sum([m.numel() for m in model.parameters()]))) - print(model) - # optionally resume from a checkpoint - if args.restore_checkpoint and os.path.isfile(args.restore_checkpoint): - print("=> loading checkpoint '{}'".format(args.restore_checkpoint)) - checkpoint = torch.load(args.restore_checkpoint) + if args.checkpoint and os.path.isfile(args.checkpoint): + print("=> loading checkpoint '{}'".format(args.checkpoint)) + checkpoint = torch.load(args.checkpoint) if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: model.load_state_dict(checkpoint['state_dict']) else: model.load_state_dict(checkpoint) - print("=> loaded checkpoint '{}'".format(args.restore_checkpoint)) + print("=> loaded checkpoint '{}'".format(args.checkpoint)) elif not args.pretrained: - print("=> no checkpoint found at '{}'".format(args.restore_checkpoint)) + print("=> no checkpoint found at '{}'".format(args.checkpoint)) exit(1) - if args.multi_gpu: - model = torch.nn.DataParallel(model).cuda() + if args.num_gpu > 1: + model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda() else: model = model.cuda() @@ -82,13 +80,9 @@ def main(): cudnn.benchmark = True - transforms = model_factory.get_transforms_eval( - args.model, - args.img_size) - dataset = Dataset( args.data, - transforms) + transforms_imagenet_eval(args.model, args.img_size)) loader = data.DataLoader( dataset,