mirror of https://github.com/open-mmlab/mmcv.git
Support default/external json for open-mmlab models (#230)
* support default/external json for open-mmlab models * add local * add more test * add docs * add docs * update docs * refactor * add json in MANIFEST * fixed json typopull/300/head
parent
c63e4d57c3
commit
37d8facfad
|
@ -1,2 +1,3 @@
|
|||
include mmcv/video/optflow_warp/*.hpp mmcv/video/optflow_warp/*.pyx
|
||||
include requirements.txt
|
||||
include mmcv/model_zoo/open_mmlab.json
|
||||
|
|
|
@ -13,6 +13,7 @@ Contents
|
|||
utils.md
|
||||
runner.md
|
||||
cnn.md
|
||||
model_zoo.md
|
||||
api.rst
|
||||
|
||||
|
||||
|
|
|
@ -0,0 +1,33 @@
|
|||
## Model Zoo
|
||||
Besides torchvision pre-trained models, we also provide pre-trained models of following CNN:
|
||||
* VGG Caffe
|
||||
* ResNet Caffe
|
||||
* ResNeXt
|
||||
* ResNet with Group Normalization
|
||||
* ResNet with Group Normalization and Weight Standardization
|
||||
* HRNetV2
|
||||
* Res2Net
|
||||
* RegNet
|
||||
|
||||
### Model URLs in JSON
|
||||
The model zoo links in MMCV are managed by JSON files.
|
||||
The json file consists of key-value pair of model name and its url or path.
|
||||
An example json file could be like:
|
||||
```json
|
||||
{
|
||||
"model_a": "https://example.com/models/model_a_9e5bac.pth",
|
||||
"model_b": "pretrain/model_b_ab3ef2c.pth"
|
||||
}
|
||||
```
|
||||
The default links of the pre-trained models hosted on Open-MMLab AWS could be found [here](../mmcv/model_zoo/open_mmlab.json).
|
||||
|
||||
You may override default links by putting `open-mmlab.json` under `MMCV_HOME`. If `MMCV_HOME` is not find in the environment, `~/.cache/mmcv` will be used by default. You may `export MMCV_HOME=/your/path` to use your own path.
|
||||
|
||||
The external json files will be merged into default one. If the same key presents in both external json and default json, the external one will be used.
|
||||
|
||||
### Load Checkpoint
|
||||
The following types are supported for `filename` argument of `mmcv.load_checkpoint()`.
|
||||
* filepath: The filepath of the checkpoint.
|
||||
* `http://xxx` and `https://xxx`: The link to download the checkpoint. The `SHA256` postfix should be contained in the filename.
|
||||
* `torchvison://xxx`: The model links in `torchvision.models`.Please refer to [torchvision](https://pytorch.org/docs/stable/torchvision/models.html) for details.
|
||||
* `open-mmlab://xxx`: The model links or filepath provided in default and additional json files.
|
|
@ -0,0 +1,33 @@
|
|||
{
|
||||
"vgg16_caffe": "https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/vgg16_caffe-292e1171.pth",
|
||||
"detectron/resnet50_caffe": "https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/resnet50_caffe-788b5fa3.pth",
|
||||
"detectron2/resnet50_caffe_bgr": "https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/resnet50_msra-5891d200.pth",
|
||||
"detectron/resnet101_caffe": "https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/resnet101_caffe-3ad79236.pth",
|
||||
"detectron2/resnet101_caffe_bgr": "https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/resnet101_msra-6cc46731.pth",
|
||||
"resnext50_32x4d": "https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/resnext50-32x4d-0ab1a123.pth",
|
||||
"resnext101_32x4d": "https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/resnext101_32x4d-a5af3160.pth",
|
||||
"resnext101_64x4d": "https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/resnext101_64x4d-ee2c6f71.pth",
|
||||
"contrib/resnet50_gn": "https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/resnet50_gn_thangvubk-ad1730dd.pth",
|
||||
"detectron/resnet50_gn": "https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/resnet50_gn-9186a21c.pth",
|
||||
"detectron/resnet101_gn": "https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/resnet101_gn-cac0ab98.pth",
|
||||
"jhu/resnet50_gn_ws": "https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/resnet50_gn_ws-15beedd8.pth",
|
||||
"jhu/resnet101_gn_ws": "https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/resnet101_gn_ws-3e3c308c.pth",
|
||||
"jhu/resnext50_32x4d_gn_ws": "https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/resnext50_32x4d_gn_ws-0d87ac85.pth",
|
||||
"jhu/resnext101_32x4d_gn_ws": "https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/resnext101_32x4d_gn_ws-34ac1a9e.pth",
|
||||
"jhu/resnext50_32x4d_gn": "https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/resnext50_32x4d_gn-c7e8b754.pth",
|
||||
"jhu/resnext101_32x4d_gn": "https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/resnext101_32x4d_gn-ac3bb84e.pth",
|
||||
"msra/hrnetv2_w18": "https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/hrnetv2_w18-00eb2006.pth",
|
||||
"msra/hrnetv2_w32": "https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/hrnetv2_w32-dc9eeb4f.pth",
|
||||
"msra/hrnetv2_w40": "https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/hrnetv2_w40-ed0b031c.pth",
|
||||
"bninception_caffe": "https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/bn_inception_caffe-ed2e8665.pth",
|
||||
"kin400/i3d_r50_f32s2_k400": "https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/i3d_r50_f32s2_k400-2c57e077.pth",
|
||||
"kin400/nl3d_r50_f32s2_k400": "https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/nl3d_r50_f32s2_k400-fa7e7caa.pth",
|
||||
"res2net101_v1d_26w_4s": "https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/res2net101_v1d_26w_4s_mmdetv2-f0a600f9.pth",
|
||||
"regnetx_800mf": "https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/regnetx_800mf-1f4be4c7.pth",
|
||||
"regnetx_1.6gf": "https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/regnetx_1.6gf-5791c176.pth",
|
||||
"regnetx_3.2gf": "https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/regnetx_3.2gf-c2599b0f.pth",
|
||||
"regnetx_4.0gf": "https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/regnetx_4.0gf-a88f671e.pth",
|
||||
"regnetx_6.4gf": "https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/regnetx_6.4gf-006af45d.pth",
|
||||
"regnetx_8.0gf": "https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/regnetx_8.0gf-3c68abe7.pth",
|
||||
"regnetx_12gf": "https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/regnetx_12gf-4c2a3350.pth"
|
||||
}
|
|
@ -12,41 +12,24 @@ import torchvision
|
|||
from torch.utils import model_zoo
|
||||
|
||||
import mmcv
|
||||
from ..fileio import load as load_file
|
||||
from ..utils import mkdir_or_exist
|
||||
from .dist_utils import get_dist_info
|
||||
|
||||
open_mmlab_model_urls = {
|
||||
'vgg16_caffe': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/vgg16_caffe-292e1171.pth', # noqa: E501
|
||||
'resnet50_caffe': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/resnet50_caffe-788b5fa3.pth', # noqa: E501
|
||||
'resnet50_caffe_bgr': 'https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/resnet50_msra-5891d200.pth', # noqa: E501
|
||||
'resnet101_caffe': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/resnet101_caffe-3ad79236.pth', # noqa: E501
|
||||
'resnet101_caffe_bgr': 'https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/resnet101_msra-6cc46731.pth', # noqa: E501
|
||||
'resnext50_32x4d': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/resnext50-32x4d-0ab1a123.pth', # noqa: E501
|
||||
'resnext101_32x4d': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/resnext101_32x4d-a5af3160.pth', # noqa: E501
|
||||
'resnext101_64x4d': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/resnext101_64x4d-ee2c6f71.pth', # noqa: E501
|
||||
'contrib/resnet50_gn': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/resnet50_gn_thangvubk-ad1730dd.pth', # noqa: E501
|
||||
'detectron/resnet50_gn': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/resnet50_gn-9186a21c.pth', # noqa: E501
|
||||
'detectron/resnet101_gn': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/resnet101_gn-cac0ab98.pth', # noqa: E501
|
||||
'jhu/resnet50_gn_ws': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/resnet50_gn_ws-15beedd8.pth', # noqa: E501
|
||||
'jhu/resnet101_gn_ws': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/resnet101_gn_ws-3e3c308c.pth', # noqa: E501
|
||||
'jhu/resnext50_32x4d_gn_ws': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/resnext50_32x4d_gn_ws-0d87ac85.pth', # noqa: E501
|
||||
'jhu/resnext101_32x4d_gn_ws': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/resnext101_32x4d_gn_ws-34ac1a9e.pth', # noqa: E501
|
||||
'jhu/resnext50_32x4d_gn': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/resnext50_32x4d_gn-c7e8b754.pth', # noqa: E501
|
||||
'jhu/resnext101_32x4d_gn': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/resnext101_32x4d_gn-ac3bb84e.pth', # noqa: E501
|
||||
'msra/hrnetv2_w18': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/hrnetv2_w18-00eb2006.pth', # noqa: E501
|
||||
'msra/hrnetv2_w32': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/hrnetv2_w32-dc9eeb4f.pth', # noqa: E501
|
||||
'msra/hrnetv2_w40': 'https://s3.ap-northeast-2.amazonaws.com/open-mmlab/pretrain/third_party/hrnetv2_w40-ed0b031c.pth', # noqa: E501
|
||||
'bninception_caffe': 'https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/bn_inception_caffe-ed2e8665.pth', # noqa: E501
|
||||
'kin400/i3d_r50_f32s2_k400': 'https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/i3d_r50_f32s2_k400-2c57e077.pth', # noqa: E501
|
||||
'kin400/nl3d_r50_f32s2_k400': 'https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/nl3d_r50_f32s2_k400-fa7e7caa.pth', # noqa: E501
|
||||
'regnetx_800mf': 'https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/regnetx_800mf-1f4be4c7.pth', # noqa: E501
|
||||
'regnetx_1.6gf': 'https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/regnetx_1.6gf-5791c176.pth', # noqa: E501
|
||||
'regnetx_3.2gf': 'https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/regnetx_3.2gf-c2599b0f.pth', # noqa: E501
|
||||
'regnetx_4.0gf': 'https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/regnetx_4.0gf-a88f671e.pth', # noqa: E501
|
||||
'regnetx_6.4gf': 'https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/regnetx_6.4gf-006af45d.pth', # noqa: E501
|
||||
'regnetx_8.0gf': 'https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/regnetx_8.0gf-3c68abe7.pth', # noqa: E501
|
||||
'regnetx_12gf': 'https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/regnetx_12gf-4c2a3350.pth', # noqa: E501
|
||||
'res2net101_v1d_26w_4s': 'https://open-mmlab.s3.ap-northeast-2.amazonaws.com/pretrain/third_party/res2net101_v1d_26w_4s_mmdetv2-f0a600f9.pth', # noqa: E501
|
||||
} # yapf: disable
|
||||
ENV_MMCV_HOME = 'MMCV_HOME'
|
||||
ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME'
|
||||
DEFAULT_CACHE_DIR = '~/.cache'
|
||||
|
||||
|
||||
def _get_mmcv_home():
|
||||
mmcv_home = os.path.expanduser(
|
||||
os.getenv(
|
||||
ENV_MMCV_HOME,
|
||||
os.path.join(
|
||||
os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'mmcv')))
|
||||
|
||||
mkdir_or_exist(mmcv_home)
|
||||
return mmcv_home
|
||||
|
||||
|
||||
def load_state_dict(module, state_dict, strict=False, logger=None):
|
||||
|
@ -113,17 +96,17 @@ def load_state_dict(module, state_dict, strict=False, logger=None):
|
|||
print(err_msg)
|
||||
|
||||
|
||||
def load_url_dist(url):
|
||||
def load_url_dist(url, model_dir=None):
|
||||
""" In distributed setting, this function only download checkpoint at
|
||||
local rank 0 """
|
||||
rank, world_size = get_dist_info()
|
||||
rank = int(os.environ.get('LOCAL_RANK', rank))
|
||||
if rank == 0:
|
||||
checkpoint = model_zoo.load_url(url)
|
||||
checkpoint = model_zoo.load_url(url, model_dir=model_dir)
|
||||
if world_size > 1:
|
||||
torch.distributed.barrier()
|
||||
if rank > 0:
|
||||
checkpoint = model_zoo.load_url(url)
|
||||
checkpoint = model_zoo.load_url(url, model_dir=model_dir)
|
||||
return checkpoint
|
||||
|
||||
|
||||
|
@ -139,11 +122,27 @@ def get_torchvision_models():
|
|||
return model_urls
|
||||
|
||||
|
||||
def get_external_models():
|
||||
mmcv_home = _get_mmcv_home()
|
||||
default_json_path = osp.join(mmcv.__path__[0], 'model_zoo/open_mmlab.json')
|
||||
default_urls = load_file(default_json_path)
|
||||
assert isinstance(default_urls, dict)
|
||||
external_json_path = osp.join(mmcv_home, 'open_mmlab.json')
|
||||
if osp.exists(external_json_path):
|
||||
external_urls = load_file(external_json_path)
|
||||
assert isinstance(external_urls, dict)
|
||||
default_urls.update(external_urls)
|
||||
|
||||
return default_urls
|
||||
|
||||
|
||||
def _load_checkpoint(filename, map_location=None):
|
||||
"""Load checkpoint from somewhere (modelzoo, file, url).
|
||||
|
||||
Args:
|
||||
filename (str): Either a filepath or URI.
|
||||
filename (str): Accept local filepath, URL, ``torchvision://xxx``,
|
||||
``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
|
||||
details.
|
||||
map_location (str | None): Same as :func:`torch.load`. Default: None.
|
||||
|
||||
Returns:
|
||||
|
@ -162,8 +161,17 @@ def _load_checkpoint(filename, map_location=None):
|
|||
model_name = filename[14:]
|
||||
checkpoint = load_url_dist(model_urls[model_name])
|
||||
elif filename.startswith('open-mmlab://'):
|
||||
model_urls = get_external_models()
|
||||
model_name = filename[13:]
|
||||
checkpoint = load_url_dist(open_mmlab_model_urls[model_name])
|
||||
model_url = model_urls[model_name]
|
||||
# check if is url
|
||||
if model_url.startswith(('http://', 'https://')):
|
||||
checkpoint = load_url_dist(model_url)
|
||||
else:
|
||||
filename = osp.join(_get_mmcv_home(), model_url)
|
||||
if not osp.isfile(filename):
|
||||
raise IOError(f'{filename} is not a checkpoint file')
|
||||
checkpoint = torch.load(filename, map_location=map_location)
|
||||
elif filename.startswith(('http://', 'https://')):
|
||||
checkpoint = load_url_dist(filename)
|
||||
else:
|
||||
|
@ -182,7 +190,9 @@ def load_checkpoint(model,
|
|||
|
||||
Args:
|
||||
model (Module): Module to load checkpoint.
|
||||
filename (str): Either a filepath or URL or modelzoo://xxxxxxx.
|
||||
filename (str): Accept local filepath, URL, ``torchvision://xxx``,
|
||||
``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
|
||||
details.
|
||||
map_location (str): Same as :func:`torch.load`.
|
||||
strict (bool): Whether to allow different params for the model and
|
||||
checkpoint.
|
||||
|
|
|
@ -0,0 +1,5 @@
|
|||
{
|
||||
"test": "test.pth",
|
||||
"val": "val.pth",
|
||||
"train_empty": "train.pth"
|
||||
}
|
Binary file not shown.
Binary file not shown.
|
@ -0,0 +1,4 @@
|
|||
{
|
||||
"train": "https://localhost/train.pth",
|
||||
"test": "https://localhost/test.pth"
|
||||
}
|
|
@ -0,0 +1,96 @@
|
|||
# Copyright (c) Open-MMLab. All rights reserved.
|
||||
import os
|
||||
import os.path as osp
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
import mmcv
|
||||
from mmcv.runner.checkpoint import (DEFAULT_CACHE_DIR, ENV_MMCV_HOME,
|
||||
ENV_XDG_CACHE_HOME, _get_mmcv_home,
|
||||
_load_checkpoint, get_external_models)
|
||||
|
||||
|
||||
@patch('mmcv.__path__', [osp.join(osp.dirname(__file__), 'data/')])
|
||||
def test_set_mmcv_home():
|
||||
os.environ.pop(ENV_MMCV_HOME, None)
|
||||
mmcv_home = osp.join(osp.dirname(__file__), 'data/model_zoo/mmcv_home/')
|
||||
os.environ[ENV_MMCV_HOME] = mmcv_home
|
||||
assert _get_mmcv_home() == mmcv_home
|
||||
|
||||
|
||||
@patch('mmcv.__path__', [osp.join(osp.dirname(__file__), 'data/')])
|
||||
def test_default_mmcv_home():
|
||||
os.environ.pop(ENV_MMCV_HOME, None)
|
||||
os.environ.pop(ENV_XDG_CACHE_HOME, None)
|
||||
assert _get_mmcv_home() == os.path.expanduser(
|
||||
os.path.join(DEFAULT_CACHE_DIR, 'mmcv'))
|
||||
model_urls = get_external_models()
|
||||
assert model_urls == mmcv.load(
|
||||
osp.join(mmcv.__path__[0], 'model_zoo/open_mmlab.json'))
|
||||
|
||||
|
||||
@patch('mmcv.__path__', [osp.join(osp.dirname(__file__), 'data/')])
|
||||
def test_get_external_models():
|
||||
os.environ.pop(ENV_MMCV_HOME, None)
|
||||
mmcv_home = osp.join(osp.dirname(__file__), 'data/model_zoo/mmcv_home/')
|
||||
os.environ[ENV_MMCV_HOME] = mmcv_home
|
||||
ext_urls = get_external_models()
|
||||
assert ext_urls == {
|
||||
'train': 'https://localhost/train.pth',
|
||||
'test': 'test.pth',
|
||||
'val': 'val.pth',
|
||||
'train_empty': 'train.pth'
|
||||
}
|
||||
|
||||
|
||||
def load_url_dist(url):
|
||||
return 'url:' + url
|
||||
|
||||
|
||||
def load(filepath, map_location=None):
|
||||
return 'local:' + filepath
|
||||
|
||||
|
||||
@patch('mmcv.__path__', [osp.join(osp.dirname(__file__), 'data/')])
|
||||
@patch('mmcv.runner.checkpoint.load_url_dist', load_url_dist)
|
||||
@patch('torch.load', load)
|
||||
def test_load_external_url():
|
||||
# test modelzoo://
|
||||
url = _load_checkpoint('modelzoo://resnet50')
|
||||
assert url == 'url:https://download.pytorch.org/models/resnet50-19c8e357' \
|
||||
'.pth'
|
||||
|
||||
# test torchvision://
|
||||
url = _load_checkpoint('torchvision://resnet50')
|
||||
assert url == 'url:https://download.pytorch.org/models/resnet50-19c8e357' \
|
||||
'.pth'
|
||||
|
||||
# test open-mmlab:// with default MMCV_HOME
|
||||
os.environ.pop(ENV_MMCV_HOME, None)
|
||||
os.environ.pop(ENV_XDG_CACHE_HOME, None)
|
||||
url = _load_checkpoint('open-mmlab://train')
|
||||
assert url == 'url:https://localhost/train.pth'
|
||||
|
||||
# test open-mmlab:// with user-defined MMCV_HOME
|
||||
os.environ.pop(ENV_MMCV_HOME, None)
|
||||
mmcv_home = osp.join(osp.dirname(__file__), 'data/model_zoo/mmcv_home')
|
||||
os.environ[ENV_MMCV_HOME] = mmcv_home
|
||||
url = _load_checkpoint('open-mmlab://train')
|
||||
assert url == 'url:https://localhost/train.pth'
|
||||
with pytest.raises(IOError, match='train.pth is not a checkpoint ' 'file'):
|
||||
_load_checkpoint('open-mmlab://train_empty')
|
||||
url = _load_checkpoint('open-mmlab://test')
|
||||
assert url == f'local:{osp.join(_get_mmcv_home(), "test.pth")}'
|
||||
url = _load_checkpoint('open-mmlab://val')
|
||||
assert url == f'local:{osp.join(_get_mmcv_home(), "val.pth")}'
|
||||
|
||||
# test http:// https://
|
||||
url = _load_checkpoint('http://localhost/train.pth')
|
||||
assert url == 'url:http://localhost/train.pth'
|
||||
|
||||
# test local file
|
||||
with pytest.raises(IOError, match='train.pth is not a checkpoint ' 'file'):
|
||||
_load_checkpoint('train.pth')
|
||||
url = _load_checkpoint(osp.join(_get_mmcv_home(), 'test.pth'))
|
||||
assert url == f'local:{osp.join(_get_mmcv_home(), "test.pth")}'
|
Loading…
Reference in New Issue