mmengine/tests/test_analysis/test_activation_count.py

150 lines
5.1 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
# Modified from
# https://github.com/facebookresearch/fvcore/blob/main/tests/test_activation_count.py
# pyre-ignore-all-errors[2]
import typing
import unittest
from collections import Counter, defaultdict
from typing import Any, Dict, List, Tuple
import torch
import torch.nn as nn
from numpy import prod
from mmengine.analysis import ActivationAnalyzer, activation_count
from mmengine.analysis.jit_handles import Handle
class SmallConvNet(nn.Module):
"""A network with three conv layers.
This is used for testing convolution layers for activation count.
"""
def __init__(self, input_dim: int) -> None:
super().__init__()
conv_dim1 = 8
conv_dim2 = 4
conv_dim3 = 2
self.conv1 = nn.Conv2d(input_dim, conv_dim1, 1, 1)
self.conv2 = nn.Conv2d(conv_dim1, conv_dim2, 1, 2)
self.conv3 = nn.Conv2d(conv_dim2, conv_dim3, 1, 2)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
return x
def get_gt_activation(self, x: torch.Tensor) -> Tuple[int, int, int]:
x = self.conv1(x)
count1 = prod(list(x.size()))
x = self.conv2(x)
count2 = prod(list(x.size()))
x = self.conv3(x)
count3 = prod(list(x.size()))
return count1, count2, count3
class TestActivationAnalyzer(unittest.TestCase):
"""Unittest for activation_count."""
def setUp(self) -> None:
# nn.Linear uses a different operator based on version, so make sure
# we are testing the right thing.
lin = nn.Linear(10, 10)
lin_x: torch.Tensor = torch.randn(10, 10)
trace = torch.jit.trace(lin, (lin_x, ))
node_kinds = [node.kind() for node in trace.graph.nodes()]
assert 'aten::addmm' in node_kinds or 'aten::linear' in node_kinds
if 'aten::addmm' in node_kinds:
self.lin_op = 'addmm'
else:
self.lin_op = 'linear'
def test_conv2d(self) -> None:
"""Test the activation count for convolutions."""
batch_size = 1
input_dim = 3
spatial_dim = 32
x = torch.randn(batch_size, input_dim, spatial_dim, spatial_dim)
conv_net = SmallConvNet(input_dim)
ac_dict, _ = activation_count(conv_net, (x, ))
gt_count = sum(conv_net.get_gt_activation(x))
gt_dict = defaultdict(float)
gt_dict['conv'] = gt_count / 1e6
self.assertDictEqual(
gt_dict,
ac_dict,
'conv_net with 3 layers failed to pass the activation count test.',
)
def test_linear(self) -> None:
"""Test the activation count for fully connected layer."""
batch_size = 1
input_dim = 10
output_dim = 20
linear = nn.Linear(input_dim, output_dim)
x = torch.randn(batch_size, input_dim)
ac_dict, _ = activation_count(linear, (x, ))
gt_count = batch_size * output_dim
gt_dict = defaultdict(float)
gt_dict[self.lin_op] = gt_count / 1e6
self.assertEqual(gt_dict, ac_dict,
'FC layer failed to pass the activation count test.')
def test_supported_ops(self) -> None:
"""Test the activation count for user provided handles."""
def dummy_handle(inputs: List[Any],
outputs: List[Any]) -> typing.Counter[str]:
return Counter({'conv': 100})
batch_size = 1
input_dim = 3
spatial_dim = 32
x = torch.randn(batch_size, input_dim, spatial_dim, spatial_dim)
conv_net = SmallConvNet(input_dim)
sp_ops: Dict[str, Handle] = {'aten::_convolution': dummy_handle}
ac_dict, _ = activation_count(conv_net, (x, ), sp_ops)
gt_dict = defaultdict(float)
conv_layers = 3
gt_dict['conv'] = 100 * conv_layers / 1e6
self.assertDictEqual(
gt_dict,
ac_dict,
'conv_net with 3 layers failed to pass the activation count test.',
)
def test_activation_count_class(self) -> None:
"""Tests ActivationAnalyzer."""
batch_size = 1
input_dim = 10
output_dim = 20
netLinear = nn.Linear(input_dim, output_dim)
x = torch.randn(batch_size, input_dim)
gt_count = batch_size * output_dim
gt_dict = Counter({
'': gt_count,
})
acts_counter = ActivationAnalyzer(netLinear, (x, ))
self.assertEqual(acts_counter.by_module(), gt_dict)
batch_size = 1
input_dim = 3
spatial_dim = 32
x = torch.randn(batch_size, input_dim, spatial_dim, spatial_dim)
conv_net = SmallConvNet(input_dim)
acts_counter = ActivationAnalyzer(conv_net, (x, ))
gt_counts = conv_net.get_gt_activation(x)
gt_dict = Counter({
'': sum(gt_counts),
'conv1': gt_counts[0],
'conv2': gt_counts[1],
'conv3': gt_counts[2],
})
self.assertDictEqual(gt_dict, acts_counter.by_module())