From 5ee08767f208daea6e88da56534a23369d755e8f Mon Sep 17 00:00:00 2001 From: Wenwei Zhang <40779233+ZwwWayne@users.noreply.github.com> Date: Fri, 14 May 2021 23:36:56 +0800 Subject: [PATCH] inherits mmcv registry (#252) --- mmcls/models/builder.py | 44 +++++++++++++++++++---------------------- 1 file changed, 20 insertions(+), 24 deletions(-) diff --git a/mmcls/models/builder.py b/mmcls/models/builder.py index e2cd04754..7b61742df 100644 --- a/mmcls/models/builder.py +++ b/mmcls/models/builder.py @@ -1,38 +1,34 @@ -import torch.nn as nn -from mmcv.utils import Registry, build_from_cfg +from mmcv.cnn import MODELS as MMCV_MODELS +from mmcv.utils import Registry -BACKBONES = Registry('backbone') -CLASSIFIERS = Registry('classifier') -HEADS = Registry('head') -NECKS = Registry('neck') -LOSSES = Registry('loss') +MODELS = Registry('models', parent=MMCV_MODELS) - -def build(cfg, registry, default_args=None): - if isinstance(cfg, list): - modules = [ - build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg - ] - return nn.Sequential(*modules) - else: - return build_from_cfg(cfg, registry, default_args) +BACKBONES = MODELS +NECKS = MODELS +HEADS = MODELS +LOSSES = MODELS +CLASSIFIERS = MODELS def build_backbone(cfg): - return build(cfg, BACKBONES) - - -def build_head(cfg): - return build(cfg, HEADS) + """Build backbone.""" + return BACKBONES.build(cfg) def build_neck(cfg): - return build(cfg, NECKS) + """Build neck.""" + return NECKS.build(cfg) + + +def build_head(cfg): + """Build head.""" + return HEADS.build(cfg) def build_loss(cfg): - return build(cfg, LOSSES) + """Build loss.""" + return LOSSES.build(cfg) def build_classifier(cfg): - return build(cfg, CLASSIFIERS) + return CLASSIFIERS.build(cfg)