mirror of
https://github.com/open-mmlab/mmdeploy.git
synced 2025-01-14 08:09:43 +08:00
fix mmcls ut
This commit is contained in:
parent
ce036d547a
commit
27a856637c
@ -98,7 +98,7 @@ def test_visualize(backend_model):
|
|||||||
results = backend_model.test_step([input_dict])
|
results = backend_model.test_step([input_dict])
|
||||||
with TemporaryDirectory() as dir:
|
with TemporaryDirectory() as dir:
|
||||||
filename = dir + '/tmp.jpg'
|
filename = dir + '/tmp.jpg'
|
||||||
task_processor.visualize(img, results[0], filename, '')
|
task_processor.visualize(img, results[0], filename, 'window')
|
||||||
assert os.path.exists(filename)
|
assert os.path.exists(filename)
|
||||||
|
|
||||||
|
|
||||||
|
@ -5,6 +5,7 @@ import torch
|
|||||||
from mmengine import Config
|
from mmengine import Config
|
||||||
|
|
||||||
from mmdeploy.codebase import import_codebase
|
from mmdeploy.codebase import import_codebase
|
||||||
|
from mmdeploy.core.rewriters.rewriter_manager import RewriterContext
|
||||||
from mmdeploy.utils import Backend, Codebase
|
from mmdeploy.utils import Backend, Codebase
|
||||||
from mmdeploy.utils.test import WrapModel, check_backend, get_rewrite_outputs
|
from mmdeploy.utils.test import WrapModel, check_backend, get_rewrite_outputs
|
||||||
|
|
||||||
@ -62,7 +63,6 @@ def test_baseclassifier_forward():
|
|||||||
from mmcls.models.classifiers import ImageClassifier
|
from mmcls.models.classifiers import ImageClassifier
|
||||||
|
|
||||||
from mmdeploy.codebase.mmcls import models # noqa
|
from mmdeploy.codebase.mmcls import models # noqa
|
||||||
from mmdeploy.core.rewriters import patch_model
|
|
||||||
|
|
||||||
class DummyClassifier(ImageClassifier):
|
class DummyClassifier(ImageClassifier):
|
||||||
|
|
||||||
@ -75,8 +75,8 @@ def test_baseclassifier_forward():
|
|||||||
def head(self, x):
|
def head(self, x):
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def forward(self, batch_inputs, data_samples, mode):
|
def predict(self, x, data_samples):
|
||||||
return batch_inputs + 1
|
return x
|
||||||
|
|
||||||
backbone_cfg = dict(
|
backbone_cfg = dict(
|
||||||
type='ResNet',
|
type='ResNet',
|
||||||
@ -86,11 +86,12 @@ def test_baseclassifier_forward():
|
|||||||
style='pytorch')
|
style='pytorch')
|
||||||
model = DummyClassifier(backbone_cfg).eval()
|
model = DummyClassifier(backbone_cfg).eval()
|
||||||
|
|
||||||
model_output = model(input, None, None)
|
model_output = model(input, None, mode='predict')
|
||||||
model = patch_model(model, {}, bachend='onnxruntime', data_samples=None)
|
|
||||||
backend_output = model(input)
|
|
||||||
|
|
||||||
assert model_output == input + 1
|
with RewriterContext({}):
|
||||||
|
backend_output = model(input)
|
||||||
|
|
||||||
|
assert model_output == input
|
||||||
assert backend_output == input
|
assert backend_output == input
|
||||||
|
|
||||||
|
|
||||||
@ -186,7 +187,7 @@ def test_vision_transformer_backbone__forward(backend_type: Backend):
|
|||||||
model_output.reshape(-1),
|
model_output.reshape(-1),
|
||||||
rewrite_output.reshape(-1),
|
rewrite_output.reshape(-1),
|
||||||
rtol=1e-03,
|
rtol=1e-03,
|
||||||
atol=1e-03)
|
atol=1e-02)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user