[Fix] Fix inferencer (#3333)

This commit is contained in:
谢昕辰 2023-09-18 19:24:03 +08:00 committed by GitHub
parent 913fe3e91c
commit f1fa61a48a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -59,7 +59,9 @@ class MMSegInferencer(BaseInferencer):
preprocess_kwargs: set = set() preprocess_kwargs: set = set()
forward_kwargs: set = {'mode', 'out_dir'} forward_kwargs: set = {'mode', 'out_dir'}
visualize_kwargs: set = {'show', 'wait_time', 'img_out_dir', 'opacity'} visualize_kwargs: set = {
'show', 'wait_time', 'img_out_dir', 'opacity', 'return_vis'
}
postprocess_kwargs: set = {'pred_out_dir', 'return_datasample'} postprocess_kwargs: set = {'pred_out_dir', 'return_datasample'}
def __init__(self, def __init__(self,
@ -137,6 +139,7 @@ class MMSegInferencer(BaseInferencer):
inputs: InputsType, inputs: InputsType,
return_datasamples: bool = False, return_datasamples: bool = False,
batch_size: int = 1, batch_size: int = 1,
return_vis: bool = False,
show: bool = False, show: bool = False,
wait_time: int = 0, wait_time: int = 0,
out_dir: str = '', out_dir: str = '',
@ -188,11 +191,13 @@ class MMSegInferencer(BaseInferencer):
wait_time=wait_time, wait_time=wait_time,
img_out_dir=img_out_dir, img_out_dir=img_out_dir,
pred_out_dir=pred_out_dir, pred_out_dir=pred_out_dir,
return_vis=return_vis,
**kwargs) **kwargs)
def visualize(self, def visualize(self,
inputs: list, inputs: list,
preds: List[dict], preds: List[dict],
return_vis: bool = False,
show: bool = False, show: bool = False,
wait_time: int = 0, wait_time: int = 0,
img_out_dir: str = '', img_out_dir: str = '',
@ -213,12 +218,12 @@ class MMSegInferencer(BaseInferencer):
Returns: Returns:
List[np.ndarray]: Visualization results. List[np.ndarray]: Visualization results.
""" """
if self.visualizer is None or (not show and img_out_dir == ''): if not show and img_out_dir == '' and not return_vis:
return None return None
if self.visualizer is None:
if getattr(self, 'visualizer') is None:
raise ValueError('Visualization needs the "visualizer" term' raise ValueError('Visualization needs the "visualizer" term'
'defined in the config, but got None') 'defined in the config, but got None.')
self.visualizer.set_dataset_meta(**self.model.dataset_meta) self.visualizer.set_dataset_meta(**self.model.dataset_meta)
self.visualizer.alpha = opacity self.visualizer.alpha = opacity
@ -250,10 +255,11 @@ class MMSegInferencer(BaseInferencer):
draw_gt=False, draw_gt=False,
draw_pred=True, draw_pred=True,
out_file=out_file) out_file=out_file)
if return_vis:
results.append(self.visualizer.get_image()) results.append(self.visualizer.get_image())
self.num_visualized_imgs += 1 self.num_visualized_imgs += 1
return results return results if return_vis else None
def postprocess(self, def postprocess(self,
preds: PredType, preds: PredType,