diff --git a/train.py b/train.py index fda3c184..11b783e8 100755 --- a/train.py +++ b/train.py @@ -892,6 +892,7 @@ def main(): optimizer, train_loss_fn, args, + device=device, lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir,