mirror of https://github.com/open-mmlab/mmcv.git
parent
b2b42cbd95
commit
b11c56603f
|
@ -109,88 +109,99 @@ class SyncBatchNormFunction(Function):
|
|||
None, None, None, None
|
||||
|
||||
|
||||
class SyncBatchNorm(Module):
|
||||
if dist.is_available():
|
||||
|
||||
def __init__(self,
|
||||
num_features,
|
||||
eps=1e-5,
|
||||
momentum=0.1,
|
||||
affine=True,
|
||||
track_running_stats=True,
|
||||
group=dist.group.WORLD):
|
||||
super(SyncBatchNorm, self).__init__()
|
||||
self.num_features = num_features
|
||||
self.eps = eps
|
||||
self.momentum = momentum
|
||||
self.affine = affine
|
||||
self.track_running_stats = track_running_stats
|
||||
self.group = group
|
||||
self.group_size = dist.get_world_size(group)
|
||||
if self.affine:
|
||||
self.weight = Parameter(torch.Tensor(num_features))
|
||||
self.bias = Parameter(torch.Tensor(num_features))
|
||||
else:
|
||||
self.register_parameter('weight', None)
|
||||
self.register_parameter('bias', None)
|
||||
if self.track_running_stats:
|
||||
self.register_buffer('running_mean', torch.zeros(num_features))
|
||||
self.register_buffer('running_var', torch.ones(num_features))
|
||||
self.register_buffer('num_batches_tracked',
|
||||
torch.tensor(0, dtype=torch.long))
|
||||
else:
|
||||
self.register_buffer('running_mean', None)
|
||||
self.register_buffer('running_var', None)
|
||||
self.register_buffer('num_batches_tracked', None)
|
||||
self.reset_parameters()
|
||||
class SyncBatchNorm(Module):
|
||||
|
||||
def reset_running_stats(self):
|
||||
if self.track_running_stats:
|
||||
self.running_mean.zero_()
|
||||
self.running_var.fill_(1)
|
||||
self.num_batches_tracked.zero_()
|
||||
def __init__(self,
|
||||
num_features,
|
||||
eps=1e-5,
|
||||
momentum=0.1,
|
||||
affine=True,
|
||||
track_running_stats=True,
|
||||
group=dist.group.WORLD):
|
||||
super(SyncBatchNorm, self).__init__()
|
||||
self.num_features = num_features
|
||||
self.eps = eps
|
||||
self.momentum = momentum
|
||||
self.affine = affine
|
||||
self.track_running_stats = track_running_stats
|
||||
self.group = group
|
||||
self.group_size = dist.get_world_size(group)
|
||||
if self.affine:
|
||||
self.weight = Parameter(torch.Tensor(num_features))
|
||||
self.bias = Parameter(torch.Tensor(num_features))
|
||||
else:
|
||||
self.register_parameter('weight', None)
|
||||
self.register_parameter('bias', None)
|
||||
if self.track_running_stats:
|
||||
self.register_buffer('running_mean', torch.zeros(num_features))
|
||||
self.register_buffer('running_var', torch.ones(num_features))
|
||||
self.register_buffer('num_batches_tracked',
|
||||
torch.tensor(0, dtype=torch.long))
|
||||
else:
|
||||
self.register_buffer('running_mean', None)
|
||||
self.register_buffer('running_var', None)
|
||||
self.register_buffer('num_batches_tracked', None)
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
self.reset_running_stats()
|
||||
if self.affine:
|
||||
self.weight.data.uniform_() # pytorch use ones_()
|
||||
self.bias.data.zero_()
|
||||
def reset_running_stats(self):
|
||||
if self.track_running_stats:
|
||||
self.running_mean.zero_()
|
||||
self.running_var.fill_(1)
|
||||
self.num_batches_tracked.zero_()
|
||||
|
||||
def forward(self, input):
|
||||
if input.dim() < 2:
|
||||
raise ValueError(
|
||||
f'expected at least 2D input, got {input.dim()}D input')
|
||||
if self.momentum is None:
|
||||
exponential_average_factor = 0.0
|
||||
else:
|
||||
exponential_average_factor = self.momentum
|
||||
def reset_parameters(self):
|
||||
self.reset_running_stats()
|
||||
if self.affine:
|
||||
self.weight.data.uniform_() # pytorch use ones_()
|
||||
self.bias.data.zero_()
|
||||
|
||||
if self.training and self.track_running_stats:
|
||||
if self.num_batches_tracked is not None:
|
||||
self.num_batches_tracked += 1
|
||||
if self.momentum is None: # use cumulative moving average
|
||||
exponential_average_factor = 1.0 / float(
|
||||
self.num_batches_tracked)
|
||||
else: # use exponential moving average
|
||||
exponential_average_factor = self.momentum
|
||||
def forward(self, input):
|
||||
if input.dim() < 2:
|
||||
raise ValueError(
|
||||
f'expected at least 2D input, got {input.dim()}D input')
|
||||
if self.momentum is None:
|
||||
exponential_average_factor = 0.0
|
||||
else:
|
||||
exponential_average_factor = self.momentum
|
||||
|
||||
if self.training or not self.track_running_stats:
|
||||
return SyncBatchNormFunction.apply(input, self.running_mean,
|
||||
self.running_var, self.weight,
|
||||
self.bias,
|
||||
exponential_average_factor,
|
||||
self.eps, self.group,
|
||||
self.group_size)
|
||||
else:
|
||||
return F.batch_norm(input, self.running_mean, self.running_var,
|
||||
self.weight, self.bias, False,
|
||||
exponential_average_factor, self.eps)
|
||||
if self.training and self.track_running_stats:
|
||||
if self.num_batches_tracked is not None:
|
||||
self.num_batches_tracked += 1
|
||||
if self.momentum is None: # use cumulative moving average
|
||||
exponential_average_factor = 1.0 / float(
|
||||
self.num_batches_tracked)
|
||||
else: # use exponential moving average
|
||||
exponential_average_factor = self.momentum
|
||||
|
||||
def __repr__(self):
|
||||
s = self.__class__.__name__
|
||||
s += f'({self.num_features}, '
|
||||
s += f'eps={self.eps}, '
|
||||
s += f'momentum={self.momentum}, '
|
||||
s += f'affine={self.affine}, '
|
||||
s += f'track_running_stats={self.track_running_stats}, '
|
||||
s += f'group_size={self.group_size})'
|
||||
return s
|
||||
if self.training or not self.track_running_stats:
|
||||
return SyncBatchNormFunction.apply(input, self.running_mean,
|
||||
self.running_var,
|
||||
self.weight, self.bias,
|
||||
exponential_average_factor,
|
||||
self.eps, self.group,
|
||||
self.group_size)
|
||||
else:
|
||||
return F.batch_norm(input, self.running_mean, self.running_var,
|
||||
self.weight, self.bias, False,
|
||||
exponential_average_factor, self.eps)
|
||||
|
||||
def __repr__(self):
|
||||
s = self.__class__.__name__
|
||||
s += f'({self.num_features}, '
|
||||
s += f'eps={self.eps}, '
|
||||
s += f'momentum={self.momentum}, '
|
||||
s += f'affine={self.affine}, '
|
||||
s += f'track_running_stats={self.track_running_stats}, '
|
||||
s += f'group_size={self.group_size})'
|
||||
return s
|
||||
|
||||
else:
|
||||
|
||||
class SyncBatchNorm(Module):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
raise NotImplementedError(
|
||||
'SyncBatchNorm is not supported in this OS since the '
|
||||
'distributed package is not available')
|
||||
|
|
Loading…
Reference in New Issue