fix bug for estimating model latency in NAS
parent
90c7af1fdf
commit
c1a8013c3b
|
@ -35,9 +35,9 @@ def check_subnet_resources(
|
|||
|
||||
model_to_check = sliced_model.architecture # type: ignore
|
||||
if isinstance(model_to_check, BaseDetector):
|
||||
results = estimator.estimate(model=model_to_check.backbone)
|
||||
results = estimator.estimate(model=model_to_check.backbone, measure_latency=True if 'latency' in constraints_range.keys() else False)
|
||||
else:
|
||||
results = estimator.estimate(model=model_to_check)
|
||||
results = estimator.estimate(model=model_to_check, measure_latency=True if 'latency' in constraints_range.keys() else False)
|
||||
|
||||
for k, v in constraints_range.items():
|
||||
if not isinstance(v, (list, tuple)):
|
||||
|
|
|
@ -95,7 +95,8 @@ class ResourceEstimator(BaseEstimator):
|
|||
def estimate(self,
|
||||
model: torch.nn.Module,
|
||||
flops_params_cfg: dict = None,
|
||||
latency_cfg: dict = None) -> Dict[str, Union[float, str]]:
|
||||
latency_cfg: dict = None,
|
||||
measure_latency: bool = False) -> Dict[str, Union[float, str]]:
|
||||
"""Estimate the resources(flops/params/latency) of the given model.
|
||||
|
||||
This method will first parse the merged :attr:`self.flops_params_cfg`
|
||||
|
@ -106,6 +107,7 @@ class ResourceEstimator(BaseEstimator):
|
|||
flops_params_cfg (dict): Cfg for estimating FLOPs and parameters.
|
||||
Default to None.
|
||||
latency_cfg (dict): Cfg for estimating latency. Default to None.
|
||||
measure_latency (bool): Measure latency or not. Default to False.
|
||||
|
||||
NOTE: If the `flops_params_cfg` and `latency_cfg` are both None,
|
||||
this method will only estimate FLOPs/params with default settings.
|
||||
|
|
Loading…
Reference in New Issue