mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
[Fix] Fix inferencer (#3333)
This commit is contained in:
parent
913fe3e91c
commit
f1fa61a48a
@ -59,7 +59,9 @@ class MMSegInferencer(BaseInferencer):
|
|||||||
|
|
||||||
preprocess_kwargs: set = set()
|
preprocess_kwargs: set = set()
|
||||||
forward_kwargs: set = {'mode', 'out_dir'}
|
forward_kwargs: set = {'mode', 'out_dir'}
|
||||||
visualize_kwargs: set = {'show', 'wait_time', 'img_out_dir', 'opacity'}
|
visualize_kwargs: set = {
|
||||||
|
'show', 'wait_time', 'img_out_dir', 'opacity', 'return_vis'
|
||||||
|
}
|
||||||
postprocess_kwargs: set = {'pred_out_dir', 'return_datasample'}
|
postprocess_kwargs: set = {'pred_out_dir', 'return_datasample'}
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
@ -137,6 +139,7 @@ class MMSegInferencer(BaseInferencer):
|
|||||||
inputs: InputsType,
|
inputs: InputsType,
|
||||||
return_datasamples: bool = False,
|
return_datasamples: bool = False,
|
||||||
batch_size: int = 1,
|
batch_size: int = 1,
|
||||||
|
return_vis: bool = False,
|
||||||
show: bool = False,
|
show: bool = False,
|
||||||
wait_time: int = 0,
|
wait_time: int = 0,
|
||||||
out_dir: str = '',
|
out_dir: str = '',
|
||||||
@ -188,11 +191,13 @@ class MMSegInferencer(BaseInferencer):
|
|||||||
wait_time=wait_time,
|
wait_time=wait_time,
|
||||||
img_out_dir=img_out_dir,
|
img_out_dir=img_out_dir,
|
||||||
pred_out_dir=pred_out_dir,
|
pred_out_dir=pred_out_dir,
|
||||||
|
return_vis=return_vis,
|
||||||
**kwargs)
|
**kwargs)
|
||||||
|
|
||||||
def visualize(self,
|
def visualize(self,
|
||||||
inputs: list,
|
inputs: list,
|
||||||
preds: List[dict],
|
preds: List[dict],
|
||||||
|
return_vis: bool = False,
|
||||||
show: bool = False,
|
show: bool = False,
|
||||||
wait_time: int = 0,
|
wait_time: int = 0,
|
||||||
img_out_dir: str = '',
|
img_out_dir: str = '',
|
||||||
@ -213,12 +218,12 @@ class MMSegInferencer(BaseInferencer):
|
|||||||
Returns:
|
Returns:
|
||||||
List[np.ndarray]: Visualization results.
|
List[np.ndarray]: Visualization results.
|
||||||
"""
|
"""
|
||||||
if self.visualizer is None or (not show and img_out_dir == ''):
|
if not show and img_out_dir == '' and not return_vis:
|
||||||
return None
|
return None
|
||||||
|
if self.visualizer is None:
|
||||||
if getattr(self, 'visualizer') is None:
|
|
||||||
raise ValueError('Visualization needs the "visualizer" term'
|
raise ValueError('Visualization needs the "visualizer" term'
|
||||||
'defined in the config, but got None')
|
'defined in the config, but got None.')
|
||||||
|
|
||||||
self.visualizer.set_dataset_meta(**self.model.dataset_meta)
|
self.visualizer.set_dataset_meta(**self.model.dataset_meta)
|
||||||
self.visualizer.alpha = opacity
|
self.visualizer.alpha = opacity
|
||||||
|
|
||||||
@ -250,10 +255,11 @@ class MMSegInferencer(BaseInferencer):
|
|||||||
draw_gt=False,
|
draw_gt=False,
|
||||||
draw_pred=True,
|
draw_pred=True,
|
||||||
out_file=out_file)
|
out_file=out_file)
|
||||||
|
if return_vis:
|
||||||
results.append(self.visualizer.get_image())
|
results.append(self.visualizer.get_image())
|
||||||
self.num_visualized_imgs += 1
|
self.num_visualized_imgs += 1
|
||||||
|
|
||||||
return results
|
return results if return_vis else None
|
||||||
|
|
||||||
def postprocess(self,
|
def postprocess(self,
|
||||||
preds: PredType,
|
preds: PredType,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user