436 lines
16 KiB
Python
436 lines
16 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import os
|
|
import sys
|
|
import tempfile
|
|
from collections import OrderedDict
|
|
from tempfile import TemporaryDirectory
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import pytest
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
from torch.nn.parallel import DataParallel
|
|
|
|
from mmengine.fileio.file_client import PetrelBackend
|
|
from mmengine.registry import MODEL_WRAPPERS
|
|
from mmengine.runner.checkpoint import (CheckpointLoader,
|
|
_load_checkpoint_with_prefix,
|
|
get_state_dict, load_checkpoint,
|
|
load_from_local, load_from_pavi,
|
|
load_state_dict, save_checkpoint)
|
|
|
|
|
|
@MODEL_WRAPPERS.register_module()
|
|
class DDPWrapper:
|
|
|
|
def __init__(self, module):
|
|
self.module = module
|
|
|
|
|
|
class Block(nn.Module):
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.conv = nn.Conv2d(3, 3, 1)
|
|
self.norm = nn.BatchNorm2d(3)
|
|
|
|
|
|
class Model(nn.Module):
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.block = Block()
|
|
self.conv = nn.Conv2d(3, 3, 1)
|
|
|
|
|
|
class Mockpavimodel:
|
|
|
|
def __init__(self, name='fakename'):
|
|
self.name = name
|
|
|
|
def download(self, file):
|
|
pass
|
|
|
|
|
|
def assert_tensor_equal(tensor_a, tensor_b):
|
|
assert tensor_a.eq(tensor_b).all()
|
|
|
|
|
|
def test_get_state_dict():
|
|
if torch.__version__ == 'parrots':
|
|
state_dict_keys = {
|
|
'block.conv.weight', 'block.conv.bias', 'block.norm.weight',
|
|
'block.norm.bias', 'block.norm.running_mean',
|
|
'block.norm.running_var', 'conv.weight', 'conv.bias'
|
|
}
|
|
else:
|
|
state_dict_keys = {
|
|
'block.conv.weight', 'block.conv.bias', 'block.norm.weight',
|
|
'block.norm.bias', 'block.norm.running_mean',
|
|
'block.norm.running_var', 'block.norm.num_batches_tracked',
|
|
'conv.weight', 'conv.bias'
|
|
}
|
|
|
|
model = Model()
|
|
state_dict = get_state_dict(model)
|
|
assert isinstance(state_dict, OrderedDict)
|
|
assert set(state_dict.keys()) == state_dict_keys
|
|
|
|
assert_tensor_equal(state_dict['block.conv.weight'],
|
|
model.block.conv.weight)
|
|
assert_tensor_equal(state_dict['block.conv.bias'], model.block.conv.bias)
|
|
assert_tensor_equal(state_dict['block.norm.weight'],
|
|
model.block.norm.weight)
|
|
assert_tensor_equal(state_dict['block.norm.bias'], model.block.norm.bias)
|
|
assert_tensor_equal(state_dict['block.norm.running_mean'],
|
|
model.block.norm.running_mean)
|
|
assert_tensor_equal(state_dict['block.norm.running_var'],
|
|
model.block.norm.running_var)
|
|
if torch.__version__ != 'parrots':
|
|
assert_tensor_equal(state_dict['block.norm.num_batches_tracked'],
|
|
model.block.norm.num_batches_tracked)
|
|
assert_tensor_equal(state_dict['conv.weight'], model.conv.weight)
|
|
assert_tensor_equal(state_dict['conv.bias'], model.conv.bias)
|
|
|
|
wrapped_model = DDPWrapper(model)
|
|
state_dict = get_state_dict(wrapped_model)
|
|
assert isinstance(state_dict, OrderedDict)
|
|
assert set(state_dict.keys()) == state_dict_keys
|
|
assert_tensor_equal(state_dict['block.conv.weight'],
|
|
wrapped_model.module.block.conv.weight)
|
|
assert_tensor_equal(state_dict['block.conv.bias'],
|
|
wrapped_model.module.block.conv.bias)
|
|
assert_tensor_equal(state_dict['block.norm.weight'],
|
|
wrapped_model.module.block.norm.weight)
|
|
assert_tensor_equal(state_dict['block.norm.bias'],
|
|
wrapped_model.module.block.norm.bias)
|
|
assert_tensor_equal(state_dict['block.norm.running_mean'],
|
|
wrapped_model.module.block.norm.running_mean)
|
|
assert_tensor_equal(state_dict['block.norm.running_var'],
|
|
wrapped_model.module.block.norm.running_var)
|
|
if torch.__version__ != 'parrots':
|
|
assert_tensor_equal(
|
|
state_dict['block.norm.num_batches_tracked'],
|
|
wrapped_model.module.block.norm.num_batches_tracked)
|
|
assert_tensor_equal(state_dict['conv.weight'],
|
|
wrapped_model.module.conv.weight)
|
|
assert_tensor_equal(state_dict['conv.bias'],
|
|
wrapped_model.module.conv.bias)
|
|
|
|
# wrapped inner module
|
|
for name, module in wrapped_model.module._modules.items():
|
|
module = DataParallel(module)
|
|
wrapped_model.module._modules[name] = module
|
|
state_dict = get_state_dict(wrapped_model)
|
|
assert isinstance(state_dict, OrderedDict)
|
|
assert set(state_dict.keys()) == state_dict_keys
|
|
assert_tensor_equal(state_dict['block.conv.weight'],
|
|
wrapped_model.module.block.module.conv.weight)
|
|
assert_tensor_equal(state_dict['block.conv.bias'],
|
|
wrapped_model.module.block.module.conv.bias)
|
|
assert_tensor_equal(state_dict['block.norm.weight'],
|
|
wrapped_model.module.block.module.norm.weight)
|
|
assert_tensor_equal(state_dict['block.norm.bias'],
|
|
wrapped_model.module.block.module.norm.bias)
|
|
assert_tensor_equal(state_dict['block.norm.running_mean'],
|
|
wrapped_model.module.block.module.norm.running_mean)
|
|
assert_tensor_equal(state_dict['block.norm.running_var'],
|
|
wrapped_model.module.block.module.norm.running_var)
|
|
if torch.__version__ != 'parrots':
|
|
assert_tensor_equal(
|
|
state_dict['block.norm.num_batches_tracked'],
|
|
wrapped_model.module.block.module.norm.num_batches_tracked)
|
|
assert_tensor_equal(state_dict['conv.weight'],
|
|
wrapped_model.module.conv.module.weight)
|
|
assert_tensor_equal(state_dict['conv.bias'],
|
|
wrapped_model.module.conv.module.bias)
|
|
|
|
|
|
@patch.dict(sys.modules, {'pavi': MagicMock()})
|
|
def test_load_pavimodel_dist():
|
|
pavimodel = Mockpavimodel()
|
|
import pavi
|
|
pavi.modelcloud.get = MagicMock(return_value=pavimodel)
|
|
with pytest.raises(AssertionError):
|
|
# test pavi prefix
|
|
_ = load_from_pavi('MyPaviFolder/checkpoint.pth')
|
|
|
|
with pytest.raises(FileNotFoundError):
|
|
# there is not such checkpoint for us to load
|
|
_ = load_from_pavi('pavi://checkpoint.pth')
|
|
|
|
|
|
def test_load_checkpoint_with_prefix():
|
|
|
|
class FooModule(nn.Module):
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.linear = nn.Linear(1, 2)
|
|
self.conv2d = nn.Conv2d(3, 1, 3)
|
|
self.conv2d_2 = nn.Conv2d(3, 2, 3)
|
|
|
|
model = FooModule()
|
|
nn.init.constant_(model.linear.weight, 1)
|
|
nn.init.constant_(model.linear.bias, 2)
|
|
nn.init.constant_(model.conv2d.weight, 3)
|
|
nn.init.constant_(model.conv2d.bias, 4)
|
|
nn.init.constant_(model.conv2d_2.weight, 5)
|
|
nn.init.constant_(model.conv2d_2.bias, 6)
|
|
|
|
with TemporaryDirectory():
|
|
torch.save(model.state_dict(), 'model.pth')
|
|
prefix = 'conv2d'
|
|
state_dict = _load_checkpoint_with_prefix(prefix, 'model.pth')
|
|
assert torch.equal(model.conv2d.state_dict()['weight'],
|
|
state_dict['weight'])
|
|
assert torch.equal(model.conv2d.state_dict()['bias'],
|
|
state_dict['bias'])
|
|
|
|
# test whether prefix is in pretrained model
|
|
with pytest.raises(AssertionError):
|
|
prefix = 'back'
|
|
_load_checkpoint_with_prefix(prefix, 'model.pth')
|
|
|
|
|
|
def test_load_checkpoint():
|
|
import os
|
|
import re
|
|
import tempfile
|
|
|
|
class PrefixModel(nn.Module):
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.backbone = Model()
|
|
|
|
pmodel = PrefixModel()
|
|
model = Model()
|
|
checkpoint_path = os.path.join(tempfile.gettempdir(), 'checkpoint.pth')
|
|
|
|
# add prefix
|
|
torch.save(model.state_dict(), checkpoint_path)
|
|
state_dict = load_checkpoint(
|
|
pmodel, checkpoint_path, revise_keys=[(r'^', 'backbone.')])
|
|
for key in pmodel.backbone.state_dict().keys():
|
|
assert torch.equal(pmodel.backbone.state_dict()[key], state_dict[key])
|
|
# strip prefix
|
|
torch.save(pmodel.state_dict(), checkpoint_path)
|
|
state_dict = load_checkpoint(
|
|
model, checkpoint_path, revise_keys=[(r'^backbone\.', '')])
|
|
|
|
for key in state_dict.keys():
|
|
key_stripped = re.sub(r'^backbone\.', '', key)
|
|
assert torch.equal(model.state_dict()[key_stripped], state_dict[key])
|
|
os.remove(checkpoint_path)
|
|
|
|
|
|
def test_load_checkpoint_metadata():
|
|
|
|
class ModelV1(nn.Module):
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.block = Block()
|
|
self.conv1 = nn.Conv2d(3, 3, 1)
|
|
self.conv2 = nn.Conv2d(3, 3, 1)
|
|
nn.init.normal_(self.conv1.weight)
|
|
nn.init.normal_(self.conv2.weight)
|
|
|
|
class ModelV2(nn.Module):
|
|
_version = 2
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.block = Block()
|
|
self.conv0 = nn.Conv2d(3, 3, 1)
|
|
self.conv1 = nn.Conv2d(3, 3, 1)
|
|
nn.init.normal_(self.conv0.weight)
|
|
nn.init.normal_(self.conv1.weight)
|
|
|
|
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
|
|
*args, **kwargs):
|
|
"""load checkpoints."""
|
|
|
|
# Names of some parameters in has been changed.
|
|
version = local_metadata.get('version', None)
|
|
if version is None or version < 2:
|
|
state_dict_keys = list(state_dict.keys())
|
|
convert_map = {'conv1': 'conv0', 'conv2': 'conv1'}
|
|
for k in state_dict_keys:
|
|
for ori_str, new_str in convert_map.items():
|
|
if k.startswith(prefix + ori_str):
|
|
new_key = k.replace(ori_str, new_str)
|
|
state_dict[new_key] = state_dict[k]
|
|
del state_dict[k]
|
|
|
|
super()._load_from_state_dict(state_dict, prefix, local_metadata,
|
|
*args, **kwargs)
|
|
|
|
model_v1 = ModelV1()
|
|
model_v1_conv0_weight = model_v1.conv1.weight.detach()
|
|
model_v1_conv1_weight = model_v1.conv2.weight.detach()
|
|
model_v2 = ModelV2()
|
|
model_v2_conv0_weight = model_v2.conv0.weight.detach()
|
|
model_v2_conv1_weight = model_v2.conv1.weight.detach()
|
|
ckpt_v1_path = os.path.join(tempfile.gettempdir(), 'checkpoint_v1.pth')
|
|
ckpt_v2_path = os.path.join(tempfile.gettempdir(), 'checkpoint_v2.pth')
|
|
|
|
# Save checkpoint
|
|
save_checkpoint(model_v1.state_dict(), ckpt_v1_path)
|
|
save_checkpoint(model_v2.state_dict(), ckpt_v2_path)
|
|
|
|
# test load v1 model
|
|
load_checkpoint(model_v2, ckpt_v1_path)
|
|
assert torch.allclose(model_v2.conv0.weight, model_v1_conv0_weight)
|
|
assert torch.allclose(model_v2.conv1.weight, model_v1_conv1_weight)
|
|
|
|
# test load v2 model
|
|
load_checkpoint(model_v2, ckpt_v2_path)
|
|
assert torch.allclose(model_v2.conv0.weight, model_v2_conv0_weight)
|
|
assert torch.allclose(model_v2.conv1.weight, model_v2_conv1_weight)
|
|
|
|
|
|
@patch.dict(sys.modules, {'petrel_client': MagicMock()})
|
|
def test_checkpoint_loader():
|
|
filenames = [
|
|
'http://xx.xx/xx.pth', 'https://xx.xx/xx.pth',
|
|
'modelzoo://xx.xx/xx.pth', 'torchvision://xx.xx/xx.pth',
|
|
'open-mmlab://xx.xx/xx.pth', 'openmmlab://xx.xx/xx.pth',
|
|
'mmcls://xx.xx/xx.pth', 'pavi://xx.xx/xx.pth', 's3://xx.xx/xx.pth',
|
|
'ss3://xx.xx/xx.pth', ' s3://xx.xx/xx.pth',
|
|
'open-mmlab:s3://xx.xx/xx.pth', 'openmmlab:s3://xx.xx/xx.pth',
|
|
'openmmlabs3://xx.xx/xx.pth', ':s3://xx.xx/xx.path'
|
|
]
|
|
fn_names = [
|
|
'load_from_http', 'load_from_http', 'load_from_torchvision',
|
|
'load_from_torchvision', 'load_from_openmmlab', 'load_from_openmmlab',
|
|
'load_from_mmcls', 'load_from_pavi', 'load_from_ceph',
|
|
'load_from_local', 'load_from_local', 'load_from_ceph',
|
|
'load_from_ceph', 'load_from_local', 'load_from_local'
|
|
]
|
|
|
|
for filename, fn_name in zip(filenames, fn_names):
|
|
loader = CheckpointLoader._get_checkpoint_loader(filename)
|
|
assert loader.__name__ == fn_name
|
|
|
|
@CheckpointLoader.register_scheme(prefixes='ftp://')
|
|
def load_from_ftp(filename, map_location):
|
|
return dict(filename=filename)
|
|
|
|
# test register_loader
|
|
filename = 'ftp://xx.xx/xx.pth'
|
|
loader = CheckpointLoader._get_checkpoint_loader(filename)
|
|
assert loader.__name__ == 'load_from_ftp'
|
|
|
|
def load_from_ftp1(filename, map_location):
|
|
return dict(filename=filename)
|
|
|
|
# test duplicate registered error
|
|
with pytest.raises(KeyError):
|
|
CheckpointLoader.register_scheme('ftp://', load_from_ftp1)
|
|
|
|
# test force param
|
|
CheckpointLoader.register_scheme('ftp://', load_from_ftp1, force=True)
|
|
checkpoint = CheckpointLoader.load_checkpoint(filename)
|
|
assert checkpoint['filename'] == filename
|
|
|
|
# test print function name
|
|
loader = CheckpointLoader._get_checkpoint_loader(filename)
|
|
assert loader.__name__ == 'load_from_ftp1'
|
|
|
|
# test sort
|
|
@CheckpointLoader.register_scheme(prefixes='a/b')
|
|
def load_from_ab(filename, map_location):
|
|
return dict(filename=filename)
|
|
|
|
@CheckpointLoader.register_scheme(prefixes='a/b/c')
|
|
def load_from_abc(filename, map_location):
|
|
return dict(filename=filename)
|
|
|
|
filename = 'a/b/c/d'
|
|
loader = CheckpointLoader._get_checkpoint_loader(filename)
|
|
assert loader.__name__ == 'load_from_abc'
|
|
|
|
|
|
def test_save_checkpoint(tmp_path):
|
|
model = Model()
|
|
optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
|
|
# meta is not a dict
|
|
with pytest.raises(TypeError):
|
|
save_checkpoint(model, '/path/of/your/filename', meta='invalid type')
|
|
|
|
# 1. save to disk
|
|
filename = str(tmp_path / 'checkpoint1.pth')
|
|
save_checkpoint(model.state_dict(), filename)
|
|
|
|
filename = str(tmp_path / 'checkpoint2.pth')
|
|
checkpoint = dict(
|
|
model=model.state_dict(), optimizer=optimizer.state_dict())
|
|
save_checkpoint(checkpoint, filename)
|
|
|
|
filename = str(tmp_path / 'checkpoint3.pth')
|
|
save_checkpoint(
|
|
model.state_dict(), filename, backend_args={'backend': 'local'})
|
|
|
|
filename = str(tmp_path / 'checkpoint4.pth')
|
|
save_checkpoint(
|
|
model.state_dict(), filename, file_client_args={'backend': 'disk'})
|
|
|
|
# 2. save to petrel oss
|
|
with patch.object(PetrelBackend, 'put') as mock_method:
|
|
filename = 's3://path/of/your/checkpoint1.pth'
|
|
save_checkpoint(model.state_dict(), filename)
|
|
mock_method.assert_called()
|
|
|
|
with patch.object(PetrelBackend, 'put') as mock_method:
|
|
filename = 's3://path//of/your/checkpoint2.pth'
|
|
save_checkpoint(
|
|
model.state_dict(),
|
|
filename,
|
|
file_client_args={'backend': 'petrel'})
|
|
mock_method.assert_called()
|
|
|
|
|
|
def test_load_from_local():
|
|
import os
|
|
home_path = os.path.expanduser('~')
|
|
checkpoint_path = os.path.join(
|
|
home_path, 'dummy_checkpoint_used_to_test_load_from_local.pth')
|
|
model = Model()
|
|
save_checkpoint(model.state_dict(), checkpoint_path)
|
|
checkpoint = load_from_local(
|
|
'~/dummy_checkpoint_used_to_test_load_from_local.pth',
|
|
map_location=None)
|
|
assert_tensor_equal(checkpoint['block.conv.weight'],
|
|
model.block.conv.weight)
|
|
os.remove(checkpoint_path)
|
|
|
|
|
|
def test_load_state_dict_post_hooks():
|
|
module = Block()
|
|
|
|
state_dict = {
|
|
'conv.weight': torch.empty((3, 3, 1, 1), dtype=torch.float32),
|
|
'conv.bias': torch.empty((3, ), dtype=torch.float32),
|
|
'norm.weight': torch.empty([3], dtype=torch.float32),
|
|
'norm.bias': torch.empty([3], dtype=torch.float32),
|
|
'norm.running_mean': torch.empty([3], dtype=torch.float32),
|
|
'norm.running_var': torch.empty([3], dtype=torch.float32),
|
|
}
|
|
state_dict.pop('norm.running_var')
|
|
|
|
with patch('mmengine.runner.checkpoint.print_log') as mock:
|
|
load_state_dict(module, state_dict, strict=False)
|
|
mock.assert_called_once()
|
|
|
|
def post_hook(_, incompatible_keys):
|
|
incompatible_keys.missing_keys.remove('norm.running_var')
|
|
|
|
module._load_state_dict_post_hooks = {0: post_hook}
|
|
|
|
with patch('mmengine.runner.checkpoint.print_log') as mock:
|
|
load_state_dict(module, state_dict, strict=False)
|
|
mock.assert_not_called()
|