rename method

This commit is contained in:
grimoire 2022-06-29 14:14:03 +08:00
parent 41a97746cf
commit 8c37f79f06
25 changed files with 111 additions and 111 deletions

View File

@ -57,7 +57,7 @@ def create_calib_input_data(calib_file: str,
apply_marks = cfg_apply_marks(deploy_cfg) apply_marks = cfg_apply_marks(deploy_cfg)
model = task_processor.init_pytorch_model(model_checkpoint) model = task_processor.build_pytorch_model(model_checkpoint)
dataset = task_processor.build_dataset(dataset_cfg, dataset_type) dataset = task_processor.build_dataset(dataset_cfg, dataset_type)
# patch model # patch model

View File

@ -42,7 +42,7 @@ def inference_model(model_cfg: Union[str, mmcv.Config],
from mmdeploy.apis.utils import build_task_processor from mmdeploy.apis.utils import build_task_processor
task_processor = build_task_processor(model_cfg, deploy_cfg, device) task_processor = build_task_processor(model_cfg, deploy_cfg, device)
model = task_processor.init_backend_model(backend_files) model = task_processor.build_backend_model(backend_files)
input_shape = get_input_shape(deploy_cfg) input_shape = get_input_shape(deploy_cfg)
model_inputs, _ = task_processor.create_input(img, input_shape) model_inputs, _ = task_processor.create_input(img, input_shape)

View File

@ -59,7 +59,7 @@ def torch2onnx(img: Any,
from mmdeploy.apis import build_task_processor from mmdeploy.apis import build_task_processor
task_processor = build_task_processor(model_cfg, deploy_cfg, device) task_processor = build_task_processor(model_cfg, deploy_cfg, device)
torch_model = task_processor.init_pytorch_model(model_checkpoint) torch_model = task_processor.build_pytorch_model(model_checkpoint)
data, model_inputs = task_processor.create_input( data, model_inputs = task_processor.create_input(
img, img,
input_shape, input_shape,

View File

@ -41,7 +41,7 @@ def torch2torchscript(img: Any,
from mmdeploy.apis import build_task_processor from mmdeploy.apis import build_task_processor
task_processor = build_task_processor(model_cfg, deploy_cfg, device) task_processor = build_task_processor(model_cfg, deploy_cfg, device)
torch_model = task_processor.init_pytorch_model(model_checkpoint) torch_model = task_processor.build_pytorch_model(model_checkpoint)
_, model_inputs = task_processor.create_input(img, input_shape) _, model_inputs = task_processor.create_input(img, input_shape)
if not isinstance(model_inputs, torch.Tensor): if not isinstance(model_inputs, torch.Tensor):
model_inputs = model_inputs[0] model_inputs = model_inputs[0]

View File

@ -62,9 +62,9 @@ def visualize_model(model_cfg: Union[str, mmcv.Config],
list should be str' list should be str'
if backend == Backend.PYTORCH: if backend == Backend.PYTORCH:
model = task_processor.init_pytorch_model(model[0]) model = task_processor.build_pytorch_model(model[0])
else: else:
model = task_processor.init_backend_model(model) model = task_processor.build_backend_model(model)
model_inputs, _ = task_processor.create_input(img, input_shape) model_inputs, _ = task_processor.create_input(img, input_shape)
with torch.no_grad(): with torch.no_grad():

View File

@ -54,7 +54,7 @@ class BaseTask(metaclass=ABCMeta):
self.visualizer = self.model_cfg.visualizer self.visualizer = self.model_cfg.visualizer
@abstractmethod @abstractmethod
def init_backend_model(self, def build_backend_model(self,
model_files: Sequence[str] = None, model_files: Sequence[str] = None,
**kwargs) -> torch.nn.Module: **kwargs) -> torch.nn.Module:
"""Initialize backend model. """Initialize backend model.
@ -67,7 +67,7 @@ class BaseTask(metaclass=ABCMeta):
""" """
pass pass
def init_pytorch_model(self, def build_pytorch_model(self,
model_checkpoint: Optional[str] = None, model_checkpoint: Optional[str] = None,
cfg_options: Optional[Dict] = None, cfg_options: Optional[Dict] = None,
**kwargs) -> torch.nn.Module: **kwargs) -> torch.nn.Module:

View File

@ -105,7 +105,7 @@ class Classification(BaseTask):
super(Classification, self).__init__(model_cfg, deploy_cfg, device, super(Classification, self).__init__(model_cfg, deploy_cfg, device,
experiment_name) experiment_name)
def init_backend_model(self, def build_backend_model(self,
model_files: Sequence[str] = None, model_files: Sequence[str] = None,
**kwargs) -> torch.nn.Module: **kwargs) -> torch.nn.Module:
"""Initialize backend model. """Initialize backend model.

View File

@ -58,7 +58,7 @@ class ObjectDetection(BaseTask):
device: str) -> None: device: str) -> None:
super().__init__(model_cfg, deploy_cfg, device) super().__init__(model_cfg, deploy_cfg, device)
def init_backend_model(self, def build_backend_model(self,
model_files: Optional[str] = None, model_files: Optional[str] = None,
**kwargs) -> torch.nn.Module: **kwargs) -> torch.nn.Module:
"""Initialize backend model. """Initialize backend model.
@ -74,7 +74,7 @@ class ObjectDetection(BaseTask):
model_files, self.model_cfg, self.deploy_cfg, device=self.device) model_files, self.model_cfg, self.deploy_cfg, device=self.device)
return model.eval() return model.eval()
def init_pytorch_model(self, def build_pytorch_model(self,
model_checkpoint: Optional[str] = None, model_checkpoint: Optional[str] = None,
cfg_options: Optional[Dict] = None, cfg_options: Optional[Dict] = None,
**kwargs) -> torch.nn.Module: **kwargs) -> torch.nn.Module:

View File

@ -23,7 +23,7 @@ class VoxelDetection(BaseTask):
device: str): device: str):
super().__init__(model_cfg, deploy_cfg, device) super().__init__(model_cfg, deploy_cfg, device)
def init_backend_model(self, def build_backend_model(self,
model_files: Sequence[str] = None, model_files: Sequence[str] = None,
**kwargs) -> torch.nn.Module: **kwargs) -> torch.nn.Module:
"""Initialize backend model. """Initialize backend model.
@ -39,7 +39,7 @@ class VoxelDetection(BaseTask):
model_files, self.model_cfg, self.deploy_cfg, device=self.device) model_files, self.model_cfg, self.deploy_cfg, device=self.device)
return model return model
def init_pytorch_model(self, def build_pytorch_model(self,
model_checkpoint: Optional[str] = None, model_checkpoint: Optional[str] = None,
cfg_options: Optional[Dict] = None, cfg_options: Optional[Dict] = None,
**kwargs) -> torch.nn.Module: **kwargs) -> torch.nn.Module:

