diff --git a/mmengine/runner/checkpoint.py b/mmengine/runner/checkpoint.py index 12051eab..efff32d3 100644 --- a/mmengine/runner/checkpoint.py +++ b/mmengine/runner/checkpoint.py @@ -19,7 +19,7 @@ from mmengine.fileio import FileClient, get_file_backend from mmengine.fileio import load as load_file from mmengine.logging import print_log from mmengine.model import BaseTTAModel, is_model_wrapper -from mmengine.utils import deprecated_function, mkdir_or_exist +from mmengine.utils import deprecated_function, digit_version, mkdir_or_exist from mmengine.utils.dl_utils import load_url # `MMENGINE_HOME` is the highest priority directory to save checkpoints @@ -113,14 +113,58 @@ def load_state_dict(module, state_dict, strict=False, logger=None): def get_torchvision_models(): - model_urls = dict() - for _, name, ispkg in pkgutil.walk_packages(torchvision.models.__path__): - if ispkg: - continue - _zoo = import_module(f'torchvision.models.{name}') - if hasattr(_zoo, 'model_urls'): - _urls = getattr(_zoo, 'model_urls') - model_urls.update(_urls) + if digit_version(torchvision.__version__) < digit_version('0.13.0a0'): + model_urls = dict() + # When the version of torchvision is lower than 0.13, the model url is + # not declared in `torchvision.model.__init__.py`, so we need to + # iterate through `torchvision.models.__path__` to get the url for each + # model. + for _, name, ispkg in pkgutil.walk_packages( + torchvision.models.__path__): + if ispkg: + continue + _zoo = import_module(f'torchvision.models.{name}') + if hasattr(_zoo, 'model_urls'): + _urls = getattr(_zoo, 'model_urls') + model_urls.update(_urls) + else: + # Since torchvision bumps to v0.13, the weight loading logic, + # model keys and model urls have been changed. Here the URLs of old + # version is loaded to avoid breaking back compatibility. If the + # torchvision version>=0.13.0, new URLs will be added. Users can get + # the resnet50 checkpoint by setting 'resnet50.imagent1k_v1', + # 'resnet50' or 'ResNet50_Weights.IMAGENET1K_V1' in the config. + json_path = osp.join(mmengine.__path__[0], 'hub/torchvision_0.12.json') + model_urls = mmengine.load(json_path) + if digit_version(torchvision.__version__) < digit_version('0.14.0a0'): + weights_list = [ + cls for cls_name, cls in torchvision.models.__dict__.items() + if cls_name.endswith('_Weights') + ] + else: + weights_list = [ + torchvision.models.get_model_weights(model) + for model in torchvision.models.list_models(torchvision.models) + ] + + for cls in weights_list: + # The name of torchvision model weights classes ends with + # `_Weights` such as `ResNet18_Weights`. However, some model weight + # classes, such as `MNASNet0_75_Weights` does not have any urls in + # torchvision 0.13.0 and cannot be iterated. Here we simply check + # `DEFAULT` attribute to ensure the class is not empty. + if not hasattr(cls, 'DEFAULT'): + continue + # Since `cls.DEFAULT` can not be accessed by iterating cls, we set + # default urls explicitly. + cls_name = cls.__name__ + cls_key = cls_name.replace('_Weights', '').lower() + model_urls[f'{cls_key}.default'] = cls.DEFAULT.url + for weight_enum in cls: + cls_key = cls_name.replace('_Weights', '').lower() + cls_key = f'{cls_key}.{weight_enum.name.lower()}' + model_urls[cls_key] = weight_enum.url + return model_urls diff --git a/tests/test_runner/test_checkpoint.py b/tests/test_runner/test_checkpoint.py new file mode 100644 index 00000000..fd7d6d28 --- /dev/null +++ b/tests/test_runner/test_checkpoint.py @@ -0,0 +1,411 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os +import sys +import tempfile +from collections import OrderedDict +from tempfile import TemporaryDirectory +from unittest.mock import MagicMock, patch + +import pytest +import torch +import torch.nn as nn +import torch.optim as optim +from torch.nn.parallel import DataParallel + +from mmengine.fileio.file_client import PetrelBackend +from mmengine.registry import MODEL_WRAPPERS +from mmengine.runner.checkpoint import (CheckpointLoader, + _load_checkpoint_with_prefix, + get_state_dict, load_checkpoint, + load_from_local, load_from_pavi, + save_checkpoint) + +sys.modules['petrel_client'] = MagicMock() +sys.modules['petrel_client.client'] = MagicMock() + + +@MODEL_WRAPPERS.register_module() +class DDPWrapper: + + def __init__(self, module): + self.module = module + + +class Block(nn.Module): + + def __init__(self): + super().__init__() + self.conv = nn.Conv2d(3, 3, 1) + self.norm = nn.BatchNorm2d(3) + + +class Model(nn.Module): + + def __init__(self): + super().__init__() + self.block = Block() + self.conv = nn.Conv2d(3, 3, 1) + + +class Mockpavimodel: + + def __init__(self, name='fakename'): + self.name = name + + def download(self, file): + pass + + +def assert_tensor_equal(tensor_a, tensor_b): + assert tensor_a.eq(tensor_b).all() + + +def test_get_state_dict(): + if torch.__version__ == 'parrots': + state_dict_keys = { + 'block.conv.weight', 'block.conv.bias', 'block.norm.weight', + 'block.norm.bias', 'block.norm.running_mean', + 'block.norm.running_var', 'conv.weight', 'conv.bias' + } + else: + state_dict_keys = { + 'block.conv.weight', 'block.conv.bias', 'block.norm.weight', + 'block.norm.bias', 'block.norm.running_mean', + 'block.norm.running_var', 'block.norm.num_batches_tracked', + 'conv.weight', 'conv.bias' + } + + model = Model() + state_dict = get_state_dict(model) + assert isinstance(state_dict, OrderedDict) + assert set(state_dict.keys()) == state_dict_keys + + assert_tensor_equal(state_dict['block.conv.weight'], + model.block.conv.weight) + assert_tensor_equal(state_dict['block.conv.bias'], model.block.conv.bias) + assert_tensor_equal(state_dict['block.norm.weight'], + model.block.norm.weight) + assert_tensor_equal(state_dict['block.norm.bias'], model.block.norm.bias) + assert_tensor_equal(state_dict['block.norm.running_mean'], + model.block.norm.running_mean) + assert_tensor_equal(state_dict['block.norm.running_var'], + model.block.norm.running_var) + if torch.__version__ != 'parrots': + assert_tensor_equal(state_dict['block.norm.num_batches_tracked'], + model.block.norm.num_batches_tracked) + assert_tensor_equal(state_dict['conv.weight'], model.conv.weight) + assert_tensor_equal(state_dict['conv.bias'], model.conv.bias) + + wrapped_model = DDPWrapper(model) + state_dict = get_state_dict(wrapped_model) + assert isinstance(state_dict, OrderedDict) + assert set(state_dict.keys()) == state_dict_keys + assert_tensor_equal(state_dict['block.conv.weight'], + wrapped_model.module.block.conv.weight) + assert_tensor_equal(state_dict['block.conv.bias'], + wrapped_model.module.block.conv.bias) + assert_tensor_equal(state_dict['block.norm.weight'], + wrapped_model.module.block.norm.weight) + assert_tensor_equal(state_dict['block.norm.bias'], + wrapped_model.module.block.norm.bias) + assert_tensor_equal(state_dict['block.norm.running_mean'], + wrapped_model.module.block.norm.running_mean) + assert_tensor_equal(state_dict['block.norm.running_var'], + wrapped_model.module.block.norm.running_var) + if torch.__version__ != 'parrots': + assert_tensor_equal( + state_dict['block.norm.num_batches_tracked'], + wrapped_model.module.block.norm.num_batches_tracked) + assert_tensor_equal(state_dict['conv.weight'], + wrapped_model.module.conv.weight) + assert_tensor_equal(state_dict['conv.bias'], + wrapped_model.module.conv.bias) + + # wrapped inner module + for name, module in wrapped_model.module._modules.items(): + module = DataParallel(module) + wrapped_model.module._modules[name] = module + state_dict = get_state_dict(wrapped_model) + assert isinstance(state_dict, OrderedDict) + assert set(state_dict.keys()) == state_dict_keys + assert_tensor_equal(state_dict['block.conv.weight'], + wrapped_model.module.block.module.conv.weight) + assert_tensor_equal(state_dict['block.conv.bias'], + wrapped_model.module.block.module.conv.bias) + assert_tensor_equal(state_dict['block.norm.weight'], + wrapped_model.module.block.module.norm.weight) + assert_tensor_equal(state_dict['block.norm.bias'], + wrapped_model.module.block.module.norm.bias) + assert_tensor_equal(state_dict['block.norm.running_mean'], + wrapped_model.module.block.module.norm.running_mean) + assert_tensor_equal(state_dict['block.norm.running_var'], + wrapped_model.module.block.module.norm.running_var) + if torch.__version__ != 'parrots': + assert_tensor_equal( + state_dict['block.norm.num_batches_tracked'], + wrapped_model.module.block.module.norm.num_batches_tracked) + assert_tensor_equal(state_dict['conv.weight'], + wrapped_model.module.conv.module.weight) + assert_tensor_equal(state_dict['conv.bias'], + wrapped_model.module.conv.module.bias) + + +def test_load_pavimodel_dist(): + sys.modules['pavi'] = MagicMock() + sys.modules['pavi.modelcloud'] = MagicMock() + pavimodel = Mockpavimodel() + import pavi + pavi.modelcloud.get = MagicMock(return_value=pavimodel) + with pytest.raises(AssertionError): + # test pavi prefix + _ = load_from_pavi('MyPaviFolder/checkpoint.pth') + + with pytest.raises(FileNotFoundError): + # there is not such checkpoint for us to load + _ = load_from_pavi('pavi://checkpoint.pth') + + +def test_load_checkpoint_with_prefix(): + + class FooModule(nn.Module): + + def __init__(self): + super().__init__() + self.linear = nn.Linear(1, 2) + self.conv2d = nn.Conv2d(3, 1, 3) + self.conv2d_2 = nn.Conv2d(3, 2, 3) + + model = FooModule() + nn.init.constant_(model.linear.weight, 1) + nn.init.constant_(model.linear.bias, 2) + nn.init.constant_(model.conv2d.weight, 3) + nn.init.constant_(model.conv2d.bias, 4) + nn.init.constant_(model.conv2d_2.weight, 5) + nn.init.constant_(model.conv2d_2.bias, 6) + + with TemporaryDirectory(): + torch.save(model.state_dict(), 'model.pth') + prefix = 'conv2d' + state_dict = _load_checkpoint_with_prefix(prefix, 'model.pth') + assert torch.equal(model.conv2d.state_dict()['weight'], + state_dict['weight']) + assert torch.equal(model.conv2d.state_dict()['bias'], + state_dict['bias']) + + # test whether prefix is in pretrained model + with pytest.raises(AssertionError): + prefix = 'back' + _load_checkpoint_with_prefix(prefix, 'model.pth') + + +def test_load_checkpoint(): + import os + import re + import tempfile + + class PrefixModel(nn.Module): + + def __init__(self): + super().__init__() + self.backbone = Model() + + pmodel = PrefixModel() + model = Model() + checkpoint_path = os.path.join(tempfile.gettempdir(), 'checkpoint.pth') + + # add prefix + torch.save(model.state_dict(), checkpoint_path) + state_dict = load_checkpoint( + pmodel, checkpoint_path, revise_keys=[(r'^', 'backbone.')]) + for key in pmodel.backbone.state_dict().keys(): + assert torch.equal(pmodel.backbone.state_dict()[key], state_dict[key]) + # strip prefix + torch.save(pmodel.state_dict(), checkpoint_path) + state_dict = load_checkpoint( + model, checkpoint_path, revise_keys=[(r'^backbone\.', '')]) + + for key in state_dict.keys(): + key_stripped = re.sub(r'^backbone\.', '', key) + assert torch.equal(model.state_dict()[key_stripped], state_dict[key]) + os.remove(checkpoint_path) + + +def test_load_checkpoint_metadata(): + + class ModelV1(nn.Module): + + def __init__(self): + super().__init__() + self.block = Block() + self.conv1 = nn.Conv2d(3, 3, 1) + self.conv2 = nn.Conv2d(3, 3, 1) + nn.init.normal_(self.conv1.weight) + nn.init.normal_(self.conv2.weight) + + class ModelV2(nn.Module): + _version = 2 + + def __init__(self): + super().__init__() + self.block = Block() + self.conv0 = nn.Conv2d(3, 3, 1) + self.conv1 = nn.Conv2d(3, 3, 1) + nn.init.normal_(self.conv0.weight) + nn.init.normal_(self.conv1.weight) + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, + *args, **kwargs): + """load checkpoints.""" + + # Names of some parameters in has been changed. + version = local_metadata.get('version', None) + if version is None or version < 2: + state_dict_keys = list(state_dict.keys()) + convert_map = {'conv1': 'conv0', 'conv2': 'conv1'} + for k in state_dict_keys: + for ori_str, new_str in convert_map.items(): + if k.startswith(prefix + ori_str): + new_key = k.replace(ori_str, new_str) + state_dict[new_key] = state_dict[k] + del state_dict[k] + + super()._load_from_state_dict(state_dict, prefix, local_metadata, + *args, **kwargs) + + model_v1 = ModelV1() + model_v1_conv0_weight = model_v1.conv1.weight.detach() + model_v1_conv1_weight = model_v1.conv2.weight.detach() + model_v2 = ModelV2() + model_v2_conv0_weight = model_v2.conv0.weight.detach() + model_v2_conv1_weight = model_v2.conv1.weight.detach() + ckpt_v1_path = os.path.join(tempfile.gettempdir(), 'checkpoint_v1.pth') + ckpt_v2_path = os.path.join(tempfile.gettempdir(), 'checkpoint_v2.pth') + + # Save checkpoint + save_checkpoint(model_v1.state_dict(), ckpt_v1_path) + save_checkpoint(model_v2.state_dict(), ckpt_v2_path) + + # test load v1 model + load_checkpoint(model_v2, ckpt_v1_path) + assert torch.allclose(model_v2.conv0.weight, model_v1_conv0_weight) + assert torch.allclose(model_v2.conv1.weight, model_v1_conv1_weight) + + # test load v2 model + load_checkpoint(model_v2, ckpt_v2_path) + assert torch.allclose(model_v2.conv0.weight, model_v2_conv0_weight) + assert torch.allclose(model_v2.conv1.weight, model_v2_conv1_weight) + + +def test_checkpoint_loader(): + filenames = [ + 'http://xx.xx/xx.pth', 'https://xx.xx/xx.pth', + 'modelzoo://xx.xx/xx.pth', 'torchvision://xx.xx/xx.pth', + 'open-mmlab://xx.xx/xx.pth', 'openmmlab://xx.xx/xx.pth', + 'mmcls://xx.xx/xx.pth', 'pavi://xx.xx/xx.pth', 's3://xx.xx/xx.pth', + 'ss3://xx.xx/xx.pth', ' s3://xx.xx/xx.pth', + 'open-mmlab:s3://xx.xx/xx.pth', 'openmmlab:s3://xx.xx/xx.pth', + 'openmmlabs3://xx.xx/xx.pth', ':s3://xx.xx/xx.path' + ] + fn_names = [ + 'load_from_http', 'load_from_http', 'load_from_torchvision', + 'load_from_torchvision', 'load_from_openmmlab', 'load_from_openmmlab', + 'load_from_mmcls', 'load_from_pavi', 'load_from_ceph', + 'load_from_local', 'load_from_local', 'load_from_ceph', + 'load_from_ceph', 'load_from_local', 'load_from_local' + ] + + for filename, fn_name in zip(filenames, fn_names): + loader = CheckpointLoader._get_checkpoint_loader(filename) + assert loader.__name__ == fn_name + + @CheckpointLoader.register_scheme(prefixes='ftp://') + def load_from_ftp(filename, map_location): + return dict(filename=filename) + + # test register_loader + filename = 'ftp://xx.xx/xx.pth' + loader = CheckpointLoader._get_checkpoint_loader(filename) + assert loader.__name__ == 'load_from_ftp' + + def load_from_ftp1(filename, map_location): + return dict(filename=filename) + + # test duplicate registered error + with pytest.raises(KeyError): + CheckpointLoader.register_scheme('ftp://', load_from_ftp1) + + # test force param + CheckpointLoader.register_scheme('ftp://', load_from_ftp1, force=True) + checkpoint = CheckpointLoader.load_checkpoint(filename) + assert checkpoint['filename'] == filename + + # test print function name + loader = CheckpointLoader._get_checkpoint_loader(filename) + assert loader.__name__ == 'load_from_ftp1' + + # test sort + @CheckpointLoader.register_scheme(prefixes='a/b') + def load_from_ab(filename, map_location): + return dict(filename=filename) + + @CheckpointLoader.register_scheme(prefixes='a/b/c') + def load_from_abc(filename, map_location): + return dict(filename=filename) + + filename = 'a/b/c/d' + loader = CheckpointLoader._get_checkpoint_loader(filename) + assert loader.__name__ == 'load_from_abc' + + +def test_save_checkpoint(tmp_path): + model = Model() + optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9) + # meta is not a dict + with pytest.raises(TypeError): + save_checkpoint(model, '/path/of/your/filename', meta='invalid type') + + # 1. save to disk + filename = str(tmp_path / 'checkpoint1.pth') + save_checkpoint(model.state_dict(), filename) + + filename = str(tmp_path / 'checkpoint2.pth') + checkpoint = dict( + model=model.state_dict(), optimizer=optimizer.state_dict()) + save_checkpoint(checkpoint, filename) + + filename = str(tmp_path / 'checkpoint3.pth') + save_checkpoint( + model.state_dict(), filename, backend_args={'backend': 'local'}) + + filename = str(tmp_path / 'checkpoint4.pth') + save_checkpoint( + model.state_dict(), filename, file_client_args={'backend': 'disk'}) + + # 2. save to petrel oss + with patch.object(PetrelBackend, 'put') as mock_method: + filename = 's3://path/of/your/checkpoint1.pth' + save_checkpoint(model.state_dict(), filename) + mock_method.assert_called() + + with patch.object(PetrelBackend, 'put') as mock_method: + filename = 's3://path//of/your/checkpoint2.pth' + save_checkpoint( + model.state_dict(), + filename, + file_client_args={'backend': 'petrel'}) + mock_method.assert_called() + + +def test_load_from_local(): + import os + home_path = os.path.expanduser('~') + checkpoint_path = os.path.join( + home_path, 'dummy_checkpoint_used_to_test_load_from_local.pth') + model = Model() + save_checkpoint(model.state_dict(), checkpoint_path) + checkpoint = load_from_local( + '~/dummy_checkpoint_used_to_test_load_from_local.pth', + map_location=None) + assert_tensor_equal(checkpoint['block.conv.weight'], + model.block.conv.weight) + os.remove(checkpoint_path)