mmengine/tests/test_analysis/test_flop_count.py

930 lines
29 KiB
Python
Raw Normal View History

[Feature] Support model complexity computation (#779) * [Feature] Add support model complexity computation * [Fix] fix lint error * [Feature] update print_helper * Update docstring * update api, docs, fix lint * fix lint * update doc and add test * update docstring * update docstring * update test * Update mmengine/analysis/print_helper.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update mmengine/analysis/print_helper.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update mmengine/analysis/print_helper.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update mmengine/analysis/print_helper.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update mmengine/analysis/print_helper.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update mmengine/analysis/print_helper.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update mmengine/analysis/print_helper.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update mmengine/analysis/print_helper.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update mmengine/analysis/print_helper.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update mmengine/analysis/print_helper.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update mmengine/analysis/print_helper.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update mmengine/analysis/print_helper.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update mmengine/analysis/print_helper.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update mmengine/analysis/print_helper.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update mmengine/analysis/complexity_analysis.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update docs/en/advanced_tutorials/model_analysis.md Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update docs/en/advanced_tutorials/model_analysis.md Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * update docs * update docs * update docs and docstring * update docs * update test withj mmlogger * Update mmengine/analysis/complexity_analysis.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update tests/test_analysis/test_activation_count.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Apply suggestions from code review Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * update test according to review * Apply suggestions from code review Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * fix lint * fix test * Apply suggestions from code review * fix API document * Update analysis.rst * rename variables * minor refinement * Apply suggestions from code review * fix lint * replace tabulate with existing rich * Apply suggestions from code review * indent * Update mmengine/analysis/complexity_analysis.py * Update mmengine/analysis/complexity_analysis.py * Update mmengine/analysis/complexity_analysis.py --------- Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> Co-authored-by: zhouzaida <zhouzaida@163.com>
2023-02-20 15:00:28 +08:00
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
# Modified from
# https://github.com/facebookresearch/fvcore/blob/main/tests/test_flop_count.py
# pyre-ignore-all-errors[2,3,14,53]
import typing
import unittest
from collections import Counter, defaultdict
from typing import Any, Dict, Tuple
import torch
import torch.nn as nn
from torch.autograd.function import Function
from torch.nn import functional as F
from mmengine.analysis import FlopAnalyzer, flop_count
from mmengine.analysis.complexity_analysis import _DEFAULT_SUPPORTED_FLOP_OPS
from mmengine.analysis.jit_handles import Handle
class _CustomOp(Function):
@staticmethod
def forward(ctx, input: torch.Tensor) -> torch.Tensor:
return input
@staticmethod
def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
return torch.ones_like(grad_output)
class ThreeNet(nn.Module):
"""A network with three layers.
This is used for testing a network with more than one operation. The
network has a convolution layer followed by two fully connected layers.
"""
def __init__(self, input_dim: int, conv_dim: int, linear_dim: int) -> None:
super().__init__()
self.conv = nn.Conv2d(input_dim, conv_dim, 1, 1)
out_dim = 1
self.pool = nn.AdaptiveAvgPool2d((out_dim, out_dim))
self.linear1 = nn.Linear(conv_dim, linear_dim)
self.linear2 = nn.Linear(linear_dim, 1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.conv(x)
x = self.pool(x)
x = torch.flatten(x, 1)
x = self.linear1(x)
x = self.linear2(x)
return x
class ConvNet(nn.Module):
"""A network with a single convolution layer.
This is used for testing flop count for convolution layers.
"""
def __init__(
self,
conv_dim: int,
input_dim: int,
output_dim: int,
kernel_size: int,
stride: int,
padding: int,
groups_num: int,
transpose: bool = False,
output_padding: int = 0,
) -> None:
super().__init__()
if transpose:
conv_layers = [
nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d
]
kwargs = {'output_padding': output_padding}
else:
conv_layers = [nn.Conv1d, nn.Conv2d, nn.Conv3d]
assert (output_padding == 0), 'output_padding is not supported for'
' un-transposed convolutions.'
kwargs = {}
ConvLayer = conv_layers[conv_dim - 1]
self.conv: nn.Module = ConvLayer(
input_dim,
output_dim,
kernel_size,
stride,
padding,
groups=groups_num,
**kwargs,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.conv(x)
return x
class LinearNet(nn.Module):
"""A network with a single fully connected layer.
This is used for testing flop count for fully connected layers.
"""
def __init__(self, input_dim: int, output_dim: int) -> None:
super().__init__()
self.linear = nn.Linear(input_dim, output_dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.linear(x)
return x
class EinsumNet(nn.Module):
"""A network with a single torch.einsum operation.
This is used for testing flop count for torch.einsum.
"""
def __init__(self, equation: str) -> None:
super().__init__()
self.eq: str = equation
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
x = torch.einsum(self.eq, x, y)
return x
class MatmulNet(nn.Module):
"""A network with a single torch.matmul operation.
This is used for testing flop count for torch.matmul.
"""
def __init__(self) -> None:
super().__init__()
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
x = torch.matmul(x, y)
return x
class BMMNet(nn.Module):
"""A network with a single torch.bmm operation.
This is used for testing flop count for torch.bmm.
"""
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
x = torch.bmm(x, y)
return x
class CustomNet(nn.Module):
"""A network with a fully connected layer followed by a sigmoid layer.
This is used for testing customized operation handles.
"""
def __init__(self, input_dim: int, output_dim: int) -> None:
super().__init__()
self.conv = nn.Linear(input_dim, output_dim)
self.sigmoid = nn.Sigmoid()
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.conv(x)
x = self.sigmoid(x)
return x
class TestFlopAnalyzer(unittest.TestCase):
"""Unittest for flop_count."""
def setUp(self) -> None:
# nn.Linear uses a different operator based on version, so make sure
# we are testing the right thing.
lin = nn.Linear(10, 10)
lin_x: torch.Tensor = torch.randn(10, 10)
trace = torch.jit.trace(lin, (lin_x, ))
node_kinds = [node.kind() for node in trace.graph.nodes()]
assert 'aten::addmm' in node_kinds or 'aten::linear' in node_kinds
if 'aten::addmm' in node_kinds:
self.lin_op = 'addmm'
else:
self.lin_op = 'linear'
def test_customized_ops(self) -> None:
"""Test the use of customized operation handles.
The first test checks the case when a new handle for a new operation is
passed as an argument. The second case checks when a new handle for a
default operation is passed. The new handle should overwrite the
default handle.
"""
# New handle for a new operation.
def dummy_sigmoid_flop_jit(
inputs: typing.List[Any],
outputs: typing.List[Any]) -> typing.Counter[str]:
"""A dummy handle function for sigmoid.
Note the handle here does not compute actual flop count. This is
used for test only.
"""
flop_dict = Counter() # type: Counter
flop_dict['sigmoid'] = 10000
return flop_dict
batch_size = 10
input_dim = 5
output_dim = 4
custom_net = CustomNet(input_dim, output_dim)
custom_ops: Dict[str, Handle] = {
'aten::sigmoid': dummy_sigmoid_flop_jit
}
x = torch.rand(batch_size, input_dim)
flop_dict1, _ = flop_count(custom_net, (x, ), supported_ops=custom_ops)
flop_sigmoid = 10000 / 1e9
self.assertEqual(
flop_dict1['sigmoid'],
flop_sigmoid,
'Customized operation handle failed to pass the flop count test.',
)
# New handle that overwrites a default handle addmm. So now the new
# handle counts flops for the fully connected layer.
def addmm_dummy_flop_jit(
inputs: typing.List[object],
outputs: typing.List[object]) -> typing.Counter[str]:
"""A dummy handle function for fully connected layers.
This overwrites the default handle. Note the handle here does not
compute actual flop count. This is used for test only.
"""
flop_dict = Counter() # type: Counter
flop_dict[self.lin_op] = 400000
return flop_dict
custom_ops2: Dict[str, Handle] = {
f'aten::{self.lin_op}': addmm_dummy_flop_jit
}
flop_dict2, _ = flop_count(
custom_net, (x, ), supported_ops=custom_ops2)
flop = 400000 / 1e9
self.assertEqual(
flop_dict2[self.lin_op],
flop,
'Customized operation handle failed to pass the flop count test.',
)
def test_nn(self) -> None:
"""Test a model which is a pre-defined nn.module without defining a new
customized network."""
batch_size = 5
input_dim = 8
output_dim = 4
x = torch.randn(batch_size, input_dim)
flop_dict, _ = flop_count(nn.Linear(input_dim, output_dim), (x, ))
gt_flop = batch_size * input_dim * output_dim / 1e9
gt_dict = defaultdict(float)
gt_dict[self.lin_op] = gt_flop
self.assertDictEqual(flop_dict, gt_dict,
'nn.Linear failed to pass the flop count test.')
def test_skip_ops(self) -> None:
"""Test the return of skipped operations."""
batch_size = 10
input_dim = 5
output_dim = 4
custom_net = CustomNet(input_dim, output_dim)
x = torch.rand(batch_size, input_dim)
_, skip_dict = flop_count(custom_net, (x, ))
gt_dict = Counter() # type: Counter
gt_dict['aten::sigmoid'] = 1
self.assertDictEqual(
skip_dict, gt_dict,
'Skipped operations failed to pass the flop count test.')
def test_linear(self) -> None:
"""Test a network with a single fully connected layer."""
batch_size = 5
input_dim = 10
output_dim = 20
linear_net = LinearNet(input_dim, output_dim)
x = torch.randn(batch_size, input_dim)
flop_dict, _ = flop_count(linear_net, (x, ))
gt_flop = batch_size * input_dim * output_dim / 1e9
gt_dict = defaultdict(float)
gt_dict[self.lin_op] = gt_flop
self.assertDictEqual(
flop_dict,
gt_dict,
'Fully connected layer failed to pass the flop count test.',
)
# Test with #input_dims>2
if self.lin_op != 'linear':
# Skip this test if nn.Linear doesn't use aten::linear
# TODO: Stop skipping when multidimensional aten::matmul
# flop counting is implemented
return
extra_dim = 5
x = torch.randn(batch_size, extra_dim, input_dim)
flop_dict, _ = flop_count(linear_net, (x, ))
gt_flop = batch_size * input_dim * extra_dim * output_dim / 1e9
gt_dict = defaultdict(float)
gt_dict[self.lin_op] = gt_flop
self.assertDictEqual(
flop_dict,
gt_dict,
'Fully connected layer failed to pass the flop count test.',
)
def test_conv(self) -> None:
"""Test a network with a single convolution layer.
The test cases are: 1) 2D convolution; 2) 2D convolution with change in
spatial dimensions; 3) group convolution; 4) depthwise convolution 5)
1d convolution; 6) 3d convolution.
"""
def _test_conv(
conv_dim: int,
batch_size: int,
input_dim: int,
output_dim: int,
spatial_dim: int,
kernel_size: int,
padding: int,
stride: int,
group_size: int,
transpose: bool = False,
output_padding: int = 0,
) -> None:
convNet = ConvNet(
conv_dim,
input_dim,
output_dim,
kernel_size,
stride,
padding,
group_size,
transpose,
output_padding,
)
assert conv_dim in [
1, 2, 3
], 'Convolution dimension needs to be 1, 2, or 3'
if conv_dim == 1:
x = torch.randn(batch_size, input_dim, spatial_dim)
elif conv_dim == 2:
x = torch.randn(batch_size, input_dim, spatial_dim,
spatial_dim)
else:
x = torch.randn(batch_size, input_dim, spatial_dim,
spatial_dim, spatial_dim)
flop_dict, _ = flop_count(convNet, (x, ))
if transpose:
spatial_size = spatial_dim
else:
spatial_size = (
(spatial_dim + 2 * padding) - kernel_size) // stride + 1
gt_flop = (
batch_size * input_dim * output_dim * (kernel_size**conv_dim) *
(spatial_size**conv_dim) / group_size / 1e9)
gt_dict = defaultdict(float)
gt_dict['conv'] = gt_flop
self.assertDictEqual(
flop_dict,
gt_dict,
'Convolution layer failed to pass the flop count test.',
)
# Test flop count for 2d convolution.
conv_dim1 = 2
batch_size1 = 5
input_dim1 = 10
output_dim1 = 3
spatial_dim1 = 15
kernel_size1 = 3
padding1 = 1
stride1 = 1
group_size1 = 1
_test_conv(
conv_dim1,
batch_size1,
input_dim1,
output_dim1,
spatial_dim1,
kernel_size1,
padding1,
stride1,
group_size1,
)
# Test flop count for convolution with spatial change in output.
conv_dim2 = 2
batch_size2 = 2
input_dim2 = 10
output_dim2 = 6
spatial_dim2 = 20
kernel_size2 = 3
padding2 = 1
stride2 = 2
group_size2 = 1
_test_conv(
conv_dim2,
batch_size2,
input_dim2,
output_dim2,
spatial_dim2,
kernel_size2,
padding2,
stride2,
group_size2,
)
# Test flop count for group convolution.
conv_dim3 = 2
batch_size3 = 5
input_dim3 = 16
output_dim3 = 8
spatial_dim3 = 15
kernel_size3 = 5
padding3 = 2
stride3 = 1
group_size3 = 4
_test_conv(
conv_dim3,
batch_size3,
input_dim3,
output_dim3,
spatial_dim3,
kernel_size3,
padding3,
stride3,
group_size3,
)
# Test the special case of group convolution when group = output_dim.
# This is equivalent to depthwise convolution.
conv_dim4 = 2
batch_size4 = 5
input_dim4 = 16
output_dim4 = 8
spatial_dim4 = 15
kernel_size4 = 3
padding4 = 1
stride4 = 1
group_size4 = output_dim4
_test_conv(
conv_dim4,
batch_size4,
input_dim4,
output_dim4,
spatial_dim4,
kernel_size4,
padding4,
stride4,
group_size4,
)
# Test flop count for 1d convolution.
conv_dim5 = 1
batch_size5 = 5
input_dim5 = 10
output_dim5 = 3
spatial_dim5 = 15
kernel_size5 = 3
padding5 = 1
stride5 = 1
group_size5 = 1
_test_conv(
conv_dim5,
batch_size5,
input_dim5,
output_dim5,
spatial_dim5,
kernel_size5,
padding5,
stride5,
group_size5,
)
# Test flop count for 3d convolution.
conv_dim6 = 3
batch_size6 = 5
input_dim6 = 10
output_dim6 = 3
spatial_dim6 = 15
kernel_size6 = 3
padding6 = 1
stride6 = 1
group_size6 = 1
_test_conv(
conv_dim6,
batch_size6,
input_dim6,
output_dim6,
spatial_dim6,
kernel_size6,
padding6,
stride6,
group_size6,
)
# Test flop count for transposed 2d convolution.
conv_dim7 = 2
batch_size7 = 5
input_dim7 = 10
output_dim7 = 3
spatial_dim7 = 15
kernel_size7 = 3
padding7 = 1
stride7 = 1
group_size7 = 1
_test_conv(
conv_dim7,
batch_size7,
input_dim7,
output_dim7,
spatial_dim7,
kernel_size7,
padding7,
stride7,
group_size7,
transpose=True,
)
# Test flop count for strided transposed 2d convolution.
conv_dim8 = 2
batch_size8 = 5
input_dim8 = 10
output_dim8 = 3
spatial_dim8 = 15
kernel_size8 = 3
padding8 = 1
stride8 = 2
group_size8 = 1
_test_conv(
conv_dim8,
batch_size8,
input_dim8,
output_dim8,
spatial_dim8,
kernel_size8,
padding8,
stride8,
group_size8,
transpose=True,
)
# Test flop count for strided transposed 2d convolution
# w/ output_padding.
conv_dim9 = 2
batch_size9 = 5
input_dim9 = 10
output_dim9 = 3
spatial_dim9 = 15
kernel_size9 = 3
padding9 = 1
stride9 = 3
group_size9 = 1
output_padding9 = 2
_test_conv(
conv_dim9,
batch_size9,
input_dim9,
output_dim9,
spatial_dim9,
kernel_size9,
padding9,
stride9,
group_size9,
transpose=True,
output_padding=output_padding9,
)
"""Test flop count for operation matmul."""
m = 20
n = 10
p = 100
m_net = MatmulNet()
x = torch.randn(m, n)
y = torch.randn(n, p)
flop_dict, _ = flop_count(m_net, (x, y))
gt_flop = m * n * p / 1e9
gt_dict = defaultdict(float)
gt_dict['matmul'] = gt_flop
self.assertDictEqual(
flop_dict, gt_dict,
'Matmul operation failed to pass the flop count test.')
# Test with single dimension y
y = torch.randn(n)
gt_dict['matmul'] = m * n * 1 / 1e9
flop_dict, _ = flop_count(m_net, (x, y))
self.assertDictEqual(
flop_dict, gt_dict,
[Feature] Support model complexity computation (#779) * [Feature] Add support model complexity computation * [Fix] fix lint error * [Feature] update print_helper * Update docstring * update api, docs, fix lint * fix lint * update doc and add test * update docstring * update docstring * update test * Update mmengine/analysis/print_helper.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update mmengine/analysis/print_helper.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update mmengine/analysis/print_helper.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update mmengine/analysis/print_helper.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update mmengine/analysis/print_helper.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update mmengine/analysis/print_helper.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update mmengine/analysis/print_helper.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update mmengine/analysis/print_helper.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update mmengine/analysis/print_helper.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update mmengine/analysis/print_helper.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update mmengine/analysis/print_helper.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update mmengine/analysis/print_helper.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update mmengine/analysis/print_helper.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update mmengine/analysis/print_helper.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update mmengine/analysis/complexity_analysis.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update docs/en/advanced_tutorials/model_analysis.md Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update docs/en/advanced_tutorials/model_analysis.md Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * update docs * update docs * update docs and docstring * update docs * update test withj mmlogger * Update mmengine/analysis/complexity_analysis.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update tests/test_analysis/test_activation_count.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Apply suggestions from code review Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * update test according to review * Apply suggestions from code review Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * fix lint * fix test * Apply suggestions from code review * fix API document * Update analysis.rst * rename variables * minor refinement * Apply suggestions from code review * fix lint * replace tabulate with existing rich * Apply suggestions from code review * indent * Update mmengine/analysis/complexity_analysis.py * Update mmengine/analysis/complexity_analysis.py * Update mmengine/analysis/complexity_analysis.py --------- Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> Co-authored-by: zhouzaida <zhouzaida@163.com>
2023-02-20 15:00:28 +08:00
'Matmul operation failed to pass the flop count test.')
def test_matmul_broadcast(self) -> None:
"""Test flop count for operation matmul."""
m = 20
n = 10
p = 100
m_net = MatmulNet()
x = torch.randn(1, m, n)
y = torch.randn(1, n, p)
flop_dict, _ = flop_count(m_net, (x, y))
gt_flop = m * n * p / 1e9
gt_dict = defaultdict(float)
gt_dict['matmul'] = gt_flop
self.assertDictEqual(
flop_dict, gt_dict,
'Matmul operation failed to pass the flop count test.')
x = torch.randn(2, 2, m, n)
y = torch.randn(2, 2, n, p)
flop_dict, _ = flop_count(m_net, (x, y))
gt_flop = 4 * m * n * p / 1e9
gt_dict = defaultdict(float)
gt_dict['matmul'] = gt_flop
self.assertDictEqual(
flop_dict, gt_dict,
'Matmul operation failed to pass the flop count test.')
x = torch.randn(1, m, n)
y = torch.randn(n, p)
flop_dict, _ = flop_count(m_net, (x, y))
gt_flop = m * n * p / 1e9
gt_dict = defaultdict(float)
gt_dict['matmul'] = gt_flop
self.assertDictEqual(
flop_dict, gt_dict,
'Matmul operation failed to pass the flop count test.')
x = torch.randn(2, m, n)
y = torch.randn(n, p)
flop_dict, _ = flop_count(m_net, (x, y))
gt_flop = 2 * m * n * p / 1e9
gt_dict = defaultdict(float)
gt_dict['matmul'] = gt_flop
self.assertDictEqual(
flop_dict, gt_dict,
'Matmul operation failed to pass the flop count test.')
def test_bmm(self) -> None:
"""Test flop count for operation torch.bmm.
The case checks torch.bmm with equation nct,ntp->ncp.
"""
n = 2
c = 5
t = 2
p = 12
e_net = BMMNet()
x = torch.randn(n, c, t)
y = torch.randn(n, t, p)
flop_dict, _ = flop_count(e_net, (x, y))
gt_flop = n * t * p * c / 1e9
gt_dict = defaultdict(float)
gt_dict['bmm'] = gt_flop
self.assertDictEqual(
flop_dict,
gt_dict,
'bmm operation nct,ncp->ntp failed to pass the flop count test.',
)
def test_einsum(self) -> None:
"""Test flop count for operation torch.einsum.
The first case checks torch.einsum with equation nct,ncp->ntp. The
second case checks torch.einsum with equation "ntg,ncg->nct".
"""
equation = 'nct,ncp->ntp'
n = 1
c = 5
t = 2
p = 12
e_net = EinsumNet(equation)
x = torch.randn(n, c, t)
y = torch.randn(n, c, p)
flop_dict, _ = flop_count(e_net, (x, y))
gt_flop = n * t * p * c / 1e9
gt_dict = defaultdict(float)
gt_dict['einsum'] = gt_flop
self.assertDictEqual(
flop_dict,
gt_dict,
'Einsum operation nct,ncp->ntp failed to pass flop count test.',
)
equation = 'ntg,ncg->nct'
g = 6
e_net = EinsumNet(equation)
x = torch.randn(n, t, g)
y = torch.randn(n, c, g)
flop_dict, _ = flop_count(e_net, (x, y))
gt_flop = n * t * g * c / 1e9
gt_dict = defaultdict(float)
gt_dict['einsum'] = gt_flop
self.assertDictEqual(
flop_dict,
gt_dict,
'Einsum operation ntg,ncg->nct failed to pass flop count test.',
)
def test_batchnorm(self) -> None:
"""Test flop count for operation batchnorm.
The test cases include BatchNorm1d, BatchNorm2d and BatchNorm3d.
"""
# Test for BatchNorm1d.
batch_size = 10
input_dim = 10
batch_1d = nn.BatchNorm1d(input_dim, affine=False).eval()
x = torch.randn(batch_size, input_dim)
flop_dict, _ = flop_count(batch_1d, (x, ))
gt_flop = batch_size * input_dim / 1e9
gt_dict = defaultdict(float)
gt_dict['batch_norm'] = gt_flop
self.assertDictEqual(
flop_dict, gt_dict,
'BatchNorm1d failed to pass the flop count test.')
# Test for BatchNorm2d.
batch_size = 10
input_dim = 10
spatial_dim_x = 5
spatial_dim_y = 5
batch_2d = nn.BatchNorm2d(input_dim, affine=False)
x = torch.randn(batch_size, input_dim, spatial_dim_x, spatial_dim_y)
flop_dict, _ = flop_count(batch_2d, (x, ))
gt_flop = 4 * batch_size * input_dim * spatial_dim_x * \
spatial_dim_y / 1e9
gt_dict = defaultdict(float)
gt_dict['batch_norm'] = gt_flop
self.assertDictEqual(
flop_dict, gt_dict,
'BatchNorm2d failed to pass the flop count test.')
# Test for BatchNorm3d.
batch_size = 10
input_dim = 10
spatial_dim_x = 5
spatial_dim_y = 5
spatial_dim_z = 5
batch_3d = nn.BatchNorm3d(input_dim, affine=False)
x = torch.randn(batch_size, input_dim, spatial_dim_x, spatial_dim_y,
spatial_dim_z)
flop_dict, _ = flop_count(batch_3d, (x, ))
gt_flop = (4 * batch_size * input_dim * spatial_dim_x * spatial_dim_y *
spatial_dim_z / 1e9)
gt_dict = defaultdict(float)
gt_dict['batch_norm'] = gt_flop
self.assertDictEqual(
flop_dict, gt_dict,
'BatchNorm3d failed to pass the flop count test.')
def test_threeNet(self) -> None:
"""Test a network with more than one layer.
The network has a convolution layer followed by two fully connected
layers.
"""
batch_size = 4
input_dim = 2
conv_dim = 5
spatial_dim = 10
linear_dim = 3
x = torch.randn(batch_size, input_dim, spatial_dim, spatial_dim)
three_net = ThreeNet(input_dim, conv_dim, linear_dim)
flop1 = batch_size * conv_dim * input_dim * spatial_dim * \
spatial_dim / 1e9
flop_linear1 = batch_size * conv_dim * linear_dim / 1e9
flop_linear2 = batch_size * linear_dim * 1 / 1e9
flop2 = flop_linear1 + flop_linear2
flop_dict, _ = flop_count(three_net, (x, ))
gt_dict = defaultdict(float)
gt_dict['conv'] = flop1
gt_dict[self.lin_op] = flop2
gt_dict['adaptive_avg_pool2d'] = 2e-6
self.assertDictEqual(
flop_dict,
gt_dict,
'The three-layer network failed to pass the flop count test.',
)
def test_flop_counter_class(self) -> None:
"""Test FlopAnalyzer."""
batch_size = 4
input_dim = 2
conv_dim = 5
spatial_dim = 10
linear_dim = 3
x = torch.randn(batch_size, input_dim, spatial_dim, spatial_dim)
three_net = ThreeNet(input_dim, conv_dim, linear_dim)
flop1 = batch_size * conv_dim * input_dim * spatial_dim * spatial_dim
flop_linear1 = batch_size * conv_dim * linear_dim
flop_linear2 = batch_size * linear_dim * 1
flop_counter = FlopAnalyzer(three_net, (x, ))
gt_dict = Counter({
'conv': flop1,
'linear1': flop_linear1,
'linear2': flop_linear2,
'pool': flop1 // input_dim,
})
gt_dict[''] = sum(gt_dict.values())
self.assertEqual(flop_counter.by_module(), gt_dict)
def test_autograd_function(self):
# test support on custom autograd function
class Mod(nn.Module):
def forward(self, x):
return _CustomOp.apply(x)
flop = FlopAnalyzer(Mod(), (torch.rand(4, 5), )).set_op_handle(
'prim::PythonOp._CustomOp', lambda *args, **kwargs: 42)
self.assertEqual(flop.total(), 42)
def test_scripted_function(self):
# Scripted function is not yet supported. It should produce a warning
def func(x):
return x @ x
class Mod(nn.Module):
def forward(self, x):
f = torch.jit.script(func)
return f(x * x)
flop = FlopAnalyzer(Mod(), (torch.rand(5, 5), ))
_ = flop.total()
self.assertIn('prim::CallFunction', flop.unsupported_ops())
class TestFlopCountHandles(unittest.TestCase):
def _count_function(self, func, inputs, name) -> Tuple[Any, Any]:
tensor_inputs = [x for x in inputs if isinstance(x, torch.Tensor)]
def f(*args):
return func(*inputs)
graph = torch.jit.trace(
f, tuple(tensor_inputs), check_trace=False).graph
nodes = [k for k in graph.nodes() if k.kind() == name]
self.assertEqual(len(nodes), 1)
node = nodes[0]
return list(node.inputs()), list(node.outputs())
def test_batch_norm(self):
op_name = 'aten::batch_norm'
counter = _DEFAULT_SUPPORTED_FLOP_OPS[op_name]
vec = torch.rand(2)
nodes = self._count_function(
F.batch_norm, (torch.rand(2, 2, 2, 2), vec, vec, vec, vec),
op_name)
self.assertEqual(counter(*nodes), 32)
nodes = self._count_function(
F.batch_norm,
(torch.rand(2, 2, 2, 2), vec, vec, None, None),
op_name,
)
self.assertEqual(counter(*nodes), 16)
nodes = self._count_function(
# training=True
F.batch_norm,
(torch.rand(2, 2, 2, 2), vec, vec, vec, vec, True),
op_name,
)
self.assertEqual(counter(*nodes), 80)
def test_group_norm(self):
op_name = 'aten::group_norm'
counter = _DEFAULT_SUPPORTED_FLOP_OPS[op_name]
vec = torch.rand(2)
nodes = self._count_function(F.group_norm,
(torch.rand(2, 2, 2, 2), 2, vec, vec),
op_name)
self.assertEqual(counter(*nodes), 80)
nodes = self._count_function(F.group_norm,
(torch.rand(2, 2, 2, 2), 2, None, None),
op_name)
self.assertEqual(counter(*nodes), 64)
def test_upsample(self):
op_name = 'aten::upsample_bilinear2d'
counter = _DEFAULT_SUPPORTED_FLOP_OPS[op_name]
nodes = self._count_function(
F.interpolate,
(torch.rand(2, 2, 2, 2), None, 2, 'bilinear', False), op_name)
self.assertEqual(counter(*nodes), 2**4 * 4 * 4)
def test_complicated_einsum(self):
op_name = 'aten::einsum'
counter = _DEFAULT_SUPPORTED_FLOP_OPS[op_name]
nodes = self._count_function(
torch.einsum,
('nc,nchw->hw', torch.rand(3, 4), torch.rand(3, 4, 2, 3)),
op_name,
)
self.assertEqual(counter(*nodes), 72.0)
def test_torch_mm(self):
for op_name, func in zip(['aten::mm', 'aten::matmul'],
[torch.mm, torch.matmul]):
counter = _DEFAULT_SUPPORTED_FLOP_OPS[op_name]
nodes = self._count_function(
func,
(torch.rand(3, 4), torch.rand(4, 5)),
op_name,
)
self.assertEqual(counter(*nodes), 60)