mirror of https://github.com/open-mmlab/mmocr.git
[Fix] Remove dependency on MMCV registry (#1261)
* [Fix] Remove dependency on MMCV registry * fixpull/1292/head
parent
6b6d833be4
commit
0d9b40706c
|
@ -1,59 +0,0 @@
|
||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
|
||||||
import torch.nn as nn
|
|
||||||
from mmcv.cnn import ACTIVATION_LAYERS as MMCV_ACTIVATION_LAYERS
|
|
||||||
from mmcv.cnn import UPSAMPLE_LAYERS as MMCV_UPSAMPLE_LAYERS
|
|
||||||
from mmcv.utils import Registry, build_from_cfg
|
|
||||||
|
|
||||||
UPSAMPLE_LAYERS = Registry('upsample layer', parent=MMCV_UPSAMPLE_LAYERS)
|
|
||||||
ACTIVATION_LAYERS = Registry('activation layer', parent=MMCV_ACTIVATION_LAYERS)
|
|
||||||
|
|
||||||
|
|
||||||
def build_upsample_layer(cfg, *args, **kwargs):
|
|
||||||
"""Build upsample layer.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
cfg (dict): The upsample layer config, which should contain:
|
|
||||||
|
|
||||||
- type (str): Layer type.
|
|
||||||
- scale_factor (int): Upsample ratio, which is not applicable to
|
|
||||||
deconv.
|
|
||||||
- layer args: Args needed to instantiate a upsample layer.
|
|
||||||
args (argument list): Arguments passed to the ``__init__``
|
|
||||||
method of the corresponding conv layer.
|
|
||||||
kwargs (keyword arguments): Keyword arguments passed to the
|
|
||||||
``__init__`` method of the corresponding conv layer.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
nn.Module: Created upsample layer.
|
|
||||||
"""
|
|
||||||
if not isinstance(cfg, dict):
|
|
||||||
raise TypeError(f'cfg must be a dict, but got {type(cfg)}')
|
|
||||||
if 'type' not in cfg:
|
|
||||||
raise KeyError(
|
|
||||||
f'the cfg dict must contain the key "type", but got {cfg}')
|
|
||||||
cfg_ = cfg.copy()
|
|
||||||
|
|
||||||
layer_type = cfg_.pop('type')
|
|
||||||
if layer_type not in UPSAMPLE_LAYERS:
|
|
||||||
raise KeyError(f'Unrecognized upsample type {layer_type}')
|
|
||||||
else:
|
|
||||||
upsample = UPSAMPLE_LAYERS.get(layer_type)
|
|
||||||
|
|
||||||
if upsample is nn.Upsample:
|
|
||||||
cfg_['mode'] = layer_type
|
|
||||||
layer = upsample(*args, **kwargs, **cfg_)
|
|
||||||
return layer
|
|
||||||
|
|
||||||
|
|
||||||
def build_activation_layer(cfg):
|
|
||||||
"""Build activation layer.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
cfg (dict): The activation layer config, which should contain:
|
|
||||||
- type (str): Layer type.
|
|
||||||
- layer args: Args needed to instantiate an activation layer.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
nn.Module: Created activation layer.
|
|
||||||
"""
|
|
||||||
return build_from_cfg(cfg, ACTIVATION_LAYERS)
|
|
|
@ -6,8 +6,6 @@ from mmcv.cnn import ConvModule, build_norm_layer
|
||||||
from mmcv.utils.parrots_wrapper import _BatchNorm
|
from mmcv.utils.parrots_wrapper import _BatchNorm
|
||||||
from mmengine.model import BaseModule
|
from mmengine.model import BaseModule
|
||||||
|
|
||||||
from mmocr.models.builder import (UPSAMPLE_LAYERS, build_activation_layer,
|
|
||||||
build_upsample_layer)
|
|
||||||
from mmocr.registry import MODELS
|
from mmocr.registry import MODELS
|
||||||
|
|
||||||
|
|
||||||
|
@ -81,13 +79,14 @@ class UpConvBlock(nn.Module):
|
||||||
dcn=None,
|
dcn=None,
|
||||||
plugins=None)
|
plugins=None)
|
||||||
if upsample_cfg is not None:
|
if upsample_cfg is not None:
|
||||||
self.upsample = build_upsample_layer(
|
upsample_cfg.update(
|
||||||
cfg=upsample_cfg,
|
dict(
|
||||||
in_channels=in_channels,
|
in_channels=in_channels,
|
||||||
out_channels=skip_channels,
|
out_channels=skip_channels,
|
||||||
with_cp=with_cp,
|
with_cp=with_cp,
|
||||||
norm_cfg=norm_cfg,
|
norm_cfg=norm_cfg,
|
||||||
act_cfg=act_cfg)
|
act_cfg=act_cfg))
|
||||||
|
self.upsample = MODELS.build(upsample_cfg)
|
||||||
else:
|
else:
|
||||||
self.upsample = ConvModule(
|
self.upsample = ConvModule(
|
||||||
in_channels,
|
in_channels,
|
||||||
|
@ -182,7 +181,7 @@ class BasicConvBlock(nn.Module):
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
@UPSAMPLE_LAYERS.register_module()
|
@MODELS.register_module()
|
||||||
class DeconvModule(nn.Module):
|
class DeconvModule(nn.Module):
|
||||||
"""Deconvolution upsample module in decoder for UNet (2X upsample).
|
"""Deconvolution upsample module in decoder for UNet (2X upsample).
|
||||||
|
|
||||||
|
@ -231,7 +230,7 @@ class DeconvModule(nn.Module):
|
||||||
padding=padding)
|
padding=padding)
|
||||||
|
|
||||||
_, norm = build_norm_layer(norm_cfg, out_channels)
|
_, norm = build_norm_layer(norm_cfg, out_channels)
|
||||||
activate = build_activation_layer(act_cfg)
|
activate = MODELS.build(act_cfg)
|
||||||
self.deconv_upsamping = nn.Sequential(deconv, norm, activate)
|
self.deconv_upsamping = nn.Sequential(deconv, norm, activate)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
@ -244,7 +243,7 @@ class DeconvModule(nn.Module):
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
@UPSAMPLE_LAYERS.register_module()
|
@MODELS.register_module()
|
||||||
class InterpConv(nn.Module):
|
class InterpConv(nn.Module):
|
||||||
"""Interpolation upsample module in decoder for UNet.
|
"""Interpolation upsample module in decoder for UNet.
|
||||||
|
|
||||||
|
|
|
@ -35,7 +35,7 @@ class TFEncoderLayer(BaseModule):
|
||||||
d_v=64,
|
d_v=64,
|
||||||
dropout=0.1,
|
dropout=0.1,
|
||||||
qkv_bias=False,
|
qkv_bias=False,
|
||||||
act_cfg=dict(type='mmcv.GELU'),
|
act_cfg=dict(type='mmengine.GELU'),
|
||||||
operation_order=None):
|
operation_order=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.attn = MultiHeadAttention(
|
self.attn = MultiHeadAttention(
|
||||||
|
@ -103,7 +103,7 @@ class TFDecoderLayer(nn.Module):
|
||||||
d_v=64,
|
d_v=64,
|
||||||
dropout=0.1,
|
dropout=0.1,
|
||||||
qkv_bias=False,
|
qkv_bias=False,
|
||||||
act_cfg=dict(type='mmcv.GELU'),
|
act_cfg=dict(type='mmengine.GELU'),
|
||||||
operation_order=None):
|
operation_order=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
|
|
@ -4,7 +4,7 @@ import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from mmocr.models.builder import build_activation_layer
|
from mmocr.registry import MODELS
|
||||||
|
|
||||||
|
|
||||||
class ScaledDotProductAttention(nn.Module):
|
class ScaledDotProductAttention(nn.Module):
|
||||||
|
@ -115,7 +115,7 @@ class PositionwiseFeedForward(nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.w_1 = nn.Linear(d_in, d_hid)
|
self.w_1 = nn.Linear(d_in, d_hid)
|
||||||
self.w_2 = nn.Linear(d_hid, d_in)
|
self.w_2 = nn.Linear(d_hid, d_in)
|
||||||
self.act = build_activation_layer(act_cfg)
|
self.act = MODELS.build(act_cfg)
|
||||||
self.dropout = nn.Dropout(dropout)
|
self.dropout = nn.Dropout(dropout)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
|
|
@ -3,11 +3,11 @@ from typing import Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from mmcv.cnn import PLUGIN_LAYERS
|
|
||||||
|
from mmocr.registry import MODELS
|
||||||
|
|
||||||
|
|
||||||
# TODO: Replace PLUGIN_LAYERS with MODELS
|
@MODELS.register_module()
|
||||||
@PLUGIN_LAYERS.register_module()
|
|
||||||
class Maxpool2d(nn.Module):
|
class Maxpool2d(nn.Module):
|
||||||
"""A wrapper around nn.Maxpool2d().
|
"""A wrapper around nn.Maxpool2d().
|
||||||
|
|
||||||
|
@ -36,7 +36,7 @@ class Maxpool2d(nn.Module):
|
||||||
return self.model(x)
|
return self.model(x)
|
||||||
|
|
||||||
|
|
||||||
@PLUGIN_LAYERS.register_module()
|
@MODELS.register_module()
|
||||||
class GCAModule(nn.Module):
|
class GCAModule(nn.Module):
|
||||||
"""GCAModule in MASTER.
|
"""GCAModule in MASTER.
|
||||||
|
|
||||||
|
|
|
@ -4,12 +4,14 @@ from unittest import TestCase
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from mmocr.models.textrecog.backbones import ResNet
|
from mmocr.models.textrecog.backbones import ResNet
|
||||||
|
from mmocr.utils import register_all_modules
|
||||||
|
|
||||||
|
|
||||||
class TestResNet(TestCase):
|
class TestResNet(TestCase):
|
||||||
|
|
||||||
def setUp(self) -> None:
|
def setUp(self) -> None:
|
||||||
self.img = torch.rand(1, 3, 32, 100)
|
self.img = torch.rand(1, 3, 32, 100)
|
||||||
|
register_all_modules()
|
||||||
|
|
||||||
def test_resnet45_aster(self):
|
def test_resnet45_aster(self):
|
||||||
resnet45_aster = ResNet(
|
resnet45_aster = ResNet(
|
||||||
|
|
Loading…
Reference in New Issue