mmcv/tests/test_cnn/test_fuse_conv_bn.py
Cao Yuhang 926ac07bb8
Move fuse conv bn to mmcv (#382)
* move fuse conv bn to mmcv

* update doc

* update test conv bn

* rename

* fix doc and variable name

* change func name
2020-07-08 19:03:15 +08:00

16 lines
490 B
Python

import torch
import torch.nn as nn
from mmcv.cnn import ConvModule, fuse_conv_bn
def test_fuse_conv_bn():
inputs = torch.rand((1, 3, 5, 5))
modules = nn.ModuleList()
modules.append(nn.BatchNorm2d(3))
modules.append(ConvModule(3, 5, 3, norm_cfg=dict(type='BN')))
modules.append(ConvModule(5, 5, 3, norm_cfg=dict(type='BN')))
modules = nn.Sequential(*modules)
fused_modules = fuse_conv_bn(modules)
assert torch.equal(modules(inputs), fused_modules(inputs))