mirror of
https://github.com/open-mmlab/mmrazor.git
synced 2025-06-03 15:02:54 +08:00
* update candidates * update subnet_sampler_loop * update candidate * add readme * rename variable * rename variable * clean * update * add doc string * Revert "[Improvement] Support for candidate multiple dimensional search constraints." * [Improvement] Update Candidate with multi-dim search constraints. (#322) * update doc * add support type * clean code * update candidates * clean * xx * set_resource -> set_score * fix ci bug * py36 lint * fix bug * fix check constrain * py36 ci * redesign candidate * fix pre-commit * update cfg * add build_resource_estimator * fix ci bug * remove runner.epoch in testcase * [Feature] Autoformer architecture and dynamicOPs (#327) * add DynamicSequential * dynamiclayernorm * add dynamic_pathchembed * add DynamicMultiheadAttention and DynamicRelativePosition2D * add channel-level dynamicOP * add autoformer algo * clean notes * adapt channel_mutator * vit fly * fix import * mutable init * remove annotation * add DynamicInputResizer * add unittest for mutables * add OneShotMutableChannelUnit_VIT * clean code * reset unit for vit * remove attr * add autoformer backbone UT * add valuemutator UT * clean code * add autoformer algo UT * update classifier UT * fix test error * ignore * make lint * update * fix lint * mutable_attrs * fix test * fix error * remove DynamicInputResizer * fix test ci * remove InputResizer * rename variables * modify type * Continued improvements of ChannelUnit * fix lint * fix lint * remove OneShotMutableChannelUnit * adjust derived type * combination mixins * clean code * fix sample subnet * search loop fly * more annotations * avoid counter warning and modify batch_augment cfg by gy * restore * source_value_mutables restriction * simply arch_setting api * update * clean * fix ut
45 lines
1.5 KiB
Python
45 lines
1.5 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
from unittest.mock import patch
|
|
|
|
from mmrazor.engine.runner.utils import check_subnet_resources
|
|
|
|
try:
|
|
from mmdet.models.detectors import BaseDetector
|
|
except ImportError:
|
|
from mmrazor.utils import get_placeholder
|
|
BaseDetector = get_placeholder('mmdet')
|
|
|
|
|
|
@patch('mmrazor.models.ResourceEstimator')
|
|
@patch('mmrazor.models.SPOS')
|
|
def test_check_subnet_resources(mock_model, mock_estimator):
|
|
# constraints_range = dict()
|
|
constraints_range = dict()
|
|
fake_subnet = {'1': 'choice1', '2': 'choice2'}
|
|
is_pass, _ = check_subnet_resources(mock_model, fake_subnet,
|
|
mock_estimator, constraints_range)
|
|
assert is_pass is True
|
|
|
|
# constraints_range is not None
|
|
# architecturte is BaseDetector
|
|
constraints_range = dict(flops=(0, 330))
|
|
mock_model.architecture = BaseDetector
|
|
fake_results = {'flops': 50.}
|
|
mock_estimator.estimate.return_value = fake_results
|
|
is_pass, _ = check_subnet_resources(
|
|
mock_model,
|
|
fake_subnet,
|
|
mock_estimator,
|
|
constraints_range,
|
|
)
|
|
assert is_pass is True
|
|
|
|
# constraints_range is not None
|
|
# architecturte is BaseDetector
|
|
constraints_range = dict(flops=(0, 330))
|
|
fake_results = {'flops': -50.}
|
|
mock_estimator.estimate.return_value = fake_results
|
|
is_pass, _ = check_subnet_resources(mock_model, fake_subnet,
|
|
mock_estimator, constraints_range)
|
|
assert is_pass is False
|