mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Update benchmark script to add precision arg. Fix some downstream (DeiT) compat issues with latest changes. Bump version to 0.4.7
This commit is contained in:
parent
ea9c9550b2
commit
288682796f
92
benchmark.py
92
benchmark.py
@ -19,7 +19,7 @@ from contextlib import suppress
|
||||
from functools import partial
|
||||
|
||||
from timm.models import create_model, is_model, list_models
|
||||
from timm.optim import create_optimizer
|
||||
from timm.optim import create_optimizer_v2
|
||||
from timm.data import resolve_data_config
|
||||
from timm.utils import AverageMeter, setup_default_logging
|
||||
|
||||
@ -53,6 +53,10 @@ parser.add_argument('--detail', action='store_true', default=False,
|
||||
help='Provide train fwd/bwd/opt breakdown detail if True. Defaults to False')
|
||||
parser.add_argument('--results-file', default='', type=str, metavar='FILENAME',
|
||||
help='Output csv file for validation results (summary)')
|
||||
parser.add_argument('--num-warm-iter', default=10, type=int,
|
||||
metavar='N', help='Number of warmup iterations (default: 10)')
|
||||
parser.add_argument('--num-bench-iter', default=40, type=int,
|
||||
metavar='N', help='Number of benchmark iterations (default: 40)')
|
||||
|
||||
# common inference / train args
|
||||
parser.add_argument('--model', '-m', metavar='NAME', default='resnet50',
|
||||
@ -70,11 +74,9 @@ parser.add_argument('--gp', default=None, type=str, metavar='POOL',
|
||||
parser.add_argument('--channels-last', action='store_true', default=False,
|
||||
help='Use channels_last memory layout')
|
||||
parser.add_argument('--amp', action='store_true', default=False,
|
||||
help='Use AMP mixed precision. Defaults to Apex, fallback to native Torch AMP.')
|
||||
parser.add_argument('--apex-amp', action='store_true', default=False,
|
||||
help='Use NVIDIA Apex AMP mixed precision')
|
||||
parser.add_argument('--native-amp', action='store_true', default=False,
|
||||
help='Use Native Torch AMP mixed precision')
|
||||
help='use PyTorch Native AMP for mixed precision training. Overrides --precision arg.')
|
||||
parser.add_argument('--precision', default='float32', type=str,
|
||||
help='Numeric precision. One of (amp, float32, float16, bfloat16, tf32)')
|
||||
parser.add_argument('--torchscript', dest='torchscript', action='store_true',
|
||||
help='convert model torchscript for inference')
|
||||
|
||||
@ -117,28 +119,50 @@ def cuda_timestamp(sync=False, device=None):
|
||||
return time.perf_counter()
|
||||
|
||||
|
||||
def count_params(model):
|
||||
def count_params(model: nn.Module):
|
||||
return sum([m.numel() for m in model.parameters()])
|
||||
|
||||
|
||||
def resolve_precision(precision: str):
|
||||
assert precision in ('amp', 'float16', 'bfloat16', 'float32')
|
||||
use_amp = False
|
||||
model_dtype = torch.float32
|
||||
data_dtype = torch.float32
|
||||
if precision == 'amp':
|
||||
use_amp = True
|
||||
elif precision == 'float16':
|
||||
model_dtype = torch.float16
|
||||
data_dtype = torch.float16
|
||||
elif precision == 'bfloat16':
|
||||
model_dtype = torch.bfloat16
|
||||
data_dtype = torch.bfloat16
|
||||
return use_amp, model_dtype, data_dtype
|
||||
|
||||
|
||||
class BenchmarkRunner:
|
||||
def __init__(self, model_name, detail=False, device='cuda', torchscript=False, **kwargs):
|
||||
def __init__(
|
||||
self, model_name, detail=False, device='cuda', torchscript=False, precision='float32',
|
||||
num_warm_iter=10, num_bench_iter=50, **kwargs):
|
||||
self.model_name = model_name
|
||||
self.detail = detail
|
||||
self.device = device
|
||||
self.use_amp, self.model_dtype, self.data_dtype = resolve_precision(precision)
|
||||
self.channels_last = kwargs.pop('channels_last', False)
|
||||
self.amp_autocast = torch.cuda.amp.autocast if self.use_amp else suppress
|
||||
|
||||
self.model = create_model(
|
||||
model_name,
|
||||
num_classes=kwargs.pop('num_classes', None),
|
||||
in_chans=3,
|
||||
global_pool=kwargs.pop('gp', 'fast'),
|
||||
scriptable=torchscript).to(device=self.device)
|
||||
scriptable=torchscript)
|
||||
self.model.to(
|
||||
device=self.device,
|
||||
dtype=self.model_dtype,
|
||||
memory_format=torch.channels_last if self.channels_last else None)
|
||||
self.num_classes = self.model.num_classes
|
||||
self.param_count = count_params(self.model)
|
||||
_logger.info('Model %s created, param count: %d' % (model_name, self.param_count))
|
||||
|
||||
self.channels_last = kwargs.pop('channels_last', False)
|
||||
self.use_amp = kwargs.pop('use_amp', '')
|
||||
self.amp_autocast = torch.cuda.amp.autocast if self.use_amp == 'native' else suppress
|
||||
if torchscript:
|
||||
self.model = torch.jit.script(self.model)
|
||||
|
||||
@ -147,16 +171,17 @@ class BenchmarkRunner:
|
||||
self.batch_size = kwargs.pop('batch_size', 256)
|
||||
|
||||
self.example_inputs = None
|
||||
self.num_warm_iter = 10
|
||||
self.num_bench_iter = 50
|
||||
self.log_freq = 10
|
||||
self.num_warm_iter = num_warm_iter
|
||||
self.num_bench_iter = num_bench_iter
|
||||
self.log_freq = num_bench_iter // 5
|
||||
if 'cuda' in self.device:
|
||||
self.time_fn = partial(cuda_timestamp, device=self.device)
|
||||
else:
|
||||
self.time_fn = timestamp
|
||||
|
||||
def _init_input(self):
|
||||
self.example_inputs = torch.randn((self.batch_size,) + self.input_size, device=self.device)
|
||||
self.example_inputs = torch.randn(
|
||||
(self.batch_size,) + self.input_size, device=self.device, dtype=self.data_dtype)
|
||||
if self.channels_last:
|
||||
self.example_inputs = self.example_inputs.contiguous(memory_format=torch.channels_last)
|
||||
|
||||
@ -166,10 +191,6 @@ class InferenceBenchmarkRunner(BenchmarkRunner):
|
||||
def __init__(self, model_name, device='cuda', torchscript=False, **kwargs):
|
||||
super().__init__(model_name=model_name, device=device, torchscript=torchscript, **kwargs)
|
||||
self.model.eval()
|
||||
if self.use_amp == 'apex':
|
||||
self.model = amp.initialize(self.model, opt_level='O1')
|
||||
if self.channels_last:
|
||||
self.model = self.model.to(memory_format=torch.channels_last)
|
||||
|
||||
def run(self):
|
||||
def _step():
|
||||
@ -231,16 +252,11 @@ class TrainBenchmarkRunner(BenchmarkRunner):
|
||||
self.loss = nn.CrossEntropyLoss().to(self.device)
|
||||
self.target_shape = tuple()
|
||||
|
||||
self.optimizer = create_optimizer(
|
||||
self.optimizer = create_optimizer_v2(
|
||||
self.model,
|
||||
opt_name=kwargs.pop('opt', 'sgd'),
|
||||
lr=kwargs.pop('lr', 1e-4))
|
||||
|
||||
if self.use_amp == 'apex':
|
||||
self.model, self.optimizer = amp.initialize(self.model, self.optimizer, opt_level='O1')
|
||||
if self.channels_last:
|
||||
self.model = self.model.to(memory_format=torch.channels_last)
|
||||
|
||||
def _gen_target(self, batch_size):
|
||||
return torch.empty(
|
||||
(batch_size,) + self.target_shape, device=self.device, dtype=torch.long).random_(self.num_classes)
|
||||
@ -331,6 +347,7 @@ class TrainBenchmarkRunner(BenchmarkRunner):
|
||||
samples_per_sec=round(num_samples / t_run_elapsed, 2),
|
||||
step_time=round(1000 * total_step / num_samples, 3),
|
||||
batch_size=self.batch_size,
|
||||
img_size=self.input_size[-1],
|
||||
param_count=round(self.param_count / 1e6, 2),
|
||||
)
|
||||
|
||||
@ -367,23 +384,14 @@ def _try_run(model_name, bench_fn, initial_batch_size, bench_kwargs):
|
||||
|
||||
def benchmark(args):
|
||||
if args.amp:
|
||||
if has_native_amp:
|
||||
args.native_amp = True
|
||||
elif has_apex:
|
||||
args.apex_amp = True
|
||||
else:
|
||||
_logger.warning("Neither APEX or Native Torch AMP is available.")
|
||||
if args.native_amp:
|
||||
args.use_amp = 'native'
|
||||
_logger.info('Benchmarking in mixed precision with native PyTorch AMP.')
|
||||
elif args.apex_amp:
|
||||
args.use_amp = 'apex'
|
||||
_logger.info('Benchmarking in mixed precision with NVIDIA APEX AMP.')
|
||||
else:
|
||||
args.use_amp = ''
|
||||
_logger.info('Benchmarking in float32. AMP not enabled.')
|
||||
_logger.warning("Overriding precision to 'amp' since --amp flag set.")
|
||||
args.precision = 'amp'
|
||||
_logger.info(f'Benchmarking in {args.precision} precision. '
|
||||
f'{"NHWC" if args.channels_last else "NCHW"} layout. '
|
||||
f'torchscript {"enabled" if args.torchscript else "disabled"}')
|
||||
|
||||
bench_kwargs = vars(args).copy()
|
||||
bench_kwargs.pop('amp')
|
||||
model = bench_kwargs.pop('model')
|
||||
batch_size = bench_kwargs.pop('batch_size')
|
||||
|
||||
|
@ -89,7 +89,7 @@ default_cfgs = dict(
|
||||
regnety_064=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_064-0a48325c.pth'),
|
||||
regnety_080=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_080-e7f3eb93.pth'),
|
||||
regnety_120=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_120-721ba79a.pth'),
|
||||
regnety_160=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_160-d64013cd.pth'),
|
||||
regnety_160=_cfg(url='https://dl.fbaipublicfiles.com/deit/regnety_160-a5fe301d.pth'),
|
||||
regnety_320=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_320-ba464b29.pth'),
|
||||
)
|
||||
|
||||
|
@ -281,8 +281,9 @@ class VisionTransformer(nn.Module):
|
||||
|
||||
# Classifier head(s)
|
||||
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
||||
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) \
|
||||
if num_classes > 0 and distilled else nn.Identity()
|
||||
self.head_dist = None
|
||||
if distilled:
|
||||
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
# Weight init
|
||||
assert weight_init in ('jax', 'jax_nlhb', 'nlhb', '')
|
||||
@ -336,8 +337,8 @@ class VisionTransformer(nn.Module):
|
||||
def reset_classifier(self, num_classes, global_pool=''):
|
||||
self.num_classes = num_classes
|
||||
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) \
|
||||
if num_classes > 0 and self.dist_token is not None else nn.Identity()
|
||||
if self.head_dist is not None:
|
||||
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.patch_embed(x)
|
||||
@ -356,8 +357,8 @@ class VisionTransformer(nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
x = self.forward_features(x)
|
||||
if isinstance(x, tuple):
|
||||
x, x_dist = self.head(x[0]), self.head_dist(x[1])
|
||||
if self.head_dist is not None:
|
||||
x, x_dist = self.head(x[0]), self.head_dist(x[1]) # x must be a tuple
|
||||
if self.training and not torch.jit.is_scripting():
|
||||
# during inference, return the average of both classifier predictions
|
||||
return x, x_dist
|
||||
|
@ -145,6 +145,12 @@ def vit_base_r50_s16_224_in21k(pretrained=False, **kwargs):
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_base_resnet50_224_in21k(pretrained=False, **kwargs):
|
||||
# NOTE this is forwarding to model def above for backwards compatibility
|
||||
return vit_base_r50_s16_224_in21k(pretrained=pretrained, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_base_r50_s16_384(pretrained=False, **kwargs):
|
||||
""" R50+ViT-B/16 hybrid from original paper (https://arxiv.org/abs/2010.11929).
|
||||
@ -157,6 +163,12 @@ def vit_base_r50_s16_384(pretrained=False, **kwargs):
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_base_resnet50_384(pretrained=False, **kwargs):
|
||||
# NOTE this is forwarding to model def above for backwards compatibility
|
||||
return vit_base_r50_s16_384(pretrained=pretrained, **kwargs)
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_tiny_r_s16_p8_224(pretrained=False, **kwargs):
|
||||
""" R+ViT-Ti/S16 w/ 8x8 patch hybrid @ 224 x 224.
|
||||
|
@ -10,4 +10,4 @@ from .radam import RAdam
|
||||
from .rmsprop_tf import RMSpropTF
|
||||
from .sgdp import SGDP
|
||||
|
||||
from .optim_factory import create_optimizer, optimizer_kwargs
|
||||
from .optim_factory import create_optimizer, create_optimizer_v2, optimizer_kwargs
|
@ -55,7 +55,21 @@ def optimizer_kwargs(cfg):
|
||||
return kwargs
|
||||
|
||||
|
||||
def create_optimizer(
|
||||
def create_optimizer(args, model, filter_bias_and_bn=True):
|
||||
""" Legacy optimizer factory for backwards compatibility.
|
||||
NOTE: Use create_optimizer_v2 for new code.
|
||||
"""
|
||||
opt_args = dict(lr=args.lr, weight_decay=args.weight_decay, momentum=args.momentum)
|
||||
if hasattr(args, 'opt_eps') and args.opt_eps is not None:
|
||||
opt_args['eps'] = args.opt_eps
|
||||
if hasattr(args, 'opt_betas') and args.opt_betas is not None:
|
||||
opt_args['betas'] = args.opt_betas
|
||||
if hasattr(args, 'opt_args') and args.opt_args is not None:
|
||||
opt_args.update(args.opt_args)
|
||||
return create_optimizer_v2(model, opt_name=args.opt, filter_bias_and_bn=filter_bias_and_bn, **opt_args)
|
||||
|
||||
|
||||
def create_optimizer_v2(
|
||||
model: nn.Module,
|
||||
opt_name: str = 'sgd',
|
||||
lr: Optional[float] = None,
|
||||
|
@ -1 +1 @@
|
||||
__version__ = '0.4.6'
|
||||
__version__ = '0.4.7'
|
||||
|
4
train.py
4
train.py
@ -33,7 +33,7 @@ from timm.models import create_model, safe_model_name, resume_checkpoint, load_c
|
||||
convert_splitbn_model, model_parameters
|
||||
from timm.utils import *
|
||||
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy, JsdCrossEntropy
|
||||
from timm.optim import create_optimizer, optimizer_kwargs
|
||||
from timm.optim import create_optimizer_v2, optimizer_kwargs
|
||||
from timm.scheduler import create_scheduler
|
||||
from timm.utils import ApexScaler, NativeScaler
|
||||
|
||||
@ -389,7 +389,7 @@ def main():
|
||||
assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model'
|
||||
model = torch.jit.script(model)
|
||||
|
||||
optimizer = create_optimizer(model, **optimizer_kwargs(cfg=args))
|
||||
optimizer = create_optimizer_v2(model, **optimizer_kwargs(cfg=args))
|
||||
|
||||
# setup automatic mixed-precision (AMP) loss scaling and op casting
|
||||
amp_autocast = suppress # do nothing
|
||||
|
Loading…
x
Reference in New Issue
Block a user