Add --fuser arg to train/validate/benchmark scripts to select jit fuser type
parent
010b486590
commit
f0f9eccda8
|
@ -21,7 +21,7 @@ from functools import partial
|
|||
from timm.models import create_model, is_model, list_models
|
||||
from timm.optim import create_optimizer_v2
|
||||
from timm.data import resolve_data_config
|
||||
from timm.utils import AverageMeter, setup_default_logging
|
||||
from timm.utils import setup_default_logging, set_jit_fuser
|
||||
|
||||
|
||||
has_apex = False
|
||||
|
@ -95,7 +95,8 @@ parser.add_argument('--precision', default='float32', type=str,
|
|||
help='Numeric precision. One of (amp, float32, float16, bfloat16, tf32)')
|
||||
parser.add_argument('--torchscript', dest='torchscript', action='store_true',
|
||||
help='convert model torchscript for inference')
|
||||
|
||||
parser.add_argument('--fuser', default='', type=str,
|
||||
help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
|
||||
|
||||
|
||||
# train optimizer parameters
|
||||
|
@ -186,7 +187,7 @@ def profile_fvcore(model, input_size=(3, 224, 224), batch_size=1, detailed=False
|
|||
class BenchmarkRunner:
|
||||
def __init__(
|
||||
self, model_name, detail=False, device='cuda', torchscript=False, precision='float32',
|
||||
num_warm_iter=10, num_bench_iter=50, use_train_size=False, **kwargs):
|
||||
fuser='', num_warm_iter=10, num_bench_iter=50, use_train_size=False, **kwargs):
|
||||
self.model_name = model_name
|
||||
self.detail = detail
|
||||
self.device = device
|
||||
|
@ -194,6 +195,8 @@ class BenchmarkRunner:
|
|||
self.channels_last = kwargs.pop('channels_last', False)
|
||||
self.amp_autocast = torch.cuda.amp.autocast if self.use_amp else suppress
|
||||
|
||||
if fuser:
|
||||
set_jit_fuser(fuser)
|
||||
self.model = create_model(
|
||||
model_name,
|
||||
num_classes=kwargs.pop('num_classes', None),
|
||||
|
|
|
@ -3,7 +3,7 @@ from .checkpoint_saver import CheckpointSaver
|
|||
from .clip_grad import dispatch_clip_grad
|
||||
from .cuda import ApexScaler, NativeScaler
|
||||
from .distributed import distribute_bn, reduce_tensor
|
||||
from .jit import set_jit_legacy
|
||||
from .jit import set_jit_legacy, set_jit_fuser
|
||||
from .log import setup_default_logging, FormatterNoInfo
|
||||
from .metrics import AverageMeter, accuracy
|
||||
from .misc import natural_key, add_bool_arg
|
||||
|
|
|
@ -2,6 +2,8 @@
|
|||
|
||||
Hacked together by / Copyright 2020 Ross Wightman
|
||||
"""
|
||||
import os
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
|
@ -16,3 +18,33 @@ def set_jit_legacy():
|
|||
torch._C._jit_set_profiling_mode(False)
|
||||
torch._C._jit_override_can_fuse_on_gpu(True)
|
||||
#torch._C._jit_set_texpr_fuser_enabled(True)
|
||||
|
||||
|
||||
def set_jit_fuser(fuser):
|
||||
if fuser == "te":
|
||||
# default fuser should be == 'te'
|
||||
torch._C._jit_set_profiling_executor(True)
|
||||
torch._C._jit_set_profiling_mode(True)
|
||||
torch._C._jit_override_can_fuse_on_cpu(False)
|
||||
torch._C._jit_override_can_fuse_on_gpu(True)
|
||||
torch._C._jit_set_texpr_fuser_enabled(True)
|
||||
elif fuser == "old" or fuser == "legacy":
|
||||
torch._C._jit_set_profiling_executor(False)
|
||||
torch._C._jit_set_profiling_mode(False)
|
||||
torch._C._jit_override_can_fuse_on_gpu(True)
|
||||
torch._C._jit_set_texpr_fuser_enabled(False)
|
||||
elif fuser == "nvfuser" or fuser == "nvf":
|
||||
os.environ['PYTORCH_CUDA_FUSER_DISABLE_FALLBACK'] = '1'
|
||||
os.environ['PYTORCH_CUDA_FUSER_DISABLE_FMA'] = '1'
|
||||
os.environ['PYTORCH_CUDA_FUSER_JIT_OPT_LEVEL'] = '0'
|
||||
torch._C._jit_set_texpr_fuser_enabled(False)
|
||||
torch._C._jit_set_profiling_executor(True)
|
||||
torch._C._jit_set_profiling_mode(True)
|
||||
torch._C._jit_can_fuse_on_cpu()
|
||||
torch._C._jit_can_fuse_on_gpu()
|
||||
torch._C._jit_override_can_fuse_on_cpu(False)
|
||||
torch._C._jit_override_can_fuse_on_gpu(False)
|
||||
torch._C._jit_set_nvfuser_guard_mode(True)
|
||||
torch._C._jit_set_nvfuser_enabled(True)
|
||||
else:
|
||||
assert False, f"Invalid jit fuser ({fuser})"
|
||||
|
|
5
train.py
5
train.py
|
@ -295,6 +295,8 @@ parser.add_argument('--use-multi-epochs-loader', action='store_true', default=Fa
|
|||
help='use the multi-epochs-loader to save time at the beginning of every epoch')
|
||||
parser.add_argument('--torchscript', dest='torchscript', action='store_true',
|
||||
help='convert model torchscript for inference')
|
||||
parser.add_argument('--fuser', default='', type=str,
|
||||
help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
|
||||
parser.add_argument('--log-wandb', action='store_true', default=False,
|
||||
help='log training and validation metrics to wandb')
|
||||
|
||||
|
@ -364,6 +366,9 @@ def main():
|
|||
|
||||
random_seed(args.seed, args.rank)
|
||||
|
||||
if args.fuser:
|
||||
set_jit_fuser(args.fuser)
|
||||
|
||||
model = create_model(
|
||||
args.model,
|
||||
pretrained=args.pretrained,
|
||||
|
|
10
validate.py
10
validate.py
|
@ -21,7 +21,7 @@ from contextlib import suppress
|
|||
|
||||
from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models
|
||||
from timm.data import create_dataset, create_loader, resolve_data_config, RealLabelsImagenet
|
||||
from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging, set_jit_legacy
|
||||
from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging, set_jit_fuser
|
||||
|
||||
has_apex = False
|
||||
try:
|
||||
|
@ -102,8 +102,8 @@ parser.add_argument('--use-ema', dest='use_ema', action='store_true',
|
|||
help='use ema version of weights if present')
|
||||
parser.add_argument('--torchscript', dest='torchscript', action='store_true',
|
||||
help='convert model torchscript for inference')
|
||||
parser.add_argument('--legacy-jit', dest='legacy_jit', action='store_true',
|
||||
help='use legacy jit mode for pytorch 1.5/1.5.1/1.6 to get back fusion performance')
|
||||
parser.add_argument('--fuser', default='', type=str,
|
||||
help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
|
||||
parser.add_argument('--results-file', default='', type=str, metavar='FILENAME',
|
||||
help='Output csv file for validation results (summary)')
|
||||
parser.add_argument('--real-labels', default='', type=str, metavar='FILENAME',
|
||||
|
@ -133,8 +133,8 @@ def validate(args):
|
|||
else:
|
||||
_logger.info('Validating in float32. AMP not enabled.')
|
||||
|
||||
if args.legacy_jit:
|
||||
set_jit_legacy()
|
||||
if args.fuser:
|
||||
set_jit_fuser(args.fuser)
|
||||
|
||||
# create model
|
||||
model = create_model(
|
||||
|
|
Loading…
Reference in New Issue