[Feature] NaiveVisualizationHook (#98)

* [WIP] testvisualizationhook

* add TestNaiveVisualizationHook

* fix comment

* unpad

* batch imdenormalize

* fix comment

* fix comment
This commit is contained in:
liukuikun 2022-03-10 17:22:31 +08:00 committed by GitHub
parent 02ceaedb82
commit 3e0c064f49
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 202 additions and 1 deletions

View File

@ -4,6 +4,7 @@ from .empty_cache_hook import EmptyCacheHook
from .hook import Hook
from .iter_timer_hook import IterTimerHook
from .logger_hook import LoggerHook
from .naive_visualization_hook import NaiveVisualizationHook
from .optimizer_hook import OptimizerHook
from .param_scheduler_hook import ParamSchedulerHook
from .sampler_seed_hook import DistSamplerSeedHook
@ -12,5 +13,5 @@ from .sync_buffer_hook import SyncBuffersHook
__all__ = [
'Hook', 'IterTimerHook', 'DistSamplerSeedHook', 'ParamSchedulerHook',
'OptimizerHook', 'SyncBuffersHook', 'EmptyCacheHook', 'CheckpointHook',
'LoggerHook'
'LoggerHook', 'NaiveVisualizationHook'
]

View File

@ -0,0 +1,71 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
from typing import Any, Optional, Sequence, Tuple
import cv2
import numpy as np
from mmengine.data import BaseDataSample
from mmengine.hooks import Hook
from mmengine.registry import HOOKS
from mmengine.utils.misc import tensor2imgs
@HOOKS.register_module()
class NaiveVisualizationHook(Hook):
"""Show or Write the predicted results during the process of testing.
Args:
interval (int): Visualization interval. Default: 1.
draw_gt (bool): Whether to draw the ground truth. Default to True.
draw_pred (bool): Whether to draw the predicted result.
Default to True.
"""
priority = 'NORMAL'
def __init__(self,
interval: int = 1,
draw_gt: bool = True,
draw_pred: bool = True):
self.draw_gt = draw_gt
self.draw_pred = draw_pred
self._interval = interval
def _unpad(self, input: np.ndarray, unpad_shape: Tuple[int,
int]) -> np.ndarray:
unpad_width, unpad_height = unpad_shape
unpad_image = input[:unpad_height, :unpad_width]
return unpad_image
def after_test_iter(
self,
runner,
data_batch: Optional[Sequence[Tuple[Any, BaseDataSample]]] = None,
outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
"""Show or Write the predicted results.
Args:
runner (Runner): The runner of the training process.
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data
from dataloader. Defaults to None.
outputs (Sequence[BaseDataSample], optional): Outputs from model.
Defaults to None.
"""
if self.every_n_iters(runner, self._interval):
inputs, data_samples = data_batch # type: ignore
inputs = tensor2imgs(inputs,
**data_samples[0].get('img_norm_cfg', dict()))
for input, data_sample, output in zip(
inputs,
data_samples, # type: ignore
outputs): # type: ignore
# TODO We will implement a function to revert the augmentation
# in the future.
ori_shape = (data_sample.ori_width, data_sample.ori_height)
if 'pad_shape' in data_sample:
input = self._unpad(input,
data_sample.get('scale', ori_shape))
origin_image = cv2.resize(input, ori_shape)
name = osp.basename(data_sample.img_path)
runner.writer.add_image(name, origin_image, data_sample,
output, self.draw_gt, self.draw_pred)

View File

@ -11,6 +11,8 @@ from inspect import getfullargspec
from itertools import repeat
from typing import Any, Callable, Optional, Sequence, Tuple, Type, Union
import numpy as np
import torch
import torch.nn as nn
from .parrots_wrapper import _BatchNorm, _InstanceNorm
@ -433,3 +435,46 @@ def is_norm(layer: nn.Module,
all_norm_bases = (_BatchNorm, _InstanceNorm, nn.GroupNorm, nn.LayerNorm)
return isinstance(layer, all_norm_bases)
def tensor2imgs(tensor: torch.Tensor,
mean: Optional[Tuple[float, float, float]] = None,
std: Optional[Tuple[float, float, float]] = None,
to_bgr: bool = True):
"""Convert tensor to 3-channel images or 1-channel gray images.
Args:
tensor (torch.Tensor): Tensor that contains multiple images, shape (
N, C, H, W). :math:`C` can be either 3 or 1. If C is 3, the format
should be RGB.
mean (tuple[float], optional): Mean of images. If None,
(0, 0, 0) will be used for tensor with 3-channel,
while (0, ) for tensor with 1-channel. Defaults to None.
std (tuple[float], optional): Standard deviation of images. If None,
(1, 1, 1) will be used for tensor with 3-channel,
while (1, ) for tensor with 1-channel. Defaults to None.
to_bgr (bool): For the tensor with 3 channel, convert its format to
BGR. For the tensor with 1 channel, it must be False. Defaults to
True.
Returns:
list[np.ndarray]: A list that contains multiple images.
"""
assert torch.is_tensor(tensor) and tensor.ndim == 4
channels = tensor.size(1)
assert channels in [1, 3]
if mean is None:
mean = (0, ) * channels
if std is None:
std = (1, ) * channels
assert (channels == len(mean) == len(std) == 3) or \
(channels == len(mean) == len(std) == 1 and not to_bgr)
mean = tensor.new_tensor(mean).view(1, -1)
std = tensor.new_tensor(std).view(1, -1)
tensor = tensor.permute(0, 2, 3, 1) * std + mean
imgs = tensor.detach().cpu().numpy()
if to_bgr and channels == 3:
imgs = imgs[:, :, :, (2, 1, 0)] # RGB2BGR
imgs = [np.ascontiguousarray(img) for img in imgs]
return imgs

View File

@ -0,0 +1,84 @@
# Copyright (c) OpenMMLab. All rights reserved.
from unittest.mock import Mock
import torch
from mmengine.data import BaseDataSample
from mmengine.hooks import NaiveVisualizationHook
class TestNaiveVisualizationHook:
def test_after_train_iter(self):
naive_visualization_hook = NaiveVisualizationHook()
Runner = Mock(iter=1)
Runner.writer.add_image = Mock()
inputs = torch.randn(1, 3, 15, 15)
# test with normalize, resize, pad
gt_datasamples = [
BaseDataSample(
metainfo=dict(
img_norm_cfg=dict(
mean=(0, 0, 0), std=(0.5, 0.5, 0.5), to_bgr=True),
scale=(10, 10),
pad_shape=(15, 15, 3),
ori_height=5,
ori_width=5,
img_path='tmp.jpg'))
]
pred_datasamples = [BaseDataSample()]
data_batch = (inputs, gt_datasamples)
naive_visualization_hook.after_test_iter(Runner, data_batch,
pred_datasamples)
# test with resize, pad
gt_datasamples = [
BaseDataSample(
metainfo=dict(
scale=(10, 10),
pad_shape=(15, 15, 3),
ori_height=5,
ori_width=5,
img_path='tmp.jpg')),
]
pred_datasamples = [BaseDataSample()]
data_batch = (inputs, gt_datasamples)
naive_visualization_hook.after_test_iter(Runner, data_batch,
pred_datasamples)
# test with only resize
gt_datasamples = [
BaseDataSample(
metainfo=dict(
scale=(15, 15),
ori_height=5,
ori_width=5,
img_path='tmp.jpg')),
]
pred_datasamples = [BaseDataSample()]
data_batch = (inputs, gt_datasamples)
naive_visualization_hook.after_test_iter(Runner, data_batch,
pred_datasamples)
# test with only pad
gt_datasamples = [
BaseDataSample(
metainfo=dict(
pad_shape=(15, 15, 3),
ori_height=5,
ori_width=5,
img_path='tmp.jpg')),
]
pred_datasamples = [BaseDataSample()]
data_batch = (inputs, gt_datasamples)
naive_visualization_hook.after_test_iter(Runner, data_batch,
pred_datasamples)
# test no transform
gt_datasamples = [
BaseDataSample(
metainfo=dict(ori_height=15, ori_width=15,
img_path='tmp.jpg')),
]
pred_datasamples = [BaseDataSample()]
data_batch = (inputs, gt_datasamples)
naive_visualization_hook.after_test_iter(Runner, data_batch,
pred_datasamples)