View File

@ -76,7 +76,7 @@ class SuperResolution(BaseTask):
device: str): device: str):
super().__init__(model_cfg, deploy_cfg, device) super().__init__(model_cfg, deploy_cfg, device)
def init_backend_model(self, def build_backend_model(self,
model_files: Sequence[str] = None, model_files: Sequence[str] = None,
**kwargs) -> torch.nn.Module: **kwargs) -> torch.nn.Module:
"""Initialize backend model. """Initialize backend model.
@ -92,7 +92,7 @@ class SuperResolution(BaseTask):
model_files, self.model_cfg, self.deploy_cfg, device=self.device) model_files, self.model_cfg, self.deploy_cfg, device=self.device)
return model return model
def init_pytorch_model(self, def build_pytorch_model(self,
model_checkpoint: Optional[str] = None, model_checkpoint: Optional[str] = None,
**kwargs) -> torch.nn.Module: **kwargs) -> torch.nn.Module:
"""Initialize torch model. """Initialize torch model.

View File

@ -63,7 +63,7 @@ class TextDetection(BaseTask):
device: str): device: str):
super(TextDetection, self).__init__(model_cfg, deploy_cfg, device) super(TextDetection, self).__init__(model_cfg, deploy_cfg, device)
def init_backend_model(self, def build_backend_model(self,
model_files: Optional[str] = None, model_files: Optional[str] = None,
**kwargs) -> torch.nn.Module: **kwargs) -> torch.nn.Module:
"""Initialize backend model. """Initialize backend model.
@ -79,7 +79,7 @@ class TextDetection(BaseTask):
model_files, self.model_cfg, self.deploy_cfg, device=self.device) model_files, self.model_cfg, self.deploy_cfg, device=self.device)
return model.eval() return model.eval()
def init_pytorch_model(self, def build_pytorch_model(self,
model_checkpoint: Optional[str] = None, model_checkpoint: Optional[str] = None,
cfg_options: Optional[Dict] = None, cfg_options: Optional[Dict] = None,
**kwargs) -> torch.nn.Module: **kwargs) -> torch.nn.Module:

