fix bug for estimating model latency in NAS

pull/568/head
李泽贤 2023-08-17 20:20:23 +08:00
parent 90c7af1fdf
commit c1a8013c3b
2 changed files with 5 additions and 3 deletions

View File

@ -35,9 +35,9 @@ def check_subnet_resources(
model_to_check = sliced_model.architecture # type: ignore
if isinstance(model_to_check, BaseDetector):
results = estimator.estimate(model=model_to_check.backbone)
results = estimator.estimate(model=model_to_check.backbone, measure_latency=True if 'latency' in constraints_range.keys() else False)
else:
results = estimator.estimate(model=model_to_check)
results = estimator.estimate(model=model_to_check, measure_latency=True if 'latency' in constraints_range.keys() else False)
for k, v in constraints_range.items():
if not isinstance(v, (list, tuple)):

View File

@ -95,7 +95,8 @@ class ResourceEstimator(BaseEstimator):
def estimate(self,
model: torch.nn.Module,
flops_params_cfg: dict = None,
latency_cfg: dict = None) -> Dict[str, Union[float, str]]:
latency_cfg: dict = None,
measure_latency: bool = False) -> Dict[str, Union[float, str]]:
"""Estimate the resources(flops/params/latency) of the given model.
This method will first parse the merged :attr:`self.flops_params_cfg`
@ -106,6 +107,7 @@ class ResourceEstimator(BaseEstimator):
flops_params_cfg (dict): Cfg for estimating FLOPs and parameters.
Default to None.
latency_cfg (dict): Cfg for estimating latency. Default to None.
measure_latency (bool): Measure latency or not. Default to False.
NOTE: If the `flops_params_cfg` and `latency_cfg` are both None,
this method will only estimate FLOPs/params with default settings.