diff --git a/tests/test_codebase/test_mmcls/test_classification.py b/tests/test_codebase/test_mmcls/test_classification.py index 2b3291aa4..067009d49 100644 --- a/tests/test_codebase/test_mmcls/test_classification.py +++ b/tests/test_codebase/test_mmcls/test_classification.py @@ -98,7 +98,7 @@ def test_visualize(backend_model): results = backend_model.test_step([input_dict]) with TemporaryDirectory() as dir: filename = dir + '/tmp.jpg' - task_processor.visualize(img, results[0], filename, '') + task_processor.visualize(img, results[0], filename, 'window') assert os.path.exists(filename) diff --git a/tests/test_codebase/test_mmcls/test_mmcls_models.py b/tests/test_codebase/test_mmcls/test_mmcls_models.py index b43df9a05..64738d91c 100644 --- a/tests/test_codebase/test_mmcls/test_mmcls_models.py +++ b/tests/test_codebase/test_mmcls/test_mmcls_models.py @@ -5,6 +5,7 @@ import torch from mmengine import Config from mmdeploy.codebase import import_codebase +from mmdeploy.core.rewriters.rewriter_manager import RewriterContext from mmdeploy.utils import Backend, Codebase from mmdeploy.utils.test import WrapModel, check_backend, get_rewrite_outputs @@ -62,7 +63,6 @@ def test_baseclassifier_forward(): from mmcls.models.classifiers import ImageClassifier from mmdeploy.codebase.mmcls import models # noqa - from mmdeploy.core.rewriters import patch_model class DummyClassifier(ImageClassifier): @@ -75,8 +75,8 @@ def test_baseclassifier_forward(): def head(self, x): return x - def forward(self, batch_inputs, data_samples, mode): - return batch_inputs + 1 + def predict(self, x, data_samples): + return x backbone_cfg = dict( type='ResNet', @@ -86,11 +86,12 @@ def test_baseclassifier_forward(): style='pytorch') model = DummyClassifier(backbone_cfg).eval() - model_output = model(input, None, None) - model = patch_model(model, {}, bachend='onnxruntime', data_samples=None) - backend_output = model(input) + model_output = model(input, None, mode='predict') - assert model_output == input + 1 + with RewriterContext({}): + backend_output = model(input) + + assert model_output == input assert backend_output == input @@ -186,7 +187,7 @@ def test_vision_transformer_backbone__forward(backend_type: Backend): model_output.reshape(-1), rewrite_output.reshape(-1), rtol=1e-03, - atol=1e-03) + atol=1e-02) @pytest.mark.parametrize(