diff --git a/mmcv/ops/sync_bn.py b/mmcv/ops/sync_bn.py index e8bdd55c7..941910000 100644 --- a/mmcv/ops/sync_bn.py +++ b/mmcv/ops/sync_bn.py @@ -126,7 +126,8 @@ class SyncBatchNorm(Module): self.momentum = momentum self.affine = affine self.track_running_stats = track_running_stats - self.group = dist.group.WORLD if group is None else group + group = dist.group.WORLD if group is None else group + self.group = group self.group_size = dist.get_world_size(group) if self.affine: self.weight = Parameter(torch.Tensor(num_features))