diff --git a/docs/zh_cn/tutorials/visualization.md b/docs/zh_cn/tutorials/visualization.md
deleted file mode 100644
index 80acabcc..00000000
--- a/docs/zh_cn/tutorials/visualization.md
+++ /dev/null
@@ -1,301 +0,0 @@
-# 可视化 (Visualization)
-
-## 概述
-
-**(1) 总体介绍**
-
-可视化可以给深度学习的模型训练和测试过程提供直观解释。在 OpenMMLab 算法库中,我们期望可视化功能的设计能满足以下需求:
-
-- 提供丰富的开箱即用可视化功能,能够满足大部分计算机视觉可视化任务
-- 高扩展性,可视化功能通常多样化,应该能够通过简单扩展实现定制需求
-- 能够在训练和测试流程的任意点位进行可视化
-- OpenMMLab 各个算法库具有统一可视化接口,利于用户理解和维护
-
-基于上述需求,OpenMMLab 2.0 引入了绘制对象 Visualizer 和写端对象 Writer 的概念
-
-- **Visualizer 负责单张图片的绘制功能**
-
- MMEngine 提供了以 Matplotlib 库为绘制后端的 `Visualizer` 类,其具备如下功能:
-
- - 提供了一系列和视觉任务无关的基础方法,例如 `draw_bboxes` 和 `draw_texts` 等
- - 各个基础方法支持链式调用,方便叠加绘制显示
- - 通过 `draw_featmap` 提供绘制特征图功能
-
- 各个下游算法库可以继承 `Visualizer` 并在 `draw` 接口中实现所需的可视化功能,例如 MMDetection 中的 `DetVisualizer` 继承自 `Visualizer` 并在 `draw` 接口中实现可视化检测框、实例掩码和语义分割图等功能。Visualizer 类的 UML 关系图如下
-
-
-

-
-
-- **Writer 负责将各类数据写入到指定后端**
-
- 为了统一接口调用,MMEngine 提供了统一的抽象类 `BaseWriter`,和一些常用的 Writer 如 `LocalWriter` 来支持将数据写入本地,`TensorboardWriter` 来支持将数据写入 Tensorboard,`WandbWriter` 来支持将数据写入 Wandb。用户也可以自定义 Writer 来将数据写入自定义后端。写入的数据可以是图片,模型结构图,标量如模型精度指标等。
-
- 考虑到在训练或者测试过程中可能同时存在多个 Writer 对象,例如同时想进行本地和远程端写数据,为此设计了 `ComposedWriter` 负责管理所有运行中实例化的 Writer 对象,其会自动管理所有 Writer 对象,并遍历调用所有 Writer 对象的方法。Writer 类的 UML 关系图如下
-
-

