mmengine/mmengine/hooks/sync_buffer_hook.py

100 lines
3.3 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
# from mmengine.dist import get_dist_info, all_reduce
from collections import OrderedDict
from typing import Generator, List
from unittest.mock import MagicMock, Mock
import torch
from torch._utils import (_flatten_dense_tensors, _take_tensors,
_unflatten_dense_tensors)
from mmengine.registry import HOOKS
from .hook import Hook
# TODO, replace with import mmengine.dist as dist
dist = Mock()
dist.IS_DIST = MagicMock(return_value=True)
# TODO, replace with mmengine.dist.get_dist_info
get_dist_info = MagicMock(return_value=(0, 1))
# TODO, replace with mmengine.dist.all_reduce
all_reduce = MagicMock()
# TODO, may need to move to dist.utils after implementing dist module
def _allreduce_coalesced(tensors: List[torch.Tensor],
world_size: int,
bucket_size_mb: int = -1) -> None:
"""All-reduce a sequence of tensors as a whole.
Args:
tensors (List[torch.Tensor]): A sequence of tensors to be
all-reduced.
world_size (int): The world size of the process group.
bucket_size_mb (int): The limit of each chunk in megabytes
for grouping tensors into chunks. Defaults to -1.
"""
if bucket_size_mb > 0:
bucket_size_bytes = bucket_size_mb * 1024 * 1024
buckets = _take_tensors(tensors, bucket_size_bytes)
else:
buckets = OrderedDict()
for tensor in tensors:
tp = tensor.type()
if tp not in buckets:
buckets[tp] = []
buckets[tp].append(tensor)
buckets = buckets.values()
for bucket in buckets:
flat_tensors = _flatten_dense_tensors(bucket)
all_reduce(flat_tensors)
flat_tensors.div_(world_size)
for tensor, synced in zip(
bucket, _unflatten_dense_tensors(flat_tensors, bucket)):
tensor.copy_(synced)
def allreduce_params(params: Generator[torch.Tensor, None, None],
coalesce: bool = True,
bucket_size_mb: int = -1) -> None:
"""All-reduce parameters.
Args:
params (Generator[torch.Tensor, None, None]): List of parameters or
buffers of a model.
coalesce (bool, optional): Whether to reduce parameters as a whole.
Defaults to True.
bucket_size_mb (int, optional): Size of bucket, the unit is MB.
Defaults to -1.
"""
_, world_size = get_dist_info()
if world_size == 1:
return
params_data = [param.data for param in params]
if coalesce:
_allreduce_coalesced(params_data, world_size, bucket_size_mb)
else:
for tensor in params_data:
all_reduce(tensor.div_(world_size))
@HOOKS.register_module()
class SyncBuffersHook(Hook):
"""Synchronize model buffers such as running_mean and running_var in BN at
the end of each epoch."""
priority = 'NORMAL'
def __init__(self) -> None:
self.distributed = dist.IS_DIST
def after_epoch(self, runner: object) -> None:
"""All-reduce model buffers at the end of each epoch.
Args:
runner (object): The runner of the training process.
"""
if self.distributed:
allreduce_params(runner.model.buffers()) # type: ignore