2020-12-22 12:59:07 +08:00
|
|
|
import sys
|
2020-07-08 23:20:22 +08:00
|
|
|
from collections import OrderedDict
|
2020-12-22 12:59:07 +08:00
|
|
|
from unittest.mock import MagicMock
|
2020-07-08 23:20:22 +08:00
|
|
|
|
2020-12-22 12:59:07 +08:00
|
|
|
import pytest
|
2020-07-08 23:20:22 +08:00
|
|
|
import torch.nn as nn
|
|
|
|
from torch.nn.parallel import DataParallel
|
|
|
|
|
|
|
|
from mmcv.parallel.registry import MODULE_WRAPPERS
|
2020-12-22 12:59:07 +08:00
|
|
|
from mmcv.runner.checkpoint import get_state_dict, load_pavimodel_dist
|
2020-07-08 23:20:22 +08:00
|
|
|
|
|
|
|
|
|
|
|
@MODULE_WRAPPERS.register_module()
|
|
|
|
class DDPWrapper(object):
|
|
|
|
|
|
|
|
def __init__(self, module):
|
|
|
|
self.module = module
|
|
|
|
|
|
|
|
|
|
|
|
class Block(nn.Module):
|
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
super().__init__()
|
|
|
|
self.conv = nn.Conv2d(3, 3, 1)
|
|
|
|
self.norm = nn.BatchNorm2d(3)
|
|
|
|
|
|
|
|
|
|
|
|
class Model(nn.Module):
|
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
super().__init__()
|
|
|
|
self.block = Block()
|
|
|
|
self.conv = nn.Conv2d(3, 3, 1)
|
|
|
|
|
|
|
|
|
2020-12-22 12:59:07 +08:00
|
|
|
class Mockpavimodel(object):
|
|
|
|
|
|
|
|
def __init__(self, name='fakename'):
|
|
|
|
self.name = name
|
|
|
|
|
|
|
|
def download(self, file):
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
2020-07-08 23:20:22 +08:00
|
|
|
def assert_tensor_equal(tensor_a, tensor_b):
|
|
|
|
assert tensor_a.eq(tensor_b).all()
|
|
|
|
|
|
|
|
|
|
|
|
def test_get_state_dict():
|
|
|
|
state_dict_keys = set([
|
|
|
|
'block.conv.weight', 'block.conv.bias', 'block.norm.weight',
|
|
|
|
'block.norm.bias', 'block.norm.running_mean', 'block.norm.running_var',
|
|
|
|
'block.norm.num_batches_tracked', 'conv.weight', 'conv.bias'
|
|
|
|
])
|
|
|
|
|
|
|
|
model = Model()
|
|
|
|
state_dict = get_state_dict(model)
|
|
|
|
assert isinstance(state_dict, OrderedDict)
|
|
|
|
assert set(state_dict.keys()) == state_dict_keys
|
|
|
|
|
|
|
|
assert_tensor_equal(state_dict['block.conv.weight'],
|
|
|
|
model.block.conv.weight)
|
|
|
|
assert_tensor_equal(state_dict['block.conv.bias'], model.block.conv.bias)
|
|
|
|
assert_tensor_equal(state_dict['block.norm.weight'],
|
|
|
|
model.block.norm.weight)
|
|
|
|
assert_tensor_equal(state_dict['block.norm.bias'], model.block.norm.bias)
|
|
|
|
assert_tensor_equal(state_dict['block.norm.running_mean'],
|
|
|
|
model.block.norm.running_mean)
|
|
|
|
assert_tensor_equal(state_dict['block.norm.running_var'],
|
|
|
|
model.block.norm.running_var)
|
|
|
|
assert_tensor_equal(state_dict['block.norm.num_batches_tracked'],
|
|
|
|
model.block.norm.num_batches_tracked)
|
|
|
|
assert_tensor_equal(state_dict['conv.weight'], model.conv.weight)
|
|
|
|
assert_tensor_equal(state_dict['conv.bias'], model.conv.bias)
|
|
|
|
|
|
|
|
wrapped_model = DDPWrapper(model)
|
|
|
|
state_dict = get_state_dict(wrapped_model)
|
|
|
|
assert isinstance(state_dict, OrderedDict)
|
|
|
|
assert set(state_dict.keys()) == state_dict_keys
|
|
|
|
assert_tensor_equal(state_dict['block.conv.weight'],
|
|
|
|
wrapped_model.module.block.conv.weight)
|
|
|
|
assert_tensor_equal(state_dict['block.conv.bias'],
|
|
|
|
wrapped_model.module.block.conv.bias)
|
|
|
|
assert_tensor_equal(state_dict['block.norm.weight'],
|
|
|
|
wrapped_model.module.block.norm.weight)
|
|
|
|
assert_tensor_equal(state_dict['block.norm.bias'],
|
|
|
|
wrapped_model.module.block.norm.bias)
|
|
|
|
assert_tensor_equal(state_dict['block.norm.running_mean'],
|
|
|
|
wrapped_model.module.block.norm.running_mean)
|
|
|
|
assert_tensor_equal(state_dict['block.norm.running_var'],
|
|
|
|
wrapped_model.module.block.norm.running_var)
|
|
|
|
assert_tensor_equal(state_dict['block.norm.num_batches_tracked'],
|
|
|
|
wrapped_model.module.block.norm.num_batches_tracked)
|
|
|
|
assert_tensor_equal(state_dict['conv.weight'],
|
|
|
|
wrapped_model.module.conv.weight)
|
|
|
|
assert_tensor_equal(state_dict['conv.bias'],
|
|
|
|
wrapped_model.module.conv.bias)
|
|
|
|
|
|
|
|
# wrapped inner module
|
|
|
|
for name, module in wrapped_model.module._modules.items():
|
|
|
|
module = DataParallel(module)
|
|
|
|
wrapped_model.module._modules[name] = module
|
|
|
|
state_dict = get_state_dict(wrapped_model)
|
|
|
|
assert isinstance(state_dict, OrderedDict)
|
|
|
|
assert set(state_dict.keys()) == state_dict_keys
|
|
|
|
assert_tensor_equal(state_dict['block.conv.weight'],
|
|
|
|
wrapped_model.module.block.module.conv.weight)
|
|
|
|
assert_tensor_equal(state_dict['block.conv.bias'],
|
|
|
|
wrapped_model.module.block.module.conv.bias)
|
|
|
|
assert_tensor_equal(state_dict['block.norm.weight'],
|
|
|
|
wrapped_model.module.block.module.norm.weight)
|
|
|
|
assert_tensor_equal(state_dict['block.norm.bias'],
|
|
|
|
wrapped_model.module.block.module.norm.bias)
|
|
|
|
assert_tensor_equal(state_dict['block.norm.running_mean'],
|
|
|
|
wrapped_model.module.block.module.norm.running_mean)
|
|
|
|
assert_tensor_equal(state_dict['block.norm.running_var'],
|
|
|
|
wrapped_model.module.block.module.norm.running_var)
|
|
|
|
assert_tensor_equal(
|
|
|
|
state_dict['block.norm.num_batches_tracked'],
|
|
|
|
wrapped_model.module.block.module.norm.num_batches_tracked)
|
|
|
|
assert_tensor_equal(state_dict['conv.weight'],
|
|
|
|
wrapped_model.module.conv.module.weight)
|
|
|
|
assert_tensor_equal(state_dict['conv.bias'],
|
|
|
|
wrapped_model.module.conv.module.bias)
|
2020-12-22 12:59:07 +08:00
|
|
|
|
|
|
|
|
|
|
|
def test_load_pavimodel_dist():
|
|
|
|
sys.modules['pavi'] = MagicMock()
|
|
|
|
sys.modules['pavi.modelcloud'] = MagicMock()
|
|
|
|
pavimodel = Mockpavimodel()
|
|
|
|
import pavi
|
|
|
|
pavi.modelcloud.get = MagicMock(return_value=pavimodel)
|
|
|
|
with pytest.raises(FileNotFoundError):
|
|
|
|
# there is not such checkpoint for us to load
|
|
|
|
_ = load_pavimodel_dist('MyPaviFolder/checkpoint.pth')
|