[Fix] Init weights after build model. (#164)
* [Fix] Init weights after build model. * add unit tests and docstringpull/165/head
parent
87da7599ae
commit
ab8b51682f
|
@ -658,6 +658,10 @@ class Runner:
|
||||||
def build_model(self, model: Union[nn.Module, Dict]) -> nn.Module:
|
def build_model(self, model: Union[nn.Module, Dict]) -> nn.Module:
|
||||||
"""Build model.
|
"""Build model.
|
||||||
|
|
||||||
|
If ``model`` is a dict, it will be used to build a nn.Module object
|
||||||
|
and initialize the weights if it has ``init_weights`` method.
|
||||||
|
Else, if ``model`` is a nn.Module object it will be returned directly.
|
||||||
|
|
||||||
An example of ``model``::
|
An example of ``model``::
|
||||||
|
|
||||||
model = dict(type='ResNet')
|
model = dict(type='ResNet')
|
||||||
|
@ -673,7 +677,11 @@ class Runner:
|
||||||
if isinstance(model, nn.Module):
|
if isinstance(model, nn.Module):
|
||||||
return model
|
return model
|
||||||
elif isinstance(model, dict):
|
elif isinstance(model, dict):
|
||||||
return MODELS.build(model)
|
model = MODELS.build(model)
|
||||||
|
# init weights
|
||||||
|
if hasattr(model, 'init_weights'):
|
||||||
|
model.init_weights()
|
||||||
|
return model
|
||||||
else:
|
else:
|
||||||
raise TypeError('model should be a nn.Module object or dict, '
|
raise TypeError('model should be a nn.Module object or dict, '
|
||||||
f'but got {model}')
|
f'but got {model}')
|
||||||
|
|
|
@ -511,6 +511,25 @@ class TestRunner(TestCase):
|
||||||
model = runner.build_model(dict(type='ToyModel1'))
|
model = runner.build_model(dict(type='ToyModel1'))
|
||||||
self.assertIsInstance(model, ToyModel1)
|
self.assertIsInstance(model, ToyModel1)
|
||||||
|
|
||||||
|
# test init weights
|
||||||
|
@MODELS.register_module()
|
||||||
|
class ToyModel2(ToyModel):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.initiailzed = False
|
||||||
|
|
||||||
|
def init_weights(self):
|
||||||
|
self.initiailzed = True
|
||||||
|
|
||||||
|
model = runner.build_model(dict(type='ToyModel2'))
|
||||||
|
self.assertTrue(model.initiailzed)
|
||||||
|
|
||||||
|
# test init weights with model object
|
||||||
|
_model = ToyModel2()
|
||||||
|
model = runner.build_model(_model)
|
||||||
|
self.assertFalse(model.initiailzed)
|
||||||
|
|
||||||
def test_wrap_model(self):
|
def test_wrap_model(self):
|
||||||
# TODO: test on distributed environment
|
# TODO: test on distributed environment
|
||||||
# custom model wrapper
|
# custom model wrapper
|
||||||
|
|
Loading…
Reference in New Issue