[Fix] Initialize nested modules in ddp which define 'init_weights' method (#1045)
parent
fd84c210e5
commit
dc931fd2c0
|
@ -11,6 +11,7 @@ import torch.nn as nn
|
|||
from mmengine.dist import master_only
|
||||
from mmengine.logging import MMLogger, print_log
|
||||
from .weight_init import initialize, update_init_info
|
||||
from .wrappers.utils import is_model_wrapper
|
||||
|
||||
|
||||
class BaseModule(nn.Module, metaclass=ABCMeta):
|
||||
|
@ -123,6 +124,8 @@ class BaseModule(nn.Module, metaclass=ABCMeta):
|
|||
initialize(self, other_cfgs)
|
||||
|
||||
for m in self.children():
|
||||
if is_model_wrapper(m) and not hasattr(m, 'init_weights'):
|
||||
m = m.module
|
||||
if hasattr(m, 'init_weights'):
|
||||
m.init_weights()
|
||||
# users may overload the `init_weights`
|
||||
|
|
|
@ -4,6 +4,7 @@ import logging
|
|||
import os.path as osp
|
||||
import tempfile
|
||||
from unittest import TestCase
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
@ -145,7 +146,6 @@ class TestBaseModule(TestCase):
|
|||
├──conv1d (FooConv1d, weight=3, bias=4)
|
||||
├──reg (nn.Linear, weight=1, bias=2)
|
||||
"""
|
||||
|
||||
self.model.init_weights()
|
||||
|
||||
assert torch.equal(
|
||||
|
@ -222,6 +222,22 @@ class TestBaseModule(TestCase):
|
|||
self.assertTrue((ori_layer_weight != model.linear.linear.weight).any())
|
||||
self.assertTrue((ori_layer_bias != model.linear.linear.bias).any())
|
||||
|
||||
class FakeDDP(nn.Module):
|
||||
|
||||
def __init__(self, module) -> None:
|
||||
super().__init__()
|
||||
self.module = module
|
||||
|
||||
# Test initialization of nested modules in DDPModule which define
|
||||
# `init_weights`.
|
||||
with patch('mmengine.model.base_module.is_model_wrapper',
|
||||
lambda x: isinstance(x, FakeDDP)):
|
||||
model = FOOMODELS.build(model_cfg)
|
||||
model.ddp = FakeDDP(CustomLinear())
|
||||
model.init_weights()
|
||||
self.assertTrue((model.ddp.module.linear.weight == 1).all())
|
||||
self.assertTrue((model.ddp.module.linear.bias == 2).all())
|
||||
|
||||
def test_dump_init_info(self):
|
||||
import os
|
||||
import shutil
|
||||
|
|
Loading…
Reference in New Issue