mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add --torchcompile-mode args to train, validation, inference, benchmark scripts
This commit is contained in:
parent
14d55a7cf3
commit
4d4bdd64a9
@ -120,6 +120,8 @@ parser.add_argument('--fast-norm', default=False, action='store_true',
|
|||||||
parser.add_argument('--reparam', default=False, action='store_true',
|
parser.add_argument('--reparam', default=False, action='store_true',
|
||||||
help='Reparameterize model')
|
help='Reparameterize model')
|
||||||
parser.add_argument('--model-kwargs', nargs='*', default={}, action=ParseKwargs)
|
parser.add_argument('--model-kwargs', nargs='*', default={}, action=ParseKwargs)
|
||||||
|
parser.add_argument('--torchcompile-mode', type=str, default=None,
|
||||||
|
help="torch.compile mode (default: None).")
|
||||||
|
|
||||||
# codegen (model compilation) options
|
# codegen (model compilation) options
|
||||||
scripting_group = parser.add_mutually_exclusive_group()
|
scripting_group = parser.add_mutually_exclusive_group()
|
||||||
@ -224,6 +226,7 @@ class BenchmarkRunner:
|
|||||||
device='cuda',
|
device='cuda',
|
||||||
torchscript=False,
|
torchscript=False,
|
||||||
torchcompile=None,
|
torchcompile=None,
|
||||||
|
torchcompile_mode=None,
|
||||||
aot_autograd=False,
|
aot_autograd=False,
|
||||||
reparam=False,
|
reparam=False,
|
||||||
precision='float32',
|
precision='float32',
|
||||||
@ -278,7 +281,7 @@ class BenchmarkRunner:
|
|||||||
elif torchcompile:
|
elif torchcompile:
|
||||||
assert has_compile, 'A version of torch w/ torch.compile() is required, possibly a nightly.'
|
assert has_compile, 'A version of torch w/ torch.compile() is required, possibly a nightly.'
|
||||||
torch._dynamo.reset()
|
torch._dynamo.reset()
|
||||||
self.model = torch.compile(self.model, backend=torchcompile)
|
self.model = torch.compile(self.model, backend=torchcompile, mode=torchcompile_mode)
|
||||||
self.compiled = True
|
self.compiled = True
|
||||||
elif aot_autograd:
|
elif aot_autograd:
|
||||||
assert has_functorch, "functorch is needed for --aot-autograd"
|
assert has_functorch, "functorch is needed for --aot-autograd"
|
||||||
|
@ -114,6 +114,8 @@ parser.add_argument('--amp-dtype', default='float16', type=str,
|
|||||||
parser.add_argument('--fuser', default='', type=str,
|
parser.add_argument('--fuser', default='', type=str,
|
||||||
help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
|
help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
|
||||||
parser.add_argument('--model-kwargs', nargs='*', default={}, action=ParseKwargs)
|
parser.add_argument('--model-kwargs', nargs='*', default={}, action=ParseKwargs)
|
||||||
|
parser.add_argument('--torchcompile-mode', type=str, default=None,
|
||||||
|
help="torch.compile mode (default: None).")
|
||||||
|
|
||||||
scripting_group = parser.add_mutually_exclusive_group()
|
scripting_group = parser.add_mutually_exclusive_group()
|
||||||
scripting_group.add_argument('--torchscript', default=False, action='store_true',
|
scripting_group.add_argument('--torchscript', default=False, action='store_true',
|
||||||
@ -216,7 +218,7 @@ def main():
|
|||||||
elif args.torchcompile:
|
elif args.torchcompile:
|
||||||
assert has_compile, 'A version of torch w/ torch.compile() is required for --compile, possibly a nightly.'
|
assert has_compile, 'A version of torch w/ torch.compile() is required for --compile, possibly a nightly.'
|
||||||
torch._dynamo.reset()
|
torch._dynamo.reset()
|
||||||
model = torch.compile(model, backend=args.torchcompile)
|
model = torch.compile(model, backend=args.torchcompile, mode=args.torchcompile_mode)
|
||||||
elif args.aot_autograd:
|
elif args.aot_autograd:
|
||||||
assert has_functorch, "functorch is needed for --aot-autograd"
|
assert has_functorch, "functorch is needed for --aot-autograd"
|
||||||
model = memory_efficient_fusion(model)
|
model = memory_efficient_fusion(model)
|
||||||
|
4
train.py
4
train.py
@ -161,6 +161,8 @@ group.add_argument('--head-init-scale', default=None, type=float,
|
|||||||
help='Head initialization scale')
|
help='Head initialization scale')
|
||||||
group.add_argument('--head-init-bias', default=None, type=float,
|
group.add_argument('--head-init-bias', default=None, type=float,
|
||||||
help='Head initialization bias value')
|
help='Head initialization bias value')
|
||||||
|
group.add_argument('--torchcompile-mode', type=str, default=None,
|
||||||
|
help="torch.compile mode (default: None).")
|
||||||
|
|
||||||
# scripting / codegen
|
# scripting / codegen
|
||||||
scripting_group = group.add_mutually_exclusive_group()
|
scripting_group = group.add_mutually_exclusive_group()
|
||||||
@ -627,7 +629,7 @@ def main():
|
|||||||
if args.torchcompile:
|
if args.torchcompile:
|
||||||
# torch compile should be done after DDP
|
# torch compile should be done after DDP
|
||||||
assert has_compile, 'A version of torch w/ torch.compile() is required for --compile, possibly a nightly.'
|
assert has_compile, 'A version of torch w/ torch.compile() is required for --compile, possibly a nightly.'
|
||||||
model = torch.compile(model, backend=args.torchcompile)
|
model = torch.compile(model, backend=args.torchcompile, mode=args.torchcompile_mode)
|
||||||
|
|
||||||
# create the train and eval datasets
|
# create the train and eval datasets
|
||||||
if args.data and not args.data_dir:
|
if args.data and not args.data_dir:
|
||||||
|
@ -139,7 +139,8 @@ parser.add_argument('--fast-norm', default=False, action='store_true',
|
|||||||
parser.add_argument('--reparam', default=False, action='store_true',
|
parser.add_argument('--reparam', default=False, action='store_true',
|
||||||
help='Reparameterize model')
|
help='Reparameterize model')
|
||||||
parser.add_argument('--model-kwargs', nargs='*', default={}, action=ParseKwargs)
|
parser.add_argument('--model-kwargs', nargs='*', default={}, action=ParseKwargs)
|
||||||
|
parser.add_argument('--torchcompile-mode', type=str, default=None,
|
||||||
|
help="torch.compile mode (default: None).")
|
||||||
|
|
||||||
scripting_group = parser.add_mutually_exclusive_group()
|
scripting_group = parser.add_mutually_exclusive_group()
|
||||||
scripting_group.add_argument('--torchscript', default=False, action='store_true',
|
scripting_group.add_argument('--torchscript', default=False, action='store_true',
|
||||||
@ -246,7 +247,7 @@ def validate(args):
|
|||||||
elif args.torchcompile:
|
elif args.torchcompile:
|
||||||
assert has_compile, 'A version of torch w/ torch.compile() is required for --compile, possibly a nightly.'
|
assert has_compile, 'A version of torch w/ torch.compile() is required for --compile, possibly a nightly.'
|
||||||
torch._dynamo.reset()
|
torch._dynamo.reset()
|
||||||
model = torch.compile(model, backend=args.torchcompile)
|
model = torch.compile(model, backend=args.torchcompile, mode=args.torchcompile_mode)
|
||||||
elif args.aot_autograd:
|
elif args.aot_autograd:
|
||||||
assert has_functorch, "functorch is needed for --aot-autograd"
|
assert has_functorch, "functorch is needed for --aot-autograd"
|
||||||
model = memory_efficient_fusion(model)
|
model = memory_efficient_fusion(model)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user