2022-09-14 20:39:49 +08:00
|
|
|
# Copyright (c) OpenMMLab. All rights reserved.
|
|
|
|
from unittest.mock import patch
|
|
|
|
|
2022-11-14 13:01:04 +08:00
|
|
|
from mmrazor.engine.runner.utils import check_subnet_resources
|
2022-09-14 20:39:49 +08:00
|
|
|
|
|
|
|
try:
|
|
|
|
from mmdet.models.detectors import BaseDetector
|
|
|
|
except ImportError:
|
|
|
|
from mmrazor.utils import get_placeholder
|
|
|
|
BaseDetector = get_placeholder('mmdet')
|
|
|
|
|
|
|
|
|
|
|
|
@patch('mmrazor.models.ResourceEstimator')
|
|
|
|
@patch('mmrazor.models.SPOS')
|
2022-11-14 13:01:04 +08:00
|
|
|
def test_check_subnet_resources(mock_model, mock_estimator):
|
|
|
|
# constraints_range = dict()
|
|
|
|
constraints_range = dict()
|
2022-09-14 20:39:49 +08:00
|
|
|
fake_subnet = {'1': 'choice1', '2': 'choice2'}
|
2022-11-14 13:01:04 +08:00
|
|
|
is_pass, _ = check_subnet_resources(mock_model, fake_subnet,
|
|
|
|
mock_estimator, constraints_range)
|
|
|
|
assert is_pass is True
|
2022-09-14 20:39:49 +08:00
|
|
|
|
2022-11-14 13:01:04 +08:00
|
|
|
# constraints_range is not None
|
2022-09-14 20:39:49 +08:00
|
|
|
# architecturte is BaseDetector
|
2022-11-14 13:01:04 +08:00
|
|
|
constraints_range = dict(flops=(0, 330))
|
2022-09-14 20:39:49 +08:00
|
|
|
mock_model.architecture = BaseDetector
|
|
|
|
fake_results = {'flops': 50.}
|
|
|
|
mock_estimator.estimate.return_value = fake_results
|
2022-11-14 13:01:04 +08:00
|
|
|
is_pass, _ = check_subnet_resources(
|
|
|
|
mock_model,
|
|
|
|
fake_subnet,
|
|
|
|
mock_estimator,
|
|
|
|
constraints_range,
|
|
|
|
)
|
|
|
|
assert is_pass is True
|
2022-09-14 20:39:49 +08:00
|
|
|
|
2022-11-14 13:01:04 +08:00
|
|
|
# constraints_range is not None
|
2022-09-14 20:39:49 +08:00
|
|
|
# architecturte is BaseDetector
|
2022-11-14 13:01:04 +08:00
|
|
|
constraints_range = dict(flops=(0, 330))
|
2022-09-14 20:39:49 +08:00
|
|
|
fake_results = {'flops': -50.}
|
|
|
|
mock_estimator.estimate.return_value = fake_results
|
2022-11-14 13:01:04 +08:00
|
|
|
is_pass, _ = check_subnet_resources(mock_model, fake_subnet,
|
|
|
|
mock_estimator, constraints_range)
|
|
|
|
assert is_pass is False
|