View File

@ -75,7 +75,7 @@ class TextRecognition(BaseTask):
device: str): device: str):
super(TextRecognition, self).__init__(model_cfg, deploy_cfg, device) super(TextRecognition, self).__init__(model_cfg, deploy_cfg, device)
def init_backend_model(self, def build_backend_model(self,
model_files: Optional[str] = None, model_files: Optional[str] = None,
**kwargs) -> torch.nn.Module: **kwargs) -> torch.nn.Module:
"""Initialize backend model. """Initialize backend model.
@ -91,7 +91,7 @@ class TextRecognition(BaseTask):
model_files, self.model_cfg, self.deploy_cfg, device=self.device) model_files, self.model_cfg, self.deploy_cfg, device=self.device)
return model.eval() return model.eval()
def init_pytorch_model(self, def build_pytorch_model(self,
model_checkpoint: Optional[str] = None, model_checkpoint: Optional[str] = None,
cfg_options: Optional[Dict] = None, cfg_options: Optional[Dict] = None,
**kwargs) -> torch.nn.Module: **kwargs) -> torch.nn.Module:

View File

@ -86,7 +86,7 @@ class PoseDetection(BaseTask):
device: str): device: str):
super().__init__(model_cfg, deploy_cfg, device) super().__init__(model_cfg, deploy_cfg, device)
def init_backend_model(self, def build_backend_model(self,
model_files: Sequence[str] = None, model_files: Sequence[str] = None,
**kwargs) -> torch.nn.Module: **kwargs) -> torch.nn.Module:
"""Initialize backend model. """Initialize backend model.
@ -102,7 +102,7 @@ class PoseDetection(BaseTask):
model_files, self.model_cfg, self.deploy_cfg, device=self.device) model_files, self.model_cfg, self.deploy_cfg, device=self.device)
return model.eval() return model.eval()
def init_pytorch_model(self, def build_pytorch_model(self,
model_checkpoint: Optional[str] = None, model_checkpoint: Optional[str] = None,
**kwargs) -> torch.nn.Module: **kwargs) -> torch.nn.Module:
"""Initialize torch model. """Initialize torch model.

View File

@ -85,7 +85,7 @@ class RotatedDetection(BaseTask):
device: str): device: str):
super(RotatedDetection, self).__init__(model_cfg, deploy_cfg, device) super(RotatedDetection, self).__init__(model_cfg, deploy_cfg, device)
def init_backend_model(self, def build_backend_model(self,
model_files: Optional[str] = None, model_files: Optional[str] = None,
**kwargs) -> torch.nn.Module: **kwargs) -> torch.nn.Module:
"""Initialize backend model. """Initialize backend model.
@ -101,7 +101,7 @@ class RotatedDetection(BaseTask):
model_files, self.model_cfg, self.deploy_cfg, device=self.device) model_files, self.model_cfg, self.deploy_cfg, device=self.device)
return model.eval() return model.eval()
def init_pytorch_model(self, def build_pytorch_model(self,
model_checkpoint: Optional[str] = None, model_checkpoint: Optional[str] = None,
cfg_options: Optional[Dict] = None, cfg_options: Optional[Dict] = None,
**kwargs) -> torch.nn.Module: **kwargs) -> torch.nn.Module:

