mirror of https://github.com/open-mmlab/mmcv.git
[Feature] support to calculate FLOPs of GN, IN, LN (#897)
* [Feature] support to calculate FLOPs of GN, IN, LN * Update flops_counter.py * Update flops_counter.pypull/902/head
parent
00870b9c4e
commit
97730c2316
|
@ -56,7 +56,8 @@ def get_model_complexity_info(model,
|
|||
``nn.AdaptiveMaxPool3d``, ``nn.AdaptiveAvgPool1d``,
|
||||
``nn.AdaptiveAvgPool2d``, ``nn.AdaptiveAvgPool3d``.
|
||||
- BatchNorms: ``nn.BatchNorm1d``, ``nn.BatchNorm2d``,
|
||||
``nn.BatchNorm3d``.
|
||||
``nn.BatchNorm3d``, ``nn.GroupNorm``, ``nn.InstanceNorm1d``,
|
||||
``InstanceNorm2d``, ``InstanceNorm3d``, ``nn.LayerNorm``.
|
||||
- Linear: ``nn.Linear``.
|
||||
- Deconvolution: ``nn.ConvTranspose2d``.
|
||||
- Upsample: ``nn.Upsample``.
|
||||
|
@ -426,11 +427,12 @@ def pool_flops_counter_hook(module, input, output):
|
|||
module.__flops__ += int(np.prod(input.shape))
|
||||
|
||||
|
||||
def bn_flops_counter_hook(module, input, output):
|
||||
def norm_flops_counter_hook(module, input, output):
|
||||
input = input[0]
|
||||
|
||||
batch_flops = np.prod(input.shape)
|
||||
if module.affine:
|
||||
if (getattr(module, 'affine', False)
|
||||
or getattr(module, 'elementwise_affine', False)):
|
||||
batch_flops *= 2
|
||||
module.__flops__ += int(batch_flops)
|
||||
|
||||
|
@ -577,10 +579,15 @@ def get_modules_mapping():
|
|||
nn.AdaptiveAvgPool2d: pool_flops_counter_hook,
|
||||
nn.AdaptiveMaxPool3d: pool_flops_counter_hook,
|
||||
nn.AdaptiveAvgPool3d: pool_flops_counter_hook,
|
||||
# BNs
|
||||
nn.BatchNorm1d: bn_flops_counter_hook,
|
||||
nn.BatchNorm2d: bn_flops_counter_hook,
|
||||
nn.BatchNorm3d: bn_flops_counter_hook,
|
||||
# normalizations
|
||||
nn.BatchNorm1d: norm_flops_counter_hook,
|
||||
nn.BatchNorm2d: norm_flops_counter_hook,
|
||||
nn.BatchNorm3d: norm_flops_counter_hook,
|
||||
nn.GroupNorm: norm_flops_counter_hook,
|
||||
nn.InstanceNorm1d: norm_flops_counter_hook,
|
||||
nn.InstanceNorm2d: norm_flops_counter_hook,
|
||||
nn.InstanceNorm3d: norm_flops_counter_hook,
|
||||
nn.LayerNorm: norm_flops_counter_hook,
|
||||
# FC
|
||||
nn.Linear: linear_flops_counter_hook,
|
||||
mmcv.cnn.bricks.Linear: linear_flops_counter_hook,
|
||||
|
|
|
@ -32,9 +32,15 @@ gt_results = [
|
|||
{'model': nn.AdaptiveAvgPool1d(2), 'input': (3, 16), 'flops': 48.0, 'params': 0}, # noqa: E501
|
||||
{'model': nn.AdaptiveAvgPool2d(2), 'input': (3, 16, 16), 'flops': 768.0, 'params': 0}, # noqa: E501
|
||||
{'model': nn.AdaptiveAvgPool3d(2), 'input': (3, 3, 16, 16), 'flops': 2304.0, 'params': 0}, # noqa: E501
|
||||
{'model': nn.BatchNorm1d(3, 8), 'input': (3, 16), 'flops': 96.0, 'params': 6.0}, # noqa: E501
|
||||
{'model': nn.BatchNorm2d(3, 8), 'input': (3, 16, 16), 'flops': 1536.0, 'params': 6.0}, # noqa: E501
|
||||
{'model': nn.BatchNorm3d(3, 8), 'input': (3, 3, 16, 16), 'flops': 4608.0, 'params': 6.0}, # noqa: E501
|
||||
{'model': nn.BatchNorm1d(3), 'input': (3, 16), 'flops': 96.0, 'params': 6.0}, # noqa: E501
|
||||
{'model': nn.BatchNorm2d(3), 'input': (3, 16, 16), 'flops': 1536.0, 'params': 6.0}, # noqa: E501
|
||||
{'model': nn.BatchNorm3d(3), 'input': (3, 3, 16, 16), 'flops': 4608.0, 'params': 6.0}, # noqa: E501
|
||||
{'model': nn.GroupNorm(2, 6), 'input': (6, 16, 16), 'flops': 3072.0, 'params': 12.0}, # noqa: E501
|
||||
{'model': nn.InstanceNorm1d(3, affine=True), 'input': (3, 16), 'flops': 96.0, 'params': 6.0}, # noqa: E501
|
||||
{'model': nn.InstanceNorm2d(3, affine=True), 'input': (3, 16, 16), 'flops': 1536.0, 'params': 6.0}, # noqa: E501
|
||||
{'model': nn.InstanceNorm3d(3, affine=True), 'input': (3, 3, 16, 16), 'flops': 4608.0, 'params': 6.0}, # noqa: E501
|
||||
{'model': nn.LayerNorm((3, 16, 16)), 'input': (3, 16, 16), 'flops': 1536.0, 'params': 1536.0}, # noqa: E501
|
||||
{'model': nn.LayerNorm((3, 16, 16), elementwise_affine=False), 'input': (3, 16, 16), 'flops': 768.0, 'params': 0}, # noqa: E501
|
||||
{'model': nn.Linear(1024, 2), 'input': (1024, ), 'flops': 2048.0, 'params': 2050.0}, # noqa: E501
|
||||
{'model': nn.ConvTranspose2d(3, 8, 3), 'input': (3, 16, 16), 'flops': 57888, 'params': 224.0}, # noqa: E501
|
||||
{'model': nn.Upsample((32, 32)), 'input': (3, 16, 16), 'flops': 3072.0, 'params': 0} # noqa: E501
|
||||
|
|
Loading…
Reference in New Issue