mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Merge pull request #1294 from xwang233/add-aot-autograd
Add AOT Autograd support
This commit is contained in:
commit
db8e33c69f
20
benchmark.py
20
benchmark.py
@ -51,6 +51,12 @@ except ImportError as e:
|
||||
FlopCountAnalysis = None
|
||||
has_fvcore_profiling = False
|
||||
|
||||
try:
|
||||
from functorch.compile import memory_efficient_fusion
|
||||
has_functorch = True
|
||||
except ImportError as e:
|
||||
has_functorch = False
|
||||
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
_logger = logging.getLogger('validate')
|
||||
@ -95,10 +101,13 @@ parser.add_argument('--amp', action='store_true', default=False,
|
||||
help='use PyTorch Native AMP for mixed precision training. Overrides --precision arg.')
|
||||
parser.add_argument('--precision', default='float32', type=str,
|
||||
help='Numeric precision. One of (amp, float32, float16, bfloat16, tf32)')
|
||||
parser.add_argument('--torchscript', dest='torchscript', action='store_true',
|
||||
help='convert model torchscript for inference')
|
||||
parser.add_argument('--fuser', default='', type=str,
|
||||
help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
|
||||
scripting_group = parser.add_mutually_exclusive_group()
|
||||
scripting_group.add_argument('--torchscript', dest='torchscript', action='store_true',
|
||||
help='convert model torchscript for inference')
|
||||
scripting_group.add_argument('--aot-autograd', default=False, action='store_true',
|
||||
help="Enable AOT Autograd support. (It's recommended to use this option with `--fuser nvfuser` together)")
|
||||
|
||||
|
||||
# train optimizer parameters
|
||||
@ -188,7 +197,7 @@ def profile_fvcore(model, input_size=(3, 224, 224), batch_size=1, detailed=False
|
||||
|
||||
class BenchmarkRunner:
|
||||
def __init__(
|
||||
self, model_name, detail=False, device='cuda', torchscript=False, precision='float32',
|
||||
self, model_name, detail=False, device='cuda', torchscript=False, aot_autograd=False, precision='float32',
|
||||
fuser='', num_warm_iter=10, num_bench_iter=50, use_train_size=False, **kwargs):
|
||||
self.model_name = model_name
|
||||
self.detail = detail
|
||||
@ -220,11 +229,14 @@ class BenchmarkRunner:
|
||||
if torchscript:
|
||||
self.model = torch.jit.script(self.model)
|
||||
self.scripted = True
|
||||
|
||||
data_config = resolve_data_config(kwargs, model=self.model, use_test_size=not use_train_size)
|
||||
self.input_size = data_config['input_size']
|
||||
self.batch_size = kwargs.pop('batch_size', 256)
|
||||
|
||||
if aot_autograd:
|
||||
assert has_functorch, "functorch is needed for --aot-autograd"
|
||||
self.model = memory_efficient_fusion(self.model)
|
||||
|
||||
self.example_inputs = None
|
||||
self.num_warm_iter = num_warm_iter
|
||||
self.num_bench_iter = num_bench_iter
|
||||
|
15
train.py
15
train.py
@ -61,6 +61,13 @@ try:
|
||||
except ImportError:
|
||||
has_wandb = False
|
||||
|
||||
try:
|
||||
from functorch.compile import memory_efficient_fusion
|
||||
has_functorch = True
|
||||
except ImportError as e:
|
||||
has_functorch = False
|
||||
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
_logger = logging.getLogger('train')
|
||||
|
||||
@ -123,8 +130,11 @@ group.add_argument('-vb', '--validation-batch-size', type=int, default=None, met
|
||||
help='Validation batch size override (default: None)')
|
||||
group.add_argument('--channels-last', action='store_true', default=False,
|
||||
help='Use channels_last memory layout')
|
||||
group.add_argument('--torchscript', dest='torchscript', action='store_true',
|
||||
scripting_group = group.add_mutually_exclusive_group()
|
||||
scripting_group.add_argument('--torchscript', dest='torchscript', action='store_true',
|
||||
help='torch.jit.script the full model')
|
||||
scripting_group.add_argument('--aot-autograd', default=False, action='store_true',
|
||||
help="Enable AOT Autograd support. (It's recommended to use this option with `--fuser nvfuser` together)")
|
||||
group.add_argument('--fuser', default='', type=str,
|
||||
help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
|
||||
group.add_argument('--grad-checkpointing', action='store_true', default=False,
|
||||
@ -445,6 +455,9 @@ def main():
|
||||
assert not use_amp == 'apex', 'Cannot use APEX AMP with torchscripted model'
|
||||
assert not args.sync_bn, 'Cannot use SyncBatchNorm with torchscripted model'
|
||||
model = torch.jit.script(model)
|
||||
if args.aot_autograd:
|
||||
assert has_functorch, "functorch is needed for --aot-autograd"
|
||||
model = memory_efficient_fusion(model)
|
||||
|
||||
optimizer = create_optimizer_v2(model, **optimizer_kwargs(cfg=args))
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user