930 lines
29 KiB
Python
930 lines
29 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
||
# Modified from
|
||
# https://github.com/facebookresearch/fvcore/blob/main/tests/test_flop_count.py
|
||
|
||
# pyre-ignore-all-errors[2,3,14,53]
|
||
import typing
|
||
import unittest
|
||
from collections import Counter, defaultdict
|
||
from typing import Any, Dict, Tuple
|
||
|
||
import torch
|
||
import torch.nn as nn
|
||
from torch.autograd.function import Function
|
||
from torch.nn import functional as F
|
||
|
||
from mmengine.analysis import FlopAnalyzer, flop_count
|
||
from mmengine.analysis.complexity_analysis import _DEFAULT_SUPPORTED_FLOP_OPS
|
||
from mmengine.analysis.jit_handles import Handle
|
||
|
||
|
||
class _CustomOp(Function):
|
||
|
||
@staticmethod
|
||
def forward(ctx, input: torch.Tensor) -> torch.Tensor:
|
||
return input
|
||
|
||
@staticmethod
|
||
def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor:
|
||
return torch.ones_like(grad_output)
|
||
|
||
|
||
class ThreeNet(nn.Module):
|
||
"""A network with three layers.
|
||
|
||
This is used for testing a network with more than one operation. The
|
||
network has a convolution layer followed by two fully connected layers.
|
||
"""
|
||
|
||
def __init__(self, input_dim: int, conv_dim: int, linear_dim: int) -> None:
|
||
super().__init__()
|
||
self.conv = nn.Conv2d(input_dim, conv_dim, 1, 1)
|
||
out_dim = 1
|
||
self.pool = nn.AdaptiveAvgPool2d((out_dim, out_dim))
|
||
self.linear1 = nn.Linear(conv_dim, linear_dim)
|
||
self.linear2 = nn.Linear(linear_dim, 1)
|
||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||
x = self.conv(x)
|
||
x = self.pool(x)
|
||
x = torch.flatten(x, 1)
|
||
x = self.linear1(x)
|
||
x = self.linear2(x)
|
||
return x
|
||
|
||
|
||
class ConvNet(nn.Module):
|
||
"""A network with a single convolution layer.
|
||
|
||
This is used for testing flop count for convolution layers.
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
conv_dim: int,
|
||
input_dim: int,
|
||
output_dim: int,
|
||
kernel_size: int,
|
||
stride: int,
|
||
padding: int,
|
||
groups_num: int,
|
||
transpose: bool = False,
|
||
output_padding: int = 0,
|
||
) -> None:
|
||
super().__init__()
|
||
if transpose:
|
||
conv_layers = [
|
||
nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d
|
||
]
|
||
kwargs = {'output_padding': output_padding}
|
||
else:
|
||
conv_layers = [nn.Conv1d, nn.Conv2d, nn.Conv3d]
|
||
assert (output_padding == 0), 'output_padding is not supported for'
|
||
' un-transposed convolutions.'
|
||
kwargs = {}
|
||
ConvLayer = conv_layers[conv_dim - 1]
|
||
|
||
self.conv: nn.Module = ConvLayer(
|
||
input_dim,
|
||
output_dim,
|
||
kernel_size,
|
||
stride,
|
||
padding,
|
||
groups=groups_num,
|
||
**kwargs,
|
||
)
|
||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||
x = self.conv(x)
|
||
return x
|
||
|
||
|
||
class LinearNet(nn.Module):
|
||
"""A network with a single fully connected layer.
|
||
|
||
This is used for testing flop count for fully connected layers.
|
||
"""
|
||
|
||
def __init__(self, input_dim: int, output_dim: int) -> None:
|
||
super().__init__()
|
||
self.linear = nn.Linear(input_dim, output_dim)
|
||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||
x = self.linear(x)
|
||
return x
|
||
|
||
|
||
class EinsumNet(nn.Module):
|
||
"""A network with a single torch.einsum operation.
|
||
|
||
This is used for testing flop count for torch.einsum.
|
||
"""
|
||
|
||
def __init__(self, equation: str) -> None:
|
||
super().__init__()
|
||
self.eq: str = equation
|
||
|
||
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||
x = torch.einsum(self.eq, x, y)
|
||
return x
|
||
|
||
|
||
class MatmulNet(nn.Module):
|
||
"""A network with a single torch.matmul operation.
|
||
|
||
This is used for testing flop count for torch.matmul.
|
||
"""
|
||
|
||
def __init__(self) -> None:
|
||
super().__init__()
|
||
|
||
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||
x = torch.matmul(x, y)
|
||
return x
|
||
|
||
|
||
class BMMNet(nn.Module):
|
||
"""A network with a single torch.bmm operation.
|
||
|
||
This is used for testing flop count for torch.bmm.
|
||
"""
|
||
|
||
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||
x = torch.bmm(x, y)
|
||
return x
|
||
|
||
|
||
class CustomNet(nn.Module):
|
||
"""A network with a fully connected layer followed by a sigmoid layer.
|
||
|
||
This is used for testing customized operation handles.
|
||
"""
|
||
|
||
def __init__(self, input_dim: int, output_dim: int) -> None:
|
||
super().__init__()
|
||
self.conv = nn.Linear(input_dim, output_dim)
|
||
self.sigmoid = nn.Sigmoid()
|
||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||
x = self.conv(x)
|
||
x = self.sigmoid(x)
|
||
return x
|
||
|
||
|
||
class TestFlopAnalyzer(unittest.TestCase):
|
||
"""Unittest for flop_count."""
|
||
|
||
def setUp(self) -> None:
|
||
# nn.Linear uses a different operator based on version, so make sure
|
||
# we are testing the right thing.
|
||
lin = nn.Linear(10, 10)
|
||
lin_x: torch.Tensor = torch.randn(10, 10)
|
||
trace = torch.jit.trace(lin, (lin_x, ))
|
||
node_kinds = [node.kind() for node in trace.graph.nodes()]
|
||
assert 'aten::addmm' in node_kinds or 'aten::linear' in node_kinds
|
||
if 'aten::addmm' in node_kinds:
|
||
self.lin_op = 'addmm'
|
||
else:
|
||
self.lin_op = 'linear'
|
||
|
||
def test_customized_ops(self) -> None:
|
||
"""Test the use of customized operation handles.
|
||
|
||
The first test checks the case when a new handle for a new operation is
|
||
passed as an argument. The second case checks when a new handle for a
|
||
default operation is passed. The new handle should overwrite the
|
||
default handle.
|
||
"""
|
||
|
||
# New handle for a new operation.
|
||
def dummy_sigmoid_flop_jit(
|
||
inputs: typing.List[Any],
|
||
outputs: typing.List[Any]) -> typing.Counter[str]:
|
||
"""A dummy handle function for sigmoid.
|
||
|
||
Note the handle here does not compute actual flop count. This is
|
||
used for test only.
|
||
"""
|
||
flop_dict = Counter() # type: Counter
|
||
flop_dict['sigmoid'] = 10000
|
||
return flop_dict
|
||
|
||
batch_size = 10
|
||
input_dim = 5
|
||
output_dim = 4
|
||
custom_net = CustomNet(input_dim, output_dim)
|
||
custom_ops: Dict[str, Handle] = {
|
||
'aten::sigmoid': dummy_sigmoid_flop_jit
|
||
}
|
||
x = torch.rand(batch_size, input_dim)
|
||
flop_dict1, _ = flop_count(custom_net, (x, ), supported_ops=custom_ops)
|
||
flop_sigmoid = 10000 / 1e9
|
||
self.assertEqual(
|
||
flop_dict1['sigmoid'],
|
||
flop_sigmoid,
|
||
'Customized operation handle failed to pass the flop count test.',
|
||
)
|
||
|
||
# New handle that overwrites a default handle addmm. So now the new
|
||
# handle counts flops for the fully connected layer.
|
||
def addmm_dummy_flop_jit(
|
||
inputs: typing.List[object],
|
||
outputs: typing.List[object]) -> typing.Counter[str]:
|
||
"""A dummy handle function for fully connected layers.
|
||
|
||
This overwrites the default handle. Note the handle here does not
|
||
compute actual flop count. This is used for test only.
|
||
"""
|
||
flop_dict = Counter() # type: Counter
|
||
flop_dict[self.lin_op] = 400000
|
||
return flop_dict
|
||
|
||
custom_ops2: Dict[str, Handle] = {
|
||
f'aten::{self.lin_op}': addmm_dummy_flop_jit
|
||
}
|
||
flop_dict2, _ = flop_count(
|
||
custom_net, (x, ), supported_ops=custom_ops2)
|
||
flop = 400000 / 1e9
|
||
self.assertEqual(
|
||
flop_dict2[self.lin_op],
|
||
flop,
|
||
'Customized operation handle failed to pass the flop count test.',
|
||
)
|
||
|
||
def test_nn(self) -> None:
|
||
"""Test a model which is a pre-defined nn.module without defining a new
|
||
customized network."""
|
||
batch_size = 5
|
||
input_dim = 8
|
||
output_dim = 4
|
||
x = torch.randn(batch_size, input_dim)
|
||
flop_dict, _ = flop_count(nn.Linear(input_dim, output_dim), (x, ))
|
||
gt_flop = batch_size * input_dim * output_dim / 1e9
|
||
gt_dict = defaultdict(float)
|
||
gt_dict[self.lin_op] = gt_flop
|
||
self.assertDictEqual(flop_dict, gt_dict,
|
||
'nn.Linear failed to pass the flop count test.')
|
||
|
||
def test_skip_ops(self) -> None:
|
||
"""Test the return of skipped operations."""
|
||
batch_size = 10
|
||
input_dim = 5
|
||
output_dim = 4
|
||
custom_net = CustomNet(input_dim, output_dim)
|
||
x = torch.rand(batch_size, input_dim)
|
||
_, skip_dict = flop_count(custom_net, (x, ))
|
||
gt_dict = Counter() # type: Counter
|
||
gt_dict['aten::sigmoid'] = 1
|
||
self.assertDictEqual(
|
||
skip_dict, gt_dict,
|
||
'Skipped operations failed to pass the flop count test.')
|
||
|
||
def test_linear(self) -> None:
|
||
"""Test a network with a single fully connected layer."""
|
||
batch_size = 5
|
||
input_dim = 10
|
||
output_dim = 20
|
||
linear_net = LinearNet(input_dim, output_dim)
|
||
x = torch.randn(batch_size, input_dim)
|
||
flop_dict, _ = flop_count(linear_net, (x, ))
|
||
gt_flop = batch_size * input_dim * output_dim / 1e9
|
||
gt_dict = defaultdict(float)
|
||
gt_dict[self.lin_op] = gt_flop
|
||
self.assertDictEqual(
|
||
flop_dict,
|
||
gt_dict,
|
||
'Fully connected layer failed to pass the flop count test.',
|
||
)
|
||
|
||
# Test with #input_dims>2
|
||
if self.lin_op != 'linear':
|
||
# Skip this test if nn.Linear doesn't use aten::linear
|
||
# TODO: Stop skipping when multidimensional aten::matmul
|
||
# flop counting is implemented
|
||
return
|
||
extra_dim = 5
|
||
x = torch.randn(batch_size, extra_dim, input_dim)
|
||
flop_dict, _ = flop_count(linear_net, (x, ))
|
||
gt_flop = batch_size * input_dim * extra_dim * output_dim / 1e9
|
||
gt_dict = defaultdict(float)
|
||
gt_dict[self.lin_op] = gt_flop
|
||
self.assertDictEqual(
|
||
flop_dict,
|
||
gt_dict,
|
||
'Fully connected layer failed to pass the flop count test.',
|
||
)
|
||
|
||
def test_conv(self) -> None:
|
||
"""Test a network with a single convolution layer.
|
||
|
||
The test cases are: 1) 2D convolution; 2) 2D convolution with change in
|
||
spatial dimensions; 3) group convolution; 4) depthwise convolution; 5)
|
||
1d convolution; 6) 3d convolution.
|
||
"""
|
||
|
||
def _test_conv(
|
||
conv_dim: int,
|
||
batch_size: int,
|
||
input_dim: int,
|
||
output_dim: int,
|
||
spatial_dim: int,
|
||
kernel_size: int,
|
||
padding: int,
|
||
stride: int,
|
||
group_size: int,
|
||
transpose: bool = False,
|
||
output_padding: int = 0,
|
||
) -> None:
|
||
convNet = ConvNet(
|
||
conv_dim,
|
||
input_dim,
|
||
output_dim,
|
||
kernel_size,
|
||
stride,
|
||
padding,
|
||
group_size,
|
||
transpose,
|
||
output_padding,
|
||
)
|
||
assert conv_dim in [
|
||
1, 2, 3
|
||
], 'Convolution dimension needs to be 1, 2, or 3'
|
||
if conv_dim == 1:
|
||
x = torch.randn(batch_size, input_dim, spatial_dim)
|
||
elif conv_dim == 2:
|
||
x = torch.randn(batch_size, input_dim, spatial_dim,
|
||
spatial_dim)
|
||
else:
|
||
x = torch.randn(batch_size, input_dim, spatial_dim,
|
||
spatial_dim, spatial_dim)
|
||
|
||
flop_dict, _ = flop_count(convNet, (x, ))
|
||
if transpose:
|
||
spatial_size = spatial_dim
|
||
else:
|
||
spatial_size = (
|
||
(spatial_dim + 2 * padding) - kernel_size) // stride + 1
|
||
gt_flop = (
|
||
batch_size * input_dim * output_dim * (kernel_size**conv_dim) *
|
||
(spatial_size**conv_dim) / group_size / 1e9)
|
||
gt_dict = defaultdict(float)
|
||
gt_dict['conv'] = gt_flop
|
||
self.assertDictEqual(
|
||
flop_dict,
|
||
gt_dict,
|
||
'Convolution layer failed to pass the flop count test.',
|
||
)
|
||
|
||
# Test flop count for 2d convolution.
|
||
conv_dim1 = 2
|
||
batch_size1 = 5
|
||
input_dim1 = 10
|
||
output_dim1 = 3
|
||
spatial_dim1 = 15
|
||
kernel_size1 = 3
|
||
padding1 = 1
|
||
stride1 = 1
|
||
group_size1 = 1
|
||
_test_conv(
|
||
conv_dim1,
|
||
batch_size1,
|
||
input_dim1,
|
||
output_dim1,
|
||
spatial_dim1,
|
||
kernel_size1,
|
||
padding1,
|
||
stride1,
|
||
group_size1,
|
||
)
|
||
|
||
# Test flop count for convolution with spatial change in output.
|
||
conv_dim2 = 2
|
||
batch_size2 = 2
|
||
input_dim2 = 10
|
||
output_dim2 = 6
|
||
spatial_dim2 = 20
|
||
kernel_size2 = 3
|
||
padding2 = 1
|
||
stride2 = 2
|
||
group_size2 = 1
|
||
_test_conv(
|
||
conv_dim2,
|
||
batch_size2,
|
||
input_dim2,
|
||
output_dim2,
|
||
spatial_dim2,
|
||
kernel_size2,
|
||
padding2,
|
||
stride2,
|
||
group_size2,
|
||
)
|
||
|
||
# Test flop count for group convolution.
|
||
conv_dim3 = 2
|
||
batch_size3 = 5
|
||
input_dim3 = 16
|
||
output_dim3 = 8
|
||
spatial_dim3 = 15
|
||
kernel_size3 = 5
|
||
padding3 = 2
|
||
stride3 = 1
|
||
group_size3 = 4
|
||
_test_conv(
|
||
conv_dim3,
|
||
batch_size3,
|
||
input_dim3,
|
||
output_dim3,
|
||
spatial_dim3,
|
||
kernel_size3,
|
||
padding3,
|
||
stride3,
|
||
group_size3,
|
||
)
|
||
|
||
# Test the special case of group convolution when group = output_dim.
|
||
# This is equivalent to depthwise convolution.
|
||
conv_dim4 = 2
|
||
batch_size4 = 5
|
||
input_dim4 = 16
|
||
output_dim4 = 8
|
||
spatial_dim4 = 15
|
||
kernel_size4 = 3
|
||
padding4 = 1
|
||
stride4 = 1
|
||
group_size4 = output_dim4
|
||
_test_conv(
|
||
conv_dim4,
|
||
batch_size4,
|
||
input_dim4,
|
||
output_dim4,
|
||
spatial_dim4,
|
||
kernel_size4,
|
||
padding4,
|
||
stride4,
|
||
group_size4,
|
||
)
|
||
|
||
# Test flop count for 1d convolution.
|
||
conv_dim5 = 1
|
||
batch_size5 = 5
|
||
input_dim5 = 10
|
||
output_dim5 = 3
|
||
spatial_dim5 = 15
|
||
kernel_size5 = 3
|
||
padding5 = 1
|
||
stride5 = 1
|
||
group_size5 = 1
|
||
_test_conv(
|
||
conv_dim5,
|
||
batch_size5,
|
||
input_dim5,
|
||
output_dim5,
|
||
spatial_dim5,
|
||
kernel_size5,
|
||
padding5,
|
||
stride5,
|
||
group_size5,
|
||
)
|
||
|
||
# Test flop count for 3d convolution.
|
||
conv_dim6 = 3
|
||
batch_size6 = 5
|
||
input_dim6 = 10
|
||
output_dim6 = 3
|
||
spatial_dim6 = 15
|
||
kernel_size6 = 3
|
||
padding6 = 1
|
||
stride6 = 1
|
||
group_size6 = 1
|
||
_test_conv(
|
||
conv_dim6,
|
||
batch_size6,
|
||
input_dim6,
|
||
output_dim6,
|
||
spatial_dim6,
|
||
kernel_size6,
|
||
padding6,
|
||
stride6,
|
||
group_size6,
|
||
)
|
||
|
||
# Test flop count for transposed 2d convolution.
|
||
conv_dim7 = 2
|
||
batch_size7 = 5
|
||
input_dim7 = 10
|
||
output_dim7 = 3
|
||
spatial_dim7 = 15
|
||
kernel_size7 = 3
|
||
padding7 = 1
|
||
stride7 = 1
|
||
group_size7 = 1
|
||
_test_conv(
|
||
conv_dim7,
|
||
batch_size7,
|
||
input_dim7,
|
||
output_dim7,
|
||
spatial_dim7,
|
||
kernel_size7,
|
||
padding7,
|
||
stride7,
|
||
group_size7,
|
||
transpose=True,
|
||
)
|
||
|
||
# Test flop count for strided transposed 2d convolution.
|
||
conv_dim8 = 2
|
||
batch_size8 = 5
|
||
input_dim8 = 10
|
||
output_dim8 = 3
|
||
spatial_dim8 = 15
|
||
kernel_size8 = 3
|
||
padding8 = 1
|
||
stride8 = 2
|
||
group_size8 = 1
|
||
_test_conv(
|
||
conv_dim8,
|
||
batch_size8,
|
||
input_dim8,
|
||
output_dim8,
|
||
spatial_dim8,
|
||
kernel_size8,
|
||
padding8,
|
||
stride8,
|
||
group_size8,
|
||
transpose=True,
|
||
)
|
||
|
||
# Test flop count for strided transposed 2d convolution
|
||
# w/ output_padding.
|
||
conv_dim9 = 2
|
||
batch_size9 = 5
|
||
input_dim9 = 10
|
||
output_dim9 = 3
|
||
spatial_dim9 = 15
|
||
kernel_size9 = 3
|
||
padding9 = 1
|
||
stride9 = 3
|
||
group_size9 = 1
|
||
output_padding9 = 2
|
||
_test_conv(
|
||
conv_dim9,
|
||
batch_size9,
|
||
input_dim9,
|
||
output_dim9,
|
||
spatial_dim9,
|
||
kernel_size9,
|
||
padding9,
|
||
stride9,
|
||
group_size9,
|
||
transpose=True,
|
||
output_padding=output_padding9,
|
||
)
|
||
"""Test flop count for operation matmul."""
|
||
m = 20
|
||
n = 10
|
||
p = 100
|
||
m_net = MatmulNet()
|
||
x = torch.randn(m, n)
|
||
y = torch.randn(n, p)
|
||
flop_dict, _ = flop_count(m_net, (x, y))
|
||
gt_flop = m * n * p / 1e9
|
||
gt_dict = defaultdict(float)
|
||
gt_dict['matmul'] = gt_flop
|
||
self.assertDictEqual(
|
||
flop_dict, gt_dict,
|
||
'Matmul operation failed to pass the flop count test.')
|
||
# Test with single dimension y
|
||
y = torch.randn(n)
|
||
gt_dict['matmul'] = m * n * 1 / 1e9
|
||
flop_dict, _ = flop_count(m_net, (x, y))
|
||
self.assertDictEqual(
|
||
flop_dict, gt_dict,
|
||
'Matmul operation failed to pass the flop count test.')
|
||
|
||
def test_matmul_broadcast(self) -> None:
|
||
"""Test flop count for operation matmul."""
|
||
m = 20
|
||
n = 10
|
||
p = 100
|
||
m_net = MatmulNet()
|
||
x = torch.randn(1, m, n)
|
||
y = torch.randn(1, n, p)
|
||
flop_dict, _ = flop_count(m_net, (x, y))
|
||
gt_flop = m * n * p / 1e9
|
||
gt_dict = defaultdict(float)
|
||
gt_dict['matmul'] = gt_flop
|
||
self.assertDictEqual(
|
||
flop_dict, gt_dict,
|
||
'Matmul operation failed to pass the flop count test.')
|
||
|
||
x = torch.randn(2, 2, m, n)
|
||
y = torch.randn(2, 2, n, p)
|
||
flop_dict, _ = flop_count(m_net, (x, y))
|
||
gt_flop = 4 * m * n * p / 1e9
|
||
gt_dict = defaultdict(float)
|
||
gt_dict['matmul'] = gt_flop
|
||
self.assertDictEqual(
|
||
flop_dict, gt_dict,
|
||
'Matmul operation failed to pass the flop count test.')
|
||
|
||
x = torch.randn(1, m, n)
|
||
y = torch.randn(n, p)
|
||
flop_dict, _ = flop_count(m_net, (x, y))
|
||
gt_flop = m * n * p / 1e9
|
||
gt_dict = defaultdict(float)
|
||
gt_dict['matmul'] = gt_flop
|
||
self.assertDictEqual(
|
||
flop_dict, gt_dict,
|
||
'Matmul operation failed to pass the flop count test.')
|
||
|
||
x = torch.randn(2, m, n)
|
||
y = torch.randn(n, p)
|
||
flop_dict, _ = flop_count(m_net, (x, y))
|
||
gt_flop = 2 * m * n * p / 1e9
|
||
gt_dict = defaultdict(float)
|
||
gt_dict['matmul'] = gt_flop
|
||
self.assertDictEqual(
|
||
flop_dict, gt_dict,
|
||
'Matmul operation failed to pass the flop count test.')
|
||
|
||
def test_bmm(self) -> None:
|
||
"""Test flop count for operation torch.bmm.
|
||
|
||
The case checks torch.bmm with equation nct,ntp->ncp.
|
||
"""
|
||
n = 2
|
||
c = 5
|
||
t = 2
|
||
p = 12
|
||
e_net = BMMNet()
|
||
x = torch.randn(n, c, t)
|
||
y = torch.randn(n, t, p)
|
||
flop_dict, _ = flop_count(e_net, (x, y))
|
||
gt_flop = n * t * p * c / 1e9
|
||
gt_dict = defaultdict(float)
|
||
gt_dict['bmm'] = gt_flop
|
||
self.assertDictEqual(
|
||
flop_dict,
|
||
gt_dict,
|
||
'bmm operation nct,ncp->ntp failed to pass the flop count test.',
|
||
)
|
||
|
||
def test_einsum(self) -> None:
|
||
"""Test flop count for operation torch.einsum.
|
||
|
||
The first case checks torch.einsum with equation nct,ncp->ntp. The
|
||
second case checks torch.einsum with equation "ntg,ncg->nct".
|
||
"""
|
||
equation = 'nct,ncp->ntp'
|
||
n = 1
|
||
c = 5
|
||
t = 2
|
||
p = 12
|
||
e_net = EinsumNet(equation)
|
||
x = torch.randn(n, c, t)
|
||
y = torch.randn(n, c, p)
|
||
flop_dict, _ = flop_count(e_net, (x, y))
|
||
gt_flop = n * t * p * c / 1e9
|
||
gt_dict = defaultdict(float)
|
||
gt_dict['einsum'] = gt_flop
|
||
self.assertDictEqual(
|
||
flop_dict,
|
||
gt_dict,
|
||
'Einsum operation nct,ncp->ntp failed to pass flop count test.',
|
||
)
|
||
|
||
equation = 'ntg,ncg->nct'
|
||
g = 6
|
||
e_net = EinsumNet(equation)
|
||
x = torch.randn(n, t, g)
|
||
y = torch.randn(n, c, g)
|
||
flop_dict, _ = flop_count(e_net, (x, y))
|
||
gt_flop = n * t * g * c / 1e9
|
||
gt_dict = defaultdict(float)
|
||
gt_dict['einsum'] = gt_flop
|
||
self.assertDictEqual(
|
||
flop_dict,
|
||
gt_dict,
|
||
'Einsum operation ntg,ncg->nct failed to pass flop count test.',
|
||
)
|
||
|
||
def test_batchnorm(self) -> None:
|
||
"""Test flop count for operation batchnorm.
|
||
|
||
The test cases include BatchNorm1d, BatchNorm2d and BatchNorm3d.
|
||
"""
|
||
# Test for BatchNorm1d.
|
||
batch_size = 10
|
||
input_dim = 10
|
||
batch_1d = nn.BatchNorm1d(input_dim, affine=False).eval()
|
||
x = torch.randn(batch_size, input_dim)
|
||
flop_dict, _ = flop_count(batch_1d, (x, ))
|
||
gt_flop = batch_size * input_dim / 1e9
|
||
gt_dict = defaultdict(float)
|
||
gt_dict['batch_norm'] = gt_flop
|
||
self.assertDictEqual(
|
||
flop_dict, gt_dict,
|
||
'BatchNorm1d failed to pass the flop count test.')
|
||
|
||
# Test for BatchNorm2d.
|
||
batch_size = 10
|
||
input_dim = 10
|
||
spatial_dim_x = 5
|
||
spatial_dim_y = 5
|
||
batch_2d = nn.BatchNorm2d(input_dim, affine=False)
|
||
x = torch.randn(batch_size, input_dim, spatial_dim_x, spatial_dim_y)
|
||
flop_dict, _ = flop_count(batch_2d, (x, ))
|
||
gt_flop = 4 * batch_size * input_dim * spatial_dim_x * \
|
||
spatial_dim_y / 1e9
|
||
gt_dict = defaultdict(float)
|
||
gt_dict['batch_norm'] = gt_flop
|
||
self.assertDictEqual(
|
||
flop_dict, gt_dict,
|
||
'BatchNorm2d failed to pass the flop count test.')
|
||
|
||
# Test for BatchNorm3d.
|
||
batch_size = 10
|
||
input_dim = 10
|
||
spatial_dim_x = 5
|
||
spatial_dim_y = 5
|
||
spatial_dim_z = 5
|
||
batch_3d = nn.BatchNorm3d(input_dim, affine=False)
|
||
x = torch.randn(batch_size, input_dim, spatial_dim_x, spatial_dim_y,
|
||
spatial_dim_z)
|
||
flop_dict, _ = flop_count(batch_3d, (x, ))
|
||
gt_flop = (4 * batch_size * input_dim * spatial_dim_x * spatial_dim_y *
|
||
spatial_dim_z / 1e9)
|
||
gt_dict = defaultdict(float)
|
||
gt_dict['batch_norm'] = gt_flop
|
||
self.assertDictEqual(
|
||
flop_dict, gt_dict,
|
||
'BatchNorm3d failed to pass the flop count test.')
|
||
|
||
def test_threeNet(self) -> None:
|
||
"""Test a network with more than one layer.
|
||
|
||
The network has a convolution layer followed by two fully connected
|
||
layers.
|
||
"""
|
||
batch_size = 4
|
||
input_dim = 2
|
||
conv_dim = 5
|
||
spatial_dim = 10
|
||
linear_dim = 3
|
||
x = torch.randn(batch_size, input_dim, spatial_dim, spatial_dim)
|
||
three_net = ThreeNet(input_dim, conv_dim, linear_dim)
|
||
flop1 = batch_size * conv_dim * input_dim * spatial_dim * \
|
||
spatial_dim / 1e9
|
||
flop_linear1 = batch_size * conv_dim * linear_dim / 1e9
|
||
flop_linear2 = batch_size * linear_dim * 1 / 1e9
|
||
flop2 = flop_linear1 + flop_linear2
|
||
flop_dict, _ = flop_count(three_net, (x, ))
|
||
gt_dict = defaultdict(float)
|
||
gt_dict['conv'] = flop1
|
||
gt_dict[self.lin_op] = flop2
|
||
gt_dict['adaptive_avg_pool2d'] = 2e-6
|
||
self.assertDictEqual(
|
||
flop_dict,
|
||
gt_dict,
|
||
'The three-layer network failed to pass the flop count test.',
|
||
)
|
||
|
||
def test_flop_counter_class(self) -> None:
|
||
"""Test FlopAnalyzer."""
|
||
batch_size = 4
|
||
input_dim = 2
|
||
conv_dim = 5
|
||
spatial_dim = 10
|
||
linear_dim = 3
|
||
x = torch.randn(batch_size, input_dim, spatial_dim, spatial_dim)
|
||
three_net = ThreeNet(input_dim, conv_dim, linear_dim)
|
||
flop1 = batch_size * conv_dim * input_dim * spatial_dim * spatial_dim
|
||
flop_linear1 = batch_size * conv_dim * linear_dim
|
||
flop_linear2 = batch_size * linear_dim * 1
|
||
flop_counter = FlopAnalyzer(three_net, (x, ))
|
||
gt_dict = Counter({
|
||
'conv': flop1,
|
||
'linear1': flop_linear1,
|
||
'linear2': flop_linear2,
|
||
'pool': flop1 // input_dim,
|
||
})
|
||
gt_dict[''] = sum(gt_dict.values())
|
||
self.assertEqual(flop_counter.by_module(), gt_dict)
|
||
|
||
def test_autograd_function(self):
|
||
# test support on custom autograd function
|
||
|
||
class Mod(nn.Module):
|
||
|
||
def forward(self, x):
|
||
return _CustomOp.apply(x)
|
||
|
||
flop = FlopAnalyzer(Mod(), (torch.rand(4, 5), )).set_op_handle(
|
||
'prim::PythonOp._CustomOp', lambda *args, **kwargs: 42)
|
||
self.assertEqual(flop.total(), 42)
|
||
|
||
def test_scripted_function(self):
|
||
# Scripted function is not yet supported. It should produce a warning
|
||
|
||
def func(x):
|
||
return x @ x
|
||
|
||
class Mod(nn.Module):
|
||
|
||
def forward(self, x):
|
||
f = torch.jit.script(func)
|
||
return f(x * x)
|
||
|
||
flop = FlopAnalyzer(Mod(), (torch.rand(5, 5), ))
|
||
_ = flop.total()
|
||
self.assertIn('prim::CallFunction', flop.unsupported_ops())
|
||
|
||
|
||
class TestFlopCountHandles(unittest.TestCase):
|
||
|
||
def _count_function(self, func, inputs, name) -> Tuple[Any, Any]:
|
||
tensor_inputs = [x for x in inputs if isinstance(x, torch.Tensor)]
|
||
|
||
def f(*args):
|
||
return func(*inputs)
|
||
|
||
graph = torch.jit.trace(
|
||
f, tuple(tensor_inputs), check_trace=False).graph
|
||
nodes = [k for k in graph.nodes() if k.kind() == name]
|
||
self.assertEqual(len(nodes), 1)
|
||
node = nodes[0]
|
||
return list(node.inputs()), list(node.outputs())
|
||
|
||
def test_batch_norm(self):
|
||
op_name = 'aten::batch_norm'
|
||
counter = _DEFAULT_SUPPORTED_FLOP_OPS[op_name]
|
||
|
||
vec = torch.rand(2)
|
||
nodes = self._count_function(
|
||
F.batch_norm, (torch.rand(2, 2, 2, 2), vec, vec, vec, vec),
|
||
op_name)
|
||
self.assertEqual(counter(*nodes), 32)
|
||
|
||
nodes = self._count_function(
|
||
F.batch_norm,
|
||
(torch.rand(2, 2, 2, 2), vec, vec, None, None),
|
||
op_name,
|
||
)
|
||
self.assertEqual(counter(*nodes), 16)
|
||
|
||
nodes = self._count_function(
|
||
# training=True
|
||
F.batch_norm,
|
||
(torch.rand(2, 2, 2, 2), vec, vec, vec, vec, True),
|
||
op_name,
|
||
)
|
||
self.assertEqual(counter(*nodes), 80)
|
||
|
||
def test_group_norm(self):
|
||
op_name = 'aten::group_norm'
|
||
counter = _DEFAULT_SUPPORTED_FLOP_OPS[op_name]
|
||
|
||
vec = torch.rand(2)
|
||
nodes = self._count_function(F.group_norm,
|
||
(torch.rand(2, 2, 2, 2), 2, vec, vec),
|
||
op_name)
|
||
self.assertEqual(counter(*nodes), 80)
|
||
|
||
nodes = self._count_function(F.group_norm,
|
||
(torch.rand(2, 2, 2, 2), 2, None, None),
|
||
op_name)
|
||
self.assertEqual(counter(*nodes), 64)
|
||
|
||
def test_upsample(self):
|
||
op_name = 'aten::upsample_bilinear2d'
|
||
counter = _DEFAULT_SUPPORTED_FLOP_OPS[op_name]
|
||
|
||
nodes = self._count_function(
|
||
F.interpolate,
|
||
(torch.rand(2, 2, 2, 2), None, 2, 'bilinear', False), op_name)
|
||
self.assertEqual(counter(*nodes), 2**4 * 4 * 4)
|
||
|
||
def test_complicated_einsum(self):
|
||
op_name = 'aten::einsum'
|
||
counter = _DEFAULT_SUPPORTED_FLOP_OPS[op_name]
|
||
|
||
nodes = self._count_function(
|
||
torch.einsum,
|
||
('nc,nchw->hw', torch.rand(3, 4), torch.rand(3, 4, 2, 3)),
|
||
op_name,
|
||
)
|
||
self.assertEqual(counter(*nodes), 72.0)
|
||
|
||
def test_torch_mm(self):
|
||
for op_name, func in zip(['aten::mm', 'aten::matmul'],
|
||
[torch.mm, torch.matmul]):
|
||
counter = _DEFAULT_SUPPORTED_FLOP_OPS[op_name]
|
||
|
||
nodes = self._count_function(
|
||
func,
|
||
(torch.rand(3, 4), torch.rand(4, 5)),
|
||
op_name,
|
||
)
|
||
self.assertEqual(counter(*nodes), 60)
|