[Fix] Fix counter mapping bug (#331)
* fix counter mapping bug * move judgment into get_counter_type & update UTpull/304/head^2
parent
6659a34e33
commit
a6a337b6bc
|
@ -1,8 +1,9 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import sys
|
||||
from functools import partial
|
||||
from typing import Dict
|
||||
from typing import Dict, List
|
||||
|
||||
import mmcv
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
@ -497,9 +498,29 @@ def add_flops_params_counter_variable_or_reset(module):
|
|||
module.__params__ = 0
|
||||
|
||||
|
||||
def get_counter_type(module):
|
||||
"""Get counter type of the module based on the module class name."""
|
||||
return module.__class__.__name__ + 'Counter'
|
||||
def get_counter_type(module) -> str:
|
||||
"""Get counter type of the module based on the module class name.
|
||||
|
||||
If the current module counter_type is not in TASK_UTILS._module_dict,
|
||||
it will search the base classes of the module to see if it matches any
|
||||
base class counter_type.
|
||||
|
||||
Returns:
|
||||
str: Counter type (or the base counter type) of the current module.
|
||||
"""
|
||||
counter_type = module.__class__.__name__ + 'Counter'
|
||||
if counter_type not in TASK_UTILS._module_dict.keys():
|
||||
old_counter_type = counter_type
|
||||
assert nn.Module in module.__class__.mro()
|
||||
for base_cls in module.__class__.mro():
|
||||
if base_cls in get_modules_list():
|
||||
counter_type = base_cls.__name__ + 'Counter'
|
||||
from mmengine import MMLogger
|
||||
logger = MMLogger.get_current_instance()
|
||||
logger.warning(f'`{old_counter_type}` not in op_counters. '
|
||||
f'Using `{counter_type}` instead.')
|
||||
break
|
||||
return counter_type
|
||||
|
||||
|
||||
def is_supported_instance(module):
|
||||
|
@ -518,3 +539,54 @@ def remove_flops_params_counter_hook_function(module):
|
|||
del module.__flops__
|
||||
if hasattr(module, '__params__'):
|
||||
del module.__params__
|
||||
|
||||
|
||||
def get_modules_list() -> List:
|
||||
return [
|
||||
# convolutions
|
||||
nn.Conv1d,
|
||||
nn.Conv2d,
|
||||
nn.Conv3d,
|
||||
mmcv.cnn.bricks.Conv2d,
|
||||
mmcv.cnn.bricks.Conv3d,
|
||||
# activations
|
||||
nn.ReLU,
|
||||
nn.PReLU,
|
||||
nn.ELU,
|
||||
nn.LeakyReLU,
|
||||
nn.ReLU6,
|
||||
# poolings
|
||||
nn.MaxPool1d,
|
||||
nn.AvgPool1d,
|
||||
nn.AvgPool2d,
|
||||
nn.MaxPool2d,
|
||||
nn.MaxPool3d,
|
||||
nn.AvgPool3d,
|
||||
mmcv.cnn.bricks.MaxPool2d,
|
||||
mmcv.cnn.bricks.MaxPool3d,
|
||||
nn.AdaptiveMaxPool1d,
|
||||
nn.AdaptiveAvgPool1d,
|
||||
nn.AdaptiveMaxPool2d,
|
||||
nn.AdaptiveAvgPool2d,
|
||||
nn.AdaptiveMaxPool3d,
|
||||
nn.AdaptiveAvgPool3d,
|
||||
# normalizations
|
||||
nn.BatchNorm1d,
|
||||
nn.BatchNorm2d,
|
||||
nn.BatchNorm3d,
|
||||
nn.GroupNorm,
|
||||
nn.InstanceNorm1d,
|
||||
nn.InstanceNorm2d,
|
||||
nn.InstanceNorm3d,
|
||||
nn.LayerNorm,
|
||||
# FC
|
||||
nn.Linear,
|
||||
mmcv.cnn.bricks.Linear,
|
||||
# Upscale
|
||||
nn.Upsample,
|
||||
nn.UpsamplingNearest2d,
|
||||
nn.UpsamplingBilinear2d,
|
||||
# Deconvolution
|
||||
nn.ConvTranspose2d,
|
||||
mmcv.cnn.bricks.ConvTranspose2d,
|
||||
]
|
||||
|
|
|
@ -4,6 +4,7 @@ from unittest import TestCase
|
|||
|
||||
import pytest
|
||||
import torch
|
||||
from mmcv.cnn.bricks import Conv2dAdaptivePadding
|
||||
from torch import Tensor
|
||||
from torch.nn import Conv2d, Module, Parameter
|
||||
|
||||
|
@ -124,8 +125,17 @@ class TestResourceEstimator(TestCase):
|
|||
flops_count = results['flops']
|
||||
params_count = results['params']
|
||||
|
||||
self.assertGreater(flops_count, 0)
|
||||
self.assertGreater(params_count, 0)
|
||||
self.assertEqual(flops_count, 44.158)
|
||||
self.assertEqual(params_count, 0.001)
|
||||
|
||||
fool_conv2d = Conv2dAdaptivePadding(3, 32, 3)
|
||||
results = estimator.estimate(
|
||||
model=fool_conv2d, flops_params_cfg=flops_params_cfg)
|
||||
flops_count = results['flops']
|
||||
params_count = results['params']
|
||||
|
||||
self.assertEqual(flops_count, 44.958)
|
||||
self.assertEqual(params_count, 0.001)
|
||||
|
||||
def test_register_module(self) -> None:
|
||||
fool_add_constant = FoolConvModule()
|
||||
|
@ -151,6 +161,17 @@ class TestResourceEstimator(TestCase):
|
|||
self.assertLess(rest_flops_count, 45.158)
|
||||
self.assertLess(rest_params_count, 0.701)
|
||||
|
||||
fool_conv2d = Conv2dAdaptivePadding(3, 32, 3)
|
||||
flops_params_cfg = dict(
|
||||
input_shape=(1, 3, 224, 224), disabled_counters=['Conv2dCounter'])
|
||||
rest_results = estimator.estimate(
|
||||
model=fool_conv2d, flops_params_cfg=flops_params_cfg)
|
||||
rest_flops_count = rest_results['flops']
|
||||
rest_params_count = rest_results['params']
|
||||
|
||||
self.assertEqual(rest_flops_count, 0)
|
||||
self.assertEqual(rest_params_count, 0)
|
||||
|
||||
def test_estimate_spec_module(self) -> None:
|
||||
fool_add_constant = FoolConvModule()
|
||||
flops_params_cfg = dict(
|
||||
|
|
Loading…
Reference in New Issue