mirror of
https://github.com/open-mmlab/mmdeploy.git
synced 2025-01-14 08:09:43 +08:00
[Fix] Fix batch inference error for Mask R-CNN (#1576)
* sync pr1575 to dev-1.x * only test_img accept list input
This commit is contained in:
parent
7f2e8f7ce0
commit
baa86aa4a5
@ -11,7 +11,7 @@ from mmdeploy.utils import Backend, get_backend, get_input_shape, load_config
|
|||||||
def visualize_model(model_cfg: Union[str, mmengine.Config],
|
def visualize_model(model_cfg: Union[str, mmengine.Config],
|
||||||
deploy_cfg: Union[str, mmengine.Config],
|
deploy_cfg: Union[str, mmengine.Config],
|
||||||
model: Union[str, Sequence[str]],
|
model: Union[str, Sequence[str]],
|
||||||
img: Union[str, np.ndarray],
|
img: Union[str, np.ndarray, Sequence[str]],
|
||||||
device: str,
|
device: str,
|
||||||
backend: Optional[Backend] = None,
|
backend: Optional[Backend] = None,
|
||||||
output_file: Optional[str] = None,
|
output_file: Optional[str] = None,
|
||||||
@ -35,8 +35,9 @@ def visualize_model(model_cfg: Union[str, mmengine.Config],
|
|||||||
model_cfg (str | mmengine.Config): Model config file or Config object.
|
model_cfg (str | mmengine.Config): Model config file or Config object.
|
||||||
deploy_cfg (str | mmengine.Config): Deployment config file or Config
|
deploy_cfg (str | mmengine.Config): Deployment config file or Config
|
||||||
object.
|
object.
|
||||||
model (str | list[str], BaseSubtask): Input model or file(s).
|
model (str | Sequence[str]): Input model or file(s).
|
||||||
img (str | np.ndarray): Input image file or numpy array for inference.
|
img (str | np.ndarray | Sequence[str]): Input image file or numpy array
|
||||||
|
for inference.
|
||||||
device (str): A string specifying device type.
|
device (str): A string specifying device type.
|
||||||
backend (Backend): Specifying backend type, defaults to `None`.
|
backend (Backend): Specifying backend type, defaults to `None`.
|
||||||
output_file (str): Output file to save visualized image, defaults to
|
output_file (str): Output file to save visualized image, defaults to
|
||||||
@ -84,8 +85,11 @@ def visualize_model(model_cfg: Union[str, mmengine.Config],
|
|||||||
visualize = False
|
visualize = False
|
||||||
|
|
||||||
if visualize is True:
|
if visualize is True:
|
||||||
|
if not isinstance(img, list):
|
||||||
|
img = [img]
|
||||||
|
for single_img in img:
|
||||||
task_processor.visualize(
|
task_processor.visualize(
|
||||||
image=img,
|
image=single_img,
|
||||||
model=model,
|
model=model,
|
||||||
result=result,
|
result=result,
|
||||||
output_file=output_file,
|
output_file=output_file,
|
||||||
|
@ -106,7 +106,7 @@ def standard_roi_head__predict_mask(self,
|
|||||||
# expand might lead to static shape, use broadcast instead
|
# expand might lead to static shape, use broadcast instead
|
||||||
batch_index = torch.arange(
|
batch_index = torch.arange(
|
||||||
det_bboxes.size(0), device=det_bboxes.device).float().view(
|
det_bboxes.size(0), device=det_bboxes.device).float().view(
|
||||||
-1, 1) + det_bboxes.new_zeros(
|
-1, 1, 1) + det_bboxes.new_zeros(
|
||||||
(det_bboxes.size(0), det_bboxes.size(1))).unsqueeze(-1)
|
(det_bboxes.size(0), det_bboxes.size(1))).unsqueeze(-1)
|
||||||
mask_rois = torch.cat([batch_index, det_bboxes], dim=-1)
|
mask_rois = torch.cat([batch_index, det_bboxes], dim=-1)
|
||||||
mask_rois = mask_rois.view(-1, 5)
|
mask_rois = mask_rois.view(-1, 5)
|
||||||
|
@ -27,7 +27,11 @@ def parse_args():
|
|||||||
parser.add_argument('checkpoint', help='model checkpoint path')
|
parser.add_argument('checkpoint', help='model checkpoint path')
|
||||||
parser.add_argument('img', help='image used to convert model model')
|
parser.add_argument('img', help='image used to convert model model')
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--test-img', default=None, help='image used to test model')
|
'--test-img',
|
||||||
|
default=None,
|
||||||
|
type=str,
|
||||||
|
nargs='+',
|
||||||
|
help='image used to test model')
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--work-dir',
|
'--work-dir',
|
||||||
default=os.getcwd(),
|
default=os.getcwd(),
|
||||||
|
Loading…
x
Reference in New Issue
Block a user