Support default/external json for open-mmlab models (#230)

* support default/external json for open-mmlab models

* add local

* add more test

* add docs

* add docs

* update docs

* refactor

* add json in MANIFEST

* fixed json typo
pull/300/head
Jerry Jiarui XU 2020-05-27 17:12:43 +08:00 committed by GitHub
parent c63e4d57c3
commit 37d8facfad
10 changed files with 222 additions and 39 deletions

View File

@ -1,2 +1,3 @@
include mmcv/video/optflow_warp/*.hpp mmcv/video/optflow_warp/*.pyx
include requirements.txt
include mmcv/model_zoo/open_mmlab.json

View File

@ -13,6 +13,7 @@ Contents
utils.md
runner.md
cnn.md
model_zoo.md
api.rst

33
docs/model_zoo.md 100644
View File

@ -0,0 +1,33 @@
## Model Zoo
Besides torchvision pre-trained models, we also provide pre-trained models of following CNN:
* VGG Caffe
* ResNet Caffe
* ResNeXt
* ResNet with Group Normalization
* ResNet with Group Normalization and Weight Standardization
* HRNetV2
* Res2Net
* RegNet
### Model URLs in JSON
The model zoo links in MMCV are managed by JSON files.
The json file consists of key-value pair of model name and its url or path.
An example json file could be like:
```json
{
"model_a": "https://example.com/models/model_a_9e5bac.pth",
"model_b": "pretrain/model_b_ab3ef2c.pth"
}
```
The default links of the pre-trained models hosted on Open-MMLab AWS could be found [here](../mmcv/model_zoo/open_mmlab.json).
You may override default links by putting `open-mmlab.json` under `MMCV_HOME`. If `MMCV_HOME` is not find in the environment, `~/.cache/mmcv` will be used by default. You may `export MMCV_HOME=/your/path` to use your own path.
The external json files will be merged into default one. If the same key presents in both external json and default json, the external one will be used.
### Load Checkpoint
The following types are supported for `filename` argument of `mmcv.load_checkpoint()`.
* filepath: The filepath of the checkpoint.
* `http://xxx` and `https://xxx`: The link to download the checkpoint. The `SHA256` postfix should be contained in the filename.
* `torchvison://xxx`: The model links in `torchvision.models`.Please refer to [torchvision](https://pytorch.org/docs/stable/torchvision/models.html) for details.
* `open-mmlab://xxx`: The model links or filepath provided in default and additional json files.

View File

@ -0,0 +1,33 @@
{
"vgg16_caffe": "https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/vgg16_caffe-292e1171.pth",
"detectron/resnet50_caffe": "https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/resnet50_caffe-788b5fa3.pth",
"detectron2/resnet50_caffe_bgr": "https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/resnet50_msra-5891d200.pth",
"detectron/resnet101_caffe": "https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/resnet101_caffe-3ad79236.pth",
"detectron2/resnet101_caffe_bgr": "https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/resnet101_msra-6cc46731.pth",
"resnext50_32x4d": "https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/resnext50-32x4d-0ab1a123.pth",
"resnext101_32x4d": "https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/resnext101_32x4d-a5af3160.pth",
"resnext101_64x4d": "https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/resnext101_64x4d-ee2c6f71.pth",
"contrib/resnet50_gn": "https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/resnet50_gn_thangvubk-ad1730dd.pth",
"detectron/resnet50_gn": "https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/resnet50_gn-9186a21c.pth",
"detectron/resnet101_gn": "https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/resnet101_gn-cac0ab98.pth",
"jhu/resnet50_gn_ws": "https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/resnet50_gn_ws-15beedd8.pth",
"jhu/resnet101_gn_ws": "https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/resnet101_gn_ws-3e3c308c.pth",
"jhu/resnext50_32x4d_gn_ws": "https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/resnext50_32x4d_gn_ws-0d87ac85.pth",
"jhu/resnext101_32x4d_gn_ws": "https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/resnext101_32x4d_gn_ws-34ac1a9e.pth",
"jhu/resnext50_32x4d_gn": "https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/resnext50_32x4d_gn-c7e8b754.pth",
"jhu/resnext101_32x4d_gn": "https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/resnext101_32x4d_gn-ac3bb84e.pth",
"msra/hrnetv2_w18": "https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/hrnetv2_w18-00eb2006.pth",
"msra/hrnetv2_w32": "https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/hrnetv2_w32-dc9eeb4f.pth",
"msra/hrnetv2_w40": "https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/hrnetv2_w40-ed0b031c.pth",
"bninception_caffe": "https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/bn_inception_caffe-ed2e8665.pth",
"kin400/i3d_r50_f32s2_k400": "https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/i3d_r50_f32s2_k400-2c57e077.pth",
"kin400/nl3d_r50_f32s2_k400": "https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/nl3d_r50_f32s2_k400-fa7e7caa.pth",
"res2net101_v1d_26w_4s": "https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/res2net101_v1d_26w_4s_mmdetv2-f0a600f9.pth",
"regnetx_800mf": "https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/regnetx_800mf-1f4be4c7.pth",
"regnetx_1.6gf": "https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/regnetx_1.6gf-5791c176.pth",
"regnetx_3.2gf": "https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/regnetx_3.2gf-c2599b0f.pth",
"regnetx_4.0gf": "https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/regnetx_4.0gf-a88f671e.pth",
"regnetx_6.4gf": "https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/regnetx_6.4gf-006af45d.pth",
"regnetx_8.0gf": "https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/regnetx_8.0gf-3c68abe7.pth",
"regnetx_12gf": "https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/regnetx_12gf-4c2a3350.pth"
}

View File

@ -12,41 +12,24 @@ import torchvision
from torch.utils import model_zoo
import mmcv
from ..fileio import load as load_file
from ..utils import mkdir_or_exist
from .dist_utils import get_dist_info
open_mmlab_model_urls = {
'vgg16_caffe': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/vgg16_caffe-292e1171.pth', # noqa: E501
'resnet50_caffe': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/resnet50_caffe-788b5fa3.pth', # noqa: E501
'resnet50_caffe_bgr': 'https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/resnet50_msra-5891d200.pth', # noqa: E501
'resnet101_caffe': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/resnet101_caffe-3ad79236.pth', # noqa: E501
'resnet101_caffe_bgr': 'https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/resnet101_msra-6cc46731.pth', # noqa: E501
'resnext50_32x4d': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/resnext50-32x4d-0ab1a123.pth', # noqa: E501
'resnext101_32x4d': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/resnext101_32x4d-a5af3160.pth', # noqa: E501
'resnext101_64x4d': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/resnext101_64x4d-ee2c6f71.pth', # noqa: E501
'contrib/resnet50_gn': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/resnet50_gn_thangvubk-ad1730dd.pth', # noqa: E501
'detectron/resnet50_gn': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/resnet50_gn-9186a21c.pth', # noqa: E501
'detectron/resnet101_gn': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/resnet101_gn-cac0ab98.pth', # noqa: E501
'jhu/resnet50_gn_ws': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/resnet50_gn_ws-15beedd8.pth', # noqa: E501
'jhu/resnet101_gn_ws': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/resnet101_gn_ws-3e3c308c.pth', # noqa: E501
'jhu/resnext50_32x4d_gn_ws': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/resnext50_32x4d_gn_ws-0d87ac85.pth', # noqa: E501
'jhu/resnext101_32x4d_gn_ws': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/resnext101_32x4d_gn_ws-34ac1a9e.pth', # noqa: E501
'jhu/resnext50_32x4d_gn': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/resnext50_32x4d_gn-c7e8b754.pth', # noqa: E501
'jhu/resnext101_32x4d_gn': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/resnext101_32x4d_gn-ac3bb84e.pth', # noqa: E501
'msra/hrnetv2_w18': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/hrnetv2_w18-00eb2006.pth', # noqa: E501
'msra/hrnetv2_w32': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/hrnetv2_w32-dc9eeb4f.pth', # noqa: E501
'msra/hrnetv2_w40': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/hrnetv2_w40-ed0b031c.pth', # noqa: E501
'bninception_caffe': 'https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/bn_inception_caffe-ed2e8665.pth', # noqa: E501
'kin400/i3d_r50_f32s2_k400': 'https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/i3d_r50_f32s2_k400-2c57e077.pth', # noqa: E501
'kin400/nl3d_r50_f32s2_k400': 'https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/nl3d_r50_f32s2_k400-fa7e7caa.pth', # noqa: E501
'regnetx_800mf': 'https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/regnetx_800mf-1f4be4c7.pth', # noqa: E501
'regnetx_1.6gf': 'https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/regnetx_1.6gf-5791c176.pth', # noqa: E501
'regnetx_3.2gf': 'https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/regnetx_3.2gf-c2599b0f.pth', # noqa: E501
'regnetx_4.0gf': 'https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/regnetx_4.0gf-a88f671e.pth', # noqa: E501
'regnetx_6.4gf': 'https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/regnetx_6.4gf-006af45d.pth', # noqa: E501
'regnetx_8.0gf': 'https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/regnetx_8.0gf-3c68abe7.pth', # noqa: E501
'regnetx_12gf': 'https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/regnetx_12gf-4c2a3350.pth', # noqa: E501
'res2net101_v1d_26w_4s': 'https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/res2net101_v1d_26w_4s_mmdetv2-f0a600f9.pth', # noqa: E501
} # yapf: disable
ENV_MMCV_HOME = 'MMCV_HOME'
ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME'
DEFAULT_CACHE_DIR = '~/.cache'
def _get_mmcv_home():
mmcv_home = os.path.expanduser(
os.getenv(
ENV_MMCV_HOME,
os.path.join(
os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'mmcv')))
mkdir_or_exist(mmcv_home)
return mmcv_home
def load_state_dict(module, state_dict, strict=False, logger=None):
@ -113,17 +96,17 @@ def load_state_dict(module, state_dict, strict=False, logger=None):
print(err_msg)
def load_url_dist(url):
def load_url_dist(url, model_dir=None):
""" In distributed setting, this function only download checkpoint at
local rank 0 """
rank, world_size = get_dist_info()
rank = int(os.environ.get('LOCAL_RANK', rank))
if rank == 0:
checkpoint = model_zoo.load_url(url)
checkpoint = model_zoo.load_url(url, model_dir=model_dir)
if world_size > 1:
torch.distributed.barrier()
if rank > 0:
checkpoint = model_zoo.load_url(url)
checkpoint = model_zoo.load_url(url, model_dir=model_dir)
return checkpoint
@ -139,11 +122,27 @@ def get_torchvision_models():
return model_urls
def get_external_models():
mmcv_home = _get_mmcv_home()
default_json_path = osp.join(mmcv.__path__[0], 'model_zoo/open_mmlab.json')
default_urls = load_file(default_json_path)
assert isinstance(default_urls, dict)
external_json_path = osp.join(mmcv_home, 'open_mmlab.json')
if osp.exists(external_json_path):
external_urls = load_file(external_json_path)
assert isinstance(external_urls, dict)
default_urls.update(external_urls)
return default_urls
def _load_checkpoint(filename, map_location=None):
"""Load checkpoint from somewhere (modelzoo, file, url).
Args:
filename (str): Either a filepath or URI.
filename (str): Accept local filepath, URL, ``torchvision://xxx``,
``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
details.
map_location (str | None): Same as :func:`torch.load`. Default: None.
Returns:
@ -162,8 +161,17 @@ def _load_checkpoint(filename, map_location=None):
model_name = filename[14:]
checkpoint = load_url_dist(model_urls[model_name])
elif filename.startswith('open-mmlab://'):
model_urls = get_external_models()
model_name = filename[13:]
checkpoint = load_url_dist(open_mmlab_model_urls[model_name])
model_url = model_urls[model_name]
# check if is url
if model_url.startswith(('http://', 'https://')):
checkpoint = load_url_dist(model_url)
else:
filename = osp.join(_get_mmcv_home(), model_url)
if not osp.isfile(filename):
raise IOError(f'{filename} is not a checkpoint file')
checkpoint = torch.load(filename, map_location=map_location)
elif filename.startswith(('http://', 'https://')):
checkpoint = load_url_dist(filename)
else:
@ -182,7 +190,9 @@ def load_checkpoint(model,
Args:
model (Module): Module to load checkpoint.
filename (str): Either a filepath or URL or modelzoo://xxxxxxx.
filename (str): Accept local filepath, URL, ``torchvision://xxx``,
``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
details.
map_location (str): Same as :func:`torch.load`.
strict (bool): Whether to allow different params for the model and
checkpoint.

View File

@ -0,0 +1,5 @@
{
"test": "test.pth",
"val": "val.pth",
"train_empty": "train.pth"
}

Binary file not shown.

Binary file not shown.

View File

@ -0,0 +1,4 @@
{
"train": "https://localhost/train.pth",
"test": "https://localhost/test.pth"
}

View File

@ -0,0 +1,96 @@
# Copyright (c) Open-MMLab. All rights reserved.
import os
import os.path as osp
from unittest.mock import patch
import pytest
import mmcv
from mmcv.runner.checkpoint import (DEFAULT_CACHE_DIR, ENV_MMCV_HOME,
ENV_XDG_CACHE_HOME, _get_mmcv_home,
_load_checkpoint, get_external_models)
@patch('mmcv.__path__', [osp.join(osp.dirname(__file__), 'data/')])
def test_set_mmcv_home():
os.environ.pop(ENV_MMCV_HOME, None)
mmcv_home = osp.join(osp.dirname(__file__), 'data/model_zoo/mmcv_home/')
os.environ[ENV_MMCV_HOME] = mmcv_home
assert _get_mmcv_home() == mmcv_home
@patch('mmcv.__path__', [osp.join(osp.dirname(__file__), 'data/')])
def test_default_mmcv_home():
os.environ.pop(ENV_MMCV_HOME, None)
os.environ.pop(ENV_XDG_CACHE_HOME, None)
assert _get_mmcv_home() == os.path.expanduser(
os.path.join(DEFAULT_CACHE_DIR, 'mmcv'))
model_urls = get_external_models()
assert model_urls == mmcv.load(
osp.join(mmcv.__path__[0], 'model_zoo/open_mmlab.json'))
@patch('mmcv.__path__', [osp.join(osp.dirname(__file__), 'data/')])
def test_get_external_models():
os.environ.pop(ENV_MMCV_HOME, None)
mmcv_home = osp.join(osp.dirname(__file__), 'data/model_zoo/mmcv_home/')
os.environ[ENV_MMCV_HOME] = mmcv_home
ext_urls = get_external_models()
assert ext_urls == {
'train': 'https://localhost/train.pth',
'test': 'test.pth',
'val': 'val.pth',
'train_empty': 'train.pth'
}
def load_url_dist(url):
return 'url:' + url
def load(filepath, map_location=None):
return 'local:' + filepath
@patch('mmcv.__path__', [osp.join(osp.dirname(__file__), 'data/')])
@patch('mmcv.runner.checkpoint.load_url_dist', load_url_dist)
@patch('torch.load', load)
def test_load_external_url():
# test modelzoo://
url = _load_checkpoint('modelzoo://resnet50')
assert url == 'url:https://download.pytorch.org/models/resnet50-19c8e357' \
'.pth'
# test torchvision://
url = _load_checkpoint('torchvision://resnet50')
assert url == 'url:https://download.pytorch.org/models/resnet50-19c8e357' \
'.pth'
# test open-mmlab:// with default MMCV_HOME
os.environ.pop(ENV_MMCV_HOME, None)
os.environ.pop(ENV_XDG_CACHE_HOME, None)
url = _load_checkpoint('open-mmlab://train')
assert url == 'url:https://localhost/train.pth'
# test open-mmlab:// with user-defined MMCV_HOME
os.environ.pop(ENV_MMCV_HOME, None)
mmcv_home = osp.join(osp.dirname(__file__), 'data/model_zoo/mmcv_home')
os.environ[ENV_MMCV_HOME] = mmcv_home
url = _load_checkpoint('open-mmlab://train')
assert url == 'url:https://localhost/train.pth'
with pytest.raises(IOError, match='train.pth is not a checkpoint ' 'file'):
_load_checkpoint('open-mmlab://train_empty')
url = _load_checkpoint('open-mmlab://test')
assert url == f'local:{osp.join(_get_mmcv_home(), "test.pth")}'
url = _load_checkpoint('open-mmlab://val')
assert url == f'local:{osp.join(_get_mmcv_home(), "val.pth")}'
# test http:// https://
url = _load_checkpoint('http://localhost/train.pth')
assert url == 'url:http://localhost/train.pth'
# test local file
with pytest.raises(IOError, match='train.pth is not a checkpoint ' 'file'):
_load_checkpoint('train.pth')
url = _load_checkpoint(osp.join(_get_mmcv_home(), 'test.pth'))
assert url == f'local:{osp.join(_get_mmcv_home(), "test.pth")}'