diff --git a/mmrazor/engine/runner/utils/check.py b/mmrazor/engine/runner/utils/check.py index ad774f64..0b5885d8 100644 --- a/mmrazor/engine/runner/utils/check.py +++ b/mmrazor/engine/runner/utils/check.py @@ -35,9 +35,9 @@ def check_subnet_resources( model_to_check = sliced_model.architecture # type: ignore if isinstance(model_to_check, BaseDetector): - results = estimator.estimate(model=model_to_check.backbone) + results = estimator.estimate(model=model_to_check.backbone, measure_latency=True if 'latency' in constraints_range.keys() else False) else: - results = estimator.estimate(model=model_to_check) + results = estimator.estimate(model=model_to_check, measure_latency=True if 'latency' in constraints_range.keys() else False) for k, v in constraints_range.items(): if not isinstance(v, (list, tuple)): diff --git a/mmrazor/models/task_modules/estimators/resource_estimator.py b/mmrazor/models/task_modules/estimators/resource_estimator.py index ac5292d0..93e3fdb8 100644 --- a/mmrazor/models/task_modules/estimators/resource_estimator.py +++ b/mmrazor/models/task_modules/estimators/resource_estimator.py @@ -95,7 +95,8 @@ class ResourceEstimator(BaseEstimator): def estimate(self, model: torch.nn.Module, flops_params_cfg: dict = None, - latency_cfg: dict = None) -> Dict[str, Union[float, str]]: + latency_cfg: dict = None, + measure_latency: bool = False) -> Dict[str, Union[float, str]]: """Estimate the resources(flops/params/latency) of the given model. This method will first parse the merged :attr:`self.flops_params_cfg` @@ -106,6 +107,7 @@ class ResourceEstimator(BaseEstimator): flops_params_cfg (dict): Cfg for estimating FLOPs and parameters. Default to None. latency_cfg (dict): Cfg for estimating latency. Default to None. + measure_latency (bool): Measure latency or not. Default to False. NOTE: If the `flops_params_cfg` and `latency_cfg` are both None, this method will only estimate FLOPs/params with default settings.