224 lines
8.6 KiB
Python
224 lines
8.6 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
# Adapted from https://github.com/facebookresearch/pycls/blob/f8cd962737e33ce9e19b3083a33551da95c2d9c0/pycls/core/net.py # noqa: E501
|
|
# Original licence: Copyright (c) 2019 Facebook, Inc under the Apache License 2.0 # noqa: E501
|
|
|
|
import itertools
|
|
import logging
|
|
from typing import List, Optional, Sequence, Union
|
|
|
|
import mmengine
|
|
import torch
|
|
import torch.nn as nn
|
|
from mmengine.hooks import Hook
|
|
from mmengine.logging import print_log
|
|
from mmengine.model import is_model_wrapper
|
|
from mmengine.runner import EpochBasedTrainLoop, IterBasedTrainLoop, Runner
|
|
from mmengine.utils import ProgressBar
|
|
from torch.functional import Tensor
|
|
from torch.nn import GroupNorm
|
|
from torch.nn.modules.batchnorm import _BatchNorm
|
|
from torch.nn.modules.instancenorm import _InstanceNorm
|
|
from torch.utils.data import DataLoader
|
|
|
|
from mmcls.registry import HOOKS
|
|
|
|
DATA_BATCH = Optional[Sequence[dict]]
|
|
|
|
|
|
def scaled_all_reduce(tensors: List[Tensor], num_gpus: int) -> List[Tensor]:
|
|
"""Performs the scaled all_reduce operation on the provided tensors.
|
|
|
|
The input tensors are modified in-place. Currently supports only the sum
|
|
reduction operator. The reduced values are scaled by the inverse size of
|
|
the process group.
|
|
|
|
Args:
|
|
tensors (List[torch.Tensor]): The tensors to process.
|
|
num_gpus (int): The number of gpus to use
|
|
Returns:
|
|
List[torch.Tensor]: The processed tensors.
|
|
"""
|
|
# There is no need for reduction in the single-proc case
|
|
if num_gpus == 1:
|
|
return tensors
|
|
# Queue the reductions
|
|
reductions = []
|
|
for tensor in tensors:
|
|
reduction = torch.distributed.all_reduce(tensor, async_op=True)
|
|
reductions.append(reduction)
|
|
# Wait for reductions to finish
|
|
for reduction in reductions:
|
|
reduction.wait()
|
|
# Scale the results
|
|
for tensor in tensors:
|
|
tensor.mul_(1.0 / num_gpus)
|
|
return tensors
|
|
|
|
|
|
@torch.no_grad()
|
|
def update_bn_stats(
|
|
model: nn.Module,
|
|
loader: DataLoader,
|
|
num_samples: int = 8192,
|
|
logger: Optional[Union[logging.Logger, str]] = None) -> None:
|
|
"""Computes precise BN stats on training data.
|
|
|
|
Args:
|
|
model (nn.module): The model whose bn stats will be recomputed.
|
|
loader (DataLoader): PyTorch dataloader._dataloader
|
|
num_samples (int): The number of samples to update the bn stats.
|
|
Defaults to 8192.
|
|
logger (logging.Logger or str, optional): If the type of logger is
|
|
``logging.Logger``, we directly use logger to log messages.
|
|
Some special loggers are:
|
|
- "silent": No message will be printed.
|
|
- "current": Use latest created logger to log message.
|
|
- other str: Instance name of logger. The corresponding logger
|
|
will log message if it has been created, otherwise will raise a
|
|
`ValueError`.
|
|
- None: The `print()` method will be used to print log messages.
|
|
"""
|
|
if is_model_wrapper(model):
|
|
model = model.module
|
|
|
|
# get dist info
|
|
rank, world_size = mmengine.dist.get_dist_info()
|
|
# Compute the number of mini-batches to use, if the size of dataloader is
|
|
# less than num_iters, use all the samples in dataloader.
|
|
num_iter = num_samples // (loader.batch_size * world_size)
|
|
num_iter = min(num_iter, len(loader))
|
|
# Retrieve the BN layers
|
|
bn_layers = [
|
|
m for m in model.modules()
|
|
if m.training and isinstance(m, (_BatchNorm))
|
|
]
|
|
if len(bn_layers) == 0:
|
|
print_log('No BN found in model', logger=logger, level=logging.WARNING)
|
|
return
|
|
print_log(
|
|
f'{len(bn_layers)} BN found, run {num_iter} iters...', logger=logger)
|
|
|
|
# Finds all the other norm layers with training=True.
|
|
other_norm_layers = [
|
|
m for m in model.modules()
|
|
if m.training and isinstance(m, (_InstanceNorm, GroupNorm))
|
|
]
|
|
if len(other_norm_layers) > 0:
|
|
print_log(
|
|
'IN/GN stats will not be updated in PreciseHook.',
|
|
logger=logger,
|
|
level=logging.INFO)
|
|
|
|
# Initialize BN stats storage for computing
|
|
# mean(mean(batch)) and mean(var(batch))
|
|
running_means = [torch.zeros_like(bn.running_mean) for bn in bn_layers]
|
|
running_vars = [torch.zeros_like(bn.running_var) for bn in bn_layers]
|
|
# Remember momentum values
|
|
momentums = [bn.momentum for bn in bn_layers]
|
|
# Set momentum to 1.0 to compute BN stats that reflect the current batch
|
|
for bn in bn_layers:
|
|
bn.momentum = 1.0
|
|
# Average the BN stats for each BN layer over the batches
|
|
if rank == 0:
|
|
prog_bar = ProgressBar(num_iter)
|
|
|
|
for data in itertools.islice(loader, num_iter):
|
|
batch_inputs, data_samples = model.data_preprocessor(data, False)
|
|
model(batch_inputs, data_samples)
|
|
|
|
for i, bn in enumerate(bn_layers):
|
|
running_means[i] += bn.running_mean / num_iter
|
|
running_vars[i] += bn.running_var / num_iter
|
|
if rank == 0:
|
|
prog_bar.update()
|
|
|
|
# Sync BN stats across GPUs (no reduction if 1 GPU used)
|
|
running_means = scaled_all_reduce(running_means, world_size)
|
|
running_vars = scaled_all_reduce(running_vars, world_size)
|
|
# Set BN stats and restore original momentum values
|
|
for i, bn in enumerate(bn_layers):
|
|
bn.running_mean = running_means[i]
|
|
bn.running_var = running_vars[i]
|
|
bn.momentum = momentums[i]
|
|
|
|
|
|
@HOOKS.register_module()
|
|
class PreciseBNHook(Hook):
|
|
"""Precise BN hook.
|
|
|
|
Recompute and update the batch norm stats to make them more precise. During
|
|
training both BN stats and the weight are changing after every iteration,
|
|
so the running average can not precisely reflect the actual stats of the
|
|
current model.
|
|
|
|
With this hook, the BN stats are recomputed with fixed weights, to make the
|
|
running average more precise. Specifically, it computes the true average of
|
|
per-batch mean/variance instead of the running average. See Sec. 3 of the
|
|
paper `Rethinking Batch in BatchNorm <https://arxiv.org/abs/2105.07576>`
|
|
for details.
|
|
|
|
This hook will update BN stats, so it should be executed before
|
|
``CheckpointHook`` and ``EMAHook``, generally set its priority to
|
|
"ABOVE_NORMAL".
|
|
|
|
Args:
|
|
num_samples (int): The number of samples to update the bn stats.
|
|
Defaults to 8192.
|
|
interval (int): Perform precise bn interval. If the train loop is
|
|
`EpochBasedTrainLoop` or `by_epoch=True`, its unit is 'epoch'; if the
|
|
train loop is `IterBasedTrainLoop` or `by_epoch=False`, its unit is
|
|
'iter'. Defaults to 1.
|
|
"""
|
|
|
|
def __init__(self, num_samples: int = 8192, interval: int = 1) -> None:
|
|
assert interval > 0 and num_samples > 0, "'interval' and " \
|
|
"'num_samples' must be bigger than 0."
|
|
|
|
self.interval = interval
|
|
self.num_samples = num_samples
|
|
|
|
def _perform_precise_bn(self, runner: Runner) -> None:
|
|
"""perform precise bn."""
|
|
print_log(
|
|
f'Running Precise BN for {self.num_samples} samples...',
|
|
logger=runner.logger)
|
|
update_bn_stats(
|
|
runner.model,
|
|
runner.train_loop.dataloader,
|
|
self.num_samples,
|
|
logger=runner.logger)
|
|
print_log('Finish Precise BN, BN stats updated.', logger=runner.logger)
|
|
|
|
def after_train_epoch(self, runner: Runner) -> None:
|
|
"""Calculate prcise BN and broadcast BN stats across GPUs.
|
|
|
|
Args:
|
|
runner (obj:`Runner`): The runner of the training process.
|
|
"""
|
|
# if use `EpochBasedTrainLoop``, do perform precise every
|
|
# `self.interval` epochs.
|
|
if isinstance(runner.train_loop,
|
|
EpochBasedTrainLoop) and self.every_n_epochs(
|
|
runner, self.interval):
|
|
self._perform_precise_bn(runner)
|
|
|
|
def after_train_iter(self,
|
|
runner,
|
|
batch_idx: int,
|
|
data_batch: DATA_BATCH = None,
|
|
outputs: Optional[dict] = None) -> None:
|
|
"""Calculate prcise BN and broadcast BN stats across GPUs.
|
|
|
|
Args:
|
|
runner (obj:`Runner`): The runner of the training process.
|
|
batch_idx (int): The index of the current batch in the train loop.
|
|
data_batch (Sequence[dict], optional): Data from dataloader.
|
|
Defaults to None.
|
|
"""
|
|
# if use `IterBasedTrainLoop``, do perform precise every
|
|
# `self.interval` iters.
|
|
if isinstance(runner.train_loop,
|
|
IterBasedTrainLoop) and self.every_n_train_iters(
|
|
runner, self.interval):
|
|
self._perform_precise_bn(runner)
|