diff --git a/mmcv/ops/sync_bn.py b/mmcv/ops/sync_bn.py index 5c476cadc..ca709b795 100644 --- a/mmcv/ops/sync_bn.py +++ b/mmcv/ops/sync_bn.py @@ -109,88 +109,99 @@ class SyncBatchNormFunction(Function): None, None, None, None -class SyncBatchNorm(Module): +if dist.is_available(): - def __init__(self, - num_features, - eps=1e-5, - momentum=0.1, - affine=True, - track_running_stats=True, - group=dist.group.WORLD): - super(SyncBatchNorm, self).__init__() - self.num_features = num_features - self.eps = eps - self.momentum = momentum - self.affine = affine - self.track_running_stats = track_running_stats - self.group = group - self.group_size = dist.get_world_size(group) - if self.affine: - self.weight = Parameter(torch.Tensor(num_features)) - self.bias = Parameter(torch.Tensor(num_features)) - else: - self.register_parameter('weight', None) - self.register_parameter('bias', None) - if self.track_running_stats: - self.register_buffer('running_mean', torch.zeros(num_features)) - self.register_buffer('running_var', torch.ones(num_features)) - self.register_buffer('num_batches_tracked', - torch.tensor(0, dtype=torch.long)) - else: - self.register_buffer('running_mean', None) - self.register_buffer('running_var', None) - self.register_buffer('num_batches_tracked', None) - self.reset_parameters() + class SyncBatchNorm(Module): - def reset_running_stats(self): - if self.track_running_stats: - self.running_mean.zero_() - self.running_var.fill_(1) - self.num_batches_tracked.zero_() + def __init__(self, + num_features, + eps=1e-5, + momentum=0.1, + affine=True, + track_running_stats=True, + group=dist.group.WORLD): + super(SyncBatchNorm, self).__init__() + self.num_features = num_features + self.eps = eps + self.momentum = momentum + self.affine = affine + self.track_running_stats = track_running_stats + self.group = group + self.group_size = dist.get_world_size(group) + if self.affine: + self.weight = Parameter(torch.Tensor(num_features)) + self.bias = Parameter(torch.Tensor(num_features)) + else: + self.register_parameter('weight', None) + self.register_parameter('bias', None) + if self.track_running_stats: + self.register_buffer('running_mean', torch.zeros(num_features)) + self.register_buffer('running_var', torch.ones(num_features)) + self.register_buffer('num_batches_tracked', + torch.tensor(0, dtype=torch.long)) + else: + self.register_buffer('running_mean', None) + self.register_buffer('running_var', None) + self.register_buffer('num_batches_tracked', None) + self.reset_parameters() - def reset_parameters(self): - self.reset_running_stats() - if self.affine: - self.weight.data.uniform_() # pytorch use ones_() - self.bias.data.zero_() + def reset_running_stats(self): + if self.track_running_stats: + self.running_mean.zero_() + self.running_var.fill_(1) + self.num_batches_tracked.zero_() - def forward(self, input): - if input.dim() < 2: - raise ValueError( - f'expected at least 2D input, got {input.dim()}D input') - if self.momentum is None: - exponential_average_factor = 0.0 - else: - exponential_average_factor = self.momentum + def reset_parameters(self): + self.reset_running_stats() + if self.affine: + self.weight.data.uniform_() # pytorch use ones_() + self.bias.data.zero_() - if self.training and self.track_running_stats: - if self.num_batches_tracked is not None: - self.num_batches_tracked += 1 - if self.momentum is None: # use cumulative moving average - exponential_average_factor = 1.0 / float( - self.num_batches_tracked) - else: # use exponential moving average - exponential_average_factor = self.momentum + def forward(self, input): + if input.dim() < 2: + raise ValueError( + f'expected at least 2D input, got {input.dim()}D input') + if self.momentum is None: + exponential_average_factor = 0.0 + else: + exponential_average_factor = self.momentum - if self.training or not self.track_running_stats: - return SyncBatchNormFunction.apply(input, self.running_mean, - self.running_var, self.weight, - self.bias, - exponential_average_factor, - self.eps, self.group, - self.group_size) - else: - return F.batch_norm(input, self.running_mean, self.running_var, - self.weight, self.bias, False, - exponential_average_factor, self.eps) + if self.training and self.track_running_stats: + if self.num_batches_tracked is not None: + self.num_batches_tracked += 1 + if self.momentum is None: # use cumulative moving average + exponential_average_factor = 1.0 / float( + self.num_batches_tracked) + else: # use exponential moving average + exponential_average_factor = self.momentum - def __repr__(self): - s = self.__class__.__name__ - s += f'({self.num_features}, ' - s += f'eps={self.eps}, ' - s += f'momentum={self.momentum}, ' - s += f'affine={self.affine}, ' - s += f'track_running_stats={self.track_running_stats}, ' - s += f'group_size={self.group_size})' - return s + if self.training or not self.track_running_stats: + return SyncBatchNormFunction.apply(input, self.running_mean, + self.running_var, + self.weight, self.bias, + exponential_average_factor, + self.eps, self.group, + self.group_size) + else: + return F.batch_norm(input, self.running_mean, self.running_var, + self.weight, self.bias, False, + exponential_average_factor, self.eps) + + def __repr__(self): + s = self.__class__.__name__ + s += f'({self.num_features}, ' + s += f'eps={self.eps}, ' + s += f'momentum={self.momentum}, ' + s += f'affine={self.affine}, ' + s += f'track_running_stats={self.track_running_stats}, ' + s += f'group_size={self.group_size})' + return s + +else: + + class SyncBatchNorm(Module): + + def __init__(self, *args, **kwargs): + raise NotImplementedError( + 'SyncBatchNorm is not supported in this OS since the ' + 'distributed package is not available')