mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Merge branch 'feature/AB/logger' of https://github.com/antoinebrl/pytorch-image-models into logger
This commit is contained in:
commit
1998bd3180
@ -17,6 +17,8 @@ from timm.data import Dataset, create_loader, resolve_data_config
|
|||||||
from timm.utils import AverageMeter, setup_default_logging
|
from timm.utils import AverageMeter, setup_default_logging
|
||||||
|
|
||||||
torch.backends.cudnn.benchmark = True
|
torch.backends.cudnn.benchmark = True
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description='PyTorch ImageNet Inference')
|
parser = argparse.ArgumentParser(description='PyTorch ImageNet Inference')
|
||||||
parser.add_argument('data', metavar='DIR',
|
parser.add_argument('data', metavar='DIR',
|
||||||
@ -67,7 +69,7 @@ def main():
|
|||||||
pretrained=args.pretrained,
|
pretrained=args.pretrained,
|
||||||
checkpoint_path=args.checkpoint)
|
checkpoint_path=args.checkpoint)
|
||||||
|
|
||||||
logging.info('Model %s created, param count: %d' %
|
logger.info('Model %s created, param count: %d' %
|
||||||
(args.model, sum([m.numel() for m in model.parameters()])))
|
(args.model, sum([m.numel() for m in model.parameters()])))
|
||||||
|
|
||||||
config = resolve_data_config(vars(args), model=model)
|
config = resolve_data_config(vars(args), model=model)
|
||||||
@ -107,7 +109,7 @@ def main():
|
|||||||
end = time.time()
|
end = time.time()
|
||||||
|
|
||||||
if batch_idx % args.log_freq == 0:
|
if batch_idx % args.log_freq == 0:
|
||||||
logging.info('Predict: [{0}/{1}] Time {batch_time.val:.3f} ({batch_time.avg:.3f})'.format(
|
logger.info('Predict: [{0}/{1}] Time {batch_time.val:.3f} ({batch_time.avg:.3f})'.format(
|
||||||
batch_idx, len(loader), batch_time=batch_time))
|
batch_idx, len(loader), batch_time=batch_time))
|
||||||
|
|
||||||
topk_ids = np.concatenate(topk_ids, axis=0).squeeze()
|
topk_ids = np.concatenate(topk_ids, axis=0).squeeze()
|
||||||
|
@ -2,6 +2,9 @@ import logging
|
|||||||
from .constants import *
|
from .constants import *
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def resolve_data_config(args, default_cfg={}, model=None, verbose=True):
|
def resolve_data_config(args, default_cfg={}, model=None, verbose=True):
|
||||||
new_config = {}
|
new_config = {}
|
||||||
default_cfg = default_cfg
|
default_cfg = default_cfg
|
||||||
@ -65,8 +68,8 @@ def resolve_data_config(args, default_cfg={}, model=None, verbose=True):
|
|||||||
new_config['crop_pct'] = default_cfg['crop_pct']
|
new_config['crop_pct'] = default_cfg['crop_pct']
|
||||||
|
|
||||||
if verbose:
|
if verbose:
|
||||||
logging.info('Data processing configuration for current model + dataset:')
|
logger.info('Data processing configuration for current model + dataset:')
|
||||||
for n, v in new_config.items():
|
for n, v in new_config.items():
|
||||||
logging.info('\t%s: %s' % (n, str(v)))
|
logger.info('\t%s: %s' % (n, str(v)))
|
||||||
|
|
||||||
return new_config
|
return new_config
|
||||||
|
@ -12,6 +12,9 @@ from .layers import CondConv2d, get_condconv_initializer
|
|||||||
__all__ = ["EfficientNetBuilder", "decode_arch_def", "efficientnet_init_weights"]
|
__all__ = ["EfficientNetBuilder", "decode_arch_def", "efficientnet_init_weights"]
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _parse_ksize(ss):
|
def _parse_ksize(ss):
|
||||||
if ss.isdigit():
|
if ss.isdigit():
|
||||||
return int(ss)
|
return int(ss)
|
||||||
@ -248,7 +251,7 @@ class EfficientNetBuilder:
|
|||||||
ba['drop_path_rate'] = drop_path_rate
|
ba['drop_path_rate'] = drop_path_rate
|
||||||
ba['se_kwargs'] = self.se_kwargs
|
ba['se_kwargs'] = self.se_kwargs
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
logging.info(' InvertedResidual {}, Args: {}'.format(block_idx, str(ba)))
|
logger.info(' InvertedResidual {}, Args: {}'.format(block_idx, str(ba)))
|
||||||
if ba.get('num_experts', 0) > 0:
|
if ba.get('num_experts', 0) > 0:
|
||||||
block = CondConvResidual(**ba)
|
block = CondConvResidual(**ba)
|
||||||
else:
|
else:
|
||||||
@ -257,17 +260,17 @@ class EfficientNetBuilder:
|
|||||||
ba['drop_path_rate'] = drop_path_rate
|
ba['drop_path_rate'] = drop_path_rate
|
||||||
ba['se_kwargs'] = self.se_kwargs
|
ba['se_kwargs'] = self.se_kwargs
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
logging.info(' DepthwiseSeparable {}, Args: {}'.format(block_idx, str(ba)))
|
logger.info(' DepthwiseSeparable {}, Args: {}'.format(block_idx, str(ba)))
|
||||||
block = DepthwiseSeparableConv(**ba)
|
block = DepthwiseSeparableConv(**ba)
|
||||||
elif bt == 'er':
|
elif bt == 'er':
|
||||||
ba['drop_path_rate'] = drop_path_rate
|
ba['drop_path_rate'] = drop_path_rate
|
||||||
ba['se_kwargs'] = self.se_kwargs
|
ba['se_kwargs'] = self.se_kwargs
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
logging.info(' EdgeResidual {}, Args: {}'.format(block_idx, str(ba)))
|
logger.info(' EdgeResidual {}, Args: {}'.format(block_idx, str(ba)))
|
||||||
block = EdgeResidual(**ba)
|
block = EdgeResidual(**ba)
|
||||||
elif bt == 'cn':
|
elif bt == 'cn':
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
logging.info(' ConvBnAct {}, Args: {}'.format(block_idx, str(ba)))
|
logger.info(' ConvBnAct {}, Args: {}'.format(block_idx, str(ba)))
|
||||||
block = ConvBnAct(**ba)
|
block = ConvBnAct(**ba)
|
||||||
else:
|
else:
|
||||||
assert False, 'Uknkown block type (%s) while building model.' % bt
|
assert False, 'Uknkown block type (%s) while building model.' % bt
|
||||||
@ -285,7 +288,7 @@ class EfficientNetBuilder:
|
|||||||
List of block stacks (each stack wrapped in nn.Sequential)
|
List of block stacks (each stack wrapped in nn.Sequential)
|
||||||
"""
|
"""
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
logging.info('Building model trunk with %d stages...' % len(model_block_args))
|
logger.info('Building model trunk with %d stages...' % len(model_block_args))
|
||||||
self.in_chs = in_chs
|
self.in_chs = in_chs
|
||||||
total_block_count = sum([len(x) for x in model_block_args])
|
total_block_count = sum([len(x) for x in model_block_args])
|
||||||
total_block_idx = 0
|
total_block_idx = 0
|
||||||
@ -297,7 +300,7 @@ class EfficientNetBuilder:
|
|||||||
for stage_idx, stage_block_args in enumerate(model_block_args):
|
for stage_idx, stage_block_args in enumerate(model_block_args):
|
||||||
last_stack = stage_idx == (len(model_block_args) - 1)
|
last_stack = stage_idx == (len(model_block_args) - 1)
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
logging.info('Stack: {}'.format(stage_idx))
|
logger.info('Stack: {}'.format(stage_idx))
|
||||||
assert isinstance(stage_block_args, list)
|
assert isinstance(stage_block_args, list)
|
||||||
|
|
||||||
blocks = []
|
blocks = []
|
||||||
@ -306,7 +309,7 @@ class EfficientNetBuilder:
|
|||||||
last_block = block_idx == (len(stage_block_args) - 1)
|
last_block = block_idx == (len(stage_block_args) - 1)
|
||||||
extract_features = '' # No features extracted
|
extract_features = '' # No features extracted
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
logging.info(' Block: {}'.format(block_idx))
|
logger.info(' Block: {}'.format(block_idx))
|
||||||
|
|
||||||
# Sort out stride, dilation, and feature extraction details
|
# Sort out stride, dilation, and feature extraction details
|
||||||
assert block_args['stride'] in (1, 2)
|
assert block_args['stride'] in (1, 2)
|
||||||
@ -336,7 +339,7 @@ class EfficientNetBuilder:
|
|||||||
next_dilation = current_dilation * block_args['stride']
|
next_dilation = current_dilation * block_args['stride']
|
||||||
block_args['stride'] = 1
|
block_args['stride'] = 1
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
logging.info(' Converting stride to dilation to maintain output_stride=={}'.format(
|
logger.info(' Converting stride to dilation to maintain output_stride=={}'.format(
|
||||||
self.output_stride))
|
self.output_stride))
|
||||||
else:
|
else:
|
||||||
current_stride = next_output_stride
|
current_stride = next_output_stride
|
||||||
|
@ -8,6 +8,9 @@ from collections import OrderedDict
|
|||||||
from timm.models.layers.conv2d_same import Conv2dSame
|
from timm.models.layers.conv2d_same import Conv2dSame
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def load_state_dict(checkpoint_path, use_ema=False):
|
def load_state_dict(checkpoint_path, use_ema=False):
|
||||||
if checkpoint_path and os.path.isfile(checkpoint_path):
|
if checkpoint_path and os.path.isfile(checkpoint_path):
|
||||||
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
||||||
@ -24,10 +27,10 @@ def load_state_dict(checkpoint_path, use_ema=False):
|
|||||||
state_dict = new_state_dict
|
state_dict = new_state_dict
|
||||||
else:
|
else:
|
||||||
state_dict = checkpoint
|
state_dict = checkpoint
|
||||||
logging.info("Loaded {} from checkpoint '{}'".format(state_dict_key, checkpoint_path))
|
logger.info("Loaded {} from checkpoint '{}'".format(state_dict_key, checkpoint_path))
|
||||||
return state_dict
|
return state_dict
|
||||||
else:
|
else:
|
||||||
logging.error("No checkpoint found at '{}'".format(checkpoint_path))
|
logger.error("No checkpoint found at '{}'".format(checkpoint_path))
|
||||||
raise FileNotFoundError()
|
raise FileNotFoundError()
|
||||||
|
|
||||||
|
|
||||||
@ -55,13 +58,13 @@ def resume_checkpoint(model, checkpoint_path):
|
|||||||
resume_epoch = checkpoint['epoch']
|
resume_epoch = checkpoint['epoch']
|
||||||
if 'version' in checkpoint and checkpoint['version'] > 1:
|
if 'version' in checkpoint and checkpoint['version'] > 1:
|
||||||
resume_epoch += 1 # start at the next epoch, old checkpoints incremented before save
|
resume_epoch += 1 # start at the next epoch, old checkpoints incremented before save
|
||||||
logging.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch']))
|
logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch']))
|
||||||
else:
|
else:
|
||||||
model.load_state_dict(checkpoint)
|
model.load_state_dict(checkpoint)
|
||||||
logging.info("Loaded checkpoint '{}'".format(checkpoint_path))
|
logger.info("Loaded checkpoint '{}'".format(checkpoint_path))
|
||||||
return other_state, resume_epoch
|
return other_state, resume_epoch
|
||||||
else:
|
else:
|
||||||
logging.error("No checkpoint found at '{}'".format(checkpoint_path))
|
logger.error("No checkpoint found at '{}'".format(checkpoint_path))
|
||||||
raise FileNotFoundError()
|
raise FileNotFoundError()
|
||||||
|
|
||||||
|
|
||||||
@ -69,14 +72,14 @@ def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=Non
|
|||||||
if cfg is None:
|
if cfg is None:
|
||||||
cfg = getattr(model, 'default_cfg')
|
cfg = getattr(model, 'default_cfg')
|
||||||
if cfg is None or 'url' not in cfg or not cfg['url']:
|
if cfg is None or 'url' not in cfg or not cfg['url']:
|
||||||
logging.warning("Pretrained model URL is invalid, using random initialization.")
|
logger.warning("Pretrained model URL is invalid, using random initialization.")
|
||||||
return
|
return
|
||||||
|
|
||||||
state_dict = model_zoo.load_url(cfg['url'], progress=False, map_location='cpu')
|
state_dict = model_zoo.load_url(cfg['url'], progress=False, map_location='cpu')
|
||||||
|
|
||||||
if in_chans == 1:
|
if in_chans == 1:
|
||||||
conv1_name = cfg['first_conv']
|
conv1_name = cfg['first_conv']
|
||||||
logging.info('Converting first conv (%s) from 3 to 1 channel' % conv1_name)
|
logger.info('Converting first conv (%s) from 3 to 1 channel' % conv1_name)
|
||||||
conv1_weight = state_dict[conv1_name + '.weight']
|
conv1_weight = state_dict[conv1_name + '.weight']
|
||||||
state_dict[conv1_name + '.weight'] = conv1_weight.sum(dim=1, keepdim=True)
|
state_dict[conv1_name + '.weight'] = conv1_weight.sum(dim=1, keepdim=True)
|
||||||
elif in_chans != 3:
|
elif in_chans != 3:
|
||||||
|
@ -10,6 +10,9 @@ import torch.nn.functional as F
|
|||||||
from .adaptive_avgmax_pool import adaptive_avgmax_pool2d
|
from .adaptive_avgmax_pool import adaptive_avgmax_pool2d
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class TestTimePoolHead(nn.Module):
|
class TestTimePoolHead(nn.Module):
|
||||||
def __init__(self, base, original_pool=7):
|
def __init__(self, base, original_pool=7):
|
||||||
super(TestTimePoolHead, self).__init__()
|
super(TestTimePoolHead, self).__init__()
|
||||||
@ -40,7 +43,7 @@ def apply_test_time_pool(model, config, args):
|
|||||||
if not args.no_test_pool and \
|
if not args.no_test_pool and \
|
||||||
config['input_size'][-1] > model.default_cfg['input_size'][-1] and \
|
config['input_size'][-1] > model.default_cfg['input_size'][-1] and \
|
||||||
config['input_size'][-2] > model.default_cfg['input_size'][-2]:
|
config['input_size'][-2] > model.default_cfg['input_size'][-2]:
|
||||||
logging.info('Target input size %s > pretrained default %s, using test time pooling' %
|
logger.info('Target input size %s > pretrained default %s, using test time pooling' %
|
||||||
(str(config['input_size'][-2:]), str(model.default_cfg['input_size'][-2:])))
|
(str(config['input_size'][-2:]), str(model.default_cfg['input_size'][-2:])))
|
||||||
model = TestTimePoolHead(model, original_pool=model.default_cfg['pool_size'])
|
model = TestTimePoolHead(model, original_pool=model.default_cfg['pool_size'])
|
||||||
test_time_pool = True
|
test_time_pool = True
|
||||||
|
@ -21,6 +21,9 @@ except ImportError:
|
|||||||
from torch import distributed as dist
|
from torch import distributed as dist
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def unwrap_model(model):
|
def unwrap_model(model):
|
||||||
if isinstance(model, ModelEma):
|
if isinstance(model, ModelEma):
|
||||||
return unwrap_model(model.ema)
|
return unwrap_model(model.ema)
|
||||||
@ -84,7 +87,7 @@ class CheckpointSaver:
|
|||||||
checkpoints_str = "Current checkpoints:\n"
|
checkpoints_str = "Current checkpoints:\n"
|
||||||
for c in self.checkpoint_files:
|
for c in self.checkpoint_files:
|
||||||
checkpoints_str += ' {}\n'.format(c)
|
checkpoints_str += ' {}\n'.format(c)
|
||||||
logging.info(checkpoints_str)
|
logger.info(checkpoints_str)
|
||||||
|
|
||||||
if metric is not None and (self.best_metric is None or self.cmp(metric, self.best_metric)):
|
if metric is not None and (self.best_metric is None or self.cmp(metric, self.best_metric)):
|
||||||
self.best_epoch = epoch
|
self.best_epoch = epoch
|
||||||
@ -121,10 +124,10 @@ class CheckpointSaver:
|
|||||||
to_delete = self.checkpoint_files[delete_index:]
|
to_delete = self.checkpoint_files[delete_index:]
|
||||||
for d in to_delete:
|
for d in to_delete:
|
||||||
try:
|
try:
|
||||||
logging.debug("Cleaning checkpoint: {}".format(d))
|
logger.debug("Cleaning checkpoint: {}".format(d))
|
||||||
os.remove(d[0])
|
os.remove(d[0])
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error("Exception '{}' while deleting checkpoint".format(e))
|
logger.error("Exception '{}' while deleting checkpoint".format(e))
|
||||||
self.checkpoint_files = self.checkpoint_files[:delete_index]
|
self.checkpoint_files = self.checkpoint_files[:delete_index]
|
||||||
|
|
||||||
def save_recovery(self, model, optimizer, args, epoch, model_ema=None, use_amp=False, batch_idx=0):
|
def save_recovery(self, model, optimizer, args, epoch, model_ema=None, use_amp=False, batch_idx=0):
|
||||||
@ -134,10 +137,10 @@ class CheckpointSaver:
|
|||||||
self._save(save_path, model, optimizer, args, epoch, model_ema, use_amp=use_amp)
|
self._save(save_path, model, optimizer, args, epoch, model_ema, use_amp=use_amp)
|
||||||
if os.path.exists(self.last_recovery_file):
|
if os.path.exists(self.last_recovery_file):
|
||||||
try:
|
try:
|
||||||
logging.debug("Cleaning recovery: {}".format(self.last_recovery_file))
|
logger.debug("Cleaning recovery: {}".format(self.last_recovery_file))
|
||||||
os.remove(self.last_recovery_file)
|
os.remove(self.last_recovery_file)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error("Exception '{}' while removing {}".format(e, self.last_recovery_file))
|
logger.error("Exception '{}' while removing {}".format(e, self.last_recovery_file))
|
||||||
self.last_recovery_file = self.curr_recovery_file
|
self.last_recovery_file = self.curr_recovery_file
|
||||||
self.curr_recovery_file = save_path
|
self.curr_recovery_file = save_path
|
||||||
|
|
||||||
@ -279,9 +282,9 @@ class ModelEma:
|
|||||||
name = k
|
name = k
|
||||||
new_state_dict[name] = v
|
new_state_dict[name] = v
|
||||||
self.ema.load_state_dict(new_state_dict)
|
self.ema.load_state_dict(new_state_dict)
|
||||||
logging.info("Loaded state_dict_ema")
|
logger.info("Loaded state_dict_ema")
|
||||||
else:
|
else:
|
||||||
logging.warning("Failed to find state_dict_ema, starting from loaded model weights")
|
logger.warning("Failed to find state_dict_ema, starting from loaded model weights")
|
||||||
|
|
||||||
def update(self, model):
|
def update(self, model):
|
||||||
# correct a mismatch in state dict keys
|
# correct a mismatch in state dict keys
|
||||||
|
37
train.py
37
train.py
@ -40,6 +40,7 @@ import torch.nn as nn
|
|||||||
import torchvision.utils
|
import torchvision.utils
|
||||||
|
|
||||||
torch.backends.cudnn.benchmark = True
|
torch.backends.cudnn.benchmark = True
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
# The first arg parser parses out only the --config argument, this argument is used to
|
# The first arg parser parses out only the --config argument, this argument is used to
|
||||||
@ -232,7 +233,7 @@ def main():
|
|||||||
if 'WORLD_SIZE' in os.environ:
|
if 'WORLD_SIZE' in os.environ:
|
||||||
args.distributed = int(os.environ['WORLD_SIZE']) > 1
|
args.distributed = int(os.environ['WORLD_SIZE']) > 1
|
||||||
if args.distributed and args.num_gpu > 1:
|
if args.distributed and args.num_gpu > 1:
|
||||||
logging.warning('Using more than one GPU per process in distributed mode is not allowed. Setting num_gpu to 1.')
|
logger.warning('Using more than one GPU per process in distributed mode is not allowed. Setting num_gpu to 1.')
|
||||||
args.num_gpu = 1
|
args.num_gpu = 1
|
||||||
|
|
||||||
args.device = 'cuda:0'
|
args.device = 'cuda:0'
|
||||||
@ -248,10 +249,10 @@ def main():
|
|||||||
assert args.rank >= 0
|
assert args.rank >= 0
|
||||||
|
|
||||||
if args.distributed:
|
if args.distributed:
|
||||||
logging.info('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
|
logger.info('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
|
||||||
% (args.rank, args.world_size))
|
% (args.rank, args.world_size))
|
||||||
else:
|
else:
|
||||||
logging.info('Training with a single process on %d GPUs.' % args.num_gpu)
|
logger.info('Training with a single process on %d GPUs.' % args.num_gpu)
|
||||||
|
|
||||||
torch.manual_seed(args.seed + args.rank)
|
torch.manual_seed(args.seed + args.rank)
|
||||||
|
|
||||||
@ -270,7 +271,7 @@ def main():
|
|||||||
checkpoint_path=args.initial_checkpoint)
|
checkpoint_path=args.initial_checkpoint)
|
||||||
|
|
||||||
if args.local_rank == 0:
|
if args.local_rank == 0:
|
||||||
logging.info('Model %s created, param count: %d' %
|
logger.info('Model %s created, param count: %d' %
|
||||||
(args.model, sum([m.numel() for m in model.parameters()])))
|
(args.model, sum([m.numel() for m in model.parameters()])))
|
||||||
|
|
||||||
data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0)
|
data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0)
|
||||||
@ -286,7 +287,7 @@ def main():
|
|||||||
|
|
||||||
if args.num_gpu > 1:
|
if args.num_gpu > 1:
|
||||||
if args.amp:
|
if args.amp:
|
||||||
logging.warning(
|
logger.warning(
|
||||||
'AMP does not work well with nn.DataParallel, disabling. Use distributed mode for multi-GPU AMP.')
|
'AMP does not work well with nn.DataParallel, disabling. Use distributed mode for multi-GPU AMP.')
|
||||||
args.amp = False
|
args.amp = False
|
||||||
model = nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda()
|
model = nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda()
|
||||||
@ -300,7 +301,7 @@ def main():
|
|||||||
model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
|
model, optimizer = amp.initialize(model, optimizer, opt_level='O1')
|
||||||
use_amp = True
|
use_amp = True
|
||||||
if args.local_rank == 0:
|
if args.local_rank == 0:
|
||||||
logging.info('NVIDIA APEX {}. AMP {}.'.format(
|
logger.info('NVIDIA APEX {}. AMP {}.'.format(
|
||||||
'installed' if has_apex else 'not installed', 'on' if use_amp else 'off'))
|
'installed' if has_apex else 'not installed', 'on' if use_amp else 'off'))
|
||||||
|
|
||||||
# optionally resume from a checkpoint
|
# optionally resume from a checkpoint
|
||||||
@ -311,11 +312,11 @@ def main():
|
|||||||
if resume_state and not args.no_resume_opt:
|
if resume_state and not args.no_resume_opt:
|
||||||
if 'optimizer' in resume_state:
|
if 'optimizer' in resume_state:
|
||||||
if args.local_rank == 0:
|
if args.local_rank == 0:
|
||||||
logging.info('Restoring Optimizer state from checkpoint')
|
logger.info('Restoring Optimizer state from checkpoint')
|
||||||
optimizer.load_state_dict(resume_state['optimizer'])
|
optimizer.load_state_dict(resume_state['optimizer'])
|
||||||
if use_amp and 'amp' in resume_state and 'load_state_dict' in amp.__dict__:
|
if use_amp and 'amp' in resume_state and 'load_state_dict' in amp.__dict__:
|
||||||
if args.local_rank == 0:
|
if args.local_rank == 0:
|
||||||
logging.info('Restoring NVIDIA AMP state from checkpoint')
|
logger.info('Restoring NVIDIA AMP state from checkpoint')
|
||||||
amp.load_state_dict(resume_state['amp'])
|
amp.load_state_dict(resume_state['amp'])
|
||||||
del resume_state
|
del resume_state
|
||||||
|
|
||||||
@ -337,16 +338,16 @@ def main():
|
|||||||
else:
|
else:
|
||||||
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
|
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
|
||||||
if args.local_rank == 0:
|
if args.local_rank == 0:
|
||||||
logging.info(
|
logger.info(
|
||||||
'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '
|
'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '
|
||||||
'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.')
|
'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.')
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error('Failed to enable Synchronized BatchNorm. Install Apex or Torch >= 1.1')
|
logger.error('Failed to enable Synchronized BatchNorm. Install Apex or Torch >= 1.1')
|
||||||
if has_apex:
|
if has_apex:
|
||||||
model = DDP(model, delay_allreduce=True)
|
model = DDP(model, delay_allreduce=True)
|
||||||
else:
|
else:
|
||||||
if args.local_rank == 0:
|
if args.local_rank == 0:
|
||||||
logging.info("Using torch DistributedDataParallel. Install NVIDIA Apex for Apex DDP.")
|
logger.info("Using torch DistributedDataParallel. Install NVIDIA Apex for Apex DDP.")
|
||||||
model = DDP(model, device_ids=[args.local_rank]) # can use device str in Torch >= 1.1
|
model = DDP(model, device_ids=[args.local_rank]) # can use device str in Torch >= 1.1
|
||||||
# NOTE: EMA model does not need to be wrapped by DDP
|
# NOTE: EMA model does not need to be wrapped by DDP
|
||||||
|
|
||||||
@ -361,11 +362,11 @@ def main():
|
|||||||
lr_scheduler.step(start_epoch)
|
lr_scheduler.step(start_epoch)
|
||||||
|
|
||||||
if args.local_rank == 0:
|
if args.local_rank == 0:
|
||||||
logging.info('Scheduled epochs: {}'.format(num_epochs))
|
logger.info('Scheduled epochs: {}'.format(num_epochs))
|
||||||
|
|
||||||
train_dir = os.path.join(args.data, 'train')
|
train_dir = os.path.join(args.data, 'train')
|
||||||
if not os.path.exists(train_dir):
|
if not os.path.exists(train_dir):
|
||||||
logging.error('Training folder does not exist at: {}'.format(train_dir))
|
logger.error('Training folder does not exist at: {}'.format(train_dir))
|
||||||
exit(1)
|
exit(1)
|
||||||
dataset_train = Dataset(train_dir)
|
dataset_train = Dataset(train_dir)
|
||||||
|
|
||||||
@ -404,7 +405,7 @@ def main():
|
|||||||
if not os.path.isdir(eval_dir):
|
if not os.path.isdir(eval_dir):
|
||||||
eval_dir = os.path.join(args.data, 'validation')
|
eval_dir = os.path.join(args.data, 'validation')
|
||||||
if not os.path.isdir(eval_dir):
|
if not os.path.isdir(eval_dir):
|
||||||
logging.error('Validation folder does not exist at: {}'.format(eval_dir))
|
logger.error('Validation folder does not exist at: {}'.format(eval_dir))
|
||||||
exit(1)
|
exit(1)
|
||||||
dataset_eval = Dataset(eval_dir)
|
dataset_eval = Dataset(eval_dir)
|
||||||
|
|
||||||
@ -468,7 +469,7 @@ def main():
|
|||||||
|
|
||||||
if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
|
if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
|
||||||
if args.local_rank == 0:
|
if args.local_rank == 0:
|
||||||
logging.info("Distributing BatchNorm running means and vars")
|
logger.info("Distributing BatchNorm running means and vars")
|
||||||
distribute_bn(model, args.world_size, args.dist_bn == 'reduce')
|
distribute_bn(model, args.world_size, args.dist_bn == 'reduce')
|
||||||
|
|
||||||
eval_metrics = validate(model, loader_eval, validate_loss_fn, args)
|
eval_metrics = validate(model, loader_eval, validate_loss_fn, args)
|
||||||
@ -499,7 +500,7 @@ def main():
|
|||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
pass
|
pass
|
||||||
if best_metric is not None:
|
if best_metric is not None:
|
||||||
logging.info('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch))
|
logger.info('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch))
|
||||||
|
|
||||||
|
|
||||||
def train_epoch(
|
def train_epoch(
|
||||||
@ -559,7 +560,7 @@ def train_epoch(
|
|||||||
losses_m.update(reduced_loss.item(), input.size(0))
|
losses_m.update(reduced_loss.item(), input.size(0))
|
||||||
|
|
||||||
if args.local_rank == 0:
|
if args.local_rank == 0:
|
||||||
logging.info(
|
logger.info(
|
||||||
'Train: {} [{:>4d}/{} ({:>3.0f}%)] '
|
'Train: {} [{:>4d}/{} ({:>3.0f}%)] '
|
||||||
'Loss: {loss.val:>9.6f} ({loss.avg:>6.4f}) '
|
'Loss: {loss.val:>9.6f} ({loss.avg:>6.4f}) '
|
||||||
'Time: {batch_time.val:.3f}s, {rate:>7.2f}/s '
|
'Time: {batch_time.val:.3f}s, {rate:>7.2f}/s '
|
||||||
@ -647,7 +648,7 @@ def validate(model, loader, loss_fn, args, log_suffix=''):
|
|||||||
end = time.time()
|
end = time.time()
|
||||||
if args.local_rank == 0 and (last_batch or batch_idx % args.log_interval == 0):
|
if args.local_rank == 0 and (last_batch or batch_idx % args.log_interval == 0):
|
||||||
log_name = 'Test' + log_suffix
|
log_name = 'Test' + log_suffix
|
||||||
logging.info(
|
logger.info(
|
||||||
'{0}: [{1:>4d}/{2}] '
|
'{0}: [{1:>4d}/{2}] '
|
||||||
'Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) '
|
'Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) '
|
||||||
'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) '
|
'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) '
|
||||||
|
10
validate.py
10
validate.py
@ -30,6 +30,8 @@ from timm.data import Dataset, DatasetTar, create_loader, resolve_data_config
|
|||||||
from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging
|
from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging
|
||||||
|
|
||||||
torch.backends.cudnn.benchmark = True
|
torch.backends.cudnn.benchmark = True
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description='PyTorch ImageNet Validation')
|
parser = argparse.ArgumentParser(description='PyTorch ImageNet Validation')
|
||||||
parser.add_argument('data', metavar='DIR',
|
parser.add_argument('data', metavar='DIR',
|
||||||
@ -97,7 +99,7 @@ def validate(args):
|
|||||||
load_checkpoint(model, args.checkpoint, args.use_ema)
|
load_checkpoint(model, args.checkpoint, args.use_ema)
|
||||||
|
|
||||||
param_count = sum([m.numel() for m in model.parameters()])
|
param_count = sum([m.numel() for m in model.parameters()])
|
||||||
logging.info('Model %s created, param count: %d' % (args.model, param_count))
|
logger.info('Model %s created, param count: %d' % (args.model, param_count))
|
||||||
|
|
||||||
data_config = resolve_data_config(vars(args), model=model)
|
data_config = resolve_data_config(vars(args), model=model)
|
||||||
model, test_time_pool = apply_test_time_pool(model, data_config, args)
|
model, test_time_pool = apply_test_time_pool(model, data_config, args)
|
||||||
@ -170,7 +172,7 @@ def validate(args):
|
|||||||
end = time.time()
|
end = time.time()
|
||||||
|
|
||||||
if i % args.log_freq == 0:
|
if i % args.log_freq == 0:
|
||||||
logging.info(
|
logger.info(
|
||||||
'Test: [{0:>4d}/{1}] '
|
'Test: [{0:>4d}/{1}] '
|
||||||
'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) '
|
'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) '
|
||||||
'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) '
|
'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) '
|
||||||
@ -188,7 +190,7 @@ def validate(args):
|
|||||||
cropt_pct=crop_pct,
|
cropt_pct=crop_pct,
|
||||||
interpolation=data_config['interpolation'])
|
interpolation=data_config['interpolation'])
|
||||||
|
|
||||||
logging.info(' * Acc@1 {:.3f} ({:.3f}) Acc@5 {:.3f} ({:.3f})'.format(
|
logger.info(' * Acc@1 {:.3f} ({:.3f}) Acc@5 {:.3f} ({:.3f})'.format(
|
||||||
results['top1'], results['top1_err'], results['top5'], results['top5_err']))
|
results['top1'], results['top1_err'], results['top5'], results['top5_err']))
|
||||||
|
|
||||||
return results
|
return results
|
||||||
@ -218,7 +220,7 @@ def main():
|
|||||||
|
|
||||||
if len(model_cfgs):
|
if len(model_cfgs):
|
||||||
results_file = args.results_file or './results-all.csv'
|
results_file = args.results_file or './results-all.csv'
|
||||||
logging.info('Running bulk validation on these pretrained models: {}'.format(', '.join(model_names)))
|
logger.info('Running bulk validation on these pretrained models: {}'.format(', '.join(model_names)))
|
||||||
results = []
|
results = []
|
||||||
try:
|
try:
|
||||||
start_batch_size = args.batch_size
|
start_batch_size = args.batch_size
|
||||||
|
Loading…
x
Reference in New Issue
Block a user