mmengine/tests/test_analysis/test_activation_count.py

150 lines
5.1 KiB
Python
Raw Normal View History

[Feature] Support model complexity computation (#779) * [Feature] Add support model complexity computation * [Fix] fix lint error * [Feature] update print_helper * Update docstring * update api, docs, fix lint * fix lint * update doc and add test * update docstring * update docstring * update test * Update mmengine/analysis/print_helper.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update mmengine/analysis/print_helper.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update mmengine/analysis/print_helper.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update mmengine/analysis/print_helper.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update mmengine/analysis/print_helper.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update mmengine/analysis/print_helper.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update mmengine/analysis/print_helper.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update mmengine/analysis/print_helper.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update mmengine/analysis/print_helper.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update mmengine/analysis/print_helper.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update mmengine/analysis/print_helper.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update mmengine/analysis/print_helper.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update mmengine/analysis/print_helper.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update mmengine/analysis/print_helper.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update mmengine/analysis/complexity_analysis.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update docs/en/advanced_tutorials/model_analysis.md Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update docs/en/advanced_tutorials/model_analysis.md Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * update docs * update docs * update docs and docstring * update docs * update test withj mmlogger * Update mmengine/analysis/complexity_analysis.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update tests/test_analysis/test_activation_count.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Apply suggestions from code review Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * update test according to review * Apply suggestions from code review Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * fix lint * fix test * Apply suggestions from code review * fix API document * Update analysis.rst * rename variables * minor refinement * Apply suggestions from code review * fix lint * replace tabulate with existing rich * Apply suggestions from code review * indent * Update mmengine/analysis/complexity_analysis.py * Update mmengine/analysis/complexity_analysis.py * Update mmengine/analysis/complexity_analysis.py --------- Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> Co-authored-by: zhouzaida <zhouzaida@163.com>
2023-02-20 15:00:28 +08:00
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
# Modified from
# https://github.com/facebookresearch/fvcore/blob/main/tests/test_activation_count.py
# pyre-ignore-all-errors[2]
import typing
import unittest
from collections import Counter, defaultdict
from typing import Any, Dict, List, Tuple
import torch
import torch.nn as nn
from numpy import prod
from mmengine.analysis import ActivationAnalyzer, activation_count
from mmengine.analysis.jit_handles import Handle
class SmallConvNet(nn.Module):
"""A network with three conv layers.
This is used for testing convolution layers for activation count.
"""
def __init__(self, input_dim: int) -> None:
super().__init__()
conv_dim1 = 8
conv_dim2 = 4
conv_dim3 = 2
self.conv1 = nn.Conv2d(input_dim, conv_dim1, 1, 1)
self.conv2 = nn.Conv2d(conv_dim1, conv_dim2, 1, 2)
self.conv3 = nn.Conv2d(conv_dim2, conv_dim3, 1, 2)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
return x
def get_gt_activation(self, x: torch.Tensor) -> Tuple[int, int, int]:
x = self.conv1(x)
count1 = prod(list(x.size()))
x = self.conv2(x)
count2 = prod(list(x.size()))
x = self.conv3(x)
count3 = prod(list(x.size()))
return count1, count2, count3
class TestActivationAnalyzer(unittest.TestCase):
"""Unittest for activation_count."""
def setUp(self) -> None:
# nn.Linear uses a different operator based on version, so make sure
# we are testing the right thing.
lin = nn.Linear(10, 10)
lin_x: torch.Tensor = torch.randn(10, 10)
trace = torch.jit.trace(lin, (lin_x, ))
node_kinds = [node.kind() for node in trace.graph.nodes()]
assert 'aten::addmm' in node_kinds or 'aten::linear' in node_kinds
if 'aten::addmm' in node_kinds:
self.lin_op = 'addmm'
else:
self.lin_op = 'linear'
def test_conv2d(self) -> None:
"""Test the activation count for convolutions."""
batch_size = 1
input_dim = 3
spatial_dim = 32
x = torch.randn(batch_size, input_dim, spatial_dim, spatial_dim)
conv_net = SmallConvNet(input_dim)
ac_dict, _ = activation_count(conv_net, (x, ))
gt_count = sum(conv_net.get_gt_activation(x))
gt_dict = defaultdict(float)
gt_dict['conv'] = gt_count / 1e6
self.assertDictEqual(
gt_dict,
ac_dict,
'conv_net with 3 layers failed to pass the activation count test.',
)
def test_linear(self) -> None:
"""Test the activation count for fully connected layer."""
batch_size = 1
input_dim = 10
output_dim = 20
linear = nn.Linear(input_dim, output_dim)
x = torch.randn(batch_size, input_dim)
ac_dict, _ = activation_count(linear, (x, ))
gt_count = batch_size * output_dim
gt_dict = defaultdict(float)
gt_dict[self.lin_op] = gt_count / 1e6
self.assertEqual(gt_dict, ac_dict,
'FC layer failed to pass the activation count test.')
def test_supported_ops(self) -> None:
"""Test the activation count for user provided handles."""
def dummy_handle(inputs: List[Any],
outputs: List[Any]) -> typing.Counter[str]:
return Counter({'conv': 100})
batch_size = 1
input_dim = 3
spatial_dim = 32
x = torch.randn(batch_size, input_dim, spatial_dim, spatial_dim)
conv_net = SmallConvNet(input_dim)
sp_ops: Dict[str, Handle] = {'aten::_convolution': dummy_handle}
ac_dict, _ = activation_count(conv_net, (x, ), sp_ops)
gt_dict = defaultdict(float)
conv_layers = 3
gt_dict['conv'] = 100 * conv_layers / 1e6
self.assertDictEqual(
gt_dict,
ac_dict,
'conv_net with 3 layers failed to pass the activation count test.',
)
def test_activation_count_class(self) -> None:
"""Tests ActivationAnalyzer."""
batch_size = 1
input_dim = 10
output_dim = 20
netLinear = nn.Linear(input_dim, output_dim)
x = torch.randn(batch_size, input_dim)
gt_count = batch_size * output_dim
gt_dict = Counter({
'': gt_count,
})
acts_counter = ActivationAnalyzer(netLinear, (x, ))
self.assertEqual(acts_counter.by_module(), gt_dict)
batch_size = 1
input_dim = 3
spatial_dim = 32
x = torch.randn(batch_size, input_dim, spatial_dim, spatial_dim)
conv_net = SmallConvNet(input_dim)
acts_counter = ActivationAnalyzer(conv_net, (x, ))
gt_counts = conv_net.get_gt_activation(x)
gt_dict = Counter({
'': sum(gt_counts),
'conv1': gt_counts[0],
'conv2': gt_counts[1],
'conv3': gt_counts[2],
})
self.assertDictEqual(gt_dict, acts_counter.by_module())