# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. # Modified from # https://github.com/facebookresearch/fvcore/blob/main/tests/test_flop_count.py # pyre-ignore-all-errors[2,3,14,53] import typing import unittest from collections import Counter, defaultdict from typing import Any, Dict, Tuple import torch import torch.nn as nn from torch.autograd.function import Function from torch.nn import functional as F from mmengine.analysis import FlopAnalyzer, flop_count from mmengine.analysis.complexity_analysis import _DEFAULT_SUPPORTED_FLOP_OPS from mmengine.analysis.jit_handles import Handle class _CustomOp(Function): @staticmethod def forward(ctx, input: torch.Tensor) -> torch.Tensor: return input @staticmethod def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor: return torch.ones_like(grad_output) class ThreeNet(nn.Module): """A network with three layers. This is used for testing a network with more than one operation. The network has a convolution layer followed by two fully connected layers. """ def __init__(self, input_dim: int, conv_dim: int, linear_dim: int) -> None: super().__init__() self.conv = nn.Conv2d(input_dim, conv_dim, 1, 1) out_dim = 1 self.pool = nn.AdaptiveAvgPool2d((out_dim, out_dim)) self.linear1 = nn.Linear(conv_dim, linear_dim) self.linear2 = nn.Linear(linear_dim, 1) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.conv(x) x = self.pool(x) x = torch.flatten(x, 1) x = self.linear1(x) x = self.linear2(x) return x class ConvNet(nn.Module): """A network with a single convolution layer. This is used for testing flop count for convolution layers. """ def __init__( self, conv_dim: int, input_dim: int, output_dim: int, kernel_size: int, stride: int, padding: int, groups_num: int, transpose: bool = False, output_padding: int = 0, ) -> None: super().__init__() if transpose: conv_layers = [ nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d ] kwargs = {'output_padding': output_padding} else: conv_layers = [nn.Conv1d, nn.Conv2d, nn.Conv3d] assert (output_padding == 0), 'output_padding is not supported for' ' un-transposed convolutions.' kwargs = {} ConvLayer = conv_layers[conv_dim - 1] self.conv: nn.Module = ConvLayer( input_dim, output_dim, kernel_size, stride, padding, groups=groups_num, **kwargs, ) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.conv(x) return x class LinearNet(nn.Module): """A network with a single fully connected layer. This is used for testing flop count for fully connected layers. """ def __init__(self, input_dim: int, output_dim: int) -> None: super().__init__() self.linear = nn.Linear(input_dim, output_dim) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.linear(x) return x class EinsumNet(nn.Module): """A network with a single torch.einsum operation. This is used for testing flop count for torch.einsum. """ def __init__(self, equation: str) -> None: super().__init__() self.eq: str = equation def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: x = torch.einsum(self.eq, x, y) return x class MatmulNet(nn.Module): """A network with a single torch.matmul operation. This is used for testing flop count for torch.matmul. """ def __init__(self) -> None: super().__init__() def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: x = torch.matmul(x, y) return x class BMMNet(nn.Module): """A network with a single torch.bmm operation. This is used for testing flop count for torch.bmm. """ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: x = torch.bmm(x, y) return x class CustomNet(nn.Module): """A network with a fully connected layer followed by a sigmoid layer. This is used for testing customized operation handles. """ def __init__(self, input_dim: int, output_dim: int) -> None: super().__init__() self.conv = nn.Linear(input_dim, output_dim) self.sigmoid = nn.Sigmoid() def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.conv(x) x = self.sigmoid(x) return x class TestFlopAnalyzer(unittest.TestCase): """Unittest for flop_count.""" def setUp(self) -> None: # nn.Linear uses a different operator based on version, so make sure # we are testing the right thing. lin = nn.Linear(10, 10) lin_x: torch.Tensor = torch.randn(10, 10) trace = torch.jit.trace(lin, (lin_x, )) node_kinds = [node.kind() for node in trace.graph.nodes()] assert 'aten::addmm' in node_kinds or 'aten::linear' in node_kinds if 'aten::addmm' in node_kinds: self.lin_op = 'addmm' else: self.lin_op = 'linear' def test_customized_ops(self) -> None: """Test the use of customized operation handles. The first test checks the case when a new handle for a new operation is passed as an argument. The second case checks when a new handle for a default operation is passed. The new handle should overwrite the default handle. """ # New handle for a new operation. def dummy_sigmoid_flop_jit( inputs: typing.List[Any], outputs: typing.List[Any]) -> typing.Counter[str]: """A dummy handle function for sigmoid. Note the handle here does not compute actual flop count. This is used for test only. """ flop_dict = Counter() # type: Counter flop_dict['sigmoid'] = 10000 return flop_dict batch_size = 10 input_dim = 5 output_dim = 4 custom_net = CustomNet(input_dim, output_dim) custom_ops: Dict[str, Handle] = { 'aten::sigmoid': dummy_sigmoid_flop_jit } x = torch.rand(batch_size, input_dim) flop_dict1, _ = flop_count(custom_net, (x, ), supported_ops=custom_ops) flop_sigmoid = 10000 / 1e9 self.assertEqual( flop_dict1['sigmoid'], flop_sigmoid, 'Customized operation handle failed to pass the flop count test.', ) # New handle that overwrites a default handle addmm. So now the new # handle counts flops for the fully connected layer. def addmm_dummy_flop_jit( inputs: typing.List[object], outputs: typing.List[object]) -> typing.Counter[str]: """A dummy handle function for fully connected layers. This overwrites the default handle. Note the handle here does not compute actual flop count. This is used for test only. """ flop_dict = Counter() # type: Counter flop_dict[self.lin_op] = 400000 return flop_dict custom_ops2: Dict[str, Handle] = { f'aten::{self.lin_op}': addmm_dummy_flop_jit } flop_dict2, _ = flop_count( custom_net, (x, ), supported_ops=custom_ops2) flop = 400000 / 1e9 self.assertEqual( flop_dict2[self.lin_op], flop, 'Customized operation handle failed to pass the flop count test.', ) def test_nn(self) -> None: """Test a model which is a pre-defined nn.module without defining a new customized network.""" batch_size = 5 input_dim = 8 output_dim = 4 x = torch.randn(batch_size, input_dim) flop_dict, _ = flop_count(nn.Linear(input_dim, output_dim), (x, )) gt_flop = batch_size * input_dim * output_dim / 1e9 gt_dict = defaultdict(float) gt_dict[self.lin_op] = gt_flop self.assertDictEqual(flop_dict, gt_dict, 'nn.Linear failed to pass the flop count test.') def test_skip_ops(self) -> None: """Test the return of skipped operations.""" batch_size = 10 input_dim = 5 output_dim = 4 custom_net = CustomNet(input_dim, output_dim) x = torch.rand(batch_size, input_dim) _, skip_dict = flop_count(custom_net, (x, )) gt_dict = Counter() # type: Counter gt_dict['aten::sigmoid'] = 1 self.assertDictEqual( skip_dict, gt_dict, 'Skipped operations failed to pass the flop count test.') def test_linear(self) -> None: """Test a network with a single fully connected layer.""" batch_size = 5 input_dim = 10 output_dim = 20 linear_net = LinearNet(input_dim, output_dim) x = torch.randn(batch_size, input_dim) flop_dict, _ = flop_count(linear_net, (x, )) gt_flop = batch_size * input_dim * output_dim / 1e9 gt_dict = defaultdict(float) gt_dict[self.lin_op] = gt_flop self.assertDictEqual( flop_dict, gt_dict, 'Fully connected layer failed to pass the flop count test.', ) # Test with #input_dims>2 if self.lin_op != 'linear': # Skip this test if nn.Linear doesn't use aten::linear # TODO: Stop skipping when multidimensional aten::matmul # flop counting is implemented return extra_dim = 5 x = torch.randn(batch_size, extra_dim, input_dim) flop_dict, _ = flop_count(linear_net, (x, )) gt_flop = batch_size * input_dim * extra_dim * output_dim / 1e9 gt_dict = defaultdict(float) gt_dict[self.lin_op] = gt_flop self.assertDictEqual( flop_dict, gt_dict, 'Fully connected layer failed to pass the flop count test.', ) def test_conv(self) -> None: """Test a network with a single convolution layer. The test cases are: 1) 2D convolution; 2) 2D convolution with change in spatial dimensions; 3) group convolution; 4) depthwise convolution; 5) 1d convolution; 6) 3d convolution. """ def _test_conv( conv_dim: int, batch_size: int, input_dim: int, output_dim: int, spatial_dim: int, kernel_size: int, padding: int, stride: int, group_size: int, transpose: bool = False, output_padding: int = 0, ) -> None: convNet = ConvNet( conv_dim, input_dim, output_dim, kernel_size, stride, padding, group_size, transpose, output_padding, ) assert conv_dim in [ 1, 2, 3 ], 'Convolution dimension needs to be 1, 2, or 3' if conv_dim == 1: x = torch.randn(batch_size, input_dim, spatial_dim) elif conv_dim == 2: x = torch.randn(batch_size, input_dim, spatial_dim, spatial_dim) else: x = torch.randn(batch_size, input_dim, spatial_dim, spatial_dim, spatial_dim) flop_dict, _ = flop_count(convNet, (x, )) if transpose: spatial_size = spatial_dim else: spatial_size = ( (spatial_dim + 2 * padding) - kernel_size) // stride + 1 gt_flop = ( batch_size * input_dim * output_dim * (kernel_size**conv_dim) * (spatial_size**conv_dim) / group_size / 1e9) gt_dict = defaultdict(float) gt_dict['conv'] = gt_flop self.assertDictEqual( flop_dict, gt_dict, 'Convolution layer failed to pass the flop count test.', ) # Test flop count for 2d convolution. conv_dim1 = 2 batch_size1 = 5 input_dim1 = 10 output_dim1 = 3 spatial_dim1 = 15 kernel_size1 = 3 padding1 = 1 stride1 = 1 group_size1 = 1 _test_conv( conv_dim1, batch_size1, input_dim1, output_dim1, spatial_dim1, kernel_size1, padding1, stride1, group_size1, ) # Test flop count for convolution with spatial change in output. conv_dim2 = 2 batch_size2 = 2 input_dim2 = 10 output_dim2 = 6 spatial_dim2 = 20 kernel_size2 = 3 padding2 = 1 stride2 = 2 group_size2 = 1 _test_conv( conv_dim2, batch_size2, input_dim2, output_dim2, spatial_dim2, kernel_size2, padding2, stride2, group_size2, ) # Test flop count for group convolution. conv_dim3 = 2 batch_size3 = 5 input_dim3 = 16 output_dim3 = 8 spatial_dim3 = 15 kernel_size3 = 5 padding3 = 2 stride3 = 1 group_size3 = 4 _test_conv( conv_dim3, batch_size3, input_dim3, output_dim3, spatial_dim3, kernel_size3, padding3, stride3, group_size3, ) # Test the special case of group convolution when group = output_dim. # This is equivalent to depthwise convolution. conv_dim4 = 2 batch_size4 = 5 input_dim4 = 16 output_dim4 = 8 spatial_dim4 = 15 kernel_size4 = 3 padding4 = 1 stride4 = 1 group_size4 = output_dim4 _test_conv( conv_dim4, batch_size4, input_dim4, output_dim4, spatial_dim4, kernel_size4, padding4, stride4, group_size4, ) # Test flop count for 1d convolution. conv_dim5 = 1 batch_size5 = 5 input_dim5 = 10 output_dim5 = 3 spatial_dim5 = 15 kernel_size5 = 3 padding5 = 1 stride5 = 1 group_size5 = 1 _test_conv( conv_dim5, batch_size5, input_dim5, output_dim5, spatial_dim5, kernel_size5, padding5, stride5, group_size5, ) # Test flop count for 3d convolution. conv_dim6 = 3 batch_size6 = 5 input_dim6 = 10 output_dim6 = 3 spatial_dim6 = 15 kernel_size6 = 3 padding6 = 1 stride6 = 1 group_size6 = 1 _test_conv( conv_dim6, batch_size6, input_dim6, output_dim6, spatial_dim6, kernel_size6, padding6, stride6, group_size6, ) # Test flop count for transposed 2d convolution. conv_dim7 = 2 batch_size7 = 5 input_dim7 = 10 output_dim7 = 3 spatial_dim7 = 15 kernel_size7 = 3 padding7 = 1 stride7 = 1 group_size7 = 1 _test_conv( conv_dim7, batch_size7, input_dim7, output_dim7, spatial_dim7, kernel_size7, padding7, stride7, group_size7, transpose=True, ) # Test flop count for strided transposed 2d convolution. conv_dim8 = 2 batch_size8 = 5 input_dim8 = 10 output_dim8 = 3 spatial_dim8 = 15 kernel_size8 = 3 padding8 = 1 stride8 = 2 group_size8 = 1 _test_conv( conv_dim8, batch_size8, input_dim8, output_dim8, spatial_dim8, kernel_size8, padding8, stride8, group_size8, transpose=True, ) # Test flop count for strided transposed 2d convolution # w/ output_padding. conv_dim9 = 2 batch_size9 = 5 input_dim9 = 10 output_dim9 = 3 spatial_dim9 = 15 kernel_size9 = 3 padding9 = 1 stride9 = 3 group_size9 = 1 output_padding9 = 2 _test_conv( conv_dim9, batch_size9, input_dim9, output_dim9, spatial_dim9, kernel_size9, padding9, stride9, group_size9, transpose=True, output_padding=output_padding9, ) """Test flop count for operation matmul.""" m = 20 n = 10 p = 100 m_net = MatmulNet() x = torch.randn(m, n) y = torch.randn(n, p) flop_dict, _ = flop_count(m_net, (x, y)) gt_flop = m * n * p / 1e9 gt_dict = defaultdict(float) gt_dict['matmul'] = gt_flop self.assertDictEqual( flop_dict, gt_dict, 'Matmul operation failed to pass the flop count test.') # Test with single dimension y y = torch.randn(n) gt_dict['matmul'] = m * n * 1 / 1e9 flop_dict, _ = flop_count(m_net, (x, y)) self.assertDictEqual( flop_dict, gt_dict, 'Matmul operation failed to pass the flop count test.') def test_matmul_broadcast(self) -> None: """Test flop count for operation matmul.""" m = 20 n = 10 p = 100 m_net = MatmulNet() x = torch.randn(1, m, n) y = torch.randn(1, n, p) flop_dict, _ = flop_count(m_net, (x, y)) gt_flop = m * n * p / 1e9 gt_dict = defaultdict(float) gt_dict['matmul'] = gt_flop self.assertDictEqual( flop_dict, gt_dict, 'Matmul operation failed to pass the flop count test.') x = torch.randn(2, 2, m, n) y = torch.randn(2, 2, n, p) flop_dict, _ = flop_count(m_net, (x, y)) gt_flop = 4 * m * n * p / 1e9 gt_dict = defaultdict(float) gt_dict['matmul'] = gt_flop self.assertDictEqual( flop_dict, gt_dict, 'Matmul operation failed to pass the flop count test.') x = torch.randn(1, m, n) y = torch.randn(n, p) flop_dict, _ = flop_count(m_net, (x, y)) gt_flop = m * n * p / 1e9 gt_dict = defaultdict(float) gt_dict['matmul'] = gt_flop self.assertDictEqual( flop_dict, gt_dict, 'Matmul operation failed to pass the flop count test.') x = torch.randn(2, m, n) y = torch.randn(n, p) flop_dict, _ = flop_count(m_net, (x, y)) gt_flop = 2 * m * n * p / 1e9 gt_dict = defaultdict(float) gt_dict['matmul'] = gt_flop self.assertDictEqual( flop_dict, gt_dict, 'Matmul operation failed to pass the flop count test.') def test_bmm(self) -> None: """Test flop count for operation torch.bmm. The case checks torch.bmm with equation nct,ntp->ncp. """ n = 2 c = 5 t = 2 p = 12 e_net = BMMNet() x = torch.randn(n, c, t) y = torch.randn(n, t, p) flop_dict, _ = flop_count(e_net, (x, y)) gt_flop = n * t * p * c / 1e9 gt_dict = defaultdict(float) gt_dict['bmm'] = gt_flop self.assertDictEqual( flop_dict, gt_dict, 'bmm operation nct,ncp->ntp failed to pass the flop count test.', ) def test_einsum(self) -> None: """Test flop count for operation torch.einsum. The first case checks torch.einsum with equation nct,ncp->ntp. The second case checks torch.einsum with equation "ntg,ncg->nct". """ equation = 'nct,ncp->ntp' n = 1 c = 5 t = 2 p = 12 e_net = EinsumNet(equation) x = torch.randn(n, c, t) y = torch.randn(n, c, p) flop_dict, _ = flop_count(e_net, (x, y)) gt_flop = n * t * p * c / 1e9 gt_dict = defaultdict(float) gt_dict['einsum'] = gt_flop self.assertDictEqual( flop_dict, gt_dict, 'Einsum operation nct,ncp->ntp failed to pass flop count test.', ) equation = 'ntg,ncg->nct' g = 6 e_net = EinsumNet(equation) x = torch.randn(n, t, g) y = torch.randn(n, c, g) flop_dict, _ = flop_count(e_net, (x, y)) gt_flop = n * t * g * c / 1e9 gt_dict = defaultdict(float) gt_dict['einsum'] = gt_flop self.assertDictEqual( flop_dict, gt_dict, 'Einsum operation ntg,ncg->nct failed to pass flop count test.', ) def test_batchnorm(self) -> None: """Test flop count for operation batchnorm. The test cases include BatchNorm1d, BatchNorm2d and BatchNorm3d. """ # Test for BatchNorm1d. batch_size = 10 input_dim = 10 batch_1d = nn.BatchNorm1d(input_dim, affine=False).eval() x = torch.randn(batch_size, input_dim) flop_dict, _ = flop_count(batch_1d, (x, )) gt_flop = batch_size * input_dim / 1e9 gt_dict = defaultdict(float) gt_dict['batch_norm'] = gt_flop self.assertDictEqual( flop_dict, gt_dict, 'BatchNorm1d failed to pass the flop count test.') # Test for BatchNorm2d. batch_size = 10 input_dim = 10 spatial_dim_x = 5 spatial_dim_y = 5 batch_2d = nn.BatchNorm2d(input_dim, affine=False) x = torch.randn(batch_size, input_dim, spatial_dim_x, spatial_dim_y) flop_dict, _ = flop_count(batch_2d, (x, )) gt_flop = 4 * batch_size * input_dim * spatial_dim_x * \ spatial_dim_y / 1e9 gt_dict = defaultdict(float) gt_dict['batch_norm'] = gt_flop self.assertDictEqual( flop_dict, gt_dict, 'BatchNorm2d failed to pass the flop count test.') # Test for BatchNorm3d. batch_size = 10 input_dim = 10 spatial_dim_x = 5 spatial_dim_y = 5 spatial_dim_z = 5 batch_3d = nn.BatchNorm3d(input_dim, affine=False) x = torch.randn(batch_size, input_dim, spatial_dim_x, spatial_dim_y, spatial_dim_z) flop_dict, _ = flop_count(batch_3d, (x, )) gt_flop = (4 * batch_size * input_dim * spatial_dim_x * spatial_dim_y * spatial_dim_z / 1e9) gt_dict = defaultdict(float) gt_dict['batch_norm'] = gt_flop self.assertDictEqual( flop_dict, gt_dict, 'BatchNorm3d failed to pass the flop count test.') def test_threeNet(self) -> None: """Test a network with more than one layer. The network has a convolution layer followed by two fully connected layers. """ batch_size = 4 input_dim = 2 conv_dim = 5 spatial_dim = 10 linear_dim = 3 x = torch.randn(batch_size, input_dim, spatial_dim, spatial_dim) three_net = ThreeNet(input_dim, conv_dim, linear_dim) flop1 = batch_size * conv_dim * input_dim * spatial_dim * \ spatial_dim / 1e9 flop_linear1 = batch_size * conv_dim * linear_dim / 1e9 flop_linear2 = batch_size * linear_dim * 1 / 1e9 flop2 = flop_linear1 + flop_linear2 flop_dict, _ = flop_count(three_net, (x, )) gt_dict = defaultdict(float) gt_dict['conv'] = flop1 gt_dict[self.lin_op] = flop2 gt_dict['adaptive_avg_pool2d'] = 2e-6 self.assertDictEqual( flop_dict, gt_dict, 'The three-layer network failed to pass the flop count test.', ) def test_flop_counter_class(self) -> None: """Test FlopAnalyzer.""" batch_size = 4 input_dim = 2 conv_dim = 5 spatial_dim = 10 linear_dim = 3 x = torch.randn(batch_size, input_dim, spatial_dim, spatial_dim) three_net = ThreeNet(input_dim, conv_dim, linear_dim) flop1 = batch_size * conv_dim * input_dim * spatial_dim * spatial_dim flop_linear1 = batch_size * conv_dim * linear_dim flop_linear2 = batch_size * linear_dim * 1 flop_counter = FlopAnalyzer(three_net, (x, )) gt_dict = Counter({ 'conv': flop1, 'linear1': flop_linear1, 'linear2': flop_linear2, 'pool': flop1 // input_dim, }) gt_dict[''] = sum(gt_dict.values()) self.assertEqual(flop_counter.by_module(), gt_dict) def test_autograd_function(self): # test support on custom autograd function class Mod(nn.Module): def forward(self, x): return _CustomOp.apply(x) flop = FlopAnalyzer(Mod(), (torch.rand(4, 5), )).set_op_handle( 'prim::PythonOp._CustomOp', lambda *args, **kwargs: 42) self.assertEqual(flop.total(), 42) def test_scripted_function(self): # Scripted function is not yet supported. It should produce a warning def func(x): return x @ x class Mod(nn.Module): def forward(self, x): f = torch.jit.script(func) return f(x * x) flop = FlopAnalyzer(Mod(), (torch.rand(5, 5), )) _ = flop.total() self.assertIn('prim::CallFunction', flop.unsupported_ops()) class TestFlopCountHandles(unittest.TestCase): def _count_function(self, func, inputs, name) -> Tuple[Any, Any]: tensor_inputs = [x for x in inputs if isinstance(x, torch.Tensor)] def f(*args): return func(*inputs) graph = torch.jit.trace( f, tuple(tensor_inputs), check_trace=False).graph nodes = [k for k in graph.nodes() if k.kind() == name] self.assertEqual(len(nodes), 1) node = nodes[0] return list(node.inputs()), list(node.outputs()) def test_batch_norm(self): op_name = 'aten::batch_norm' counter = _DEFAULT_SUPPORTED_FLOP_OPS[op_name] vec = torch.rand(2) nodes = self._count_function( F.batch_norm, (torch.rand(2, 2, 2, 2), vec, vec, vec, vec), op_name) self.assertEqual(counter(*nodes), 32) nodes = self._count_function( F.batch_norm, (torch.rand(2, 2, 2, 2), vec, vec, None, None), op_name, ) self.assertEqual(counter(*nodes), 16) nodes = self._count_function( # training=True F.batch_norm, (torch.rand(2, 2, 2, 2), vec, vec, vec, vec, True), op_name, ) self.assertEqual(counter(*nodes), 80) def test_group_norm(self): op_name = 'aten::group_norm' counter = _DEFAULT_SUPPORTED_FLOP_OPS[op_name] vec = torch.rand(2) nodes = self._count_function(F.group_norm, (torch.rand(2, 2, 2, 2), 2, vec, vec), op_name) self.assertEqual(counter(*nodes), 80) nodes = self._count_function(F.group_norm, (torch.rand(2, 2, 2, 2), 2, None, None), op_name) self.assertEqual(counter(*nodes), 64) def test_upsample(self): op_name = 'aten::upsample_bilinear2d' counter = _DEFAULT_SUPPORTED_FLOP_OPS[op_name] nodes = self._count_function( F.interpolate, (torch.rand(2, 2, 2, 2), None, 2, 'bilinear', False), op_name) self.assertEqual(counter(*nodes), 2**4 * 4 * 4) def test_complicated_einsum(self): op_name = 'aten::einsum' counter = _DEFAULT_SUPPORTED_FLOP_OPS[op_name] nodes = self._count_function( torch.einsum, ('nc,nchw->hw', torch.rand(3, 4), torch.rand(3, 4, 2, 3)), op_name, ) self.assertEqual(counter(*nodes), 72.0) def test_torch_mm(self): for op_name, func in zip(['aten::mm', 'aten::matmul'], [torch.mm, torch.matmul]): counter = _DEFAULT_SUPPORTED_FLOP_OPS[op_name] nodes = self._count_function( func, (torch.rand(3, 4), torch.rand(4, 5)), op_name, ) self.assertEqual(counter(*nodes), 60)