mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
* [Feature] Add support model complexity computation * [Fix] fix lint error * [Feature] update print_helper * Update docstring * update api, docs, fix lint * fix lint * update doc and add test * update docstring * update docstring * update test * Update mmengine/analysis/print_helper.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update mmengine/analysis/print_helper.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update mmengine/analysis/print_helper.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update mmengine/analysis/print_helper.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update mmengine/analysis/print_helper.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update mmengine/analysis/print_helper.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update mmengine/analysis/print_helper.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update mmengine/analysis/print_helper.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update mmengine/analysis/print_helper.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update mmengine/analysis/print_helper.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update mmengine/analysis/print_helper.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update mmengine/analysis/print_helper.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update mmengine/analysis/print_helper.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update mmengine/analysis/print_helper.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update mmengine/analysis/complexity_analysis.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update docs/en/advanced_tutorials/model_analysis.md Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update docs/en/advanced_tutorials/model_analysis.md Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * update docs * update docs * update docs and docstring * update docs * update test withj mmlogger * Update mmengine/analysis/complexity_analysis.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update tests/test_analysis/test_activation_count.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Apply suggestions from code review Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * update test according to review * Apply suggestions from code review Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * fix lint * fix test * Apply suggestions from code review * fix API document * Update analysis.rst * rename variables * minor refinement * Apply suggestions from code review * fix lint * replace tabulate with existing rich * Apply suggestions from code review * indent * Update mmengine/analysis/complexity_analysis.py * Update mmengine/analysis/complexity_analysis.py * Update mmengine/analysis/complexity_analysis.py --------- Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> Co-authored-by: zhouzaida <zhouzaida@163.com>
52 lines
1.6 KiB
Python
52 lines
1.6 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
|
|
# Modified from
|
|
# https://github.com/facebookresearch/fvcore/blob/main/tests/test_param_count.py
|
|
|
|
import unittest
|
|
|
|
from torch import nn
|
|
|
|
from mmengine.analysis.complexity_analysis import (parameter_count,
|
|
parameter_count_table)
|
|
|
|
|
|
class NetWithReuse(nn.Module):
|
|
|
|
def __init__(self, reuse: bool = False) -> None:
|
|
super().__init__()
|
|
self.conv1 = nn.Conv2d(100, 100, 3)
|
|
self.conv2 = nn.Conv2d(100, 100, 3)
|
|
if reuse:
|
|
self.conv2.weight = self.conv1.weight
|
|
|
|
|
|
class NetWithDupPrefix(nn.Module):
|
|
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.conv1 = nn.Conv2d(100, 100, 3)
|
|
self.conv111 = nn.Conv2d(100, 100, 3)
|
|
|
|
|
|
class TestParamCount(unittest.TestCase):
|
|
|
|
def test_param(self) -> None:
|
|
net = NetWithReuse()
|
|
count = parameter_count(net)
|
|
self.assertTrue(count[''], 180200)
|
|
self.assertTrue(count['conv2'], 90100)
|
|
|
|
def test_param_with_reuse(self) -> None:
|
|
net = NetWithReuse(reuse=True)
|
|
count = parameter_count(net)
|
|
self.assertTrue(count[''], 90200)
|
|
self.assertTrue(count['conv2'], 100)
|
|
|
|
def test_param_with_same_prefix(self) -> None:
|
|
net = NetWithDupPrefix()
|
|
table = parameter_count_table(net)
|
|
c = ['conv111.weight' in line for line in table.split('\n')]
|
|
self.assertEqual(
|
|
sum(c), 1) # it only appears once, despite being a prefix of conv1
|