mirror of https://github.com/open-mmlab/mmocr.git
[Fix] Remove dependency on MMCV registry (#1261)
* [Fix] Remove dependency on MMCV registry * fixpull/1292/head
parent
6b6d833be4
commit
0d9b40706c
|
@ -1,59 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ACTIVATION_LAYERS as MMCV_ACTIVATION_LAYERS
|
||||
from mmcv.cnn import UPSAMPLE_LAYERS as MMCV_UPSAMPLE_LAYERS
|
||||
from mmcv.utils import Registry, build_from_cfg
|
||||
|
||||
UPSAMPLE_LAYERS = Registry('upsample layer', parent=MMCV_UPSAMPLE_LAYERS)
|
||||
ACTIVATION_LAYERS = Registry('activation layer', parent=MMCV_ACTIVATION_LAYERS)
|
||||
|
||||
|
||||
def build_upsample_layer(cfg, *args, **kwargs):
|
||||
"""Build upsample layer.
|
||||
|
||||
Args:
|
||||
cfg (dict): The upsample layer config, which should contain:
|
||||
|
||||
- type (str): Layer type.
|
||||
- scale_factor (int): Upsample ratio, which is not applicable to
|
||||
deconv.
|
||||
- layer args: Args needed to instantiate a upsample layer.
|
||||
args (argument list): Arguments passed to the ``__init__``
|
||||
method of the corresponding conv layer.
|
||||
kwargs (keyword arguments): Keyword arguments passed to the
|
||||
``__init__`` method of the corresponding conv layer.
|
||||
|
||||
Returns:
|
||||
nn.Module: Created upsample layer.
|
||||
"""
|
||||
if not isinstance(cfg, dict):
|
||||
raise TypeError(f'cfg must be a dict, but got {type(cfg)}')
|
||||
if 'type' not in cfg:
|
||||
raise KeyError(
|
||||
f'the cfg dict must contain the key "type", but got {cfg}')
|
||||
cfg_ = cfg.copy()
|
||||
|
||||
layer_type = cfg_.pop('type')
|
||||
if layer_type not in UPSAMPLE_LAYERS:
|
||||
raise KeyError(f'Unrecognized upsample type {layer_type}')
|
||||
else:
|
||||
upsample = UPSAMPLE_LAYERS.get(layer_type)
|
||||
|
||||
if upsample is nn.Upsample:
|
||||
cfg_['mode'] = layer_type
|
||||
layer = upsample(*args, **kwargs, **cfg_)
|
||||
return layer
|
||||
|
||||
|
||||
def build_activation_layer(cfg):
|
||||
"""Build activation layer.
|
||||
|
||||
Args:
|
||||
cfg (dict): The activation layer config, which should contain:
|
||||
- type (str): Layer type.
|
||||
- layer args: Args needed to instantiate an activation layer.
|
||||
|
||||
Returns:
|
||||
nn.Module: Created activation layer.
|
||||
"""
|
||||
return build_from_cfg(cfg, ACTIVATION_LAYERS)
|
|
@ -6,8 +6,6 @@ from mmcv.cnn import ConvModule, build_norm_layer
|
|||
from mmcv.utils.parrots_wrapper import _BatchNorm
|
||||
from mmengine.model import BaseModule
|
||||
|
||||
from mmocr.models.builder import (UPSAMPLE_LAYERS, build_activation_layer,
|
||||
build_upsample_layer)
|
||||
from mmocr.registry import MODELS
|
||||
|
||||
|
||||
|
@ -81,13 +79,14 @@ class UpConvBlock(nn.Module):
|
|||
dcn=None,
|
||||
plugins=None)
|
||||
if upsample_cfg is not None:
|
||||
self.upsample = build_upsample_layer(
|
||||
cfg=upsample_cfg,
|
||||
in_channels=in_channels,
|
||||
out_channels=skip_channels,
|
||||
with_cp=with_cp,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
upsample_cfg.update(
|
||||
dict(
|
||||
in_channels=in_channels,
|
||||
out_channels=skip_channels,
|
||||
with_cp=with_cp,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg))
|
||||
self.upsample = MODELS.build(upsample_cfg)
|
||||
else:
|
||||
self.upsample = ConvModule(
|
||||
in_channels,
|
||||
|
@ -182,7 +181,7 @@ class BasicConvBlock(nn.Module):
|
|||
return out
|
||||
|
||||
|
||||
@UPSAMPLE_LAYERS.register_module()
|
||||
@MODELS.register_module()
|
||||
class DeconvModule(nn.Module):
|
||||
"""Deconvolution upsample module in decoder for UNet (2X upsample).
|
||||
|
||||
|
@ -231,7 +230,7 @@ class DeconvModule(nn.Module):
|
|||
padding=padding)
|
||||
|
||||
_, norm = build_norm_layer(norm_cfg, out_channels)
|
||||
activate = build_activation_layer(act_cfg)
|
||||
activate = MODELS.build(act_cfg)
|
||||
self.deconv_upsamping = nn.Sequential(deconv, norm, activate)
|
||||
|
||||
def forward(self, x):
|
||||
|
@ -244,7 +243,7 @@ class DeconvModule(nn.Module):
|
|||
return out
|
||||
|
||||
|
||||
@UPSAMPLE_LAYERS.register_module()
|
||||
@MODELS.register_module()
|
||||
class InterpConv(nn.Module):
|
||||
"""Interpolation upsample module in decoder for UNet.
|
||||
|
||||
|
|
|
@ -35,7 +35,7 @@ class TFEncoderLayer(BaseModule):
|
|||
d_v=64,
|
||||
dropout=0.1,
|
||||
qkv_bias=False,
|
||||
act_cfg=dict(type='mmcv.GELU'),
|
||||
act_cfg=dict(type='mmengine.GELU'),
|
||||
operation_order=None):
|
||||
super().__init__()
|
||||
self.attn = MultiHeadAttention(
|
||||
|
@ -103,7 +103,7 @@ class TFDecoderLayer(nn.Module):
|
|||
d_v=64,
|
||||
dropout=0.1,
|
||||
qkv_bias=False,
|
||||
act_cfg=dict(type='mmcv.GELU'),
|
||||
act_cfg=dict(type='mmengine.GELU'),
|
||||
operation_order=None):
|
||||
super().__init__()
|
||||
|
||||
|
|
|
@ -4,7 +4,7 @@ import torch
|
|||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from mmocr.models.builder import build_activation_layer
|
||||
from mmocr.registry import MODELS
|
||||
|
||||
|
||||
class ScaledDotProductAttention(nn.Module):
|
||||
|
@ -115,7 +115,7 @@ class PositionwiseFeedForward(nn.Module):
|
|||
super().__init__()
|
||||
self.w_1 = nn.Linear(d_in, d_hid)
|
||||
self.w_2 = nn.Linear(d_hid, d_in)
|
||||
self.act = build_activation_layer(act_cfg)
|
||||
self.act = MODELS.build(act_cfg)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(self, x):
|
||||
|
|
|
@ -3,11 +3,11 @@ from typing import Tuple, Union
|
|||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import PLUGIN_LAYERS
|
||||
|
||||
from mmocr.registry import MODELS
|
||||
|
||||
|
||||
# TODO: Replace PLUGIN_LAYERS with MODELS
|
||||
@PLUGIN_LAYERS.register_module()
|
||||
@MODELS.register_module()
|
||||
class Maxpool2d(nn.Module):
|
||||
"""A wrapper around nn.Maxpool2d().
|
||||
|
||||
|
@ -36,7 +36,7 @@ class Maxpool2d(nn.Module):
|
|||
return self.model(x)
|
||||
|
||||
|
||||
@PLUGIN_LAYERS.register_module()
|
||||
@MODELS.register_module()
|
||||
class GCAModule(nn.Module):
|
||||
"""GCAModule in MASTER.
|
||||
|
||||
|
|
|
@ -4,12 +4,14 @@ from unittest import TestCase
|
|||
import torch
|
||||
|
||||
from mmocr.models.textrecog.backbones import ResNet
|
||||
from mmocr.utils import register_all_modules
|
||||
|
||||
|
||||
class TestResNet(TestCase):
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.img = torch.rand(1, 3, 32, 100)
|
||||
register_all_modules()
|
||||
|
||||
def test_resnet45_aster(self):
|
||||
resnet45_aster = ResNet(
|
||||
|
|
Loading…
Reference in New Issue