[Fix] Fix counter mapping bug (#331)

* fix counter mapping bug

* move judgment into get_counter_type & update UT
pull/304/head^2
Yang Gao 2022-10-24 11:02:10 +08:00 committed by GitHub
parent 6659a34e33
commit a6a337b6bc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 99 additions and 6 deletions

View File

@ -1,8 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
import sys
from functools import partial
from typing import Dict
from typing import Dict, List
import mmcv
import torch
import torch.nn as nn
@ -497,9 +498,29 @@ def add_flops_params_counter_variable_or_reset(module):
module.__params__ = 0
def get_counter_type(module):
"""Get counter type of the module based on the module class name."""
return module.__class__.__name__ + 'Counter'
def get_counter_type(module) -> str:
"""Get counter type of the module based on the module class name.
If the current module counter_type is not in TASK_UTILS._module_dict,
it will search the base classes of the module to see if it matches any
base class counter_type.
Returns:
str: Counter type (or the base counter type) of the current module.
"""
counter_type = module.__class__.__name__ + 'Counter'
if counter_type not in TASK_UTILS._module_dict.keys():
old_counter_type = counter_type
assert nn.Module in module.__class__.mro()
for base_cls in module.__class__.mro():
if base_cls in get_modules_list():
counter_type = base_cls.__name__ + 'Counter'
from mmengine import MMLogger
logger = MMLogger.get_current_instance()
logger.warning(f'`{old_counter_type}` not in op_counters. '
f'Using `{counter_type}` instead.')
break
return counter_type
def is_supported_instance(module):
@ -518,3 +539,54 @@ def remove_flops_params_counter_hook_function(module):
del module.__flops__
if hasattr(module, '__params__'):
del module.__params__
def get_modules_list() -> List:
return [
# convolutions
nn.Conv1d,
nn.Conv2d,
nn.Conv3d,
mmcv.cnn.bricks.Conv2d,
mmcv.cnn.bricks.Conv3d,
# activations
nn.ReLU,
nn.PReLU,
nn.ELU,
nn.LeakyReLU,
nn.ReLU6,
# poolings
nn.MaxPool1d,
nn.AvgPool1d,
nn.AvgPool2d,
nn.MaxPool2d,
nn.MaxPool3d,
nn.AvgPool3d,
mmcv.cnn.bricks.MaxPool2d,
mmcv.cnn.bricks.MaxPool3d,
nn.AdaptiveMaxPool1d,
nn.AdaptiveAvgPool1d,
nn.AdaptiveMaxPool2d,
nn.AdaptiveAvgPool2d,
nn.AdaptiveMaxPool3d,
nn.AdaptiveAvgPool3d,
# normalizations
nn.BatchNorm1d,
nn.BatchNorm2d,
nn.BatchNorm3d,
nn.GroupNorm,
nn.InstanceNorm1d,
nn.InstanceNorm2d,
nn.InstanceNorm3d,
nn.LayerNorm,
# FC
nn.Linear,
mmcv.cnn.bricks.Linear,
# Upscale
nn.Upsample,
nn.UpsamplingNearest2d,
nn.UpsamplingBilinear2d,
# Deconvolution
nn.ConvTranspose2d,
mmcv.cnn.bricks.ConvTranspose2d,
]

View File

@ -4,6 +4,7 @@ from unittest import TestCase
import pytest
import torch
from mmcv.cnn.bricks import Conv2dAdaptivePadding
from torch import Tensor
from torch.nn import Conv2d, Module, Parameter
@ -124,8 +125,17 @@ class TestResourceEstimator(TestCase):
flops_count = results['flops']
params_count = results['params']
self.assertGreater(flops_count, 0)
self.assertGreater(params_count, 0)
self.assertEqual(flops_count, 44.158)
self.assertEqual(params_count, 0.001)
fool_conv2d = Conv2dAdaptivePadding(3, 32, 3)
results = estimator.estimate(
model=fool_conv2d, flops_params_cfg=flops_params_cfg)
flops_count = results['flops']
params_count = results['params']
self.assertEqual(flops_count, 44.958)
self.assertEqual(params_count, 0.001)
def test_register_module(self) -> None:
fool_add_constant = FoolConvModule()
@ -151,6 +161,17 @@ class TestResourceEstimator(TestCase):
self.assertLess(rest_flops_count, 45.158)
self.assertLess(rest_params_count, 0.701)
fool_conv2d = Conv2dAdaptivePadding(3, 32, 3)
flops_params_cfg = dict(
input_shape=(1, 3, 224, 224), disabled_counters=['Conv2dCounter'])
rest_results = estimator.estimate(
model=fool_conv2d, flops_params_cfg=flops_params_cfg)
rest_flops_count = rest_results['flops']
rest_params_count = rest_results['params']
self.assertEqual(rest_flops_count, 0)
self.assertEqual(rest_params_count, 0)
def test_estimate_spec_module(self) -> None:
fool_add_constant = FoolConvModule()
flops_params_cfg = dict(