diff --git a/main.py b/main.py index 2d93567..5ff017b 100644 --- a/main.py +++ b/main.py @@ -303,7 +303,7 @@ def main(args): linear_scaled_lr = args.lr * args.batch_size * utils.get_world_size() / 512.0 args.lr = linear_scaled_lr - optimizer = create_optimizer(args, model) + optimizer = create_optimizer(args, model_without_ddp) loss_scaler = NativeScaler() lr_scheduler, _ = create_scheduler(args, optimizer)