mirror of
https://github.com/open-mmlab/mmdeploy.git
synced 2025-01-14 08:09:43 +08:00
fix mmcls ut
This commit is contained in:
parent
ce036d547a
commit
27a856637c
@ -98,7 +98,7 @@ def test_visualize(backend_model):
|
||||
results = backend_model.test_step([input_dict])
|
||||
with TemporaryDirectory() as dir:
|
||||
filename = dir + '/tmp.jpg'
|
||||
task_processor.visualize(img, results[0], filename, '')
|
||||
task_processor.visualize(img, results[0], filename, 'window')
|
||||
assert os.path.exists(filename)
|
||||
|
||||
|
||||
|
@ -5,6 +5,7 @@ import torch
|
||||
from mmengine import Config
|
||||
|
||||
from mmdeploy.codebase import import_codebase
|
||||
from mmdeploy.core.rewriters.rewriter_manager import RewriterContext
|
||||
from mmdeploy.utils import Backend, Codebase
|
||||
from mmdeploy.utils.test import WrapModel, check_backend, get_rewrite_outputs
|
||||
|
||||
@ -62,7 +63,6 @@ def test_baseclassifier_forward():
|
||||
from mmcls.models.classifiers import ImageClassifier
|
||||
|
||||
from mmdeploy.codebase.mmcls import models # noqa
|
||||
from mmdeploy.core.rewriters import patch_model
|
||||
|
||||
class DummyClassifier(ImageClassifier):
|
||||
|
||||
@ -75,8 +75,8 @@ def test_baseclassifier_forward():
|
||||
def head(self, x):
|
||||
return x
|
||||
|
||||
def forward(self, batch_inputs, data_samples, mode):
|
||||
return batch_inputs + 1
|
||||
def predict(self, x, data_samples):
|
||||
return x
|
||||
|
||||
backbone_cfg = dict(
|
||||
type='ResNet',
|
||||
@ -86,11 +86,12 @@ def test_baseclassifier_forward():
|
||||
style='pytorch')
|
||||
model = DummyClassifier(backbone_cfg).eval()
|
||||
|
||||
model_output = model(input, None, None)
|
||||
model = patch_model(model, {}, bachend='onnxruntime', data_samples=None)
|
||||
model_output = model(input, None, mode='predict')
|
||||
|
||||
with RewriterContext({}):
|
||||
backend_output = model(input)
|
||||
|
||||
assert model_output == input + 1
|
||||
assert model_output == input
|
||||
assert backend_output == input
|
||||
|
||||
|
||||
@ -186,7 +187,7 @@ def test_vision_transformer_backbone__forward(backend_type: Backend):
|
||||
model_output.reshape(-1),
|
||||
rewrite_output.reshape(-1),
|
||||
rtol=1e-03,
|
||||
atol=1e-03)
|
||||
atol=1e-02)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
Loading…
x
Reference in New Issue
Block a user