View File

@ -57,7 +57,7 @@ class Segmentation(BaseTask):
device: str): device: str):
super(Segmentation, self).__init__(model_cfg, deploy_cfg, device) super(Segmentation, self).__init__(model_cfg, deploy_cfg, device)
def init_backend_model(self, def build_backend_model(self,
model_files: Optional[str] = None, model_files: Optional[str] = None,
**kwargs) -> torch.nn.Module: **kwargs) -> torch.nn.Module:
"""Initialize backend model. """Initialize backend model.
@ -73,7 +73,7 @@ class Segmentation(BaseTask):
model_files, self.model_cfg, self.deploy_cfg, device=self.device) model_files, self.model_cfg, self.deploy_cfg, device=self.device)
return model.eval() return model.eval()
def init_pytorch_model(self, def build_pytorch_model(self,
model_checkpoint: Optional[str] = None, model_checkpoint: Optional[str] = None,
cfg_options: Optional[Dict] = None, cfg_options: Optional[Dict] = None,
**kwargs) -> torch.nn.Module: **kwargs) -> torch.nn.Module:

View File

@ -40,7 +40,7 @@ img = np.random.rand(*img_shape, 3)
@pytest.mark.parametrize('from_mmrazor', [True, False, '123', 0]) @pytest.mark.parametrize('from_mmrazor', [True, False, '123', 0])
def test_init_pytorch_model(from_mmrazor: Any): def test_build_pytorch_model(from_mmrazor: Any):
from mmcls.models.classifiers.base import BaseClassifier from mmcls.models.classifiers.base import BaseClassifier
if from_mmrazor is False: if from_mmrazor is False:
_task_processor = task_processor _task_processor = task_processor
@ -66,7 +66,7 @@ def test_init_pytorch_model(from_mmrazor: Any):
assert from_mmrazor == _task_processor.from_mmrazor assert from_mmrazor == _task_processor.from_mmrazor
if from_mmrazor: if from_mmrazor:
pytest.importorskip('mmrazor', reason='mmrazor is not installed.') pytest.importorskip('mmrazor', reason='mmrazor is not installed.')
model = _task_processor.init_pytorch_model(None) model = _task_processor.build_pytorch_model(None)
assert isinstance(model, BaseClassifier) assert isinstance(model, BaseClassifier)
@ -79,12 +79,12 @@ def backend_model():
'output': torch.rand(1, num_classes), 'output': torch.rand(1, num_classes),
}) })
yield task_processor.init_backend_model(['']) yield task_processor.build_backend_model([''])
wrapper.recover() wrapper.recover()
def test_init_backend_model(backend_model): def test_build_backend_model(backend_model):
assert isinstance(backend_model, torch.nn.Module) assert isinstance(backend_model, torch.nn.Module)

View File

@ -51,7 +51,7 @@ img = np.random.rand(*img_shape, 3)
@pytest.mark.parametrize('from_mmrazor', [True, False, '123', 0]) @pytest.mark.parametrize('from_mmrazor', [True, False, '123', 0])
def test_init_pytorch_model(from_mmrazor: Any): def test_build_pytorch_model(from_mmrazor: Any):
from mmdet.models import BaseDetector from mmdet.models import BaseDetector
if from_mmrazor is False: if from_mmrazor is False:
_task_processor = task_processor _task_processor = task_processor
@ -77,7 +77,7 @@ def test_init_pytorch_model(from_mmrazor: Any):
assert from_mmrazor == _task_processor.from_mmrazor assert from_mmrazor == _task_processor.from_mmrazor
if from_mmrazor: if from_mmrazor:
pytest.importorskip('mmrazor', reason='mmrazor is not installed.') pytest.importorskip('mmrazor', reason='mmrazor is not installed.')
model = _task_processor.init_pytorch_model(None) model = _task_processor.build_pytorch_model(None)
assert isinstance(model, BaseDetector) assert isinstance(model, BaseDetector)
@ -91,12 +91,12 @@ def backend_model():
'labels': torch.rand(1, 10) 'labels': torch.rand(1, 10)
}) })
yield task_processor.init_backend_model(['']) yield task_processor.build_backend_model([''])
wrapper.recover() wrapper.recover()
def test_init_backend_model(backend_model): def test_build_backend_model(backend_model):
from mmdeploy.codebase.mmdet.deploy.object_detection_model import \ from mmdeploy.codebase.mmdet.deploy.object_detection_model import \
End2EndModel End2EndModel
assert isinstance(backend_model, End2EndModel) assert isinstance(backend_model, End2EndModel)
@ -131,7 +131,7 @@ def test_create_input(device):
def test_run_inference(backend_model): def test_run_inference(backend_model):
torch_model = task_processor.init_pytorch_model(None) torch_model = task_processor.build_pytorch_model(None)
input_dict, _ = task_processor.create_input(img, input_shape=img_shape) input_dict, _ = task_processor.create_input(img, input_shape=img_shape)
torch_results = task_processor.run_inference(torch_model, input_dict) torch_results = task_processor.run_inference(torch_model, input_dict)
backend_results = task_processor.run_inference(backend_model, input_dict) backend_results = task_processor.run_inference(backend_model, input_dict)

