[Fix] Fix batch inference error for Mask R-CNN (#1576)

* sync pr1575 to dev-1.x

* only test_img accept list input
pull/1597/head
hanrui1sensetime 2022-12-30 14:27:41 +08:00 committed by GitHub
parent 7f2e8f7ce0
commit baa86aa4a5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 20 additions and 12 deletions

View File

@ -11,7 +11,7 @@ from mmdeploy.utils import Backend, get_backend, get_input_shape, load_config
def visualize_model(model_cfg: Union[str, mmengine.Config],
deploy_cfg: Union[str, mmengine.Config],
model: Union[str, Sequence[str]],
img: Union[str, np.ndarray],
img: Union[str, np.ndarray, Sequence[str]],
device: str,
backend: Optional[Backend] = None,
output_file: Optional[str] = None,
@ -35,8 +35,9 @@ def visualize_model(model_cfg: Union[str, mmengine.Config],
model_cfg (str | mmengine.Config): Model config file or Config object.
deploy_cfg (str | mmengine.Config): Deployment config file or Config
object.
model (str | list[str], BaseSubtask): Input model or file(s).
img (str | np.ndarray): Input image file or numpy array for inference.
model (str | Sequence[str]): Input model or file(s).
img (str | np.ndarray | Sequence[str]): Input image file or numpy array
for inference.
device (str): A string specifying device type.
backend (Backend): Specifying backend type, defaults to `None`.
output_file (str): Output file to save visualized image, defaults to
@ -84,10 +85,13 @@ def visualize_model(model_cfg: Union[str, mmengine.Config],
visualize = False
if visualize is True:
task_processor.visualize(
image=img,
model=model,
result=result,
output_file=output_file,
window_name=backend.value,
show_result=show_result)
if not isinstance(img, list):
img = [img]
for single_img in img:
task_processor.visualize(
image=single_img,
model=model,
result=result,
output_file=output_file,
window_name=backend.value,
show_result=show_result)

View File

@ -106,7 +106,7 @@ def standard_roi_head__predict_mask(self,
# expand might lead to static shape, use broadcast instead
batch_index = torch.arange(
det_bboxes.size(0), device=det_bboxes.device).float().view(
-1, 1) + det_bboxes.new_zeros(
-1, 1, 1) + det_bboxes.new_zeros(
(det_bboxes.size(0), det_bboxes.size(1))).unsqueeze(-1)
mask_rois = torch.cat([batch_index, det_bboxes], dim=-1)
mask_rois = mask_rois.view(-1, 5)

View File

@ -27,7 +27,11 @@ def parse_args():
parser.add_argument('checkpoint', help='model checkpoint path')
parser.add_argument('img', help='image used to convert model model')
parser.add_argument(
'--test-img', default=None, help='image used to test model')
'--test-img',
default=None,
type=str,
nargs='+',
help='image used to test model')
parser.add_argument(
'--work-dir',
default=os.getcwd(),