177 lines
6.0 KiB
Python
177 lines
6.0 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
from typing import Sequence
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from mmcv.cnn.bricks import (Conv2dAdaptivePadding, build_activation_layer,
|
|
build_norm_layer)
|
|
from mmengine.utils import digit_version
|
|
|
|
from mmcls.registry import MODELS
|
|
from .base_backbone import BaseBackbone
|
|
|
|
|
|
class Residual(nn.Module):
|
|
|
|
def __init__(self, fn):
|
|
super().__init__()
|
|
self.fn = fn
|
|
|
|
def forward(self, x):
|
|
return self.fn(x) + x
|
|
|
|
|
|
@MODELS.register_module()
|
|
class ConvMixer(BaseBackbone):
|
|
"""ConvMixer. .
|
|
|
|
A PyTorch implementation of : `Patches Are All You Need?
|
|
<https://arxiv.org/pdf/2201.09792.pdf>`_
|
|
|
|
Modified from the `official repo
|
|
<https://github.com/locuslab/convmixer/blob/main/convmixer.py>`_
|
|
and `timm
|
|
<https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/convmixer.py>`_.
|
|
|
|
Args:
|
|
arch (str | dict): The model's architecture. If string, it should be
|
|
one of architecture in ``ConvMixer.arch_settings``. And if dict, it
|
|
should include the following two keys:
|
|
|
|
- embed_dims (int): The dimensions of patch embedding.
|
|
- depth (int): Number of repetitions of ConvMixer Layer.
|
|
- patch_size (int): The patch size.
|
|
- kernel_size (int): The kernel size of depthwise conv layers.
|
|
|
|
Defaults to '768/32'.
|
|
in_channels (int): Number of input image channels. Defaults to 3.
|
|
patch_size (int): The size of one patch in the patch embed layer.
|
|
Defaults to 7.
|
|
norm_cfg (dict): The config dict for norm layers.
|
|
Defaults to ``dict(type='BN')``.
|
|
act_cfg (dict): The config dict for activation after each convolution.
|
|
Defaults to ``dict(type='GELU')``.
|
|
out_indices (Sequence | int): Output from which stages.
|
|
Defaults to -1, means the last stage.
|
|
frozen_stages (int): Stages to be frozen (all param fixed).
|
|
Defaults to 0, which means not freezing any parameters.
|
|
init_cfg (dict, optional): Initialization config dict.
|
|
"""
|
|
arch_settings = {
|
|
'768/32': {
|
|
'embed_dims': 768,
|
|
'depth': 32,
|
|
'patch_size': 7,
|
|
'kernel_size': 7
|
|
},
|
|
'1024/20': {
|
|
'embed_dims': 1024,
|
|
'depth': 20,
|
|
'patch_size': 14,
|
|
'kernel_size': 9
|
|
},
|
|
'1536/20': {
|
|
'embed_dims': 1536,
|
|
'depth': 20,
|
|
'patch_size': 7,
|
|
'kernel_size': 9
|
|
},
|
|
}
|
|
|
|
def __init__(self,
|
|
arch='768/32',
|
|
in_channels=3,
|
|
norm_cfg=dict(type='BN'),
|
|
act_cfg=dict(type='GELU'),
|
|
out_indices=-1,
|
|
frozen_stages=0,
|
|
init_cfg=None):
|
|
super().__init__(init_cfg=init_cfg)
|
|
|
|
if isinstance(arch, str):
|
|
assert arch in self.arch_settings, \
|
|
f'Unavailable arch, please choose from ' \
|
|
f'({set(self.arch_settings)}) or pass a dict.'
|
|
arch = self.arch_settings[arch]
|
|
elif isinstance(arch, dict):
|
|
essential_keys = {
|
|
'embed_dims', 'depth', 'patch_size', 'kernel_size'
|
|
}
|
|
assert isinstance(arch, dict) and essential_keys <= set(arch), \
|
|
f'Custom arch needs a dict with keys {essential_keys}'
|
|
|
|
self.embed_dims = arch['embed_dims']
|
|
self.depth = arch['depth']
|
|
self.patch_size = arch['patch_size']
|
|
self.kernel_size = arch['kernel_size']
|
|
self.act = build_activation_layer(act_cfg)
|
|
|
|
# check out indices and frozen stages
|
|
if isinstance(out_indices, int):
|
|
out_indices = [out_indices]
|
|
assert isinstance(out_indices, Sequence), \
|
|
f'"out_indices" must by a sequence or int, ' \
|
|
f'get {type(out_indices)} instead.'
|
|
for i, index in enumerate(out_indices):
|
|
if index < 0:
|
|
out_indices[i] = self.depth + index
|
|
assert out_indices[i] >= 0, f'Invalid out_indices {index}'
|
|
self.out_indices = out_indices
|
|
self.frozen_stages = frozen_stages
|
|
|
|
# Set stem layers
|
|
self.stem = nn.Sequential(
|
|
nn.Conv2d(
|
|
in_channels,
|
|
self.embed_dims,
|
|
kernel_size=self.patch_size,
|
|
stride=self.patch_size), self.act,
|
|
build_norm_layer(norm_cfg, self.embed_dims)[1])
|
|
|
|
# Set conv2d according to torch version
|
|
convfunc = nn.Conv2d
|
|
if digit_version(torch.__version__) < digit_version('1.9.0'):
|
|
convfunc = Conv2dAdaptivePadding
|
|
|
|
# Repetitions of ConvMixer Layer
|
|
self.stages = nn.Sequential(*[
|
|
nn.Sequential(
|
|
Residual(
|
|
nn.Sequential(
|
|
convfunc(
|
|
self.embed_dims,
|
|
self.embed_dims,
|
|
self.kernel_size,
|
|
groups=self.embed_dims,
|
|
padding='same'), self.act,
|
|
build_norm_layer(norm_cfg, self.embed_dims)[1])),
|
|
nn.Conv2d(self.embed_dims, self.embed_dims, kernel_size=1),
|
|
self.act,
|
|
build_norm_layer(norm_cfg, self.embed_dims)[1])
|
|
for _ in range(self.depth)
|
|
])
|
|
|
|
self._freeze_stages()
|
|
|
|
def forward(self, x):
|
|
x = self.stem(x)
|
|
outs = []
|
|
for i, stage in enumerate(self.stages):
|
|
x = stage(x)
|
|
if i in self.out_indices:
|
|
outs.append(x)
|
|
|
|
# x = self.pooling(x).flatten(1)
|
|
return tuple(outs)
|
|
|
|
def train(self, mode=True):
|
|
super(ConvMixer, self).train(mode)
|
|
self._freeze_stages()
|
|
|
|
def _freeze_stages(self):
|
|
for i in range(self.frozen_stages):
|
|
stage = self.stages[i]
|
|
stage.eval()
|
|
for param in stage.parameters():
|
|
param.requires_grad = False
|