[Feature] Add ModuleList Sequential and ModuleDict (#299)

* add module list

* add module list

* fix docstring
pull/301/head
Mashiro 2022-06-13 13:51:07 +08:00 committed by GitHub
parent df0c510444
commit bcab813242
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 218 additions and 3 deletions

View File

@ -2,7 +2,7 @@
from .averaged_model import (ExponentialMovingAverage, MomentumAnnealingEMA, from .averaged_model import (ExponentialMovingAverage, MomentumAnnealingEMA,
StochasticWeightAverage) StochasticWeightAverage)
from .base_model import BaseDataPreprocessor, BaseModel, ImgDataPreprocessor from .base_model import BaseDataPreprocessor, BaseModel, ImgDataPreprocessor
from .base_module import BaseModule from .base_module import BaseModule, ModuleDict, ModuleList, Sequential
from .utils import detect_anomalous_params, merge_dict, stack_batch from .utils import detect_anomalous_params, merge_dict, stack_batch
from .wrappers import (MMDistributedDataParallel, from .wrappers import (MMDistributedDataParallel,
MMSeparateDistributedDataParallel, is_model_wrapper) MMSeparateDistributedDataParallel, is_model_wrapper)
@ -12,5 +12,6 @@ __all__ = [
'ExponentialMovingAverage', 'MomentumAnnealingEMA', 'BaseModel', 'ExponentialMovingAverage', 'MomentumAnnealingEMA', 'BaseModel',
'BaseDataPreprocessor', 'ImgDataPreprocessor', 'BaseDataPreprocessor', 'ImgDataPreprocessor',
'MMSeparateDistributedDataParallel', 'BaseModule', 'stack_batch', 'MMSeparateDistributedDataParallel', 'BaseModule', 'stack_batch',
'merge_dict', 'detect_anomalous_params' 'merge_dict', 'detect_anomalous_params', 'ModuleList', 'ModuleDict',
'Sequential'
] ]

View File

@ -5,6 +5,7 @@ import warnings
from abc import ABCMeta from abc import ABCMeta
from collections import defaultdict from collections import defaultdict
from logging import FileHandler from logging import FileHandler
from typing import Iterable, Optional
import torch.nn as nn import torch.nn as nn
@ -165,3 +166,55 @@ class BaseModule(nn.Module, metaclass=ABCMeta):
if self.init_cfg: if self.init_cfg:
s += f'\ninit_cfg={self.init_cfg}' s += f'\ninit_cfg={self.init_cfg}'
return s return s
class Sequential(BaseModule, nn.Sequential):
"""Sequential module in openmmlab.
Ensures that all modules in ``Sequential`` have a different initialization
strategy than the outer model
Args:
init_cfg (dict, optional): Initialization config dict.
"""
def __init__(self, *args, init_cfg: Optional[dict] = None):
BaseModule.__init__(self, init_cfg)
nn.Sequential.__init__(self, *args)
class ModuleList(BaseModule, nn.ModuleList):
"""ModuleList in openmmlab.
Ensures that all modules in ``ModuleList`` have a different initialization
strategy than the outer model
Args:
modules (iterable, optional): An iterable of modules to add.
init_cfg (dict, optional): Initialization config dict.
"""
def __init__(self,
modules: Optional[Iterable] = None,
init_cfg: Optional[dict] = None):
BaseModule.__init__(self, init_cfg)
nn.ModuleList.__init__(self, modules)
class ModuleDict(BaseModule, nn.ModuleDict):
"""ModuleDict in openmmlab.
Ensures that all modules in ``ModuleDict`` have a different initialization
strategy than the outer model
Args:
modules (dict, optional): A mapping (dictionary) of (string: module)
or an iterable of key-value pairs of type (string, module).
init_cfg (dict, optional): Initialization config dict.
"""
def __init__(self,
modules: Optional[dict] = None,
init_cfg: Optional[dict] = None):
BaseModule.__init__(self, init_cfg)
nn.ModuleDict.__init__(self, modules)

View File

