150 lines
5.1 KiB
Python
150 lines
5.1 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
|
# Modified from
|
|
# https://github.com/facebookresearch/fvcore/blob/main/tests/test_activation_count.py
|
|
# pyre-ignore-all-errors[2]
|
|
import typing
|
|
import unittest
|
|
from collections import Counter, defaultdict
|
|
from typing import Any, Dict, List, Tuple
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from numpy import prod
|
|
|
|
from mmengine.analysis import ActivationAnalyzer, activation_count
|
|
from mmengine.analysis.jit_handles import Handle
|
|
|
|
|
|
class SmallConvNet(nn.Module):
|
|
"""A network with three conv layers.
|
|
|
|
This is used for testing convolution layers for activation count.
|
|
"""
|
|
|
|
def __init__(self, input_dim: int) -> None:
|
|
super().__init__()
|
|
conv_dim1 = 8
|
|
conv_dim2 = 4
|
|
conv_dim3 = 2
|
|
self.conv1 = nn.Conv2d(input_dim, conv_dim1, 1, 1)
|
|
self.conv2 = nn.Conv2d(conv_dim1, conv_dim2, 1, 2)
|
|
self.conv3 = nn.Conv2d(conv_dim2, conv_dim3, 1, 2)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
x = self.conv1(x)
|
|
x = self.conv2(x)
|
|
x = self.conv3(x)
|
|
return x
|
|
|
|
def get_gt_activation(self, x: torch.Tensor) -> Tuple[int, int, int]:
|
|
x = self.conv1(x)
|
|
count1 = prod(list(x.size()))
|
|
x = self.conv2(x)
|
|
count2 = prod(list(x.size()))
|
|
x = self.conv3(x)
|
|
count3 = prod(list(x.size()))
|
|
return count1, count2, count3
|
|
|
|
|
|
class TestActivationAnalyzer(unittest.TestCase):
|
|
"""Unittest for activation_count."""
|
|
|
|
def setUp(self) -> None:
|
|
# nn.Linear uses a different operator based on version, so make sure
|
|
# we are testing the right thing.
|
|
lin = nn.Linear(10, 10)
|
|
lin_x: torch.Tensor = torch.randn(10, 10)
|
|
trace = torch.jit.trace(lin, (lin_x, ))
|
|
node_kinds = [node.kind() for node in trace.graph.nodes()]
|
|
assert 'aten::addmm' in node_kinds or 'aten::linear' in node_kinds
|
|
if 'aten::addmm' in node_kinds:
|
|
self.lin_op = 'addmm'
|
|
else:
|
|
self.lin_op = 'linear'
|
|
|
|
def test_conv2d(self) -> None:
|
|
"""Test the activation count for convolutions."""
|
|
batch_size = 1
|
|
input_dim = 3
|
|
spatial_dim = 32
|
|
x = torch.randn(batch_size, input_dim, spatial_dim, spatial_dim)
|
|
conv_net = SmallConvNet(input_dim)
|
|
ac_dict, _ = activation_count(conv_net, (x, ))
|
|
gt_count = sum(conv_net.get_gt_activation(x))
|
|
|
|
gt_dict = defaultdict(float)
|
|
gt_dict['conv'] = gt_count / 1e6
|
|
self.assertDictEqual(
|
|
gt_dict,
|
|
ac_dict,
|
|
'conv_net with 3 layers failed to pass the activation count test.',
|
|
)
|
|
|
|
def test_linear(self) -> None:
|
|
"""Test the activation count for fully connected layer."""
|
|
batch_size = 1
|
|
input_dim = 10
|
|
output_dim = 20
|
|
linear = nn.Linear(input_dim, output_dim)
|
|
x = torch.randn(batch_size, input_dim)
|
|
ac_dict, _ = activation_count(linear, (x, ))
|
|
gt_count = batch_size * output_dim
|
|
gt_dict = defaultdict(float)
|
|
gt_dict[self.lin_op] = gt_count / 1e6
|
|
self.assertEqual(gt_dict, ac_dict,
|
|
'FC layer failed to pass the activation count test.')
|
|
|
|
def test_supported_ops(self) -> None:
|
|
"""Test the activation count for user provided handles."""
|
|
|
|
def dummy_handle(inputs: List[Any],
|
|
outputs: List[Any]) -> typing.Counter[str]:
|
|
return Counter({'conv': 100})
|
|
|
|
batch_size = 1
|
|
input_dim = 3
|
|
spatial_dim = 32
|
|
x = torch.randn(batch_size, input_dim, spatial_dim, spatial_dim)
|
|
conv_net = SmallConvNet(input_dim)
|
|
sp_ops: Dict[str, Handle] = {'aten::_convolution': dummy_handle}
|
|
ac_dict, _ = activation_count(conv_net, (x, ), sp_ops)
|
|
gt_dict = defaultdict(float)
|
|
conv_layers = 3
|
|
gt_dict['conv'] = 100 * conv_layers / 1e6
|
|
self.assertDictEqual(
|
|
gt_dict,
|
|
ac_dict,
|
|
'conv_net with 3 layers failed to pass the activation count test.',
|
|
)
|
|
|
|
def test_activation_count_class(self) -> None:
|
|
"""Tests ActivationAnalyzer."""
|
|
batch_size = 1
|
|
input_dim = 10
|
|
output_dim = 20
|
|
netLinear = nn.Linear(input_dim, output_dim)
|
|
x = torch.randn(batch_size, input_dim)
|
|
gt_count = batch_size * output_dim
|
|
gt_dict = Counter({
|
|
'': gt_count,
|
|
})
|
|
acts_counter = ActivationAnalyzer(netLinear, (x, ))
|
|
self.assertEqual(acts_counter.by_module(), gt_dict)
|
|
|
|
batch_size = 1
|
|
input_dim = 3
|
|
spatial_dim = 32
|
|
x = torch.randn(batch_size, input_dim, spatial_dim, spatial_dim)
|
|
conv_net = SmallConvNet(input_dim)
|
|
acts_counter = ActivationAnalyzer(conv_net, (x, ))
|
|
gt_counts = conv_net.get_gt_activation(x)
|
|
gt_dict = Counter({
|
|
'': sum(gt_counts),
|
|
'conv1': gt_counts[0],
|
|
'conv2': gt_counts[1],
|
|
'conv3': gt_counts[2],
|
|
})
|
|
|
|
self.assertDictEqual(gt_dict, acts_counter.by_module())
|