View File

@ -39,9 +39,9 @@ onnx_file = NamedTemporaryFile(suffix='.onnx').name
task_processor = build_task_processor(model_cfg, deploy_cfg, 'cpu') task_processor = build_task_processor(model_cfg, deploy_cfg, 'cpu')
def test_init_pytorch_model(): def test_build_pytorch_model():
from mmdet3d.models import Base3DDetector from mmdet3d.models import Base3DDetector
model = task_processor.init_pytorch_model(None) model = task_processor.build_pytorch_model(None)
assert isinstance(model, Base3DDetector) assert isinstance(model, Base3DDetector)
@ -57,12 +57,12 @@ def backend_model():
'dir_scores': torch.rand(1, 12, 32, 32) 'dir_scores': torch.rand(1, 12, 32, 32)
}) })
yield task_processor.init_backend_model(['']) yield task_processor.build_backend_model([''])
wrapper.recover() wrapper.recover()
def test_init_backend_model(backend_model): def test_build_backend_model(backend_model):
from mmdeploy.codebase.mmdet3d.deploy.voxel_detection_model import \ from mmdeploy.codebase.mmdet3d.deploy.voxel_detection_model import \
VoxelDetectionModel VoxelDetectionModel
assert isinstance(backend_model, VoxelDetectionModel) assert isinstance(backend_model, VoxelDetectionModel)
@ -83,7 +83,7 @@ def test_create_input(device):
reason='Only support GPU test', condition=not torch.cuda.is_available()) reason='Only support GPU test', condition=not torch.cuda.is_available())
def test_run_inference(backend_model): def test_run_inference(backend_model):
task_processor.device = 'cuda:0' task_processor.device = 'cuda:0'
torch_model = task_processor.init_pytorch_model(None) torch_model = task_processor.build_pytorch_model(None)
input_dict, _ = task_processor.create_input(pcd_path) input_dict, _ = task_processor.create_input(pcd_path)
torch_results = task_processor.run_inference(torch_model, input_dict) torch_results = task_processor.run_inference(torch_model, input_dict)
backend_results = task_processor.run_inference(backend_model, input_dict) backend_results = task_processor.run_inference(backend_model, input_dict)
@ -98,7 +98,7 @@ def test_run_inference(backend_model):
def test_visualize(): def test_visualize():
task_processor.device = 'cuda:0' task_processor.device = 'cuda:0'
input_dict, _ = task_processor.create_input(pcd_path) input_dict, _ = task_processor.create_input(pcd_path)
torch_model = task_processor.init_pytorch_model(None) torch_model = task_processor.build_pytorch_model(None)
results = task_processor.run_inference(torch_model, input_dict) results = task_processor.run_inference(torch_model, input_dict)
with TemporaryDirectory() as dir: with TemporaryDirectory() as dir:
filename = dir + 'tmp.bin' filename = dir + 'tmp.bin'

