mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add --fast-norm arg to benchmark.py, train.py, validate.py
This commit is contained in:
parent
769ab4b98a
commit
ff6a919cf5
@ -19,7 +19,7 @@ import torch.nn as nn
|
||||
import torch.nn.parallel
|
||||
|
||||
from timm.data import resolve_data_config
|
||||
from timm.models import create_model, is_model, list_models
|
||||
from timm.models import create_model, is_model, list_models, set_fast_norm
|
||||
from timm.optim import create_optimizer_v2
|
||||
from timm.utils import setup_default_logging, set_jit_fuser, decay_batch_step, check_batch_size_retry
|
||||
|
||||
@ -109,7 +109,8 @@ scripting_group.add_argument('--torchscript', dest='torchscript', action='store_
|
||||
help='convert model torchscript for inference')
|
||||
scripting_group.add_argument('--aot-autograd', default=False, action='store_true',
|
||||
help="Enable AOT Autograd support. (It's recommended to use this option with `--fuser nvfuser` together)")
|
||||
|
||||
scripting_group.add_argument('--fast-norm', default=False, action='store_true',
|
||||
help='enable experimental fast-norm')
|
||||
|
||||
# train optimizer parameters
|
||||
parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER',
|
||||
@ -598,6 +599,9 @@ def main():
|
||||
model_cfgs = []
|
||||
model_names = []
|
||||
|
||||
if args.fast_norm:
|
||||
set_fast_norm()
|
||||
|
||||
if args.model_list:
|
||||
args.model = ''
|
||||
with open(args.model_list) as f:
|
||||
|
@ -69,5 +69,6 @@ from .helpers import load_checkpoint, resume_checkpoint, model_parameters
|
||||
from .layers import TestTimePoolHead, apply_test_time_pool
|
||||
from .layers import convert_splitbn_model, convert_sync_batchnorm
|
||||
from .layers import is_scriptable, is_exportable, set_scriptable, set_exportable, is_no_jit, set_no_jit
|
||||
from .layers import set_fast_norm
|
||||
from .registry import register_model, model_entrypoint, list_models, is_model, list_modules, is_model_in_modules,\
|
||||
is_model_pretrained, get_pretrained_cfg, has_pretrained_cfg_key, is_pretrained_cfg_key, get_pretrained_cfg_value
|
||||
|
6
train.py
6
train.py
@ -33,7 +33,7 @@ from timm.data import create_dataset, create_loader, resolve_data_config, Mixup,
|
||||
from timm.loss import JsdCrossEntropy, SoftTargetCrossEntropy, BinaryCrossEntropy, \
|
||||
LabelSmoothingCrossEntropy
|
||||
from timm.models import create_model, safe_model_name, resume_checkpoint, load_checkpoint, \
|
||||
convert_splitbn_model, convert_sync_batchnorm, model_parameters
|
||||
convert_splitbn_model, convert_sync_batchnorm, model_parameters, set_fast_norm
|
||||
from timm.optim import create_optimizer_v2, optimizer_kwargs
|
||||
from timm.scheduler import create_scheduler
|
||||
from timm.utils import ApexScaler, NativeScaler
|
||||
@ -135,6 +135,8 @@ scripting_group.add_argument('--aot-autograd', default=False, action='store_true
|
||||
help="Enable AOT Autograd support. (It's recommended to use this option with `--fuser nvfuser` together)")
|
||||
group.add_argument('--fuser', default='', type=str,
|
||||
help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
|
||||
group.add_argument('--fast-norm', default=False, action='store_true',
|
||||
help='enable experimental fast-norm')
|
||||
group.add_argument('--grad-checkpointing', action='store_true', default=False,
|
||||
help='Enable gradient checkpointing through model blocks/stages')
|
||||
|
||||
@ -395,6 +397,8 @@ def main():
|
||||
|
||||
if args.fuser:
|
||||
utils.set_jit_fuser(args.fuser)
|
||||
if args.fast_norm:
|
||||
set_fast_norm()
|
||||
|
||||
model = create_model(
|
||||
args.model,
|
||||
|
@ -20,7 +20,7 @@ import torch.nn.parallel
|
||||
from collections import OrderedDict
|
||||
from contextlib import suppress
|
||||
|
||||
from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models
|
||||
from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models, set_fast_norm
|
||||
from timm.data import create_dataset, create_loader, resolve_data_config, RealLabelsImagenet
|
||||
from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging, set_jit_fuser,\
|
||||
decay_batch_step, check_batch_size_retry
|
||||
@ -117,6 +117,8 @@ scripting_group.add_argument('--aot-autograd', default=False, action='store_true
|
||||
help="Enable AOT Autograd support. (It's recommended to use this option with `--fuser nvfuser` together)")
|
||||
parser.add_argument('--fuser', default='', type=str,
|
||||
help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
|
||||
parser.add_argument('--fast-norm', default=False, action='store_true',
|
||||
help='enable experimental fast-norm')
|
||||
parser.add_argument('--results-file', default='', type=str, metavar='FILENAME',
|
||||
help='Output csv file for validation results (summary)')
|
||||
parser.add_argument('--real-labels', default='', type=str, metavar='FILENAME',
|
||||
@ -150,6 +152,8 @@ def validate(args):
|
||||
|
||||
if args.fuser:
|
||||
set_jit_fuser(args.fuser)
|
||||
if args.fast_norm:
|
||||
set_fast_norm()
|
||||
|
||||
# create model
|
||||
model = create_model(
|
||||
|
Loading…
x
Reference in New Issue
Block a user