[Fix] resolve conflict betweem adapt and main. (#198)

* [Docs] Refine registry documentation (#186)

* [Docs] Refine registry documentation

* reslove comments

* minor refinement

* Refine Visualizer docs (#177)

* Refine Visualizer docs

* update

* update

* update featmap

* update docs

* update visualizer docs

* [Refactor] Refine LoggerHook (#155)

* rename global accessible and intergration get_sintance and create_instance

* move ManagerMixin to utils

* fix as docstring and seporate get_instance to get_instance and get_current_instance

* fix lint

* fix docstring, rename and move test_global_meta

* rename LogBuffer to HistoryBuffer, rename MessageHub methods, MessageHub support resume

* refine MMLogger timestamp, update unit test

* MMLogger add logger_name arguments

* Fix docstring

* Add LogProcessor and some unit test

* update unit test

* complete LogProcessor unit test

* refine LoggerHook

* solve circle import

* change default logger_name to mmengine

* refactor eta

* Fix docstring comment and unitt test

* Fix with runner

* fix docstring

fix docstring

* fix docstring

* Add by_epoch attribute to LoggerHook and fix docstring

* Please mypy and fix comment

* remove \ in MMLogger

* Fix lint

* roll back pre-commit-hook

* Fix hook unit test

* Fix comments

* remove \t in log and add docstring

* Fix as comment

* should not accept other arguments if corresponding instance has been created

* fix logging ddp file saving

* fix logging ddp file saving

* move log processor to logging

* move log processor to logging

* remove current datalaoder

* fix docstring

* fix unit test

* add learing rate in messagehub

* Support output training/validation/testing message after iterations/epochs

* fix docstring

* Fix IterBasedRunner log string

* Fix IterBasedRunner log string

* Support parse validation loss in log processor

* [Enhancement] Add PolyParamScheduler, PolyMomentum and PolyLR (#188)

* [Enhancement] Add PolyParamScheduler, PolyMomentum and PolyLR

* min_lr -> eta_min, refined docstr

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
Co-authored-by: Haian Huang(深度眸) <1286304229@qq.com>
Co-authored-by: Tong Gao <gaotongxiao@gmail.com>
pull/191/head
Mashiro 2022-04-26 00:37:16 +08:00 committed by GitHub
parent fb7d8ccd6b
commit e0d00c5bdd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 1555 additions and 759 deletions

View File

@ -262,7 +262,7 @@ class RetinaNet(nn.Module):
![registry](https://user-images.githubusercontent.com/58739961/153880947-1d66ac06-e5ee-448e-8d7d-201e96d1101d.png)
我们可以在 `MMDetection` 中调用 `MMEngine` 中模块。
我们可以在 `MMDetection` 中调用 `MMEngine`模块。
```python
from mmdet.models import MODELS
@ -278,6 +278,29 @@ model = MODELS.build(cfg=dict(type='Conv2d'))
如果不加前缀,`build` 方法首先查找当前节点是否存在该模块,如果存在则返回该模块,否则会继续向上查找父节点甚至祖先节点直到找到该模块,因此,如果当前节点和父节点存在同一模块并且希望调用父节点的模块,我们需要指定 `scope` 前缀。需要注意的是,向上查找父节点甚至祖先节点的**前提是父节点或者祖先节点的模块已通过某种方式被导入进而完成注册**。例如,在上面这个示例中,之所以没有显示导入父节点 `mmengine` 中的 `MODELS`,是因为通过 `from mmdet.models import MODELS` 间接触发 `mmengine.MODELS` 完成模块的注册。
上面展示了如何使用子节点注册器构建模块,但有时候我们希望不填加前缀也能在父节点注册器中构建子节点的模块,目的是提供通用的代码,避免下游算法库重复造轮子,该如何实现呢?
假设 MMEngine 中有一个 `build_model` 函数,该方法用于构建模型。
```python
from mmengine.registry import MODELS
def build_model(cfg):
model = MODELS.build(cfg)
```
如果我们希望在 MMDetection 中调用该函数构建 MMDetection 注册的模块,那么我们需要先获取一个 scope_name 为 'mmdet' 的 [DefaultScope](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.registry.DefaultScope) 实例,该实例全局唯一。
```python
from mmengine import build_model
import mmdet.models # 通过 import 的方式将 mmdet 中的模块导入注册器进而完成注册
default_scope = DefaultScope.get_instance('my_experiment', scope_name='mmdet')
model = build_model(cfg=dict(type='RetinaNet'))
```
获取 `DefaultScope` 实例的目的是使 Registry 的 build 方法会将 DefaultScope 名称mmdet注册器节点作为注册器的起点才能在配置中不填加 mmdet 前缀的情况下在 MMDetection 的注册器节点中找到 RetinaNet 模块,如若不然,程序会报找不到 RetinaNet 错误。
### 调用兄弟节点的模块
除了可以调用父节点的模块,也可以调用兄弟节点的模块。
@ -311,16 +334,7 @@ from mmcls.models import MODELS
model = MODELS.build(cfg=dict(type='mmdet.RetinaNet'))
```
调用非本节点的模块需要指定在 `type` 中指定 `scope` 前缀,如果不想指定,我们可以创建一个全局变量 `default_scope` 并将 `scope_name` 设置为 'mmdet'`Registry` 会将 `scope_name` 对应的 `registry` 作为当前 `Registry` 并调用 `build` 方法。
```python
from mmengine.registry import DefaultScope, MODELS
# 调用注册在 mmdet 中的 RetinaNet
default_scope = DefaultScope.get_instance(
'my_experiment', scope_name='mmdet')
model = MODELS.build(cfg=dict(type='RetinaNet'))
```
调用非本节点或父节点的模块需要在 `type` 中指定 `scope` 前缀。
注册器除了支持两层结构,三层甚至更多层结构也是支持的。
@ -358,10 +372,4 @@ model = MODELS.build(cfg=dict(type='mmcls.ResNet'))
from mmcls.models import MODELS
# 需要注意前缀的顺序,'detplus.mmdet.ResNet' 是不正确的
model = MODELS.build(cfg=dict(type='mmdet.detplus.MetaNet'))
# 如果希望默认从 detplus 构建模型,设置可以 default_scope
from mmengine.registry import DefaultScope
default_scope = DefaultScope.get_instance(
'my_experiment', scope_name='detplus')
model = MODELS.build(cfg=dict(type='MetaNet', default_scope='detplus'))
```

View File

@ -0,0 +1,300 @@
# 可视化 (Visualization)
## 概述
可视化可以给深度学习的模型训练和测试过程提供直观解释。在 OpenMMLab 算法库中,我们期望可视化功能的设计能满足以下需求:
- 提供丰富的开箱即用可视化功能,能够满足大部分计算机视觉可视化任务
- 高扩展性,可视化功能通常多样化,应该能够通过简单扩展实现定制需求
- 能够在训练和测试流程的任意点位进行可视化
- OpenMMLab 各个算法库具有统一可视化接口,利于用户理解和维护
基于上述需求OpenMMLab 2.0 引入了可视化对象 Visualizer 和各个可视化存储后端 VisBackend 如 `LocalVisBackend`、`WandbVisBackend` 和 `TensorboardVisBackend` 等。此处的可视化不仅仅包括图片数据格式,还包括配置内容、标量和模型图等数据的可视化。
- 为了方便调用Visualizer 提供的接口实现了绘制和存储的功能。可视化存储后端 VisBackend 作为 Visualizer 的内部属性,会在需要的时候被 Visualizer 调用,将数据存到不同的后端
- 考虑到绘制后会希望存储到多个后端Visualizer 可以配置多个 VisBackend当用户调用 Visualizer 的存储接口时候Visualizer 内部会遍历的调用 VisBackend 存储接口
两者的 UML 关系图如下
<div align="center">
<img src="https://user-images.githubusercontent.com/17425982/163327736-f7cb3b16-ef07-46bc-982a-3cc7495e6c82.png" >
</div>
## 可视化对象 Visualizer
### 接口说明
可视化对象 Visualizer 对外提供了所有接口。可以将其接口分成 3 大类,如下所示
**(1) 绘制相关接口**
- [draw_bboxes](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.Visualizer.draw_bboxes) 绘制单个或多个边界框
- [draw_points](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.Visualizer.draw_points) 绘制单个或多个点
- [draw_texts](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.Visualizer.draw_texts) 绘制单个或多个文本框
- [draw_lines](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.Visualizer.lines) 绘制单个或多个线段
- [draw_circles](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.Visualizer.draw_circles) 绘制单个或多个圆
- [draw_polygons](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.Visualizer.draw_polygons) 绘制单个或多个多边形
- [draw_binary_masks](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.Visualizer.draw_binary_mask) 绘制单个或多个二值掩码
- [draw_featmap](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.Visualizer.draw_featmap) 绘制特征图,静态方法
上述接口除了 `draw_featmap` 外都可以链式调用,因为该方法调用后可能会导致图片尺寸发生改变。为了避免给用户带来困扰, `draw_featmap` 被设置为静态方法。
当用户想先绘制边界框,在此基础上绘制文本,绘制线段的时候,可以通过链式调用实现:
```python
visualizer.set_image(image)
visualizer.draw_bboxes(...).draw_texts(...).draw_lines(...)
visualizer.show() # 可视化绘制结果
```
特征图可视化是一个常见的功能,用户通过调用 `draw_featmap` 可视化特征图,其参数定义为:
```python
@staticmethod
def draw_featmap(featmap: torch.Tensor, # 输入格式要求为 CHW
overlaid_image: Optional[np.ndarray] = None, # 如果同时输入了 image 数据,则特征图会叠加到 image 上绘制
channel_reduction: Optional[str] = 'squeeze_mean', # 多个通道压缩为单通道的策略
topk: int = 10, # 可选择激活度最高的 topk 个特征图显示
arrangement: Tuple[int, int] = (5, 2), # 多通道展开为多张图时候布局
resize_shapeOptional[tuple] = None, # 可以指定 resize_shape 参数来缩放特征图
alpha: float = 0.5) -> np.ndarray: # 图片和特征图绘制的叠加比例
```
特征图可视化功能较多,目前不支持 Batch 输入,其功能可以归纳如下
- 输入的 Tensor 一般是包括多个通道的channel_reduction 参数可以将多个通道压缩为单通道,然后和图片进行叠加显示
- `squeeze_mean` 将输入的 C 维度采用 mean 函数压缩为一个通道,输出维度变成 (1, H, W)
- `select_max` 从输入的 C 维度中先在空间维度 sum维度变成 (C, ),然后选择值最大的通道
- `None` 表示不需要压缩,此时可以通过 topk 参数可选择激活度最高的 topk 个特征图显示
- 在 channel_reduction 参数为 None 的情况下topk 参数生效,其会按照激活度排序选择 topk 个通道,然后和图片进行叠加显示,并且此时会通过 arrangement 参数指定显示的布局
- 如果 topk 不是 -1则会按照激活度排序选择 topk 个通道显示
- 如果 topk = -1此时通道 C 必须是 1 或者 3 表示输入数据是图片,否则报错提示用户应该设置 `channel_reduction`来压缩通道。
- 考虑到输入的特征图通常非常小,函数支持输入 `resize_shape` 参数,方便将特征图进行上采样后进行可视化。
**(2) 存储相关接口**
- [add_config](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.writer.BaseWriter.add_config) 写配置到特定存储后端
- [add_graph](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.writer.BaseWriter.add_graph) 写模型图到特定存储后端
- [add_image](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.writer.BaseWriter.add_image) 写图片到特定存储后端
- [add_scalar](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.writer.BaseWriter.add_scalar) 写标量到特定存储后端
- [add_scalars](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.writer.BaseWriter.add_scalars) 一次性写多个标量到特定存储后端
- [add_datasample](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.writer.BaseWriter.add_datasample) 各个下游库绘制 datasample 数据的抽象接口
以 add 前缀开头的接口表示存储接口。datasample 是 OpenMMLab 2.0 架构中设计的各个下游库统一的抽象数据接口,而 `add_datasample` 接口可以直接处理该数据格式,例如可视化预测结果、可视化 Dataset 或者 DataLoader 输出、可视化中间预测结果等等都可以直接调用下游库重写的 `add_datasample` 接口。
所有下游库都必须要继承 Visualizer 并实现 `add_datasample` 接口。以 MMDetection 为例,应该继承并通过该接口实现目标检测中所有预置任务的可视化功能,例如目标检测、实例分割、全景分割任务结果的绘制和存储。
**(3) 其余功能性接口**
- [set_image](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.Visualizer.set_image) 设置原始图片数据,默认输入图片格式为 RGB
- [get_image](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.Visualizer.get_image) 获取绘制后的 Numpy 格式图片数据,默认输出格式为 RGB
- [show](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.Visualizer.show) 可视化
- [get_backend](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.Visualizer.get_backend) 通过 name 获取特定存储后端
- [close](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.Visualizer.close) 关闭所有已经打开的资源,包括 VisBackend
### 使用样例
**(1) 在任意位置获取 visualizer**
为了确保可视化对象 Visualizer 能够在任何地方被调用,设计上将其继承自 `ManagerMixin` 类,转变为全局唯一对象,用户初始化 `Visualizer` 时必须要调用 `visualizer.get_instance()` 方法才能使实例对象具备全局唯一性。一旦实例化完成,后续可以在任意代码位置通过 `Visualizer.get_current_instance()` 来获取可视化对象。
以 MMDetection 为例,假设 `DetLocalVisualizer` 类继承自 `Visualizer`,并实现了 `add_datasample` 接口。配置文件写法为:
```python
vis_backends = [dict(type='LocalVisBackend')]
visualizer = dict(
type='DetLocalVisualizer', vis_backends=vis_backends, name='visualizer')
```
```python
# 内部会调用 get_instance() 进行全局唯一实例化
VISUALIZERS.build(cfg.visualizer)
```
通过上述代码实例化后,可以在任意位置调用 `get_current_instance` 方法来获取 visualizer
```python
# 任意代码位置获取 visualizer
visualizer = Visualizer.get_current_instance()
```
如果用户直接使用了 MMEngine 或者下游库中的 Runner则无需进行额外的实例化因为在 Runner 的初始化函数中会自动创建全局唯一的 visualizer。
**(2) 将数据写入至特定后端**
在获取到 visualizer 后,可以调用 `add_xxx` 接口将各类数据写入到特定后端
```python
# 绘制 datasample并保存到本地存储后端
visualizer.add_datasample('demo_image', image, gt_sample, pred_sample, step=1)
# 直接本地窗口显示,而无需存储
visualizer.add_datasample('demo_image', image, gt_sample, pred_sample, show=True)
# 写图片
visualizer.add_image('demo_image', image, step=1)
# 写模型精度值
visualizer.add_scalar('mAP', 0.9, step=1)
visualizer.add_scalars({'loss': 1.2, 'acc': 0.8}, step=1)
# 写配置文件
visualizer.add_config(cfg)
# 写模型图
visualizer.add_graph(model, data_batch)
```
**(3) 特征图可视化**
通过 `channel_reduction` 参数压缩或者选择特征图,并显示到本地窗口
```python
featmap = ... # CHW shape 的 tensor
# 压缩
feat_img = visualizer.draw_featmap(featmap, channel_reduction='squeeze_mean')
visualizer.show(feat_img)
# 选择激活度最高的通道显示
feat_img = visualizer.draw_featmap(featmap, channel_reduction='select_max')
visualizer.show(feat_img)
```
叠加图片显示
```python
featmap = ... # CHW shape 的 tensor
img = ... # 如果 featmap 和 img 空间尺寸不一致,内部会对 featmap 进行插值
# 压缩
feat_img = visualizer.draw_featmap(featmap, img, channel_reduction='squeeze_mean')
visualizer.show(feat_img)
# 选择激活度最高的通道显示
feat_img = visualizer.draw_featmap(featmap, img, channel_reduction='select_max')
visualizer.show(feat_img)
```
通过 `topk` 参数选择指定个数的通道显示,并显示到本地窗口
```python
featmap= ... # CHW shape 的 tensor
# topk并以 2 行 5 列模式显示
feat_img = visualizer.draw_featmap(featmap, channel_reduction=None, topk=10, arrangement=(2, 5))
visualizer.show(feat_img)
# topk并以 5 行 2 列模式显示
feat_img = visualizer.draw_featmap(featmap, channel_reduction=None, topk=10, arrangement=(5, 2))
visualizer.show(feat_img)
```
通过 `resize_shape` 缩放显示的特征图
```python
featmap = ... # CHW shape 的 tensor
# 压缩
feat_img = visualizer.draw_featmap(featmap, channel_reduction='squeeze_mean', resize_shape=(224, 224))
visualizer.show(feat_img)
```
存储特征图到可视化后端
```python
featmap = ... # CHW shape 的 tensor
# 压缩
feat_img = visualizer.draw_featmap(featmap, channel_reduction='squeeze_mean', resize_shape=(224, 224))
# 存储
visualizer.add_image('feat_image', feat_img)
```
**(4) 远程窗口显示**
用户可以指定 Wandb 、Tensorboard 或者自定义具备远程窗口显示的后端来保存数据,然后在浏览器上显示。以 Wandb 为例,典型配置为:
```python
vis_backends = [dict(type='WandbVisBackend')]
visualizer = dict(
type='DetWandbVisualizer', vis_backends=vis_backends, name='visualizer')
```
使用方法和上面完全一致。需要特别注意的是由于 Wandb 绘制的数据无法和 `LocalVisBackend` 后端兼容,所以当 `vis_backends` 存在多个可视化存储后端时候只有 `WandbVisBackend` 才是有效的。
## 可视化存储后端 VisBackend
在绘制后可以将绘制后的数据存储到多个可视化存储后端中。为了统一接口调用MMEngine 提供了统一的抽象类 `BaseVisBackend`,和一些常用的 VisBackend 如 `LocalVisBackend`、`WandbVisBackend` 和 `TensorboardVisBackend`
### 接口说明
BaseVisBackend 定义了对外调用的接口规范,主要接口和属性如下:
- [add_config](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.vis_backend.BaseVisBackend.add_config) 写配置到特定存储后端
- [add_graph](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.vis_backend.BaseVisBackend.add_graph) 写模型图到特定后端
- [add_image](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.vis_backend.BaseVisBackend.add_image) 写图片到特定后端
- [add_scalar](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.vis_backend.BaseVisBackend.add_scalar) 写标量到特定后端
- [add_scalars](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.vis_backend.BaseVisBackend.add_scalars) 一次性写多个标量到特定后端
- [close](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.vis_backend.BaseVisBackend.close) 关闭已经打开的资源
- [experiment](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.vis_backend.BaseVisBackend.experiment) 写后端对象,例如 Wandb 对象和 Tensorboard 对象
`BaseVisBackend` 定义了 5 个常见的写数据接口,考虑到某些写后端功能非常强大,例如 Wandb其具备写表格写视频等等功能针对这类需求用户可以直接获取 experiment 对象,然后调用写后端对象本身的 API 即可。而 `LocalVisBackend`、`WandbVisBackend` 和 `TensorboardVisBackend` 等都是继承自 `BaseVisBackend`,并根据自身特性实现了对应的存储功能。
### 使用案例
一般情况下用户无需操作 VisBackend 对象,只有在当前可视化存储无法满足需求时候,用户会希望直接操作存储后端。以 Wandb 为例,其提供了非常丰富的存储格式,例如存储表格、存储权重等等接口。为了所有后端能够统一接口,我们并没有提供这类常用接口,此时用户可以直接获取 Wandb 对象进行自定义存储。
```python
vis_backends = [dict(type='WandbVisBackend')]
visualizer = dict(
type='DetWandbVisualizer', vis_backends=vis_backends, name='visualizer')
```
```python
# 内部会调用 get_instance() 进行全局唯一实例化
VISUALIZERS.build(cfg.visualizer)
# 任意代码位置获取 visualizer
visualizer = Visualizer.get_current_instance()
# 扩展 add 功能,例如利用 Wandb 对象绘制表格
wandb = visualizer.get_backend('WandbVisBackend').experiment
val_table = wandb.Table(data=my_data, columns=column_names)
wandb.log({'my_val_table': val_table})
```
一个 visualizer 对象可以接入任意多个 VisBackend。为了方便用户获取任意的 VisBackend在不指定 name 参数情况下,可以通过类名获取
```python
vis_backends = [dict(type='LocalVisBackend'), dict(type='WandbVisBackend')]
visualizer = dict(
type='DetLocalVisualizer', vis_backends=vis_backends, name='visualizer')
```
```python
# 内部会调用 get_instance() 进行全局唯一实例化
VISUALIZERS.build(cfg.visualizer)
# 任意代码位置获取 visualizer
visualizer = Visualizer.get_current_instance()
local_vis_backend = visualizer.get_backend('LocalVisBackend')
wandb_vis_backend = visualizer.get_backend('WandbVisBackend')
```
当存在多个同名的 VisBackend 时候,用户必须指定唯一的 name 参数,后续可以通过 name 字符串来获取
```python
vis_backends = [dict(type='LocalVisBackend', name='local_vis_backend_1'), dict(type='LocalVisBackend', name='local_vis_backend_2')]
visualizer = dict(
type='DetLocalVisualizer', vis_backends=vis_backends, name='visualizer')
```
```python
# 内部会调用 get_instance() 进行全局唯一实例化
VISUALIZERS.build(cfg.visualizer)
# 任意代码位置获取 visualizer
visualizer = Visualizer.get_current_instance()
local_vis_backend_1 = visualizer.get_backend('local_vis_backend_1')
local_vis_backend_2 = visualizer.get_backend('local_vis_backend_2')
```

View File

@ -358,11 +358,11 @@ class Hook:
"""
return (runner.epoch + 1) % n == 0 if n > 0 else False
def every_n_inner_iters(self, inner_iter: int, n: int) -> bool:
def every_n_inner_iters(self, batch_idx: int, n: int) -> bool:
"""Test whether current inner iteration can be evenly divided by n.
Args:
inner_iter (int): Current inner_iter of the training, validation
batch_idx (int): Current batch index of the training, validation
or testing loop.
n (int): Whether current inner iteration can be evenly
divided by n.
@ -371,7 +371,7 @@ class Hook:
bool: Whether current inner iteration can be evenly
divided by n.
"""
return (inner_iter + 1) % n == 0 if n > 0 else False
return (batch_idx + 1) % n == 0 if n > 0 else False
def every_n_iters(self, runner, n: int) -> bool:
"""Test whether current iteration can be evenly divided by n.
@ -395,7 +395,6 @@ class Hook:
dataloader (Dataloader): The dataloader of the training,
validation or testing process.
batch_idx (int): The index of the current batch in the loop.
Returns:
bool: Whether reaches the end of current epoch or not.
"""
@ -418,10 +417,10 @@ class Hook:
Args:
runner (Runner): The runner of the training, validation or testing
process.
mode (str): Current mode of runner. Defaults to 'train'.
Returns:
bool: Whether current iteration is the last iteration.
mode (str): Current mode of runner. Defaults to 'train'.
"""
if mode == 'train':
return runner.iter + 1 == runner.train_loop.max_iters

View File

@ -18,11 +18,25 @@ class IterTimerHook(Hook):
priority = 'NORMAL'
def _before_epoch(self, runner, mode: str = 'train') -> None:
"""Record time flag before start a epoch.
def __init__(self):
self.time_sec_tot = 0
self.start_iter = 0
def before_run(self, runner) -> None:
"""Synchronize the number of iterations with the runner.
Args:
runner (Runner): The runner of the training process.
runner: The runner of the training, validation or testing
process.
"""
self.start_iter = runner.iter
def _before_epoch(self, runner, mode: str = 'train') -> None:
"""Record timestamp before start an epoch.
Args:
runner (Runner): The runner of the training validation and
testing process.
mode (str): Current mode of runner. Defaults to 'train'.
"""
self.t = time.time()
@ -32,16 +46,18 @@ class IterTimerHook(Hook):
batch_idx: int,
data_batch: DATA_BATCH = None,
mode: str = 'train') -> None:
"""Logging time for loading data and update the time flag.
"""Calculating time for loading data and updating "data_time"
``HistoryBuffer`` of ``runner.message_hub``.
Args:
runner (Runner): The runner of the training process.
runner (Runner): The runner of the training, validation and
testing process.
batch_idx (int): The index of the current batch in the loop.
data_batch (Sequence[dict], optional): Data from dataloader.
Defaults to None.
mode (str): Current mode of runner. Defaults to 'train'.
"""
# TODO: update for new logging system
# Update data loading time in `runner.message_hub`.
runner.message_hub.update_scalar(f'{mode}/data_time',
time.time() - self.t)
@ -52,10 +68,12 @@ class IterTimerHook(Hook):
outputs: Optional[Union[dict,
Sequence[BaseDataElement]]] = None,
mode: str = 'train') -> None:
"""Logging time for a iteration and update the time flag.
"""Calculating time for an iteration and updating "time"
``HistoryBuffer`` of ``runner.message_hub``.
Args:
runner (Runner): The runner of the training process.
runner (Runner): The runner of the training validation and
testing process.
batch_idx (int): The index of the current batch in the loop.
data_batch (Sequence[dict], optional): Data from dataloader.
Defaults to None.
@ -63,7 +81,31 @@ class IterTimerHook(Hook):
to None.
mode (str): Current mode of runner. Defaults to 'train'.
"""
# TODO: update for new logging system
runner.message_hub.update_scalar(f'{mode}/time', time.time() - self.t)
# Update iteration time in `runner.message_hub`.
message_hub = runner.message_hub
message_hub.update_scalar(f'{mode}/time', time.time() - self.t)
self.t = time.time()
window_size = runner.log_processor.window_size
# Calculate eta every `window_size` iterations. Since test and val
# loop will not update runner.iter, use `every_n_innter_iters`to check
# the interval.
if self.every_n_inner_iters(batch_idx, window_size):
iter_time = message_hub.get_scalar(f'{mode}/time').mean(
window_size)
if mode == 'train':
self.time_sec_tot += iter_time * window_size
# Calculate average iterative time.
time_sec_avg = self.time_sec_tot / (
runner.iter - self.start_iter + 1)
# Calculate eta.
eta_sec = time_sec_avg * (
runner.train_loop.max_iters - runner.iter - 1)
runner.message_hub.update_info('eta', eta_sec)
else:
if mode == 'val':
cur_dataloader = runner.val_loop.dataloader
else:
cur_dataloader = runner.test_loop.dataloader
eta_sec = iter_time * (len(cur_dataloader) - batch_idx - 1)
runner.message_hub.update_info('eta', eta_sec)

View File

@ -1,14 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import datetime
import os
import os.path as osp
from collections import OrderedDict
from pathlib import Path
from typing import Optional, Sequence, Union
import torch
from mmengine.data import BaseDataElement
from mmengine.fileio import FileClient
from mmengine.hooks import Hook
from mmengine.registry import HOOKS
@ -19,33 +15,20 @@ DATA_BATCH = Optional[Sequence[dict]]
@HOOKS.register_module()
class LoggerHook(Hook):
"""In this logger hook, the information will be printed on the terminal and
saved in JSON file, tensorboard, wandb .etc.
"""Collect logs from different components of ``Runner`` and write them to
terminal, JSON file, tensorboard and wandb .etc.
``LoggerHook`` is used to record logs formatted by ``LogProcessor`` during
training/validation/testing phase. It is used to control following
behaviers:
- The frequency of logs update in terminal, local, tensorboad wandb.etc.
- The frequency of show experiment information in terminal.
- The work directory to save logs.
Args:
by_epoch (bool): Whether ``EpochBasedLoop`` is used.
Defaults to True.
interval (int): Logging interval (every k iterations).
Defaults to 10.
custom_keys (dict, optional): Defines the keys in the log and which
kinds of statistic methods should be used to log them.
- ``custom_keys`` contains multiple string-dict pairs. In each
string-dict pair, the string defines a key name in the log and the
dict is a config defines the statistic methods and corresponding
arguments used to log the value. For example,
``dict(loss=dict(method_name='mean', log_name='global_loss',
window_size='global'))`` which means the log key ``loss`` will be
counted as global mean and additionally logged as ``global_loss``.
If ``log_name`` is not defined in config dict, the original logged
key will be overwritten.
- The key in ``LoggerHook.fixed_smooth_keys`` cannot be overwritten
because ``time`` and ``iter_time`` will be used to calculate
estimated time of arrival. If you want to recount the time, you
should set ``log_name`` in corresponding values.
- For those statistic methods with the ``window_size`` argument,
if ``by_epoch`` is set to False, ``windows_size`` should not be
`epoch` to statistics log value by epoch.
ignore_last (bool): Ignore the log of last iterations in each epoch if
the number of remaining iterations is less than :attr:`interval`.
Defaults to True.
@ -70,64 +53,24 @@ class LoggerHook(Hook):
Defaults to None.
Examples:
>>> # `log_name` is defined, `loss_mean_window` will be an additional
>>> # record.
>>> logger_hook_cfg = dict(by_epoch=True,
>>> custom_keys=dict(
>>> loss=dict(
>>> log_name='loss_mean_window',
>>> method_name='mean',
>>> window_size=10)))
>>> # `log_name` is not defined. `loss` will be overwritten by
>>> # `global_mean` statistics.
>>> logger_hook_cfg = dict(by_epoch=True,
>>> custom_keys=dict(
>>> loss=dict(
>>> method_name='mean',
>>> window_size='global')))
>>> # `time` cannot be overwritten, `global_time` will be an additional
>>> # record.
>>> logger_hook_cfg = dict(by_epoch=True,
>>> custom_keys=dict(
>>> time=dict(
>>> log_name='global_time',
>>> method='mean',
>>> window_size='global')))
>>> # Record loss with different statistics methods.
>>> logger_hook_cfg = dict(by_epoch=True,
>>> custom_keys=dict(loss=[
>>> dict(log_name='loss_mean_window',
>>> method_name='mean',
>>> window_size=10),
>>> dict(method_name='mean',
>>> window_size='global')]))
>>> # A simplest LoggerHook config.
>>> logger_hook_cfg = dict(interval=20)
"""
# eta will be calculated by time. `time` and `data_time` should not be
# overwritten.
fixed_smooth_keys = ('time', 'data_time')
priority = 'BELOW_NORMAL'
def __init__(
self,
by_epoch: bool = True,
interval: int = 10,
custom_keys: Optional[dict] = None,
ignore_last: bool = True,
interval_exp_name: int = 1000,
out_dir: Optional[Union[str, Path]] = None,
out_suffix: Union[Sequence[str], str] = ('.log.json', '.log', '.py'),
keep_local=True,
file_client_args=None,
keep_local: bool = True,
file_client_args: Optional[dict] = None,
):
self._inner_iter = 0
self.by_epoch = by_epoch
self.interval = interval
self.custom_keys = custom_keys if custom_keys is not None else dict()
self.ignore_last = ignore_last
self.time_sec_tot = 0
self.interval_exp_name = interval_exp_name
self._check_custom_keys()
if out_dir is None and file_client_args is not None:
raise ValueError(
@ -165,14 +108,15 @@ class LoggerHook(Hook):
self.json_log_path = osp.join(runner.work_dir,
f'{runner.timestamp}.log.json')
self.start_iter = runner.iter
self.yaml_log_path = osp.join(runner.work_dir,
f'{runner.timestamp}.log.json')
def after_train_iter(self,
runner,
batch_idx: int,
data_batch: DATA_BATCH = None,
outputs: Optional[dict] = None) -> None:
"""Record training logs.
"""Record training logs after training iteration.
Args:
runner (Runner): The runner of the training process.
@ -182,33 +126,90 @@ class LoggerHook(Hook):
outputs (dict, optional): Outputs from model.
Defaults to None.
"""
self._inner_iter = batch_idx
if runner.meta is not None and 'exp_name' in runner.meta:
if (self.every_n_iters(runner, self.interval_exp_name)) or (
self.by_epoch and self.end_of_epoch(
runner.train_loop.dataloader, batch_idx)):
exp_info = f'Exp name: {runner.meta["exp_name"]}'
runner.logger.info(exp_info)
if self.by_epoch and self.every_n_inner_iters(batch_idx,
self.interval):
self._log_train(runner)
elif not self.by_epoch and self.every_n_iters(runner, self.interval):
self._log_train(runner)
elif self.end_of_epoch(runner.train_loop.dataloader,
batch_idx) and not self.ignore_last:
# Print experiment name every n iterations.
if self.every_n_iters(runner,
self.interval_exp_name) or (self.end_of_epoch(
runner.train_dataloader, batch_idx)):
exp_info = f'Exp name: {runner.experiment_name}'
runner.logger.info(exp_info)
if self.every_n_inner_iters(batch_idx, self.interval):
tag, log_str = runner.log_processor.get_log_after_iter(
runner, batch_idx, 'train')
elif (self.end_of_epoch(runner.train_dataloader, batch_idx)
and not self.ignore_last):
# `runner.max_iters` may not be divisible by `self.interval`. if
# `self.ignore_last==True`, the log of remaining iterations will
# be recorded (Epoch [4][1000/1007], the logs of 998-1007
# iterations will be recorded).
self._log_train(runner)
tag, log_str = runner.log_processor.get_log_after_iter(
runner, batch_idx, 'train')
else:
return
runner.logger.info(log_str)
# TODO compatible with visualizer.
runner.visualizer.add_scalars(tag, step=runner.iter + 1)
def after_val_iter(
self,
runner,
batch_idx: int,
data_batch: DATA_BATCH = None,
outputs: Optional[Sequence[BaseDataElement]] = None) -> None:
"""Record validation logs after validation iteration.
Args:
runner (Runner): The runner of the training process.
batch_idx (int): The index of the current batch in the train loop.
data_batch (Sequence[Tuple[Any, BaseDataElement]], optional):
Data from dataloader. Defaults to None.
outputs (sequence, optional): Outputs from model. Defaults to None.
"""
if self.every_n_inner_iters(batch_idx, self.interval):
tag, log_str = runner.log_processor.get_log_after_iter(
runner, batch_idx, 'val')
runner.logger.info(log_str)
def after_test_iter(
self,
runner,
batch_idx: int,
data_batch: DATA_BATCH = None,
outputs: Optional[Sequence[BaseDataElement]] = None) -> None:
"""Record testing logs after iteration.
Args:
runner (Runner): The runner of the training process.
batch_idx (int): The index of the current batch in the train loop.
data_batch (Sequence[Tuple[Any, BaseDataElement]], optional):
Data from dataloader. Defaults to None.
outputs (sequence, optional): Outputs from model. Defaults to None.
"""
if self.every_n_inner_iters(batch_idx, self.interval):
tag, log_str = runner.log_processor.get_log_after_iter(
runner, batch_idx, 'test')
runner.logger.info(log_str)
def after_val_epoch(self, runner) -> None:
"""Record validation logs.
"""Record validation logs after validation epoch.
Args:
runner (Runner): The runner of the training process.
"""
self._log_val(runner)
tag, log_str = runner.log_processor.get_log_after_epoch(
runner, len(runner.val_dataloader), 'val')
runner.logger.info(log_str)
# TODO compatible with visualizer.
runner.visualizer.add_scalars(tag, step=runner.iter + 1)
def after_test_epoch(self, runner) -> None:
"""Record testing logs after test epoch.
Args:
runner (Runner): The runner of the training process.
"""
tag, log_str = runner.log_processor.get_log_after_epoch(
runner, len(runner.val_dataloader), 'test')
runner.logger.info(log_str)
def after_run(self, runner) -> None:
"""Copy logs to ``self.out_dir`` if ``self.out_dir is not None``
@ -233,278 +234,3 @@ class LoggerHook(Hook):
os.remove(local_filepath)
runner.logger.info((f'{local_filepath} was removed due to the '
'`self.keep_local=False`'))
def _log_train(self, runner) -> None:
"""Collect and record training logs which start named with "train/*".
Args:
runner (Runner): The runner of the training process.
"""
tag = self._collect_info(runner, 'train')
# The training log default defines `lr`, `momentum`, `time` and
# `data_time`. `log_tag` will pop these keys and loop other keys to
# `log_str`.
log_tag = copy.deepcopy(tag)
cur_iter = self._get_iter(runner, inner_iter=True)
cur_epoch = self._get_epoch(runner, 'train')
# Record learning rate and momentum.
lr_str_list = []
momentum_str_list = []
for key, value in tag.items():
if key.startswith('lr'):
log_tag.pop(key)
lr_str_list.append(f'{key}: {value:.3e}')
lr_str = ' '.join(lr_str_list)
for key, value in tag.items():
if key.startswith('momentum'):
log_tag.pop(key)
momentum_str_list.append(f'{key}: {value:.3e}')
momentum_str = ' '.join(momentum_str_list)
lr_momentum_str = f'{lr_str} {momentum_str}'
# by epoch: Epoch [4][100/1000]
# by iter: Iter [100/100000]
if self.by_epoch:
log_str = f'Epoch [{cur_epoch}]' \
f'[{cur_iter}/{len(runner.train_loop.dataloader)}] '
else:
log_str = f'Iter [{cur_iter}/{runner.train_loop.max_iters}] '
log_str += f'{lr_momentum_str}, '
# Calculate eta time.
self.time_sec_tot += (tag['time'] * self.interval)
time_sec_avg = self.time_sec_tot / (runner.iter - self.start_iter + 1)
eta_sec = time_sec_avg * (
runner.train_loop.max_iters - runner.iter - 1)
eta_str = str(datetime.timedelta(seconds=int(eta_sec)))
log_str += f'eta: {eta_str}, '
log_str += f'time: {tag["time"]:.3f}, ' \
f'data_time: {tag["data_time"]:.3f}, '
# Pop recorded keys
log_tag.pop('time')
log_tag.pop('data_time')
# statistic memory
if torch.cuda.is_available():
log_str += f'memory: {self._get_max_memory(runner)}, '
# Loop left keys to fill `log_str`.
log_items = []
for name, val in log_tag.items():
if isinstance(val, float):
val = f'{val:.4f}'
log_items.append(f'{name}: {val}')
log_str += ', '.join(log_items)
runner.logger.info(log_str)
# Write logs to local, tensorboad, and wandb.
runner.visualizer.add_scalars(
tag, step=runner.iter + 1, file_path=self.json_log_path)
def _log_val(self, runner) -> None:
"""Collect and record training logs which start named with "val/*".
Args:
runner (Runner): The runner of the training process.
"""
tag = self._collect_info(runner, 'val')
# Compatible with function `log` https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/hooks/logger/text.py # noqa E501
eval_iter = len(runner.val_loop.dataloader)
cur_iter = self._get_iter(runner)
cur_epoch = self._get_epoch(runner, 'val')
# val/test time
# here 1000 is the length of the val dataloader
# by epoch: Epoch[val] [4][1000]
# by iter: Iter[val] [1000]
if self.by_epoch:
# runner.epoch += 1 has been done before val workflow
log_str = f'Epoch(val) [{cur_epoch}][{eval_iter}] '
else:
log_str = f'Iter(val) [{eval_iter}] '
log_items = []
for name, val in tag.items():
if isinstance(val, float):
val = f'{val:.4f}'
log_items.append(f'{name}: {val}')
log_str += ', '.join(log_items)
runner.logger.info(log_str)
# Write tag.
runner.visualizer.add_scalars(
tag, step=cur_iter, file_path=self.json_log_path)
def _get_window_size(self, runner, window_size: Union[int, str]) \
-> int:
"""Parse window_size specified in ``self.custom_keys`` to int value.
Args:
runner (Runner): The runner of the training process.
window_size (int or str): Smoothing scale of logs.
Returns:
int: Smoothing window for statistical methods.
"""
if isinstance(window_size, int):
assert window_size == self.interval, \
'The value of windows size must equal to LoggerHook.interval'
return window_size
elif window_size == 'epoch':
return self._inner_iter + 1
elif window_size == 'global':
return runner.iter + 1
else:
raise ValueError('window_size should be int, epoch or global, but '
f'got invalid {window_size}')
def _collect_info(self, runner, mode: str) -> dict:
"""Collect log information to a dict according to mode.
Args:
runner (Runner): The runner of the training process.
mode (str): 'train' or 'val', which means the prefix attached by
runner.
Returns:
dict: Statistical values of logs.
"""
tag = OrderedDict()
log_buffers = runner.message_hub.log_scalars
mode_log_buffers = OrderedDict()
# Filter log_buffers which starts with `mode`.
for prefix_key, log_buffer in log_buffers.items():
if prefix_key.startswith(mode):
key = prefix_key.split('/')[-1]
mode_log_buffers[key] = log_buffer
# Ensure all metric and lr values are latest.
for key in mode_log_buffers:
# Update the latest learning rate and smoothed time logs.
if key in self.fixed_smooth_keys or key.startswith('loss'):
tag[key] = mode_log_buffers[key].mean(self.interval)
else:
tag[key] = mode_log_buffers[key].current()
# Update custom keys.
if mode == 'train':
for log_key, log_cfg in self.custom_keys.items():
self._parse_custom_keys(runner, log_key,
copy.deepcopy(log_cfg),
mode_log_buffers, tag)
return tag
def _parse_custom_keys(self, runner, log_key: str, log_cfg: dict,
log_buffers: OrderedDict, tag: OrderedDict) -> None:
"""Statistics logs in log_buffers according to custom_keys.
Args:
runner (Runner): The runner of the training process.
log_key (str): log key specified in ``self.custom_keys``
log_cfg (dict): A config dict for describing the logging
statistics method.
log_buffers (OrderedDict): All logs for the corresponding phase.
tag (OrderedDict): A dict which defines all statistic values of
logs.
"""
if isinstance(log_cfg, list):
log_names = set()
for cfg in log_cfg:
log_name = cfg.get('log_name', None)
if log_name in log_names:
raise KeyError(f'{cfg["log_name"]} cannot be redefined in '
'log_key')
if log_name is not None:
log_names.add(log_name)
self._parse_custom_keys(runner, log_key, cfg, log_buffers, tag)
assert len(log_names) == len(log_cfg) - 1, \
f'{log_key} cannot be overwritten multiple times, please ' \
f'check only one key does not contain `log_name` in {log_cfg}.'
elif isinstance(log_cfg, dict):
if 'window_size' in log_cfg:
log_cfg['window_size'] = \
self._get_window_size(runner, log_cfg['window_size'])
if 'log_name' in log_cfg:
name = log_cfg.pop('log_name')
else:
name = log_key
tag[name] = log_buffers[log_key].statistics(**log_cfg).item()
else:
raise ValueError('The structure of `LoggerHook.custom key` is '
'wrong, please make sure the type of each key is '
'dict or list.')
def _get_max_memory(self, runner) -> int:
"""Returns the maximum GPU memory occupied by tensors in megabytes (MB)
for a given device.
Args:
runner (Runner): The runner of the training process.
Returns:
The maximum GPU memory occupied by tensors in megabytes for a given
device.
"""
device = getattr(runner.model, 'output_device', None)
mem = torch.cuda.max_memory_allocated(device=device)
mem_mb = torch.tensor([int(mem) // (1024 * 1024)],
dtype=torch.int,
device=device)
torch.cuda.reset_peak_memory_stats()
return int(mem_mb.item())
def _check_custom_keys(self) -> None:
"""Check the legality of ``self.custom_keys``.
If ``self.by_epoch==False``, ``window_size`` should not be "epoch". The
key of ``self.fixed_smooth_keys`` cannot be overwritten.
"""
def _check_window_size(item):
if not self.by_epoch:
assert item['window_size'] != 'epoch', \
'window_size cannot be epoch if LoggerHook.by_epoch is ' \
'False.'
def _check_fixed_keys(key, item):
if key in self.fixed_smooth_keys:
assert 'log_name' in item, f'{key} cannot be overwritten by ' \
'custom keys!'
for key, value in self.custom_keys.items():
if isinstance(value, Sequence):
[(_check_window_size(item), _check_fixed_keys(key, item))
for item in value]
else:
_check_window_size(value)
_check_fixed_keys(key, value)
def _get_epoch(self, runner, mode: str) -> int:
"""Get epoch according to mode.
Args:
runner (Runner): The runner of the training process.
mode (str): Train or val.
Returns:
int: The current epoch.
"""
if mode == 'train':
epoch = runner.epoch + 1
elif mode == 'val':
# normal val mode
# runner.epoch += 1 has been done before val workflow
epoch = runner.epoch
else:
raise ValueError(f"runner mode should be 'train' or 'val', "
f'but got {runner.mode}')
return epoch
def _get_iter(self, runner, inner_iter=False) -> int:
"""Get the current training iteration step.
Args:
runner (Runner): The runner of the training process.
inner_iter (bool): Whether to return the inner iter of an epoch.
Defaults to False.
Returns:
int: The current global iter or inner iter.
"""
if self.by_epoch and inner_iter:
current_iter = self._inner_iter + 1
else:
current_iter = runner.iter + 1
return current_iter

View File

@ -84,6 +84,9 @@ class OptimizerHook(Hook):
we keep ``outputs`` here. Defaults to None.
"""
runner.optimizer.zero_grad()
runner.message_hub.update_scalar(
'train/lr', runner.optimizer.param_groups[0]['lr'])
if self.detect_anomalous_params:
self.detect_anomalous_parameters(runner.outputs['loss'], runner)
runner.outputs['loss'].backward()

View File

@ -1,6 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .history_buffer import HistoryBuffer
from .log_processor import LogProcessor
from .logger import MMLogger, print_log
from .message_hub import MessageHub
__all__ = ['HistoryBuffer', 'MessageHub', 'MMLogger', 'print_log']
__all__ = [
'HistoryBuffer', 'MessageHub', 'MMLogger', 'print_log', 'LogProcessor'
]

View File

@ -0,0 +1,409 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import datetime
from collections import OrderedDict
from typing import List, Optional, Tuple
import torch
class LogProcessor:
"""A log processor used to format log information collected from
``runner.message_hub.log_scalars``.
``LogProcessor`` instance is built by runner and will format
``runner.message_hub.log_scalars`` to ``tag`` and ``log_str``, which can
directly used by ``LoggerHook`` and ``MMLogger``. Besides, the argument
``custom_cfg`` of constructor can control the statistics method of logs.
Args:
window_size (int): default smooth interval Defaults to 10.
by_epoch (bool): Whether to format logs with epoch stype. Defaults to
True.
custom_cfg (list[dict], optional): Contains multiple log config dict,
in which key means the data source name of log and value means the
statistic method and corresponding arguments used to count the
data source. Defaults to None
- If custom_cfg is None, all logs will be formatted via default
methods, such as smoothing loss by default window_size. If
custom_cfg is defined as a list of config dict, for example:
[dict(data_src=loss, method='mean', log_name='global_loss',
window_size='global')]. It means the log item ``loss`` will be
counted as global mean and additionally logged as ``global_loss``
(defined by ``log_name``). If ``log_name`` is not defined in
config dict, the original logged key will be overwritten.
- The original log item cannot be overwritten twice. Here is
an error example:
[dict(data_src=loss, method='mean', window_size='global'),
dict(data_src=loss, method='mean', window_size='epoch')].
Both log config dict in custom_cfg do not have ``log_name`` key,
which means the loss item will be overwritten twice.
- For those statistic methods with the ``window_size`` argument,
if ``by_epoch`` is set to False, ``windows_size`` should not be
`epoch` to statistics log value by epoch.
Examples:
>>> # `log_name` is defined, `loss_large_window` will be an additional
>>> # record.
>>> log_processor = dict(
>>> window_size=10,
>>> by_epoch=True,
>>> custom_cfg=[dict(data_src='loss',
>>> log_name='loss_large_window',
>>> method_name='mean',
>>> window_size=100)])
>>> # `log_name` is not defined. `loss` will be overwritten.
>>> log_processor = dict(
>>> window_size=10,
>>> by_epoch=True,
>>> custom_cfg=[dict(data_src='loss',
>>> method_name='mean',
>>> window_size=100)])
>>> # Record loss with different statistics methods.
>>> log_processor = dict(
>>> window_size=10,
>>> by_epoch=True,
>>> custom_cfg=[dict(data_src='loss',
>>> log_name='loss_large_window',
>>> method_name='mean',
>>> window_size=100),
>>> dict(data_src='loss',
>>> method_name='mean',
>>> window_size=100)])
>>> # Overwrite loss item twice will raise an error.
>>> log_processor = dict(
>>> window_size=10,
>>> by_epoch=True,
>>> custom_cfg=[dict(data_src='loss',
>>> method_name='mean',
>>> window_size=100),
>>> dict(data_src='loss',
>>> method_name='max',
>>> window_size=100)])
AssertionError
"""
def __init__(self,
window_size=10,
by_epoch=True,
custom_cfg: Optional[List[dict]] = None):
self.window_size = window_size
self.by_epoch = by_epoch
self.custom_cfg = custom_cfg if custom_cfg else []
self._check_custom_cfg()
def get_log_after_iter(self, runner, batch_idx: int,
mode: str) -> Tuple[dict, str]:
"""Format log string after training, validation or testing epoch.
Args:
runner (Runner): The runner of training phase.
batch_idx (int): The index of the current batch in the current
loop.
mode (str): Current mode of runner, train, test or val.
Return:
Tuple(dict, str): Formatted log dict/string which will be
recorded by :obj:`runner.message_hub` and :obj:`runner.visualizer`.
"""
assert mode in ['train', 'test', 'val']
current_loop = self._get_cur_loop(runner, mode)
cur_iter = self._get_iter(runner, batch_idx=batch_idx)
# Overwrite ``window_size`` defined in ``custom_cfg`` to int value.
custom_cfg_copy = self._parse_windows_size(runner, batch_idx)
# tag is used to write log information to different backends.
tag = self._collect_scalars(custom_cfg_copy, runner, mode)
# `log_tag` will pop 'lr' and loop other keys to `log_str`.
log_tag = copy.deepcopy(tag)
# Record learning rate.
lr_str_list = []
for key, value in tag.items():
if key.startswith('lr'):
log_tag.pop(key)
lr_str_list.append(f'{key}: {value:.3e}')
lr_str = ' '.join(lr_str_list)
# Format log header.
# by_epoch == True
# train/val: Epoch [5][5/10] ...
# test: Epoch [5/10]
# by_epoch == False
# train: Epoch [5/10000] ... (divided by `max_iter`)
# val/test: Epoch [5/2000] ... (divided by length of dataloader)
if self.by_epoch:
if mode in ['train', 'val']:
cur_epoch = self._get_epoch(runner, mode)
log_str = (f'Epoch({mode}) [{cur_epoch}]'
f'[{cur_iter}/{len(current_loop.dataloader)}] ')
else:
log_str = (f'Epoch({mode}) '
f'[{cur_iter}/{len(current_loop.dataloader)}] ')
else:
if mode == 'train':
log_str = (f'Iter({mode}) '
f'[{cur_iter}/{runner.train_loop.max_iters}] ')
else:
log_str = (f'Iter({mode}) [{batch_idx+1}'
f'/{len(current_loop.dataloader)}] ')
# Concatenate lr, momentum string with log header.
log_str += f'{lr_str} '
# If IterTimerHook used in runner, eta, time, and data_time should be
# recorded.
if (all(item in tag for item in ['time', 'data_time'])
and 'eta' in runner.message_hub.runtime_info):
eta = runner.message_hub.get_info('eta')
eta_str = str(datetime.timedelta(seconds=int(eta)))
log_str += f'eta: {eta_str} '
log_str += (f'time: {tag["time"]:.3f} '
f'data_time: {tag["data_time"]:.3f} ')
# Pop recorded keys
log_tag.pop('time')
log_tag.pop('data_time')
# If cuda is available, the max memory occupied should be calculated.
if torch.cuda.is_available():
log_str += f'memory: {self._get_max_memory(runner)} '
# Loop left keys to fill `log_str`.
if mode in ('train', 'val'):
log_items = []
for name, val in log_tag.items():
if mode == 'val' and not name.startswith('val/loss'):
continue
if isinstance(val, float):
val = f'{val:.4f}'
log_items.append(f'{name}: {val}')
log_str += ' '.join(log_items)
return tag, log_str
def get_log_after_epoch(self, runner, batch_idx: int,
mode: str) -> Tuple[dict, str]:
"""Format log string after validation or testing epoch.
Args:
runner (Runner): The runner of training phase.
batch_idx (int): The index of the current batch in the current
loop.
mode (str): Current mode of runner.
Return:
Tuple(dict, str): Formatted log dict/string which will be
recorded by :obj:`runner.message_hub` and :obj:`runner.visualizer`.
"""
assert mode in [
'test', 'val'
], ('`_get_metric_log_str` only accept val or test mode, but got '
f'{mode}')
cur_loop = self._get_cur_loop(runner, mode)
dataloader_len = len(cur_loop.dataloader)
custom_cfg_copy = self._parse_windows_size(runner, batch_idx)
# tag is used to write log information to different backends.
tag = self._collect_scalars(custom_cfg_copy, runner, mode)
# validation log string needs cur epoch/iteration and max
# epochs/iterations. test log string only needs length of test
# dataloader.
cur_iter = self._get_iter(runner, batch_idx)
if self.by_epoch:
if mode == 'val':
cur_epoch = self._get_epoch(runner, mode)
log_str = (f'Epoch({mode}) [{cur_epoch}][{dataloader_len}/'
f'{dataloader_len}] ')
else:
log_str = (
f'Epoch({mode}) [{dataloader_len}/{dataloader_len}] ')
else:
if mode == 'train':
log_str = (f'Iter({mode}) [{cur_iter}/'
f'{runner.train_loop.max_iters}] ')
else:
log_str = (
f'Iter({mode}) [{dataloader_len}/{dataloader_len}] ')
log_items = []
for name, val in tag.items():
if name in ('time', 'data_time'):
continue
if isinstance(val, float):
val = f'{val:.4f}'
log_items.append(f'{name}: {val}')
log_str += ' '.join(log_items)
return tag, log_str
def _collect_scalars(self, custom_cfg: List[dict], runner,
mode: str) -> dict:
"""Collect log information to compose a dict according to mode.
Args:
custom_cfg (List[dict]): A copy of ``self.custom_cfg`` with int
``window_size``.
runner (Runner): The runner of the training process.
mode (str): 'train' or 'val', which means the prefix attached by
runner.
Returns:
dict: Statistical values of logs.
"""
tag = OrderedDict()
# history_scalars of train/val/test phase.
history_scalars = runner.message_hub.log_scalars
# corresponding mode history_scalars
mode_history_scalars = OrderedDict()
# extract log scalars and remove prefix to `mode_history_scalars`
# according to mode.
for prefix_key, log_buffer in history_scalars.items():
if prefix_key.startswith(mode):
key = prefix_key.split('/')[-1]
mode_history_scalars[key] = log_buffer
for key in mode_history_scalars:
# Update the latest learning rate and smoothed time logs.
if key.startswith('loss'):
tag[key] = mode_history_scalars[key].mean(self.window_size)
else:
# Default statistic method is current.
tag[key] = mode_history_scalars[key].current()
# Update custom keys.
for log_cfg in custom_cfg:
data_src = log_cfg.pop('data_src')
if 'log_name' in log_cfg:
log_name = log_cfg.pop('log_name')
else:
log_name = data_src
# log item in custom_cfg could only exist in train or val
# mode.
if data_src in mode_history_scalars:
tag[log_name] = mode_history_scalars[data_src].statistics(
**log_cfg)
return tag
def _check_custom_cfg(self) -> None:
"""Check the legality of ``self.custom_cfg``."""
def _check_window_size():
for log_cfg in self.custom_cfg:
if not self.by_epoch:
assert log_cfg['window_size'] != 'epoch', \
'window_size cannot be epoch if LoggerHook.by_epoch' \
' is False.'
def _check_repeated_log_name():
check_dict = dict()
# The `log_name` of the same data_src should not be repeated.
# If `log_name` is not specified, `data_src` will be overwritten.
# But only allowed to be overwritten once.
for log_cfg in self.custom_cfg:
assert 'data_src' in log_cfg
data_src = log_cfg['data_src']
log_name = log_cfg.get('log_name', data_src)
check_dict.setdefault(data_src,
dict(log_names=set(), log_counts=0))
check_dict[data_src]['log_names'].add(log_name)
check_dict[data_src]['log_counts'] += 1
assert (len(
check_dict[data_src]
['log_names']) == check_dict[data_src]['log_counts']), (
f'If you want to statistic {data_src} with multiple '
'statistics method, please check `log_name` is unique'
f'and {data_src} will not be overwritten twice. See '
f'more information in the docstring of `LogProcessor`')
_check_repeated_log_name()
_check_window_size()
def _parse_windows_size(self, runner, batch_idx: int) -> list:
"""Parse window_size defined in custom_cfg to int value.
Args:
runner (Runner): The runner of the training process.
batch_idx (int): The iteration index of current dataloader.
"""
custom_cfg_copy = copy.deepcopy(self.custom_cfg)
for log_cfg in custom_cfg_copy:
window_size = log_cfg.get('window_size', None)
if window_size is None or isinstance(window_size, int):
continue
elif window_size == 'epoch':
log_cfg['window_size'] = batch_idx + 1
elif window_size == 'global':
log_cfg['window_size'] = runner.iter + 1
else:
raise TypeError(
'window_size should be int, epoch or global, but got '
f'invalid {window_size}')
return custom_cfg_copy
def _get_max_memory(self, runner) -> int:
"""Returns the maximum GPU memory occupied by tensors in megabytes (MB)
for a given device.
Args:
runner (Runner): The runner of the training process.
Returns:
The maximum GPU memory occupied by tensors in megabytes for a given
device.
"""
device = getattr(runner.model, 'output_device', None)
mem = torch.cuda.max_memory_allocated(device=device)
mem_mb = torch.tensor([int(mem) // (1024 * 1024)],
dtype=torch.int,
device=device)
torch.cuda.reset_peak_memory_stats()
return int(mem_mb.item())
def _get_iter(self, runner, batch_idx: int = None) -> int:
"""Get current training iteration step.
Args:
runner (Runner): The runner of the training process.
batch_idx (int, optional): The interaction index of current
dataloader. Defaults to None.
Returns:
int: The current global iter or inner iter.
"""
if self.by_epoch and batch_idx:
current_iter = batch_idx + 1
else:
current_iter = runner.iter + 1
return current_iter
def _get_epoch(self, runner, mode: str) -> int:
"""Get current epoch according to mode.
Args:
runner (Runner): The runner of the training/validation process.
mode (str): Current mode of runner, "train" or "val".
Returns:
int: The current epoch.
"""
if mode == 'train':
epoch = runner.epoch + 1
elif mode == 'val':
# normal val mode
# runner.epoch += 1 has been done before validation
epoch = runner.epoch
else:
raise ValueError(
f"runner mode should be 'train' or 'val', but got {mode}")
return epoch
def _get_cur_loop(self, runner, mode: str):
"""Get current loop according to mode.
Args:
runner (Runner): The runner of the training/validation/testing
process.
mode (str): Current mode of runner, "train", "val" or test.
Returns:
BaseLoop: Current loop of runner.
"""
# returns type hint will occur circular import
if mode == 'train':
return runner.train_loop
elif mode == 'val':
return runner.val_loop
else:
return runner.test_loop

View File

@ -32,15 +32,15 @@ class MMFormatter(logging.Formatter):
info_prefix = self._get_prefix('INFO', color)
debug_prefix = self._get_prefix('DEBUG', color)
# Config output format.
self.err_format = f'%(asctime)s - %(name)s - {error_prefix} - ' \
f'%(pathname)s - %(funcName)s - %(lineno)d - ' \
'%(message)s'
self.warn_format = f'%(asctime)s - %(name)s - {warn_prefix} - %(' \
'message)s'
self.info_format = f'%(asctime)s - %(name)s - {info_prefix} - %(' \
'message)s'
self.debug_format = f'%(asctime)s - %(name)s - {debug_prefix} - %(' \
'message)s'
self.err_format = (f'%(asctime)s - %(name)s - {error_prefix} - '
'%(pathname)s - %(funcName)s - %(lineno)d - '
'%(message)s')
self.warn_format = (f'%(asctime)s - %(name)s - {warn_prefix} - %('
'message)s')
self.info_format = (f'%(asctime)s - %(name)s - {info_prefix} - %('
'message)s')
self.debug_format = (f'%(asctime)s - %(name)s - {debug_prefix} - %('
'message)s')
def _get_prefix(self, level: str, color: bool) -> str:
"""Get the prefix of the target log level.

View File

@ -1,14 +1,14 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .lr_scheduler import (ConstantLR, CosineAnnealingLR, ExponentialLR,
LinearLR, MultiStepLR, StepLR)
LinearLR, MultiStepLR, PolyLR, StepLR)
from .momentum_scheduler import (ConstantMomentum, CosineAnnealingMomentum,
ExponentialMomentum, LinearMomentum,
MultiStepMomentum, StepMomentum)
MultiStepMomentum, PolyMomentum, StepMomentum)
from .param_scheduler import (ConstantParamScheduler,
CosineAnnealingParamScheduler,
ExponentialParamScheduler, LinearParamScheduler,
MultiStepParamScheduler, StepParamScheduler,
_ParamScheduler)
MultiStepParamScheduler, PolyParamScheduler,
StepParamScheduler, _ParamScheduler)
__all__ = [
'ConstantLR', 'CosineAnnealingLR', 'ExponentialLR', 'LinearLR',
@ -16,5 +16,6 @@ __all__ = [
'ExponentialMomentum', 'LinearMomentum', 'MultiStepMomentum',
'StepMomentum', 'ConstantParamScheduler', 'CosineAnnealingParamScheduler',
'ExponentialParamScheduler', 'LinearParamScheduler',
'MultiStepParamScheduler', 'StepParamScheduler', '_ParamScheduler'
'MultiStepParamScheduler', 'StepParamScheduler', '_ParamScheduler',
'PolyParamScheduler', 'PolyLR', 'PolyMomentum'
]

View File

@ -7,7 +7,8 @@ from mmengine.registry import PARAM_SCHEDULERS
from .param_scheduler import (INF, ConstantParamScheduler,
CosineAnnealingParamScheduler,
ExponentialParamScheduler, LinearParamScheduler,
MultiStepParamScheduler, StepParamScheduler)
MultiStepParamScheduler, PolyParamScheduler,
StepParamScheduler)
@PARAM_SCHEDULERS.register_module()
@ -294,3 +295,49 @@ class StepLR(StepParamScheduler):
last_step=last_step,
by_epoch=by_epoch,
verbose=verbose)
@PARAM_SCHEDULERS.register_module()
class PolyLR(PolyParamScheduler):
"""Decays the learning rate of each parameter group in a polynomial decay
scheme.
Notice that such decay can happen simultaneously with other changes to the
parameter value from outside this scheduler.
Args:
optimizer (Optimizer): Wrapped optimizer.
eta_min (float): Minimum learning rate at the end of scheduling.
Defaults to 0.
power (float): The power of the polynomial. Defaults to 1.0.
begin (int): Step at which to start updating the parameters.
Defaults to 0.
end (int): Step at which to stop updating the parameters.
Defaults to INF.
last_step (int): The index of last step. Used for resume without
state dict. Defaults to -1.
by_epoch (bool): Whether the scheduled parameters are updated by
epochs. Defaults to True.
verbose (bool): Whether to print the value for each update.
Defaults to False.
"""
def __init__(self,
optimizer: torch.optim.Optimizer,
eta_min: float = 0,
power: float = 1,
begin: int = 0,
end: int = INF,
last_step: int = -1,
by_epoch: bool = True,
verbose: bool = False):
super().__init__(
optimizer,
param_name='lr',
eta_min=eta_min,
power=power,
begin=begin,
end=end,
last_step=last_step,
by_epoch=by_epoch,
verbose=verbose)

View File

@ -7,7 +7,8 @@ from mmengine.registry import PARAM_SCHEDULERS
from .param_scheduler import (INF, ConstantParamScheduler,
CosineAnnealingParamScheduler,
ExponentialParamScheduler, LinearParamScheduler,
MultiStepParamScheduler, StepParamScheduler)
MultiStepParamScheduler, PolyParamScheduler,
StepParamScheduler)
@PARAM_SCHEDULERS.register_module()
@ -294,3 +295,49 @@ class StepMomentum(StepParamScheduler):
last_step=last_step,
by_epoch=by_epoch,
verbose=verbose)
@PARAM_SCHEDULERS.register_module()
class PolyMomentum(PolyParamScheduler):
"""Decays the momentum of each parameter group in a polynomial decay
scheme.
Notice that such decay can happen simultaneously with other changes to the
parameter value from outside this scheduler.
Args:
optimizer (Optimizer): Wrapped optimizer.
eta_min (float): Minimum momentum at the end of scheduling.
Defaults to 0.
power (float): The power of the polynomial. Defaults to 1.0.
begin (int): Step at which to start updating the parameters.
Defaults to 0.
end (int): Step at which to stop updating the parameters.
Defaults to INF.
last_step (int): The index of last step. Used for resume without
state dict. Defaults to -1.
by_epoch (bool): Whether the scheduled parameters are updated by
epochs. Defaults to True.
verbose (bool): Whether to print the value for each update.
Defaults to False.
"""
def __init__(self,
optimizer: torch.optim.Optimizer,
eta_min: float = 0,
power: float = 1,
begin: int = 0,
end: int = INF,
last_step: int = -1,
by_epoch: bool = True,
verbose: bool = False):
super().__init__(
optimizer,
param_name='momentum',
eta_min=eta_min,
power=power,
begin=begin,
end=end,
last_step=last_step,
by_epoch=by_epoch,
verbose=verbose)

View File

@ -534,6 +534,7 @@ class LinearParamScheduler(_ParamScheduler):
Notice that such decay can happen simultaneously with other changes to the
parameter value from outside this scheduler.
Args:
optimizer (Optimizer): Wrapped optimizer.
start_factor (float): The number we multiply parameter value in the
@ -598,3 +599,64 @@ class LinearParamScheduler(_ParamScheduler):
(self.end_factor - self.start_factor)))
for group in self.optimizer.param_groups
]
@PARAM_SCHEDULERS.register_module()
class PolyParamScheduler(_ParamScheduler):
"""Decays the parameter value of each parameter group in a polynomial decay
scheme.
Notice that such decay can happen simultaneously with other changes to the
parameter value from outside this scheduler.
Args:
optimizer (Optimizer): Wrapped optimizer.
eta_min (float): Minimum parameter value at the end of scheduling.
Defaults to 0.
power (float): The power of the polynomial. Defaults to 1.0.
begin (int): Step at which to start updating the parameters.
Defaults to 0.
end (int): Step at which to stop updating the parameters.
Defaults to INF.
last_step (int): The index of last step. Used for resume without
state dict. Defaults to -1.
by_epoch (bool): Whether the scheduled parameters are updated by
epochs. Defaults to True.
verbose (bool): Whether to print the value for each update.
Defaults to False.
"""
def __init__(self,
optimizer: Optimizer,
param_name: str,
eta_min: float = 0,
power: float = 1.0,
begin: int = 0,
end: int = INF,
last_step: int = -1,
by_epoch: bool = True,
verbose: bool = False):
self.eta_min = eta_min
self.power = power
self.total_iters = end - begin - 1
super().__init__(
optimizer,
param_name=param_name,
begin=begin,
end=end,
last_step=last_step,
by_epoch=by_epoch,
verbose=verbose)
def _get_value(self):
if self.last_step == 0:
return [
group[self.param_name] for group in self.optimizer.param_groups
]
return [(group[self.param_name] - self.eta_min) *
(1 - 1 / (self.total_iters - self.last_step + 1))**self.power +
self.eta_min for group in self.optimizer.param_groups]

View File

@ -25,7 +25,7 @@ from mmengine.dist import (broadcast, get_dist_info, init_dist, master_only,
sync_random_seed)
from mmengine.evaluator import Evaluator
from mmengine.hooks import Hook
from mmengine.logging import MessageHub, MMLogger
from mmengine.logging import LogProcessor, MessageHub, MMLogger
from mmengine.model import is_model_wrapper
from mmengine.optim import _ParamScheduler, build_optimizer
from mmengine.registry import (DATA_SAMPLERS, DATASETS, HOOKS, LOOPS,
@ -127,6 +127,8 @@ class Runner:
non-distributed environment will be launched.
env_cfg (dict): A dict used for setting environment. Defaults to
dict(dist_cfg=dict(backend='nccl')).
log_processor (dict, optional): A processor to format logs. Defaults to
None.
log_level (int or str): The log level of MMLogger handlers.
Defaults to 'INFO'.
visualizer (Visualizer or dict, optional): A Visualizer object or a
@ -151,43 +153,44 @@ class Runner:
Examples:
>>> from mmengine import Runner
>>> cfg = dict(
model=dict(type='ToyModel'),
work_dir='path/of/work_dir',
train_dataloader=dict(
dataset=dict(type='ToyDataset'),
sampler=dict(type='DefaultSampler', shuffle=True),
batch_size=1,
num_workers=0),
val_dataloader=dict(
dataset=dict(type='ToyDataset'),
sampler=dict(type='DefaultSampler', shuffle=False),
batch_size=1,
num_workers=0),
test_dataloader=dict(
dataset=dict(type='ToyDataset'),
sampler=dict(type='DefaultSampler', shuffle=False),
batch_size=1,
num_workers=0),
optimizer=dict(type='SGD', lr=0.01),
param_scheduler=dict(type='MultiStepLR', milestones=[1, 2]),
val_evaluator=dict(type='ToyEvaluator'),
test_evaluator=dict(type='ToyEvaluator'),
train_cfg=dict(by_epoch=True, max_epochs=3),
val_cfg=dict(interval=1),
test_cfg=dict(),
custom_hooks=[],
default_hooks=dict(
timer=dict(type='IterTimerHook'),
checkpoint=dict(type='CheckpointHook', interval=1),
logger=dict(type='LoggerHook'),
optimizer=dict(type='OptimizerHook', grad_clip=False),
param_scheduler=dict(type='ParamSchedulerHook')),
launcher='none',
env_cfg=dict(dist_cfg=dict(backend='nccl')),
visualizer=dict(type='Visualizer',
vis_backends=[dict(type='LocalVisBackend',
save_dir='temp_dir')])
)
>>> model=dict(type='ToyModel'),
>>> work_dir='path/of/work_dir',
>>> train_dataloader=dict(
>>> dataset=dict(type='ToyDataset'),
>>> sampler=dict(type='DefaultSampler', shuffle=True),
>>> batch_size=1,
>>> num_workers=0),
>>> val_dataloader=dict(
>>> dataset=dict(type='ToyDataset'),
>>> sampler=dict(type='DefaultSampler', shuffle=False),
>>> batch_size=1,
>>> num_workers=0),
>>> test_dataloader=dict(
>>> dataset=dict(type='ToyDataset'),
>>> sampler=dict(type='DefaultSampler', shuffle=False),
>>> batch_size=1,
>>> num_workers=0),
>>> optimizer=dict(type='SGD', lr=0.01),
>>> param_scheduler=dict(type='MultiStepLR', milestones=[1, 2]),
>>> val_evaluator=dict(type='ToyEvaluator'),
>>> test_evaluator=dict(type='ToyEvaluator'),
>>> train_cfg=dict(by_epoch=True, max_epochs=3),
>>> val_cfg=dict(interval=1),
>>> test_cfg=dict(),
>>> custom_hooks=[],
>>> default_hooks=dict(
>>> timer=dict(type='IterTimerHook'),
>>> checkpoint=dict(type='CheckpointHook', interval=1),
>>> logger=dict(type='LoggerHook'),
>>> optimizer=dict(type='OptimizerHook', grad_clip=False),
>>> param_scheduler=dict(type='ParamSchedulerHook')),
>>> launcher='none',
>>> env_cfg=dict(dist_cfg=dict(backend='nccl')),
>>> log_processor=dict(window_size=20),
>>> visualizer=dict(type='Visualizer',
>>> vis_backends=[dict(type='LocalVisBackend',
>>> save_dir='temp_dir')])
>>> )
>>> runner = Runner.from_cfg(cfg)
>>> runner.train()
>>> runner.test()
@ -217,6 +220,7 @@ class Runner:
resume: bool = False,
launcher: str = 'none',
env_cfg: Dict = dict(dist_cfg=dict(backend='nccl')),
log_processor: Optional[Dict] = None,
log_level: str = 'INFO',
visualizer: Optional[Union[Visualizer, Dict]] = None,
default_scope: Optional[str] = None,
@ -309,14 +313,16 @@ class Runner:
self._experiment_name = f'{filename_no_ext}_{self._timestamp}'
else:
self._experiment_name = self.timestamp
# Used to reset registries location. See :meth:`Registry.build` for
# more details.
self.default_scope = DefaultScope.get_instance(
self._experiment_name, scope_name=default_scope)
# Build log processor to format message.
log_processor = dict() if log_processor is None else log_processor
self.log_processor = LogProcessor(**log_processor)
# Since `get_instance` could return any subclass of ManagerMixin. The
# corresponding attribute needs a type hint.
self.logger = self.build_logger(log_level=log_level)
# Build `message_hub` for communication among components.
# `message_hub` can store log scalars (loss, learning rate) and
# runtime information (iter and epoch). Those components that do not
@ -387,6 +393,7 @@ class Runner:
resume=cfg.get('resume', False),
launcher=cfg.get('launcher', 'none'),
env_cfg=cfg.get('env_cfg'), # type: ignore
log_processor=cfg.get('log_processor'),
log_level=cfg.get('log_level', 'INFO'),
visualizer=cfg.get('visualizer'),
default_scope=cfg.get('default_scope'),

View File

@ -1,29 +1,70 @@
# Copyright (c) OpenMMLab. All rights reserved.
from unittest.mock import Mock
from unittest import TestCase
from unittest.mock import MagicMock, Mock, patch
from mmengine.hooks import IterTimerHook
from mmengine.logging import MessageHub
class TestIterTimerHook:
def time_patch():
if not hasattr(time_patch, 'time'):
time_patch.time = 0
else:
time_patch.time += 1
return time_patch.time
class TestIterTimerHook(TestCase):
def setUp(self) -> None:
self.hook = IterTimerHook()
def test_init(self):
assert self.hook.time_sec_tot == 0
assert self.hook.start_iter == 0
def test_before_run(self):
runner = MagicMock()
runner.iter = 1
self.hook.before_run(runner)
assert self.hook.start_iter == 1
def test_before_epoch(self):
hook = IterTimerHook()
runner = Mock()
hook._before_epoch(runner)
assert isinstance(hook.t, float)
self.hook._before_epoch(runner)
assert isinstance(self.hook.t, float)
@patch('time.time', MagicMock(return_value=1))
def test_before_iter(self):
hook = IterTimerHook()
runner = Mock()
runner = MagicMock()
runner.log_buffer = dict()
hook._before_epoch(runner)
hook._before_iter(runner, 0)
runner.message_hub.update_scalar.assert_called()
self.hook._before_epoch(runner)
for mode in ('train', 'val', 'test'):
self.hook._before_iter(runner, batch_idx=1, mode=mode)
runner.message_hub.update_scalar.assert_called_with(
f'{mode}/data_time', 0)
@patch('time.time', time_patch)
def test_after_iter(self):
hook = IterTimerHook()
runner = Mock()
runner = MagicMock()
runner.log_buffer = dict()
hook._before_epoch(runner)
hook._after_iter(runner, 0)
runner.log_processor.window_size = 10
runner.train_loop.max_iters = 100
runner.iter = 0
runner.test_loop.dataloader = [0] * 20
runner.val_loop.dataloader = [0] * 20
self.hook._before_epoch(runner)
self.hook.before_run(runner)
self.hook._after_iter(runner, batch_idx=1)
runner.message_hub.update_scalar.assert_called()
runner.message_hub.get_log.assert_not_called()
runner.message_hub.update_info.assert_not_called()
runner.message_hub = MessageHub.get_instance('test_iter_timer_hook')
runner.iter = 9
# eta = (100 - 10) / 1
self.hook._after_iter(runner, batch_idx=89)
assert runner.message_hub.get_info('eta') == 90
self.hook._after_iter(runner, batch_idx=9, mode='val')
assert runner.message_hub.get_info('eta') == 10
self.hook._after_iter(runner, batch_idx=19, mode='test')
assert runner.message_hub.get_info('eta') == 0

View File

@ -1,13 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
import datetime
import logging
import os.path as osp
import sys
from collections import OrderedDict
from unittest.mock import MagicMock, patch
from unittest.mock import MagicMock
import pytest
import torch
from mmengine.fileio.file_client import HardDiskBackend
from mmengine.hooks import LoggerHook
@ -17,11 +12,8 @@ class TestLoggerHook:
def test_init(self):
logger_hook = LoggerHook(out_dir='tmp.txt')
assert logger_hook.by_epoch
assert logger_hook.interval == 10
assert not logger_hook.custom_keys
assert logger_hook.ignore_last
assert logger_hook.time_sec_tot == 0
assert logger_hook.interval_exp_name == 1000
assert logger_hook.out_suffix == ('.log.json', '.log', '.py')
assert logger_hook.keep_local
@ -30,22 +22,7 @@ class TestLoggerHook:
# out_dir should be None or string or tuple of string.
with pytest.raises(TypeError):
LoggerHook(out_dir=1)
# time cannot be overwritten.
with pytest.raises(AssertionError):
LoggerHook(custom_keys=dict(time=dict(method='max')))
LoggerHook(
custom_keys=dict(time=[
dict(method='max', log_name='time_max'),
dict(method='min', log_name='time_min')
]))
# Epoch window_size cannot be used when `LoggerHook.by_epoch=False`
with pytest.raises(AssertionError):
LoggerHook(
by_epoch=False,
custom_keys=dict(
time=dict(
method='max', log_name='time_max',
window_size='epoch')))
with pytest.raises(ValueError):
LoggerHook(file_client_args=dict(enable_mc=True))
@ -60,19 +37,22 @@ class TestLoggerHook:
assert logger_hook.out_dir == osp.join('out_dir', 'work_dir')
assert logger_hook.json_log_path == osp.join('work_dir',
'timestamp.log.json')
assert logger_hook.start_iter == runner.iter
def test_after_run(self, tmp_path):
# Test
out_dir = tmp_path / 'out_dir'
out_dir.mkdir()
work_dir = tmp_path / 'work_dir'
work_dir.mkdir()
work_dir_json = work_dir / 'tmp.log.json'
json_f = open(work_dir_json, 'w')
json_f.close()
runner = MagicMock()
runner.work_dir = work_dir
# Test without out_dir.
logger_hook = LoggerHook()
logger_hook.after_run(runner)
# Test with out_dir and make sure json file has been moved to out_dir.
json_f = open(work_dir_json, 'w')
json_f.close()
logger_hook = LoggerHook(out_dir=str(tmp_path), keep_local=False)
logger_hook.out_dir = str(out_dir)
logger_hook.after_run(runner)
@ -83,276 +63,83 @@ class TestLoggerHook:
def test_after_train_iter(self):
# Test LoggerHook by iter.
runner = MagicMock()
runner.iter = 10
batch_idx = 5
logger_hook = LoggerHook(by_epoch=False)
logger_hook._log_train = MagicMock()
logger_hook.after_train_iter(runner, batch_idx=batch_idx)
runner.log_processor.get_log_after_iter = MagicMock(
return_value=(dict(), 'log_str'))
logger_hook = LoggerHook()
logger_hook.after_train_iter(runner, batch_idx=5)
# `cur_iter=10+1`, which cannot be exact division by
# `logger_hook.interval`
logger_hook._log_train.assert_not_called()
runner.iter = 9
logger_hook.after_train_iter(runner, batch_idx=batch_idx)
logger_hook._log_train.assert_called()
runner.log_processor.get_log_after_iter.assert_not_called()
logger_hook.after_train_iter(runner, batch_idx=9)
runner.log_processor.get_log_after_iter.assert_called()
# Test LoggerHook by epoch.
logger_hook = LoggerHook(by_epoch=True)
logger_hook._log_train = MagicMock()
# Only `runner.inner_iter` will work.
runner.iter = 9
batch_idx = 10
logger_hook.after_train_iter(runner, batch_idx=batch_idx)
logger_hook._log_train.assert_not_called()
batch_idx = 9
logger_hook.after_train_iter(runner, batch_idx=batch_idx)
logger_hook._log_train.assert_called()
logger_hook = LoggerHook()
runner = MagicMock()
runner.log_processor.get_log_after_iter = MagicMock(
return_value=(dict(), 'log_str'))
# Only `batch_idx` will work.
logger_hook.after_train_iter(runner, batch_idx=10)
runner.log_processor.get_log_after_iter.assert_not_called()
logger_hook.after_train_iter(runner, batch_idx=9)
runner.log_processor.get_log_after_iter.assert_called()
# Test end of the epoch.
logger_hook = LoggerHook(by_epoch=True, ignore_last=False)
logger_hook._log_train = MagicMock()
runner.train_loop.dataloader = [0] * 5
batch_idx = 4
logger_hook.after_train_iter(runner, batch_idx=batch_idx)
logger_hook._log_train.assert_called()
runner = MagicMock()
runner.log_processor.get_log_after_iter = MagicMock(
return_value=(dict(), 'log_str'))
logger_hook = LoggerHook(ignore_last=False)
runner.train_dataloader = [0] * 5
logger_hook.after_train_iter(runner, batch_idx=4)
runner.log_processor.get_log_after_iter.assert_called()
# Test print exp_name
runner = MagicMock()
runner.log_processor.get_log_after_iter = MagicMock(
return_value=(dict(), 'log_str'))
runner.meta = dict(exp_name='retinanet')
logger_hook = LoggerHook()
runner.logger = MagicMock()
logger_hook._log_train = MagicMock()
logger_hook.after_train_iter(runner, batch_idx=batch_idx)
runner.logger.info.assert_called_with(
f'Exp name: {runner.meta["exp_name"]}')
logger_hook = LoggerHook()
logger_hook.after_train_iter(runner, batch_idx=999)
runner.logger.info.assert_called()
def test_after_val_epoch(self):
logger_hook = LoggerHook()
runner = MagicMock()
logger_hook._log_val = MagicMock()
runner.log_processor.get_log_after_epoch = MagicMock(
return_value=(dict(), 'string'))
logger_hook.after_val_epoch(runner)
logger_hook._log_val.assert_called()
runner.log_processor.get_log_after_epoch.assert_called()
runner.logger.info.assert_called()
runner.visualizer.add_scalars.assert_called()
@pytest.mark.parametrize('by_epoch', [True, False])
def test_log_train(self, by_epoch, capsys):
runner = self._setup_runner()
runner.meta = dict(exp_name='retinanet')
# Prepare LoggerHook
logger_hook = LoggerHook(by_epoch=by_epoch)
logger_hook._inner_iter = 1
logger_hook.writer = MagicMock()
logger_hook.time_sec_tot = 1000
logger_hook.start_iter = 0
logger_hook._get_max_memory = MagicMock(return_value='100')
logger_hook.json_log_path = 'tmp.json'
# Prepare training information.
train_infos = dict(
lr=0.1, momentum=0.9, time=1.0, data_time=1.0, loss_cls=1.0)
logger_hook._collect_info = MagicMock(return_value=train_infos)
logger_hook._log_train(runner)
# Verify that the correct variables have been written.
runner.visualizer.add_scalars.assert_called_with(
train_infos, step=11, file_path='tmp.json')
# Verify that the correct context have been logged.
out, _ = capsys.readouterr()
time_avg = logger_hook.time_sec_tot / (
runner.iter + 1 - logger_hook.start_iter)
eta_second = time_avg * (runner.train_loop.max_iters - runner.iter - 1)
eta_str = str(datetime.timedelta(seconds=int(eta_second)))
if by_epoch:
if torch.cuda.is_available():
log_str = 'Epoch [2][2/5] ' \
f"lr: {train_infos['lr']:.3e} " \
f"momentum: {train_infos['momentum']:.3e}, " \
f'eta: {eta_str}, ' \
f"time: {train_infos['time']:.3f}, " \
f"data_time: {train_infos['data_time']:.3f}, " \
f'memory: 100, ' \
f"loss_cls: {train_infos['loss_cls']:.4f}\n"
else:
log_str = 'Epoch [2][2/5] ' \
f"lr: {train_infos['lr']:.3e} " \
f"momentum: {train_infos['momentum']:.3e}, " \
f'eta: {eta_str}, ' \
f"time: {train_infos['time']:.3f}, " \
f"data_time: {train_infos['data_time']:.3f}, " \
f"loss_cls: {train_infos['loss_cls']:.4f}\n"
assert out == log_str
else:
if torch.cuda.is_available():
log_str = 'Iter [11/50] ' \
f"lr: {train_infos['lr']:.3e} " \
f"momentum: {train_infos['momentum']:.3e}, " \
f'eta: {eta_str}, ' \
f"time: {train_infos['time']:.3f}, " \
f"data_time: {train_infos['data_time']:.3f}, " \
f'memory: 100, ' \
f"loss_cls: {train_infos['loss_cls']:.4f}\n"
else:
log_str = 'Iter [11/50] ' \
f"lr: {train_infos['lr']:.3e} " \
f"momentum: {train_infos['momentum']:.3e}, " \
f'eta: {eta_str}, ' \
f"time: {train_infos['time']:.3f}, " \
f"data_time: {train_infos['data_time']:.3f}, " \
f"loss_cls: {train_infos['loss_cls']:.4f}\n"
assert out == log_str
@pytest.mark.parametrize('by_epoch', [True, False])
def test_log_val(self, by_epoch, capsys):
runner = self._setup_runner()
# Prepare LoggerHook.
logger_hook = LoggerHook(by_epoch=by_epoch)
logger_hook.json_log_path = 'tmp.json'
metric = dict(accuracy=0.9, data_time=1.0)
logger_hook._collect_info = MagicMock(return_value=metric)
logger_hook._log_val(runner)
# Verify that the correct context have been logged.
out, _ = capsys.readouterr()
runner.visualizer.add_scalars.assert_called_with(
metric, step=11, file_path='tmp.json')
if by_epoch:
assert out == 'Epoch(val) [1][5] accuracy: 0.9000, ' \
'data_time: 1.0000\n'
else:
assert out == 'Iter(val) [5] accuracy: 0.9000, ' \
'data_time: 1.0000\n'
def test_get_window_size(self):
runner = self._setup_runner()
logger_hook = LoggerHook()
logger_hook._inner_iter = 1
# Test get window size by name.
assert logger_hook._get_window_size(runner, 'epoch') == 2
assert logger_hook._get_window_size(runner, 'global') == 11
assert logger_hook._get_window_size(runner, 10) == 10
# Window size must equal to `logger_hook.interval`.
with pytest.raises(AssertionError):
logger_hook._get_window_size(runner, 20)
with pytest.raises(ValueError):
logger_hook._get_window_size(runner, 'unknwon')
def test_parse_custom_keys(self):
tag = OrderedDict()
runner = self._setup_runner()
log_buffers = OrderedDict(lr=MagicMock(), loss=MagicMock())
cfg_dict = dict(
lr=dict(method='min'),
loss=[
dict(method='min', window_size='global'),
dict(method='max', log_name='loss_max')
])
logger_hook = LoggerHook()
for log_key, log_cfg in cfg_dict.items():
logger_hook._parse_custom_keys(runner, log_key, log_cfg,
log_buffers, tag)
assert list(tag) == ['lr', 'loss', 'loss_max']
assert log_buffers['lr'].min.assert_called
assert log_buffers['loss'].min.assert_called
assert log_buffers['loss'].max.assert_called
assert log_buffers['loss'].mean.assert_called
# `log_name` Cannot be repeated.
with pytest.raises(KeyError):
cfg_dict = dict(loss=[
dict(method='min', window_size='global'),
dict(method='max', log_name='loss_max'),
dict(method='mean', log_name='loss_max')
])
logger_hook.custom_keys = cfg_dict
for log_key, log_cfg in cfg_dict.items():
logger_hook._parse_custom_keys(runner, log_key, log_cfg,
log_buffers, tag)
# `log_key` cannot be overwritten multiple times.
with pytest.raises(AssertionError):
cfg_dict = dict(loss=[
dict(method='min', window_size='global'),
dict(method='max'),
])
logger_hook.custom_keys = cfg_dict
for log_key, log_cfg in cfg_dict.items():
logger_hook._parse_custom_keys(runner, log_key, log_cfg,
log_buffers, tag)
def test_collect_info(self):
runner = self._setup_runner()
logger_hook = LoggerHook(
custom_keys=dict(time=dict(method='max', log_name='time_max')))
logger_hook._parse_custom_keys = MagicMock()
# Collect with prefix.
log_buffers = {
'train/time': MagicMock(),
'lr': MagicMock(),
'train/loss_cls': MagicMock(),
'val/metric': MagicMock()
}
runner.message_hub.log_scalars = log_buffers
tag = logger_hook._collect_info(runner, mode='train')
# Test parse custom_keys
logger_hook._parse_custom_keys.assert_called()
# Test training key in tag.
assert list(tag.keys()) == ['time', 'loss_cls']
# Test statistics lr with `current`, loss and time with 'mean'
log_buffers['train/time'].mean.assert_called()
log_buffers['train/loss_cls'].mean.assert_called()
log_buffers['train/loss_cls'].current.assert_not_called()
tag = logger_hook._collect_info(runner, mode='val')
assert list(tag.keys()) == ['metric']
log_buffers['val/metric'].current.assert_called()
@patch('torch.cuda.max_memory_allocated', MagicMock())
@patch('torch.cuda.reset_peak_memory_stats', MagicMock())
def test_get_max_memory(self):
def test_after_test_epoch(self):
logger_hook = LoggerHook()
runner = MagicMock()
runner.world_size = 1
runner.model = torch.nn.Linear(1, 1)
logger_hook._get_max_memory(runner)
torch.cuda.max_memory_allocated.assert_called()
torch.cuda.reset_peak_memory_stats.assert_called()
runner.log_processor.get_log_after_epoch = MagicMock(
return_value=(dict(), 'log_str'))
logger_hook.after_test_epoch(runner)
runner.log_processor.get_log_after_epoch.assert_called()
runner.logger.info.assert_called()
def test_get_iter(self):
runner = self._setup_runner()
def test_after_val_iter(self):
logger_hook = LoggerHook()
logger_hook._inner_iter = 1
# Get global iter when `inner_iter=False`
iter = logger_hook._get_iter(runner)
assert iter == 11
# Get inner iter
iter = logger_hook._get_iter(runner, inner_iter=True)
assert iter == 2
# Still get global iter when `logger_hook.by_epoch==False`
logger_hook.by_epoch = False
iter = logger_hook._get_iter(runner, inner_iter=True)
assert iter == 11
def test_get_epoch(self):
runner = self._setup_runner()
logger_hook = LoggerHook()
epoch = logger_hook._get_epoch(runner, 'train')
assert epoch == 2
epoch = logger_hook._get_epoch(runner, 'val')
assert epoch == 1
with pytest.raises(ValueError):
logger_hook._get_epoch(runner, 'test')
def _setup_runner(self):
runner = MagicMock()
runner.epoch = 1
runner.train_loop.dataloader = [0] * 5
runner.val_loop.dataloader = [0] * 5
runner.test_loop.dataloader = [0] * 5
runner.iter = 10
runner.train_loop.max_iters = 50
logger = logging.getLogger()
logger.setLevel(logging.INFO)
for handler in logger.handlers:
if not isinstance(handler, logging.StreamHandler):
continue
else:
logger.addHandler(logging.StreamHandler(stream=sys.stdout))
runner.logger = logger
runner.message_hub = MagicMock()
runner.composed_wirter = MagicMock()
return runner
runner.iter = 0
runner.log_processor.get_log_after_iter = MagicMock(
return_value=(dict(), 'log_str'))
logger_hook.after_val_iter(runner, 1)
runner.log_processor.get_log_after_iter.assert_not_called()
logger_hook.after_val_iter(runner, 9)
runner.log_processor.get_log_after_iter.assert_called()
def test_after_test_iter(self):
logger_hook = LoggerHook()
runner = MagicMock()
runner.iter = 0
runner.log_processor.get_log_after_iter = MagicMock(
return_value=(dict(), 'log_str'))
logger_hook.after_test_iter(runner, 1)
runner.log_processor.get_log_after_iter.assert_not_called()
logger_hook.after_test_iter(runner, 9)
runner.log_processor.get_log_after_iter.assert_called()

View File

@ -1,5 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from unittest.mock import Mock
from unittest.mock import MagicMock, Mock
import torch
from torch import nn
@ -45,7 +45,7 @@ class TestOptimizerHook:
model = Model()
x = torch.rand(1, 1, 3, 3)
dummy_runner = Mock()
dummy_runner = MagicMock()
dummy_runner.optimizer.zero_grad = Mock(return_value=None)
dummy_runner.optimizer.step = Mock(return_value=None)
dummy_runner.model = model

View File

@ -0,0 +1,242 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
from unittest.mock import MagicMock, patch
import pytest
import torch
from mmengine.logging import LogProcessor, MessageHub, MMLogger
class TestLogProcessor:
def test_init(self):
log_processor = LogProcessor(
window_size=10, by_epoch=True, custom_cfg=None)
assert log_processor.by_epoch
assert log_processor.window_size == 10
assert log_processor.custom_cfg == []
def test_check_custom_cfg(self):
# ``by_epoch==False`` and `window_size='epoch'` in log config will
# raise AssertionError.
custom_cfg = [dict(data_src='loss', window_size='epoch')]
with pytest.raises(AssertionError):
LogProcessor(by_epoch=False, custom_cfg=custom_cfg)
# Duplicate log_name will raise AssertionError.
custom_cfg = [
dict(data_src='loss', log_name='loss_1'),
dict(data_src='loss', log_name='loss_1')
]
with pytest.raises(AssertionError):
LogProcessor(custom_cfg=custom_cfg)
# Overwrite loss item twice will raise AssertionError.
custom_cfg = [dict(data_src='loss'), dict(data_src='loss')]
with pytest.raises(AssertionError):
LogProcessor(custom_cfg=custom_cfg)
custom_cfg = [
dict(data_src='loss_cls', window_size=100, method_name='min'),
dict(data_src='loss', log_name='loss_min', method_name='max'),
dict(data_src='loss', log_name='loss_max', method_name='max')
]
LogProcessor(custom_cfg=custom_cfg)
def test_parse_windows_size(self):
log_processor = LogProcessor()
# Test parse 'epoch' window_size.
log_processor.custom_cfg = [
dict(data_src='loss_cls', window_size='epoch')
]
custom_cfg = log_processor._parse_windows_size(self.runner, 1)
assert custom_cfg[0]['window_size'] == 2
# Test parse 'global' window_size.
log_processor.custom_cfg = [
dict(data_src='loss_cls', window_size='global')
]
custom_cfg = log_processor._parse_windows_size(self.runner, 1)
assert custom_cfg[0]['window_size'] == 11
# Test parse int window_size
log_processor.custom_cfg = [dict(data_src='loss_cls', window_size=100)]
custom_cfg = log_processor._parse_windows_size(self.runner, 1)
assert custom_cfg[0]['window_size'] == 100
# Invalid type window_size will raise TypeError.
log_processor.custom_cfg = [dict(data_src='loss_cls', window_size=[])]
with pytest.raises(TypeError):
log_processor._parse_windows_size(custom_cfg, self.runner)
@pytest.mark.parametrize('by_epoch,mode',
([True, 'train'], [False, 'train'], [True, 'val'],
[False, 'val'], [True, 'test'], [False, 'test']))
def test_get_log_after_iter(self, by_epoch, mode):
# Prepare LoggerHook
log_processor = LogProcessor(by_epoch=by_epoch)
log_processor._get_max_memory = MagicMock(return_value='100')
eta = 40
self.runner.message_hub.update_info('eta', eta)
# Prepare training information.
if mode == 'train':
train_logs = dict(lr=0.1, time=1.0, data_time=1.0, loss_cls=1.0)
else:
train_logs = dict(time=1.0, data_time=1.0, loss_cls=1.0)
log_processor._collect_scalars = MagicMock(return_value=train_logs)
tag, out = log_processor.get_log_after_iter(self.runner, 1, mode)
# Verify that the correct context have been logged.
cur_loop = log_processor._get_cur_loop(self.runner, mode)
if by_epoch:
if mode in ['train', 'val']:
cur_epoch = log_processor._get_epoch(self.runner, mode)
log_str = (f'Epoch({mode}) [{cur_epoch}][2/'
f'{len(cur_loop.dataloader)}] ')
else:
log_str = (f'Epoch({mode}) [2/{len(cur_loop.dataloader)}] ')
if mode == 'train':
log_str += f"lr: {train_logs['lr']:.3e} "
else:
log_str += ' '
log_str += (f'eta: 0:00:40 '
f"time: {train_logs['time']:.3f} "
f"data_time: {train_logs['data_time']:.3f} ")
if torch.cuda.is_available():
log_str += 'memory: 100 '
if mode == 'train':
log_str += f"loss_cls: {train_logs['loss_cls']:.4f}"
assert out == log_str
else:
if mode == 'train':
max_iters = self.runner.train_loop.max_iters
log_str = f'Iter({mode}) [11/{max_iters}] '
else:
max_iters = len(cur_loop.dataloader)
log_str = f'Iter({mode}) [2/{max_iters}] '
if mode == 'train':
log_str += f"lr: {train_logs['lr']:.3e} "
else:
log_str += ' '
log_str += (f'eta: 0:00:40 '
f"time: {train_logs['time']:.3f} "
f"data_time: {train_logs['data_time']:.3f} ")
if torch.cuda.is_available():
log_str += 'memory: 100 '
if mode == 'train':
log_str += f"loss_cls: {train_logs['loss_cls']:.4f}"
assert out == log_str
@pytest.mark.parametrize(
'by_epoch,mode',
([True, 'val'], [False, 'val'], [True, 'test'], [False, 'test']))
def test_log_val(self, by_epoch, mode):
# Prepare LoggerHook
log_processor = LogProcessor(by_epoch=by_epoch)
# Prepare validation information.
val_logs = dict(accuracy=0.9, data_time=1.0)
log_processor._collect_scalars = MagicMock(return_value=val_logs)
_, out = log_processor.get_log_after_epoch(self.runner, 2, mode)
if by_epoch:
if mode == 'test':
assert out == 'Epoch(test) [5/5] accuracy: 0.9000'
else:
assert out == 'Epoch(val) [1][10/10] accuracy: 0.9000'
else:
if mode == 'test':
assert out == 'Iter(test) [5/5] accuracy: 0.9000'
else:
assert out == 'Iter(val) [10/10] accuracy: 0.9000'
def test_collect_scalars(self):
custom_cfg = [
dict(data_src='time', method_name='mean', window_size=100),
dict(data_src='time', method_name='max', log_name='time_max')
]
logger_hook = LogProcessor(custom_cfg=custom_cfg)
# Collect with prefix.
log_scalars = {
'train/time': MagicMock(),
'lr': MagicMock(),
'train/loss_cls': MagicMock(),
'val/metric': MagicMock()
}
self.runner.message_hub._log_scalars = log_scalars
tag = logger_hook._collect_scalars(
copy.deepcopy(custom_cfg), self.runner, mode='train')
# Test training key in tag.
assert list(tag.keys()) == ['time', 'loss_cls', 'time_max']
# Test statistics lr with `current`, loss and time with 'mean'
log_scalars['train/time'].statistics.assert_called_with(
method_name='max')
log_scalars['train/loss_cls'].mean.assert_called()
tag = logger_hook._collect_scalars(
copy.deepcopy(custom_cfg), self.runner, mode='val')
assert list(tag.keys()) == ['metric']
log_scalars['val/metric'].current.assert_called()
@patch('torch.cuda.max_memory_allocated', MagicMock())
@patch('torch.cuda.reset_peak_memory_stats', MagicMock())
def test_get_max_memory(self):
logger_hook = LogProcessor()
runner = MagicMock()
runner.world_size = 1
runner.model = torch.nn.Linear(1, 1)
logger_hook._get_max_memory(runner)
torch.cuda.max_memory_allocated.assert_called()
torch.cuda.reset_peak_memory_stats.assert_called()
def test_get_iter(self):
log_processor = LogProcessor()
# Get global iter when `inner_iter=False`
iter = log_processor._get_iter(self.runner)
assert iter == 11
# Get inner iter
iter = log_processor._get_iter(self.runner, 1)
assert iter == 2
# Still get global iter when `logger_hook.by_epoch==False`
log_processor.by_epoch = False
iter = log_processor._get_iter(self.runner, 1)
assert iter == 11
def test_get_epoch(self):
log_processor = LogProcessor()
epoch = log_processor._get_epoch(self.runner, 'train')
assert epoch == 2
epoch = log_processor._get_epoch(self.runner, 'val')
assert epoch == 1
with pytest.raises(ValueError):
log_processor._get_epoch(self.runner, 'test')
def test_get_cur_loop(self):
log_processor = LogProcessor()
loop = log_processor._get_cur_loop(self.runner, 'train')
assert len(loop.dataloader) == 20
loop = log_processor._get_cur_loop(self.runner, 'val')
assert len(loop.dataloader) == 10
loop = log_processor._get_cur_loop(self.runner, 'test')
assert len(loop.dataloader) == 5
def setup(self):
runner = MagicMock()
runner.epoch = 1
runner.iter = 10
runner.train_loop.max_iters = 50
runner.train_loop.dataloader = [0] * 20
runner.val_loop.dataloader = [0] * 10
runner.test_loop.dataloader = [0] * 5
logger = MMLogger.get_instance('log_processor_test')
runner.logger = logger
message_hub = MessageHub.get_instance('log_processor_test')
for i in range(10):
message_hub.update_scalar('train/loss', 10 - i)
for i in range(10):
message_hub.update_scalar('val/acc', i * 0.1)
runner.message_hub = message_hub
self.runner = runner

View File

@ -8,7 +8,7 @@ import torch.optim as optim
from mmengine.optim.scheduler import (ConstantLR, CosineAnnealingLR,
ExponentialLR, LinearLR, MultiStepLR,
StepLR, _ParamScheduler)
PolyLR, StepLR, _ParamScheduler)
from mmengine.testing import assert_allclose
@ -283,6 +283,21 @@ class TestLRScheduler(TestCase):
scheduler = CosineAnnealingLR(self.optimizer, T_max=t, eta_min=eta_min)
self._test_scheduler_value(scheduler, targets, epochs)
def test_poly_scheduler(self):
epochs = 10
power = 0.9
min_lr = 0.001
iters = 4
single_targets = [
min_lr + (0.05 - min_lr) * (1 - i / iters)**power
for i in range(iters)
] + [min_lr] * (
epochs - iters)
targets = [single_targets, [x * epochs for x in single_targets]]
scheduler = PolyLR(
self.optimizer, power=power, eta_min=min_lr, end=iters + 1)
self._test_scheduler_value(scheduler, targets, epochs=10)
def _check_scheduler_state_dict(self, construct, construct2, epochs=10):
scheduler = construct()
for _ in range(epochs):
@ -331,6 +346,12 @@ class TestLRScheduler(TestCase):
lambda: LinearLR(self.optimizer, start_factor=0, end_factor=0.3),
epochs=epochs)
def test_poly_scheduler_state_dict(self):
self._check_scheduler_state_dict(
lambda: PolyLR(self.optimizer, power=0.5, eta_min=0.001),
lambda: PolyLR(self.optimizer, power=0.8, eta_min=0.002),
epochs=10)
def test_multi_scheduler_without_overlap_linear_multi_step(self):
# use Linear in the first 5 epochs and then use MultiStep
epochs = 12

View File

@ -9,8 +9,8 @@ import torch.optim as optim
from mmengine.optim.scheduler import (ConstantMomentum,
CosineAnnealingMomentum,
ExponentialMomentum, LinearMomentum,
MultiStepMomentum, StepMomentum,
_ParamScheduler)
MultiStepMomentum, PolyMomentum,
StepMomentum, _ParamScheduler)
from mmengine.testing import assert_allclose
@ -284,6 +284,21 @@ class TestMomentumScheduler(TestCase):
self.optimizer, T_max=t, eta_min=eta_min)
self._test_scheduler_value(scheduler, targets, epochs)
def test_poly_scheduler(self):
epochs = 10
power = 0.9
min_lr = 0.001
iters = 4
single_targets = [
min_lr + (0.05 - min_lr) * (1 - i / iters)**power
for i in range(iters)
] + [min_lr] * (
epochs - iters)
targets = [single_targets, [x * epochs for x in single_targets]]
scheduler = PolyMomentum(
self.optimizer, power=power, eta_min=min_lr, end=iters + 1)
self._test_scheduler_value(scheduler, targets, epochs=10)
def _check_scheduler_state_dict(self, construct, construct2, epochs=10):
scheduler = construct()
for _ in range(epochs):
@ -333,6 +348,12 @@ class TestMomentumScheduler(TestCase):
self.optimizer, start_factor=0, end_factor=0.3),
epochs=epochs)
def test_poly_scheduler_state_dict(self):
self._check_scheduler_state_dict(
lambda: PolyMomentum(self.optimizer, power=0.5, eta_min=0.001),
lambda: PolyMomentum(self.optimizer, power=0.8, eta_min=0.002),
epochs=10)
def test_multi_scheduler_without_overlap_linear_multi_step(self):
# use Linear in the first 5 epochs and then use MultiStep
epochs = 12

View File

@ -6,12 +6,15 @@ import torch
import torch.nn.functional as F
import torch.optim as optim
# yapf: disable
from mmengine.optim.scheduler import (ConstantParamScheduler,
CosineAnnealingParamScheduler,
ExponentialParamScheduler,
LinearParamScheduler,
MultiStepParamScheduler,
StepParamScheduler, _ParamScheduler)
PolyParamScheduler, StepParamScheduler,
_ParamScheduler)
# yapf: enable
from mmengine.testing import assert_allclose
@ -336,6 +339,25 @@ class TestParameterScheduler(TestCase):
self.optimizer, param_name='lr', T_max=t, eta_min=eta_min)
self._test_scheduler_value(scheduler, targets, epochs)
def test_poly_scheduler(self):
epochs = 10
power = 0.9
min_lr = 0.001
iters = 4
single_targets = [
min_lr + (0.05 - min_lr) * (1 - i / iters)**power
for i in range(iters)
] + [min_lr] * (
epochs - iters)
targets = [single_targets, [x * epochs for x in single_targets]]
scheduler = PolyParamScheduler(
self.optimizer,
param_name='lr',
power=power,
eta_min=min_lr,
end=iters + 1)
self._test_scheduler_value(scheduler, targets, epochs=10)
def _check_scheduler_state_dict(self, construct, construct2, epochs=10):
scheduler = construct()
for _ in range(epochs):
@ -402,6 +424,14 @@ class TestParameterScheduler(TestCase):
end_factor=0.3),
epochs=epochs)
def test_poly_scheduler_state_dict(self):
self._check_scheduler_state_dict(
lambda: PolyParamScheduler(
self.optimizer, param_name='lr', power=0.5, eta_min=0.001),
lambda: PolyParamScheduler(
self.optimizer, param_name='lr', power=0.8, eta_min=0.002),
epochs=10)
def test_multi_scheduler_without_overlap_linear_multi_step(self):
# use Linear in the first 5 epochs and then use MultiStep
epochs = 12

View File

@ -222,7 +222,7 @@ class TestRunner(TestCase):
self.iter_based_cfg.default_hooks = dict(
timer=dict(type='IterTimerHook'),
checkpoint=dict(type='CheckpointHook', interval=1, by_epoch=False),
logger=dict(type='LoggerHook', by_epoch=False),
logger=dict(type='LoggerHook'),
optimizer=dict(type='OptimizerHook', grad_clip=None),
param_scheduler=dict(type='ParamSchedulerHook'))