Remove all module wrapper's module when saving checkpoint (#399)

* fix: remove all module wrapper when saving checkpoint

* refactor: move position of if

* docs: add docstring

* refactor: add _save_to_state_dict from official torch

* docs: modify docstring of _save_to_state_dict

* docs: modify docstring

* feat: add unittest

* feat: add DataParallel to unittest

* fix: a bug when model has batchnorm

* docs: update docstring
pull/401/head
Harry 2020-07-08 23:20:22 +08:00 committed by GitHub
parent 27cc439d01
commit 5704613e28
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 177 additions and 1 deletions

View File

@ -254,6 +254,70 @@ def weights_to_cpu(state_dict):
return state_dict_cpu
def _save_to_state_dict(module, destination, prefix, keep_vars):
"""Saves module state to `destination` dictionary.
This method is modified from :meth:`torch.nn.Module._save_to_state_dict`.
Args:
module (nn.Module): The module to generate state_dict.
destination (dict): A dict where state will be stored.
prefix (str): The prefix for parameters and buffers used in this
module.
"""
for name, param in module._parameters.items():
if param is not None:
destination[prefix + name] = param if keep_vars else param.detach()
for name, buf in module._buffers.items():
# remove check of _non_persistent_buffers_set to allow nn.BatchNorm2d
if buf is not None:
destination[prefix + name] = buf if keep_vars else buf.detach()
def get_state_dict(module, destination=None, prefix='', keep_vars=False):
"""Returns a dictionary containing a whole state of the module.
Both parameters and persistent buffers (e.g. running averages) are
included. Keys are corresponding parameter and buffer names.
This method is modified from :meth:`torch.nn.Module.state_dict` to
recursively check parallel module in case that the model has a complicated
structure, e.g., nn.Module(nn.Module(DDP)).
Args:
module (nn.Module): The module to generate state_dict.
destination (OrderedDict): Returned dict for the state of the
module.
prefix (str): Prefix of the key.
keep_vars (bool): Whether to keep the variable property of the
parameters. Default: False.
Returns:
dict: A dictionary containing a whole state of the module.
"""
# recursively check parallel module in case that the model has a
# complicated structure, e.g., nn.Module(nn.Module(DDP))
if is_module_wrapper(module):
module = module.module
# below is the same as torch.nn.Module.state_dict()
if destination is None:
destination = OrderedDict()
destination._metadata = OrderedDict()
destination._metadata[prefix[:-1]] = local_metadata = dict(
version=module._version)
_save_to_state_dict(module, destination, prefix, keep_vars)
for name, child in module._modules.items():
if child is not None:
get_state_dict(
child, destination, prefix + name + '.', keep_vars=keep_vars)
for hook in module._state_dict_hooks.values():
hook_result = hook(module, destination, prefix, local_metadata)
if hook_result is not None:
destination = hook_result
return destination
def save_checkpoint(model, filename, optimizer=None, meta=None):
"""Save checkpoint to file.
@ -278,7 +342,7 @@ def save_checkpoint(model, filename, optimizer=None, meta=None):
checkpoint = {
'meta': meta,
'state_dict': weights_to_cpu(model.state_dict())
'state_dict': weights_to_cpu(get_state_dict(model))
}
# save optimizer state dict in the checkpoint
if isinstance(optimizer, Optimizer):

View File

@ -0,0 +1,112 @@
from collections import OrderedDict
import torch.nn as nn
from torch.nn.parallel import DataParallel
from mmcv.parallel.registry import MODULE_WRAPPERS
from mmcv.runner.checkpoint import get_state_dict
@MODULE_WRAPPERS.register_module()
class DDPWrapper(object):
def __init__(self, module):
self.module = module
class Block(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(3, 3, 1)
self.norm = nn.BatchNorm2d(3)
class Model(nn.Module):
def __init__(self):
super().__init__()
self.block = Block()
self.conv = nn.Conv2d(3, 3, 1)
def assert_tensor_equal(tensor_a, tensor_b):
assert tensor_a.eq(tensor_b).all()
def test_get_state_dict():
state_dict_keys = set([
'block.conv.weight', 'block.conv.bias', 'block.norm.weight',
'block.norm.bias', 'block.norm.running_mean', 'block.norm.running_var',
'block.norm.num_batches_tracked', 'conv.weight', 'conv.bias'
])
model = Model()
state_dict = get_state_dict(model)
assert isinstance(state_dict, OrderedDict)
assert set(state_dict.keys()) == state_dict_keys
assert_tensor_equal(state_dict['block.conv.weight'],
model.block.conv.weight)
assert_tensor_equal(state_dict['block.conv.bias'], model.block.conv.bias)
assert_tensor_equal(state_dict['block.norm.weight'],
model.block.norm.weight)
assert_tensor_equal(state_dict['block.norm.bias'], model.block.norm.bias)
assert_tensor_equal(state_dict['block.norm.running_mean'],
model.block.norm.running_mean)
assert_tensor_equal(state_dict['block.norm.running_var'],
model.block.norm.running_var)
assert_tensor_equal(state_dict['block.norm.num_batches_tracked'],
model.block.norm.num_batches_tracked)
assert_tensor_equal(state_dict['conv.weight'], model.conv.weight)
assert_tensor_equal(state_dict['conv.bias'], model.conv.bias)
wrapped_model = DDPWrapper(model)
state_dict = get_state_dict(wrapped_model)
assert isinstance(state_dict, OrderedDict)
assert set(state_dict.keys()) == state_dict_keys
assert_tensor_equal(state_dict['block.conv.weight'],
wrapped_model.module.block.conv.weight)
assert_tensor_equal(state_dict['block.conv.bias'],
wrapped_model.module.block.conv.bias)
assert_tensor_equal(state_dict['block.norm.weight'],
wrapped_model.module.block.norm.weight)
assert_tensor_equal(state_dict['block.norm.bias'],
wrapped_model.module.block.norm.bias)
assert_tensor_equal(state_dict['block.norm.running_mean'],
wrapped_model.module.block.norm.running_mean)
assert_tensor_equal(state_dict['block.norm.running_var'],
wrapped_model.module.block.norm.running_var)
assert_tensor_equal(state_dict['block.norm.num_batches_tracked'],
wrapped_model.module.block.norm.num_batches_tracked)
assert_tensor_equal(state_dict['conv.weight'],
wrapped_model.module.conv.weight)
assert_tensor_equal(state_dict['conv.bias'],
wrapped_model.module.conv.bias)
# wrapped inner module
for name, module in wrapped_model.module._modules.items():
module = DataParallel(module)
wrapped_model.module._modules[name] = module
state_dict = get_state_dict(wrapped_model)
assert isinstance(state_dict, OrderedDict)
assert set(state_dict.keys()) == state_dict_keys
assert_tensor_equal(state_dict['block.conv.weight'],
wrapped_model.module.block.module.conv.weight)
assert_tensor_equal(state_dict['block.conv.bias'],
wrapped_model.module.block.module.conv.bias)
assert_tensor_equal(state_dict['block.norm.weight'],
wrapped_model.module.block.module.norm.weight)
assert_tensor_equal(state_dict['block.norm.bias'],
wrapped_model.module.block.module.norm.bias)
assert_tensor_equal(state_dict['block.norm.running_mean'],
wrapped_model.module.block.module.norm.running_mean)
assert_tensor_equal(state_dict['block.norm.running_var'],
wrapped_model.module.block.module.norm.running_var)
assert_tensor_equal(
state_dict['block.norm.num_batches_tracked'],
wrapped_model.module.block.module.norm.num_batches_tracked)
assert_tensor_equal(state_dict['conv.weight'],
wrapped_model.module.conv.module.weight)
assert_tensor_equal(state_dict['conv.bias'],
wrapped_model.module.conv.module.bias)