mirror of https://github.com/alibaba/EasyCV.git
fix mmlab utils
Link: https://code.alibaba-inc.com/pai-vision/EasyCV/codereview/10316166 * fix mmlab utilspull/207/head
parent
0a796ec2f1
commit
ae51a8c341
|
@ -21,24 +21,14 @@ try:
|
|||
from mmcv.runner.hooks import HOOKS
|
||||
import mmdet
|
||||
HOOKS._module_dict.pop('YOLOXLrUpdaterHook', None)
|
||||
from mmdet.models.builder import MODELS as MMMODELS
|
||||
from mmdet.models.builder import BACKBONES as MMBACKBONES
|
||||
from mmdet.models.builder import NECKS as MMNECKS
|
||||
from mmdet.models.builder import HEADS as MMHEADS
|
||||
from mmdet.core import BitmapMasks, PolygonMasks, encode_mask_results
|
||||
from mmdet.core.mask import mask2bbox
|
||||
MM_REGISTRY = {
|
||||
MMDET: {
|
||||
'model': MMMODELS,
|
||||
'backbone': MMBACKBONES,
|
||||
'neck': MMNECKS,
|
||||
'head': MMHEADS
|
||||
}
|
||||
}
|
||||
MM_ORIGINAL_REGISTRY = copy.deepcopy(MM_REGISTRY)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
MM_REGISTRY = None
|
||||
MM_ORIGINAL_REGISTRY = None
|
||||
|
||||
EASYCV_REGISTRY_MAP = {
|
||||
'model': MODELS,
|
||||
'backbone': BACKBONES,
|
||||
|
@ -167,6 +157,24 @@ class MMAdapter:
|
|||
|
||||
@staticmethod
|
||||
def reset_mm_registry():
|
||||
global MM_ORIGINAL_REGISTRY
|
||||
global MM_REGISTRY
|
||||
|
||||
if MM_REGISTRY is None:
|
||||
from mmdet.models.builder import MODELS as MMMODELS
|
||||
from mmdet.models.builder import BACKBONES as MMBACKBONES
|
||||
from mmdet.models.builder import NECKS as MMNECKS
|
||||
from mmdet.models.builder import HEADS as MMHEADS
|
||||
MM_REGISTRY = {
|
||||
MMDET: {
|
||||
'model': MMMODELS,
|
||||
'backbone': MMBACKBONES,
|
||||
'neck': MMNECKS,
|
||||
'head': MMHEADS
|
||||
}
|
||||
}
|
||||
MM_ORIGINAL_REGISTRY = copy.deepcopy(MM_REGISTRY)
|
||||
|
||||
for mmtype, registries in MM_ORIGINAL_REGISTRY.items():
|
||||
for k, ori_v in registries.items():
|
||||
MM_REGISTRY[mmtype][k]._module_dict = copy.deepcopy(
|
||||
|
|
|
@ -12,8 +12,7 @@ from easycv.apis.test import single_gpu_test
|
|||
from easycv.datasets import build_dataloader, build_dataset
|
||||
from easycv.models.builder import build_model
|
||||
from easycv.utils.config_tools import mmcv_config_fromfile
|
||||
from easycv.utils.mmlab_utils import (MM_REGISTRY, MMDET,
|
||||
dynamic_adapt_for_mmlab,
|
||||
from easycv.utils.mmlab_utils import (dynamic_adapt_for_mmlab,
|
||||
remove_adapt_for_mmlab)
|
||||
|
||||
|
||||
|
@ -128,7 +127,17 @@ class MMLabUtilTest(unittest.TestCase):
|
|||
def test_reset(self):
|
||||
model = self._get_model()
|
||||
remove_adapt_for_mmlab(self.cfg)
|
||||
mmdet_registry = MM_REGISTRY[MMDET]
|
||||
from mmdet.models.builder import MODELS as MMMODELS
|
||||
from mmdet.models.builder import BACKBONES as MMBACKBONES
|
||||
from mmdet.models.builder import NECKS as MMNECKS
|
||||
from mmdet.models.builder import HEADS as MMHEADS
|
||||
mmdet_registry = {
|
||||
'model': MMMODELS,
|
||||
'backbone': MMBACKBONES,
|
||||
'neck': MMNECKS,
|
||||
'head': MMHEADS
|
||||
}
|
||||
|
||||
for module, registry in mmdet_registry.items():
|
||||
for k, v in registry.module_dict.items():
|
||||
self.assertTrue('easycv' not in str(v))
|
||||
|
|
Loading…
Reference in New Issue