mmengine/mmengine/hooks/naive_visualization_hook.py
Mashiro 8770c6c7fc
[Refactor] Refactor data flow to make the interface more natural (#468)
* [Refactor]: modify interface of Visualizer.add_datasample (#365)

* [Refactor] Refactor data flow: refine `data_preprocessor`. (#359)

* refine data_preprocessor

* remove unused BATCH_DATA alias

* Fix type hints

* rename move_data to cast_data

* [Refactor] Refactor data flow: collate data in `collate_fn` of `DataLoader`  (#323)

* acollate data in dataloader

* fix docstring

* refine comment

* fix as comment

* refactor default collate and psedo collate

* foramt test file

* fix docstring

* fix as comment

* rename elem to data_item

* minor fix

* fix as comment

* [Refactor] Refactor data flow: `data_batch` argument of `Evaluator.process is a `dict` (#360)

* refine evaluator and metric

* compatible with new default collate

* replace default collate with pseudo

* Handle data_batch in metric

* fix unit test

* fix unit test

* fix unit test

* minor refine

* make data_batch optional

make data_batch optional

* rename outputs to predictions

* fix ut

* rename predictions to outputs

* fix docstring

* fix docstring

* fix unit test

* make outputs and data_batch to kwargs

* fix unit test

* keep signature of metric

* fix ut

* rename pred_sample arguments to data_sample(Visualizer)

* fix loop and ut

* [refactor]: Refactor model dataflow (#398)

* [Refactor] Refactor data flow: refine `data_preprocessor`. (#359)

* refine data_preprocessor

* remove unused BATCH_DATA alias

* Fix type hints

* rename move_data to cast_data

* refactor model data flow

tmp_commt

tmp commit

* make val_cfg and test_cfg optional

* roll back runner

* pass test mmdet

* fix as comment

fix as comment

fix ci in DataPreprocessor

* fix ut

* fix ut

* fix rebase main

* [Fix]: Fix test val ddp (#462)

* [Fix] Fix docstring and type hint of data flow (#463)

* Fix docstring of data flow

* change signature of hook

* fix unit test

* resolve conflicts

* fix lint
2022-08-24 22:04:55 +08:00

83 lines
3.2 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
from typing import Optional, Sequence, Tuple, Union
import cv2
import numpy as np
from mmengine.hooks import Hook
from mmengine.registry import HOOKS
from mmengine.utils.dl_utils import tensor2imgs
DATA_BATCH = Optional[Union[dict, tuple, list]]
# TODO: Due to interface changes, the current class
# functions incorrectly
@HOOKS.register_module()
class NaiveVisualizationHook(Hook):
"""Show or Write the predicted results during the process of testing.
Args:
interval (int): Visualization interval. Defaults to 1.
draw_gt (bool): Whether to draw the ground truth. Default to True.
draw_pred (bool): Whether to draw the predicted result.
Default to True.
"""
priority = 'NORMAL'
def __init__(self,
interval: int = 1,
draw_gt: bool = True,
draw_pred: bool = True):
self.draw_gt = draw_gt
self.draw_pred = draw_pred
self._interval = interval
def _unpad(self, input: np.ndarray, unpad_shape: Tuple[int,
int]) -> np.ndarray:
"""Unpad the input image.
Args:
input (np.ndarray): The image to unpad.
unpad_shape (tuple): The shape of image before padding.
Returns:
np.ndarray: The image before padding.
"""
unpad_width, unpad_height = unpad_shape
unpad_image = input[:unpad_height, :unpad_width]
return unpad_image
def after_test_iter(self,
runner,
batch_idx: int,
data_batch: DATA_BATCH = None,
outputs: Optional[Sequence] = None) -> None:
"""Show or Write the predicted results.
Args:
runner (Runner): The runner of the training process.
batch_idx (int): The index of the current batch in the test loop.
data_batch (dict or tuple or list, optional): Data from dataloader.
outputs (Sequence, optional): Outputs from model.
"""
if self.every_n_inner_iters(batch_idx, self._interval):
for data, output in zip(data_batch, outputs): # type: ignore
input = data['inputs']
data_sample = data['data_sample']
input = tensor2imgs(input,
**data_sample.get('img_norm_cfg',
dict()))[0]
# TODO We will implement a function to revert the augmentation
# in the future.
ori_shape = (data_sample.ori_width, data_sample.ori_height)
if 'pad_shape' in data_sample:
input = self._unpad(input,
data_sample.get('scale', ori_shape))
origin_image = cv2.resize(input, ori_shape)
name = osp.basename(data_sample.img_path)
runner.visualizer.add_datasample(name, origin_image,
data_sample, output,
self.draw_gt, self.draw_pred)