BC of model zoo: add deprecate urls (#301)

* add deprecate urls

* add deprecate urls

* warning test

* rename to deprecated.json
This commit is contained in:
Jerry Jiarui XU 2020-05-27 22:09:06 +08:00 committed by GitHub
parent 37d8facfad
commit d6411b7fff
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 50 additions and 2 deletions

View File

@ -1,3 +1,3 @@
include mmcv/video/optflow_warp/*.hpp mmcv/video/optflow_warp/*.pyx
include requirements.txt
include mmcv/model_zoo/open_mmlab.json
include mmcv/model_zoo/open_mmlab.json mmcv/model_zoo/deprecated.json

View File

@ -0,0 +1,6 @@
{
"resnet50_caffe": "detectron/resnet50_caffe",
"resnet50_caffe_bgr": "detectron2/resnet50_caffe_bgr",
"resnet101_caffe": "detectron/resnet101_caffe",
"resnet101_caffe_bgr": "detectron2/resnet101_caffe_bgr"
}

View File

@ -136,6 +136,15 @@ def get_external_models():
return default_urls
def get_deprecated_model_names():
deprecate_json_path = osp.join(mmcv.__path__[0],
'model_zoo/deprecated.json')
deprecate_urls = load_file(deprecate_json_path)
assert isinstance(deprecate_urls, dict)
return deprecate_urls
def _load_checkpoint(filename, map_location=None):
"""Load checkpoint from somewhere (modelzoo, file, url).
@ -163,6 +172,11 @@ def _load_checkpoint(filename, map_location=None):
elif filename.startswith('open-mmlab://'):
model_urls = get_external_models()
model_name = filename[13:]
deprecated_urls = get_deprecated_model_names()
if model_name in deprecated_urls:
warnings.warn(f'open-mmlab://{model_name} is deprecated in favor '
f'of open-mmlab://{deprecated_urls[model_name]}')
model_name = deprecated_urls[model_name]
model_url = model_urls[model_name]
# check if is url
if model_url.startswith(('http://', 'https://')):

View File

@ -0,0 +1,4 @@
{
"train_old": "train",
"test_old": "test"
}

View File

@ -8,7 +8,9 @@ import pytest
import mmcv
from mmcv.runner.checkpoint import (DEFAULT_CACHE_DIR, ENV_MMCV_HOME,
ENV_XDG_CACHE_HOME, _get_mmcv_home,
_load_checkpoint, get_external_models)
_load_checkpoint,
get_deprecated_model_names,
get_external_models)
@patch('mmcv.__path__', [osp.join(osp.dirname(__file__), 'data/')])
@ -44,6 +46,18 @@ def test_get_external_models():
}
@patch('mmcv.__path__', [osp.join(osp.dirname(__file__), 'data/')])
def test_get_deprecated_models():
os.environ.pop(ENV_MMCV_HOME, None)
mmcv_home = osp.join(osp.dirname(__file__), 'data/model_zoo/mmcv_home/')
os.environ[ENV_MMCV_HOME] = mmcv_home
dep_urls = get_deprecated_model_names()
assert dep_urls == {
'train_old': 'train',
'test_old': 'test',
}
def load_url_dist(url):
return 'url:' + url
@ -72,6 +86,16 @@ def test_load_external_url():
url = _load_checkpoint('open-mmlab://train')
assert url == 'url:https://localhost/train.pth'
# test open-mmlab:// with deprecated model name
os.environ.pop(ENV_MMCV_HOME, None)
os.environ.pop(ENV_XDG_CACHE_HOME, None)
with pytest.warns(
Warning,
match='open-mmlab://train_old is deprecated in favor of '
'open-mmlab://train'):
url = _load_checkpoint('open-mmlab://train_old')
assert url == 'url:https://localhost/train.pth'
# test open-mmlab:// with user-defined MMCV_HOME
os.environ.pop(ENV_MMCV_HOME, None)
mmcv_home = osp.join(osp.dirname(__file__), 'data/model_zoo/mmcv_home')