mirror of https://github.com/open-mmlab/mmcv.git
Remove all module wrapper's module when saving checkpoint (#399)
* fix: remove all module wrapper when saving checkpoint * refactor: move position of if * docs: add docstring * refactor: add _save_to_state_dict from official torch * docs: modify docstring of _save_to_state_dict * docs: modify docstring * feat: add unittest * feat: add DataParallel to unittest * fix: a bug when model has batchnorm * docs: update docstringpull/401/head
parent
27cc439d01
commit
5704613e28
|
@ -254,6 +254,70 @@ def weights_to_cpu(state_dict):
|
||||||
return state_dict_cpu
|
return state_dict_cpu
|
||||||
|
|
||||||
|
|
||||||
|
def _save_to_state_dict(module, destination, prefix, keep_vars):
|
||||||
|
"""Saves module state to `destination` dictionary.
|
||||||
|
|
||||||
|
This method is modified from :meth:`torch.nn.Module._save_to_state_dict`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
module (nn.Module): The module to generate state_dict.
|
||||||
|
destination (dict): A dict where state will be stored.
|
||||||
|
prefix (str): The prefix for parameters and buffers used in this
|
||||||
|
module.
|
||||||
|
"""
|
||||||
|
for name, param in module._parameters.items():
|
||||||
|
if param is not None:
|
||||||
|
destination[prefix + name] = param if keep_vars else param.detach()
|
||||||
|
for name, buf in module._buffers.items():
|
||||||
|
# remove check of _non_persistent_buffers_set to allow nn.BatchNorm2d
|
||||||
|
if buf is not None:
|
||||||
|
destination[prefix + name] = buf if keep_vars else buf.detach()
|
||||||
|
|
||||||
|
|
||||||
|
def get_state_dict(module, destination=None, prefix='', keep_vars=False):
|
||||||
|
"""Returns a dictionary containing a whole state of the module.
|
||||||
|
|
||||||
|
Both parameters and persistent buffers (e.g. running averages) are
|
||||||
|
included. Keys are corresponding parameter and buffer names.
|
||||||
|
|
||||||
|
This method is modified from :meth:`torch.nn.Module.state_dict` to
|
||||||
|
recursively check parallel module in case that the model has a complicated
|
||||||
|
structure, e.g., nn.Module(nn.Module(DDP)).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
module (nn.Module): The module to generate state_dict.
|
||||||
|
destination (OrderedDict): Returned dict for the state of the
|
||||||
|
module.
|
||||||
|
prefix (str): Prefix of the key.
|
||||||
|
keep_vars (bool): Whether to keep the variable property of the
|
||||||
|
parameters. Default: False.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: A dictionary containing a whole state of the module.
|
||||||
|
"""
|
||||||
|
# recursively check parallel module in case that the model has a
|
||||||
|
# complicated structure, e.g., nn.Module(nn.Module(DDP))
|
||||||
|
if is_module_wrapper(module):
|
||||||
|
module = module.module
|
||||||
|
|
||||||
|
# below is the same as torch.nn.Module.state_dict()
|
||||||
|
if destination is None:
|
||||||
|
destination = OrderedDict()
|
||||||
|
destination._metadata = OrderedDict()
|
||||||
|
destination._metadata[prefix[:-1]] = local_metadata = dict(
|
||||||
|
version=module._version)
|
||||||
|
_save_to_state_dict(module, destination, prefix, keep_vars)
|
||||||
|
for name, child in module._modules.items():
|
||||||
|
if child is not None:
|
||||||
|
get_state_dict(
|
||||||
|
child, destination, prefix + name + '.', keep_vars=keep_vars)
|
||||||
|
for hook in module._state_dict_hooks.values():
|
||||||
|
hook_result = hook(module, destination, prefix, local_metadata)
|
||||||
|
if hook_result is not None:
|
||||||
|
destination = hook_result
|
||||||
|
return destination
|
||||||
|
|
||||||
|
|
||||||
def save_checkpoint(model, filename, optimizer=None, meta=None):
|
def save_checkpoint(model, filename, optimizer=None, meta=None):
|
||||||
"""Save checkpoint to file.
|
"""Save checkpoint to file.
|
||||||
|
|
||||||
|
@ -278,7 +342,7 @@ def save_checkpoint(model, filename, optimizer=None, meta=None):
|
||||||
|
|
||||||
checkpoint = {
|
checkpoint = {
|
||||||
'meta': meta,
|
'meta': meta,
|
||||||
'state_dict': weights_to_cpu(model.state_dict())
|
'state_dict': weights_to_cpu(get_state_dict(model))
|
||||||
}
|
}
|
||||||
# save optimizer state dict in the checkpoint
|
# save optimizer state dict in the checkpoint
|
||||||
if isinstance(optimizer, Optimizer):
|
if isinstance(optimizer, Optimizer):
|
||||||
|
|
|
@ -0,0 +1,112 @@
|
||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.nn.parallel import DataParallel
|
||||||
|
|
||||||
|
from mmcv.parallel.registry import MODULE_WRAPPERS
|
||||||
|
from mmcv.runner.checkpoint import get_state_dict
|
||||||
|
|
||||||
|
|
||||||
|
@MODULE_WRAPPERS.register_module()
|
||||||
|
class DDPWrapper(object):
|
||||||
|
|
||||||
|
def __init__(self, module):
|
||||||
|
self.module = module
|
||||||
|
|
||||||
|
|
||||||
|
class Block(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.conv = nn.Conv2d(3, 3, 1)
|
||||||
|
self.norm = nn.BatchNorm2d(3)
|
||||||
|
|
||||||
|
|
||||||
|
class Model(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.block = Block()
|
||||||
|
self.conv = nn.Conv2d(3, 3, 1)
|
||||||
|
|
||||||
|
|
||||||
|
def assert_tensor_equal(tensor_a, tensor_b):
|
||||||
|
assert tensor_a.eq(tensor_b).all()
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_state_dict():
|
||||||
|
state_dict_keys = set([
|
||||||
|
'block.conv.weight', 'block.conv.bias', 'block.norm.weight',
|
||||||
|
'block.norm.bias', 'block.norm.running_mean', 'block.norm.running_var',
|
||||||
|
'block.norm.num_batches_tracked', 'conv.weight', 'conv.bias'
|
||||||
|
])
|
||||||
|
|
||||||
|
model = Model()
|
||||||
|
state_dict = get_state_dict(model)
|
||||||
|
assert isinstance(state_dict, OrderedDict)
|
||||||
|
assert set(state_dict.keys()) == state_dict_keys
|
||||||
|
|
||||||
|
assert_tensor_equal(state_dict['block.conv.weight'],
|
||||||
|
model.block.conv.weight)
|
||||||
|
assert_tensor_equal(state_dict['block.conv.bias'], model.block.conv.bias)
|
||||||
|
assert_tensor_equal(state_dict['block.norm.weight'],
|
||||||
|
model.block.norm.weight)
|
||||||
|
assert_tensor_equal(state_dict['block.norm.bias'], model.block.norm.bias)
|
||||||
|
assert_tensor_equal(state_dict['block.norm.running_mean'],
|
||||||
|
model.block.norm.running_mean)
|
||||||
|
assert_tensor_equal(state_dict['block.norm.running_var'],
|
||||||
|
model.block.norm.running_var)
|
||||||
|
assert_tensor_equal(state_dict['block.norm.num_batches_tracked'],
|
||||||
|
model.block.norm.num_batches_tracked)
|
||||||
|
assert_tensor_equal(state_dict['conv.weight'], model.conv.weight)
|
||||||
|
assert_tensor_equal(state_dict['conv.bias'], model.conv.bias)
|
||||||
|
|
||||||
|
wrapped_model = DDPWrapper(model)
|
||||||
|
state_dict = get_state_dict(wrapped_model)
|
||||||
|
assert isinstance(state_dict, OrderedDict)
|
||||||
|
assert set(state_dict.keys()) == state_dict_keys
|
||||||
|
assert_tensor_equal(state_dict['block.conv.weight'],
|
||||||
|
wrapped_model.module.block.conv.weight)
|
||||||
|
assert_tensor_equal(state_dict['block.conv.bias'],
|
||||||
|
wrapped_model.module.block.conv.bias)
|
||||||
|
assert_tensor_equal(state_dict['block.norm.weight'],
|
||||||
|
wrapped_model.module.block.norm.weight)
|
||||||
|
assert_tensor_equal(state_dict['block.norm.bias'],
|
||||||
|
wrapped_model.module.block.norm.bias)
|
||||||
|
assert_tensor_equal(state_dict['block.norm.running_mean'],
|
||||||
|
wrapped_model.module.block.norm.running_mean)
|
||||||
|
assert_tensor_equal(state_dict['block.norm.running_var'],
|
||||||
|
wrapped_model.module.block.norm.running_var)
|
||||||
|
assert_tensor_equal(state_dict['block.norm.num_batches_tracked'],
|
||||||
|
wrapped_model.module.block.norm.num_batches_tracked)
|
||||||
|
assert_tensor_equal(state_dict['conv.weight'],
|
||||||
|
wrapped_model.module.conv.weight)
|
||||||
|
assert_tensor_equal(state_dict['conv.bias'],
|
||||||
|
wrapped_model.module.conv.bias)
|
||||||
|
|
||||||
|
# wrapped inner module
|
||||||
|
for name, module in wrapped_model.module._modules.items():
|
||||||
|
module = DataParallel(module)
|
||||||
|
wrapped_model.module._modules[name] = module
|
||||||
|
state_dict = get_state_dict(wrapped_model)
|
||||||
|
assert isinstance(state_dict, OrderedDict)
|
||||||
|
assert set(state_dict.keys()) == state_dict_keys
|
||||||
|
assert_tensor_equal(state_dict['block.conv.weight'],
|
||||||
|
wrapped_model.module.block.module.conv.weight)
|
||||||
|
assert_tensor_equal(state_dict['block.conv.bias'],
|
||||||
|
wrapped_model.module.block.module.conv.bias)
|
||||||
|
assert_tensor_equal(state_dict['block.norm.weight'],
|
||||||
|
wrapped_model.module.block.module.norm.weight)
|
||||||
|
assert_tensor_equal(state_dict['block.norm.bias'],
|
||||||
|
wrapped_model.module.block.module.norm.bias)
|
||||||
|
assert_tensor_equal(state_dict['block.norm.running_mean'],
|
||||||
|
wrapped_model.module.block.module.norm.running_mean)
|
||||||
|
assert_tensor_equal(state_dict['block.norm.running_var'],
|
||||||
|
wrapped_model.module.block.module.norm.running_var)
|
||||||
|
assert_tensor_equal(
|
||||||
|
state_dict['block.norm.num_batches_tracked'],
|
||||||
|
wrapped_model.module.block.module.norm.num_batches_tracked)
|
||||||
|
assert_tensor_equal(state_dict['conv.weight'],
|
||||||
|
wrapped_model.module.conv.module.weight)
|
||||||
|
assert_tensor_equal(state_dict['conv.bias'],
|
||||||
|
wrapped_model.module.conv.module.bias)
|
Loading…
Reference in New Issue