[Fix] Init weights after build model. (#164)
* [Fix] Init weights after build model. * add unit tests and docstringpull/165/head
parent
87da7599ae
commit
ab8b51682f
|
@ -658,6 +658,10 @@ class Runner:
|
|||
def build_model(self, model: Union[nn.Module, Dict]) -> nn.Module:
|
||||
"""Build model.
|
||||
|
||||
If ``model`` is a dict, it will be used to build a nn.Module object
|
||||
and initialize the weights if it has ``init_weights`` method.
|
||||
Else, if ``model`` is a nn.Module object it will be returned directly.
|
||||
|
||||
An example of ``model``::
|
||||
|
||||
model = dict(type='ResNet')
|
||||
|
@ -673,7 +677,11 @@ class Runner:
|
|||
if isinstance(model, nn.Module):
|
||||
return model
|
||||
elif isinstance(model, dict):
|
||||
return MODELS.build(model)
|
||||
model = MODELS.build(model)
|
||||
# init weights
|
||||
if hasattr(model, 'init_weights'):
|
||||
model.init_weights()
|
||||
return model
|
||||
else:
|
||||
raise TypeError('model should be a nn.Module object or dict, '
|
||||
f'but got {model}')
|
||||
|
|
|
@ -511,6 +511,25 @@ class TestRunner(TestCase):
|
|||
model = runner.build_model(dict(type='ToyModel1'))
|
||||
self.assertIsInstance(model, ToyModel1)
|
||||
|
||||
# test init weights
|
||||
@MODELS.register_module()
|
||||
class ToyModel2(ToyModel):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.initiailzed = False
|
||||
|
||||
def init_weights(self):
|
||||
self.initiailzed = True
|
||||
|
||||
model = runner.build_model(dict(type='ToyModel2'))
|
||||
self.assertTrue(model.initiailzed)
|
||||
|
||||
# test init weights with model object
|
||||
_model = ToyModel2()
|
||||
model = runner.build_model(_model)
|
||||
self.assertFalse(model.initiailzed)
|
||||
|
||||
def test_wrap_model(self):
|
||||
# TODO: test on distributed environment
|
||||
# custom model wrapper
|
||||
|
|
Loading…
Reference in New Issue