mmcv/tests/test_runner/test_checkpoint.py
Zaida Zhou 32e09f4933
[Feature] Upload checkpoints and logs to ceph (#1375)
* [Feature] Choose storage backend by the prefix of filepath

* refactor FileClient and add unittest

* support loading from different backends

* polish docstring

* fix unittet

* rename attribute str_like_obj to is_str_like_obj

* [Docs] Upload checkpoint to petrel oss

* add infer_client method

* Support uploading checkpoint to petrel oss

* add check_exist method

* refactor CheckpointHook

* support uploading logs to ceph

* rename var client to file_client

* polish docstring

* enhance load_from_ceph

* refactor load_from_ceph

* refactor TextLoggerHook

* change the meaning of out_dir argument

* fix test_checkpoint_hook.py

* add join_paths method

* remove join_paths and add _format_path

* enhance unittest

* refactor unittest

* add a unittest for EvalHook when file backend is petrel

* singleton pattern

* fix test_clientio.py

* deprecate CephBackend

* add warning in load_from_ceph

* fix type of out_suffix

* enhance docstring

* refactor unittest for petrel

* refactor unittest for disk backend

* update io.md

* add concat_paths method

* fix CI

* mock check_exist

* improve docstring

* improve docstring

* improve docstring

* improve docstring

* add isdir and copyfile for file backend

* delete copyfile and add get_local_path

* remove isdir method of petrel

* fix typo

* rename check_exists to exists

* refactor code and polish docstring

* fix windows ci

* add comment and polish docstring

* polish docstring

* polish docstring

* rename _path_mapping to _map_path

* polish docstring and fix typo

* refactor get_local_path

* add list_dir_or_file for FileClient

* add list_dir_or_file for PetrelBackend

* fix windows ci

* Add return docstring

* polish docstring

* fix typo

* fix typo

* fix typo

* fix error when mocking PetrelBackend

* deprecate the conversion from Path to str

* add docs for loading checkpoints with FileClient

* rename keep_log to keep_local

* refactor map_path

* add _ensure_methods to ensure methods have been implemented

* fix list_dir_or_file

* rename _ensure_method_implemented to has_method

* refactor

* polish information

* format information
2021-10-24 14:26:52 +08:00

433 lines
16 KiB
Python

import sys
from collections import OrderedDict
from tempfile import TemporaryDirectory
from unittest.mock import MagicMock, patch
import pytest
import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn.parallel import DataParallel
from mmcv.fileio.file_client import PetrelBackend
from mmcv.parallel.registry import MODULE_WRAPPERS
from mmcv.runner.checkpoint import (_load_checkpoint_with_prefix,
get_state_dict, load_checkpoint,
load_from_pavi, save_checkpoint)
sys.modules['petrel_client'] = MagicMock()
sys.modules['petrel_client.client'] = MagicMock()
@MODULE_WRAPPERS.register_module()
class DDPWrapper(object):
def __init__(self, module):
self.module = module
class Block(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(3, 3, 1)
self.norm = nn.BatchNorm2d(3)
class Model(nn.Module):
def __init__(self):
super().__init__()
self.block = Block()
self.conv = nn.Conv2d(3, 3, 1)
class Mockpavimodel(object):
def __init__(self, name='fakename'):
self.name = name
def download(self, file):
pass
def assert_tensor_equal(tensor_a, tensor_b):
assert tensor_a.eq(tensor_b).all()
def test_get_state_dict():
if torch.__version__ == 'parrots':
state_dict_keys = set([
'block.conv.weight', 'block.conv.bias', 'block.norm.weight',
'block.norm.bias', 'block.norm.running_mean',
'block.norm.running_var', 'conv.weight', 'conv.bias'
])
else:
state_dict_keys = set([
'block.conv.weight', 'block.conv.bias', 'block.norm.weight',
'block.norm.bias', 'block.norm.running_mean',
'block.norm.running_var', 'block.norm.num_batches_tracked',
'conv.weight', 'conv.bias'
])
model = Model()
state_dict = get_state_dict(model)
assert isinstance(state_dict, OrderedDict)
assert set(state_dict.keys()) == state_dict_keys
assert_tensor_equal(state_dict['block.conv.weight'],
model.block.conv.weight)
assert_tensor_equal(state_dict['block.conv.bias'], model.block.conv.bias)
assert_tensor_equal(state_dict['block.norm.weight'],
model.block.norm.weight)
assert_tensor_equal(state_dict['block.norm.bias'], model.block.norm.bias)
assert_tensor_equal(state_dict['block.norm.running_mean'],
model.block.norm.running_mean)
assert_tensor_equal(state_dict['block.norm.running_var'],
model.block.norm.running_var)
if torch.__version__ != 'parrots':
assert_tensor_equal(state_dict['block.norm.num_batches_tracked'],
model.block.norm.num_batches_tracked)
assert_tensor_equal(state_dict['conv.weight'], model.conv.weight)
assert_tensor_equal(state_dict['conv.bias'], model.conv.bias)
wrapped_model = DDPWrapper(model)
state_dict = get_state_dict(wrapped_model)
assert isinstance(state_dict, OrderedDict)
assert set(state_dict.keys()) == state_dict_keys
assert_tensor_equal(state_dict['block.conv.weight'],
wrapped_model.module.block.conv.weight)
assert_tensor_equal(state_dict['block.conv.bias'],
wrapped_model.module.block.conv.bias)
assert_tensor_equal(state_dict['block.norm.weight'],
wrapped_model.module.block.norm.weight)
assert_tensor_equal(state_dict['block.norm.bias'],
wrapped_model.module.block.norm.bias)
assert_tensor_equal(state_dict['block.norm.running_mean'],
wrapped_model.module.block.norm.running_mean)
assert_tensor_equal(state_dict['block.norm.running_var'],
wrapped_model.module.block.norm.running_var)
if torch.__version__ != 'parrots':
assert_tensor_equal(
state_dict['block.norm.num_batches_tracked'],
wrapped_model.module.block.norm.num_batches_tracked)
assert_tensor_equal(state_dict['conv.weight'],
wrapped_model.module.conv.weight)
assert_tensor_equal(state_dict['conv.bias'],
wrapped_model.module.conv.bias)
# wrapped inner module
for name, module in wrapped_model.module._modules.items():
module = DataParallel(module)
wrapped_model.module._modules[name] = module
state_dict = get_state_dict(wrapped_model)
assert isinstance(state_dict, OrderedDict)
assert set(state_dict.keys()) == state_dict_keys
assert_tensor_equal(state_dict['block.conv.weight'],
wrapped_model.module.block.module.conv.weight)
assert_tensor_equal(state_dict['block.conv.bias'],
wrapped_model.module.block.module.conv.bias)
assert_tensor_equal(state_dict['block.norm.weight'],
wrapped_model.module.block.module.norm.weight)
assert_tensor_equal(state_dict['block.norm.bias'],
wrapped_model.module.block.module.norm.bias)
assert_tensor_equal(state_dict['block.norm.running_mean'],
wrapped_model.module.block.module.norm.running_mean)
assert_tensor_equal(state_dict['block.norm.running_var'],
wrapped_model.module.block.module.norm.running_var)
if torch.__version__ != 'parrots':
assert_tensor_equal(
state_dict['block.norm.num_batches_tracked'],
wrapped_model.module.block.module.norm.num_batches_tracked)
assert_tensor_equal(state_dict['conv.weight'],
wrapped_model.module.conv.module.weight)
assert_tensor_equal(state_dict['conv.bias'],
wrapped_model.module.conv.module.bias)
def test_load_pavimodel_dist():
sys.modules['pavi'] = MagicMock()
sys.modules['pavi.modelcloud'] = MagicMock()
pavimodel = Mockpavimodel()
import pavi
pavi.modelcloud.get = MagicMock(return_value=pavimodel)
with pytest.raises(AssertionError):
# test pavi prefix
_ = load_from_pavi('MyPaviFolder/checkpoint.pth')
with pytest.raises(FileNotFoundError):
# there is not such checkpoint for us to load
_ = load_from_pavi('pavi://checkpoint.pth')
def test_load_checkpoint_with_prefix():
class FooModule(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(1, 2)
self.conv2d = nn.Conv2d(3, 1, 3)
self.conv2d_2 = nn.Conv2d(3, 2, 3)
model = FooModule()
nn.init.constant_(model.linear.weight, 1)
nn.init.constant_(model.linear.bias, 2)
nn.init.constant_(model.conv2d.weight, 3)
nn.init.constant_(model.conv2d.bias, 4)
nn.init.constant_(model.conv2d_2.weight, 5)
nn.init.constant_(model.conv2d_2.bias, 6)
with TemporaryDirectory():
torch.save(model.state_dict(), 'model.pth')
prefix = 'conv2d'
state_dict = _load_checkpoint_with_prefix(prefix, 'model.pth')
assert torch.equal(model.conv2d.state_dict()['weight'],
state_dict['weight'])
assert torch.equal(model.conv2d.state_dict()['bias'],
state_dict['bias'])
# test whether prefix is in pretrained model
with pytest.raises(AssertionError):
prefix = 'back'
_load_checkpoint_with_prefix(prefix, 'model.pth')
def test_load_checkpoint():
import os
import tempfile
import re
class PrefixModel(nn.Module):
def __init__(self):
super().__init__()
self.backbone = Model()
pmodel = PrefixModel()
model = Model()
checkpoint_path = os.path.join(tempfile.gettempdir(), 'checkpoint.pth')
# add prefix
torch.save(model.state_dict(), checkpoint_path)
state_dict = load_checkpoint(
pmodel, checkpoint_path, revise_keys=[(r'^', 'backbone.')])
for key in pmodel.backbone.state_dict().keys():
assert torch.equal(pmodel.backbone.state_dict()[key], state_dict[key])
# strip prefix
torch.save(pmodel.state_dict(), checkpoint_path)
state_dict = load_checkpoint(
model, checkpoint_path, revise_keys=[(r'^backbone\.', '')])
for key in state_dict.keys():
key_stripped = re.sub(r'^backbone\.', '', key)
assert torch.equal(model.state_dict()[key_stripped], state_dict[key])
os.remove(checkpoint_path)
def test_load_checkpoint_metadata():
import os
import tempfile
from mmcv.runner import load_checkpoint, save_checkpoint
class ModelV1(nn.Module):
def __init__(self):
super().__init__()
self.block = Block()
self.conv1 = nn.Conv2d(3, 3, 1)
self.conv2 = nn.Conv2d(3, 3, 1)
nn.init.normal_(self.conv1.weight)
nn.init.normal_(self.conv2.weight)
class ModelV2(nn.Module):
_version = 2
def __init__(self):
super().__init__()
self.block = Block()
self.conv0 = nn.Conv2d(3, 3, 1)
self.conv1 = nn.Conv2d(3, 3, 1)
nn.init.normal_(self.conv0.weight)
nn.init.normal_(self.conv1.weight)
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
*args, **kwargs):
"""load checkpoints."""
# Names of some parameters in has been changed.
version = local_metadata.get('version', None)
if version is None or version < 2:
state_dict_keys = list(state_dict.keys())
convert_map = {'conv1': 'conv0', 'conv2': 'conv1'}
for k in state_dict_keys:
for ori_str, new_str in convert_map.items():
if k.startswith(prefix + ori_str):
new_key = k.replace(ori_str, new_str)
state_dict[new_key] = state_dict[k]
del state_dict[k]
super()._load_from_state_dict(state_dict, prefix, local_metadata,
*args, **kwargs)
model_v1 = ModelV1()
model_v1_conv0_weight = model_v1.conv1.weight.detach()
model_v1_conv1_weight = model_v1.conv2.weight.detach()
model_v2 = ModelV2()
model_v2_conv0_weight = model_v2.conv0.weight.detach()
model_v2_conv1_weight = model_v2.conv1.weight.detach()
ckpt_v1_path = os.path.join(tempfile.gettempdir(), 'checkpoint_v1.pth')
ckpt_v2_path = os.path.join(tempfile.gettempdir(), 'checkpoint_v2.pth')
# Save checkpoint
save_checkpoint(model_v1, ckpt_v1_path)
save_checkpoint(model_v2, ckpt_v2_path)
# test load v1 model
load_checkpoint(model_v2, ckpt_v1_path)
assert torch.allclose(model_v2.conv0.weight, model_v1_conv0_weight)
assert torch.allclose(model_v2.conv1.weight, model_v1_conv1_weight)
# test load v2 model
load_checkpoint(model_v2, ckpt_v2_path)
assert torch.allclose(model_v2.conv0.weight, model_v2_conv0_weight)
assert torch.allclose(model_v2.conv1.weight, model_v2_conv1_weight)
def test_load_classes_name():
import os
import tempfile
from mmcv.runner import load_checkpoint, save_checkpoint
checkpoint_path = os.path.join(tempfile.gettempdir(), 'checkpoint.pth')
model = Model()
save_checkpoint(model, checkpoint_path)
checkpoint = load_checkpoint(model, checkpoint_path)
assert 'meta' in checkpoint and 'CLASSES' not in checkpoint['meta']
model.CLASSES = ('class1', 'class2')
save_checkpoint(model, checkpoint_path)
checkpoint = load_checkpoint(model, checkpoint_path)
assert 'meta' in checkpoint and 'CLASSES' in checkpoint['meta']
assert checkpoint['meta']['CLASSES'] == ('class1', 'class2')
model = Model()
wrapped_model = DDPWrapper(model)
save_checkpoint(wrapped_model, checkpoint_path)
checkpoint = load_checkpoint(wrapped_model, checkpoint_path)
assert 'meta' in checkpoint and 'CLASSES' not in checkpoint['meta']
wrapped_model.module.CLASSES = ('class1', 'class2')
save_checkpoint(wrapped_model, checkpoint_path)
checkpoint = load_checkpoint(wrapped_model, checkpoint_path)
assert 'meta' in checkpoint and 'CLASSES' in checkpoint['meta']
assert checkpoint['meta']['CLASSES'] == ('class1', 'class2')
# remove the temp file
os.remove(checkpoint_path)
def test_checkpoint_loader():
from mmcv.runner import _load_checkpoint, save_checkpoint, CheckpointLoader
import tempfile
import os
checkpoint_path = os.path.join(tempfile.gettempdir(), 'checkpoint.pth')
model = Model()
save_checkpoint(model, checkpoint_path)
checkpoint = _load_checkpoint(checkpoint_path)
assert 'meta' in checkpoint and 'CLASSES' not in checkpoint['meta']
# remove the temp file
os.remove(checkpoint_path)
filenames = [
'http://xx.xx/xx.pth', 'https://xx.xx/xx.pth',
'modelzoo://xx.xx/xx.pth', 'torchvision://xx.xx/xx.pth',
'open-mmlab://xx.xx/xx.pth', 'openmmlab://xx.xx/xx.pth',
'mmcls://xx.xx/xx.pth', 'pavi://xx.xx/xx.pth', 's3://xx.xx/xx.pth',
'ss3://xx.xx/xx.pth', ' s3://xx.xx/xx.pth'
]
fn_names = [
'load_from_http', 'load_from_http', 'load_from_torchvision',
'load_from_torchvision', 'load_from_openmmlab', 'load_from_openmmlab',
'load_from_mmcls', 'load_from_pavi', 'load_from_ceph',
'load_from_local', 'load_from_local'
]
for filename, fn_name in zip(filenames, fn_names):
loader = CheckpointLoader._get_checkpoint_loader(filename)
assert loader.__name__ == fn_name
@CheckpointLoader.register_scheme(prefixes='ftp://')
def load_from_ftp(filename, map_location):
return dict(filename=filename)
# test register_loader
filename = 'ftp://xx.xx/xx.pth'
loader = CheckpointLoader._get_checkpoint_loader(filename)
assert loader.__name__ == 'load_from_ftp'
def load_from_ftp1(filename, map_location):
return dict(filename=filename)
# test duplicate registered error
with pytest.raises(KeyError):
CheckpointLoader.register_scheme('ftp://', load_from_ftp1)
# test force param
CheckpointLoader.register_scheme('ftp://', load_from_ftp1, force=True)
checkpoint = CheckpointLoader.load_checkpoint(filename)
assert checkpoint['filename'] == filename
# test print function name
loader = CheckpointLoader._get_checkpoint_loader(filename)
assert loader.__name__ == 'load_from_ftp1'
# test sort
@CheckpointLoader.register_scheme(prefixes='a/b')
def load_from_ab(filename, map_location):
return dict(filename=filename)
@CheckpointLoader.register_scheme(prefixes='a/b/c')
def load_from_abc(filename, map_location):
return dict(filename=filename)
filename = 'a/b/c/d'
loader = CheckpointLoader._get_checkpoint_loader(filename)
assert loader.__name__ == 'load_from_abc'
def test_save_checkpoint(tmp_path):
model = Model()
optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
# meta is not a dict
with pytest.raises(TypeError):
save_checkpoint(model, '/path/of/your/filename', meta='invalid type')
# 1. save to disk
filename = str(tmp_path / 'checkpoint1.pth')
save_checkpoint(model, filename)
filename = str(tmp_path / 'checkpoint2.pth')
save_checkpoint(model, filename, optimizer)
filename = str(tmp_path / 'checkpoint3.pth')
save_checkpoint(model, filename, meta={'test': 'test'})
filename = str(tmp_path / 'checkpoint4.pth')
save_checkpoint(model, filename, file_client_args={'backend': 'disk'})
# 2. save to petrel oss
with patch.object(PetrelBackend, 'put') as mock_method:
filename = 's3://path/of/your/checkpoint1.pth'
save_checkpoint(model, filename)
mock_method.assert_called()
with patch.object(PetrelBackend, 'put') as mock_method:
filename = 's3://path//of/your/checkpoint2.pth'
save_checkpoint(
model, filename, file_client_args={'backend': 'petrel'})
mock_method.assert_called()