View File

@ -37,8 +37,8 @@ onnx_file = NamedTemporaryFile(suffix='.onnx').name
task_processor = build_task_processor(model_cfg, deploy_cfg, 'cpu') task_processor = build_task_processor(model_cfg, deploy_cfg, 'cpu')
def test_init_pytorch_model(): def test_build_pytorch_model():
torch_model = task_processor.init_pytorch_model(None) torch_model = task_processor.build_pytorch_model(None)
assert torch_model is not None assert torch_model is not None
@ -51,12 +51,12 @@ def backend_model():
'output': torch.rand(3, 50, 50), 'output': torch.rand(3, 50, 50),
}) })
yield task_processor.init_backend_model(['']) yield task_processor.build_backend_model([''])
wrapper.recover() wrapper.recover()
def test_init_backend_model(backend_model): def test_build_backend_model(backend_model):
assert backend_model is not None assert backend_model is not None

View File

@ -37,10 +37,10 @@ img_shape = (32, 32)
img = np.random.rand(*img_shape, 3).astype(np.uint8) img = np.random.rand(*img_shape, 3).astype(np.uint8)
def test_init_pytorch_model(): def test_build_pytorch_model():
from mmocr.models.textdet.detectors.single_stage_text_detector import \ from mmocr.models.textdet.detectors.single_stage_text_detector import \
SingleStageDetector SingleStageDetector
model = task_processor.init_pytorch_model(None) model = task_processor.build_pytorch_model(None)
assert isinstance(model, SingleStageDetector) assert isinstance(model, SingleStageDetector)
@ -53,12 +53,12 @@ def backend_model():
'output': torch.rand(1, 3, *img_shape), 'output': torch.rand(1, 3, *img_shape),
}) })
yield task_processor.init_backend_model(['']) yield task_processor.build_backend_model([''])
wrapper.recover() wrapper.recover()
def test_init_backend_model(backend_model): def test_build_backend_model(backend_model):
assert isinstance(backend_model, torch.nn.Module) assert isinstance(backend_model, torch.nn.Module)

View File

@ -37,9 +37,9 @@ img_shape = (32, 32)
img = np.random.rand(*img_shape, 3).astype(np.uint8) img = np.random.rand(*img_shape, 3).astype(np.uint8)
def test_init_pytorch_model(): def test_build_pytorch_model():
from mmocr.models.textrecog.recognizer import BaseRecognizer from mmocr.models.textrecog.recognizer import BaseRecognizer
model = task_processor.init_pytorch_model(None) model = task_processor.build_pytorch_model(None)
assert isinstance(model, BaseRecognizer) assert isinstance(model, BaseRecognizer)
@ -52,12 +52,12 @@ def backend_model():
'output': torch.rand(1, 9, 37), 'output': torch.rand(1, 9, 37),
}) })
yield task_processor.init_backend_model(['']) yield task_processor.build_backend_model([''])
wrapper.recover() wrapper.recover()
def test_init_backend_model(backend_model): def test_build_backend_model(backend_model):
assert isinstance(backend_model, torch.nn.Module) assert isinstance(backend_model, torch.nn.Module)

View File

@ -65,9 +65,9 @@ def test_create_input():
assert isinstance(inputs, tuple) and len(inputs) == 2 assert isinstance(inputs, tuple) and len(inputs) == 2
def test_init_pytorch_model(): def test_build_pytorch_model():
from mmpose.models.detectors.base import BasePose from mmpose.models.detectors.base import BasePose
model = task_processor.init_pytorch_model(None) model = task_processor.build_pytorch_model(None)
assert isinstance(model, BasePose) assert isinstance(model, BasePose)
@ -80,12 +80,12 @@ def backend_model():
'output': torch.rand(1, num_output_channels, *heatmap_shape), 'output': torch.rand(1, num_output_channels, *heatmap_shape),
}) })
yield task_processor.init_backend_model(['']) yield task_processor.build_backend_model([''])
wrapper.recover() wrapper.recover()
def test_init_backend_model(backend_model): def test_build_backend_model(backend_model):
assert isinstance(backend_model, torch.nn.Module) assert isinstance(backend_model, torch.nn.Module)

