From 4b3f8ab69e4572fdb746c00d4d267858a709e7f1 Mon Sep 17 00:00:00 2001 From: Yang Gao Date: Tue, 23 Aug 2022 15:01:47 +0800 Subject: [PATCH] [Feature] Refactor Estimator for computing FLOPs/Params/Latency. (#230) * Refactor ModelEstimator: 1. add EvaluatorLoop in engine.runners; 2. add estimator for structures (both subnet & supernet); 3. add layer_counter for each op. * fix lint * update estimator: 1. add ResourceEstimator based on BaseEstimator; 2. add notes & examples for ResourceEstimator & EvaluatorLoop usage; 3. fix a bug of latency test. 4. minor changes according to comments. * add UT & fix a bug caused by UT * add docstrings & remove old estimator * update docstrings for op_spec_counters * rename resource_evaluator_val_loop * support adding resource attrs of each submodule in a measured model * fix lint * refactor estimator file structures * support estimating resources for spec modules * rm old UT * update new estimator UT cases * fix traversal range of the model * cancel unit convert in accumulate_sub_module_flops_params * use estimator_cfg to build ResourceEstimator * fix a broadcast bug * delete fixed input_shape * add assertion and string-format-return when measuring spec_modules * add UT for estimating spec_modules --- mmrazor/engine/__init__.py | 5 +- mmrazor/engine/hooks/__init__.py | 3 +- .../engine/hooks/estimate_resources_hook.py | 117 +++++ mmrazor/engine/runner/autoslim_val_loop.py | 2 +- .../engine/runner/evolution_search_loop.py | 19 +- mmrazor/engine/runner/slimmable_val_loop.py | 2 +- mmrazor/engine/runner/subnet_sampler_loop.py | 21 +- mmrazor/models/__init__.py | 1 + mmrazor/models/task_modules/__init__.py | 4 + .../task_modules/estimators/__init__.py | 5 + .../task_modules/estimators/base_estimator.py | 55 ++ .../estimators/counters/__init__.py | 10 + .../counters/flops_params_counter.py | 472 ++++++++++++++++++ .../estimators/counters/latency_counter.py | 88 ++++ .../counters/op_counters/__init__.py | 22 + .../op_counters/activation_layer_counter.py | 35 ++ .../counters/op_counters/base_counter.py | 28 ++ .../op_counters/conv_layer_counter.py | 57 +++ .../op_counters/deconv_layer_counter.py | 38 ++ .../op_counters/linear_layer_counter.py | 19 + .../op_counters/norm_layer_counter.py | 59 +++ .../op_counters/pooling_layer_counter.py | 76 +++ .../op_counters/upsample_layer_counter.py | 19 + .../estimators/resource_estimator.py | 155 ++++++ mmrazor/registry/registry.py | 2 +- mmrazor/structures/subnet/__init__.py | 5 +- .../structures/subnet/estimators/__init__.py | 4 - mmrazor/structures/subnet/estimators/flops.py | 270 ---------- mmrazor/utils/setup_env.py | 13 +- .../test_subnet/test_estimators/test_flops.py | 231 --------- .../test_estimators/test_flops_params.py | 197 ++++++++ .../test_evolution_search_loop.py | 21 +- .../test_runners/test_subnet_sampler_loop.py | 19 +- 33 files changed, 1524 insertions(+), 550 deletions(-) create mode 100644 mmrazor/engine/hooks/estimate_resources_hook.py create mode 100644 mmrazor/models/task_modules/__init__.py create mode 100644 mmrazor/models/task_modules/estimators/__init__.py create mode 100644 mmrazor/models/task_modules/estimators/base_estimator.py create mode 100644 mmrazor/models/task_modules/estimators/counters/__init__.py create mode 100644 mmrazor/models/task_modules/estimators/counters/flops_params_counter.py create mode 100644 mmrazor/models/task_modules/estimators/counters/latency_counter.py create mode 100644 mmrazor/models/task_modules/estimators/counters/op_counters/__init__.py create mode 100644 mmrazor/models/task_modules/estimators/counters/op_counters/activation_layer_counter.py create mode 100644 mmrazor/models/task_modules/estimators/counters/op_counters/base_counter.py create mode 100644 mmrazor/models/task_modules/estimators/counters/op_counters/conv_layer_counter.py create mode 100644 mmrazor/models/task_modules/estimators/counters/op_counters/deconv_layer_counter.py create mode 100644 mmrazor/models/task_modules/estimators/counters/op_counters/linear_layer_counter.py create mode 100644 mmrazor/models/task_modules/estimators/counters/op_counters/norm_layer_counter.py create mode 100644 mmrazor/models/task_modules/estimators/counters/op_counters/pooling_layer_counter.py create mode 100644 mmrazor/models/task_modules/estimators/counters/op_counters/upsample_layer_counter.py create mode 100644 mmrazor/models/task_modules/estimators/resource_estimator.py delete mode 100644 mmrazor/structures/subnet/estimators/__init__.py delete mode 100644 mmrazor/structures/subnet/estimators/flops.py delete mode 100644 tests/test_models/test_subnet/test_estimators/test_flops.py create mode 100644 tests/test_models/test_task_modules/test_estimators/test_flops_params.py diff --git a/mmrazor/engine/__init__.py b/mmrazor/engine/__init__.py index ce464dfd..fd221b57 100644 --- a/mmrazor/engine/__init__.py +++ b/mmrazor/engine/__init__.py @@ -1,5 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .hooks import DumpSubnetHook +from .hooks import DumpSubnetHook, EstimateResourcesHook from .optimizers import SeparateOptimWrapperConstructor from .runner import (AutoSlimValLoop, DartsEpochBasedTrainLoop, DartsIterBasedTrainLoop, EvolutionSearchLoop, @@ -10,5 +10,6 @@ __all__ = [ 'SeparateOptimWrapperConstructor', 'DumpSubnetHook', 'SingleTeacherDistillValLoop', 'DartsEpochBasedTrainLoop', 'DartsIterBasedTrainLoop', 'SlimmableValLoop', 'EvolutionSearchLoop', - 'GreedySamplerTrainLoop', 'AutoSlimValLoop', 'SelfDistillValLoop' + 'GreedySamplerTrainLoop', 'AutoSlimValLoop', 'SelfDistillValLoop', + 'EstimateResourcesHook' ] diff --git a/mmrazor/engine/hooks/__init__.py b/mmrazor/engine/hooks/__init__.py index 48013881..2fc3cc12 100644 --- a/mmrazor/engine/hooks/__init__.py +++ b/mmrazor/engine/hooks/__init__.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. from .dump_subnet_hook import DumpSubnetHook +from .estimate_resources_hook import EstimateResourcesHook -__all__ = ['DumpSubnetHook'] +__all__ = ['DumpSubnetHook', 'EstimateResourcesHook'] diff --git a/mmrazor/engine/hooks/estimate_resources_hook.py b/mmrazor/engine/hooks/estimate_resources_hook.py new file mode 100644 index 00000000..34ebe7ef --- /dev/null +++ b/mmrazor/engine/hooks/estimate_resources_hook.py @@ -0,0 +1,117 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +from typing import Any, Dict, Optional, Sequence + +import torch +from mmengine.data import BaseDataElement +from mmengine.hooks import Hook +from mmengine.registry import HOOKS + +from mmrazor.models.task_modules import ResourceEstimator + +DATA_BATCH = Optional[Sequence[dict]] + + +@HOOKS.register_module() +class EstimateResourcesHook(Hook): + """Estimate model resources periodically. + + Args: + interval (int): The saving period. If ``by_epoch=True``, interval + indicates epochs, otherwise it indicates iterations. + Defaults to -1, which means "never". + by_epoch (bool): Saving checkpoints by epoch or by iteration. + Default to True. + estimator_cfg (Dict[str, Any]): Used for building a resource estimator. + Default to dict(). + + Example: + >>> add the `EstimatorResourcesHook` in custom_hooks as follows: + custom_hooks = [ + dict(type='mmrazor.EstimateResourcesHook', + interval=1, + by_epoch=True, + estimator_cfg=dict(input_shape=(1, 3, 64, 64))) + ] + """ + out_dir: str + + priority = 'VERY_LOW' + + def __init__(self, + interval: int = -1, + by_epoch: bool = True, + estimator_cfg: Dict[str, Any] = dict(), + **kwargs) -> None: + self.interval = interval + self.by_epoch = by_epoch + self.estimator = ResourceEstimator(**estimator_cfg) + + def after_val_epoch(self, + runner, + metrics: Optional[Dict[str, float]] = None) -> None: + """Estimate model resources after every n val epochs. + + Args: + runner (Runner): The runner of the training process. + """ + if not self.by_epoch: + return + + if self.every_n_epochs(runner, self.interval): + self.estimate_resources(runner) + + def after_val_iter(self, + runner, + batch_idx: int, + data_batch: DATA_BATCH = None, + outputs: Optional[Sequence[BaseDataElement]] = None) \ + -> None: + """Estimate model resources after every n val iters. + + Args: + runner (Runner): The runner of the training process. + """ + if self.by_epoch: + return + + if self.every_n_train_iters(runner, self.interval): + self.estimate_resources(runner) + + def estimate_resources(self, runner) -> None: + """Estimate model resources: latency/flops/params.""" + model = runner.model.module if runner.distributed else runner.model + + # TODO confirm the state judgement. + if hasattr(model, 'is_supernet') and model.is_supernet: + model = self.export_subnet(model) + + resource_metrics = self.estimator.estimate(model) + runner.logger.info(f'Estimate model resources: {resource_metrics}') + + def export_subnet(self, model) -> torch.nn.Module: + """Export current best subnet. + + NOTE: This method is called when it comes to those NAS algorithms that + require building a supernet for training. + + For those algorithms, measuring subnet resources is more meaningful + than supernet during validation, therefore this method is required to + get the current searched subnet from the supernet. + """ + # Avoid circular import + from mmrazor.models.mutables.base_mutable import BaseMutable + from mmrazor.structures import load_fix_subnet + + # delete non-leaf tensor to get deepcopy(model). + # TODO solve the hard case. + for module in model.architecture.modules(): + if isinstance(module, BaseMutable): + if hasattr(module, 'arch_weights'): + delattr(module, 'arch_weights') + + copied_model = copy.deepcopy(model) + fix_mutable = copied_model.search_subnet() + load_fix_subnet(copied_model, fix_mutable) + + return copied_model diff --git a/mmrazor/engine/runner/autoslim_val_loop.py b/mmrazor/engine/runner/autoslim_val_loop.py index e583d2b0..5b61266b 100644 --- a/mmrazor/engine/runner/autoslim_val_loop.py +++ b/mmrazor/engine/runner/autoslim_val_loop.py @@ -27,7 +27,7 @@ class AutoSlimValLoop(ValLoop): # just for convenience self._model = model - def run(self) -> None: + def run(self): """Launch validation.""" self.runner.call_hook('before_val') diff --git a/mmrazor/engine/runner/evolution_search_loop.py b/mmrazor/engine/runner/evolution_search_loop.py index a6636c23..c7d681e3 100644 --- a/mmrazor/engine/runner/evolution_search_loop.py +++ b/mmrazor/engine/runner/evolution_search_loop.py @@ -1,9 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. +import copy import os import os.path as osp import random import warnings -from typing import Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union import torch from mmengine import fileio @@ -13,8 +14,9 @@ from mmengine.runner import EpochBasedTrainLoop from mmengine.utils import is_list_of from torch.utils.data import DataLoader +from mmrazor.models.task_modules import ResourceEstimator from mmrazor.registry import LOOPS -from mmrazor.structures import Candidates, FlopsEstimator, export_fix_subnet +from mmrazor.structures import Candidates, export_fix_subnet, load_fix_subnet from mmrazor.utils import SupportRandomSubnet from .utils import crossover @@ -42,6 +44,8 @@ class EvolutionSearchLoop(EpochBasedTrainLoop): mutate_prob (float): The probability of mutation. Defaults to 0.1. flops_range (tuple, optional): flops_range to be used for screening candidates. + estimator_cfg (Dict[str, Any]): Used for building a resource estimator. + Default to dict(). score_key (str): Specify one metric in evaluation results to score candidates. Defaults to 'accuracy_top-1'. init_candidates (str, optional): The candidates file path, which is @@ -62,6 +66,7 @@ class EvolutionSearchLoop(EpochBasedTrainLoop): num_crossover: int = 25, mutate_prob: float = 0.1, flops_range: Optional[Tuple[float, float]] = (0., 330 * 1e6), + estimator_cfg: Dict[str, Any] = dict(), score_key: str = 'accuracy_top-1', init_candidates: Optional[str] = None) -> None: super().__init__(runner, dataloader, max_epochs) @@ -80,6 +85,7 @@ class EvolutionSearchLoop(EpochBasedTrainLoop): self.num_candidates = num_candidates self.top_k = top_k self.flops_range = flops_range + self.estimator_cfg = estimator_cfg self.score_key = score_key self.num_mutation = num_mutation self.num_crossover = num_crossover @@ -299,8 +305,13 @@ class EvolutionSearchLoop(EpochBasedTrainLoop): self.model.set_subnet(random_subnet) fix_mutable = export_fix_subnet(self.model) - flops: float = FlopsEstimator.get_model_complexity_info( - self.model, fix_mutable=fix_mutable, as_strings=False)[0] + copied_model = copy.deepcopy(self.model) + load_fix_subnet(copied_model, fix_mutable) + + estimator = ResourceEstimator(**self.estimator_cfg) + results = estimator.estimate(copied_model) + flops = results['flops'] + if self.flops_range[0] < flops < self.flops_range[1]: return True else: diff --git a/mmrazor/engine/runner/slimmable_val_loop.py b/mmrazor/engine/runner/slimmable_val_loop.py index b73230ce..f830ffa8 100644 --- a/mmrazor/engine/runner/slimmable_val_loop.py +++ b/mmrazor/engine/runner/slimmable_val_loop.py @@ -38,7 +38,7 @@ class SlimmableValLoop(ValLoop): # just for convenience self._model = model - def run(self) -> None: + def run(self): """Launch validation.""" self.runner.call_hook('before_val') diff --git a/mmrazor/engine/runner/subnet_sampler_loop.py b/mmrazor/engine/runner/subnet_sampler_loop.py index b6cd3be5..12307dc5 100644 --- a/mmrazor/engine/runner/subnet_sampler_loop.py +++ b/mmrazor/engine/runner/subnet_sampler_loop.py @@ -1,9 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. +import copy import math import os import random from abc import abstractmethod -from typing import Dict, List, Optional, Sequence, Tuple, Union +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union import torch from mmengine import fileio @@ -12,8 +13,9 @@ from mmengine.runner import IterBasedTrainLoop from mmengine.utils import is_list_of from torch.utils.data import DataLoader +from mmrazor.models.task_modules import ResourceEstimator from mmrazor.registry import LOOPS -from mmrazor.structures import Candidates, FlopsEstimator, export_fix_subnet +from mmrazor.structures import Candidates, export_fix_subnet, load_fix_subnet from mmrazor.utils import SupportRandomSubnet @@ -100,7 +102,9 @@ class GreedySamplerTrainLoop(BaseSamplerTrainLoop): val_interval (int): Validation interval. Defaults to 1000. score_key (str): Specify one metric in evaluation results to score candidates. Defaults to 'accuracy_top-1'. - constraints (dict): Constraints to be used for screening candidates. + flops_range (dict): Constraints to be used for screening candidates. + estimator_cfg (Dict[str, Any]): Used for building a resource estimator. + Default to dict(). num_candidates (int): The number of the candidates consist of samples from supernet and itself. Defaults to 1000. num_samples (int): The number of sample in each sampling subnet. @@ -135,6 +139,7 @@ class GreedySamplerTrainLoop(BaseSamplerTrainLoop): val_interval: int = 1000, score_key: str = 'accuracy_top-1', flops_range: Optional[Tuple[float, float]] = (0., 330 * 1e6), + estimator_cfg: Dict[str, Any] = dict(), num_candidates: int = 1000, num_samples: int = 10, top_k: int = 5, @@ -158,6 +163,7 @@ class GreedySamplerTrainLoop(BaseSamplerTrainLoop): self.score_key = score_key self.flops_range = flops_range + self.estimator_cfg = estimator_cfg self.num_candidates = num_candidates self.num_samples = num_samples self.top_k = top_k @@ -316,8 +322,13 @@ class GreedySamplerTrainLoop(BaseSamplerTrainLoop): self.model.set_subnet(random_subnet) fix_mutable = export_fix_subnet(self.model) - flops = FlopsEstimator.get_model_complexity_info( - self.model, fix_mutable=fix_mutable, as_strings=False)[0] + copied_model = copy.deepcopy(self.model) + load_fix_subnet(copied_model, fix_mutable) + + estimator = ResourceEstimator(**self.estimator_cfg) + results = estimator.estimate(copied_model) + flops = results['flops'] + if self.flops_range[0] < flops < self.flops_range[1]: return True else: diff --git a/mmrazor/models/__init__.py b/mmrazor/models/__init__.py index f20ec586..1b83f430 100644 --- a/mmrazor/models/__init__.py +++ b/mmrazor/models/__init__.py @@ -6,3 +6,4 @@ from .losses import * # noqa: F401,F403 from .mutables import * # noqa: F401,F403 from .mutators import * # noqa: F401,F403 from .ops import * # noqa: F401,F403 +from .task_modules import * # noqa: F401,F403 diff --git a/mmrazor/models/task_modules/__init__.py b/mmrazor/models/task_modules/__init__.py new file mode 100644 index 00000000..f5898ff5 --- /dev/null +++ b/mmrazor/models/task_modules/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .estimators import ResourceEstimator + +__all__ = ['ResourceEstimator'] diff --git a/mmrazor/models/task_modules/estimators/__init__.py b/mmrazor/models/task_modules/estimators/__init__.py new file mode 100644 index 00000000..f1cd00f8 --- /dev/null +++ b/mmrazor/models/task_modules/estimators/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .counters import * # noqa: F401,F403 +from .resource_estimator import ResourceEstimator + +__all__ = ['ResourceEstimator'] diff --git a/mmrazor/models/task_modules/estimators/base_estimator.py b/mmrazor/models/task_modules/estimators/base_estimator.py new file mode 100644 index 00000000..22a82d10 --- /dev/null +++ b/mmrazor/models/task_modules/estimators/base_estimator.py @@ -0,0 +1,55 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import ABCMeta, abstractmethod +from typing import Any, Dict, List, Tuple + +import torch.nn + +from mmrazor.registry import TASK_UTILS + + +@TASK_UTILS.register_module() +class BaseEstimator(metaclass=ABCMeta): + """The base class of Estimator, used for estimating model infos. + + Args: + default_shape (tuple): Input data's default shape, for calculating + resources consume. Defaults to (1, 3, 224, 224). + units (str): Resource units. Defaults to 'M'. + disabled_counters (list): List of disabled spec op counters. + Defaults to None. + as_strings (bool): Output FLOPs and params counts in a string + form. Default to False. + measure_inference (bool): whether to measure infer speed or not. + Default to False. + """ + + def __init__(self, + default_shape: Tuple = (1, 3, 224, 224), + units: str = 'M', + disabled_counters: List[str] = None, + as_strings: bool = False, + measure_inference: bool = False): + assert len(default_shape) in [3, 4, 5], \ + f'Unsupported shape: {default_shape}' + self.default_shape = default_shape + self.units = units + self.disabled_counters = disabled_counters + self.as_strings = as_strings + self.measure_inference = measure_inference + + @abstractmethod + def estimate( + self, model: torch.nn.Module, resource_args: Dict[str, Any] = dict() + ) -> Dict[str, float]: + """Estimate the resources(flops/params/latency) of the given model. + + Args: + model: The measured model. + resource_args (Dict[str, float]): resources information. + NOTE: resource_args have the same items() as the init cfgs. + + Returns: + Dict[str, float]): A dict that containing resource results(flops, + params and latency). + """ + pass diff --git a/mmrazor/models/task_modules/estimators/counters/__init__.py b/mmrazor/models/task_modules/estimators/counters/__init__.py new file mode 100644 index 00000000..0a6adee4 --- /dev/null +++ b/mmrazor/models/task_modules/estimators/counters/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .flops_params_counter import (get_model_complexity_info, + params_units_convert) +from .latency_counter import repeat_measure_inference_speed +from .op_counters import * # noqa: F401,F403 + +__all__ = [ + 'get_model_complexity_info', 'params_units_convert', + 'repeat_measure_inference_speed' +] diff --git a/mmrazor/models/task_modules/estimators/counters/flops_params_counter.py b/mmrazor/models/task_modules/estimators/counters/flops_params_counter.py new file mode 100644 index 00000000..0da47491 --- /dev/null +++ b/mmrazor/models/task_modules/estimators/counters/flops_params_counter.py @@ -0,0 +1,472 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import sys +from functools import partial + +import torch +import torch.nn as nn + +from mmrazor.registry import TASK_UTILS + + +def get_model_complexity_info(model, + input_shape, + spec_modules=[], + disabled_counters=[], + print_per_layer_stat=False, + as_strings=False, + input_constructor=None, + flush=False, + ost=sys.stdout): + """Get complexity information of a model. This method can calculate FLOPs + and parameter counts of a model with corresponding input shape. It can also + print complexity information for each layer in a model. Supported layers + are listed as below: + + - Convolutions: ``nn.Conv1d``, ``nn.Conv2d``, ``nn.Conv3d``. + - Activations: ``nn.ReLU``, ``nn.PReLU``, ``nn.ELU``, ``nn.LeakyReLU``, + ``nn.ReLU6``. + - Poolings: ``nn.MaxPool1d``, ``nn.MaxPool2d``, ``nn.MaxPool3d``, + ``nn.AvgPool1d``, ``nn.AvgPool2d``, ``nn.AvgPool3d``, + ``nn.AdaptiveMaxPool1d``, ``nn.AdaptiveMaxPool2d``, + ``nn.AdaptiveMaxPool3d``, ``nn.AdaptiveAvgPool1d``, + ``nn.AdaptiveAvgPool2d``, ``nn.AdaptiveAvgPool3d``. + - BatchNorms: ``nn.BatchNorm1d``, ``nn.BatchNorm2d``, + ``nn.BatchNorm3d``. + - Linear: ``nn.Linear``. + - Deconvolution: ``nn.ConvTranspose2d``. + - Upsample: ``nn.Upsample``. + + Args: + model (nn.Module): The model for complexity calculation. + input_shape (tuple): Input shape (including batchsize) used for + calculation. + spec_modules (list): A list that contains the names of several spec + modules, which users want to get resources infos of them. + e.g., ['backbone', 'head'], ['backbone.layer1']. Default to []. + disabled_counters (list): One can limit which ops' spec would be + calculated. Default to []. + print_per_layer_stat (bool): Whether to print complexity information + for each layer in a model. Default to True. + as_strings (bool): Output FLOPs and params counts in a string form. + Default to True. + input_constructor (None | callable): If specified, it takes a callable + method that generates input. otherwise, it will generate a random + tensor with input shape to calculate FLOPs. Default to None. + flush (bool): same as that in :func:`print`. Default to False. + ost (stream): same as ``file`` param in :func:`print`. + Default to sys.stdout. + + Returns: + tuple[float | str] | dict[str, float]: If `as_strings` is set to True, + it will return FLOPs and parameter counts in a string format. + Otherwise, it will return those in a float number format. + If len(spec_modules) > 0, it will return a resource info dict with + FLOPs and parameter counts of each spec module in float format. + """ + assert type(input_shape) is tuple + assert len(input_shape) >= 1 + assert isinstance(model, nn.Module) + flops_params_model = add_flops_params_counting_methods(model) + flops_params_model.eval() + flops_params_model.start_flops_params_count(disabled_counters) + if input_constructor: + input = input_constructor(input_shape) + _ = flops_params_model(**input) + else: + try: + batch = torch.ones(()).new_empty( + tuple(input_shape), + dtype=next(flops_params_model.parameters()).dtype, + device=next(flops_params_model.parameters()).device) + except StopIteration: + # Avoid StopIteration for models which have no parameters, + # like `nn.Relu()`, `nn.AvgPool2d`, etc. + batch = torch.ones(()).new_empty(tuple(input_shape)) + + _ = flops_params_model(batch) + + flops_count, params_count = \ + flops_params_model.compute_average_flops_params_cost() + + if print_per_layer_stat: + print_model_with_flops_params( + flops_params_model, + flops_count, + params_count, + ost=ost, + flush=flush) + + if len(spec_modules): + module_names = [name for name, _ in flops_params_model.named_modules()] + for module in spec_modules: + assert module in module_names, \ + f'All modules in spec_modules should be in the measured ' \ + f'flops_params_model. Got module {module} in spec_modules.' + spec_modules_resources = dict() + accumulate_sub_module_flops_params(flops_params_model) + for name, module in flops_params_model.named_modules(): + if name in spec_modules: + spec_modules_resources[name] = dict() + spec_modules_resources[name]['flops'] = module.__flops__ + spec_modules_resources[name]['params'] = module.__params__ + if as_strings: + spec_modules_resources[name]['flops'] = str( + params_units_convert(module.__flops__, + 'G')) + ' GFLOPs' + spec_modules_resources[name]['params'] = str( + params_units_convert(module.__params__, 'M')) + ' M' + + flops_params_model.stop_flops_params_count() + + if len(spec_modules): + return spec_modules_resources + + if as_strings: + flops_string = str(params_units_convert(flops_count, 'G')) + ' GFLOPs' + params_string = str(params_units_convert(params_count, 'M')) + ' M' + return flops_string, params_string + + return flops_count, params_count + + +def params_units_convert(num_params, units='M', precision=3): + """Convert parameter number with units. + + Args: + num_params (float): Parameter number to be converted. + units (str | None): Converted FLOPs units. Options are None, 'M', + 'K' and ''. If set to None, it will automatically choose the most + suitable unit for Parameter number. Default to None. + precision (int): Digit number after the decimal point. Default to 2. + + Returns: + str: The converted parameter number. + + Examples: + >>> params_units_convert(1e9) + '1000.0' + >>> params_units_convert(2e5) + '200.0' + >>> params_units_convert(3e-9) + '3e-09' + """ + + if units == 'G': + return round(num_params / 10.**9, precision) + elif units == 'M': + return round(num_params / 10.**6, precision) + elif units == 'K': + return round(num_params / 10.**3, precision) + else: + raise ValueError(f'Unsupported units convert: {units}') + + +def print_model_with_flops_params(model, + total_flops, + total_params, + units='G', + precision=3, + ost=sys.stdout, + flush=False): + """Print a model with FLOPs and Params for each layer. + + Args: + model (nn.Module): The model to be printed. + total_flops (float): Total FLOPs of the model. + total_params (float): Total parameter counts of the model. + units (str | None): Converted FLOPs units. Default to 'G'. + precision (int): Digit number after the decimal point. Default to 3. + ost (stream): same as `file` param in :func:`print`. + Default to sys.stdout. + flush (bool): same as that in :func:`print`. Default to False. + + Example: + >>> class ExampleModel(nn.Module): + >>> def __init__(self): + >>> super().__init__() + >>> self.conv1 = nn.Conv2d(3, 8, 3) + >>> self.conv2 = nn.Conv2d(8, 256, 3) + >>> self.conv3 = nn.Conv2d(256, 8, 3) + >>> self.avg_pool = nn.AdaptiveAvgPool2d((1, 1)) + >>> self.flatten = nn.Flatten() + >>> self.fc = nn.Linear(8, 1) + >>> def forward(self, x): + >>> x = self.conv1(x) + >>> x = self.conv2(x) + >>> x = self.conv3(x) + >>> x = self.avg_pool(x) + >>> x = self.flatten(x) + >>> x = self.fc(x) + >>> return x + >>> model = ExampleModel() + >>> x = (3, 16, 16) + to print the complexity information state for each layer, you can use + >>> get_model_complexity_info(model, x) + or directly use + >>> print_model_with_flops_params(model, 4579784.0, 37361) + ExampleModel( + 0.037 M, 100.000% Params, 0.005 GFLOPs, 100.000% FLOPs, + (conv1): Conv2d(0.0 M, 0.600% Params, 0.0 GFLOPs, 0.959% FLOPs, 3, 8, kernel_size=(3, 3), stride=(1, 1)) # noqa: E501 + (conv2): Conv2d(0.019 M, 50.020% Params, 0.003 GFLOPs, 58.760% FLOPs, 8, 256, kernel_size=(3, 3), stride=(1, 1)) + (conv3): Conv2d(0.018 M, 49.356% Params, 0.002 GFLOPs, 40.264% FLOPs, 256, 8, kernel_size=(3, 3), stride=(1, 1)) + (avg_pool): AdaptiveAvgPool2d(0.0 M, 0.000% Params, 0.0 GFLOPs, 0.017% FLOPs, output_size=(1, 1)) + (flatten): Flatten(0.0 M, 0.000% Params, 0.0 GFLOPs, 0.000% FLOPs, ) + (fc): Linear(0.0 M, 0.024% Params, 0.0 GFLOPs, 0.000% FLOPs, in_features=8, out_features=1, bias=True) + ) + """ + + def accumulate_params(self): + if is_supported_instance(self): + return self.__params__ + else: + sum = 0 + for m in self.children(): + sum += m.accumulate_params() + return sum + + def accumulate_flops(self): + if is_supported_instance(self): + return self.__flops__ / model.__batch_counter__ + else: + sum = 0 + for m in self.children(): + sum += m.accumulate_flops() + return sum + + def flops_repr(self): + accumulated_num_params = self.accumulate_params() + accumulated_flops_cost = self.accumulate_flops() + flops_string = str( + params_units_convert( + accumulated_flops_cost, units=units, + precision=precision)) + ' ' + units + 'FLOPs' + params_string = str( + params_units_convert( + accumulated_num_params, units='M', precision=precision)) + ' M' + return ', '.join([ + params_string, + '{:.3%} Params'.format(accumulated_num_params / total_params), + flops_string, + '{:.3%} FLOPs'.format(accumulated_flops_cost / total_flops), + self.original_extra_repr() + ]) + + def add_extra_repr(m): + m.accumulate_flops = accumulate_flops.__get__(m) + m.accumulate_params = accumulate_params.__get__(m) + flops_extra_repr = flops_repr.__get__(m) + if m.extra_repr != flops_extra_repr: + m.original_extra_repr = m.extra_repr + m.extra_repr = flops_extra_repr + assert m.extra_repr != m.original_extra_repr + + def del_extra_repr(m): + if hasattr(m, 'original_extra_repr'): + m.extra_repr = m.original_extra_repr + del m.original_extra_repr + if hasattr(m, 'accumulate_flops'): + del m.accumulate_flops + + model.apply(add_extra_repr) + print(model, file=ost, flush=flush) + model.apply(del_extra_repr) + + +def accumulate_sub_module_flops_params(model): + """Accumulate FLOPs and params for each module in the model. Each module in + the model will have the `__flops__` and `__params__` parameters. + + Args: + model (nn.Module): The model to be accumulated. + """ + + def accumulate_params(module): + if is_supported_instance(module): + return module.__params__ + else: + sum = 0 + for m in module.children(): + sum += accumulate_params(m) + return sum + + def accumulate_flops(module): + if is_supported_instance(module): + return module.__flops__ / model.__batch_counter__ + else: + sum = 0 + for m in module.children(): + sum += accumulate_flops(m) + return sum + + for module in model.modules(): + _flops = accumulate_flops(module) + _params = accumulate_params(module) + module.__flops__ = _flops + module.__params__ = _params + + +def get_model_parameters_number(model): + """Calculate parameter number of a model. + + Args: + model (nn.module): The model for parameter number calculation. + Returns: + float: Parameter number of the model. + """ + num_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + return num_params + + +def add_flops_params_counting_methods(net_main_module): + # adding additional methods to the existing module object, + # this is done this way so that each function has access to self object + net_main_module.start_flops_params_count = start_flops_params_count.__get__( # noqa: E501 + net_main_module) + net_main_module.stop_flops_params_count = stop_flops_params_count.__get__( + net_main_module) + net_main_module.reset_flops_params_count = reset_flops_params_count.__get__( # noqa: E501 + net_main_module) + net_main_module.compute_average_flops_params_cost = compute_average_flops_params_cost.__get__( # noqa: E501 + net_main_module) + + net_main_module.reset_flops_params_count() + + return net_main_module + + +def compute_average_flops_params_cost(self): + """Compute average FLOPs and Params cost. + + A method to compute average FLOPs cost, which will be available after + `add_flops_params_counting_methods()` is called on a desired net object. + Returns: + float: Current mean flops consumption per image. + """ + batches_count = self.__batch_counter__ + flops_sum = 0 + params_sum = 0 + for module in self.modules(): + if is_supported_instance(module): + flops_sum += module.__flops__ + params_sum += module.__params__ + return flops_sum / batches_count, params_sum + + +def start_flops_params_count(self, disabled_counters): + """Activate the computation of mean flops and params consumption per image. + + A method to activate the computation of mean flops consumption per image. + which will be available after ``add_flops_params_counting_methods()`` is + called on a desired net object. It should be called before running the + network. + """ + add_batch_counter_hook_function(self) + + def add_flops_params_counter_hook_function(module): + if is_supported_instance(module): + if hasattr(module, '__flops_params_handle__'): + return + + else: + counter_type = get_counter_type(module) + if (disabled_counters is None + or counter_type not in disabled_counters): + counter = TASK_UTILS.build( + dict(type=counter_type, _scope_='mmrazor')) + handle = module.register_forward_hook( + counter.add_count_hook) + + module.__flops_params_handle__ = handle + else: + return + + self.apply(partial(add_flops_params_counter_hook_function)) + + +def stop_flops_params_count(self): + """Stop computing the mean flops and params consumption per image. + + A method to stop computing the mean flops consumption per image, which will + be available after ``add_flops_params_counting_methods()`` is called on a + desired net object. It can be called to pause the computation whenever. + """ + remove_batch_counter_hook_function(self) + self.apply(remove_flops_params_counter_hook_function) + + +def reset_flops_params_count(self): + """Reset statistics computed so far. + + A method to Reset computed statistics, which will be available after + `add_flops_params_counting_methods()` is called on a desired net object. + """ + add_batch_counter_variables_or_reset(self) + self.apply(add_flops_params_counter_variable_or_reset) + + +# ---- Internal functions +def empty_flops_params_counter_hook(module, input, output): + module.__flops__ += 0 + module.__params__ += 0 + + +def add_batch_counter_variables_or_reset(module): + + module.__batch_counter__ = 0 + + +def add_batch_counter_hook_function(module): + if hasattr(module, '__batch_counter_handle__'): + return + + handle = module.register_forward_hook(batch_counter_hook) + module.__batch_counter_handle__ = handle + + +def batch_counter_hook(module, input, output): + batch_size = 1 + if len(input) > 0: + # Can have multiple inputs, getting the first one + input = input[0] + batch_size = len(input) + else: + pass + print('Warning! No positional inputs found for a module, ' + 'assuming batch size is 1.') + module.__batch_counter__ += batch_size + + +def remove_batch_counter_hook_function(module): + if hasattr(module, '__batch_counter_handle__'): + module.__batch_counter_handle__.remove() + del module.__batch_counter_handle__ + + +def add_flops_params_counter_variable_or_reset(module): + if is_supported_instance(module): + if hasattr(module, '__flops__') or hasattr(module, '__params__'): + print('Warning: variables __flops__ or __params__ are already ' + 'defined for the module' + type(module).__name__ + + ' ptflops can affect your code!') + module.__flops__ = 0 + module.__params__ = 0 + + +def get_counter_type(module): + return module.__class__.__name__ + 'Counter' + + +def is_supported_instance(module): + if get_counter_type(module) in TASK_UTILS._module_dict.keys(): + return True + return False + + +def remove_flops_params_counter_hook_function(module): + if hasattr(module, '__flops_params_handle__'): + module.__flops_params_handle__.remove() + del module.__flops_params_handle__ + if hasattr(module, '__flops__'): + del module.__flops__ + if hasattr(module, '__params__'): + del module.__params__ diff --git a/mmrazor/models/task_modules/estimators/counters/latency_counter.py b/mmrazor/models/task_modules/estimators/counters/latency_counter.py new file mode 100644 index 00000000..55a145d0 --- /dev/null +++ b/mmrazor/models/task_modules/estimators/counters/latency_counter.py @@ -0,0 +1,88 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import logging +import time + +import torch +from mmengine.logging import print_log + + +def repeat_measure_inference_speed(model, + resource_args, + max_iter: int = 100, + log_interval: int = 100, + repeat_num: int = 1) -> float: + """Repeat speed measure for multi-times to get more precise results.""" + assert repeat_num >= 1 + + fps_list = [] + + for _ in range(repeat_num): + + fps_list.append( + measure_inference_speed(model, resource_args, max_iter, + log_interval)) + + if repeat_num > 1: + fps_list_ = [round(fps, 1) for fps in fps_list] + times_per_img_list = [round(1000 / fps, 1) for fps in fps_list] + mean_fps_ = sum(fps_list_) / len(fps_list_) + mean_times_per_img = sum(times_per_img_list) / len(times_per_img_list) + print_log( + f'Overall fps: {fps_list_}[{mean_fps_:.1f}] img / s, ' + f'times per image: ' + f'{times_per_img_list}[{mean_times_per_img:.1f}] ms/img', + logger='current', + level=logging.DEBUG) + return mean_times_per_img + + latency = round(1000 / fps_list[0], 1) + return latency + + +def measure_inference_speed(model, resource_args, max_iter: int, + log_interval: int) -> float: + # the first several iterations may be very slow so skip them + num_warmup = 5 + pure_inf_time = 0.0 + fps = 0.0 + data = dict() + if next(model.parameters()).is_cuda: + device = 'cuda' + else: + raise NotImplementedError('To use cpu to test latency not supported.') + # benchmark with 100 image and take the average + for i in range(1, max_iter): + if device == 'cuda': + data = torch.rand(resource_args['input_shape']).cuda() + torch.cuda.synchronize() + start_time = time.perf_counter() + + with torch.no_grad(): + model(data) + + torch.cuda.synchronize() + elapsed = time.perf_counter() - start_time + + if i >= num_warmup: + pure_inf_time += elapsed + if (i + 1) % log_interval == 0: + fps = (i + 1 - num_warmup) / pure_inf_time + print_log( + f'Done image [{i + 1:<3}/ {max_iter}], ' + f'fps: {fps:.1f} img / s, ' + f'times per image: {1000 / fps:.1f} ms / img', + logger='current', + level=logging.DEBUG) + + if (i + 1) == max_iter: + fps = (i + 1 - num_warmup) / pure_inf_time + print_log( + f'Overall fps: {fps:.1f} img / s, ' + f'times per image: {1000 / fps:.1f} ms / img', + logger='current', + level=logging.DEBUG) + break + + torch.cuda.empty_cache() + + return fps diff --git a/mmrazor/models/task_modules/estimators/counters/op_counters/__init__.py b/mmrazor/models/task_modules/estimators/counters/op_counters/__init__.py new file mode 100644 index 00000000..6e33babe --- /dev/null +++ b/mmrazor/models/task_modules/estimators/counters/op_counters/__init__.py @@ -0,0 +1,22 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .activation_layer_counter import (ELUCounter, LeakyReLUCounter, + PReLUCounter, ReLU6Counter, ReLUCounter) +from .base_counter import BaseCounter +from .conv_layer_counter import Conv1dCounter, Conv2dCounter, Conv3dCounter +from .deconv_layer_counter import ConvTranspose2dCounter +from .linear_layer_counter import LinearCounter +from .norm_layer_counter import (BatchNorm1dCounter, BatchNorm2dCounter, + BatchNorm3dCounter, GroupNormCounter, + InstanceNorm1dCounter, InstanceNorm2dCounter, + InstanceNorm3dCounter, LayerNormCounter) +from .pooling_layer_counter import * # noqa: F403, F405, F401 +from .upsample_layer_counter import UpsampleCounter + +__all__ = [ + 'ReLUCounter', 'PReLUCounter', 'ELUCounter', 'LeakyReLUCounter', + 'ReLU6Counter', 'BatchNorm1dCounter', 'BatchNorm2dCounter', + 'BatchNorm3dCounter', 'Conv1dCounter', 'Conv2dCounter', 'Conv3dCounter', + 'ConvTranspose2dCounter', 'UpsampleCounter', 'LinearCounter', + 'GroupNormCounter', 'InstanceNorm1dCounter', 'InstanceNorm2dCounter', + 'InstanceNorm3dCounter', 'LayerNormCounter', 'BaseCounter' +] diff --git a/mmrazor/models/task_modules/estimators/counters/op_counters/activation_layer_counter.py b/mmrazor/models/task_modules/estimators/counters/op_counters/activation_layer_counter.py new file mode 100644 index 00000000..f124c0db --- /dev/null +++ b/mmrazor/models/task_modules/estimators/counters/op_counters/activation_layer_counter.py @@ -0,0 +1,35 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmrazor.registry import TASK_UTILS +from ..flops_params_counter import get_model_parameters_number +from .base_counter import BaseCounter + + +@TASK_UTILS.register_module() +class ReLUCounter(BaseCounter): + """FLOPs/params counter for ReLU series activate function.""" + + @staticmethod + def add_count_hook(module, input, output): + active_elements_count = output.numel() + module.__flops__ += int(active_elements_count) + module.__params__ += get_model_parameters_number(module) + + +@TASK_UTILS.register_module() +class PReLUCounter(ReLUCounter): + pass + + +@TASK_UTILS.register_module() +class ELUCounter(ReLUCounter): + pass + + +@TASK_UTILS.register_module() +class LeakyReLUCounter(ReLUCounter): + pass + + +@TASK_UTILS.register_module() +class ReLU6Counter(ReLUCounter): + pass diff --git a/mmrazor/models/task_modules/estimators/counters/op_counters/base_counter.py b/mmrazor/models/task_modules/estimators/counters/op_counters/base_counter.py new file mode 100644 index 00000000..46baee93 --- /dev/null +++ b/mmrazor/models/task_modules/estimators/counters/op_counters/base_counter.py @@ -0,0 +1,28 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from abc import ABCMeta, abstractclassmethod + + +class BaseCounter(object, metaclass=ABCMeta): + """Base class of all op module counters in `TASK_UTILS`. + + In ResourceEstimator, `XXModuleCounter` is responsible for `XXModule`, + which refers to estimator/flops_params_counter.py::get_counter_type(). + Users can customize a `ModuleACounter` and overwrite the `add_count_hook` + method with a self-defined module `ModuleA`. + """ + + def __init__(self) -> None: + pass + + @staticmethod + @abstractclassmethod + def add_count_hook(module, input, output): + """The main method of a `BaseCounter` which defines the way to + calculate resources(flops/params) of the current module. + + Args: + module (nn.Module): the module to be tested. + input (_type_): input_tensor. Plz refer to `torch forward_hook` + output (_type_): output_tensor. Plz refer to `torch forward_hook` + """ + pass diff --git a/mmrazor/models/task_modules/estimators/counters/op_counters/conv_layer_counter.py b/mmrazor/models/task_modules/estimators/counters/op_counters/conv_layer_counter.py new file mode 100644 index 00000000..0e9c6c77 --- /dev/null +++ b/mmrazor/models/task_modules/estimators/counters/op_counters/conv_layer_counter.py @@ -0,0 +1,57 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numpy as np + +from mmrazor.registry import TASK_UTILS +from .base_counter import BaseCounter + + +class ConvCounter(BaseCounter): + """FLOPs/params counter for Conv module series.""" + + @staticmethod + def add_count_hook(module, input, output): + # Can have multiple inputs, getting the first one + input = input[0] + + batch_size = input.shape[0] + output_dims = list(output.shape[2:]) + + kernel_dims = list(module.kernel_size) + in_channels = module.in_channels + out_channels = module.out_channels + groups = module.groups + + filters_per_channel = out_channels / groups + conv_per_position_flops = int( + np.prod(kernel_dims)) * in_channels * filters_per_channel + + active_elements_count = batch_size * int(np.prod(output_dims)) + + overall_conv_flops = conv_per_position_flops * active_elements_count + overall_params = conv_per_position_flops + + bias_flops = 0 + overall_params = conv_per_position_flops + if module.bias is not None: + bias_flops = out_channels * active_elements_count + overall_params += out_channels + + overall_flops = overall_conv_flops + bias_flops + + module.__flops__ += overall_flops + module.__params__ += int(overall_params) + + +@TASK_UTILS.register_module() +class Conv1dCounter(ConvCounter): + pass + + +@TASK_UTILS.register_module() +class Conv2dCounter(ConvCounter): + pass + + +@TASK_UTILS.register_module() +class Conv3dCounter(ConvCounter): + pass diff --git a/mmrazor/models/task_modules/estimators/counters/op_counters/deconv_layer_counter.py b/mmrazor/models/task_modules/estimators/counters/op_counters/deconv_layer_counter.py new file mode 100644 index 00000000..73604243 --- /dev/null +++ b/mmrazor/models/task_modules/estimators/counters/op_counters/deconv_layer_counter.py @@ -0,0 +1,38 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmrazor.registry import TASK_UTILS +from ..flops_params_counter import get_model_parameters_number +from .base_counter import BaseCounter + + +@TASK_UTILS.register_module() +class ConvTranspose2dCounter(BaseCounter): + """FLOPs/params counter for Decov module series.""" + + @staticmethod + def add_count_hook(module, input, output): + # Can have multiple inputs, getting the first one + input = input[0] + + batch_size = input.shape[0] + input_height, input_width = input.shape[2:] + + # TODO: use more common representation + kernel_height, kernel_width = module.kernel_size + in_channels = module.in_channels + out_channels = module.out_channels + groups = module.groups + + filters_per_channel = out_channels // groups + conv_per_position_flops = ( + kernel_height * kernel_width * in_channels * filters_per_channel) + + active_elements_count = batch_size * input_height * input_width + overall_conv_flops = conv_per_position_flops * active_elements_count + bias_flops = 0 + if module.bias is not None: + output_height, output_width = output.shape[2:] + bias_flops = out_channels * batch_size * output_height * output_height # noqa: E501 + overall_flops = overall_conv_flops + bias_flops + + module.__flops__ += int(overall_flops) + module.__params__ += get_model_parameters_number(module) diff --git a/mmrazor/models/task_modules/estimators/counters/op_counters/linear_layer_counter.py b/mmrazor/models/task_modules/estimators/counters/op_counters/linear_layer_counter.py new file mode 100644 index 00000000..c4f6ac6e --- /dev/null +++ b/mmrazor/models/task_modules/estimators/counters/op_counters/linear_layer_counter.py @@ -0,0 +1,19 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numpy as np + +from mmrazor.registry import TASK_UTILS +from ..flops_params_counter import get_model_parameters_number +from .base_counter import BaseCounter + + +@TASK_UTILS.register_module() +class LinearCounter(BaseCounter): + """FLOPs/params counter for Linear operation series.""" + + @staticmethod + def add_count_hook(module, input, output): + input = input[0] + output_last_dim = output.shape[ + -1] # pytorch checks dimensions, so here we don't care much + module.__flops__ += int(np.prod(input.shape) * output_last_dim) + module.__params__ += get_model_parameters_number(module) diff --git a/mmrazor/models/task_modules/estimators/counters/op_counters/norm_layer_counter.py b/mmrazor/models/task_modules/estimators/counters/op_counters/norm_layer_counter.py new file mode 100644 index 00000000..5941f7c0 --- /dev/null +++ b/mmrazor/models/task_modules/estimators/counters/op_counters/norm_layer_counter.py @@ -0,0 +1,59 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numpy as np + +from mmrazor.registry import TASK_UTILS +from ..flops_params_counter import get_model_parameters_number +from .base_counter import BaseCounter + + +class BNCounter(BaseCounter): + """FLOPs/params counter for BatchNormalization series.""" + + @staticmethod + def add_count_hook(module, input, output): + input = input[0] + batch_flops = np.prod(input.shape) + if getattr(module, 'affine', False): + batch_flops *= 2 + module.__flops__ += int(batch_flops) + module.__params__ += get_model_parameters_number(module) + + +@TASK_UTILS.register_module() +class BatchNorm1dCounter(BNCounter): + pass + + +@TASK_UTILS.register_module() +class BatchNorm2dCounter(BNCounter): + pass + + +@TASK_UTILS.register_module() +class BatchNorm3dCounter(BNCounter): + pass + + +@TASK_UTILS.register_module() +class InstanceNorm1dCounter(BNCounter): + pass + + +@TASK_UTILS.register_module() +class InstanceNorm2dCounter(BNCounter): + pass + + +@TASK_UTILS.register_module() +class InstanceNorm3dCounter(BNCounter): + pass + + +@TASK_UTILS.register_module() +class LayerNormCounter(BNCounter): + pass + + +@TASK_UTILS.register_module() +class GroupNormCounter(BNCounter): + pass diff --git a/mmrazor/models/task_modules/estimators/counters/op_counters/pooling_layer_counter.py b/mmrazor/models/task_modules/estimators/counters/op_counters/pooling_layer_counter.py new file mode 100644 index 00000000..c4e94cdc --- /dev/null +++ b/mmrazor/models/task_modules/estimators/counters/op_counters/pooling_layer_counter.py @@ -0,0 +1,76 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numpy as np + +from mmrazor.registry import TASK_UTILS +from ..flops_params_counter import get_model_parameters_number +from .base_counter import BaseCounter + + +class PoolCounter(BaseCounter): + """FLOPs/params counter for Pooling series.""" + + @staticmethod + def add_count_hook(module, input, output): + input = input[0] + module.__flops__ += int(np.prod(input.shape)) + module.__params__ += get_model_parameters_number(module) + + +@TASK_UTILS.register_module() +class MaxPool1dCounter(PoolCounter): + pass + + +@TASK_UTILS.register_module() +class MaxPool2dCounter(PoolCounter): + pass + + +@TASK_UTILS.register_module() +class MaxPool3dCounter(PoolCounter): + pass + + +@TASK_UTILS.register_module() +class AvgPool1dCounter(PoolCounter): + pass + + +@TASK_UTILS.register_module() +class AvgPool2dCounter(PoolCounter): + pass + + +@TASK_UTILS.register_module() +class AvgPool3dCounter(PoolCounter): + pass + + +@TASK_UTILS.register_module() +class AdaptiveMaxPool1dCounter(PoolCounter): + pass + + +@TASK_UTILS.register_module() +class AdaptiveMaxPool2dCounter(PoolCounter): + pass + + +@TASK_UTILS.register_module() +class AdaptiveMaxPool3dCounter(PoolCounter): + pass + + +@TASK_UTILS.register_module() +class AdaptiveAvgPool1dCounter(PoolCounter): + pass + + +@TASK_UTILS.register_module() +class AdaptiveAvgPool2dCounter(PoolCounter): + pass + + +@TASK_UTILS.register_module() +class AdaptiveAvgPool3dCounter(PoolCounter): + pass diff --git a/mmrazor/models/task_modules/estimators/counters/op_counters/upsample_layer_counter.py b/mmrazor/models/task_modules/estimators/counters/op_counters/upsample_layer_counter.py new file mode 100644 index 00000000..9442ac56 --- /dev/null +++ b/mmrazor/models/task_modules/estimators/counters/op_counters/upsample_layer_counter.py @@ -0,0 +1,19 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmrazor.registry import TASK_UTILS +from ..flops_params_counter import get_model_parameters_number +from .base_counter import BaseCounter + + +@TASK_UTILS.register_module() +class UpsampleCounter(BaseCounter): + """FLOPs/params counter for Upsample function.""" + + @staticmethod + def add_count_hook(module, input, output): + output_size = output[0] + batch_size = output_size.shape[0] + output_elements_count = batch_size + for val in output_size.shape[1:]: + output_elements_count *= val + module.__flops__ += int(output_elements_count) + module.__params__ += get_model_parameters_number(module) diff --git a/mmrazor/models/task_modules/estimators/resource_estimator.py b/mmrazor/models/task_modules/estimators/resource_estimator.py new file mode 100644 index 00000000..6d434286 --- /dev/null +++ b/mmrazor/models/task_modules/estimators/resource_estimator.py @@ -0,0 +1,155 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Any, Dict, List, Tuple + +import torch.nn +from mmengine.dist import broadcast_object_list, is_main_process + +from mmrazor.registry import TASK_UTILS +from .base_estimator import BaseEstimator +from .counters import (get_model_complexity_info, params_units_convert, + repeat_measure_inference_speed) + + +@TASK_UTILS.register_module() +class ResourceEstimator(BaseEstimator): + """Estimator for calculating the resources consume. + + Args: + default_shape (tuple): Input data's default shape, for calculating + resources consume. Defaults to (1, 3, 224, 224) + units (str): Resource units. Defaults to 'M'. + disabled_counters (list): List of disabled spec op counters. + Defaults to None. + NOTE: disabled_counters contains the op counter class names + in estimator.op_counters that require to be disabled, + such as 'ConvCounter', 'BatchNorm2dCounter', ... + + Examples: + >>> # direct calculate resource consume of nn.Conv2d + >>> conv2d = nn.Conv2d(3, 32, 3) + >>> estimator = ResourceEstimator() + >>> estimator.estimate( + ... model=conv2d, + ... resource_args=dict(input_shape=(1, 3, 64, 64))) + {'flops': 3.444, 'params': 0.001, 'latency': 0.0} + + >>> # calculate resources of custom modules + >>> class CustomModule(nn.Module): + ... + ... def __init__(self) -> None: + ... super().__init__() + ... + ... def forward(self, x): + ... return x + ... + >>> @TASK_UTILS.register_module() + ... class CustomModuleCounter(BaseCounter): + ... + ... @staticmethod + ... def add_count_hook(module, input, output): + ... module.__flops__ += 1000000 + ... module.__params__ += 700000 + ... + >>> model = CustomModule() + >>> estimator.estimate( + ... model=model, + ... resource_args=dict(input_shape=(1, 3, 64, 64))) + {'flops': 1.0, 'params': 0.7, 'latency': 0.0} + ... + >>> # calculate resources of custom modules with disable_counters + >>> estimator.estimate( + ... model=model, + ... resource_args=dict( + ... input_shape=(1, 3, 64, 64), + ... disabled_counters=['CustomModuleCounter'])) + {'flops': 0.0, 'params': 0.0, 'latency': 0.0} + + >>> # calculate resources of mmrazor.models + NOTE: check 'EstimateResourcesHook' in + mmrazor.engine.hooks.estimate_resources_hook for details. + """ + + def __init__(self, + default_shape: Tuple = (1, 3, 224, 224), + units: str = 'M', + disabled_counters: List[str] = [], + as_strings: bool = False, + measure_inference: bool = False): + super().__init__(default_shape, units, disabled_counters, as_strings, + measure_inference) + + def estimate( + self, model: torch.nn.Module, resource_args: Dict[str, Any] = dict() + ) -> Dict[str, Any]: + """Estimate the resources(flops/params/latency) of the given model. + + Args: + model: The measured model. + resource_args (Dict[str, float]): Args for resources estimation. + NOTE: resource_args have the same items() as the init cfgs. + + Returns: + Dict[str, str]): A dict that containing resource results(flops, + params and latency). + """ + resource_metrics = dict() + if is_main_process(): + measure_inference = resource_args.pop('measure_inference', False) + if 'input_shape' not in resource_args.keys(): + resource_args['input_shape'] = self.default_shape + if 'disabled_counters' not in resource_args.keys(): + resource_args['disabled_counters'] = self.disabled_counters + model.eval() + flops, params = get_model_complexity_info(model, **resource_args) + if measure_inference: + latency = repeat_measure_inference_speed( + model, resource_args, max_iter=100, repeat_num=2) + else: + latency = 0.0 + as_strings = resource_args.get('as_strings', self.as_strings) + if as_strings and self.units is not None: + raise ValueError('Set units to None, when as_trings=True.') + if self.units is not None: + flops = params_units_convert(flops, self.units) + params = params_units_convert(params, self.units) + resource_metrics.update({ + 'flops': flops, + 'params': params, + 'latency': latency + }) + results = [resource_metrics] + else: + results = [None] # type: ignore + + broadcast_object_list(results) + + return results[0] + + def estimate_spec_modules( + self, model: torch.nn.Module, resource_args: Dict[str, Any] = dict() + ) -> Dict[str, float]: + """Estimate the resources(flops/params/latency) of the spec modules. + + Args: + model: The measured model. + resource_args (Dict[str, float]): Args for resources estimation. + NOTE: resource_args have the same items() as the init cfgs. + + Returns: + Dict[str, float]): A dict that containing resource results(flops, + params) of each modules in resource_args['spec_modules']. + """ + assert 'spec_modules' in resource_args, \ + 'spec_modules is required when calling estimate_spec_modules().' + + resource_args.pop('measure_inference', False) + if 'input_shape' not in resource_args.keys(): + resource_args['input_shape'] = self.default_shape + if 'disabled_counters' not in resource_args.keys(): + resource_args['disabled_counters'] = self.disabled_counters + + model.eval() + spec_modules_resources = get_model_complexity_info( + model, **resource_args) + + return spec_modules_resources diff --git a/mmrazor/registry/registry.py b/mmrazor/registry/registry.py index c598fcb1..1e066c37 100644 --- a/mmrazor/registry/registry.py +++ b/mmrazor/registry/registry.py @@ -40,7 +40,7 @@ def build_razor_model_from_cfg( # TODO relay on mmengine:HAOCHENYE/config_new_feature if cfg.get('cfg_path', None) and not cfg.get('type', None): from mmengine.config import get_model - model = get_model(**cfg) + model = get_model(**cfg) # type: ignore return model from mmrazor.structures import load_fix_subnet diff --git a/mmrazor/structures/subnet/__init__.py b/mmrazor/structures/subnet/__init__.py index b64cf88e..fa3c9fae 100644 --- a/mmrazor/structures/subnet/__init__.py +++ b/mmrazor/structures/subnet/__init__.py @@ -1,8 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. from .candidate import Candidates -from .estimators import FlopsEstimator from .fix_subnet import export_fix_subnet, load_fix_subnet -__all__ = [ - 'FlopsEstimator', 'load_fix_subnet', 'export_fix_subnet', 'Candidates' -] +__all__ = ['load_fix_subnet', 'export_fix_subnet', 'Candidates'] diff --git a/mmrazor/structures/subnet/estimators/__init__.py b/mmrazor/structures/subnet/estimators/__init__.py deleted file mode 100644 index 51ba019b..00000000 --- a/mmrazor/structures/subnet/estimators/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from .flops import FlopsEstimator - -__all__ = ['FlopsEstimator'] diff --git a/mmrazor/structures/subnet/estimators/flops.py b/mmrazor/structures/subnet/estimators/flops.py deleted file mode 100644 index 65153988..00000000 --- a/mmrazor/structures/subnet/estimators/flops.py +++ /dev/null @@ -1,270 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import copy -import sys -from functools import wraps -from typing import IO, Callable, Dict, Iterable, Optional, Tuple, Type - -from mmcv.cnn.utils import flops_counter as mmcv_flops_counter -from mmcv.cnn.utils import get_model_complexity_info -from torch.nn import Module - -from mmrazor.utils import ValidFixMutable -from ..fix_subnet import load_fix_subnet - - -class FlopsEstimator: - """An estimator to help calculate flops of module. - - FlopsEstimator is based on flops counter in mmcv, it can be directly used - to calculate flops of module or calculate flops of subnet. Also, it - provides api for adding flops counter hook to custom modules that are not - supported by mmcv. - - Examples: - >>> # direct calculate flops of nn.Conv2d - >>> conv2d = nn.Conv2d(3, 32, 3) - >>> FlopsEstimator.get_model_complexity_info( - ... conv2d, - ... input_shape=[3, 224, 224], - ... print_per_layer_stat=False) - ('0.04 GFLOPs', '896') - - >>> # calculate flops of custom modules - >>> class FoolAddConstant(nn.Module): - ... - ... def __init__(self, p: float = 0.1) -> None: - ... super().__init__() - ... - ... self.register_parameter( - ... name='p', - ... param=Parameter(torch.tensor(p, dtype=torch.float32))) - ... - ... def forward(self, x: Tensor) -> Tensor: - ... return x + self.p - ... - >>> def fool_add_constant_flops_counter_hook( - ... add_constant_module: nn.Module, - ... input: Tensor, - ... output: Tensor) -> None: - ... add_constant_module.__flops__ = 1e8 - ... - >>> FlopsEstimator.register_module( - ... flops_counter_hook=fool_add_constant_flops_counter_hook, - ... module=FoolAddConstant) - >>> model = FoolAddConstant() - >>> FlopsEstimator.get_model_complexity_info( - ... model=model, - ... input_shape=[3, 224, 224], - ... print_per_layer_stat=False) - ('0.1 GFLOPs', '1') - - >>> # calculate subnet flops - >>> class FoolOneShotModel(nn.Module): - ... - ... def __init__(self) -> None: - ... super().__init__() - ... - ... candidates = nn.ModuleDict({ - ... 'conv3x3': nn.Conv2d(3, 32, 3), - ... 'conv5x5': nn.Conv2d(3, 32, 5)}) - ... self.op = OneShotMutableOP(candidates) - ... self.op.current_choice = 'conv3x3' - ... - ... def forward(self, x: Tensor) -> Tensor: - ... return self.op(x) - ... - >>> model = FoolOneShotModel() - >>> fix_subnet = export_fix_subnet(model) - >>> fix_subnet - FixSubnet(modules={'op': 'conv3x3'}, channels=None) - >>> FlopsEstimator.get_model_complexity_info( - ... supernet=model, - ... fix_subnet=fix_subnet, - ... input_shape=[3, 224, 224], - ... print_per_layer_stat=False) - ('0.04 GFLOPs', '896') - """ - - _mmcv_modules_mapping: Dict[Module, Callable] = \ - mmcv_flops_counter.get_modules_mapping() - _custom_modules_mapping: Dict[Module, Callable] = {} - - @staticmethod - def get_model_complexity_info( - model: Module, - fix_mutable: Optional[ValidFixMutable] = None, - input_shape: Iterable[int] = (3, 224, 224), - input_constructor: Optional[Callable] = None, - print_per_layer_stat: bool = True, - as_strings: bool = True, - flush: bool = False, - ost: IO = sys.stdout) -> Tuple: - """Get complexity information of model. - - This method is based on ``get_model_complexity_info`` of mmcv. It can - calculate FLOPs and parameter counts of a model with corresponding - input shape. It can also print complexity information for each layer - in a model. - - Args: - model (torch.nn.Module): The model for complexity calculation. - fix_mutable (ValidFixMutable, optional): The config of fixed - subnet. When this argument is specified, the function will - return complexity information of the subnet. Default: None. - input_shape (Iterable[int]): Input shape used for calculation. - print_per_layer_stat (bool): Whether to print complexity - information for each layer in a model. Default: True. - as_strings (bool): Output FLOPs and params counts in a string form. - Default: True. - input_constructor (Callable, optional): If specified, it takes a - callable method that generates input. otherwise, it will - generate a random tensor with input shape to calculate FLOPs. - Default: None. - flush (bool): same as that in :func:`print`. Default: False. - ost (stream): same as ``file`` param in :func:`print`. - Default: sys.stdout. - - Returns: - tuple[float | str]: If ``as_strings`` is set to True, it will - return FLOPs and parameter counts in a string format. - otherwise, it will return those in a float number format. - """ - copied_model = copy.deepcopy(model) - if fix_mutable is not None: - load_fix_subnet(copied_model, fix_mutable) - - return get_model_complexity_info( - model=copied_model, - input_shape=input_shape, - input_constructor=input_constructor, - print_per_layer_stat=print_per_layer_stat, - as_strings=as_strings, - flush=flush, - ost=ost) - - @classmethod - def register_module(cls, - flops_counter_hook: Callable, - module: Optional[Type[Module]] = None, - force: bool = False) -> Optional[Callable]: - """Register a module with flops_counter_hook. - - Args: - flops_counter (Callable): The hook that specifies how to calculate - flops of given module. - module (torch.nn.Module, optional): Module class to be registered. - Defaults to None. - force (bool): Whether to override an existing flops_counter_hook - with the same module. Default to False. - """ - if not isinstance(force, bool): - raise TypeError(f'force must be a boolean, but got {type(force)}') - - if not (module is None or issubclass(module, Module)): - raise TypeError( - 'module must be None, an subclass of torch.nn.Module, ' - f'but got {module}') - - if not callable(flops_counter_hook): - raise TypeError('flops_counter_hook must be Callable, ' - f'but got {type(flops_counter_hook)}') - - if module is not None: - return cls._register_module( - module=module, - flops_counter_hook=flops_counter_hook, - force=force) - - def _register(module: Type[Module]) -> None: - cls._register_module( - module=module, - flops_counter_hook=flops_counter_hook, - force=force) - - return _register - - @classmethod - def remove_custom_module(cls, module: Type[Module]) -> None: - """Remove a registered module. - - Args: - module (torch.nn.Module): Module class to be removed. - """ - if module not in cls._custom_modules_mapping: - raise KeyError(f'{module} not in custom module mapping') - - del cls._custom_modules_mapping[module] - - @classmethod - def clear_custom_module(cls) -> None: - """Remove all registered modules.""" - cls._custom_modules_mapping.clear() - - @classmethod - def _register_module(cls, - flops_counter_hook: Callable, - module: Type[Module], - force: bool = False) -> None: - """Register a module with flops_counter_hook. - - Args: - flops_counter (Callable): The hook that specifies how to calculate - flops of given module. - module (torch.nn.Module, optional): Module class to be registered. - Defaults to None. - force (bool): Whether to override an existing flops_counter_hook - with the same module. Default to False. - """ - if not force and module in cls.get_modules_mapping(): - raise KeyError(f'{module} is already registered') - cls._custom_modules_mapping[module] = flops_counter_hook - - @classmethod - def get_modules_mapping(cls) -> Dict[Module, Callable]: - """Get all modules with their corresponding flops counter hook. - - Returns: - Dict[Module, Callable]: Modules with their corresponding flops - counter hook. - """ - return {**cls._mmcv_modules_mapping, **cls._custom_modules_mapping} - - @classmethod - def get_custom_modules_mapping(cls) -> Dict[Module, Callable]: - """Get customed modules with their corresponding flops counter hook. - - Returns: - Dict[Module, Callable]: Modules with their corresponding flops - counter hook. - """ - return {**cls._custom_modules_mapping} - - @classmethod - def _mmcv_modules_mappings_wrapper( - cls, mmcv_get_modules_mapping: Callable) -> Callable: - """Wrapper for ``get_modules_mapping`` function in mmcv. - - Args: - mmcv_get_modules_mapping (Callable): ``get_modules_mapping`` - function in mmcv. - - Returns: - Callable: Wrapped ``get_modules_mapping`` function. - """ - - @wraps(mmcv_get_modules_mapping) - def wrapper() -> Dict[Module, Callable]: - mmcv_modules_mapping: Dict[Module, Callable] = \ - mmcv_get_modules_mapping() - - # TODO - # use | operator - # | operator only be supported in python 3.9.0 or greater - return {**mmcv_modules_mapping, **cls._custom_modules_mapping} - - return wrapper - - -mmcv_flops_counter.get_modules_mapping = \ - FlopsEstimator._mmcv_modules_mappings_wrapper( - mmcv_flops_counter.get_modules_mapping) diff --git a/mmrazor/utils/setup_env.py b/mmrazor/utils/setup_env.py index 82e7c9c4..392658f8 100644 --- a/mmrazor/utils/setup_env.py +++ b/mmrazor/utils/setup_env.py @@ -71,12 +71,13 @@ def register_all_modules(init_default_scope: bool = True) -> None: DefaultScope.get_instance('mmrazor', scope_name='mmrazor') return current_scope = DefaultScope.get_current_instance() - if current_scope.scope_name != 'mmrazor': - warnings.warn('The current default scope ' - f'"{current_scope.scope_name}" is not "mmrazor", ' - '`register_all_modules` will force the current' - 'default scope to be "mmrazor". If this is not ' - 'expected, please set `init_default_scope=False`.') + if current_scope.scope_name != 'mmrazor': # type: ignore + warnings.warn( + 'The current default scope ' # type: ignore + f'"{current_scope.scope_name}" is not ' + '"mmrazor", `register_all_modules` will force the current' + 'default scope to be "mmrazor". If this is not expected, ' + 'please set `init_default_scope=False`.') # avoid name conflict new_instance_name = f'mmrazor-{datetime.datetime.now()}' DefaultScope.get_instance(new_instance_name, scope_name='mmrazor') diff --git a/tests/test_models/test_subnet/test_estimators/test_flops.py b/tests/test_models/test_subnet/test_estimators/test_flops.py deleted file mode 100644 index c090a611..00000000 --- a/tests/test_models/test_subnet/test_estimators/test_flops.py +++ /dev/null @@ -1,231 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import copy -from unittest import TestCase - -import pytest -import torch -from torch import Tensor -from torch.nn import Conv2d, Module, Parameter - -from mmrazor.models import OneShotMutableModule -from mmrazor.registry import MODELS -from mmrazor.structures import FlopsEstimator, export_fix_subnet - -_FIRST_STAGE_MUTABLE = dict( - type='OneShotMutableOP', - candidates=dict( - mb_k3e1=dict( - type='MBBlock', - kernel_size=3, - expand_ratio=1, - norm_cfg=dict(type='BN'), - act_cfg=dict(type='ReLU6')))) - -_OTHER_STAGE_MUTABLE = dict( - type='OneShotMutableOP', - candidates=dict( - mb_k3e3=dict( - type='MBBlock', - kernel_size=3, - expand_ratio=3, - norm_cfg=dict(type='BN'), - act_cfg=dict(type='ReLU6')), - mb_k5e3=dict( - type='MBBlock', - kernel_size=5, - expand_ratio=3, - norm_cfg=dict(type='BN'), - act_cfg=dict(type='ReLU6')), - identity=dict(type='Identity'))) - -ARCHSETTING_CFG = [ - # Parameters to build layers. 4 parameters are needed to construct a - # layer, from left to right: channel, num_blocks, stride, mutable cfg. - [16, 1, 1, _FIRST_STAGE_MUTABLE], - [24, 2, 2, _OTHER_STAGE_MUTABLE], - [32, 3, 2, _OTHER_STAGE_MUTABLE], - [64, 4, 2, _OTHER_STAGE_MUTABLE], - [96, 3, 1, _OTHER_STAGE_MUTABLE], - [160, 3, 2, _OTHER_STAGE_MUTABLE], - [320, 1, 1, _OTHER_STAGE_MUTABLE] -] - -NORM_CFG = dict(type='BN') -BACKBONE_CFG = dict( - type='mmrazor.SearchableMobileNet', - first_channels=32, - last_channels=1280, - widen_factor=1.0, - norm_cfg=NORM_CFG, - arch_setting=ARCHSETTING_CFG) - - -class FoolAddConstant(Module): - - def __init__(self, p: float = 0.1) -> None: - super().__init__() - - self.register_parameter( - name='p', param=Parameter(torch.tensor(p, dtype=torch.float32))) - - def forward(self, x: Tensor) -> Tensor: - return x + self.p - - -class FoolConv2d(Module): - - def __init__(self) -> None: - super().__init__() - - self.conv2d = Conv2d(3, 32, 3) - - def forward(self, x: Tensor) -> Tensor: - return self.conv2d(x) - - -class FoolConvModule(Module): - - def __init__(self) -> None: - super().__init__() - - self.add_constant = FoolAddConstant(0.1) - self.conv2d = FoolConv2d() - - def forward(self, x: Tensor) -> Tensor: - x = self.add_constant(x) - - return self.conv2d(x) - - -class TestFlopsEstimator(TestCase): - - def sample_choice(self, model: Module) -> None: - for module in model.modules(): - if isinstance(module, OneShotMutableModule): - module.current_choice = module.sample_choice() - - def test_get_model_complexity_info(self) -> None: - fool_conv2d = FoolConv2d() - flops_count, params_count = FlopsEstimator.get_model_complexity_info( - fool_conv2d, as_strings=False) - - self.assertGreater(flops_count, 0) - self.assertGreater(params_count, 0) - - def test_register_module(self) -> None: - fool_add_constant = FoolConvModule() - copied_module = copy.deepcopy(fool_add_constant) - flops_count, params_count = FlopsEstimator.get_model_complexity_info( - copied_module, as_strings=False) - - def fool_add_constant_flops_counter_hook(add_constant_module: Module, - input: Tensor, - output: Tensor) -> None: - add_constant_module.__flops__ = 1e6 - - # test register directly - FlopsEstimator.register_module( - flops_counter_hook=fool_add_constant_flops_counter_hook, - module=FoolAddConstant) - copied_module = copy.deepcopy(fool_add_constant) - flops_count_after_registered, params_count_after_registered = \ - FlopsEstimator.get_model_complexity_info( - model=copied_module, as_strings=False) - self.assertEqual(flops_count_after_registered - flops_count, 1e6) - self.assertEqual(params_count_after_registered - params_count, 0) - FlopsEstimator.remove_custom_module(FoolAddConstant) - - # test register using decorator - FlopsEstimator.register_module( - flops_counter_hook=fool_add_constant_flops_counter_hook)( - FoolAddConstant) - copied_module = copy.deepcopy(fool_add_constant) - flops_count_after_registered, params_count_after_registered = \ - FlopsEstimator.get_model_complexity_info( - model=copied_module, as_strings=False) - self.assertEqual(flops_count_after_registered - flops_count, 1e6) - self.assertEqual(params_count_after_registered - params_count, 0) - - FlopsEstimator.remove_custom_module(FoolAddConstant) - - def test_register_module_wrong_parameter(self) -> None: - - def fool_flops_counter_hook(module: Module, input: Tensor, - output: Tensor) -> None: - return - - with pytest.raises(TypeError): - FlopsEstimator.register_module( - flops_counter_hook=fool_flops_counter_hook, force=1) - with pytest.raises(TypeError): - FlopsEstimator.register_module( - flops_counter_hook=fool_flops_counter_hook, module=list) - with pytest.raises(TypeError): - FlopsEstimator.register_module( - flops_counter_hook=123, module=FoolAddConstant) - - # test double register - FlopsEstimator.register_module( - flops_counter_hook=fool_flops_counter_hook, module=FoolAddConstant) - with pytest.raises(KeyError): - FlopsEstimator.register_module( - flops_counter_hook=fool_flops_counter_hook, - module=FoolAddConstant) - FlopsEstimator.register_module( - flops_counter_hook=fool_flops_counter_hook, - module=FoolAddConstant, - force=True) - - FlopsEstimator.remove_custom_module(FoolAddConstant) - - def test_remove_custom_module(self) -> None: - with pytest.raises(KeyError): - FlopsEstimator.remove_custom_module(FoolAddConstant) - - def fool_flops_counter_hook(module: Module, input: Tensor, - output: Tensor) -> None: - return - - FlopsEstimator.register_module( - flops_counter_hook=fool_flops_counter_hook, module=FoolAddConstant) - - FlopsEstimator.remove_custom_module(FoolAddConstant) - - def test_clear_custom_module(self) -> None: - - def fool_flops_counter_hook(module: Module, input: Tensor, - output: Tensor) -> None: - return - - FlopsEstimator.register_module( - flops_counter_hook=fool_flops_counter_hook, module=FoolAddConstant) - FlopsEstimator.register_module( - flops_counter_hook=fool_flops_counter_hook, module=FoolConvModule) - - FlopsEstimator.clear_custom_module() - self.assertEqual(FlopsEstimator.get_custom_modules_mapping(), {}) - - def test_get_model_complexity_info_subnet(self) -> None: - model = MODELS.build(BACKBONE_CFG) - self.sample_choice(model) - copied_model = copy.deepcopy(model) - - flops_count, params_count = FlopsEstimator.get_model_complexity_info( - copied_model, as_strings=False) - - fix_subnet = export_fix_subnet(model) - subnet_flops_count, subnet_params_count = \ - FlopsEstimator.get_model_complexity_info( - model, fix_subnet, as_strings=False) - - self.assertEqual(flops_count, subnet_flops_count) - self.assertGreater(params_count, subnet_params_count) - - # test whether subnet estimate will affect original model - copied_model = copy.deepcopy(model) - flops_count_after_estimate, params_count_after_estimate = \ - FlopsEstimator.get_model_complexity_info( - copied_model, as_strings=False) - - self.assertEqual(flops_count, flops_count_after_estimate) - self.assertEqual(params_count, params_count_after_estimate) diff --git a/tests/test_models/test_task_modules/test_estimators/test_flops_params.py b/tests/test_models/test_task_modules/test_estimators/test_flops_params.py new file mode 100644 index 00000000..99be89bc --- /dev/null +++ b/tests/test_models/test_task_modules/test_estimators/test_flops_params.py @@ -0,0 +1,197 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +from unittest import TestCase + +import pytest +import torch +from torch import Tensor +from torch.nn import Conv2d, Module, Parameter + +from mmrazor.models import OneShotMutableModule, ResourceEstimator +from mmrazor.models.task_modules.estimators.counters import BaseCounter +from mmrazor.registry import MODELS, TASK_UTILS +from mmrazor.structures import export_fix_subnet, load_fix_subnet + +_FIRST_STAGE_MUTABLE = dict( + type='OneShotMutableOP', + candidates=dict( + mb_k3e1=dict( + type='MBBlock', + kernel_size=3, + expand_ratio=1, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU6')))) + +_OTHER_STAGE_MUTABLE = dict( + type='OneShotMutableOP', + candidates=dict( + mb_k3e3=dict( + type='MBBlock', + kernel_size=3, + expand_ratio=3, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU6')), + mb_k5e3=dict( + type='MBBlock', + kernel_size=5, + expand_ratio=3, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU6')), + identity=dict(type='Identity'))) + +ARCHSETTING_CFG = [ + # Parameters to build layers. 4 parameters are needed to construct a + # layer, from left to right: channel, num_blocks, stride, mutable cfg. + [16, 1, 1, _FIRST_STAGE_MUTABLE], + [24, 2, 2, _OTHER_STAGE_MUTABLE], + [32, 3, 2, _OTHER_STAGE_MUTABLE], + [64, 4, 2, _OTHER_STAGE_MUTABLE], + [96, 3, 1, _OTHER_STAGE_MUTABLE], + [160, 3, 2, _OTHER_STAGE_MUTABLE], + [320, 1, 1, _OTHER_STAGE_MUTABLE] +] + +NORM_CFG = dict(type='BN') +BACKBONE_CFG = dict( + type='mmrazor.SearchableMobileNet', + first_channels=32, + last_channels=1280, + widen_factor=1.0, + norm_cfg=NORM_CFG, + arch_setting=ARCHSETTING_CFG) + +estimator = ResourceEstimator() + + +class FoolAddConstant(Module): + + def __init__(self, p: float = 0.1) -> None: + super().__init__() + + self.register_parameter( + name='p', param=Parameter(torch.tensor(p, dtype=torch.float32))) + + def forward(self, x: Tensor) -> Tensor: + return x + self.p + + +@TASK_UTILS.register_module() +class FoolAddConstantCounter(BaseCounter): + + @staticmethod + def add_count_hook(module, input, output): + module.__flops__ += 1000000 + module.__params__ += 700000 + + +class FoolConv2d(Module): + + def __init__(self) -> None: + super().__init__() + + self.conv2d = Conv2d(3, 32, 3) + + def forward(self, x: Tensor) -> Tensor: + return self.conv2d(x) + + +class FoolConvModule(Module): + + def __init__(self) -> None: + super().__init__() + + self.add_constant = FoolAddConstant(0.1) + self.conv2d = FoolConv2d() + + def forward(self, x: Tensor) -> Tensor: + x = self.add_constant(x) + + return self.conv2d(x) + + +class TestResourceEstimator(TestCase): + + def sample_choice(self, model: Module) -> None: + for module in model.modules(): + if isinstance(module, OneShotMutableModule): + module.current_choice = module.sample_choice() + + def test_estimate(self) -> None: + fool_conv2d = FoolConv2d() + results = estimator.estimate( + model=fool_conv2d, + resource_args=dict(input_shape=(1, 3, 224, 224))) + flops_count = results['flops'] + params_count = results['params'] + + self.assertGreater(flops_count, 0) + self.assertGreater(params_count, 0) + + def test_register_module(self) -> None: + fool_add_constant = FoolConvModule() + results = estimator.estimate( + model=fool_add_constant, + resource_args=dict(input_shape=(1, 3, 224, 224))) + flops_count = results['flops'] + params_count = results['params'] + + self.assertEqual(flops_count, 45.158) + self.assertEqual(params_count, 0.701) + + def test_disable_sepc_counter(self) -> None: + fool_add_constant = FoolConvModule() + rest_results = estimator.estimate( + model=fool_add_constant, + resource_args=dict( + input_shape=(1, 3, 224, 224), + disabled_counters=['FoolAddConstantCounter'])) + rest_flops_count = rest_results['flops'] + rest_params_count = rest_results['params'] + + self.assertLess(rest_flops_count, 45.158) + self.assertLess(rest_params_count, 0.701) + + def test_estimate_spec_modules(self) -> None: + fool_add_constant = FoolConvModule() + results = estimator.estimate_spec_modules( + model=fool_add_constant, + resource_args=dict( + input_shape=(1, 3, 224, 224), spec_modules=['add_constant'])) + self.assertGreater(results['add_constant']['flops'], 0) + + with pytest.raises(AssertionError): + results = estimator.estimate_spec_modules( + model=fool_add_constant, + resource_args=dict( + input_shape=(1, 3, 224, 224), spec_modules=['backbone'])) + + def test_estimate_subnet(self) -> None: + resource_args = dict(input_shape=(1, 3, 224, 224)) + model = MODELS.build(BACKBONE_CFG) + self.sample_choice(model) + copied_model = copy.deepcopy(model) + + results = estimator.estimate( + model=copied_model, resource_args=resource_args) + flops_count = results['flops'] + params_count = results['params'] + + fix_subnet = export_fix_subnet(model) + load_fix_subnet(copied_model, fix_subnet) + subnet_results = estimator.estimate( + model=copied_model, resource_args=resource_args) + subnet_flops_count = subnet_results['flops'] + subnet_params_count = subnet_results['params'] + + self.assertEqual(flops_count, subnet_flops_count) + self.assertEqual(params_count, subnet_params_count) + + # test whether subnet estimate will affect original model + copied_model = copy.deepcopy(model) + results_after_estimate = \ + estimator.estimate(model=copied_model, resource_args=resource_args) + flops_count_after_estimate = results_after_estimate['flops'] + params_count_after_estimate = results_after_estimate['params'] + + self.assertEqual(flops_count, flops_count_after_estimate) + self.assertEqual(params_count, params_count_after_estimate) diff --git a/tests/test_runners/test_evolution_search_loop.py b/tests/test_runners/test_evolution_search_loop.py index f90e5c02..9ce59ad0 100644 --- a/tests/test_runners/test_evolution_search_loop.py +++ b/tests/test_runners/test_evolution_search_loop.py @@ -110,9 +110,16 @@ class TestEvolutionSearchLoop(TestCase): self.assertIsInstance(loop, EvolutionSearchLoop) self.assertEqual(loop.candidates, fake_candidates) - @patch('mmrazor.engine.runner.evolution_search_loop.export_fix_subnet') - @patch('mmrazor.structures.FlopsEstimator.get_model_complexity_info') - def test_run_epoch(self, mock_flops, mock_export_fix_subnet): + @patch( + 'mmrazor.engine.runner.evolution_search_loop.export_fix_subnet', + return_value={ + '1': 'choice1', + '2': 'choice2' + }) + @patch( + 'mmrazor.models.task_modules.ResourceEstimator.estimate', + return_value=dict(flops=50.0, params=1.0)) + def test_run_epoch(self, export_fix_subnet, estimate): # test_run_epoch: distributed == False loop_cfg = copy.deepcopy(self.train_cfg) loop_cfg.runner = self.runner @@ -123,8 +130,6 @@ class TestEvolutionSearchLoop(TestCase): self.runner.epoch = 1 self.runner.distributed = False self.runner.work_dir = self.temp_dir - fake_subnet = {'1': 'choice1', '2': 'choice2'} - self.runner.model.sample_subnet = MagicMock(return_value=fake_subnet) loop.run_epoch() self.assertEqual(len(loop.candidates), 4) self.assertEqual(len(loop.top_k_candidates), 2) @@ -136,8 +141,6 @@ class TestEvolutionSearchLoop(TestCase): self.runner.epoch = 1 self.runner.distributed = True self.runner.work_dir = self.temp_dir - fake_subnet = {'1': 'choice1', '2': 'choice2'} - self.runner.model.sample_subnet = MagicMock(return_value=fake_subnet) loop.run_epoch() self.assertEqual(len(loop.candidates), 4) self.assertEqual(len(loop.top_k_candidates), 2) @@ -150,10 +153,6 @@ class TestEvolutionSearchLoop(TestCase): self.runner.epoch = 1 self.runner.distributed = True self.runner.work_dir = self.temp_dir - fake_subnet = {'1': 'choice1', '2': 'choice2'} - loop.model.sample_subnet = MagicMock(return_value=fake_subnet) - mock_flops.return_value = (50., 1) - mock_export_fix_subnet.return_value = fake_subnet loop.run_epoch() self.assertEqual(len(loop.candidates), 4) self.assertEqual(len(loop.top_k_candidates), 2) diff --git a/tests/test_runners/test_subnet_sampler_loop.py b/tests/test_runners/test_subnet_sampler_loop.py index 5d6cfc26..15bd76c9 100644 --- a/tests/test_runners/test_subnet_sampler_loop.py +++ b/tests/test_runners/test_subnet_sampler_loop.py @@ -191,15 +191,20 @@ class TestGreedySamplerTrainLoop(TestCase): self.assertEqual(subnet, fake_subnet) self.assertEqual(len(loop.top_k_candidates), loop.top_k - 1) - @patch('mmrazor.engine.runner.subnet_sampler_loop.export_fix_subnet') - @patch('mmrazor.structures.FlopsEstimator.get_model_complexity_info') - def test_run(self, mock_flops, mock_export_fix_subnet): + @patch( + 'mmrazor.engine.runner.evolution_search_loop.export_fix_subnet', + return_value={ + '1': 'choice1', + '2': 'choice2' + }) + @patch( + 'mmrazor.models.task_modules.ResourceEstimator.estimate', + return_value=dict(flops=50.0, params=1.0)) + def test_run(self, export_fix_subnet, estimate): # test run with flops_range=None cfg = copy.deepcopy(self.iter_based_cfg) cfg.experiment_name = 'test_run1' runner = Runner.from_cfg(cfg) - fake_subnet = {'1': 'choice1', '2': 'choice2'} - runner.model.sample_subnet = MagicMock(return_value=fake_subnet) runner.train() self.assertEqual(runner.iter, runner.max_iters) @@ -210,10 +215,6 @@ class TestGreedySamplerTrainLoop(TestCase): cfg.experiment_name = 'test_run2' cfg.train_cfg.flops_range = (0, 100) runner = Runner.from_cfg(cfg) - fake_subnet = {'1': 'choice1', '2': 'choice2'} - runner.model.sample_subnet = MagicMock(return_value=fake_subnet) - mock_flops.return_value = (50., 1) - mock_export_fix_subnet.return_value = fake_subnet runner.train() self.assertEqual(runner.iter, runner.max_iters)