mirror of
https://github.com/open-mmlab/mmclassification.git
synced 2025-06-03 21:53:55 +08:00
* add heads * add losses * fix * remove mim head * add modified backbones and target generators * fix lint * fix lint * add heads * add losses * fix * add data preprocessor from mmselfsup * add ut for data prepocessor * add GatherLayer * add ema * add batch shuffle * add misc * fix lint * update * update docstring
23 lines
677 B
Python
23 lines
677 B
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
from typing import Any, List, Tuple
|
|
|
|
import torch
|
|
from mmengine.dist import all_gather, get_rank
|
|
|
|
|
|
class GatherLayer(torch.autograd.Function):
|
|
"""Gather tensors from all process, supporting backward propagation."""
|
|
|
|
@staticmethod
|
|
def forward(ctx: Any, input: torch.Tensor) -> Tuple[List]:
|
|
ctx.save_for_backward(input)
|
|
output = all_gather(input)
|
|
return tuple(output)
|
|
|
|
@staticmethod
|
|
def backward(ctx: Any, *grads: torch.Tensor) -> torch.Tensor:
|
|
input, = ctx.saved_tensors
|
|
grad_out = torch.zeros_like(input)
|
|
grad_out[:] = grads[get_rank()]
|
|
return grad_out
|