mmengine/tests/test_analysis/test_param_count.py

52 lines
1.6 KiB
Python
Raw Normal View History

[Feature] Support model complexity computation (#779) * [Feature] Add support model complexity computation * [Fix] fix lint error * [Feature] update print_helper * Update docstring * update api, docs, fix lint * fix lint * update doc and add test * update docstring * update docstring * update test * Update mmengine/analysis/print_helper.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update mmengine/analysis/print_helper.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update mmengine/analysis/print_helper.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update mmengine/analysis/print_helper.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update mmengine/analysis/print_helper.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update mmengine/analysis/print_helper.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update mmengine/analysis/print_helper.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update mmengine/analysis/print_helper.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update mmengine/analysis/print_helper.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update mmengine/analysis/print_helper.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update mmengine/analysis/print_helper.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update mmengine/analysis/print_helper.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update mmengine/analysis/print_helper.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update mmengine/analysis/print_helper.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update mmengine/analysis/complexity_analysis.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update docs/en/advanced_tutorials/model_analysis.md Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update docs/en/advanced_tutorials/model_analysis.md Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * update docs * update docs * update docs and docstring * update docs * update test withj mmlogger * Update mmengine/analysis/complexity_analysis.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update tests/test_analysis/test_activation_count.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Apply suggestions from code review Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * update test according to review * Apply suggestions from code review Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * fix lint * fix test * Apply suggestions from code review * fix API document * Update analysis.rst * rename variables * minor refinement * Apply suggestions from code review * fix lint * replace tabulate with existing rich * Apply suggestions from code review * indent * Update mmengine/analysis/complexity_analysis.py * Update mmengine/analysis/complexity_analysis.py * Update mmengine/analysis/complexity_analysis.py --------- Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> Co-authored-by: zhouzaida <zhouzaida@163.com>
2023-02-20 15:00:28 +08:00
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
# Modified from
# https://github.com/facebookresearch/fvcore/blob/main/tests/test_param_count.py
import unittest
from torch import nn
from mmengine.analysis.complexity_analysis import (parameter_count,
parameter_count_table)
class NetWithReuse(nn.Module):
def __init__(self, reuse: bool = False) -> None:
super().__init__()
self.conv1 = nn.Conv2d(100, 100, 3)
self.conv2 = nn.Conv2d(100, 100, 3)
if reuse:
self.conv2.weight = self.conv1.weight
class NetWithDupPrefix(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(100, 100, 3)
self.conv111 = nn.Conv2d(100, 100, 3)
class TestParamCount(unittest.TestCase):
def test_param(self) -> None:
net = NetWithReuse()
count = parameter_count(net)
self.assertTrue(count[''], 180200)
self.assertTrue(count['conv2'], 90100)
def test_param_with_reuse(self) -> None:
net = NetWithReuse(reuse=True)
count = parameter_count(net)
self.assertTrue(count[''], 90200)
self.assertTrue(count['conv2'], 100)
def test_param_with_same_prefix(self) -> None:
net = NetWithDupPrefix()
table = parameter_count_table(net)
c = ['conv111.weight' in line for line in table.split('\n')]
self.assertEqual(
sum(c), 1) # it only appears once, despite being a prefix of conv1