mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add --fuser arg to train/validate/benchmark scripts to select jit fuser type
This commit is contained in:
parent
010b486590
commit
f0f9eccda8
@ -21,7 +21,7 @@ from functools import partial
|
|||||||
from timm.models import create_model, is_model, list_models
|
from timm.models import create_model, is_model, list_models
|
||||||
from timm.optim import create_optimizer_v2
|
from timm.optim import create_optimizer_v2
|
||||||
from timm.data import resolve_data_config
|
from timm.data import resolve_data_config
|
||||||
from timm.utils import AverageMeter, setup_default_logging
|
from timm.utils import setup_default_logging, set_jit_fuser
|
||||||
|
|
||||||
|
|
||||||
has_apex = False
|
has_apex = False
|
||||||
@ -95,7 +95,8 @@ parser.add_argument('--precision', default='float32', type=str,
|
|||||||
help='Numeric precision. One of (amp, float32, float16, bfloat16, tf32)')
|
help='Numeric precision. One of (amp, float32, float16, bfloat16, tf32)')
|
||||||
parser.add_argument('--torchscript', dest='torchscript', action='store_true',
|
parser.add_argument('--torchscript', dest='torchscript', action='store_true',
|
||||||
help='convert model torchscript for inference')
|
help='convert model torchscript for inference')
|
||||||
|
parser.add_argument('--fuser', default='', type=str,
|
||||||
|
help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
|
||||||
|
|
||||||
|
|
||||||
# train optimizer parameters
|
# train optimizer parameters
|
||||||
@ -186,7 +187,7 @@ def profile_fvcore(model, input_size=(3, 224, 224), batch_size=1, detailed=False
|
|||||||
class BenchmarkRunner:
|
class BenchmarkRunner:
|
||||||
def __init__(
|
def __init__(
|
||||||
self, model_name, detail=False, device='cuda', torchscript=False, precision='float32',
|
self, model_name, detail=False, device='cuda', torchscript=False, precision='float32',
|
||||||
num_warm_iter=10, num_bench_iter=50, use_train_size=False, **kwargs):
|
fuser='', num_warm_iter=10, num_bench_iter=50, use_train_size=False, **kwargs):
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
self.detail = detail
|
self.detail = detail
|
||||||
self.device = device
|
self.device = device
|
||||||
@ -194,6 +195,8 @@ class BenchmarkRunner:
|
|||||||
self.channels_last = kwargs.pop('channels_last', False)
|
self.channels_last = kwargs.pop('channels_last', False)
|
||||||
self.amp_autocast = torch.cuda.amp.autocast if self.use_amp else suppress
|
self.amp_autocast = torch.cuda.amp.autocast if self.use_amp else suppress
|
||||||
|
|
||||||
|
if fuser:
|
||||||
|
set_jit_fuser(fuser)
|
||||||
self.model = create_model(
|
self.model = create_model(
|
||||||
model_name,
|
model_name,
|
||||||
num_classes=kwargs.pop('num_classes', None),
|
num_classes=kwargs.pop('num_classes', None),
|
||||||
|
@ -3,7 +3,7 @@ from .checkpoint_saver import CheckpointSaver
|
|||||||
from .clip_grad import dispatch_clip_grad
|
from .clip_grad import dispatch_clip_grad
|
||||||
from .cuda import ApexScaler, NativeScaler
|
from .cuda import ApexScaler, NativeScaler
|
||||||
from .distributed import distribute_bn, reduce_tensor
|
from .distributed import distribute_bn, reduce_tensor
|
||||||
from .jit import set_jit_legacy
|
from .jit import set_jit_legacy, set_jit_fuser
|
||||||
from .log import setup_default_logging, FormatterNoInfo
|
from .log import setup_default_logging, FormatterNoInfo
|
||||||
from .metrics import AverageMeter, accuracy
|
from .metrics import AverageMeter, accuracy
|
||||||
from .misc import natural_key, add_bool_arg
|
from .misc import natural_key, add_bool_arg
|
||||||
|
@ -2,6 +2,8 @@
|
|||||||
|
|
||||||
Hacked together by / Copyright 2020 Ross Wightman
|
Hacked together by / Copyright 2020 Ross Wightman
|
||||||
"""
|
"""
|
||||||
|
import os
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
@ -16,3 +18,33 @@ def set_jit_legacy():
|
|||||||
torch._C._jit_set_profiling_mode(False)
|
torch._C._jit_set_profiling_mode(False)
|
||||||
torch._C._jit_override_can_fuse_on_gpu(True)
|
torch._C._jit_override_can_fuse_on_gpu(True)
|
||||||
#torch._C._jit_set_texpr_fuser_enabled(True)
|
#torch._C._jit_set_texpr_fuser_enabled(True)
|
||||||
|
|
||||||
|
|
||||||
|
def set_jit_fuser(fuser):
|
||||||
|
if fuser == "te":
|
||||||
|
# default fuser should be == 'te'
|
||||||
|
torch._C._jit_set_profiling_executor(True)
|
||||||
|
torch._C._jit_set_profiling_mode(True)
|
||||||
|
torch._C._jit_override_can_fuse_on_cpu(False)
|
||||||
|
torch._C._jit_override_can_fuse_on_gpu(True)
|
||||||
|
torch._C._jit_set_texpr_fuser_enabled(True)
|
||||||
|
elif fuser == "old" or fuser == "legacy":
|
||||||
|
torch._C._jit_set_profiling_executor(False)
|
||||||
|
torch._C._jit_set_profiling_mode(False)
|
||||||
|
torch._C._jit_override_can_fuse_on_gpu(True)
|
||||||
|
torch._C._jit_set_texpr_fuser_enabled(False)
|
||||||
|
elif fuser == "nvfuser" or fuser == "nvf":
|
||||||
|
os.environ['PYTORCH_CUDA_FUSER_DISABLE_FALLBACK'] = '1'
|
||||||
|
os.environ['PYTORCH_CUDA_FUSER_DISABLE_FMA'] = '1'
|
||||||
|
os.environ['PYTORCH_CUDA_FUSER_JIT_OPT_LEVEL'] = '0'
|
||||||
|
torch._C._jit_set_texpr_fuser_enabled(False)
|
||||||
|
torch._C._jit_set_profiling_executor(True)
|
||||||
|
torch._C._jit_set_profiling_mode(True)
|
||||||
|
torch._C._jit_can_fuse_on_cpu()
|
||||||
|
torch._C._jit_can_fuse_on_gpu()
|
||||||
|
torch._C._jit_override_can_fuse_on_cpu(False)
|
||||||
|
torch._C._jit_override_can_fuse_on_gpu(False)
|
||||||
|
torch._C._jit_set_nvfuser_guard_mode(True)
|
||||||
|
torch._C._jit_set_nvfuser_enabled(True)
|
||||||
|
else:
|
||||||
|
assert False, f"Invalid jit fuser ({fuser})"
|
||||||
|
5
train.py
5
train.py
@ -295,6 +295,8 @@ parser.add_argument('--use-multi-epochs-loader', action='store_true', default=Fa
|
|||||||
help='use the multi-epochs-loader to save time at the beginning of every epoch')
|
help='use the multi-epochs-loader to save time at the beginning of every epoch')
|
||||||
parser.add_argument('--torchscript', dest='torchscript', action='store_true',
|
parser.add_argument('--torchscript', dest='torchscript', action='store_true',
|
||||||
help='convert model torchscript for inference')
|
help='convert model torchscript for inference')
|
||||||
|
parser.add_argument('--fuser', default='', type=str,
|
||||||
|
help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
|
||||||
parser.add_argument('--log-wandb', action='store_true', default=False,
|
parser.add_argument('--log-wandb', action='store_true', default=False,
|
||||||
help='log training and validation metrics to wandb')
|
help='log training and validation metrics to wandb')
|
||||||
|
|
||||||
@ -364,6 +366,9 @@ def main():
|
|||||||
|
|
||||||
random_seed(args.seed, args.rank)
|
random_seed(args.seed, args.rank)
|
||||||
|
|
||||||
|
if args.fuser:
|
||||||
|
set_jit_fuser(args.fuser)
|
||||||
|
|
||||||
model = create_model(
|
model = create_model(
|
||||||
args.model,
|
args.model,
|
||||||
pretrained=args.pretrained,
|
pretrained=args.pretrained,
|
||||||
|
10
validate.py
10
validate.py
@ -21,7 +21,7 @@ from contextlib import suppress
|
|||||||
|
|
||||||
from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models
|
from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models
|
||||||
from timm.data import create_dataset, create_loader, resolve_data_config, RealLabelsImagenet
|
from timm.data import create_dataset, create_loader, resolve_data_config, RealLabelsImagenet
|
||||||
from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging, set_jit_legacy
|
from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging, set_jit_fuser
|
||||||
|
|
||||||
has_apex = False
|
has_apex = False
|
||||||
try:
|
try:
|
||||||
@ -102,8 +102,8 @@ parser.add_argument('--use-ema', dest='use_ema', action='store_true',
|
|||||||
help='use ema version of weights if present')
|
help='use ema version of weights if present')
|
||||||
parser.add_argument('--torchscript', dest='torchscript', action='store_true',
|
parser.add_argument('--torchscript', dest='torchscript', action='store_true',
|
||||||
help='convert model torchscript for inference')
|
help='convert model torchscript for inference')
|
||||||
parser.add_argument('--legacy-jit', dest='legacy_jit', action='store_true',
|
parser.add_argument('--fuser', default='', type=str,
|
||||||
help='use legacy jit mode for pytorch 1.5/1.5.1/1.6 to get back fusion performance')
|
help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
|
||||||
parser.add_argument('--results-file', default='', type=str, metavar='FILENAME',
|
parser.add_argument('--results-file', default='', type=str, metavar='FILENAME',
|
||||||
help='Output csv file for validation results (summary)')
|
help='Output csv file for validation results (summary)')
|
||||||
parser.add_argument('--real-labels', default='', type=str, metavar='FILENAME',
|
parser.add_argument('--real-labels', default='', type=str, metavar='FILENAME',
|
||||||
@ -133,8 +133,8 @@ def validate(args):
|
|||||||
else:
|
else:
|
||||||
_logger.info('Validating in float32. AMP not enabled.')
|
_logger.info('Validating in float32. AMP not enabled.')
|
||||||
|
|
||||||
if args.legacy_jit:
|
if args.fuser:
|
||||||
set_jit_legacy()
|
set_jit_fuser(args.fuser)
|
||||||
|
|
||||||
# create model
|
# create model
|
||||||
model = create_model(
|
model = create_model(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user