mirror of https://github.com/open-mmlab/mmcv.git
[Enhancement] Enhance get_torchvision_models (#1867)
* enhance get_torchvision_models * simplify logic * Dump ckpt in torchvision lower than 0.13.0 to a json file * add json * refactor load urls logic * fix unit test * change url key to lower letters * check torchvision version rather than check torch version in unittest * Fix CI and refine test logic of torchvision version * add comment * support compare pre-release version * support loaad modeel like torchvision * refine comment. * fix unit test and comment * fxi unit test bug * support get model by lower weightspull/1896/head
parent
a5cfcb93ff
commit
a80df6874d
|
@ -0,0 +1,57 @@
|
|||
{
|
||||
"alexnet": "https://download.pytorch.org/models/alexnet-owt-7be5be79.pth",
|
||||
"densenet121": "https://download.pytorch.org/models/densenet121-a639ec97.pth",
|
||||
"densenet169": "https://download.pytorch.org/models/densenet169-b2777c0a.pth",
|
||||
"densenet201": "https://download.pytorch.org/models/densenet201-c1103571.pth",
|
||||
"densenet161": "https://download.pytorch.org/models/densenet161-8d451a50.pth",
|
||||
"efficientnet_b0": "https://download.pytorch.org/models/efficientnet_b0_rwightman-3dd342df.pth",
|
||||
"efficientnet_b1": "https://download.pytorch.org/models/efficientnet_b1_rwightman-533bc792.pth",
|
||||
"efficientnet_b2": "https://download.pytorch.org/models/efficientnet_b2_rwightman-bcdf34b7.pth",
|
||||
"efficientnet_b3": "https://download.pytorch.org/models/efficientnet_b3_rwightman-cf984f9c.pth",
|
||||
"efficientnet_b4": "https://download.pytorch.org/models/efficientnet_b4_rwightman-7eb33cd5.pth",
|
||||
"efficientnet_b5": "https://download.pytorch.org/models/efficientnet_b5_lukemelas-b6417697.pth",
|
||||
"efficientnet_b6": "https://download.pytorch.org/models/efficientnet_b6_lukemelas-c76e70fd.pth",
|
||||
"efficientnet_b7": "https://download.pytorch.org/models/efficientnet_b7_lukemelas-dcc49843.pth",
|
||||
"googlenet": "https://download.pytorch.org/models/googlenet-1378be20.pth",
|
||||
"inception_v3_google": "https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth",
|
||||
"mobilenet_v2": "https://download.pytorch.org/models/mobilenet_v2-b0353104.pth",
|
||||
"mobilenet_v3_large": "https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth",
|
||||
"mobilenet_v3_small": "https://download.pytorch.org/models/mobilenet_v3_small-047dcff4.pth",
|
||||
"regnet_y_400mf": "https://download.pytorch.org/models/regnet_y_400mf-c65dace8.pth",
|
||||
"regnet_y_800mf": "https://download.pytorch.org/models/regnet_y_800mf-1b27b58c.pth",
|
||||
"regnet_y_1_6gf": "https://download.pytorch.org/models/regnet_y_1_6gf-b11a554e.pth",
|
||||
"regnet_y_3_2gf": "https://download.pytorch.org/models/regnet_y_3_2gf-b5a9779c.pth",
|
||||
"regnet_y_8gf": "https://download.pytorch.org/models/regnet_y_8gf-d0d0e4a8.pth",
|
||||
"regnet_y_16gf": "https://download.pytorch.org/models/regnet_y_16gf-9e6ed7dd.pth",
|
||||
"regnet_y_32gf": "https://download.pytorch.org/models/regnet_y_32gf-4dee3f7a.pth",
|
||||
"regnet_x_400mf": "https://download.pytorch.org/models/regnet_x_400mf-adf1edd5.pth",
|
||||
"regnet_x_800mf": "https://download.pytorch.org/models/regnet_x_800mf-ad17e45c.pth",
|
||||
"regnet_x_1_6gf": "https://download.pytorch.org/models/regnet_x_1_6gf-e3633e7f.pth",
|
||||
"regnet_x_3_2gf": "https://download.pytorch.org/models/regnet_x_3_2gf-f342aeae.pth",
|
||||
"regnet_x_8gf": "https://download.pytorch.org/models/regnet_x_8gf-03ceed89.pth",
|
||||
"regnet_x_16gf": "https://download.pytorch.org/models/regnet_x_16gf-2007eb11.pth",
|
||||
"regnet_x_32gf": "https://download.pytorch.org/models/regnet_x_32gf-9d47f8d0.pth",
|
||||
"resnet18": "https://download.pytorch.org/models/resnet18-f37072fd.pth",
|
||||
"resnet34": "https://download.pytorch.org/models/resnet34-b627a593.pth",
|
||||
"resnet50": "https://download.pytorch.org/models/resnet50-0676ba61.pth",
|
||||
"resnet101": "https://download.pytorch.org/models/resnet101-63fe2227.pth",
|
||||
"resnet152": "https://download.pytorch.org/models/resnet152-394f9c45.pth",
|
||||
"resnext50_32x4d": "https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth",
|
||||
"resnext101_32x8d": "https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth",
|
||||
"wide_resnet50_2": "https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth",
|
||||
"wide_resnet101_2": "https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth",
|
||||
"shufflenetv2_x0.5": "https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth",
|
||||
"shufflenetv2_x1.0": "https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth",
|
||||
"shufflenetv2_x1.5": null,
|
||||
"shufflenetv2_x2.0": null,
|
||||
"squeezenet1_0": "https://download.pytorch.org/models/squeezenet1_0-b66bff10.pth",
|
||||
"squeezenet1_1": "https://download.pytorch.org/models/squeezenet1_1-b8a52dc0.pth",
|
||||
"vgg11": "https://download.pytorch.org/models/vgg11-8a719046.pth",
|
||||
"vgg13": "https://download.pytorch.org/models/vgg13-19584684.pth",
|
||||
"vgg16": "https://download.pytorch.org/models/vgg16-397923af.pth",
|
||||
"vgg19": "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth",
|
||||
"vgg11_bn": "https://download.pytorch.org/models/vgg11_bn-6002323d.pth",
|
||||
"vgg13_bn": "https://download.pytorch.org/models/vgg13_bn-abd245e5.pth",
|
||||
"vgg16_bn": "https://download.pytorch.org/models/vgg16_bn-6c64b313.pth",
|
||||
"vgg19_bn": "https://download.pytorch.org/models/vgg19_bn-c79401a0.pth"
|
||||
}
|
|
@ -18,7 +18,7 @@ import mmcv
|
|||
from ..fileio import FileClient
|
||||
from ..fileio import load as load_file
|
||||
from ..parallel import is_module_wrapper
|
||||
from ..utils import load_url, mkdir_or_exist
|
||||
from ..utils import digit_version, load_url, mkdir_or_exist
|
||||
from .dist_utils import get_dist_info
|
||||
|
||||
ENV_MMCV_HOME = 'MMCV_HOME'
|
||||
|
@ -106,14 +106,48 @@ def load_state_dict(module, state_dict, strict=False, logger=None):
|
|||
|
||||
|
||||
def get_torchvision_models():
|
||||
model_urls = dict()
|
||||
for _, name, ispkg in pkgutil.walk_packages(torchvision.models.__path__):
|
||||
if ispkg:
|
||||
continue
|
||||
_zoo = import_module(f'torchvision.models.{name}')
|
||||
if hasattr(_zoo, 'model_urls'):
|
||||
_urls = getattr(_zoo, 'model_urls')
|
||||
model_urls.update(_urls)
|
||||
if digit_version(torchvision.__version__) < digit_version('0.13.0a0'):
|
||||
model_urls = dict()
|
||||
# When the version of torchvision is lower than 0.13, the model url is
|
||||
# not declared in `torchvision.model.__init__.py`, so we need to
|
||||
# iterate through `torchvision.models.__path__` to get the url for each
|
||||
# model.
|
||||
for _, name, ispkg in pkgutil.walk_packages(
|
||||
torchvision.models.__path__):
|
||||
if ispkg:
|
||||
continue
|
||||
_zoo = import_module(f'torchvision.models.{name}')
|
||||
if hasattr(_zoo, 'model_urls'):
|
||||
_urls = getattr(_zoo, 'model_urls')
|
||||
model_urls.update(_urls)
|
||||
else:
|
||||
# Since torchvision bumps to v0.13, the weight loading logic,
|
||||
# model keys and model urls have been changed. Here the URLs of old
|
||||
# version is loaded to avoid breaking back compatibility. If the
|
||||
# torchvision version>=0.13.0, new URLs will be added. Users can get
|
||||
# the resnet50 checkpoint by setting 'resnet50.imagent1k_v1',
|
||||
# 'resnet50' or 'ResNet50_Weights.IMAGENET1K_V1' in the config.
|
||||
json_path = osp.join(mmcv.__path__[0],
|
||||
'model_zoo/torchvision_0.12.json')
|
||||
model_urls = mmcv.load(json_path)
|
||||
for cls_name, cls in torchvision.models.__dict__.items():
|
||||
# The name of torchvision model weights classes ends with
|
||||
# `_Weights` such as `ResNet18_Weights`. However, some model weight
|
||||
# classes, such as `MNASNet0_75_Weights` does not have any urls in
|
||||
# torchvision 0.13.0 and cannot be iterated. Here we simply check
|
||||
# `DEFAULT` attribute to ensure the class is not empty.
|
||||
if (not cls_name.endswith('_Weights')
|
||||
or not hasattr(cls, 'DEFAULT')):
|
||||
continue
|
||||
# Since `cls.DEFAULT` can not be accessed by iterating cls, we set
|
||||
# default urls explicitly.
|
||||
cls_key = cls_name.replace('_Weights', '').lower()
|
||||
model_urls[f'{cls_key}.default'] = cls.DEFAULT.url
|
||||
for weight_enum in cls:
|
||||
cls_key = cls_name.replace('_Weights', '').lower()
|
||||
cls_key = f'{cls_key}.{weight_enum.name.lower()}'
|
||||
model_urls[cls_key] = weight_enum.url
|
||||
|
||||
return model_urls
|
||||
|
||||
|
||||
|
@ -396,6 +430,11 @@ def load_from_torchvision(filename, map_location=None):
|
|||
model_name = filename[11:]
|
||||
else:
|
||||
model_name = filename[14:]
|
||||
|
||||
# Support getting model urls in the same way as torchvision
|
||||
# `ResNet50_Weights.IMAGENET1K_V1` will be mapped to
|
||||
# resnet50.imagenet1k_v1.
|
||||
model_name = model_name.lower().replace('_weights', '')
|
||||
return load_from_http(model_urls[model_name], map_location=map_location)
|
||||
|
||||
|
||||
|
|
|
@ -4,6 +4,7 @@ import os.path as osp
|
|||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import torchvision
|
||||
|
||||
import mmcv
|
||||
from mmcv.runner.checkpoint import (DEFAULT_CACHE_DIR, ENV_MMCV_HOME,
|
||||
|
@ -11,7 +12,7 @@ from mmcv.runner.checkpoint import (DEFAULT_CACHE_DIR, ENV_MMCV_HOME,
|
|||
_load_checkpoint,
|
||||
get_deprecated_model_names,
|
||||
get_external_models)
|
||||
from mmcv.utils import TORCH_VERSION
|
||||
from mmcv.utils import digit_version
|
||||
|
||||
|
||||
@patch('mmcv.__path__', [osp.join(osp.dirname(__file__), 'data/')])
|
||||
|
@ -77,24 +78,33 @@ def load(filepath, map_location=None):
|
|||
@patch('torch.load', load)
|
||||
def test_load_external_url():
|
||||
# test modelzoo://
|
||||
url = _load_checkpoint('modelzoo://resnet50')
|
||||
if TORCH_VERSION < '1.9.0':
|
||||
assert url == ('url:https://download.pytorch.org/models/resnet50-19c8e'
|
||||
'357.pth')
|
||||
torchvision_version = torchvision.__version__
|
||||
if digit_version(torchvision_version) < digit_version('0.10.0a0'):
|
||||
assert (_load_checkpoint('modelzoo://resnet50') ==
|
||||
'url:https://download.pytorch.org/models/resnet50-19c8e'
|
||||
'357.pth')
|
||||
assert (_load_checkpoint('torchvision://resnet50') ==
|
||||
'url:https://download.pytorch.org/models/resnet50-19c8e'
|
||||
'357.pth')
|
||||
else:
|
||||
# filename of checkpoint is renamed in torch1.9.0
|
||||
assert url == ('url:https://download.pytorch.org/models/resnet50-0676b'
|
||||
'a61.pth')
|
||||
assert (_load_checkpoint('modelzoo://resnet50') ==
|
||||
'url:https://download.pytorch.org/models/resnet50-0676b'
|
||||
'a61.pth')
|
||||
assert (_load_checkpoint('torchvision://resnet50') ==
|
||||
'url:https://download.pytorch.org/models/resnet50-0676b'
|
||||
'a61.pth')
|
||||
|
||||
# test torchvision://
|
||||
url = _load_checkpoint('torchvision://resnet50')
|
||||
if TORCH_VERSION < '1.9.0':
|
||||
assert url == ('url:https://download.pytorch.org/models/resnet50-19c8e'
|
||||
'357.pth')
|
||||
else:
|
||||
# filename of checkpoint is renamed in torch1.9.0
|
||||
assert url == ('url:https://download.pytorch.org/models/resnet50-0676b'
|
||||
'a61.pth')
|
||||
if digit_version(torchvision_version) >= digit_version('0.13.0a0'):
|
||||
# Test load new format torchvision models.
|
||||
assert (
|
||||
_load_checkpoint('torchvision://resnet50.imagenet1k_v1') ==
|
||||
'url:https://download.pytorch.org/models/resnet50-0676ba61.pth')
|
||||
|
||||
assert (
|
||||
_load_checkpoint('torchvision://ResNet50_Weights.IMAGENET1K_V1') ==
|
||||
'url:https://download.pytorch.org/models/resnet50-0676ba61.pth')
|
||||
|
||||
_load_checkpoint('torchvision://resnet50.default')
|
||||
|
||||
# test open-mmlab:// with default MMCV_HOME
|
||||
os.environ.pop(ENV_MMCV_HOME, None)
|
||||
|
|
Loading…
Reference in New Issue