mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
* [Refactor]: modify interface of Visualizer.add_datasample (#365) * [Refactor] Refactor data flow: refine `data_preprocessor`. (#359) * refine data_preprocessor * remove unused BATCH_DATA alias * Fix type hints * rename move_data to cast_data * [Refactor] Refactor data flow: collate data in `collate_fn` of `DataLoader` (#323) * acollate data in dataloader * fix docstring * refine comment * fix as comment * refactor default collate and psedo collate * foramt test file * fix docstring * fix as comment * rename elem to data_item * minor fix * fix as comment * [Refactor] Refactor data flow: `data_batch` argument of `Evaluator.process is a `dict` (#360) * refine evaluator and metric * compatible with new default collate * replace default collate with pseudo * Handle data_batch in metric * fix unit test * fix unit test * fix unit test * minor refine * make data_batch optional make data_batch optional * rename outputs to predictions * fix ut * rename predictions to outputs * fix docstring * fix docstring * fix unit test * make outputs and data_batch to kwargs * fix unit test * keep signature of metric * fix ut * rename pred_sample arguments to data_sample(Visualizer) * fix loop and ut * [refactor]: Refactor model dataflow (#398) * [Refactor] Refactor data flow: refine `data_preprocessor`. (#359) * refine data_preprocessor * remove unused BATCH_DATA alias * Fix type hints * rename move_data to cast_data * refactor model data flow tmp_commt tmp commit * make val_cfg and test_cfg optional * roll back runner * pass test mmdet * fix as comment fix as comment fix ci in DataPreprocessor * fix ut * fix ut * fix rebase main * [Fix]: Fix test val ddp (#462) * [Fix] Fix docstring and type hint of data flow (#463) * Fix docstring of data flow * change signature of hook * fix unit test * resolve conflicts * fix lint
83 lines
3.2 KiB
Python
83 lines
3.2 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import os.path as osp
|
|
from typing import Optional, Sequence, Tuple, Union
|
|
|
|
import cv2
|
|
import numpy as np
|
|
|
|
from mmengine.hooks import Hook
|
|
from mmengine.registry import HOOKS
|
|
from mmengine.utils.dl_utils import tensor2imgs
|
|
|
|
DATA_BATCH = Optional[Union[dict, tuple, list]]
|
|
|
|
|
|
# TODO: Due to interface changes, the current class
|
|
# functions incorrectly
|
|
@HOOKS.register_module()
|
|
class NaiveVisualizationHook(Hook):
|
|
"""Show or Write the predicted results during the process of testing.
|
|
|
|
Args:
|
|
interval (int): Visualization interval. Defaults to 1.
|
|
draw_gt (bool): Whether to draw the ground truth. Default to True.
|
|
draw_pred (bool): Whether to draw the predicted result.
|
|
Default to True.
|
|
"""
|
|
priority = 'NORMAL'
|
|
|
|
def __init__(self,
|
|
interval: int = 1,
|
|
draw_gt: bool = True,
|
|
draw_pred: bool = True):
|
|
self.draw_gt = draw_gt
|
|
self.draw_pred = draw_pred
|
|
self._interval = interval
|
|
|
|
def _unpad(self, input: np.ndarray, unpad_shape: Tuple[int,
|
|
int]) -> np.ndarray:
|
|
"""Unpad the input image.
|
|
|
|
Args:
|
|
input (np.ndarray): The image to unpad.
|
|
unpad_shape (tuple): The shape of image before padding.
|
|
|
|
Returns:
|
|
np.ndarray: The image before padding.
|
|
"""
|
|
unpad_width, unpad_height = unpad_shape
|
|
unpad_image = input[:unpad_height, :unpad_width]
|
|
return unpad_image
|
|
|
|
def after_test_iter(self,
|
|
runner,
|
|
batch_idx: int,
|
|
data_batch: DATA_BATCH = None,
|
|
outputs: Optional[Sequence] = None) -> None:
|
|
"""Show or Write the predicted results.
|
|
|
|
Args:
|
|
runner (Runner): The runner of the training process.
|
|
batch_idx (int): The index of the current batch in the test loop.
|
|
data_batch (dict or tuple or list, optional): Data from dataloader.
|
|
outputs (Sequence, optional): Outputs from model.
|
|
"""
|
|
if self.every_n_inner_iters(batch_idx, self._interval):
|
|
for data, output in zip(data_batch, outputs): # type: ignore
|
|
input = data['inputs']
|
|
data_sample = data['data_sample']
|
|
input = tensor2imgs(input,
|
|
**data_sample.get('img_norm_cfg',
|
|
dict()))[0]
|
|
# TODO We will implement a function to revert the augmentation
|
|
# in the future.
|
|
ori_shape = (data_sample.ori_width, data_sample.ori_height)
|
|
if 'pad_shape' in data_sample:
|
|
input = self._unpad(input,
|
|
data_sample.get('scale', ori_shape))
|
|
origin_image = cv2.resize(input, ori_shape)
|
|
name = osp.basename(data_sample.img_path)
|
|
runner.visualizer.add_datasample(name, origin_image,
|
|
data_sample, output,
|
|
self.draw_gt, self.draw_pred)
|