View File

@ -48,9 +48,9 @@ img_shape = (32, 32)
img = np.random.rand(*img_shape, 3) img = np.random.rand(*img_shape, 3)
def test_init_pytorch_model(): def test_build_pytorch_model():
from mmrotate.models import RotatedBaseDetector from mmrotate.models import RotatedBaseDetector
model = task_processor.init_pytorch_model(None) model = task_processor.build_pytorch_model(None)
assert isinstance(model, RotatedBaseDetector) assert isinstance(model, RotatedBaseDetector)
@ -64,12 +64,12 @@ def backend_model():
'labels': torch.rand(1, 10) 'labels': torch.rand(1, 10)
}) })
yield task_processor.init_backend_model(['']) yield task_processor.build_backend_model([''])
wrapper.recover() wrapper.recover()
def test_init_backend_model(backend_model): def test_build_backend_model(backend_model):
from mmdeploy.codebase.mmrotate.deploy.rotated_detection_model import \ from mmdeploy.codebase.mmrotate.deploy.rotated_detection_model import \
End2EndModel End2EndModel
assert isinstance(backend_model, End2EndModel) assert isinstance(backend_model, End2EndModel)
@ -85,7 +85,7 @@ def test_create_input(device):
def test_run_inference(backend_model): def test_run_inference(backend_model):
torch_model = task_processor.init_pytorch_model(None) torch_model = task_processor.build_pytorch_model(None)
input_dict, _ = task_processor.create_input(img, input_shape=img_shape) input_dict, _ = task_processor.create_input(img, input_shape=img_shape)
torch_results = task_processor.run_inference(torch_model, input_dict) torch_results = task_processor.run_inference(torch_model, input_dict)
backend_results = task_processor.run_inference(backend_model, input_dict) backend_results = task_processor.run_inference(backend_model, input_dict)

View File

@ -40,7 +40,7 @@ img = np.random.rand(*img_shape, 3)
@pytest.mark.parametrize('from_mmrazor', [True, False, '123', 0]) @pytest.mark.parametrize('from_mmrazor', [True, False, '123', 0])
def test_init_pytorch_model(from_mmrazor: Any): def test_build_pytorch_model(from_mmrazor: Any):
from mmseg.models.segmentors.base import BaseSegmentor from mmseg.models.segmentors.base import BaseSegmentor
if from_mmrazor is False: if from_mmrazor is False:
_task_processor = task_processor _task_processor = task_processor
@ -65,7 +65,7 @@ def test_init_pytorch_model(from_mmrazor: Any):
assert from_mmrazor == _task_processor.from_mmrazor assert from_mmrazor == _task_processor.from_mmrazor
if from_mmrazor: if from_mmrazor:
pytest.importorskip('mmrazor', reason='mmrazor is not installed.') pytest.importorskip('mmrazor', reason='mmrazor is not installed.')
model = _task_processor.init_pytorch_model(None) model = _task_processor.build_pytorch_model(None)
assert isinstance(model, BaseSegmentor) assert isinstance(model, BaseSegmentor)
@ -78,12 +78,12 @@ def backend_model():
'output': torch.rand(1, 1, *img_shape), 'output': torch.rand(1, 1, *img_shape),
}) })
yield task_processor.init_backend_model(['']) yield task_processor.build_backend_model([''])
wrapper.recover() wrapper.recover()
def test_init_backend_model(backend_model): def test_build_backend_model(backend_model):
assert isinstance(backend_model, torch.nn.Module) assert isinstance(backend_model, torch.nn.Module)

View File

@ -100,7 +100,7 @@ def main():
dataloader = task_processor.build_dataloader(test_dataloader) dataloader = task_processor.build_dataloader(test_dataloader)
# load the model of the backend # load the model of the backend
model = task_processor.init_backend_model(args.model) model = task_processor.build_backend_model(args.model)
is_device_cpu = (args.device == 'cpu') is_device_cpu = (args.device == 'cpu')