[Feature] Add ModuleList Sequential and ModuleDict (#299)
* add module list * add module list * fix docstringpull/301/head
parent
df0c510444
commit
bcab813242
|
@ -2,7 +2,7 @@
|
|||
from .averaged_model import (ExponentialMovingAverage, MomentumAnnealingEMA,
|
||||
StochasticWeightAverage)
|
||||
from .base_model import BaseDataPreprocessor, BaseModel, ImgDataPreprocessor
|
||||
from .base_module import BaseModule
|
||||
from .base_module import BaseModule, ModuleDict, ModuleList, Sequential
|
||||
from .utils import detect_anomalous_params, merge_dict, stack_batch
|
||||
from .wrappers import (MMDistributedDataParallel,
|
||||
MMSeparateDistributedDataParallel, is_model_wrapper)
|
||||
|
@ -12,5 +12,6 @@ __all__ = [
|
|||
'ExponentialMovingAverage', 'MomentumAnnealingEMA', 'BaseModel',
|
||||
'BaseDataPreprocessor', 'ImgDataPreprocessor',
|
||||
'MMSeparateDistributedDataParallel', 'BaseModule', 'stack_batch',
|
||||
'merge_dict', 'detect_anomalous_params'
|
||||
'merge_dict', 'detect_anomalous_params', 'ModuleList', 'ModuleDict',
|
||||
'Sequential'
|
||||
]
|
||||
|
|
|
@ -5,6 +5,7 @@ import warnings
|
|||
from abc import ABCMeta
|
||||
from collections import defaultdict
|
||||
from logging import FileHandler
|
||||
from typing import Iterable, Optional
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
|
@ -165,3 +166,55 @@ class BaseModule(nn.Module, metaclass=ABCMeta):
|
|||
if self.init_cfg:
|
||||
s += f'\ninit_cfg={self.init_cfg}'
|
||||
return s
|
||||
|
||||
|
||||
class Sequential(BaseModule, nn.Sequential):
|
||||
"""Sequential module in openmmlab.
|
||||
|
||||
Ensures that all modules in ``Sequential`` have a different initialization
|
||||
strategy than the outer model
|
||||
|
||||
Args:
|
||||
init_cfg (dict, optional): Initialization config dict.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, init_cfg: Optional[dict] = None):
|
||||
BaseModule.__init__(self, init_cfg)
|
||||
nn.Sequential.__init__(self, *args)
|
||||
|
||||
|
||||
class ModuleList(BaseModule, nn.ModuleList):
|
||||
"""ModuleList in openmmlab.
|
||||
|
||||
Ensures that all modules in ``ModuleList`` have a different initialization
|
||||
strategy than the outer model
|
||||
|
||||
Args:
|
||||
modules (iterable, optional): An iterable of modules to add.
|
||||
init_cfg (dict, optional): Initialization config dict.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
modules: Optional[Iterable] = None,
|
||||
init_cfg: Optional[dict] = None):
|
||||
BaseModule.__init__(self, init_cfg)
|
||||
nn.ModuleList.__init__(self, modules)
|
||||
|
||||
|
||||
class ModuleDict(BaseModule, nn.ModuleDict):
|
||||
"""ModuleDict in openmmlab.
|
||||
|
||||
Ensures that all modules in ``ModuleDict`` have a different initialization
|
||||
strategy than the outer model
|
||||
|
||||
Args:
|
||||
modules (dict, optional): A mapping (dictionary) of (string: module)
|
||||
or an iterable of key-value pairs of type (string, module).
|
||||
init_cfg (dict, optional): Initialization config dict.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
modules: Optional[dict] = None,
|
||||
init_cfg: Optional[dict] = None):
|
||||
BaseModule.__init__(self, init_cfg)
|
||||
nn.ModuleDict.__init__(self, modules)
|
||||
|
|
|
@ -5,7 +5,7 @@ import torch
|
|||
from torch import nn
|
||||
|
||||
from mmengine.logging.logger import MMLogger
|
||||
from mmengine.model.base_module import BaseModule
|
||||
from mmengine.model import BaseModule, ModuleDict, ModuleList, Sequential
|
||||
from mmengine.registry import Registry, build_from_cfg
|
||||
|
||||
COMPONENTS = Registry('component')
|
||||
|
@ -195,3 +195,164 @@ class TestBaseModule(TestCase):
|
|||
assert len(os.listdir(dump_dir)) == 1
|
||||
assert os.stat(log_path).st_size != 0
|
||||
shutil.rmtree(dump_dir)
|
||||
|
||||
|
||||
class TestModuleList(TestCase):
|
||||
|
||||
def test_modulelist_weight_init(self):
|
||||
models_cfg = [
|
||||
dict(
|
||||
type='FooConv1d',
|
||||
init_cfg=dict(
|
||||
type='Constant', layer='Conv1d', val=0., bias=1.)),
|
||||
dict(
|
||||
type='FooConv2d',
|
||||
init_cfg=dict(
|
||||
type='Constant', layer='Conv2d', val=2., bias=3.)),
|
||||
]
|
||||
layers = [build_from_cfg(cfg, COMPONENTS) for cfg in models_cfg]
|
||||
modellist = ModuleList(layers)
|
||||
modellist.init_weights()
|
||||
self.assertTrue(
|
||||
torch.equal(modellist[0].conv1d.weight,
|
||||
torch.full(modellist[0].conv1d.weight.shape, 0.)))
|
||||
self.assertTrue(
|
||||
torch.equal(modellist[0].conv1d.bias,
|
||||
torch.full(modellist[0].conv1d.bias.shape, 1.)))
|
||||
self.assertTrue(
|
||||
torch.equal(modellist[1].conv2d.weight,
|
||||
torch.full(modellist[1].conv2d.weight.shape, 2.)))
|
||||
self.assertTrue(
|
||||
torch.equal(modellist[1].conv2d.bias,
|
||||
torch.full(modellist[1].conv2d.bias.shape, 3.)))
|
||||
# inner init_cfg has higher priority
|
||||
layers = [build_from_cfg(cfg, COMPONENTS) for cfg in models_cfg]
|
||||
modellist = ModuleList(
|
||||
layers,
|
||||
init_cfg=dict(
|
||||
type='Constant', layer=['Conv1d', 'Conv2d'], val=4., bias=5.))
|
||||
modellist.init_weights()
|
||||
self.assertTrue(
|
||||
torch.equal(modellist[0].conv1d.weight,
|
||||
torch.full(modellist[0].conv1d.weight.shape, 0.)))
|
||||
self.assertTrue(
|
||||
torch.equal(modellist[0].conv1d.bias,
|
||||
torch.full(modellist[0].conv1d.bias.shape, 1.)))
|
||||
self.assertTrue(
|
||||
torch.equal(modellist[1].conv2d.weight,
|
||||
torch.full(modellist[1].conv2d.weight.shape, 2.)))
|
||||
self.assertTrue(
|
||||
torch.equal(modellist[1].conv2d.bias,
|
||||
torch.full(modellist[1].conv2d.bias.shape, 3.)))
|
||||
|
||||
|
||||
class TestModuleDict(TestCase):
|
||||
|
||||
def test_moduledict_weight_init(self):
|
||||
models_cfg = dict(
|
||||
foo_conv_1d=dict(
|
||||
type='FooConv1d',
|
||||
init_cfg=dict(
|
||||
type='Constant', layer='Conv1d', val=0., bias=1.)),
|
||||
foo_conv_2d=dict(
|
||||
type='FooConv2d',
|
||||
init_cfg=dict(
|
||||
type='Constant', layer='Conv2d', val=2., bias=3.)),
|
||||
)
|
||||
layers = {
|
||||
name: build_from_cfg(cfg, COMPONENTS)
|
||||
for name, cfg in models_cfg.items()
|
||||
}
|
||||
modeldict = ModuleDict(layers)
|
||||
modeldict.init_weights()
|
||||
self.assertTrue(
|
||||
torch.equal(
|
||||
modeldict['foo_conv_1d'].conv1d.weight,
|
||||
torch.full(modeldict['foo_conv_1d'].conv1d.weight.shape, 0.)))
|
||||
self.assertTrue(
|
||||
torch.equal(
|
||||
modeldict['foo_conv_1d'].conv1d.bias,
|
||||
torch.full(modeldict['foo_conv_1d'].conv1d.bias.shape, 1.)))
|
||||
self.assertTrue(
|
||||
torch.equal(
|
||||
modeldict['foo_conv_2d'].conv2d.weight,
|
||||
torch.full(modeldict['foo_conv_2d'].conv2d.weight.shape, 2.)))
|
||||
self.assertTrue(
|
||||
torch.equal(
|
||||
modeldict['foo_conv_2d'].conv2d.bias,
|
||||
torch.full(modeldict['foo_conv_2d'].conv2d.bias.shape, 3.)))
|
||||
# inner init_cfg has higher priority
|
||||
layers = {
|
||||
name: build_from_cfg(cfg, COMPONENTS)
|
||||
for name, cfg in models_cfg.items()
|
||||
}
|
||||
modeldict = ModuleDict(
|
||||
layers,
|
||||
init_cfg=dict(
|
||||
type='Constant', layer=['Conv1d', 'Conv2d'], val=4., bias=5.))
|
||||
modeldict.init_weights()
|
||||
self.assertTrue(
|
||||
torch.equal(
|
||||
modeldict['foo_conv_1d'].conv1d.weight,
|
||||
torch.full(modeldict['foo_conv_1d'].conv1d.weight.shape, 0.)))
|
||||
self.assertTrue(
|
||||
torch.equal(
|
||||
modeldict['foo_conv_1d'].conv1d.bias,
|
||||
torch.full(modeldict['foo_conv_1d'].conv1d.bias.shape, 1.)))
|
||||
self.assertTrue(
|
||||
torch.equal(
|
||||
modeldict['foo_conv_2d'].conv2d.weight,
|
||||
torch.full(modeldict['foo_conv_2d'].conv2d.weight.shape, 2.)))
|
||||
self.assertTrue(
|
||||
torch.equal(
|
||||
modeldict['foo_conv_2d'].conv2d.bias,
|
||||
torch.full(modeldict['foo_conv_2d'].conv2d.bias.shape, 3.)))
|
||||
|
||||
|
||||
class TestSequential(TestCase):
|
||||
|
||||
def test_sequential_model_weight_init(self):
|
||||
seq_model_cfg = [
|
||||
dict(
|
||||
type='FooConv1d',
|
||||
init_cfg=dict(
|
||||
type='Constant', layer='Conv1d', val=0., bias=1.)),
|
||||
dict(
|
||||
type='FooConv2d',
|
||||
init_cfg=dict(
|
||||
type='Constant', layer='Conv2d', val=2., bias=3.)),
|
||||
]
|
||||
layers = [build_from_cfg(cfg, COMPONENTS) for cfg in seq_model_cfg]
|
||||
seq_model = Sequential(*layers)
|
||||
seq_model.init_weights()
|
||||
self.assertTrue(
|
||||
torch.equal(seq_model[0].conv1d.weight,
|
||||
torch.full(seq_model[0].conv1d.weight.shape, 0.)))
|
||||
self.assertTrue(
|
||||
torch.equal(seq_model[0].conv1d.bias,
|
||||
torch.full(seq_model[0].conv1d.bias.shape, 1.)))
|
||||
self.assertTrue(
|
||||
torch.equal(seq_model[1].conv2d.weight,
|
||||
torch.full(seq_model[1].conv2d.weight.shape, 2.)))
|
||||
self.assertTrue(
|
||||
torch.equal(seq_model[1].conv2d.bias,
|
||||
torch.full(seq_model[1].conv2d.bias.shape, 3.)))
|
||||
# inner init_cfg has higher priority
|
||||
layers = [build_from_cfg(cfg, COMPONENTS) for cfg in seq_model_cfg]
|
||||
seq_model = Sequential(
|
||||
*layers,
|
||||
init_cfg=dict(
|
||||
type='Constant', layer=['Conv1d', 'Conv2d'], val=4., bias=5.))
|
||||
seq_model.init_weights()
|
||||
self.assertTrue(
|
||||
torch.equal(seq_model[0].conv1d.weight,
|
||||
torch.full(seq_model[0].conv1d.weight.shape, 0.)))
|
||||
self.assertTrue(
|
||||
torch.equal(seq_model[0].conv1d.bias,
|
||||
torch.full(seq_model[0].conv1d.bias.shape, 1.)))
|
||||
self.assertTrue(
|
||||
torch.equal(seq_model[1].conv2d.weight,
|
||||
torch.full(seq_model[1].conv2d.weight.shape, 2.)))
|
||||
self.assertTrue(
|
||||
torch.equal(seq_model[1].conv2d.bias,
|
||||
torch.full(seq_model[1].conv2d.bias.shape, 3.)))
|
||||
|
|
Loading…
Reference in New Issue