170 lines
5.8 KiB
Python
170 lines
5.8 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import math
|
|
from typing import Optional
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from mmengine.model import BaseModule
|
|
|
|
from mmpretrain.registry import MODELS
|
|
from ..utils import build_norm_layer
|
|
|
|
|
|
def is_pow2n(x):
|
|
return x > 0 and (x & (x - 1) == 0)
|
|
|
|
|
|
class ConvBlock2x(BaseModule):
|
|
"""The definition of convolution block."""
|
|
|
|
def __init__(self,
|
|
in_channels: int,
|
|
out_channels: int,
|
|
mid_channels: int,
|
|
norm_cfg: dict,
|
|
act_cfg: dict,
|
|
last_act: bool,
|
|
init_cfg: Optional[dict] = None) -> None:
|
|
super().__init__(init_cfg=init_cfg)
|
|
|
|
self.conv1 = nn.Conv2d(in_channels, mid_channels, 3, 1, 1, bias=False)
|
|
self.norm1 = build_norm_layer(norm_cfg, mid_channels)
|
|
self.activate1 = MODELS.build(act_cfg)
|
|
|
|
self.conv2 = nn.Conv2d(mid_channels, out_channels, 3, 1, 1, bias=False)
|
|
self.norm2 = build_norm_layer(norm_cfg, out_channels)
|
|
self.activate2 = MODELS.build(act_cfg) if last_act else nn.Identity()
|
|
|
|
def forward(self, x: torch.Tensor):
|
|
out = self.conv1(x)
|
|
out = self.norm1(out)
|
|
out = self.activate1(out)
|
|
|
|
out = self.conv2(out)
|
|
out = self.norm2(out)
|
|
out = self.activate2(out)
|
|
return out
|
|
|
|
|
|
class DecoderConvModule(BaseModule):
|
|
"""The convolution module of decoder with upsampling."""
|
|
|
|
def __init__(self,
|
|
in_channels: int,
|
|
out_channels: int,
|
|
mid_channels: int,
|
|
kernel_size: int = 4,
|
|
scale_factor: int = 2,
|
|
num_conv_blocks: int = 1,
|
|
norm_cfg: dict = dict(type='SyncBN'),
|
|
act_cfg: dict = dict(type='ReLU6'),
|
|
last_act: bool = True,
|
|
init_cfg: Optional[dict] = None):
|
|
super().__init__(init_cfg=init_cfg)
|
|
|
|
assert (kernel_size - scale_factor >= 0) and\
|
|
(kernel_size - scale_factor) % 2 == 0,\
|
|
f'kernel_size should be greater than or equal to scale_factor '\
|
|
f'and (kernel_size - scale_factor) should be even numbers, '\
|
|
f'while the kernel size is {kernel_size} and scale_factor is '\
|
|
f'{scale_factor}.'
|
|
|
|
padding = (kernel_size - scale_factor) // 2
|
|
self.upsample = nn.ConvTranspose2d(
|
|
in_channels,
|
|
in_channels,
|
|
kernel_size=kernel_size,
|
|
stride=scale_factor,
|
|
padding=padding,
|
|
bias=True)
|
|
|
|
conv_blocks_list = [
|
|
ConvBlock2x(
|
|
in_channels=in_channels,
|
|
out_channels=out_channels,
|
|
mid_channels=mid_channels,
|
|
norm_cfg=norm_cfg,
|
|
last_act=last_act,
|
|
act_cfg=act_cfg) for _ in range(num_conv_blocks)
|
|
]
|
|
self.conv_blocks = nn.Sequential(*conv_blocks_list)
|
|
|
|
def forward(self, x):
|
|
x = self.upsample(x)
|
|
return self.conv_blocks(x)
|
|
|
|
|
|
@MODELS.register_module()
|
|
class SparKLightDecoder(BaseModule):
|
|
"""The decoder for SparK, which upsamples the feature maps.
|
|
|
|
Args:
|
|
feature_dim (int): The dimension of feature map.
|
|
upsample_ratio (int): The ratio of upsample, equal to downsample_raito
|
|
of the algorithm.
|
|
mid_channels (int): The middle channel of `DecoderConvModule`. Defaults
|
|
to 0.
|
|
kernel_size (int): The kernel size of `ConvTranspose2d` in
|
|
`DecoderConvModule`. Defaults to 4.
|
|
scale_factor (int): The scale_factor of `ConvTranspose2d` in
|
|
`DecoderConvModule`. Defaults to 2.
|
|
num_conv_blocks (int): The number of convolution blocks in
|
|
`DecoderConvModule`. Defaults to 1.
|
|
norm_cfg (dict): Normalization config. Defaults to dict(type='SyncBN').
|
|
act_cfg (dict): Activation config. Defaults to dict(type='ReLU6').
|
|
last_act (bool): Whether apply the last activation in
|
|
`DecoderConvModule`. Defaults to False.
|
|
init_cfg (dict or list[dict], optional): Initialization config dict.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
feature_dim: int,
|
|
upsample_ratio: int,
|
|
mid_channels: int = 0,
|
|
kernel_size: int = 4,
|
|
scale_factor: int = 2,
|
|
num_conv_blocks: int = 1,
|
|
norm_cfg: dict = dict(type='SyncBN'),
|
|
act_cfg: dict = dict(type='ReLU6'),
|
|
last_act: bool = False,
|
|
init_cfg: Optional[dict] = [
|
|
dict(type='Kaiming', layer=['Conv2d', 'ConvTranspose2d']),
|
|
dict(type='TruncNormal', std=0.02, layer=['Linear']),
|
|
dict(
|
|
type='Constant',
|
|
val=1,
|
|
layer=['_BatchNorm', 'LayerNorm', 'SyncBatchNorm'])
|
|
],
|
|
):
|
|
super().__init__(init_cfg=init_cfg)
|
|
self.feature_dim = feature_dim
|
|
|
|
assert is_pow2n(upsample_ratio)
|
|
n = round(math.log2(upsample_ratio))
|
|
channels = [feature_dim // 2**i for i in range(n + 1)]
|
|
|
|
self.decoder = nn.ModuleList([
|
|
DecoderConvModule(
|
|
in_channels=c_in,
|
|
out_channels=c_out,
|
|
mid_channels=c_in if mid_channels == 0 else mid_channels,
|
|
kernel_size=kernel_size,
|
|
scale_factor=scale_factor,
|
|
num_conv_blocks=num_conv_blocks,
|
|
norm_cfg=norm_cfg,
|
|
act_cfg=act_cfg,
|
|
last_act=last_act)
|
|
for (c_in, c_out) in zip(channels[:-1], channels[1:])
|
|
])
|
|
self.proj = nn.Conv2d(
|
|
channels[-1], 3, kernel_size=1, stride=1, bias=True)
|
|
|
|
def forward(self, to_dec):
|
|
x = 0
|
|
for i, d in enumerate(self.decoder):
|
|
if i < len(to_dec) and to_dec[i] is not None:
|
|
x = x + to_dec[i]
|
|
x = self.decoder[i](x)
|
|
return self.proj(x)
|