[Feature] Add ModuleList Sequential and ModuleDict (#299)
* add module list * add module list * fix docstringpull/301/head
parent
df0c510444
commit
bcab813242
|
@ -2,7 +2,7 @@
|
||||||
from .averaged_model import (ExponentialMovingAverage, MomentumAnnealingEMA,
|
from .averaged_model import (ExponentialMovingAverage, MomentumAnnealingEMA,
|
||||||
StochasticWeightAverage)
|
StochasticWeightAverage)
|
||||||
from .base_model import BaseDataPreprocessor, BaseModel, ImgDataPreprocessor
|
from .base_model import BaseDataPreprocessor, BaseModel, ImgDataPreprocessor
|
||||||
from .base_module import BaseModule
|
from .base_module import BaseModule, ModuleDict, ModuleList, Sequential
|
||||||
from .utils import detect_anomalous_params, merge_dict, stack_batch
|
from .utils import detect_anomalous_params, merge_dict, stack_batch
|
||||||
from .wrappers import (MMDistributedDataParallel,
|
from .wrappers import (MMDistributedDataParallel,
|
||||||
MMSeparateDistributedDataParallel, is_model_wrapper)
|
MMSeparateDistributedDataParallel, is_model_wrapper)
|
||||||
|
@ -12,5 +12,6 @@ __all__ = [
|
||||||
'ExponentialMovingAverage', 'MomentumAnnealingEMA', 'BaseModel',
|
'ExponentialMovingAverage', 'MomentumAnnealingEMA', 'BaseModel',
|
||||||
'BaseDataPreprocessor', 'ImgDataPreprocessor',
|
'BaseDataPreprocessor', 'ImgDataPreprocessor',
|
||||||
'MMSeparateDistributedDataParallel', 'BaseModule', 'stack_batch',
|
'MMSeparateDistributedDataParallel', 'BaseModule', 'stack_batch',
|
||||||
'merge_dict', 'detect_anomalous_params'
|
'merge_dict', 'detect_anomalous_params', 'ModuleList', 'ModuleDict',
|
||||||
|
'Sequential'
|
||||||
]
|
]
|
||||||
|
|
|
@ -5,6 +5,7 @@ import warnings
|
||||||
from abc import ABCMeta
|
from abc import ABCMeta
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from logging import FileHandler
|
from logging import FileHandler
|
||||||
|
from typing import Iterable, Optional
|
||||||
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
|
@ -165,3 +166,55 @@ class BaseModule(nn.Module, metaclass=ABCMeta):
|
||||||
if self.init_cfg:
|
if self.init_cfg:
|
||||||
s += f'\ninit_cfg={self.init_cfg}'
|
s += f'\ninit_cfg={self.init_cfg}'
|
||||||
return s
|
return s
|
||||||
|
|
||||||
|
|
||||||
|
class Sequential(BaseModule, nn.Sequential):
|
||||||
|
"""Sequential module in openmmlab.
|
||||||
|
|
||||||
|
Ensures that all modules in ``Sequential`` have a different initialization
|
||||||
|
strategy than the outer model
|
||||||
|
|
||||||
|
Args:
|
||||||
|
init_cfg (dict, optional): Initialization config dict.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *args, init_cfg: Optional[dict] = None):
|
||||||
|
BaseModule.__init__(self, init_cfg)
|
||||||
|
nn.Sequential.__init__(self, *args)
|
||||||
|
|
||||||
|
|
||||||
|
class ModuleList(BaseModule, nn.ModuleList):
|
||||||
|
"""ModuleList in openmmlab.
|
||||||
|
|
||||||
|
Ensures that all modules in ``ModuleList`` have a different initialization
|
||||||
|
strategy than the outer model
|
||||||
|
|
||||||
|
Args:
|
||||||
|
modules (iterable, optional): An iterable of modules to add.
|
||||||
|
init_cfg (dict, optional): Initialization config dict.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
modules: Optional[Iterable] = None,
|
||||||
|
init_cfg: Optional[dict] = None):
|
||||||
|
BaseModule.__init__(self, init_cfg)
|
||||||
|
nn.ModuleList.__init__(self, modules)
|
||||||
|
|
||||||
|
|
||||||
|
class ModuleDict(BaseModule, nn.ModuleDict):
|
||||||
|
"""ModuleDict in openmmlab.
|
||||||
|
|
||||||
|
Ensures that all modules in ``ModuleDict`` have a different initialization
|
||||||
|
strategy than the outer model
|
||||||
|
|
||||||
|
Args:
|
||||||
|
modules (dict, optional): A mapping (dictionary) of (string: module)
|
||||||
|
or an iterable of key-value pairs of type (string, module).
|
||||||
|
init_cfg (dict, optional): Initialization config dict.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
modules: Optional[dict] = None,
|
||||||
|
init_cfg: Optional[dict] = None):
|
||||||
|
BaseModule.__init__(self, init_cfg)
|
||||||
|
nn.ModuleDict.__init__(self, modules)
|
||||||
|
|
|
@ -5,7 +5,7 @@ import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from mmengine.logging.logger import MMLogger
|
from mmengine.logging.logger import MMLogger
|
||||||
from mmengine.model.base_module import BaseModule
|
from mmengine.model import BaseModule, ModuleDict, ModuleList, Sequential
|
||||||
from mmengine.registry import Registry, build_from_cfg
|
from mmengine.registry import Registry, build_from_cfg
|
||||||
|
|
||||||
COMPONENTS = Registry('component')
|
COMPONENTS = Registry('component')
|
||||||
|
@ -195,3 +195,164 @@ class TestBaseModule(TestCase):
|
||||||
assert len(os.listdir(dump_dir)) == 1
|
assert len(os.listdir(dump_dir)) == 1
|
||||||
assert os.stat(log_path).st_size != 0
|
assert os.stat(log_path).st_size != 0
|
||||||
shutil.rmtree(dump_dir)
|
shutil.rmtree(dump_dir)
|
||||||
|
|
||||||
|
|
||||||
|
class TestModuleList(TestCase):
|
||||||
|
|
||||||
|
def test_modulelist_weight_init(self):
|
||||||
|
models_cfg = [
|
||||||
|
dict(
|
||||||
|
type='FooConv1d',
|
||||||
|
init_cfg=dict(
|
||||||
|
type='Constant', layer='Conv1d', val=0., bias=1.)),
|
||||||
|
dict(
|
||||||
|
type='FooConv2d',
|
||||||
|
init_cfg=dict(
|
||||||
|
type='Constant', layer='Conv2d', val=2., bias=3.)),
|
||||||
|
]
|
||||||
|
layers = [build_from_cfg(cfg, COMPONENTS) for cfg in models_cfg]
|
||||||
|
modellist = ModuleList(layers)
|
||||||
|
modellist.init_weights()
|
||||||
|
self.assertTrue(
|
||||||
|
torch.equal(modellist[0].conv1d.weight,
|
||||||
|
torch.full(modellist[0].conv1d.weight.shape, 0.)))
|
||||||
|
self.assertTrue(
|
||||||
|
torch.equal(modellist[0].conv1d.bias,
|
||||||
|
torch.full(modellist[0].conv1d.bias.shape, 1.)))
|
||||||
|
self.assertTrue(
|
||||||
|
torch.equal(modellist[1].conv2d.weight,
|
||||||
|
torch.full(modellist[1].conv2d.weight.shape, 2.)))
|
||||||
|
self.assertTrue(
|
||||||
|
torch.equal(modellist[1].conv2d.bias,
|
||||||
|
torch.full(modellist[1].conv2d.bias.shape, 3.)))
|
||||||
|
# inner init_cfg has higher priority
|
||||||
|
layers = [build_from_cfg(cfg, COMPONENTS) for cfg in models_cfg]
|
||||||
|
modellist = ModuleList(
|
||||||
|
layers,
|
||||||
|
init_cfg=dict(
|
||||||
|
type='Constant', layer=['Conv1d', 'Conv2d'], val=4., bias=5.))
|
||||||
|
modellist.init_weights()
|
||||||
|
self.assertTrue(
|
||||||
|
torch.equal(modellist[0].conv1d.weight,
|
||||||
|
torch.full(modellist[0].conv1d.weight.shape, 0.)))
|
||||||
|
self.assertTrue(
|
||||||
|
torch.equal(modellist[0].conv1d.bias,
|
||||||
|
torch.full(modellist[0].conv1d.bias.shape, 1.)))
|
||||||
|
self.assertTrue(
|
||||||
|
torch.equal(modellist[1].conv2d.weight,
|
||||||
|
torch.full(modellist[1].conv2d.weight.shape, 2.)))
|
||||||
|
self.assertTrue(
|
||||||
|
torch.equal(modellist[1].conv2d.bias,
|
||||||
|
torch.full(modellist[1].conv2d.bias.shape, 3.)))
|
||||||
|
|
||||||
|
|
||||||
|
class TestModuleDict(TestCase):
|
||||||
|
|
||||||
|
def test_moduledict_weight_init(self):
|
||||||
|
models_cfg = dict(
|
||||||
|
foo_conv_1d=dict(
|
||||||
|
type='FooConv1d',
|
||||||
|
init_cfg=dict(
|
||||||
|
type='Constant', layer='Conv1d', val=0., bias=1.)),
|
||||||
|
foo_conv_2d=dict(
|
||||||
|
type='FooConv2d',
|
||||||
|
init_cfg=dict(
|
||||||
|
type='Constant', layer='Conv2d', val=2., bias=3.)),
|
||||||
|
)
|
||||||
|
layers = {
|
||||||
|
name: build_from_cfg(cfg, COMPONENTS)
|
||||||
|
for name, cfg in models_cfg.items()
|
||||||
|
}
|
||||||
|
modeldict = ModuleDict(layers)
|
||||||
|
modeldict.init_weights()
|
||||||
|
self.assertTrue(
|
||||||
|
torch.equal(
|
||||||
|
modeldict['foo_conv_1d'].conv1d.weight,
|
||||||
|
torch.full(modeldict['foo_conv_1d'].conv1d.weight.shape, 0.)))
|
||||||
|
self.assertTrue(
|
||||||
|
torch.equal(
|
||||||
|
modeldict['foo_conv_1d'].conv1d.bias,
|
||||||
|
torch.full(modeldict['foo_conv_1d'].conv1d.bias.shape, 1.)))
|
||||||
|
self.assertTrue(
|
||||||
|
torch.equal(
|
||||||
|
modeldict['foo_conv_2d'].conv2d.weight,
|
||||||
|
torch.full(modeldict['foo_conv_2d'].conv2d.weight.shape, 2.)))
|
||||||
|
self.assertTrue(
|
||||||
|
torch.equal(
|
||||||
|
modeldict['foo_conv_2d'].conv2d.bias,
|
||||||
|
torch.full(modeldict['foo_conv_2d'].conv2d.bias.shape, 3.)))
|
||||||
|
# inner init_cfg has higher priority
|
||||||
|
layers = {
|
||||||
|
name: build_from_cfg(cfg, COMPONENTS)
|
||||||
|
for name, cfg in models_cfg.items()
|
||||||
|
}
|
||||||
|
modeldict = ModuleDict(
|
||||||
|
layers,
|
||||||
|
init_cfg=dict(
|
||||||
|
type='Constant', layer=['Conv1d', 'Conv2d'], val=4., bias=5.))
|
||||||
|
modeldict.init_weights()
|
||||||
|
self.assertTrue(
|
||||||
|
torch.equal(
|
||||||
|
modeldict['foo_conv_1d'].conv1d.weight,
|
||||||
|
torch.full(modeldict['foo_conv_1d'].conv1d.weight.shape, 0.)))
|
||||||
|
self.assertTrue(
|
||||||
|
torch.equal(
|
||||||
|
modeldict['foo_conv_1d'].conv1d.bias,
|
||||||
|
torch.full(modeldict['foo_conv_1d'].conv1d.bias.shape, 1.)))
|
||||||
|
self.assertTrue(
|
||||||
|
torch.equal(
|
||||||
|
modeldict['foo_conv_2d'].conv2d.weight,
|
||||||
|
torch.full(modeldict['foo_conv_2d'].conv2d.weight.shape, 2.)))
|
||||||
|
self.assertTrue(
|
||||||
|
torch.equal(
|
||||||
|
modeldict['foo_conv_2d'].conv2d.bias,
|
||||||
|
torch.full(modeldict['foo_conv_2d'].conv2d.bias.shape, 3.)))
|
||||||
|
|
||||||
|
|
||||||
|
class TestSequential(TestCase):
|
||||||
|
|
||||||
|
def test_sequential_model_weight_init(self):
|
||||||
|
seq_model_cfg = [
|
||||||
|
dict(
|
||||||
|
type='FooConv1d',
|
||||||
|
init_cfg=dict(
|
||||||
|
type='Constant', layer='Conv1d', val=0., bias=1.)),
|
||||||
|
dict(
|
||||||
|
type='FooConv2d',
|
||||||
|
init_cfg=dict(
|
||||||
|
type='Constant', layer='Conv2d', val=2., bias=3.)),
|
||||||
|
]
|
||||||
|
layers = [build_from_cfg(cfg, COMPONENTS) for cfg in seq_model_cfg]
|
||||||
|
seq_model = Sequential(*layers)
|
||||||
|
seq_model.init_weights()
|
||||||
|
self.assertTrue(
|
||||||
|
torch.equal(seq_model[0].conv1d.weight,
|
||||||
|
torch.full(seq_model[0].conv1d.weight.shape, 0.)))
|
||||||
|
self.assertTrue(
|
||||||
|
torch.equal(seq_model[0].conv1d.bias,
|
||||||
|
torch.full(seq_model[0].conv1d.bias.shape, 1.)))
|
||||||
|
self.assertTrue(
|
||||||
|
torch.equal(seq_model[1].conv2d.weight,
|
||||||
|
torch.full(seq_model[1].conv2d.weight.shape, 2.)))
|
||||||
|
self.assertTrue(
|
||||||
|
torch.equal(seq_model[1].conv2d.bias,
|
||||||
|
torch.full(seq_model[1].conv2d.bias.shape, 3.)))
|
||||||
|
# inner init_cfg has higher priority
|
||||||
|
layers = [build_from_cfg(cfg, COMPONENTS) for cfg in seq_model_cfg]
|
||||||
|
seq_model = Sequential(
|
||||||
|
*layers,
|
||||||
|
init_cfg=dict(
|
||||||
|
type='Constant', layer=['Conv1d', 'Conv2d'], val=4., bias=5.))
|
||||||
|
seq_model.init_weights()
|
||||||
|
self.assertTrue(
|
||||||
|
torch.equal(seq_model[0].conv1d.weight,
|
||||||
|
torch.full(seq_model[0].conv1d.weight.shape, 0.)))
|
||||||
|
self.assertTrue(
|
||||||
|
torch.equal(seq_model[0].conv1d.bias,
|
||||||
|
torch.full(seq_model[0].conv1d.bias.shape, 1.)))
|
||||||
|
self.assertTrue(
|
||||||
|
torch.equal(seq_model[1].conv2d.weight,
|
||||||
|
torch.full(seq_model[1].conv2d.weight.shape, 2.)))
|
||||||
|
self.assertTrue(
|
||||||
|
torch.equal(seq_model[1].conv2d.bias,
|
||||||
|
torch.full(seq_model[1].conv2d.bias.shape, 3.)))
|
||||||
|
|
Loading…
Reference in New Issue