[Feature] Add ModuleList Sequential and ModuleDict (#299)

* add module list

* add module list

* fix docstring
pull/301/head
Mashiro 2022-06-13 13:51:07 +08:00 committed by GitHub
parent df0c510444
commit bcab813242
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 218 additions and 3 deletions

View File

@ -2,7 +2,7 @@
from .averaged_model import (ExponentialMovingAverage, MomentumAnnealingEMA,
StochasticWeightAverage)
from .base_model import BaseDataPreprocessor, BaseModel, ImgDataPreprocessor
from .base_module import BaseModule
from .base_module import BaseModule, ModuleDict, ModuleList, Sequential
from .utils import detect_anomalous_params, merge_dict, stack_batch
from .wrappers import (MMDistributedDataParallel,
MMSeparateDistributedDataParallel, is_model_wrapper)
@ -12,5 +12,6 @@ __all__ = [
'ExponentialMovingAverage', 'MomentumAnnealingEMA', 'BaseModel',
'BaseDataPreprocessor', 'ImgDataPreprocessor',
'MMSeparateDistributedDataParallel', 'BaseModule', 'stack_batch',
'merge_dict', 'detect_anomalous_params'
'merge_dict', 'detect_anomalous_params', 'ModuleList', 'ModuleDict',
'Sequential'
]

View File

@ -5,6 +5,7 @@ import warnings
from abc import ABCMeta
from collections import defaultdict
from logging import FileHandler
from typing import Iterable, Optional
import torch.nn as nn
@ -165,3 +166,55 @@ class BaseModule(nn.Module, metaclass=ABCMeta):
if self.init_cfg:
s += f'\ninit_cfg={self.init_cfg}'
return s
class Sequential(BaseModule, nn.Sequential):
"""Sequential module in openmmlab.
Ensures that all modules in ``Sequential`` have a different initialization
strategy than the outer model
Args:
init_cfg (dict, optional): Initialization config dict.
"""
def __init__(self, *args, init_cfg: Optional[dict] = None):
BaseModule.__init__(self, init_cfg)
nn.Sequential.__init__(self, *args)
class ModuleList(BaseModule, nn.ModuleList):
"""ModuleList in openmmlab.
Ensures that all modules in ``ModuleList`` have a different initialization
strategy than the outer model
Args:
modules (iterable, optional): An iterable of modules to add.
init_cfg (dict, optional): Initialization config dict.
"""
def __init__(self,
modules: Optional[Iterable] = None,
init_cfg: Optional[dict] = None):
BaseModule.__init__(self, init_cfg)
nn.ModuleList.__init__(self, modules)
class ModuleDict(BaseModule, nn.ModuleDict):
"""ModuleDict in openmmlab.
Ensures that all modules in ``ModuleDict`` have a different initialization
strategy than the outer model
Args:
modules (dict, optional): A mapping (dictionary) of (string: module)
or an iterable of key-value pairs of type (string, module).
init_cfg (dict, optional): Initialization config dict.
"""
def __init__(self,
modules: Optional[dict] = None,
init_cfg: Optional[dict] = None):
BaseModule.__init__(self, init_cfg)
nn.ModuleDict.__init__(self, modules)

View File

@ -5,7 +5,7 @@ import torch
from torch import nn
from mmengine.logging.logger import MMLogger
from mmengine.model.base_module import BaseModule
from mmengine.model import BaseModule, ModuleDict, ModuleList, Sequential
from mmengine.registry import Registry, build_from_cfg
COMPONENTS = Registry('component')
@ -195,3 +195,164 @@ class TestBaseModule(TestCase):
assert len(os.listdir(dump_dir)) == 1
assert os.stat(log_path).st_size != 0
shutil.rmtree(dump_dir)
class TestModuleList(TestCase):
def test_modulelist_weight_init(self):
models_cfg = [
dict(
type='FooConv1d',
init_cfg=dict(
type='Constant', layer='Conv1d', val=0., bias=1.)),
dict(
type='FooConv2d',
init_cfg=dict(
type='Constant', layer='Conv2d', val=2., bias=3.)),
]
layers = [build_from_cfg(cfg, COMPONENTS) for cfg in models_cfg]
modellist = ModuleList(layers)
modellist.init_weights()
self.assertTrue(
torch.equal(modellist[0].conv1d.weight,
torch.full(modellist[0].conv1d.weight.shape, 0.)))
self.assertTrue(
torch.equal(modellist[0].conv1d.bias,
torch.full(modellist[0].conv1d.bias.shape, 1.)))
self.assertTrue(
torch.equal(modellist[1].conv2d.weight,
torch.full(modellist[1].conv2d.weight.shape, 2.)))
self.assertTrue(
torch.equal(modellist[1].conv2d.bias,
torch.full(modellist[1].conv2d.bias.shape, 3.)))
# inner init_cfg has higher priority
layers = [build_from_cfg(cfg, COMPONENTS) for cfg in models_cfg]
modellist = ModuleList(
layers,
init_cfg=dict(
type='Constant', layer=['Conv1d', 'Conv2d'], val=4., bias=5.))
modellist.init_weights()
self.assertTrue(
torch.equal(modellist[0].conv1d.weight,
torch.full(modellist[0].conv1d.weight.shape, 0.)))
self.assertTrue(
torch.equal(modellist[0].conv1d.bias,
torch.full(modellist[0].conv1d.bias.shape, 1.)))
self.assertTrue(
torch.equal(modellist[1].conv2d.weight,
torch.full(modellist[1].conv2d.weight.shape, 2.)))
self.assertTrue(
torch.equal(modellist[1].conv2d.bias,
torch.full(modellist[1].conv2d.bias.shape, 3.)))
class TestModuleDict(TestCase):
def test_moduledict_weight_init(self):
models_cfg = dict(
foo_conv_1d=dict(
type='FooConv1d',
init_cfg=dict(
type='Constant', layer='Conv1d', val=0., bias=1.)),
foo_conv_2d=dict(
type='FooConv2d',
init_cfg=dict(
type='Constant', layer='Conv2d', val=2., bias=3.)),
)
layers = {
name: build_from_cfg(cfg, COMPONENTS)
for name, cfg in models_cfg.items()
}
modeldict = ModuleDict(layers)
modeldict.init_weights()
self.assertTrue(
torch.equal(
modeldict['foo_conv_1d'].conv1d.weight,
torch.full(modeldict['foo_conv_1d'].conv1d.weight.shape, 0.)))
self.assertTrue(
torch.equal(
modeldict['foo_conv_1d'].conv1d.bias,
torch.full(modeldict['foo_conv_1d'].conv1d.bias.shape, 1.)))
self.assertTrue(
torch.equal(
modeldict['foo_conv_2d'].conv2d.weight,
torch.full(modeldict['foo_conv_2d'].conv2d.weight.shape, 2.)))
self.assertTrue(
torch.equal(
modeldict['foo_conv_2d'].conv2d.bias,
torch.full(modeldict['foo_conv_2d'].conv2d.bias.shape, 3.)))
# inner init_cfg has higher priority
layers = {
name: build_from_cfg(cfg, COMPONENTS)
for name, cfg in models_cfg.items()
}
modeldict = ModuleDict(
layers,
init_cfg=dict(
type='Constant', layer=['Conv1d', 'Conv2d'], val=4., bias=5.))
modeldict.init_weights()
self.assertTrue(
torch.equal(
modeldict['foo_conv_1d'].conv1d.weight,
torch.full(modeldict['foo_conv_1d'].conv1d.weight.shape, 0.)))
self.assertTrue(
torch.equal(
modeldict['foo_conv_1d'].conv1d.bias,
torch.full(modeldict['foo_conv_1d'].conv1d.bias.shape, 1.)))
self.assertTrue(
torch.equal(
modeldict['foo_conv_2d'].conv2d.weight,
torch.full(modeldict['foo_conv_2d'].conv2d.weight.shape, 2.)))
self.assertTrue(
torch.equal(
modeldict['foo_conv_2d'].conv2d.bias,
torch.full(modeldict['foo_conv_2d'].conv2d.bias.shape, 3.)))
class TestSequential(TestCase):
def test_sequential_model_weight_init(self):
seq_model_cfg = [
dict(
type='FooConv1d',
init_cfg=dict(
type='Constant', layer='Conv1d', val=0., bias=1.)),
dict(
type='FooConv2d',
init_cfg=dict(
type='Constant', layer='Conv2d', val=2., bias=3.)),
]
layers = [build_from_cfg(cfg, COMPONENTS) for cfg in seq_model_cfg]
seq_model = Sequential(*layers)
seq_model.init_weights()
self.assertTrue(
torch.equal(seq_model[0].conv1d.weight,
torch.full(seq_model[0].conv1d.weight.shape, 0.)))
self.assertTrue(
torch.equal(seq_model[0].conv1d.bias,
torch.full(seq_model[0].conv1d.bias.shape, 1.)))
self.assertTrue(
torch.equal(seq_model[1].conv2d.weight,
torch.full(seq_model[1].conv2d.weight.shape, 2.)))
self.assertTrue(
torch.equal(seq_model[1].conv2d.bias,
torch.full(seq_model[1].conv2d.bias.shape, 3.)))
# inner init_cfg has higher priority
layers = [build_from_cfg(cfg, COMPONENTS) for cfg in seq_model_cfg]
seq_model = Sequential(
*layers,
init_cfg=dict(
type='Constant', layer=['Conv1d', 'Conv2d'], val=4., bias=5.))
seq_model.init_weights()
self.assertTrue(
torch.equal(seq_model[0].conv1d.weight,
torch.full(seq_model[0].conv1d.weight.shape, 0.)))
self.assertTrue(
torch.equal(seq_model[0].conv1d.bias,
torch.full(seq_model[0].conv1d.bias.shape, 1.)))
self.assertTrue(
torch.equal(seq_model[1].conv2d.weight,
torch.full(seq_model[1].conv2d.weight.shape, 2.)))
self.assertTrue(
torch.equal(seq_model[1].conv2d.bias,
torch.full(seq_model[1].conv2d.bias.shape, 3.)))