mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add --torchcompile-mode args to train, validation, inference, benchmark scripts
This commit is contained in:
parent
14d55a7cf3
commit
4d4bdd64a9
@ -120,6 +120,8 @@ parser.add_argument('--fast-norm', default=False, action='store_true',
|
||||
parser.add_argument('--reparam', default=False, action='store_true',
|
||||
help='Reparameterize model')
|
||||
parser.add_argument('--model-kwargs', nargs='*', default={}, action=ParseKwargs)
|
||||
parser.add_argument('--torchcompile-mode', type=str, default=None,
|
||||
help="torch.compile mode (default: None).")
|
||||
|
||||
# codegen (model compilation) options
|
||||
scripting_group = parser.add_mutually_exclusive_group()
|
||||
@ -224,6 +226,7 @@ class BenchmarkRunner:
|
||||
device='cuda',
|
||||
torchscript=False,
|
||||
torchcompile=None,
|
||||
torchcompile_mode=None,
|
||||
aot_autograd=False,
|
||||
reparam=False,
|
||||
precision='float32',
|
||||
@ -278,7 +281,7 @@ class BenchmarkRunner:
|
||||
elif torchcompile:
|
||||
assert has_compile, 'A version of torch w/ torch.compile() is required, possibly a nightly.'
|
||||
torch._dynamo.reset()
|
||||
self.model = torch.compile(self.model, backend=torchcompile)
|
||||
self.model = torch.compile(self.model, backend=torchcompile, mode=torchcompile_mode)
|
||||
self.compiled = True
|
||||
elif aot_autograd:
|
||||
assert has_functorch, "functorch is needed for --aot-autograd"
|
||||
|
@ -114,6 +114,8 @@ parser.add_argument('--amp-dtype', default='float16', type=str,
|
||||
parser.add_argument('--fuser', default='', type=str,
|
||||
help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
|
||||
parser.add_argument('--model-kwargs', nargs='*', default={}, action=ParseKwargs)
|
||||
parser.add_argument('--torchcompile-mode', type=str, default=None,
|
||||
help="torch.compile mode (default: None).")
|
||||
|
||||
scripting_group = parser.add_mutually_exclusive_group()
|
||||
scripting_group.add_argument('--torchscript', default=False, action='store_true',
|
||||
@ -216,7 +218,7 @@ def main():
|
||||
elif args.torchcompile:
|
||||
assert has_compile, 'A version of torch w/ torch.compile() is required for --compile, possibly a nightly.'
|
||||
torch._dynamo.reset()
|
||||
model = torch.compile(model, backend=args.torchcompile)
|
||||
model = torch.compile(model, backend=args.torchcompile, mode=args.torchcompile_mode)
|
||||
elif args.aot_autograd:
|
||||
assert has_functorch, "functorch is needed for --aot-autograd"
|
||||
model = memory_efficient_fusion(model)
|
||||
|
4
train.py
4
train.py
@ -161,6 +161,8 @@ group.add_argument('--head-init-scale', default=None, type=float,
|
||||
help='Head initialization scale')
|
||||
group.add_argument('--head-init-bias', default=None, type=float,
|
||||
help='Head initialization bias value')
|
||||
group.add_argument('--torchcompile-mode', type=str, default=None,
|
||||
help="torch.compile mode (default: None).")
|
||||
|
||||
# scripting / codegen
|
||||
scripting_group = group.add_mutually_exclusive_group()
|
||||
@ -627,7 +629,7 @@ def main():
|
||||
if args.torchcompile:
|
||||
# torch compile should be done after DDP
|
||||
assert has_compile, 'A version of torch w/ torch.compile() is required for --compile, possibly a nightly.'
|
||||
model = torch.compile(model, backend=args.torchcompile)
|
||||
model = torch.compile(model, backend=args.torchcompile, mode=args.torchcompile_mode)
|
||||
|
||||
# create the train and eval datasets
|
||||
if args.data and not args.data_dir:
|
||||
|
@ -139,7 +139,8 @@ parser.add_argument('--fast-norm', default=False, action='store_true',
|
||||
parser.add_argument('--reparam', default=False, action='store_true',
|
||||
help='Reparameterize model')
|
||||
parser.add_argument('--model-kwargs', nargs='*', default={}, action=ParseKwargs)
|
||||
|
||||
parser.add_argument('--torchcompile-mode', type=str, default=None,
|
||||
help="torch.compile mode (default: None).")
|
||||
|
||||
scripting_group = parser.add_mutually_exclusive_group()
|
||||
scripting_group.add_argument('--torchscript', default=False, action='store_true',
|
||||
@ -246,7 +247,7 @@ def validate(args):
|
||||
elif args.torchcompile:
|
||||
assert has_compile, 'A version of torch w/ torch.compile() is required for --compile, possibly a nightly.'
|
||||
torch._dynamo.reset()
|
||||
model = torch.compile(model, backend=args.torchcompile)
|
||||
model = torch.compile(model, backend=args.torchcompile, mode=args.torchcompile_mode)
|
||||
elif args.aot_autograd:
|
||||
assert has_functorch, "functorch is needed for --aot-autograd"
|
||||
model = memory_efficient_fusion(model)
|
||||
|
Loading…
x
Reference in New Issue
Block a user