mmengine/tests/test_analysis/test_param_count.py

52 lines
1.6 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
# Modified from
# https://github.com/facebookresearch/fvcore/blob/main/tests/test_param_count.py
import unittest
from torch import nn
from mmengine.analysis.complexity_analysis import (parameter_count,
parameter_count_table)
class NetWithReuse(nn.Module):
def __init__(self, reuse: bool = False) -> None:
super().__init__()
self.conv1 = nn.Conv2d(100, 100, 3)
self.conv2 = nn.Conv2d(100, 100, 3)
if reuse:
self.conv2.weight = self.conv1.weight
class NetWithDupPrefix(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv1 = nn.Conv2d(100, 100, 3)
self.conv111 = nn.Conv2d(100, 100, 3)
class TestParamCount(unittest.TestCase):
def test_param(self) -> None:
net = NetWithReuse()
count = parameter_count(net)
self.assertTrue(count[''], 180200)
self.assertTrue(count['conv2'], 90100)
def test_param_with_reuse(self) -> None:
net = NetWithReuse(reuse=True)
count = parameter_count(net)
self.assertTrue(count[''], 90200)
self.assertTrue(count['conv2'], 100)
def test_param_with_same_prefix(self) -> None:
net = NetWithDupPrefix()
table = parameter_count_table(net)
c = ['conv111.weight' in line for line in table.split('\n')]
self.assertEqual(
sum(c), 1) # it only appears once, despite being a prefix of conv1