@ -5,7 +5,7 @@ import torch
from torch import nn from torch import nn
from mmengine.logging.logger import MMLogger from mmengine.logging.logger import MMLogger
from mmengine.model.base_module import BaseModule from mmengine.model import BaseModule, ModuleDict, ModuleList, Sequential
from mmengine.registry import Registry, build_from_cfg from mmengine.registry import Registry, build_from_cfg
COMPONENTS = Registry('component') COMPONENTS = Registry('component')
@ -195,3 +195,164 @@ class TestBaseModule(TestCase):
assert len(os.listdir(dump_dir)) == 1 assert len(os.listdir(dump_dir)) == 1
assert os.stat(log_path).st_size != 0 assert os.stat(log_path).st_size != 0
shutil.rmtree(dump_dir) shutil.rmtree(dump_dir)
class TestModuleList(TestCase):
def test_modulelist_weight_init(self):
models_cfg = [
dict(
type='FooConv1d',
init_cfg=dict(
type='Constant', layer='Conv1d', val=0., bias=1.)),
dict(
type='FooConv2d',
init_cfg=dict(
type='Constant', layer='Conv2d', val=2., bias=3.)),
]
layers = [build_from_cfg(cfg, COMPONENTS) for cfg in models_cfg]
modellist = ModuleList(layers)
modellist.init_weights()
self.assertTrue(
torch.equal(modellist[0].conv1d.weight,
torch.full(modellist[0].conv1d.weight.shape, 0.)))
self.assertTrue(
torch.equal(modellist[0].conv1d.bias,
torch.full(modellist[0].conv1d.bias.shape, 1.)))
self.assertTrue(
torch.equal(modellist[1].conv2d.weight,
torch.full(modellist[1].conv2d.weight.shape, 2.)))
self.assertTrue(
torch.equal(modellist[1].conv2d.bias,
torch.full(modellist[1].conv2d.bias.shape, 3.)))
# inner init_cfg has higher priority
layers = [build_from_cfg(cfg, COMPONENTS) for cfg in models_cfg]
modellist = ModuleList(
layers,
init_cfg=dict(
type='Constant', layer=['Conv1d', 'Conv2d'], val=4., bias=5.))
modellist.init_weights()
self.assertTrue(
torch.equal(modellist[0].conv1d.weight,
torch.full(modellist[0].conv1d.weight.shape, 0.)))
self.assertTrue(
torch.equal(modellist[0].conv1d.bias,
torch.full(modellist[0].conv1d.bias.shape, 1.)))
self.assertTrue(
torch.equal(modellist[1].conv2d.weight,
torch.full(modellist[1].conv2d.weight.shape, 2.)))
self.assertTrue(
torch.equal(modellist[1].conv2d.bias,
torch.full(modellist[1].conv2d.bias.shape, 3.)))
class TestModuleDict(TestCase):
def test_moduledict_weight_init(self):
models_cfg = dict(
foo_conv_1d=dict(
type='FooConv1d',
init_cfg=dict(
type='Constant', layer='Conv1d', val=0., bias=1.)),
foo_conv_2d=dict(
type='FooConv2d',
init_cfg=dict(
type='Constant', layer='Conv2d', val=2., bias=3.)),
)
layers = {
name: build_from_cfg(cfg, COMPONENTS)
for name, cfg in models_cfg.items()
}
modeldict = ModuleDict(layers)
modeldict.init_weights()
self.assertTrue(
torch.equal(
modeldict['foo_conv_1d'].conv1d.weight,
torch.full(modeldict['foo_conv_1d'].conv1d.weight.shape, 0.)))
self.assertTrue(
torch.equal(
modeldict['foo_conv_1d'].conv1d.bias,
torch.full(modeldict['foo_conv_1d'].conv1d.bias.shape, 1.)))
self.assertTrue(
torch.equal(
modeldict['foo_conv_2d'].conv2d.weight,
torch.full(modeldict['foo_conv_2d'].conv2d.weight.shape, 2.)))
self.assertTrue(
torch.equal(
modeldict['foo_conv_2d'].conv2d.bias,
torch.full(modeldict['foo_conv_2d'].conv2d.bias.shape, 3.)))
# inner init_cfg has higher priority
layers = {
name: build_from_cfg(cfg, COMPONENTS)
for name, cfg in models_cfg.items()
}
modeldict = ModuleDict(
layers,
init_cfg=dict(
type='Constant', layer=['Conv1d', 'Conv2d'], val=4., bias=5.))
modeldict.init_weights()
self.assertTrue(
torch.equal(
modeldict['foo_conv_1d'].conv1d.weight,
torch.full(modeldict['foo_conv_1d'].conv1d.weight.shape, 0.)))
self.assertTrue(
torch.equal(
modeldict['foo_conv_1d'].conv1d.bias,
torch.full(modeldict['foo_conv_1d'].conv1d.bias.shape, 1.)))
self.assertTrue(
torch.equal(
modeldict['foo_conv_2d'].conv2d.weight,
torch.full(modeldict['foo_conv_2d'].conv2d.weight.shape, 2.)))
self.assertTrue(
torch.equal(
modeldict['foo_conv_2d'].conv2d.bias,
torch.full(modeldict['foo_conv_2d'].conv2d.bias.shape, 3.)))
class TestSequential(TestCase):
def test_sequential_model_weight_init(self):
seq_model_cfg = [
dict(
type='FooConv1d',
init_cfg=dict(
type='Constant', layer='Conv1d', val=0., bias=1.)),
dict(
type='FooConv2d',
init_cfg=dict(
type='Constant', layer='Conv2d', val=2., bias=3.)),
]
layers = [build_from_cfg(cfg, COMPONENTS) for cfg in seq_model_cfg]
seq_model = Sequential(*layers)
seq_model.init_weights()
self.assertTrue(
torch.equal(seq_model[0].conv1d.weight,
torch.full(seq_model[0].conv1d.weight.shape, 0.)))
self.assertTrue(
torch.equal(seq_model[0].conv1d.bias,
torch.full(seq_model[0].conv1d.bias.shape, 1.)))
self.assertTrue(
torch.equal(seq_model[1].conv2d.weight,
torch.full(seq_model[1].conv2d.weight.shape, 2.)))
self.assertTrue(
torch.equal(seq_model[1].conv2d.bias,
torch.full(seq_model[1].conv2d.bias.shape, 3.)))
# inner init_cfg has higher priority
layers = [build_from_cfg(cfg, COMPONENTS) for cfg in seq_model_cfg]
seq_model = Sequential(
*layers,
init_cfg=dict(
type='Constant', layer=['Conv1d', 'Conv2d'], val=4., bias=5.))
seq_model.init_weights()
self.assertTrue(
torch.equal(seq_model[0].conv1d.weight,
torch.full(seq_model[0].conv1d.weight.shape, 0.)))
self.assertTrue(
torch.equal(seq_model[0].conv1d.bias,
torch.full(seq_model[0].conv1d.bias.shape, 1.)))
self.assertTrue(
torch.equal(seq_model[1].conv2d.weight,
torch.full(seq_model[1].conv2d.weight.shape, 2.)))
self.assertTrue(
torch.equal(seq_model[1].conv2d.bias,
torch.full(seq_model[1].conv2d.bias.shape, 3.)))