fast-reid/fastreid/utils/precision_bn.py

95 lines
3.6 KiB
Python

# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
import itertools
import torch
BN_MODULE_TYPES = (
torch.nn.BatchNorm1d,
torch.nn.BatchNorm2d,
torch.nn.BatchNorm3d,
torch.nn.SyncBatchNorm,
)
@torch.no_grad()
def update_bn_stats(model, data_loader, num_iters: int = 200):
"""
Recompute and update the batch norm stats to make them more precise. During
training both BN stats and the weight are changing after every iteration, so
the running average can not precisely reflect the actual stats of the
current model.
In this function, the BN stats are recomputed with fixed weights, to make
the running average more precise. Specifically, it computes the true average
of per-batch mean/variance instead of the running average.
Args:
model (nn.Module): the model whose bn stats will be recomputed.
Note that:
1. This function will not alter the training mode of the given model.
Users are responsible for setting the layers that needs
precise-BN to training mode, prior to calling this function.
2. Be careful if your models contain other stateful layers in
addition to BN, i.e. layers whose state can change in forward
iterations. This function will alter their state. If you wish
them unchanged, you need to either pass in a submodule without
those layers, or backup the states.
data_loader (iterator): an iterator. Produce data as inputs to the model.
num_iters (int): number of iterations to compute the stats.
"""
bn_layers = get_bn_modules(model)
if len(bn_layers) == 0:
return
# In order to make the running stats only reflect the current batch, the
# momentum is disabled.
# bn.running_mean = (1 - momentum) * bn.running_mean + momentum * batch_mean
# Setting the momentum to 1.0 to compute the stats without momentum.
momentum_actual = [bn.momentum for bn in bn_layers]
for bn in bn_layers:
bn.momentum = 1.0
# Note that running_var actually means "running average of variance"
running_mean = [torch.zeros_like(bn.running_mean) for bn in bn_layers]
running_var = [torch.zeros_like(bn.running_var) for bn in bn_layers]
for ind, inputs in enumerate(itertools.islice(data_loader, num_iters)):
inputs['targets'].fill_(-1)
with torch.no_grad(): # No need to backward
model(inputs)
for i, bn in enumerate(bn_layers):
# Accumulates the bn stats.
running_mean[i] += (bn.running_mean - running_mean[i]) / (ind + 1)
running_var[i] += (bn.running_var - running_var[i]) / (ind + 1)
# We compute the "average of variance" across iterations.
assert ind == num_iters - 1, (
"update_bn_stats is meant to run for {} iterations, "
"but the dataloader stops at {} iterations.".format(num_iters, ind)
)
for i, bn in enumerate(bn_layers):
# Sets the precise bn stats.
bn.running_mean = running_mean[i]
bn.running_var = running_var[i]
bn.momentum = momentum_actual[i]
def get_bn_modules(model):
"""
Find all BatchNorm (BN) modules that are in training mode. See
fvcore.precise_bn.BN_MODULE_TYPES for a list of all modules that are
included in this search.
Args:
model (nn.Module): a model possibly containing BN modules.
Returns:
list[nn.Module]: all BN modules in the model.
"""
# Finds all the bn layers.
bn_layers = [
m for m in model.modules() if m.training and isinstance(m, BN_MODULE_TYPES)
]
return bn_layers