mirror of https://github.com/JDAI-CV/fast-reid.git
105 lines
3.8 KiB
Python
105 lines
3.8 KiB
Python
|
# encoding: utf-8
|
||
|
"""
|
||
|
@author: liaoxingyu
|
||
|
@contact: sherlockliao01@gmail.com
|
||
|
"""
|
||
|
|
||
|
import itertools
|
||
|
import torch
|
||
|
from data.prefetcher import data_prefetcher
|
||
|
|
||
|
BN_MODULE_TYPES = (
|
||
|
torch.nn.BatchNorm1d,
|
||
|
torch.nn.BatchNorm2d,
|
||
|
torch.nn.BatchNorm3d,
|
||
|
torch.nn.SyncBatchNorm,
|
||
|
)
|
||
|
|
||
|
|
||
|
@torch.no_grad()
|
||
|
def update_bn_stats(model, data_loader, num_iters: int = 200):
|
||
|
"""
|
||
|
Recompute and update the batch norm stats to make them more precise. During
|
||
|
training both BN stats and the weight are changing after every iteration, so
|
||
|
the running average can not precisely reflect the actual stats of the
|
||
|
current model.
|
||
|
In this function, the BN stats are recomputed with fixed weights, to make
|
||
|
the running average more precise. Specifically, it computes the true average
|
||
|
of per-batch mean/variance instead of the running average.
|
||
|
Args:
|
||
|
model (nn.Module): the model whose bn stats will be recomputed.
|
||
|
Note that:
|
||
|
1. This function will not alter the training mode of the given model.
|
||
|
Users are responsible for setting the layers that needs
|
||
|
precise-BN to training mode, prior to calling this function.
|
||
|
2. Be careful if your models contain other stateful layers in
|
||
|
addition to BN, i.e. layers whose state can change in forward
|
||
|
iterations. This function will alter their state. If you wish
|
||
|
them unchanged, you need to either pass in a submodule without
|
||
|
those layers, or backup the states.
|
||
|
data_loader (iterator): an iterator. Produce data as inputs to the model.
|
||
|
num_iters (int): number of iterations to compute the stats.
|
||
|
"""
|
||
|
bn_layers = get_bn_modules(model)
|
||
|
|
||
|
if len(bn_layers) == 0:
|
||
|
return
|
||
|
|
||
|
# In order to make the running stats only reflect the current batch, the
|
||
|
# momentum is disabled.
|
||
|
# bn.running_mean = (1 - momentum) * bn.running_mean + momentum * batch_mean
|
||
|
# Setting the momentum to 1.0 to compute the stats without momentum.
|
||
|
momentum_actual = [bn.momentum for bn in bn_layers]
|
||
|
for bn in bn_layers:
|
||
|
bn.momentum = 1.0
|
||
|
|
||
|
# Note that running_var actually means "running average of variance"
|
||
|
running_mean = [torch.zeros_like(bn.running_mean) for bn in bn_layers]
|
||
|
running_var = [torch.zeros_like(bn.running_var) for bn in bn_layers]
|
||
|
|
||
|
ind = 0
|
||
|
num_epoch = num_iters // len(data_loader) + 1
|
||
|
for _ in range(num_epoch):
|
||
|
prefetcher = data_prefetcher(data_loader)
|
||
|
batch = prefetcher.next()
|
||
|
while batch[0] is not None:
|
||
|
model(batch[0], batch[1])
|
||
|
|
||
|
for i, bn in enumerate(bn_layers):
|
||
|
# Accumulates the bn stats.
|
||
|
running_mean[i] += (bn.running_mean - running_mean[i]) / (ind + 1)
|
||
|
running_var[i] += (bn.running_var - running_var[i]) / (ind + 1)
|
||
|
# We compute the "average of variance" across iterations.
|
||
|
|
||
|
if ind == (num_iters - 1):
|
||
|
print(f"update_bn_stats is running for {num_iters} iterations.")
|
||
|
break
|
||
|
|
||
|
ind += 1
|
||
|
batch = prefetcher.next()
|
||
|
|
||
|
for i, bn in enumerate(bn_layers):
|
||
|
# Sets the precise bn stats.
|
||
|
bn.running_mean = running_mean[i]
|
||
|
bn.running_var = running_var[i]
|
||
|
bn.momentum = momentum_actual[i]
|
||
|
|
||
|
|
||
|
def get_bn_modules(model):
|
||
|
"""
|
||
|
Find all BatchNorm (BN) modules that are in training mode. See
|
||
|
fvcore.precise_bn.BN_MODULE_TYPES for a list of all modules that are
|
||
|
included in this search.
|
||
|
Args:
|
||
|
model (nn.Module): a model possibly containing BN modules.
|
||
|
Returns:
|
||
|
list[nn.Module]: all BN modules in the model.
|
||
|
"""
|
||
|
# Finds all the bn layers.
|
||
|
bn_layers = [
|
||
|
m
|
||
|
for m in model.modules()
|
||
|
if m.training and isinstance(m, BN_MODULE_TYPES)
|
||
|
]
|
||
|
return bn_layers
|