[Enhancement] Enhance get_torchvision_models (#1867)

* enhance get_torchvision_models

* simplify logic

* Dump ckpt in torchvision lower than 0.13.0 to a json file

* add json

* refactor load urls logic

* fix unit test

* change url key to lower letters

* check torchvision version rather than check torch version in unittest

* Fix CI and refine test logic of torchvision version

* add comment

* support compare pre-release version

* support loaad modeel like torchvision

* refine comment.

* fix unit test and comment

* fxi unit test bug

* support get model by lower weights
pull/1896/head
Mashiro 2022-04-18 23:20:48 +08:00 committed by GitHub
parent a5cfcb93ff
commit a80df6874d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 132 additions and 26 deletions

View File

@ -0,0 +1,57 @@
{
"alexnet": "https://download.pytorch.org/models/alexnet-owt-7be5be79.pth",
"densenet121": "https://download.pytorch.org/models/densenet121-a639ec97.pth",
"densenet169": "https://download.pytorch.org/models/densenet169-b2777c0a.pth",
"densenet201": "https://download.pytorch.org/models/densenet201-c1103571.pth",
"densenet161": "https://download.pytorch.org/models/densenet161-8d451a50.pth",
"efficientnet_b0": "https://download.pytorch.org/models/efficientnet_b0_rwightman-3dd342df.pth",
"efficientnet_b1": "https://download.pytorch.org/models/efficientnet_b1_rwightman-533bc792.pth",
"efficientnet_b2": "https://download.pytorch.org/models/efficientnet_b2_rwightman-bcdf34b7.pth",
"efficientnet_b3": "https://download.pytorch.org/models/efficientnet_b3_rwightman-cf984f9c.pth",
"efficientnet_b4": "https://download.pytorch.org/models/efficientnet_b4_rwightman-7eb33cd5.pth",
"efficientnet_b5": "https://download.pytorch.org/models/efficientnet_b5_lukemelas-b6417697.pth",
"efficientnet_b6": "https://download.pytorch.org/models/efficientnet_b6_lukemelas-c76e70fd.pth",
"efficientnet_b7": "https://download.pytorch.org/models/efficientnet_b7_lukemelas-dcc49843.pth",
"googlenet": "https://download.pytorch.org/models/googlenet-1378be20.pth",
"inception_v3_google": "https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth",
"mobilenet_v2": "https://download.pytorch.org/models/mobilenet_v2-b0353104.pth",
"mobilenet_v3_large": "https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth",
"mobilenet_v3_small": "https://download.pytorch.org/models/mobilenet_v3_small-047dcff4.pth",
"regnet_y_400mf": "https://download.pytorch.org/models/regnet_y_400mf-c65dace8.pth",
"regnet_y_800mf": "https://download.pytorch.org/models/regnet_y_800mf-1b27b58c.pth",
"regnet_y_1_6gf": "https://download.pytorch.org/models/regnet_y_1_6gf-b11a554e.pth",
"regnet_y_3_2gf": "https://download.pytorch.org/models/regnet_y_3_2gf-b5a9779c.pth",
"regnet_y_8gf": "https://download.pytorch.org/models/regnet_y_8gf-d0d0e4a8.pth",
"regnet_y_16gf": "https://download.pytorch.org/models/regnet_y_16gf-9e6ed7dd.pth",
"regnet_y_32gf": "https://download.pytorch.org/models/regnet_y_32gf-4dee3f7a.pth",
"regnet_x_400mf": "https://download.pytorch.org/models/regnet_x_400mf-adf1edd5.pth",
"regnet_x_800mf": "https://download.pytorch.org/models/regnet_x_800mf-ad17e45c.pth",
"regnet_x_1_6gf": "https://download.pytorch.org/models/regnet_x_1_6gf-e3633e7f.pth",
"regnet_x_3_2gf": "https://download.pytorch.org/models/regnet_x_3_2gf-f342aeae.pth",
"regnet_x_8gf": "https://download.pytorch.org/models/regnet_x_8gf-03ceed89.pth",
"regnet_x_16gf": "https://download.pytorch.org/models/regnet_x_16gf-2007eb11.pth",
"regnet_x_32gf": "https://download.pytorch.org/models/regnet_x_32gf-9d47f8d0.pth",
"resnet18": "https://download.pytorch.org/models/resnet18-f37072fd.pth",
"resnet34": "https://download.pytorch.org/models/resnet34-b627a593.pth",
"resnet50": "https://download.pytorch.org/models/resnet50-0676ba61.pth",
"resnet101": "https://download.pytorch.org/models/resnet101-63fe2227.pth",
"resnet152": "https://download.pytorch.org/models/resnet152-394f9c45.pth",
"resnext50_32x4d": "https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth",
"resnext101_32x8d": "https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth",
"wide_resnet50_2": "https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth",
"wide_resnet101_2": "https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth",
"shufflenetv2_x0.5": "https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth",
"shufflenetv2_x1.0": "https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth",
"shufflenetv2_x1.5": null,
"shufflenetv2_x2.0": null,
"squeezenet1_0": "https://download.pytorch.org/models/squeezenet1_0-b66bff10.pth",
"squeezenet1_1": "https://download.pytorch.org/models/squeezenet1_1-b8a52dc0.pth",
"vgg11": "https://download.pytorch.org/models/vgg11-8a719046.pth",
"vgg13": "https://download.pytorch.org/models/vgg13-19584684.pth",
"vgg16": "https://download.pytorch.org/models/vgg16-397923af.pth",
"vgg19": "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth",
"vgg11_bn": "https://download.pytorch.org/models/vgg11_bn-6002323d.pth",
"vgg13_bn": "https://download.pytorch.org/models/vgg13_bn-abd245e5.pth",
"vgg16_bn": "https://download.pytorch.org/models/vgg16_bn-6c64b313.pth",
"vgg19_bn": "https://download.pytorch.org/models/vgg19_bn-c79401a0.pth"
}

View File

@ -18,7 +18,7 @@ import mmcv
from ..fileio import FileClient
from ..fileio import load as load_file
from ..parallel import is_module_wrapper
from ..utils import load_url, mkdir_or_exist
from ..utils import digit_version, load_url, mkdir_or_exist
from .dist_utils import get_dist_info
ENV_MMCV_HOME = 'MMCV_HOME'
@ -106,14 +106,48 @@ def load_state_dict(module, state_dict, strict=False, logger=None):
def get_torchvision_models():
model_urls = dict()
for _, name, ispkg in pkgutil.walk_packages(torchvision.models.__path__):
if ispkg:
continue
_zoo = import_module(f'torchvision.models.{name}')
if hasattr(_zoo, 'model_urls'):
_urls = getattr(_zoo, 'model_urls')
model_urls.update(_urls)
if digit_version(torchvision.__version__) < digit_version('0.13.0a0'):
model_urls = dict()
# When the version of torchvision is lower than 0.13, the model url is
# not declared in `torchvision.model.__init__.py`, so we need to
# iterate through `torchvision.models.__path__` to get the url for each
# model.
for _, name, ispkg in pkgutil.walk_packages(
torchvision.models.__path__):
if ispkg:
continue
_zoo = import_module(f'torchvision.models.{name}')
if hasattr(_zoo, 'model_urls'):
_urls = getattr(_zoo, 'model_urls')
model_urls.update(_urls)
else:
# Since torchvision bumps to v0.13, the weight loading logic,
# model keys and model urls have been changed. Here the URLs of old
# version is loaded to avoid breaking back compatibility. If the
# torchvision version>=0.13.0, new URLs will be added. Users can get
# the resnet50 checkpoint by setting 'resnet50.imagent1k_v1',
# 'resnet50' or 'ResNet50_Weights.IMAGENET1K_V1' in the config.
json_path = osp.join(mmcv.__path__[0],
'model_zoo/torchvision_0.12.json')
model_urls = mmcv.load(json_path)
for cls_name, cls in torchvision.models.__dict__.items():
# The name of torchvision model weights classes ends with
# `_Weights` such as `ResNet18_Weights`. However, some model weight
# classes, such as `MNASNet0_75_Weights` does not have any urls in
# torchvision 0.13.0 and cannot be iterated. Here we simply check
# `DEFAULT` attribute to ensure the class is not empty.
if (not cls_name.endswith('_Weights')
or not hasattr(cls, 'DEFAULT')):
continue
# Since `cls.DEFAULT` can not be accessed by iterating cls, we set
# default urls explicitly.
cls_key = cls_name.replace('_Weights', '').lower()
model_urls[f'{cls_key}.default'] = cls.DEFAULT.url
for weight_enum in cls:
cls_key = cls_name.replace('_Weights', '').lower()
cls_key = f'{cls_key}.{weight_enum.name.lower()}'
model_urls[cls_key] = weight_enum.url
return model_urls
@ -396,6 +430,11 @@ def load_from_torchvision(filename, map_location=None):
model_name = filename[11:]
else:
model_name = filename[14:]
# Support getting model urls in the same way as torchvision
# `ResNet50_Weights.IMAGENET1K_V1` will be mapped to
# resnet50.imagenet1k_v1.
model_name = model_name.lower().replace('_weights', '')
return load_from_http(model_urls[model_name], map_location=map_location)

View File

@ -4,6 +4,7 @@ import os.path as osp
from unittest.mock import patch
import pytest
import torchvision
import mmcv
from mmcv.runner.checkpoint import (DEFAULT_CACHE_DIR, ENV_MMCV_HOME,
@ -11,7 +12,7 @@ from mmcv.runner.checkpoint import (DEFAULT_CACHE_DIR, ENV_MMCV_HOME,
_load_checkpoint,
get_deprecated_model_names,
get_external_models)
from mmcv.utils import TORCH_VERSION
from mmcv.utils import digit_version
@patch('mmcv.__path__', [osp.join(osp.dirname(__file__), 'data/')])
@ -77,24 +78,33 @@ def load(filepath, map_location=None):
@patch('torch.load', load)
def test_load_external_url():
# test modelzoo://
url = _load_checkpoint('modelzoo://resnet50')
if TORCH_VERSION < '1.9.0':
assert url == ('url:https://download.pytorch.org/models/resnet50-19c8e'
'357.pth')
torchvision_version = torchvision.__version__
if digit_version(torchvision_version) < digit_version('0.10.0a0'):
assert (_load_checkpoint('modelzoo://resnet50') ==
'url:https://download.pytorch.org/models/resnet50-19c8e'
'357.pth')
assert (_load_checkpoint('torchvision://resnet50') ==
'url:https://download.pytorch.org/models/resnet50-19c8e'
'357.pth')
else:
# filename of checkpoint is renamed in torch1.9.0
assert url == ('url:https://download.pytorch.org/models/resnet50-0676b'
'a61.pth')
assert (_load_checkpoint('modelzoo://resnet50') ==
'url:https://download.pytorch.org/models/resnet50-0676b'
'a61.pth')
assert (_load_checkpoint('torchvision://resnet50') ==
'url:https://download.pytorch.org/models/resnet50-0676b'
'a61.pth')
# test torchvision://
url = _load_checkpoint('torchvision://resnet50')
if TORCH_VERSION < '1.9.0':
assert url == ('url:https://download.pytorch.org/models/resnet50-19c8e'
'357.pth')
else:
# filename of checkpoint is renamed in torch1.9.0
assert url == ('url:https://download.pytorch.org/models/resnet50-0676b'
'a61.pth')
if digit_version(torchvision_version) >= digit_version('0.13.0a0'):
# Test load new format torchvision models.
assert (
_load_checkpoint('torchvision://resnet50.imagenet1k_v1') ==
'url:https://download.pytorch.org/models/resnet50-0676ba61.pth')
assert (
_load_checkpoint('torchvision://ResNet50_Weights.IMAGENET1K_V1') ==
'url:https://download.pytorch.org/models/resnet50-0676ba61.pth')
_load_checkpoint('torchvision://resnet50.default')
# test open-mmlab:// with default MMCV_HOME
os.environ.pop(ENV_MMCV_HOME, None)