fix mmcls ut

This commit is contained in:
grimoire 2022-07-11 19:10:04 +08:00
parent ce036d547a
commit 27a856637c
2 changed files with 10 additions and 9 deletions

View File

@ -98,7 +98,7 @@ def test_visualize(backend_model):
results = backend_model.test_step([input_dict]) results = backend_model.test_step([input_dict])
with TemporaryDirectory() as dir: with TemporaryDirectory() as dir:
filename = dir + '/tmp.jpg' filename = dir + '/tmp.jpg'
task_processor.visualize(img, results[0], filename, '') task_processor.visualize(img, results[0], filename, 'window')
assert os.path.exists(filename) assert os.path.exists(filename)

View File

@ -5,6 +5,7 @@ import torch
from mmengine import Config from mmengine import Config
from mmdeploy.codebase import import_codebase from mmdeploy.codebase import import_codebase
from mmdeploy.core.rewriters.rewriter_manager import RewriterContext
from mmdeploy.utils import Backend, Codebase from mmdeploy.utils import Backend, Codebase
from mmdeploy.utils.test import WrapModel, check_backend, get_rewrite_outputs from mmdeploy.utils.test import WrapModel, check_backend, get_rewrite_outputs
@ -62,7 +63,6 @@ def test_baseclassifier_forward():
from mmcls.models.classifiers import ImageClassifier from mmcls.models.classifiers import ImageClassifier
from mmdeploy.codebase.mmcls import models # noqa from mmdeploy.codebase.mmcls import models # noqa
from mmdeploy.core.rewriters import patch_model
class DummyClassifier(ImageClassifier): class DummyClassifier(ImageClassifier):
@ -75,8 +75,8 @@ def test_baseclassifier_forward():
def head(self, x): def head(self, x):
return x return x
def forward(self, batch_inputs, data_samples, mode): def predict(self, x, data_samples):
return batch_inputs + 1 return x
backbone_cfg = dict( backbone_cfg = dict(
type='ResNet', type='ResNet',
@ -86,11 +86,12 @@ def test_baseclassifier_forward():
style='pytorch') style='pytorch')
model = DummyClassifier(backbone_cfg).eval() model = DummyClassifier(backbone_cfg).eval()
model_output = model(input, None, None) model_output = model(input, None, mode='predict')
model = patch_model(model, {}, bachend='onnxruntime', data_samples=None)
backend_output = model(input)
assert model_output == input + 1 with RewriterContext({}):
backend_output = model(input)
assert model_output == input
assert backend_output == input assert backend_output == input
@ -186,7 +187,7 @@ def test_vision_transformer_backbone__forward(backend_type: Backend):
model_output.reshape(-1), model_output.reshape(-1),
rewrite_output.reshape(-1), rewrite_output.reshape(-1),
rtol=1e-03, rtol=1e-03,
atol=1e-03) atol=1e-02)
@pytest.mark.parametrize( @pytest.mark.parametrize(