[Improvement] Update estimator with api revision (#277)
* update estimator usage and fix bugs * refactor api of estimator & add inner check methods * fix docstrings * update search loop and config * fix lint * update unittest * decouple mmdet dependency and fix lint Co-authored-by: humu789 <humu@pjlab.org.cn>pull/273/head^2
parent
eb25bb7577
commit
4e80037393
|
@ -13,5 +13,5 @@ train_cfg = dict(
|
|||
num_mutation=25,
|
||||
num_crossover=25,
|
||||
mutate_prob=0.1,
|
||||
flops_range=(0., 465 * 1e6),
|
||||
flops_range=(0., 465.),
|
||||
score_key='accuracy/top1')
|
||||
|
|
|
@ -13,5 +13,5 @@ train_cfg = dict(
|
|||
num_mutation=25,
|
||||
num_crossover=25,
|
||||
mutate_prob=0.1,
|
||||
flops_range=(0., 330 * 1e6),
|
||||
flops_range=(0., 330.),
|
||||
score_key='accuracy/top1')
|
||||
|
|
|
@ -13,5 +13,5 @@ train_cfg = dict(
|
|||
num_mutation=20,
|
||||
num_crossover=20,
|
||||
mutate_prob=0.1,
|
||||
flops_range=None,
|
||||
score_key='bbox_mAP')
|
||||
flops_range=(0., 300.),
|
||||
score_key='coco/bbox_mAP')
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
import os
|
||||
import os.path as osp
|
||||
import random
|
||||
|
@ -14,11 +13,11 @@ from mmengine.runner import EpochBasedTrainLoop
|
|||
from mmengine.utils import is_list_of
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from mmrazor.models.task_modules.estimators import get_model_complexity_info
|
||||
from mmrazor.models.task_modules import ResourceEstimator
|
||||
from mmrazor.registry import LOOPS
|
||||
from mmrazor.structures import Candidates, export_fix_subnet, load_fix_subnet
|
||||
from mmrazor.structures import Candidates, export_fix_subnet
|
||||
from mmrazor.utils import SupportRandomSubnet
|
||||
from .utils import crossover
|
||||
from .utils import check_subnet_flops, crossover
|
||||
|
||||
|
||||
@LOOPS.register_module()
|
||||
|
@ -42,10 +41,10 @@ class EvolutionSearchLoop(EpochBasedTrainLoop):
|
|||
num_crossover (int): The number of candidates got by crossover.
|
||||
Defaults to 25.
|
||||
mutate_prob (float): The probability of mutation. Defaults to 0.1.
|
||||
flops_range (tuple, optional): flops_range to be used for screening
|
||||
candidates.
|
||||
spec_modules (list): Used for specify modules need to counter.
|
||||
Defaults to list().
|
||||
flops_range (tuple, optional): It is used for screening candidates.
|
||||
resource_estimator_cfg (dict): The config for building estimator, which
|
||||
is be used to estimate the flops of sampled subnet. Defaults to
|
||||
None, which means default config is used.
|
||||
score_key (str): Specify one metric in evaluation results to score
|
||||
candidates. Defaults to 'accuracy_top-1'.
|
||||
init_candidates (str, optional): The candidates file path, which is
|
||||
|
@ -65,8 +64,8 @@ class EvolutionSearchLoop(EpochBasedTrainLoop):
|
|||
num_mutation: int = 25,
|
||||
num_crossover: int = 25,
|
||||
mutate_prob: float = 0.1,
|
||||
flops_range: Optional[Tuple[float, float]] = (0., 330 * 1e6),
|
||||
spec_modules: List = [],
|
||||
flops_range: Optional[Tuple[float, float]] = (0., 330.),
|
||||
resource_estimator_cfg: Optional[dict] = None,
|
||||
score_key: str = 'accuracy/top1',
|
||||
init_candidates: Optional[str] = None) -> None:
|
||||
super().__init__(runner, dataloader, max_epochs)
|
||||
|
@ -85,7 +84,6 @@ class EvolutionSearchLoop(EpochBasedTrainLoop):
|
|||
self.num_candidates = num_candidates
|
||||
self.top_k = top_k
|
||||
self.flops_range = flops_range
|
||||
self.spec_modules = spec_modules
|
||||
self.score_key = score_key
|
||||
self.num_mutation = num_mutation
|
||||
self.num_crossover = num_crossover
|
||||
|
@ -101,6 +99,10 @@ class EvolutionSearchLoop(EpochBasedTrainLoop):
|
|||
correct init candidates file'
|
||||
|
||||
self.top_k_candidates = Candidates()
|
||||
if resource_estimator_cfg is None:
|
||||
self.estimator = ResourceEstimator()
|
||||
else:
|
||||
self.estimator = ResourceEstimator(**resource_estimator_cfg)
|
||||
|
||||
if self.runner.distributed:
|
||||
self.model = runner.model.module
|
||||
|
@ -299,17 +301,10 @@ class EvolutionSearchLoop(EpochBasedTrainLoop):
|
|||
Returns:
|
||||
bool: The result of checking.
|
||||
"""
|
||||
if self.flops_range is None:
|
||||
return True
|
||||
is_pass = check_subnet_flops(
|
||||
model=self.model,
|
||||
subnet=random_subnet,
|
||||
estimator=self.estimator,
|
||||
flops_range=self.flops_range)
|
||||
|
||||
self.model.set_subnet(random_subnet)
|
||||
fix_mutable = export_fix_subnet(self.model)
|
||||
copied_model = copy.deepcopy(self.model)
|
||||
load_fix_subnet(copied_model, fix_mutable)
|
||||
flops, _ = get_model_complexity_info(
|
||||
copied_model, spec_modules=self.spec_modules)
|
||||
|
||||
if self.flops_range[0] <= flops <= self.flops_range[1]:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
return is_pass
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
|
@ -13,10 +12,11 @@ from mmengine.runner import IterBasedTrainLoop
|
|||
from mmengine.utils import is_list_of
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from mmrazor.models.task_modules.estimators import get_model_complexity_info
|
||||
from mmrazor.models.task_modules import ResourceEstimator
|
||||
from mmrazor.registry import LOOPS
|
||||
from mmrazor.structures import Candidates, export_fix_subnet, load_fix_subnet
|
||||
from mmrazor.structures import Candidates
|
||||
from mmrazor.utils import SupportRandomSubnet
|
||||
from .utils import check_subnet_flops
|
||||
|
||||
|
||||
class BaseSamplerTrainLoop(IterBasedTrainLoop):
|
||||
|
@ -103,8 +103,9 @@ class GreedySamplerTrainLoop(BaseSamplerTrainLoop):
|
|||
score_key (str): Specify one metric in evaluation results to score
|
||||
candidates. Defaults to 'accuracy_top-1'.
|
||||
flops_range (dict): Constraints to be used for screening candidates.
|
||||
spec_modules (list): Used for specify modules need to counter.
|
||||
Defaults to list().
|
||||
resource_estimator_cfg (dict): The config for building estimator, which
|
||||
is be used to estimate the flops of sampled subnet. Defaults to
|
||||
None, which means default config is used.
|
||||
num_candidates (int): The number of the candidates consist of samples
|
||||
from supernet and itself. Defaults to 1000.
|
||||
num_samples (int): The number of sample in each sampling subnet.
|
||||
|
@ -138,8 +139,8 @@ class GreedySamplerTrainLoop(BaseSamplerTrainLoop):
|
|||
val_begin: int = 1,
|
||||
val_interval: int = 1000,
|
||||
score_key: str = 'accuracy/top1',
|
||||
flops_range: Optional[Tuple[float, float]] = (0., 330 * 1e6),
|
||||
spec_modules: List = [],
|
||||
flops_range: Optional[Tuple[float, float]] = (0., 330),
|
||||
resource_estimator_cfg: Optional[dict] = None,
|
||||
num_candidates: int = 1000,
|
||||
num_samples: int = 10,
|
||||
top_k: int = 5,
|
||||
|
@ -163,7 +164,6 @@ class GreedySamplerTrainLoop(BaseSamplerTrainLoop):
|
|||
|
||||
self.score_key = score_key
|
||||
self.flops_range = flops_range
|
||||
self.spec_modules = spec_modules
|
||||
self.num_candidates = num_candidates
|
||||
self.num_samples = num_samples
|
||||
self.top_k = top_k
|
||||
|
@ -177,6 +177,10 @@ class GreedySamplerTrainLoop(BaseSamplerTrainLoop):
|
|||
|
||||
self.candidates = Candidates()
|
||||
self.top_k_candidates = Candidates()
|
||||
if resource_estimator_cfg is None:
|
||||
self.estimator = ResourceEstimator()
|
||||
else:
|
||||
self.estimator = ResourceEstimator(**resource_estimator_cfg)
|
||||
|
||||
def run(self) -> None:
|
||||
"""Launch training."""
|
||||
|
@ -317,20 +321,13 @@ class GreedySamplerTrainLoop(BaseSamplerTrainLoop):
|
|||
Returns:
|
||||
bool: The result of checking.
|
||||
"""
|
||||
if self.flops_range is None:
|
||||
return True
|
||||
is_pass = check_subnet_flops(
|
||||
model=self.model,
|
||||
subnet=random_subnet,
|
||||
estimator=self.estimator,
|
||||
flops_range=self.flops_range)
|
||||
|
||||
self.model.set_subnet(random_subnet)
|
||||
fix_mutable = export_fix_subnet(self.model)
|
||||
copied_model = copy.deepcopy(self.model)
|
||||
load_fix_subnet(copied_model, fix_mutable)
|
||||
flops, _ = get_model_complexity_info(
|
||||
copied_model, spec_modules=self.spec_modules)
|
||||
|
||||
if self.flops_range[0] <= flops <= self.flops_range[1]:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
return is_pass
|
||||
|
||||
def _save_candidates(self) -> None:
|
||||
"""Save the candidates to init the next searching."""
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .check import check_subnet_flops
|
||||
from .genetic import crossover
|
||||
|
||||
__all__ = ['crossover']
|
||||
__all__ = ['crossover', 'check_subnet_flops']
|
||||
|
|
|
@ -0,0 +1,48 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
from mmrazor.models import ResourceEstimator
|
||||
from mmrazor.structures import export_fix_subnet, load_fix_subnet
|
||||
from mmrazor.utils import SupportRandomSubnet
|
||||
|
||||
try:
|
||||
from mmdet.models.detectors import BaseDetector
|
||||
except ImportError:
|
||||
from mmrazor.utils import get_placeholder
|
||||
BaseDetector = get_placeholder('mmdet')
|
||||
|
||||
|
||||
def check_subnet_flops(
|
||||
model: nn.Module,
|
||||
subnet: SupportRandomSubnet,
|
||||
estimator: ResourceEstimator,
|
||||
flops_range: Optional[Tuple[float, float]] = None) -> bool:
|
||||
"""Check whether is beyond flops constraints.
|
||||
|
||||
Returns:
|
||||
bool: The result of checking.
|
||||
"""
|
||||
if flops_range is None:
|
||||
return True
|
||||
|
||||
assert hasattr(model, 'set_subnet') and hasattr(model, 'architecture')
|
||||
model.set_subnet(subnet)
|
||||
fix_mutable = export_fix_subnet(model)
|
||||
copied_model = copy.deepcopy(model)
|
||||
load_fix_subnet(copied_model, fix_mutable)
|
||||
|
||||
model_to_check = model.architecture
|
||||
if isinstance(model_to_check, BaseDetector):
|
||||
results = estimator.estimate(model=model_to_check.backbone)
|
||||
else:
|
||||
results = estimator.estimate(model=model_to_check)
|
||||
|
||||
flops = results['flops']
|
||||
flops_mix, flops_max = flops_range
|
||||
if flops_mix <= flops <= flops_max: # type: ignore
|
||||
return True
|
||||
else:
|
||||
return False
|
|
@ -1,6 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from typing import Any, Dict, List, Tuple
|
||||
from typing import Dict, Tuple, Union
|
||||
|
||||
import torch.nn
|
||||
|
||||
|
@ -12,44 +12,40 @@ class BaseEstimator(metaclass=ABCMeta):
|
|||
"""The base class of Estimator, used for estimating model infos.
|
||||
|
||||
Args:
|
||||
default_shape (tuple): Input data's default shape, for calculating
|
||||
input_shape (tuple): Input data's default shape, for calculating
|
||||
resources consume. Defaults to (1, 3, 224, 224).
|
||||
units (str): Resource units. Defaults to 'M'.
|
||||
disabled_counters (list): List of disabled spec op counters.
|
||||
Defaults to None.
|
||||
units (dict): A dict including required units. Default to dict().
|
||||
as_strings (bool): Output FLOPs and params counts in a string
|
||||
form. Default to False.
|
||||
measure_inference (bool): whether to measure infer speed or not.
|
||||
Default to False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
default_shape: Tuple = (1, 3, 224, 224),
|
||||
units: str = 'M',
|
||||
disabled_counters: List[str] = None,
|
||||
as_strings: bool = False,
|
||||
measure_inference: bool = False):
|
||||
assert len(default_shape) in [3, 4, 5], \
|
||||
f'Unsupported shape: {default_shape}'
|
||||
self.default_shape = default_shape
|
||||
input_shape: Tuple = (1, 3, 224, 224),
|
||||
units: Dict = dict(),
|
||||
as_strings: bool = False):
|
||||
assert len(input_shape) in [
|
||||
3, 4, 5
|
||||
], ('The length of input_shape must be in [3, 4, 5]. '
|
||||
f'Got `{len(input_shape)}`.')
|
||||
self.input_shape = input_shape
|
||||
self.units = units
|
||||
self.disabled_counters = disabled_counters
|
||||
self.as_strings = as_strings
|
||||
self.measure_inference = measure_inference
|
||||
|
||||
@abstractmethod
|
||||
def estimate(
|
||||
self, model: torch.nn.Module, resource_args: Dict[str, Any] = dict()
|
||||
) -> Dict[str, float]:
|
||||
def estimate(self,
|
||||
model: torch.nn.Module,
|
||||
flops_params_cfg: dict = None,
|
||||
latency_cfg: dict = None) -> Dict[str, Union[float, str]]:
|
||||
"""Estimate the resources(flops/params/latency) of the given model.
|
||||
|
||||
Args:
|
||||
model: The measured model.
|
||||
resource_args (Dict[str, float]): resources information.
|
||||
NOTE: resource_args have the same items() as the init cfgs.
|
||||
flops_params_cfg (dict): Cfg for estimating FLOPs and parameters.
|
||||
Default to None.
|
||||
latency_cfg (dict): Cfg for estimating latency. Default to None.
|
||||
|
||||
Returns:
|
||||
Dict[str, float]): A dict that containing resource results(flops,
|
||||
params and latency).
|
||||
Dict[str, Union[float, str]]): A dict that contains the resource
|
||||
results(FLOPs, params and latency).
|
||||
"""
|
||||
pass
|
||||
|
|
|
@ -1,10 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .flops_params_counter import (get_model_complexity_info,
|
||||
params_units_convert)
|
||||
from .latency_counter import repeat_measure_inference_speed
|
||||
from .flops_params_counter import get_model_flops_params
|
||||
from .latency_counter import get_model_latency
|
||||
from .op_counters import * # noqa: F401,F403
|
||||
|
||||
__all__ = [
|
||||
'get_model_complexity_info', 'params_units_convert',
|
||||
'repeat_measure_inference_speed'
|
||||
]
|
||||
__all__ = ['get_model_flops_params', 'get_model_latency']
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import sys
|
||||
from functools import partial
|
||||
from typing import Dict
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
@ -8,19 +9,21 @@ import torch.nn as nn
|
|||
from mmrazor.registry import TASK_UTILS
|
||||
|
||||
|
||||
def get_model_complexity_info(model,
|
||||
input_shape=(1, 3, 224, 224),
|
||||
spec_modules=[],
|
||||
disabled_counters=[],
|
||||
print_per_layer_stat=False,
|
||||
as_strings=False,
|
||||
input_constructor=None,
|
||||
flush=False,
|
||||
ost=sys.stdout):
|
||||
"""Get complexity information of a model. This method can calculate FLOPs
|
||||
and parameter counts of a model with corresponding input shape. It can also
|
||||
print complexity information for each layer in a model. Supported layers
|
||||
are listed as below:
|
||||
def get_model_flops_params(model,
|
||||
input_shape=(1, 3, 224, 224),
|
||||
spec_modules=[],
|
||||
disabled_counters=[],
|
||||
print_per_layer_stat=False,
|
||||
units=dict(flops='M', params='M'),
|
||||
as_strings=False,
|
||||
seperate_return: bool = False,
|
||||
input_constructor=None,
|
||||
flush=False,
|
||||
ost=sys.stdout):
|
||||
"""Get FLOPs and parameters of a model. This method can calculate FLOPs and
|
||||
parameter counts of a model with corresponding input shape. It can also
|
||||
print FLOPs and params for each layer in a model. Supported layers are
|
||||
listed as below:
|
||||
|
||||
- Convolutions: ``nn.Conv1d``, ``nn.Conv2d``, ``nn.Conv3d``.
|
||||
- Activations: ``nn.ReLU``, ``nn.PReLU``, ``nn.ELU``, ``nn.LeakyReLU``,
|
||||
|
@ -39,16 +42,20 @@ def get_model_complexity_info(model,
|
|||
Args:
|
||||
model (nn.Module): The model for complexity calculation.
|
||||
input_shape (tuple): Input shape (including batchsize) used for
|
||||
calculation. Default to (1, 3, 224, 224)
|
||||
calculation. Default to (1, 3, 224, 224).
|
||||
spec_modules (list): A list that contains the names of several spec
|
||||
modules, which users want to get resources infos of them.
|
||||
e.g., ['backbone', 'head'], ['backbone.layer1']. Default to [].
|
||||
disabled_counters (list): One can limit which ops' spec would be
|
||||
calculated. Default to [].
|
||||
print_per_layer_stat (bool): Whether to print complexity information
|
||||
print_per_layer_stat (bool): Whether to print FLOPs and params
|
||||
for each layer in a model. Default to True.
|
||||
units (dict): A dict including converted FLOPs and params units.
|
||||
Default to dict(flops='M', params='M').
|
||||
as_strings (bool): Output FLOPs and params counts in a string form.
|
||||
Default to True.
|
||||
seperate_return (bool): Whether to return the resource information
|
||||
separately. Default to False.
|
||||
input_constructor (None | callable): If specified, it takes a callable
|
||||
method that generates input. otherwise, it will generate a random
|
||||
tensor with input shape to calculate FLOPs. Default to None.
|
||||
|
@ -60,12 +67,16 @@ def get_model_complexity_info(model,
|
|||
tuple[float | str] | dict[str, float]: If `as_strings` is set to True,
|
||||
it will return FLOPs and parameter counts in a string format.
|
||||
Otherwise, it will return those in a float number format.
|
||||
If len(spec_modules) > 0, it will return a resource info dict with
|
||||
FLOPs and parameter counts of each spec module in float format.
|
||||
NOTE: If seperate_return, it will return a resource info dict with
|
||||
FLOPs & params counts of each spec module in float|string format.
|
||||
"""
|
||||
assert type(input_shape) is tuple
|
||||
assert len(input_shape) >= 1
|
||||
assert isinstance(model, nn.Module)
|
||||
if seperate_return and not len(spec_modules):
|
||||
raise AssertionError('`seperate_return` can only be set to True when '
|
||||
'`spec_modules` are not empty.')
|
||||
|
||||
flops_params_model = add_flops_params_counting_methods(model)
|
||||
flops_params_model.eval()
|
||||
flops_params_model.start_flops_params_count(disabled_counters)
|
||||
|
@ -96,34 +107,44 @@ def get_model_complexity_info(model,
|
|||
ost=ost,
|
||||
flush=flush)
|
||||
|
||||
if units is not None:
|
||||
flops_count = params_units_convert(flops_count, units['flops'])
|
||||
params_count = params_units_convert(params_count, units['params'])
|
||||
|
||||
if as_strings:
|
||||
flops_suffix = ' ' + units['flops'] + 'FLOPs' if units else ' FLOPs'
|
||||
params_suffix = ' ' + units['params'] if units else ''
|
||||
|
||||
if len(spec_modules):
|
||||
flops_count, params_count = 0.0, 0.0
|
||||
module_names = [name for name, _ in flops_params_model.named_modules()]
|
||||
for module in spec_modules:
|
||||
assert module in module_names, \
|
||||
f'All modules in spec_modules should be in the measured ' \
|
||||
f'flops_params_model. Got module {module} in spec_modules.'
|
||||
spec_modules_resources = dict()
|
||||
accumulate_sub_module_flops_params(flops_params_model)
|
||||
f'flops_params_model. Got module `{module}` in spec_modules.'
|
||||
spec_modules_resources: Dict[str, dict] = dict()
|
||||
accumulate_sub_module_flops_params(flops_params_model, units=units)
|
||||
for name, module in flops_params_model.named_modules():
|
||||
if name in spec_modules:
|
||||
spec_modules_resources[name] = dict()
|
||||
spec_modules_resources[name]['flops'] = module.__flops__
|
||||
spec_modules_resources[name]['params'] = module.__params__
|
||||
flops_count += module.__flops__
|
||||
params_count += module.__params__
|
||||
if as_strings:
|
||||
spec_modules_resources[name]['flops'] = str(
|
||||
params_units_convert(module.__flops__,
|
||||
'G')) + ' GFLOPs'
|
||||
spec_modules_resources[name]['params'] = str(
|
||||
params_units_convert(module.__params__, 'M')) + ' M'
|
||||
spec_modules_resources[name]['flops'] = \
|
||||
str(module.__flops__) + flops_suffix
|
||||
spec_modules_resources[name]['params'] = \
|
||||
str(module.__params__) + params_suffix
|
||||
|
||||
flops_params_model.stop_flops_params_count()
|
||||
|
||||
if len(spec_modules):
|
||||
if seperate_return:
|
||||
return spec_modules_resources
|
||||
|
||||
if as_strings:
|
||||
flops_string = str(params_units_convert(flops_count, 'G')) + ' GFLOPs'
|
||||
params_string = str(params_units_convert(params_count, 'M')) + ' M'
|
||||
flops_string = str(flops_count) + flops_suffix
|
||||
params_string = str(params_count) + params_suffix
|
||||
return flops_string, params_string
|
||||
|
||||
return flops_count, params_count
|
||||
|
@ -164,7 +185,7 @@ def params_units_convert(num_params, units='M', precision=3):
|
|||
def print_model_with_flops_params(model,
|
||||
total_flops,
|
||||
total_params,
|
||||
units='G',
|
||||
units=dict(flops='M', params='M'),
|
||||
precision=3,
|
||||
ost=sys.stdout,
|
||||
flush=False):
|
||||
|
@ -174,7 +195,9 @@ def print_model_with_flops_params(model,
|
|||
model (nn.Module): The model to be printed.
|
||||
total_flops (float): Total FLOPs of the model.
|
||||
total_params (float): Total parameter counts of the model.
|
||||
units (str | None): Converted FLOPs units. Default to 'G'.
|
||||
units (tuple | none): A tuple pair including converted FLOPs & params
|
||||
units. e.g., ('G', 'M') stands for FLOPs as 'G' & params as 'M'.
|
||||
Default to ('M', 'M').
|
||||
precision (int): Digit number after the decimal point. Default to 3.
|
||||
ost (stream): same as `file` param in :func:`print`.
|
||||
Default to sys.stdout.
|
||||
|
@ -200,8 +223,8 @@ def print_model_with_flops_params(model,
|
|||
>>> return x
|
||||
>>> model = ExampleModel()
|
||||
>>> x = (3, 16, 16)
|
||||
to print the complexity information state for each layer, you can use
|
||||
>>> get_model_complexity_info(model, x)
|
||||
to print the FLOPs and params state for each layer, you can use
|
||||
>>> get_model_flops_params(model, x)
|
||||
or directly use
|
||||
>>> print_model_with_flops_params(model, 4579784.0, 37361)
|
||||
ExampleModel(
|
||||
|
@ -241,11 +264,11 @@ def print_model_with_flops_params(model,
|
|||
accumulated_flops_cost = self.accumulate_flops()
|
||||
flops_string = str(
|
||||
params_units_convert(
|
||||
accumulated_flops_cost, units=units,
|
||||
precision=precision)) + ' ' + units + 'FLOPs'
|
||||
accumulated_flops_cost, units['flops'],
|
||||
precision=precision)) + ' ' + units['flops'] + 'FLOPs'
|
||||
params_string = str(
|
||||
params_units_convert(
|
||||
accumulated_num_params, units='M', precision=precision)) + ' M'
|
||||
params_units_convert(accumulated_num_params, units['params'],
|
||||
precision)) + ' M'
|
||||
return ', '.join([
|
||||
params_string,
|
||||
'{:.3%} Params'.format(accumulated_num_params / total_params),
|
||||
|
@ -277,12 +300,15 @@ def print_model_with_flops_params(model,
|
|||
model.apply(del_extra_repr)
|
||||
|
||||
|
||||
def accumulate_sub_module_flops_params(model):
|
||||
def accumulate_sub_module_flops_params(model, units=None):
|
||||
"""Accumulate FLOPs and params for each module in the model. Each module in
|
||||
the model will have the `__flops__` and `__params__` parameters.
|
||||
|
||||
Args:
|
||||
model (nn.Module): The model to be accumulated.
|
||||
units (tuple | none): A tuple pair including converted FLOPs & params
|
||||
units. e.g., ('G', 'M') stands for FLOPs as 'G' & params as 'M'.
|
||||
Default to None.
|
||||
"""
|
||||
|
||||
def accumulate_params(module):
|
||||
|
@ -310,6 +336,9 @@ def accumulate_sub_module_flops_params(model):
|
|||
_params = accumulate_params(module)
|
||||
module.__flops__ = _flops
|
||||
module.__params__ = _params
|
||||
if units is not None:
|
||||
module.__flops__ = params_units_convert(_flops, units['flops'])
|
||||
module.__params__ = params_units_convert(_params, units['params'])
|
||||
|
||||
|
||||
def get_model_parameters_number(model):
|
||||
|
|
|
@ -1,71 +1,89 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Dict
|
||||
from typing import Tuple, Union
|
||||
|
||||
import torch
|
||||
from mmengine.logging import print_log
|
||||
|
||||
|
||||
def repeat_measure_inference_speed(model: torch.nn.Module,
|
||||
resource_args: Dict[str, Any],
|
||||
max_iter: int = 100,
|
||||
num_warmup: int = 5,
|
||||
log_interval: int = 100,
|
||||
repeat_num: int = 1) -> float:
|
||||
def get_model_latency(model: torch.nn.Module,
|
||||
input_shape: Tuple = (1, 3, 224, 224),
|
||||
unit: str = 'ms',
|
||||
as_strings: bool = False,
|
||||
max_iter: int = 100,
|
||||
num_warmup: int = 5,
|
||||
log_interval: int = 100,
|
||||
repeat_num: int = 1) -> Union[float, str]:
|
||||
"""Repeat speed measure for multi-times to get more precise results.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): The measured model.
|
||||
resource_args (Dict[str, float]): resources information.
|
||||
max_iter (Optional[int]): Max iteration num for inference speed test.
|
||||
input_shape (tuple): Input shape (including batchsize) used for
|
||||
calculation. Default to (1, 3, 224, 224).
|
||||
unit (str): Unit of latency in string format. Default to 'ms'.
|
||||
as_strings (bool): Output latency counts in a string form.
|
||||
Default to False.
|
||||
max_iter (Optional[int]): Max iteration num for the measurement.
|
||||
Default to 100.
|
||||
num_warmup (Optional[int]): Iteration num for warm-up stage.
|
||||
Default to 5.
|
||||
log_interval (Optional[int]): Interval num for logging the results.
|
||||
Default to 100.
|
||||
repeat_num (Optional[int]): Num of times to repeat the measurement.
|
||||
Default to 1.
|
||||
|
||||
Returns:
|
||||
fps (float): The measured inference speed of the model.
|
||||
latency (Union[float, str]): The measured inference speed of the model.
|
||||
if ``as_strings=True``, it will return latency in string format.
|
||||
"""
|
||||
assert repeat_num >= 1
|
||||
|
||||
fps_list = []
|
||||
|
||||
for _ in range(repeat_num):
|
||||
|
||||
fps_list.append(
|
||||
measure_inference_speed(model, resource_args, max_iter, num_warmup,
|
||||
log_interval))
|
||||
_get_model_latency(model, input_shape, max_iter, num_warmup,
|
||||
log_interval))
|
||||
|
||||
latency = round(1000 / fps_list[0], 1)
|
||||
|
||||
if repeat_num > 1:
|
||||
fps_list_ = [round(fps, 1) for fps in fps_list]
|
||||
_fps_list = [round(fps, 1) for fps in fps_list]
|
||||
times_per_img_list = [round(1000 / fps, 1) for fps in fps_list]
|
||||
mean_fps_ = sum(fps_list_) / len(fps_list_)
|
||||
_mean_fps = sum(_fps_list) / len(_fps_list)
|
||||
mean_times_per_img = sum(times_per_img_list) / len(times_per_img_list)
|
||||
print_log(
|
||||
f'Overall fps: {fps_list_}[{mean_fps_:.1f}] img / s, '
|
||||
f'Overall fps: {_fps_list}[{_mean_fps:.1f}] img / s, '
|
||||
f'times per image: '
|
||||
f'{times_per_img_list}[{mean_times_per_img:.1f}] ms/img',
|
||||
logger='current',
|
||||
level=logging.DEBUG)
|
||||
return mean_times_per_img
|
||||
latency = mean_times_per_img
|
||||
|
||||
if as_strings:
|
||||
latency = str(latency) + ' ' + unit # type: ignore
|
||||
|
||||
latency = round(1000 / fps_list[0], 1)
|
||||
return latency
|
||||
|
||||
|
||||
def measure_inference_speed(model: torch.nn.Module,
|
||||
resource_args: Dict[str, Any],
|
||||
max_iter: int = 100,
|
||||
num_warmup: int = 5,
|
||||
log_interval: int = 100) -> float:
|
||||
def _get_model_latency(model: torch.nn.Module,
|
||||
input_shape: Tuple = (1, 3, 224, 224),
|
||||
max_iter: int = 100,
|
||||
num_warmup: int = 5,
|
||||
log_interval: int = 100) -> float:
|
||||
"""Measure inference speed on GPU devices.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): The measured model.
|
||||
resource_args (Dict[str, float]): resources information.
|
||||
max_iter (Optional[int]): Max iteration num for inference speed test.
|
||||
input_shape (tuple): Input shape (including batchsize) used for
|
||||
calculation. Default to (1, 3, 224, 224).
|
||||
max_iter (Optional[int]): Max iteration num for the measurement.
|
||||
Default to 100.
|
||||
num_warmup (Optional[int]): Iteration num for warm-up stage.
|
||||
Default to 5.
|
||||
log_interval (Optional[int]): Interval num for logging the results.
|
||||
Default to 100.
|
||||
|
||||
Returns:
|
||||
fps (float): The measured inference speed of the model.
|
||||
|
@ -78,10 +96,11 @@ def measure_inference_speed(model: torch.nn.Module,
|
|||
device = 'cuda'
|
||||
else:
|
||||
raise NotImplementedError('To use cpu to test latency not supported.')
|
||||
|
||||
# benchmark with {max_iter} image and take the average
|
||||
for i in range(1, max_iter):
|
||||
if device == 'cuda':
|
||||
data = torch.rand(resource_args['input_shape']).cuda()
|
||||
data = torch.rand(input_shape).cuda()
|
||||
torch.cuda.synchronize()
|
||||
start_time = time.perf_counter()
|
||||
|
||||
|
|
|
@ -1,13 +1,11 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Any, Dict, List, Tuple
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
|
||||
import torch.nn
|
||||
from mmengine.dist import broadcast_object_list, is_main_process
|
||||
|
||||
from mmrazor.registry import TASK_UTILS
|
||||
from .base_estimator import BaseEstimator
|
||||
from .counters import (get_model_complexity_info, params_units_convert,
|
||||
repeat_measure_inference_speed)
|
||||
from .counters import get_model_flops_params, get_model_latency
|
||||
|
||||
|
||||
@TASK_UTILS.register_module()
|
||||
|
@ -15,24 +13,30 @@ class ResourceEstimator(BaseEstimator):
|
|||
"""Estimator for calculating the resources consume.
|
||||
|
||||
Args:
|
||||
default_shape (tuple): Input data's default shape, for calculating
|
||||
resources consume. Defaults to (1, 3, 224, 224)
|
||||
units (str): Resource units. Defaults to 'M'.
|
||||
disabled_counters (list): List of disabled spec op counters.
|
||||
Defaults to None.
|
||||
NOTE: disabled_counters contains the op counter class names
|
||||
in estimator.op_counters that require to be disabled,
|
||||
such as 'ConvCounter', 'BatchNorm2dCounter', ...
|
||||
input_shape (tuple): Input data's default shape, for calculating
|
||||
resources consume. Defaults to (1, 3, 224, 224).
|
||||
units (dict): Dict that contains converted FLOPs/params/latency units.
|
||||
Default to dict(flops='M', params='M', latency='ms').
|
||||
as_strings (bool): Output FLOPs/params/latency counts in a string
|
||||
form. Default to False.
|
||||
flops_params_cfg (dict): Cfg for estimating FLOPs and parameters.
|
||||
Default to None.
|
||||
latency_cfg (dict): Cfg for estimating latency. Default to None.
|
||||
|
||||
Examples:
|
||||
>>> # direct calculate resource consume of nn.Conv2d
|
||||
>>> conv2d = nn.Conv2d(3, 32, 3)
|
||||
>>> estimator = ResourceEstimator()
|
||||
>>> estimator.estimate(
|
||||
... model=conv2d,
|
||||
... resource_args=dict(input_shape=(1, 3, 64, 64)))
|
||||
>>> estimator = ResourceEstimator(input_shape=(1, 3, 64, 64))
|
||||
>>> estimator.estimate(model=conv2d)
|
||||
{'flops': 3.444, 'params': 0.001, 'latency': 0.0}
|
||||
|
||||
>>> # direct calculate resource consume of nn.Conv2d
|
||||
>>> conv2d = nn.Conv2d(3, 32, 3)
|
||||
>>> estimator = ResourceEstimator()
|
||||
>>> flops_params_cfg = dict(input_shape=(1, 3, 32, 32))
|
||||
>>> estimator.estimate(model=conv2d, flops_params_cfg)
|
||||
{'flops': 0.806, 'params': 0.001, 'latency': 0.0}
|
||||
|
||||
>>> # calculate resources of custom modules
|
||||
>>> class CustomModule(nn.Module):
|
||||
...
|
||||
|
@ -51,17 +55,14 @@ class ResourceEstimator(BaseEstimator):
|
|||
... module.__params__ += 700000
|
||||
...
|
||||
>>> model = CustomModule()
|
||||
>>> estimator.estimate(
|
||||
... model=model,
|
||||
... resource_args=dict(input_shape=(1, 3, 64, 64)))
|
||||
>>> flops_params_cfg = dict(input_shape=(1, 3, 64, 64))
|
||||
>>> estimator.estimate(model=model, flops_params_cfg)
|
||||
{'flops': 1.0, 'params': 0.7, 'latency': 0.0}
|
||||
...
|
||||
>>> # calculate resources of custom modules with disable_counters
|
||||
>>> estimator.estimate(
|
||||
... model=model,
|
||||
... resource_args=dict(
|
||||
... input_shape=(1, 3, 64, 64),
|
||||
... disabled_counters=['CustomModuleCounter']))
|
||||
>>> flops_params_cfg = dict(input_shape=(1, 3, 64, 64),
|
||||
... disabled_counters=['CustomModuleCounter'])
|
||||
>>> estimator.estimate(model=model, flops_params_cfg)
|
||||
{'flops': 0.0, 'params': 0.0, 'latency': 0.0}
|
||||
|
||||
>>> # calculate resources of mmrazor.models
|
||||
|
@ -69,87 +70,146 @@ class ResourceEstimator(BaseEstimator):
|
|||
mmrazor.engine.hooks.estimate_resources_hook for details.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
default_shape: Tuple = (1, 3, 224, 224),
|
||||
units: str = 'M',
|
||||
disabled_counters: List[str] = [],
|
||||
as_strings: bool = False,
|
||||
measure_inference: bool = False):
|
||||
super().__init__(default_shape, units, disabled_counters, as_strings,
|
||||
measure_inference)
|
||||
def __init__(
|
||||
self,
|
||||
input_shape: Tuple = (1, 3, 224, 224),
|
||||
units: Dict = dict(flops='M', params='M', latency='ms'),
|
||||
as_strings: bool = False,
|
||||
flops_params_cfg: Optional[dict] = None,
|
||||
latency_cfg: Optional[dict] = None,
|
||||
):
|
||||
super().__init__(input_shape, units, as_strings)
|
||||
if not isinstance(units, dict):
|
||||
raise TypeError('units for estimator should be a dict',
|
||||
f'but got `{type(units)}`')
|
||||
for unit_key in units:
|
||||
if unit_key not in ['flops', 'params', 'latency']:
|
||||
raise KeyError(f'Got invalid key `{unit_key}` in units. ',
|
||||
'Should be `flops`, `params` or `latency`.')
|
||||
if flops_params_cfg:
|
||||
self.flops_params_cfg = flops_params_cfg
|
||||
else:
|
||||
self.flops_params_cfg = dict()
|
||||
self.latency_cfg = latency_cfg if latency_cfg else dict()
|
||||
|
||||
def estimate(
|
||||
self, model: torch.nn.Module, resource_args: Dict[str, Any] = dict()
|
||||
) -> Dict[str, Any]:
|
||||
def estimate(self,
|
||||
model: torch.nn.Module,
|
||||
flops_params_cfg: dict = None,
|
||||
latency_cfg: dict = None) -> Dict[str, Union[float, str]]:
|
||||
"""Estimate the resources(flops/params/latency) of the given model.
|
||||
|
||||
This method will first parse the merged :attr:`self.flops_params_cfg`
|
||||
and the :attr:`self.latency_cfg` to check whether the keys are valid.
|
||||
|
||||
Args:
|
||||
model: The measured model.
|
||||
resource_args (Dict[str, float]): Args for resources estimation.
|
||||
NOTE: resource_args have the same items() as the init cfgs.
|
||||
flops_params_cfg (dict): Cfg for estimating FLOPs and parameters.
|
||||
Default to None.
|
||||
latency_cfg (dict): Cfg for estimating latency. Default to None.
|
||||
|
||||
NOTE: If the `flops_params_cfg` and `latency_cfg` are both None,
|
||||
this method will only estimate FLOPs/params with default settings.
|
||||
|
||||
Returns:
|
||||
Dict[str, str]): A dict that containing resource results(flops,
|
||||
params and latency).
|
||||
Dict[str, Union[float, str]]): A dict that contains the resource
|
||||
results(FLOPs, params and latency).
|
||||
"""
|
||||
resource_metrics = dict()
|
||||
if is_main_process():
|
||||
measure_inference = resource_args.pop('measure_inference', False)
|
||||
if 'input_shape' not in resource_args.keys():
|
||||
resource_args['input_shape'] = self.default_shape
|
||||
if 'disabled_counters' not in resource_args.keys():
|
||||
resource_args['disabled_counters'] = self.disabled_counters
|
||||
model.eval()
|
||||
flops, params = get_model_complexity_info(model, **resource_args)
|
||||
if measure_inference:
|
||||
latency = repeat_measure_inference_speed(
|
||||
model, resource_args, max_iter=100, repeat_num=2)
|
||||
else:
|
||||
latency = 0.0
|
||||
as_strings = resource_args.get('as_strings', self.as_strings)
|
||||
if as_strings and self.units is not None:
|
||||
raise ValueError('Set units to None, when as_trings=True.')
|
||||
if self.units is not None:
|
||||
flops = params_units_convert(flops, self.units)
|
||||
params = params_units_convert(params, self.units)
|
||||
resource_metrics.update({
|
||||
'flops': flops,
|
||||
'params': params,
|
||||
'latency': latency
|
||||
})
|
||||
results = [resource_metrics]
|
||||
measure_latency = True if latency_cfg else False
|
||||
|
||||
if flops_params_cfg:
|
||||
flops_params_cfg = {**self.flops_params_cfg, **flops_params_cfg}
|
||||
self._check_flops_params_cfg(flops_params_cfg)
|
||||
flops_params_cfg = self._set_default_resource_params(
|
||||
flops_params_cfg)
|
||||
else:
|
||||
results = [None] # type: ignore
|
||||
flops_params_cfg = self.flops_params_cfg
|
||||
|
||||
broadcast_object_list(results)
|
||||
if latency_cfg:
|
||||
latency_cfg = {**self.latency_cfg, **latency_cfg}
|
||||
self._check_latency_cfg(latency_cfg)
|
||||
latency_cfg = self._set_default_resource_params(latency_cfg)
|
||||
else:
|
||||
latency_cfg = self.latency_cfg
|
||||
|
||||
return results[0]
|
||||
model.eval()
|
||||
flops, params = get_model_flops_params(model, **flops_params_cfg)
|
||||
if measure_latency:
|
||||
latency = get_model_latency(model, **latency_cfg)
|
||||
else:
|
||||
latency = '0.0 ms' if self.as_strings else 0.0 # type: ignore
|
||||
|
||||
def estimate_spec_modules(
|
||||
self, model: torch.nn.Module, resource_args: Dict[str, Any] = dict()
|
||||
) -> Dict[str, float]:
|
||||
"""Estimate the resources(flops/params/latency) of the spec modules.
|
||||
resource_metrics.update({
|
||||
'flops': flops,
|
||||
'params': params,
|
||||
'latency': latency
|
||||
})
|
||||
return resource_metrics
|
||||
|
||||
def estimate_separation_modules(
|
||||
self,
|
||||
model: torch.nn.Module,
|
||||
flops_params_cfg: dict = None) -> Dict[str, Union[float, str]]:
|
||||
"""Estimate FLOPs and params of the spec modules with separate return.
|
||||
|
||||
Args:
|
||||
model: The measured model.
|
||||
resource_args (Dict[str, float]): Args for resources estimation.
|
||||
NOTE: resource_args have the same items() as the init cfgs.
|
||||
flops_params_cfg (dict): Cfg for estimating FLOPs and parameters.
|
||||
Default to None.
|
||||
|
||||
Returns:
|
||||
Dict[str, float]): A dict that containing resource results(flops,
|
||||
params) of each modules in resource_args['spec_modules'].
|
||||
Dict[str, Union[float, str]]): A dict that contains the FLOPs and
|
||||
params results (string | float format) of each modules in the
|
||||
``flops_params_cfg['spec_modules']``.
|
||||
"""
|
||||
assert 'spec_modules' in resource_args, \
|
||||
'spec_modules is required when calling estimate_spec_modules().'
|
||||
if flops_params_cfg:
|
||||
flops_params_cfg = {**self.flops_params_cfg, **flops_params_cfg}
|
||||
self._check_flops_params_cfg(flops_params_cfg)
|
||||
flops_params_cfg = self._set_default_resource_params(
|
||||
flops_params_cfg)
|
||||
else:
|
||||
flops_params_cfg = self.flops_params_cfg
|
||||
flops_params_cfg['seperate_return'] = True
|
||||
|
||||
resource_args.pop('measure_inference', False)
|
||||
if 'input_shape' not in resource_args.keys():
|
||||
resource_args['input_shape'] = self.default_shape
|
||||
if 'disabled_counters' not in resource_args.keys():
|
||||
resource_args['disabled_counters'] = self.disabled_counters
|
||||
assert len(flops_params_cfg['spec_modules']), (
|
||||
'spec_modules can not be empty when calling '
|
||||
f'`estimate_separation_modules` of {self.__class__.__name__} ')
|
||||
|
||||
model.eval()
|
||||
spec_modules_resources = get_model_complexity_info(
|
||||
model, **resource_args)
|
||||
|
||||
spec_modules_resources = get_model_flops_params(
|
||||
model, **flops_params_cfg)
|
||||
return spec_modules_resources
|
||||
|
||||
def _check_flops_params_cfg(self, flops_params_cfg: dict) -> None:
|
||||
"""Check the legality of ``flops_params_cfg``.
|
||||
|
||||
Args:
|
||||
flops_params_cfg (dict): Cfg for estimating FLOPs and parameters.
|
||||
"""
|
||||
for key in flops_params_cfg:
|
||||
if key not in get_model_flops_params.__code__.co_varnames[
|
||||
1:]: # type: ignore
|
||||
raise KeyError(f'Got invalid key `{key}` in flops_params_cfg.')
|
||||
|
||||
def _check_latency_cfg(self, latency_cfg: dict) -> None:
|
||||
"""Check the legality of ``latency_cfg``.
|
||||
|
||||
Args:
|
||||
latency_cfg (dict): Cfg for estimating latency.
|
||||
"""
|
||||
for key in latency_cfg:
|
||||
if key not in get_model_latency.__code__.co_varnames[
|
||||
1:]: # type: ignore
|
||||
raise KeyError(f'Got invalid key `{key}` in latency_cfg.')
|
||||
|
||||
def _set_default_resource_params(self, cfg: dict) -> dict:
|
||||
"""Set default attributes for the input cfgs.
|
||||
|
||||
Args:
|
||||
cfg (dict): flops_params_cfg or latency_cfg.
|
||||
"""
|
||||
default_common_settings = ['input_shape', 'units', 'as_strings']
|
||||
for key in default_common_settings:
|
||||
if key not in cfg:
|
||||
cfg[key] = getattr(self, key)
|
||||
return cfg
|
||||
|
|
|
@ -1,9 +1,14 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
from mmdet.models import BaseDetector
|
||||
|
||||
from mmrazor.registry import TASK_UTILS
|
||||
|
||||
try:
|
||||
from mmdet.models.detectors import BaseDetector
|
||||
except ImportError:
|
||||
from mmrazor.utils import get_placeholder
|
||||
BaseDetector = get_placeholder('mmdet')
|
||||
|
||||
|
||||
# todo: adapt to mmdet 2.0
|
||||
@TASK_UTILS.register_module()
|
||||
|
|
|
@ -118,9 +118,9 @@ class TestResourceEstimator(TestCase):
|
|||
|
||||
def test_estimate(self) -> None:
|
||||
fool_conv2d = FoolConv2d()
|
||||
flops_params_cfg = dict(input_shape=(1, 3, 224, 224))
|
||||
results = estimator.estimate(
|
||||
model=fool_conv2d,
|
||||
resource_args=dict(input_shape=(1, 3, 224, 224)))
|
||||
model=fool_conv2d, flops_params_cfg=flops_params_cfg)
|
||||
flops_count = results['flops']
|
||||
params_count = results['params']
|
||||
|
||||
|
@ -129,9 +129,9 @@ class TestResourceEstimator(TestCase):
|
|||
|
||||
def test_register_module(self) -> None:
|
||||
fool_add_constant = FoolConvModule()
|
||||
flops_params_cfg = dict(input_shape=(1, 3, 224, 224))
|
||||
results = estimator.estimate(
|
||||
model=fool_add_constant,
|
||||
resource_args=dict(input_shape=(1, 3, 224, 224)))
|
||||
model=fool_add_constant, flops_params_cfg=flops_params_cfg)
|
||||
flops_count = results['flops']
|
||||
params_count = results['params']
|
||||
|
||||
|
@ -140,46 +140,65 @@ class TestResourceEstimator(TestCase):
|
|||
|
||||
def test_disable_sepc_counter(self) -> None:
|
||||
fool_add_constant = FoolConvModule()
|
||||
flops_params_cfg = dict(
|
||||
input_shape=(1, 3, 224, 224),
|
||||
disabled_counters=['FoolAddConstantCounter'])
|
||||
rest_results = estimator.estimate(
|
||||
model=fool_add_constant,
|
||||
resource_args=dict(
|
||||
input_shape=(1, 3, 224, 224),
|
||||
disabled_counters=['FoolAddConstantCounter']))
|
||||
model=fool_add_constant, flops_params_cfg=flops_params_cfg)
|
||||
rest_flops_count = rest_results['flops']
|
||||
rest_params_count = rest_results['params']
|
||||
|
||||
self.assertLess(rest_flops_count, 45.158)
|
||||
self.assertLess(rest_params_count, 0.701)
|
||||
|
||||
def test_estimate_spec_modules(self) -> None:
|
||||
def test_estimate_spec_module(self) -> None:
|
||||
fool_add_constant = FoolConvModule()
|
||||
results = estimator.estimate_spec_modules(
|
||||
model=fool_add_constant,
|
||||
resource_args=dict(
|
||||
input_shape=(1, 3, 224, 224), spec_modules=['add_constant']))
|
||||
flops_params_cfg = dict(
|
||||
input_shape=(1, 3, 224, 224),
|
||||
spec_modules=['add_constant', 'conv2d'])
|
||||
results = estimator.estimate(
|
||||
model=fool_add_constant, flops_params_cfg=flops_params_cfg)
|
||||
flops_count = results['flops']
|
||||
params_count = results['params']
|
||||
|
||||
self.assertEqual(flops_count, 45.158)
|
||||
self.assertEqual(params_count, 0.701)
|
||||
|
||||
def test_estimate_separation_modules(self) -> None:
|
||||
fool_add_constant = FoolConvModule()
|
||||
flops_params_cfg = dict(
|
||||
input_shape=(1, 3, 224, 224), spec_modules=['add_constant'])
|
||||
results = estimator.estimate_separation_modules(
|
||||
model=fool_add_constant, flops_params_cfg=flops_params_cfg)
|
||||
self.assertGreater(results['add_constant']['flops'], 0)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
results = estimator.estimate_spec_modules(
|
||||
model=fool_add_constant,
|
||||
resource_args=dict(
|
||||
input_shape=(1, 3, 224, 224), spec_modules=['backbone']))
|
||||
flops_params_cfg = dict(
|
||||
input_shape=(1, 3, 224, 224), spec_modules=['backbone'])
|
||||
results = estimator.estimate_separation_modules(
|
||||
model=fool_add_constant, flops_params_cfg=flops_params_cfg)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
flops_params_cfg = dict(
|
||||
input_shape=(1, 3, 224, 224), spec_modules=[])
|
||||
results = estimator.estimate_separation_modules(
|
||||
model=fool_add_constant, flops_params_cfg=flops_params_cfg)
|
||||
|
||||
def test_estimate_subnet(self) -> None:
|
||||
resource_args = dict(input_shape=(1, 3, 224, 224))
|
||||
flops_params_cfg = dict(input_shape=(1, 3, 224, 224))
|
||||
model = MODELS.build(BACKBONE_CFG)
|
||||
self.sample_choice(model)
|
||||
copied_model = copy.deepcopy(model)
|
||||
|
||||
results = estimator.estimate(
|
||||
model=copied_model, resource_args=resource_args)
|
||||
model=copied_model, flops_params_cfg=flops_params_cfg)
|
||||
flops_count = results['flops']
|
||||
params_count = results['params']
|
||||
|
||||
fix_subnet = export_fix_subnet(model)
|
||||
load_fix_subnet(copied_model, fix_subnet)
|
||||
subnet_results = estimator.estimate(
|
||||
model=copied_model, resource_args=resource_args)
|
||||
model=copied_model, flops_params_cfg=flops_params_cfg)
|
||||
subnet_flops_count = subnet_results['flops']
|
||||
subnet_params_count = subnet_results['params']
|
||||
|
||||
|
@ -188,8 +207,8 @@ class TestResourceEstimator(TestCase):
|
|||
|
||||
# test whether subnet estimate will affect original model
|
||||
copied_model = copy.deepcopy(model)
|
||||
results_after_estimate = \
|
||||
estimator.estimate(model=copied_model, resource_args=resource_args)
|
||||
results_after_estimate = estimator.estimate(
|
||||
model=copied_model, flops_params_cfg=flops_params_cfg)
|
||||
flops_count_after_estimate = results_after_estimate['flops']
|
||||
params_count_after_estimate = results_after_estimate['params']
|
||||
|
||||
|
|
|
@ -112,10 +112,7 @@ class TestEvolutionSearchLoop(TestCase):
|
|||
self.assertEqual(loop.candidates, fake_candidates)
|
||||
|
||||
@patch('mmrazor.engine.runner.evolution_search_loop.export_fix_subnet')
|
||||
@patch(
|
||||
'mmrazor.engine.runner.evolution_search_loop.get_model_complexity_info'
|
||||
)
|
||||
def test_run_epoch(self, mock_flops, mock_export_fix_subnet):
|
||||
def test_run_epoch(self, mock_export_fix_subnet):
|
||||
# test_run_epoch: distributed == False
|
||||
loop_cfg = copy.deepcopy(self.train_cfg)
|
||||
loop_cfg.runner = self.runner
|
||||
|
@ -155,7 +152,7 @@ class TestEvolutionSearchLoop(TestCase):
|
|||
self.runner.work_dir = self.temp_dir
|
||||
fake_subnet = {'1': 'choice1', '2': 'choice2'}
|
||||
loop.model.sample_subnet = MagicMock(return_value=fake_subnet)
|
||||
mock_flops.return_value = (50., 1)
|
||||
loop._check_constraints = MagicMock(return_value=True)
|
||||
mock_export_fix_subnet.return_value = fake_subnet
|
||||
loop.run_epoch()
|
||||
self.assertEqual(len(loop.candidates), 4)
|
||||
|
|
|
@ -192,30 +192,15 @@ class TestGreedySamplerTrainLoop(TestCase):
|
|||
self.assertEqual(subnet, fake_subnet)
|
||||
self.assertEqual(len(loop.top_k_candidates), loop.top_k - 1)
|
||||
|
||||
@patch('mmrazor.engine.runner.subnet_sampler_loop.export_fix_subnet')
|
||||
@patch(
|
||||
'mmrazor.engine.runner.subnet_sampler_loop.get_model_complexity_info')
|
||||
def test_run(self, mock_flops, mock_export_fix_subnet):
|
||||
# test run with flops_range=None
|
||||
def test_run(self):
|
||||
# test run with _check_constraints
|
||||
cfg = copy.deepcopy(self.iter_based_cfg)
|
||||
cfg.experiment_name = 'test_run1'
|
||||
runner = Runner.from_cfg(cfg)
|
||||
fake_subnet = {'1': 'choice1', '2': 'choice2'}
|
||||
runner.model.sample_subnet = MagicMock(return_value=fake_subnet)
|
||||
runner.train()
|
||||
|
||||
self.assertEqual(runner.iter, runner.max_iters)
|
||||
assert os.path.exists(os.path.join(self.temp_dir, 'candidates.pkl'))
|
||||
|
||||
# test run with _check_constraints
|
||||
cfg = copy.deepcopy(self.iter_based_cfg)
|
||||
cfg.experiment_name = 'test_run2'
|
||||
cfg.train_cfg.flops_range = (0, 100)
|
||||
runner = Runner.from_cfg(cfg)
|
||||
fake_subnet = {'1': 'choice1', '2': 'choice2'}
|
||||
runner.model.sample_subnet = MagicMock(return_value=fake_subnet)
|
||||
mock_flops.return_value = (50., 1)
|
||||
mock_export_fix_subnet.return_value = fake_subnet
|
||||
loop = runner.build_train_loop(cfg.train_cfg)
|
||||
loop._check_constraints = MagicMock(return_value=True)
|
||||
runner.train()
|
||||
|
||||
self.assertEqual(runner.iter, runner.max_iters)
|
||||
|
|
|
@ -0,0 +1,40 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from unittest.mock import patch
|
||||
|
||||
from mmrazor.engine.runner.utils import check_subnet_flops
|
||||
|
||||
try:
|
||||
from mmdet.models.detectors import BaseDetector
|
||||
except ImportError:
|
||||
from mmrazor.utils import get_placeholder
|
||||
BaseDetector = get_placeholder('mmdet')
|
||||
|
||||
|
||||
@patch('mmrazor.models.ResourceEstimator')
|
||||
@patch('mmrazor.models.SPOS')
|
||||
def test_check_subnet_flops(mock_model, mock_estimator):
|
||||
# flops_range = None
|
||||
flops_range = None
|
||||
fake_subnet = {'1': 'choice1', '2': 'choice2'}
|
||||
result = check_subnet_flops(mock_model, fake_subnet, mock_estimator,
|
||||
flops_range)
|
||||
assert result is True
|
||||
|
||||
# flops_range is not None
|
||||
# architecturte is BaseDetector
|
||||
flops_range = (0., 100.)
|
||||
mock_model.architecture = BaseDetector
|
||||
fake_results = {'flops': 50.}
|
||||
mock_estimator.estimate.return_value = fake_results
|
||||
result = check_subnet_flops(mock_model, fake_subnet, mock_estimator,
|
||||
flops_range)
|
||||
assert result is True
|
||||
|
||||
# flops_range is not None
|
||||
# architecturte is BaseDetector
|
||||
flops_range = (0., 100.)
|
||||
fake_results = {'flops': -50.}
|
||||
mock_estimator.estimate.return_value = fake_results
|
||||
result = check_subnet_flops(mock_model, fake_subnet, mock_estimator,
|
||||
flops_range)
|
||||
assert result is False
|
Loading…
Reference in New Issue