From 7e123ad6d6978bc447f34a33c1809ca230501fe3 Mon Sep 17 00:00:00 2001
From: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
Date: Fri, 22 Apr 2022 20:45:13 +0800
Subject: [PATCH 1/4] [Docs] Refine registry documentation (#186)
* [Docs] Refine registry documentation
* reslove comments
* minor refinement
---
docs/zh_cn/tutorials/registry.md | 42 +++++++++++++++++++-------------
1 file changed, 25 insertions(+), 17 deletions(-)
diff --git a/docs/zh_cn/tutorials/registry.md b/docs/zh_cn/tutorials/registry.md
index ced73f25..1e0ff3a2 100644
--- a/docs/zh_cn/tutorials/registry.md
+++ b/docs/zh_cn/tutorials/registry.md
@@ -262,7 +262,7 @@ class RetinaNet(nn.Module):

-我们可以在 `MMDetection` 中调用 `MMEngine` 中模块。
+我们可以在 `MMDetection` 中调用 `MMEngine` 中的模块。
```python
from mmdet.models import MODELS
@@ -278,6 +278,29 @@ model = MODELS.build(cfg=dict(type='Conv2d'))
如果不加前缀,`build` 方法首先查找当前节点是否存在该模块,如果存在则返回该模块,否则会继续向上查找父节点甚至祖先节点直到找到该模块,因此,如果当前节点和父节点存在同一模块并且希望调用父节点的模块,我们需要指定 `scope` 前缀。需要注意的是,向上查找父节点甚至祖先节点的**前提是父节点或者祖先节点的模块已通过某种方式被导入进而完成注册**。例如,在上面这个示例中,之所以没有显示导入父节点 `mmengine` 中的 `MODELS`,是因为通过 `from mmdet.models import MODELS` 间接触发 `mmengine.MODELS` 完成模块的注册。
+上面展示了如何使用子节点注册器构建模块,但有时候我们希望不填加前缀也能在父节点注册器中构建子节点的模块,目的是提供通用的代码,避免下游算法库重复造轮子,该如何实现呢?
+
+假设 MMEngine 中有一个 `build_model` 函数,该方法用于构建模型。
+
+```python
+from mmengine.registry import MODELS
+
+def build_model(cfg):
+ model = MODELS.build(cfg)
+```
+
+如果我们希望在 MMDetection 中调用该函数构建 MMDetection 注册的模块,那么我们需要先获取一个 scope_name 为 'mmdet' 的 [DefaultScope](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.registry.DefaultScope) 实例,该实例全局唯一。
+
+```python
+from mmengine import build_model
+import mmdet.models # 通过 import 的方式将 mmdet 中的模块导入注册器进而完成注册
+
+default_scope = DefaultScope.get_instance('my_experiment', scope_name='mmdet')
+model = build_model(cfg=dict(type='RetinaNet'))
+```
+
+获取 `DefaultScope` 实例的目的是使 Registry 的 build 方法会将 DefaultScope 名称(mmdet)注册器节点作为注册器的起点,才能在配置中不填加 mmdet 前缀的情况下在 MMDetection 的注册器节点中找到 RetinaNet 模块,如若不然,程序会报找不到 RetinaNet 错误。
+
### 调用兄弟节点的模块
除了可以调用父节点的模块,也可以调用兄弟节点的模块。
@@ -311,16 +334,7 @@ from mmcls.models import MODELS
model = MODELS.build(cfg=dict(type='mmdet.RetinaNet'))
```
-调用非本节点的模块需要指定在 `type` 中指定 `scope` 前缀,如果不想指定,我们可以创建一个全局变量 `default_scope` 并将 `scope_name` 设置为 'mmdet',`Registry` 会将 `scope_name` 对应的 `registry` 作为当前 `Registry` 并调用 `build` 方法。
-
-```python
-from mmengine.registry import DefaultScope, MODELS
-
-# 调用注册在 mmdet 中的 RetinaNet
-default_scope = DefaultScope.get_instance(
- 'my_experiment', scope_name='mmdet')
-model = MODELS.build(cfg=dict(type='RetinaNet'))
-```
+调用非本节点或父节点的模块需要在 `type` 中指定 `scope` 前缀。
注册器除了支持两层结构,三层甚至更多层结构也是支持的。
@@ -358,10 +372,4 @@ model = MODELS.build(cfg=dict(type='mmcls.ResNet'))
from mmcls.models import MODELS
# 需要注意前缀的顺序,'detplus.mmdet.ResNet' 是不正确的
model = MODELS.build(cfg=dict(type='mmdet.detplus.MetaNet'))
-
-# 如果希望默认从 detplus 构建模型,设置可以 default_scope
-from mmengine.registry import DefaultScope
-default_scope = DefaultScope.get_instance(
- 'my_experiment', scope_name='detplus')
-model = MODELS.build(cfg=dict(type='MetaNet', default_scope='detplus'))
```
From 427467937617dd817d19567289607a39272e1406 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Haian=20Huang=28=E6=B7=B1=E5=BA=A6=E7=9C=B8=29?=
<1286304229@qq.com>
Date: Sun, 24 Apr 2022 19:21:10 +0800
Subject: [PATCH 2/4] Refine Visualizer docs (#177)
* Refine Visualizer docs
* update
* update
* update featmap
* update docs
* update visualizer docs
---
docs/zh_cn/tutorials/visualization.md | 477 +++++++++++++-------------
1 file changed, 238 insertions(+), 239 deletions(-)
diff --git a/docs/zh_cn/tutorials/visualization.md b/docs/zh_cn/tutorials/visualization.md
index 80acabcc..4aa7d6ec 100644
--- a/docs/zh_cn/tutorials/visualization.md
+++ b/docs/zh_cn/tutorials/visualization.md
@@ -2,8 +2,6 @@
## 概述
-**(1) 总体介绍**
-
可视化可以给深度学习的模型训练和测试过程提供直观解释。在 OpenMMLab 算法库中,我们期望可视化功能的设计能满足以下需求:
- 提供丰富的开箱即用可视化功能,能够满足大部分计算机视觉可视化任务
@@ -11,100 +9,37 @@
- 能够在训练和测试流程的任意点位进行可视化
- OpenMMLab 各个算法库具有统一可视化接口,利于用户理解和维护
-基于上述需求,OpenMMLab 2.0 引入了绘制对象 Visualizer 和写端对象 Writer 的概念
+基于上述需求,OpenMMLab 2.0 引入了可视化对象 Visualizer 和各个可视化存储后端 VisBackend 如 `LocalVisBackend`、`WandbVisBackend` 和 `TensorboardVisBackend` 等。此处的可视化不仅仅包括图片数据格式,还包括配置内容、标量和模型图等数据的可视化。
-- **Visualizer 负责单张图片的绘制功能**
+- 为了方便调用,Visualizer 提供的接口实现了绘制和存储的功能。可视化存储后端 VisBackend 作为 Visualizer 的内部属性,会在需要的时候被 Visualizer 调用,将数据存到不同的后端
+- 考虑到绘制后会希望存储到多个后端,Visualizer 可以配置多个 VisBackend,当用户调用 Visualizer 的存储接口时候,Visualizer 内部会遍历的调用 VisBackend 存储接口
- MMEngine 提供了以 Matplotlib 库为绘制后端的 `Visualizer` 类,其具备如下功能:
+两者的 UML 关系图如下
- - 提供了一系列和视觉任务无关的基础方法,例如 `draw_bboxes` 和 `draw_texts` 等
- - 各个基础方法支持链式调用,方便叠加绘制显示
- - 通过 `draw_featmap` 提供绘制特征图功能
+
+

+
- 各个下游算法库可以继承 `Visualizer` 并在 `draw` 接口中实现所需的可视化功能,例如 MMDetection 中的 `DetVisualizer` 继承自 `Visualizer` 并在 `draw` 接口中实现可视化检测框、实例掩码和语义分割图等功能。Visualizer 类的 UML 关系图如下
+## 可视化对象 Visualizer
-
-

-
+### 接口说明
-- **Writer 负责将各类数据写入到指定后端**
+可视化对象 Visualizer 对外提供了所有接口。可以将其接口分成 3 大类,如下所示
- 为了统一接口调用,MMEngine 提供了统一的抽象类 `BaseWriter`,和一些常用的 Writer 如 `LocalWriter` 来支持将数据写入本地,`TensorboardWriter` 来支持将数据写入 Tensorboard,`WandbWriter` 来支持将数据写入 Wandb。用户也可以自定义 Writer 来将数据写入自定义后端。写入的数据可以是图片,模型结构图,标量如模型精度指标等。
+**(1) 绘制相关接口**
- 考虑到在训练或者测试过程中可能同时存在多个 Writer 对象,例如同时想进行本地和远程端写数据,为此设计了 `ComposedWriter` 负责管理所有运行中实例化的 Writer 对象,其会自动管理所有 Writer 对象,并遍历调用所有 Writer 对象的方法。Writer 类的 UML 关系图如下
-
-

-
+- [draw_bboxes](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.Visualizer.draw_bboxes) 绘制单个或多个边界框
+- [draw_points](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.Visualizer.draw_points) 绘制单个或多个点
+- [draw_texts](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.Visualizer.draw_texts) 绘制单个或多个文本框
+- [draw_lines](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.Visualizer.lines) 绘制单个或多个线段
+- [draw_circles](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.Visualizer.draw_circles) 绘制单个或多个圆
+- [draw_polygons](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.Visualizer.draw_polygons) 绘制单个或多个多边形
+- [draw_binary_masks](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.Visualizer.draw_binary_mask) 绘制单个或多个二值掩码
+- [draw_featmap](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.Visualizer.draw_featmap) 绘制特征图,静态方法
-**(2) Writer 和 Visualizer 关系**
+上述接口除了 `draw_featmap` 外都可以链式调用,因为该方法调用后可能会导致图片尺寸发生改变。为了避免给用户带来困扰, `draw_featmap` 被设置为静态方法。
-Writer 对象的核心功能是写各类数据到指定后端中,例如写图片、写模型图、写超参和写模型精度指标等,后端可以指定为本地存储、Wandb 和 Tensorboard 等等。在写图片过程中,通常希望能够将预测结果或者标注结果绘制到图片上,然后再进行写操作,为此在 Writer 内部维护了 Visualizer 对象,将 Visualizer 作为 Writer 的一个属性。需要注意的是:
-
-- 只有调用了 Writer 中的 `add_image` 写图片功能时候才可能会用到 Visualizer 对象,其余接口和 Visualizer 没有关系
-- 考虑到某些 Writer 后端本身就具备绘制功能例如 `WandbWriter`,此时 `WandbWriter` 中的 Visualizer 属性就是可选的,如果用户在初始化时候传入了 Visualizer 对象,则在 `add_image` 时候会调用 Visualizer 对象,否则会直接调用 Wandb 本身 API 进行图片绘制
-- `LocalWriter` 和 `TensorboardWriter` 由于绘制功能单一,目前强制由 Visualizer 对象绘制,所以这两个 Writer 必须传入 Visualizer 或者子类对象
-
-`WandbWriter` 的一个简略的演示代码如下
-
-```python
-# 为了方便理解,没有继承 BaseWriter
-class WandbWriter:
- def __init__(self, visualizer=None):
- self._visualizer = None
- if visualizer:
- # 示例配置 visualizer=dict(type='DetVisualizer')
- self._visualizer = VISUALIZERS.build(visualizer)
-
- @property
- def visualizer(self):
- return self._visualizer
-
- def add_image(self, name, image, gt_sample=None, pred_sample=None, draw_gt=True, draw_pred=True, step=0, **kwargs):
- if self._visualize:
- self._visualize.draw(image, gt_sample, pred_sample, draw_gt, draw_pred)
- # 调用 Writer API 写图片到后端
- self.wandb.log({name: self.visualizer.get_image()}, ...)
- ...
- else:
- # 调用 Writer API 汇总并写图片到后端
- ...
-
- def add_scalar(self, name, value, step):
- self.wandb.log({name: value}, ...)
-```
-
-
-## 绘制对象 Visualizer
-
-绘制对象 Visualizer 负责单张图片的各类绘制功能,默认绘制后端为 Matplotlib。为了统一 OpenMMLab 各个算法库的可视化接口,MMEngine 定义提供了基础绘制功能的 `Visualizer` 类,下游库可以继承 `Visualizer` 并实现 `draw` 接口来满足自己的绘制需求。
-
-### Visualizer
-
-`Visualizer` 提供了基础而通用的绘制功能,主要接口如下:
-
-**(1) 绘制无关的功能性接口**
-
-- [set_image](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.Visualizer.set_image) 设置原始图片数据,默认输入图片格式为 RGB
-- [get_image](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.Visualizer.get_image) 获取绘制后的 Numpy 格式图片数据,默认输出格式为 RGB
-- [show](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.Visualizer.show) 可视化
-- [register_task](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.Visualizer.register_task) 注册绘制函数(其作用在 *自定义 Visualizer* 小节描述)
-
-**(2) 绘制相关接口**
-
-- [draw](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.Visualizer.draw) 用户使用的抽象绘制接口
-- [draw_featmap](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.Visualizer.draw_featmap) 绘制特征图
-- [draw_bboxes](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.Visualizer.draw_bboxes) 绘制单个或者多个边界框
-- [draw_texts](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.Visualizer.draw_texts) 绘制单个或者多个文本框
-- [draw_lines](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.Visualizer.lines) 绘制单个或者多个线段
-- [draw_circles](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.Visualizer.draw_circles) 绘制单个或者多个圆
-- [draw_polygons](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.Visualizer.draw_polygons) 绘制单个或者多个多边形
-- [draw_binary_masks](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.Visualizer.draw_binary_mask) 绘制单个或者多个二值掩码
-
-用户除了可以单独调用 `Visualizer` 中基础绘制接口,同时也提供了链式调用功能和特征图可视化功能。`draw` 函数是抽象接口,内部没有任何实现,继承了 Visualizer 的类可以实现该接口,从而对外提供统一的绘制功能,而 `draw_xxx` 等目的是提供最基础的绘制功能,用户一般无需重写。
-
-**(1) 链式调用**
-
-例如用户先绘制边界框,在此基础上绘制文本,绘制线段,则调用过程为:
+当用户想先绘制边界框,在此基础上绘制文本,绘制线段的时候,可以通过链式调用实现:
```python
visualizer.set_image(image)
@@ -112,190 +47,254 @@ visualizer.draw_bboxes(...).draw_texts(...).draw_lines(...)
visualizer.show() # 可视化绘制结果
```
-**(2) 可视化特征图**
-
-特征图可视化是一个常见的功能,通过调用 `draw_featmap` 可以直接可视化特征图,其参数定义为:
+特征图可视化是一个常见的功能,用户通过调用 `draw_featmap` 可视化特征图,其参数定义为:
```python
@staticmethod
-def draw_featmap(tensor_chw: torch.Tensor, # 输入格式要求为 CHW
- image: Optional[np.ndarray] = None, # 如果同时输入了 image 数据,则特征图会叠加到 image 上绘制
- mode: Optional[str] = 'mean', # 多个通道压缩为单通道的策略
+def draw_featmap(featmap: torch.Tensor, # 输入格式要求为 CHW
+ overlaid_image: Optional[np.ndarray] = None, # 如果同时输入了 image 数据,则特征图会叠加到 image 上绘制
+ channel_reduction: Optional[str] = 'squeeze_mean', # 多个通道压缩为单通道的策略
topk: int = 10, # 可选择激活度最高的 topk 个特征图显示
arrangement: Tuple[int, int] = (5, 2), # 多通道展开为多张图时候布局
- alpha: float = 0.3) -> np.ndarray: # 图片和特征图绘制的叠加比例
+ resize_shape:Optional[tuple] = None, # 可以指定 resize_shape 参数来缩放特征图
+ alpha: float = 0.5) -> np.ndarray: # 图片和特征图绘制的叠加比例
```
-特征图可视化功能较多,目前不支持 Batch 输入
+特征图可视化功能较多,目前不支持 Batch 输入,其功能可以归纳如下
-- mode 不是 None,topk 无效,会将多个通道输出采用 mode 模式函数压缩为单通道,变成单张图片显示,目前 mode 仅支持 None、'mean'、'max' 和 'min' 参数输入
-- mode 是 None,topk 有效,如果 topk 不是 -1,则会按照激活度排序选择 topk 个通道显示,此时可以通过 arrangement 参数指定显示的布局
-- mode 是 None,topk 有效,如果 `topk = -1`,此时通道 C 必须是 1 或者 3 表示输入数据是图片,可以直接显示,否则报错提示用户应该设置 mode 来压缩通道
+- 输入的 Tensor 一般是包括多个通道的,channel_reduction 参数可以将多个通道压缩为单通道,然后和图片进行叠加显示
+ - `squeeze_mean` 将输入的 C 维度采用 mean 函数压缩为一个通道,输出维度变成 (1, H, W)
+ - `select_max` 从输入的 C 维度中先在空间维度 sum,维度变成 (C, ),然后选择值最大的通道
+ - `None` 表示不需要压缩,此时可以通过 topk 参数可选择激活度最高的 topk 个特征图显示
+
+- 在 channel_reduction 参数为 None 的情况下,topk 参数生效,其会按照激活度排序选择 topk 个通道,然后和图片进行叠加显示,并且此时会通过 arrangement 参数指定显示的布局
+ - 如果 topk 不是 -1,则会按照激活度排序选择 topk 个通道显示
+ - 如果 topk = -1,此时通道 C 必须是 1 或者 3 表示输入数据是图片,否则报错提示用户应该设置 `channel_reduction`来压缩通道。
+
+- 考虑到输入的特征图通常非常小,函数支持输入 `resize_shape` 参数,方便将特征图进行上采样后进行可视化。
+
+**(2) 存储相关接口**
+
+- [add_config](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.writer.BaseWriter.add_config) 写配置到特定存储后端
+- [add_graph](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.writer.BaseWriter.add_graph) 写模型图到特定存储后端
+- [add_image](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.writer.BaseWriter.add_image) 写图片到特定存储后端
+- [add_scalar](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.writer.BaseWriter.add_scalar) 写标量到特定存储后端
+- [add_scalars](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.writer.BaseWriter.add_scalars) 一次性写多个标量到特定存储后端
+- [add_datasample](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.writer.BaseWriter.add_datasample) 各个下游库绘制 datasample 数据的抽象接口
+
+以 add 前缀开头的接口表示存储接口。datasample 是 OpenMMLab 2.0 架构中设计的各个下游库统一的抽象数据接口,而 `add_datasample` 接口可以直接处理该数据格式,例如可视化预测结果、可视化 Dataset 或者 DataLoader 输出、可视化中间预测结果等等都可以直接调用下游库重写的 `add_datasample` 接口。
+
+所有下游库都必须要继承 Visualizer 并实现 `add_datasample` 接口。以 MMDetection 为例,应该继承并通过该接口实现目标检测中所有预置任务的可视化功能,例如目标检测、实例分割、全景分割任务结果的绘制和存储。
+
+**(3) 其余功能性接口**
+
+- [set_image](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.Visualizer.set_image) 设置原始图片数据,默认输入图片格式为 RGB
+- [get_image](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.Visualizer.get_image) 获取绘制后的 Numpy 格式图片数据,默认输出格式为 RGB
+- [show](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.Visualizer.show) 可视化
+- [get_backend](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.Visualizer.get_backend) 通过 name 获取特定存储后端
+- [close](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.Visualizer.close) 关闭所有已经打开的资源,包括 VisBackend
+
+### 使用样例
+
+**(1) 在任意位置获取 visualizer**
+
+为了确保可视化对象 Visualizer 能够在任何地方被调用,设计上将其继承自 `ManagerMixin` 类,转变为全局唯一对象,用户初始化 `Visualizer` 时必须要调用 `visualizer.get_instance()` 方法才能使实例对象具备全局唯一性。一旦实例化完成,后续可以在任意代码位置通过 `Visualizer.get_current_instance()` 来获取可视化对象。
+
+以 MMDetection 为例,假设 `DetLocalVisualizer` 类继承自 `Visualizer`,并实现了 `add_datasample` 接口。配置文件写法为:
```python
-featmap=visualizer.draw_featmap(tensor_chw,image)
+vis_backends = [dict(type='LocalVisBackend')]
+visualizer = dict(
+ type='DetLocalVisualizer', vis_backends=vis_backends, name='visualizer')
+```
+```python
+# 内部会调用 get_instance() 进行全局唯一实例化
+VISUALIZERS.build(cfg.visualizer)
```
-### 自定义 Visualizer
-
-自定义的 Visualizer 中大部分情况下只需要实现 `get_image` 和 `draw` 接口。`draw` 是最高层的用户调用接口,`draw` 接口负责所有绘制功能,例如绘制检测框、检测掩码 mask 和 检测语义分割图等等。依据任务的不同,`draw` 接口实现的复杂度也不同。
-
-以目标检测可视化需求为例,可能需要同时绘制边界框 bbox、掩码 mask 和语义分割图 seg_map,如果如此多功能全部写到 `draw` 方法中会难以理解和维护。为了解决该问题,`Visualizer` 基于 OpenMMLab 2.0 抽象数据接口规范支持了 `register_task` 函数。假设 MMDetection 中需要同时绘制预测结果中的 instances 和 sem_seg,可以在 MMDetection 的 `DetVisualizer` 中实现 `draw_instances` 和 `draw_sem_seg` 两个方法,用于绘制预测实例和预测语义分割图, 我们希望只要输入数据中存在 instances 或 sem_seg 时候,对应的两个绘制函数 `draw_instances` 和 `draw_sem_seg` 能够自动被调用,而用户不需要手动调用。为了实现上述功能,可以通过在 `draw_instances` 和 `draw_sem_seg` 两个函数加上 `@Visualizer.register_task` 装饰器,此时 `task_dict` 中就会存储字符串和函数的映射关系,在调用 `draw` 方法时候就可以通过 `self.task_dict`获取到已经被注册的函数。一个简略的实现如下所示
+通过上述代码实例化后,可以在任意位置调用 `get_current_instance` 方法来获取 visualizer
```python
-class DetVisualizer(Visualizer):
-
- def draw(self, image, gt_sample=None, pred_sample=None, draw_gt=True, draw_pred=True):
- # 将图片和 matplotlib 布局关联
- self.set_image(image)
-
- if draw_gt:
- # self.task_dict 内部存储如下信息:
- # dict(instances=draw_instance 方法,sem_seg=draw_sem_seg 方法)
- for task in self.task_dict:
- task_attr = 'gt_' + task
- if task_attr in gt_sample:
- self.task_dict[task](self, gt_sample[task_attr], 'gt')
- if draw_pred:
- for task in self.task_dict:
- task_attr = 'pred_' + task
- if task_attr in pred_sample:
- self.task_dict[task](self, pred_sample[task_attr], 'pred')
-
- # data_type 用于区分当前绘制的内容是标注还是预测结果
- @Visualizer.register_task('instances')
- def draw_instance(self, instances, data_type):
- ...
-
- # data_type 用于区分当前绘制的内容是标注还是预测结果
- @Visualizer.register_task('sem_seg')
- def draw_sem_seg(self, pixel_data, data_type):
- ...
+# 任意代码位置获取 visualizer
+visualizer = Visualizer.get_current_instance()
```
-注意:是否使用 `register_task` 装饰器函数不是必须的,如果用户自定义 Visualizer,并且 `draw` 实现非常简单,则无需考虑 `register_task`。
+如果用户直接使用了 MMEngine 或者下游库中的 Runner,则无需进行额外的实例化,因为在 Runner 的初始化函数中会自动创建全局唯一的 visualizer。
-在使用 Jupyter notebook 或者其他地方不需要写数据到指定后端的情形下,用户可以自己实例化 visualizer。一个简单的例子如下
+**(2) 将数据写入至特定后端**
+
+在获取到 visualizer 后,可以调用 `add_xxx` 接口将各类数据写入到特定后端
```python
-# 实例化 visualizer
-visualizer=dict(type='DetVisualizer')
-visualizer = VISUALIZERS.build(visualizer)
-visualizer.draw(image, datasample)
-visualizer.show() # 可视化绘制结果
-```
+# 绘制 datasample,并保存到本地存储后端
+visualizer.add_datasample('demo_image', image, gt_sample, pred_sample, step=1)
+# 直接本地窗口显示,而无需存储
+visualizer.add_datasample('demo_image', image, gt_sample, pred_sample, show=True)
-## 写端 Writer
-
-Visualizer 只实现了单张图片的绘制功能,但是在训练或者测试过程中,对一些关键指标或者模型训练超参的记录非常重要,此功能通过写端 Writer 实现。为了统一接口调用,MMEngine 提供了统一的抽象类 `BaseWriter`,和一些常用的 Writer 如 `LocalWriter` 、`TensorboardWriter` 和 `WandbWriter` 。
-
-### BaseWriter
-
-BaseWriter 定义了对外调用的接口规范,主要接口和属性如下:
-
-- [add_params](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.writer.BaseWriter.add_params) 写超参到特定后端,常见的训练超参如初始学习率 LR、权重衰减系数和批大小等等
-- [add_graph](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.writer.BaseWriter.add_graph) 写模型图到特定后端
-- [add_image](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.writer.BaseWriter.add_image) 写图片到特定后端
-- [add_scalar](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.writer.BaseWriter.add_scalar) 写标量到特定后端
-- [add_scalars](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.writer.BaseWriter.add_scalars) 一次性写多个标量到特定后端
-- [visualizer](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.writer.BaseWriter.visualizer) 绘制对象
-- [experiment](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.writer.BaseWriter.experiment) 写后端对象,例如 Wandb 对象和 Tensorboard 对象
-
-`BaseWriter` 定义了 5 个常见的写数据接口,考虑到某些写后端功能非常强大,例如 Wandb,其具备写表格,写视频等等功能,针对这类需求用户可以直接获取 experiment 对象,然后调用写后端对象本身的 API 即可。
-
-### LocalWriter、TensorboardWriter 和 WandbWriter
-
-`LocalWriter` 提供了将数据写入到本地磁盘功能。如果用户需要写图片到硬盘,则**必须要通过初始化参数提供 Visualizer对象**。其典型用法为:
-
-```python
-# 配置文件
-writer=dict(type='LocalWriter', save_dir='demo_dir', visualizer=dict(type='DetVisualizer'))
-# 实例化和调用
-local_writer=WRITERS.build(writer)
-# 写模型精度值
-local_writer.add_scalar('mAP', 0.9)
-local_writer.add_scalars({'loss': 1.2, 'acc': 0.8})
-# 写超参
-local_writer.add_params(dict(lr=0.1, mode='linear'))
# 写图片
-local_writer.add_image('demo_image', image, datasample)
-```
-
-如果用户有自定义绘制需求,则可以通过获取内部的 visualizer 属性来实现,如下所示
-
-```python
-# 配置文件
-writer=dict(type='LocalWriter', save_dir='demo_dir', visualizer=dict(type='DetVisualizer'))
-# 实例化和调用
-local_writer=WRITERS.build(writer)
-# 写图片
-local_writer.visualizer.draw_bboxes(np.array([0, 0, 1, 1]))
-local_writer.add_image('img', local_writer.visualizer.get_image())
-
-# 绘制特征图并保存到本地
-featmap_image=local_writer.visualizer.draw_featmap(tensor_chw)
-local_writer.add_image('featmap', featmap_image)
-```
-
-`TensorboardWriter` 提供了将各类数据写入到 Tensorboard 功能,其用法和 LocalWriter 非常类似。 注意如果用户需要写图片到 Tensorboard,则**必须要通过初始化参数提供 Visualizer对象**。
-
-`WandbWriter` 提供了将各类数据写入到 Wandb 功能。考虑到 Wandb 本身具备强大的图片功能,在调用 `WandbWriter` 的 `add_image` 方法时 Visualizer 对象是可选的,如果用户指定了 Visualizer 对象,则会调用 Visualizer 对象的绘制方法,否则直接调用 Wandb 自带的图片处理功能。
-
-## 组合写端 ComposedWriter
-
-考虑到在训练或者测试过程中,可能需要同时调用多个 Writer,例如想同时写到本地和 Wandb 端,为此设计了对外的 `ComposedWriter` 类,在训练或者测试过程中 `ComposedWriter` 会依次调用各个 Writer 的接口,其接口和 `BaseWriter` 一致,主要接口如下:
-
-- [add_params](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.writer.ComposedWriter.add_params) 写超参到所有已经加入的后端中,常见的训练超参如初始学习率 LR、权重衰减系数和批大小等等
-- [add_graph](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.writer.ComposedWriter.add_graph) 写模型图到所有已经加入的后端中
-- [add_image](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.writer.ComposedWriter.add_image) 写图片到所有已经加入的后端中
-- [add_scalar](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.writer.ComposedWriter.add_scalar) 写标量到所有已经加入的后端中
-- [add_scalars](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.writer.ComposedWriter.add_scalars) 一次性写多个标量到所有已经加入的后端中
-- [get_writer](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.writer.ComposedWriter.get_writer) 获取指定索引的 Writer,任何一个 Writer 中包括了 experiment 和 visualizer 属性
-- [get_experiment](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.writer.ComposedWriter.get_experiment) 获取指定索引的 experiment
-- [get_visualizer](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.writer.ComposedWriter.get_visualizer) 获取指定索引的 visualizer
-- [close](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.writer.ComposedWriter.close) 调用所有 Writer 的 close 方法
-
-为了让用户可以在代码的任意位置进行数据可视化,`ComposedWriter` 类继承至 [全局可访问基类 BaseGlobalAccessible](./logging.md/#全局可访问基类baseglobalaccessible)。一旦继承了全局可访问基类, 用户就可以通过调用 `ComposedWriter` 对象的 `get_instance` 来获取全局对象。其基本用法如下
-
-```python
-# 创建实例
-writers=[dict(type='LocalWriter', save_dir='temp_dir', visualizer=dict(type='DetVisualizer')), dict(type='WandbWriter')]
-
-ComposedWriter.create_instance('composed_writer', writers=writers)
-```
-
-一旦创建实例后,可以在代码任意位置获取 `ComposedWriter` 对象
-
-```python
-composed_writer=ComposedWriter.get_instance('composed_writer')
+visualizer.add_image('demo_image', image, step=1)
# 写模型精度值
-composed_writer.add_scalar('mAP', 0.9)
-composed_writer.add_scalars({'loss': 1.2, 'acc': 0.8})
-# 写超参
-composed_writer.add_params(dict(lr=0.1, mode='linear'))
-# 写图片
-composed_writer.add_image('demo_image', image, datasample)
+visualizer.add_scalar('mAP', 0.9, step=1)
+visualizer.add_scalars({'loss': 1.2, 'acc': 0.8}, step=1)
+
+# 写配置文件
+visualizer.add_config(cfg)
+
# 写模型图
-composed_writer.add_graph(model, input_array)
+visualizer.add_graph(model, data_batch)
```
-对于一些用户需要的自定义绘制需求或者上述接口无法满足的需求,用户可以通过 `get_xxx` 方法获取具体对象来实现特定需求
+**(3) 特征图可视化**
+
+通过 `channel_reduction` 参数压缩或者选择特征图,并显示到本地窗口
```python
-composed_writer=ComposedWriter.get_instance('composed_writer')
+featmap = ... # CHW shape 的 tensor
-# 绘制特征图,获取 LocalWriter 中的 visualizer
-visualizer=composed_writer.get_visualizer(0)
-featmap_image=visualizer.draw_featmap(tensor_chw)
-composed_writer.add_image('featmap', featmap_image)
+# 压缩
+feat_img = visualizer.draw_featmap(featmap, channel_reduction='squeeze_mean')
+visualizer.show(feat_img)
+
+# 选择激活度最高的通道显示
+feat_img = visualizer.draw_featmap(featmap, channel_reduction='select_max')
+visualizer.show(feat_img)
+```
+
+叠加图片显示
+
+```python
+featmap = ... # CHW shape 的 tensor
+img = ... # 如果 featmap 和 img 空间尺寸不一致,内部会对 featmap 进行插值
+
+# 压缩
+feat_img = visualizer.draw_featmap(featmap, img, channel_reduction='squeeze_mean')
+visualizer.show(feat_img)
+
+# 选择激活度最高的通道显示
+feat_img = visualizer.draw_featmap(featmap, img, channel_reduction='select_max')
+visualizer.show(feat_img)
+```
+
+通过 `topk` 参数选择指定个数的通道显示,并显示到本地窗口
+
+```python
+featmap= ... # CHW shape 的 tensor
+
+# topk,并以 2 行 5 列模式显示
+feat_img = visualizer.draw_featmap(featmap, channel_reduction=None, topk=10, arrangement=(2, 5))
+visualizer.show(feat_img)
+
+# topk,并以 5 行 2 列模式显示
+feat_img = visualizer.draw_featmap(featmap, channel_reduction=None, topk=10, arrangement=(5, 2))
+visualizer.show(feat_img)
+```
+
+通过 `resize_shape` 缩放显示的特征图
+
+```python
+featmap = ... # CHW shape 的 tensor
+
+# 压缩
+feat_img = visualizer.draw_featmap(featmap, channel_reduction='squeeze_mean', resize_shape=(224, 224))
+visualizer.show(feat_img)
+```
+
+存储特征图到可视化后端
+
+```python
+featmap = ... # CHW shape 的 tensor
+
+# 压缩
+feat_img = visualizer.draw_featmap(featmap, channel_reduction='squeeze_mean', resize_shape=(224, 224))
+# 存储
+visualizer.add_image('feat_image', feat_img)
+```
+
+**(4) 远程窗口显示**
+
+用户可以指定 Wandb 、Tensorboard 或者自定义具备远程窗口显示的后端来保存数据,然后在浏览器上显示。以 Wandb 为例,典型配置为:
+
+```python
+vis_backends = [dict(type='WandbVisBackend')]
+visualizer = dict(
+ type='DetWandbVisualizer', vis_backends=vis_backends, name='visualizer')
+```
+
+使用方法和上面完全一致。需要特别注意的是由于 Wandb 绘制的数据无法和 `LocalVisBackend` 后端兼容,所以当 `vis_backends` 存在多个可视化存储后端时候只有 `WandbVisBackend` 才是有效的。
+
+## 可视化存储后端 VisBackend
+
+在绘制后可以将绘制后的数据存储到多个可视化存储后端中。为了统一接口调用,MMEngine 提供了统一的抽象类 `BaseVisBackend`,和一些常用的 VisBackend 如 `LocalVisBackend`、`WandbVisBackend` 和 `TensorboardVisBackend`。
+
+### 接口说明
+
+BaseVisBackend 定义了对外调用的接口规范,主要接口和属性如下:
+
+- [add_config](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.vis_backend.BaseVisBackend.add_config) 写配置到特定存储后端
+- [add_graph](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.vis_backend.BaseVisBackend.add_graph) 写模型图到特定后端
+- [add_image](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.vis_backend.BaseVisBackend.add_image) 写图片到特定后端
+- [add_scalar](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.vis_backend.BaseVisBackend.add_scalar) 写标量到特定后端
+- [add_scalars](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.vis_backend.BaseVisBackend.add_scalars) 一次性写多个标量到特定后端
+- [close](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.vis_backend.BaseVisBackend.close) 关闭已经打开的资源
+- [experiment](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.vis_backend.BaseVisBackend.experiment) 写后端对象,例如 Wandb 对象和 Tensorboard 对象
+
+`BaseVisBackend` 定义了 5 个常见的写数据接口,考虑到某些写后端功能非常强大,例如 Wandb,其具备写表格,写视频等等功能,针对这类需求用户可以直接获取 experiment 对象,然后调用写后端对象本身的 API 即可。而 `LocalVisBackend`、`WandbVisBackend` 和 `TensorboardVisBackend` 等都是继承自 `BaseVisBackend`,并根据自身特性实现了对应的存储功能。
+
+### 使用案例
+
+一般情况下用户无需操作 VisBackend 对象,只有在当前可视化存储无法满足需求时候,用户会希望直接操作存储后端。以 Wandb 为例,其提供了非常丰富的存储格式,例如存储表格、存储权重等等接口。为了所有后端能够统一接口,我们并没有提供这类常用接口,此时用户可以直接获取 Wandb 对象进行自定义存储。
+
+```python
+vis_backends = [dict(type='WandbVisBackend')]
+visualizer = dict(
+ type='DetWandbVisualizer', vis_backends=vis_backends, name='visualizer')
+```
+
+```python
+# 内部会调用 get_instance() 进行全局唯一实例化
+VISUALIZERS.build(cfg.visualizer)
+# 任意代码位置获取 visualizer
+visualizer = Visualizer.get_current_instance()
# 扩展 add 功能,例如利用 Wandb 对象绘制表格
-wandb=composed_writer.get_experiment(1)
+wandb = visualizer.get_backend('WandbVisBackend').experiment
val_table = wandb.Table(data=my_data, columns=column_names)
wandb.log({'my_val_table': val_table})
-
-# 配置中存在多个 Writer,在不想改动配置情况下只使用 LocalWriter
-local_writer=composed_writer.get_writer(0)
-local_writer.add_image('demo_image', image, datasample)
+```
+
+一个 visualizer 对象可以接入任意多个 VisBackend。为了方便用户获取任意的 VisBackend,在不指定 name 参数情况下,可以通过类名获取
+
+```python
+vis_backends = [dict(type='LocalVisBackend'), dict(type='WandbVisBackend')]
+visualizer = dict(
+ type='DetLocalVisualizer', vis_backends=vis_backends, name='visualizer')
+```
+
+```python
+# 内部会调用 get_instance() 进行全局唯一实例化
+VISUALIZERS.build(cfg.visualizer)
+# 任意代码位置获取 visualizer
+visualizer = Visualizer.get_current_instance()
+
+local_vis_backend = visualizer.get_backend('LocalVisBackend')
+wandb_vis_backend = visualizer.get_backend('WandbVisBackend')
+```
+
+当存在多个同名的 VisBackend 时候,用户必须指定唯一的 name 参数,后续可以通过 name 字符串来获取
+
+```python
+vis_backends = [dict(type='LocalVisBackend', name='local_vis_backend_1'), dict(type='LocalVisBackend', name='local_vis_backend_2')]
+visualizer = dict(
+ type='DetLocalVisualizer', vis_backends=vis_backends, name='visualizer')
+```
+
+```python
+# 内部会调用 get_instance() 进行全局唯一实例化
+VISUALIZERS.build(cfg.visualizer)
+# 任意代码位置获取 visualizer
+visualizer = Visualizer.get_current_instance()
+
+local_vis_backend_1 = visualizer.get_backend('local_vis_backend_1')
+local_vis_backend_2 = visualizer.get_backend('local_vis_backend_2')
```
From e2a2b0438edbdca44414f73db6fc9f0a3bcc6a3a Mon Sep 17 00:00:00 2001
From: Mashiro <57566630+HAOCHENYE@users.noreply.github.com>
Date: Sun, 24 Apr 2022 19:23:28 +0800
Subject: [PATCH 3/4] [Refactor] Refine LoggerHook (#155)
* rename global accessible and intergration get_sintance and create_instance
* move ManagerMixin to utils
* fix as docstring and seporate get_instance to get_instance and get_current_instance
* fix lint
* fix docstring, rename and move test_global_meta
* rename LogBuffer to HistoryBuffer, rename MessageHub methods, MessageHub support resume
* refine MMLogger timestamp, update unit test
* MMLogger add logger_name arguments
* Fix docstring
* Add LogProcessor and some unit test
* update unit test
* complete LogProcessor unit test
* refine LoggerHook
* solve circle import
* change default logger_name to mmengine
* refactor eta
* Fix docstring comment and unitt test
* Fix with runner
* fix docstring
fix docstring
* fix docstring
* Add by_epoch attribute to LoggerHook and fix docstring
* Please mypy and fix comment
* remove \ in MMLogger
* Fix lint
* roll back pre-commit-hook
* Fix hook unit test
* Fix comments
* remove \t in log and add docstring
* Fix as comment
* should not accept other arguments if corresponding instance has been created
* fix logging ddp file saving
* fix logging ddp file saving
* move log processor to logging
* move log processor to logging
* remove current datalaoder
* fix docstring
* fix unit test
* add learing rate in messagehub
* Support output training/validation/testing message after iterations/epochs
* fix docstring
* Fix IterBasedRunner log string
* Fix IterBasedRunner log string
* Support parse validation loss in log processor
---
mmengine/hooks/hook.py | 19 +-
mmengine/hooks/iter_timer_hook.py | 64 +++-
mmengine/hooks/logger_hook.py | 461 +++++------------------
mmengine/hooks/optimizer_hook.py | 3 +
mmengine/logging/__init__.py | 5 +-
mmengine/logging/log_processor.py | 409 ++++++++++++++++++++
mmengine/logging/logger.py | 18 +-
mmengine/runner/runner.py | 11 +-
tests/test_hook/test_hook.py | 9 +-
tests/test_hook/test_iter_timer_hook.py | 69 +++-
tests/test_hook/test_logger_hook.py | 347 ++++-------------
tests/test_hook/test_optimizer_hook.py | 4 +-
tests/test_logging/test_log_processor.py | 242 ++++++++++++
tests/test_runner/test_runner.py | 2 +-
14 files changed, 961 insertions(+), 702 deletions(-)
create mode 100644 mmengine/logging/log_processor.py
create mode 100644 tests/test_logging/test_log_processor.py
diff --git a/mmengine/hooks/hook.py b/mmengine/hooks/hook.py
index 84060334..1e6e9370 100644
--- a/mmengine/hooks/hook.py
+++ b/mmengine/hooks/hook.py
@@ -358,11 +358,11 @@ class Hook:
"""
return (runner.epoch + 1) % n == 0 if n > 0 else False
- def every_n_inner_iters(self, inner_iter: int, n: int) -> bool:
+ def every_n_inner_iters(self, batch_idx: int, n: int) -> bool:
"""Test whether current inner iteration can be evenly divided by n.
Args:
- inner_iter (int): Current inner_iter of the training, validation
+ batch_idx (int): Current batch index of the training, validation
or testing loop.
n (int): Whether current inner iteration can be evenly
divided by n.
@@ -371,7 +371,7 @@ class Hook:
bool: Whether current inner iteration can be evenly
divided by n.
"""
- return (inner_iter + 1) % n == 0 if n > 0 else False
+ return (batch_idx + 1) % n == 0 if n > 0 else False
def every_n_iters(self, runner, n: int) -> bool:
"""Test whether current iteration can be evenly divided by n.
@@ -387,19 +387,18 @@ class Hook:
"""
return (runner.iter + 1) % n == 0 if n > 0 else False
- def end_of_epoch(self, runner, batch_idx: int) -> bool:
+ def end_of_epoch(self, dataloader, batch_idx: int) -> bool:
"""Check whether the current iteration reaches the last iteration of
- current dataloader.
+ the dataloader.
Args:
- runner (Runner): The runner of the training, validation or testing
- process.
+ dataloader (Dataloader): The dataloader of the training,
+ validation or testing process.
batch_idx (int): The index of the current batch in the loop.
-
Returns:
bool: Whether reaches the end of current epoch or not.
"""
- return batch_idx + 1 == len(runner.cur_dataloader)
+ return batch_idx + 1 == len(dataloader)
def is_last_train_epoch(self, runner) -> bool:
"""Test whether current epoch is the last train epoch.
@@ -418,10 +417,10 @@ class Hook:
Args:
runner (Runner): The runner of the training, validation or testing
process.
+ mode (str): Current mode of runner. Defaults to 'train'.
Returns:
bool: Whether current iteration is the last iteration.
- mode (str): Current mode of runner. Defaults to 'train'.
"""
if mode == 'train':
return runner.iter + 1 == runner.train_loop.max_iters
diff --git a/mmengine/hooks/iter_timer_hook.py b/mmengine/hooks/iter_timer_hook.py
index d281745d..ef7124d5 100644
--- a/mmengine/hooks/iter_timer_hook.py
+++ b/mmengine/hooks/iter_timer_hook.py
@@ -18,11 +18,25 @@ class IterTimerHook(Hook):
priority = 'NORMAL'
- def _before_epoch(self, runner, mode: str = 'train') -> None:
- """Record time flag before start a epoch.
+ def __init__(self):
+ self.time_sec_tot = 0
+ self.start_iter = 0
+
+ def before_run(self, runner) -> None:
+ """Synchronize the number of iterations with the runner.
Args:
- runner (Runner): The runner of the training process.
+ runner: The runner of the training, validation or testing
+ process.
+ """
+ self.start_iter = runner.iter
+
+ def _before_epoch(self, runner, mode: str = 'train') -> None:
+ """Record timestamp before start an epoch.
+
+ Args:
+ runner (Runner): The runner of the training validation and
+ testing process.
mode (str): Current mode of runner. Defaults to 'train'.
"""
self.t = time.time()
@@ -32,16 +46,18 @@ class IterTimerHook(Hook):
batch_idx: int,
data_batch: DATA_BATCH = None,
mode: str = 'train') -> None:
- """Logging time for loading data and update the time flag.
+ """Calculating time for loading data and updating "data_time"
+ ``HistoryBuffer`` of ``runner.message_hub``.
Args:
- runner (Runner): The runner of the training process.
+ runner (Runner): The runner of the training, validation and
+ testing process.
batch_idx (int): The index of the current batch in the loop.
data_batch (Sequence[Tuple[Any, BaseDataElement]], optional): Data
from dataloader. Defaults to None.
mode (str): Current mode of runner. Defaults to 'train'.
"""
- # TODO: update for new logging system
+ # Update data loading time in `runner.message_hub`.
runner.message_hub.update_scalar(f'{mode}/data_time',
time.time() - self.t)
@@ -52,10 +68,12 @@ class IterTimerHook(Hook):
outputs: Optional[Union[dict,
Sequence[BaseDataElement]]] = None,
mode: str = 'train') -> None:
- """Logging time for a iteration and update the time flag.
+ """Calculating time for an iteration and updating "time"
+ ``HistoryBuffer`` of ``runner.message_hub``.
Args:
- runner (Runner): The runner of the training process.
+ runner (Runner): The runner of the training validation and
+ testing process.
batch_idx (int): The index of the current batch in the loop.
data_batch (Sequence[Tuple[Any, BaseDataElement]], optional): Data
from dataloader. Defaults to None.
@@ -63,7 +81,31 @@ class IterTimerHook(Hook):
to None.
mode (str): Current mode of runner. Defaults to 'train'.
"""
- # TODO: update for new logging system
-
- runner.message_hub.update_scalar(f'{mode}/time', time.time() - self.t)
+ # Update iteration time in `runner.message_hub`.
+ message_hub = runner.message_hub
+ message_hub.update_scalar(f'{mode}/time', time.time() - self.t)
self.t = time.time()
+ window_size = runner.log_processor.window_size
+ # Calculate eta every `window_size` iterations. Since test and val
+ # loop will not update runner.iter, use `every_n_innter_iters`to check
+ # the interval.
+ if self.every_n_inner_iters(batch_idx, window_size):
+ iter_time = message_hub.get_scalar(f'{mode}/time').mean(
+ window_size)
+ if mode == 'train':
+ self.time_sec_tot += iter_time * window_size
+ # Calculate average iterative time.
+ time_sec_avg = self.time_sec_tot / (
+ runner.iter - self.start_iter + 1)
+ # Calculate eta.
+ eta_sec = time_sec_avg * (
+ runner.train_loop.max_iters - runner.iter - 1)
+ runner.message_hub.update_info('eta', eta_sec)
+ else:
+ if mode == 'val':
+ cur_dataloader = runner.val_loop.dataloader
+ else:
+ cur_dataloader = runner.test_loop.dataloader
+
+ eta_sec = iter_time * (len(cur_dataloader) - batch_idx - 1)
+ runner.message_hub.update_info('eta', eta_sec)
diff --git a/mmengine/hooks/logger_hook.py b/mmengine/hooks/logger_hook.py
index aed1d0e0..87b69114 100644
--- a/mmengine/hooks/logger_hook.py
+++ b/mmengine/hooks/logger_hook.py
@@ -1,16 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
-import copy
-import datetime
import os
import os.path as osp
-from collections import OrderedDict
from pathlib import Path
from typing import Any, Optional, Sequence, Tuple, Union
-import torch
-
from mmengine.data import BaseDataElement
-from mmengine.dist import master_only
from mmengine.fileio import FileClient
from mmengine.hooks import Hook
from mmengine.registry import HOOKS
@@ -21,33 +15,20 @@ DATA_BATCH = Optional[Sequence[Tuple[Any, BaseDataElement]]]
@HOOKS.register_module()
class LoggerHook(Hook):
- """In this logger hook, the information will be printed on the terminal and
- saved in JSON file, tensorboard, wandb .etc.
+ """Collect logs from different components of ``Runner`` and write them to
+ terminal, JSON file, tensorboard and wandb .etc.
+
+ ``LoggerHook`` is used to record logs formatted by ``LogProcessor`` during
+ training/validation/testing phase. It is used to control following
+ behaviers:
+
+ - The frequency of logs update in terminal, local, tensorboad wandb.etc.
+ - The frequency of show experiment information in terminal.
+ - The work directory to save logs.
Args:
- by_epoch (bool): Whether ``EpochBasedLoop`` is used.
- Defaults to True.
interval (int): Logging interval (every k iterations).
Defaults to 10.
- custom_keys (dict, optional): Defines the keys in the log and which
- kinds of statistic methods should be used to log them.
-
- - ``custom_keys`` contains multiple string-dict pairs. In each
- string-dict pair, the string defines a key name in the log and the
- dict is a config defines the statistic methods and corresponding
- arguments used to log the value. For example,
- ``dict(loss=dict(method_name='mean', log_name='global_loss',
- window_size='global'))`` which means the log key ``loss`` will be
- counted as global mean and additionally logged as ``global_loss``.
- If ``log_name`` is not defined in config dict, the original logged
- key will be overwritten.
- - The key in ``LoggerHook.fixed_smooth_keys`` cannot be overwritten
- because ``time`` and ``iter_time`` will be used to calculate
- estimated time of arrival. If you want to recount the time, you
- should set ``log_name`` in corresponding values.
- - For those statistic methods with the ``window_size`` argument,
- if ``by_epoch`` is set to False, ``windows_size`` should not be
- `epoch` to statistics log value by epoch.
ignore_last (bool): Ignore the log of last iterations in each epoch if
the number of remaining iterations is less than :attr:`interval`.
Defaults to True.
@@ -72,64 +53,24 @@ class LoggerHook(Hook):
Defaults to None.
Examples:
- >>> # `log_name` is defined, `loss_mean_window` will be an additional
- >>> # record.
- >>> logger_hook_cfg = dict(by_epoch=True,
- >>> custom_keys=dict(
- >>> loss=dict(
- >>> log_name='loss_mean_window',
- >>> method_name='mean',
- >>> window_size=10)))
- >>> # `log_name` is not defined. `loss` will be overwritten by
- >>> # `global_mean` statistics.
- >>> logger_hook_cfg = dict(by_epoch=True,
- >>> custom_keys=dict(
- >>> loss=dict(
- >>> method_name='mean',
- >>> window_size='global')))
- >>> # `time` cannot be overwritten, `global_time` will be an additional
- >>> # record.
- >>> logger_hook_cfg = dict(by_epoch=True,
- >>> custom_keys=dict(
- >>> time=dict(
- >>> log_name='global_time',
- >>> method='mean',
- >>> window_size='global')))
- >>> # Record loss with different statistics methods.
- >>> logger_hook_cfg = dict(by_epoch=True,
- >>> custom_keys=dict(loss=[
- >>> dict(log_name='loss_mean_window',
- >>> method_name='mean',
- >>> window_size=10),
- >>> dict(method_name='mean',
- >>> window_size='global')]))
+ >>> # A simplest LoggerHook config.
+ >>> logger_hook_cfg = dict(interval=20)
"""
- # eta will be calculated by time. `time` and `data_time` should not be
- # overwritten.
- fixed_smooth_keys = ('time', 'data_time')
priority = 'BELOW_NORMAL'
def __init__(
self,
- by_epoch: bool = True,
interval: int = 10,
- custom_keys: Optional[dict] = None,
ignore_last: bool = True,
interval_exp_name: int = 1000,
out_dir: Optional[Union[str, Path]] = None,
out_suffix: Union[Sequence[str], str] = ('.log.json', '.log', '.py'),
- keep_local=True,
- file_client_args=None,
+ keep_local: bool = True,
+ file_client_args: Optional[dict] = None,
):
- self._inner_iter = 0
- self.by_epoch = by_epoch
self.interval = interval
- self.custom_keys = custom_keys if custom_keys is not None else dict()
self.ignore_last = ignore_last
-
- self.time_sec_tot = 0
self.interval_exp_name = interval_exp_name
- self._check_custom_keys()
if out_dir is None and file_client_args is not None:
raise ValueError(
@@ -169,7 +110,7 @@ class LoggerHook(Hook):
f'{runner.timestamp}.log.json')
self.yaml_log_path = osp.join(runner.work_dir,
f'{runner.timestamp}.log.json')
- self.start_iter = runner.iter
+ # TODO Compatible with Visualizer.
if runner.meta is not None:
runner.writer.add_params(runner.meta, file_path=self.yaml_log_path)
@@ -178,41 +119,100 @@ class LoggerHook(Hook):
batch_idx: int,
data_batch: DATA_BATCH = None,
outputs: Optional[dict] = None) -> None:
- """Record training logs.
+ """Record training logs after training iteration.
Args:
runner (Runner): The runner of the training process.
batch_idx (int): The index of the current batch in the train loop.
- data_batch (Sequence[BaseDataElement], optional): Data from
- dataloader. Defaults to None.
+ data_batch (Sequence[Tuple[Any, BaseDataElement]], optional):
+ Data from dataloader. Defaults to None.
outputs (dict, optional): Outputs from model.
Defaults to None.
"""
- self._inner_iter = batch_idx
- if runner.meta is not None and 'exp_name' in runner.meta:
- if (self.every_n_iters(runner, self.interval_exp_name)) or (
- self.by_epoch and self.end_of_epoch(runner, batch_idx)):
- exp_info = f'Exp name: {runner.meta["exp_name"]}'
- runner.logger.info(exp_info)
- if self.by_epoch and self.every_n_inner_iters(batch_idx,
- self.interval):
- self._log_train(runner)
- elif not self.by_epoch and self.every_n_iters(runner, self.interval):
- self._log_train(runner)
- elif self.end_of_epoch(runner, batch_idx) and not self.ignore_last:
+ # Print experiment name every n iterations.
+ if self.every_n_iters(runner,
+ self.interval_exp_name) or (self.end_of_epoch(
+ runner.train_dataloader, batch_idx)):
+ exp_info = f'Exp name: {runner.experiment_name}'
+ runner.logger.info(exp_info)
+ if self.every_n_inner_iters(batch_idx, self.interval):
+ tag, log_str = runner.log_processor.get_log_after_iter(
+ runner, batch_idx, 'train')
+ elif (self.end_of_epoch(runner.train_dataloader, batch_idx)
+ and not self.ignore_last):
# `runner.max_iters` may not be divisible by `self.interval`. if
# `self.ignore_last==True`, the log of remaining iterations will
# be recorded (Epoch [4][1000/1007], the logs of 998-1007
# iterations will be recorded).
- self._log_train(runner)
+ tag, log_str = runner.log_processor.get_log_after_iter(
+ runner, batch_idx, 'train')
+ else:
+ return
+ runner.logger.info(log_str)
+ # TODO compatible with visualizer.
+ runner.writer.add_scalars(tag, step=runner.iter + 1)
+
+ def after_val_iter(
+ self,
+ runner,
+ batch_idx: int,
+ data_batch: DATA_BATCH = None,
+ outputs: Optional[Sequence[BaseDataElement]] = None) -> None:
+ """Record validation logs after validation iteration.
+
+ Args:
+ runner (Runner): The runner of the training process.
+ batch_idx (int): The index of the current batch in the train loop.
+ data_batch (Sequence[Tuple[Any, BaseDataElement]], optional):
+ Data from dataloader. Defaults to None.
+ outputs (sequence, optional): Outputs from model. Defaults to None.
+ """
+ if self.every_n_inner_iters(batch_idx, self.interval):
+ tag, log_str = runner.log_processor.get_log_after_iter(
+ runner, batch_idx, 'val')
+ runner.logger.info(log_str)
+
+ def after_test_iter(
+ self,
+ runner,
+ batch_idx: int,
+ data_batch: DATA_BATCH = None,
+ outputs: Optional[Sequence[BaseDataElement]] = None) -> None:
+ """Record testing logs after iteration.
+
+ Args:
+ runner (Runner): The runner of the training process.
+ batch_idx (int): The index of the current batch in the train loop.
+ data_batch (Sequence[Tuple[Any, BaseDataElement]], optional):
+ Data from dataloader. Defaults to None.
+ outputs (sequence, optional): Outputs from model. Defaults to None.
+ """
+ if self.every_n_inner_iters(batch_idx, self.interval):
+ tag, log_str = runner.log_processor.get_log_after_iter(
+ runner, batch_idx, 'test')
+ runner.logger.info(log_str)
def after_val_epoch(self, runner) -> None:
- """Record validation logs.
+ """Record validation logs after validation epoch.
Args:
runner (Runner): The runner of the training process.
"""
- self._log_val(runner)
+ tag, log_str = runner.log_processor.get_log_after_epoch(
+ runner, len(runner.val_dataloader), 'val')
+ runner.logger.info(log_str)
+ # TODO compatible with visualizer.
+ runner.writer.add_scalars(tag, step=runner.iter + 1)
+
+ def after_test_epoch(self, runner) -> None:
+ """Record testing logs after test epoch.
+
+ Args:
+ runner (Runner): The runner of the training process.
+ """
+ tag, log_str = runner.log_processor.get_log_after_epoch(
+ runner, len(runner.val_dataloader), 'test')
+ runner.logger.info(log_str)
def after_run(self, runner) -> None:
"""Copy logs to ``self.out_dir`` if ``self.out_dir is not None``
@@ -237,280 +237,3 @@ class LoggerHook(Hook):
os.remove(local_filepath)
runner.logger.info((f'{local_filepath} was removed due to the '
'`self.keep_local=False`'))
-
- @master_only
- def _log_train(self, runner) -> None:
- """Collect and record training logs which start named with "train/*".
-
- Args:
- runner (Runner): The runner of the training process.
- """
- tag = self._collect_info(runner, 'train')
- # The training log default defines `lr`, `momentum`, `time` and
- # `data_time`. `log_tag` will pop these keys and loop other keys to
- # `log_str`.
- log_tag = copy.deepcopy(tag)
- cur_iter = self._get_iter(runner, inner_iter=True)
- cur_epoch = self._get_epoch(runner, 'train')
-
- # Record learning rate and momentum.
- lr_str_list = []
- momentum_str_list = []
- for key, value in tag.items():
- if key.startswith('lr'):
- log_tag.pop(key)
- lr_str_list.append(f'{key}: {value:.3e}')
- lr_str = ' '.join(lr_str_list)
- for key, value in tag.items():
- if key.startswith('momentum'):
- log_tag.pop(key)
- momentum_str_list.append(f'{key}: {value:.3e}')
- momentum_str = ' '.join(momentum_str_list)
- lr_momentum_str = f'{lr_str} {momentum_str}'
- # by epoch: Epoch [4][100/1000]
- # by iter: Iter [100/100000]
- if self.by_epoch:
- log_str = f'Epoch [{cur_epoch}]' \
- f'[{cur_iter}/{len(runner.cur_dataloader)}]\t'
- else:
- log_str = f'Iter [{cur_iter}/{runner.train_loop.max_iters}]\t'
- log_str += f'{lr_momentum_str}, '
- # Calculate eta time.
- self.time_sec_tot += (tag['time'] * self.interval)
- time_sec_avg = self.time_sec_tot / (runner.iter - self.start_iter + 1)
- eta_sec = time_sec_avg * (
- runner.train_loop.max_iters - runner.iter - 1)
- eta_str = str(datetime.timedelta(seconds=int(eta_sec)))
- log_str += f'eta: {eta_str}, '
- log_str += f'time: {tag["time"]:.3f}, ' \
- f'data_time: {tag["data_time"]:.3f}, '
- # Pop recorded keys
- log_tag.pop('time')
- log_tag.pop('data_time')
- # statistic memory
- if torch.cuda.is_available():
- log_str += f'memory: {self._get_max_memory(runner)}, '
- # Loop left keys to fill `log_str`.
- log_items = []
- for name, val in log_tag.items():
- if isinstance(val, float):
- val = f'{val:.4f}'
- log_items.append(f'{name}: {val}')
- log_str += ', '.join(log_items)
- runner.logger.info(log_str)
- # Write logs to local, tensorboad, and wandb.
- runner.writer.add_scalars(
- tag, step=runner.iter + 1, file_path=self.json_log_path)
-
- @master_only
- def _log_val(self, runner) -> None:
- """Collect and record training logs which start named with "val/*".
-
- Args:
- runner (Runner): The runner of the training process.
- """
- tag = self._collect_info(runner, 'val')
- # Compatible with function `log` https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/hooks/logger/text.py # noqa E501
- eval_iter = len(runner.cur_dataloader)
- cur_iter = self._get_iter(runner)
- cur_epoch = self._get_epoch(runner, 'val')
- # val/test time
- # here 1000 is the length of the val dataloader
- # by epoch: Epoch[val] [4][1000]
- # by iter: Iter[val] [1000]
- if self.by_epoch:
- # runner.epoch += 1 has been done before val workflow
- log_str = f'Epoch(val) [{cur_epoch}][{eval_iter}]\t'
- else:
- log_str = f'Iter(val) [{eval_iter}]\t'
-
- log_items = []
- for name, val in tag.items():
- if isinstance(val, float):
- val = f'{val:.4f}'
- log_items.append(f'{name}: {val}')
- log_str += ', '.join(log_items)
- runner.logger.info(log_str)
- # Write tag.
- runner.writer.add_scalars(
- tag, step=cur_iter, file_path=self.json_log_path)
-
- def _get_window_size(self, runner, window_size: Union[int, str]) \
- -> int:
- """Parse window_size specified in ``self.custom_keys`` to int value.
-
- Args:
- runner (Runner): The runner of the training process.
- window_size (int or str): Smoothing scale of logs.
-
- Returns:
- int: Smoothing window for statistical methods.
- """
- if isinstance(window_size, int):
- assert window_size == self.interval, \
- 'The value of windows size must equal to LoggerHook.interval'
- return window_size
- elif window_size == 'epoch':
- return self._inner_iter + 1
- elif window_size == 'global':
- return runner.iter + 1
- else:
- raise ValueError('window_size should be int, epoch or global, but '
- f'got invalid {window_size}')
-
- def _collect_info(self, runner, mode: str) -> dict:
- """Collect log information to a dict according to mode.
-
- Args:
- runner (Runner): The runner of the training process.
- mode (str): 'train' or 'val', which means the prefix attached by
- runner.
-
- Returns:
- dict: Statistical values of logs.
- """
- tag = OrderedDict()
- log_buffers = runner.message_hub.log_scalars
- mode_log_buffers = OrderedDict()
- # Filter log_buffers which starts with `mode`.
- for prefix_key, log_buffer in log_buffers.items():
- if prefix_key.startswith(mode):
- key = prefix_key.split('/')[-1]
- mode_log_buffers[key] = log_buffer
- # Ensure all metric and lr values are latest.
- for key in mode_log_buffers:
- # Update the latest learning rate and smoothed time logs.
- if key in self.fixed_smooth_keys or key.startswith('loss'):
- tag[key] = mode_log_buffers[key].mean(self.interval)
- else:
- tag[key] = mode_log_buffers[key].current()
- # Update custom keys.
- if mode == 'train':
- for log_key, log_cfg in self.custom_keys.items():
- self._parse_custom_keys(runner, log_key,
- copy.deepcopy(log_cfg),
- mode_log_buffers, tag)
- return tag
-
- def _parse_custom_keys(self, runner, log_key: str, log_cfg: dict,
- log_buffers: OrderedDict, tag: OrderedDict) -> None:
- """Statistics logs in log_buffers according to custom_keys.
-
- Args:
- runner (Runner): The runner of the training process.
- log_key (str): log key specified in ``self.custom_keys``
- log_cfg (dict): A config dict for describing the logging
- statistics method.
- log_buffers (OrderedDict): All logs for the corresponding phase.
- tag (OrderedDict): A dict which defines all statistic values of
- logs.
- """
- if isinstance(log_cfg, list):
- log_names = set()
- for cfg in log_cfg:
- log_name = cfg.get('log_name', None)
- if log_name in log_names:
- raise KeyError(f'{cfg["log_name"]} cannot be redefined in '
- 'log_key')
- if log_name is not None:
- log_names.add(log_name)
- self._parse_custom_keys(runner, log_key, cfg, log_buffers, tag)
- assert len(log_names) == len(log_cfg) - 1, \
- f'{log_key} cannot be overwritten multiple times, please ' \
- f'check only one key does not contain `log_name` in {log_cfg}.'
- elif isinstance(log_cfg, dict):
- if 'window_size' in log_cfg:
- log_cfg['window_size'] = \
- self._get_window_size(runner, log_cfg['window_size'])
- if 'log_name' in log_cfg:
- name = log_cfg.pop('log_name')
- else:
- name = log_key
- tag[name] = log_buffers[log_key].statistics(**log_cfg).item()
- else:
- raise ValueError('The structure of `LoggerHook.custom key` is '
- 'wrong, please make sure the type of each key is '
- 'dict or list.')
-
- def _get_max_memory(self, runner) -> int:
- """Returns the maximum GPU memory occupied by tensors in megabytes (MB)
- for a given device.
-
- Args:
- runner (Runner): The runner of the training process.
-
- Returns:
- The maximum GPU memory occupied by tensors in megabytes for a given
- device.
- """
- device = getattr(runner.model, 'output_device', None)
- mem = torch.cuda.max_memory_allocated(device=device)
- mem_mb = torch.tensor([int(mem) // (1024 * 1024)],
- dtype=torch.int,
- device=device)
- torch.cuda.reset_peak_memory_stats()
- return int(mem_mb.item())
-
- def _check_custom_keys(self) -> None:
- """Check the legality of ``self.custom_keys``.
-
- If ``self.by_epoch==False``, ``window_size`` should not be "epoch". The
- key of ``self.fixed_smooth_keys`` cannot be overwritten.
- """
-
- def _check_window_size(item):
- if not self.by_epoch:
- assert item['window_size'] != 'epoch', \
- 'window_size cannot be epoch if LoggerHook.by_epoch is ' \
- 'False.'
-
- def _check_fixed_keys(key, item):
- if key in self.fixed_smooth_keys:
- assert 'log_name' in item, f'{key} cannot be overwritten by ' \
- 'custom keys!'
-
- for key, value in self.custom_keys.items():
- if isinstance(value, Sequence):
- [(_check_window_size(item), _check_fixed_keys(key, item))
- for item in value]
-
- else:
- _check_window_size(value)
- _check_fixed_keys(key, value)
-
- def _get_epoch(self, runner, mode: str) -> int:
- """Get epoch according to mode.
-
- Args:
- runner (Runner): The runner of the training process.
- mode (str): Train or val.
-
- Returns:
- int: The current epoch.
- """
- if mode == 'train':
- epoch = runner.epoch + 1
- elif mode == 'val':
- # normal val mode
- # runner.epoch += 1 has been done before val workflow
- epoch = runner.epoch
- else:
- raise ValueError(f"runner mode should be 'train' or 'val', "
- f'but got {runner.mode}')
- return epoch
-
- def _get_iter(self, runner, inner_iter=False) -> int:
- """Get the current training iteration step.
- Args:
- runner (Runner): The runner of the training process.
- inner_iter (bool): Whether to return the inner iter of an epoch.
- Defaults to False.
-
- Returns:
- int: The current global iter or inner iter.
- """
- if self.by_epoch and inner_iter:
- current_iter = self._inner_iter + 1
- else:
- current_iter = runner.iter + 1
- return current_iter
diff --git a/mmengine/hooks/optimizer_hook.py b/mmengine/hooks/optimizer_hook.py
index ff33b54a..61870de8 100644
--- a/mmengine/hooks/optimizer_hook.py
+++ b/mmengine/hooks/optimizer_hook.py
@@ -86,6 +86,9 @@ class OptimizerHook(Hook):
we keep ``outputs`` here. Defaults to None.
"""
runner.optimizer.zero_grad()
+ runner.message_hub.update_scalar(
+ 'train/lr', runner.optimizer.param_groups[0]['lr'])
+
if self.detect_anomalous_params:
self.detect_anomalous_parameters(runner.outputs['loss'], runner)
runner.outputs['loss'].backward()
diff --git a/mmengine/logging/__init__.py b/mmengine/logging/__init__.py
index ba5533c2..eeac7ff1 100644
--- a/mmengine/logging/__init__.py
+++ b/mmengine/logging/__init__.py
@@ -1,6 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .history_buffer import HistoryBuffer
+from .log_processor import LogProcessor
from .logger import MMLogger, print_log
from .message_hub import MessageHub
-__all__ = ['HistoryBuffer', 'MessageHub', 'MMLogger', 'print_log']
+__all__ = [
+ 'HistoryBuffer', 'MessageHub', 'MMLogger', 'print_log', 'LogProcessor'
+]
diff --git a/mmengine/logging/log_processor.py b/mmengine/logging/log_processor.py
new file mode 100644
index 00000000..cb97286c
--- /dev/null
+++ b/mmengine/logging/log_processor.py
@@ -0,0 +1,409 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import copy
+import datetime
+from collections import OrderedDict
+from typing import List, Optional, Tuple
+
+import torch
+
+
+class LogProcessor:
+ """A log processor used to format log information collected from
+ ``runner.message_hub.log_scalars``.
+
+ ``LogProcessor`` instance is built by runner and will format
+ ``runner.message_hub.log_scalars`` to ``tag`` and ``log_str``, which can
+ directly used by ``LoggerHook`` and ``MMLogger``. Besides, the argument
+ ``custom_cfg`` of constructor can control the statistics method of logs.
+
+ Args:
+ window_size (int): default smooth interval Defaults to 10.
+ by_epoch (bool): Whether to format logs with epoch stype. Defaults to
+ True.
+ custom_cfg (list[dict], optional): Contains multiple log config dict,
+ in which key means the data source name of log and value means the
+ statistic method and corresponding arguments used to count the
+ data source. Defaults to None
+ - If custom_cfg is None, all logs will be formatted via default
+ methods, such as smoothing loss by default window_size. If
+ custom_cfg is defined as a list of config dict, for example:
+ [dict(data_src=loss, method='mean', log_name='global_loss',
+ window_size='global')]. It means the log item ``loss`` will be
+ counted as global mean and additionally logged as ``global_loss``
+ (defined by ``log_name``). If ``log_name`` is not defined in
+ config dict, the original logged key will be overwritten.
+
+ - The original log item cannot be overwritten twice. Here is
+ an error example:
+ [dict(data_src=loss, method='mean', window_size='global'),
+ dict(data_src=loss, method='mean', window_size='epoch')].
+ Both log config dict in custom_cfg do not have ``log_name`` key,
+ which means the loss item will be overwritten twice.
+
+ - For those statistic methods with the ``window_size`` argument,
+ if ``by_epoch`` is set to False, ``windows_size`` should not be
+ `epoch` to statistics log value by epoch.
+
+ Examples:
+ >>> # `log_name` is defined, `loss_large_window` will be an additional
+ >>> # record.
+ >>> log_processor = dict(
+ >>> window_size=10,
+ >>> by_epoch=True,
+ >>> custom_cfg=[dict(data_src='loss',
+ >>> log_name='loss_large_window',
+ >>> method_name='mean',
+ >>> window_size=100)])
+ >>> # `log_name` is not defined. `loss` will be overwritten.
+ >>> log_processor = dict(
+ >>> window_size=10,
+ >>> by_epoch=True,
+ >>> custom_cfg=[dict(data_src='loss',
+ >>> method_name='mean',
+ >>> window_size=100)])
+ >>> # Record loss with different statistics methods.
+ >>> log_processor = dict(
+ >>> window_size=10,
+ >>> by_epoch=True,
+ >>> custom_cfg=[dict(data_src='loss',
+ >>> log_name='loss_large_window',
+ >>> method_name='mean',
+ >>> window_size=100),
+ >>> dict(data_src='loss',
+ >>> method_name='mean',
+ >>> window_size=100)])
+ >>> # Overwrite loss item twice will raise an error.
+ >>> log_processor = dict(
+ >>> window_size=10,
+ >>> by_epoch=True,
+ >>> custom_cfg=[dict(data_src='loss',
+ >>> method_name='mean',
+ >>> window_size=100),
+ >>> dict(data_src='loss',
+ >>> method_name='max',
+ >>> window_size=100)])
+ AssertionError
+ """
+
+ def __init__(self,
+ window_size=10,
+ by_epoch=True,
+ custom_cfg: Optional[List[dict]] = None):
+ self.window_size = window_size
+ self.by_epoch = by_epoch
+ self.custom_cfg = custom_cfg if custom_cfg else []
+ self._check_custom_cfg()
+
+ def get_log_after_iter(self, runner, batch_idx: int,
+ mode: str) -> Tuple[dict, str]:
+ """Format log string after training, validation or testing epoch.
+
+ Args:
+ runner (Runner): The runner of training phase.
+ batch_idx (int): The index of the current batch in the current
+ loop.
+ mode (str): Current mode of runner, train, test or val.
+
+ Return:
+ Tuple(dict, str): Formatted log dict/string which will be
+ recorded by :obj:`runner.message_hub` and :obj:`runner.visualizer`.
+ """
+ assert mode in ['train', 'test', 'val']
+ current_loop = self._get_cur_loop(runner, mode)
+ cur_iter = self._get_iter(runner, batch_idx=batch_idx)
+ # Overwrite ``window_size`` defined in ``custom_cfg`` to int value.
+ custom_cfg_copy = self._parse_windows_size(runner, batch_idx)
+ # tag is used to write log information to different backends.
+ tag = self._collect_scalars(custom_cfg_copy, runner, mode)
+ # `log_tag` will pop 'lr' and loop other keys to `log_str`.
+ log_tag = copy.deepcopy(tag)
+ # Record learning rate.
+ lr_str_list = []
+ for key, value in tag.items():
+ if key.startswith('lr'):
+ log_tag.pop(key)
+ lr_str_list.append(f'{key}: {value:.3e}')
+ lr_str = ' '.join(lr_str_list)
+ # Format log header.
+ # by_epoch == True
+ # train/val: Epoch [5][5/10] ...
+ # test: Epoch [5/10]
+ # by_epoch == False
+ # train: Epoch [5/10000] ... (divided by `max_iter`)
+ # val/test: Epoch [5/2000] ... (divided by length of dataloader)
+ if self.by_epoch:
+ if mode in ['train', 'val']:
+ cur_epoch = self._get_epoch(runner, mode)
+ log_str = (f'Epoch({mode}) [{cur_epoch}]'
+ f'[{cur_iter}/{len(current_loop.dataloader)}] ')
+ else:
+ log_str = (f'Epoch({mode}) '
+ f'[{cur_iter}/{len(current_loop.dataloader)}] ')
+ else:
+ if mode == 'train':
+ log_str = (f'Iter({mode}) '
+ f'[{cur_iter}/{runner.train_loop.max_iters}] ')
+ else:
+ log_str = (f'Iter({mode}) [{batch_idx+1}'
+ f'/{len(current_loop.dataloader)}] ')
+ # Concatenate lr, momentum string with log header.
+ log_str += f'{lr_str} '
+ # If IterTimerHook used in runner, eta, time, and data_time should be
+ # recorded.
+ if (all(item in tag for item in ['time', 'data_time'])
+ and 'eta' in runner.message_hub.runtime_info):
+ eta = runner.message_hub.get_info('eta')
+ eta_str = str(datetime.timedelta(seconds=int(eta)))
+ log_str += f'eta: {eta_str} '
+ log_str += (f'time: {tag["time"]:.3f} '
+ f'data_time: {tag["data_time"]:.3f} ')
+ # Pop recorded keys
+ log_tag.pop('time')
+ log_tag.pop('data_time')
+
+ # If cuda is available, the max memory occupied should be calculated.
+ if torch.cuda.is_available():
+ log_str += f'memory: {self._get_max_memory(runner)} '
+ # Loop left keys to fill `log_str`.
+ if mode in ('train', 'val'):
+ log_items = []
+ for name, val in log_tag.items():
+ if mode == 'val' and not name.startswith('loss'):
+ continue
+ if isinstance(val, float):
+ val = f'{val:.4f}'
+ log_items.append(f'{name}: {val}')
+ log_str += ' '.join(log_items)
+ return tag, log_str
+
+ def get_log_after_epoch(self, runner, batch_idx: int,
+ mode: str) -> Tuple[dict, str]:
+ """Format log string after validation or testing epoch.
+
+ Args:
+ runner (Runner): The runner of training phase.
+ batch_idx (int): The index of the current batch in the current
+ loop.
+ mode (str): Current mode of runner.
+
+ Return:
+ Tuple(dict, str): Formatted log dict/string which will be
+ recorded by :obj:`runner.message_hub` and :obj:`runner.visualizer`.
+ """
+ assert mode in [
+ 'test', 'val'
+ ], ('`_get_metric_log_str` only accept val or test mode, but got '
+ f'{mode}')
+ cur_loop = self._get_cur_loop(runner, mode)
+ dataloader_len = len(cur_loop.dataloader)
+
+ custom_cfg_copy = self._parse_windows_size(runner, batch_idx)
+ # tag is used to write log information to different backends.
+ tag = self._collect_scalars(custom_cfg_copy, runner, mode)
+ # validation log string needs cur epoch/iteration and max
+ # epochs/iterations. test log string only needs length of test
+ # dataloader.
+ cur_iter = self._get_iter(runner, batch_idx)
+ if self.by_epoch:
+ if mode == 'val':
+ cur_epoch = self._get_epoch(runner, mode)
+ log_str = (f'Epoch({mode}) [{cur_epoch}][{dataloader_len}/'
+ f'{dataloader_len}] ')
+ else:
+ log_str = (
+ f'Epoch({mode}) [{dataloader_len}/{dataloader_len}] ')
+
+ else:
+ if mode == 'train':
+ log_str = (f'Iter({mode}) [{cur_iter}/'
+ f'{runner.train_loop.max_iters}] ')
+ else:
+ log_str = (
+ f'Iter({mode}) [{dataloader_len}/{dataloader_len}] ')
+ log_items = []
+ for name, val in tag.items():
+ if name in ('time', 'data_time'):
+ continue
+ if isinstance(val, float):
+ val = f'{val:.4f}'
+ log_items.append(f'{name}: {val}')
+ log_str += ' '.join(log_items)
+ return tag, log_str
+
+ def _collect_scalars(self, custom_cfg: List[dict], runner,
+ mode: str) -> dict:
+ """Collect log information to compose a dict according to mode.
+
+ Args:
+ custom_cfg (List[dict]): A copy of ``self.custom_cfg`` with int
+ ``window_size``.
+ runner (Runner): The runner of the training process.
+ mode (str): 'train' or 'val', which means the prefix attached by
+ runner.
+
+ Returns:
+ dict: Statistical values of logs.
+ """
+ tag = OrderedDict()
+ # history_scalars of train/val/test phase.
+ history_scalars = runner.message_hub.log_scalars
+ # corresponding mode history_scalars
+ mode_history_scalars = OrderedDict()
+ # extract log scalars and remove prefix to `mode_history_scalars`
+ # according to mode.
+ for prefix_key, log_buffer in history_scalars.items():
+ if prefix_key.startswith(mode):
+ key = prefix_key.split('/')[-1]
+ mode_history_scalars[key] = log_buffer
+ for key in mode_history_scalars:
+ # Update the latest learning rate and smoothed time logs.
+ if key.startswith('loss'):
+ tag[key] = mode_history_scalars[key].mean(self.window_size)
+ else:
+ # Default statistic method is current.
+ tag[key] = mode_history_scalars[key].current()
+ # Update custom keys.
+ for log_cfg in custom_cfg:
+ data_src = log_cfg.pop('data_src')
+ if 'log_name' in log_cfg:
+ log_name = log_cfg.pop('log_name')
+ else:
+ log_name = data_src
+ # log item in custom_cfg could only exist in train or val
+ # mode.
+ if data_src in mode_history_scalars:
+ tag[log_name] = mode_history_scalars[data_src].statistics(
+ **log_cfg)
+ return tag
+
+ def _check_custom_cfg(self) -> None:
+ """Check the legality of ``self.custom_cfg``."""
+
+ def _check_window_size():
+ for log_cfg in self.custom_cfg:
+ if not self.by_epoch:
+ assert log_cfg['window_size'] != 'epoch', \
+ 'window_size cannot be epoch if LoggerHook.by_epoch' \
+ ' is False.'
+
+ def _check_repeated_log_name():
+ check_dict = dict()
+ # The `log_name` of the same data_src should not be repeated.
+ # If `log_name` is not specified, `data_src` will be overwritten.
+ # But only allowed to be overwritten once.
+ for log_cfg in self.custom_cfg:
+ assert 'data_src' in log_cfg
+ data_src = log_cfg['data_src']
+ log_name = log_cfg.get('log_name', data_src)
+ check_dict.setdefault(data_src,
+ dict(log_names=set(), log_counts=0))
+ check_dict[data_src]['log_names'].add(log_name)
+ check_dict[data_src]['log_counts'] += 1
+ assert (len(
+ check_dict[data_src]
+ ['log_names']) == check_dict[data_src]['log_counts']), (
+ f'If you want to statistic {data_src} with multiple '
+ 'statistics method, please check `log_name` is unique'
+ f'and {data_src} will not be overwritten twice. See '
+ f'more information in the docstring of `LogProcessor`')
+
+ _check_repeated_log_name()
+ _check_window_size()
+
+ def _parse_windows_size(self, runner, batch_idx: int) -> list:
+ """Parse window_size defined in custom_cfg to int value.
+
+ Args:
+ runner (Runner): The runner of the training process.
+ batch_idx (int): The iteration index of current dataloader.
+ """
+ custom_cfg_copy = copy.deepcopy(self.custom_cfg)
+ for log_cfg in custom_cfg_copy:
+ window_size = log_cfg.get('window_size', None)
+ if window_size is None or isinstance(window_size, int):
+ continue
+ elif window_size == 'epoch':
+ log_cfg['window_size'] = batch_idx + 1
+ elif window_size == 'global':
+ log_cfg['window_size'] = runner.iter + 1
+ else:
+ raise TypeError(
+ 'window_size should be int, epoch or global, but got '
+ f'invalid {window_size}')
+ return custom_cfg_copy
+
+ def _get_max_memory(self, runner) -> int:
+ """Returns the maximum GPU memory occupied by tensors in megabytes (MB)
+ for a given device.
+
+ Args:
+ runner (Runner): The runner of the training process.
+
+ Returns:
+ The maximum GPU memory occupied by tensors in megabytes for a given
+ device.
+ """
+ device = getattr(runner.model, 'output_device', None)
+ mem = torch.cuda.max_memory_allocated(device=device)
+ mem_mb = torch.tensor([int(mem) // (1024 * 1024)],
+ dtype=torch.int,
+ device=device)
+ torch.cuda.reset_peak_memory_stats()
+ return int(mem_mb.item())
+
+ def _get_iter(self, runner, batch_idx: int = None) -> int:
+ """Get current training iteration step.
+
+ Args:
+ runner (Runner): The runner of the training process.
+ batch_idx (int, optional): The interaction index of current
+ dataloader. Defaults to None.
+
+ Returns:
+ int: The current global iter or inner iter.
+ """
+ if self.by_epoch and batch_idx:
+ current_iter = batch_idx + 1
+ else:
+ current_iter = runner.iter + 1
+ return current_iter
+
+ def _get_epoch(self, runner, mode: str) -> int:
+ """Get current epoch according to mode.
+
+ Args:
+ runner (Runner): The runner of the training/validation process.
+ mode (str): Current mode of runner, "train" or "val".
+
+ Returns:
+ int: The current epoch.
+ """
+ if mode == 'train':
+ epoch = runner.epoch + 1
+ elif mode == 'val':
+ # normal val mode
+ # runner.epoch += 1 has been done before validation
+ epoch = runner.epoch
+ else:
+ raise ValueError(
+ f"runner mode should be 'train' or 'val', but got {mode}")
+ return epoch
+
+ def _get_cur_loop(self, runner, mode: str):
+ """Get current loop according to mode.
+
+ Args:
+ runner (Runner): The runner of the training/validation/testing
+ process.
+ mode (str): Current mode of runner, "train", "val" or test.
+
+ Returns:
+ BaseLoop: Current loop of runner.
+ """
+ # returns type hint will occur circular import
+ if mode == 'train':
+ return runner.train_loop
+ elif mode == 'val':
+ return runner.val_loop
+ else:
+ return runner.test_loop
diff --git a/mmengine/logging/logger.py b/mmengine/logging/logger.py
index 3ae26524..6066449f 100644
--- a/mmengine/logging/logger.py
+++ b/mmengine/logging/logger.py
@@ -32,15 +32,15 @@ class MMFormatter(logging.Formatter):
info_prefix = self._get_prefix('INFO', color)
debug_prefix = self._get_prefix('DEBUG', color)
# Config output format.
- self.err_format = f'%(asctime)s - %(name)s - {error_prefix} - ' \
- f'%(pathname)s - %(funcName)s - %(lineno)d - ' \
- '%(message)s'
- self.warn_format = f'%(asctime)s - %(name)s - {warn_prefix} - %(' \
- 'message)s'
- self.info_format = f'%(asctime)s - %(name)s - {info_prefix} - %(' \
- 'message)s'
- self.debug_format = f'%(asctime)s - %(name)s - {debug_prefix} - %(' \
- 'message)s'
+ self.err_format = (f'%(asctime)s - %(name)s - {error_prefix} - '
+ '%(pathname)s - %(funcName)s - %(lineno)d - '
+ '%(message)s')
+ self.warn_format = (f'%(asctime)s - %(name)s - {warn_prefix} - %('
+ 'message)s')
+ self.info_format = (f'%(asctime)s - %(name)s - {info_prefix} - %('
+ 'message)s')
+ self.debug_format = (f'%(asctime)s - %(name)s - {debug_prefix} - %('
+ 'message)s')
def _get_prefix(self, level: str, color: bool) -> str:
"""Get the prefix of the target log level.
diff --git a/mmengine/runner/runner.py b/mmengine/runner/runner.py
index a1b11511..e43c0b08 100644
--- a/mmengine/runner/runner.py
+++ b/mmengine/runner/runner.py
@@ -25,7 +25,7 @@ from mmengine.dist import (broadcast, get_dist_info, init_dist, master_only,
sync_random_seed)
from mmengine.evaluator import Evaluator
from mmengine.hooks import Hook
-from mmengine.logging import MessageHub, MMLogger
+from mmengine.logging import LogProcessor, MessageHub, MMLogger
from mmengine.model import is_model_wrapper
from mmengine.optim import _ParamScheduler, build_optimizer
from mmengine.registry import (DATA_SAMPLERS, DATASETS, HOOKS, LOOPS,
@@ -127,6 +127,8 @@ class Runner:
non-distributed environment will be launched.
env_cfg (dict): A dict used for setting environment. Defaults to
dict(dist_cfg=dict(backend='nccl')).
+ log_processor (dict, optional): A processor to format logs. Defaults to
+ None.
log_level (int or str): The log level of MMLogger handlers.
Defaults to 'INFO'.
writer (ComposedWriter or dict, optional): A ComposedWriter object or a
@@ -184,6 +186,7 @@ class Runner:
param_scheduler=dict(type='ParamSchedulerHook')),
launcher='none',
env_cfg=dict(dist_cfg=dict(backend='nccl')),
+ log_processor=dict(window_size=20),
writer=dict(
name='composed_writer',
writers=[dict(type='LocalWriter', save_dir='temp_dir')])
@@ -218,6 +221,7 @@ class Runner:
launcher: str = 'none',
env_cfg: Dict = dict(dist_cfg=dict(backend='nccl')),
log_level: str = 'INFO',
+ log_processor: Optional[Dict] = None,
writer: Optional[Union[ComposedWriter, Dict]] = None,
default_scope: Optional[str] = None,
randomness: Dict = dict(seed=None),
@@ -310,6 +314,10 @@ class Runner:
else:
self._experiment_name = self.timestamp
+ log_processor = dict() if log_processor is None else log_processor
+ self.log_processor = LogProcessor(**log_processor)
+ # Since `get_instance` could return any subclass of ManagerMixin. The
+ # corresponding attribute needs a type hint.
self.logger = self.build_logger(log_level=log_level)
# Build `message_hub` for communication among components.
# `message_hub` can store log scalars (loss, learning rate) and
@@ -385,6 +393,7 @@ class Runner:
resume=cfg.get('resume', False),
launcher=cfg.get('launcher', 'none'),
env_cfg=cfg.get('env_cfg'), # type: ignore
+ log_processor=cfg.get('log_processor'),
log_level=cfg.get('log_level', 'INFO'),
writer=cfg.get('writer'),
default_scope=cfg.get('default_scope'),
diff --git a/tests/test_hook/test_hook.py b/tests/test_hook/test_hook.py
index db80ed4a..771c54f6 100644
--- a/tests/test_hook/test_hook.py
+++ b/tests/test_hook/test_hook.py
@@ -157,18 +157,17 @@ class TestHook:
def test_end_of_epoch(self):
hook = Hook()
- runner = Mock()
# last inner iter
batch_idx = 1
- runner.cur_dataloader.__len__ = Mock(return_value=2)
- runner.cur_dataloader.__len__ = Mock(return_value=2)
- return_val = hook.end_of_epoch(runner, batch_idx)
+ dataloader = Mock()
+ dataloader.__len__ = Mock(return_value=2)
+ return_val = hook.end_of_epoch(dataloader, batch_idx)
assert return_val
# not the last inner iter
batch_idx = 0
- return_val = hook.end_of_epoch(runner, batch_idx)
+ return_val = hook.end_of_epoch(dataloader, batch_idx)
assert not return_val
def test_is_last_train_epoch(self):
diff --git a/tests/test_hook/test_iter_timer_hook.py b/tests/test_hook/test_iter_timer_hook.py
index af149f2f..8d3dfb9d 100644
--- a/tests/test_hook/test_iter_timer_hook.py
+++ b/tests/test_hook/test_iter_timer_hook.py
@@ -1,29 +1,70 @@
# Copyright (c) OpenMMLab. All rights reserved.
-from unittest.mock import Mock
+from unittest import TestCase
+from unittest.mock import MagicMock, Mock, patch
from mmengine.hooks import IterTimerHook
+from mmengine.logging import MessageHub
-class TestIterTimerHook:
+def time_patch():
+ if not hasattr(time_patch, 'time'):
+ time_patch.time = 0
+ else:
+ time_patch.time += 1
+ return time_patch.time
+
+
+class TestIterTimerHook(TestCase):
+
+ def setUp(self) -> None:
+ self.hook = IterTimerHook()
+
+ def test_init(self):
+ assert self.hook.time_sec_tot == 0
+ assert self.hook.start_iter == 0
+
+ def test_before_run(self):
+ runner = MagicMock()
+ runner.iter = 1
+ self.hook.before_run(runner)
+ assert self.hook.start_iter == 1
def test_before_epoch(self):
- hook = IterTimerHook()
runner = Mock()
- hook._before_epoch(runner)
- assert isinstance(hook.t, float)
+ self.hook._before_epoch(runner)
+ assert isinstance(self.hook.t, float)
+ @patch('time.time', MagicMock(return_value=1))
def test_before_iter(self):
- hook = IterTimerHook()
- runner = Mock()
+ runner = MagicMock()
runner.log_buffer = dict()
- hook._before_epoch(runner)
- hook._before_iter(runner, 0)
- runner.message_hub.update_scalar.assert_called()
+ self.hook._before_epoch(runner)
+ for mode in ('train', 'val', 'test'):
+ self.hook._before_iter(runner, batch_idx=1, mode=mode)
+ runner.message_hub.update_scalar.assert_called_with(
+ f'{mode}/data_time', 0)
+ @patch('time.time', time_patch)
def test_after_iter(self):
- hook = IterTimerHook()
- runner = Mock()
+ runner = MagicMock()
runner.log_buffer = dict()
- hook._before_epoch(runner)
- hook._after_iter(runner, 0)
+ runner.log_processor.window_size = 10
+ runner.train_loop.max_iters = 100
+ runner.iter = 0
+ runner.test_loop.dataloader = [0] * 20
+ runner.val_loop.dataloader = [0] * 20
+ self.hook._before_epoch(runner)
+ self.hook.before_run(runner)
+ self.hook._after_iter(runner, batch_idx=1)
runner.message_hub.update_scalar.assert_called()
+ runner.message_hub.get_log.assert_not_called()
+ runner.message_hub.update_info.assert_not_called()
+ runner.message_hub = MessageHub.get_instance('test_iter_timer_hook')
+ runner.iter = 9
+ # eta = (100 - 10) / 1
+ self.hook._after_iter(runner, batch_idx=89)
+ assert runner.message_hub.get_info('eta') == 90
+ self.hook._after_iter(runner, batch_idx=9, mode='val')
+ assert runner.message_hub.get_info('eta') == 10
+ self.hook._after_iter(runner, batch_idx=19, mode='test')
+ assert runner.message_hub.get_info('eta') == 0
diff --git a/tests/test_hook/test_logger_hook.py b/tests/test_hook/test_logger_hook.py
index cac2e45b..3caed5dd 100644
--- a/tests/test_hook/test_logger_hook.py
+++ b/tests/test_hook/test_logger_hook.py
@@ -1,13 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
-import datetime
-import logging
import os.path as osp
-import sys
-from collections import OrderedDict
-from unittest.mock import MagicMock, patch
+from unittest.mock import MagicMock
import pytest
-import torch
from mmengine.fileio.file_client import HardDiskBackend
from mmengine.hooks import LoggerHook
@@ -17,11 +12,8 @@ class TestLoggerHook:
def test_init(self):
logger_hook = LoggerHook(out_dir='tmp.txt')
- assert logger_hook.by_epoch
assert logger_hook.interval == 10
- assert not logger_hook.custom_keys
assert logger_hook.ignore_last
- assert logger_hook.time_sec_tot == 0
assert logger_hook.interval_exp_name == 1000
assert logger_hook.out_suffix == ('.log.json', '.log', '.py')
assert logger_hook.keep_local
@@ -30,22 +22,7 @@ class TestLoggerHook:
# out_dir should be None or string or tuple of string.
with pytest.raises(TypeError):
LoggerHook(out_dir=1)
- # time cannot be overwritten.
- with pytest.raises(AssertionError):
- LoggerHook(custom_keys=dict(time=dict(method='max')))
- LoggerHook(
- custom_keys=dict(time=[
- dict(method='max', log_name='time_max'),
- dict(method='min', log_name='time_min')
- ]))
- # Epoch window_size cannot be used when `LoggerHook.by_epoch=False`
- with pytest.raises(AssertionError):
- LoggerHook(
- by_epoch=False,
- custom_keys=dict(
- time=dict(
- method='max', log_name='time_max',
- window_size='epoch')))
+
with pytest.raises(ValueError):
LoggerHook(file_client_args=dict(enable_mc=True))
@@ -60,20 +37,23 @@ class TestLoggerHook:
assert logger_hook.out_dir == osp.join('out_dir', 'work_dir')
assert logger_hook.json_log_path == osp.join('work_dir',
'timestamp.log.json')
- assert logger_hook.start_iter == runner.iter
runner.writer.add_params.assert_called()
def test_after_run(self, tmp_path):
+ # Test
out_dir = tmp_path / 'out_dir'
out_dir.mkdir()
work_dir = tmp_path / 'work_dir'
work_dir.mkdir()
work_dir_json = work_dir / 'tmp.log.json'
- json_f = open(work_dir_json, 'w')
- json_f.close()
runner = MagicMock()
runner.work_dir = work_dir
-
+ # Test without out_dir.
+ logger_hook = LoggerHook()
+ logger_hook.after_run(runner)
+ # Test with out_dir and make sure json file has been moved to out_dir.
+ json_f = open(work_dir_json, 'w')
+ json_f.close()
logger_hook = LoggerHook(out_dir=str(tmp_path), keep_local=False)
logger_hook.out_dir = str(out_dir)
logger_hook.after_run(runner)
@@ -84,274 +64,83 @@ class TestLoggerHook:
def test_after_train_iter(self):
# Test LoggerHook by iter.
runner = MagicMock()
- runner.iter = 10
- batch_idx = 5
- logger_hook = LoggerHook(by_epoch=False)
- logger_hook._log_train = MagicMock()
- logger_hook.after_train_iter(runner, batch_idx=batch_idx)
+ runner.log_processor.get_log_after_iter = MagicMock(
+ return_value=(dict(), 'log_str'))
+ logger_hook = LoggerHook()
+ logger_hook.after_train_iter(runner, batch_idx=5)
# `cur_iter=10+1`, which cannot be exact division by
# `logger_hook.interval`
- logger_hook._log_train.assert_not_called()
- runner.iter = 9
- logger_hook.after_train_iter(runner, batch_idx=batch_idx)
- logger_hook._log_train.assert_called()
+ runner.log_processor.get_log_after_iter.assert_not_called()
+ logger_hook.after_train_iter(runner, batch_idx=9)
+ runner.log_processor.get_log_after_iter.assert_called()
# Test LoggerHook by epoch.
- logger_hook = LoggerHook(by_epoch=True)
- logger_hook._log_train = MagicMock()
- # Only `runner.inner_iter` will work.
- runner.iter = 9
- batch_idx = 10
- logger_hook.after_train_iter(runner, batch_idx=batch_idx)
- logger_hook._log_train.assert_not_called()
- batch_idx = 9
- logger_hook.after_train_iter(runner, batch_idx=batch_idx)
- logger_hook._log_train.assert_called()
+ logger_hook = LoggerHook()
+ runner = MagicMock()
+ runner.log_processor.get_log_after_iter = MagicMock(
+ return_value=(dict(), 'log_str'))
+ # Only `batch_idx` will work.
+ logger_hook.after_train_iter(runner, batch_idx=10)
+ runner.log_processor.get_log_after_iter.assert_not_called()
+ logger_hook.after_train_iter(runner, batch_idx=9)
+ runner.log_processor.get_log_after_iter.assert_called()
# Test end of the epoch.
- logger_hook = LoggerHook(by_epoch=True, ignore_last=False)
- logger_hook._log_train = MagicMock()
- runner.cur_dataloader = [0] * 5
- batch_idx = 4
- logger_hook.after_train_iter(runner, batch_idx=batch_idx)
- logger_hook._log_train.assert_called()
+ runner = MagicMock()
+ runner.log_processor.get_log_after_iter = MagicMock(
+ return_value=(dict(), 'log_str'))
+ logger_hook = LoggerHook(ignore_last=False)
+ runner.train_dataloader = [0] * 5
+ logger_hook.after_train_iter(runner, batch_idx=4)
+ runner.log_processor.get_log_after_iter.assert_called()
# Test print exp_name
+ runner = MagicMock()
+ runner.log_processor.get_log_after_iter = MagicMock(
+ return_value=(dict(), 'log_str'))
runner.meta = dict(exp_name='retinanet')
- logger_hook = LoggerHook()
runner.logger = MagicMock()
- logger_hook._log_train = MagicMock()
- logger_hook.after_train_iter(runner, batch_idx=batch_idx)
- runner.logger.info.assert_called_with(
- f'Exp name: {runner.meta["exp_name"]}')
+ logger_hook = LoggerHook()
+ logger_hook.after_train_iter(runner, batch_idx=999)
+ runner.logger.info.assert_called()
def test_after_val_epoch(self):
logger_hook = LoggerHook()
runner = MagicMock()
- logger_hook._log_val = MagicMock()
+ runner.log_processor.get_log_after_epoch = MagicMock(
+ return_value=(dict(), 'string'))
logger_hook.after_val_epoch(runner)
- logger_hook._log_val.assert_called()
+ runner.log_processor.get_log_after_epoch.assert_called()
+ runner.logger.info.assert_called()
+ runner.writer.add_scalars.assert_called()
- @pytest.mark.parametrize('by_epoch', [True, False])
- def test_log_train(self, by_epoch, capsys):
- runner = self._setup_runner()
- runner.meta = dict(exp_name='retinanet')
- # Prepare LoggerHook
- logger_hook = LoggerHook(by_epoch=by_epoch)
- logger_hook._inner_iter = 1
- logger_hook.writer = MagicMock()
- logger_hook.time_sec_tot = 1000
- logger_hook.start_iter = 0
- logger_hook._get_max_memory = MagicMock(return_value='100')
- logger_hook.json_log_path = 'tmp.json'
-
- # Prepare training information.
- train_infos = dict(
- lr=0.1, momentum=0.9, time=1.0, data_time=1.0, loss_cls=1.0)
- logger_hook._collect_info = MagicMock(return_value=train_infos)
- logger_hook._log_train(runner)
- # Verify that the correct variables have been written.
- runner.writer.add_scalars.assert_called_with(
- train_infos, step=11, file_path='tmp.json')
- # Verify that the correct context have been logged.
- out, _ = capsys.readouterr()
- time_avg = logger_hook.time_sec_tot / (
- runner.iter + 1 - logger_hook.start_iter)
- eta_second = time_avg * (runner.train_loop.max_iters - runner.iter - 1)
- eta_str = str(datetime.timedelta(seconds=int(eta_second)))
- if by_epoch:
- if torch.cuda.is_available():
- log_str = 'Epoch [2][2/5]\t' \
- f"lr: {train_infos['lr']:.3e} " \
- f"momentum: {train_infos['momentum']:.3e}, " \
- f'eta: {eta_str}, ' \
- f"time: {train_infos['time']:.3f}, " \
- f"data_time: {train_infos['data_time']:.3f}, " \
- f'memory: 100, ' \
- f"loss_cls: {train_infos['loss_cls']:.4f}\n"
- else:
- log_str = 'Epoch [2][2/5]\t' \
- f"lr: {train_infos['lr']:.3e} " \
- f"momentum: {train_infos['momentum']:.3e}, " \
- f'eta: {eta_str}, ' \
- f"time: {train_infos['time']:.3f}, " \
- f"data_time: {train_infos['data_time']:.3f}, " \
- f"loss_cls: {train_infos['loss_cls']:.4f}\n"
- assert out == log_str
- else:
- if torch.cuda.is_available():
- log_str = 'Iter [11/50]\t' \
- f"lr: {train_infos['lr']:.3e} " \
- f"momentum: {train_infos['momentum']:.3e}, " \
- f'eta: {eta_str}, ' \
- f"time: {train_infos['time']:.3f}, " \
- f"data_time: {train_infos['data_time']:.3f}, " \
- f'memory: 100, ' \
- f"loss_cls: {train_infos['loss_cls']:.4f}\n"
- else:
- log_str = 'Iter [11/50]\t' \
- f"lr: {train_infos['lr']:.3e} " \
- f"momentum: {train_infos['momentum']:.3e}, " \
- f'eta: {eta_str}, ' \
- f"time: {train_infos['time']:.3f}, " \
- f"data_time: {train_infos['data_time']:.3f}, " \
- f"loss_cls: {train_infos['loss_cls']:.4f}\n"
- assert out == log_str
-
- @pytest.mark.parametrize('by_epoch', [True, False])
- def test_log_val(self, by_epoch, capsys):
- runner = self._setup_runner()
- # Prepare LoggerHook.
- logger_hook = LoggerHook(by_epoch=by_epoch)
- logger_hook.json_log_path = 'tmp.json'
- metric = dict(accuracy=0.9, data_time=1.0)
- logger_hook._collect_info = MagicMock(return_value=metric)
- logger_hook._log_val(runner)
- # Verify that the correct context have been logged.
- out, _ = capsys.readouterr()
- runner.writer.add_scalars.assert_called_with(
- metric, step=11, file_path='tmp.json')
- if by_epoch:
- assert out == 'Epoch(val) [1][5]\taccuracy: 0.9000, ' \
- 'data_time: 1.0000\n'
-
- else:
- assert out == 'Iter(val) [5]\taccuracy: 0.9000, ' \
- 'data_time: 1.0000\n'
-
- def test_get_window_size(self):
- runner = self._setup_runner()
- logger_hook = LoggerHook()
- logger_hook._inner_iter = 1
- # Test get window size by name.
- assert logger_hook._get_window_size(runner, 'epoch') == 2
- assert logger_hook._get_window_size(runner, 'global') == 11
- assert logger_hook._get_window_size(runner, 10) == 10
- # Window size must equal to `logger_hook.interval`.
- with pytest.raises(AssertionError):
- logger_hook._get_window_size(runner, 20)
-
- with pytest.raises(ValueError):
- logger_hook._get_window_size(runner, 'unknwon')
-
- def test_parse_custom_keys(self):
- tag = OrderedDict()
- runner = self._setup_runner()
- log_buffers = OrderedDict(lr=MagicMock(), loss=MagicMock())
- cfg_dict = dict(
- lr=dict(method='min'),
- loss=[
- dict(method='min', window_size='global'),
- dict(method='max', log_name='loss_max')
- ])
- logger_hook = LoggerHook()
- for log_key, log_cfg in cfg_dict.items():
- logger_hook._parse_custom_keys(runner, log_key, log_cfg,
- log_buffers, tag)
- assert list(tag) == ['lr', 'loss', 'loss_max']
- assert log_buffers['lr'].min.assert_called
- assert log_buffers['loss'].min.assert_called
- assert log_buffers['loss'].max.assert_called
- assert log_buffers['loss'].mean.assert_called
- # `log_name` Cannot be repeated.
- with pytest.raises(KeyError):
- cfg_dict = dict(loss=[
- dict(method='min', window_size='global'),
- dict(method='max', log_name='loss_max'),
- dict(method='mean', log_name='loss_max')
- ])
- logger_hook.custom_keys = cfg_dict
- for log_key, log_cfg in cfg_dict.items():
- logger_hook._parse_custom_keys(runner, log_key, log_cfg,
- log_buffers, tag)
- # `log_key` cannot be overwritten multiple times.
- with pytest.raises(AssertionError):
- cfg_dict = dict(loss=[
- dict(method='min', window_size='global'),
- dict(method='max'),
- ])
- logger_hook.custom_keys = cfg_dict
- for log_key, log_cfg in cfg_dict.items():
- logger_hook._parse_custom_keys(runner, log_key, log_cfg,
- log_buffers, tag)
-
- def test_collect_info(self):
- runner = self._setup_runner()
- logger_hook = LoggerHook(
- custom_keys=dict(time=dict(method='max', log_name='time_max')))
- logger_hook._parse_custom_keys = MagicMock()
- # Collect with prefix.
- log_buffers = {
- 'train/time': MagicMock(),
- 'lr': MagicMock(),
- 'train/loss_cls': MagicMock(),
- 'val/metric': MagicMock()
- }
- runner.message_hub.log_scalars = log_buffers
- tag = logger_hook._collect_info(runner, mode='train')
- # Test parse custom_keys
- logger_hook._parse_custom_keys.assert_called()
- # Test training key in tag.
- assert list(tag.keys()) == ['time', 'loss_cls']
- # Test statistics lr with `current`, loss and time with 'mean'
- log_buffers['train/time'].mean.assert_called()
- log_buffers['train/loss_cls'].mean.assert_called()
- log_buffers['train/loss_cls'].current.assert_not_called()
-
- tag = logger_hook._collect_info(runner, mode='val')
- assert list(tag.keys()) == ['metric']
- log_buffers['val/metric'].current.assert_called()
-
- @patch('torch.cuda.max_memory_allocated', MagicMock())
- @patch('torch.cuda.reset_peak_memory_stats', MagicMock())
- def test_get_max_memory(self):
+ def test_after_test_epoch(self):
logger_hook = LoggerHook()
runner = MagicMock()
- runner.world_size = 1
- runner.model = torch.nn.Linear(1, 1)
- logger_hook._get_max_memory(runner)
- torch.cuda.max_memory_allocated.assert_called()
- torch.cuda.reset_peak_memory_stats.assert_called()
+ runner.log_processor.get_log_after_epoch = MagicMock(
+ return_value=(dict(), 'log_str'))
+ logger_hook.after_test_epoch(runner)
+ runner.log_processor.get_log_after_epoch.assert_called()
+ runner.logger.info.assert_called()
- def test_get_iter(self):
- runner = self._setup_runner()
+ def test_after_val_iter(self):
logger_hook = LoggerHook()
- logger_hook._inner_iter = 1
- # Get global iter when `inner_iter=False`
- iter = logger_hook._get_iter(runner)
- assert iter == 11
- # Get inner iter
- iter = logger_hook._get_iter(runner, inner_iter=True)
- assert iter == 2
- # Still get global iter when `logger_hook.by_epoch==False`
- logger_hook.by_epoch = False
- iter = logger_hook._get_iter(runner, inner_iter=True)
- assert iter == 11
-
- def test_get_epoch(self):
- runner = self._setup_runner()
- logger_hook = LoggerHook()
- epoch = logger_hook._get_epoch(runner, 'train')
- assert epoch == 2
- epoch = logger_hook._get_epoch(runner, 'val')
- assert epoch == 1
- with pytest.raises(ValueError):
- logger_hook._get_epoch(runner, 'test')
-
- def _setup_runner(self):
runner = MagicMock()
- runner.epoch = 1
- runner.cur_dataloader = [0] * 5
- runner.iter = 10
- runner.train_loop.max_iters = 50
- logger = logging.getLogger()
- logger.setLevel(logging.INFO)
- for handler in logger.handlers:
- if not isinstance(handler, logging.StreamHandler):
- continue
- else:
- logger.addHandler(logging.StreamHandler(stream=sys.stdout))
- runner.logger = logger
- runner.message_hub = MagicMock()
- runner.composed_wirter = MagicMock()
- return runner
+ runner.iter = 0
+ runner.log_processor.get_log_after_iter = MagicMock(
+ return_value=(dict(), 'log_str'))
+ logger_hook.after_val_iter(runner, 1)
+ runner.log_processor.get_log_after_iter.assert_not_called()
+ logger_hook.after_val_iter(runner, 9)
+ runner.log_processor.get_log_after_iter.assert_called()
+
+ def test_after_test_iter(self):
+ logger_hook = LoggerHook()
+ runner = MagicMock()
+ runner.iter = 0
+ runner.log_processor.get_log_after_iter = MagicMock(
+ return_value=(dict(), 'log_str'))
+ logger_hook.after_test_iter(runner, 1)
+ runner.log_processor.get_log_after_iter.assert_not_called()
+ logger_hook.after_test_iter(runner, 9)
+ runner.log_processor.get_log_after_iter.assert_called()
diff --git a/tests/test_hook/test_optimizer_hook.py b/tests/test_hook/test_optimizer_hook.py
index 5d04ca3f..dc11ee0f 100644
--- a/tests/test_hook/test_optimizer_hook.py
+++ b/tests/test_hook/test_optimizer_hook.py
@@ -1,5 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
-from unittest.mock import Mock
+from unittest.mock import MagicMock, Mock
import torch
from torch import nn
@@ -45,7 +45,7 @@ class TestOptimizerHook:
model = Model()
x = torch.rand(1, 1, 3, 3)
- dummy_runner = Mock()
+ dummy_runner = MagicMock()
dummy_runner.optimizer.zero_grad = Mock(return_value=None)
dummy_runner.optimizer.step = Mock(return_value=None)
dummy_runner.model = model
diff --git a/tests/test_logging/test_log_processor.py b/tests/test_logging/test_log_processor.py
new file mode 100644
index 00000000..b10cac48
--- /dev/null
+++ b/tests/test_logging/test_log_processor.py
@@ -0,0 +1,242 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import copy
+from unittest.mock import MagicMock, patch
+
+import pytest
+import torch
+
+from mmengine.logging import LogProcessor, MessageHub, MMLogger
+
+
+class TestLogProcessor:
+
+ def test_init(self):
+ log_processor = LogProcessor(
+ window_size=10, by_epoch=True, custom_cfg=None)
+ assert log_processor.by_epoch
+ assert log_processor.window_size == 10
+ assert log_processor.custom_cfg == []
+
+ def test_check_custom_cfg(self):
+ # ``by_epoch==False`` and `window_size='epoch'` in log config will
+ # raise AssertionError.
+ custom_cfg = [dict(data_src='loss', window_size='epoch')]
+ with pytest.raises(AssertionError):
+ LogProcessor(by_epoch=False, custom_cfg=custom_cfg)
+ # Duplicate log_name will raise AssertionError.
+ custom_cfg = [
+ dict(data_src='loss', log_name='loss_1'),
+ dict(data_src='loss', log_name='loss_1')
+ ]
+ with pytest.raises(AssertionError):
+ LogProcessor(custom_cfg=custom_cfg)
+ # Overwrite loss item twice will raise AssertionError.
+ custom_cfg = [dict(data_src='loss'), dict(data_src='loss')]
+ with pytest.raises(AssertionError):
+ LogProcessor(custom_cfg=custom_cfg)
+
+ custom_cfg = [
+ dict(data_src='loss_cls', window_size=100, method_name='min'),
+ dict(data_src='loss', log_name='loss_min', method_name='max'),
+ dict(data_src='loss', log_name='loss_max', method_name='max')
+ ]
+ LogProcessor(custom_cfg=custom_cfg)
+
+ def test_parse_windows_size(self):
+ log_processor = LogProcessor()
+ # Test parse 'epoch' window_size.
+ log_processor.custom_cfg = [
+ dict(data_src='loss_cls', window_size='epoch')
+ ]
+ custom_cfg = log_processor._parse_windows_size(self.runner, 1)
+ assert custom_cfg[0]['window_size'] == 2
+
+ # Test parse 'global' window_size.
+ log_processor.custom_cfg = [
+ dict(data_src='loss_cls', window_size='global')
+ ]
+ custom_cfg = log_processor._parse_windows_size(self.runner, 1)
+ assert custom_cfg[0]['window_size'] == 11
+
+ # Test parse int window_size
+ log_processor.custom_cfg = [dict(data_src='loss_cls', window_size=100)]
+ custom_cfg = log_processor._parse_windows_size(self.runner, 1)
+ assert custom_cfg[0]['window_size'] == 100
+
+ # Invalid type window_size will raise TypeError.
+ log_processor.custom_cfg = [dict(data_src='loss_cls', window_size=[])]
+ with pytest.raises(TypeError):
+ log_processor._parse_windows_size(custom_cfg, self.runner)
+
+ @pytest.mark.parametrize('by_epoch,mode',
+ ([True, 'train'], [False, 'train'], [True, 'val'],
+ [False, 'val'], [True, 'test'], [False, 'test']))
+ def test_get_log_after_iter(self, by_epoch, mode):
+ # Prepare LoggerHook
+ log_processor = LogProcessor(by_epoch=by_epoch)
+ log_processor._get_max_memory = MagicMock(return_value='100')
+ eta = 40
+ self.runner.message_hub.update_info('eta', eta)
+ # Prepare training information.
+ if mode == 'train':
+ train_logs = dict(lr=0.1, time=1.0, data_time=1.0, loss_cls=1.0)
+ else:
+ train_logs = dict(time=1.0, data_time=1.0, loss_cls=1.0)
+ log_processor._collect_scalars = MagicMock(return_value=train_logs)
+ tag, out = log_processor.get_log_after_iter(self.runner, 1, mode)
+ # Verify that the correct context have been logged.
+ cur_loop = log_processor._get_cur_loop(self.runner, mode)
+ if by_epoch:
+ if mode in ['train', 'val']:
+ cur_epoch = log_processor._get_epoch(self.runner, mode)
+ log_str = (f'Epoch({mode}) [{cur_epoch}][2/'
+ f'{len(cur_loop.dataloader)}] ')
+ else:
+ log_str = (f'Epoch({mode}) [2/{len(cur_loop.dataloader)}] ')
+
+ if mode == 'train':
+ log_str += f"lr: {train_logs['lr']:.3e} "
+ else:
+ log_str += ' '
+
+ log_str += (f'eta: 0:00:40 '
+ f"time: {train_logs['time']:.3f} "
+ f"data_time: {train_logs['data_time']:.3f} ")
+
+ if torch.cuda.is_available():
+ log_str += 'memory: 100 '
+ if mode == 'train':
+ log_str += f"loss_cls: {train_logs['loss_cls']:.4f}"
+ assert out == log_str
+ else:
+ if mode == 'train':
+ max_iters = self.runner.train_loop.max_iters
+ log_str = f'Iter({mode}) [11/{max_iters}] '
+ else:
+ max_iters = len(cur_loop.dataloader)
+ log_str = f'Iter({mode}) [2/{max_iters}] '
+
+ if mode == 'train':
+ log_str += f"lr: {train_logs['lr']:.3e} "
+ else:
+ log_str += ' '
+
+ log_str += (f'eta: 0:00:40 '
+ f"time: {train_logs['time']:.3f} "
+ f"data_time: {train_logs['data_time']:.3f} ")
+
+ if torch.cuda.is_available():
+ log_str += 'memory: 100 '
+
+ if mode == 'train':
+ log_str += f"loss_cls: {train_logs['loss_cls']:.4f}"
+ assert out == log_str
+
+ @pytest.mark.parametrize(
+ 'by_epoch,mode',
+ ([True, 'val'], [False, 'val'], [True, 'test'], [False, 'test']))
+ def test_log_val(self, by_epoch, mode):
+ # Prepare LoggerHook
+ log_processor = LogProcessor(by_epoch=by_epoch)
+ # Prepare validation information.
+ val_logs = dict(accuracy=0.9, data_time=1.0)
+ log_processor._collect_scalars = MagicMock(return_value=val_logs)
+ _, out = log_processor.get_log_after_epoch(self.runner, 2, mode)
+ if by_epoch:
+ if mode == 'test':
+ assert out == 'Epoch(test) [5/5] accuracy: 0.9000'
+ else:
+ assert out == 'Epoch(val) [1][10/10] accuracy: 0.9000'
+ else:
+ if mode == 'test':
+ assert out == 'Iter(test) [5/5] accuracy: 0.9000'
+ else:
+ assert out == 'Iter(val) [10/10] accuracy: 0.9000'
+
+ def test_collect_scalars(self):
+ custom_cfg = [
+ dict(data_src='time', method_name='mean', window_size=100),
+ dict(data_src='time', method_name='max', log_name='time_max')
+ ]
+ logger_hook = LogProcessor(custom_cfg=custom_cfg)
+ # Collect with prefix.
+ log_scalars = {
+ 'train/time': MagicMock(),
+ 'lr': MagicMock(),
+ 'train/loss_cls': MagicMock(),
+ 'val/metric': MagicMock()
+ }
+ self.runner.message_hub._log_scalars = log_scalars
+ tag = logger_hook._collect_scalars(
+ copy.deepcopy(custom_cfg), self.runner, mode='train')
+ # Test training key in tag.
+ assert list(tag.keys()) == ['time', 'loss_cls', 'time_max']
+ # Test statistics lr with `current`, loss and time with 'mean'
+ log_scalars['train/time'].statistics.assert_called_with(
+ method_name='max')
+ log_scalars['train/loss_cls'].mean.assert_called()
+
+ tag = logger_hook._collect_scalars(
+ copy.deepcopy(custom_cfg), self.runner, mode='val')
+ assert list(tag.keys()) == ['metric']
+ log_scalars['val/metric'].current.assert_called()
+
+ @patch('torch.cuda.max_memory_allocated', MagicMock())
+ @patch('torch.cuda.reset_peak_memory_stats', MagicMock())
+ def test_get_max_memory(self):
+ logger_hook = LogProcessor()
+ runner = MagicMock()
+ runner.world_size = 1
+ runner.model = torch.nn.Linear(1, 1)
+ logger_hook._get_max_memory(runner)
+ torch.cuda.max_memory_allocated.assert_called()
+ torch.cuda.reset_peak_memory_stats.assert_called()
+
+ def test_get_iter(self):
+ log_processor = LogProcessor()
+ # Get global iter when `inner_iter=False`
+ iter = log_processor._get_iter(self.runner)
+ assert iter == 11
+ # Get inner iter
+ iter = log_processor._get_iter(self.runner, 1)
+ assert iter == 2
+ # Still get global iter when `logger_hook.by_epoch==False`
+ log_processor.by_epoch = False
+ iter = log_processor._get_iter(self.runner, 1)
+ assert iter == 11
+
+ def test_get_epoch(self):
+ log_processor = LogProcessor()
+ epoch = log_processor._get_epoch(self.runner, 'train')
+ assert epoch == 2
+ epoch = log_processor._get_epoch(self.runner, 'val')
+ assert epoch == 1
+ with pytest.raises(ValueError):
+ log_processor._get_epoch(self.runner, 'test')
+
+ def test_get_cur_loop(self):
+ log_processor = LogProcessor()
+ loop = log_processor._get_cur_loop(self.runner, 'train')
+ assert len(loop.dataloader) == 20
+ loop = log_processor._get_cur_loop(self.runner, 'val')
+ assert len(loop.dataloader) == 10
+ loop = log_processor._get_cur_loop(self.runner, 'test')
+ assert len(loop.dataloader) == 5
+
+ def setup(self):
+ runner = MagicMock()
+ runner.epoch = 1
+ runner.iter = 10
+ runner.train_loop.max_iters = 50
+ runner.train_loop.dataloader = [0] * 20
+ runner.val_loop.dataloader = [0] * 10
+ runner.test_loop.dataloader = [0] * 5
+ logger = MMLogger.get_instance('log_processor_test')
+ runner.logger = logger
+ message_hub = MessageHub.get_instance('log_processor_test')
+ for i in range(10):
+ message_hub.update_scalar('train/loss', 10 - i)
+ for i in range(10):
+ message_hub.update_scalar('val/acc', i * 0.1)
+ runner.message_hub = message_hub
+ self.runner = runner
diff --git a/tests/test_runner/test_runner.py b/tests/test_runner/test_runner.py
index 7b497c8e..a2576dec 100644
--- a/tests/test_runner/test_runner.py
+++ b/tests/test_runner/test_runner.py
@@ -221,7 +221,7 @@ class TestRunner(TestCase):
self.iter_based_cfg.default_hooks = dict(
timer=dict(type='IterTimerHook'),
checkpoint=dict(type='CheckpointHook', interval=1, by_epoch=False),
- logger=dict(type='LoggerHook', by_epoch=False),
+ logger=dict(type='LoggerHook'),
optimizer=dict(type='OptimizerHook', grad_clip=None),
param_scheduler=dict(type='ParamSchedulerHook'))
From c3aff4fc9afc40c013113f21c5db2489ebda2d59 Mon Sep 17 00:00:00 2001
From: Tong Gao
Date: Mon, 25 Apr 2022 13:44:15 +0800
Subject: [PATCH 4/4] [Enhancement] Add PolyParamScheduler, PolyMomentum and
PolyLR (#188)
* [Enhancement] Add PolyParamScheduler, PolyMomentum and PolyLR
* min_lr -> eta_min, refined docstr
---
mmengine/optim/scheduler/__init__.py | 11 ++--
mmengine/optim/scheduler/lr_scheduler.py | 49 ++++++++++++++-
.../optim/scheduler/momentum_scheduler.py | 49 ++++++++++++++-
mmengine/optim/scheduler/param_scheduler.py | 62 +++++++++++++++++++
.../test_scheduler/test_lr_scheduler.py | 23 ++++++-
.../test_scheduler/test_momentum_scheduler.py | 25 +++++++-
.../test_scheduler/test_param_scheduler.py | 32 +++++++++-
7 files changed, 240 insertions(+), 11 deletions(-)
diff --git a/mmengine/optim/scheduler/__init__.py b/mmengine/optim/scheduler/__init__.py
index f7ea1d57..733ca752 100644
--- a/mmengine/optim/scheduler/__init__.py
+++ b/mmengine/optim/scheduler/__init__.py
@@ -1,14 +1,14 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .lr_scheduler import (ConstantLR, CosineAnnealingLR, ExponentialLR,
- LinearLR, MultiStepLR, StepLR)
+ LinearLR, MultiStepLR, PolyLR, StepLR)
from .momentum_scheduler import (ConstantMomentum, CosineAnnealingMomentum,
ExponentialMomentum, LinearMomentum,
- MultiStepMomentum, StepMomentum)
+ MultiStepMomentum, PolyMomentum, StepMomentum)
from .param_scheduler import (ConstantParamScheduler,
CosineAnnealingParamScheduler,
ExponentialParamScheduler, LinearParamScheduler,
- MultiStepParamScheduler, StepParamScheduler,
- _ParamScheduler)
+ MultiStepParamScheduler, PolyParamScheduler,
+ StepParamScheduler, _ParamScheduler)
__all__ = [
'ConstantLR', 'CosineAnnealingLR', 'ExponentialLR', 'LinearLR',
@@ -16,5 +16,6 @@ __all__ = [
'ExponentialMomentum', 'LinearMomentum', 'MultiStepMomentum',
'StepMomentum', 'ConstantParamScheduler', 'CosineAnnealingParamScheduler',
'ExponentialParamScheduler', 'LinearParamScheduler',
- 'MultiStepParamScheduler', 'StepParamScheduler', '_ParamScheduler'
+ 'MultiStepParamScheduler', 'StepParamScheduler', '_ParamScheduler',
+ 'PolyParamScheduler', 'PolyLR', 'PolyMomentum'
]
diff --git a/mmengine/optim/scheduler/lr_scheduler.py b/mmengine/optim/scheduler/lr_scheduler.py
index 514b8b03..3c774a67 100644
--- a/mmengine/optim/scheduler/lr_scheduler.py
+++ b/mmengine/optim/scheduler/lr_scheduler.py
@@ -7,7 +7,8 @@ from mmengine.registry import PARAM_SCHEDULERS
from .param_scheduler import (INF, ConstantParamScheduler,
CosineAnnealingParamScheduler,
ExponentialParamScheduler, LinearParamScheduler,
- MultiStepParamScheduler, StepParamScheduler)
+ MultiStepParamScheduler, PolyParamScheduler,
+ StepParamScheduler)
@PARAM_SCHEDULERS.register_module()
@@ -294,3 +295,49 @@ class StepLR(StepParamScheduler):
last_step=last_step,
by_epoch=by_epoch,
verbose=verbose)
+
+
+@PARAM_SCHEDULERS.register_module()
+class PolyLR(PolyParamScheduler):
+ """Decays the learning rate of each parameter group in a polynomial decay
+ scheme.
+
+ Notice that such decay can happen simultaneously with other changes to the
+ parameter value from outside this scheduler.
+
+ Args:
+ optimizer (Optimizer): Wrapped optimizer.
+ eta_min (float): Minimum learning rate at the end of scheduling.
+ Defaults to 0.
+ power (float): The power of the polynomial. Defaults to 1.0.
+ begin (int): Step at which to start updating the parameters.
+ Defaults to 0.
+ end (int): Step at which to stop updating the parameters.
+ Defaults to INF.
+ last_step (int): The index of last step. Used for resume without
+ state dict. Defaults to -1.
+ by_epoch (bool): Whether the scheduled parameters are updated by
+ epochs. Defaults to True.
+ verbose (bool): Whether to print the value for each update.
+ Defaults to False.
+ """
+
+ def __init__(self,
+ optimizer: torch.optim.Optimizer,
+ eta_min: float = 0,
+ power: float = 1,
+ begin: int = 0,
+ end: int = INF,
+ last_step: int = -1,
+ by_epoch: bool = True,
+ verbose: bool = False):
+ super().__init__(
+ optimizer,
+ param_name='lr',
+ eta_min=eta_min,
+ power=power,
+ begin=begin,
+ end=end,
+ last_step=last_step,
+ by_epoch=by_epoch,
+ verbose=verbose)
diff --git a/mmengine/optim/scheduler/momentum_scheduler.py b/mmengine/optim/scheduler/momentum_scheduler.py
index cc882c3b..fa357eb1 100644
--- a/mmengine/optim/scheduler/momentum_scheduler.py
+++ b/mmengine/optim/scheduler/momentum_scheduler.py
@@ -7,7 +7,8 @@ from mmengine.registry import PARAM_SCHEDULERS
from .param_scheduler import (INF, ConstantParamScheduler,
CosineAnnealingParamScheduler,
ExponentialParamScheduler, LinearParamScheduler,
- MultiStepParamScheduler, StepParamScheduler)
+ MultiStepParamScheduler, PolyParamScheduler,
+ StepParamScheduler)
@PARAM_SCHEDULERS.register_module()
@@ -294,3 +295,49 @@ class StepMomentum(StepParamScheduler):
last_step=last_step,
by_epoch=by_epoch,
verbose=verbose)
+
+
+@PARAM_SCHEDULERS.register_module()
+class PolyMomentum(PolyParamScheduler):
+ """Decays the momentum of each parameter group in a polynomial decay
+ scheme.
+
+ Notice that such decay can happen simultaneously with other changes to the
+ parameter value from outside this scheduler.
+
+ Args:
+ optimizer (Optimizer): Wrapped optimizer.
+ eta_min (float): Minimum momentum at the end of scheduling.
+ Defaults to 0.
+ power (float): The power of the polynomial. Defaults to 1.0.
+ begin (int): Step at which to start updating the parameters.
+ Defaults to 0.
+ end (int): Step at which to stop updating the parameters.
+ Defaults to INF.
+ last_step (int): The index of last step. Used for resume without
+ state dict. Defaults to -1.
+ by_epoch (bool): Whether the scheduled parameters are updated by
+ epochs. Defaults to True.
+ verbose (bool): Whether to print the value for each update.
+ Defaults to False.
+ """
+
+ def __init__(self,
+ optimizer: torch.optim.Optimizer,
+ eta_min: float = 0,
+ power: float = 1,
+ begin: int = 0,
+ end: int = INF,
+ last_step: int = -1,
+ by_epoch: bool = True,
+ verbose: bool = False):
+ super().__init__(
+ optimizer,
+ param_name='momentum',
+ eta_min=eta_min,
+ power=power,
+ begin=begin,
+ end=end,
+ last_step=last_step,
+ by_epoch=by_epoch,
+ verbose=verbose)
diff --git a/mmengine/optim/scheduler/param_scheduler.py b/mmengine/optim/scheduler/param_scheduler.py
index bbec0556..f40507e5 100644
--- a/mmengine/optim/scheduler/param_scheduler.py
+++ b/mmengine/optim/scheduler/param_scheduler.py
@@ -534,6 +534,7 @@ class LinearParamScheduler(_ParamScheduler):
Notice that such decay can happen simultaneously with other changes to the
parameter value from outside this scheduler.
+
Args:
optimizer (Optimizer): Wrapped optimizer.
start_factor (float): The number we multiply parameter value in the
@@ -598,3 +599,64 @@ class LinearParamScheduler(_ParamScheduler):
(self.end_factor - self.start_factor)))
for group in self.optimizer.param_groups
]
+
+
+@PARAM_SCHEDULERS.register_module()
+class PolyParamScheduler(_ParamScheduler):
+ """Decays the parameter value of each parameter group in a polynomial decay
+ scheme.
+
+ Notice that such decay can happen simultaneously with other changes to the
+ parameter value from outside this scheduler.
+
+ Args:
+ optimizer (Optimizer): Wrapped optimizer.
+ eta_min (float): Minimum parameter value at the end of scheduling.
+ Defaults to 0.
+ power (float): The power of the polynomial. Defaults to 1.0.
+ begin (int): Step at which to start updating the parameters.
+ Defaults to 0.
+ end (int): Step at which to stop updating the parameters.
+ Defaults to INF.
+ last_step (int): The index of last step. Used for resume without
+ state dict. Defaults to -1.
+ by_epoch (bool): Whether the scheduled parameters are updated by
+ epochs. Defaults to True.
+ verbose (bool): Whether to print the value for each update.
+ Defaults to False.
+ """
+
+ def __init__(self,
+ optimizer: Optimizer,
+ param_name: str,
+ eta_min: float = 0,
+ power: float = 1.0,
+ begin: int = 0,
+ end: int = INF,
+ last_step: int = -1,
+ by_epoch: bool = True,
+ verbose: bool = False):
+
+ self.eta_min = eta_min
+ self.power = power
+ self.total_iters = end - begin - 1
+
+ super().__init__(
+ optimizer,
+ param_name=param_name,
+ begin=begin,
+ end=end,
+ last_step=last_step,
+ by_epoch=by_epoch,
+ verbose=verbose)
+
+ def _get_value(self):
+
+ if self.last_step == 0:
+ return [
+ group[self.param_name] for group in self.optimizer.param_groups
+ ]
+
+ return [(group[self.param_name] - self.eta_min) *
+ (1 - 1 / (self.total_iters - self.last_step + 1))**self.power +
+ self.eta_min for group in self.optimizer.param_groups]
diff --git a/tests/test_optim/test_scheduler/test_lr_scheduler.py b/tests/test_optim/test_scheduler/test_lr_scheduler.py
index d747b6bd..6e8f337d 100644
--- a/tests/test_optim/test_scheduler/test_lr_scheduler.py
+++ b/tests/test_optim/test_scheduler/test_lr_scheduler.py
@@ -8,7 +8,7 @@ import torch.optim as optim
from mmengine.optim.scheduler import (ConstantLR, CosineAnnealingLR,
ExponentialLR, LinearLR, MultiStepLR,
- StepLR, _ParamScheduler)
+ PolyLR, StepLR, _ParamScheduler)
from mmengine.testing import assert_allclose
@@ -283,6 +283,21 @@ class TestLRScheduler(TestCase):
scheduler = CosineAnnealingLR(self.optimizer, T_max=t, eta_min=eta_min)
self._test_scheduler_value(scheduler, targets, epochs)
+ def test_poly_scheduler(self):
+ epochs = 10
+ power = 0.9
+ min_lr = 0.001
+ iters = 4
+ single_targets = [
+ min_lr + (0.05 - min_lr) * (1 - i / iters)**power
+ for i in range(iters)
+ ] + [min_lr] * (
+ epochs - iters)
+ targets = [single_targets, [x * epochs for x in single_targets]]
+ scheduler = PolyLR(
+ self.optimizer, power=power, eta_min=min_lr, end=iters + 1)
+ self._test_scheduler_value(scheduler, targets, epochs=10)
+
def _check_scheduler_state_dict(self, construct, construct2, epochs=10):
scheduler = construct()
for _ in range(epochs):
@@ -331,6 +346,12 @@ class TestLRScheduler(TestCase):
lambda: LinearLR(self.optimizer, start_factor=0, end_factor=0.3),
epochs=epochs)
+ def test_poly_scheduler_state_dict(self):
+ self._check_scheduler_state_dict(
+ lambda: PolyLR(self.optimizer, power=0.5, eta_min=0.001),
+ lambda: PolyLR(self.optimizer, power=0.8, eta_min=0.002),
+ epochs=10)
+
def test_multi_scheduler_without_overlap_linear_multi_step(self):
# use Linear in the first 5 epochs and then use MultiStep
epochs = 12
diff --git a/tests/test_optim/test_scheduler/test_momentum_scheduler.py b/tests/test_optim/test_scheduler/test_momentum_scheduler.py
index fd63a9b9..97d7af3b 100644
--- a/tests/test_optim/test_scheduler/test_momentum_scheduler.py
+++ b/tests/test_optim/test_scheduler/test_momentum_scheduler.py
@@ -9,8 +9,8 @@ import torch.optim as optim
from mmengine.optim.scheduler import (ConstantMomentum,
CosineAnnealingMomentum,
ExponentialMomentum, LinearMomentum,
- MultiStepMomentum, StepMomentum,
- _ParamScheduler)
+ MultiStepMomentum, PolyMomentum,
+ StepMomentum, _ParamScheduler)
from mmengine.testing import assert_allclose
@@ -284,6 +284,21 @@ class TestMomentumScheduler(TestCase):
self.optimizer, T_max=t, eta_min=eta_min)
self._test_scheduler_value(scheduler, targets, epochs)
+ def test_poly_scheduler(self):
+ epochs = 10
+ power = 0.9
+ min_lr = 0.001
+ iters = 4
+ single_targets = [
+ min_lr + (0.05 - min_lr) * (1 - i / iters)**power
+ for i in range(iters)
+ ] + [min_lr] * (
+ epochs - iters)
+ targets = [single_targets, [x * epochs for x in single_targets]]
+ scheduler = PolyMomentum(
+ self.optimizer, power=power, eta_min=min_lr, end=iters + 1)
+ self._test_scheduler_value(scheduler, targets, epochs=10)
+
def _check_scheduler_state_dict(self, construct, construct2, epochs=10):
scheduler = construct()
for _ in range(epochs):
@@ -333,6 +348,12 @@ class TestMomentumScheduler(TestCase):
self.optimizer, start_factor=0, end_factor=0.3),
epochs=epochs)
+ def test_poly_scheduler_state_dict(self):
+ self._check_scheduler_state_dict(
+ lambda: PolyMomentum(self.optimizer, power=0.5, eta_min=0.001),
+ lambda: PolyMomentum(self.optimizer, power=0.8, eta_min=0.002),
+ epochs=10)
+
def test_multi_scheduler_without_overlap_linear_multi_step(self):
# use Linear in the first 5 epochs and then use MultiStep
epochs = 12
diff --git a/tests/test_optim/test_scheduler/test_param_scheduler.py b/tests/test_optim/test_scheduler/test_param_scheduler.py
index d1467828..c4703392 100644
--- a/tests/test_optim/test_scheduler/test_param_scheduler.py
+++ b/tests/test_optim/test_scheduler/test_param_scheduler.py
@@ -6,12 +6,15 @@ import torch
import torch.nn.functional as F
import torch.optim as optim
+# yapf: disable
from mmengine.optim.scheduler import (ConstantParamScheduler,
CosineAnnealingParamScheduler,
ExponentialParamScheduler,
LinearParamScheduler,
MultiStepParamScheduler,
- StepParamScheduler, _ParamScheduler)
+ PolyParamScheduler, StepParamScheduler,
+ _ParamScheduler)
+# yapf: enable
from mmengine.testing import assert_allclose
@@ -336,6 +339,25 @@ class TestParameterScheduler(TestCase):
self.optimizer, param_name='lr', T_max=t, eta_min=eta_min)
self._test_scheduler_value(scheduler, targets, epochs)
+ def test_poly_scheduler(self):
+ epochs = 10
+ power = 0.9
+ min_lr = 0.001
+ iters = 4
+ single_targets = [
+ min_lr + (0.05 - min_lr) * (1 - i / iters)**power
+ for i in range(iters)
+ ] + [min_lr] * (
+ epochs - iters)
+ targets = [single_targets, [x * epochs for x in single_targets]]
+ scheduler = PolyParamScheduler(
+ self.optimizer,
+ param_name='lr',
+ power=power,
+ eta_min=min_lr,
+ end=iters + 1)
+ self._test_scheduler_value(scheduler, targets, epochs=10)
+
def _check_scheduler_state_dict(self, construct, construct2, epochs=10):
scheduler = construct()
for _ in range(epochs):
@@ -402,6 +424,14 @@ class TestParameterScheduler(TestCase):
end_factor=0.3),
epochs=epochs)
+ def test_poly_scheduler_state_dict(self):
+ self._check_scheduler_state_dict(
+ lambda: PolyParamScheduler(
+ self.optimizer, param_name='lr', power=0.5, eta_min=0.001),
+ lambda: PolyParamScheduler(
+ self.optimizer, param_name='lr', power=0.8, eta_min=0.002),
+ epochs=10)
+
def test_multi_scheduler_without_overlap_linear_multi_step(self):
# use Linear in the first 5 epochs and then use MultiStep
epochs = 12