150 lines
5.3 KiB
Python
150 lines
5.3 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
# Copyright (c) ByteDance, Inc. and its affiliates. All rights reserved.
|
|
# Modified from https://github.com/keyu-tian/SparK/blob/main/encoder.py
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from mmpretrain.registry import MODELS
|
|
|
|
|
|
class SparseHelper:
|
|
"""The helper to compute sparse operation with pytorch, such as sparse
|
|
convlolution, sparse batch norm, etc."""
|
|
|
|
_cur_active: torch.Tensor = None
|
|
|
|
@staticmethod
|
|
def _get_active_map_or_index(H: int,
|
|
returning_active_map: bool = True
|
|
) -> torch.Tensor:
|
|
"""Get current active map with (B, 1, f, f) shape or index format."""
|
|
# _cur_active with shape (B, 1, f, f)
|
|
downsample_raito = H // SparseHelper._cur_active.shape[-1]
|
|
active_ex = SparseHelper._cur_active.repeat_interleave(
|
|
downsample_raito, 2).repeat_interleave(downsample_raito, 3)
|
|
return active_ex if returning_active_map else active_ex.squeeze(
|
|
1).nonzero(as_tuple=True)
|
|
|
|
@staticmethod
|
|
def sp_conv_forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
"""Sparse convolution forward function."""
|
|
x = super(type(self), self).forward(x)
|
|
|
|
# (b, c, h, w) *= (b, 1, h, w), mask the output of conv
|
|
x *= SparseHelper._get_active_map_or_index(
|
|
H=x.shape[2], returning_active_map=True)
|
|
return x
|
|
|
|
@staticmethod
|
|
def sp_bn_forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
"""Sparse batch norm forward function."""
|
|
active_index = SparseHelper._get_active_map_or_index(
|
|
H=x.shape[2], returning_active_map=False)
|
|
|
|
# (b, c, h, w) -> (b, h, w, c)
|
|
x_permuted = x.permute(0, 2, 3, 1)
|
|
|
|
# select the features on non-masked positions to form flatten features
|
|
# with shape (n, c)
|
|
x_flattened = x_permuted[active_index]
|
|
|
|
# use BN1d to normalize this flatten feature (n, c)
|
|
x_flattened = super(type(self), self).forward(x_flattened)
|
|
|
|
# generate output
|
|
output = torch.zeros_like(x_permuted, dtype=x_flattened.dtype)
|
|
output[active_index] = x_flattened
|
|
|
|
# (b, h, w, c) -> (b, c, h, w)
|
|
output = output.permute(0, 3, 1, 2)
|
|
return output
|
|
|
|
|
|
class SparseConv2d(nn.Conv2d):
|
|
"""hack: override the forward function.
|
|
See `sp_conv_forward` above for more details
|
|
"""
|
|
forward = SparseHelper.sp_conv_forward
|
|
|
|
|
|
class SparseMaxPooling(nn.MaxPool2d):
|
|
"""hack: override the forward function.
|
|
See `sp_conv_forward` above for more details
|
|
"""
|
|
forward = SparseHelper.sp_conv_forward
|
|
|
|
|
|
class SparseAvgPooling(nn.AvgPool2d):
|
|
"""hack: override the forward function.
|
|
See `sp_conv_forward` above for more details
|
|
"""
|
|
forward = SparseHelper.sp_conv_forward
|
|
|
|
|
|
@MODELS.register_module()
|
|
class SparseBatchNorm2d(nn.BatchNorm1d):
|
|
"""hack: override the forward function.
|
|
See `sp_bn_forward` above for more details
|
|
"""
|
|
forward = SparseHelper.sp_bn_forward
|
|
|
|
|
|
@MODELS.register_module()
|
|
class SparseSyncBatchNorm2d(nn.SyncBatchNorm):
|
|
"""hack: override the forward function.
|
|
See `sp_bn_forward` above for more details
|
|
"""
|
|
forward = SparseHelper.sp_bn_forward
|
|
|
|
|
|
@MODELS.register_module('SparseLN2d')
|
|
class SparseLayerNorm2D(nn.LayerNorm):
|
|
"""Implementation of sparse LayerNorm on channels for 2d images."""
|
|
|
|
def forward(self,
|
|
x: torch.Tensor,
|
|
data_format='channel_first') -> torch.Tensor:
|
|
"""Sparse layer norm forward function with 2D data.
|
|
|
|
Args:
|
|
x (torch.Tensor): The input tensor.
|
|
data_format (str): The format of the input tensor. If
|
|
``"channel_first"``, the shape of the input tensor should be
|
|
(B, C, H, W). If ``"channel_last"``, the shape of the input
|
|
tensor should be (B, H, W, C). Defaults to "channel_first".
|
|
"""
|
|
assert x.dim() == 4, (
|
|
f'LayerNorm2d only supports inputs with shape '
|
|
f'(N, C, H, W), but got tensor with shape {x.shape}')
|
|
if data_format == 'channel_last':
|
|
index = SparseHelper._get_active_map_or_index(
|
|
H=x.shape[1], returning_active_map=False)
|
|
|
|
# select the features on non-masked positions to form flatten
|
|
# features with shape (n, c)
|
|
x_flattened = x[index]
|
|
# use LayerNorm to normalize this flatten feature (n, c)
|
|
x_flattened = super().forward(x_flattened)
|
|
|
|
# generate output
|
|
x = torch.zeros_like(x, dtype=x_flattened.dtype)
|
|
x[index] = x_flattened
|
|
elif data_format == 'channel_first':
|
|
index = SparseHelper._get_active_map_or_index(
|
|
H=x.shape[2], returning_active_map=False)
|
|
x_permuted = x.permute(0, 2, 3, 1)
|
|
|
|
# select the features on non-masked positions to form flatten
|
|
# features with shape (n, c)
|
|
x_flattened = x_permuted[index]
|
|
# use LayerNorm to normalize this flatten feature (n, c)
|
|
x_flattened = super().forward(x_flattened)
|
|
|
|
# generate output
|
|
x = torch.zeros_like(x_permuted, dtype=x_flattened.dtype)
|
|
x[index] = x_flattened
|
|
x = x.permute(0, 3, 1, 2).contiguous()
|
|
else:
|
|
raise NotImplementedError
|
|
return x
|