mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
[Fix] Fix inferencer (#3333)
This commit is contained in:
parent
913fe3e91c
commit
f1fa61a48a
@ -59,7 +59,9 @@ class MMSegInferencer(BaseInferencer):
|
||||
|
||||
preprocess_kwargs: set = set()
|
||||
forward_kwargs: set = {'mode', 'out_dir'}
|
||||
visualize_kwargs: set = {'show', 'wait_time', 'img_out_dir', 'opacity'}
|
||||
visualize_kwargs: set = {
|
||||
'show', 'wait_time', 'img_out_dir', 'opacity', 'return_vis'
|
||||
}
|
||||
postprocess_kwargs: set = {'pred_out_dir', 'return_datasample'}
|
||||
|
||||
def __init__(self,
|
||||
@ -137,6 +139,7 @@ class MMSegInferencer(BaseInferencer):
|
||||
inputs: InputsType,
|
||||
return_datasamples: bool = False,
|
||||
batch_size: int = 1,
|
||||
return_vis: bool = False,
|
||||
show: bool = False,
|
||||
wait_time: int = 0,
|
||||
out_dir: str = '',
|
||||
@ -188,11 +191,13 @@ class MMSegInferencer(BaseInferencer):
|
||||
wait_time=wait_time,
|
||||
img_out_dir=img_out_dir,
|
||||
pred_out_dir=pred_out_dir,
|
||||
return_vis=return_vis,
|
||||
**kwargs)
|
||||
|
||||
def visualize(self,
|
||||
inputs: list,
|
||||
preds: List[dict],
|
||||
return_vis: bool = False,
|
||||
show: bool = False,
|
||||
wait_time: int = 0,
|
||||
img_out_dir: str = '',
|
||||
@ -213,12 +218,12 @@ class MMSegInferencer(BaseInferencer):
|
||||
Returns:
|
||||
List[np.ndarray]: Visualization results.
|
||||
"""
|
||||
if self.visualizer is None or (not show and img_out_dir == ''):
|
||||
if not show and img_out_dir == '' and not return_vis:
|
||||
return None
|
||||
|
||||
if getattr(self, 'visualizer') is None:
|
||||
if self.visualizer is None:
|
||||
raise ValueError('Visualization needs the "visualizer" term'
|
||||
'defined in the config, but got None')
|
||||
'defined in the config, but got None.')
|
||||
|
||||
self.visualizer.set_dataset_meta(**self.model.dataset_meta)
|
||||
self.visualizer.alpha = opacity
|
||||
|
||||
@ -250,10 +255,11 @@ class MMSegInferencer(BaseInferencer):
|
||||
draw_gt=False,
|
||||
draw_pred=True,
|
||||
out_file=out_file)
|
||||
results.append(self.visualizer.get_image())
|
||||
if return_vis:
|
||||
results.append(self.visualizer.get_image())
|
||||
self.num_visualized_imgs += 1
|
||||
|
||||
return results
|
||||
return results if return_vis else None
|
||||
|
||||
def postprocess(self,
|
||||
preds: PredType,
|
||||
|
Loading…
x
Reference in New Issue
Block a user