mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Feature] NaiveVisualizationHook (#98)
* [WIP] testvisualizationhook * add TestNaiveVisualizationHook * fix comment * unpad * batch imdenormalize * fix comment * fix comment
This commit is contained in:
parent
02ceaedb82
commit
3e0c064f49
@ -4,6 +4,7 @@ from .empty_cache_hook import EmptyCacheHook
|
|||||||
from .hook import Hook
|
from .hook import Hook
|
||||||
from .iter_timer_hook import IterTimerHook
|
from .iter_timer_hook import IterTimerHook
|
||||||
from .logger_hook import LoggerHook
|
from .logger_hook import LoggerHook
|
||||||
|
from .naive_visualization_hook import NaiveVisualizationHook
|
||||||
from .optimizer_hook import OptimizerHook
|
from .optimizer_hook import OptimizerHook
|
||||||
from .param_scheduler_hook import ParamSchedulerHook
|
from .param_scheduler_hook import ParamSchedulerHook
|
||||||
from .sampler_seed_hook import DistSamplerSeedHook
|
from .sampler_seed_hook import DistSamplerSeedHook
|
||||||
@ -12,5 +13,5 @@ from .sync_buffer_hook import SyncBuffersHook
|
|||||||
__all__ = [
|
__all__ = [
|
||||||
'Hook', 'IterTimerHook', 'DistSamplerSeedHook', 'ParamSchedulerHook',
|
'Hook', 'IterTimerHook', 'DistSamplerSeedHook', 'ParamSchedulerHook',
|
||||||
'OptimizerHook', 'SyncBuffersHook', 'EmptyCacheHook', 'CheckpointHook',
|
'OptimizerHook', 'SyncBuffersHook', 'EmptyCacheHook', 'CheckpointHook',
|
||||||
'LoggerHook'
|
'LoggerHook', 'NaiveVisualizationHook'
|
||||||
]
|
]
|
||||||
|
71
mmengine/hooks/naive_visualization_hook.py
Normal file
71
mmengine/hooks/naive_visualization_hook.py
Normal file
@ -0,0 +1,71 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
import os.path as osp
|
||||||
|
from typing import Any, Optional, Sequence, Tuple
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from mmengine.data import BaseDataSample
|
||||||
|
from mmengine.hooks import Hook
|
||||||
|
from mmengine.registry import HOOKS
|
||||||
|
from mmengine.utils.misc import tensor2imgs
|
||||||
|
|
||||||
|
|
||||||
|
@HOOKS.register_module()
|
||||||
|
class NaiveVisualizationHook(Hook):
|
||||||
|
"""Show or Write the predicted results during the process of testing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
interval (int): Visualization interval. Default: 1.
|
||||||
|
draw_gt (bool): Whether to draw the ground truth. Default to True.
|
||||||
|
draw_pred (bool): Whether to draw the predicted result.
|
||||||
|
Default to True.
|
||||||
|
"""
|
||||||
|
priority = 'NORMAL'
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
interval: int = 1,
|
||||||
|
draw_gt: bool = True,
|
||||||
|
draw_pred: bool = True):
|
||||||
|
self.draw_gt = draw_gt
|
||||||
|
self.draw_pred = draw_pred
|
||||||
|
self._interval = interval
|
||||||
|
|
||||||
|
def _unpad(self, input: np.ndarray, unpad_shape: Tuple[int,
|
||||||
|
int]) -> np.ndarray:
|
||||||
|
unpad_width, unpad_height = unpad_shape
|
||||||
|
unpad_image = input[:unpad_height, :unpad_width]
|
||||||
|
return unpad_image
|
||||||
|
|
||||||
|
def after_test_iter(
|
||||||
|
self,
|
||||||
|
runner,
|
||||||
|
data_batch: Optional[Sequence[Tuple[Any, BaseDataSample]]] = None,
|
||||||
|
outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
|
||||||
|
"""Show or Write the predicted results.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
runner (Runner): The runner of the training process.
|
||||||
|
data_batch (Sequence[Tuple[Any, BaseDataSample]], optional): Data
|
||||||
|
from dataloader. Defaults to None.
|
||||||
|
outputs (Sequence[BaseDataSample], optional): Outputs from model.
|
||||||
|
Defaults to None.
|
||||||
|
"""
|
||||||
|
if self.every_n_iters(runner, self._interval):
|
||||||
|
inputs, data_samples = data_batch # type: ignore
|
||||||
|
inputs = tensor2imgs(inputs,
|
||||||
|
**data_samples[0].get('img_norm_cfg', dict()))
|
||||||
|
for input, data_sample, output in zip(
|
||||||
|
inputs,
|
||||||
|
data_samples, # type: ignore
|
||||||
|
outputs): # type: ignore
|
||||||
|
# TODO We will implement a function to revert the augmentation
|
||||||
|
# in the future.
|
||||||
|
ori_shape = (data_sample.ori_width, data_sample.ori_height)
|
||||||
|
if 'pad_shape' in data_sample:
|
||||||
|
input = self._unpad(input,
|
||||||
|
data_sample.get('scale', ori_shape))
|
||||||
|
origin_image = cv2.resize(input, ori_shape)
|
||||||
|
name = osp.basename(data_sample.img_path)
|
||||||
|
runner.writer.add_image(name, origin_image, data_sample,
|
||||||
|
output, self.draw_gt, self.draw_pred)
|
@ -11,6 +11,8 @@ from inspect import getfullargspec
|
|||||||
from itertools import repeat
|
from itertools import repeat
|
||||||
from typing import Any, Callable, Optional, Sequence, Tuple, Type, Union
|
from typing import Any, Callable, Optional, Sequence, Tuple, Type, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from .parrots_wrapper import _BatchNorm, _InstanceNorm
|
from .parrots_wrapper import _BatchNorm, _InstanceNorm
|
||||||
@ -433,3 +435,46 @@ def is_norm(layer: nn.Module,
|
|||||||
|
|
||||||
all_norm_bases = (_BatchNorm, _InstanceNorm, nn.GroupNorm, nn.LayerNorm)
|
all_norm_bases = (_BatchNorm, _InstanceNorm, nn.GroupNorm, nn.LayerNorm)
|
||||||
return isinstance(layer, all_norm_bases)
|
return isinstance(layer, all_norm_bases)
|
||||||
|
|
||||||
|
|
||||||
|
def tensor2imgs(tensor: torch.Tensor,
|
||||||
|
mean: Optional[Tuple[float, float, float]] = None,
|
||||||
|
std: Optional[Tuple[float, float, float]] = None,
|
||||||
|
to_bgr: bool = True):
|
||||||
|
"""Convert tensor to 3-channel images or 1-channel gray images.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensor (torch.Tensor): Tensor that contains multiple images, shape (
|
||||||
|
N, C, H, W). :math:`C` can be either 3 or 1. If C is 3, the format
|
||||||
|
should be RGB.
|
||||||
|
mean (tuple[float], optional): Mean of images. If None,
|
||||||
|
(0, 0, 0) will be used for tensor with 3-channel,
|
||||||
|
while (0, ) for tensor with 1-channel. Defaults to None.
|
||||||
|
std (tuple[float], optional): Standard deviation of images. If None,
|
||||||
|
(1, 1, 1) will be used for tensor with 3-channel,
|
||||||
|
while (1, ) for tensor with 1-channel. Defaults to None.
|
||||||
|
to_bgr (bool): For the tensor with 3 channel, convert its format to
|
||||||
|
BGR. For the tensor with 1 channel, it must be False. Defaults to
|
||||||
|
True.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[np.ndarray]: A list that contains multiple images.
|
||||||
|
"""
|
||||||
|
|
||||||
|
assert torch.is_tensor(tensor) and tensor.ndim == 4
|
||||||
|
channels = tensor.size(1)
|
||||||
|
assert channels in [1, 3]
|
||||||
|
if mean is None:
|
||||||
|
mean = (0, ) * channels
|
||||||
|
if std is None:
|
||||||
|
std = (1, ) * channels
|
||||||
|
assert (channels == len(mean) == len(std) == 3) or \
|
||||||
|
(channels == len(mean) == len(std) == 1 and not to_bgr)
|
||||||
|
mean = tensor.new_tensor(mean).view(1, -1)
|
||||||
|
std = tensor.new_tensor(std).view(1, -1)
|
||||||
|
tensor = tensor.permute(0, 2, 3, 1) * std + mean
|
||||||
|
imgs = tensor.detach().cpu().numpy()
|
||||||
|
if to_bgr and channels == 3:
|
||||||
|
imgs = imgs[:, :, :, (2, 1, 0)] # RGB2BGR
|
||||||
|
imgs = [np.ascontiguousarray(img) for img in imgs]
|
||||||
|
return imgs
|
||||||
|
84
tests/test_hook/test_naive_visualization_hook.py
Normal file
84
tests/test_hook/test_naive_visualization_hook.py
Normal file
@ -0,0 +1,84 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
from unittest.mock import Mock
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from mmengine.data import BaseDataSample
|
||||||
|
from mmengine.hooks import NaiveVisualizationHook
|
||||||
|
|
||||||
|
|
||||||
|
class TestNaiveVisualizationHook:
|
||||||
|
|
||||||
|
def test_after_train_iter(self):
|
||||||
|
naive_visualization_hook = NaiveVisualizationHook()
|
||||||
|
Runner = Mock(iter=1)
|
||||||
|
Runner.writer.add_image = Mock()
|
||||||
|
inputs = torch.randn(1, 3, 15, 15)
|
||||||
|
# test with normalize, resize, pad
|
||||||
|
gt_datasamples = [
|
||||||
|
BaseDataSample(
|
||||||
|
metainfo=dict(
|
||||||
|
img_norm_cfg=dict(
|
||||||
|
mean=(0, 0, 0), std=(0.5, 0.5, 0.5), to_bgr=True),
|
||||||
|
scale=(10, 10),
|
||||||
|
pad_shape=(15, 15, 3),
|
||||||
|
ori_height=5,
|
||||||
|
ori_width=5,
|
||||||
|
img_path='tmp.jpg'))
|
||||||
|
]
|
||||||
|
pred_datasamples = [BaseDataSample()]
|
||||||
|
data_batch = (inputs, gt_datasamples)
|
||||||
|
naive_visualization_hook.after_test_iter(Runner, data_batch,
|
||||||
|
pred_datasamples)
|
||||||
|
# test with resize, pad
|
||||||
|
gt_datasamples = [
|
||||||
|
BaseDataSample(
|
||||||
|
metainfo=dict(
|
||||||
|
scale=(10, 10),
|
||||||
|
pad_shape=(15, 15, 3),
|
||||||
|
ori_height=5,
|
||||||
|
ori_width=5,
|
||||||
|
img_path='tmp.jpg')),
|
||||||
|
]
|
||||||
|
pred_datasamples = [BaseDataSample()]
|
||||||
|
data_batch = (inputs, gt_datasamples)
|
||||||
|
naive_visualization_hook.after_test_iter(Runner, data_batch,
|
||||||
|
pred_datasamples)
|
||||||
|
# test with only resize
|
||||||
|
gt_datasamples = [
|
||||||
|
BaseDataSample(
|
||||||
|
metainfo=dict(
|
||||||
|
scale=(15, 15),
|
||||||
|
ori_height=5,
|
||||||
|
ori_width=5,
|
||||||
|
img_path='tmp.jpg')),
|
||||||
|
]
|
||||||
|
pred_datasamples = [BaseDataSample()]
|
||||||
|
data_batch = (inputs, gt_datasamples)
|
||||||
|
naive_visualization_hook.after_test_iter(Runner, data_batch,
|
||||||
|
pred_datasamples)
|
||||||
|
|
||||||
|
# test with only pad
|
||||||
|
gt_datasamples = [
|
||||||
|
BaseDataSample(
|
||||||
|
metainfo=dict(
|
||||||
|
pad_shape=(15, 15, 3),
|
||||||
|
ori_height=5,
|
||||||
|
ori_width=5,
|
||||||
|
img_path='tmp.jpg')),
|
||||||
|
]
|
||||||
|
pred_datasamples = [BaseDataSample()]
|
||||||
|
data_batch = (inputs, gt_datasamples)
|
||||||
|
naive_visualization_hook.after_test_iter(Runner, data_batch,
|
||||||
|
pred_datasamples)
|
||||||
|
|
||||||
|
# test no transform
|
||||||
|
gt_datasamples = [
|
||||||
|
BaseDataSample(
|
||||||
|
metainfo=dict(ori_height=15, ori_width=15,
|
||||||
|
img_path='tmp.jpg')),
|
||||||
|
]
|
||||||
|
pred_datasamples = [BaseDataSample()]
|
||||||
|
data_batch = (inputs, gt_datasamples)
|
||||||
|
naive_visualization_hook.after_test_iter(Runner, data_batch,
|
||||||
|
pred_datasamples)
|
Loading…
x
Reference in New Issue
Block a user