mirror of
https://github.com/open-mmlab/mmdeploy.git
synced 2025-01-14 08:09:43 +08:00
rename method
This commit is contained in:
parent
41a97746cf
commit
8c37f79f06
@ -57,7 +57,7 @@ def create_calib_input_data(calib_file: str,
|
||||
|
||||
apply_marks = cfg_apply_marks(deploy_cfg)
|
||||
|
||||
model = task_processor.init_pytorch_model(model_checkpoint)
|
||||
model = task_processor.build_pytorch_model(model_checkpoint)
|
||||
dataset = task_processor.build_dataset(dataset_cfg, dataset_type)
|
||||
|
||||
# patch model
|
||||
|
@ -42,7 +42,7 @@ def inference_model(model_cfg: Union[str, mmcv.Config],
|
||||
from mmdeploy.apis.utils import build_task_processor
|
||||
task_processor = build_task_processor(model_cfg, deploy_cfg, device)
|
||||
|
||||
model = task_processor.init_backend_model(backend_files)
|
||||
model = task_processor.build_backend_model(backend_files)
|
||||
|
||||
input_shape = get_input_shape(deploy_cfg)
|
||||
model_inputs, _ = task_processor.create_input(img, input_shape)
|
||||
|
@ -59,7 +59,7 @@ def torch2onnx(img: Any,
|
||||
from mmdeploy.apis import build_task_processor
|
||||
task_processor = build_task_processor(model_cfg, deploy_cfg, device)
|
||||
|
||||
torch_model = task_processor.init_pytorch_model(model_checkpoint)
|
||||
torch_model = task_processor.build_pytorch_model(model_checkpoint)
|
||||
data, model_inputs = task_processor.create_input(
|
||||
img,
|
||||
input_shape,
|
||||
|
@ -41,7 +41,7 @@ def torch2torchscript(img: Any,
|
||||
from mmdeploy.apis import build_task_processor
|
||||
task_processor = build_task_processor(model_cfg, deploy_cfg, device)
|
||||
|
||||
torch_model = task_processor.init_pytorch_model(model_checkpoint)
|
||||
torch_model = task_processor.build_pytorch_model(model_checkpoint)
|
||||
_, model_inputs = task_processor.create_input(img, input_shape)
|
||||
if not isinstance(model_inputs, torch.Tensor):
|
||||
model_inputs = model_inputs[0]
|
||||
|
@ -62,9 +62,9 @@ def visualize_model(model_cfg: Union[str, mmcv.Config],
|
||||
list should be str'
|
||||
|
||||
if backend == Backend.PYTORCH:
|
||||
model = task_processor.init_pytorch_model(model[0])
|
||||
model = task_processor.build_pytorch_model(model[0])
|
||||
else:
|
||||
model = task_processor.init_backend_model(model)
|
||||
model = task_processor.build_backend_model(model)
|
||||
|
||||
model_inputs, _ = task_processor.create_input(img, input_shape)
|
||||
with torch.no_grad():
|
||||
|
@ -54,9 +54,9 @@ class BaseTask(metaclass=ABCMeta):
|
||||
self.visualizer = self.model_cfg.visualizer
|
||||
|
||||
@abstractmethod
|
||||
def init_backend_model(self,
|
||||
model_files: Sequence[str] = None,
|
||||
**kwargs) -> torch.nn.Module:
|
||||
def build_backend_model(self,
|
||||
model_files: Sequence[str] = None,
|
||||
**kwargs) -> torch.nn.Module:
|
||||
"""Initialize backend model.
|
||||
|
||||
Args:
|
||||
@ -67,10 +67,10 @@ class BaseTask(metaclass=ABCMeta):
|
||||
"""
|
||||
pass
|
||||
|
||||
def init_pytorch_model(self,
|
||||
model_checkpoint: Optional[str] = None,
|
||||
cfg_options: Optional[Dict] = None,
|
||||
**kwargs) -> torch.nn.Module:
|
||||
def build_pytorch_model(self,
|
||||
model_checkpoint: Optional[str] = None,
|
||||
cfg_options: Optional[Dict] = None,
|
||||
**kwargs) -> torch.nn.Module:
|
||||
"""Initialize torch model.
|
||||
|
||||
Args:
|
||||
|
@ -105,9 +105,9 @@ class Classification(BaseTask):
|
||||
super(Classification, self).__init__(model_cfg, deploy_cfg, device,
|
||||
experiment_name)
|
||||
|
||||
def init_backend_model(self,
|
||||
model_files: Sequence[str] = None,
|
||||
**kwargs) -> torch.nn.Module:
|
||||
def build_backend_model(self,
|
||||
model_files: Sequence[str] = None,
|
||||
**kwargs) -> torch.nn.Module:
|
||||
"""Initialize backend model.
|
||||
|
||||
Args:
|
||||
|
@ -58,9 +58,9 @@ class ObjectDetection(BaseTask):
|
||||
device: str) -> None:
|
||||
super().__init__(model_cfg, deploy_cfg, device)
|
||||
|
||||
def init_backend_model(self,
|
||||
model_files: Optional[str] = None,
|
||||
**kwargs) -> torch.nn.Module:
|
||||
def build_backend_model(self,
|
||||
model_files: Optional[str] = None,
|
||||
**kwargs) -> torch.nn.Module:
|
||||
"""Initialize backend model.
|
||||
|
||||
Args:
|
||||
@ -74,10 +74,10 @@ class ObjectDetection(BaseTask):
|
||||
model_files, self.model_cfg, self.deploy_cfg, device=self.device)
|
||||
return model.eval()
|
||||
|
||||
def init_pytorch_model(self,
|
||||
model_checkpoint: Optional[str] = None,
|
||||
cfg_options: Optional[Dict] = None,
|
||||
**kwargs) -> torch.nn.Module:
|
||||
def build_pytorch_model(self,
|
||||
model_checkpoint: Optional[str] = None,
|
||||
cfg_options: Optional[Dict] = None,
|
||||
**kwargs) -> torch.nn.Module:
|
||||
"""Initialize torch model.
|
||||
|
||||
Args:
|
||||
|
@ -23,9 +23,9 @@ class VoxelDetection(BaseTask):
|
||||
device: str):
|
||||
super().__init__(model_cfg, deploy_cfg, device)
|
||||
|
||||
def init_backend_model(self,
|
||||
model_files: Sequence[str] = None,
|
||||
**kwargs) -> torch.nn.Module:
|
||||
def build_backend_model(self,
|
||||
model_files: Sequence[str] = None,
|
||||
**kwargs) -> torch.nn.Module:
|
||||
"""Initialize backend model.
|
||||
|
||||
Args:
|
||||
@ -39,10 +39,10 @@ class VoxelDetection(BaseTask):
|
||||
model_files, self.model_cfg, self.deploy_cfg, device=self.device)
|
||||
return model
|
||||
|
||||
def init_pytorch_model(self,
|
||||
model_checkpoint: Optional[str] = None,
|
||||
cfg_options: Optional[Dict] = None,
|
||||
**kwargs) -> torch.nn.Module:
|
||||
def build_pytorch_model(self,
|
||||
model_checkpoint: Optional[str] = None,
|
||||
cfg_options: Optional[Dict] = None,
|
||||
**kwargs) -> torch.nn.Module:
|
||||
"""Initialize torch model.
|
||||
|
||||
Args:
|
||||
|
@ -76,9 +76,9 @@ class SuperResolution(BaseTask):
|
||||
device: str):
|
||||
super().__init__(model_cfg, deploy_cfg, device)
|
||||
|
||||
def init_backend_model(self,
|
||||
model_files: Sequence[str] = None,
|
||||
**kwargs) -> torch.nn.Module:
|
||||
def build_backend_model(self,
|
||||
model_files: Sequence[str] = None,
|
||||
**kwargs) -> torch.nn.Module:
|
||||
"""Initialize backend model.
|
||||
|
||||
Args:
|
||||
@ -92,9 +92,9 @@ class SuperResolution(BaseTask):
|
||||
model_files, self.model_cfg, self.deploy_cfg, device=self.device)
|
||||
return model
|
||||
|
||||
def init_pytorch_model(self,
|
||||
model_checkpoint: Optional[str] = None,
|
||||
**kwargs) -> torch.nn.Module:
|
||||
def build_pytorch_model(self,
|
||||
model_checkpoint: Optional[str] = None,
|
||||
**kwargs) -> torch.nn.Module:
|
||||
"""Initialize torch model.
|
||||
|
||||
Args:
|
||||
|
@ -63,9 +63,9 @@ class TextDetection(BaseTask):
|
||||
device: str):
|
||||
super(TextDetection, self).__init__(model_cfg, deploy_cfg, device)
|
||||
|
||||
def init_backend_model(self,
|
||||
model_files: Optional[str] = None,
|
||||
**kwargs) -> torch.nn.Module:
|
||||
def build_backend_model(self,
|
||||
model_files: Optional[str] = None,
|
||||
**kwargs) -> torch.nn.Module:
|
||||
"""Initialize backend model.
|
||||
|
||||
Args:
|
||||
@ -79,10 +79,10 @@ class TextDetection(BaseTask):
|
||||
model_files, self.model_cfg, self.deploy_cfg, device=self.device)
|
||||
return model.eval()
|
||||
|
||||
def init_pytorch_model(self,
|
||||
model_checkpoint: Optional[str] = None,
|
||||
cfg_options: Optional[Dict] = None,
|
||||
**kwargs) -> torch.nn.Module:
|
||||
def build_pytorch_model(self,
|
||||
model_checkpoint: Optional[str] = None,
|
||||
cfg_options: Optional[Dict] = None,
|
||||
**kwargs) -> torch.nn.Module:
|
||||
"""Initialize torch model.
|
||||
|
||||
Args:
|
||||
|
@ -75,9 +75,9 @@ class TextRecognition(BaseTask):
|
||||
device: str):
|
||||
super(TextRecognition, self).__init__(model_cfg, deploy_cfg, device)
|
||||
|
||||
def init_backend_model(self,
|
||||
model_files: Optional[str] = None,
|
||||
**kwargs) -> torch.nn.Module:
|
||||
def build_backend_model(self,
|
||||
model_files: Optional[str] = None,
|
||||
**kwargs) -> torch.nn.Module:
|
||||
"""Initialize backend model.
|
||||
|
||||
Args:
|
||||
@ -91,10 +91,10 @@ class TextRecognition(BaseTask):
|
||||
model_files, self.model_cfg, self.deploy_cfg, device=self.device)
|
||||
return model.eval()
|
||||
|
||||
def init_pytorch_model(self,
|
||||
model_checkpoint: Optional[str] = None,
|
||||
cfg_options: Optional[Dict] = None,
|
||||
**kwargs) -> torch.nn.Module:
|
||||
def build_pytorch_model(self,
|
||||
model_checkpoint: Optional[str] = None,
|
||||
cfg_options: Optional[Dict] = None,
|
||||
**kwargs) -> torch.nn.Module:
|
||||
"""Initialize torch model.
|
||||
|
||||
Args:
|
||||
|
@ -86,9 +86,9 @@ class PoseDetection(BaseTask):
|
||||
device: str):
|
||||
super().__init__(model_cfg, deploy_cfg, device)
|
||||
|
||||
def init_backend_model(self,
|
||||
model_files: Sequence[str] = None,
|
||||
**kwargs) -> torch.nn.Module:
|
||||
def build_backend_model(self,
|
||||
model_files: Sequence[str] = None,
|
||||
**kwargs) -> torch.nn.Module:
|
||||
"""Initialize backend model.
|
||||
|
||||
Args:
|
||||
@ -102,9 +102,9 @@ class PoseDetection(BaseTask):
|
||||
model_files, self.model_cfg, self.deploy_cfg, device=self.device)
|
||||
return model.eval()
|
||||
|
||||
def init_pytorch_model(self,
|
||||
model_checkpoint: Optional[str] = None,
|
||||
**kwargs) -> torch.nn.Module:
|
||||
def build_pytorch_model(self,
|
||||
model_checkpoint: Optional[str] = None,
|
||||
**kwargs) -> torch.nn.Module:
|
||||
"""Initialize torch model.
|
||||
|
||||
Args:
|
||||
|
@ -85,9 +85,9 @@ class RotatedDetection(BaseTask):
|
||||
device: str):
|
||||
super(RotatedDetection, self).__init__(model_cfg, deploy_cfg, device)
|
||||
|
||||
def init_backend_model(self,
|
||||
model_files: Optional[str] = None,
|
||||
**kwargs) -> torch.nn.Module:
|
||||
def build_backend_model(self,
|
||||
model_files: Optional[str] = None,
|
||||
**kwargs) -> torch.nn.Module:
|
||||
"""Initialize backend model.
|
||||
|
||||
Args:
|
||||
@ -101,10 +101,10 @@ class RotatedDetection(BaseTask):
|
||||
model_files, self.model_cfg, self.deploy_cfg, device=self.device)
|
||||
return model.eval()
|
||||
|
||||
def init_pytorch_model(self,
|
||||
model_checkpoint: Optional[str] = None,
|
||||
cfg_options: Optional[Dict] = None,
|
||||
**kwargs) -> torch.nn.Module:
|
||||
def build_pytorch_model(self,
|
||||
model_checkpoint: Optional[str] = None,
|
||||
cfg_options: Optional[Dict] = None,
|
||||
**kwargs) -> torch.nn.Module:
|
||||
"""Initialize torch model.
|
||||
|
||||
Args:
|
||||
|
@ -57,9 +57,9 @@ class Segmentation(BaseTask):
|
||||
device: str):
|
||||
super(Segmentation, self).__init__(model_cfg, deploy_cfg, device)
|
||||
|
||||
def init_backend_model(self,
|
||||
model_files: Optional[str] = None,
|
||||
**kwargs) -> torch.nn.Module:
|
||||
def build_backend_model(self,
|
||||
model_files: Optional[str] = None,
|
||||
**kwargs) -> torch.nn.Module:
|
||||
"""Initialize backend model.
|
||||
|
||||
Args:
|
||||
@ -73,10 +73,10 @@ class Segmentation(BaseTask):
|
||||
model_files, self.model_cfg, self.deploy_cfg, device=self.device)
|
||||
return model.eval()
|
||||
|
||||
def init_pytorch_model(self,
|
||||
model_checkpoint: Optional[str] = None,
|
||||
cfg_options: Optional[Dict] = None,
|
||||
**kwargs) -> torch.nn.Module:
|
||||
def build_pytorch_model(self,
|
||||
model_checkpoint: Optional[str] = None,
|
||||
cfg_options: Optional[Dict] = None,
|
||||
**kwargs) -> torch.nn.Module:
|
||||
"""Initialize torch model.
|
||||
|
||||
Args:
|
||||
|
@ -40,7 +40,7 @@ img = np.random.rand(*img_shape, 3)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('from_mmrazor', [True, False, '123', 0])
|
||||
def test_init_pytorch_model(from_mmrazor: Any):
|
||||
def test_build_pytorch_model(from_mmrazor: Any):
|
||||
from mmcls.models.classifiers.base import BaseClassifier
|
||||
if from_mmrazor is False:
|
||||
_task_processor = task_processor
|
||||
@ -66,7 +66,7 @@ def test_init_pytorch_model(from_mmrazor: Any):
|
||||
assert from_mmrazor == _task_processor.from_mmrazor
|
||||
if from_mmrazor:
|
||||
pytest.importorskip('mmrazor', reason='mmrazor is not installed.')
|
||||
model = _task_processor.init_pytorch_model(None)
|
||||
model = _task_processor.build_pytorch_model(None)
|
||||
assert isinstance(model, BaseClassifier)
|
||||
|
||||
|
||||
@ -79,12 +79,12 @@ def backend_model():
|
||||
'output': torch.rand(1, num_classes),
|
||||
})
|
||||
|
||||
yield task_processor.init_backend_model([''])
|
||||
yield task_processor.build_backend_model([''])
|
||||
|
||||
wrapper.recover()
|
||||
|
||||
|
||||
def test_init_backend_model(backend_model):
|
||||
def test_build_backend_model(backend_model):
|
||||
assert isinstance(backend_model, torch.nn.Module)
|
||||
|
||||
|
||||
|
@ -51,7 +51,7 @@ img = np.random.rand(*img_shape, 3)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('from_mmrazor', [True, False, '123', 0])
|
||||
def test_init_pytorch_model(from_mmrazor: Any):
|
||||
def test_build_pytorch_model(from_mmrazor: Any):
|
||||
from mmdet.models import BaseDetector
|
||||
if from_mmrazor is False:
|
||||
_task_processor = task_processor
|
||||
@ -77,7 +77,7 @@ def test_init_pytorch_model(from_mmrazor: Any):
|
||||
assert from_mmrazor == _task_processor.from_mmrazor
|
||||
if from_mmrazor:
|
||||
pytest.importorskip('mmrazor', reason='mmrazor is not installed.')
|
||||
model = _task_processor.init_pytorch_model(None)
|
||||
model = _task_processor.build_pytorch_model(None)
|
||||
assert isinstance(model, BaseDetector)
|
||||
|
||||
|
||||
@ -91,12 +91,12 @@ def backend_model():
|
||||
'labels': torch.rand(1, 10)
|
||||
})
|
||||
|
||||
yield task_processor.init_backend_model([''])
|
||||
yield task_processor.build_backend_model([''])
|
||||
|
||||
wrapper.recover()
|
||||
|
||||
|
||||
def test_init_backend_model(backend_model):
|
||||
def test_build_backend_model(backend_model):
|
||||
from mmdeploy.codebase.mmdet.deploy.object_detection_model import \
|
||||
End2EndModel
|
||||
assert isinstance(backend_model, End2EndModel)
|
||||
@ -131,7 +131,7 @@ def test_create_input(device):
|
||||
|
||||
|
||||
def test_run_inference(backend_model):
|
||||
torch_model = task_processor.init_pytorch_model(None)
|
||||
torch_model = task_processor.build_pytorch_model(None)
|
||||
input_dict, _ = task_processor.create_input(img, input_shape=img_shape)
|
||||
torch_results = task_processor.run_inference(torch_model, input_dict)
|
||||
backend_results = task_processor.run_inference(backend_model, input_dict)
|
||||
|
@ -39,9 +39,9 @@ onnx_file = NamedTemporaryFile(suffix='.onnx').name
|
||||
task_processor = build_task_processor(model_cfg, deploy_cfg, 'cpu')
|
||||
|
||||
|
||||
def test_init_pytorch_model():
|
||||
def test_build_pytorch_model():
|
||||
from mmdet3d.models import Base3DDetector
|
||||
model = task_processor.init_pytorch_model(None)
|
||||
model = task_processor.build_pytorch_model(None)
|
||||
assert isinstance(model, Base3DDetector)
|
||||
|
||||
|
||||
@ -57,12 +57,12 @@ def backend_model():
|
||||
'dir_scores': torch.rand(1, 12, 32, 32)
|
||||
})
|
||||
|
||||
yield task_processor.init_backend_model([''])
|
||||
yield task_processor.build_backend_model([''])
|
||||
|
||||
wrapper.recover()
|
||||
|
||||
|
||||
def test_init_backend_model(backend_model):
|
||||
def test_build_backend_model(backend_model):
|
||||
from mmdeploy.codebase.mmdet3d.deploy.voxel_detection_model import \
|
||||
VoxelDetectionModel
|
||||
assert isinstance(backend_model, VoxelDetectionModel)
|
||||
@ -83,7 +83,7 @@ def test_create_input(device):
|
||||
reason='Only support GPU test', condition=not torch.cuda.is_available())
|
||||
def test_run_inference(backend_model):
|
||||
task_processor.device = 'cuda:0'
|
||||
torch_model = task_processor.init_pytorch_model(None)
|
||||
torch_model = task_processor.build_pytorch_model(None)
|
||||
input_dict, _ = task_processor.create_input(pcd_path)
|
||||
torch_results = task_processor.run_inference(torch_model, input_dict)
|
||||
backend_results = task_processor.run_inference(backend_model, input_dict)
|
||||
@ -98,7 +98,7 @@ def test_run_inference(backend_model):
|
||||
def test_visualize():
|
||||
task_processor.device = 'cuda:0'
|
||||
input_dict, _ = task_processor.create_input(pcd_path)
|
||||
torch_model = task_processor.init_pytorch_model(None)
|
||||
torch_model = task_processor.build_pytorch_model(None)
|
||||
results = task_processor.run_inference(torch_model, input_dict)
|
||||
with TemporaryDirectory() as dir:
|
||||
filename = dir + 'tmp.bin'
|
||||
|
@ -37,8 +37,8 @@ onnx_file = NamedTemporaryFile(suffix='.onnx').name
|
||||
task_processor = build_task_processor(model_cfg, deploy_cfg, 'cpu')
|
||||
|
||||
|
||||
def test_init_pytorch_model():
|
||||
torch_model = task_processor.init_pytorch_model(None)
|
||||
def test_build_pytorch_model():
|
||||
torch_model = task_processor.build_pytorch_model(None)
|
||||
assert torch_model is not None
|
||||
|
||||
|
||||
@ -51,12 +51,12 @@ def backend_model():
|
||||
'output': torch.rand(3, 50, 50),
|
||||
})
|
||||
|
||||
yield task_processor.init_backend_model([''])
|
||||
yield task_processor.build_backend_model([''])
|
||||
|
||||
wrapper.recover()
|
||||
|
||||
|
||||
def test_init_backend_model(backend_model):
|
||||
def test_build_backend_model(backend_model):
|
||||
assert backend_model is not None
|
||||
|
||||
|
||||
|
@ -37,10 +37,10 @@ img_shape = (32, 32)
|
||||
img = np.random.rand(*img_shape, 3).astype(np.uint8)
|
||||
|
||||
|
||||
def test_init_pytorch_model():
|
||||
def test_build_pytorch_model():
|
||||
from mmocr.models.textdet.detectors.single_stage_text_detector import \
|
||||
SingleStageDetector
|
||||
model = task_processor.init_pytorch_model(None)
|
||||
model = task_processor.build_pytorch_model(None)
|
||||
assert isinstance(model, SingleStageDetector)
|
||||
|
||||
|
||||
@ -53,12 +53,12 @@ def backend_model():
|
||||
'output': torch.rand(1, 3, *img_shape),
|
||||
})
|
||||
|
||||
yield task_processor.init_backend_model([''])
|
||||
yield task_processor.build_backend_model([''])
|
||||
|
||||
wrapper.recover()
|
||||
|
||||
|
||||
def test_init_backend_model(backend_model):
|
||||
def test_build_backend_model(backend_model):
|
||||
assert isinstance(backend_model, torch.nn.Module)
|
||||
|
||||
|
||||
|
@ -37,9 +37,9 @@ img_shape = (32, 32)
|
||||
img = np.random.rand(*img_shape, 3).astype(np.uint8)
|
||||
|
||||
|
||||
def test_init_pytorch_model():
|
||||
def test_build_pytorch_model():
|
||||
from mmocr.models.textrecog.recognizer import BaseRecognizer
|
||||
model = task_processor.init_pytorch_model(None)
|
||||
model = task_processor.build_pytorch_model(None)
|
||||
assert isinstance(model, BaseRecognizer)
|
||||
|
||||
|
||||
@ -52,12 +52,12 @@ def backend_model():
|
||||
'output': torch.rand(1, 9, 37),
|
||||
})
|
||||
|
||||
yield task_processor.init_backend_model([''])
|
||||
yield task_processor.build_backend_model([''])
|
||||
|
||||
wrapper.recover()
|
||||
|
||||
|
||||
def test_init_backend_model(backend_model):
|
||||
def test_build_backend_model(backend_model):
|
||||
assert isinstance(backend_model, torch.nn.Module)
|
||||
|
||||
|
||||
|
@ -65,9 +65,9 @@ def test_create_input():
|
||||
assert isinstance(inputs, tuple) and len(inputs) == 2
|
||||
|
||||
|
||||
def test_init_pytorch_model():
|
||||
def test_build_pytorch_model():
|
||||
from mmpose.models.detectors.base import BasePose
|
||||
model = task_processor.init_pytorch_model(None)
|
||||
model = task_processor.build_pytorch_model(None)
|
||||
assert isinstance(model, BasePose)
|
||||
|
||||
|
||||
@ -80,12 +80,12 @@ def backend_model():
|
||||
'output': torch.rand(1, num_output_channels, *heatmap_shape),
|
||||
})
|
||||
|
||||
yield task_processor.init_backend_model([''])
|
||||
yield task_processor.build_backend_model([''])
|
||||
|
||||
wrapper.recover()
|
||||
|
||||
|
||||
def test_init_backend_model(backend_model):
|
||||
def test_build_backend_model(backend_model):
|
||||
assert isinstance(backend_model, torch.nn.Module)
|
||||
|
||||
|
||||
|
@ -48,9 +48,9 @@ img_shape = (32, 32)
|
||||
img = np.random.rand(*img_shape, 3)
|
||||
|
||||
|
||||
def test_init_pytorch_model():
|
||||
def test_build_pytorch_model():
|
||||
from mmrotate.models import RotatedBaseDetector
|
||||
model = task_processor.init_pytorch_model(None)
|
||||
model = task_processor.build_pytorch_model(None)
|
||||
assert isinstance(model, RotatedBaseDetector)
|
||||
|
||||
|
||||
@ -64,12 +64,12 @@ def backend_model():
|
||||
'labels': torch.rand(1, 10)
|
||||
})
|
||||
|
||||
yield task_processor.init_backend_model([''])
|
||||
yield task_processor.build_backend_model([''])
|
||||
|
||||
wrapper.recover()
|
||||
|
||||
|
||||
def test_init_backend_model(backend_model):
|
||||
def test_build_backend_model(backend_model):
|
||||
from mmdeploy.codebase.mmrotate.deploy.rotated_detection_model import \
|
||||
End2EndModel
|
||||
assert isinstance(backend_model, End2EndModel)
|
||||
@ -85,7 +85,7 @@ def test_create_input(device):
|
||||
|
||||
|
||||
def test_run_inference(backend_model):
|
||||
torch_model = task_processor.init_pytorch_model(None)
|
||||
torch_model = task_processor.build_pytorch_model(None)
|
||||
input_dict, _ = task_processor.create_input(img, input_shape=img_shape)
|
||||
torch_results = task_processor.run_inference(torch_model, input_dict)
|
||||
backend_results = task_processor.run_inference(backend_model, input_dict)
|
||||
|
@ -40,7 +40,7 @@ img = np.random.rand(*img_shape, 3)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('from_mmrazor', [True, False, '123', 0])
|
||||
def test_init_pytorch_model(from_mmrazor: Any):
|
||||
def test_build_pytorch_model(from_mmrazor: Any):
|
||||
from mmseg.models.segmentors.base import BaseSegmentor
|
||||
if from_mmrazor is False:
|
||||
_task_processor = task_processor
|
||||
@ -65,7 +65,7 @@ def test_init_pytorch_model(from_mmrazor: Any):
|
||||
assert from_mmrazor == _task_processor.from_mmrazor
|
||||
if from_mmrazor:
|
||||
pytest.importorskip('mmrazor', reason='mmrazor is not installed.')
|
||||
model = _task_processor.init_pytorch_model(None)
|
||||
model = _task_processor.build_pytorch_model(None)
|
||||
assert isinstance(model, BaseSegmentor)
|
||||
|
||||
|
||||
@ -78,12 +78,12 @@ def backend_model():
|
||||
'output': torch.rand(1, 1, *img_shape),
|
||||
})
|
||||
|
||||
yield task_processor.init_backend_model([''])
|
||||
yield task_processor.build_backend_model([''])
|
||||
|
||||
wrapper.recover()
|
||||
|
||||
|
||||
def test_init_backend_model(backend_model):
|
||||
def test_build_backend_model(backend_model):
|
||||
assert isinstance(backend_model, torch.nn.Module)
|
||||
|
||||
|
||||
|
@ -100,7 +100,7 @@ def main():
|
||||
dataloader = task_processor.build_dataloader(test_dataloader)
|
||||
|
||||
# load the model of the backend
|
||||
model = task_processor.init_backend_model(args.model)
|
||||
model = task_processor.build_backend_model(args.model)
|
||||
|
||||
is_device_cpu = (args.device == 'cpu')
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user