[Feature] Refactor Estimator for computing FLOPs/Params/Latency. (#230)

* Refactor ModelEstimator:
1. add EvaluatorLoop in engine.runners;
2. add estimator for structures (both subnet & supernet);
3. add layer_counter for each op.

* fix lint

* update estimator:
1. add ResourceEstimator based on BaseEstimator;
2. add notes & examples for ResourceEstimator & EvaluatorLoop usage;
3. fix a bug of latency test.
4. minor changes according to comments.

* add UT & fix a bug caused by UT

* add docstrings & remove old estimator

* update docstrings for op_spec_counters

* rename resource_evaluator_val_loop

* support adding resource attrs of each submodule in a measured model

* fix lint

* refactor estimator file structures

* support estimating resources for spec modules

* rm old UT

* update new estimator UT cases

* fix traversal range of the model

* cancel unit convert in accumulate_sub_module_flops_params

* use estimator_cfg to build ResourceEstimator

* fix a broadcast bug

* delete fixed input_shape

* add assertion and string-format-return when measuring spec_modules

* add UT for estimating spec_modules
pull/237/head
Yang Gao 2022-08-23 15:01:47 +08:00 committed by GitHub
parent 57aec1f730
commit 4b3f8ab69e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
33 changed files with 1524 additions and 550 deletions

View File

@ -1,5 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .hooks import DumpSubnetHook
from .hooks import DumpSubnetHook, EstimateResourcesHook
from .optimizers import SeparateOptimWrapperConstructor
from .runner import (AutoSlimValLoop, DartsEpochBasedTrainLoop,
DartsIterBasedTrainLoop, EvolutionSearchLoop,
@ -10,5 +10,6 @@ __all__ = [
'SeparateOptimWrapperConstructor', 'DumpSubnetHook',
'SingleTeacherDistillValLoop', 'DartsEpochBasedTrainLoop',
'DartsIterBasedTrainLoop', 'SlimmableValLoop', 'EvolutionSearchLoop',
'GreedySamplerTrainLoop', 'AutoSlimValLoop', 'SelfDistillValLoop'
'GreedySamplerTrainLoop', 'AutoSlimValLoop', 'SelfDistillValLoop',
'EstimateResourcesHook'
]

View File

@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .dump_subnet_hook import DumpSubnetHook
from .estimate_resources_hook import EstimateResourcesHook
__all__ = ['DumpSubnetHook']
__all__ = ['DumpSubnetHook', 'EstimateResourcesHook']

View File

@ -0,0 +1,117 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
from typing import Any, Dict, Optional, Sequence
import torch
from mmengine.data import BaseDataElement
from mmengine.hooks import Hook
from mmengine.registry import HOOKS
from mmrazor.models.task_modules import ResourceEstimator
DATA_BATCH = Optional[Sequence[dict]]
@HOOKS.register_module()
class EstimateResourcesHook(Hook):
"""Estimate model resources periodically.
Args:
interval (int): The saving period. If ``by_epoch=True``, interval
indicates epochs, otherwise it indicates iterations.
Defaults to -1, which means "never".
by_epoch (bool): Saving checkpoints by epoch or by iteration.
Default to True.
estimator_cfg (Dict[str, Any]): Used for building a resource estimator.
Default to dict().
Example:
>>> add the `EstimatorResourcesHook` in custom_hooks as follows:
custom_hooks = [
dict(type='mmrazor.EstimateResourcesHook',
interval=1,
by_epoch=True,
estimator_cfg=dict(input_shape=(1, 3, 64, 64)))
]
"""
out_dir: str
priority = 'VERY_LOW'
def __init__(self,
interval: int = -1,
by_epoch: bool = True,
estimator_cfg: Dict[str, Any] = dict(),
**kwargs) -> None:
self.interval = interval
self.by_epoch = by_epoch
self.estimator = ResourceEstimator(**estimator_cfg)
def after_val_epoch(self,
runner,
metrics: Optional[Dict[str, float]] = None) -> None:
"""Estimate model resources after every n val epochs.
Args:
runner (Runner): The runner of the training process.
"""
if not self.by_epoch:
return
if self.every_n_epochs(runner, self.interval):
self.estimate_resources(runner)
def after_val_iter(self,
runner,
batch_idx: int,
data_batch: DATA_BATCH = None,
outputs: Optional[Sequence[BaseDataElement]] = None) \
-> None:
"""Estimate model resources after every n val iters.
Args:
runner (Runner): The runner of the training process.
"""
if self.by_epoch:
return
if self.every_n_train_iters(runner, self.interval):
self.estimate_resources(runner)
def estimate_resources(self, runner) -> None:
"""Estimate model resources: latency/flops/params."""
model = runner.model.module if runner.distributed else runner.model
# TODO confirm the state judgement.
if hasattr(model, 'is_supernet') and model.is_supernet:
model = self.export_subnet(model)
resource_metrics = self.estimator.estimate(model)
runner.logger.info(f'Estimate model resources: {resource_metrics}')
def export_subnet(self, model) -> torch.nn.Module:
"""Export current best subnet.
NOTE: This method is called when it comes to those NAS algorithms that
require building a supernet for training.
For those algorithms, measuring subnet resources is more meaningful
than supernet during validation, therefore this method is required to
get the current searched subnet from the supernet.
"""
# Avoid circular import
from mmrazor.models.mutables.base_mutable import BaseMutable
from mmrazor.structures import load_fix_subnet
# delete non-leaf tensor to get deepcopy(model).
# TODO solve the hard case.
for module in model.architecture.modules():
if isinstance(module, BaseMutable):
if hasattr(module, 'arch_weights'):
delattr(module, 'arch_weights')
copied_model = copy.deepcopy(model)
fix_mutable = copied_model.search_subnet()
load_fix_subnet(copied_model, fix_mutable)
return copied_model

View File

@ -27,7 +27,7 @@ class AutoSlimValLoop(ValLoop):
# just for convenience
self._model = model
def run(self) -> None:
def run(self):
"""Launch validation."""
self.runner.call_hook('before_val')

View File

@ -1,9 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import os
import os.path as osp
import random
import warnings
from typing import Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
from mmengine import fileio
@ -13,8 +14,9 @@ from mmengine.runner import EpochBasedTrainLoop
from mmengine.utils import is_list_of
from torch.utils.data import DataLoader
from mmrazor.models.task_modules import ResourceEstimator
from mmrazor.registry import LOOPS
from mmrazor.structures import Candidates, FlopsEstimator, export_fix_subnet
from mmrazor.structures import Candidates, export_fix_subnet, load_fix_subnet
from mmrazor.utils import SupportRandomSubnet
from .utils import crossover
@ -42,6 +44,8 @@ class EvolutionSearchLoop(EpochBasedTrainLoop):
mutate_prob (float): The probability of mutation. Defaults to 0.1.
flops_range (tuple, optional): flops_range to be used for screening
candidates.
estimator_cfg (Dict[str, Any]): Used for building a resource estimator.
Default to dict().
score_key (str): Specify one metric in evaluation results to score
candidates. Defaults to 'accuracy_top-1'.
init_candidates (str, optional): The candidates file path, which is
@ -62,6 +66,7 @@ class EvolutionSearchLoop(EpochBasedTrainLoop):
num_crossover: int = 25,
mutate_prob: float = 0.1,
flops_range: Optional[Tuple[float, float]] = (0., 330 * 1e6),
estimator_cfg: Dict[str, Any] = dict(),
score_key: str = 'accuracy_top-1',
init_candidates: Optional[str] = None) -> None:
super().__init__(runner, dataloader, max_epochs)
@ -80,6 +85,7 @@ class EvolutionSearchLoop(EpochBasedTrainLoop):
self.num_candidates = num_candidates
self.top_k = top_k
self.flops_range = flops_range
self.estimator_cfg = estimator_cfg
self.score_key = score_key
self.num_mutation = num_mutation
self.num_crossover = num_crossover
@ -299,8 +305,13 @@ class EvolutionSearchLoop(EpochBasedTrainLoop):
self.model.set_subnet(random_subnet)
fix_mutable = export_fix_subnet(self.model)
flops: float = FlopsEstimator.get_model_complexity_info(
self.model, fix_mutable=fix_mutable, as_strings=False)[0]
copied_model = copy.deepcopy(self.model)
load_fix_subnet(copied_model, fix_mutable)
estimator = ResourceEstimator(**self.estimator_cfg)
results = estimator.estimate(copied_model)
flops = results['flops']
if self.flops_range[0] < flops < self.flops_range[1]:
return True
else:

View File

@ -38,7 +38,7 @@ class SlimmableValLoop(ValLoop):
# just for convenience
self._model = model
def run(self) -> None:
def run(self):
"""Launch validation."""
self.runner.call_hook('before_val')

View File

@ -1,9 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import math
import os
import random
from abc import abstractmethod
from typing import Dict, List, Optional, Sequence, Tuple, Union
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
import torch
from mmengine import fileio
@ -12,8 +13,9 @@ from mmengine.runner import IterBasedTrainLoop
from mmengine.utils import is_list_of
from torch.utils.data import DataLoader
from mmrazor.models.task_modules import ResourceEstimator
from mmrazor.registry import LOOPS
from mmrazor.structures import Candidates, FlopsEstimator, export_fix_subnet
from mmrazor.structures import Candidates, export_fix_subnet, load_fix_subnet
from mmrazor.utils import SupportRandomSubnet
@ -100,7 +102,9 @@ class GreedySamplerTrainLoop(BaseSamplerTrainLoop):
val_interval (int): Validation interval. Defaults to 1000.
score_key (str): Specify one metric in evaluation results to score
candidates. Defaults to 'accuracy_top-1'.
constraints (dict): Constraints to be used for screening candidates.
flops_range (dict): Constraints to be used for screening candidates.
estimator_cfg (Dict[str, Any]): Used for building a resource estimator.
Default to dict().
num_candidates (int): The number of the candidates consist of samples
from supernet and itself. Defaults to 1000.
num_samples (int): The number of sample in each sampling subnet.
@ -135,6 +139,7 @@ class GreedySamplerTrainLoop(BaseSamplerTrainLoop):
val_interval: int = 1000,
score_key: str = 'accuracy_top-1',
flops_range: Optional[Tuple[float, float]] = (0., 330 * 1e6),
estimator_cfg: Dict[str, Any] = dict(),
num_candidates: int = 1000,
num_samples: int = 10,
top_k: int = 5,
@ -158,6 +163,7 @@ class GreedySamplerTrainLoop(BaseSamplerTrainLoop):
self.score_key = score_key
self.flops_range = flops_range
self.estimator_cfg = estimator_cfg
self.num_candidates = num_candidates
self.num_samples = num_samples
self.top_k = top_k
@ -316,8 +322,13 @@ class GreedySamplerTrainLoop(BaseSamplerTrainLoop):
self.model.set_subnet(random_subnet)
fix_mutable = export_fix_subnet(self.model)
flops = FlopsEstimator.get_model_complexity_info(
self.model, fix_mutable=fix_mutable, as_strings=False)[0]
copied_model = copy.deepcopy(self.model)
load_fix_subnet(copied_model, fix_mutable)
estimator = ResourceEstimator(**self.estimator_cfg)
results = estimator.estimate(copied_model)
flops = results['flops']
if self.flops_range[0] < flops < self.flops_range[1]:
return True
else:

View File

@ -6,3 +6,4 @@ from .losses import * # noqa: F401,F403
from .mutables import * # noqa: F401,F403
from .mutators import * # noqa: F401,F403
from .ops import * # noqa: F401,F403
from .task_modules import * # noqa: F401,F403

View File

@ -0,0 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .estimators import ResourceEstimator
__all__ = ['ResourceEstimator']

View File

@ -0,0 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .counters import * # noqa: F401,F403
from .resource_estimator import ResourceEstimator
__all__ = ['ResourceEstimator']

View File

@ -0,0 +1,55 @@
# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABCMeta, abstractmethod
from typing import Any, Dict, List, Tuple
import torch.nn
from mmrazor.registry import TASK_UTILS
@TASK_UTILS.register_module()
class BaseEstimator(metaclass=ABCMeta):
"""The base class of Estimator, used for estimating model infos.
Args:
default_shape (tuple): Input data's default shape, for calculating
resources consume. Defaults to (1, 3, 224, 224).
units (str): Resource units. Defaults to 'M'.
disabled_counters (list): List of disabled spec op counters.
Defaults to None.
as_strings (bool): Output FLOPs and params counts in a string
form. Default to False.
measure_inference (bool): whether to measure infer speed or not.
Default to False.
"""
def __init__(self,
default_shape: Tuple = (1, 3, 224, 224),
units: str = 'M',
disabled_counters: List[str] = None,
as_strings: bool = False,
measure_inference: bool = False):
assert len(default_shape) in [3, 4, 5], \
f'Unsupported shape: {default_shape}'
self.default_shape = default_shape
self.units = units
self.disabled_counters = disabled_counters
self.as_strings = as_strings
self.measure_inference = measure_inference
@abstractmethod
def estimate(
self, model: torch.nn.Module, resource_args: Dict[str, Any] = dict()
) -> Dict[str, float]:
"""Estimate the resources(flops/params/latency) of the given model.
Args:
model: The measured model.
resource_args (Dict[str, float]): resources information.
NOTE: resource_args have the same items() as the init cfgs.
Returns:
Dict[str, float]): A dict that containing resource results(flops,
params and latency).
"""
pass

View File

@ -0,0 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .flops_params_counter import (get_model_complexity_info,
params_units_convert)
from .latency_counter import repeat_measure_inference_speed
from .op_counters import * # noqa: F401,F403
__all__ = [
'get_model_complexity_info', 'params_units_convert',
'repeat_measure_inference_speed'
]

View File

@ -0,0 +1,472 @@
# Copyright (c) OpenMMLab. All rights reserved.
import sys
from functools import partial
import torch
import torch.nn as nn
from mmrazor.registry import TASK_UTILS
def get_model_complexity_info(model,
input_shape,
spec_modules=[],
disabled_counters=[],
print_per_layer_stat=False,
as_strings=False,
input_constructor=None,
flush=False,
ost=sys.stdout):
"""Get complexity information of a model. This method can calculate FLOPs
and parameter counts of a model with corresponding input shape. It can also
print complexity information for each layer in a model. Supported layers
are listed as below:
- Convolutions: ``nn.Conv1d``, ``nn.Conv2d``, ``nn.Conv3d``.
- Activations: ``nn.ReLU``, ``nn.PReLU``, ``nn.ELU``, ``nn.LeakyReLU``,
``nn.ReLU6``.
- Poolings: ``nn.MaxPool1d``, ``nn.MaxPool2d``, ``nn.MaxPool3d``,
``nn.AvgPool1d``, ``nn.AvgPool2d``, ``nn.AvgPool3d``,
``nn.AdaptiveMaxPool1d``, ``nn.AdaptiveMaxPool2d``,
``nn.AdaptiveMaxPool3d``, ``nn.AdaptiveAvgPool1d``,
``nn.AdaptiveAvgPool2d``, ``nn.AdaptiveAvgPool3d``.
- BatchNorms: ``nn.BatchNorm1d``, ``nn.BatchNorm2d``,
``nn.BatchNorm3d``.
- Linear: ``nn.Linear``.
- Deconvolution: ``nn.ConvTranspose2d``.
- Upsample: ``nn.Upsample``.
Args:
model (nn.Module): The model for complexity calculation.
input_shape (tuple): Input shape (including batchsize) used for
calculation.
spec_modules (list): A list that contains the names of several spec
modules, which users want to get resources infos of them.
e.g., ['backbone', 'head'], ['backbone.layer1']. Default to [].
disabled_counters (list): One can limit which ops' spec would be
calculated. Default to [].
print_per_layer_stat (bool): Whether to print complexity information
for each layer in a model. Default to True.
as_strings (bool): Output FLOPs and params counts in a string form.
Default to True.
input_constructor (None | callable): If specified, it takes a callable
method that generates input. otherwise, it will generate a random
tensor with input shape to calculate FLOPs. Default to None.
flush (bool): same as that in :func:`print`. Default to False.
ost (stream): same as ``file`` param in :func:`print`.
Default to sys.stdout.
Returns:
tuple[float | str] | dict[str, float]: If `as_strings` is set to True,
it will return FLOPs and parameter counts in a string format.
Otherwise, it will return those in a float number format.
If len(spec_modules) > 0, it will return a resource info dict with
FLOPs and parameter counts of each spec module in float format.
"""
assert type(input_shape) is tuple
assert len(input_shape) >= 1
assert isinstance(model, nn.Module)
flops_params_model = add_flops_params_counting_methods(model)
flops_params_model.eval()
flops_params_model.start_flops_params_count(disabled_counters)
if input_constructor:
input = input_constructor(input_shape)
_ = flops_params_model(**input)
else:
try:
batch = torch.ones(()).new_empty(
tuple(input_shape),
dtype=next(flops_params_model.parameters()).dtype,
device=next(flops_params_model.parameters()).device)
except StopIteration:
# Avoid StopIteration for models which have no parameters,
# like `nn.Relu()`, `nn.AvgPool2d`, etc.
batch = torch.ones(()).new_empty(tuple(input_shape))
_ = flops_params_model(batch)
flops_count, params_count = \
flops_params_model.compute_average_flops_params_cost()
if print_per_layer_stat:
print_model_with_flops_params(
flops_params_model,
flops_count,
params_count,
ost=ost,
flush=flush)
if len(spec_modules):
module_names = [name for name, _ in flops_params_model.named_modules()]
for module in spec_modules:
assert module in module_names, \
f'All modules in spec_modules should be in the measured ' \
f'flops_params_model. Got module {module} in spec_modules.'
spec_modules_resources = dict()
accumulate_sub_module_flops_params(flops_params_model)
for name, module in flops_params_model.named_modules():
if name in spec_modules:
spec_modules_resources[name] = dict()
spec_modules_resources[name]['flops'] = module.__flops__
spec_modules_resources[name]['params'] = module.__params__
if as_strings:
spec_modules_resources[name]['flops'] = str(
params_units_convert(module.__flops__,
'G')) + ' GFLOPs'
spec_modules_resources[name]['params'] = str(
params_units_convert(module.__params__, 'M')) + ' M'
flops_params_model.stop_flops_params_count()
if len(spec_modules):
return spec_modules_resources
if as_strings:
flops_string = str(params_units_convert(flops_count, 'G')) + ' GFLOPs'
params_string = str(params_units_convert(params_count, 'M')) + ' M'
return flops_string, params_string
return flops_count, params_count
def params_units_convert(num_params, units='M', precision=3):
"""Convert parameter number with units.
Args:
num_params (float): Parameter number to be converted.
units (str | None): Converted FLOPs units. Options are None, 'M',
'K' and ''. If set to None, it will automatically choose the most
suitable unit for Parameter number. Default to None.
precision (int): Digit number after the decimal point. Default to 2.
Returns:
str: The converted parameter number.
Examples:
>>> params_units_convert(1e9)
'1000.0'
>>> params_units_convert(2e5)
'200.0'
>>> params_units_convert(3e-9)
'3e-09'
"""
if units == 'G':
return round(num_params / 10.**9, precision)
elif units == 'M':
return round(num_params / 10.**6, precision)
elif units == 'K':
return round(num_params / 10.**3, precision)
else:
raise ValueError(f'Unsupported units convert: {units}')
def print_model_with_flops_params(model,
total_flops,
total_params,
units='G',
precision=3,
ost=sys.stdout,
flush=False):
"""Print a model with FLOPs and Params for each layer.
Args:
model (nn.Module): The model to be printed.
total_flops (float): Total FLOPs of the model.
total_params (float): Total parameter counts of the model.
units (str | None): Converted FLOPs units. Default to 'G'.
precision (int): Digit number after the decimal point. Default to 3.
ost (stream): same as `file` param in :func:`print`.
Default to sys.stdout.
flush (bool): same as that in :func:`print`. Default to False.
Example:
>>> class ExampleModel(nn.Module):
>>> def __init__(self):
>>> super().__init__()
>>> self.conv1 = nn.Conv2d(3, 8, 3)
>>> self.conv2 = nn.Conv2d(8, 256, 3)
>>> self.conv3 = nn.Conv2d(256, 8, 3)
>>> self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
>>> self.flatten = nn.Flatten()
>>> self.fc = nn.Linear(8, 1)
>>> def forward(self, x):
>>> x = self.conv1(x)
>>> x = self.conv2(x)
>>> x = self.conv3(x)
>>> x = self.avg_pool(x)
>>> x = self.flatten(x)
>>> x = self.fc(x)
>>> return x
>>> model = ExampleModel()
>>> x = (3, 16, 16)
to print the complexity information state for each layer, you can use
>>> get_model_complexity_info(model, x)
or directly use
>>> print_model_with_flops_params(model, 4579784.0, 37361)
ExampleModel(
0.037 M, 100.000% Params, 0.005 GFLOPs, 100.000% FLOPs,
(conv1): Conv2d(0.0 M, 0.600% Params, 0.0 GFLOPs, 0.959% FLOPs, 3, 8, kernel_size=(3, 3), stride=(1, 1)) # noqa: E501
(conv2): Conv2d(0.019 M, 50.020% Params, 0.003 GFLOPs, 58.760% FLOPs, 8, 256, kernel_size=(3, 3), stride=(1, 1))
(conv3): Conv2d(0.018 M, 49.356% Params, 0.002 GFLOPs, 40.264% FLOPs, 256, 8, kernel_size=(3, 3), stride=(1, 1))
(avg_pool): AdaptiveAvgPool2d(0.0 M, 0.000% Params, 0.0 GFLOPs, 0.017% FLOPs, output_size=(1, 1))
(flatten): Flatten(0.0 M, 0.000% Params, 0.0 GFLOPs, 0.000% FLOPs, )
(fc): Linear(0.0 M, 0.024% Params, 0.0 GFLOPs, 0.000% FLOPs, in_features=8, out_features=1, bias=True)
)
"""
def accumulate_params(self):
if is_supported_instance(self):
return self.__params__
else:
sum = 0
for m in self.children():
sum += m.accumulate_params()
return sum
def accumulate_flops(self):
if is_supported_instance(self):
return self.__flops__ / model.__batch_counter__
else:
sum = 0
for m in self.children():
sum += m.accumulate_flops()
return sum
def flops_repr(self):
accumulated_num_params = self.accumulate_params()
accumulated_flops_cost = self.accumulate_flops()
flops_string = str(
params_units_convert(
accumulated_flops_cost, units=units,
precision=precision)) + ' ' + units + 'FLOPs'
params_string = str(
params_units_convert(
accumulated_num_params, units='M', precision=precision)) + ' M'
return ', '.join([
params_string,
'{:.3%} Params'.format(accumulated_num_params / total_params),
flops_string,
'{:.3%} FLOPs'.format(accumulated_flops_cost / total_flops),
self.original_extra_repr()
])
def add_extra_repr(m):
m.accumulate_flops = accumulate_flops.__get__(m)
m.accumulate_params = accumulate_params.__get__(m)
flops_extra_repr = flops_repr.__get__(m)
if m.extra_repr != flops_extra_repr:
m.original_extra_repr = m.extra_repr
m.extra_repr = flops_extra_repr
assert m.extra_repr != m.original_extra_repr
def del_extra_repr(m):
if hasattr(m, 'original_extra_repr'):
m.extra_repr = m.original_extra_repr
del m.original_extra_repr
if hasattr(m, 'accumulate_flops'):
del m.accumulate_flops
model.apply(add_extra_repr)
print(model, file=ost, flush=flush)
model.apply(del_extra_repr)
def accumulate_sub_module_flops_params(model):
"""Accumulate FLOPs and params for each module in the model. Each module in
the model will have the `__flops__` and `__params__` parameters.
Args:
model (nn.Module): The model to be accumulated.
"""
def accumulate_params(module):
if is_supported_instance(module):
return module.__params__
else:
sum = 0
for m in module.children():
sum += accumulate_params(m)
return sum
def accumulate_flops(module):
if is_supported_instance(module):
return module.__flops__ / model.__batch_counter__
else:
sum = 0
for m in module.children():
sum += accumulate_flops(m)
return sum
for module in model.modules():
_flops = accumulate_flops(module)
_params = accumulate_params(module)
module.__flops__ = _flops
module.__params__ = _params
def get_model_parameters_number(model):
"""Calculate parameter number of a model.
Args:
model (nn.module): The model for parameter number calculation.
Returns:
float: Parameter number of the model.
"""
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
return num_params
def add_flops_params_counting_methods(net_main_module):
# adding additional methods to the existing module object,
# this is done this way so that each function has access to self object
net_main_module.start_flops_params_count = start_flops_params_count.__get__( # noqa: E501
net_main_module)
net_main_module.stop_flops_params_count = stop_flops_params_count.__get__(
net_main_module)
net_main_module.reset_flops_params_count = reset_flops_params_count.__get__( # noqa: E501
net_main_module)
net_main_module.compute_average_flops_params_cost = compute_average_flops_params_cost.__get__( # noqa: E501
net_main_module)
net_main_module.reset_flops_params_count()
return net_main_module
def compute_average_flops_params_cost(self):
"""Compute average FLOPs and Params cost.
A method to compute average FLOPs cost, which will be available after
`add_flops_params_counting_methods()` is called on a desired net object.
Returns:
float: Current mean flops consumption per image.
"""
batches_count = self.__batch_counter__
flops_sum = 0
params_sum = 0
for module in self.modules():
if is_supported_instance(module):
flops_sum += module.__flops__
params_sum += module.__params__
return flops_sum / batches_count, params_sum
def start_flops_params_count(self, disabled_counters):
"""Activate the computation of mean flops and params consumption per image.
A method to activate the computation of mean flops consumption per image.
which will be available after ``add_flops_params_counting_methods()`` is
called on a desired net object. It should be called before running the
network.
"""
add_batch_counter_hook_function(self)
def add_flops_params_counter_hook_function(module):
if is_supported_instance(module):
if hasattr(module, '__flops_params_handle__'):
return
else:
counter_type = get_counter_type(module)
if (disabled_counters is None
or counter_type not in disabled_counters):
counter = TASK_UTILS.build(
dict(type=counter_type, _scope_='mmrazor'))
handle = module.register_forward_hook(
counter.add_count_hook)
module.__flops_params_handle__ = handle
else:
return
self.apply(partial(add_flops_params_counter_hook_function))
def stop_flops_params_count(self):
"""Stop computing the mean flops and params consumption per image.
A method to stop computing the mean flops consumption per image, which will
be available after ``add_flops_params_counting_methods()`` is called on a
desired net object. It can be called to pause the computation whenever.
"""
remove_batch_counter_hook_function(self)
self.apply(remove_flops_params_counter_hook_function)
def reset_flops_params_count(self):
"""Reset statistics computed so far.
A method to Reset computed statistics, which will be available after
`add_flops_params_counting_methods()` is called on a desired net object.
"""
add_batch_counter_variables_or_reset(self)
self.apply(add_flops_params_counter_variable_or_reset)
# ---- Internal functions
def empty_flops_params_counter_hook(module, input, output):
module.__flops__ += 0
module.__params__ += 0
def add_batch_counter_variables_or_reset(module):
module.__batch_counter__ = 0
def add_batch_counter_hook_function(module):
if hasattr(module, '__batch_counter_handle__'):
return
handle = module.register_forward_hook(batch_counter_hook)
module.__batch_counter_handle__ = handle
def batch_counter_hook(module, input, output):
batch_size = 1
if len(input) > 0:
# Can have multiple inputs, getting the first one
input = input[0]
batch_size = len(input)
else:
pass
print('Warning! No positional inputs found for a module, '
'assuming batch size is 1.')
module.__batch_counter__ += batch_size
def remove_batch_counter_hook_function(module):
if hasattr(module, '__batch_counter_handle__'):
module.__batch_counter_handle__.remove()
del module.__batch_counter_handle__
def add_flops_params_counter_variable_or_reset(module):
if is_supported_instance(module):
if hasattr(module, '__flops__') or hasattr(module, '__params__'):
print('Warning: variables __flops__ or __params__ are already '
'defined for the module' + type(module).__name__ +
' ptflops can affect your code!')
module.__flops__ = 0
module.__params__ = 0
def get_counter_type(module):
return module.__class__.__name__ + 'Counter'
def is_supported_instance(module):
if get_counter_type(module) in TASK_UTILS._module_dict.keys():
return True
return False
def remove_flops_params_counter_hook_function(module):
if hasattr(module, '__flops_params_handle__'):
module.__flops_params_handle__.remove()
del module.__flops_params_handle__
if hasattr(module, '__flops__'):
del module.__flops__
if hasattr(module, '__params__'):
del module.__params__

View File

@ -0,0 +1,88 @@
# Copyright (c) OpenMMLab. All rights reserved.
import logging
import time
import torch
from mmengine.logging import print_log
def repeat_measure_inference_speed(model,
resource_args,
max_iter: int = 100,
log_interval: int = 100,
repeat_num: int = 1) -> float:
"""Repeat speed measure for multi-times to get more precise results."""
assert repeat_num >= 1
fps_list = []
for _ in range(repeat_num):
fps_list.append(
measure_inference_speed(model, resource_args, max_iter,
log_interval))
if repeat_num > 1:
fps_list_ = [round(fps, 1) for fps in fps_list]
times_per_img_list = [round(1000 / fps, 1) for fps in fps_list]
mean_fps_ = sum(fps_list_) / len(fps_list_)
mean_times_per_img = sum(times_per_img_list) / len(times_per_img_list)
print_log(
f'Overall fps: {fps_list_}[{mean_fps_:.1f}] img / s, '
f'times per image: '
f'{times_per_img_list}[{mean_times_per_img:.1f}] ms/img',
logger='current',
level=logging.DEBUG)
return mean_times_per_img
latency = round(1000 / fps_list[0], 1)
return latency
def measure_inference_speed(model, resource_args, max_iter: int,
log_interval: int) -> float:
# the first several iterations may be very slow so skip them
num_warmup = 5
pure_inf_time = 0.0
fps = 0.0
data = dict()
if next(model.parameters()).is_cuda:
device = 'cuda'
else:
raise NotImplementedError('To use cpu to test latency not supported.')
# benchmark with 100 image and take the average
for i in range(1, max_iter):
if device == 'cuda':
data = torch.rand(resource_args['input_shape']).cuda()
torch.cuda.synchronize()
start_time = time.perf_counter()
with torch.no_grad():
model(data)
torch.cuda.synchronize()
elapsed = time.perf_counter() - start_time
if i >= num_warmup:
pure_inf_time += elapsed
if (i + 1) % log_interval == 0:
fps = (i + 1 - num_warmup) / pure_inf_time
print_log(
f'Done image [{i + 1:<3}/ {max_iter}], '
f'fps: {fps:.1f} img / s, '
f'times per image: {1000 / fps:.1f} ms / img',
logger='current',
level=logging.DEBUG)
if (i + 1) == max_iter:
fps = (i + 1 - num_warmup) / pure_inf_time
print_log(
f'Overall fps: {fps:.1f} img / s, '
f'times per image: {1000 / fps:.1f} ms / img',
logger='current',
level=logging.DEBUG)
break
torch.cuda.empty_cache()
return fps

View File

@ -0,0 +1,22 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .activation_layer_counter import (ELUCounter, LeakyReLUCounter,
PReLUCounter, ReLU6Counter, ReLUCounter)
from .base_counter import BaseCounter
from .conv_layer_counter import Conv1dCounter, Conv2dCounter, Conv3dCounter
from .deconv_layer_counter import ConvTranspose2dCounter
from .linear_layer_counter import LinearCounter
from .norm_layer_counter import (BatchNorm1dCounter, BatchNorm2dCounter,
BatchNorm3dCounter, GroupNormCounter,
InstanceNorm1dCounter, InstanceNorm2dCounter,
InstanceNorm3dCounter, LayerNormCounter)
from .pooling_layer_counter import * # noqa: F403, F405, F401
from .upsample_layer_counter import UpsampleCounter
__all__ = [
'ReLUCounter', 'PReLUCounter', 'ELUCounter', 'LeakyReLUCounter',
'ReLU6Counter', 'BatchNorm1dCounter', 'BatchNorm2dCounter',
'BatchNorm3dCounter', 'Conv1dCounter', 'Conv2dCounter', 'Conv3dCounter',
'ConvTranspose2dCounter', 'UpsampleCounter', 'LinearCounter',
'GroupNormCounter', 'InstanceNorm1dCounter', 'InstanceNorm2dCounter',
'InstanceNorm3dCounter', 'LayerNormCounter', 'BaseCounter'
]

View File

@ -0,0 +1,35 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmrazor.registry import TASK_UTILS
from ..flops_params_counter import get_model_parameters_number
from .base_counter import BaseCounter
@TASK_UTILS.register_module()
class ReLUCounter(BaseCounter):
"""FLOPs/params counter for ReLU series activate function."""
@staticmethod
def add_count_hook(module, input, output):
active_elements_count = output.numel()
module.__flops__ += int(active_elements_count)
module.__params__ += get_model_parameters_number(module)
@TASK_UTILS.register_module()
class PReLUCounter(ReLUCounter):
pass
@TASK_UTILS.register_module()
class ELUCounter(ReLUCounter):
pass
@TASK_UTILS.register_module()
class LeakyReLUCounter(ReLUCounter):
pass
@TASK_UTILS.register_module()
class ReLU6Counter(ReLUCounter):
pass

View File

@ -0,0 +1,28 @@
# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABCMeta, abstractclassmethod
class BaseCounter(object, metaclass=ABCMeta):
"""Base class of all op module counters in `TASK_UTILS`.
In ResourceEstimator, `XXModuleCounter` is responsible for `XXModule`,
which refers to estimator/flops_params_counter.py::get_counter_type().
Users can customize a `ModuleACounter` and overwrite the `add_count_hook`
method with a self-defined module `ModuleA`.
"""
def __init__(self) -> None:
pass
@staticmethod
@abstractclassmethod
def add_count_hook(module, input, output):
"""The main method of a `BaseCounter` which defines the way to
calculate resources(flops/params) of the current module.
Args:
module (nn.Module): the module to be tested.
input (_type_): input_tensor. Plz refer to `torch forward_hook`
output (_type_): output_tensor. Plz refer to `torch forward_hook`
"""
pass

View File

@ -0,0 +1,57 @@
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
from mmrazor.registry import TASK_UTILS
from .base_counter import BaseCounter
class ConvCounter(BaseCounter):
"""FLOPs/params counter for Conv module series."""
@staticmethod
def add_count_hook(module, input, output):
# Can have multiple inputs, getting the first one
input = input[0]
batch_size = input.shape[0]
output_dims = list(output.shape[2:])
kernel_dims = list(module.kernel_size)
in_channels = module.in_channels
out_channels = module.out_channels
groups = module.groups
filters_per_channel = out_channels / groups
conv_per_position_flops = int(
np.prod(kernel_dims)) * in_channels * filters_per_channel
active_elements_count = batch_size * int(np.prod(output_dims))
overall_conv_flops = conv_per_position_flops * active_elements_count
overall_params = conv_per_position_flops
bias_flops = 0
overall_params = conv_per_position_flops
if module.bias is not None:
bias_flops = out_channels * active_elements_count
overall_params += out_channels
overall_flops = overall_conv_flops + bias_flops
module.__flops__ += overall_flops
module.__params__ += int(overall_params)
@TASK_UTILS.register_module()
class Conv1dCounter(ConvCounter):
pass
@TASK_UTILS.register_module()
class Conv2dCounter(ConvCounter):
pass
@TASK_UTILS.register_module()
class Conv3dCounter(ConvCounter):
pass

View File

@ -0,0 +1,38 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmrazor.registry import TASK_UTILS
from ..flops_params_counter import get_model_parameters_number
from .base_counter import BaseCounter
@TASK_UTILS.register_module()
class ConvTranspose2dCounter(BaseCounter):
"""FLOPs/params counter for Decov module series."""
@staticmethod
def add_count_hook(module, input, output):
# Can have multiple inputs, getting the first one
input = input[0]
batch_size = input.shape[0]
input_height, input_width = input.shape[2:]
# TODO: use more common representation
kernel_height, kernel_width = module.kernel_size
in_channels = module.in_channels
out_channels = module.out_channels
groups = module.groups
filters_per_channel = out_channels // groups
conv_per_position_flops = (
kernel_height * kernel_width * in_channels * filters_per_channel)
active_elements_count = batch_size * input_height * input_width
overall_conv_flops = conv_per_position_flops * active_elements_count
bias_flops = 0
if module.bias is not None:
output_height, output_width = output.shape[2:]
bias_flops = out_channels * batch_size * output_height * output_height # noqa: E501
overall_flops = overall_conv_flops + bias_flops
module.__flops__ += int(overall_flops)
module.__params__ += get_model_parameters_number(module)

View File

@ -0,0 +1,19 @@
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
from mmrazor.registry import TASK_UTILS
from ..flops_params_counter import get_model_parameters_number
from .base_counter import BaseCounter
@TASK_UTILS.register_module()
class LinearCounter(BaseCounter):
"""FLOPs/params counter for Linear operation series."""
@staticmethod
def add_count_hook(module, input, output):
input = input[0]
output_last_dim = output.shape[
-1] # pytorch checks dimensions, so here we don't care much
module.__flops__ += int(np.prod(input.shape) * output_last_dim)
module.__params__ += get_model_parameters_number(module)

View File

@ -0,0 +1,59 @@
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
from mmrazor.registry import TASK_UTILS
from ..flops_params_counter import get_model_parameters_number
from .base_counter import BaseCounter
class BNCounter(BaseCounter):
"""FLOPs/params counter for BatchNormalization series."""
@staticmethod
def add_count_hook(module, input, output):
input = input[0]
batch_flops = np.prod(input.shape)
if getattr(module, 'affine', False):
batch_flops *= 2
module.__flops__ += int(batch_flops)
module.__params__ += get_model_parameters_number(module)
@TASK_UTILS.register_module()
class BatchNorm1dCounter(BNCounter):
pass
@TASK_UTILS.register_module()
class BatchNorm2dCounter(BNCounter):
pass
@TASK_UTILS.register_module()
class BatchNorm3dCounter(BNCounter):
pass
@TASK_UTILS.register_module()
class InstanceNorm1dCounter(BNCounter):
pass
@TASK_UTILS.register_module()
class InstanceNorm2dCounter(BNCounter):
pass
@TASK_UTILS.register_module()
class InstanceNorm3dCounter(BNCounter):
pass
@TASK_UTILS.register_module()
class LayerNormCounter(BNCounter):
pass
@TASK_UTILS.register_module()
class GroupNormCounter(BNCounter):
pass

View File

@ -0,0 +1,76 @@
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
from mmrazor.registry import TASK_UTILS
from ..flops_params_counter import get_model_parameters_number
from .base_counter import BaseCounter
class PoolCounter(BaseCounter):
"""FLOPs/params counter for Pooling series."""
@staticmethod
def add_count_hook(module, input, output):
input = input[0]
module.__flops__ += int(np.prod(input.shape))
module.__params__ += get_model_parameters_number(module)
@TASK_UTILS.register_module()
class MaxPool1dCounter(PoolCounter):
pass
@TASK_UTILS.register_module()
class MaxPool2dCounter(PoolCounter):
pass
@TASK_UTILS.register_module()
class MaxPool3dCounter(PoolCounter):
pass
@TASK_UTILS.register_module()
class AvgPool1dCounter(PoolCounter):
pass
@TASK_UTILS.register_module()
class AvgPool2dCounter(PoolCounter):
pass
@TASK_UTILS.register_module()
class AvgPool3dCounter(PoolCounter):
pass
@TASK_UTILS.register_module()
class AdaptiveMaxPool1dCounter(PoolCounter):
pass
@TASK_UTILS.register_module()
class AdaptiveMaxPool2dCounter(PoolCounter):
pass
@TASK_UTILS.register_module()
class AdaptiveMaxPool3dCounter(PoolCounter):
pass
@TASK_UTILS.register_module()
class AdaptiveAvgPool1dCounter(PoolCounter):
pass
@TASK_UTILS.register_module()
class AdaptiveAvgPool2dCounter(PoolCounter):
pass
@TASK_UTILS.register_module()
class AdaptiveAvgPool3dCounter(PoolCounter):
pass

View File

@ -0,0 +1,19 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmrazor.registry import TASK_UTILS
from ..flops_params_counter import get_model_parameters_number
from .base_counter import BaseCounter
@TASK_UTILS.register_module()
class UpsampleCounter(BaseCounter):
"""FLOPs/params counter for Upsample function."""
@staticmethod
def add_count_hook(module, input, output):
output_size = output[0]
batch_size = output_size.shape[0]
output_elements_count = batch_size
for val in output_size.shape[1:]:
output_elements_count *= val
module.__flops__ += int(output_elements_count)
module.__params__ += get_model_parameters_number(module)

View File

@ -0,0 +1,155 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Any, Dict, List, Tuple
import torch.nn
from mmengine.dist import broadcast_object_list, is_main_process
from mmrazor.registry import TASK_UTILS
from .base_estimator import BaseEstimator
from .counters import (get_model_complexity_info, params_units_convert,
repeat_measure_inference_speed)
@TASK_UTILS.register_module()
class ResourceEstimator(BaseEstimator):
"""Estimator for calculating the resources consume.
Args:
default_shape (tuple): Input data's default shape, for calculating
resources consume. Defaults to (1, 3, 224, 224)
units (str): Resource units. Defaults to 'M'.
disabled_counters (list): List of disabled spec op counters.
Defaults to None.
NOTE: disabled_counters contains the op counter class names
in estimator.op_counters that require to be disabled,
such as 'ConvCounter', 'BatchNorm2dCounter', ...
Examples:
>>> # direct calculate resource consume of nn.Conv2d
>>> conv2d = nn.Conv2d(3, 32, 3)
>>> estimator = ResourceEstimator()
>>> estimator.estimate(
... model=conv2d,
... resource_args=dict(input_shape=(1, 3, 64, 64)))
{'flops': 3.444, 'params': 0.001, 'latency': 0.0}
>>> # calculate resources of custom modules
>>> class CustomModule(nn.Module):
...
... def __init__(self) -> None:
... super().__init__()
...
... def forward(self, x):
... return x
...
>>> @TASK_UTILS.register_module()
... class CustomModuleCounter(BaseCounter):
...
... @staticmethod
... def add_count_hook(module, input, output):
... module.__flops__ += 1000000
... module.__params__ += 700000
...
>>> model = CustomModule()
>>> estimator.estimate(
... model=model,
... resource_args=dict(input_shape=(1, 3, 64, 64)))
{'flops': 1.0, 'params': 0.7, 'latency': 0.0}
...
>>> # calculate resources of custom modules with disable_counters
>>> estimator.estimate(
... model=model,
... resource_args=dict(
... input_shape=(1, 3, 64, 64),
... disabled_counters=['CustomModuleCounter']))
{'flops': 0.0, 'params': 0.0, 'latency': 0.0}
>>> # calculate resources of mmrazor.models
NOTE: check 'EstimateResourcesHook' in
mmrazor.engine.hooks.estimate_resources_hook for details.
"""
def __init__(self,
default_shape: Tuple = (1, 3, 224, 224),
units: str = 'M',
disabled_counters: List[str] = [],
as_strings: bool = False,
measure_inference: bool = False):
super().__init__(default_shape, units, disabled_counters, as_strings,
measure_inference)
def estimate(
self, model: torch.nn.Module, resource_args: Dict[str, Any] = dict()
) -> Dict[str, Any]:
"""Estimate the resources(flops/params/latency) of the given model.
Args:
model: The measured model.
resource_args (Dict[str, float]): Args for resources estimation.
NOTE: resource_args have the same items() as the init cfgs.
Returns:
Dict[str, str]): A dict that containing resource results(flops,
params and latency).
"""
resource_metrics = dict()
if is_main_process():
measure_inference = resource_args.pop('measure_inference', False)
if 'input_shape' not in resource_args.keys():
resource_args['input_shape'] = self.default_shape
if 'disabled_counters' not in resource_args.keys():
resource_args['disabled_counters'] = self.disabled_counters
model.eval()
flops, params = get_model_complexity_info(model, **resource_args)
if measure_inference:
latency = repeat_measure_inference_speed(
model, resource_args, max_iter=100, repeat_num=2)
else:
latency = 0.0
as_strings = resource_args.get('as_strings', self.as_strings)
if as_strings and self.units is not None:
raise ValueError('Set units to None, when as_trings=True.')
if self.units is not None:
flops = params_units_convert(flops, self.units)
params = params_units_convert(params, self.units)
resource_metrics.update({
'flops': flops,
'params': params,
'latency': latency
})
results = [resource_metrics]
else:
results = [None] # type: ignore
broadcast_object_list(results)
return results[0]
def estimate_spec_modules(
self, model: torch.nn.Module, resource_args: Dict[str, Any] = dict()
) -> Dict[str, float]:
"""Estimate the resources(flops/params/latency) of the spec modules.
Args:
model: The measured model.
resource_args (Dict[str, float]): Args for resources estimation.
NOTE: resource_args have the same items() as the init cfgs.
Returns:
Dict[str, float]): A dict that containing resource results(flops,
params) of each modules in resource_args['spec_modules'].
"""
assert 'spec_modules' in resource_args, \
'spec_modules is required when calling estimate_spec_modules().'
resource_args.pop('measure_inference', False)
if 'input_shape' not in resource_args.keys():
resource_args['input_shape'] = self.default_shape
if 'disabled_counters' not in resource_args.keys():
resource_args['disabled_counters'] = self.disabled_counters
model.eval()
spec_modules_resources = get_model_complexity_info(
model, **resource_args)
return spec_modules_resources

View File

@ -40,7 +40,7 @@ def build_razor_model_from_cfg(
# TODO relay on mmengine:HAOCHENYE/config_new_feature
if cfg.get('cfg_path', None) and not cfg.get('type', None):
from mmengine.config import get_model
model = get_model(**cfg)
model = get_model(**cfg) # type: ignore
return model
from mmrazor.structures import load_fix_subnet

View File

@ -1,8 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .candidate import Candidates
from .estimators import FlopsEstimator
from .fix_subnet import export_fix_subnet, load_fix_subnet
__all__ = [
'FlopsEstimator', 'load_fix_subnet', 'export_fix_subnet', 'Candidates'
]
__all__ = ['load_fix_subnet', 'export_fix_subnet', 'Candidates']

View File

@ -1,4 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .flops import FlopsEstimator
__all__ = ['FlopsEstimator']

View File

@ -1,270 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import sys
from functools import wraps
from typing import IO, Callable, Dict, Iterable, Optional, Tuple, Type
from mmcv.cnn.utils import flops_counter as mmcv_flops_counter
from mmcv.cnn.utils import get_model_complexity_info
from torch.nn import Module
from mmrazor.utils import ValidFixMutable
from ..fix_subnet import load_fix_subnet
class FlopsEstimator:
"""An estimator to help calculate flops of module.
FlopsEstimator is based on flops counter in mmcv, it can be directly used
to calculate flops of module or calculate flops of subnet. Also, it
provides api for adding flops counter hook to custom modules that are not
supported by mmcv.
Examples:
>>> # direct calculate flops of nn.Conv2d
>>> conv2d = nn.Conv2d(3, 32, 3)
>>> FlopsEstimator.get_model_complexity_info(
... conv2d,
... input_shape=[3, 224, 224],
... print_per_layer_stat=False)
('0.04 GFLOPs', '896')
>>> # calculate flops of custom modules
>>> class FoolAddConstant(nn.Module):
...
... def __init__(self, p: float = 0.1) -> None:
... super().__init__()
...
... self.register_parameter(
... name='p',
... param=Parameter(torch.tensor(p, dtype=torch.float32)))
...
... def forward(self, x: Tensor) -> Tensor:
... return x + self.p
...
>>> def fool_add_constant_flops_counter_hook(
... add_constant_module: nn.Module,
... input: Tensor,
... output: Tensor) -> None:
... add_constant_module.__flops__ = 1e8
...
>>> FlopsEstimator.register_module(
... flops_counter_hook=fool_add_constant_flops_counter_hook,
... module=FoolAddConstant)
>>> model = FoolAddConstant()
>>> FlopsEstimator.get_model_complexity_info(
... model=model,
... input_shape=[3, 224, 224],
... print_per_layer_stat=False)
('0.1 GFLOPs', '1')
>>> # calculate subnet flops
>>> class FoolOneShotModel(nn.Module):
...
... def __init__(self) -> None:
... super().__init__()
...
... candidates = nn.ModuleDict({
... 'conv3x3': nn.Conv2d(3, 32, 3),
... 'conv5x5': nn.Conv2d(3, 32, 5)})
... self.op = OneShotMutableOP(candidates)
... self.op.current_choice = 'conv3x3'
...
... def forward(self, x: Tensor) -> Tensor:
... return self.op(x)
...
>>> model = FoolOneShotModel()
>>> fix_subnet = export_fix_subnet(model)
>>> fix_subnet
FixSubnet(modules={'op': 'conv3x3'}, channels=None)
>>> FlopsEstimator.get_model_complexity_info(
... supernet=model,
... fix_subnet=fix_subnet,
... input_shape=[3, 224, 224],
... print_per_layer_stat=False)
('0.04 GFLOPs', '896')
"""
_mmcv_modules_mapping: Dict[Module, Callable] = \
mmcv_flops_counter.get_modules_mapping()
_custom_modules_mapping: Dict[Module, Callable] = {}
@staticmethod
def get_model_complexity_info(
model: Module,
fix_mutable: Optional[ValidFixMutable] = None,
input_shape: Iterable[int] = (3, 224, 224),
input_constructor: Optional[Callable] = None,
print_per_layer_stat: bool = True,
as_strings: bool = True,
flush: bool = False,
ost: IO = sys.stdout) -> Tuple:
"""Get complexity information of model.
This method is based on ``get_model_complexity_info`` of mmcv. It can
calculate FLOPs and parameter counts of a model with corresponding
input shape. It can also print complexity information for each layer
in a model.
Args:
model (torch.nn.Module): The model for complexity calculation.
fix_mutable (ValidFixMutable, optional): The config of fixed
subnet. When this argument is specified, the function will
return complexity information of the subnet. Default: None.
input_shape (Iterable[int]): Input shape used for calculation.
print_per_layer_stat (bool): Whether to print complexity
information for each layer in a model. Default: True.
as_strings (bool): Output FLOPs and params counts in a string form.
Default: True.
input_constructor (Callable, optional): If specified, it takes a
callable method that generates input. otherwise, it will
generate a random tensor with input shape to calculate FLOPs.
Default: None.
flush (bool): same as that in :func:`print`. Default: False.
ost (stream): same as ``file`` param in :func:`print`.
Default: sys.stdout.
Returns:
tuple[float | str]: If ``as_strings`` is set to True, it will
return FLOPs and parameter counts in a string format.
otherwise, it will return those in a float number format.
"""
copied_model = copy.deepcopy(model)
if fix_mutable is not None:
load_fix_subnet(copied_model, fix_mutable)
return get_model_complexity_info(
model=copied_model,
input_shape=input_shape,
input_constructor=input_constructor,
print_per_layer_stat=print_per_layer_stat,
as_strings=as_strings,
flush=flush,
ost=ost)
@classmethod
def register_module(cls,
flops_counter_hook: Callable,
module: Optional[Type[Module]] = None,
force: bool = False) -> Optional[Callable]:
"""Register a module with flops_counter_hook.
Args:
flops_counter (Callable): The hook that specifies how to calculate
flops of given module.
module (torch.nn.Module, optional): Module class to be registered.
Defaults to None.
force (bool): Whether to override an existing flops_counter_hook
with the same module. Default to False.
"""
if not isinstance(force, bool):
raise TypeError(f'force must be a boolean, but got {type(force)}')
if not (module is None or issubclass(module, Module)):
raise TypeError(
'module must be None, an subclass of torch.nn.Module, '
f'but got {module}')
if not callable(flops_counter_hook):
raise TypeError('flops_counter_hook must be Callable, '
f'but got {type(flops_counter_hook)}')
if module is not None:
return cls._register_module(
module=module,
flops_counter_hook=flops_counter_hook,
force=force)
def _register(module: Type[Module]) -> None:
cls._register_module(
module=module,
flops_counter_hook=flops_counter_hook,
force=force)
return _register
@classmethod
def remove_custom_module(cls, module: Type[Module]) -> None:
"""Remove a registered module.
Args:
module (torch.nn.Module): Module class to be removed.
"""
if module not in cls._custom_modules_mapping:
raise KeyError(f'{module} not in custom module mapping')
del cls._custom_modules_mapping[module]
@classmethod
def clear_custom_module(cls) -> None:
"""Remove all registered modules."""
cls._custom_modules_mapping.clear()
@classmethod
def _register_module(cls,
flops_counter_hook: Callable,
module: Type[Module],
force: bool = False) -> None:
"""Register a module with flops_counter_hook.
Args:
flops_counter (Callable): The hook that specifies how to calculate
flops of given module.
module (torch.nn.Module, optional): Module class to be registered.
Defaults to None.
force (bool): Whether to override an existing flops_counter_hook
with the same module. Default to False.
"""
if not force and module in cls.get_modules_mapping():
raise KeyError(f'{module} is already registered')
cls._custom_modules_mapping[module] = flops_counter_hook
@classmethod
def get_modules_mapping(cls) -> Dict[Module, Callable]:
"""Get all modules with their corresponding flops counter hook.
Returns:
Dict[Module, Callable]: Modules with their corresponding flops
counter hook.
"""
return {**cls._mmcv_modules_mapping, **cls._custom_modules_mapping}
@classmethod
def get_custom_modules_mapping(cls) -> Dict[Module, Callable]:
"""Get customed modules with their corresponding flops counter hook.
Returns:
Dict[Module, Callable]: Modules with their corresponding flops
counter hook.
"""
return {**cls._custom_modules_mapping}
@classmethod
def _mmcv_modules_mappings_wrapper(
cls, mmcv_get_modules_mapping: Callable) -> Callable:
"""Wrapper for ``get_modules_mapping`` function in mmcv.
Args:
mmcv_get_modules_mapping (Callable): ``get_modules_mapping``
function in mmcv.
Returns:
Callable: Wrapped ``get_modules_mapping`` function.
"""
@wraps(mmcv_get_modules_mapping)
def wrapper() -> Dict[Module, Callable]:
mmcv_modules_mapping: Dict[Module, Callable] = \
mmcv_get_modules_mapping()
# TODO
# use | operator
# | operator only be supported in python 3.9.0 or greater
return {**mmcv_modules_mapping, **cls._custom_modules_mapping}
return wrapper
mmcv_flops_counter.get_modules_mapping = \
FlopsEstimator._mmcv_modules_mappings_wrapper(
mmcv_flops_counter.get_modules_mapping)

View File

@ -71,12 +71,13 @@ def register_all_modules(init_default_scope: bool = True) -> None:
DefaultScope.get_instance('mmrazor', scope_name='mmrazor')
return
current_scope = DefaultScope.get_current_instance()
if current_scope.scope_name != 'mmrazor':
warnings.warn('The current default scope '
f'"{current_scope.scope_name}" is not "mmrazor", '
'`register_all_modules` will force the current'
'default scope to be "mmrazor". If this is not '
'expected, please set `init_default_scope=False`.')
if current_scope.scope_name != 'mmrazor': # type: ignore
warnings.warn(
'The current default scope ' # type: ignore
f'"{current_scope.scope_name}" is not '
'"mmrazor", `register_all_modules` will force the current'
'default scope to be "mmrazor". If this is not expected, '
'please set `init_default_scope=False`.')
# avoid name conflict
new_instance_name = f'mmrazor-{datetime.datetime.now()}'
DefaultScope.get_instance(new_instance_name, scope_name='mmrazor')

View File

@ -1,231 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
from unittest import TestCase
import pytest
import torch
from torch import Tensor
from torch.nn import Conv2d, Module, Parameter
from mmrazor.models import OneShotMutableModule
from mmrazor.registry import MODELS
from mmrazor.structures import FlopsEstimator, export_fix_subnet
_FIRST_STAGE_MUTABLE = dict(
type='OneShotMutableOP',
candidates=dict(
mb_k3e1=dict(
type='MBBlock',
kernel_size=3,
expand_ratio=1,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU6'))))
_OTHER_STAGE_MUTABLE = dict(
type='OneShotMutableOP',
candidates=dict(
mb_k3e3=dict(
type='MBBlock',
kernel_size=3,
expand_ratio=3,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU6')),
mb_k5e3=dict(
type='MBBlock',
kernel_size=5,
expand_ratio=3,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU6')),
identity=dict(type='Identity')))
ARCHSETTING_CFG = [
# Parameters to build layers. 4 parameters are needed to construct a
# layer, from left to right: channel, num_blocks, stride, mutable cfg.
[16, 1, 1, _FIRST_STAGE_MUTABLE],
[24, 2, 2, _OTHER_STAGE_MUTABLE],
[32, 3, 2, _OTHER_STAGE_MUTABLE],
[64, 4, 2, _OTHER_STAGE_MUTABLE],
[96, 3, 1, _OTHER_STAGE_MUTABLE],
[160, 3, 2, _OTHER_STAGE_MUTABLE],
[320, 1, 1, _OTHER_STAGE_MUTABLE]
]
NORM_CFG = dict(type='BN')
BACKBONE_CFG = dict(
type='mmrazor.SearchableMobileNet',
first_channels=32,
last_channels=1280,
widen_factor=1.0,
norm_cfg=NORM_CFG,
arch_setting=ARCHSETTING_CFG)
class FoolAddConstant(Module):
def __init__(self, p: float = 0.1) -> None:
super().__init__()
self.register_parameter(
name='p', param=Parameter(torch.tensor(p, dtype=torch.float32)))
def forward(self, x: Tensor) -> Tensor:
return x + self.p
class FoolConv2d(Module):
def __init__(self) -> None:
super().__init__()
self.conv2d = Conv2d(3, 32, 3)
def forward(self, x: Tensor) -> Tensor:
return self.conv2d(x)
class FoolConvModule(Module):
def __init__(self) -> None:
super().__init__()
self.add_constant = FoolAddConstant(0.1)
self.conv2d = FoolConv2d()
def forward(self, x: Tensor) -> Tensor:
x = self.add_constant(x)
return self.conv2d(x)
class TestFlopsEstimator(TestCase):
def sample_choice(self, model: Module) -> None:
for module in model.modules():
if isinstance(module, OneShotMutableModule):
module.current_choice = module.sample_choice()
def test_get_model_complexity_info(self) -> None:
fool_conv2d = FoolConv2d()
flops_count, params_count = FlopsEstimator.get_model_complexity_info(
fool_conv2d, as_strings=False)
self.assertGreater(flops_count, 0)
self.assertGreater(params_count, 0)
def test_register_module(self) -> None:
fool_add_constant = FoolConvModule()
copied_module = copy.deepcopy(fool_add_constant)
flops_count, params_count = FlopsEstimator.get_model_complexity_info(
copied_module, as_strings=False)
def fool_add_constant_flops_counter_hook(add_constant_module: Module,
input: Tensor,
output: Tensor) -> None:
add_constant_module.__flops__ = 1e6
# test register directly
FlopsEstimator.register_module(
flops_counter_hook=fool_add_constant_flops_counter_hook,
module=FoolAddConstant)
copied_module = copy.deepcopy(fool_add_constant)
flops_count_after_registered, params_count_after_registered = \
FlopsEstimator.get_model_complexity_info(
model=copied_module, as_strings=False)
self.assertEqual(flops_count_after_registered - flops_count, 1e6)
self.assertEqual(params_count_after_registered - params_count, 0)
FlopsEstimator.remove_custom_module(FoolAddConstant)
# test register using decorator
FlopsEstimator.register_module(
flops_counter_hook=fool_add_constant_flops_counter_hook)(
FoolAddConstant)
copied_module = copy.deepcopy(fool_add_constant)
flops_count_after_registered, params_count_after_registered = \
FlopsEstimator.get_model_complexity_info(
model=copied_module, as_strings=False)
self.assertEqual(flops_count_after_registered - flops_count, 1e6)
self.assertEqual(params_count_after_registered - params_count, 0)
FlopsEstimator.remove_custom_module(FoolAddConstant)
def test_register_module_wrong_parameter(self) -> None:
def fool_flops_counter_hook(module: Module, input: Tensor,
output: Tensor) -> None:
return
with pytest.raises(TypeError):
FlopsEstimator.register_module(
flops_counter_hook=fool_flops_counter_hook, force=1)
with pytest.raises(TypeError):
FlopsEstimator.register_module(
flops_counter_hook=fool_flops_counter_hook, module=list)
with pytest.raises(TypeError):
FlopsEstimator.register_module(
flops_counter_hook=123, module=FoolAddConstant)
# test double register
FlopsEstimator.register_module(
flops_counter_hook=fool_flops_counter_hook, module=FoolAddConstant)
with pytest.raises(KeyError):
FlopsEstimator.register_module(
flops_counter_hook=fool_flops_counter_hook,
module=FoolAddConstant)
FlopsEstimator.register_module(
flops_counter_hook=fool_flops_counter_hook,
module=FoolAddConstant,
force=True)
FlopsEstimator.remove_custom_module(FoolAddConstant)
def test_remove_custom_module(self) -> None:
with pytest.raises(KeyError):
FlopsEstimator.remove_custom_module(FoolAddConstant)
def fool_flops_counter_hook(module: Module, input: Tensor,
output: Tensor) -> None:
return
FlopsEstimator.register_module(
flops_counter_hook=fool_flops_counter_hook, module=FoolAddConstant)
FlopsEstimator.remove_custom_module(FoolAddConstant)
def test_clear_custom_module(self) -> None:
def fool_flops_counter_hook(module: Module, input: Tensor,
output: Tensor) -> None:
return
FlopsEstimator.register_module(
flops_counter_hook=fool_flops_counter_hook, module=FoolAddConstant)
FlopsEstimator.register_module(
flops_counter_hook=fool_flops_counter_hook, module=FoolConvModule)
FlopsEstimator.clear_custom_module()
self.assertEqual(FlopsEstimator.get_custom_modules_mapping(), {})
def test_get_model_complexity_info_subnet(self) -> None:
model = MODELS.build(BACKBONE_CFG)
self.sample_choice(model)
copied_model = copy.deepcopy(model)
flops_count, params_count = FlopsEstimator.get_model_complexity_info(
copied_model, as_strings=False)
fix_subnet = export_fix_subnet(model)
subnet_flops_count, subnet_params_count = \
FlopsEstimator.get_model_complexity_info(
model, fix_subnet, as_strings=False)
self.assertEqual(flops_count, subnet_flops_count)
self.assertGreater(params_count, subnet_params_count)
# test whether subnet estimate will affect original model
copied_model = copy.deepcopy(model)
flops_count_after_estimate, params_count_after_estimate = \
FlopsEstimator.get_model_complexity_info(
copied_model, as_strings=False)
self.assertEqual(flops_count, flops_count_after_estimate)
self.assertEqual(params_count, params_count_after_estimate)

View File

@ -0,0 +1,197 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
from unittest import TestCase
import pytest
import torch
from torch import Tensor
from torch.nn import Conv2d, Module, Parameter
from mmrazor.models import OneShotMutableModule, ResourceEstimator
from mmrazor.models.task_modules.estimators.counters import BaseCounter
from mmrazor.registry import MODELS, TASK_UTILS
from mmrazor.structures import export_fix_subnet, load_fix_subnet
_FIRST_STAGE_MUTABLE = dict(
type='OneShotMutableOP',
candidates=dict(
mb_k3e1=dict(
type='MBBlock',
kernel_size=3,
expand_ratio=1,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU6'))))
_OTHER_STAGE_MUTABLE = dict(
type='OneShotMutableOP',
candidates=dict(
mb_k3e3=dict(
type='MBBlock',
kernel_size=3,
expand_ratio=3,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU6')),
mb_k5e3=dict(
type='MBBlock',
kernel_size=5,
expand_ratio=3,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU6')),
identity=dict(type='Identity')))
ARCHSETTING_CFG = [
# Parameters to build layers. 4 parameters are needed to construct a
# layer, from left to right: channel, num_blocks, stride, mutable cfg.
[16, 1, 1, _FIRST_STAGE_MUTABLE],
[24, 2, 2, _OTHER_STAGE_MUTABLE],
[32, 3, 2, _OTHER_STAGE_MUTABLE],
[64, 4, 2, _OTHER_STAGE_MUTABLE],
[96, 3, 1, _OTHER_STAGE_MUTABLE],
[160, 3, 2, _OTHER_STAGE_MUTABLE],
[320, 1, 1, _OTHER_STAGE_MUTABLE]
]
NORM_CFG = dict(type='BN')
BACKBONE_CFG = dict(
type='mmrazor.SearchableMobileNet',
first_channels=32,
last_channels=1280,
widen_factor=1.0,
norm_cfg=NORM_CFG,
arch_setting=ARCHSETTING_CFG)
estimator = ResourceEstimator()
class FoolAddConstant(Module):
def __init__(self, p: float = 0.1) -> None:
super().__init__()
self.register_parameter(
name='p', param=Parameter(torch.tensor(p, dtype=torch.float32)))
def forward(self, x: Tensor) -> Tensor:
return x + self.p
@TASK_UTILS.register_module()
class FoolAddConstantCounter(BaseCounter):
@staticmethod
def add_count_hook(module, input, output):
module.__flops__ += 1000000
module.__params__ += 700000
class FoolConv2d(Module):
def __init__(self) -> None:
super().__init__()
self.conv2d = Conv2d(3, 32, 3)
def forward(self, x: Tensor) -> Tensor:
return self.conv2d(x)
class FoolConvModule(Module):
def __init__(self) -> None:
super().__init__()
self.add_constant = FoolAddConstant(0.1)
self.conv2d = FoolConv2d()
def forward(self, x: Tensor) -> Tensor:
x = self.add_constant(x)
return self.conv2d(x)
class TestResourceEstimator(TestCase):
def sample_choice(self, model: Module) -> None:
for module in model.modules():
if isinstance(module, OneShotMutableModule):
module.current_choice = module.sample_choice()
def test_estimate(self) -> None:
fool_conv2d = FoolConv2d()
results = estimator.estimate(
model=fool_conv2d,
resource_args=dict(input_shape=(1, 3, 224, 224)))
flops_count = results['flops']
params_count = results['params']
self.assertGreater(flops_count, 0)
self.assertGreater(params_count, 0)
def test_register_module(self) -> None:
fool_add_constant = FoolConvModule()
results = estimator.estimate(
model=fool_add_constant,
resource_args=dict(input_shape=(1, 3, 224, 224)))
flops_count = results['flops']
params_count = results['params']
self.assertEqual(flops_count, 45.158)
self.assertEqual(params_count, 0.701)
def test_disable_sepc_counter(self) -> None:
fool_add_constant = FoolConvModule()
rest_results = estimator.estimate(
model=fool_add_constant,
resource_args=dict(
input_shape=(1, 3, 224, 224),
disabled_counters=['FoolAddConstantCounter']))
rest_flops_count = rest_results['flops']
rest_params_count = rest_results['params']
self.assertLess(rest_flops_count, 45.158)
self.assertLess(rest_params_count, 0.701)
def test_estimate_spec_modules(self) -> None:
fool_add_constant = FoolConvModule()
results = estimator.estimate_spec_modules(
model=fool_add_constant,
resource_args=dict(
input_shape=(1, 3, 224, 224), spec_modules=['add_constant']))
self.assertGreater(results['add_constant']['flops'], 0)
with pytest.raises(AssertionError):
results = estimator.estimate_spec_modules(
model=fool_add_constant,
resource_args=dict(
input_shape=(1, 3, 224, 224), spec_modules=['backbone']))
def test_estimate_subnet(self) -> None:
resource_args = dict(input_shape=(1, 3, 224, 224))
model = MODELS.build(BACKBONE_CFG)
self.sample_choice(model)
copied_model = copy.deepcopy(model)
results = estimator.estimate(
model=copied_model, resource_args=resource_args)
flops_count = results['flops']
params_count = results['params']
fix_subnet = export_fix_subnet(model)
load_fix_subnet(copied_model, fix_subnet)
subnet_results = estimator.estimate(
model=copied_model, resource_args=resource_args)
subnet_flops_count = subnet_results['flops']
subnet_params_count = subnet_results['params']
self.assertEqual(flops_count, subnet_flops_count)
self.assertEqual(params_count, subnet_params_count)
# test whether subnet estimate will affect original model
copied_model = copy.deepcopy(model)
results_after_estimate = \
estimator.estimate(model=copied_model, resource_args=resource_args)
flops_count_after_estimate = results_after_estimate['flops']
params_count_after_estimate = results_after_estimate['params']
self.assertEqual(flops_count, flops_count_after_estimate)
self.assertEqual(params_count, params_count_after_estimate)

View File

@ -110,9 +110,16 @@ class TestEvolutionSearchLoop(TestCase):
self.assertIsInstance(loop, EvolutionSearchLoop)
self.assertEqual(loop.candidates, fake_candidates)
@patch('mmrazor.engine.runner.evolution_search_loop.export_fix_subnet')
@patch('mmrazor.structures.FlopsEstimator.get_model_complexity_info')
def test_run_epoch(self, mock_flops, mock_export_fix_subnet):
@patch(
'mmrazor.engine.runner.evolution_search_loop.export_fix_subnet',
return_value={
'1': 'choice1',
'2': 'choice2'
})
@patch(
'mmrazor.models.task_modules.ResourceEstimator.estimate',
return_value=dict(flops=50.0, params=1.0))
def test_run_epoch(self, export_fix_subnet, estimate):
# test_run_epoch: distributed == False
loop_cfg = copy.deepcopy(self.train_cfg)
loop_cfg.runner = self.runner
@ -123,8 +130,6 @@ class TestEvolutionSearchLoop(TestCase):
self.runner.epoch = 1
self.runner.distributed = False
self.runner.work_dir = self.temp_dir
fake_subnet = {'1': 'choice1', '2': 'choice2'}
self.runner.model.sample_subnet = MagicMock(return_value=fake_subnet)
loop.run_epoch()
self.assertEqual(len(loop.candidates), 4)
self.assertEqual(len(loop.top_k_candidates), 2)
@ -136,8 +141,6 @@ class TestEvolutionSearchLoop(TestCase):
self.runner.epoch = 1
self.runner.distributed = True
self.runner.work_dir = self.temp_dir
fake_subnet = {'1': 'choice1', '2': 'choice2'}
self.runner.model.sample_subnet = MagicMock(return_value=fake_subnet)
loop.run_epoch()
self.assertEqual(len(loop.candidates), 4)
self.assertEqual(len(loop.top_k_candidates), 2)
@ -150,10 +153,6 @@ class TestEvolutionSearchLoop(TestCase):
self.runner.epoch = 1
self.runner.distributed = True
self.runner.work_dir = self.temp_dir
fake_subnet = {'1': 'choice1', '2': 'choice2'}
loop.model.sample_subnet = MagicMock(return_value=fake_subnet)
mock_flops.return_value = (50., 1)
mock_export_fix_subnet.return_value = fake_subnet
loop.run_epoch()
self.assertEqual(len(loop.candidates), 4)
self.assertEqual(len(loop.top_k_candidates), 2)

View File

@ -191,15 +191,20 @@ class TestGreedySamplerTrainLoop(TestCase):
self.assertEqual(subnet, fake_subnet)
self.assertEqual(len(loop.top_k_candidates), loop.top_k - 1)
@patch('mmrazor.engine.runner.subnet_sampler_loop.export_fix_subnet')
@patch('mmrazor.structures.FlopsEstimator.get_model_complexity_info')
def test_run(self, mock_flops, mock_export_fix_subnet):
@patch(
'mmrazor.engine.runner.evolution_search_loop.export_fix_subnet',
return_value={
'1': 'choice1',
'2': 'choice2'
})
@patch(
'mmrazor.models.task_modules.ResourceEstimator.estimate',
return_value=dict(flops=50.0, params=1.0))
def test_run(self, export_fix_subnet, estimate):
# test run with flops_range=None
cfg = copy.deepcopy(self.iter_based_cfg)
cfg.experiment_name = 'test_run1'
runner = Runner.from_cfg(cfg)
fake_subnet = {'1': 'choice1', '2': 'choice2'}
runner.model.sample_subnet = MagicMock(return_value=fake_subnet)
runner.train()
self.assertEqual(runner.iter, runner.max_iters)
@ -210,10 +215,6 @@ class TestGreedySamplerTrainLoop(TestCase):
cfg.experiment_name = 'test_run2'
cfg.train_cfg.flops_range = (0, 100)
runner = Runner.from_cfg(cfg)
fake_subnet = {'1': 'choice1', '2': 'choice2'}
runner.model.sample_subnet = MagicMock(return_value=fake_subnet)
mock_flops.return_value = (50., 1)
mock_export_fix_subnet.return_value = fake_subnet
runner.train()
self.assertEqual(runner.iter, runner.max_iters)