diff --git a/mmrazor/models/task_modules/estimators/counters/flops_params_counter.py b/mmrazor/models/task_modules/estimators/counters/flops_params_counter.py index f3120824..df0c867c 100644 --- a/mmrazor/models/task_modules/estimators/counters/flops_params_counter.py +++ b/mmrazor/models/task_modules/estimators/counters/flops_params_counter.py @@ -1,8 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. import sys from functools import partial -from typing import Dict +from typing import Dict, List +import mmcv import torch import torch.nn as nn @@ -497,9 +498,29 @@ def add_flops_params_counter_variable_or_reset(module): module.__params__ = 0 -def get_counter_type(module): - """Get counter type of the module based on the module class name.""" - return module.__class__.__name__ + 'Counter' +def get_counter_type(module) -> str: + """Get counter type of the module based on the module class name. + + If the current module counter_type is not in TASK_UTILS._module_dict, + it will search the base classes of the module to see if it matches any + base class counter_type. + + Returns: + str: Counter type (or the base counter type) of the current module. + """ + counter_type = module.__class__.__name__ + 'Counter' + if counter_type not in TASK_UTILS._module_dict.keys(): + old_counter_type = counter_type + assert nn.Module in module.__class__.mro() + for base_cls in module.__class__.mro(): + if base_cls in get_modules_list(): + counter_type = base_cls.__name__ + 'Counter' + from mmengine import MMLogger + logger = MMLogger.get_current_instance() + logger.warning(f'`{old_counter_type}` not in op_counters. ' + f'Using `{counter_type}` instead.') + break + return counter_type def is_supported_instance(module): @@ -518,3 +539,54 @@ def remove_flops_params_counter_hook_function(module): del module.__flops__ if hasattr(module, '__params__'): del module.__params__ + + +def get_modules_list() -> List: + return [ + # convolutions + nn.Conv1d, + nn.Conv2d, + nn.Conv3d, + mmcv.cnn.bricks.Conv2d, + mmcv.cnn.bricks.Conv3d, + # activations + nn.ReLU, + nn.PReLU, + nn.ELU, + nn.LeakyReLU, + nn.ReLU6, + # poolings + nn.MaxPool1d, + nn.AvgPool1d, + nn.AvgPool2d, + nn.MaxPool2d, + nn.MaxPool3d, + nn.AvgPool3d, + mmcv.cnn.bricks.MaxPool2d, + mmcv.cnn.bricks.MaxPool3d, + nn.AdaptiveMaxPool1d, + nn.AdaptiveAvgPool1d, + nn.AdaptiveMaxPool2d, + nn.AdaptiveAvgPool2d, + nn.AdaptiveMaxPool3d, + nn.AdaptiveAvgPool3d, + # normalizations + nn.BatchNorm1d, + nn.BatchNorm2d, + nn.BatchNorm3d, + nn.GroupNorm, + nn.InstanceNorm1d, + nn.InstanceNorm2d, + nn.InstanceNorm3d, + nn.LayerNorm, + # FC + nn.Linear, + mmcv.cnn.bricks.Linear, + # Upscale + nn.Upsample, + nn.UpsamplingNearest2d, + nn.UpsamplingBilinear2d, + # Deconvolution + nn.ConvTranspose2d, + mmcv.cnn.bricks.ConvTranspose2d, + ] diff --git a/tests/test_models/test_task_modules/test_estimators/test_flops_params.py b/tests/test_models/test_task_modules/test_estimators/test_flops_params.py index 60bcef4b..2acb58e9 100644 --- a/tests/test_models/test_task_modules/test_estimators/test_flops_params.py +++ b/tests/test_models/test_task_modules/test_estimators/test_flops_params.py @@ -4,6 +4,7 @@ from unittest import TestCase import pytest import torch +from mmcv.cnn.bricks import Conv2dAdaptivePadding from torch import Tensor from torch.nn import Conv2d, Module, Parameter @@ -124,8 +125,17 @@ class TestResourceEstimator(TestCase): flops_count = results['flops'] params_count = results['params'] - self.assertGreater(flops_count, 0) - self.assertGreater(params_count, 0) + self.assertEqual(flops_count, 44.158) + self.assertEqual(params_count, 0.001) + + fool_conv2d = Conv2dAdaptivePadding(3, 32, 3) + results = estimator.estimate( + model=fool_conv2d, flops_params_cfg=flops_params_cfg) + flops_count = results['flops'] + params_count = results['params'] + + self.assertEqual(flops_count, 44.958) + self.assertEqual(params_count, 0.001) def test_register_module(self) -> None: fool_add_constant = FoolConvModule() @@ -151,6 +161,17 @@ class TestResourceEstimator(TestCase): self.assertLess(rest_flops_count, 45.158) self.assertLess(rest_params_count, 0.701) + fool_conv2d = Conv2dAdaptivePadding(3, 32, 3) + flops_params_cfg = dict( + input_shape=(1, 3, 224, 224), disabled_counters=['Conv2dCounter']) + rest_results = estimator.estimate( + model=fool_conv2d, flops_params_cfg=flops_params_cfg) + rest_flops_count = rest_results['flops'] + rest_params_count = rest_results['params'] + + self.assertEqual(rest_flops_count, 0) + self.assertEqual(rest_params_count, 0) + def test_estimate_spec_module(self) -> None: fool_add_constant = FoolConvModule() flops_params_cfg = dict(