mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Remove all prints, change most to logging calls, tweak alignment of batch logs, improve setup.py
This commit is contained in:
parent
1d7f2d93a6
commit
6fc886acaf
@ -30,7 +30,7 @@ I've included a few of my favourite models, but this is not an exhaustive collec
|
||||
* DPN (from [me](https://github.com/rwightman/pytorch-dpn-pretrained), weights hosted by Cadene)
|
||||
* DPN-68, DPN-68b, DPN-92, DPN-98, DPN-131, DPN-107
|
||||
* Generic EfficientNet (from my standalone [GenMobileNet](https://github.com/rwightman/genmobilenet-pytorch)) - A generic model that implements many of the mobile optimized architecture search derived models that utilize similar DepthwiseSeparable and InvertedResidual blocks
|
||||
* EfficientNet (B0-B4) (https://arxiv.org/abs/1905.11946) -- validated, compat with TF weights
|
||||
* EfficientNet (B0-B5) (https://arxiv.org/abs/1905.11946) -- validated, compat with TF weights
|
||||
* MNASNet B1, A1 (Squeeze-Excite), and Small (https://arxiv.org/abs/1807.11626)
|
||||
* MobileNet-V1 (https://arxiv.org/abs/1704.04861)
|
||||
* MobileNet-V2 (https://arxiv.org/abs/1801.04381)
|
||||
@ -187,9 +187,6 @@ To run inference from a checkpoint:
|
||||
|
||||
## TODO
|
||||
A number of additions planned in the future for various projects, incl
|
||||
* Find optimal training hyperparams and create/port pretraiend weights for the generic MobileNet variants
|
||||
* Do a model performance (speed + accuracy) benchmarking across all models (make runable as script)
|
||||
* More training experiments
|
||||
* Make folder/file layout compat with usage as a module
|
||||
* Add usage examples to comments, good hyper params for training
|
||||
* Comments, cleanup and the usual things that get pushed back
|
||||
|
18
inference.py
18
inference.py
@ -8,12 +8,13 @@ from __future__ import print_function
|
||||
import os
|
||||
import time
|
||||
import argparse
|
||||
import logging
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from timm.models import create_model, apply_test_time_pool
|
||||
from timm.data import Dataset, create_loader, resolve_data_config
|
||||
from timm.utils import AverageMeter
|
||||
from timm.utils import AverageMeter, setup_default_logging
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
@ -38,8 +39,8 @@ parser.add_argument('--interpolation', default='', type=str, metavar='NAME',
|
||||
help='Image resize interpolation type (overrides model)')
|
||||
parser.add_argument('--num-classes', type=int, default=1000,
|
||||
help='Number classes in dataset')
|
||||
parser.add_argument('--print-freq', '-p', default=10, type=int,
|
||||
metavar='N', help='print frequency (default: 10)')
|
||||
parser.add_argument('--log-freq', default=10, type=int,
|
||||
metavar='N', help='batch logging frequency (default: 10)')
|
||||
parser.add_argument('--checkpoint', default='', type=str, metavar='PATH',
|
||||
help='path to latest checkpoint (default: none)')
|
||||
parser.add_argument('--pretrained', dest='pretrained', action='store_true',
|
||||
@ -53,8 +54,8 @@ parser.add_argument('--topk', default=5, type=int,
|
||||
|
||||
|
||||
def main():
|
||||
setup_default_logging()
|
||||
args = parser.parse_args()
|
||||
|
||||
# might as well try to do something useful...
|
||||
args.pretrained = args.pretrained or not args.checkpoint
|
||||
|
||||
@ -66,8 +67,8 @@ def main():
|
||||
pretrained=args.pretrained,
|
||||
checkpoint_path=args.checkpoint)
|
||||
|
||||
print('Model %s created, param count: %d' %
|
||||
(args.model, sum([m.numel() for m in model.parameters()])))
|
||||
logging.info('Model %s created, param count: %d' %
|
||||
(args.model, sum([m.numel() for m in model.parameters()])))
|
||||
|
||||
config = resolve_data_config(model, args)
|
||||
model, test_time_pool = apply_test_time_pool(model, config, args)
|
||||
@ -105,9 +106,8 @@ def main():
|
||||
batch_time.update(time.time() - end)
|
||||
end = time.time()
|
||||
|
||||
if batch_idx % args.print_freq == 0:
|
||||
print('Predict: [{0}/{1}]\t'
|
||||
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})'.format(
|
||||
if batch_idx % args.log_freq == 0:
|
||||
logging.info('Predict: [{0}/{1}] Time {batch_time.val:.3f} ({batch_time.avg:.3f})'.format(
|
||||
batch_idx, len(loader), batch_time=batch_time))
|
||||
|
||||
topk_ids = np.concatenate(topk_ids, axis=0).squeeze()
|
||||
|
16
setup.py
16
setup.py
@ -19,21 +19,27 @@ setup(
|
||||
url='https://github.com/rwightman/pytorch-image-models',
|
||||
author='Ross Wightman',
|
||||
author_email='hello@rwightman.com',
|
||||
classifiers=[ # Optional
|
||||
classifiers=[
|
||||
# How mature is this project? Common values are
|
||||
# 3 - Alpha
|
||||
# 4 - Beta
|
||||
# 5 - Production/Stable
|
||||
'Development Status :: 3 - Alpha',
|
||||
'Intended Audience :: Developers',
|
||||
'Topic :: Software Development :: Build Tools',
|
||||
'License :: OSI Approved :: Apache License',
|
||||
'Intended Audience :: Education',
|
||||
'Intended Audience :: Science/Research',
|
||||
'License :: OSI Approved :: Apache Software License',
|
||||
'Programming Language :: Python :: 3.6',
|
||||
'Programming Language :: Python :: 3.7',
|
||||
'Topic :: Scientific/Engineering',
|
||||
'Topic :: Scientific/Engineering :: Artificial Intelligence',
|
||||
'Topic :: Software Development',
|
||||
'Topic :: Software Development :: Libraries',
|
||||
'Topic :: Software Development :: Libraries :: Python Modules',
|
||||
],
|
||||
|
||||
# Note that this is a string of words separated by whitespace, not a list.
|
||||
keywords='pytorch pretrained models efficientnet mobilenetv3 mnasnet',
|
||||
packages=find_packages(exclude=['convert']),
|
||||
install_requires=['torch', 'torchvision'],
|
||||
install_requires=['torch >= 1.0', 'torchvision'],
|
||||
python_requires='>=3.6',
|
||||
)
|
||||
|
@ -1,3 +1,4 @@
|
||||
import logging
|
||||
from .constants import *
|
||||
|
||||
|
||||
@ -56,9 +57,9 @@ def resolve_data_config(model, args, default_cfg={}, verbose=True):
|
||||
new_config['crop_pct'] = default_cfg['crop_pct']
|
||||
|
||||
if verbose:
|
||||
print('Data processing configuration for current model + dataset:')
|
||||
logging.info('Data processing configuration for current model + dataset:')
|
||||
for n, v in new_config.items():
|
||||
print('\t%s: %s' % (n, str(v)))
|
||||
logging.info('\t%s: %s' % (n, str(v)))
|
||||
|
||||
return new_config
|
||||
|
||||
|
@ -82,7 +82,7 @@ class SelectAdaptivePool2d(nn.Module):
|
||||
self.pool = nn.AdaptiveMaxPool2d(output_size)
|
||||
else:
|
||||
if pool_type != 'avg':
|
||||
print('Invalid pool type %s specified. Defaulting to average pooling.' % pool_type)
|
||||
assert False, 'Invalid pool type: %s' % pool_type
|
||||
self.pool = nn.AdaptiveAvgPool2d(output_size)
|
||||
|
||||
def forward(self, x):
|
||||
|
@ -86,7 +86,6 @@ def densenet161(num_classes=1000, in_chans=3, pretrained=False, **kwargs):
|
||||
r"""Densenet-201 model from
|
||||
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
|
||||
"""
|
||||
print(num_classes, in_chans, pretrained)
|
||||
default_cfg = default_cfgs['densenet161']
|
||||
model = DenseNet(num_init_features=96, growth_rate=48, block_config=(6, 12, 36, 24),
|
||||
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||
|
@ -17,6 +17,7 @@ Hacked together by Ross Wightman
|
||||
|
||||
import math
|
||||
import re
|
||||
import logging
|
||||
from copy import deepcopy
|
||||
|
||||
import torch
|
||||
@ -336,7 +337,7 @@ class _BlockBuilder:
|
||||
ba['act_fn'] = ba['act_fn'] if ba['act_fn'] is not None else self.act_fn
|
||||
assert ba['act_fn'] is not None
|
||||
if self.verbose:
|
||||
print('args:', ba)
|
||||
logging.info(' Args: {}'.format(str(ba)))
|
||||
# could replace this if with lambdas or functools binding if variety increases
|
||||
if bt == 'ir':
|
||||
ba['drop_connect_rate'] = self.drop_connect_rate
|
||||
@ -358,7 +359,7 @@ class _BlockBuilder:
|
||||
# each stack (stage) contains a list of block arguments
|
||||
for block_idx, ba in enumerate(stack_args):
|
||||
if self.verbose:
|
||||
print('block', block_idx, end=', ')
|
||||
logging.info(' Block: {}'.format(block_idx))
|
||||
if block_idx >= 1:
|
||||
# only the first block in any stack/stage can have a stride > 1
|
||||
ba['stride'] = 1
|
||||
@ -370,24 +371,22 @@ class _BlockBuilder:
|
||||
""" Build the blocks
|
||||
Args:
|
||||
in_chs: Number of input-channels passed to first block
|
||||
arch_def: A list of lists, outer list defines stacks (or stages), inner
|
||||
block_args: A list of lists, outer list defines stages, inner
|
||||
list contains strings defining block configuration(s)
|
||||
Return:
|
||||
List of block stacks (each stack wrapped in nn.Sequential)
|
||||
"""
|
||||
if self.verbose:
|
||||
print('Building model trunk with %d stacks (stages)...' % len(block_args))
|
||||
logging.info('Building model trunk with %d stages...' % len(block_args))
|
||||
self.in_chs = in_chs
|
||||
blocks = []
|
||||
# outer list of block_args defines the stacks ('stages' by some conventions)
|
||||
for stack_idx, stack in enumerate(block_args):
|
||||
if self.verbose:
|
||||
print('stack', stack_idx)
|
||||
logging.info('Stack: {}'.format(stack_idx))
|
||||
assert isinstance(stack, list)
|
||||
stack = self._make_stack(stack)
|
||||
blocks.append(stack)
|
||||
if self.verbose:
|
||||
print()
|
||||
return blocks
|
||||
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
import torch
|
||||
import torch.utils.model_zoo as model_zoo
|
||||
import os
|
||||
import logging
|
||||
from collections import OrderedDict
|
||||
|
||||
|
||||
@ -21,9 +22,9 @@ def load_checkpoint(model, checkpoint_path, use_ema=False):
|
||||
model.load_state_dict(new_state_dict)
|
||||
else:
|
||||
model.load_state_dict(checkpoint)
|
||||
print("=> Loaded {} from checkpoint '{}'".format(state_dict_key or 'weights', checkpoint_path))
|
||||
logging.info("Loaded {} from checkpoint '{}'".format(state_dict_key or 'weights', checkpoint_path))
|
||||
else:
|
||||
print("=> Error: No checkpoint found at '{}'".format(checkpoint_path))
|
||||
logging.error("No checkpoint found at '{}'".format(checkpoint_path))
|
||||
raise FileNotFoundError()
|
||||
|
||||
|
||||
@ -40,27 +41,27 @@ def resume_checkpoint(model, checkpoint_path, start_epoch=None):
|
||||
if 'optimizer' in checkpoint:
|
||||
optimizer_state = checkpoint['optimizer']
|
||||
start_epoch = checkpoint['epoch'] if start_epoch is None else start_epoch
|
||||
print("=> Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch']))
|
||||
logging.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch']))
|
||||
else:
|
||||
model.load_state_dict(checkpoint)
|
||||
start_epoch = 0 if start_epoch is None else start_epoch
|
||||
print("=> Loaded checkpoint '{}'".format(checkpoint_path))
|
||||
logging.info("Loaded checkpoint '{}'".format(checkpoint_path))
|
||||
return optimizer_state, start_epoch
|
||||
else:
|
||||
print("=> Error: No checkpoint found at '{}'".format(checkpoint_path))
|
||||
logging.error("No checkpoint found at '{}'".format(checkpoint_path))
|
||||
raise FileNotFoundError()
|
||||
|
||||
|
||||
def load_pretrained(model, default_cfg, num_classes=1000, in_chans=3, filter_fn=None):
|
||||
if 'url' not in default_cfg or not default_cfg['url']:
|
||||
print("Warning: pretrained model URL is invalid, using random initialization.")
|
||||
logging.warning("Pretrained model URL is invalid, using random initialization.")
|
||||
return
|
||||
|
||||
state_dict = model_zoo.load_url(default_cfg['url'])
|
||||
|
||||
if in_chans == 1:
|
||||
conv1_name = default_cfg['first_conv']
|
||||
print('Converting first conv (%s) from 3 to 1 channel' % conv1_name)
|
||||
logging.info('Converting first conv (%s) from 3 to 1 channel' % conv1_name)
|
||||
conv1_weight = state_dict[conv1_name + '.weight']
|
||||
state_dict[conv1_name + '.weight'] = conv1_weight.sum(dim=1, keepdim=True)
|
||||
elif in_chans != 3:
|
||||
|
@ -1,3 +1,4 @@
|
||||
import logging
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from .adaptive_avgmax_pool import adaptive_avgmax_pool2d
|
||||
@ -31,8 +32,8 @@ def apply_test_time_pool(model, config, args):
|
||||
if not args.no_test_pool and \
|
||||
config['input_size'][-1] > model.default_cfg['input_size'][-1] and \
|
||||
config['input_size'][-2] > model.default_cfg['input_size'][-2]:
|
||||
print('Target input size %s > pretrained default %s, using test time pooling' %
|
||||
(str(config['input_size'][-2:]), str(model.default_cfg['input_size'][-2:])))
|
||||
logging.info('Target input size %s > pretrained default %s, using test time pooling' %
|
||||
(str(config['input_size'][-2:]), str(model.default_cfg['input_size'][-2:])))
|
||||
model = TestTimePoolHead(model, original_pool=model.default_cfg['pool_size'])
|
||||
test_time_pool = True
|
||||
return model, test_time_pool
|
||||
|
@ -50,7 +50,6 @@ class TanhLRScheduler(Scheduler):
|
||||
self.t_in_epochs = t_in_epochs
|
||||
if self.warmup_t:
|
||||
t_v = self.base_values if self.warmup_prefix else self._get_lr(self.warmup_t)
|
||||
print(t_v)
|
||||
self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in t_v]
|
||||
super().update_groups(self.warmup_lr_init)
|
||||
else:
|
||||
|
@ -8,6 +8,7 @@ import shutil
|
||||
import glob
|
||||
import csv
|
||||
import operator
|
||||
import logging
|
||||
import numpy as np
|
||||
from collections import OrderedDict
|
||||
|
||||
@ -18,7 +19,7 @@ def get_state_dict(model):
|
||||
if isinstance(model, ModelEma):
|
||||
return get_state_dict(model.ema)
|
||||
else:
|
||||
return model.module.state_dict() if getattr(model, 'module') else model.state_dict()
|
||||
return model.module.state_dict() if hasattr(model, 'module') else model.state_dict()
|
||||
|
||||
|
||||
class CheckpointSaver:
|
||||
@ -29,7 +30,6 @@ class CheckpointSaver:
|
||||
checkpoint_dir='',
|
||||
recovery_dir='',
|
||||
decreasing=False,
|
||||
verbose=True,
|
||||
max_history=10):
|
||||
|
||||
# state
|
||||
@ -47,7 +47,6 @@ class CheckpointSaver:
|
||||
self.extension = '.pth.tar'
|
||||
self.decreasing = decreasing # a lower metric is better if True
|
||||
self.cmp = operator.lt if decreasing else operator.gt # True if lhs better than rhs
|
||||
self.verbose = verbose
|
||||
self.max_history = max_history
|
||||
assert self.max_history >= 1
|
||||
|
||||
@ -66,11 +65,6 @@ class CheckpointSaver:
|
||||
self.checkpoint_files, key=lambda x: x[1],
|
||||
reverse=not self.decreasing) # sort in descending order if a lower metric is not better
|
||||
|
||||
if self.verbose:
|
||||
print("Current checkpoints:")
|
||||
for c in self.checkpoint_files:
|
||||
print(c)
|
||||
|
||||
if metric is not None and (self.best_metric is None or self.cmp(metric, self.best_metric)):
|
||||
self.best_epoch = epoch
|
||||
self.best_metric = metric
|
||||
@ -100,11 +94,10 @@ class CheckpointSaver:
|
||||
to_delete = self.checkpoint_files[delete_index:]
|
||||
for d in to_delete:
|
||||
try:
|
||||
if self.verbose:
|
||||
print('Cleaning checkpoint: ', d)
|
||||
logging.debug("Cleaning checkpoint: {}".format(d))
|
||||
os.remove(d[0])
|
||||
except Exception as e:
|
||||
print('Exception (%s) while deleting checkpoint' % str(e))
|
||||
logging.error("Exception '{}' while deleting checkpoint".format(e))
|
||||
self.checkpoint_files = self.checkpoint_files[:delete_index]
|
||||
|
||||
def save_recovery(self, model, optimizer, args, epoch, model_ema=None, batch_idx=0):
|
||||
@ -114,11 +107,10 @@ class CheckpointSaver:
|
||||
self._save(save_path, model, optimizer, args, epoch, model_ema)
|
||||
if os.path.exists(self.last_recovery_file):
|
||||
try:
|
||||
if self.verbose:
|
||||
print('Cleaning recovery', self.last_recovery_file)
|
||||
logging.debug("Cleaning recovery: {}".format(self.last_recovery_file))
|
||||
os.remove(self.last_recovery_file)
|
||||
except Exception as e:
|
||||
print("Exception (%s) while removing %s" % (str(e), self.last_recovery_file))
|
||||
logging.error("Exception '{}' while removing {}".format(e, self.last_recovery_file))
|
||||
self.last_recovery_file = self.curr_recovery_file
|
||||
self.curr_recovery_file = save_path
|
||||
|
||||
@ -253,9 +245,9 @@ class ModelEma:
|
||||
name = k
|
||||
new_state_dict[name] = v
|
||||
self.ema.load_state_dict(new_state_dict)
|
||||
print("=> Loaded state_dict_ema")
|
||||
logging.info("Loaded state_dict_ema")
|
||||
else:
|
||||
print("=> Failed to find state_dict_ema, starting from loaded model weights")
|
||||
logging.warning("Failed to find state_dict_ema, starting from loaded model weights")
|
||||
|
||||
def update(self, model):
|
||||
# correct a mismatch in state dict keys
|
||||
@ -269,3 +261,20 @@ class ModelEma:
|
||||
if self.device:
|
||||
model_v = model_v.to(device=self.device)
|
||||
ema_v.copy_(ema_v * self.decay + (1. - self.decay) * model_v)
|
||||
|
||||
|
||||
class FormatterNoInfo(logging.Formatter):
|
||||
def __init__(self, fmt='%(levelname)s: %(message)s'):
|
||||
logging.Formatter.__init__(self, fmt)
|
||||
|
||||
def format(self, record):
|
||||
if record.levelno == logging.INFO:
|
||||
return str(record.getMessage())
|
||||
return logging.Formatter.format(self, record)
|
||||
|
||||
|
||||
def setup_default_logging(default_level=logging.INFO):
|
||||
console_handler = logging.StreamHandler()
|
||||
console_handler.setFormatter(FormatterNoInfo())
|
||||
logging.root.addHandler(console_handler)
|
||||
logging.root.setLevel(default_level)
|
||||
|
82
train.py
82
train.py
@ -1,6 +1,7 @@
|
||||
|
||||
import argparse
|
||||
import time
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
try:
|
||||
@ -127,14 +128,14 @@ parser.add_argument("--local_rank", default=0, type=int)
|
||||
|
||||
|
||||
def main():
|
||||
setup_default_logging()
|
||||
args = parser.parse_args()
|
||||
|
||||
args.prefetcher = not args.no_prefetcher
|
||||
args.distributed = False
|
||||
if 'WORLD_SIZE' in os.environ:
|
||||
args.distributed = int(os.environ['WORLD_SIZE']) > 1
|
||||
if args.distributed and args.num_gpu > 1:
|
||||
print('Using more than one GPU per process in distributed mode is not allowed. Setting num_gpu to 1.')
|
||||
logging.warning('Using more than one GPU per process in distributed mode is not allowed. Setting num_gpu to 1.')
|
||||
args.num_gpu = 1
|
||||
|
||||
args.device = 'cuda:0'
|
||||
@ -144,17 +145,16 @@ def main():
|
||||
args.num_gpu = 1
|
||||
args.device = 'cuda:%d' % args.local_rank
|
||||
torch.cuda.set_device(args.local_rank)
|
||||
torch.distributed.init_process_group(
|
||||
backend='nccl', init_method='env://')
|
||||
torch.distributed.init_process_group(backend='nccl', init_method='env://')
|
||||
args.world_size = torch.distributed.get_world_size()
|
||||
args.rank = torch.distributed.get_rank()
|
||||
assert args.rank >= 0
|
||||
|
||||
if args.distributed:
|
||||
print('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
|
||||
% (args.rank, args.world_size))
|
||||
logging.info('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
|
||||
% (args.rank, args.world_size))
|
||||
else:
|
||||
print('Training with a single process on %d GPUs.' % args.num_gpu)
|
||||
logging.info('Training with a single process on %d GPUs.' % args.num_gpu)
|
||||
|
||||
torch.manual_seed(args.seed + args.rank)
|
||||
|
||||
@ -169,8 +169,8 @@ def main():
|
||||
bn_eps=args.bn_eps,
|
||||
checkpoint_path=args.initial_checkpoint)
|
||||
|
||||
print('Model %s created, param count: %d' %
|
||||
(args.model, sum([m.numel() for m in model.parameters()])))
|
||||
logging.info('Model %s created, param count: %d' %
|
||||
(args.model, sum([m.numel() for m in model.parameters()])))
|
||||
|
||||
data_config = resolve_data_config(model, args, verbose=args.local_rank == 0)
|
||||
|
||||
@ -182,8 +182,8 @@ def main():
|
||||
|
||||
if args.num_gpu > 1:
|
||||
if args.amp:
|
||||
print('Warning: AMP does not work well with nn.DataParallel, disabling. '
|
||||
'Use distributed mode for multi-GPU AMP.')
|
||||
logging.warning(
|
||||
'AMP does not work well with nn.DataParallel, disabling. Use distributed mode for multi-GPU AMP.')
|
||||
args.amp = False
|
||||
model = nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda()
|
||||
else:
|
||||
@ -198,10 +198,10 @@ def main():
|
||||
if has_apex and args.amp:
|
||||
model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
|
||||
use_amp = True
|
||||
print('AMP enabled')
|
||||
logging.info('AMP enabled')
|
||||
else:
|
||||
use_amp = False
|
||||
print('AMP disabled')
|
||||
logging.info('AMP disabled')
|
||||
|
||||
model_ema = None
|
||||
if args.model_ema:
|
||||
@ -222,11 +222,11 @@ def main():
|
||||
if start_epoch > 0:
|
||||
lr_scheduler.step(start_epoch)
|
||||
if args.local_rank == 0:
|
||||
print('Scheduled epochs: ', num_epochs)
|
||||
logging.info('Scheduled epochs: {}'.format(num_epochs))
|
||||
|
||||
train_dir = os.path.join(args.data, 'train')
|
||||
if not os.path.exists(train_dir):
|
||||
print('Error: training folder does not exist at: %s' % train_dir)
|
||||
logging.error('Training folder does not exist at: {}'.format(train_dir))
|
||||
exit(1)
|
||||
dataset_train = Dataset(train_dir)
|
||||
|
||||
@ -252,7 +252,7 @@ def main():
|
||||
|
||||
eval_dir = os.path.join(args.data, 'validation')
|
||||
if not os.path.isdir(eval_dir):
|
||||
print('Error: validation folder does not exist at: %s' % eval_dir)
|
||||
logging.error('Validation folder does not exist at: {}'.format(eval_dir))
|
||||
exit(1)
|
||||
dataset_eval = Dataset(eval_dir)
|
||||
|
||||
@ -332,7 +332,7 @@ def main():
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
if best_metric is not None:
|
||||
print('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch))
|
||||
logging.info('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch))
|
||||
|
||||
|
||||
def train_epoch(
|
||||
@ -394,21 +394,22 @@ def train_epoch(
|
||||
losses_m.update(reduced_loss.item(), input.size(0))
|
||||
|
||||
if args.local_rank == 0:
|
||||
print('Train: {} [{}/{} ({:.0f}%)] '
|
||||
'Loss: {loss.val:.6f} ({loss.avg:.4f}) '
|
||||
'Time: {batch_time.val:.3f}s, {rate:.3f}/s '
|
||||
'({batch_time.avg:.3f}s, {rate_avg:.3f}/s) '
|
||||
'LR: {lr:.4f} '
|
||||
'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format(
|
||||
epoch,
|
||||
batch_idx, len(loader),
|
||||
100. * batch_idx / last_idx,
|
||||
loss=losses_m,
|
||||
batch_time=batch_time_m,
|
||||
rate=input.size(0) * args.world_size / batch_time_m.val,
|
||||
rate_avg=input.size(0) * args.world_size / batch_time_m.avg,
|
||||
lr=lr,
|
||||
data_time=data_time_m))
|
||||
logging.info(
|
||||
'Train: {} [{:>4d}/{} ({:>3.0f}%)] '
|
||||
'Loss: {loss.val:>9.6f} ({loss.avg:>6.4f}) '
|
||||
'Time: {batch_time.val:.3f}s, {rate:>7.2f}/s '
|
||||
'({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) '
|
||||
'LR: {lr:.3e} '
|
||||
'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format(
|
||||
epoch,
|
||||
batch_idx, len(loader),
|
||||
100. * batch_idx / last_idx,
|
||||
loss=losses_m,
|
||||
batch_time=batch_time_m,
|
||||
rate=input.size(0) * args.world_size / batch_time_m.val,
|
||||
rate_avg=input.size(0) * args.world_size / batch_time_m.avg,
|
||||
lr=lr,
|
||||
data_time=data_time_m))
|
||||
|
||||
if args.save_images and output_dir:
|
||||
torchvision.utils.save_image(
|
||||
@ -478,14 +479,15 @@ def validate(model, loader, loss_fn, args, log_suffix=''):
|
||||
end = time.time()
|
||||
if args.local_rank == 0 and (last_batch or batch_idx % args.log_interval == 0):
|
||||
log_name = 'Test' + log_suffix
|
||||
print('{0}: [{1}/{2}]\t'
|
||||
'Time {batch_time.val:.3f} ({batch_time.avg:.3f}) '
|
||||
'Loss {loss.val:.4f} ({loss.avg:.4f}) '
|
||||
'Prec@1 {top1.val:.4f} ({top1.avg:.4f}) '
|
||||
'Prec@5 {top5.val:.4f} ({top5.avg:.4f})'.format(
|
||||
log_name, batch_idx, last_idx,
|
||||
batch_time=batch_time_m, loss=losses_m,
|
||||
top1=prec1_m, top5=prec5_m))
|
||||
logging.info(
|
||||
'{0}: [{1:>4d}/{2}] '
|
||||
'Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) '
|
||||
'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) '
|
||||
'Prec@1: {top1.val:>7.4f} ({top1.avg:>7.4f}) '
|
||||
'Prec@5: {top5.val:>7.4f} ({top5.avg:>7.4f})'.format(
|
||||
log_name, batch_idx, last_idx,
|
||||
batch_time=batch_time_m, loss=losses_m,
|
||||
top1=prec1_m, top5=prec5_m))
|
||||
|
||||
metrics = OrderedDict([('loss', losses_m.avg), ('prec1', prec1_m.avg), ('prec5', prec5_m.avg)])
|
||||
|
||||
|
31
validate.py
31
validate.py
@ -7,6 +7,7 @@ import os
|
||||
import csv
|
||||
import glob
|
||||
import time
|
||||
import logging
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.parallel
|
||||
@ -14,7 +15,7 @@ from collections import OrderedDict
|
||||
|
||||
from timm.models import create_model, apply_test_time_pool, load_checkpoint
|
||||
from timm.data import Dataset, create_loader, resolve_data_config
|
||||
from timm.utils import accuracy, AverageMeter, natural_key
|
||||
from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
@ -37,8 +38,8 @@ parser.add_argument('--interpolation', default='', type=str, metavar='NAME',
|
||||
help='Image resize interpolation type (overrides model)')
|
||||
parser.add_argument('--num-classes', type=int, default=1000,
|
||||
help='Number classes in dataset')
|
||||
parser.add_argument('--print-freq', '-p', default=10, type=int,
|
||||
metavar='N', help='print frequency (default: 10)')
|
||||
parser.add_argument('--log-freq', default=10, type=int,
|
||||
metavar='N', help='batch logging frequency (default: 10)')
|
||||
parser.add_argument('--checkpoint', default='', type=str, metavar='PATH',
|
||||
help='path to latest checkpoint (default: none)')
|
||||
parser.add_argument('--pretrained', dest='pretrained', action='store_true',
|
||||
@ -68,7 +69,7 @@ def validate(args):
|
||||
load_checkpoint(model, args.checkpoint, args.use_ema)
|
||||
|
||||
param_count = sum([m.numel() for m in model.parameters()])
|
||||
print('Model %s created, param count: %d' % (args.model, param_count))
|
||||
logging.info('Model %s created, param count: %d' % (args.model, param_count))
|
||||
|
||||
data_config = resolve_data_config(model, args)
|
||||
model, test_time_pool = apply_test_time_pool(model, data_config, args)
|
||||
@ -118,28 +119,30 @@ def validate(args):
|
||||
batch_time.update(time.time() - end)
|
||||
end = time.time()
|
||||
|
||||
if i % args.print_freq == 0:
|
||||
print('Test: [{0}/{1}]\t'
|
||||
'Time {batch_time.val:.3f} ({batch_time.avg:.3f}, {rate_avg:.3f}/s) \t'
|
||||
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
|
||||
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
|
||||
'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
|
||||
i, len(loader), batch_time=batch_time,
|
||||
rate_avg=input.size(0) / batch_time.avg,
|
||||
loss=losses, top1=top1, top5=top5))
|
||||
if i % args.log_freq == 0:
|
||||
logging.info(
|
||||
'Test: [{0:>4d}/{1}] '
|
||||
'Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) '
|
||||
'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) '
|
||||
'Prec@1: {top1.val:>7.4f} ({top1.avg:>7.4f}) '
|
||||
'Prec@5: {top5.val:>7.4f} ({top5.avg:>7.4f})'.format(
|
||||
i, len(loader), batch_time=batch_time,
|
||||
rate_avg=input.size(0) / batch_time.avg,
|
||||
loss=losses, top1=top1, top5=top5))
|
||||
|
||||
results = OrderedDict(
|
||||
top1=round(top1.avg, 3), top1_err=round(100 - top1.avg, 3),
|
||||
top5=round(top5.avg, 3), top5_err=round(100 - top5.avg, 3),
|
||||
param_count=round(param_count / 1e6, 2))
|
||||
|
||||
print(' * Prec@1 {:.3f} ({:.3f}) Prec@5 {:.3f} ({:.3f})'.format(
|
||||
logging.info(' * Prec@1 {:.3f} ({:.3f}) Prec@5 {:.3f} ({:.3f})'.format(
|
||||
results['top1'], results['top1_err'], results['top5'], results['top5_err']))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def main():
|
||||
setup_default_logging()
|
||||
args = parser.parse_args()
|
||||
if args.model == 'all':
|
||||
# validate all models in a list of names with pretrained checkpoints
|
||||
|
Loading…
x
Reference in New Issue
Block a user