diff --git a/mmcv/model_zoo/torchvision_0.12.json b/mmcv/model_zoo/torchvision_0.12.json new file mode 100644 index 000000000..06defe674 --- /dev/null +++ b/mmcv/model_zoo/torchvision_0.12.json @@ -0,0 +1,57 @@ +{ + "alexnet": "https://download.pytorch.org/models/alexnet-owt-7be5be79.pth", + "densenet121": "https://download.pytorch.org/models/densenet121-a639ec97.pth", + "densenet169": "https://download.pytorch.org/models/densenet169-b2777c0a.pth", + "densenet201": "https://download.pytorch.org/models/densenet201-c1103571.pth", + "densenet161": "https://download.pytorch.org/models/densenet161-8d451a50.pth", + "efficientnet_b0": "https://download.pytorch.org/models/efficientnet_b0_rwightman-3dd342df.pth", + "efficientnet_b1": "https://download.pytorch.org/models/efficientnet_b1_rwightman-533bc792.pth", + "efficientnet_b2": "https://download.pytorch.org/models/efficientnet_b2_rwightman-bcdf34b7.pth", + "efficientnet_b3": "https://download.pytorch.org/models/efficientnet_b3_rwightman-cf984f9c.pth", + "efficientnet_b4": "https://download.pytorch.org/models/efficientnet_b4_rwightman-7eb33cd5.pth", + "efficientnet_b5": "https://download.pytorch.org/models/efficientnet_b5_lukemelas-b6417697.pth", + "efficientnet_b6": "https://download.pytorch.org/models/efficientnet_b6_lukemelas-c76e70fd.pth", + "efficientnet_b7": "https://download.pytorch.org/models/efficientnet_b7_lukemelas-dcc49843.pth", + "googlenet": "https://download.pytorch.org/models/googlenet-1378be20.pth", + "inception_v3_google": "https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth", + "mobilenet_v2": "https://download.pytorch.org/models/mobilenet_v2-b0353104.pth", + "mobilenet_v3_large": "https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth", + "mobilenet_v3_small": "https://download.pytorch.org/models/mobilenet_v3_small-047dcff4.pth", + "regnet_y_400mf": "https://download.pytorch.org/models/regnet_y_400mf-c65dace8.pth", + "regnet_y_800mf": "https://download.pytorch.org/models/regnet_y_800mf-1b27b58c.pth", + "regnet_y_1_6gf": "https://download.pytorch.org/models/regnet_y_1_6gf-b11a554e.pth", + "regnet_y_3_2gf": "https://download.pytorch.org/models/regnet_y_3_2gf-b5a9779c.pth", + "regnet_y_8gf": "https://download.pytorch.org/models/regnet_y_8gf-d0d0e4a8.pth", + "regnet_y_16gf": "https://download.pytorch.org/models/regnet_y_16gf-9e6ed7dd.pth", + "regnet_y_32gf": "https://download.pytorch.org/models/regnet_y_32gf-4dee3f7a.pth", + "regnet_x_400mf": "https://download.pytorch.org/models/regnet_x_400mf-adf1edd5.pth", + "regnet_x_800mf": "https://download.pytorch.org/models/regnet_x_800mf-ad17e45c.pth", + "regnet_x_1_6gf": "https://download.pytorch.org/models/regnet_x_1_6gf-e3633e7f.pth", + "regnet_x_3_2gf": "https://download.pytorch.org/models/regnet_x_3_2gf-f342aeae.pth", + "regnet_x_8gf": "https://download.pytorch.org/models/regnet_x_8gf-03ceed89.pth", + "regnet_x_16gf": "https://download.pytorch.org/models/regnet_x_16gf-2007eb11.pth", + "regnet_x_32gf": "https://download.pytorch.org/models/regnet_x_32gf-9d47f8d0.pth", + "resnet18": "https://download.pytorch.org/models/resnet18-f37072fd.pth", + "resnet34": "https://download.pytorch.org/models/resnet34-b627a593.pth", + "resnet50": "https://download.pytorch.org/models/resnet50-0676ba61.pth", + "resnet101": "https://download.pytorch.org/models/resnet101-63fe2227.pth", + "resnet152": "https://download.pytorch.org/models/resnet152-394f9c45.pth", + "resnext50_32x4d": "https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth", + "resnext101_32x8d": "https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth", + "wide_resnet50_2": "https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth", + "wide_resnet101_2": "https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth", + "shufflenetv2_x0.5": "https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth", + "shufflenetv2_x1.0": "https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth", + "shufflenetv2_x1.5": null, + "shufflenetv2_x2.0": null, + "squeezenet1_0": "https://download.pytorch.org/models/squeezenet1_0-b66bff10.pth", + "squeezenet1_1": "https://download.pytorch.org/models/squeezenet1_1-b8a52dc0.pth", + "vgg11": "https://download.pytorch.org/models/vgg11-8a719046.pth", + "vgg13": "https://download.pytorch.org/models/vgg13-19584684.pth", + "vgg16": "https://download.pytorch.org/models/vgg16-397923af.pth", + "vgg19": "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth", + "vgg11_bn": "https://download.pytorch.org/models/vgg11_bn-6002323d.pth", + "vgg13_bn": "https://download.pytorch.org/models/vgg13_bn-abd245e5.pth", + "vgg16_bn": "https://download.pytorch.org/models/vgg16_bn-6c64b313.pth", + "vgg19_bn": "https://download.pytorch.org/models/vgg19_bn-c79401a0.pth" +} diff --git a/mmcv/runner/checkpoint.py b/mmcv/runner/checkpoint.py index 7eaa0816c..835ee725a 100644 --- a/mmcv/runner/checkpoint.py +++ b/mmcv/runner/checkpoint.py @@ -18,7 +18,7 @@ import mmcv from ..fileio import FileClient from ..fileio import load as load_file from ..parallel import is_module_wrapper -from ..utils import load_url, mkdir_or_exist +from ..utils import digit_version, load_url, mkdir_or_exist from .dist_utils import get_dist_info ENV_MMCV_HOME = 'MMCV_HOME' @@ -106,14 +106,48 @@ def load_state_dict(module, state_dict, strict=False, logger=None): def get_torchvision_models(): - model_urls = dict() - for _, name, ispkg in pkgutil.walk_packages(torchvision.models.__path__): - if ispkg: - continue - _zoo = import_module(f'torchvision.models.{name}') - if hasattr(_zoo, 'model_urls'): - _urls = getattr(_zoo, 'model_urls') - model_urls.update(_urls) + if digit_version(torchvision.__version__) < digit_version('0.13.0a0'): + model_urls = dict() + # When the version of torchvision is lower than 0.13, the model url is + # not declared in `torchvision.model.__init__.py`, so we need to + # iterate through `torchvision.models.__path__` to get the url for each + # model. + for _, name, ispkg in pkgutil.walk_packages( + torchvision.models.__path__): + if ispkg: + continue + _zoo = import_module(f'torchvision.models.{name}') + if hasattr(_zoo, 'model_urls'): + _urls = getattr(_zoo, 'model_urls') + model_urls.update(_urls) + else: + # Since torchvision bumps to v0.13, the weight loading logic, + # model keys and model urls have been changed. Here the URLs of old + # version is loaded to avoid breaking back compatibility. If the + # torchvision version>=0.13.0, new URLs will be added. Users can get + # the resnet50 checkpoint by setting 'resnet50.imagent1k_v1', + # 'resnet50' or 'ResNet50_Weights.IMAGENET1K_V1' in the config. + json_path = osp.join(mmcv.__path__[0], + 'model_zoo/torchvision_0.12.json') + model_urls = mmcv.load(json_path) + for cls_name, cls in torchvision.models.__dict__.items(): + # The name of torchvision model weights classes ends with + # `_Weights` such as `ResNet18_Weights`. However, some model weight + # classes, such as `MNASNet0_75_Weights` does not have any urls in + # torchvision 0.13.0 and cannot be iterated. Here we simply check + # `DEFAULT` attribute to ensure the class is not empty. + if (not cls_name.endswith('_Weights') + or not hasattr(cls, 'DEFAULT')): + continue + # Since `cls.DEFAULT` can not be accessed by iterating cls, we set + # default urls explicitly. + cls_key = cls_name.replace('_Weights', '').lower() + model_urls[f'{cls_key}.default'] = cls.DEFAULT.url + for weight_enum in cls: + cls_key = cls_name.replace('_Weights', '').lower() + cls_key = f'{cls_key}.{weight_enum.name.lower()}' + model_urls[cls_key] = weight_enum.url + return model_urls @@ -396,6 +430,11 @@ def load_from_torchvision(filename, map_location=None): model_name = filename[11:] else: model_name = filename[14:] + + # Support getting model urls in the same way as torchvision + # `ResNet50_Weights.IMAGENET1K_V1` will be mapped to + # resnet50.imagenet1k_v1. + model_name = model_name.lower().replace('_weights', '') return load_from_http(model_urls[model_name], map_location=map_location) diff --git a/tests/test_load_model_zoo.py b/tests/test_load_model_zoo.py index 35492fa8a..904cb9403 100644 --- a/tests/test_load_model_zoo.py +++ b/tests/test_load_model_zoo.py @@ -4,6 +4,7 @@ import os.path as osp from unittest.mock import patch import pytest +import torchvision import mmcv from mmcv.runner.checkpoint import (DEFAULT_CACHE_DIR, ENV_MMCV_HOME, @@ -11,7 +12,7 @@ from mmcv.runner.checkpoint import (DEFAULT_CACHE_DIR, ENV_MMCV_HOME, _load_checkpoint, get_deprecated_model_names, get_external_models) -from mmcv.utils import TORCH_VERSION +from mmcv.utils import digit_version @patch('mmcv.__path__', [osp.join(osp.dirname(__file__), 'data/')]) @@ -77,24 +78,33 @@ def load(filepath, map_location=None): @patch('torch.load', load) def test_load_external_url(): # test modelzoo:// - url = _load_checkpoint('modelzoo://resnet50') - if TORCH_VERSION < '1.9.0': - assert url == ('url:https://download.pytorch.org/models/resnet50-19c8e' - '357.pth') + torchvision_version = torchvision.__version__ + if digit_version(torchvision_version) < digit_version('0.10.0a0'): + assert (_load_checkpoint('modelzoo://resnet50') == + 'url:https://download.pytorch.org/models/resnet50-19c8e' + '357.pth') + assert (_load_checkpoint('torchvision://resnet50') == + 'url:https://download.pytorch.org/models/resnet50-19c8e' + '357.pth') else: - # filename of checkpoint is renamed in torch1.9.0 - assert url == ('url:https://download.pytorch.org/models/resnet50-0676b' - 'a61.pth') + assert (_load_checkpoint('modelzoo://resnet50') == + 'url:https://download.pytorch.org/models/resnet50-0676b' + 'a61.pth') + assert (_load_checkpoint('torchvision://resnet50') == + 'url:https://download.pytorch.org/models/resnet50-0676b' + 'a61.pth') - # test torchvision:// - url = _load_checkpoint('torchvision://resnet50') - if TORCH_VERSION < '1.9.0': - assert url == ('url:https://download.pytorch.org/models/resnet50-19c8e' - '357.pth') - else: - # filename of checkpoint is renamed in torch1.9.0 - assert url == ('url:https://download.pytorch.org/models/resnet50-0676b' - 'a61.pth') + if digit_version(torchvision_version) >= digit_version('0.13.0a0'): + # Test load new format torchvision models. + assert ( + _load_checkpoint('torchvision://resnet50.imagenet1k_v1') == + 'url:https://download.pytorch.org/models/resnet50-0676ba61.pth') + + assert ( + _load_checkpoint('torchvision://ResNet50_Weights.IMAGENET1K_V1') == + 'url:https://download.pytorch.org/models/resnet50-0676ba61.pth') + + _load_checkpoint('torchvision://resnet50.default') # test open-mmlab:// with default MMCV_HOME os.environ.pop(ENV_MMCV_HOME, None)