mirror of https://github.com/open-mmlab/mmcv.git
Remove all module wrapper's module when saving checkpoint (#399)
* fix: remove all module wrapper when saving checkpoint * refactor: move position of if * docs: add docstring * refactor: add _save_to_state_dict from official torch * docs: modify docstring of _save_to_state_dict * docs: modify docstring * feat: add unittest * feat: add DataParallel to unittest * fix: a bug when model has batchnorm * docs: update docstringpull/401/head
parent
27cc439d01
commit
5704613e28
|
@ -254,6 +254,70 @@ def weights_to_cpu(state_dict):
|
|||
return state_dict_cpu
|
||||
|
||||
|
||||
def _save_to_state_dict(module, destination, prefix, keep_vars):
|
||||
"""Saves module state to `destination` dictionary.
|
||||
|
||||
This method is modified from :meth:`torch.nn.Module._save_to_state_dict`.
|
||||
|
||||
Args:
|
||||
module (nn.Module): The module to generate state_dict.
|
||||
destination (dict): A dict where state will be stored.
|
||||
prefix (str): The prefix for parameters and buffers used in this
|
||||
module.
|
||||
"""
|
||||
for name, param in module._parameters.items():
|
||||
if param is not None:
|
||||
destination[prefix + name] = param if keep_vars else param.detach()
|
||||
for name, buf in module._buffers.items():
|
||||
# remove check of _non_persistent_buffers_set to allow nn.BatchNorm2d
|
||||
if buf is not None:
|
||||
destination[prefix + name] = buf if keep_vars else buf.detach()
|
||||
|
||||
|
||||
def get_state_dict(module, destination=None, prefix='', keep_vars=False):
|
||||
"""Returns a dictionary containing a whole state of the module.
|
||||
|
||||
Both parameters and persistent buffers (e.g. running averages) are
|
||||
included. Keys are corresponding parameter and buffer names.
|
||||
|
||||
This method is modified from :meth:`torch.nn.Module.state_dict` to
|
||||
recursively check parallel module in case that the model has a complicated
|
||||
structure, e.g., nn.Module(nn.Module(DDP)).
|
||||
|
||||
Args:
|
||||
module (nn.Module): The module to generate state_dict.
|
||||
destination (OrderedDict): Returned dict for the state of the
|
||||
module.
|
||||
prefix (str): Prefix of the key.
|
||||
keep_vars (bool): Whether to keep the variable property of the
|
||||
parameters. Default: False.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing a whole state of the module.
|
||||
"""
|
||||
# recursively check parallel module in case that the model has a
|
||||
# complicated structure, e.g., nn.Module(nn.Module(DDP))
|
||||
if is_module_wrapper(module):
|
||||
module = module.module
|
||||
|
||||
# below is the same as torch.nn.Module.state_dict()
|
||||
if destination is None:
|
||||
destination = OrderedDict()
|
||||
destination._metadata = OrderedDict()
|
||||
destination._metadata[prefix[:-1]] = local_metadata = dict(
|
||||
version=module._version)
|
||||
_save_to_state_dict(module, destination, prefix, keep_vars)
|
||||
for name, child in module._modules.items():
|
||||
if child is not None:
|
||||
get_state_dict(
|
||||
child, destination, prefix + name + '.', keep_vars=keep_vars)
|
||||
for hook in module._state_dict_hooks.values():
|
||||
hook_result = hook(module, destination, prefix, local_metadata)
|
||||
if hook_result is not None:
|
||||
destination = hook_result
|
||||
return destination
|
||||
|
||||
|
||||
def save_checkpoint(model, filename, optimizer=None, meta=None):
|
||||
"""Save checkpoint to file.
|
||||
|
||||
|
@ -278,7 +342,7 @@ def save_checkpoint(model, filename, optimizer=None, meta=None):
|
|||
|
||||
checkpoint = {
|
||||
'meta': meta,
|
||||
'state_dict': weights_to_cpu(model.state_dict())
|
||||
'state_dict': weights_to_cpu(get_state_dict(model))
|
||||
}
|
||||
# save optimizer state dict in the checkpoint
|
||||
if isinstance(optimizer, Optimizer):
|
||||
|
|
|
@ -0,0 +1,112 @@
|
|||
from collections import OrderedDict
|
||||
|
||||
import torch.nn as nn
|
||||
from torch.nn.parallel import DataParallel
|
||||
|
||||
from mmcv.parallel.registry import MODULE_WRAPPERS
|
||||
from mmcv.runner.checkpoint import get_state_dict
|
||||
|
||||
|
||||
@MODULE_WRAPPERS.register_module()
|
||||
class DDPWrapper(object):
|
||||
|
||||
def __init__(self, module):
|
||||
self.module = module
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(3, 3, 1)
|
||||
self.norm = nn.BatchNorm2d(3)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.block = Block()
|
||||
self.conv = nn.Conv2d(3, 3, 1)
|
||||
|
||||
|
||||
def assert_tensor_equal(tensor_a, tensor_b):
|
||||
assert tensor_a.eq(tensor_b).all()
|
||||
|
||||
|
||||
def test_get_state_dict():
|
||||
state_dict_keys = set([
|
||||
'block.conv.weight', 'block.conv.bias', 'block.norm.weight',
|
||||
'block.norm.bias', 'block.norm.running_mean', 'block.norm.running_var',
|
||||
'block.norm.num_batches_tracked', 'conv.weight', 'conv.bias'
|
||||
])
|
||||
|
||||
model = Model()
|
||||
state_dict = get_state_dict(model)
|
||||
assert isinstance(state_dict, OrderedDict)
|
||||
assert set(state_dict.keys()) == state_dict_keys
|
||||
|
||||
assert_tensor_equal(state_dict['block.conv.weight'],
|
||||
model.block.conv.weight)
|
||||
assert_tensor_equal(state_dict['block.conv.bias'], model.block.conv.bias)
|
||||
assert_tensor_equal(state_dict['block.norm.weight'],
|
||||
model.block.norm.weight)
|
||||
assert_tensor_equal(state_dict['block.norm.bias'], model.block.norm.bias)
|
||||
assert_tensor_equal(state_dict['block.norm.running_mean'],
|
||||
model.block.norm.running_mean)
|
||||
assert_tensor_equal(state_dict['block.norm.running_var'],
|
||||
model.block.norm.running_var)
|
||||
assert_tensor_equal(state_dict['block.norm.num_batches_tracked'],
|
||||
model.block.norm.num_batches_tracked)
|
||||
assert_tensor_equal(state_dict['conv.weight'], model.conv.weight)
|
||||
assert_tensor_equal(state_dict['conv.bias'], model.conv.bias)
|
||||
|
||||
wrapped_model = DDPWrapper(model)
|
||||
state_dict = get_state_dict(wrapped_model)
|
||||
assert isinstance(state_dict, OrderedDict)
|
||||
assert set(state_dict.keys()) == state_dict_keys
|
||||
assert_tensor_equal(state_dict['block.conv.weight'],
|
||||
wrapped_model.module.block.conv.weight)
|
||||
assert_tensor_equal(state_dict['block.conv.bias'],
|
||||
wrapped_model.module.block.conv.bias)
|
||||
assert_tensor_equal(state_dict['block.norm.weight'],
|
||||
wrapped_model.module.block.norm.weight)
|
||||
assert_tensor_equal(state_dict['block.norm.bias'],
|
||||
wrapped_model.module.block.norm.bias)
|
||||
assert_tensor_equal(state_dict['block.norm.running_mean'],
|
||||
wrapped_model.module.block.norm.running_mean)
|
||||
assert_tensor_equal(state_dict['block.norm.running_var'],
|
||||
wrapped_model.module.block.norm.running_var)
|
||||
assert_tensor_equal(state_dict['block.norm.num_batches_tracked'],
|
||||
wrapped_model.module.block.norm.num_batches_tracked)
|
||||
assert_tensor_equal(state_dict['conv.weight'],
|
||||
wrapped_model.module.conv.weight)
|
||||
assert_tensor_equal(state_dict['conv.bias'],
|
||||
wrapped_model.module.conv.bias)
|
||||
|
||||
# wrapped inner module
|
||||
for name, module in wrapped_model.module._modules.items():
|
||||
module = DataParallel(module)
|
||||
wrapped_model.module._modules[name] = module
|
||||
state_dict = get_state_dict(wrapped_model)
|
||||
assert isinstance(state_dict, OrderedDict)
|
||||
assert set(state_dict.keys()) == state_dict_keys
|
||||
assert_tensor_equal(state_dict['block.conv.weight'],
|
||||
wrapped_model.module.block.module.conv.weight)
|
||||
assert_tensor_equal(state_dict['block.conv.bias'],
|
||||
wrapped_model.module.block.module.conv.bias)
|
||||
assert_tensor_equal(state_dict['block.norm.weight'],
|
||||
wrapped_model.module.block.module.norm.weight)
|
||||
assert_tensor_equal(state_dict['block.norm.bias'],
|
||||
wrapped_model.module.block.module.norm.bias)
|
||||
assert_tensor_equal(state_dict['block.norm.running_mean'],
|
||||
wrapped_model.module.block.module.norm.running_mean)
|
||||
assert_tensor_equal(state_dict['block.norm.running_var'],
|
||||
wrapped_model.module.block.module.norm.running_var)
|
||||
assert_tensor_equal(
|
||||
state_dict['block.norm.num_batches_tracked'],
|
||||
wrapped_model.module.block.module.norm.num_batches_tracked)
|
||||
assert_tensor_equal(state_dict['conv.weight'],
|
||||
wrapped_model.module.conv.module.weight)
|
||||
assert_tensor_equal(state_dict['conv.bias'],
|
||||
wrapped_model.module.conv.module.bias)
|
Loading…
Reference in New Issue