Yixiao Fang c9670173aa
[Refactor] Move and refactor utils from mmselfsup. (#1385)
* add heads

* add losses

* fix

* remove mim head

* add modified backbones and target generators

* fix lint

* fix lint

* add heads

* add losses

* fix

* add data preprocessor from mmselfsup

* add ut for data prepocessor

* add GatherLayer

* add ema

* add batch shuffle

* add misc

* fix lint

* update

* update docstring
2023-02-28 17:04:40 +08:00

67 lines
1.8 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
from typing import Tuple
import torch
from mmengine.dist import all_gather, broadcast, get_rank
@torch.no_grad()
def batch_shuffle_ddp(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Batch shuffle, for making use of BatchNorm.
Args:
x (torch.Tensor): Data in each GPU.
Returns:
Tuple[torch.Tensor, torch.Tensor]: Output of shuffle operation.
- x_gather[idx_this]: Shuffled data.
- idx_unshuffle: Index for restoring.
"""
# gather from all gpus
batch_size_this = x.shape[0]
x_gather = torch.cat(all_gather(x), dim=0)
batch_size_all = x_gather.shape[0]
num_gpus = batch_size_all // batch_size_this
# random shuffle index
idx_shuffle = torch.randperm(batch_size_all)
# broadcast to all gpus
broadcast(idx_shuffle, src=0)
# index for restoring
idx_unshuffle = torch.argsort(idx_shuffle)
# shuffled index for this gpu
gpu_idx = get_rank()
idx_this = idx_shuffle.view(num_gpus, -1)[gpu_idx]
return x_gather[idx_this], idx_unshuffle
@torch.no_grad()
def batch_unshuffle_ddp(x: torch.Tensor,
idx_unshuffle: torch.Tensor) -> torch.Tensor:
"""Undo batch shuffle.
Args:
x (torch.Tensor): Data in each GPU.
idx_unshuffle (torch.Tensor): Index for restoring.
Returns:
torch.Tensor: Output of unshuffle operation.
"""
# gather from all gpus
batch_size_this = x.shape[0]
x_gather = torch.cat(all_gather(x), dim=0)
batch_size_all = x_gather.shape[0]
num_gpus = batch_size_all // batch_size_this
# restored index for this gpu
gpu_idx = get_rank()
idx_this = idx_unshuffle.view(num_gpus, -1)[gpu_idx]
return x_gather[idx_this]