mirror of https://github.com/open-mmlab/mmcv.git
64 lines
1.8 KiB
Python
64 lines
1.8 KiB
Python
import torch.nn as nn
|
|
|
|
import mmcv
|
|
from mmcv.cnn import MODELS, build_model_from_cfg
|
|
|
|
|
|
def test_build_model_from_cfg():
|
|
BACKBONES = mmcv.Registry('backbone', build_func=build_model_from_cfg)
|
|
|
|
@BACKBONES.register_module()
|
|
class ResNet(nn.Module):
|
|
|
|
def __init__(self, depth, stages=4):
|
|
super().__init__()
|
|
self.depth = depth
|
|
self.stages = stages
|
|
|
|
def forward(self, x):
|
|
return x
|
|
|
|
@BACKBONES.register_module()
|
|
class ResNeXt(nn.Module):
|
|
|
|
def __init__(self, depth, stages=4):
|
|
super().__init__()
|
|
self.depth = depth
|
|
self.stages = stages
|
|
|
|
def forward(self, x):
|
|
return x
|
|
|
|
cfg = dict(type='ResNet', depth=50)
|
|
model = BACKBONES.build(cfg)
|
|
assert isinstance(model, ResNet)
|
|
assert model.depth == 50 and model.stages == 4
|
|
|
|
cfg = dict(type='ResNeXt', depth=50, stages=3)
|
|
model = BACKBONES.build(cfg)
|
|
assert isinstance(model, ResNeXt)
|
|
assert model.depth == 50 and model.stages == 3
|
|
|
|
cfg = [
|
|
dict(type='ResNet', depth=50),
|
|
dict(type='ResNeXt', depth=50, stages=3)
|
|
]
|
|
model = BACKBONES.build(cfg)
|
|
assert isinstance(model, nn.Sequential)
|
|
assert isinstance(model[0], ResNet)
|
|
assert model[0].depth == 50 and model[0].stages == 4
|
|
assert isinstance(model[1], ResNeXt)
|
|
assert model[1].depth == 50 and model[1].stages == 3
|
|
|
|
# test inherit `build_func` from parent
|
|
NEW_MODELS = mmcv.Registry('models', parent=MODELS, scope='new')
|
|
assert NEW_MODELS.build_func is build_model_from_cfg
|
|
|
|
# test specify `build_func`
|
|
def pseudo_build(cfg):
|
|
return cfg
|
|
|
|
NEW_MODELS = mmcv.Registry(
|
|
'models', parent=MODELS, build_func=pseudo_build)
|
|
assert NEW_MODELS.build_func is pseudo_build
|