diff --git a/MANIFEST.in b/MANIFEST.in index 477116160..5c7867e8e 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,3 +1,3 @@ include mmcv/video/optflow_warp/*.hpp mmcv/video/optflow_warp/*.pyx include requirements.txt -include mmcv/model_zoo/open_mmlab.json +include mmcv/model_zoo/open_mmlab.json mmcv/model_zoo/deprecated.json diff --git a/mmcv/model_zoo/deprecated.json b/mmcv/model_zoo/deprecated.json new file mode 100644 index 000000000..25cf6f28c --- /dev/null +++ b/mmcv/model_zoo/deprecated.json @@ -0,0 +1,6 @@ +{ + "resnet50_caffe": "detectron/resnet50_caffe", + "resnet50_caffe_bgr": "detectron2/resnet50_caffe_bgr", + "resnet101_caffe": "detectron/resnet101_caffe", + "resnet101_caffe_bgr": "detectron2/resnet101_caffe_bgr" +} diff --git a/mmcv/runner/checkpoint.py b/mmcv/runner/checkpoint.py index 64db22458..86bd767d7 100644 --- a/mmcv/runner/checkpoint.py +++ b/mmcv/runner/checkpoint.py @@ -136,6 +136,15 @@ def get_external_models(): return default_urls +def get_deprecated_model_names(): + deprecate_json_path = osp.join(mmcv.__path__[0], + 'model_zoo/deprecated.json') + deprecate_urls = load_file(deprecate_json_path) + assert isinstance(deprecate_urls, dict) + + return deprecate_urls + + def _load_checkpoint(filename, map_location=None): """Load checkpoint from somewhere (modelzoo, file, url). @@ -163,6 +172,11 @@ def _load_checkpoint(filename, map_location=None): elif filename.startswith('open-mmlab://'): model_urls = get_external_models() model_name = filename[13:] + deprecated_urls = get_deprecated_model_names() + if model_name in deprecated_urls: + warnings.warn(f'open-mmlab://{model_name} is deprecated in favor ' + f'of open-mmlab://{deprecated_urls[model_name]}') + model_name = deprecated_urls[model_name] model_url = model_urls[model_name] # check if is url if model_url.startswith(('http://', 'https://')): diff --git a/tests/data/model_zoo/deprecated.json b/tests/data/model_zoo/deprecated.json new file mode 100644 index 000000000..7c2d3e458 --- /dev/null +++ b/tests/data/model_zoo/deprecated.json @@ -0,0 +1,4 @@ +{ + "train_old": "train", + "test_old": "test" +} \ No newline at end of file diff --git a/tests/test_load_model_zoo.py b/tests/test_load_model_zoo.py index a416a8e33..aa13c4f5c 100644 --- a/tests/test_load_model_zoo.py +++ b/tests/test_load_model_zoo.py @@ -8,7 +8,9 @@ import pytest import mmcv from mmcv.runner.checkpoint import (DEFAULT_CACHE_DIR, ENV_MMCV_HOME, ENV_XDG_CACHE_HOME, _get_mmcv_home, - _load_checkpoint, get_external_models) + _load_checkpoint, + get_deprecated_model_names, + get_external_models) @patch('mmcv.__path__', [osp.join(osp.dirname(__file__), 'data/')]) @@ -44,6 +46,18 @@ def test_get_external_models(): } +@patch('mmcv.__path__', [osp.join(osp.dirname(__file__), 'data/')]) +def test_get_deprecated_models(): + os.environ.pop(ENV_MMCV_HOME, None) + mmcv_home = osp.join(osp.dirname(__file__), 'data/model_zoo/mmcv_home/') + os.environ[ENV_MMCV_HOME] = mmcv_home + dep_urls = get_deprecated_model_names() + assert dep_urls == { + 'train_old': 'train', + 'test_old': 'test', + } + + def load_url_dist(url): return 'url:' + url @@ -72,6 +86,16 @@ def test_load_external_url(): url = _load_checkpoint('open-mmlab://train') assert url == 'url:https://localhost/train.pth' + # test open-mmlab:// with deprecated model name + os.environ.pop(ENV_MMCV_HOME, None) + os.environ.pop(ENV_XDG_CACHE_HOME, None) + with pytest.warns( + Warning, + match='open-mmlab://train_old is deprecated in favor of ' + 'open-mmlab://train'): + url = _load_checkpoint('open-mmlab://train_old') + assert url == 'url:https://localhost/train.pth' + # test open-mmlab:// with user-defined MMCV_HOME os.environ.pop(ENV_MMCV_HOME, None) mmcv_home = osp.join(osp.dirname(__file__), 'data/model_zoo/mmcv_home')