fix mmlab utils

Link: https://code.alibaba-inc.com/pai-vision/EasyCV/codereview/10316166

    * fix mmlab utils
pull/207/head
jiangnana.jnn 2022-10-11 13:47:35 +08:00
parent 0a796ec2f1
commit ae51a8c341
2 changed files with 33 additions and 16 deletions

View File

@ -21,24 +21,14 @@ try:
from mmcv.runner.hooks import HOOKS
import mmdet
HOOKS._module_dict.pop('YOLOXLrUpdaterHook', None)
from mmdet.models.builder import MODELS as MMMODELS
from mmdet.models.builder import BACKBONES as MMBACKBONES
from mmdet.models.builder import NECKS as MMNECKS
from mmdet.models.builder import HEADS as MMHEADS
from mmdet.core import BitmapMasks, PolygonMasks, encode_mask_results
from mmdet.core.mask import mask2bbox
MM_REGISTRY = {
MMDET: {
'model': MMMODELS,
'backbone': MMBACKBONES,
'neck': MMNECKS,
'head': MMHEADS
}
}
MM_ORIGINAL_REGISTRY = copy.deepcopy(MM_REGISTRY)
except ImportError:
pass
MM_REGISTRY = None
MM_ORIGINAL_REGISTRY = None
EASYCV_REGISTRY_MAP = {
'model': MODELS,
'backbone': BACKBONES,
@ -167,6 +157,24 @@ class MMAdapter:
@staticmethod
def reset_mm_registry():
global MM_ORIGINAL_REGISTRY
global MM_REGISTRY
if MM_REGISTRY is None:
from mmdet.models.builder import MODELS as MMMODELS
from mmdet.models.builder import BACKBONES as MMBACKBONES
from mmdet.models.builder import NECKS as MMNECKS
from mmdet.models.builder import HEADS as MMHEADS
MM_REGISTRY = {
MMDET: {
'model': MMMODELS,
'backbone': MMBACKBONES,
'neck': MMNECKS,
'head': MMHEADS
}
}
MM_ORIGINAL_REGISTRY = copy.deepcopy(MM_REGISTRY)
for mmtype, registries in MM_ORIGINAL_REGISTRY.items():
for k, ori_v in registries.items():
MM_REGISTRY[mmtype][k]._module_dict = copy.deepcopy(

View File

@ -12,8 +12,7 @@ from easycv.apis.test import single_gpu_test
from easycv.datasets import build_dataloader, build_dataset
from easycv.models.builder import build_model
from easycv.utils.config_tools import mmcv_config_fromfile
from easycv.utils.mmlab_utils import (MM_REGISTRY, MMDET,
dynamic_adapt_for_mmlab,
from easycv.utils.mmlab_utils import (dynamic_adapt_for_mmlab,
remove_adapt_for_mmlab)
@ -128,7 +127,17 @@ class MMLabUtilTest(unittest.TestCase):
def test_reset(self):
model = self._get_model()
remove_adapt_for_mmlab(self.cfg)
mmdet_registry = MM_REGISTRY[MMDET]
from mmdet.models.builder import MODELS as MMMODELS
from mmdet.models.builder import BACKBONES as MMBACKBONES
from mmdet.models.builder import NECKS as MMNECKS
from mmdet.models.builder import HEADS as MMHEADS
mmdet_registry = {
'model': MMMODELS,
'backbone': MMBACKBONES,
'neck': MMNECKS,
'head': MMHEADS
}
for module, registry in mmdet_registry.items():
for k, v in registry.module_dict.items():
self.assertTrue('easycv' not in str(v))