jit trace comparisons snuck into torchscript part of validate.py, fixed

pull/1420/head
Ross Wightman 2022-07-31 21:13:56 -07:00
parent 8ad4bdfa06
commit 56596e4e84
1 changed files with 1 additions and 1 deletions

View File

@ -181,7 +181,7 @@ def validate(args):
if args.torchscript: if args.torchscript:
torch.jit.optimized_execution(True) torch.jit.optimized_execution(True)
model = torch.jit.trace(model, example_inputs=torch.randn((args.batch_size,) + data_config['input_size'])) model = torch.jit.script(model)
if args.aot_autograd: if args.aot_autograd:
assert has_functorch, "functorch is needed for --aot-autograd" assert has_functorch, "functorch is needed for --aot-autograd"
model = memory_efficient_fusion(model) model = memory_efficient_fusion(model)