mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Feature] NaiveVisualizationHook (#98)
* [WIP] testvisualizationhook * add TestNaiveVisualizationHook * fix comment * unpad * batch imdenormalize * fix comment * fix comment
This commit is contained in:
parent
02ceaedb82
commit
3e0c064f49
@ -4,6 +4,7 @@ from .empty_cache_hook import EmptyCacheHook
|
||||
from .hook import Hook
|
||||
from .iter_timer_hook import IterTimerHook
|
||||
from .logger_hook import LoggerHook
|
||||
from .naive_visualization_hook import NaiveVisualizationHook
|
||||
from .optimizer_hook import OptimizerHook
|
||||
from .param_scheduler_hook import ParamSchedulerHook
|
||||
from .sampler_seed_hook import DistSamplerSeedHook
|
||||
@ -12,5 +13,5 @@ from .sync_buffer_hook import SyncBuffersHook
|
||||
__all__ = [
|
||||
'Hook', 'IterTimerHook', 'DistSamplerSeedHook', 'ParamSchedulerHook',
|
||||
'OptimizerHook', 'SyncBuffersHook', 'EmptyCacheHook', 'CheckpointHook',
|
||||
'LoggerHook'
|
||||
'LoggerHook', 'NaiveVisualizationHook'
|
||||
]
|
||||
|
71
mmengine/hooks/naive_visualization_hook.py
Normal file
71
mmengine/hooks/naive_visualization_hook.py
Normal file
@ -0,0 +1,71 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os.path as osp
|
||||
from typing import Any, Optional, Sequence, Tuple
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from mmengine.data import BaseDataSample
|
||||
from mmengine.hooks import Hook
|
||||
from mmengine.registry import HOOKS
|
||||
from mmengine.utils.misc import tensor2imgs
|
||||
|
||||
|
||||
@HOOKS.register_module()
|
||||
class NaiveVisualizationHook(Hook):
|
||||
"""Show or Write the predicted results during the process of testing.
|
||||
|
||||
Args:
|
||||
interval (int): Visualization interval. Default: 1.
|
||||
draw_gt (bool): Whether to draw the ground truth. Default to True.
|
||||
draw_pred (bool): Whether to draw the predicted result.
|
||||
Default to True.
|
||||
"""
|
||||
priority = 'NORMAL'
|
||||
|
||||
def __init__(self,
|
||||
interval: int = 1,
|
||||
draw_gt: bool = True,
|
||||
draw_pred: bool = True):
|
||||
self.draw_gt = draw_gt
|
||||
self.draw_pred = draw_pred
|
||||
self._interval = interval
|
||||
|
||||
def _unpad(self, input: np.ndarray, unpad_shape: Tuple[int,
|
||||
int]) -> np.ndarray:
|
||||
unpad_width, unpad_height = unpad_shape
|
||||
unpad_image = input[:unpad_height, :unpad_width]
|
||||
return unpad_image
|
||||
|
||||
def after_test_iter(
|
||||
self,
|
||||
runner,
|
||||
data_batch: Optional[Sequence[Tuple[Any, BaseDataSample]]] = None,
|
||||
outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
|
||||
"""Show or Write the predicted results.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data
|
||||
from dataloader. Defaults to None.
|
||||
outputs (Sequence[BaseDataSample], optional): Outputs from model.
|
||||
Defaults to None.
|
||||
"""
|
||||
if self.every_n_iters(runner, self._interval):
|
||||
inputs, data_samples = data_batch # type: ignore
|
||||
inputs = tensor2imgs(inputs,
|
||||
**data_samples[0].get('img_norm_cfg', dict()))
|
||||
for input, data_sample, output in zip(
|
||||
inputs,
|
||||
data_samples, # type: ignore
|
||||
outputs): # type: ignore
|
||||
# TODO We will implement a function to revert the augmentation
|
||||
# in the future.
|
||||
ori_shape = (data_sample.ori_width, data_sample.ori_height)
|
||||
if 'pad_shape' in data_sample:
|
||||
input = self._unpad(input,
|
||||
data_sample.get('scale', ori_shape))
|
||||
origin_image = cv2.resize(input, ori_shape)
|
||||
name = osp.basename(data_sample.img_path)
|
||||
runner.writer.add_image(name, origin_image, data_sample,
|
||||
output, self.draw_gt, self.draw_pred)
|
@ -11,6 +11,8 @@ from inspect import getfullargspec
|
||||
from itertools import repeat
|
||||
from typing import Any, Callable, Optional, Sequence, Tuple, Type, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .parrots_wrapper import _BatchNorm, _InstanceNorm
|
||||
@ -433,3 +435,46 @@ def is_norm(layer: nn.Module,
|
||||
|
||||
all_norm_bases = (_BatchNorm, _InstanceNorm, nn.GroupNorm, nn.LayerNorm)
|
||||
return isinstance(layer, all_norm_bases)
|
||||
|
||||
|
||||
def tensor2imgs(tensor: torch.Tensor,
|
||||
mean: Optional[Tuple[float, float, float]] = None,
|
||||
std: Optional[Tuple[float, float, float]] = None,
|
||||
to_bgr: bool = True):
|
||||
"""Convert tensor to 3-channel images or 1-channel gray images.
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): Tensor that contains multiple images, shape (
|
||||
N, C, H, W). :math:`C` can be either 3 or 1. If C is 3, the format
|
||||
should be RGB.
|
||||
mean (tuple[float], optional): Mean of images. If None,
|
||||
(0, 0, 0) will be used for tensor with 3-channel,
|
||||
while (0, ) for tensor with 1-channel. Defaults to None.
|
||||
std (tuple[float], optional): Standard deviation of images. If None,
|
||||
(1, 1, 1) will be used for tensor with 3-channel,
|
||||
while (1, ) for tensor with 1-channel. Defaults to None.
|
||||
to_bgr (bool): For the tensor with 3 channel, convert its format to
|
||||
BGR. For the tensor with 1 channel, it must be False. Defaults to
|
||||
True.
|
||||
|
||||
Returns:
|
||||
list[np.ndarray]: A list that contains multiple images.
|
||||
"""
|
||||
|
||||
assert torch.is_tensor(tensor) and tensor.ndim == 4
|
||||
channels = tensor.size(1)
|
||||
assert channels in [1, 3]
|
||||
if mean is None:
|
||||
mean = (0, ) * channels
|
||||
if std is None:
|
||||
std = (1, ) * channels
|
||||
assert (channels == len(mean) == len(std) == 3) or \
|
||||
(channels == len(mean) == len(std) == 1 and not to_bgr)
|
||||
mean = tensor.new_tensor(mean).view(1, -1)
|
||||
std = tensor.new_tensor(std).view(1, -1)
|
||||
tensor = tensor.permute(0, 2, 3, 1) * std + mean
|
||||
imgs = tensor.detach().cpu().numpy()
|
||||
if to_bgr and channels == 3:
|
||||
imgs = imgs[:, :, :, (2, 1, 0)] # RGB2BGR
|
||||
imgs = [np.ascontiguousarray(img) for img in imgs]
|
||||
return imgs
|
||||
|
84
tests/test_hook/test_naive_visualization_hook.py
Normal file
84
tests/test_hook/test_naive_visualization_hook.py
Normal file
@ -0,0 +1,84 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from unittest.mock import Mock
|
||||
|
||||
import torch
|
||||
|
||||
from mmengine.data import BaseDataSample
|
||||
from mmengine.hooks import NaiveVisualizationHook
|
||||
|
||||
|
||||
class TestNaiveVisualizationHook:
|
||||
|
||||
def test_after_train_iter(self):
|
||||
naive_visualization_hook = NaiveVisualizationHook()
|
||||
Runner = Mock(iter=1)
|
||||
Runner.writer.add_image = Mock()
|
||||
inputs = torch.randn(1, 3, 15, 15)
|
||||
# test with normalize, resize, pad
|
||||
gt_datasamples = [
|
||||
BaseDataSample(
|
||||
metainfo=dict(
|
||||
img_norm_cfg=dict(
|
||||
mean=(0, 0, 0), std=(0.5, 0.5, 0.5), to_bgr=True),
|
||||
scale=(10, 10),
|
||||
pad_shape=(15, 15, 3),
|
||||
ori_height=5,
|
||||
ori_width=5,
|
||||
img_path='tmp.jpg'))
|
||||
]
|
||||
pred_datasamples = [BaseDataSample()]
|
||||
data_batch = (inputs, gt_datasamples)
|
||||
naive_visualization_hook.after_test_iter(Runner, data_batch,
|
||||
pred_datasamples)
|
||||
# test with resize, pad
|
||||
gt_datasamples = [
|
||||
BaseDataSample(
|
||||
metainfo=dict(
|
||||
scale=(10, 10),
|
||||
pad_shape=(15, 15, 3),
|
||||
ori_height=5,
|
||||
ori_width=5,
|
||||
img_path='tmp.jpg')),
|
||||
]
|
||||
pred_datasamples = [BaseDataSample()]
|
||||
data_batch = (inputs, gt_datasamples)
|
||||
naive_visualization_hook.after_test_iter(Runner, data_batch,
|
||||
pred_datasamples)
|
||||
# test with only resize
|
||||
gt_datasamples = [
|
||||
BaseDataSample(
|
||||
metainfo=dict(
|
||||
scale=(15, 15),
|
||||
ori_height=5,
|
||||
ori_width=5,
|
||||
img_path='tmp.jpg')),
|
||||
]
|
||||
pred_datasamples = [BaseDataSample()]
|
||||
data_batch = (inputs, gt_datasamples)
|
||||
naive_visualization_hook.after_test_iter(Runner, data_batch,
|
||||
pred_datasamples)
|
||||
|
||||
# test with only pad
|
||||
gt_datasamples = [
|
||||
BaseDataSample(
|
||||
metainfo=dict(
|
||||
pad_shape=(15, 15, 3),
|
||||
ori_height=5,
|
||||
ori_width=5,
|
||||
img_path='tmp.jpg')),
|
||||
]
|
||||
pred_datasamples = [BaseDataSample()]
|
||||
data_batch = (inputs, gt_datasamples)
|
||||
naive_visualization_hook.after_test_iter(Runner, data_batch,
|
||||
pred_datasamples)
|
||||
|
||||
# test no transform
|
||||
gt_datasamples = [
|
||||
BaseDataSample(
|
||||
metainfo=dict(ori_height=15, ori_width=15,
|
||||
img_path='tmp.jpg')),
|
||||
]
|
||||
pred_datasamples = [BaseDataSample()]
|
||||
data_batch = (inputs, gt_datasamples)
|
||||
naive_visualization_hook.after_test_iter(Runner, data_batch,
|
||||
pred_datasamples)
|
Loading…
x
Reference in New Issue
Block a user