[Fix] Fix batch inference error for Mask R-CNN (#1576)
* sync pr1575 to dev-1.x * only test_img accept list inputpull/1597/head
parent
7f2e8f7ce0
commit
baa86aa4a5
|
@ -11,7 +11,7 @@ from mmdeploy.utils import Backend, get_backend, get_input_shape, load_config
|
|||
def visualize_model(model_cfg: Union[str, mmengine.Config],
|
||||
deploy_cfg: Union[str, mmengine.Config],
|
||||
model: Union[str, Sequence[str]],
|
||||
img: Union[str, np.ndarray],
|
||||
img: Union[str, np.ndarray, Sequence[str]],
|
||||
device: str,
|
||||
backend: Optional[Backend] = None,
|
||||
output_file: Optional[str] = None,
|
||||
|
@ -35,8 +35,9 @@ def visualize_model(model_cfg: Union[str, mmengine.Config],
|
|||
model_cfg (str | mmengine.Config): Model config file or Config object.
|
||||
deploy_cfg (str | mmengine.Config): Deployment config file or Config
|
||||
object.
|
||||
model (str | list[str], BaseSubtask): Input model or file(s).
|
||||
img (str | np.ndarray): Input image file or numpy array for inference.
|
||||
model (str | Sequence[str]): Input model or file(s).
|
||||
img (str | np.ndarray | Sequence[str]): Input image file or numpy array
|
||||
for inference.
|
||||
device (str): A string specifying device type.
|
||||
backend (Backend): Specifying backend type, defaults to `None`.
|
||||
output_file (str): Output file to save visualized image, defaults to
|
||||
|
@ -84,10 +85,13 @@ def visualize_model(model_cfg: Union[str, mmengine.Config],
|
|||
visualize = False
|
||||
|
||||
if visualize is True:
|
||||
task_processor.visualize(
|
||||
image=img,
|
||||
model=model,
|
||||
result=result,
|
||||
output_file=output_file,
|
||||
window_name=backend.value,
|
||||
show_result=show_result)
|
||||
if not isinstance(img, list):
|
||||
img = [img]
|
||||
for single_img in img:
|
||||
task_processor.visualize(
|
||||
image=single_img,
|
||||
model=model,
|
||||
result=result,
|
||||
output_file=output_file,
|
||||
window_name=backend.value,
|
||||
show_result=show_result)
|
||||
|
|
|
@ -106,7 +106,7 @@ def standard_roi_head__predict_mask(self,
|
|||
# expand might lead to static shape, use broadcast instead
|
||||
batch_index = torch.arange(
|
||||
det_bboxes.size(0), device=det_bboxes.device).float().view(
|
||||
-1, 1) + det_bboxes.new_zeros(
|
||||
-1, 1, 1) + det_bboxes.new_zeros(
|
||||
(det_bboxes.size(0), det_bboxes.size(1))).unsqueeze(-1)
|
||||
mask_rois = torch.cat([batch_index, det_bboxes], dim=-1)
|
||||
mask_rois = mask_rois.view(-1, 5)
|
||||
|
|
|
@ -27,7 +27,11 @@ def parse_args():
|
|||
parser.add_argument('checkpoint', help='model checkpoint path')
|
||||
parser.add_argument('img', help='image used to convert model model')
|
||||
parser.add_argument(
|
||||
'--test-img', default=None, help='image used to test model')
|
||||
'--test-img',
|
||||
default=None,
|
||||
type=str,
|
||||
nargs='+',
|
||||
help='image used to test model')
|
||||
parser.add_argument(
|
||||
'--work-dir',
|
||||
default=os.getcwd(),
|
||||
|
|
Loading…
Reference in New Issue