-
-
-**(2) Writer 和 Visualizer 关系**
-
-Writer 对象的核心功能是写各类数据到指定后端中,例如写图片、写模型图、写超参和写模型精度指标等,后端可以指定为本地存储、Wandb 和 Tensorboard 等等。在写图片过程中,通常希望能够将预测结果或者标注结果绘制到图片上,然后再进行写操作,为此在 Writer 内部维护了 Visualizer 对象,将 Visualizer 作为 Writer 的一个属性。需要注意的是:
-
-- 只有调用了 Writer 中的 `add_image` 写图片功能时候才可能会用到 Visualizer 对象,其余接口和 Visualizer 没有关系
-- 考虑到某些 Writer 后端本身就具备绘制功能例如 `WandbWriter`,此时 `WandbWriter` 中的 Visualizer 属性就是可选的,如果用户在初始化时候传入了 Visualizer 对象,则在 `add_image` 时候会调用 Visualizer 对象,否则会直接调用 Wandb 本身 API 进行图片绘制
-- `LocalWriter` 和 `TensorboardWriter` 由于绘制功能单一,目前强制由 Visualizer 对象绘制,所以这两个 Writer 必须传入 Visualizer 或者子类对象
-
-`WandbWriter` 的一个简略的演示代码如下
-
-```python
-# 为了方便理解,没有继承 BaseWriter
-class WandbWriter:
- def __init__(self, visualizer=None):
- self._visualizer = None
- if visualizer:
- # 示例配置 visualizer=dict(type='DetVisualizer')
- self._visualizer = VISUALIZERS.build(visualizer)
-
- @property
- def visualizer(self):
- return self._visualizer
-
- def add_image(self, name, image, gt_sample=None, pred_sample=None, draw_gt=True, draw_pred=True, step=0, **kwargs):
- if self._visualize:
- self._visualize.draw(image, gt_sample, pred_sample, draw_gt, draw_pred)
- # 调用 Writer API 写图片到后端
- self.wandb.log({name: self.visualizer.get_image()}, ...)
- ...
- else:
- # 调用 Writer API 汇总并写图片到后端
- ...
-
- def add_scalar(self, name, value, step):
- self.wandb.log({name: value}, ...)
-```
-
-
-## 绘制对象 Visualizer
-
-绘制对象 Visualizer 负责单张图片的各类绘制功能,默认绘制后端为 Matplotlib。为了统一 OpenMMLab 各个算法库的可视化接口,MMEngine 定义提供了基础绘制功能的 `Visualizer` 类,下游库可以继承 `Visualizer` 并实现 `draw` 接口来满足自己的绘制需求。
-
-### Visualizer
-
-`Visualizer` 提供了基础而通用的绘制功能,主要接口如下:
-
-**(1) 绘制无关的功能性接口**
-
-- [set_image](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.Visualizer.set_image) 设置原始图片数据,默认输入图片格式为 RGB
-- [get_image](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.Visualizer.get_image) 获取绘制后的 Numpy 格式图片数据,默认输出格式为 RGB
-- [show](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.Visualizer.show) 可视化
-- [register_task](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.Visualizer.register_task) 注册绘制函数(其作用在 *自定义 Visualizer* 小节描述)
-
-**(2) 绘制相关接口**
-
-- [draw](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.Visualizer.draw) 用户使用的抽象绘制接口
-- [draw_featmap](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.Visualizer.draw_featmap) 绘制特征图
-- [draw_bboxes](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.Visualizer.draw_bboxes) 绘制单个或者多个边界框
-- [draw_texts](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.Visualizer.draw_texts) 绘制单个或者多个文本框
-- [draw_lines](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.Visualizer.lines) 绘制单个或者多个线段
-- [draw_circles](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.Visualizer.draw_circles) 绘制单个或者多个圆
-- [draw_polygons](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.Visualizer.draw_polygons) 绘制单个或者多个多边形
-- [draw_binary_masks](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.Visualizer.draw_binary_mask) 绘制单个或者多个二值掩码
-
-用户除了可以单独调用 `Visualizer` 中基础绘制接口,同时也提供了链式调用功能和特征图可视化功能。`draw` 函数是抽象接口,内部没有任何实现,继承了 Visualizer 的类可以实现该接口,从而对外提供统一的绘制功能,而 `draw_xxx` 等目的是提供最基础的绘制功能,用户一般无需重写。
-
-**(1) 链式调用**
-
-例如用户先绘制边界框,在此基础上绘制文本,绘制线段,则调用过程为:
-
-```python
-visualizer.set_image(image)
-visualizer.draw_bboxes(...).draw_texts(...).draw_lines(...)
-visualizer.show() # 可视化绘制结果
-```
-
-**(2) 可视化特征图**
-
-特征图可视化是一个常见的功能,通过调用 `draw_featmap` 可以直接可视化特征图,其参数定义为:
-
-```python
-@staticmethod
-def draw_featmap(tensor_chw: torch.Tensor, # 输入格式要求为 CHW
- image: Optional[np.ndarray] = None, # 如果同时输入了 image 数据,则特征图会叠加到 image 上绘制
- mode: Optional[str] = 'mean', # 多个通道压缩为单通道的策略
- topk: int = 10, # 可选择激活度最高的 topk 个特征图显示
- arrangement: Tuple[int, int] = (5, 2), # 多通道展开为多张图时候布局
- alpha: float = 0.3) -> np.ndarray: # 图片和特征图绘制的叠加比例
-```
-
-特征图可视化功能较多,目前不支持 Batch 输入
-
-- mode 不是 None,topk 无效,会将多个通道输出采用 mode 模式函数压缩为单通道,变成单张图片显示,目前 mode 仅支持 None、'mean'、'max' 和 'min' 参数输入
-- mode 是 None,topk 有效,如果 topk 不是 -1,则会按照激活度排序选择 topk 个通道显示,此时可以通过 arrangement 参数指定显示的布局
-- mode 是 None,topk 有效,如果 `topk = -1`,此时通道 C 必须是 1 或者 3 表示输入数据是图片,可以直接显示,否则报错提示用户应该设置 mode 来压缩通道
-
-```python
-featmap=visualizer.draw_featmap(tensor_chw,image)
-```
-
-### 自定义 Visualizer
-
-自定义的 Visualizer 中大部分情况下只需要实现 `get_image` 和 `draw` 接口。`draw` 是最高层的用户调用接口,`draw` 接口负责所有绘制功能,例如绘制检测框、检测掩码 mask 和 检测语义分割图等等。依据任务的不同,`draw` 接口实现的复杂度也不同。
-
-以目标检测可视化需求为例,可能需要同时绘制边界框 bbox、掩码 mask 和语义分割图 seg_map,如果如此多功能全部写到 `draw` 方法中会难以理解和维护。为了解决该问题,`Visualizer` 基于 OpenMMLab 2.0 抽象数据接口规范支持了 `register_task` 函数。假设 MMDetection 中需要同时绘制预测结果中的 instances 和 sem_seg,可以在 MMDetection 的 `DetVisualizer` 中实现 `draw_instances` 和 `draw_sem_seg` 两个方法,用于绘制预测实例和预测语义分割图, 我们希望只要输入数据中存在 instances 或 sem_seg 时候,对应的两个绘制函数 `draw_instances` 和 `draw_sem_seg` 能够自动被调用,而用户不需要手动调用。为了实现上述功能,可以通过在 `draw_instances` 和 `draw_sem_seg` 两个函数加上 `@Visualizer.register_task` 装饰器,此时 `task_dict` 中就会存储字符串和函数的映射关系,在调用 `draw` 方法时候就可以通过 `self.task_dict`获取到已经被注册的函数。一个简略的实现如下所示
-
-```python
-class DetVisualizer(Visualizer):
-
- def draw(self, image, gt_sample=None, pred_sample=None, draw_gt=True, draw_pred=True):
- # 将图片和 matplotlib 布局关联
- self.set_image(image)
-
- if draw_gt:
- # self.task_dict 内部存储如下信息:
- # dict(instances=draw_instance 方法,sem_seg=draw_sem_seg 方法)
- for task in self.task_dict:
- task_attr = 'gt_' + task
- if task_attr in gt_sample:
- self.task_dict[task](self, gt_sample[task_attr], 'gt')
- if draw_pred:
- for task in self.task_dict:
- task_attr = 'pred_' + task
- if task_attr in pred_sample:
- self.task_dict[task](self, pred_sample[task_attr], 'pred')
-
- # data_type 用于区分当前绘制的内容是标注还是预测结果
- @Visualizer.register_task('instances')
- def draw_instance(self, instances, data_type):
- ...
-
- # data_type 用于区分当前绘制的内容是标注还是预测结果
- @Visualizer.register_task('sem_seg')
- def draw_sem_seg(self, pixel_data, data_type):
- ...
-```
-
-注意:是否使用 `register_task` 装饰器函数不是必须的,如果用户自定义 Visualizer,并且 `draw` 实现非常简单,则无需考虑 `register_task`。
-
-在使用 Jupyter notebook 或者其他地方不需要写数据到指定后端的情形下,用户可以自己实例化 visualizer。一个简单的例子如下
-
-```python
-# 实例化 visualizer
-visualizer=dict(type='DetVisualizer')
-visualizer = VISUALIZERS.build(visualizer)
-visualizer.draw(image, datasample)
-visualizer.show() # 可视化绘制结果
-```
-
-## 写端 Writer
-
-Visualizer 只实现了单张图片的绘制功能,但是在训练或者测试过程中,对一些关键指标或者模型训练超参的记录非常重要,此功能通过写端 Writer 实现。为了统一接口调用,MMEngine 提供了统一的抽象类 `BaseWriter`,和一些常用的 Writer 如 `LocalWriter` 、`TensorboardWriter` 和 `WandbWriter` 。
-
-### BaseWriter
-
-BaseWriter 定义了对外调用的接口规范,主要接口和属性如下:
-
-- [add_params](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.writer.BaseWriter.add_params) 写超参到特定后端,常见的训练超参如初始学习率 LR、权重衰减系数和批大小等等
-- [add_graph](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.writer.BaseWriter.add_graph) 写模型图到特定后端
-- [add_image](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.writer.BaseWriter.add_image) 写图片到特定后端
-- [add_scalar](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.writer.BaseWriter.add_scalar) 写标量到特定后端
-- [add_scalars](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.writer.BaseWriter.add_scalars) 一次性写多个标量到特定后端
-- [visualizer](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.writer.BaseWriter.visualizer) 绘制对象
-- [experiment](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.writer.BaseWriter.experiment) 写后端对象,例如 Wandb 对象和 Tensorboard 对象
-
-`BaseWriter` 定义了 5 个常见的写数据接口,考虑到某些写后端功能非常强大,例如 Wandb,其具备写表格,写视频等等功能,针对这类需求用户可以直接获取 experiment 对象,然后调用写后端对象本身的 API 即可。
-
-### LocalWriter、TensorboardWriter 和 WandbWriter
-
-`LocalWriter` 提供了将数据写入到本地磁盘功能。如果用户需要写图片到硬盘,则**必须要通过初始化参数提供 Visualizer对象**。其典型用法为:
-
-```python
-# 配置文件
-writer=dict(type='LocalWriter', save_dir='demo_dir', visualizer=dict(type='DetVisualizer'))
-# 实例化和调用
-local_writer=WRITERS.build(writer)
-# 写模型精度值
-local_writer.add_scalar('mAP', 0.9)
-local_writer.add_scalars({'loss': 1.2, 'acc': 0.8})
-# 写超参
-local_writer.add_params(dict(lr=0.1, mode='linear'))
-# 写图片
-local_writer.add_image('demo_image', image, datasample)
-```
-
-如果用户有自定义绘制需求,则可以通过获取内部的 visualizer 属性来实现,如下所示
-
-```python
-# 配置文件
-writer=dict(type='LocalWriter', save_dir='demo_dir', visualizer=dict(type='DetVisualizer'))
-# 实例化和调用
-local_writer=WRITERS.build(writer)
-# 写图片
-local_writer.visualizer.draw_bboxes(np.array([0, 0, 1, 1]))
-local_writer.add_image('img', local_writer.visualizer.get_image())
-
-# 绘制特征图并保存到本地
-featmap_image=local_writer.visualizer.draw_featmap(tensor_chw)
-local_writer.add_image('featmap', featmap_image)
-```
-
-`TensorboardWriter` 提供了将各类数据写入到 Tensorboard 功能,其用法和 LocalWriter 非常类似。 注意如果用户需要写图片到 Tensorboard,则**必须要通过初始化参数提供 Visualizer对象**。
-
-`WandbWriter` 提供了将各类数据写入到 Wandb 功能。考虑到 Wandb 本身具备强大的图片功能,在调用 `WandbWriter` 的 `add_image` 方法时 Visualizer 对象是可选的,如果用户指定了 Visualizer 对象,则会调用 Visualizer 对象的绘制方法,否则直接调用 Wandb 自带的图片处理功能。
-
-## 组合写端 ComposedWriter
-
-考虑到在训练或者测试过程中,可能需要同时调用多个 Writer,例如想同时写到本地和 Wandb 端,为此设计了对外的 `ComposedWriter` 类,在训练或者测试过程中 `ComposedWriter` 会依次调用各个 Writer 的接口,其接口和 `BaseWriter` 一致,主要接口如下:
-
-- [add_params](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.writer.ComposedWriter.add_params) 写超参到所有已经加入的后端中,常见的训练超参如初始学习率 LR、权重衰减系数和批大小等等
-- [add_graph](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.writer.ComposedWriter.add_graph) 写模型图到所有已经加入的后端中
-- [add_image](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.writer.ComposedWriter.add_image) 写图片到所有已经加入的后端中
-- [add_scalar](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.writer.ComposedWriter.add_scalar) 写标量到所有已经加入的后端中
-- [add_scalars](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.writer.ComposedWriter.add_scalars) 一次性写多个标量到所有已经加入的后端中
-- [get_writer](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.writer.ComposedWriter.get_writer) 获取指定索引的 Writer,任何一个 Writer 中包括了 experiment 和 visualizer 属性
-- [get_experiment](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.writer.ComposedWriter.get_experiment) 获取指定索引的 experiment
-- [get_visualizer](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.writer.ComposedWriter.get_visualizer) 获取指定索引的 visualizer
-- [close](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.writer.ComposedWriter.close) 调用所有 Writer 的 close 方法
-
-为了让用户可以在代码的任意位置进行数据可视化,`ComposedWriter` 类继承至 [全局可访问基类 BaseGlobalAccessible](./logging.md/#全局可访问基类baseglobalaccessible)。一旦继承了全局可访问基类, 用户就可以通过调用 `ComposedWriter` 对象的 `get_instance` 来获取全局对象。其基本用法如下
-
-```python
-# 创建实例
-writers=[dict(type='LocalWriter', save_dir='temp_dir', visualizer=dict(type='DetVisualizer')), dict(type='WandbWriter')]
-
-ComposedWriter.create_instance('composed_writer', writers=writers)
-```
-
-一旦创建实例后,可以在代码任意位置获取 `ComposedWriter` 对象
-
-```python
-composed_writer=ComposedWriter.get_instance('composed_writer')
-
-# 写模型精度值
-composed_writer.add_scalar('mAP', 0.9)
-composed_writer.add_scalars({'loss': 1.2, 'acc': 0.8})
-# 写超参
-composed_writer.add_params(dict(lr=0.1, mode='linear'))
-# 写图片
-composed_writer.add_image('demo_image', image, datasample)
-# 写模型图
-composed_writer.add_graph(model, input_array)
-```
-
-对于一些用户需要的自定义绘制需求或者上述接口无法满足的需求,用户可以通过 `get_xxx` 方法获取具体对象来实现特定需求
-
-```python
-composed_writer=ComposedWriter.get_instance('composed_writer')
-
-# 绘制特征图,获取 LocalWriter 中的 visualizer
-visualizer=composed_writer.get_visualizer(0)
-featmap_image=visualizer.draw_featmap(tensor_chw)
-composed_writer.add_image('featmap', featmap_image)
-
-# 扩展 add 功能,例如利用 Wandb 对象绘制表格
-wandb=composed_writer.get_experiment(1)
-val_table = wandb.Table(data=my_data, columns=column_names)
-wandb.log({'my_val_table': val_table})
-
-# 配置中存在多个 Writer,在不想改动配置情况下只使用 LocalWriter
-local_writer=composed_writer.get_writer(0)
-local_writer.add_image('demo_image', image, datasample)
-```
diff --git a/mmengine/data/instance_data.py b/mmengine/data/instance_data.py
index 76d5e996..2c4932f7 100644
--- a/mmengine/data/instance_data.py
+++ b/mmengine/data/instance_data.py
@@ -7,6 +7,9 @@ import torch
from .base_data_element import BaseDataElement
+IndexType = Union[str, slice, int, torch.LongTensor, torch.cuda.LongTensor,
+ torch.BoolTensor, torch.cuda.BoolTensor, np.long, np.bool]
+
# Modified from
# https://github.com/open-mmlab/mmdetection/blob/master/mmdet/core/data_structures/instance_data.py # noqa
@@ -87,9 +90,7 @@ class InstanceData(BaseDataElement):
f'{len(self)} '
super().__setattr__(name, value)
- def __getitem__(
- self, item: Union[str, slice, int, torch.LongTensor, torch.BoolTensor]
- ) -> 'InstanceData':
+ def __getitem__(self, item: IndexType) -> 'InstanceData':
"""
Args:
item (str, obj:`slice`,
@@ -102,7 +103,8 @@ class InstanceData(BaseDataElement):
assert len(self) > 0, ' This is a empty instance'
assert isinstance(
- item, (str, slice, int, torch.LongTensor, torch.BoolTensor))
+ item, (str, slice, int, torch.LongTensor, torch.cuda.LongTensor,
+ torch.BoolTensor, torch.cuda.BoolTensor, np.bool, np.long))
if isinstance(item, str):
return getattr(self, item)
@@ -118,7 +120,7 @@ class InstanceData(BaseDataElement):
if isinstance(item, torch.Tensor):
assert item.dim() == 1, 'Only support to get the' \
' values along the first dimension.'
- if isinstance(item, torch.BoolTensor):
+ if isinstance(item, (torch.BoolTensor, torch.cuda.BoolTensor)):
assert len(item) == len(self), f'The shape of the' \
f' input(BoolTensor)) ' \
f'{len(item)} ' \
@@ -136,7 +138,8 @@ class InstanceData(BaseDataElement):
elif isinstance(v, list):
r_list = []
# convert to indexes from boolTensor
- if isinstance(item, torch.BoolTensor):
+ if isinstance(item,
+ (torch.BoolTensor, torch.cuda.BoolTensor)):
indexes = torch.nonzero(item).view(-1)
else:
indexes = item
diff --git a/mmengine/hooks/logger_hook.py b/mmengine/hooks/logger_hook.py
index 427e1ff9..aca48652 100644
--- a/mmengine/hooks/logger_hook.py
+++ b/mmengine/hooks/logger_hook.py
@@ -165,11 +165,7 @@ class LoggerHook(Hook):
self.json_log_path = osp.join(runner.work_dir,
f'{runner.timestamp}.log.json')
- self.yaml_log_path = osp.join(runner.work_dir,
- f'{runner.timestamp}.log.json')
self.start_iter = runner.iter
- if runner.meta is not None:
- runner.writer.add_params(runner.meta, file_path=self.yaml_log_path)
def after_train_iter(self,
runner,
@@ -298,7 +294,7 @@ class LoggerHook(Hook):
log_str += ', '.join(log_items)
runner.logger.info(log_str)
# Write logs to local, tensorboad, and wandb.
- runner.writer.add_scalars(
+ runner.visualizer.add_scalars(
tag, step=runner.iter + 1, file_path=self.json_log_path)
def _log_val(self, runner) -> None:
@@ -330,7 +326,7 @@ class LoggerHook(Hook):
log_str += ', '.join(log_items)
runner.logger.info(log_str)
# Write tag.
- runner.writer.add_scalars(
+ runner.visualizer.add_scalars(
tag, step=cur_iter, file_path=self.json_log_path)
def _get_window_size(self, runner, window_size: Union[int, str]) \
diff --git a/mmengine/hooks/naive_visualization_hook.py b/mmengine/hooks/naive_visualization_hook.py
index e8bd3834..2819563a 100644
--- a/mmengine/hooks/naive_visualization_hook.py
+++ b/mmengine/hooks/naive_visualization_hook.py
@@ -11,6 +11,8 @@ from mmengine.registry import HOOKS
from mmengine.utils.misc import tensor2imgs
+# TODO: Due to interface changes, the current class
+# functions incorrectly
@HOOKS.register_module()
class NaiveVisualizationHook(Hook):
"""Show or Write the predicted results during the process of testing.
@@ -68,5 +70,6 @@ class NaiveVisualizationHook(Hook):
data_sample.get('scale', ori_shape))
origin_image = cv2.resize(input, ori_shape)
name = osp.basename(data_sample.img_path)
- runner.writer.add_image(name, origin_image, data_sample,
- output, self.draw_gt, self.draw_pred)
+ runner.visualizer.add_datasample(name, origin_image,
+ data_sample, output,
+ self.draw_gt, self.draw_pred)
diff --git a/mmengine/logging/message_hub.py b/mmengine/logging/message_hub.py
index 75a2a4bc..1a6f4e67 100644
--- a/mmengine/logging/message_hub.py
+++ b/mmengine/logging/message_hub.py
@@ -1,5 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
-import copy
from collections import OrderedDict
from typing import Any, Union
@@ -103,7 +102,8 @@ class MessageHub(ManagerMixin):
Returns:
OrderedDict: A copy of all runtime information.
"""
- return copy.deepcopy(self._runtime_info)
+ # return copy.deepcopy(self._runtime_info)
+ return self._runtime_info
def get_log(self, key: str) -> LogBuffer:
"""Get ``LogBuffer`` instance by key.
@@ -136,7 +136,10 @@ class MessageHub(ManagerMixin):
if key not in self.runtime_info:
raise KeyError(f'{key} is not found in Messagehub.log_buffers: '
f'instance name is: {MessageHub.instance_name}')
- return copy.deepcopy(self._runtime_info[key])
+
+ # TODO: There are restrictions on objects that can be saved
+ # return copy.deepcopy(self._runtime_info[key])
+ return self._runtime_info[key]
def _get_valid_value(self, key: str,
value: Union[torch.Tensor, np.ndarray, int, float])\
diff --git a/mmengine/registry/__init__.py b/mmengine/registry/__init__.py
index ead8cb0a..56c65b80 100644
--- a/mmengine/registry/__init__.py
+++ b/mmengine/registry/__init__.py
@@ -4,12 +4,12 @@ from .registry import Registry, build_from_cfg
from .root import (DATA_SAMPLERS, DATASETS, HOOKS, LOOPS, METRICS,
MODEL_WRAPPERS, MODELS, OPTIMIZER_CONSTRUCTORS, OPTIMIZERS,
PARAM_SCHEDULERS, RUNNER_CONSTRUCTORS, RUNNERS, TASK_UTILS,
- TRANSFORMS, VISUALIZERS, WEIGHT_INITIALIZERS, WRITERS)
+ TRANSFORMS, VISBACKENDS, VISUALIZERS, WEIGHT_INITIALIZERS)
__all__ = [
'Registry', 'build_from_cfg', 'RUNNERS', 'RUNNER_CONSTRUCTORS', 'HOOKS',
'DATASETS', 'DATA_SAMPLERS', 'TRANSFORMS', 'MODELS', 'WEIGHT_INITIALIZERS',
'OPTIMIZERS', 'OPTIMIZER_CONSTRUCTORS', 'TASK_UTILS', 'PARAM_SCHEDULERS',
- 'METRICS', 'MODEL_WRAPPERS', 'LOOPS', 'WRITERS', 'VISUALIZERS',
+ 'METRICS', 'MODEL_WRAPPERS', 'LOOPS', 'VISBACKENDS', 'VISUALIZERS',
'DefaultScope'
]
diff --git a/mmengine/registry/registry.py b/mmengine/registry/registry.py
index 3ee7d4d6..7e2b4845 100644
--- a/mmengine/registry/registry.py
+++ b/mmengine/registry/registry.py
@@ -6,7 +6,7 @@ from collections.abc import Callable
from typing import Any, Dict, List, Optional, Tuple, Type, Union
from ..config import Config, ConfigDict
-from ..utils import is_seq_of
+from ..utils import ManagerMixin, is_seq_of
from .default_scope import DefaultScope
@@ -88,7 +88,13 @@ def build_from_cfg(
f'type must be a str or valid type, but got {type(obj_type)}')
try:
- return obj_cls(**args) # type: ignore
+ # If `obj_cls` inherits from `ManagerMixin`, it should be instantiated
+ # by `ManagerMixin.get_instance` to ensure that it can be accessed
+ # globally.
+ if issubclass(obj_cls, ManagerMixin):
+ return obj_cls.get_instance(**args) # type: ignore
+ else:
+ return obj_cls(**args) # type: ignore
except Exception as e:
# Normal TypeError does not print class name.
raise type(e)(f'{obj_cls.__name__}: {e}') # type: ignore
diff --git a/mmengine/registry/root.py b/mmengine/registry/root.py
index 571d55cb..62d72f70 100644
--- a/mmengine/registry/root.py
+++ b/mmengine/registry/root.py
@@ -43,5 +43,5 @@ TASK_UTILS = Registry('task util')
# manage visualizer
VISUALIZERS = Registry('visualizer')
-# manage writer
-WRITERS = Registry('writer')
+# manage visualizer backend
+VISBACKENDS = Registry('vis_backend')
diff --git a/mmengine/runner/loops.py b/mmengine/runner/loops.py
index d791c52c..f1821311 100644
--- a/mmengine/runner/loops.py
+++ b/mmengine/runner/loops.py
@@ -27,6 +27,14 @@ class EpochBasedTrainLoop(BaseLoop):
super().__init__(runner, dataloader)
self._max_epochs = max_epochs
self._max_iters = max_epochs * len(self.dataloader)
+ if hasattr(self.dataloader.dataset, 'metainfo'):
+ self.runner.visualizer.dataset_meta = \
+ self.dataloader.dataset.metainfo
+ else:
+ warnings.warn(
+ f'Dataset {self.dataloader.dataset.__class__.__name__} has no '
+ 'metainfo. ``dataset_meta`` in visualizer will be '
+ 'None.')
@property
def max_epochs(self):
@@ -100,6 +108,14 @@ class IterBasedTrainLoop(BaseLoop):
max_iters: int) -> None:
super().__init__(runner, dataloader)
self._max_iters = max_iters
+ if hasattr(self.dataloader.dataset, 'metainfo'):
+ self.runner.visualizer.dataset_meta = \
+ self.dataloader.dataset.metainfo
+ else:
+ warnings.warn(
+ f'Dataset {self.dataloader.dataset.__class__.__name__} has no '
+ 'metainfo. ``dataset_meta`` in visualizer will be '
+ 'None.')
self.dataloader = iter(self.dataloader)
@property
@@ -176,11 +192,13 @@ class ValLoop(BaseLoop):
self.evaluator = evaluator # type: ignore
if hasattr(self.dataloader.dataset, 'metainfo'):
self.evaluator.dataset_meta = self.dataloader.dataset.metainfo
+ self.runner.visualizer.dataset_meta = \
+ self.dataloader.dataset.metainfo
else:
warnings.warn(
f'Dataset {self.dataloader.dataset.__class__.__name__} has no '
- 'metainfo. ``dataset_meta`` in evaluator and metric will be '
- 'None.')
+ 'metainfo. ``dataset_meta`` in evaluator, metric and '
+ 'visualizer will be None.')
self.interval = interval
def run(self):
@@ -240,11 +258,13 @@ class TestLoop(BaseLoop):
self.evaluator = evaluator # type: ignore
if hasattr(self.dataloader.dataset, 'metainfo'):
self.evaluator.dataset_meta = self.dataloader.dataset.metainfo
+ self.runner.visualizer.dataset_meta = \
+ self.dataloader.dataset.metainfo
else:
warnings.warn(
f'Dataset {self.dataloader.dataset.__class__.__name__} has no '
- 'metainfo. ``dataset_meta`` in evaluator and metric will be '
- 'None.')
+ 'metainfo. ``dataset_meta`` in evaluator, metric and '
+ 'visualizer will be None.')
def run(self) -> None:
"""Launch test."""
diff --git a/mmengine/runner/runner.py b/mmengine/runner/runner.py
index 54e1e710..86a57f36 100644
--- a/mmengine/runner/runner.py
+++ b/mmengine/runner/runner.py
@@ -30,10 +30,10 @@ from mmengine.model import is_model_wrapper
from mmengine.optim import _ParamScheduler, build_optimizer
from mmengine.registry import (DATA_SAMPLERS, DATASETS, HOOKS, LOOPS,
MODEL_WRAPPERS, MODELS, PARAM_SCHEDULERS,
- DefaultScope)
+ VISUALIZERS, DefaultScope)
from mmengine.utils import (TORCH_VERSION, digit_version,
find_latest_checkpoint, is_list_of, symlink)
-from mmengine.visualization import ComposedWriter
+from mmengine.visualization import Visualizer
from .base_loop import BaseLoop
from .checkpoint import (_load_checkpoint, _load_checkpoint_to_model,
get_state_dict, save_checkpoint, weights_to_cpu)
@@ -129,8 +129,8 @@ class Runner:
dict(dist_cfg=dict(backend='nccl')).
log_level (int or str): The log level of MMLogger handlers.
Defaults to 'INFO'.
- writer (ComposedWriter or dict, optional): A ComposedWriter object or a
- dict build ComposedWriter object. Defaults to None. If not
+ visualizer (Visualizer or dict, optional): A Visualizer object or a
+ dict build Visualizer object. Defaults to None. If not
specified, default config will be used.
default_scope (str, optional): Used to reset registries location.
Defaults to None.
@@ -184,9 +184,9 @@ class Runner:
param_scheduler=dict(type='ParamSchedulerHook')),
launcher='none',
env_cfg=dict(dist_cfg=dict(backend='nccl')),
- writer=dict(
- name='composed_writer',
- writers=[dict(type='LocalWriter', save_dir='temp_dir')])
+ visualizer=dict(type='Visualizer',
+ vis_backends=[dict(type='LocalVisBackend',
+ save_dir='temp_dir')])
)
>>> runner = Runner.from_cfg(cfg)
>>> runner.train()
@@ -218,7 +218,7 @@ class Runner:
launcher: str = 'none',
env_cfg: Dict = dict(dist_cfg=dict(backend='nccl')),
log_level: str = 'INFO',
- writer: Optional[Union[ComposedWriter, Dict]] = None,
+ visualizer: Optional[Union[Visualizer, Dict]] = None,
default_scope: Optional[str] = None,
randomness: Dict = dict(seed=None),
experiment_name: Optional[str] = None,
@@ -310,16 +310,17 @@ class Runner:
else:
self._experiment_name = self.timestamp
- self.logger = self.build_logger(log_level=log_level)
- # message hub used for component interaction
- self.message_hub = self.build_message_hub()
- # writer used for writing log or visualizing all kinds of data
- self.writer = self.build_writer(writer)
# Used to reset registries location. See :meth:`Registry.build` for
# more details.
self.default_scope = DefaultScope.get_instance(
self._experiment_name, scope_name=default_scope)
+ self.logger = self.build_logger(log_level=log_level)
+ # message hub used for component interaction
+ self.message_hub = self.build_message_hub()
+ # visualizer used for writing log or visualizing all kinds of data
+ self.visualizer = self.build_visualizer(visualizer)
+
self._load_from = load_from
self._resume = resume
# flag to mark whether checkpoint has been loaded or resumed
@@ -378,7 +379,7 @@ class Runner:
launcher=cfg.get('launcher', 'none'),
env_cfg=cfg.get('env_cfg'), # type: ignore
log_level=cfg.get('log_level', 'INFO'),
- writer=cfg.get('writer'),
+ visualizer=cfg.get('visualizer'),
default_scope=cfg.get('default_scope'),
randomness=cfg.get('randomness', dict(seed=None)),
experiment_name=cfg.get('experiment_name'),
@@ -623,37 +624,42 @@ class Runner:
return MessageHub.get_instance(**message_hub)
- def build_writer(
- self,
- writer: Optional[Union[ComposedWriter,
- Dict]] = None) -> ComposedWriter:
- """Build a global asscessable ComposedWriter.
+ def build_visualizer(
+ self,
+ visualizer: Optional[Union[Visualizer,
+ Dict]] = None) -> Visualizer:
+ """Build a global asscessable Visualizer.
Args:
- writer (ComposedWriter or dict, optional): A ComposedWriter object
- or a dict to build ComposedWriter object. If ``writer`` is a
- ComposedWriter object, just returns itself. If not specified,
- default config will be used to build ComposedWriter object.
+ visualizer (Visualizer or dict, optional): A Visualizer object
+ or a dict to build Visualizer object. If ``visualizer`` is a
+ Visualizer object, just returns itself. If not specified,
+ default config will be used to build Visualizer object.
Defaults to None.
Returns:
- ComposedWriter: A ComposedWriter object build from ``writer``.
+ Visualizer: A Visualizer object build from ``visualizer``.
"""
- if isinstance(writer, ComposedWriter):
- return writer
- elif writer is None:
- writer = dict(
+ if visualizer is None:
+ visualizer = dict(
name=self._experiment_name,
- writers=[dict(type='LocalWriter', save_dir=self._work_dir)])
- elif isinstance(writer, dict):
- # ensure writer containing name key
- writer.setdefault('name', self._experiment_name)
+ vis_backends=[
+ dict(type='LocalVisBackend', save_dir=self._work_dir)
+ ])
+ return Visualizer.get_instance(**visualizer)
+
+ if isinstance(visualizer, Visualizer):
+ return visualizer
+
+ if isinstance(visualizer, dict):
+ # ensure visualizer containing name key
+ visualizer.setdefault('name', self._experiment_name)
+ visualizer.setdefault('save_dir', self._work_dir)
+ return VISUALIZERS.build(visualizer)
else:
raise TypeError(
- 'writer should be ComposedWriter object, a dict or None, '
- f'but got {writer}')
-
- return ComposedWriter.get_instance(**writer)
+ 'visualizer should be Visualizer object, a dict or None, '
+ f'but got {visualizer}')
def build_model(self, model: Union[nn.Module, Dict]) -> nn.Module:
"""Build model.
diff --git a/mmengine/visualization/__init__.py b/mmengine/visualization/__init__.py
index 892c3daa..6c8b0bb5 100644
--- a/mmengine/visualization/__init__.py
+++ b/mmengine/visualization/__init__.py
@@ -1,9 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
+from .vis_backend import (BaseVisBackend, LocalVisBackend,
+ TensorboardVisBackend, WandbVisBackend)
from .visualizer import Visualizer
-from .writer import (BaseWriter, ComposedWriter, LocalWriter,
- TensorboardWriter, WandbWriter)
__all__ = [
- 'Visualizer', 'BaseWriter', 'LocalWriter', 'WandbWriter',
- 'TensorboardWriter', 'ComposedWriter'
+ 'Visualizer', 'BaseVisBackend', 'LocalVisBackend', 'WandbVisBackend',
+ 'TensorboardVisBackend'
]
diff --git a/mmengine/visualization/utils.py b/mmengine/visualization/utils.py
index 97803ce2..a0033dac 100644
--- a/mmengine/visualization/utils.py
+++ b/mmengine/visualization/utils.py
@@ -1,6 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
-from typing import Any, List, Tuple, Type, Union
+from typing import Any, List, Optional, Tuple, Type, Union
+
+import cv2
+import matplotlib
import numpy as np
import torch
@@ -84,3 +87,60 @@ def check_type_and_length(name: str, value: Any,
"""
check_type(name, value, valid_type)
check_length(name, value, valid_length)
+
+
+def color_val_matplotlib(colors):
+ """Convert various input in RGB order to normalized RGB matplotlib color
+ tuples,
+ Args:
+ color (:obj:`mmcv.Color`/str/tuple/int/ndarray): Color inputs
+ Returns:
+ tuple[float]: A tuple of 3 normalized floats indicating RGB channels.
+ """
+ if isinstance(colors, str):
+ return colors
+ elif isinstance(colors, tuple):
+ assert len(colors) == 3
+ for channel in colors:
+ assert 0 <= channel <= 255
+ colors = [channel / 255 for channel in colors]
+ return tuple(colors)
+ elif isinstance(colors, list):
+ colors = [color_val_matplotlib(color) for color in colors]
+ return colors
+ else:
+ raise TypeError(f'Invalid type for color: {type(colors)}')
+
+
+def str_color_to_rgb(color):
+ color = matplotlib.colors.to_rgb(color)
+ color = tuple([int(c * 255) for c in color])
+ return color
+
+
+def convert_overlay_heatmap(feat_map: Union[np.ndarray, torch.Tensor],
+ img: Optional[np.ndarray] = None,
+ alpha: float = 0.5) -> np.ndarray:
+ """Convert feat_map to heatmap and overlay on image, if image is not None.
+
+ Args:
+ feat_map (np.ndarray, torch.Tensor): The feat_map to convert
+ with of shape (H, W), where H is the image height and W is
+ the image width.
+ img (np.ndarray, optional): The origin image. The format
+ should be RGB. Defaults to None.
+ alpha (float): The transparency of origin image. Defaults to 0.5.
+
+ Returns:
+ np.ndarray: heatmap
+ """
+ if isinstance(feat_map, torch.Tensor):
+ feat_map = feat_map.detach().cpu().numpy()
+ norm_img = np.zeros(feat_map.shape)
+ norm_img = cv2.normalize(feat_map, norm_img, 0, 255, cv2.NORM_MINMAX)
+ norm_img = np.asarray(norm_img, dtype=np.uint8)
+ heat_img = cv2.applyColorMap(norm_img, cv2.COLORMAP_JET)
+ heat_img = cv2.cvtColor(heat_img, cv2.COLOR_BGR2RGB)
+ if img is not None:
+ heat_img = cv2.addWeighted(img, alpha, heat_img, 1 - alpha, 0)
+ return heat_img
diff --git a/mmengine/visualization/vis_backend.py b/mmengine/visualization/vis_backend.py
new file mode 100644
index 00000000..13de36d8
--- /dev/null
+++ b/mmengine/visualization/vis_backend.py
@@ -0,0 +1,494 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os
+import os.path as osp
+import time
+from abc import ABCMeta, abstractmethod
+from typing import Any, Optional, Sequence, Union
+
+import cv2
+import numpy as np
+import torch
+
+from mmengine.config import Config
+from mmengine.fileio import dump
+from mmengine.registry import VISBACKENDS
+from mmengine.utils import TORCH_VERSION
+
+
+class BaseVisBackend(metaclass=ABCMeta):
+ """Base class for vis backend.
+
+ All backends must inherit ``BaseVisBackend`` and implement
+ the required functions.
+
+ Args:
+ save_dir (str, optional): The root directory to save
+ the files produced by the backend. Default to None.
+ """
+
+ def __init__(self, save_dir: Optional[str] = None):
+ self._save_dir = save_dir
+ if self._save_dir:
+ timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
+ self._save_dir = osp.join(self._save_dir,
+ f'vis_data_{timestamp}') # type: ignore
+
+ @property
+ @abstractmethod
+ def experiment(self) -> Any:
+ """Return the experiment object associated with this writer.
+
+ The experiment attribute can get the visualizer backend, such as wandb,
+ tensorboard. If you want to write other data, such as writing a table,
+ you can directly get the visualizer backend through experiment.
+ """
+ pass
+
+ def add_config(self, config: Config, **kwargs) -> None:
+ """Record a set of parameters.
+
+ Args:
+ config (Config): The Config object
+ """
+ pass
+
+ def add_graph(self, model: torch.nn.Module, data_batch: Sequence[dict],
+ **kwargs) -> None:
+ """Record graph.
+
+ Args:
+ model (torch.nn.Module): Model to draw.
+ data_batch (Sequence[dict]): Batch of data from dataloader.
+ """
+ pass
+
+ def add_image(self,
+ name: str,
+ image: np.ndarray,
+ step: int = 0,
+ **kwargs) -> None:
+ """Record image.
+
+ Args:
+ name (str): The unique identifier for the image to save.
+ image (np.ndarray, optional): The image to be saved. The format
+ should be RGB. Default to None.
+ step (int): Global step value to record. Default to 0.
+ """
+ pass
+
+ def add_scalar(self,
+ name: str,
+ value: Union[int, float],
+ step: int = 0,
+ **kwargs) -> None:
+ """Record scalar.
+
+ Args:
+ name (str): The unique identifier for the scalar to save.
+ value (float, int): Value to save.
+ step (int): Global step value to record. Default to 0.
+ """
+ pass
+
+ def add_scalars(self,
+ scalar_dict: dict,
+ step: int = 0,
+ file_path: Optional[str] = None,
+ **kwargs) -> None:
+ """Record scalars' data.
+
+ Args:
+ scalar_dict (dict): Key-value pair storing the tag and
+ corresponding values.
+ step (int): Global step value to record. Default to 0.
+ file_path (str, optional): The scalar's data will be
+ saved to the `file_path` file at the same time
+ if the `file_path` parameter is specified.
+ Default to None.
+ """
+ pass
+
+ def close(self) -> None:
+ """close an opened object."""
+ pass
+
+
+@VISBACKENDS.register_module()
+class LocalVisBackend(BaseVisBackend):
+ """Local vis backend class.
+
+ It can write image, config, scalars, etc.
+ to the local hard disk. You can get the drawing backend
+ through the visualizer property for custom drawing.
+
+ Examples:
+ >>> from mmengine.visualization import LocalVisBackend
+ >>> import numpy as np
+ >>> local_vis_backend = LocalVisBackend(save_dir='temp_dir')
+ >>> img=np.random.randint(0, 256, size=(10, 10, 3))
+ >>> local_vis_backend.add_image('img', img)
+ >>> local_vis_backend.add_scaler('mAP', 0.6)
+ >>> local_vis_backend.add_scalars({'loss': [1, 2, 3], 'acc': 0.8})
+ >>> local_vis_backend.add_image('img', image)
+
+ Args:
+ save_dir (str, optional): The root directory to save the files
+ produced by the writer. If it is none, it means no data
+ is stored. Default None.
+ img_save_dir (str): The directory to save images.
+ Default to 'writer_image'.
+ config_save_file (str): The file to save parameters.
+ Default to 'parameters.yaml'.
+ scalar_save_file (str): The file to save scalar values.
+ Default to 'scalars.json'.
+ """
+
+ def __init__(self,
+ save_dir: Optional[str] = None,
+ img_save_dir: str = 'vis_image',
+ config_save_file: str = 'config.py',
+ scalar_save_file: str = 'scalars.json'):
+ assert config_save_file.split('.')[-1] == 'py'
+ assert scalar_save_file.split('.')[-1] == 'json'
+ super(LocalVisBackend, self).__init__(save_dir)
+ if self._save_dir is not None:
+ os.makedirs(self._save_dir, exist_ok=True) # type: ignore
+ self._img_save_dir = osp.join(
+ self._save_dir, # type: ignore
+ img_save_dir)
+ self._scalar_save_file = osp.join(
+ self._save_dir, # type: ignore
+ scalar_save_file)
+ self._config_save_file = osp.join(
+ self._save_dir, # type: ignore
+ config_save_file)
+
+ @property
+ def experiment(self) -> 'LocalVisBackend':
+ """Return the experiment object associated with this visualizer
+ backend."""
+ return self
+
+ def add_config(self, config: Config, **kwargs) -> None:
+ # TODO
+ assert isinstance(config, Config)
+
+ def add_image(self,
+ name: str,
+ image: np.ndarray = None,
+ step: int = 0,
+ **kwargs) -> None:
+ """Record image to disk.
+
+ Args:
+ name (str): The unique identifier for the image to save.
+ image (np.ndarray, optional): The image to be saved. The format
+ should be RGB. Default to None.
+ step (int): Global step value to record. Default to 0.
+ """
+
+ drawn_image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
+ os.makedirs(self._img_save_dir, exist_ok=True)
+ save_file_name = f'{name}_{step}.png'
+ cv2.imwrite(osp.join(self._img_save_dir, save_file_name), drawn_image)
+
+ def add_scalar(self,
+ name: str,
+ value: Union[int, float],
+ step: int = 0,
+ **kwargs) -> None:
+ """Add scalar data to disk.
+
+ Args:
+ name (str): The unique identifier for the scalar to save.
+ value (float, int): Value to save.
+ step (int): Global step value to record. Default to 0.
+ """
+ self._dump({name: value, 'step': step}, self._scalar_save_file, 'json')
+
+ def add_scalars(self,
+ scalar_dict: dict,
+ step: int = 0,
+ file_path: Optional[str] = None,
+ **kwargs) -> None:
+ """Record scalars. The scalar dict will be written to the default and
+ specified files if ``file_name`` is specified.
+
+ Args:
+ scalar_dict (dict): Key-value pair storing the tag and
+ corresponding values.
+ step (int): Global step value to record. Default to 0.
+ file_path (str, optional): The scalar's data will be
+ saved to the ``file_path`` file at the same time
+ if the ``file_path`` parameter is specified.
+ Default to None.
+ """
+ assert isinstance(scalar_dict, dict)
+ scalar_dict.setdefault('step', step)
+ if file_path is not None:
+ assert file_path.split('.')[-1] == 'json'
+ new_save_file_path = osp.join(
+ self._save_dir, # type: ignore
+ file_path)
+ assert new_save_file_path != self._scalar_save_file, \
+ '"file_path" and "scalar_save_file" have the same name, ' \
+ 'please set "file_path" to another value'
+ self._dump(scalar_dict, new_save_file_path, 'json')
+ self._dump(scalar_dict, self._scalar_save_file, 'json')
+
+ def _dump(self, value_dict: dict, file_path: str,
+ file_format: str) -> None:
+ """dump dict to file.
+
+ Args:
+ value_dict (dict) : Save dict data.
+ file_path (str): The file path to save data.
+ file_format (str): The file format to save data.
+ """
+ with open(file_path, 'a+') as f:
+ dump(value_dict, f, file_format=file_format)
+ f.write('\n')
+
+
+@VISBACKENDS.register_module()
+class WandbVisBackend(BaseVisBackend):
+ """Write various types of data to wandb.
+
+ Examples:
+ >>> from mmengine.visualization import WandbVisBackend
+ >>> import numpy as np
+ >>> wandb_vis_backend = WandbVisBackend()
+ >>> img=np.random.randint(0, 256, size=(10, 10, 3))
+ >>> wandb_vis_backend.add_image('img', img)
+ >>> wandb_vis_backend.add_scaler('mAP', 0.6)
+ >>> wandb_vis_backend.add_scalars({'loss': [1, 2, 3],'acc': 0.8})
+ >>> wandb_vis_backend.add_image('img', img)
+
+ Args:
+ init_kwargs (dict, optional): wandb initialization
+ input parameters. Default to None.
+ commit: (bool, optional) Save the metrics dict to the wandb server
+ and increment the step. If false `wandb.log` just
+ updates the current metrics dict with the row argument
+ and metrics won't be saved until `wandb.log` is called
+ with `commit=True`. Default to True.
+ save_dir (str, optional): The root directory to save the files
+ produced by the writer. Default to None.
+ """
+
+ def __init__(self,
+ init_kwargs: Optional[dict] = None,
+ commit: Optional[bool] = True,
+ save_dir: Optional[str] = None):
+ super(WandbVisBackend, self).__init__(save_dir)
+ self._commit = commit
+ self._wandb = self._setup_env(init_kwargs)
+
+ @property
+ def experiment(self):
+ """Return wandb object.
+
+ The experiment attribute can get the wandb backend, If you want to
+ write other data, such as writing a table, you can directly get the
+ wandb backend through experiment.
+ """
+ return self._wandb
+
+ def _setup_env(self, init_kwargs: Optional[dict] = None) -> Any:
+ """Setup env.
+
+ Args:
+ init_kwargs (dict): The init args.
+
+ Return:
+ :obj:`wandb`
+ """
+ try:
+ import wandb
+ except ImportError:
+ raise ImportError(
+ 'Please run "pip install wandb" to install wandb')
+ if init_kwargs:
+ wandb.init(**init_kwargs)
+ else:
+ wandb.init()
+
+ return wandb
+
+ def add_config(self, config: Config, **kwargs) -> None:
+ # TODO
+ pass
+
+ def add_image(self,
+ name: str,
+ image: np.ndarray = None,
+ step: int = 0,
+ **kwargs) -> None:
+ """Record image to wandb.
+
+ Args:
+ name (str): The unique identifier for the image to save.
+ image (np.ndarray, optional): The image to be saved. The format
+ should be RGB. Default to None.
+ step (int): Global step value to record. Default to 0.
+ """
+ self._wandb.log({name: image}, commit=self._commit, step=step)
+
+ def add_scalar(self,
+ name: str,
+ value: Union[int, float],
+ step: int = 0,
+ **kwargs) -> None:
+ """Record scalar data to wandb.
+
+ Args:
+ name (str): The unique identifier for the scalar to save.
+ value (float, int): Value to save.
+ step (int): Global step value to record. Default to 0.
+ """
+ self._wandb.log({name: value}, commit=self._commit, step=step)
+
+ def add_scalars(self,
+ scalar_dict: dict,
+ step: int = 0,
+ file_path: Optional[str] = None,
+ **kwargs) -> None:
+ """Record scalar's data to wandb.
+
+ Args:
+ scalar_dict (dict): Key-value pair storing the tag and
+ corresponding values.
+ step (int): Global step value to record. Default to 0.
+ file_path (str, optional): Useless parameter. Just for
+ interface unification. Default to None.
+ """
+ self._wandb.log(scalar_dict, commit=self._commit, step=step)
+
+ def close(self) -> None:
+ """close an opened wandb object."""
+ if hasattr(self, '_wandb'):
+ self._wandb.join()
+
+
+@VISBACKENDS.register_module()
+class TensorboardVisBackend(BaseVisBackend):
+ """Tensorboard class. It can write images, config, scalars, etc. to a
+ tensorboard file.
+
+ Its drawing function is provided by Visualizer.
+
+ Examples:
+ >>> from mmengine.visualization import TensorboardVisBackend
+ >>> import numpy as np
+ >>> tensorboard_visualizer = TensorboardVisBackend(save_dir='temp_dir')
+ >>> img=np.random.randint(0, 256, size=(10, 10, 3))
+ >>> tensorboard_visualizer.add_image('img', img)
+ >>> tensorboard_visualizer.add_scaler('mAP', 0.6)
+ >>> tensorboard_visualizer.add_scalars({'loss': 0.1,'acc':0.8})
+ >>> tensorboard_visualizer.add_image('img', image)
+
+ Args:
+ save_dir (str): The root directory to save the files
+ produced by the backend.
+ log_dir (str): Save directory location. Default to 'tf_logs'.
+ """
+
+ def __init__(self,
+ save_dir: Optional[str] = None,
+ log_dir: str = 'tf_logs'):
+ super(TensorboardVisBackend, self).__init__(save_dir)
+ if save_dir is not None:
+ self._tensorboard = self._setup_env(log_dir)
+
+ def _setup_env(self, log_dir: str):
+ """Setup env.
+
+ Args:
+ log_dir (str): Save directory location.
+
+ Return:
+ :obj:`SummaryWriter`
+ """
+ if TORCH_VERSION == 'parrots':
+ try:
+ from tensorboardX import SummaryWriter
+ except ImportError:
+ raise ImportError('Please install tensorboardX to use '
+ 'TensorboardLoggerHook.')
+ else:
+ try:
+ from torch.utils.tensorboard import SummaryWriter
+ except ImportError:
+ raise ImportError(
+ 'Please run "pip install future tensorboard" to install '
+ 'the dependencies to use torch.utils.tensorboard '
+ '(applicable to PyTorch 1.1 or higher)')
+ if self._save_dir is None:
+ return SummaryWriter(f'./{log_dir}')
+ else:
+ self.log_dir = osp.join(self._save_dir, log_dir) # type: ignore
+ return SummaryWriter(self.log_dir)
+
+ @property
+ def experiment(self):
+ """Return Tensorboard object."""
+ return self._tensorboard
+
+ def add_config(self, config: Config, **kwargs) -> None:
+ # TODO
+ pass
+
+ def add_image(self,
+ name: str,
+ image: np.ndarray,
+ step: int = 0,
+ **kwargs) -> None:
+ """Record image to tensorboard.
+
+ Args:
+ name (str): The unique identifier for the image to save.
+ image (np.ndarray, optional): The image to be saved. The format
+ should be RGB. Default to None.
+ step (int): Global step value to record. Default to 0.
+ """
+ self._tensorboard.add_image(name, image, step, dataformats='HWC')
+
+ def add_scalar(self,
+ name: str,
+ value: Union[int, float],
+ step: int = 0,
+ **kwargs) -> None:
+ """Record scalar data to summary.
+
+ Args:
+ name (str): The unique identifier for the scalar to save.
+ value (float, int): Value to save.
+ step (int): Global step value to record. Default to 0.
+ """
+ self._tensorboard.add_scalar(name, value, step)
+
+ def add_scalars(self,
+ scalar_dict: dict,
+ step: int = 0,
+ file_path: Optional[str] = None,
+ **kwargs) -> None:
+ """Record scalar's data to summary.
+
+ Args:
+ scalar_dict (dict): Key-value pair storing the tag and
+ corresponding values.
+ step (int): Global step value to record. Default to 0.
+ file_path (str, optional): Useless parameter. Just for
+ interface unification. Default to None.
+ """
+ assert isinstance(scalar_dict, dict)
+ assert 'step' not in scalar_dict, 'Please set it directly ' \
+ 'through the step parameter'
+ for key, value in scalar_dict.items():
+ self.add_scalar(key, value, step)
+
+ def close(self):
+ """close an opened tensorboard object."""
+ if hasattr(self, '_tensorboard'):
+ self._tensorboard.close()
diff --git a/mmengine/visualization/visualizer.py b/mmengine/visualization/visualizer.py
index ae6ff113..d8025616 100644
--- a/mmengine/visualization/visualizer.py
+++ b/mmengine/visualization/visualizer.py
@@ -1,25 +1,32 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
-from typing import Callable, List, Optional, Tuple, Union
+from typing import Dict, List, Optional, Sequence, Tuple, Union
import cv2
import matplotlib.pyplot as plt
import numpy as np
import torch
+from matplotlib.backend_bases import CloseEvent
from matplotlib.backends.backend_agg import FigureCanvasAgg
from matplotlib.collections import (LineCollection, PatchCollection,
PolyCollection)
from matplotlib.figure import Figure
from matplotlib.patches import Circle
+from mmengine.config import Config
from mmengine.data import BaseDataElement
-from mmengine.registry import VISUALIZERS
-from .utils import (check_type, check_type_and_length, tensor2ndarray,
- value2list)
+from mmengine.registry import VISBACKENDS, VISUALIZERS
+from mmengine.utils import ManagerMixin
+from mmengine.visualization.utils import (check_type, check_type_and_length,
+ color_val_matplotlib,
+ convert_overlay_heatmap,
+ str_color_to_rgb, tensor2ndarray,
+ value2list)
+from mmengine.visualization.vis_backend import BaseVisBackend
@VISUALIZERS.register_module()
-class Visualizer:
+class Visualizer(ManagerMixin):
"""MMEngine provides a Visualizer class that uses the ``Matplotlib``
library as the backend. It has the following functions:
@@ -67,15 +74,15 @@ class Visualizer:
>>> # Basic drawing methods
>>> vis = Visualizer(metadata=metadata, image=image)
- >>> vis.draw_bboxes(np.array([0, 0, 1, 1]), edgecolors='g')
+ >>> vis.draw_bboxes(np.array([0, 0, 1, 1]), edge_colors='g')
>>> vis.draw_bboxes(bbox=np.array([[1, 1, 2, 2], [2, 2, 3, 3]]),
- edgecolors=['g', 'r'], is_filling=True)
+ edge_colors=['g', 'r'], is_filling=True)
>>> vis.draw_lines(x_datas=np.array([1, 3]),
y_datas=np.array([1, 3]),
- colors='r', linewidths=1)
+ colors='r', line_widths=1)
>>> vis.draw_lines(x_datas=np.array([[1, 3], [2, 4]]),
y_datas=np.array([[1, 3], [2, 4]]),
- colors=['r', 'r'], linewidths=[1, 2])
+ colors=['r', 'r'], line_widths=[1, 2])
>>> vis.draw_texts(text='MMEngine',
position=np.array([2, 2]),
colors='b')
@@ -87,10 +94,10 @@ class Visualizer:
radius=np.array[1, 2], colors=['g', 'r'],
is_filling=True)
>>> vis.draw_polygons(np.array([0, 0, 1, 0, 1, 1, 0, 1]),
- edgecolors='g')
+ edge_colors='g')
>>> vis.draw_polygons(bbox=[np.array([0, 0, 1, 0, 1, 1, 0, 1],
np.array([2, 2, 3, 2, 3, 3, 2, 3]]),
- edgecolors=['g', 'r'], is_filling=True)
+ edge_colors=['g', 'r'], is_filling=True)
>>> vis.draw_binary_masks(binary_mask, alpha=0.6)
>>> # chain calls
@@ -106,80 +113,99 @@ class Visualizer:
>>> # inherit
>>> class DetVisualizer2(Visualizer):
- >>> @Visualizer.register_task('instances')
- >>> def draw_instance(self,
- >>> instances: 'BaseDataInstance',
- >>> data_type: Type):
- >>> pass
- >>> def draw(self,
+ >>> def add_datasample(self,
>>> image: Optional[np.ndarray] = None,
>>> gt_sample: Optional['BaseDataElement'] = None,
>>> pred_sample: Optional['BaseDataElement'] = None,
>>> show_gt: bool = True,
- >>> show_pred: bool = True) -> None:
+ >>> show_pred: bool = True,
+ >>> show:bool = True) -> None:
>>> pass
"""
- task_dict: dict = {}
- def __init__(self,
- image: Optional[np.ndarray] = None,
- metadata: Optional[dict] = None) -> None:
- self._metadata = metadata
+ def __init__(
+ self,
+ name='visualizer',
+ image: Optional[np.ndarray] = None,
+ vis_backends: Optional[Dict] = None,
+ save_dir: Optional[str] = None,
+ fig_save_cfg=dict(frameon=False),
+ fig_show_cfg=dict(frameon=False, num='show')
+ ) -> None:
+ super().__init__(name)
+ self._dataset_meta: Union[None, dict] = None
+ self._vis_backends: Union[Dict, Dict[str, 'BaseVisBackend']] = dict()
+ if vis_backends:
+ with_name = False
+ without_name = False
+ for vis_backend in vis_backends:
+ if 'name' in vis_backend:
+ with_name = True
+ else:
+ without_name = True
+ if with_name and without_name:
+ raise AssertionError
+
+ for vis_backend in vis_backends:
+ name = vis_backend.pop('name', vis_backend['type'])
+ assert name not in self._vis_backends
+ vis_backend.setdefault('save_dir', save_dir)
+ self._vis_backends[name] = VISBACKENDS.build(vis_backend)
+
+ self.is_inline = 'inline' in plt.get_backend()
+
+ self.fig_save = None
+ self.fig_show = None
+ self.fig_save_num = fig_save_cfg.get('num', None)
+ self.fig_show_num = fig_show_cfg.get('num', None)
+ self.fig_save_cfg = fig_save_cfg
+ self.fig_show_cfg = fig_show_cfg
+
+ (self.fig_save, self.ax_save,
+ self.fig_save_num) = self._initialize_fig(fig_save_cfg)
+ self.dpi = self.fig_save.get_dpi()
if image is not None:
- self._setup_fig(image)
+ self.set_image(image)
- def draw(self,
- image: Optional[np.ndarray] = None,
- gt_sample: Optional['BaseDataElement'] = None,
- pred_sample: Optional['BaseDataElement'] = None,
- draw_gt: bool = True,
- draw_pred: bool = True) -> None:
- pass
+ @property
+ def dataset_meta(self) -> Optional[dict]:
+ return self._dataset_meta
- def show(self, wait_time: int = 0) -> None:
+ @dataset_meta.setter
+ def dataset_meta(self, dataset_meta: dict) -> None:
+ self._dataset_meta = dataset_meta
+
+ def show(self,
+ drawn_img: Optional[np.ndarray] = None,
+ win_name: str = 'image',
+ wait_time: int = 0,
+ continue_key=' ') -> None:
"""Show the drawn image.
Args:
wait_time (int, optional): Delay in milliseconds. 0 is the special
value that means "forever". Defaults to 0.
"""
- if wait_time == 0:
- plt.show()
- else:
- plt.show(block=False)
- plt.pause(wait_time)
+ if self.is_inline:
+ return
+ if self.fig_show is None or not plt.fignum_exists(self.fig_show_num):
+ (self.fig_show, self.ax_show,
+ self.fig_show_num) = self._initialize_fig(self.fig_show_cfg)
+ img = self.get_image() if drawn_img is None else drawn_img
+ # dpi = self.fig_show.get_dpi()
+ # height, width = img.shape[:2]
+ # self.fig_show.set_size_inches((width + 1e-2) / dpi,
+ # (height + 1e-2) / dpi)
+ self.ax_show.cla()
+ self.ax_show.axis(False)
+ # self.ax_show.set_title(win_name)
+ # self.fig_show.set_label(win_name)
- def close(self) -> None:
- """Close the figure."""
- plt.close(self.fig)
-
- @classmethod
- def register_task(cls, task_name: str, force: bool = False) -> Callable:
- """Register a function.
-
- A record will be added to ``task_dict``, whose key is the task_name
- and value is the decorated function.
-
- Args:
- cls (type): Module class to be registered.
- task_name (str or list of str, optional): The module name to be
- registered.
- force (bool): Whether to override an existing function with the
- same name. Defaults to False.
- """
-
- def _register(task_func):
-
- if (task_name not in cls.task_dict) or force:
- cls.task_dict[task_name] = task_func
- else:
- raise KeyError(
- f'"{task_name}" is already registered in task_dict, '
- 'add "force=True" if you want to override it')
- return task_func
-
- return _register
+ # Refresh canvas, necessary for Qt5 backend.
+ self.ax_show.imshow(img)
+ self.fig_show.canvas.draw() # type: ignore
+ self._wait_continue(timeout=wait_time, continue_key=continue_key)
def set_image(self, image: np.ndarray) -> None:
"""Set the image to draw.
@@ -188,7 +214,23 @@ class Visualizer:
image (np.ndarray): The image to draw.
"""
assert image is not None
- self._setup_fig(image)
+ image = image.astype('uint8')
+ self._image = image
+ self.width, self.height = image.shape[1], image.shape[0]
+ self._default_font_size = max(
+ np.sqrt(self.height * self.width) // 90, 10)
+
+ # add a small 1e-2 to avoid precision lost due to matplotlib's
+ # truncation (https://github.com/matplotlib/matplotlib/issues/15363)
+ self.fig_save.set_size_inches( # type: ignore
+ (self.width + 1e-2) / self.dpi, (self.height + 1e-2) / self.dpi)
+ # self.canvas = mpl.backends.backend_cairo.FigureCanvasCairo(fig)
+ self.ax_save.cla()
+ self.ax_save.axis(False)
+ self.ax_save.imshow(
+ image,
+ extent=(0, self.width, self.height, 0),
+ interpolation='none')
def get_image(self) -> np.ndarray:
"""Get the drawn image. The format is RGB.
@@ -197,43 +239,24 @@ class Visualizer:
np.ndarray: the drawn image which channel is rgb.
"""
assert self._image is not None, 'Please set image using `set_image`'
- canvas = self.canvas
+ canvas = self.fig_save.canvas # type: ignore
s, (width, height) = canvas.print_to_buffer()
buffer = np.frombuffer(s, dtype='uint8')
img_rgba = buffer.reshape(height, width, 4)
rgb, alpha = np.split(img_rgba, [3], axis=2)
return rgb.astype('uint8')
- def _setup_fig(self, image: np.ndarray) -> None:
- """Set the image to draw.
+ def _initialize_fig(self, fig_cfg):
+ fig = plt.figure(**fig_cfg)
+ ax = fig.add_subplot()
+ ax.axis(False)
- Args:
- image (np.ndarray): The image to draw.The format
- should be RGB.
- """
- image = image.astype('uint8')
- self._image = image
- self.width, self.height = image.shape[1], image.shape[0]
- self._default_font_size = max(
- np.sqrt(self.height * self.width) // 90, 10)
- fig = plt.figure(frameon=False)
+ # remove white edges by set subplot margin
+ fig.subplots_adjust(left=0, right=1, bottom=0, top=1)
+ return fig, ax, fig.number
- self.dpi = fig.get_dpi()
- # add a small 1e-2 to avoid precision lost due to matplotlib's
- # truncation (https://github.com/matplotlib/matplotlib/issues/15363)
- fig.set_size_inches((self.width + 1e-2) / self.dpi,
- (self.height + 1e-2) / self.dpi)
- self.canvas = fig.canvas
- # self.canvas = mpl.backends.backend_cairo.FigureCanvasCairo(fig)
- plt.subplots_adjust(left=0, right=1, bottom=0, top=1)
- plt.axis('off')
- ax = plt.gca()
- self.fig = fig
- self.ax = ax
- self.ax.imshow(
- image,
- extent=(0, self.width, self.height, 0),
- interpolation='none')
+ def get_backend(self, name) -> 'BaseVisBackend':
+ return self._vis_backends.get(name) # type: ignore
def _is_posion_valid(self, position: np.ndarray) -> bool:
"""Judge whether the position is in image.
@@ -251,14 +274,86 @@ class Visualizer:
(position[..., 1] >= 0).all()
return flag
+ def _wait_continue(self, timeout: int = 0, continue_key=' ') -> int:
+ """Show the image and wait for the user's input.
+
+ This implementation refers to
+ https://github.com/matplotlib/matplotlib/blob/v3.5.x/lib/matplotlib/_blocking_input.py
+
+ Args:
+ timeout (int): If positive, continue after ``timeout`` seconds.
+ Defaults to 0.
+ continue_key (str): The key for users to continue. Defaults to
+ the space key.
+
+ Returns:
+ int: If zero, means time out or the user pressed ``continue_key``,
+ and if one, means the user closed the show figure.
+ """ # noqa: E501
+ if self.is_inline:
+ # If use inline backend, interactive input and timeout is no use.
+ return 0
+
+ if self.fig_show.canvas.manager: # type: ignore
+ # Ensure that the figure is shown
+ self.fig_show.show() # type: ignore
+
+ while True:
+
+ # Connect the events to the handler function call.
+ event = None
+
+ def handler(ev):
+ # Set external event variable
+ nonlocal event
+ # Qt backend may fire two events at the same time,
+ # use a condition to avoid missing close event.
+ event = ev if not isinstance(event, CloseEvent) else event
+ self.fig_show.canvas.stop_event_loop()
+
+ cids = [
+ self.fig_show.canvas.mpl_connect(name, handler) # type: ignore
+ for name in ('key_press_event', 'close_event')
+ ]
+
+ try:
+ self.fig_show.canvas.start_event_loop(timeout) # type: ignore
+ finally: # Run even on exception like ctrl-c.
+ # Disconnect the callbacks.
+ for cid in cids:
+ self.fig_show.canvas.mpl_disconnect(cid) # type: ignore
+
+ if isinstance(event, CloseEvent):
+ return 1 # Quit for close.
+ elif event is None or event.key == continue_key:
+ return 0 # Quit for continue.
+
+ def draw_points(self,
+ positions: Union[np.ndarray, torch.Tensor],
+ colors: Union[str, tuple, List[str], List[tuple]] = 'g',
+ marker: Optional[str] = None,
+ sizes: Optional[Union[np.ndarray, torch.Tensor]] = None):
+ check_type('positions', positions, (np.ndarray, torch.Tensor))
+ positions = tensor2ndarray(positions)
+
+ if len(positions.shape) == 1:
+ positions = positions[None]
+ assert positions.shape[-1] == 2, (
+ 'The shape of `positions` should be (N, 2), '
+ f'but got {positions.shape}')
+ colors = color_val_matplotlib(colors)
+ self.ax_save.scatter(
+ positions[:, 0], positions[:, 1], c=colors, s=sizes, marker=marker)
+ return self
+
def draw_texts(
self,
texts: Union[str, List[str]],
positions: Union[np.ndarray, torch.Tensor],
font_sizes: Optional[Union[int, List[int]]] = None,
- colors: Union[str, List[str]] = 'g',
- verticalalignments: Union[str, List[str]] = 'top',
- horizontalalignments: Union[str, List[str]] = 'left',
+ colors: Union[str, tuple, List[str], List[tuple]] = 'g',
+ vertical_alignments: Union[str, List[str]] = 'top',
+ horizontal_alignments: Union[str, List[str]] = 'left',
font_families: Union[str, List[str]] = 'sans-serif',
rotations: Union[int, str, List[Union[int, str]]] = 0,
bboxes: Optional[Union[dict, List[dict]]] = None) -> 'Visualizer':
@@ -273,29 +368,29 @@ class Visualizer:
texts. ``font_sizes`` can have the same length with texts or
just single value. If ``font_sizes`` is single value, all the
texts will have the same font size. Defaults to None.
- colors (Union[str, List[str]]): The colors of texts. ``colors``
- can have the same length with texts or just single value.
- If ``colors`` is single value, all the texts will have the same
- colors. Reference to
+ colors (Union[str, tuple, List[str], List[tuple]]): The colors
+ of texts. ``colors`` can have the same length with texts or
+ just single value. If ``colors`` is single value, all the
+ texts will have the same colors. Reference to
https://matplotlib.org/stable/gallery/color/named_colors.html
for more details. Defaults to 'g.
- verticalalignments (Union[str, List[str]]): The verticalalignment
+ vertical_alignments (Union[str, List[str]]): The verticalalignment
of texts. verticalalignment controls whether the y positional
argument for the text indicates the bottom, center or top side
of the text bounding box.
- ``verticalalignments`` can have the same length with
- texts or just single value. If ``verticalalignments`` is single
- value, all the texts will have the same verticalalignment.
- verticalalignment can be 'center' or 'top', 'bottom' or
- 'baseline'. Defaults to 'top'.
- horizontalalignments (Union[str, List[str]]): The
+ ``vertical_alignments`` can have the same length with
+ texts or just single value. If ``vertical_alignments`` is
+ single value, all the texts will have the same
+ verticalalignment. verticalalignment can be 'center' or
+ 'top', 'bottom' or 'baseline'. Defaults to 'top'.
+ horizontal_alignments (Union[str, List[str]]): The
horizontalalignment of texts. Horizontalalignment controls
whether the x positional argument for the text indicates the
left, center or right side of the text bounding box.
- ``horizontalalignments`` can have
+ ``horizontal_alignments`` can have
the same length with texts or just single value.
- If ``horizontalalignments`` is single value, all the texts will
- have the same horizontalalignment. Horizontalalignment
+ If ``horizontal_alignments`` is single value, all the texts
+ will have the same horizontalalignment. Horizontalalignment
can be 'center','right' or 'left'. Defaults to 'left'.
font_families (Union[str, List[str]]): The font family of
texts. ``font_families`` can have the same length with texts or
@@ -335,19 +430,22 @@ class Visualizer:
if font_sizes is None:
font_sizes = self._default_font_size
- check_type_and_length('font_sizes', font_sizes, (int, list), num_text)
- font_sizes = value2list(font_sizes, int, num_text)
+ check_type_and_length('font_sizes', font_sizes, (int, float, list),
+ num_text)
+ font_sizes = value2list(font_sizes, (int, float), num_text)
- check_type_and_length('colors', colors, (str, list), num_text)
- colors = value2list(colors, str, num_text)
+ check_type_and_length('colors', colors, (str, tuple, list), num_text)
+ colors = value2list(colors, (str, tuple), num_text)
+ colors = color_val_matplotlib(colors)
- check_type_and_length('verticalalignments', verticalalignments,
+ check_type_and_length('vertical_alignments', vertical_alignments,
(str, list), num_text)
- verticalalignments = value2list(verticalalignments, str, num_text)
+ vertical_alignments = value2list(vertical_alignments, str, num_text)
- check_type_and_length('horizontalalignments', horizontalalignments,
+ check_type_and_length('horizontal_alignments', horizontal_alignments,
(str, list), num_text)
- horizontalalignments = value2list(horizontalalignments, str, num_text)
+ horizontal_alignments = value2list(horizontal_alignments, str,
+ num_text)
check_type_and_length('rotations', rotations, (int, list), num_text)
rotations = value2list(rotations, int, num_text)
@@ -363,14 +461,14 @@ class Visualizer:
bboxes = value2list(bboxes, dict, num_text)
for i in range(num_text):
- self.ax.text(
+ self.ax_save.text(
positions[i][0],
positions[i][1],
texts[i],
size=font_sizes[i], # type: ignore
bbox=bboxes[i], # type: ignore
- verticalalignment=verticalalignments[i],
- horizontalalignment=horizontalalignments[i],
+ verticalalignment=vertical_alignments[i],
+ horizontalalignment=horizontal_alignments[i],
family=font_families[i],
color=colors[i])
return self
@@ -379,9 +477,9 @@ class Visualizer:
self,
x_datas: Union[np.ndarray, torch.Tensor],
y_datas: Union[np.ndarray, torch.Tensor],
- colors: Union[str, List[str]] = 'g',
- linestyles: Union[str, List[str]] = '-',
- linewidths: Union[Union[int, float], List[Union[int, float]]] = 1
+ colors: Union[str, tuple, List[str], List[tuple]] = 'g',
+ line_styles: Union[str, List[str]] = '-',
+ line_widths: Union[Union[int, float], List[Union[int, float]]] = 2
) -> 'Visualizer':
"""Draw single or multiple line segments.
@@ -390,24 +488,24 @@ class Visualizer:
each line' start and end points.
y_datas (Union[np.ndarray, torch.Tensor]): The y coordinate of
each line' start and end points.
- colors (Union[str, List[str]]): The colors of lines. ``colors``
- can have the same length with lines or just single value.
- If ``colors`` is single value, all the lines will have the same
- colors. Reference to
+ colors (Union[str, tuple, List[str], List[tuple]]): The colors of
+ lines. ``colors`` can have the same length with lines or just
+ single value. If ``colors`` is single value, all the lines
+ will have the same colors. Reference to
https://matplotlib.org/stable/gallery/color/named_colors.html
for more details. Defaults to 'g'.
- linestyles (Union[str, List[str]]): The linestyle
- of lines. ``linestyles`` can have the same length with
- texts or just single value. If ``linestyles`` is single
+ line_styles (Union[str, List[str]]): The linestyle
+ of lines. ``line_styles`` can have the same length with
+ texts or just single value. If ``line_styles`` is single
value, all the lines will have the same linestyle.
Reference to
https://matplotlib.org/stable/api/collections_api.html?highlight=collection#matplotlib.collections.AsteriskPolygonCollection.set_linestyle
for more details. Defaults to '-'.
- linewidths (Union[Union[int, float], List[Union[int, float]]]): The
- linewidth of lines. ``linewidths`` can have
+ line_widths (Union[Union[int, float], List[Union[int, float]]]):
+ The linewidth of lines. ``line_widths`` can have
the same length with lines or just single value.
- If ``linewidths`` is single value, all the lines will
- have the same linewidth. Defaults to 1.
+ If ``line_widths`` is single value, all the lines will
+ have the same linewidth. Defaults to 2.
"""
check_type('x_datas', x_datas, (np.ndarray, torch.Tensor))
x_datas = tensor2ndarray(x_datas)
@@ -421,31 +519,31 @@ class Visualizer:
if len(x_datas.shape) == 1:
x_datas = x_datas[None]
y_datas = y_datas[None]
-
+ colors = color_val_matplotlib(colors)
lines = np.concatenate(
(x_datas.reshape(-1, 2, 1), y_datas.reshape(-1, 2, 1)), axis=-1)
if not self._is_posion_valid(lines):
-
warnings.warn(
'Warning: The line is out of bounds,'
' the drawn line may not be in the image', UserWarning)
line_collect = LineCollection(
lines.tolist(),
colors=colors,
- linestyles=linestyles,
- linewidths=linewidths)
- self.ax.add_collection(line_collect)
+ linestyles=line_styles,
+ linewidths=line_widths)
+ self.ax_save.add_collection(line_collect)
return self
- def draw_circles(self,
- center: Union[np.ndarray, torch.Tensor],
- radius: Union[np.ndarray, torch.Tensor],
- alpha: Union[float, int] = 0.8,
- edgecolors: Union[str, List[str]] = 'g',
- linestyles: Union[str, List[str]] = '-',
- linewidths: Union[Union[int, float],
- List[Union[int, float]]] = 1,
- is_filling: bool = False) -> 'Visualizer':
+ def draw_circles(
+ self,
+ center: Union[np.ndarray, torch.Tensor],
+ radius: Union[np.ndarray, torch.Tensor],
+ alpha: Union[float, int] = 0.8,
+ edge_colors: Union[str, tuple, List[str], List[tuple]] = 'g',
+ line_styles: Union[str, List[str]] = '-',
+ line_widths: Union[Union[int, float], List[Union[int, float]]] = 2,
+ face_colors: Union[str, tuple, List[str], List[tuple]] = 'none'
+ ) -> 'Visualizer':
"""Draw single or multiple circles.
Args:
@@ -453,24 +551,24 @@ class Visualizer:
each line' start and end points.
radius (Union[np.ndarray, torch.Tensor]): The y coordinate of
each line' start and end points.
- edgecolors (Union[str, List[str]]): The colors of circles.
- ``colors`` can have the same length with lines or just single
- value. If ``colors`` is single value, all the lines will have
- the same colors. Reference to
+ edge_colors (Union[str, tuple, List[str], List[tuple]]): The
+ colors of circles. ``colors`` can have the same length with
+ lines or just single value. If ``colors`` is single value,
+ all the lines will have the same colors. Reference to
https://matplotlib.org/stable/gallery/color/named_colors.html
for more details. Defaults to 'g.
- linestyles (Union[str, List[str]]): The linestyle
- of lines. ``linestyles`` can have the same length with
- texts or just single value. If ``linestyles`` is single
+ line_styles (Union[str, List[str]]): The linestyle
+ of lines. ``line_styles`` can have the same length with
+ texts or just single value. If ``line_styles`` is single
value, all the lines will have the same linestyle.
Reference to
https://matplotlib.org/stable/api/collections_api.html?highlight=collection#matplotlib.collections.AsteriskPolygonCollection.set_linestyle
for more details. Defaults to '-'.
- linewidths (Union[Union[int, float], List[Union[int, float]]]): The
- linewidth of lines. ``linewidths`` can have
+ line_widths (Union[Union[int, float], List[Union[int, float]]]):
+ The linewidth of lines. ``line_widths`` can have
the same length with lines or just single value.
- If ``linewidths`` is single value, all the lines will
- have the same linewidth. Defaults to 1.
+ If ``line_widths`` is single value, all the lines will
+ have the same linewidth. Defaults to 2.
is_filling (bool): Whether to fill all the circles. Defaults to
False.
"""
@@ -493,57 +591,59 @@ class Visualizer:
center = center.tolist()
radius = radius.tolist()
+ edge_colors = color_val_matplotlib(edge_colors)
+ face_colors = color_val_matplotlib(face_colors)
circles = []
for i in range(len(center)):
circles.append(Circle(tuple(center[i]), radius[i]))
- if is_filling:
- p = PatchCollection(circles, alpha=alpha, facecolor=edgecolors)
- else:
- if isinstance(linewidths, (int, float)):
- linewidths = [linewidths] * len(circles)
- linewidths = [
- min(max(linewidth, 1), self._default_font_size / 4)
- for linewidth in linewidths
- ]
- p = PatchCollection(
- circles,
- alpha=alpha,
- facecolor='none',
- edgecolor=edgecolors,
- linewidth=linewidths,
- linestyles=linestyles)
- self.ax.add_collection(p)
+
+ if isinstance(line_widths, (int, float)):
+ line_widths = [line_widths] * len(circles)
+ line_widths = [
+ min(max(linewidth, 1), self._default_font_size / 4)
+ for linewidth in line_widths
+ ]
+ p = PatchCollection(
+ circles,
+ alpha=alpha,
+ facecolors=face_colors,
+ edgecolors=edge_colors,
+ linewidths=line_widths,
+ linestyles=line_styles)
+ self.ax_save.add_collection(p)
return self
- def draw_bboxes(self,
- bboxes: Union[np.ndarray, torch.Tensor],
- alpha: Union[int, float] = 0.8,
- edgecolors: Union[str, List[str]] = 'g',
- linestyles: Union[str, List[str]] = '-',
- linewidths: Union[Union[int, float],
- List[Union[int, float]]] = 1,
- is_filling: bool = False) -> 'Visualizer':
+ def draw_bboxes(
+ self,
+ bboxes: Union[np.ndarray, torch.Tensor],
+ alpha: Union[int, float] = 0.8,
+ edge_colors: Union[str, tuple, List[str], List[tuple]] = 'g',
+ line_styles: Union[str, List[str]] = '-',
+ line_widths: Union[Union[int, float], List[Union[int, float]]] = 2,
+ face_colors: Union[str, tuple, List[str], List[tuple]] = 'none'
+ ) -> 'Visualizer':
"""Draw single or multiple bboxes.
Args:
bboxes (Union[np.ndarray, torch.Tensor]): The bboxes to draw with
the format of(x1,y1,x2,y2).
- edgecolors (Union[str, List[str]]): The colors of bboxes.
- ``colors`` can have the same length with lines or just single
- value. If ``colors`` is single value, all the lines will have
- the same colors. Refer to `matplotlib.colors` for full list of
- formats that are accepted.. Defaults to 'g'.
- linestyles (Union[str, List[str]]): The linestyle
- of lines. ``linestyles`` can have the same length with
- texts or just single value. If ``linestyles`` is single
+ edge_colors (Union[str, tuple, List[str], List[tuple]]): The
+ colors of bboxes. ``colors`` can have the same length with
+ lines or just single value. If ``colors`` is single value, all
+ the lines will have the same colors. Refer to `matplotlib.
+ colors` for full list of formats that are accepted.
+ Defaults to 'g'.
+ line_styles (Union[str, List[str]]): The linestyle
+ of lines. ``line_styles`` can have the same length with
+ texts or just single value. If ``line_styles`` is single
value, all the lines will have the same linestyle.
Reference to
https://matplotlib.org/stable/api/collections_api.html?highlight=collection#matplotlib.collections.AsteriskPolygonCollection.set_linestyle
for more details. Defaults to '-'.
- linewidths (Union[Union[int, float], List[Union[int, float]]]): The
- linewidth of lines. ``linewidths`` can have
+ line_widths (Union[Union[int, float], List[Union[int, float]]]):
+ The linewidth of lines. ``line_widths`` can have
the same length with lines or just single value.
- If ``linewidths`` is single value, all the lines will
+ If ``line_widths`` is single value, all the lines will
have the same linewidth. Defaults to 1.
is_filling (bool): Whether to fill all the bboxes. Defaults to
False.
@@ -570,47 +670,51 @@ class Visualizer:
return self.draw_polygons(
poly,
alpha=alpha,
- edgecolors=edgecolors,
- linestyles=linestyles,
- linewidths=linewidths,
- is_filling=is_filling)
+ edge_colors=edge_colors,
+ line_styles=line_styles,
+ line_widths=line_widths,
+ face_colors=face_colors)
- def draw_polygons(self,
- polygons: Union[Union[np.ndarray, torch.Tensor],
- List[Union[np.ndarray, torch.Tensor]]],
- alpha: Union[int, float] = 0.8,
- edgecolors: Union[str, List[str]] = 'g',
- linestyles: Union[str, List[str]] = '-',
- linewidths: Union[Union[int, float],
- List[Union[int, float]]] = 1.0,
- is_filling: bool = False) -> 'Visualizer':
+ def draw_polygons(
+ self,
+ polygons: Union[Union[np.ndarray, torch.Tensor],
+ List[Union[np.ndarray, torch.Tensor]]],
+ alpha: Union[int, float] = 0.8,
+ edge_colors: Union[str, tuple, List[str], List[tuple]] = 'g',
+ line_styles: Union[str, List[str]] = '-',
+ line_widths: Union[Union[int, float], List[Union[int, float]]] = 2,
+ face_colors: Union[str, tuple, List[str], List[tuple]] = 'none'
+ ) -> 'Visualizer':
"""Draw single or multiple bboxes.
Args:
polygons (Union[Union[np.ndarray, torch.Tensor],
List[Union[np.ndarray, torch.Tensor]]]): The polygons to draw
with the format of (x1,y1,x2,y2,...,xn,yn).
- edgecolors (Union[str, List[str]]): The colors of polygons.
- ``colors`` can have the same length with lines or just single
- value. If ``colors`` is single value, all the lines will have
- the same colors. Refer to `matplotlib.colors` for full list of
- formats that are accepted.. Defaults to 'g.
- linestyles (Union[str, List[str]]): The linestyle
- of lines. ``linestyles`` can have the same length with
- texts or just single value. If ``linestyles`` is single
+ edge_colors (Union[str, tuple, List[str], List[tuple]]): The
+ colors of polygons. ``colors`` can have the same length with
+ lines or just single value. If ``colors`` is single value,
+ all the lines will have the same colors. Refer to
+ `matplotlib.colors` for full list of formats that are accepted.
+ Defaults to 'g.
+ line_styles (Union[str, List[str]]): The linestyle
+ of lines. ``line_styles`` can have the same length with
+ texts or just single value. If ``line_styles`` is single
value, all the lines will have the same linestyle.
Reference to
https://matplotlib.org/stable/api/collections_api.html?highlight=collection#matplotlib.collections.AsteriskPolygonCollection.set_linestyle
for more details. Defaults to '-'.
- linewidths (Union[Union[int, float], List[Union[int, float]]]): The
- linewidth of lines. ``linewidths`` can have
+ line_widths (Union[Union[int, float], List[Union[int, float]]]):
+ The linewidth of lines. ``line_widths`` can have
the same length with lines or just single value.
- If ``linewidths`` is single value, all the lines will
- have the same linewidth. Defaults to 1.
+ If ``line_widths`` is single value, all the lines will
+ have the same linewidth. Defaults to 2.
is_filling (bool): Whether to fill all the polygons. Defaults to
False.
"""
check_type('polygons', polygons, (list, np.ndarray, torch.Tensor))
+ edge_colors = color_val_matplotlib(edge_colors)
+ face_colors = color_val_matplotlib(face_colors)
if isinstance(polygons, (np.ndarray, torch.Tensor)):
polygons = [polygons]
@@ -625,32 +729,29 @@ class Visualizer:
warnings.warn(
'Warning: The polygon is out of bounds,'
' the drawn polygon may not be in the image', UserWarning)
- if is_filling:
- polygon_collection = PolyCollection(
- polygons, alpha=alpha, facecolor=edgecolors)
- else:
- if isinstance(linewidths, (int, float)):
- linewidths = [linewidths] * len(polygons)
- linewidths = [
- min(max(linewidth, 1), self._default_font_size / 4)
- for linewidth in linewidths
- ]
- polygon_collection = PolyCollection(
- polygons,
- alpha=alpha,
- facecolor='none',
- linestyles=linestyles,
- edgecolors=edgecolors,
- linewidths=linewidths)
+ if isinstance(line_widths, (int, float)):
+ line_widths = [line_widths] * len(polygons)
+ line_widths = [
+ min(max(linewidth, 1), self._default_font_size / 4)
+ for linewidth in line_widths
+ ]
+ polygon_collection = PolyCollection(
+ polygons,
+ alpha=alpha,
+ facecolor=face_colors,
+ linestyles=line_styles,
+ edgecolors=edge_colors,
+ linewidths=line_widths)
- self.ax.add_collection(polygon_collection)
+ self.ax_save.add_collection(polygon_collection)
return self
def draw_binary_masks(
- self,
- binary_masks: Union[np.ndarray, torch.Tensor],
- colors: np.ndarray = np.array([0, 255, 0]),
- alphas: Union[float, List[float]] = 0.5) -> 'Visualizer':
+ self,
+ binary_masks: Union[np.ndarray, torch.Tensor],
+ alphas: Union[float, List[float]] = 0.8,
+ colors: Union[str, tuple, List[str],
+ List[tuple]] = 'g') -> 'Visualizer':
"""Draw single or multiple binary masks.
Args:
@@ -677,12 +778,24 @@ class Visualizer:
binary_masks = binary_masks[None]
assert img.shape[:2] == binary_masks.shape[
1:], '`binary_marks` must have the same shpe with image'
- assert isinstance(colors, np.ndarray)
- if colors.ndim == 1:
- colors = np.tile(colors, (binary_masks.shape[0], 1))
- assert colors.shape == (binary_masks.shape[0], 3)
+ binary_mask_len = binary_masks.shape[0]
+
+ check_type_and_length('colors', colors, (str, tuple, list),
+ binary_mask_len)
+ colors = value2list(colors, (str, tuple), binary_mask_len)
+ colors = [
+ str_color_to_rgb(color) if isinstance(color, str) else color
+ for color in colors
+ ]
+ for color in colors:
+ assert len(color) == 3
+ for channel in color:
+ assert 0 <= channel <= 255 # type: ignore
+ colors = np.array(colors)
+ if colors.ndim == 1: # type: ignore
+ colors = np.tile(colors, (binary_mask_len, 1))
if isinstance(alphas, float):
- alphas = [alphas] * binary_masks.shape[0]
+ alphas = [alphas] * binary_mask_len
for binary_mask, color, alpha in zip(binary_masks, colors, alphas):
binary_mask_complement = cv2.bitwise_not(binary_mask)
@@ -692,8 +805,8 @@ class Visualizer:
img_complement = cv2.bitwise_and(
img, img, mask=binary_mask_complement)
rgb = rgb + img_complement
- img = cv2.addWeighted(img, alpha, rgb, 1 - alpha, 0)
- self.ax.imshow(
+ img = cv2.addWeighted(img, 1 - alpha, rgb, alpha, 0)
+ self.ax_save.imshow(
img,
extent=(0, self.width, self.height, 0),
interpolation='nearest')
@@ -705,7 +818,7 @@ class Visualizer:
mode: str = 'mean',
topk: int = 10,
arrangement: Tuple[int, int] = (5, 2),
- alpha: float = 0.3) -> np.ndarray:
+ alpha: float = 0.8) -> np.ndarray:
"""Draw featmap. If img is not None, the final image will be the
weighted sum of img and featmap. It support the mode:
@@ -738,37 +851,6 @@ class Visualizer:
Returns:
np.ndarray: featmap.
"""
-
- def concat_heatmap(feat_map: Union[np.ndarray, torch.Tensor],
- img: Optional[np.ndarray] = None,
- alpha: float = 0.5) -> np.ndarray:
- """Convert feat_map to heatmap and sum to image, if image is not
- None.
-
- Args:
- feat_map (np.ndarray, torch.Tensor): The feat_map to convert
- with of shape (H, W), where H is the image height and W is
- the image width.
- img (np.ndarray, optional): The origin image. The format
- should be RGB. Defaults to None.
- alphas (Union[int, List[int]]): The transparency of origin
- image. Defaults to 0.5.
-
- Returns:
- np.ndarray: heatmap
- """
- if isinstance(feat_map, torch.Tensor):
- feat_map = feat_map.detach().cpu().numpy()
- norm_img = np.zeros(feat_map.shape)
- norm_img = cv2.normalize(feat_map, norm_img, 0, 255,
- cv2.NORM_MINMAX)
- norm_img = np.asarray(norm_img, dtype=np.uint8)
- heat_img = cv2.applyColorMap(norm_img, cv2.COLORMAP_JET)
- heat_img = cv2.cvtColor(heat_img, cv2.COLOR_BGR2RGB)
- if img is not None:
- heat_img = cv2.addWeighted(img, alpha, heat_img, 1 - alpha, 0)
- return heat_img
-
assert isinstance(
tensor_chw,
torch.Tensor), (f'`tensor_chw` should be {torch.Tensor} '
@@ -785,11 +867,9 @@ class Visualizer:
], (f'Mode only support "mean", "max", "min", but got {mode}')
if mode == 'max':
feat_map, _ = torch.max(tensor_chw, dim=0)
- elif mode == 'min':
- feat_map, _ = torch.min(tensor_chw, dim=0)
elif mode == 'mean':
feat_map = torch.mean(tensor_chw, dim=0)
- return concat_heatmap(feat_map, image, alpha)
+ return convert_overlay_heatmap(feat_map, image, alpha)
if topk <= 0:
tensor_chw_channel = tensor_chw.shape[0]
@@ -801,16 +881,15 @@ class Visualizer:
' mode parameter or set topk greater than 0 to solve '
'the error')
if tensor_chw_channel == 1:
- return concat_heatmap(tensor_chw[0], image, alpha)
+ return convert_overlay_heatmap(tensor_chw[0], image, alpha)
else:
tensor_chw = tensor_chw.permute(1, 2, 0).numpy()
- norm_img = np.zeros(tensor_chw.shape)
norm_img = cv2.normalize(tensor_chw, None, 0, 255,
cv2.NORM_MINMAX)
heat_img = np.asarray(norm_img, dtype=np.uint8)
if image is not None:
- heat_img = cv2.addWeighted(image, alpha, heat_img,
- 1 - alpha, 0)
+ heat_img = cv2.addWeighted(image, 1 - alpha, heat_img,
+ alpha, 0)
return heat_img
else:
row, col = arrangement
@@ -833,9 +912,133 @@ class Visualizer:
for i in range(topk):
axes = fig.add_subplot(row, col, i + 1)
axes.axis('off')
- axes.imshow(concat_heatmap(topk_tensor[i], image, alpha))
+ axes.imshow(
+ convert_overlay_heatmap(topk_tensor[i], image, alpha))
s, (width, height) = canvas.print_to_buffer()
buffer = np.frombuffer(s, dtype='uint8')
img_rgba = buffer.reshape(height, width, 4)
rgb, alpha = np.split(img_rgba, [3], axis=2)
return rgb.astype('uint8')
+
+ def add_config(self, config: Config, **kwargs):
+ """Record parameters.
+
+ Args:
+ config (Config): The Config object.
+ """
+ for vis_backend in self._vis_backends.values():
+ vis_backend.add_config(config, **kwargs) # type: ignore
+
+ def add_graph(self, model: torch.nn.Module, data_batch: Sequence[dict],
+ **kwargs) -> None:
+ """Record graph data.
+
+ Args:
+ model (torch.nn.Module): Model to draw.
+ data_batch (Sequence[dict]): Batch of data from dataloader.
+ """
+ for vis_backend in self._vis_backends.values():
+ vis_backend.add_graph(model, data_batch, **kwargs) # type: ignore
+
+ def add_image(self, name: str, image: np.ndarray, step: int = 0) -> None:
+ """Record image.
+
+ Args:
+ name (str): The unique identifier for the image to save.
+ image (np.ndarray, optional): The image to be saved. The format
+ should be RGB. Default to None.
+ step (int): Global step value to record. Default to 0.
+ """
+ for vis_backend in self._vis_backends.values():
+ vis_backend.add_image(name, image, step) # type: ignore
+
+ def add_scalar(self,
+ name: str,
+ value: Union[int, float],
+ step: int = 0,
+ **kwargs) -> None:
+ """Record scalar data.
+
+ Args:
+ name (str): The unique identifier for the scalar to save.
+ value (float, int): Value to save.
+ step (int): Global step value to record. Default to 0.
+ """
+ for vis_backend in self._vis_backends.values():
+ vis_backend.add_scalar(name, value, step, **kwargs) # type: ignore
+
+ def add_scalars(self,
+ scalar_dict: dict,
+ step: int = 0,
+ file_path: Optional[str] = None,
+ **kwargs) -> None:
+ """Record scalars' data.
+
+ Args:
+ scalar_dict (dict): Key-value pair storing the tag and
+ corresponding values.
+ step (int): Global step value to record. Default to 0.
+ file_path (str, optional): The scalar's data will be
+ saved to the `file_path` file at the same time
+ if the `file_path` parameter is specified.
+ Default to None.
+ """
+ for vis_backend in self._vis_backends.values():
+ vis_backend.add_scalars( # type: ignore
+ scalar_dict, step, file_path, **kwargs)
+
+ def add_datasample(self,
+ name,
+ image: np.ndarray,
+ gt_sample: Optional['BaseDataElement'] = None,
+ pred_sample: Optional['BaseDataElement'] = None,
+ draw_gt: bool = True,
+ draw_pred: bool = True,
+ show: bool = False,
+ wait_time: int = 0,
+ step: int = 0) -> None:
+ pass
+
+ def close(self) -> None:
+ """close an opened object."""
+ plt.close(self.fig_save)
+ if self.fig_show is not None:
+ plt.close(self.fig_show)
+ for vis_backend in self._vis_backends.values():
+ vis_backend.close() # type: ignore
+
+ @classmethod
+ def get_instance(cls, name: str, **kwargs) -> 'Visualizer':
+ """Make subclass can get latest created instance by
+ ``Visualizer.get_current_instance()``.
+
+ Downstream codebase may need to get the latest created instance
+ without knowing the specific Visualizer type. For example, mmdetection
+ builds visualizer in runner and some component which cannot access
+ runner wants to get latest created visualizer. In this case,
+ the component does not know which type of visualizer has been built
+ and cannot get target instance. Therefore, :class:`Visualizer`
+ overrides the :meth:`get_instance` and its subclass will register
+ the created instance to :attr:`_instance_dict` additionally.
+ :meth:`get_current_instance` will return the latest created subclass
+ instance.
+
+ Examples:
+ >>> class DetLocalVisualizer(Visualizer):
+ >>> def __init__(self, name):
+ >>> super().__init__(name)
+ >>>
+ >>> visualizer1 = DetLocalVisualizer.get_instance('name1')
+ >>> visualizer2 = Visualizer.get_current_instance()
+ >>> visualizer3 = DetLocalVisualizer.get_current_instance()
+ >>> assert id(visualizer1) == id(visualizer2) == id(visualizer3)
+
+ Args:
+ name (str): Name of instance. Defaults to ''.
+
+ Returns:
+ object: Corresponding name instance.
+ """
+ instance = super().get_instance(name, **kwargs)
+ Visualizer._instance_dict[name] = instance
+ return instance
diff --git a/mmengine/visualization/writer.py b/mmengine/visualization/writer.py
deleted file mode 100644
index 72217ac8..00000000
--- a/mmengine/visualization/writer.py
+++ /dev/null
@@ -1,823 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import os
-import os.path as osp
-import time
-from abc import ABCMeta, abstractmethod
-from typing import Any, List, Optional, Union
-
-import cv2
-import numpy as np
-import torch
-
-from mmengine.data import BaseDataElement
-from mmengine.dist import master_only
-from mmengine.fileio import dump
-from mmengine.registry import VISUALIZERS, WRITERS
-from mmengine.utils import TORCH_VERSION, ManagerMixin
-from .visualizer import Visualizer
-
-
-class BaseWriter(metaclass=ABCMeta):
- """Base class for writer.
-
- Each writer can inherit ``BaseWriter`` and implement
- the required functions.
-
- Args:
- visualizer (dict, :obj:`Visualizer`, optional):
- Visualizer instance or dictionary. Default to None.
- save_dir (str, optional): The root directory to save
- the files produced by the writer. Default to None.
- """
-
- def __init__(self,
- visualizer: Optional[Union[dict, 'Visualizer']] = None,
- save_dir: Optional[str] = None):
- self._save_dir = save_dir
- if self._save_dir:
- timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
- self._save_dir = osp.join(
- self._save_dir, f'write_data_{timestamp}') # type: ignore
- self._visualizer = visualizer
- if visualizer:
- if isinstance(visualizer, dict):
- self._visualizer = VISUALIZERS.build(visualizer)
- else:
- assert isinstance(visualizer, Visualizer), \
- 'visualizer should be an instance of Visualizer, ' \
- f'but got {type(visualizer)}'
-
- @property
- def visualizer(self) -> 'Visualizer':
- """Return the visualizer object.
-
- You can get the drawing backend through the visualizer property for
- custom drawing.
- """
- return self._visualizer # type: ignore
-
- @property
- @abstractmethod
- def experiment(self) -> Any:
- """Return the experiment object associated with this writer.
-
- The experiment attribute can get the write backend, such as wandb,
- tensorboard. If you want to write other data, such as writing a table,
- you can directly get the write backend through experiment.
- """
- pass
-
- def add_params(self, params_dict: dict, **kwargs) -> None:
- """Record a set of parameters.
-
- Args:
- params_dict (dict): Each key-value pair in the dictionary is the
- name of the parameters and it's corresponding value.
- """
- pass
-
- def add_graph(self, model: torch.nn.Module,
- input_tensor: Union[torch.Tensor,
- List[torch.Tensor]], **kwargs) -> None:
- """Record graph.
-
- Args:
- model (torch.nn.Module): Model to draw.
- input_tensor (torch.Tensor, list[torch.Tensor]): A variable
- or a tuple of variables to be fed.
- """
- pass
-
- def add_image(self,
- name: str,
- image: Optional[np.ndarray] = None,
- gt_sample: Optional['BaseDataElement'] = None,
- pred_sample: Optional['BaseDataElement'] = None,
- draw_gt: bool = True,
- draw_pred: bool = True,
- step: int = 0,
- **kwargs) -> None:
- """Record image.
-
- Args:
- name (str): The unique identifier for the image to save.
- image (np.ndarray, optional): The image to be saved. The format
- should be RGB. Default to None.
- gt_sample (:obj:`BaseDataElement`, optional): The ground truth data
- structure of OpenMMlab. Default to None.
- pred_sample (:obj:`BaseDataElement`, optional): The predicted
- result data structure of OpenMMlab. Default to None.
- draw_gt (bool): Whether to draw the ground truth. Default: True.
- draw_pred (bool): Whether to draw the predicted result.
- Default to True.
- step (int): Global step value to record. Default to 0.
- """
- pass
-
- def add_scalar(self,
- name: str,
- value: Union[int, float],
- step: int = 0,
- **kwargs) -> None:
- """Record scalar.
-
- Args:
- name (str): The unique identifier for the scalar to save.
- value (float, int): Value to save.
- step (int): Global step value to record. Default to 0.
- """
- pass
-
- def add_scalars(self,
- scalar_dict: dict,
- step: int = 0,
- file_path: Optional[str] = None,
- **kwargs) -> None:
- """Record scalars' data.
-
- Args:
- scalar_dict (dict): Key-value pair storing the tag and
- corresponding values.
- step (int): Global step value to record. Default to 0.
- file_path (str, optional): The scalar's data will be
- saved to the `file_path` file at the same time
- if the `file_path` parameter is specified.
- Default to None.
- """
- pass
-
- def close(self) -> None:
- """close an opened object."""
- pass
-
-
-@WRITERS.register_module()
-class LocalWriter(BaseWriter):
- """Local write class.
-
- It can write image, hyperparameters, scalars, etc.
- to the local hard disk. You can get the drawing backend
- through the visualizer property for custom drawing.
-
- Examples:
- >>> from mmengine.visualization import LocalWriter
- >>> import numpy as np
- >>> local_writer = LocalWriter(dict(type='DetVisualizer'),\
- save_dir='temp_dir')
- >>> img=np.random.randint(0, 256, size=(10, 10, 3))
- >>> local_writer.add_image('img', img)
- >>> local_writer.add_scaler('mAP', 0.6)
- >>> local_writer.add_scalars({'loss': [1, 2, 3], 'acc': 0.8})
- >>> local_writer.add_params(dict(lr=0.1, mode='linear'))
-
- >>> local_writer.visualizer.draw_bboxes(np.array([0, 0, 1, 1]), \
- edgecolors='g')
- >>> local_writer.add_image('img', \
- local_writer.visualizer.get_image())
-
- Args:
- save_dir (str): The root directory to save the files
- produced by the writer.
- visualizer (dict, :obj:`Visualizer`, optional): Visualizer
- instance or dictionary. Default to None
- img_save_dir (str): The directory to save images.
- Default to 'writer_image'.
- params_save_file (str): The file to save parameters.
- Default to 'parameters.yaml'.
- scalar_save_file (str): The file to save scalar values.
- Default to 'scalars.json'.
- img_show (bool): Whether to show the image when calling add_image.
- Default to False.
- """
-
- def __init__(self,
- save_dir: str,
- visualizer: Optional[Union[dict, 'Visualizer']] = None,
- img_save_dir: str = 'writer_image',
- params_save_file: str = 'parameters.yaml',
- scalar_save_file: str = 'scalars.json',
- img_show: bool = False):
- assert params_save_file.split('.')[-1] == 'yaml'
- assert scalar_save_file.split('.')[-1] == 'json'
- super(LocalWriter, self).__init__(visualizer, save_dir)
- os.makedirs(self._save_dir, exist_ok=True) # type: ignore
- self._img_save_dir = osp.join(
- self._save_dir, # type: ignore
- img_save_dir)
- self._scalar_save_file = osp.join(
- self._save_dir, # type: ignore
- scalar_save_file)
- self._params_save_file = osp.join(
- self._save_dir, # type: ignore
- params_save_file)
- self._img_show = img_show
-
- @property
- def experiment(self) -> 'LocalWriter':
- """Return the experiment object associated with this writer."""
- return self
-
- def add_params(self, params_dict: dict, **kwargs) -> None:
- """Record parameters to disk.
-
- Args:
- params_dict (dict): The dict of parameters to save.
- """
- assert isinstance(params_dict, dict)
- self._dump(params_dict, self._params_save_file, 'yaml')
-
- def add_image(self,
- name: str,
- image: Optional[np.ndarray] = None,
- gt_sample: Optional['BaseDataElement'] = None,
- pred_sample: Optional['BaseDataElement'] = None,
- draw_gt: bool = True,
- draw_pred: bool = True,
- step: int = 0,
- **kwargs) -> None:
- """Record image to disk.
-
- Args:
- name (str): The unique identifier for the image to save.
- image (np.ndarray, optional): The image to be saved. The format
- should be RGB. Default to None.
- gt_sample (:obj:`BaseDataElement`, optional): The ground truth data
- structure of OpenMMlab. Default to None.
- pred_sample (:obj:`BaseDataElement`, optional): The predicted
- result data structure of OpenMMlab. Default to None.
- draw_gt (bool): Whether to draw the ground truth. Default to True.
- draw_pred (bool): Whether to draw the predicted result.
- Default to True.
- step (int): Global step value to record. Default to 0.
- """
- assert self.visualizer, 'Please instantiate the visualizer ' \
- 'object with initialization parameters.'
- self.visualizer.draw(image, gt_sample, pred_sample, draw_gt, draw_pred)
- if self._img_show:
- self.visualizer.show()
- else:
- drawn_image = cv2.cvtColor(self.visualizer.get_image(),
- cv2.COLOR_RGB2BGR)
- os.makedirs(self._img_save_dir, exist_ok=True)
- save_file_name = f'{name}_{step}.png'
- cv2.imwrite(
- osp.join(self._img_save_dir, save_file_name), drawn_image)
-
- def add_scalar(self,
- name: str,
- value: Union[int, float],
- step: int = 0,
- **kwargs) -> None:
- """Add scalar data to disk.
-
- Args:
- name (str): The unique identifier for the scalar to save.
- value (float, int): Value to save.
- step (int): Global step value to record. Default to 0.
- """
- self._dump({name: value, 'step': step}, self._scalar_save_file, 'json')
-
- def add_scalars(self,
- scalar_dict: dict,
- step: int = 0,
- file_path: Optional[str] = None,
- **kwargs) -> None:
- """Record scalars. The scalar dict will be written to the default and
- specified files if ``file_name`` is specified.
-
- Args:
- scalar_dict (dict): Key-value pair storing the tag and
- corresponding values.
- step (int): Global step value to record. Default to 0.
- file_path (str, optional): The scalar's data will be
- saved to the ``file_path`` file at the same time
- if the ``file_path`` parameter is specified.
- Default to None.
- """
- assert isinstance(scalar_dict, dict)
- scalar_dict.setdefault('step', step)
- if file_path is not None:
- assert file_path.split('.')[-1] == 'json'
- new_save_file_path = osp.join(
- self._save_dir, # type: ignore
- file_path)
- assert new_save_file_path != self._scalar_save_file, \
- '"file_path" and "scalar_save_file" have the same name, ' \
- 'please set "file_path" to another value'
- self._dump(scalar_dict, new_save_file_path, 'json')
- self._dump(scalar_dict, self._scalar_save_file, 'json')
-
- def _dump(self, value_dict: dict, file_path: str,
- file_format: str) -> None:
- """dump dict to file.
-
- Args:
- value_dict (dict) : Save dict data.
- file_path (str): The file path to save data.
- file_format (str): The file format to save data.
- """
- with open(file_path, 'a+') as f:
- dump(value_dict, f, file_format=file_format)
- f.write('\n')
-
-
-@WRITERS.register_module()
-class WandbWriter(BaseWriter):
- """Write various types of data to wandb.
-
- Examples:
- >>> from mmengine.visualization import WandbWriter
- >>> import numpy as np
- >>> wandb_writer = WandbWriter(dict(type='DetVisualizer'))
- >>> img=np.random.randint(0, 256, size=(10, 10, 3))
- >>> wandb_writer.add_image('img', img)
- >>> wandb_writer.add_scaler('mAP', 0.6)
- >>> wandb_writer.add_scalars({'loss': [1, 2, 3],'acc': 0.8})
- >>> wandb_writer.add_params(dict(lr=0.1, mode='linear'))
-
- >>> wandb_writer.visualizer.draw_bboxes(np.array([0, 0, 1, 1]), \
- edgecolors='g')
- >>> wandb_writer.add_image('img', \
- wandb_writer.visualizer.get_image())
-
- >>> wandb_writer = WandbWriter()
- >>> assert wandb_writer.visualizer is None
- >>> wandb_writer.add_image('img', img)
-
- Args:
- init_kwargs (dict, optional): wandb initialization
- input parameters. Default to None.
- commit: (bool, optional) Save the metrics dict to the wandb server
- and increment the step. If false `wandb.log` just
- updates the current metrics dict with the row argument
- and metrics won't be saved until `wandb.log` is called
- with `commit=True`. Default to True.
- visualizer (dict, :obj:`Visualizer`, optional):
- Visualizer instance or dictionary. Default to None.
- save_dir (str, optional): The root directory to save the files
- produced by the writer. Default to None.
- """
-
- def __init__(self,
- init_kwargs: Optional[dict] = None,
- commit: Optional[bool] = True,
- visualizer: Optional[Union[dict, 'Visualizer']] = None,
- save_dir: Optional[str] = None):
- super(WandbWriter, self).__init__(visualizer, save_dir)
- self._commit = commit
- self._wandb = self._setup_env(init_kwargs)
-
- @property
- def experiment(self):
- """Return wandb object.
-
- The experiment attribute can get the wandb backend, If you want to
- write other data, such as writing a table, you can directly get the
- wandb backend through experiment.
- """
- return self._wandb
-
- def _setup_env(self, init_kwargs: Optional[dict] = None) -> Any:
- """Setup env.
-
- Args:
- init_kwargs (dict): The init args.
-
- Return:
- :obj:`wandb`
- """
- try:
- import wandb
- except ImportError:
- raise ImportError(
- 'Please run "pip install wandb" to install wandb')
- if init_kwargs:
- wandb.init(**init_kwargs)
- else:
- wandb.init()
-
- return wandb
-
- def add_params(self, params_dict: dict, **kwargs) -> None:
- """Record a set of parameters to be compared in wandb.
-
- Args:
- params_dict (dict): Each key-value pair in the dictionary
- is the name of the parameters and it's
- corresponding value.
- """
- assert isinstance(params_dict, dict)
- self._wandb.log(params_dict, commit=self._commit)
-
- def add_image(self,
- name: str,
- image: Optional[np.ndarray] = None,
- gt_sample: Optional['BaseDataElement'] = None,
- pred_sample: Optional['BaseDataElement'] = None,
- draw_gt: bool = True,
- draw_pred: bool = True,
- step: int = 0,
- **kwargs) -> None:
- """Record image to wandb.
-
- Args:
- name (str): The unique identifier for the image to save.
- image (np.ndarray, optional): The image to be saved. The format
- should be RGB. Default to None.
- gt_sample (:obj:`BaseDataElement`, optional): The ground truth data
- structure of OpenMMlab. Default to None.
- pred_sample (:obj:`BaseDataElement`, optional): The predicted
- result data structure of OpenMMlab. Default to None.
- draw_gt (bool): Whether to draw the ground truth. Default: True.
- draw_pred (bool): Whether to draw the predicted result.
- Default to True.
- step (int): Global step value to record. Default to 0.
- """
- if self.visualizer:
- self.visualizer.draw(image, gt_sample, pred_sample, draw_gt,
- draw_pred)
- self._wandb.log({name: self.visualizer.get_image()},
- commit=self._commit,
- step=step)
- else:
- self.add_image_to_wandb(name, image, gt_sample, pred_sample,
- draw_gt, draw_pred, step, **kwargs)
-
- def add_scalar(self,
- name: str,
- value: Union[int, float],
- step: int = 0,
- **kwargs) -> None:
- """Record scalar data to wandb.
-
- Args:
- name (str): The unique identifier for the scalar to save.
- value (float, int): Value to save.
- step (int): Global step value to record. Default to 0.
- """
- self._wandb.log({name: value}, commit=self._commit, step=step)
-
- def add_scalars(self,
- scalar_dict: dict,
- step: int = 0,
- file_path: Optional[str] = None,
- **kwargs) -> None:
- """Record scalar's data to wandb.
-
- Args:
- scalar_dict (dict): Key-value pair storing the tag and
- corresponding values.
- step (int): Global step value to record. Default to 0.
- file_path (str, optional): Useless parameter. Just for
- interface unification. Default to None.
- """
- self._wandb.log(scalar_dict, commit=self._commit, step=step)
-
- def add_image_to_wandb(self,
- name: str,
- image: np.ndarray,
- gt_sample: Optional['BaseDataElement'] = None,
- pred_sample: Optional['BaseDataElement'] = None,
- draw_gt: bool = True,
- draw_pred: bool = True,
- step: int = 0,
- **kwargs) -> None:
- """Record image to wandb.
-
- Args:
- name (str): The unique identifier for the image to save.
- image (np.ndarray): The image to be saved. The format
- should be BGR.
- gt_sample (:obj:`BaseDataElement`, optional): The ground truth data
- structure of OpenMMlab. Default to None.
- pred_sample (:obj:`BaseDataElement`, optional): The predicted
- result data structure of OpenMMlab. Default to None.
- draw_gt (bool): Whether to draw the ground truth. Default to True.
- draw_pred (bool): Whether to draw the predicted result.
- Default to True.
- step (int): Global step value to record. Default to 0.
- """
- raise NotImplementedError()
-
- def close(self) -> None:
- """close an opened wandb object."""
- if hasattr(self, '_wandb'):
- self._wandb.join()
-
-
-@WRITERS.register_module()
-class TensorboardWriter(BaseWriter):
- """Tensorboard write class. It can write images, hyperparameters, scalars,
- etc. to a tensorboard file.
-
- Its drawing function is provided by Visualizer.
-
- Examples:
- >>> from mmengine.visualization import TensorboardWriter
- >>> import numpy as np
- >>> tensorboard_writer = TensorboardWriter(dict(type='DetVisualizer'),\
- save_dir='temp_dir')
- >>> img=np.random.randint(0, 256, size=(10, 10, 3))
- >>> tensorboard_writer.add_image('img', img)
- >>> tensorboard_writer.add_scaler('mAP', 0.6)
- >>> tensorboard_writer.add_scalars({'loss': 0.1,'acc':0.8})
- >>> tensorboard_writer.add_params(dict(lr=0.1, mode='linear'))
-
- >>> tensorboard_writer.visualizer.draw_bboxes(np.array([0, 0, 1, 1]), \
- edgecolors='g')
- >>> tensorboard_writer.add_image('img', \
- tensorboard_writer.visualizer.get_image())
-
- Args:
- save_dir (str): The root directory to save the files
- produced by the writer.
- visualizer (dict, :obj:`Visualizer`, optional): Visualizer instance
- or dictionary. Default to None.
- log_dir (str): Save directory location. Default to 'tf_writer'.
- """
-
- def __init__(self,
- save_dir: str,
- visualizer: Optional[Union[dict, 'Visualizer']] = None,
- log_dir: str = 'tf_logs'):
- super(TensorboardWriter, self).__init__(visualizer, save_dir)
- self._tensorboard = self._setup_env(log_dir)
-
- def _setup_env(self, log_dir: str):
- """Setup env.
-
- Args:
- log_dir (str): Save directory location. Default 'tf_writer'.
-
- Return:
- :obj:`SummaryWriter`
- """
- if TORCH_VERSION == 'parrots':
- try:
- from tensorboardX import SummaryWriter
- except ImportError:
- raise ImportError('Please install tensorboardX to use '
- 'TensorboardLoggerHook.')
- else:
- try:
- from torch.utils.tensorboard import SummaryWriter
- except ImportError:
- raise ImportError(
- 'Please run "pip install future tensorboard" to install '
- 'the dependencies to use torch.utils.tensorboard '
- '(applicable to PyTorch 1.1 or higher)')
-
- self.log_dir = osp.join(self._save_dir, log_dir) # type: ignore
- return SummaryWriter(self.log_dir)
-
- @property
- def experiment(self):
- """Return Tensorboard object."""
- return self._tensorboard
-
- def add_graph(self, model: torch.nn.Module,
- input_tensor: Union[torch.Tensor,
- List[torch.Tensor]], **kwargs) -> None:
- """Record graph data to tensorboard.
-
- Args:
- model (torch.nn.Module): Model to draw.
- input_tensor (torch.Tensor, list[torch.Tensor]): A variable
- or a tuple of variables to be fed.
- """
- if isinstance(input_tensor, list):
- for array in input_tensor:
- assert array.ndim == 4
- assert isinstance(array, torch.Tensor)
- else:
- assert isinstance(input_tensor,
- torch.Tensor) and input_tensor.ndim == 4
- self._tensorboard.add_graph(model, input_tensor)
-
- def add_params(self, params_dict: dict, **kwargs) -> None:
- """Record a set of hyperparameters to be compared in TensorBoard.
-
- Args:
- params_dict (dict): Each key-value pair in the dictionary is the
- name of the hyper parameter and it's corresponding value.
- The type of the value can be one of `bool`, `string`,
- `float`, `int`, or `None`.
- """
- assert isinstance(params_dict, dict)
- self._tensorboard.add_hparams(params_dict, {})
-
- def add_image(self,
- name: str,
- image: Optional[np.ndarray] = None,
- gt_sample: Optional['BaseDataElement'] = None,
- pred_sample: Optional['BaseDataElement'] = None,
- draw_gt: bool = True,
- draw_pred: bool = True,
- step: int = 0,
- **kwargs) -> None:
- """Record image to tensorboard.
-
- Args:
- name (str): The unique identifier for the image to save.
- image (np.ndarray, optional): The image to be saved. The format
- should be RGB. Default to None.
- gt_sample (:obj:`BaseDataElement`, optional): The ground truth data
- structure of OpenMMlab. Default to None.
- pred_sample (:obj:`BaseDataElement`, optional): The predicted
- result data structure of OpenMMlab. Default to None.
- draw_gt (bool): Whether to draw the ground truth. Default to True.
- draw_pred (bool): Whether to draw the predicted result.
- Default to True.
- step (int): Global step value to record. Default to 0.
- """
- assert self.visualizer, 'Please instantiate the visualizer ' \
- 'object with initialization parameters.'
- self.visualizer.draw(image, gt_sample, pred_sample, draw_gt, draw_pred)
- self._tensorboard.add_image(
- name, self.visualizer.get_image(), step, dataformats='HWC')
-
- def add_scalar(self,
- name: str,
- value: Union[int, float],
- step: int = 0,
- **kwargs) -> None:
- """Record scalar data to summary.
-
- Args:
- name (str): The unique identifier for the scalar to save.
- value (float, int): Value to save.
- step (int): Global step value to record. Default to 0.
- """
- self._tensorboard.add_scalar(name, value, step)
-
- def add_scalars(self,
- scalar_dict: dict,
- step: int = 0,
- file_path: Optional[str] = None,
- **kwargs) -> None:
- """Record scalar's data to summary.
-
- Args:
- scalar_dict (dict): Key-value pair storing the tag and
- corresponding values.
- step (int): Global step value to record. Default to 0.
- file_path (str, optional): Useless parameter. Just for
- interface unification. Default to None.
- """
- assert isinstance(scalar_dict, dict)
- assert 'step' not in scalar_dict, 'Please set it directly ' \
- 'through the step parameter'
- for key, value in scalar_dict.items():
- self.add_scalar(key, value, step)
-
- def close(self):
- """close an opened tensorboard object."""
- if hasattr(self, '_tensorboard'):
- self._tensorboard.close()
-
-
-class ComposedWriter(ManagerMixin):
- """Wrapper class to compose multiple a subclass of :class:`BaseWriter`
- instances. By inheriting ManagerMixin, it can be accessed anywhere once
- instantiated.
-
- Examples:
- >>> from mmengine.visualization import ComposedWriter
- >>> import numpy as np
- >>> composed_writer= ComposedWriter.get_instance( \
- 'composed_writer', writers=[dict(type='LocalWriter', \
- visualizer=dict(type='DetVisualizer'), \
- save_dir='temp_dir'), dict(type='WandbWriter')])
- >>> img=np.random.randint(0, 256, size=(10, 10, 3))
- >>> composed_writer.add_image('img', img)
- >>> composed_writer.add_scalar('mAP', 0.6)
- >>> composed_writer.add_scalars({'loss': 0.1,'acc':0.8})
- >>> composed_writer.add_params(dict(lr=0.1, mode='linear'))
-
- Args:
- name (str): The name of the instance. Defaults: 'composed_writer'.
- writers (list, optional): The writers to compose. Default to None
- """
-
- def __init__(self,
- name: str = 'composed_writer',
- writers: Optional[List[Union[dict, 'BaseWriter']]] = None):
- super().__init__(name)
- self._writers = []
- if writers is not None:
- assert isinstance(writers, list)
- for writer in writers:
- if isinstance(writer, dict):
- self._writers.append(WRITERS.build(writer))
- else:
- assert isinstance(writer, BaseWriter), \
- f'writer should be an instance of a subclass of ' \
- f'BaseWriter, but got {type(writer)}'
- self._writers.append(writer)
-
- def __len__(self):
- return len(self._writers)
-
- def get_writer(self, index: int) -> 'BaseWriter':
- """Returns the writer object corresponding to the specified index."""
- return self._writers[index]
-
- def get_experiment(self, index: int) -> Any:
- """Returns the writer's experiment object corresponding to the
- specified index."""
- return self._writers[index].experiment
-
- def get_visualizer(self, index: int) -> 'Visualizer':
- """Returns the writer's visualizer object corresponding to the
- specified index."""
- return self._writers[index].visualizer
-
- def add_params(self, params_dict: dict, **kwargs):
- """Record parameters.
-
- Args:
- params_dict (dict): The dictionary of parameters to save.
- """
- for writer in self._writers:
- writer.add_params(params_dict, **kwargs)
-
- def add_graph(self, model: torch.nn.Module,
- input_array: Union[torch.Tensor,
- List[torch.Tensor]], **kwargs) -> None:
- """Record graph data.
-
- Args:
- model (torch.nn.Module): Model to draw.
- input_array (torch.Tensor, list[torch.Tensor]): A variable
- or a tuple of variables to be fed.
- """
- for writer in self._writers:
- writer.add_graph(model, input_array, **kwargs)
-
- def add_image(self,
- name: str,
- image: Optional[np.ndarray] = None,
- gt_sample: Optional['BaseDataElement'] = None,
- pred_sample: Optional['BaseDataElement'] = None,
- draw_gt: bool = True,
- draw_pred: bool = True,
- step: int = 0,
- **kwargs) -> None:
- """Record image.
-
- Args:
- name (str): The unique identifier for the image to save.
- image (np.ndarray, optional): The image to be saved. The format
- should be RGB. Default to None.
- gt_sample (:obj:`BaseDataElement`, optional): The ground truth data
- structure of OpenMMlab. Default to None.
- pred_sample (:obj:`BaseDataElement`, optional): The predicted
- result data structure of OpenMMlab. Default to None.
- draw_gt (bool): Whether to draw the ground truth. Default to True.
- draw_pred (bool): Whether to draw the predicted result.
- Default to True.
- step (int): Global step value to record. Default to 0.
- """
- for writer in self._writers:
- writer.add_image(name, image, gt_sample, pred_sample, draw_gt,
- draw_pred, step, **kwargs)
-
- def add_scalar(self,
- name: str,
- value: Union[int, float],
- step: int = 0,
- **kwargs) -> None:
- """Record scalar data.
-
- Args:
- name (str): The unique identifier for the scalar to save.
- value (float, int): Value to save.
- step (int): Global step value to record. Default to 0.
- """
- for writer in self._writers:
- writer.add_scalar(name, value, step, **kwargs)
-
- @master_only
- def add_scalars(self,
- scalar_dict: dict,
- step: int = 0,
- file_path: Optional[str] = None,
- **kwargs) -> None:
- """Record scalars' data.
-
- Args:
- scalar_dict (dict): Key-value pair storing the tag and
- corresponding values.
- step (int): Global step value to record. Default to 0.
- file_path (str, optional): The scalar's data will be
- saved to the `file_path` file at the same time
- if the `file_path` parameter is specified.
- Default to None.
- """
- for writer in self._writers:
- writer.add_scalars(scalar_dict, step, file_path, **kwargs)
-
- def close(self) -> None:
- """close an opened object."""
- for writer in self._writers:
- writer.close()
diff --git a/tests/test_hook/test_logger_hook.py b/tests/test_hook/test_logger_hook.py
index 70e93d1c..1f9a3b76 100644
--- a/tests/test_hook/test_logger_hook.py
+++ b/tests/test_hook/test_logger_hook.py
@@ -61,7 +61,6 @@ class TestLoggerHook:
assert logger_hook.json_log_path == osp.join('work_dir',
'timestamp.log.json')
assert logger_hook.start_iter == runner.iter
- runner.writer.add_params.assert_called()
def test_after_run(self, tmp_path):
out_dir = tmp_path / 'out_dir'
@@ -151,7 +150,7 @@ class TestLoggerHook:
logger_hook._collect_info = MagicMock(return_value=train_infos)
logger_hook._log_train(runner)
# Verify that the correct variables have been written.
- runner.writer.add_scalars.assert_called_with(
+ runner.visualizer.add_scalars.assert_called_with(
train_infos, step=11, file_path='tmp.json')
# Verify that the correct context have been logged.
out, _ = capsys.readouterr()
@@ -209,7 +208,7 @@ class TestLoggerHook:
logger_hook._log_val(runner)
# Verify that the correct context have been logged.
out, _ = capsys.readouterr()
- runner.writer.add_scalars.assert_called_with(
+ runner.visualizer.add_scalars.assert_called_with(
metric, step=11, file_path='tmp.json')
if by_epoch:
assert out == 'Epoch(val) [1][5] accuracy: 0.9000, ' \
diff --git a/tests/test_hook/test_naive_visualization_hook.py b/tests/test_hook/test_naive_visualization_hook.py
index 0bbe47df..e06dd281 100644
--- a/tests/test_hook/test_naive_visualization_hook.py
+++ b/tests/test_hook/test_naive_visualization_hook.py
@@ -12,7 +12,7 @@ class TestNaiveVisualizationHook:
def test_after_train_iter(self):
naive_visualization_hook = NaiveVisualizationHook()
runner = Mock(iter=1)
- runner.writer.add_image = Mock()
+ runner.visualizer.add_image = Mock()
inputs = torch.randn(1, 3, 15, 15)
batch_idx = 10
# test with normalize, resize, pad
diff --git a/tests/test_registry/test_registry.py b/tests/test_registry/test_registry.py
index 76f1d7ce..cc7d4156 100644
--- a/tests/test_registry/test_registry.py
+++ b/tests/test_registry/test_registry.py
@@ -5,6 +5,7 @@ import pytest
from mmengine.config import Config, ConfigDict # type: ignore
from mmengine.registry import DefaultScope, Registry, build_from_cfg
+from mmengine.utils import ManagerMixin
class TestRegistry:
@@ -482,3 +483,17 @@ def test_build_from_cfg(cfg_type):
"")):
cfg = cfg_type(dict(type='ResNet', depth=50))
model = build_from_cfg(cfg, 'BACKBONES')
+
+ VISUALIZER = Registry('visualizer')
+
+ @VISUALIZER.register_module()
+ class Visualizer(ManagerMixin):
+
+ def __init__(self, name):
+ super().__init__(name)
+
+ with pytest.raises(RuntimeError):
+ Visualizer.get_current_instance()
+ cfg = dict(type='Visualizer', name='visualizer')
+ build_from_cfg(cfg, VISUALIZER)
+ Visualizer.get_current_instance()
diff --git a/tests/test_runner/test_runner.py b/tests/test_runner/test_runner.py
index b066838a..3e694be0 100644
--- a/tests/test_runner/test_runner.py
+++ b/tests/test_runner/test_runner.py
@@ -25,7 +25,7 @@ from mmengine.runner import (BaseLoop, EpochBasedTrainLoop, IterBasedTrainLoop,
Runner, TestLoop, ValLoop)
from mmengine.runner.priority import Priority, get_priority
from mmengine.utils import is_list_of
-from mmengine.visualization.writer import ComposedWriter
+from mmengine.visualization import Visualizer
@MODELS.register_module()
@@ -308,24 +308,24 @@ class TestRunner(TestCase):
self.assertFalse(runner.distributed)
self.assertFalse(runner.deterministic)
- # 1.5 message_hub, logger and writer
+ # 1.5 message_hub, logger and visualizer
# they are all not specified
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_init12'
runner = Runner(**cfg)
self.assertIsInstance(runner.logger, MMLogger)
self.assertIsInstance(runner.message_hub, MessageHub)
- self.assertIsInstance(runner.writer, ComposedWriter)
+ self.assertIsInstance(runner.visualizer, Visualizer)
# they are all specified
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_init13'
cfg.log_level = 'INFO'
- cfg.writer = dict(name='test_writer')
+ cfg.visualizer = None
runner = Runner(**cfg)
self.assertIsInstance(runner.logger, MMLogger)
self.assertIsInstance(runner.message_hub, MessageHub)
- self.assertIsInstance(runner.writer, ComposedWriter)
+ self.assertIsInstance(runner.visualizer, Visualizer)
assert runner.distributed is False
assert runner.seed is not None
@@ -446,32 +446,34 @@ class TestRunner(TestCase):
with self.assertRaisesRegex(TypeError, 'message_hub should be'):
runner.build_message_hub('invalid-type')
- def test_build_writer(self):
- self.epoch_based_cfg.experiment_name = 'test_build_writer1'
+ def test_build_visualizer(self):
+ self.epoch_based_cfg.experiment_name = 'test_build_visualizer1'
runner = Runner.from_cfg(self.epoch_based_cfg)
- self.assertIsInstance(runner.writer, ComposedWriter)
- self.assertEqual(runner.experiment_name, runner.writer.instance_name)
+ self.assertIsInstance(runner.visualizer, Visualizer)
+ self.assertEqual(runner.experiment_name,
+ runner.visualizer.instance_name)
- # input is a ComposedWriter object
+ # input is a Visualizer object
self.assertEqual(
- id(runner.build_writer(runner.writer)), id(runner.writer))
+ id(runner.build_visualizer(runner.visualizer)),
+ id(runner.visualizer))
# input is a dict
- writer_cfg = dict(name='test_build_writer2')
- writer = runner.build_writer(writer_cfg)
- self.assertIsInstance(writer, ComposedWriter)
- self.assertEqual(writer.instance_name, 'test_build_writer2')
+ visualizer_cfg = dict(type='Visualizer', name='test_build_visualizer2')
+ visualizer = runner.build_visualizer(visualizer_cfg)
+ self.assertIsInstance(visualizer, Visualizer)
+ self.assertEqual(visualizer.instance_name, 'test_build_visualizer2')
# input is a dict but does not contain name key
- runner._experiment_name = 'test_build_writer3'
- writer_cfg = dict()
- writer = runner.build_writer(writer_cfg)
- self.assertIsInstance(writer, ComposedWriter)
- self.assertEqual(writer.instance_name, 'test_build_writer3')
+ runner._experiment_name = 'test_build_visualizer3'
+ visualizer_cfg = None
+ visualizer = runner.build_visualizer(visualizer_cfg)
+ self.assertIsInstance(visualizer, Visualizer)
+ self.assertEqual(visualizer.instance_name, 'test_build_visualizer3')
# input is not a valid type
- with self.assertRaisesRegex(TypeError, 'writer should be'):
- runner.build_writer('invalid-type')
+ with self.assertRaisesRegex(TypeError, 'visualizer should be'):
+ runner.build_visualizer('invalid-type')
def test_default_scope(self):
TOY_SCHEDULERS = Registry(
diff --git a/tests/test_visualizer/test_vis_backend.py b/tests/test_visualizer/test_vis_backend.py
new file mode 100644
index 00000000..da662a65
--- /dev/null
+++ b/tests/test_visualizer/test_vis_backend.py
@@ -0,0 +1,200 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os
+import shutil
+import sys
+from unittest.mock import MagicMock
+
+import numpy as np
+import pytest
+
+from mmengine.fileio import load
+from mmengine.registry import VISBACKENDS
+from mmengine.visualization import (LocalVisBackend, TensorboardVisBackend,
+ WandbVisBackend)
+
+
+class TestLocalVisBackend:
+
+ def test_init(self):
+
+ # 'config_save_file' format must be py
+ with pytest.raises(AssertionError):
+ LocalVisBackend('temp_dir', config_save_file='a.txt')
+
+ # 'scalar_save_file' format must be json
+ with pytest.raises(AssertionError):
+ LocalVisBackend('temp_dir', scalar_save_file='a.yaml')
+
+ local_vis_backend = LocalVisBackend('temp_dir')
+ assert os.path.exists(local_vis_backend._save_dir)
+ shutil.rmtree('temp_dir')
+
+ local_vis_backend = VISBACKENDS.build(
+ dict(type='LocalVisBackend', save_dir='temp_dir'))
+ assert os.path.exists(local_vis_backend._save_dir)
+ shutil.rmtree('temp_dir')
+
+ def test_experiment(self):
+ local_vis_backend = LocalVisBackend('temp_dir')
+ assert local_vis_backend.experiment == local_vis_backend
+ shutil.rmtree('temp_dir')
+
+ def test_add_config(self):
+ local_vis_backend = LocalVisBackend('temp_dir')
+
+ # 'params_dict' must be dict
+ with pytest.raises(AssertionError):
+ local_vis_backend.add_config(['lr', 0])
+
+ # TODO
+
+ shutil.rmtree('temp_dir')
+
+ def test_add_image(self):
+ image = np.random.randint(0, 256, size=(10, 10, 3)).astype(np.uint8)
+ local_vis_backend = LocalVisBackend('temp_dir')
+ local_vis_backend.add_image('img', image)
+ assert os.path.exists(
+ os.path.join(local_vis_backend._img_save_dir, 'img_0.png'))
+
+ local_vis_backend.add_image('img', image, step=2)
+ assert os.path.exists(
+ os.path.join(local_vis_backend._img_save_dir, 'img_2.png'))
+
+ shutil.rmtree('temp_dir')
+
+ def test_add_scalar(self):
+ local_vis_backend = LocalVisBackend('temp_dir')
+ local_vis_backend.add_scalar('map', 0.9)
+ out_dict = load(local_vis_backend._scalar_save_file, 'json')
+ assert out_dict == {'map': 0.9, 'step': 0}
+ shutil.rmtree('temp_dir')
+
+ # test append mode
+ local_vis_backend = LocalVisBackend('temp_dir')
+ local_vis_backend.add_scalar('map', 0.9, step=0)
+ local_vis_backend.add_scalar('map', 0.95, step=1)
+ with open(local_vis_backend._scalar_save_file) as f:
+ out_dict = f.read()
+ assert out_dict == '{"map": 0.9, "step": 0}\n{"map": ' \
+ '0.95, "step": 1}\n'
+ shutil.rmtree('temp_dir')
+
+ def test_add_scalars(self):
+ local_vis_backend = LocalVisBackend('temp_dir')
+ input_dict = {'map': 0.7, 'acc': 0.9}
+ local_vis_backend.add_scalars(input_dict)
+ out_dict = load(local_vis_backend._scalar_save_file, 'json')
+ assert out_dict == {'map': 0.7, 'acc': 0.9, 'step': 0}
+
+ # test append mode
+ local_vis_backend.add_scalars({'map': 0.8, 'acc': 0.8}, step=1)
+ with open(local_vis_backend._scalar_save_file) as f:
+ out_dict = f.read()
+ assert out_dict == '{"map": 0.7, "acc": 0.9, ' \
+ '"step": 0}\n{"map": 0.8, "acc": 0.8, "step": 1}\n'
+
+ # test file_path
+ local_vis_backend = LocalVisBackend('temp_dir')
+ local_vis_backend.add_scalars(input_dict, file_path='temp.json')
+ assert os.path.exists(local_vis_backend._scalar_save_file)
+ assert os.path.exists(
+ os.path.join(local_vis_backend._save_dir, 'temp.json'))
+
+ # file_path and scalar_save_file cannot be the same
+ with pytest.raises(AssertionError):
+ local_vis_backend.add_scalars(input_dict, file_path='scalars.json')
+
+ shutil.rmtree('temp_dir')
+
+
+class TestTensorboardVisBackend:
+ sys.modules['torch.utils.tensorboard'] = MagicMock()
+ sys.modules['tensorboardX'] = MagicMock()
+
+ def test_init(self):
+
+ TensorboardVisBackend('temp_dir')
+ VISBACKENDS.build(
+ dict(type='TensorboardVisBackend', save_dir='temp_dir'))
+
+ def test_experiment(self):
+ tensorboard_vis_backend = TensorboardVisBackend('temp_dir')
+ assert (tensorboard_vis_backend.experiment ==
+ tensorboard_vis_backend._tensorboard)
+
+ def test_add_graph(self):
+ # TODO
+ pass
+
+ def test_add_config(self):
+ # TODO
+ pass
+
+ def test_add_image(self):
+ image = np.random.randint(0, 256, size=(10, 10, 3)).astype(np.uint8)
+
+ tensorboard_vis_backend = TensorboardVisBackend('temp_dir')
+ tensorboard_vis_backend.add_image('img', image)
+
+ tensorboard_vis_backend.add_image('img', image, step=2)
+
+ def test_add_scalar(self):
+ tensorboard_vis_backend = TensorboardVisBackend('temp_dir')
+ tensorboard_vis_backend.add_scalar('map', 0.9)
+ # test append mode
+ tensorboard_vis_backend.add_scalar('map', 0.9, step=0)
+ tensorboard_vis_backend.add_scalar('map', 0.95, step=1)
+
+ def test_add_scalars(self):
+ tensorboard_vis_backend = TensorboardVisBackend('temp_dir')
+ # The step value must be passed through the parameter
+ with pytest.raises(AssertionError):
+ tensorboard_vis_backend.add_scalars({
+ 'map': 0.7,
+ 'acc': 0.9,
+ 'step': 1
+ })
+
+ input_dict = {'map': 0.7, 'acc': 0.9}
+ tensorboard_vis_backend.add_scalars(input_dict)
+ # test append mode
+ tensorboard_vis_backend.add_scalars({'map': 0.8, 'acc': 0.8}, step=1)
+
+
+class TestWandbVisBackend:
+ sys.modules['wandb'] = MagicMock()
+
+ def test_init(self):
+ WandbVisBackend()
+ VISBACKENDS.build(dict(type='WandbVisBackend', save_dir='temp_dir'))
+
+ def test_experiment(self):
+ wandb_vis_backend = WandbVisBackend()
+ assert wandb_vis_backend.experiment == wandb_vis_backend._wandb
+
+ def test_add_config(self):
+ # TODO
+ pass
+
+ def test_add_image(self):
+ image = np.random.randint(0, 256, size=(10, 10, 3)).astype(np.uint8)
+
+ wandb_vis_backend = WandbVisBackend()
+ wandb_vis_backend.add_image('img', image)
+
+ wandb_vis_backend.add_image('img', image, step=2)
+
+ def test_add_scalar(self):
+ wandb_vis_backend = WandbVisBackend()
+ wandb_vis_backend.add_scalar('map', 0.9)
+ # test append mode
+ wandb_vis_backend.add_scalar('map', 0.9, step=0)
+ wandb_vis_backend.add_scalar('map', 0.95, step=1)
+
+ def test_add_scalars(self):
+ wandb_vis_backend = WandbVisBackend()
+ input_dict = {'map': 0.7, 'acc': 0.9}
+ wandb_vis_backend.add_scalars(input_dict)
+ # test append mode
+ wandb_vis_backend.add_scalars({'map': 0.8, 'acc': 0.8}, step=1)
diff --git a/tests/test_visualizer/test_visualizer.py b/tests/test_visualizer/test_visualizer.py
index 5a7da41b..ce3de94d 100644
--- a/tests/test_visualizer/test_visualizer.py
+++ b/tests/test_visualizer/test_visualizer.py
@@ -1,16 +1,64 @@
# Copyright (c) OpenMMLab. All rights reserved.
-from typing import Optional
+import copy
+from typing import Any, List, Optional, Union
from unittest import TestCase
import matplotlib.pyplot as plt
import numpy as np
import pytest
import torch
+import torch.nn as nn
-from mmengine.data import BaseDataElement
+from mmengine import VISBACKENDS
from mmengine.visualization import Visualizer
+@VISBACKENDS.register_module()
+class MockVisBackend:
+
+ def __init__(self, save_dir: Optional[str] = None):
+ self._save_dir = save_dir
+ self._close = False
+
+ @property
+ def experiment(self) -> Any:
+ return self
+
+ def add_config(self, params_dict: dict, **kwargs) -> None:
+ self._add_config = True
+
+ def add_graph(self, model: torch.nn.Module,
+ input_tensor: Union[torch.Tensor,
+ List[torch.Tensor]], **kwargs) -> None:
+
+ self._add_graph = True
+
+ def add_image(self,
+ name: str,
+ image: np.ndarray,
+ step: int = 0,
+ **kwargs) -> None:
+ self._add_image = True
+
+ def add_scalar(self,
+ name: str,
+ value: Union[int, float],
+ step: int = 0,
+ **kwargs) -> None:
+ self._add_scalar = True
+
+ def add_scalars(self,
+ scalar_dict: dict,
+ step: int = 0,
+ file_path: Optional[str] = None,
+ **kwargs) -> None:
+ self._add_scalars = True
+
+ def close(self) -> None:
+ """close an opened object."""
+ self._close = True
+
+
class TestVisualizer(TestCase):
def setUp(self):
@@ -21,11 +69,27 @@ class TestVisualizer(TestCase):
"""
self.image = np.random.randint(
0, 256, size=(10, 10, 3)).astype('uint8')
+ self.vis_backend_cfg = [
+ dict(type='MockVisBackend', name='mock1', save_dir='tmp'),
+ dict(type='MockVisBackend', name='mock2', save_dir='tmp')
+ ]
def test_init(self):
visualizer = Visualizer(image=self.image)
visualizer.get_image()
+ visualizer = Visualizer(
+ vis_backends=copy.deepcopy(self.vis_backend_cfg))
+ assert isinstance(visualizer.get_backend('mock1'), MockVisBackend)
+ assert len(visualizer._vis_backends) == 2
+
+ # test global
+ visualizer = Visualizer.get_instance(
+ 'visualizer', vis_backends=copy.deepcopy(self.vis_backend_cfg))
+ assert len(visualizer._vis_backends) == 2
+ visualizer_any = Visualizer.get_instance('visualizer')
+ assert visualizer_any == visualizer
+
def test_set_image(self):
visualizer = Visualizer()
visualizer.set_image(self.image)
@@ -45,7 +109,7 @@ class TestVisualizer(TestCase):
visualizer.draw_bboxes(torch.tensor([1, 1, 1, 2]))
bboxes = torch.tensor([[1, 1, 2, 2], [1, 2, 2, 2.5]])
visualizer.draw_bboxes(
- bboxes, alpha=0.5, edgecolors='b', linestyles='-')
+ bboxes, alpha=0.5, edge_colors=(255, 0, 0), line_styles='-')
bboxes = bboxes.numpy()
visualizer.draw_bboxes(bboxes)
@@ -66,19 +130,26 @@ class TestVisualizer(TestCase):
visualizer.draw_bboxes([1, 1, 2, 2])
def test_close(self):
- visualizer = Visualizer(image=self.image)
- fig_num = visualizer.fig.number
+ visualizer = Visualizer(
+ image=self.image, vis_backends=copy.deepcopy(self.vis_backend_cfg))
+ fig_num = visualizer.fig_save_num
assert fig_num in plt.get_fignums()
+ for name in ['mock1', 'mock2']:
+ assert visualizer.get_backend(name)._close is False
visualizer.close()
assert fig_num not in plt.get_fignums()
+ for name in ['mock1', 'mock2']:
+ assert visualizer.get_backend(name)._close is True
def test_draw_texts(self):
visualizer = Visualizer(image=self.image)
# only support tensor and numpy
- visualizer.draw_texts('text1', positions=torch.tensor([5, 5]))
+ visualizer.draw_texts(
+ 'text1', positions=torch.tensor([5, 5]), colors=(0, 255, 0))
visualizer.draw_texts(['text1', 'text2'],
- positions=torch.tensor([[5, 5], [3, 3]]))
+ positions=torch.tensor([[5, 5], [3, 3]]),
+ colors=[(255, 0, 0), (255, 0, 0)])
visualizer.draw_texts('text1', positions=np.array([5, 5]))
visualizer.draw_texts(['text1', 'text2'],
positions=np.array([[5, 5], [3, 3]]))
@@ -111,11 +182,11 @@ class TestVisualizer(TestCase):
with pytest.raises(AssertionError):
visualizer.draw_texts(['text1', 'test2'],
positions=torch.tensor([[5, 5], [3, 3]]),
- verticalalignments=['top'])
+ vertical_alignments=['top'])
with pytest.raises(AssertionError):
visualizer.draw_texts(['text1', 'test2'],
positions=torch.tensor([[5, 5], [3, 3]]),
- horizontalalignments=['left'])
+ horizontal_alignments=['left'])
with pytest.raises(AssertionError):
visualizer.draw_texts(['text1', 'test2'],
positions=torch.tensor([[5, 5], [3, 3]]),
@@ -140,8 +211,8 @@ class TestVisualizer(TestCase):
x_datas=np.array([[1, 5], [2, 4]]),
y_datas=np.array([[2, 6], [4, 7]]),
colors='r',
- linestyles=['-', '-.'],
- linewidths=[1, 2])
+ line_styles=['-', '-.'],
+ line_widths=[1, 2])
# test out of bounds
with pytest.warns(
UserWarning,
@@ -171,19 +242,20 @@ class TestVisualizer(TestCase):
visualizer.draw_circles(
torch.tensor([[1, 5], [2, 6]]), radius=torch.tensor([1, 2]))
- # test filling
+ # test face_colors
visualizer.draw_circles(
torch.tensor([[1, 5], [2, 6]]),
radius=torch.tensor([1, 2]),
- is_filling=True)
+ face_colors=(255, 0, 0),
+ edge_colors=(255, 0, 0))
# test config
visualizer.draw_circles(
torch.tensor([[1, 5], [2, 6]]),
radius=torch.tensor([1, 2]),
- edgecolors=['g', 'r'],
- linestyles=['-', '-.'],
- linewidths=[1, 2])
+ edge_colors=['g', 'r'],
+ line_styles=['-', '-.'],
+ line_widths=[1, 2])
# test out of bounds
with pytest.warns(
@@ -220,15 +292,16 @@ class TestVisualizer(TestCase):
np.array([[1, 1], [2, 2], [3, 4]]),
torch.tensor([[1, 1], [2, 2], [3, 4]])
],
- is_filling=True)
+ face_colors=(255, 0, 0),
+ edge_colors=(255, 0, 0))
visualizer.draw_polygons(
polygons=[
np.array([[1, 1], [2, 2], [3, 4]]),
torch.tensor([[1, 1], [2, 2], [3, 4]])
],
- edgecolors=['r', 'g'],
- linestyles='-',
- linewidths=[2, 1])
+ edge_colors=['r', 'g'],
+ line_styles='-',
+ line_widths=[2, 1])
# test out of bounds
with pytest.warns(
@@ -242,7 +315,10 @@ class TestVisualizer(TestCase):
visualizer = Visualizer(image=self.image)
visualizer.draw_binary_masks(binary_mask)
visualizer.draw_binary_masks(torch.from_numpy(binary_mask))
-
+ # multi binary
+ binary_mask = np.random.randint(0, 2, size=(2, 10, 10)).astype(np.bool)
+ visualizer = Visualizer(image=self.image)
+ visualizer.draw_binary_masks(binary_mask, colors=['r', (0, 255, 0)])
# test the error that the size of mask and image are different.
with pytest.raises(AssertionError):
binary_mask = np.random.randint(0, 2, size=(8, 10)).astype(np.bool)
@@ -269,7 +345,7 @@ class TestVisualizer(TestCase):
visualizer.draw_featmap(torch.randn(1, 1, 3, 3))
# test mode parameter
- # mode only supports 'mean' and 'max' and 'min
+ # mode only supports 'mean' and 'max'
with pytest.raises(AssertionError):
visualizer.draw_featmap(torch.randn(2, 3, 3), mode='xx')
# test tensor_chw and img have difference height and width
@@ -289,7 +365,6 @@ class TestVisualizer(TestCase):
visualizer.draw_featmap(torch.randn(6, 3, 3), mode='mean')
visualizer.draw_featmap(torch.randn(1, 3, 3), mode='mean')
visualizer.draw_featmap(torch.randn(6, 3, 3), mode='max')
- visualizer.draw_featmap(torch.randn(6, 3, 3), mode='min')
visualizer.draw_featmap(torch.randn(6, 3, 3), mode='max', topk=10)
visualizer.draw_featmap(torch.randn(1, 3, 3), mode=None, topk=-1)
visualizer.draw_featmap(
@@ -325,57 +400,76 @@ class TestVisualizer(TestCase):
draw_polygons(torch.tensor([[1, 1], [2, 2], [3, 4]])). \
draw_binary_masks(binary_mask)
- def test_register_task(self):
+ def test_get_backend(self):
+ visualizer = Visualizer(
+ image=self.image, vis_backends=copy.deepcopy(self.vis_backend_cfg))
+ for name in ['mock1', 'mock2']:
+ assert isinstance(visualizer.get_backend(name), MockVisBackend)
- class DetVisualizer(Visualizer):
+ def test_add_config(self):
+ visualizer = Visualizer(
+ vis_backends=copy.deepcopy(self.vis_backend_cfg))
- @Visualizer.register_task('instances')
- def draw_instance(self, instances, data_type):
- pass
+ params_dict = dict(lr=0.1, wd=0.2, mode='linear')
+ visualizer.add_config(params_dict)
+ for name in ['mock1', 'mock2']:
+ assert visualizer.get_backend(name)._add_config is True
- assert len(Visualizer.task_dict) == 1
- assert 'instances' in Visualizer.task_dict
+ def test_add_graph(self):
+ visualizer = Visualizer(
+ vis_backends=copy.deepcopy(self.vis_backend_cfg))
- # test registration of the same names.
- with pytest.raises(
- KeyError,
- match=('"instances" is already registered in task_dict, '
- 'add "force=True" if you want to override it')):
+ class Model(nn.Module):
- class DetVisualizer1(Visualizer):
+ def __init__(self):
+ super().__init__()
+ self.conv = nn.Conv2d(1, 2, 1)
- @Visualizer.register_task('instances')
- def draw_instance1(self, instances, data_type):
- pass
+ def forward(self, x, y=None):
+ return self.conv(x)
- @Visualizer.register_task('instances')
- def draw_instance2(self, instances, data_type):
- pass
+ visualizer.add_graph(Model(), np.zeros([1, 1, 3, 3]))
+ for name in ['mock1', 'mock2']:
+ assert visualizer.get_backend(name)._add_graph is True
- Visualizer.task_dict = dict()
+ def test_add_image(self):
+ image = np.random.randint(0, 256, size=(10, 10, 3)).astype(np.uint8)
+ visualizer = Visualizer(
+ vis_backends=copy.deepcopy(self.vis_backend_cfg))
- class DetVisualizer2(Visualizer):
+ visualizer.add_image('img', image)
+ for name in ['mock1', 'mock2']:
+ assert visualizer.get_backend(name)._add_image is True
- @Visualizer.register_task('instances')
- def draw_instance1(self, instances, data_type):
- pass
+ def test_add_scalar(self):
+ visualizer = Visualizer(
+ vis_backends=copy.deepcopy(self.vis_backend_cfg))
+ visualizer.add_scalar('map', 0.9, step=0)
+ for name in ['mock1', 'mock2']:
+ assert visualizer.get_backend(name)._add_scalar is True
- @Visualizer.register_task('instances', force=True)
- def draw_instance2(self, instances, data_type):
- pass
+ def test_add_scalars(self):
+ visualizer = Visualizer(
+ vis_backends=copy.deepcopy(self.vis_backend_cfg))
+ input_dict = {'map': 0.7, 'acc': 0.9}
+ visualizer.add_scalars(input_dict)
+ for name in ['mock1', 'mock2']:
+ assert visualizer.get_backend(name)._add_scalars is True
- def draw(self,
- image: Optional[np.ndarray] = None,
- gt_sample: Optional['BaseDataElement'] = None,
- pred_sample: Optional['BaseDataElement'] = None,
- draw_gt: bool = True,
- draw_pred: bool = True) -> None:
- return super().draw(image, gt_sample, pred_sample, draw_gt,
- draw_pred)
+ def test_get_instance(self):
- det_visualizer = DetVisualizer2()
- det_visualizer.draw(gt_sample={}, pred_sample={})
- assert len(det_visualizer.task_dict) == 1
- assert 'instances' in det_visualizer.task_dict
- assert det_visualizer.task_dict[
- 'instances'].__name__ == 'draw_instance2'
+ class DetLocalVisualizer(Visualizer):
+
+ def __init__(self, name):
+ super().__init__(name)
+
+ visualizer1 = DetLocalVisualizer.get_instance('name1')
+ visualizer2 = Visualizer.get_current_instance()
+ visualizer3 = DetLocalVisualizer.get_current_instance()
+ assert id(visualizer1) == id(visualizer2) == id(visualizer3)
+
+
+if __name__ == '__main__':
+ t = TestVisualizer()
+ t.setUp()
+ t.test_init()
diff --git a/tests/test_visualizer/test_writer.py b/tests/test_visualizer/test_writer.py
deleted file mode 100644
index 447a246d..00000000
--- a/tests/test_visualizer/test_writer.py
+++ /dev/null
@@ -1,484 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import os
-import shutil
-import sys
-from unittest.mock import MagicMock, Mock, patch
-
-import numpy as np
-import pytest
-import torch
-import torch.nn as nn
-
-from mmengine.fileio import load
-from mmengine.registry import VISUALIZERS, WRITERS
-from mmengine.visualization import (ComposedWriter, LocalWriter,
- TensorboardWriter, WandbWriter)
-
-
-def draw(self, image, gt_sample, pred_sample, show_gt=True, show_pred=True):
- self.set_image(image)
-
-
-class TestLocalWriter:
-
- def test_init(self):
- # visuailzer must be a dictionary or an instance
- # of Visualizer and its subclasses
- with pytest.raises(AssertionError):
- LocalWriter('temp_dir', [dict(type='Visualizer')])
-
- # 'params_save_file' format must be yaml
- with pytest.raises(AssertionError):
- LocalWriter('temp_dir', params_save_file='a.txt')
-
- # 'scalar_save_file' format must be json
- with pytest.raises(AssertionError):
- LocalWriter('temp_dir', scalar_save_file='a.yaml')
-
- local_writer = LocalWriter('temp_dir')
- assert os.path.exists(local_writer._save_dir)
- shutil.rmtree('temp_dir')
-
- local_writer = WRITERS.build(
- dict(
- type='LocalWriter',
- visualizer=dict(type='Visualizer'),
- save_dir='temp_dir'))
- assert os.path.exists(local_writer._save_dir)
- shutil.rmtree('temp_dir')
-
- def test_experiment(self):
- local_writer = LocalWriter('temp_dir')
- assert local_writer.experiment == local_writer
- shutil.rmtree('temp_dir')
-
- def test_add_params(self):
- local_writer = LocalWriter('temp_dir')
-
- # 'params_dict' must be dict
- with pytest.raises(AssertionError):
- local_writer.add_params(['lr', 0])
-
- params_dict = dict(lr=0.1, wd=[1.0, 0.1, 0.001], mode='linear')
- local_writer.add_params(params_dict)
- out_dict = load(local_writer._params_save_file, 'yaml')
- assert out_dict == params_dict
- shutil.rmtree('temp_dir')
-
- @patch('mmengine.visualization.visualizer.Visualizer.draw', draw)
- def test_add_image(self):
- image = np.random.randint(0, 256, size=(10, 10, 3)).astype(np.uint8)
-
- # The visuailzer parameter must be set when
- # the local_writer object is instantiated and
- # the `add_image` method is called.
- with pytest.raises(AssertionError):
- local_writer = LocalWriter('temp_dir')
- local_writer.add_image('img', image)
-
- local_writer = LocalWriter('temp_dir', dict(type='Visualizer'))
- local_writer.add_image('img', image)
- assert os.path.exists(
- os.path.join(local_writer._img_save_dir, 'img_0.png'))
-
- bboxes = np.array([[1, 1, 2, 2], [1, 1.5, 1, 2.5]])
- local_writer.visualizer.draw_bboxes(bboxes)
- local_writer.add_image(
- 'img', local_writer.visualizer.get_image(), step=2)
- assert os.path.exists(
- os.path.join(local_writer._img_save_dir, 'img_2.png'))
-
- visuailzer = VISUALIZERS.build(dict(type='Visualizer'))
- local_writer = LocalWriter('temp_dir', visuailzer)
- local_writer.add_image('img', image)
- assert os.path.exists(
- os.path.join(local_writer._img_save_dir, 'img_0.png'))
-
- shutil.rmtree('temp_dir')
-
- def test_add_scalar(self):
- local_writer = LocalWriter('temp_dir')
- local_writer.add_scalar('map', 0.9)
- out_dict = load(local_writer._scalar_save_file, 'json')
- assert out_dict == {'map': 0.9, 'step': 0}
- shutil.rmtree('temp_dir')
-
- # test append mode
- local_writer = LocalWriter('temp_dir')
- local_writer.add_scalar('map', 0.9, step=0)
- local_writer.add_scalar('map', 0.95, step=1)
- with open(local_writer._scalar_save_file) as f:
- out_dict = f.read()
- assert out_dict == '{"map": 0.9, "step": 0}\n{"map": ' \
- '0.95, "step": 1}\n'
- shutil.rmtree('temp_dir')
-
- def test_add_scalars(self):
- local_writer = LocalWriter('temp_dir')
- input_dict = {'map': 0.7, 'acc': 0.9}
- local_writer.add_scalars(input_dict)
- out_dict = load(local_writer._scalar_save_file, 'json')
- assert out_dict == {'map': 0.7, 'acc': 0.9, 'step': 0}
-
- # test append mode
- local_writer.add_scalars({'map': 0.8, 'acc': 0.8}, step=1)
- with open(local_writer._scalar_save_file) as f:
- out_dict = f.read()
- assert out_dict == '{"map": 0.7, "acc": 0.9, ' \
- '"step": 0}\n{"map": 0.8, "acc": 0.8, "step": 1}\n'
-
- # test file_path
- local_writer = LocalWriter('temp_dir')
- local_writer.add_scalars(input_dict, file_path='temp.json')
- assert os.path.exists(local_writer._scalar_save_file)
- assert os.path.exists(
- os.path.join(local_writer._save_dir, 'temp.json'))
-
- # file_path and scalar_save_file cannot be the same
- with pytest.raises(AssertionError):
- local_writer.add_scalars(input_dict, file_path='scalars.json')
-
- shutil.rmtree('temp_dir')
-
-
-class TestTensorboardWriter:
- sys.modules['torch.utils.tensorboard'] = MagicMock()
- sys.modules['tensorboardX'] = MagicMock()
-
- def test_init(self):
- # visuailzer must be a dictionary or an instance
- # of Visualizer and its subclasses
- with pytest.raises(AssertionError):
- LocalWriter('temp_dir', [dict(type='Visualizer')])
-
- TensorboardWriter('temp_dir')
- WRITERS.build(
- dict(
- type='TensorboardWriter',
- visualizer=dict(type='Visualizer'),
- save_dir='temp_dir'))
-
- def test_experiment(self):
- tensorboard_writer = TensorboardWriter('temp_dir')
- assert tensorboard_writer.experiment == tensorboard_writer._tensorboard
-
- def test_add_graph(self):
-
- class Model(nn.Module):
-
- def __init__(self):
- super().__init__()
- self.conv = nn.Conv2d(1, 2, 1)
-
- def forward(self, x, y=None):
- return self.conv(x)
-
- tensorboard_writer = TensorboardWriter('temp_dir')
-
- # input must be tensor
- with pytest.raises(AssertionError):
- tensorboard_writer.add_graph(Model(), np.zeros([1, 1, 3, 3]))
-
- # input must be 4d tensor
- with pytest.raises(AssertionError):
- tensorboard_writer.add_graph(Model(), torch.zeros([1, 3, 3]))
-
- # If the input is a list, the inner element must be a 4d tensor
- with pytest.raises(AssertionError):
- tensorboard_writer.add_graph(
- Model(), [torch.zeros([1, 1, 3, 3]),
- torch.zeros([1, 3, 3])])
-
- tensorboard_writer.add_graph(Model(), torch.zeros([1, 1, 3, 3]))
- tensorboard_writer.add_graph(
- Model(), [torch.zeros([1, 1, 3, 3]),
- torch.zeros([1, 1, 3, 3])])
-
- def test_add_params(self):
- tensorboard_writer = TensorboardWriter('temp_dir')
-
- # 'params_dict' must be dict
- with pytest.raises(AssertionError):
- tensorboard_writer.add_params(['lr', 0])
-
- params_dict = dict(lr=0.1, wd=0.2, mode='linear')
- tensorboard_writer.add_params(params_dict)
-
- @patch('mmengine.visualization.visualizer.Visualizer.draw', draw)
- def test_add_image(self):
- image = np.random.randint(0, 256, size=(10, 10, 3)).astype(np.uint8)
-
- # The visuailzer parameter must be set when
- # the local_writer object is instantiated and
- # the `add_image` method is called.
- with pytest.raises(AssertionError):
- tensorboard_writer = TensorboardWriter('temp_dir')
- tensorboard_writer.add_image('img', image)
-
- tensorboard_writer = TensorboardWriter('temp_dir',
- dict(type='Visualizer'))
- tensorboard_writer.add_image('img', image)
-
- bboxes = np.array([[1, 1, 2, 2], [1, 1.5, 1, 2.5]])
- tensorboard_writer.visualizer.draw_bboxes(bboxes)
- tensorboard_writer.add_image(
- 'img', tensorboard_writer.visualizer.get_image(), step=2)
-
- visuailzer = VISUALIZERS.build(dict(type='Visualizer'))
- tensorboard_writer = TensorboardWriter('temp_dir', visuailzer)
- tensorboard_writer.add_image('img', image)
-
- def test_add_scalar(self):
- tensorboard_writer = TensorboardWriter('temp_dir')
- tensorboard_writer.add_scalar('map', 0.9)
- # test append mode
- tensorboard_writer.add_scalar('map', 0.9, step=0)
- tensorboard_writer.add_scalar('map', 0.95, step=1)
-
- def test_add_scalars(self):
- tensorboard_writer = TensorboardWriter('temp_dir')
- # The step value must be passed through the parameter
- with pytest.raises(AssertionError):
- tensorboard_writer.add_scalars({'map': 0.7, 'acc': 0.9, 'step': 1})
-
- input_dict = {'map': 0.7, 'acc': 0.9}
- tensorboard_writer.add_scalars(input_dict)
- # test append mode
- tensorboard_writer.add_scalars({'map': 0.8, 'acc': 0.8}, step=1)
-
-
-class TestWandbWriter:
- sys.modules['wandb'] = MagicMock()
-
- def test_init(self):
- WandbWriter()
- WRITERS.build(
- dict(
- type='WandbWriter',
- visualizer=dict(type='Visualizer'),
- save_dir='temp_dir'))
-
- def test_experiment(self):
- wandb_writer = WandbWriter()
- assert wandb_writer.experiment == wandb_writer._wandb
-
- def test_add_params(self):
- wandb_writer = WandbWriter()
-
- # 'params_dict' must be dict
- with pytest.raises(AssertionError):
- wandb_writer.add_params(['lr', 0])
-
- params_dict = dict(lr=0.1, wd=0.2, mode='linear')
- wandb_writer.add_params(params_dict)
-
- @patch('mmengine.visualization.visualizer.Visualizer.draw', draw)
- @patch('mmengine.visualization.writer.WandbWriter.add_image_to_wandb',
- Mock)
- def test_add_image(self):
- image = np.random.randint(0, 256, size=(10, 10, 3)).astype(np.uint8)
-
- wandb_writer = WandbWriter()
- wandb_writer.add_image('img', image)
-
- wandb_writer = WandbWriter(visualizer=dict(type='Visualizer'))
- bboxes = np.array([[1, 1, 2, 2], [1, 1.5, 1, 2.5]])
- wandb_writer.visualizer.set_image(image)
- wandb_writer.visualizer.draw_bboxes(bboxes)
- wandb_writer.add_image(
- 'img', wandb_writer.visualizer.get_image(), step=2)
-
- visuailzer = VISUALIZERS.build(dict(type='Visualizer'))
- wandb_writer = WandbWriter(visualizer=visuailzer)
- wandb_writer.add_image('img', image)
-
- def test_add_scalar(self):
- wandb_writer = WandbWriter()
- wandb_writer.add_scalar('map', 0.9)
- # test append mode
- wandb_writer.add_scalar('map', 0.9, step=0)
- wandb_writer.add_scalar('map', 0.95, step=1)
-
- def test_add_scalars(self):
- wandb_writer = WandbWriter()
- input_dict = {'map': 0.7, 'acc': 0.9}
- wandb_writer.add_scalars(input_dict)
- # test append mode
- wandb_writer.add_scalars({'map': 0.8, 'acc': 0.8}, step=1)
-
-
-class TestComposedWriter:
- sys.modules['torch.utils.tensorboard'] = MagicMock()
- sys.modules['tensorboardX'] = MagicMock()
- sys.modules['wandb'] = MagicMock()
-
- def test_init(self):
-
- class A:
- pass
-
- # The writers inner element must be a dictionary or a
- # subclass of Writer.
- with pytest.raises(AssertionError):
- ComposedWriter(writers=[A()])
-
- composed_writer = ComposedWriter(writers=[
- WandbWriter(),
- dict(
- type='TensorboardWriter',
- visualizer=dict(type='Visualizer'),
- save_dir='temp_dir')
- ])
- assert len(composed_writer._writers) == 2
-
- # test global
- composed_writer = ComposedWriter.get_instance(
- 'composed_writer',
- writers=[
- WandbWriter(),
- dict(
- type='TensorboardWriter',
- visualizer=dict(type='Visualizer'),
- save_dir='temp_dir')
- ])
- assert len(composed_writer._writers) == 2
- composed_writer_any = ComposedWriter.get_instance('composed_writer')
- assert composed_writer_any == composed_writer
-
- def test_get_writer(self):
- composed_writer = ComposedWriter(writers=[
- WandbWriter(),
- dict(
- type='TensorboardWriter',
- visualizer=dict(type='Visualizer'),
- save_dir='temp_dir')
- ])
- assert isinstance(composed_writer.get_writer(0), WandbWriter)
- assert isinstance(composed_writer.get_writer(1), TensorboardWriter)
-
- def test_get_experiment(self):
- composed_writer = ComposedWriter(writers=[
- WandbWriter(),
- dict(
- type='TensorboardWriter',
- visualizer=dict(type='Visualizer'),
- save_dir='temp_dir')
- ])
- assert composed_writer.get_experiment(
- 0) == composed_writer._writers[0].experiment
- assert composed_writer.get_experiment(
- 1) == composed_writer._writers[1].experiment
-
- def test_get_visualizer(self):
- composed_writer = ComposedWriter(writers=[
- WandbWriter(),
- dict(
- type='TensorboardWriter',
- visualizer=dict(type='Visualizer'),
- save_dir='temp_dir')
- ])
- assert composed_writer.get_visualizer(
- 0) == composed_writer._writers[0].visualizer
- assert composed_writer.get_visualizer(
- 1) == composed_writer._writers[1].visualizer
-
- def test_add_params(self):
- composed_writer = ComposedWriter(writers=[
- WandbWriter(),
- dict(
- type='TensorboardWriter',
- visualizer=dict(type='Visualizer'),
- save_dir='temp_dir')
- ])
-
- # 'params_dict' must be dict
- with pytest.raises(AssertionError):
- composed_writer.add_params(['lr', 0])
-
- params_dict = dict(lr=0.1, wd=0.2, mode='linear')
- composed_writer.add_params(params_dict)
-
- def test_add_graph(self):
- composed_writer = ComposedWriter(writers=[
- WandbWriter(),
- dict(
- type='TensorboardWriter',
- visualizer=dict(type='Visualizer'),
- save_dir='temp_dir')
- ])
-
- class Model(nn.Module):
-
- def __init__(self):
- super().__init__()
- self.conv = nn.Conv2d(1, 2, 1)
-
- def forward(self, x, y=None):
- return self.conv(x)
-
- # input must be tensor
- with pytest.raises(AssertionError):
- composed_writer.add_graph(Model(), np.zeros([1, 1, 3, 3]))
-
- # input must be 4d tensor
- with pytest.raises(AssertionError):
- composed_writer.add_graph(Model(), torch.zeros([1, 3, 3]))
-
- # If the input is a list, the inner element must be a 4d tensor
- with pytest.raises(AssertionError):
- composed_writer.add_graph(
- Model(), [torch.zeros([1, 1, 3, 3]),
- torch.zeros([1, 3, 3])])
-
- composed_writer.add_graph(Model(), torch.zeros([1, 1, 3, 3]))
- composed_writer.add_graph(
- Model(), [torch.zeros([1, 1, 3, 3]),
- torch.zeros([1, 1, 3, 3])])
-
- @patch('mmengine.visualization.visualizer.Visualizer.draw', draw)
- @patch('mmengine.visualization.writer.WandbWriter.add_image_to_wandb',
- Mock)
- def test_add_image(self):
- composed_writer = ComposedWriter(writers=[
- WandbWriter(),
- dict(
- type='TensorboardWriter',
- visualizer=dict(type='Visualizer'),
- save_dir='temp_dir')
- ])
-
- image = np.random.randint(0, 256, size=(10, 10, 3)).astype(np.uint8)
- composed_writer.add_image('img', image)
-
- bboxes = np.array([[1, 1, 2, 2], [1, 1.5, 1, 2.5]])
- composed_writer.get_writer(1).visualizer.draw_bboxes(bboxes)
- composed_writer.get_writer(1).add_image(
- 'img',
- composed_writer.get_writer(1).visualizer.get_image(),
- step=2)
-
- def test_add_scalar(self):
- composed_writer = ComposedWriter(writers=[
- WandbWriter(),
- dict(
- type='TensorboardWriter',
- visualizer=dict(type='Visualizer'),
- save_dir='temp_dir')
- ])
- composed_writer.add_scalar('map', 0.9)
- # test append mode
- composed_writer.add_scalar('map', 0.9, step=0)
- composed_writer.add_scalar('map', 0.95, step=1)
-
- def test_add_scalars(self):
- composed_writer = ComposedWriter(writers=[
- WandbWriter(),
- dict(
- type='TensorboardWriter',
- visualizer=dict(type='Visualizer'),
- save_dir='temp_dir')
- ])
- input_dict = {'map': 0.7, 'acc': 0.9}
- composed_writer.add_scalars(input_dict)
- # test append mode
- composed_writer.add_scalars({'map': 0.8, 'acc': 0.8}, step=1)