diff --git a/train.py b/train.py index c55cbfb3..a47f1b4d 100644 --- a/train.py +++ b/train.py @@ -393,7 +393,7 @@ def main(): lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, use_amp=use_amp, model_ema=model_ema) - if args.distributed and args.dist_bn and args.dist_bn in ('broadcast', 'reduce'): + if args.distributed and args.dist_bn in ('broadcast', 'reduce'): if args.local_rank == 0: logging.info("Distributing BatchNorm running means and vars") distribute_bn(model, args.world_size, args.dist_bn == 'reduce') @@ -401,8 +401,8 @@ def main(): eval_metrics = validate(model, loader_eval, validate_loss_fn, args) if model_ema is not None and not args.model_ema_force_cpu: - if args.distributed and args.reduce_bn: - distribute_bn(model_ema, args.world_size) + if args.distributed and args.dist_bn in ('broadcast', 'reduce'): + distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce') ema_eval_metrics = validate( model_ema.ema, loader_eval, validate_loss_fn, args, log_suffix=' (EMA)')