[Feature] Refactor Estimator for computing FLOPs/Params/Latency. (#230)
* Refactor ModelEstimator: 1. add EvaluatorLoop in engine.runners; 2. add estimator for structures (both subnet & supernet); 3. add layer_counter for each op. * fix lint * update estimator: 1. add ResourceEstimator based on BaseEstimator; 2. add notes & examples for ResourceEstimator & EvaluatorLoop usage; 3. fix a bug of latency test. 4. minor changes according to comments. * add UT & fix a bug caused by UT * add docstrings & remove old estimator * update docstrings for op_spec_counters * rename resource_evaluator_val_loop * support adding resource attrs of each submodule in a measured model * fix lint * refactor estimator file structures * support estimating resources for spec modules * rm old UT * update new estimator UT cases * fix traversal range of the model * cancel unit convert in accumulate_sub_module_flops_params * use estimator_cfg to build ResourceEstimator * fix a broadcast bug * delete fixed input_shape * add assertion and string-format-return when measuring spec_modules * add UT for estimating spec_modulespull/237/head
parent
57aec1f730
commit
4b3f8ab69e
mmrazor
models
task_modules
estimators
registry
structures/subnet
estimators
utils
tests
test_models
test_subnet/test_estimators
test_task_modules/test_estimators
|
@ -1,5 +1,5 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .hooks import DumpSubnetHook
|
||||
from .hooks import DumpSubnetHook, EstimateResourcesHook
|
||||
from .optimizers import SeparateOptimWrapperConstructor
|
||||
from .runner import (AutoSlimValLoop, DartsEpochBasedTrainLoop,
|
||||
DartsIterBasedTrainLoop, EvolutionSearchLoop,
|
||||
|
@ -10,5 +10,6 @@ __all__ = [
|
|||
'SeparateOptimWrapperConstructor', 'DumpSubnetHook',
|
||||
'SingleTeacherDistillValLoop', 'DartsEpochBasedTrainLoop',
|
||||
'DartsIterBasedTrainLoop', 'SlimmableValLoop', 'EvolutionSearchLoop',
|
||||
'GreedySamplerTrainLoop', 'AutoSlimValLoop', 'SelfDistillValLoop'
|
||||
'GreedySamplerTrainLoop', 'AutoSlimValLoop', 'SelfDistillValLoop',
|
||||
'EstimateResourcesHook'
|
||||
]
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .dump_subnet_hook import DumpSubnetHook
|
||||
from .estimate_resources_hook import EstimateResourcesHook
|
||||
|
||||
__all__ = ['DumpSubnetHook']
|
||||
__all__ = ['DumpSubnetHook', 'EstimateResourcesHook']
|
||||
|
|
|
@ -0,0 +1,117 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
from typing import Any, Dict, Optional, Sequence
|
||||
|
||||
import torch
|
||||
from mmengine.data import BaseDataElement
|
||||
from mmengine.hooks import Hook
|
||||
from mmengine.registry import HOOKS
|
||||
|
||||
from mmrazor.models.task_modules import ResourceEstimator
|
||||
|
||||
DATA_BATCH = Optional[Sequence[dict]]
|
||||
|
||||
|
||||
@HOOKS.register_module()
|
||||
class EstimateResourcesHook(Hook):
|
||||
"""Estimate model resources periodically.
|
||||
|
||||
Args:
|
||||
interval (int): The saving period. If ``by_epoch=True``, interval
|
||||
indicates epochs, otherwise it indicates iterations.
|
||||
Defaults to -1, which means "never".
|
||||
by_epoch (bool): Saving checkpoints by epoch or by iteration.
|
||||
Default to True.
|
||||
estimator_cfg (Dict[str, Any]): Used for building a resource estimator.
|
||||
Default to dict().
|
||||
|
||||
Example:
|
||||
>>> add the `EstimatorResourcesHook` in custom_hooks as follows:
|
||||
custom_hooks = [
|
||||
dict(type='mmrazor.EstimateResourcesHook',
|
||||
interval=1,
|
||||
by_epoch=True,
|
||||
estimator_cfg=dict(input_shape=(1, 3, 64, 64)))
|
||||
]
|
||||
"""
|
||||
out_dir: str
|
||||
|
||||
priority = 'VERY_LOW'
|
||||
|
||||
def __init__(self,
|
||||
interval: int = -1,
|
||||
by_epoch: bool = True,
|
||||
estimator_cfg: Dict[str, Any] = dict(),
|
||||
**kwargs) -> None:
|
||||
self.interval = interval
|
||||
self.by_epoch = by_epoch
|
||||
self.estimator = ResourceEstimator(**estimator_cfg)
|
||||
|
||||
def after_val_epoch(self,
|
||||
runner,
|
||||
metrics: Optional[Dict[str, float]] = None) -> None:
|
||||
"""Estimate model resources after every n val epochs.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
"""
|
||||
if not self.by_epoch:
|
||||
return
|
||||
|
||||
if self.every_n_epochs(runner, self.interval):
|
||||
self.estimate_resources(runner)
|
||||
|
||||
def after_val_iter(self,
|
||||
runner,
|
||||
batch_idx: int,
|
||||
data_batch: DATA_BATCH = None,
|
||||
outputs: Optional[Sequence[BaseDataElement]] = None) \
|
||||
-> None:
|
||||
"""Estimate model resources after every n val iters.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
"""
|
||||
if self.by_epoch:
|
||||
return
|
||||
|
||||
if self.every_n_train_iters(runner, self.interval):
|
||||
self.estimate_resources(runner)
|
||||
|
||||
def estimate_resources(self, runner) -> None:
|
||||
"""Estimate model resources: latency/flops/params."""
|
||||
model = runner.model.module if runner.distributed else runner.model
|
||||
|
||||
# TODO confirm the state judgement.
|
||||
if hasattr(model, 'is_supernet') and model.is_supernet:
|
||||
model = self.export_subnet(model)
|
||||
|
||||
resource_metrics = self.estimator.estimate(model)
|
||||
runner.logger.info(f'Estimate model resources: {resource_metrics}')
|
||||
|
||||
def export_subnet(self, model) -> torch.nn.Module:
|
||||
"""Export current best subnet.
|
||||
|
||||
NOTE: This method is called when it comes to those NAS algorithms that
|
||||
require building a supernet for training.
|
||||
|
||||
For those algorithms, measuring subnet resources is more meaningful
|
||||
than supernet during validation, therefore this method is required to
|
||||
get the current searched subnet from the supernet.
|
||||
"""
|
||||
# Avoid circular import
|
||||
from mmrazor.models.mutables.base_mutable import BaseMutable
|
||||
from mmrazor.structures import load_fix_subnet
|
||||
|
||||
# delete non-leaf tensor to get deepcopy(model).
|
||||
# TODO solve the hard case.
|
||||
for module in model.architecture.modules():
|
||||
if isinstance(module, BaseMutable):
|
||||
if hasattr(module, 'arch_weights'):
|
||||
delattr(module, 'arch_weights')
|
||||
|
||||
copied_model = copy.deepcopy(model)
|
||||
fix_mutable = copied_model.search_subnet()
|
||||
load_fix_subnet(copied_model, fix_mutable)
|
||||
|
||||
return copied_model
|
|
@ -27,7 +27,7 @@ class AutoSlimValLoop(ValLoop):
|
|||
# just for convenience
|
||||
self._model = model
|
||||
|
||||
def run(self) -> None:
|
||||
def run(self):
|
||||
"""Launch validation."""
|
||||
self.runner.call_hook('before_val')
|
||||
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
import os
|
||||
import os.path as osp
|
||||
import random
|
||||
import warnings
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from mmengine import fileio
|
||||
|
@ -13,8 +14,9 @@ from mmengine.runner import EpochBasedTrainLoop
|
|||
from mmengine.utils import is_list_of
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from mmrazor.models.task_modules import ResourceEstimator
|
||||
from mmrazor.registry import LOOPS
|
||||
from mmrazor.structures import Candidates, FlopsEstimator, export_fix_subnet
|
||||
from mmrazor.structures import Candidates, export_fix_subnet, load_fix_subnet
|
||||
from mmrazor.utils import SupportRandomSubnet
|
||||
from .utils import crossover
|
||||
|
||||
|
@ -42,6 +44,8 @@ class EvolutionSearchLoop(EpochBasedTrainLoop):
|
|||
mutate_prob (float): The probability of mutation. Defaults to 0.1.
|
||||
flops_range (tuple, optional): flops_range to be used for screening
|
||||
candidates.
|
||||
estimator_cfg (Dict[str, Any]): Used for building a resource estimator.
|
||||
Default to dict().
|
||||
score_key (str): Specify one metric in evaluation results to score
|
||||
candidates. Defaults to 'accuracy_top-1'.
|
||||
init_candidates (str, optional): The candidates file path, which is
|
||||
|
@ -62,6 +66,7 @@ class EvolutionSearchLoop(EpochBasedTrainLoop):
|
|||
num_crossover: int = 25,
|
||||
mutate_prob: float = 0.1,
|
||||
flops_range: Optional[Tuple[float, float]] = (0., 330 * 1e6),
|
||||
estimator_cfg: Dict[str, Any] = dict(),
|
||||
score_key: str = 'accuracy_top-1',
|
||||
init_candidates: Optional[str] = None) -> None:
|
||||
super().__init__(runner, dataloader, max_epochs)
|
||||
|
@ -80,6 +85,7 @@ class EvolutionSearchLoop(EpochBasedTrainLoop):
|
|||
self.num_candidates = num_candidates
|
||||
self.top_k = top_k
|
||||
self.flops_range = flops_range
|
||||
self.estimator_cfg = estimator_cfg
|
||||
self.score_key = score_key
|
||||
self.num_mutation = num_mutation
|
||||
self.num_crossover = num_crossover
|
||||
|
@ -299,8 +305,13 @@ class EvolutionSearchLoop(EpochBasedTrainLoop):
|
|||
|
||||
self.model.set_subnet(random_subnet)
|
||||
fix_mutable = export_fix_subnet(self.model)
|
||||
flops: float = FlopsEstimator.get_model_complexity_info(
|
||||
self.model, fix_mutable=fix_mutable, as_strings=False)[0]
|
||||
copied_model = copy.deepcopy(self.model)
|
||||
load_fix_subnet(copied_model, fix_mutable)
|
||||
|
||||
estimator = ResourceEstimator(**self.estimator_cfg)
|
||||
results = estimator.estimate(copied_model)
|
||||
flops = results['flops']
|
||||
|
||||
if self.flops_range[0] < flops < self.flops_range[1]:
|
||||
return True
|
||||
else:
|
||||
|
|
|
@ -38,7 +38,7 @@ class SlimmableValLoop(ValLoop):
|
|||
# just for convenience
|
||||
self._model = model
|
||||
|
||||
def run(self) -> None:
|
||||
def run(self):
|
||||
"""Launch validation."""
|
||||
self.runner.call_hook('before_val')
|
||||
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
from abc import abstractmethod
|
||||
from typing import Dict, List, Optional, Sequence, Tuple, Union
|
||||
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import torch
|
||||
from mmengine import fileio
|
||||
|
@ -12,8 +13,9 @@ from mmengine.runner import IterBasedTrainLoop
|
|||
from mmengine.utils import is_list_of
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from mmrazor.models.task_modules import ResourceEstimator
|
||||
from mmrazor.registry import LOOPS
|
||||
from mmrazor.structures import Candidates, FlopsEstimator, export_fix_subnet
|
||||
from mmrazor.structures import Candidates, export_fix_subnet, load_fix_subnet
|
||||
from mmrazor.utils import SupportRandomSubnet
|
||||
|
||||
|
||||
|
@ -100,7 +102,9 @@ class GreedySamplerTrainLoop(BaseSamplerTrainLoop):
|
|||
val_interval (int): Validation interval. Defaults to 1000.
|
||||
score_key (str): Specify one metric in evaluation results to score
|
||||
candidates. Defaults to 'accuracy_top-1'.
|
||||
constraints (dict): Constraints to be used for screening candidates.
|
||||
flops_range (dict): Constraints to be used for screening candidates.
|
||||
estimator_cfg (Dict[str, Any]): Used for building a resource estimator.
|
||||
Default to dict().
|
||||
num_candidates (int): The number of the candidates consist of samples
|
||||
from supernet and itself. Defaults to 1000.
|
||||
num_samples (int): The number of sample in each sampling subnet.
|
||||
|
@ -135,6 +139,7 @@ class GreedySamplerTrainLoop(BaseSamplerTrainLoop):
|
|||
val_interval: int = 1000,
|
||||
score_key: str = 'accuracy_top-1',
|
||||
flops_range: Optional[Tuple[float, float]] = (0., 330 * 1e6),
|
||||
estimator_cfg: Dict[str, Any] = dict(),
|
||||
num_candidates: int = 1000,
|
||||
num_samples: int = 10,
|
||||
top_k: int = 5,
|
||||
|
@ -158,6 +163,7 @@ class GreedySamplerTrainLoop(BaseSamplerTrainLoop):
|
|||
|
||||
self.score_key = score_key
|
||||
self.flops_range = flops_range
|
||||
self.estimator_cfg = estimator_cfg
|
||||
self.num_candidates = num_candidates
|
||||
self.num_samples = num_samples
|
||||
self.top_k = top_k
|
||||
|
@ -316,8 +322,13 @@ class GreedySamplerTrainLoop(BaseSamplerTrainLoop):
|
|||
|
||||
self.model.set_subnet(random_subnet)
|
||||
fix_mutable = export_fix_subnet(self.model)
|
||||
flops = FlopsEstimator.get_model_complexity_info(
|
||||
self.model, fix_mutable=fix_mutable, as_strings=False)[0]
|
||||
copied_model = copy.deepcopy(self.model)
|
||||
load_fix_subnet(copied_model, fix_mutable)
|
||||
|
||||
estimator = ResourceEstimator(**self.estimator_cfg)
|
||||
results = estimator.estimate(copied_model)
|
||||
flops = results['flops']
|
||||
|
||||
if self.flops_range[0] < flops < self.flops_range[1]:
|
||||
return True
|
||||
else:
|
||||
|
|
|
@ -6,3 +6,4 @@ from .losses import * # noqa: F401,F403
|
|||
from .mutables import * # noqa: F401,F403
|
||||
from .mutators import * # noqa: F401,F403
|
||||
from .ops import * # noqa: F401,F403
|
||||
from .task_modules import * # noqa: F401,F403
|
||||
|
|
|
@ -0,0 +1,4 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .estimators import ResourceEstimator
|
||||
|
||||
__all__ = ['ResourceEstimator']
|
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .counters import * # noqa: F401,F403
|
||||
from .resource_estimator import ResourceEstimator
|
||||
|
||||
__all__ = ['ResourceEstimator']
|
|
@ -0,0 +1,55 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
import torch.nn
|
||||
|
||||
from mmrazor.registry import TASK_UTILS
|
||||
|
||||
|
||||
@TASK_UTILS.register_module()
|
||||
class BaseEstimator(metaclass=ABCMeta):
|
||||
"""The base class of Estimator, used for estimating model infos.
|
||||
|
||||
Args:
|
||||
default_shape (tuple): Input data's default shape, for calculating
|
||||
resources consume. Defaults to (1, 3, 224, 224).
|
||||
units (str): Resource units. Defaults to 'M'.
|
||||
disabled_counters (list): List of disabled spec op counters.
|
||||
Defaults to None.
|
||||
as_strings (bool): Output FLOPs and params counts in a string
|
||||
form. Default to False.
|
||||
measure_inference (bool): whether to measure infer speed or not.
|
||||
Default to False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
default_shape: Tuple = (1, 3, 224, 224),
|
||||
units: str = 'M',
|
||||
disabled_counters: List[str] = None,
|
||||
as_strings: bool = False,
|
||||
measure_inference: bool = False):
|
||||
assert len(default_shape) in [3, 4, 5], \
|
||||
f'Unsupported shape: {default_shape}'
|
||||
self.default_shape = default_shape
|
||||
self.units = units
|
||||
self.disabled_counters = disabled_counters
|
||||
self.as_strings = as_strings
|
||||
self.measure_inference = measure_inference
|
||||
|
||||
@abstractmethod
|
||||
def estimate(
|
||||
self, model: torch.nn.Module, resource_args: Dict[str, Any] = dict()
|
||||
) -> Dict[str, float]:
|
||||
"""Estimate the resources(flops/params/latency) of the given model.
|
||||
|
||||
Args:
|
||||
model: The measured model.
|
||||
resource_args (Dict[str, float]): resources information.
|
||||
NOTE: resource_args have the same items() as the init cfgs.
|
||||
|
||||
Returns:
|
||||
Dict[str, float]): A dict that containing resource results(flops,
|
||||
params and latency).
|
||||
"""
|
||||
pass
|
|
@ -0,0 +1,10 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .flops_params_counter import (get_model_complexity_info,
|
||||
params_units_convert)
|
||||
from .latency_counter import repeat_measure_inference_speed
|
||||
from .op_counters import * # noqa: F401,F403
|
||||
|
||||
__all__ = [
|
||||
'get_model_complexity_info', 'params_units_convert',
|
||||
'repeat_measure_inference_speed'
|
||||
]
|
|
@ -0,0 +1,472 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import sys
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from mmrazor.registry import TASK_UTILS
|
||||
|
||||
|
||||
def get_model_complexity_info(model,
|
||||
input_shape,
|
||||
spec_modules=[],
|
||||
disabled_counters=[],
|
||||
print_per_layer_stat=False,
|
||||
as_strings=False,
|
||||
input_constructor=None,
|
||||
flush=False,
|
||||
ost=sys.stdout):
|
||||
"""Get complexity information of a model. This method can calculate FLOPs
|
||||
and parameter counts of a model with corresponding input shape. It can also
|
||||
print complexity information for each layer in a model. Supported layers
|
||||
are listed as below:
|
||||
|
||||
- Convolutions: ``nn.Conv1d``, ``nn.Conv2d``, ``nn.Conv3d``.
|
||||
- Activations: ``nn.ReLU``, ``nn.PReLU``, ``nn.ELU``, ``nn.LeakyReLU``,
|
||||
``nn.ReLU6``.
|
||||
- Poolings: ``nn.MaxPool1d``, ``nn.MaxPool2d``, ``nn.MaxPool3d``,
|
||||
``nn.AvgPool1d``, ``nn.AvgPool2d``, ``nn.AvgPool3d``,
|
||||
``nn.AdaptiveMaxPool1d``, ``nn.AdaptiveMaxPool2d``,
|
||||
``nn.AdaptiveMaxPool3d``, ``nn.AdaptiveAvgPool1d``,
|
||||
``nn.AdaptiveAvgPool2d``, ``nn.AdaptiveAvgPool3d``.
|
||||
- BatchNorms: ``nn.BatchNorm1d``, ``nn.BatchNorm2d``,
|
||||
``nn.BatchNorm3d``.
|
||||
- Linear: ``nn.Linear``.
|
||||
- Deconvolution: ``nn.ConvTranspose2d``.
|
||||
- Upsample: ``nn.Upsample``.
|
||||
|
||||
Args:
|
||||
model (nn.Module): The model for complexity calculation.
|
||||
input_shape (tuple): Input shape (including batchsize) used for
|
||||
calculation.
|
||||
spec_modules (list): A list that contains the names of several spec
|
||||
modules, which users want to get resources infos of them.
|
||||
e.g., ['backbone', 'head'], ['backbone.layer1']. Default to [].
|
||||
disabled_counters (list): One can limit which ops' spec would be
|
||||
calculated. Default to [].
|
||||
print_per_layer_stat (bool): Whether to print complexity information
|
||||
for each layer in a model. Default to True.
|
||||
as_strings (bool): Output FLOPs and params counts in a string form.
|
||||
Default to True.
|
||||
input_constructor (None | callable): If specified, it takes a callable
|
||||
method that generates input. otherwise, it will generate a random
|
||||
tensor with input shape to calculate FLOPs. Default to None.
|
||||
flush (bool): same as that in :func:`print`. Default to False.
|
||||
ost (stream): same as ``file`` param in :func:`print`.
|
||||
Default to sys.stdout.
|
||||
|
||||
Returns:
|
||||
tuple[float | str] | dict[str, float]: If `as_strings` is set to True,
|
||||
it will return FLOPs and parameter counts in a string format.
|
||||
Otherwise, it will return those in a float number format.
|
||||
If len(spec_modules) > 0, it will return a resource info dict with
|
||||
FLOPs and parameter counts of each spec module in float format.
|
||||
"""
|
||||
assert type(input_shape) is tuple
|
||||
assert len(input_shape) >= 1
|
||||
assert isinstance(model, nn.Module)
|
||||
flops_params_model = add_flops_params_counting_methods(model)
|
||||
flops_params_model.eval()
|
||||
flops_params_model.start_flops_params_count(disabled_counters)
|
||||
if input_constructor:
|
||||
input = input_constructor(input_shape)
|
||||
_ = flops_params_model(**input)
|
||||
else:
|
||||
try:
|
||||
batch = torch.ones(()).new_empty(
|
||||
tuple(input_shape),
|
||||
dtype=next(flops_params_model.parameters()).dtype,
|
||||
device=next(flops_params_model.parameters()).device)
|
||||
except StopIteration:
|
||||
# Avoid StopIteration for models which have no parameters,
|
||||
# like `nn.Relu()`, `nn.AvgPool2d`, etc.
|
||||
batch = torch.ones(()).new_empty(tuple(input_shape))
|
||||
|
||||
_ = flops_params_model(batch)
|
||||
|
||||
flops_count, params_count = \
|
||||
flops_params_model.compute_average_flops_params_cost()
|
||||
|
||||
if print_per_layer_stat:
|
||||
print_model_with_flops_params(
|
||||
flops_params_model,
|
||||
flops_count,
|
||||
params_count,
|
||||
ost=ost,
|
||||
flush=flush)
|
||||
|
||||
if len(spec_modules):
|
||||
module_names = [name for name, _ in flops_params_model.named_modules()]
|
||||
for module in spec_modules:
|
||||
assert module in module_names, \
|
||||
f'All modules in spec_modules should be in the measured ' \
|
||||
f'flops_params_model. Got module {module} in spec_modules.'
|
||||
spec_modules_resources = dict()
|
||||
accumulate_sub_module_flops_params(flops_params_model)
|
||||
for name, module in flops_params_model.named_modules():
|
||||
if name in spec_modules:
|
||||
spec_modules_resources[name] = dict()
|
||||
spec_modules_resources[name]['flops'] = module.__flops__
|
||||
spec_modules_resources[name]['params'] = module.__params__
|
||||
if as_strings:
|
||||
spec_modules_resources[name]['flops'] = str(
|
||||
params_units_convert(module.__flops__,
|
||||
'G')) + ' GFLOPs'
|
||||
spec_modules_resources[name]['params'] = str(
|
||||
params_units_convert(module.__params__, 'M')) + ' M'
|
||||
|
||||
flops_params_model.stop_flops_params_count()
|
||||
|
||||
if len(spec_modules):
|
||||
return spec_modules_resources
|
||||
|
||||
if as_strings:
|
||||
flops_string = str(params_units_convert(flops_count, 'G')) + ' GFLOPs'
|
||||
params_string = str(params_units_convert(params_count, 'M')) + ' M'
|
||||
return flops_string, params_string
|
||||
|
||||
return flops_count, params_count
|
||||
|
||||
|
||||
def params_units_convert(num_params, units='M', precision=3):
|
||||
"""Convert parameter number with units.
|
||||
|
||||
Args:
|
||||
num_params (float): Parameter number to be converted.
|
||||
units (str | None): Converted FLOPs units. Options are None, 'M',
|
||||
'K' and ''. If set to None, it will automatically choose the most
|
||||
suitable unit for Parameter number. Default to None.
|
||||
precision (int): Digit number after the decimal point. Default to 2.
|
||||
|
||||
Returns:
|
||||
str: The converted parameter number.
|
||||
|
||||
Examples:
|
||||
>>> params_units_convert(1e9)
|
||||
'1000.0'
|
||||
>>> params_units_convert(2e5)
|
||||
'200.0'
|
||||
>>> params_units_convert(3e-9)
|
||||
'3e-09'
|
||||
"""
|
||||
|
||||
if units == 'G':
|
||||
return round(num_params / 10.**9, precision)
|
||||
elif units == 'M':
|
||||
return round(num_params / 10.**6, precision)
|
||||
elif units == 'K':
|
||||
return round(num_params / 10.**3, precision)
|
||||
else:
|
||||
raise ValueError(f'Unsupported units convert: {units}')
|
||||
|
||||
|
||||
def print_model_with_flops_params(model,
|
||||
total_flops,
|
||||
total_params,
|
||||
units='G',
|
||||
precision=3,
|
||||
ost=sys.stdout,
|
||||
flush=False):
|
||||
"""Print a model with FLOPs and Params for each layer.
|
||||
|
||||
Args:
|
||||
model (nn.Module): The model to be printed.
|
||||
total_flops (float): Total FLOPs of the model.
|
||||
total_params (float): Total parameter counts of the model.
|
||||
units (str | None): Converted FLOPs units. Default to 'G'.
|
||||
precision (int): Digit number after the decimal point. Default to 3.
|
||||
ost (stream): same as `file` param in :func:`print`.
|
||||
Default to sys.stdout.
|
||||
flush (bool): same as that in :func:`print`. Default to False.
|
||||
|
||||
Example:
|
||||
>>> class ExampleModel(nn.Module):
|
||||
>>> def __init__(self):
|
||||
>>> super().__init__()
|
||||
>>> self.conv1 = nn.Conv2d(3, 8, 3)
|
||||
>>> self.conv2 = nn.Conv2d(8, 256, 3)
|
||||
>>> self.conv3 = nn.Conv2d(256, 8, 3)
|
||||
>>> self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
|
||||
>>> self.flatten = nn.Flatten()
|
||||
>>> self.fc = nn.Linear(8, 1)
|
||||
>>> def forward(self, x):
|
||||
>>> x = self.conv1(x)
|
||||
>>> x = self.conv2(x)
|
||||
>>> x = self.conv3(x)
|
||||
>>> x = self.avg_pool(x)
|
||||
>>> x = self.flatten(x)
|
||||
>>> x = self.fc(x)
|
||||
>>> return x
|
||||
>>> model = ExampleModel()
|
||||
>>> x = (3, 16, 16)
|
||||
to print the complexity information state for each layer, you can use
|
||||
>>> get_model_complexity_info(model, x)
|
||||
or directly use
|
||||
>>> print_model_with_flops_params(model, 4579784.0, 37361)
|
||||
ExampleModel(
|
||||
0.037 M, 100.000% Params, 0.005 GFLOPs, 100.000% FLOPs,
|
||||
(conv1): Conv2d(0.0 M, 0.600% Params, 0.0 GFLOPs, 0.959% FLOPs, 3, 8, kernel_size=(3, 3), stride=(1, 1)) # noqa: E501
|
||||
(conv2): Conv2d(0.019 M, 50.020% Params, 0.003 GFLOPs, 58.760% FLOPs, 8, 256, kernel_size=(3, 3), stride=(1, 1))
|
||||
(conv3): Conv2d(0.018 M, 49.356% Params, 0.002 GFLOPs, 40.264% FLOPs, 256, 8, kernel_size=(3, 3), stride=(1, 1))
|
||||
(avg_pool): AdaptiveAvgPool2d(0.0 M, 0.000% Params, 0.0 GFLOPs, 0.017% FLOPs, output_size=(1, 1))
|
||||
(flatten): Flatten(0.0 M, 0.000% Params, 0.0 GFLOPs, 0.000% FLOPs, )
|
||||
(fc): Linear(0.0 M, 0.024% Params, 0.0 GFLOPs, 0.000% FLOPs, in_features=8, out_features=1, bias=True)
|
||||
)
|
||||
"""
|
||||
|
||||
def accumulate_params(self):
|
||||
if is_supported_instance(self):
|
||||
return self.__params__
|
||||
else:
|
||||
sum = 0
|
||||
for m in self.children():
|
||||
sum += m.accumulate_params()
|
||||
return sum
|
||||
|
||||
def accumulate_flops(self):
|
||||
if is_supported_instance(self):
|
||||
return self.__flops__ / model.__batch_counter__
|
||||
else:
|
||||
sum = 0
|
||||
for m in self.children():
|
||||
sum += m.accumulate_flops()
|
||||
return sum
|
||||
|
||||
def flops_repr(self):
|
||||
accumulated_num_params = self.accumulate_params()
|
||||
accumulated_flops_cost = self.accumulate_flops()
|
||||
flops_string = str(
|
||||
params_units_convert(
|
||||
accumulated_flops_cost, units=units,
|
||||
precision=precision)) + ' ' + units + 'FLOPs'
|
||||
params_string = str(
|
||||
params_units_convert(
|
||||
accumulated_num_params, units='M', precision=precision)) + ' M'
|
||||
return ', '.join([
|
||||
params_string,
|
||||
'{:.3%} Params'.format(accumulated_num_params / total_params),
|
||||
flops_string,
|
||||
'{:.3%} FLOPs'.format(accumulated_flops_cost / total_flops),
|
||||
self.original_extra_repr()
|
||||
])
|
||||
|
||||
def add_extra_repr(m):
|
||||
m.accumulate_flops = accumulate_flops.__get__(m)
|
||||
m.accumulate_params = accumulate_params.__get__(m)
|
||||
flops_extra_repr = flops_repr.__get__(m)
|
||||
if m.extra_repr != flops_extra_repr:
|
||||
m.original_extra_repr = m.extra_repr
|
||||
m.extra_repr = flops_extra_repr
|
||||
assert m.extra_repr != m.original_extra_repr
|
||||
|
||||
def del_extra_repr(m):
|
||||
if hasattr(m, 'original_extra_repr'):
|
||||
m.extra_repr = m.original_extra_repr
|
||||
del m.original_extra_repr
|
||||
if hasattr(m, 'accumulate_flops'):
|
||||
del m.accumulate_flops
|
||||
|
||||
model.apply(add_extra_repr)
|
||||
print(model, file=ost, flush=flush)
|
||||
model.apply(del_extra_repr)
|
||||
|
||||
|
||||
def accumulate_sub_module_flops_params(model):
|
||||
"""Accumulate FLOPs and params for each module in the model. Each module in
|
||||
the model will have the `__flops__` and `__params__` parameters.
|
||||
|
||||
Args:
|
||||
model (nn.Module): The model to be accumulated.
|
||||
"""
|
||||
|
||||
def accumulate_params(module):
|
||||
if is_supported_instance(module):
|
||||
return module.__params__
|
||||
else:
|
||||
sum = 0
|
||||
for m in module.children():
|
||||
sum += accumulate_params(m)
|
||||
return sum
|
||||
|
||||
def accumulate_flops(module):
|
||||
if is_supported_instance(module):
|
||||
return module.__flops__ / model.__batch_counter__
|
||||
else:
|
||||
sum = 0
|
||||
for m in module.children():
|
||||
sum += accumulate_flops(m)
|
||||
return sum
|
||||
|
||||
for module in model.modules():
|
||||
_flops = accumulate_flops(module)
|
||||
_params = accumulate_params(module)
|
||||
module.__flops__ = _flops
|
||||
module.__params__ = _params
|
||||
|
||||
|
||||
def get_model_parameters_number(model):
|
||||
"""Calculate parameter number of a model.
|
||||
|
||||
Args:
|
||||
model (nn.module): The model for parameter number calculation.
|
||||
Returns:
|
||||
float: Parameter number of the model.
|
||||
"""
|
||||
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
return num_params
|
||||
|
||||
|
||||
def add_flops_params_counting_methods(net_main_module):
|
||||
# adding additional methods to the existing module object,
|
||||
# this is done this way so that each function has access to self object
|
||||
net_main_module.start_flops_params_count = start_flops_params_count.__get__( # noqa: E501
|
||||
net_main_module)
|
||||
net_main_module.stop_flops_params_count = stop_flops_params_count.__get__(
|
||||
net_main_module)
|
||||
net_main_module.reset_flops_params_count = reset_flops_params_count.__get__( # noqa: E501
|
||||
net_main_module)
|
||||
net_main_module.compute_average_flops_params_cost = compute_average_flops_params_cost.__get__( # noqa: E501
|
||||
net_main_module)
|
||||
|
||||
net_main_module.reset_flops_params_count()
|
||||
|
||||
return net_main_module
|
||||
|
||||
|
||||
def compute_average_flops_params_cost(self):
|
||||
"""Compute average FLOPs and Params cost.
|
||||
|
||||
A method to compute average FLOPs cost, which will be available after
|
||||
`add_flops_params_counting_methods()` is called on a desired net object.
|
||||
Returns:
|
||||
float: Current mean flops consumption per image.
|
||||
"""
|
||||
batches_count = self.__batch_counter__
|
||||
flops_sum = 0
|
||||
params_sum = 0
|
||||
for module in self.modules():
|
||||
if is_supported_instance(module):
|
||||
flops_sum += module.__flops__
|
||||
params_sum += module.__params__
|
||||
return flops_sum / batches_count, params_sum
|
||||
|
||||
|
||||
def start_flops_params_count(self, disabled_counters):
|
||||
"""Activate the computation of mean flops and params consumption per image.
|
||||
|
||||
A method to activate the computation of mean flops consumption per image.
|
||||
which will be available after ``add_flops_params_counting_methods()`` is
|
||||
called on a desired net object. It should be called before running the
|
||||
network.
|
||||
"""
|
||||
add_batch_counter_hook_function(self)
|
||||
|
||||
def add_flops_params_counter_hook_function(module):
|
||||
if is_supported_instance(module):
|
||||
if hasattr(module, '__flops_params_handle__'):
|
||||
return
|
||||
|
||||
else:
|
||||
counter_type = get_counter_type(module)
|
||||
if (disabled_counters is None
|
||||
or counter_type not in disabled_counters):
|
||||
counter = TASK_UTILS.build(
|
||||
dict(type=counter_type, _scope_='mmrazor'))
|
||||
handle = module.register_forward_hook(
|
||||
counter.add_count_hook)
|
||||
|
||||
module.__flops_params_handle__ = handle
|
||||
else:
|
||||
return
|
||||
|
||||
self.apply(partial(add_flops_params_counter_hook_function))
|
||||
|
||||
|
||||
def stop_flops_params_count(self):
|
||||
"""Stop computing the mean flops and params consumption per image.
|
||||
|
||||
A method to stop computing the mean flops consumption per image, which will
|
||||
be available after ``add_flops_params_counting_methods()`` is called on a
|
||||
desired net object. It can be called to pause the computation whenever.
|
||||
"""
|
||||
remove_batch_counter_hook_function(self)
|
||||
self.apply(remove_flops_params_counter_hook_function)
|
||||
|
||||
|
||||
def reset_flops_params_count(self):
|
||||
"""Reset statistics computed so far.
|
||||
|
||||
A method to Reset computed statistics, which will be available after
|
||||
`add_flops_params_counting_methods()` is called on a desired net object.
|
||||
"""
|
||||
add_batch_counter_variables_or_reset(self)
|
||||
self.apply(add_flops_params_counter_variable_or_reset)
|
||||
|
||||
|
||||
# ---- Internal functions
|
||||
def empty_flops_params_counter_hook(module, input, output):
|
||||
module.__flops__ += 0
|
||||
module.__params__ += 0
|
||||
|
||||
|
||||
def add_batch_counter_variables_or_reset(module):
|
||||
|
||||
module.__batch_counter__ = 0
|
||||
|
||||
|
||||
def add_batch_counter_hook_function(module):
|
||||
if hasattr(module, '__batch_counter_handle__'):
|
||||
return
|
||||
|
||||
handle = module.register_forward_hook(batch_counter_hook)
|
||||
module.__batch_counter_handle__ = handle
|
||||
|
||||
|
||||
def batch_counter_hook(module, input, output):
|
||||
batch_size = 1
|
||||
if len(input) > 0:
|
||||
# Can have multiple inputs, getting the first one
|
||||
input = input[0]
|
||||
batch_size = len(input)
|
||||
else:
|
||||
pass
|
||||
print('Warning! No positional inputs found for a module, '
|
||||
'assuming batch size is 1.')
|
||||
module.__batch_counter__ += batch_size
|
||||
|
||||
|
||||
def remove_batch_counter_hook_function(module):
|
||||
if hasattr(module, '__batch_counter_handle__'):
|
||||
module.__batch_counter_handle__.remove()
|
||||
del module.__batch_counter_handle__
|
||||
|
||||
|
||||
def add_flops_params_counter_variable_or_reset(module):
|
||||
if is_supported_instance(module):
|
||||
if hasattr(module, '__flops__') or hasattr(module, '__params__'):
|
||||
print('Warning: variables __flops__ or __params__ are already '
|
||||
'defined for the module' + type(module).__name__ +
|
||||
' ptflops can affect your code!')
|
||||
module.__flops__ = 0
|
||||
module.__params__ = 0
|
||||
|
||||
|
||||
def get_counter_type(module):
|
||||
return module.__class__.__name__ + 'Counter'
|
||||
|
||||
|
||||
def is_supported_instance(module):
|
||||
if get_counter_type(module) in TASK_UTILS._module_dict.keys():
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def remove_flops_params_counter_hook_function(module):
|
||||
if hasattr(module, '__flops_params_handle__'):
|
||||
module.__flops_params_handle__.remove()
|
||||
del module.__flops_params_handle__
|
||||
if hasattr(module, '__flops__'):
|
||||
del module.__flops__
|
||||
if hasattr(module, '__params__'):
|
||||
del module.__params__
|
|
@ -0,0 +1,88 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import logging
|
||||
import time
|
||||
|
||||
import torch
|
||||
from mmengine.logging import print_log
|
||||
|
||||
|
||||
def repeat_measure_inference_speed(model,
|
||||
resource_args,
|
||||
max_iter: int = 100,
|
||||
log_interval: int = 100,
|
||||
repeat_num: int = 1) -> float:
|
||||
"""Repeat speed measure for multi-times to get more precise results."""
|
||||
assert repeat_num >= 1
|
||||
|
||||
fps_list = []
|
||||
|
||||
for _ in range(repeat_num):
|
||||
|
||||
fps_list.append(
|
||||
measure_inference_speed(model, resource_args, max_iter,
|
||||
log_interval))
|
||||
|
||||
if repeat_num > 1:
|
||||
fps_list_ = [round(fps, 1) for fps in fps_list]
|
||||
times_per_img_list = [round(1000 / fps, 1) for fps in fps_list]
|
||||
mean_fps_ = sum(fps_list_) / len(fps_list_)
|
||||
mean_times_per_img = sum(times_per_img_list) / len(times_per_img_list)
|
||||
print_log(
|
||||
f'Overall fps: {fps_list_}[{mean_fps_:.1f}] img / s, '
|
||||
f'times per image: '
|
||||
f'{times_per_img_list}[{mean_times_per_img:.1f}] ms/img',
|
||||
logger='current',
|
||||
level=logging.DEBUG)
|
||||
return mean_times_per_img
|
||||
|
||||
latency = round(1000 / fps_list[0], 1)
|
||||
return latency
|
||||
|
||||
|
||||
def measure_inference_speed(model, resource_args, max_iter: int,
|
||||
log_interval: int) -> float:
|
||||
# the first several iterations may be very slow so skip them
|
||||
num_warmup = 5
|
||||
pure_inf_time = 0.0
|
||||
fps = 0.0
|
||||
data = dict()
|
||||
if next(model.parameters()).is_cuda:
|
||||
device = 'cuda'
|
||||
else:
|
||||
raise NotImplementedError('To use cpu to test latency not supported.')
|
||||
# benchmark with 100 image and take the average
|
||||
for i in range(1, max_iter):
|
||||
if device == 'cuda':
|
||||
data = torch.rand(resource_args['input_shape']).cuda()
|
||||
torch.cuda.synchronize()
|
||||
start_time = time.perf_counter()
|
||||
|
||||
with torch.no_grad():
|
||||
model(data)
|
||||
|
||||
torch.cuda.synchronize()
|
||||
elapsed = time.perf_counter() - start_time
|
||||
|
||||
if i >= num_warmup:
|
||||
pure_inf_time += elapsed
|
||||
if (i + 1) % log_interval == 0:
|
||||
fps = (i + 1 - num_warmup) / pure_inf_time
|
||||
print_log(
|
||||
f'Done image [{i + 1:<3}/ {max_iter}], '
|
||||
f'fps: {fps:.1f} img / s, '
|
||||
f'times per image: {1000 / fps:.1f} ms / img',
|
||||
logger='current',
|
||||
level=logging.DEBUG)
|
||||
|
||||
if (i + 1) == max_iter:
|
||||
fps = (i + 1 - num_warmup) / pure_inf_time
|
||||
print_log(
|
||||
f'Overall fps: {fps:.1f} img / s, '
|
||||
f'times per image: {1000 / fps:.1f} ms / img',
|
||||
logger='current',
|
||||
level=logging.DEBUG)
|
||||
break
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return fps
|
|
@ -0,0 +1,22 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .activation_layer_counter import (ELUCounter, LeakyReLUCounter,
|
||||
PReLUCounter, ReLU6Counter, ReLUCounter)
|
||||
from .base_counter import BaseCounter
|
||||
from .conv_layer_counter import Conv1dCounter, Conv2dCounter, Conv3dCounter
|
||||
from .deconv_layer_counter import ConvTranspose2dCounter
|
||||
from .linear_layer_counter import LinearCounter
|
||||
from .norm_layer_counter import (BatchNorm1dCounter, BatchNorm2dCounter,
|
||||
BatchNorm3dCounter, GroupNormCounter,
|
||||
InstanceNorm1dCounter, InstanceNorm2dCounter,
|
||||
InstanceNorm3dCounter, LayerNormCounter)
|
||||
from .pooling_layer_counter import * # noqa: F403, F405, F401
|
||||
from .upsample_layer_counter import UpsampleCounter
|
||||
|
||||
__all__ = [
|
||||
'ReLUCounter', 'PReLUCounter', 'ELUCounter', 'LeakyReLUCounter',
|
||||
'ReLU6Counter', 'BatchNorm1dCounter', 'BatchNorm2dCounter',
|
||||
'BatchNorm3dCounter', 'Conv1dCounter', 'Conv2dCounter', 'Conv3dCounter',
|
||||
'ConvTranspose2dCounter', 'UpsampleCounter', 'LinearCounter',
|
||||
'GroupNormCounter', 'InstanceNorm1dCounter', 'InstanceNorm2dCounter',
|
||||
'InstanceNorm3dCounter', 'LayerNormCounter', 'BaseCounter'
|
||||
]
|
|
@ -0,0 +1,35 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmrazor.registry import TASK_UTILS
|
||||
from ..flops_params_counter import get_model_parameters_number
|
||||
from .base_counter import BaseCounter
|
||||
|
||||
|
||||
@TASK_UTILS.register_module()
|
||||
class ReLUCounter(BaseCounter):
|
||||
"""FLOPs/params counter for ReLU series activate function."""
|
||||
|
||||
@staticmethod
|
||||
def add_count_hook(module, input, output):
|
||||
active_elements_count = output.numel()
|
||||
module.__flops__ += int(active_elements_count)
|
||||
module.__params__ += get_model_parameters_number(module)
|
||||
|
||||
|
||||
@TASK_UTILS.register_module()
|
||||
class PReLUCounter(ReLUCounter):
|
||||
pass
|
||||
|
||||
|
||||
@TASK_UTILS.register_module()
|
||||
class ELUCounter(ReLUCounter):
|
||||
pass
|
||||
|
||||
|
||||
@TASK_UTILS.register_module()
|
||||
class LeakyReLUCounter(ReLUCounter):
|
||||
pass
|
||||
|
||||
|
||||
@TASK_UTILS.register_module()
|
||||
class ReLU6Counter(ReLUCounter):
|
||||
pass
|
|
@ -0,0 +1,28 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from abc import ABCMeta, abstractclassmethod
|
||||
|
||||
|
||||
class BaseCounter(object, metaclass=ABCMeta):
|
||||
"""Base class of all op module counters in `TASK_UTILS`.
|
||||
|
||||
In ResourceEstimator, `XXModuleCounter` is responsible for `XXModule`,
|
||||
which refers to estimator/flops_params_counter.py::get_counter_type().
|
||||
Users can customize a `ModuleACounter` and overwrite the `add_count_hook`
|
||||
method with a self-defined module `ModuleA`.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
@abstractclassmethod
|
||||
def add_count_hook(module, input, output):
|
||||
"""The main method of a `BaseCounter` which defines the way to
|
||||
calculate resources(flops/params) of the current module.
|
||||
|
||||
Args:
|
||||
module (nn.Module): the module to be tested.
|
||||
input (_type_): input_tensor. Plz refer to `torch forward_hook`
|
||||
output (_type_): output_tensor. Plz refer to `torch forward_hook`
|
||||
"""
|
||||
pass
|
|
@ -0,0 +1,57 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import numpy as np
|
||||
|
||||
from mmrazor.registry import TASK_UTILS
|
||||
from .base_counter import BaseCounter
|
||||
|
||||
|
||||
class ConvCounter(BaseCounter):
|
||||
"""FLOPs/params counter for Conv module series."""
|
||||
|
||||
@staticmethod
|
||||
def add_count_hook(module, input, output):
|
||||
# Can have multiple inputs, getting the first one
|
||||
input = input[0]
|
||||
|
||||
batch_size = input.shape[0]
|
||||
output_dims = list(output.shape[2:])
|
||||
|
||||
kernel_dims = list(module.kernel_size)
|
||||
in_channels = module.in_channels
|
||||
out_channels = module.out_channels
|
||||
groups = module.groups
|
||||
|
||||
filters_per_channel = out_channels / groups
|
||||
conv_per_position_flops = int(
|
||||
np.prod(kernel_dims)) * in_channels * filters_per_channel
|
||||
|
||||
active_elements_count = batch_size * int(np.prod(output_dims))
|
||||
|
||||
overall_conv_flops = conv_per_position_flops * active_elements_count
|
||||
overall_params = conv_per_position_flops
|
||||
|
||||
bias_flops = 0
|
||||
overall_params = conv_per_position_flops
|
||||
if module.bias is not None:
|
||||
bias_flops = out_channels * active_elements_count
|
||||
overall_params += out_channels
|
||||
|
||||
overall_flops = overall_conv_flops + bias_flops
|
||||
|
||||
module.__flops__ += overall_flops
|
||||
module.__params__ += int(overall_params)
|
||||
|
||||
|
||||
@TASK_UTILS.register_module()
|
||||
class Conv1dCounter(ConvCounter):
|
||||
pass
|
||||
|
||||
|
||||
@TASK_UTILS.register_module()
|
||||
class Conv2dCounter(ConvCounter):
|
||||
pass
|
||||
|
||||
|
||||
@TASK_UTILS.register_module()
|
||||
class Conv3dCounter(ConvCounter):
|
||||
pass
|
|
@ -0,0 +1,38 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmrazor.registry import TASK_UTILS
|
||||
from ..flops_params_counter import get_model_parameters_number
|
||||
from .base_counter import BaseCounter
|
||||
|
||||
|
||||
@TASK_UTILS.register_module()
|
||||
class ConvTranspose2dCounter(BaseCounter):
|
||||
"""FLOPs/params counter for Decov module series."""
|
||||
|
||||
@staticmethod
|
||||
def add_count_hook(module, input, output):
|
||||
# Can have multiple inputs, getting the first one
|
||||
input = input[0]
|
||||
|
||||
batch_size = input.shape[0]
|
||||
input_height, input_width = input.shape[2:]
|
||||
|
||||
# TODO: use more common representation
|
||||
kernel_height, kernel_width = module.kernel_size
|
||||
in_channels = module.in_channels
|
||||
out_channels = module.out_channels
|
||||
groups = module.groups
|
||||
|
||||
filters_per_channel = out_channels // groups
|
||||
conv_per_position_flops = (
|
||||
kernel_height * kernel_width * in_channels * filters_per_channel)
|
||||
|
||||
active_elements_count = batch_size * input_height * input_width
|
||||
overall_conv_flops = conv_per_position_flops * active_elements_count
|
||||
bias_flops = 0
|
||||
if module.bias is not None:
|
||||
output_height, output_width = output.shape[2:]
|
||||
bias_flops = out_channels * batch_size * output_height * output_height # noqa: E501
|
||||
overall_flops = overall_conv_flops + bias_flops
|
||||
|
||||
module.__flops__ += int(overall_flops)
|
||||
module.__params__ += get_model_parameters_number(module)
|
|
@ -0,0 +1,19 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import numpy as np
|
||||
|
||||
from mmrazor.registry import TASK_UTILS
|
||||
from ..flops_params_counter import get_model_parameters_number
|
||||
from .base_counter import BaseCounter
|
||||
|
||||
|
||||
@TASK_UTILS.register_module()
|
||||
class LinearCounter(BaseCounter):
|
||||
"""FLOPs/params counter for Linear operation series."""
|
||||
|
||||
@staticmethod
|
||||
def add_count_hook(module, input, output):
|
||||
input = input[0]
|
||||
output_last_dim = output.shape[
|
||||
-1] # pytorch checks dimensions, so here we don't care much
|
||||
module.__flops__ += int(np.prod(input.shape) * output_last_dim)
|
||||
module.__params__ += get_model_parameters_number(module)
|
|
@ -0,0 +1,59 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import numpy as np
|
||||
|
||||
from mmrazor.registry import TASK_UTILS
|
||||
from ..flops_params_counter import get_model_parameters_number
|
||||
from .base_counter import BaseCounter
|
||||
|
||||
|
||||
class BNCounter(BaseCounter):
|
||||
"""FLOPs/params counter for BatchNormalization series."""
|
||||
|
||||
@staticmethod
|
||||
def add_count_hook(module, input, output):
|
||||
input = input[0]
|
||||
batch_flops = np.prod(input.shape)
|
||||
if getattr(module, 'affine', False):
|
||||
batch_flops *= 2
|
||||
module.__flops__ += int(batch_flops)
|
||||
module.__params__ += get_model_parameters_number(module)
|
||||
|
||||
|
||||
@TASK_UTILS.register_module()
|
||||
class BatchNorm1dCounter(BNCounter):
|
||||
pass
|
||||
|
||||
|
||||
@TASK_UTILS.register_module()
|
||||
class BatchNorm2dCounter(BNCounter):
|
||||
pass
|
||||
|
||||
|
||||
@TASK_UTILS.register_module()
|
||||
class BatchNorm3dCounter(BNCounter):
|
||||
pass
|
||||
|
||||
|
||||
@TASK_UTILS.register_module()
|
||||
class InstanceNorm1dCounter(BNCounter):
|
||||
pass
|
||||
|
||||
|
||||
@TASK_UTILS.register_module()
|
||||
class InstanceNorm2dCounter(BNCounter):
|
||||
pass
|
||||
|
||||
|
||||
@TASK_UTILS.register_module()
|
||||
class InstanceNorm3dCounter(BNCounter):
|
||||
pass
|
||||
|
||||
|
||||
@TASK_UTILS.register_module()
|
||||
class LayerNormCounter(BNCounter):
|
||||
pass
|
||||
|
||||
|
||||
@TASK_UTILS.register_module()
|
||||
class GroupNormCounter(BNCounter):
|
||||
pass
|
|
@ -0,0 +1,76 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import numpy as np
|
||||
|
||||
from mmrazor.registry import TASK_UTILS
|
||||
from ..flops_params_counter import get_model_parameters_number
|
||||
from .base_counter import BaseCounter
|
||||
|
||||
|
||||
class PoolCounter(BaseCounter):
|
||||
"""FLOPs/params counter for Pooling series."""
|
||||
|
||||
@staticmethod
|
||||
def add_count_hook(module, input, output):
|
||||
input = input[0]
|
||||
module.__flops__ += int(np.prod(input.shape))
|
||||
module.__params__ += get_model_parameters_number(module)
|
||||
|
||||
|
||||
@TASK_UTILS.register_module()
|
||||
class MaxPool1dCounter(PoolCounter):
|
||||
pass
|
||||
|
||||
|
||||
@TASK_UTILS.register_module()
|
||||
class MaxPool2dCounter(PoolCounter):
|
||||
pass
|
||||
|
||||
|
||||
@TASK_UTILS.register_module()
|
||||
class MaxPool3dCounter(PoolCounter):
|
||||
pass
|
||||
|
||||
|
||||
@TASK_UTILS.register_module()
|
||||
class AvgPool1dCounter(PoolCounter):
|
||||
pass
|
||||
|
||||
|
||||
@TASK_UTILS.register_module()
|
||||
class AvgPool2dCounter(PoolCounter):
|
||||
pass
|
||||
|
||||
|
||||
@TASK_UTILS.register_module()
|
||||
class AvgPool3dCounter(PoolCounter):
|
||||
pass
|
||||
|
||||
|
||||
@TASK_UTILS.register_module()
|
||||
class AdaptiveMaxPool1dCounter(PoolCounter):
|
||||
pass
|
||||
|
||||
|
||||
@TASK_UTILS.register_module()
|
||||
class AdaptiveMaxPool2dCounter(PoolCounter):
|
||||
pass
|
||||
|
||||
|
||||
@TASK_UTILS.register_module()
|
||||
class AdaptiveMaxPool3dCounter(PoolCounter):
|
||||
pass
|
||||
|
||||
|
||||
@TASK_UTILS.register_module()
|
||||
class AdaptiveAvgPool1dCounter(PoolCounter):
|
||||
pass
|
||||
|
||||
|
||||
@TASK_UTILS.register_module()
|
||||
class AdaptiveAvgPool2dCounter(PoolCounter):
|
||||
pass
|
||||
|
||||
|
||||
@TASK_UTILS.register_module()
|
||||
class AdaptiveAvgPool3dCounter(PoolCounter):
|
||||
pass
|
|
@ -0,0 +1,19 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmrazor.registry import TASK_UTILS
|
||||
from ..flops_params_counter import get_model_parameters_number
|
||||
from .base_counter import BaseCounter
|
||||
|
||||
|
||||
@TASK_UTILS.register_module()
|
||||
class UpsampleCounter(BaseCounter):
|
||||
"""FLOPs/params counter for Upsample function."""
|
||||
|
||||
@staticmethod
|
||||
def add_count_hook(module, input, output):
|
||||
output_size = output[0]
|
||||
batch_size = output_size.shape[0]
|
||||
output_elements_count = batch_size
|
||||
for val in output_size.shape[1:]:
|
||||
output_elements_count *= val
|
||||
module.__flops__ += int(output_elements_count)
|
||||
module.__params__ += get_model_parameters_number(module)
|
|
@ -0,0 +1,155 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
import torch.nn
|
||||
from mmengine.dist import broadcast_object_list, is_main_process
|
||||
|
||||
from mmrazor.registry import TASK_UTILS
|
||||
from .base_estimator import BaseEstimator
|
||||
from .counters import (get_model_complexity_info, params_units_convert,
|
||||
repeat_measure_inference_speed)
|
||||
|
||||
|
||||
@TASK_UTILS.register_module()
|
||||
class ResourceEstimator(BaseEstimator):
|
||||
"""Estimator for calculating the resources consume.
|
||||
|
||||
Args:
|
||||
default_shape (tuple): Input data's default shape, for calculating
|
||||
resources consume. Defaults to (1, 3, 224, 224)
|
||||
units (str): Resource units. Defaults to 'M'.
|
||||
disabled_counters (list): List of disabled spec op counters.
|
||||
Defaults to None.
|
||||
NOTE: disabled_counters contains the op counter class names
|
||||
in estimator.op_counters that require to be disabled,
|
||||
such as 'ConvCounter', 'BatchNorm2dCounter', ...
|
||||
|
||||
Examples:
|
||||
>>> # direct calculate resource consume of nn.Conv2d
|
||||
>>> conv2d = nn.Conv2d(3, 32, 3)
|
||||
>>> estimator = ResourceEstimator()
|
||||
>>> estimator.estimate(
|
||||
... model=conv2d,
|
||||
... resource_args=dict(input_shape=(1, 3, 64, 64)))
|
||||
{'flops': 3.444, 'params': 0.001, 'latency': 0.0}
|
||||
|
||||
>>> # calculate resources of custom modules
|
||||
>>> class CustomModule(nn.Module):
|
||||
...
|
||||
... def __init__(self) -> None:
|
||||
... super().__init__()
|
||||
...
|
||||
... def forward(self, x):
|
||||
... return x
|
||||
...
|
||||
>>> @TASK_UTILS.register_module()
|
||||
... class CustomModuleCounter(BaseCounter):
|
||||
...
|
||||
... @staticmethod
|
||||
... def add_count_hook(module, input, output):
|
||||
... module.__flops__ += 1000000
|
||||
... module.__params__ += 700000
|
||||
...
|
||||
>>> model = CustomModule()
|
||||
>>> estimator.estimate(
|
||||
... model=model,
|
||||
... resource_args=dict(input_shape=(1, 3, 64, 64)))
|
||||
{'flops': 1.0, 'params': 0.7, 'latency': 0.0}
|
||||
...
|
||||
>>> # calculate resources of custom modules with disable_counters
|
||||
>>> estimator.estimate(
|
||||
... model=model,
|
||||
... resource_args=dict(
|
||||
... input_shape=(1, 3, 64, 64),
|
||||
... disabled_counters=['CustomModuleCounter']))
|
||||
{'flops': 0.0, 'params': 0.0, 'latency': 0.0}
|
||||
|
||||
>>> # calculate resources of mmrazor.models
|
||||
NOTE: check 'EstimateResourcesHook' in
|
||||
mmrazor.engine.hooks.estimate_resources_hook for details.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
default_shape: Tuple = (1, 3, 224, 224),
|
||||
units: str = 'M',
|
||||
disabled_counters: List[str] = [],
|
||||
as_strings: bool = False,
|
||||
measure_inference: bool = False):
|
||||
super().__init__(default_shape, units, disabled_counters, as_strings,
|
||||
measure_inference)
|
||||
|
||||
def estimate(
|
||||
self, model: torch.nn.Module, resource_args: Dict[str, Any] = dict()
|
||||
) -> Dict[str, Any]:
|
||||
"""Estimate the resources(flops/params/latency) of the given model.
|
||||
|
||||
Args:
|
||||
model: The measured model.
|
||||
resource_args (Dict[str, float]): Args for resources estimation.
|
||||
NOTE: resource_args have the same items() as the init cfgs.
|
||||
|
||||
Returns:
|
||||
Dict[str, str]): A dict that containing resource results(flops,
|
||||
params and latency).
|
||||
"""
|
||||
resource_metrics = dict()
|
||||
if is_main_process():
|
||||
measure_inference = resource_args.pop('measure_inference', False)
|
||||
if 'input_shape' not in resource_args.keys():
|
||||
resource_args['input_shape'] = self.default_shape
|
||||
if 'disabled_counters' not in resource_args.keys():
|
||||
resource_args['disabled_counters'] = self.disabled_counters
|
||||
model.eval()
|
||||
flops, params = get_model_complexity_info(model, **resource_args)
|
||||
if measure_inference:
|
||||
latency = repeat_measure_inference_speed(
|
||||
model, resource_args, max_iter=100, repeat_num=2)
|
||||
else:
|
||||
latency = 0.0
|
||||
as_strings = resource_args.get('as_strings', self.as_strings)
|
||||
if as_strings and self.units is not None:
|
||||
raise ValueError('Set units to None, when as_trings=True.')
|
||||
if self.units is not None:
|
||||
flops = params_units_convert(flops, self.units)
|
||||
params = params_units_convert(params, self.units)
|
||||
resource_metrics.update({
|
||||
'flops': flops,
|
||||
'params': params,
|
||||
'latency': latency
|
||||
})
|
||||
results = [resource_metrics]
|
||||
else:
|
||||
results = [None] # type: ignore
|
||||
|
||||
broadcast_object_list(results)
|
||||
|
||||
return results[0]
|
||||
|
||||
def estimate_spec_modules(
|
||||
self, model: torch.nn.Module, resource_args: Dict[str, Any] = dict()
|
||||
) -> Dict[str, float]:
|
||||
"""Estimate the resources(flops/params/latency) of the spec modules.
|
||||
|
||||
Args:
|
||||
model: The measured model.
|
||||
resource_args (Dict[str, float]): Args for resources estimation.
|
||||
NOTE: resource_args have the same items() as the init cfgs.
|
||||
|
||||
Returns:
|
||||
Dict[str, float]): A dict that containing resource results(flops,
|
||||
params) of each modules in resource_args['spec_modules'].
|
||||
"""
|
||||
assert 'spec_modules' in resource_args, \
|
||||
'spec_modules is required when calling estimate_spec_modules().'
|
||||
|
||||
resource_args.pop('measure_inference', False)
|
||||
if 'input_shape' not in resource_args.keys():
|
||||
resource_args['input_shape'] = self.default_shape
|
||||
if 'disabled_counters' not in resource_args.keys():
|
||||
resource_args['disabled_counters'] = self.disabled_counters
|
||||
|
||||
model.eval()
|
||||
spec_modules_resources = get_model_complexity_info(
|
||||
model, **resource_args)
|
||||
|
||||
return spec_modules_resources
|
|
@ -40,7 +40,7 @@ def build_razor_model_from_cfg(
|
|||
# TODO relay on mmengine:HAOCHENYE/config_new_feature
|
||||
if cfg.get('cfg_path', None) and not cfg.get('type', None):
|
||||
from mmengine.config import get_model
|
||||
model = get_model(**cfg)
|
||||
model = get_model(**cfg) # type: ignore
|
||||
return model
|
||||
|
||||
from mmrazor.structures import load_fix_subnet
|
||||
|
|
|
@ -1,8 +1,5 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .candidate import Candidates
|
||||
from .estimators import FlopsEstimator
|
||||
from .fix_subnet import export_fix_subnet, load_fix_subnet
|
||||
|
||||
__all__ = [
|
||||
'FlopsEstimator', 'load_fix_subnet', 'export_fix_subnet', 'Candidates'
|
||||
]
|
||||
__all__ = ['load_fix_subnet', 'export_fix_subnet', 'Candidates']
|
||||
|
|
|
@ -1,4 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .flops import FlopsEstimator
|
||||
|
||||
__all__ = ['FlopsEstimator']
|
|
@ -1,270 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
import sys
|
||||
from functools import wraps
|
||||
from typing import IO, Callable, Dict, Iterable, Optional, Tuple, Type
|
||||
|
||||
from mmcv.cnn.utils import flops_counter as mmcv_flops_counter
|
||||
from mmcv.cnn.utils import get_model_complexity_info
|
||||
from torch.nn import Module
|
||||
|
||||
from mmrazor.utils import ValidFixMutable
|
||||
from ..fix_subnet import load_fix_subnet
|
||||
|
||||
|
||||
class FlopsEstimator:
|
||||
"""An estimator to help calculate flops of module.
|
||||
|
||||
FlopsEstimator is based on flops counter in mmcv, it can be directly used
|
||||
to calculate flops of module or calculate flops of subnet. Also, it
|
||||
provides api for adding flops counter hook to custom modules that are not
|
||||
supported by mmcv.
|
||||
|
||||
Examples:
|
||||
>>> # direct calculate flops of nn.Conv2d
|
||||
>>> conv2d = nn.Conv2d(3, 32, 3)
|
||||
>>> FlopsEstimator.get_model_complexity_info(
|
||||
... conv2d,
|
||||
... input_shape=[3, 224, 224],
|
||||
... print_per_layer_stat=False)
|
||||
('0.04 GFLOPs', '896')
|
||||
|
||||
>>> # calculate flops of custom modules
|
||||
>>> class FoolAddConstant(nn.Module):
|
||||
...
|
||||
... def __init__(self, p: float = 0.1) -> None:
|
||||
... super().__init__()
|
||||
...
|
||||
... self.register_parameter(
|
||||
... name='p',
|
||||
... param=Parameter(torch.tensor(p, dtype=torch.float32)))
|
||||
...
|
||||
... def forward(self, x: Tensor) -> Tensor:
|
||||
... return x + self.p
|
||||
...
|
||||
>>> def fool_add_constant_flops_counter_hook(
|
||||
... add_constant_module: nn.Module,
|
||||
... input: Tensor,
|
||||
... output: Tensor) -> None:
|
||||
... add_constant_module.__flops__ = 1e8
|
||||
...
|
||||
>>> FlopsEstimator.register_module(
|
||||
... flops_counter_hook=fool_add_constant_flops_counter_hook,
|
||||
... module=FoolAddConstant)
|
||||
>>> model = FoolAddConstant()
|
||||
>>> FlopsEstimator.get_model_complexity_info(
|
||||
... model=model,
|
||||
... input_shape=[3, 224, 224],
|
||||
... print_per_layer_stat=False)
|
||||
('0.1 GFLOPs', '1')
|
||||
|
||||
>>> # calculate subnet flops
|
||||
>>> class FoolOneShotModel(nn.Module):
|
||||
...
|
||||
... def __init__(self) -> None:
|
||||
... super().__init__()
|
||||
...
|
||||
... candidates = nn.ModuleDict({
|
||||
... 'conv3x3': nn.Conv2d(3, 32, 3),
|
||||
... 'conv5x5': nn.Conv2d(3, 32, 5)})
|
||||
... self.op = OneShotMutableOP(candidates)
|
||||
... self.op.current_choice = 'conv3x3'
|
||||
...
|
||||
... def forward(self, x: Tensor) -> Tensor:
|
||||
... return self.op(x)
|
||||
...
|
||||
>>> model = FoolOneShotModel()
|
||||
>>> fix_subnet = export_fix_subnet(model)
|
||||
>>> fix_subnet
|
||||
FixSubnet(modules={'op': 'conv3x3'}, channels=None)
|
||||
>>> FlopsEstimator.get_model_complexity_info(
|
||||
... supernet=model,
|
||||
... fix_subnet=fix_subnet,
|
||||
... input_shape=[3, 224, 224],
|
||||
... print_per_layer_stat=False)
|
||||
('0.04 GFLOPs', '896')
|
||||
"""
|
||||
|
||||
_mmcv_modules_mapping: Dict[Module, Callable] = \
|
||||
mmcv_flops_counter.get_modules_mapping()
|
||||
_custom_modules_mapping: Dict[Module, Callable] = {}
|
||||
|
||||
@staticmethod
|
||||
def get_model_complexity_info(
|
||||
model: Module,
|
||||
fix_mutable: Optional[ValidFixMutable] = None,
|
||||
input_shape: Iterable[int] = (3, 224, 224),
|
||||
input_constructor: Optional[Callable] = None,
|
||||
print_per_layer_stat: bool = True,
|
||||
as_strings: bool = True,
|
||||
flush: bool = False,
|
||||
ost: IO = sys.stdout) -> Tuple:
|
||||
"""Get complexity information of model.
|
||||
|
||||
This method is based on ``get_model_complexity_info`` of mmcv. It can
|
||||
calculate FLOPs and parameter counts of a model with corresponding
|
||||
input shape. It can also print complexity information for each layer
|
||||
in a model.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): The model for complexity calculation.
|
||||
fix_mutable (ValidFixMutable, optional): The config of fixed
|
||||
subnet. When this argument is specified, the function will
|
||||
return complexity information of the subnet. Default: None.
|
||||
input_shape (Iterable[int]): Input shape used for calculation.
|
||||
print_per_layer_stat (bool): Whether to print complexity
|
||||
information for each layer in a model. Default: True.
|
||||
as_strings (bool): Output FLOPs and params counts in a string form.
|
||||
Default: True.
|
||||
input_constructor (Callable, optional): If specified, it takes a
|
||||
callable method that generates input. otherwise, it will
|
||||
generate a random tensor with input shape to calculate FLOPs.
|
||||
Default: None.
|
||||
flush (bool): same as that in :func:`print`. Default: False.
|
||||
ost (stream): same as ``file`` param in :func:`print`.
|
||||
Default: sys.stdout.
|
||||
|
||||
Returns:
|
||||
tuple[float | str]: If ``as_strings`` is set to True, it will
|
||||
return FLOPs and parameter counts in a string format.
|
||||
otherwise, it will return those in a float number format.
|
||||
"""
|
||||
copied_model = copy.deepcopy(model)
|
||||
if fix_mutable is not None:
|
||||
load_fix_subnet(copied_model, fix_mutable)
|
||||
|
||||
return get_model_complexity_info(
|
||||
model=copied_model,
|
||||
input_shape=input_shape,
|
||||
input_constructor=input_constructor,
|
||||
print_per_layer_stat=print_per_layer_stat,
|
||||
as_strings=as_strings,
|
||||
flush=flush,
|
||||
ost=ost)
|
||||
|
||||
@classmethod
|
||||
def register_module(cls,
|
||||
flops_counter_hook: Callable,
|
||||
module: Optional[Type[Module]] = None,
|
||||
force: bool = False) -> Optional[Callable]:
|
||||
"""Register a module with flops_counter_hook.
|
||||
|
||||
Args:
|
||||
flops_counter (Callable): The hook that specifies how to calculate
|
||||
flops of given module.
|
||||
module (torch.nn.Module, optional): Module class to be registered.
|
||||
Defaults to None.
|
||||
force (bool): Whether to override an existing flops_counter_hook
|
||||
with the same module. Default to False.
|
||||
"""
|
||||
if not isinstance(force, bool):
|
||||
raise TypeError(f'force must be a boolean, but got {type(force)}')
|
||||
|
||||
if not (module is None or issubclass(module, Module)):
|
||||
raise TypeError(
|
||||
'module must be None, an subclass of torch.nn.Module, '
|
||||
f'but got {module}')
|
||||
|
||||
if not callable(flops_counter_hook):
|
||||
raise TypeError('flops_counter_hook must be Callable, '
|
||||
f'but got {type(flops_counter_hook)}')
|
||||
|
||||
if module is not None:
|
||||
return cls._register_module(
|
||||
module=module,
|
||||
flops_counter_hook=flops_counter_hook,
|
||||
force=force)
|
||||
|
||||
def _register(module: Type[Module]) -> None:
|
||||
cls._register_module(
|
||||
module=module,
|
||||
flops_counter_hook=flops_counter_hook,
|
||||
force=force)
|
||||
|
||||
return _register
|
||||
|
||||
@classmethod
|
||||
def remove_custom_module(cls, module: Type[Module]) -> None:
|
||||
"""Remove a registered module.
|
||||
|
||||
Args:
|
||||
module (torch.nn.Module): Module class to be removed.
|
||||
"""
|
||||
if module not in cls._custom_modules_mapping:
|
||||
raise KeyError(f'{module} not in custom module mapping')
|
||||
|
||||
del cls._custom_modules_mapping[module]
|
||||
|
||||
@classmethod
|
||||
def clear_custom_module(cls) -> None:
|
||||
"""Remove all registered modules."""
|
||||
cls._custom_modules_mapping.clear()
|
||||
|
||||
@classmethod
|
||||
def _register_module(cls,
|
||||
flops_counter_hook: Callable,
|
||||
module: Type[Module],
|
||||
force: bool = False) -> None:
|
||||
"""Register a module with flops_counter_hook.
|
||||
|
||||
Args:
|
||||
flops_counter (Callable): The hook that specifies how to calculate
|
||||
flops of given module.
|
||||
module (torch.nn.Module, optional): Module class to be registered.
|
||||
Defaults to None.
|
||||
force (bool): Whether to override an existing flops_counter_hook
|
||||
with the same module. Default to False.
|
||||
"""
|
||||
if not force and module in cls.get_modules_mapping():
|
||||
raise KeyError(f'{module} is already registered')
|
||||
cls._custom_modules_mapping[module] = flops_counter_hook
|
||||
|
||||
@classmethod
|
||||
def get_modules_mapping(cls) -> Dict[Module, Callable]:
|
||||
"""Get all modules with their corresponding flops counter hook.
|
||||
|
||||
Returns:
|
||||
Dict[Module, Callable]: Modules with their corresponding flops
|
||||
counter hook.
|
||||
"""
|
||||
return {**cls._mmcv_modules_mapping, **cls._custom_modules_mapping}
|
||||
|
||||
@classmethod
|
||||
def get_custom_modules_mapping(cls) -> Dict[Module, Callable]:
|
||||
"""Get customed modules with their corresponding flops counter hook.
|
||||
|
||||
Returns:
|
||||
Dict[Module, Callable]: Modules with their corresponding flops
|
||||
counter hook.
|
||||
"""
|
||||
return {**cls._custom_modules_mapping}
|
||||
|
||||
@classmethod
|
||||
def _mmcv_modules_mappings_wrapper(
|
||||
cls, mmcv_get_modules_mapping: Callable) -> Callable:
|
||||
"""Wrapper for ``get_modules_mapping`` function in mmcv.
|
||||
|
||||
Args:
|
||||
mmcv_get_modules_mapping (Callable): ``get_modules_mapping``
|
||||
function in mmcv.
|
||||
|
||||
Returns:
|
||||
Callable: Wrapped ``get_modules_mapping`` function.
|
||||
"""
|
||||
|
||||
@wraps(mmcv_get_modules_mapping)
|
||||
def wrapper() -> Dict[Module, Callable]:
|
||||
mmcv_modules_mapping: Dict[Module, Callable] = \
|
||||
mmcv_get_modules_mapping()
|
||||
|
||||
# TODO
|
||||
# use | operator
|
||||
# | operator only be supported in python 3.9.0 or greater
|
||||
return {**mmcv_modules_mapping, **cls._custom_modules_mapping}
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
mmcv_flops_counter.get_modules_mapping = \
|
||||
FlopsEstimator._mmcv_modules_mappings_wrapper(
|
||||
mmcv_flops_counter.get_modules_mapping)
|
|
@ -71,12 +71,13 @@ def register_all_modules(init_default_scope: bool = True) -> None:
|
|||
DefaultScope.get_instance('mmrazor', scope_name='mmrazor')
|
||||
return
|
||||
current_scope = DefaultScope.get_current_instance()
|
||||
if current_scope.scope_name != 'mmrazor':
|
||||
warnings.warn('The current default scope '
|
||||
f'"{current_scope.scope_name}" is not "mmrazor", '
|
||||
'`register_all_modules` will force the current'
|
||||
'default scope to be "mmrazor". If this is not '
|
||||
'expected, please set `init_default_scope=False`.')
|
||||
if current_scope.scope_name != 'mmrazor': # type: ignore
|
||||
warnings.warn(
|
||||
'The current default scope ' # type: ignore
|
||||
f'"{current_scope.scope_name}" is not '
|
||||
'"mmrazor", `register_all_modules` will force the current'
|
||||
'default scope to be "mmrazor". If this is not expected, '
|
||||
'please set `init_default_scope=False`.')
|
||||
# avoid name conflict
|
||||
new_instance_name = f'mmrazor-{datetime.datetime.now()}'
|
||||
DefaultScope.get_instance(new_instance_name, scope_name='mmrazor')
|
||||
|
|
|
@ -1,231 +0,0 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
from unittest import TestCase
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.nn import Conv2d, Module, Parameter
|
||||
|
||||
from mmrazor.models import OneShotMutableModule
|
||||
from mmrazor.registry import MODELS
|
||||
from mmrazor.structures import FlopsEstimator, export_fix_subnet
|
||||
|
||||
_FIRST_STAGE_MUTABLE = dict(
|
||||
type='OneShotMutableOP',
|
||||
candidates=dict(
|
||||
mb_k3e1=dict(
|
||||
type='MBBlock',
|
||||
kernel_size=3,
|
||||
expand_ratio=1,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU6'))))
|
||||
|
||||
_OTHER_STAGE_MUTABLE = dict(
|
||||
type='OneShotMutableOP',
|
||||
candidates=dict(
|
||||
mb_k3e3=dict(
|
||||
type='MBBlock',
|
||||
kernel_size=3,
|
||||
expand_ratio=3,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU6')),
|
||||
mb_k5e3=dict(
|
||||
type='MBBlock',
|
||||
kernel_size=5,
|
||||
expand_ratio=3,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU6')),
|
||||
identity=dict(type='Identity')))
|
||||
|
||||
ARCHSETTING_CFG = [
|
||||
# Parameters to build layers. 4 parameters are needed to construct a
|
||||
# layer, from left to right: channel, num_blocks, stride, mutable cfg.
|
||||
[16, 1, 1, _FIRST_STAGE_MUTABLE],
|
||||
[24, 2, 2, _OTHER_STAGE_MUTABLE],
|
||||
[32, 3, 2, _OTHER_STAGE_MUTABLE],
|
||||
[64, 4, 2, _OTHER_STAGE_MUTABLE],
|
||||
[96, 3, 1, _OTHER_STAGE_MUTABLE],
|
||||
[160, 3, 2, _OTHER_STAGE_MUTABLE],
|
||||
[320, 1, 1, _OTHER_STAGE_MUTABLE]
|
||||
]
|
||||
|
||||
NORM_CFG = dict(type='BN')
|
||||
BACKBONE_CFG = dict(
|
||||
type='mmrazor.SearchableMobileNet',
|
||||
first_channels=32,
|
||||
last_channels=1280,
|
||||
widen_factor=1.0,
|
||||
norm_cfg=NORM_CFG,
|
||||
arch_setting=ARCHSETTING_CFG)
|
||||
|
||||
|
||||
class FoolAddConstant(Module):
|
||||
|
||||
def __init__(self, p: float = 0.1) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.register_parameter(
|
||||
name='p', param=Parameter(torch.tensor(p, dtype=torch.float32)))
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return x + self.p
|
||||
|
||||
|
||||
class FoolConv2d(Module):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.conv2d = Conv2d(3, 32, 3)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return self.conv2d(x)
|
||||
|
||||
|
||||
class FoolConvModule(Module):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.add_constant = FoolAddConstant(0.1)
|
||||
self.conv2d = FoolConv2d()
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
x = self.add_constant(x)
|
||||
|
||||
return self.conv2d(x)
|
||||
|
||||
|
||||
class TestFlopsEstimator(TestCase):
|
||||
|
||||
def sample_choice(self, model: Module) -> None:
|
||||
for module in model.modules():
|
||||
if isinstance(module, OneShotMutableModule):
|
||||
module.current_choice = module.sample_choice()
|
||||
|
||||
def test_get_model_complexity_info(self) -> None:
|
||||
fool_conv2d = FoolConv2d()
|
||||
flops_count, params_count = FlopsEstimator.get_model_complexity_info(
|
||||
fool_conv2d, as_strings=False)
|
||||
|
||||
self.assertGreater(flops_count, 0)
|
||||
self.assertGreater(params_count, 0)
|
||||
|
||||
def test_register_module(self) -> None:
|
||||
fool_add_constant = FoolConvModule()
|
||||
copied_module = copy.deepcopy(fool_add_constant)
|
||||
flops_count, params_count = FlopsEstimator.get_model_complexity_info(
|
||||
copied_module, as_strings=False)
|
||||
|
||||
def fool_add_constant_flops_counter_hook(add_constant_module: Module,
|
||||
input: Tensor,
|
||||
output: Tensor) -> None:
|
||||
add_constant_module.__flops__ = 1e6
|
||||
|
||||
# test register directly
|
||||
FlopsEstimator.register_module(
|
||||
flops_counter_hook=fool_add_constant_flops_counter_hook,
|
||||
module=FoolAddConstant)
|
||||
copied_module = copy.deepcopy(fool_add_constant)
|
||||
flops_count_after_registered, params_count_after_registered = \
|
||||
FlopsEstimator.get_model_complexity_info(
|
||||
model=copied_module, as_strings=False)
|
||||
self.assertEqual(flops_count_after_registered - flops_count, 1e6)
|
||||
self.assertEqual(params_count_after_registered - params_count, 0)
|
||||
FlopsEstimator.remove_custom_module(FoolAddConstant)
|
||||
|
||||
# test register using decorator
|
||||
FlopsEstimator.register_module(
|
||||
flops_counter_hook=fool_add_constant_flops_counter_hook)(
|
||||
FoolAddConstant)
|
||||
copied_module = copy.deepcopy(fool_add_constant)
|
||||
flops_count_after_registered, params_count_after_registered = \
|
||||
FlopsEstimator.get_model_complexity_info(
|
||||
model=copied_module, as_strings=False)
|
||||
self.assertEqual(flops_count_after_registered - flops_count, 1e6)
|
||||
self.assertEqual(params_count_after_registered - params_count, 0)
|
||||
|
||||
FlopsEstimator.remove_custom_module(FoolAddConstant)
|
||||
|
||||
def test_register_module_wrong_parameter(self) -> None:
|
||||
|
||||
def fool_flops_counter_hook(module: Module, input: Tensor,
|
||||
output: Tensor) -> None:
|
||||
return
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
FlopsEstimator.register_module(
|
||||
flops_counter_hook=fool_flops_counter_hook, force=1)
|
||||
with pytest.raises(TypeError):
|
||||
FlopsEstimator.register_module(
|
||||
flops_counter_hook=fool_flops_counter_hook, module=list)
|
||||
with pytest.raises(TypeError):
|
||||
FlopsEstimator.register_module(
|
||||
flops_counter_hook=123, module=FoolAddConstant)
|
||||
|
||||
# test double register
|
||||
FlopsEstimator.register_module(
|
||||
flops_counter_hook=fool_flops_counter_hook, module=FoolAddConstant)
|
||||
with pytest.raises(KeyError):
|
||||
FlopsEstimator.register_module(
|
||||
flops_counter_hook=fool_flops_counter_hook,
|
||||
module=FoolAddConstant)
|
||||
FlopsEstimator.register_module(
|
||||
flops_counter_hook=fool_flops_counter_hook,
|
||||
module=FoolAddConstant,
|
||||
force=True)
|
||||
|
||||
FlopsEstimator.remove_custom_module(FoolAddConstant)
|
||||
|
||||
def test_remove_custom_module(self) -> None:
|
||||
with pytest.raises(KeyError):
|
||||
FlopsEstimator.remove_custom_module(FoolAddConstant)
|
||||
|
||||
def fool_flops_counter_hook(module: Module, input: Tensor,
|
||||
output: Tensor) -> None:
|
||||
return
|
||||
|
||||
FlopsEstimator.register_module(
|
||||
flops_counter_hook=fool_flops_counter_hook, module=FoolAddConstant)
|
||||
|
||||
FlopsEstimator.remove_custom_module(FoolAddConstant)
|
||||
|
||||
def test_clear_custom_module(self) -> None:
|
||||
|
||||
def fool_flops_counter_hook(module: Module, input: Tensor,
|
||||
output: Tensor) -> None:
|
||||
return
|
||||
|
||||
FlopsEstimator.register_module(
|
||||
flops_counter_hook=fool_flops_counter_hook, module=FoolAddConstant)
|
||||
FlopsEstimator.register_module(
|
||||
flops_counter_hook=fool_flops_counter_hook, module=FoolConvModule)
|
||||
|
||||
FlopsEstimator.clear_custom_module()
|
||||
self.assertEqual(FlopsEstimator.get_custom_modules_mapping(), {})
|
||||
|
||||
def test_get_model_complexity_info_subnet(self) -> None:
|
||||
model = MODELS.build(BACKBONE_CFG)
|
||||
self.sample_choice(model)
|
||||
copied_model = copy.deepcopy(model)
|
||||
|
||||
flops_count, params_count = FlopsEstimator.get_model_complexity_info(
|
||||
copied_model, as_strings=False)
|
||||
|
||||
fix_subnet = export_fix_subnet(model)
|
||||
subnet_flops_count, subnet_params_count = \
|
||||
FlopsEstimator.get_model_complexity_info(
|
||||
model, fix_subnet, as_strings=False)
|
||||
|
||||
self.assertEqual(flops_count, subnet_flops_count)
|
||||
self.assertGreater(params_count, subnet_params_count)
|
||||
|
||||
# test whether subnet estimate will affect original model
|
||||
copied_model = copy.deepcopy(model)
|
||||
flops_count_after_estimate, params_count_after_estimate = \
|
||||
FlopsEstimator.get_model_complexity_info(
|
||||
copied_model, as_strings=False)
|
||||
|
||||
self.assertEqual(flops_count, flops_count_after_estimate)
|
||||
self.assertEqual(params_count, params_count_after_estimate)
|
|
@ -0,0 +1,197 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
from unittest import TestCase
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.nn import Conv2d, Module, Parameter
|
||||
|
||||
from mmrazor.models import OneShotMutableModule, ResourceEstimator
|
||||
from mmrazor.models.task_modules.estimators.counters import BaseCounter
|
||||
from mmrazor.registry import MODELS, TASK_UTILS
|
||||
from mmrazor.structures import export_fix_subnet, load_fix_subnet
|
||||
|
||||
_FIRST_STAGE_MUTABLE = dict(
|
||||
type='OneShotMutableOP',
|
||||
candidates=dict(
|
||||
mb_k3e1=dict(
|
||||
type='MBBlock',
|
||||
kernel_size=3,
|
||||
expand_ratio=1,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU6'))))
|
||||
|
||||
_OTHER_STAGE_MUTABLE = dict(
|
||||
type='OneShotMutableOP',
|
||||
candidates=dict(
|
||||
mb_k3e3=dict(
|
||||
type='MBBlock',
|
||||
kernel_size=3,
|
||||
expand_ratio=3,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU6')),
|
||||
mb_k5e3=dict(
|
||||
type='MBBlock',
|
||||
kernel_size=5,
|
||||
expand_ratio=3,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU6')),
|
||||
identity=dict(type='Identity')))
|
||||
|
||||
ARCHSETTING_CFG = [
|
||||
# Parameters to build layers. 4 parameters are needed to construct a
|
||||
# layer, from left to right: channel, num_blocks, stride, mutable cfg.
|
||||
[16, 1, 1, _FIRST_STAGE_MUTABLE],
|
||||
[24, 2, 2, _OTHER_STAGE_MUTABLE],
|
||||
[32, 3, 2, _OTHER_STAGE_MUTABLE],
|
||||
[64, 4, 2, _OTHER_STAGE_MUTABLE],
|
||||
[96, 3, 1, _OTHER_STAGE_MUTABLE],
|
||||
[160, 3, 2, _OTHER_STAGE_MUTABLE],
|
||||
[320, 1, 1, _OTHER_STAGE_MUTABLE]
|
||||
]
|
||||
|
||||
NORM_CFG = dict(type='BN')
|
||||
BACKBONE_CFG = dict(
|
||||
type='mmrazor.SearchableMobileNet',
|
||||
first_channels=32,
|
||||
last_channels=1280,
|
||||
widen_factor=1.0,
|
||||
norm_cfg=NORM_CFG,
|
||||
arch_setting=ARCHSETTING_CFG)
|
||||
|
||||
estimator = ResourceEstimator()
|
||||
|
||||
|
||||
class FoolAddConstant(Module):
|
||||
|
||||
def __init__(self, p: float = 0.1) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.register_parameter(
|
||||
name='p', param=Parameter(torch.tensor(p, dtype=torch.float32)))
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return x + self.p
|
||||
|
||||
|
||||
@TASK_UTILS.register_module()
|
||||
class FoolAddConstantCounter(BaseCounter):
|
||||
|
||||
@staticmethod
|
||||
def add_count_hook(module, input, output):
|
||||
module.__flops__ += 1000000
|
||||
module.__params__ += 700000
|
||||
|
||||
|
||||
class FoolConv2d(Module):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.conv2d = Conv2d(3, 32, 3)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return self.conv2d(x)
|
||||
|
||||
|
||||
class FoolConvModule(Module):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.add_constant = FoolAddConstant(0.1)
|
||||
self.conv2d = FoolConv2d()
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
x = self.add_constant(x)
|
||||
|
||||
return self.conv2d(x)
|
||||
|
||||
|
||||
class TestResourceEstimator(TestCase):
|
||||
|
||||
def sample_choice(self, model: Module) -> None:
|
||||
for module in model.modules():
|
||||
if isinstance(module, OneShotMutableModule):
|
||||
module.current_choice = module.sample_choice()
|
||||
|
||||
def test_estimate(self) -> None:
|
||||
fool_conv2d = FoolConv2d()
|
||||
results = estimator.estimate(
|
||||
model=fool_conv2d,
|
||||
resource_args=dict(input_shape=(1, 3, 224, 224)))
|
||||
flops_count = results['flops']
|
||||
params_count = results['params']
|
||||
|
||||
self.assertGreater(flops_count, 0)
|
||||
self.assertGreater(params_count, 0)
|
||||
|
||||
def test_register_module(self) -> None:
|
||||
fool_add_constant = FoolConvModule()
|
||||
results = estimator.estimate(
|
||||
model=fool_add_constant,
|
||||
resource_args=dict(input_shape=(1, 3, 224, 224)))
|
||||
flops_count = results['flops']
|
||||
params_count = results['params']
|
||||
|
||||
self.assertEqual(flops_count, 45.158)
|
||||
self.assertEqual(params_count, 0.701)
|
||||
|
||||
def test_disable_sepc_counter(self) -> None:
|
||||
fool_add_constant = FoolConvModule()
|
||||
rest_results = estimator.estimate(
|
||||
model=fool_add_constant,
|
||||
resource_args=dict(
|
||||
input_shape=(1, 3, 224, 224),
|
||||
disabled_counters=['FoolAddConstantCounter']))
|
||||
rest_flops_count = rest_results['flops']
|
||||
rest_params_count = rest_results['params']
|
||||
|
||||
self.assertLess(rest_flops_count, 45.158)
|
||||
self.assertLess(rest_params_count, 0.701)
|
||||
|
||||
def test_estimate_spec_modules(self) -> None:
|
||||
fool_add_constant = FoolConvModule()
|
||||
results = estimator.estimate_spec_modules(
|
||||
model=fool_add_constant,
|
||||
resource_args=dict(
|
||||
input_shape=(1, 3, 224, 224), spec_modules=['add_constant']))
|
||||
self.assertGreater(results['add_constant']['flops'], 0)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
results = estimator.estimate_spec_modules(
|
||||
model=fool_add_constant,
|
||||
resource_args=dict(
|
||||
input_shape=(1, 3, 224, 224), spec_modules=['backbone']))
|
||||
|
||||
def test_estimate_subnet(self) -> None:
|
||||
resource_args = dict(input_shape=(1, 3, 224, 224))
|
||||
model = MODELS.build(BACKBONE_CFG)
|
||||
self.sample_choice(model)
|
||||
copied_model = copy.deepcopy(model)
|
||||
|
||||
results = estimator.estimate(
|
||||
model=copied_model, resource_args=resource_args)
|
||||
flops_count = results['flops']
|
||||
params_count = results['params']
|
||||
|
||||
fix_subnet = export_fix_subnet(model)
|
||||
load_fix_subnet(copied_model, fix_subnet)
|
||||
subnet_results = estimator.estimate(
|
||||
model=copied_model, resource_args=resource_args)
|
||||
subnet_flops_count = subnet_results['flops']
|
||||
subnet_params_count = subnet_results['params']
|
||||
|
||||
self.assertEqual(flops_count, subnet_flops_count)
|
||||
self.assertEqual(params_count, subnet_params_count)
|
||||
|
||||
# test whether subnet estimate will affect original model
|
||||
copied_model = copy.deepcopy(model)
|
||||
results_after_estimate = \
|
||||
estimator.estimate(model=copied_model, resource_args=resource_args)
|
||||
flops_count_after_estimate = results_after_estimate['flops']
|
||||
params_count_after_estimate = results_after_estimate['params']
|
||||
|
||||
self.assertEqual(flops_count, flops_count_after_estimate)
|
||||
self.assertEqual(params_count, params_count_after_estimate)
|
|
@ -110,9 +110,16 @@ class TestEvolutionSearchLoop(TestCase):
|
|||
self.assertIsInstance(loop, EvolutionSearchLoop)
|
||||
self.assertEqual(loop.candidates, fake_candidates)
|
||||
|
||||
@patch('mmrazor.engine.runner.evolution_search_loop.export_fix_subnet')
|
||||
@patch('mmrazor.structures.FlopsEstimator.get_model_complexity_info')
|
||||
def test_run_epoch(self, mock_flops, mock_export_fix_subnet):
|
||||
@patch(
|
||||
'mmrazor.engine.runner.evolution_search_loop.export_fix_subnet',
|
||||
return_value={
|
||||
'1': 'choice1',
|
||||
'2': 'choice2'
|
||||
})
|
||||
@patch(
|
||||
'mmrazor.models.task_modules.ResourceEstimator.estimate',
|
||||
return_value=dict(flops=50.0, params=1.0))
|
||||
def test_run_epoch(self, export_fix_subnet, estimate):
|
||||
# test_run_epoch: distributed == False
|
||||
loop_cfg = copy.deepcopy(self.train_cfg)
|
||||
loop_cfg.runner = self.runner
|
||||
|
@ -123,8 +130,6 @@ class TestEvolutionSearchLoop(TestCase):
|
|||
self.runner.epoch = 1
|
||||
self.runner.distributed = False
|
||||
self.runner.work_dir = self.temp_dir
|
||||
fake_subnet = {'1': 'choice1', '2': 'choice2'}
|
||||
self.runner.model.sample_subnet = MagicMock(return_value=fake_subnet)
|
||||
loop.run_epoch()
|
||||
self.assertEqual(len(loop.candidates), 4)
|
||||
self.assertEqual(len(loop.top_k_candidates), 2)
|
||||
|
@ -136,8 +141,6 @@ class TestEvolutionSearchLoop(TestCase):
|
|||
self.runner.epoch = 1
|
||||
self.runner.distributed = True
|
||||
self.runner.work_dir = self.temp_dir
|
||||
fake_subnet = {'1': 'choice1', '2': 'choice2'}
|
||||
self.runner.model.sample_subnet = MagicMock(return_value=fake_subnet)
|
||||
loop.run_epoch()
|
||||
self.assertEqual(len(loop.candidates), 4)
|
||||
self.assertEqual(len(loop.top_k_candidates), 2)
|
||||
|
@ -150,10 +153,6 @@ class TestEvolutionSearchLoop(TestCase):
|
|||
self.runner.epoch = 1
|
||||
self.runner.distributed = True
|
||||
self.runner.work_dir = self.temp_dir
|
||||
fake_subnet = {'1': 'choice1', '2': 'choice2'}
|
||||
loop.model.sample_subnet = MagicMock(return_value=fake_subnet)
|
||||
mock_flops.return_value = (50., 1)
|
||||
mock_export_fix_subnet.return_value = fake_subnet
|
||||
loop.run_epoch()
|
||||
self.assertEqual(len(loop.candidates), 4)
|
||||
self.assertEqual(len(loop.top_k_candidates), 2)
|
||||
|
|
|
@ -191,15 +191,20 @@ class TestGreedySamplerTrainLoop(TestCase):
|
|||
self.assertEqual(subnet, fake_subnet)
|
||||
self.assertEqual(len(loop.top_k_candidates), loop.top_k - 1)
|
||||
|
||||
@patch('mmrazor.engine.runner.subnet_sampler_loop.export_fix_subnet')
|
||||
@patch('mmrazor.structures.FlopsEstimator.get_model_complexity_info')
|
||||
def test_run(self, mock_flops, mock_export_fix_subnet):
|
||||
@patch(
|
||||
'mmrazor.engine.runner.evolution_search_loop.export_fix_subnet',
|
||||
return_value={
|
||||
'1': 'choice1',
|
||||
'2': 'choice2'
|
||||
})
|
||||
@patch(
|
||||
'mmrazor.models.task_modules.ResourceEstimator.estimate',
|
||||
return_value=dict(flops=50.0, params=1.0))
|
||||
def test_run(self, export_fix_subnet, estimate):
|
||||
# test run with flops_range=None
|
||||
cfg = copy.deepcopy(self.iter_based_cfg)
|
||||
cfg.experiment_name = 'test_run1'
|
||||
runner = Runner.from_cfg(cfg)
|
||||
fake_subnet = {'1': 'choice1', '2': 'choice2'}
|
||||
runner.model.sample_subnet = MagicMock(return_value=fake_subnet)
|
||||
runner.train()
|
||||
|
||||
self.assertEqual(runner.iter, runner.max_iters)
|
||||
|
@ -210,10 +215,6 @@ class TestGreedySamplerTrainLoop(TestCase):
|
|||
cfg.experiment_name = 'test_run2'
|
||||
cfg.train_cfg.flops_range = (0, 100)
|
||||
runner = Runner.from_cfg(cfg)
|
||||
fake_subnet = {'1': 'choice1', '2': 'choice2'}
|
||||
runner.model.sample_subnet = MagicMock(return_value=fake_subnet)
|
||||
mock_flops.return_value = (50., 1)
|
||||
mock_export_fix_subnet.return_value = fake_subnet
|
||||
runner.train()
|
||||
|
||||
self.assertEqual(runner.iter, runner.max_iters)
|
||||
|
|
Loading…
Reference in New Issue