[Refactor] Use MMCV MODEL_REGISTRY (#515)

* [Refactor] Use MMCV MODEL_REGISTRY

* fixed args
pull/529/head
Jerry Jiarui XU 2021-04-27 23:51:09 -07:00 committed by GitHub
parent 2da3da47ed
commit d568d06e75
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 14 additions and 34 deletions

View File

@ -1,56 +1,35 @@
import warnings import warnings
from mmcv.utils import Registry, build_from_cfg from mmcv.cnn import MODELS as MMCV_MODELS
from torch import nn from mmcv.utils import Registry
BACKBONES = Registry('backbone') MODELS = Registry('models', parent=MMCV_MODELS)
NECKS = Registry('neck')
HEADS = Registry('head')
LOSSES = Registry('loss')
SEGMENTORS = Registry('segmentor')
BACKBONES = MODELS
def build(cfg, registry, default_args=None): NECKS = MODELS
"""Build a module. HEADS = MODELS
LOSSES = MODELS
Args: SEGMENTORS = MODELS
cfg (dict, list[dict]): The config of modules, is is either a dict
or a list of configs.
registry (:obj:`Registry`): A registry the module belongs to.
default_args (dict, optional): Default arguments to build the module.
Defaults to None.
Returns:
nn.Module: A built nn module.
"""
if isinstance(cfg, list):
modules = [
build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg
]
return nn.Sequential(*modules)
else:
return build_from_cfg(cfg, registry, default_args)
def build_backbone(cfg): def build_backbone(cfg):
"""Build backbone.""" """Build backbone."""
return build(cfg, BACKBONES) return BACKBONES.build(cfg)
def build_neck(cfg): def build_neck(cfg):
"""Build neck.""" """Build neck."""
return build(cfg, NECKS) return NECKS.build(cfg)
def build_head(cfg): def build_head(cfg):
"""Build head.""" """Build head."""
return build(cfg, HEADS) return HEADS.build(cfg)
def build_loss(cfg): def build_loss(cfg):
"""Build loss.""" """Build loss."""
return build(cfg, LOSSES) return LOSSES.build(cfg)
def build_segmentor(cfg, train_cfg=None, test_cfg=None): def build_segmentor(cfg, train_cfg=None, test_cfg=None):
@ -63,4 +42,5 @@ def build_segmentor(cfg, train_cfg=None, test_cfg=None):
'train_cfg specified in both outer field and model field ' 'train_cfg specified in both outer field and model field '
assert cfg.get('test_cfg') is None or test_cfg is None, \ assert cfg.get('test_cfg') is None or test_cfg is None, \
'test_cfg specified in both outer field and model field ' 'test_cfg specified in both outer field and model field '
return build(cfg, SEGMENTORS, dict(train_cfg=train_cfg, test_cfg=test_cfg)) return SEGMENTORS.build(
cfg, default_args=dict(train_cfg=train_cfg, test_cfg=test_cfg))