mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Fix] resolve conflict betweem adapt and main. (#198)
* [Docs] Refine registry documentation (#186) * [Docs] Refine registry documentation * reslove comments * minor refinement * Refine Visualizer docs (#177) * Refine Visualizer docs * update * update * update featmap * update docs * update visualizer docs * [Refactor] Refine LoggerHook (#155) * rename global accessible and intergration get_sintance and create_instance * move ManagerMixin to utils * fix as docstring and seporate get_instance to get_instance and get_current_instance * fix lint * fix docstring, rename and move test_global_meta * rename LogBuffer to HistoryBuffer, rename MessageHub methods, MessageHub support resume * refine MMLogger timestamp, update unit test * MMLogger add logger_name arguments * Fix docstring * Add LogProcessor and some unit test * update unit test * complete LogProcessor unit test * refine LoggerHook * solve circle import * change default logger_name to mmengine * refactor eta * Fix docstring comment and unitt test * Fix with runner * fix docstring fix docstring * fix docstring * Add by_epoch attribute to LoggerHook and fix docstring * Please mypy and fix comment * remove \ in MMLogger * Fix lint * roll back pre-commit-hook * Fix hook unit test * Fix comments * remove \t in log and add docstring * Fix as comment * should not accept other arguments if corresponding instance has been created * fix logging ddp file saving * fix logging ddp file saving * move log processor to logging * move log processor to logging * remove current datalaoder * fix docstring * fix unit test * add learing rate in messagehub * Support output training/validation/testing message after iterations/epochs * fix docstring * Fix IterBasedRunner log string * Fix IterBasedRunner log string * Support parse validation loss in log processor * [Enhancement] Add PolyParamScheduler, PolyMomentum and PolyLR (#188) * [Enhancement] Add PolyParamScheduler, PolyMomentum and PolyLR * min_lr -> eta_min, refined docstr Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> Co-authored-by: Haian Huang(深度眸) <1286304229@qq.com> Co-authored-by: Tong Gao <gaotongxiao@gmail.com>
This commit is contained in:
parent
fb7d8ccd6b
commit
e0d00c5bdd
@ -262,7 +262,7 @@ class RetinaNet(nn.Module):
|
|||||||
|
|
||||||

|

|
||||||
|
|
||||||
我们可以在 `MMDetection` 中调用 `MMEngine` 中模块。
|
我们可以在 `MMDetection` 中调用 `MMEngine` 中的模块。
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from mmdet.models import MODELS
|
from mmdet.models import MODELS
|
||||||
@ -278,6 +278,29 @@ model = MODELS.build(cfg=dict(type='Conv2d'))
|
|||||||
|
|
||||||
如果不加前缀,`build` 方法首先查找当前节点是否存在该模块,如果存在则返回该模块,否则会继续向上查找父节点甚至祖先节点直到找到该模块,因此,如果当前节点和父节点存在同一模块并且希望调用父节点的模块,我们需要指定 `scope` 前缀。需要注意的是,向上查找父节点甚至祖先节点的**前提是父节点或者祖先节点的模块已通过某种方式被导入进而完成注册**。例如,在上面这个示例中,之所以没有显示导入父节点 `mmengine` 中的 `MODELS`,是因为通过 `from mmdet.models import MODELS` 间接触发 `mmengine.MODELS` 完成模块的注册。
|
如果不加前缀,`build` 方法首先查找当前节点是否存在该模块,如果存在则返回该模块,否则会继续向上查找父节点甚至祖先节点直到找到该模块,因此,如果当前节点和父节点存在同一模块并且希望调用父节点的模块,我们需要指定 `scope` 前缀。需要注意的是,向上查找父节点甚至祖先节点的**前提是父节点或者祖先节点的模块已通过某种方式被导入进而完成注册**。例如,在上面这个示例中,之所以没有显示导入父节点 `mmengine` 中的 `MODELS`,是因为通过 `from mmdet.models import MODELS` 间接触发 `mmengine.MODELS` 完成模块的注册。
|
||||||
|
|
||||||
|
上面展示了如何使用子节点注册器构建模块,但有时候我们希望不填加前缀也能在父节点注册器中构建子节点的模块,目的是提供通用的代码,避免下游算法库重复造轮子,该如何实现呢?
|
||||||
|
|
||||||
|
假设 MMEngine 中有一个 `build_model` 函数,该方法用于构建模型。
|
||||||
|
|
||||||
|
```python
|
||||||
|
from mmengine.registry import MODELS
|
||||||
|
|
||||||
|
def build_model(cfg):
|
||||||
|
model = MODELS.build(cfg)
|
||||||
|
```
|
||||||
|
|
||||||
|
如果我们希望在 MMDetection 中调用该函数构建 MMDetection 注册的模块,那么我们需要先获取一个 scope_name 为 'mmdet' 的 [DefaultScope](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.registry.DefaultScope) 实例,该实例全局唯一。
|
||||||
|
|
||||||
|
```python
|
||||||
|
from mmengine import build_model
|
||||||
|
import mmdet.models # 通过 import 的方式将 mmdet 中的模块导入注册器进而完成注册
|
||||||
|
|
||||||
|
default_scope = DefaultScope.get_instance('my_experiment', scope_name='mmdet')
|
||||||
|
model = build_model(cfg=dict(type='RetinaNet'))
|
||||||
|
```
|
||||||
|
|
||||||
|
获取 `DefaultScope` 实例的目的是使 Registry 的 build 方法会将 DefaultScope 名称(mmdet)注册器节点作为注册器的起点,才能在配置中不填加 mmdet 前缀的情况下在 MMDetection 的注册器节点中找到 RetinaNet 模块,如若不然,程序会报找不到 RetinaNet 错误。
|
||||||
|
|
||||||
### 调用兄弟节点的模块
|
### 调用兄弟节点的模块
|
||||||
|
|
||||||
除了可以调用父节点的模块,也可以调用兄弟节点的模块。
|
除了可以调用父节点的模块,也可以调用兄弟节点的模块。
|
||||||
@ -311,16 +334,7 @@ from mmcls.models import MODELS
|
|||||||
model = MODELS.build(cfg=dict(type='mmdet.RetinaNet'))
|
model = MODELS.build(cfg=dict(type='mmdet.RetinaNet'))
|
||||||
```
|
```
|
||||||
|
|
||||||
调用非本节点的模块需要指定在 `type` 中指定 `scope` 前缀,如果不想指定,我们可以创建一个全局变量 `default_scope` 并将 `scope_name` 设置为 'mmdet',`Registry` 会将 `scope_name` 对应的 `registry` 作为当前 `Registry` 并调用 `build` 方法。
|
调用非本节点或父节点的模块需要在 `type` 中指定 `scope` 前缀。
|
||||||
|
|
||||||
```python
|
|
||||||
from mmengine.registry import DefaultScope, MODELS
|
|
||||||
|
|
||||||
# 调用注册在 mmdet 中的 RetinaNet
|
|
||||||
default_scope = DefaultScope.get_instance(
|
|
||||||
'my_experiment', scope_name='mmdet')
|
|
||||||
model = MODELS.build(cfg=dict(type='RetinaNet'))
|
|
||||||
```
|
|
||||||
|
|
||||||
注册器除了支持两层结构,三层甚至更多层结构也是支持的。
|
注册器除了支持两层结构,三层甚至更多层结构也是支持的。
|
||||||
|
|
||||||
@ -358,10 +372,4 @@ model = MODELS.build(cfg=dict(type='mmcls.ResNet'))
|
|||||||
from mmcls.models import MODELS
|
from mmcls.models import MODELS
|
||||||
# 需要注意前缀的顺序,'detplus.mmdet.ResNet' 是不正确的
|
# 需要注意前缀的顺序,'detplus.mmdet.ResNet' 是不正确的
|
||||||
model = MODELS.build(cfg=dict(type='mmdet.detplus.MetaNet'))
|
model = MODELS.build(cfg=dict(type='mmdet.detplus.MetaNet'))
|
||||||
|
|
||||||
# 如果希望默认从 detplus 构建模型,设置可以 default_scope
|
|
||||||
from mmengine.registry import DefaultScope
|
|
||||||
default_scope = DefaultScope.get_instance(
|
|
||||||
'my_experiment', scope_name='detplus')
|
|
||||||
model = MODELS.build(cfg=dict(type='MetaNet', default_scope='detplus'))
|
|
||||||
```
|
```
|
||||||
|
300
docs/zh_cn/tutorials/visualization.md
Normal file
300
docs/zh_cn/tutorials/visualization.md
Normal file
@ -0,0 +1,300 @@
|
|||||||
|
# 可视化 (Visualization)
|
||||||
|
|
||||||
|
## 概述
|
||||||
|
|
||||||
|
可视化可以给深度学习的模型训练和测试过程提供直观解释。在 OpenMMLab 算法库中,我们期望可视化功能的设计能满足以下需求:
|
||||||
|
|
||||||
|
- 提供丰富的开箱即用可视化功能,能够满足大部分计算机视觉可视化任务
|
||||||
|
- 高扩展性,可视化功能通常多样化,应该能够通过简单扩展实现定制需求
|
||||||
|
- 能够在训练和测试流程的任意点位进行可视化
|
||||||
|
- OpenMMLab 各个算法库具有统一可视化接口,利于用户理解和维护
|
||||||
|
|
||||||
|
基于上述需求,OpenMMLab 2.0 引入了可视化对象 Visualizer 和各个可视化存储后端 VisBackend 如 `LocalVisBackend`、`WandbVisBackend` 和 `TensorboardVisBackend` 等。此处的可视化不仅仅包括图片数据格式,还包括配置内容、标量和模型图等数据的可视化。
|
||||||
|
|
||||||
|
- 为了方便调用,Visualizer 提供的接口实现了绘制和存储的功能。可视化存储后端 VisBackend 作为 Visualizer 的内部属性,会在需要的时候被 Visualizer 调用,将数据存到不同的后端
|
||||||
|
- 考虑到绘制后会希望存储到多个后端,Visualizer 可以配置多个 VisBackend,当用户调用 Visualizer 的存储接口时候,Visualizer 内部会遍历的调用 VisBackend 存储接口
|
||||||
|
|
||||||
|
两者的 UML 关系图如下
|
||||||
|
|
||||||
|
<div align="center">
|
||||||
|
<img src="https://user-images.githubusercontent.com/17425982/163327736-f7cb3b16-ef07-46bc-982a-3cc7495e6c82.png" >
|
||||||
|
</div>
|
||||||
|
|
||||||
|
## 可视化对象 Visualizer
|
||||||
|
|
||||||
|
### 接口说明
|
||||||
|
|
||||||
|
可视化对象 Visualizer 对外提供了所有接口。可以将其接口分成 3 大类,如下所示
|
||||||
|
|
||||||
|
**(1) 绘制相关接口**
|
||||||
|
|
||||||
|
- [draw_bboxes](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.Visualizer.draw_bboxes) 绘制单个或多个边界框
|
||||||
|
- [draw_points](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.Visualizer.draw_points) 绘制单个或多个点
|
||||||
|
- [draw_texts](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.Visualizer.draw_texts) 绘制单个或多个文本框
|
||||||
|
- [draw_lines](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.Visualizer.lines) 绘制单个或多个线段
|
||||||
|
- [draw_circles](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.Visualizer.draw_circles) 绘制单个或多个圆
|
||||||
|
- [draw_polygons](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.Visualizer.draw_polygons) 绘制单个或多个多边形
|
||||||
|
- [draw_binary_masks](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.Visualizer.draw_binary_mask) 绘制单个或多个二值掩码
|
||||||
|
- [draw_featmap](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.Visualizer.draw_featmap) 绘制特征图,静态方法
|
||||||
|
|
||||||
|
上述接口除了 `draw_featmap` 外都可以链式调用,因为该方法调用后可能会导致图片尺寸发生改变。为了避免给用户带来困扰, `draw_featmap` 被设置为静态方法。
|
||||||
|
|
||||||
|
当用户想先绘制边界框,在此基础上绘制文本,绘制线段的时候,可以通过链式调用实现:
|
||||||
|
|
||||||
|
```python
|
||||||
|
visualizer.set_image(image)
|
||||||
|
visualizer.draw_bboxes(...).draw_texts(...).draw_lines(...)
|
||||||
|
visualizer.show() # 可视化绘制结果
|
||||||
|
```
|
||||||
|
|
||||||
|
特征图可视化是一个常见的功能,用户通过调用 `draw_featmap` 可视化特征图,其参数定义为:
|
||||||
|
|
||||||
|
```python
|
||||||
|
@staticmethod
|
||||||
|
def draw_featmap(featmap: torch.Tensor, # 输入格式要求为 CHW
|
||||||
|
overlaid_image: Optional[np.ndarray] = None, # 如果同时输入了 image 数据,则特征图会叠加到 image 上绘制
|
||||||
|
channel_reduction: Optional[str] = 'squeeze_mean', # 多个通道压缩为单通道的策略
|
||||||
|
topk: int = 10, # 可选择激活度最高的 topk 个特征图显示
|
||||||
|
arrangement: Tuple[int, int] = (5, 2), # 多通道展开为多张图时候布局
|
||||||
|
resize_shape:Optional[tuple] = None, # 可以指定 resize_shape 参数来缩放特征图
|
||||||
|
alpha: float = 0.5) -> np.ndarray: # 图片和特征图绘制的叠加比例
|
||||||
|
```
|
||||||
|
|
||||||
|
特征图可视化功能较多,目前不支持 Batch 输入,其功能可以归纳如下
|
||||||
|
|
||||||
|
- 输入的 Tensor 一般是包括多个通道的,channel_reduction 参数可以将多个通道压缩为单通道,然后和图片进行叠加显示
|
||||||
|
- `squeeze_mean` 将输入的 C 维度采用 mean 函数压缩为一个通道,输出维度变成 (1, H, W)
|
||||||
|
- `select_max` 从输入的 C 维度中先在空间维度 sum,维度变成 (C, ),然后选择值最大的通道
|
||||||
|
- `None` 表示不需要压缩,此时可以通过 topk 参数可选择激活度最高的 topk 个特征图显示
|
||||||
|
|
||||||
|
- 在 channel_reduction 参数为 None 的情况下,topk 参数生效,其会按照激活度排序选择 topk 个通道,然后和图片进行叠加显示,并且此时会通过 arrangement 参数指定显示的布局
|
||||||
|
- 如果 topk 不是 -1,则会按照激活度排序选择 topk 个通道显示
|
||||||
|
- 如果 topk = -1,此时通道 C 必须是 1 或者 3 表示输入数据是图片,否则报错提示用户应该设置 `channel_reduction`来压缩通道。
|
||||||
|
|
||||||
|
- 考虑到输入的特征图通常非常小,函数支持输入 `resize_shape` 参数,方便将特征图进行上采样后进行可视化。
|
||||||
|
|
||||||
|
**(2) 存储相关接口**
|
||||||
|
|
||||||
|
- [add_config](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.writer.BaseWriter.add_config) 写配置到特定存储后端
|
||||||
|
- [add_graph](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.writer.BaseWriter.add_graph) 写模型图到特定存储后端
|
||||||
|
- [add_image](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.writer.BaseWriter.add_image) 写图片到特定存储后端
|
||||||
|
- [add_scalar](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.writer.BaseWriter.add_scalar) 写标量到特定存储后端
|
||||||
|
- [add_scalars](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.writer.BaseWriter.add_scalars) 一次性写多个标量到特定存储后端
|
||||||
|
- [add_datasample](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.writer.BaseWriter.add_datasample) 各个下游库绘制 datasample 数据的抽象接口
|
||||||
|
|
||||||
|
以 add 前缀开头的接口表示存储接口。datasample 是 OpenMMLab 2.0 架构中设计的各个下游库统一的抽象数据接口,而 `add_datasample` 接口可以直接处理该数据格式,例如可视化预测结果、可视化 Dataset 或者 DataLoader 输出、可视化中间预测结果等等都可以直接调用下游库重写的 `add_datasample` 接口。
|
||||||
|
|
||||||
|
所有下游库都必须要继承 Visualizer 并实现 `add_datasample` 接口。以 MMDetection 为例,应该继承并通过该接口实现目标检测中所有预置任务的可视化功能,例如目标检测、实例分割、全景分割任务结果的绘制和存储。
|
||||||
|
|
||||||
|
**(3) 其余功能性接口**
|
||||||
|
|
||||||
|
- [set_image](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.Visualizer.set_image) 设置原始图片数据,默认输入图片格式为 RGB
|
||||||
|
- [get_image](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.Visualizer.get_image) 获取绘制后的 Numpy 格式图片数据,默认输出格式为 RGB
|
||||||
|
- [show](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.Visualizer.show) 可视化
|
||||||
|
- [get_backend](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.Visualizer.get_backend) 通过 name 获取特定存储后端
|
||||||
|
- [close](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.Visualizer.close) 关闭所有已经打开的资源,包括 VisBackend
|
||||||
|
|
||||||
|
### 使用样例
|
||||||
|
|
||||||
|
**(1) 在任意位置获取 visualizer**
|
||||||
|
|
||||||
|
为了确保可视化对象 Visualizer 能够在任何地方被调用,设计上将其继承自 `ManagerMixin` 类,转变为全局唯一对象,用户初始化 `Visualizer` 时必须要调用 `visualizer.get_instance()` 方法才能使实例对象具备全局唯一性。一旦实例化完成,后续可以在任意代码位置通过 `Visualizer.get_current_instance()` 来获取可视化对象。
|
||||||
|
|
||||||
|
以 MMDetection 为例,假设 `DetLocalVisualizer` 类继承自 `Visualizer`,并实现了 `add_datasample` 接口。配置文件写法为:
|
||||||
|
|
||||||
|
```python
|
||||||
|
vis_backends = [dict(type='LocalVisBackend')]
|
||||||
|
visualizer = dict(
|
||||||
|
type='DetLocalVisualizer', vis_backends=vis_backends, name='visualizer')
|
||||||
|
```
|
||||||
|
```python
|
||||||
|
# 内部会调用 get_instance() 进行全局唯一实例化
|
||||||
|
VISUALIZERS.build(cfg.visualizer)
|
||||||
|
```
|
||||||
|
|
||||||
|
通过上述代码实例化后,可以在任意位置调用 `get_current_instance` 方法来获取 visualizer
|
||||||
|
|
||||||
|
```python
|
||||||
|
# 任意代码位置获取 visualizer
|
||||||
|
visualizer = Visualizer.get_current_instance()
|
||||||
|
```
|
||||||
|
|
||||||
|
如果用户直接使用了 MMEngine 或者下游库中的 Runner,则无需进行额外的实例化,因为在 Runner 的初始化函数中会自动创建全局唯一的 visualizer。
|
||||||
|
|
||||||
|
**(2) 将数据写入至特定后端**
|
||||||
|
|
||||||
|
在获取到 visualizer 后,可以调用 `add_xxx` 接口将各类数据写入到特定后端
|
||||||
|
|
||||||
|
```python
|
||||||
|
# 绘制 datasample,并保存到本地存储后端
|
||||||
|
visualizer.add_datasample('demo_image', image, gt_sample, pred_sample, step=1)
|
||||||
|
# 直接本地窗口显示,而无需存储
|
||||||
|
visualizer.add_datasample('demo_image', image, gt_sample, pred_sample, show=True)
|
||||||
|
|
||||||
|
# 写图片
|
||||||
|
visualizer.add_image('demo_image', image, step=1)
|
||||||
|
|
||||||
|
# 写模型精度值
|
||||||
|
visualizer.add_scalar('mAP', 0.9, step=1)
|
||||||
|
visualizer.add_scalars({'loss': 1.2, 'acc': 0.8}, step=1)
|
||||||
|
|
||||||
|
# 写配置文件
|
||||||
|
visualizer.add_config(cfg)
|
||||||
|
|
||||||
|
# 写模型图
|
||||||
|
visualizer.add_graph(model, data_batch)
|
||||||
|
```
|
||||||
|
|
||||||
|
**(3) 特征图可视化**
|
||||||
|
|
||||||
|
通过 `channel_reduction` 参数压缩或者选择特征图,并显示到本地窗口
|
||||||
|
|
||||||
|
```python
|
||||||
|
featmap = ... # CHW shape 的 tensor
|
||||||
|
|
||||||
|
# 压缩
|
||||||
|
feat_img = visualizer.draw_featmap(featmap, channel_reduction='squeeze_mean')
|
||||||
|
visualizer.show(feat_img)
|
||||||
|
|
||||||
|
# 选择激活度最高的通道显示
|
||||||
|
feat_img = visualizer.draw_featmap(featmap, channel_reduction='select_max')
|
||||||
|
visualizer.show(feat_img)
|
||||||
|
```
|
||||||
|
|
||||||
|
叠加图片显示
|
||||||
|
|
||||||
|
```python
|
||||||
|
featmap = ... # CHW shape 的 tensor
|
||||||
|
img = ... # 如果 featmap 和 img 空间尺寸不一致,内部会对 featmap 进行插值
|
||||||
|
|
||||||
|
# 压缩
|
||||||
|
feat_img = visualizer.draw_featmap(featmap, img, channel_reduction='squeeze_mean')
|
||||||
|
visualizer.show(feat_img)
|
||||||
|
|
||||||
|
# 选择激活度最高的通道显示
|
||||||
|
feat_img = visualizer.draw_featmap(featmap, img, channel_reduction='select_max')
|
||||||
|
visualizer.show(feat_img)
|
||||||
|
```
|
||||||
|
|
||||||
|
通过 `topk` 参数选择指定个数的通道显示,并显示到本地窗口
|
||||||
|
|
||||||
|
```python
|
||||||
|
featmap= ... # CHW shape 的 tensor
|
||||||
|
|
||||||
|
# topk,并以 2 行 5 列模式显示
|
||||||
|
feat_img = visualizer.draw_featmap(featmap, channel_reduction=None, topk=10, arrangement=(2, 5))
|
||||||
|
visualizer.show(feat_img)
|
||||||
|
|
||||||
|
# topk,并以 5 行 2 列模式显示
|
||||||
|
feat_img = visualizer.draw_featmap(featmap, channel_reduction=None, topk=10, arrangement=(5, 2))
|
||||||
|
visualizer.show(feat_img)
|
||||||
|
```
|
||||||
|
|
||||||
|
通过 `resize_shape` 缩放显示的特征图
|
||||||
|
|
||||||
|
```python
|
||||||
|
featmap = ... # CHW shape 的 tensor
|
||||||
|
|
||||||
|
# 压缩
|
||||||
|
feat_img = visualizer.draw_featmap(featmap, channel_reduction='squeeze_mean', resize_shape=(224, 224))
|
||||||
|
visualizer.show(feat_img)
|
||||||
|
```
|
||||||
|
|
||||||
|
存储特征图到可视化后端
|
||||||
|
|
||||||
|
```python
|
||||||
|
featmap = ... # CHW shape 的 tensor
|
||||||
|
|
||||||
|
# 压缩
|
||||||
|
feat_img = visualizer.draw_featmap(featmap, channel_reduction='squeeze_mean', resize_shape=(224, 224))
|
||||||
|
# 存储
|
||||||
|
visualizer.add_image('feat_image', feat_img)
|
||||||
|
```
|
||||||
|
|
||||||
|
**(4) 远程窗口显示**
|
||||||
|
|
||||||
|
用户可以指定 Wandb 、Tensorboard 或者自定义具备远程窗口显示的后端来保存数据,然后在浏览器上显示。以 Wandb 为例,典型配置为:
|
||||||
|
|
||||||
|
```python
|
||||||
|
vis_backends = [dict(type='WandbVisBackend')]
|
||||||
|
visualizer = dict(
|
||||||
|
type='DetWandbVisualizer', vis_backends=vis_backends, name='visualizer')
|
||||||
|
```
|
||||||
|
|
||||||
|
使用方法和上面完全一致。需要特别注意的是由于 Wandb 绘制的数据无法和 `LocalVisBackend` 后端兼容,所以当 `vis_backends` 存在多个可视化存储后端时候只有 `WandbVisBackend` 才是有效的。
|
||||||
|
|
||||||
|
## 可视化存储后端 VisBackend
|
||||||
|
|
||||||
|
在绘制后可以将绘制后的数据存储到多个可视化存储后端中。为了统一接口调用,MMEngine 提供了统一的抽象类 `BaseVisBackend`,和一些常用的 VisBackend 如 `LocalVisBackend`、`WandbVisBackend` 和 `TensorboardVisBackend`。
|
||||||
|
|
||||||
|
### 接口说明
|
||||||
|
|
||||||
|
BaseVisBackend 定义了对外调用的接口规范,主要接口和属性如下:
|
||||||
|
|
||||||
|
- [add_config](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.vis_backend.BaseVisBackend.add_config) 写配置到特定存储后端
|
||||||
|
- [add_graph](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.vis_backend.BaseVisBackend.add_graph) 写模型图到特定后端
|
||||||
|
- [add_image](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.vis_backend.BaseVisBackend.add_image) 写图片到特定后端
|
||||||
|
- [add_scalar](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.vis_backend.BaseVisBackend.add_scalar) 写标量到特定后端
|
||||||
|
- [add_scalars](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.vis_backend.BaseVisBackend.add_scalars) 一次性写多个标量到特定后端
|
||||||
|
- [close](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.vis_backend.BaseVisBackend.close) 关闭已经打开的资源
|
||||||
|
- [experiment](https://mmengine.readthedocs.io/zh/latest/api.html#mmengine.visualization.vis_backend.BaseVisBackend.experiment) 写后端对象,例如 Wandb 对象和 Tensorboard 对象
|
||||||
|
|
||||||
|
`BaseVisBackend` 定义了 5 个常见的写数据接口,考虑到某些写后端功能非常强大,例如 Wandb,其具备写表格,写视频等等功能,针对这类需求用户可以直接获取 experiment 对象,然后调用写后端对象本身的 API 即可。而 `LocalVisBackend`、`WandbVisBackend` 和 `TensorboardVisBackend` 等都是继承自 `BaseVisBackend`,并根据自身特性实现了对应的存储功能。
|
||||||
|
|
||||||
|
### 使用案例
|
||||||
|
|
||||||
|
一般情况下用户无需操作 VisBackend 对象,只有在当前可视化存储无法满足需求时候,用户会希望直接操作存储后端。以 Wandb 为例,其提供了非常丰富的存储格式,例如存储表格、存储权重等等接口。为了所有后端能够统一接口,我们并没有提供这类常用接口,此时用户可以直接获取 Wandb 对象进行自定义存储。
|
||||||
|
|
||||||
|
```python
|
||||||
|
vis_backends = [dict(type='WandbVisBackend')]
|
||||||
|
visualizer = dict(
|
||||||
|
type='DetWandbVisualizer', vis_backends=vis_backends, name='visualizer')
|
||||||
|
```
|
||||||
|
|
||||||
|
```python
|
||||||
|
# 内部会调用 get_instance() 进行全局唯一实例化
|
||||||
|
VISUALIZERS.build(cfg.visualizer)
|
||||||
|
# 任意代码位置获取 visualizer
|
||||||
|
visualizer = Visualizer.get_current_instance()
|
||||||
|
|
||||||
|
# 扩展 add 功能,例如利用 Wandb 对象绘制表格
|
||||||
|
wandb = visualizer.get_backend('WandbVisBackend').experiment
|
||||||
|
val_table = wandb.Table(data=my_data, columns=column_names)
|
||||||
|
wandb.log({'my_val_table': val_table})
|
||||||
|
```
|
||||||
|
|
||||||
|
一个 visualizer 对象可以接入任意多个 VisBackend。为了方便用户获取任意的 VisBackend,在不指定 name 参数情况下,可以通过类名获取
|
||||||
|
|
||||||
|
```python
|
||||||
|
vis_backends = [dict(type='LocalVisBackend'), dict(type='WandbVisBackend')]
|
||||||
|
visualizer = dict(
|
||||||
|
type='DetLocalVisualizer', vis_backends=vis_backends, name='visualizer')
|
||||||
|
```
|
||||||
|
|
||||||
|
```python
|
||||||
|
# 内部会调用 get_instance() 进行全局唯一实例化
|
||||||
|
VISUALIZERS.build(cfg.visualizer)
|
||||||
|
# 任意代码位置获取 visualizer
|
||||||
|
visualizer = Visualizer.get_current_instance()
|
||||||
|
|
||||||
|
local_vis_backend = visualizer.get_backend('LocalVisBackend')
|
||||||
|
wandb_vis_backend = visualizer.get_backend('WandbVisBackend')
|
||||||
|
```
|
||||||
|
|
||||||
|
当存在多个同名的 VisBackend 时候,用户必须指定唯一的 name 参数,后续可以通过 name 字符串来获取
|
||||||
|
|
||||||
|
```python
|
||||||
|
vis_backends = [dict(type='LocalVisBackend', name='local_vis_backend_1'), dict(type='LocalVisBackend', name='local_vis_backend_2')]
|
||||||
|
visualizer = dict(
|
||||||
|
type='DetLocalVisualizer', vis_backends=vis_backends, name='visualizer')
|
||||||
|
```
|
||||||
|
|
||||||
|
```python
|
||||||
|
# 内部会调用 get_instance() 进行全局唯一实例化
|
||||||
|
VISUALIZERS.build(cfg.visualizer)
|
||||||
|
# 任意代码位置获取 visualizer
|
||||||
|
visualizer = Visualizer.get_current_instance()
|
||||||
|
|
||||||
|
local_vis_backend_1 = visualizer.get_backend('local_vis_backend_1')
|
||||||
|
local_vis_backend_2 = visualizer.get_backend('local_vis_backend_2')
|
||||||
|
```
|
@ -358,11 +358,11 @@ class Hook:
|
|||||||
"""
|
"""
|
||||||
return (runner.epoch + 1) % n == 0 if n > 0 else False
|
return (runner.epoch + 1) % n == 0 if n > 0 else False
|
||||||
|
|
||||||
def every_n_inner_iters(self, inner_iter: int, n: int) -> bool:
|
def every_n_inner_iters(self, batch_idx: int, n: int) -> bool:
|
||||||
"""Test whether current inner iteration can be evenly divided by n.
|
"""Test whether current inner iteration can be evenly divided by n.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
inner_iter (int): Current inner_iter of the training, validation
|
batch_idx (int): Current batch index of the training, validation
|
||||||
or testing loop.
|
or testing loop.
|
||||||
n (int): Whether current inner iteration can be evenly
|
n (int): Whether current inner iteration can be evenly
|
||||||
divided by n.
|
divided by n.
|
||||||
@ -371,7 +371,7 @@ class Hook:
|
|||||||
bool: Whether current inner iteration can be evenly
|
bool: Whether current inner iteration can be evenly
|
||||||
divided by n.
|
divided by n.
|
||||||
"""
|
"""
|
||||||
return (inner_iter + 1) % n == 0 if n > 0 else False
|
return (batch_idx + 1) % n == 0 if n > 0 else False
|
||||||
|
|
||||||
def every_n_iters(self, runner, n: int) -> bool:
|
def every_n_iters(self, runner, n: int) -> bool:
|
||||||
"""Test whether current iteration can be evenly divided by n.
|
"""Test whether current iteration can be evenly divided by n.
|
||||||
@ -395,7 +395,6 @@ class Hook:
|
|||||||
dataloader (Dataloader): The dataloader of the training,
|
dataloader (Dataloader): The dataloader of the training,
|
||||||
validation or testing process.
|
validation or testing process.
|
||||||
batch_idx (int): The index of the current batch in the loop.
|
batch_idx (int): The index of the current batch in the loop.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
bool: Whether reaches the end of current epoch or not.
|
bool: Whether reaches the end of current epoch or not.
|
||||||
"""
|
"""
|
||||||
@ -418,10 +417,10 @@ class Hook:
|
|||||||
Args:
|
Args:
|
||||||
runner (Runner): The runner of the training, validation or testing
|
runner (Runner): The runner of the training, validation or testing
|
||||||
process.
|
process.
|
||||||
|
mode (str): Current mode of runner. Defaults to 'train'.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
bool: Whether current iteration is the last iteration.
|
bool: Whether current iteration is the last iteration.
|
||||||
mode (str): Current mode of runner. Defaults to 'train'.
|
|
||||||
"""
|
"""
|
||||||
if mode == 'train':
|
if mode == 'train':
|
||||||
return runner.iter + 1 == runner.train_loop.max_iters
|
return runner.iter + 1 == runner.train_loop.max_iters
|
||||||
|
@ -18,11 +18,25 @@ class IterTimerHook(Hook):
|
|||||||
|
|
||||||
priority = 'NORMAL'
|
priority = 'NORMAL'
|
||||||
|
|
||||||
def _before_epoch(self, runner, mode: str = 'train') -> None:
|
def __init__(self):
|
||||||
"""Record time flag before start a epoch.
|
self.time_sec_tot = 0
|
||||||
|
self.start_iter = 0
|
||||||
|
|
||||||
|
def before_run(self, runner) -> None:
|
||||||
|
"""Synchronize the number of iterations with the runner.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
runner (Runner): The runner of the training process.
|
runner: The runner of the training, validation or testing
|
||||||
|
process.
|
||||||
|
"""
|
||||||
|
self.start_iter = runner.iter
|
||||||
|
|
||||||
|
def _before_epoch(self, runner, mode: str = 'train') -> None:
|
||||||
|
"""Record timestamp before start an epoch.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
runner (Runner): The runner of the training validation and
|
||||||
|
testing process.
|
||||||
mode (str): Current mode of runner. Defaults to 'train'.
|
mode (str): Current mode of runner. Defaults to 'train'.
|
||||||
"""
|
"""
|
||||||
self.t = time.time()
|
self.t = time.time()
|
||||||
@ -32,16 +46,18 @@ class IterTimerHook(Hook):
|
|||||||
batch_idx: int,
|
batch_idx: int,
|
||||||
data_batch: DATA_BATCH = None,
|
data_batch: DATA_BATCH = None,
|
||||||
mode: str = 'train') -> None:
|
mode: str = 'train') -> None:
|
||||||
"""Logging time for loading data and update the time flag.
|
"""Calculating time for loading data and updating "data_time"
|
||||||
|
``HistoryBuffer`` of ``runner.message_hub``.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
runner (Runner): The runner of the training process.
|
runner (Runner): The runner of the training, validation and
|
||||||
|
testing process.
|
||||||
batch_idx (int): The index of the current batch in the loop.
|
batch_idx (int): The index of the current batch in the loop.
|
||||||
data_batch (Sequence[dict], optional): Data from dataloader.
|
data_batch (Sequence[dict], optional): Data from dataloader.
|
||||||
Defaults to None.
|
Defaults to None.
|
||||||
mode (str): Current mode of runner. Defaults to 'train'.
|
mode (str): Current mode of runner. Defaults to 'train'.
|
||||||
"""
|
"""
|
||||||
# TODO: update for new logging system
|
# Update data loading time in `runner.message_hub`.
|
||||||
runner.message_hub.update_scalar(f'{mode}/data_time',
|
runner.message_hub.update_scalar(f'{mode}/data_time',
|
||||||
time.time() - self.t)
|
time.time() - self.t)
|
||||||
|
|
||||||
@ -52,10 +68,12 @@ class IterTimerHook(Hook):
|
|||||||
outputs: Optional[Union[dict,
|
outputs: Optional[Union[dict,
|
||||||
Sequence[BaseDataElement]]] = None,
|
Sequence[BaseDataElement]]] = None,
|
||||||
mode: str = 'train') -> None:
|
mode: str = 'train') -> None:
|
||||||
"""Logging time for a iteration and update the time flag.
|
"""Calculating time for an iteration and updating "time"
|
||||||
|
``HistoryBuffer`` of ``runner.message_hub``.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
runner (Runner): The runner of the training process.
|
runner (Runner): The runner of the training validation and
|
||||||
|
testing process.
|
||||||
batch_idx (int): The index of the current batch in the loop.
|
batch_idx (int): The index of the current batch in the loop.
|
||||||
data_batch (Sequence[dict], optional): Data from dataloader.
|
data_batch (Sequence[dict], optional): Data from dataloader.
|
||||||
Defaults to None.
|
Defaults to None.
|
||||||
@ -63,7 +81,31 @@ class IterTimerHook(Hook):
|
|||||||
to None.
|
to None.
|
||||||
mode (str): Current mode of runner. Defaults to 'train'.
|
mode (str): Current mode of runner. Defaults to 'train'.
|
||||||
"""
|
"""
|
||||||
# TODO: update for new logging system
|
# Update iteration time in `runner.message_hub`.
|
||||||
|
message_hub = runner.message_hub
|
||||||
runner.message_hub.update_scalar(f'{mode}/time', time.time() - self.t)
|
message_hub.update_scalar(f'{mode}/time', time.time() - self.t)
|
||||||
self.t = time.time()
|
self.t = time.time()
|
||||||
|
window_size = runner.log_processor.window_size
|
||||||
|
# Calculate eta every `window_size` iterations. Since test and val
|
||||||
|
# loop will not update runner.iter, use `every_n_innter_iters`to check
|
||||||
|
# the interval.
|
||||||
|
if self.every_n_inner_iters(batch_idx, window_size):
|
||||||
|
iter_time = message_hub.get_scalar(f'{mode}/time').mean(
|
||||||
|
window_size)
|
||||||
|
if mode == 'train':
|
||||||
|
self.time_sec_tot += iter_time * window_size
|
||||||
|
# Calculate average iterative time.
|
||||||
|
time_sec_avg = self.time_sec_tot / (
|
||||||
|
runner.iter - self.start_iter + 1)
|
||||||
|
# Calculate eta.
|
||||||
|
eta_sec = time_sec_avg * (
|
||||||
|
runner.train_loop.max_iters - runner.iter - 1)
|
||||||
|
runner.message_hub.update_info('eta', eta_sec)
|
||||||
|
else:
|
||||||
|
if mode == 'val':
|
||||||
|
cur_dataloader = runner.val_loop.dataloader
|
||||||
|
else:
|
||||||
|
cur_dataloader = runner.test_loop.dataloader
|
||||||
|
|
||||||
|
eta_sec = iter_time * (len(cur_dataloader) - batch_idx - 1)
|
||||||
|
runner.message_hub.update_info('eta', eta_sec)
|
||||||
|
@ -1,14 +1,10 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
import copy
|
|
||||||
import datetime
|
|
||||||
import os
|
import os
|
||||||
import os.path as osp
|
import os.path as osp
|
||||||
from collections import OrderedDict
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, Sequence, Union
|
from typing import Optional, Sequence, Union
|
||||||
|
|
||||||
import torch
|
from mmengine.data import BaseDataElement
|
||||||
|
|
||||||
from mmengine.fileio import FileClient
|
from mmengine.fileio import FileClient
|
||||||
from mmengine.hooks import Hook
|
from mmengine.hooks import Hook
|
||||||
from mmengine.registry import HOOKS
|
from mmengine.registry import HOOKS
|
||||||
@ -19,33 +15,20 @@ DATA_BATCH = Optional[Sequence[dict]]
|
|||||||
|
|
||||||
@HOOKS.register_module()
|
@HOOKS.register_module()
|
||||||
class LoggerHook(Hook):
|
class LoggerHook(Hook):
|
||||||
"""In this logger hook, the information will be printed on the terminal and
|
"""Collect logs from different components of ``Runner`` and write them to
|
||||||
saved in JSON file, tensorboard, wandb .etc.
|
terminal, JSON file, tensorboard and wandb .etc.
|
||||||
|
|
||||||
|
``LoggerHook`` is used to record logs formatted by ``LogProcessor`` during
|
||||||
|
training/validation/testing phase. It is used to control following
|
||||||
|
behaviers:
|
||||||
|
|
||||||
|
- The frequency of logs update in terminal, local, tensorboad wandb.etc.
|
||||||
|
- The frequency of show experiment information in terminal.
|
||||||
|
- The work directory to save logs.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
by_epoch (bool): Whether ``EpochBasedLoop`` is used.
|
|
||||||
Defaults to True.
|
|
||||||
interval (int): Logging interval (every k iterations).
|
interval (int): Logging interval (every k iterations).
|
||||||
Defaults to 10.
|
Defaults to 10.
|
||||||
custom_keys (dict, optional): Defines the keys in the log and which
|
|
||||||
kinds of statistic methods should be used to log them.
|
|
||||||
|
|
||||||
- ``custom_keys`` contains multiple string-dict pairs. In each
|
|
||||||
string-dict pair, the string defines a key name in the log and the
|
|
||||||
dict is a config defines the statistic methods and corresponding
|
|
||||||
arguments used to log the value. For example,
|
|
||||||
``dict(loss=dict(method_name='mean', log_name='global_loss',
|
|
||||||
window_size='global'))`` which means the log key ``loss`` will be
|
|
||||||
counted as global mean and additionally logged as ``global_loss``.
|
|
||||||
If ``log_name`` is not defined in config dict, the original logged
|
|
||||||
key will be overwritten.
|
|
||||||
- The key in ``LoggerHook.fixed_smooth_keys`` cannot be overwritten
|
|
||||||
because ``time`` and ``iter_time`` will be used to calculate
|
|
||||||
estimated time of arrival. If you want to recount the time, you
|
|
||||||
should set ``log_name`` in corresponding values.
|
|
||||||
- For those statistic methods with the ``window_size`` argument,
|
|
||||||
if ``by_epoch`` is set to False, ``windows_size`` should not be
|
|
||||||
`epoch` to statistics log value by epoch.
|
|
||||||
ignore_last (bool): Ignore the log of last iterations in each epoch if
|
ignore_last (bool): Ignore the log of last iterations in each epoch if
|
||||||
the number of remaining iterations is less than :attr:`interval`.
|
the number of remaining iterations is less than :attr:`interval`.
|
||||||
Defaults to True.
|
Defaults to True.
|
||||||
@ -70,64 +53,24 @@ class LoggerHook(Hook):
|
|||||||
Defaults to None.
|
Defaults to None.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> # `log_name` is defined, `loss_mean_window` will be an additional
|
>>> # A simplest LoggerHook config.
|
||||||
>>> # record.
|
>>> logger_hook_cfg = dict(interval=20)
|
||||||
>>> logger_hook_cfg = dict(by_epoch=True,
|
|
||||||
>>> custom_keys=dict(
|
|
||||||
>>> loss=dict(
|
|
||||||
>>> log_name='loss_mean_window',
|
|
||||||
>>> method_name='mean',
|
|
||||||
>>> window_size=10)))
|
|
||||||
>>> # `log_name` is not defined. `loss` will be overwritten by
|
|
||||||
>>> # `global_mean` statistics.
|
|
||||||
>>> logger_hook_cfg = dict(by_epoch=True,
|
|
||||||
>>> custom_keys=dict(
|
|
||||||
>>> loss=dict(
|
|
||||||
>>> method_name='mean',
|
|
||||||
>>> window_size='global')))
|
|
||||||
>>> # `time` cannot be overwritten, `global_time` will be an additional
|
|
||||||
>>> # record.
|
|
||||||
>>> logger_hook_cfg = dict(by_epoch=True,
|
|
||||||
>>> custom_keys=dict(
|
|
||||||
>>> time=dict(
|
|
||||||
>>> log_name='global_time',
|
|
||||||
>>> method='mean',
|
|
||||||
>>> window_size='global')))
|
|
||||||
>>> # Record loss with different statistics methods.
|
|
||||||
>>> logger_hook_cfg = dict(by_epoch=True,
|
|
||||||
>>> custom_keys=dict(loss=[
|
|
||||||
>>> dict(log_name='loss_mean_window',
|
|
||||||
>>> method_name='mean',
|
|
||||||
>>> window_size=10),
|
|
||||||
>>> dict(method_name='mean',
|
|
||||||
>>> window_size='global')]))
|
|
||||||
"""
|
"""
|
||||||
# eta will be calculated by time. `time` and `data_time` should not be
|
|
||||||
# overwritten.
|
|
||||||
fixed_smooth_keys = ('time', 'data_time')
|
|
||||||
priority = 'BELOW_NORMAL'
|
priority = 'BELOW_NORMAL'
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
by_epoch: bool = True,
|
|
||||||
interval: int = 10,
|
interval: int = 10,
|
||||||
custom_keys: Optional[dict] = None,
|
|
||||||
ignore_last: bool = True,
|
ignore_last: bool = True,
|
||||||
interval_exp_name: int = 1000,
|
interval_exp_name: int = 1000,
|
||||||
out_dir: Optional[Union[str, Path]] = None,
|
out_dir: Optional[Union[str, Path]] = None,
|
||||||
out_suffix: Union[Sequence[str], str] = ('.log.json', '.log', '.py'),
|
out_suffix: Union[Sequence[str], str] = ('.log.json', '.log', '.py'),
|
||||||
keep_local=True,
|
keep_local: bool = True,
|
||||||
file_client_args=None,
|
file_client_args: Optional[dict] = None,
|
||||||
):
|
):
|
||||||
self._inner_iter = 0
|
|
||||||
self.by_epoch = by_epoch
|
|
||||||
self.interval = interval
|
self.interval = interval
|
||||||
self.custom_keys = custom_keys if custom_keys is not None else dict()
|
|
||||||
self.ignore_last = ignore_last
|
self.ignore_last = ignore_last
|
||||||
|
|
||||||
self.time_sec_tot = 0
|
|
||||||
self.interval_exp_name = interval_exp_name
|
self.interval_exp_name = interval_exp_name
|
||||||
self._check_custom_keys()
|
|
||||||
|
|
||||||
if out_dir is None and file_client_args is not None:
|
if out_dir is None and file_client_args is not None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -165,14 +108,15 @@ class LoggerHook(Hook):
|
|||||||
|
|
||||||
self.json_log_path = osp.join(runner.work_dir,
|
self.json_log_path = osp.join(runner.work_dir,
|
||||||
f'{runner.timestamp}.log.json')
|
f'{runner.timestamp}.log.json')
|
||||||
self.start_iter = runner.iter
|
self.yaml_log_path = osp.join(runner.work_dir,
|
||||||
|
f'{runner.timestamp}.log.json')
|
||||||
|
|
||||||
def after_train_iter(self,
|
def after_train_iter(self,
|
||||||
runner,
|
runner,
|
||||||
batch_idx: int,
|
batch_idx: int,
|
||||||
data_batch: DATA_BATCH = None,
|
data_batch: DATA_BATCH = None,
|
||||||
outputs: Optional[dict] = None) -> None:
|
outputs: Optional[dict] = None) -> None:
|
||||||
"""Record training logs.
|
"""Record training logs after training iteration.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
runner (Runner): The runner of the training process.
|
runner (Runner): The runner of the training process.
|
||||||
@ -182,33 +126,90 @@ class LoggerHook(Hook):
|
|||||||
outputs (dict, optional): Outputs from model.
|
outputs (dict, optional): Outputs from model.
|
||||||
Defaults to None.
|
Defaults to None.
|
||||||
"""
|
"""
|
||||||
self._inner_iter = batch_idx
|
# Print experiment name every n iterations.
|
||||||
if runner.meta is not None and 'exp_name' in runner.meta:
|
if self.every_n_iters(runner,
|
||||||
if (self.every_n_iters(runner, self.interval_exp_name)) or (
|
self.interval_exp_name) or (self.end_of_epoch(
|
||||||
self.by_epoch and self.end_of_epoch(
|
runner.train_dataloader, batch_idx)):
|
||||||
runner.train_loop.dataloader, batch_idx)):
|
exp_info = f'Exp name: {runner.experiment_name}'
|
||||||
exp_info = f'Exp name: {runner.meta["exp_name"]}'
|
runner.logger.info(exp_info)
|
||||||
runner.logger.info(exp_info)
|
if self.every_n_inner_iters(batch_idx, self.interval):
|
||||||
if self.by_epoch and self.every_n_inner_iters(batch_idx,
|
tag, log_str = runner.log_processor.get_log_after_iter(
|
||||||
self.interval):
|
runner, batch_idx, 'train')
|
||||||
self._log_train(runner)
|
elif (self.end_of_epoch(runner.train_dataloader, batch_idx)
|
||||||
elif not self.by_epoch and self.every_n_iters(runner, self.interval):
|
and not self.ignore_last):
|
||||||
self._log_train(runner)
|
|
||||||
elif self.end_of_epoch(runner.train_loop.dataloader,
|
|
||||||
batch_idx) and not self.ignore_last:
|
|
||||||
# `runner.max_iters` may not be divisible by `self.interval`. if
|
# `runner.max_iters` may not be divisible by `self.interval`. if
|
||||||
# `self.ignore_last==True`, the log of remaining iterations will
|
# `self.ignore_last==True`, the log of remaining iterations will
|
||||||
# be recorded (Epoch [4][1000/1007], the logs of 998-1007
|
# be recorded (Epoch [4][1000/1007], the logs of 998-1007
|
||||||
# iterations will be recorded).
|
# iterations will be recorded).
|
||||||
self._log_train(runner)
|
tag, log_str = runner.log_processor.get_log_after_iter(
|
||||||
|
runner, batch_idx, 'train')
|
||||||
|
else:
|
||||||
|
return
|
||||||
|
runner.logger.info(log_str)
|
||||||
|
# TODO compatible with visualizer.
|
||||||
|
runner.visualizer.add_scalars(tag, step=runner.iter + 1)
|
||||||
|
|
||||||
|
def after_val_iter(
|
||||||
|
self,
|
||||||
|
runner,
|
||||||
|
batch_idx: int,
|
||||||
|
data_batch: DATA_BATCH = None,
|
||||||
|
outputs: Optional[Sequence[BaseDataElement]] = None) -> None:
|
||||||
|
"""Record validation logs after validation iteration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
runner (Runner): The runner of the training process.
|
||||||
|
batch_idx (int): The index of the current batch in the train loop.
|
||||||
|
data_batch (Sequence[Tuple[Any, BaseDataElement]], optional):
|
||||||
|
Data from dataloader. Defaults to None.
|
||||||
|
outputs (sequence, optional): Outputs from model. Defaults to None.
|
||||||
|
"""
|
||||||
|
if self.every_n_inner_iters(batch_idx, self.interval):
|
||||||
|
tag, log_str = runner.log_processor.get_log_after_iter(
|
||||||
|
runner, batch_idx, 'val')
|
||||||
|
runner.logger.info(log_str)
|
||||||
|
|
||||||
|
def after_test_iter(
|
||||||
|
self,
|
||||||
|
runner,
|
||||||
|
batch_idx: int,
|
||||||
|
data_batch: DATA_BATCH = None,
|
||||||
|
outputs: Optional[Sequence[BaseDataElement]] = None) -> None:
|
||||||
|
"""Record testing logs after iteration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
runner (Runner): The runner of the training process.
|
||||||
|
batch_idx (int): The index of the current batch in the train loop.
|
||||||
|
data_batch (Sequence[Tuple[Any, BaseDataElement]], optional):
|
||||||
|
Data from dataloader. Defaults to None.
|
||||||
|
outputs (sequence, optional): Outputs from model. Defaults to None.
|
||||||
|
"""
|
||||||
|
if self.every_n_inner_iters(batch_idx, self.interval):
|
||||||
|
tag, log_str = runner.log_processor.get_log_after_iter(
|
||||||
|
runner, batch_idx, 'test')
|
||||||
|
runner.logger.info(log_str)
|
||||||
|
|
||||||
def after_val_epoch(self, runner) -> None:
|
def after_val_epoch(self, runner) -> None:
|
||||||
"""Record validation logs.
|
"""Record validation logs after validation epoch.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
runner (Runner): The runner of the training process.
|
runner (Runner): The runner of the training process.
|
||||||
"""
|
"""
|
||||||
self._log_val(runner)
|
tag, log_str = runner.log_processor.get_log_after_epoch(
|
||||||
|
runner, len(runner.val_dataloader), 'val')
|
||||||
|
runner.logger.info(log_str)
|
||||||
|
# TODO compatible with visualizer.
|
||||||
|
runner.visualizer.add_scalars(tag, step=runner.iter + 1)
|
||||||
|
|
||||||
|
def after_test_epoch(self, runner) -> None:
|
||||||
|
"""Record testing logs after test epoch.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
runner (Runner): The runner of the training process.
|
||||||
|
"""
|
||||||
|
tag, log_str = runner.log_processor.get_log_after_epoch(
|
||||||
|
runner, len(runner.val_dataloader), 'test')
|
||||||
|
runner.logger.info(log_str)
|
||||||
|
|
||||||
def after_run(self, runner) -> None:
|
def after_run(self, runner) -> None:
|
||||||
"""Copy logs to ``self.out_dir`` if ``self.out_dir is not None``
|
"""Copy logs to ``self.out_dir`` if ``self.out_dir is not None``
|
||||||
@ -233,278 +234,3 @@ class LoggerHook(Hook):
|
|||||||
os.remove(local_filepath)
|
os.remove(local_filepath)
|
||||||
runner.logger.info((f'{local_filepath} was removed due to the '
|
runner.logger.info((f'{local_filepath} was removed due to the '
|
||||||
'`self.keep_local=False`'))
|
'`self.keep_local=False`'))
|
||||||
|
|
||||||
def _log_train(self, runner) -> None:
|
|
||||||
"""Collect and record training logs which start named with "train/*".
|
|
||||||
|
|
||||||
Args:
|
|
||||||
runner (Runner): The runner of the training process.
|
|
||||||
"""
|
|
||||||
tag = self._collect_info(runner, 'train')
|
|
||||||
# The training log default defines `lr`, `momentum`, `time` and
|
|
||||||
# `data_time`. `log_tag` will pop these keys and loop other keys to
|
|
||||||
# `log_str`.
|
|
||||||
log_tag = copy.deepcopy(tag)
|
|
||||||
cur_iter = self._get_iter(runner, inner_iter=True)
|
|
||||||
cur_epoch = self._get_epoch(runner, 'train')
|
|
||||||
|
|
||||||
# Record learning rate and momentum.
|
|
||||||
lr_str_list = []
|
|
||||||
momentum_str_list = []
|
|
||||||
for key, value in tag.items():
|
|
||||||
if key.startswith('lr'):
|
|
||||||
log_tag.pop(key)
|
|
||||||
lr_str_list.append(f'{key}: {value:.3e}')
|
|
||||||
lr_str = ' '.join(lr_str_list)
|
|
||||||
for key, value in tag.items():
|
|
||||||
if key.startswith('momentum'):
|
|
||||||
log_tag.pop(key)
|
|
||||||
momentum_str_list.append(f'{key}: {value:.3e}')
|
|
||||||
momentum_str = ' '.join(momentum_str_list)
|
|
||||||
lr_momentum_str = f'{lr_str} {momentum_str}'
|
|
||||||
# by epoch: Epoch [4][100/1000]
|
|
||||||
# by iter: Iter [100/100000]
|
|
||||||
if self.by_epoch:
|
|
||||||
log_str = f'Epoch [{cur_epoch}]' \
|
|
||||||
f'[{cur_iter}/{len(runner.train_loop.dataloader)}] '
|
|
||||||
else:
|
|
||||||
log_str = f'Iter [{cur_iter}/{runner.train_loop.max_iters}] '
|
|
||||||
log_str += f'{lr_momentum_str}, '
|
|
||||||
# Calculate eta time.
|
|
||||||
self.time_sec_tot += (tag['time'] * self.interval)
|
|
||||||
time_sec_avg = self.time_sec_tot / (runner.iter - self.start_iter + 1)
|
|
||||||
eta_sec = time_sec_avg * (
|
|
||||||
runner.train_loop.max_iters - runner.iter - 1)
|
|
||||||
eta_str = str(datetime.timedelta(seconds=int(eta_sec)))
|
|
||||||
log_str += f'eta: {eta_str}, '
|
|
||||||
log_str += f'time: {tag["time"]:.3f}, ' \
|
|
||||||
f'data_time: {tag["data_time"]:.3f}, '
|
|
||||||
# Pop recorded keys
|
|
||||||
log_tag.pop('time')
|
|
||||||
log_tag.pop('data_time')
|
|
||||||
# statistic memory
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
log_str += f'memory: {self._get_max_memory(runner)}, '
|
|
||||||
# Loop left keys to fill `log_str`.
|
|
||||||
log_items = []
|
|
||||||
for name, val in log_tag.items():
|
|
||||||
if isinstance(val, float):
|
|
||||||
val = f'{val:.4f}'
|
|
||||||
log_items.append(f'{name}: {val}')
|
|
||||||
log_str += ', '.join(log_items)
|
|
||||||
runner.logger.info(log_str)
|
|
||||||
# Write logs to local, tensorboad, and wandb.
|
|
||||||
runner.visualizer.add_scalars(
|
|
||||||
tag, step=runner.iter + 1, file_path=self.json_log_path)
|
|
||||||
|
|
||||||
def _log_val(self, runner) -> None:
|
|
||||||
"""Collect and record training logs which start named with "val/*".
|
|
||||||
|
|
||||||
Args:
|
|
||||||
runner (Runner): The runner of the training process.
|
|
||||||
"""
|
|
||||||
tag = self._collect_info(runner, 'val')
|
|
||||||
# Compatible with function `log` https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/hooks/logger/text.py # noqa E501
|
|
||||||
eval_iter = len(runner.val_loop.dataloader)
|
|
||||||
cur_iter = self._get_iter(runner)
|
|
||||||
cur_epoch = self._get_epoch(runner, 'val')
|
|
||||||
# val/test time
|
|
||||||
# here 1000 is the length of the val dataloader
|
|
||||||
# by epoch: Epoch[val] [4][1000]
|
|
||||||
# by iter: Iter[val] [1000]
|
|
||||||
if self.by_epoch:
|
|
||||||
# runner.epoch += 1 has been done before val workflow
|
|
||||||
log_str = f'Epoch(val) [{cur_epoch}][{eval_iter}] '
|
|
||||||
else:
|
|
||||||
log_str = f'Iter(val) [{eval_iter}] '
|
|
||||||
|
|
||||||
log_items = []
|
|
||||||
for name, val in tag.items():
|
|
||||||
if isinstance(val, float):
|
|
||||||
val = f'{val:.4f}'
|
|
||||||
log_items.append(f'{name}: {val}')
|
|
||||||
log_str += ', '.join(log_items)
|
|
||||||
runner.logger.info(log_str)
|
|
||||||
# Write tag.
|
|
||||||
runner.visualizer.add_scalars(
|
|
||||||
tag, step=cur_iter, file_path=self.json_log_path)
|
|
||||||
|
|
||||||
def _get_window_size(self, runner, window_size: Union[int, str]) \
|
|
||||||
-> int:
|
|
||||||
"""Parse window_size specified in ``self.custom_keys`` to int value.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
runner (Runner): The runner of the training process.
|
|
||||||
window_size (int or str): Smoothing scale of logs.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
int: Smoothing window for statistical methods.
|
|
||||||
"""
|
|
||||||
if isinstance(window_size, int):
|
|
||||||
assert window_size == self.interval, \
|
|
||||||
'The value of windows size must equal to LoggerHook.interval'
|
|
||||||
return window_size
|
|
||||||
elif window_size == 'epoch':
|
|
||||||
return self._inner_iter + 1
|
|
||||||
elif window_size == 'global':
|
|
||||||
return runner.iter + 1
|
|
||||||
else:
|
|
||||||
raise ValueError('window_size should be int, epoch or global, but '
|
|
||||||
f'got invalid {window_size}')
|
|
||||||
|
|
||||||
def _collect_info(self, runner, mode: str) -> dict:
|
|
||||||
"""Collect log information to a dict according to mode.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
runner (Runner): The runner of the training process.
|
|
||||||
mode (str): 'train' or 'val', which means the prefix attached by
|
|
||||||
runner.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict: Statistical values of logs.
|
|
||||||
"""
|
|
||||||
tag = OrderedDict()
|
|
||||||
log_buffers = runner.message_hub.log_scalars
|
|
||||||
mode_log_buffers = OrderedDict()
|
|
||||||
# Filter log_buffers which starts with `mode`.
|
|
||||||
for prefix_key, log_buffer in log_buffers.items():
|
|
||||||
if prefix_key.startswith(mode):
|
|
||||||
key = prefix_key.split('/')[-1]
|
|
||||||
mode_log_buffers[key] = log_buffer
|
|
||||||
# Ensure all metric and lr values are latest.
|
|
||||||
for key in mode_log_buffers:
|
|
||||||
# Update the latest learning rate and smoothed time logs.
|
|
||||||
if key in self.fixed_smooth_keys or key.startswith('loss'):
|
|
||||||
tag[key] = mode_log_buffers[key].mean(self.interval)
|
|
||||||
else:
|
|
||||||
tag[key] = mode_log_buffers[key].current()
|
|
||||||
# Update custom keys.
|
|
||||||
if mode == 'train':
|
|
||||||
for log_key, log_cfg in self.custom_keys.items():
|
|
||||||
self._parse_custom_keys(runner, log_key,
|
|
||||||
copy.deepcopy(log_cfg),
|
|
||||||
mode_log_buffers, tag)
|
|
||||||
return tag
|
|
||||||
|
|
||||||
def _parse_custom_keys(self, runner, log_key: str, log_cfg: dict,
|
|
||||||
log_buffers: OrderedDict, tag: OrderedDict) -> None:
|
|
||||||
"""Statistics logs in log_buffers according to custom_keys.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
runner (Runner): The runner of the training process.
|
|
||||||
log_key (str): log key specified in ``self.custom_keys``
|
|
||||||
log_cfg (dict): A config dict for describing the logging
|
|
||||||
statistics method.
|
|
||||||
log_buffers (OrderedDict): All logs for the corresponding phase.
|
|
||||||
tag (OrderedDict): A dict which defines all statistic values of
|
|
||||||
logs.
|
|
||||||
"""
|
|
||||||
if isinstance(log_cfg, list):
|
|
||||||
log_names = set()
|
|
||||||
for cfg in log_cfg:
|
|
||||||
log_name = cfg.get('log_name', None)
|
|
||||||
if log_name in log_names:
|
|
||||||
raise KeyError(f'{cfg["log_name"]} cannot be redefined in '
|
|
||||||
'log_key')
|
|
||||||
if log_name is not None:
|
|
||||||
log_names.add(log_name)
|
|
||||||
self._parse_custom_keys(runner, log_key, cfg, log_buffers, tag)
|
|
||||||
assert len(log_names) == len(log_cfg) - 1, \
|
|
||||||
f'{log_key} cannot be overwritten multiple times, please ' \
|
|
||||||
f'check only one key does not contain `log_name` in {log_cfg}.'
|
|
||||||
elif isinstance(log_cfg, dict):
|
|
||||||
if 'window_size' in log_cfg:
|
|
||||||
log_cfg['window_size'] = \
|
|
||||||
self._get_window_size(runner, log_cfg['window_size'])
|
|
||||||
if 'log_name' in log_cfg:
|
|
||||||
name = log_cfg.pop('log_name')
|
|
||||||
else:
|
|
||||||
name = log_key
|
|
||||||
tag[name] = log_buffers[log_key].statistics(**log_cfg).item()
|
|
||||||
else:
|
|
||||||
raise ValueError('The structure of `LoggerHook.custom key` is '
|
|
||||||
'wrong, please make sure the type of each key is '
|
|
||||||
'dict or list.')
|
|
||||||
|
|
||||||
def _get_max_memory(self, runner) -> int:
|
|
||||||
"""Returns the maximum GPU memory occupied by tensors in megabytes (MB)
|
|
||||||
for a given device.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
runner (Runner): The runner of the training process.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The maximum GPU memory occupied by tensors in megabytes for a given
|
|
||||||
device.
|
|
||||||
"""
|
|
||||||
device = getattr(runner.model, 'output_device', None)
|
|
||||||
mem = torch.cuda.max_memory_allocated(device=device)
|
|
||||||
mem_mb = torch.tensor([int(mem) // (1024 * 1024)],
|
|
||||||
dtype=torch.int,
|
|
||||||
device=device)
|
|
||||||
torch.cuda.reset_peak_memory_stats()
|
|
||||||
return int(mem_mb.item())
|
|
||||||
|
|
||||||
def _check_custom_keys(self) -> None:
|
|
||||||
"""Check the legality of ``self.custom_keys``.
|
|
||||||
|
|
||||||
If ``self.by_epoch==False``, ``window_size`` should not be "epoch". The
|
|
||||||
key of ``self.fixed_smooth_keys`` cannot be overwritten.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def _check_window_size(item):
|
|
||||||
if not self.by_epoch:
|
|
||||||
assert item['window_size'] != 'epoch', \
|
|
||||||
'window_size cannot be epoch if LoggerHook.by_epoch is ' \
|
|
||||||
'False.'
|
|
||||||
|
|
||||||
def _check_fixed_keys(key, item):
|
|
||||||
if key in self.fixed_smooth_keys:
|
|
||||||
assert 'log_name' in item, f'{key} cannot be overwritten by ' \
|
|
||||||
'custom keys!'
|
|
||||||
|
|
||||||
for key, value in self.custom_keys.items():
|
|
||||||
if isinstance(value, Sequence):
|
|
||||||
[(_check_window_size(item), _check_fixed_keys(key, item))
|
|
||||||
for item in value]
|
|
||||||
|
|
||||||
else:
|
|
||||||
_check_window_size(value)
|
|
||||||
_check_fixed_keys(key, value)
|
|
||||||
|
|
||||||
def _get_epoch(self, runner, mode: str) -> int:
|
|
||||||
"""Get epoch according to mode.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
runner (Runner): The runner of the training process.
|
|
||||||
mode (str): Train or val.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
int: The current epoch.
|
|
||||||
"""
|
|
||||||
if mode == 'train':
|
|
||||||
epoch = runner.epoch + 1
|
|
||||||
elif mode == 'val':
|
|
||||||
# normal val mode
|
|
||||||
# runner.epoch += 1 has been done before val workflow
|
|
||||||
epoch = runner.epoch
|
|
||||||
else:
|
|
||||||
raise ValueError(f"runner mode should be 'train' or 'val', "
|
|
||||||
f'but got {runner.mode}')
|
|
||||||
return epoch
|
|
||||||
|
|
||||||
def _get_iter(self, runner, inner_iter=False) -> int:
|
|
||||||
"""Get the current training iteration step.
|
|
||||||
Args:
|
|
||||||
runner (Runner): The runner of the training process.
|
|
||||||
inner_iter (bool): Whether to return the inner iter of an epoch.
|
|
||||||
Defaults to False.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
int: The current global iter or inner iter.
|
|
||||||
"""
|
|
||||||
if self.by_epoch and inner_iter:
|
|
||||||
current_iter = self._inner_iter + 1
|
|
||||||
else:
|
|
||||||
current_iter = runner.iter + 1
|
|
||||||
return current_iter
|
|
||||||
|
@ -84,6 +84,9 @@ class OptimizerHook(Hook):
|
|||||||
we keep ``outputs`` here. Defaults to None.
|
we keep ``outputs`` here. Defaults to None.
|
||||||
"""
|
"""
|
||||||
runner.optimizer.zero_grad()
|
runner.optimizer.zero_grad()
|
||||||
|
runner.message_hub.update_scalar(
|
||||||
|
'train/lr', runner.optimizer.param_groups[0]['lr'])
|
||||||
|
|
||||||
if self.detect_anomalous_params:
|
if self.detect_anomalous_params:
|
||||||
self.detect_anomalous_parameters(runner.outputs['loss'], runner)
|
self.detect_anomalous_parameters(runner.outputs['loss'], runner)
|
||||||
runner.outputs['loss'].backward()
|
runner.outputs['loss'].backward()
|
||||||
|
@ -1,6 +1,9 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from .history_buffer import HistoryBuffer
|
from .history_buffer import HistoryBuffer
|
||||||
|
from .log_processor import LogProcessor
|
||||||
from .logger import MMLogger, print_log
|
from .logger import MMLogger, print_log
|
||||||
from .message_hub import MessageHub
|
from .message_hub import MessageHub
|
||||||
|
|
||||||
__all__ = ['HistoryBuffer', 'MessageHub', 'MMLogger', 'print_log']
|
__all__ = [
|
||||||
|
'HistoryBuffer', 'MessageHub', 'MMLogger', 'print_log', 'LogProcessor'
|
||||||
|
]
|
||||||
|
409
mmengine/logging/log_processor.py
Normal file
409
mmengine/logging/log_processor.py
Normal file
@ -0,0 +1,409 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
import copy
|
||||||
|
import datetime
|
||||||
|
from collections import OrderedDict
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class LogProcessor:
|
||||||
|
"""A log processor used to format log information collected from
|
||||||
|
``runner.message_hub.log_scalars``.
|
||||||
|
|
||||||
|
``LogProcessor`` instance is built by runner and will format
|
||||||
|
``runner.message_hub.log_scalars`` to ``tag`` and ``log_str``, which can
|
||||||
|
directly used by ``LoggerHook`` and ``MMLogger``. Besides, the argument
|
||||||
|
``custom_cfg`` of constructor can control the statistics method of logs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
window_size (int): default smooth interval Defaults to 10.
|
||||||
|
by_epoch (bool): Whether to format logs with epoch stype. Defaults to
|
||||||
|
True.
|
||||||
|
custom_cfg (list[dict], optional): Contains multiple log config dict,
|
||||||
|
in which key means the data source name of log and value means the
|
||||||
|
statistic method and corresponding arguments used to count the
|
||||||
|
data source. Defaults to None
|
||||||
|
- If custom_cfg is None, all logs will be formatted via default
|
||||||
|
methods, such as smoothing loss by default window_size. If
|
||||||
|
custom_cfg is defined as a list of config dict, for example:
|
||||||
|
[dict(data_src=loss, method='mean', log_name='global_loss',
|
||||||
|
window_size='global')]. It means the log item ``loss`` will be
|
||||||
|
counted as global mean and additionally logged as ``global_loss``
|
||||||
|
(defined by ``log_name``). If ``log_name`` is not defined in
|
||||||
|
config dict, the original logged key will be overwritten.
|
||||||
|
|
||||||
|
- The original log item cannot be overwritten twice. Here is
|
||||||
|
an error example:
|
||||||
|
[dict(data_src=loss, method='mean', window_size='global'),
|
||||||
|
dict(data_src=loss, method='mean', window_size='epoch')].
|
||||||
|
Both log config dict in custom_cfg do not have ``log_name`` key,
|
||||||
|
which means the loss item will be overwritten twice.
|
||||||
|
|
||||||
|
- For those statistic methods with the ``window_size`` argument,
|
||||||
|
if ``by_epoch`` is set to False, ``windows_size`` should not be
|
||||||
|
`epoch` to statistics log value by epoch.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> # `log_name` is defined, `loss_large_window` will be an additional
|
||||||
|
>>> # record.
|
||||||
|
>>> log_processor = dict(
|
||||||
|
>>> window_size=10,
|
||||||
|
>>> by_epoch=True,
|
||||||
|
>>> custom_cfg=[dict(data_src='loss',
|
||||||
|
>>> log_name='loss_large_window',
|
||||||
|
>>> method_name='mean',
|
||||||
|
>>> window_size=100)])
|
||||||
|
>>> # `log_name` is not defined. `loss` will be overwritten.
|
||||||
|
>>> log_processor = dict(
|
||||||
|
>>> window_size=10,
|
||||||
|
>>> by_epoch=True,
|
||||||
|
>>> custom_cfg=[dict(data_src='loss',
|
||||||
|
>>> method_name='mean',
|
||||||
|
>>> window_size=100)])
|
||||||
|
>>> # Record loss with different statistics methods.
|
||||||
|
>>> log_processor = dict(
|
||||||
|
>>> window_size=10,
|
||||||
|
>>> by_epoch=True,
|
||||||
|
>>> custom_cfg=[dict(data_src='loss',
|
||||||
|
>>> log_name='loss_large_window',
|
||||||
|
>>> method_name='mean',
|
||||||
|
>>> window_size=100),
|
||||||
|
>>> dict(data_src='loss',
|
||||||
|
>>> method_name='mean',
|
||||||
|
>>> window_size=100)])
|
||||||
|
>>> # Overwrite loss item twice will raise an error.
|
||||||
|
>>> log_processor = dict(
|
||||||
|
>>> window_size=10,
|
||||||
|
>>> by_epoch=True,
|
||||||
|
>>> custom_cfg=[dict(data_src='loss',
|
||||||
|
>>> method_name='mean',
|
||||||
|
>>> window_size=100),
|
||||||
|
>>> dict(data_src='loss',
|
||||||
|
>>> method_name='max',
|
||||||
|
>>> window_size=100)])
|
||||||
|
AssertionError
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
window_size=10,
|
||||||
|
by_epoch=True,
|
||||||
|
custom_cfg: Optional[List[dict]] = None):
|
||||||
|
self.window_size = window_size
|
||||||
|
self.by_epoch = by_epoch
|
||||||
|
self.custom_cfg = custom_cfg if custom_cfg else []
|
||||||
|
self._check_custom_cfg()
|
||||||
|
|
||||||
|
def get_log_after_iter(self, runner, batch_idx: int,
|
||||||
|
mode: str) -> Tuple[dict, str]:
|
||||||
|
"""Format log string after training, validation or testing epoch.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
runner (Runner): The runner of training phase.
|
||||||
|
batch_idx (int): The index of the current batch in the current
|
||||||
|
loop.
|
||||||
|
mode (str): Current mode of runner, train, test or val.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
Tuple(dict, str): Formatted log dict/string which will be
|
||||||
|
recorded by :obj:`runner.message_hub` and :obj:`runner.visualizer`.
|
||||||
|
"""
|
||||||
|
assert mode in ['train', 'test', 'val']
|
||||||
|
current_loop = self._get_cur_loop(runner, mode)
|
||||||
|
cur_iter = self._get_iter(runner, batch_idx=batch_idx)
|
||||||
|
# Overwrite ``window_size`` defined in ``custom_cfg`` to int value.
|
||||||
|
custom_cfg_copy = self._parse_windows_size(runner, batch_idx)
|
||||||
|
# tag is used to write log information to different backends.
|
||||||
|
tag = self._collect_scalars(custom_cfg_copy, runner, mode)
|
||||||
|
# `log_tag` will pop 'lr' and loop other keys to `log_str`.
|
||||||
|
log_tag = copy.deepcopy(tag)
|
||||||
|
# Record learning rate.
|
||||||
|
lr_str_list = []
|
||||||
|
for key, value in tag.items():
|
||||||
|
if key.startswith('lr'):
|
||||||
|
log_tag.pop(key)
|
||||||
|
lr_str_list.append(f'{key}: {value:.3e}')
|
||||||
|
lr_str = ' '.join(lr_str_list)
|
||||||
|
# Format log header.
|
||||||
|
# by_epoch == True
|
||||||
|
# train/val: Epoch [5][5/10] ...
|
||||||
|
# test: Epoch [5/10]
|
||||||
|
# by_epoch == False
|
||||||
|
# train: Epoch [5/10000] ... (divided by `max_iter`)
|
||||||
|
# val/test: Epoch [5/2000] ... (divided by length of dataloader)
|
||||||
|
if self.by_epoch:
|
||||||
|
if mode in ['train', 'val']:
|
||||||
|
cur_epoch = self._get_epoch(runner, mode)
|
||||||
|
log_str = (f'Epoch({mode}) [{cur_epoch}]'
|
||||||
|
f'[{cur_iter}/{len(current_loop.dataloader)}] ')
|
||||||
|
else:
|
||||||
|
log_str = (f'Epoch({mode}) '
|
||||||
|
f'[{cur_iter}/{len(current_loop.dataloader)}] ')
|
||||||
|
else:
|
||||||
|
if mode == 'train':
|
||||||
|
log_str = (f'Iter({mode}) '
|
||||||
|
f'[{cur_iter}/{runner.train_loop.max_iters}] ')
|
||||||
|
else:
|
||||||
|
log_str = (f'Iter({mode}) [{batch_idx+1}'
|
||||||
|
f'/{len(current_loop.dataloader)}] ')
|
||||||
|
# Concatenate lr, momentum string with log header.
|
||||||
|
log_str += f'{lr_str} '
|
||||||
|
# If IterTimerHook used in runner, eta, time, and data_time should be
|
||||||
|
# recorded.
|
||||||
|
if (all(item in tag for item in ['time', 'data_time'])
|
||||||
|
and 'eta' in runner.message_hub.runtime_info):
|
||||||
|
eta = runner.message_hub.get_info('eta')
|
||||||
|
eta_str = str(datetime.timedelta(seconds=int(eta)))
|
||||||
|
log_str += f'eta: {eta_str} '
|
||||||
|
log_str += (f'time: {tag["time"]:.3f} '
|
||||||
|
f'data_time: {tag["data_time"]:.3f} ')
|
||||||
|
# Pop recorded keys
|
||||||
|
log_tag.pop('time')
|
||||||
|
log_tag.pop('data_time')
|
||||||
|
|
||||||
|
# If cuda is available, the max memory occupied should be calculated.
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
log_str += f'memory: {self._get_max_memory(runner)} '
|
||||||
|
# Loop left keys to fill `log_str`.
|
||||||
|
if mode in ('train', 'val'):
|
||||||
|
log_items = []
|
||||||
|
for name, val in log_tag.items():
|
||||||
|
if mode == 'val' and not name.startswith('val/loss'):
|
||||||
|
continue
|
||||||
|
if isinstance(val, float):
|
||||||
|
val = f'{val:.4f}'
|
||||||
|
log_items.append(f'{name}: {val}')
|
||||||
|
log_str += ' '.join(log_items)
|
||||||
|
return tag, log_str
|
||||||
|
|
||||||
|
def get_log_after_epoch(self, runner, batch_idx: int,
|
||||||
|
mode: str) -> Tuple[dict, str]:
|
||||||
|
"""Format log string after validation or testing epoch.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
runner (Runner): The runner of training phase.
|
||||||
|
batch_idx (int): The index of the current batch in the current
|
||||||
|
loop.
|
||||||
|
mode (str): Current mode of runner.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
Tuple(dict, str): Formatted log dict/string which will be
|
||||||
|
recorded by :obj:`runner.message_hub` and :obj:`runner.visualizer`.
|
||||||
|
"""
|
||||||
|
assert mode in [
|
||||||
|
'test', 'val'
|
||||||
|
], ('`_get_metric_log_str` only accept val or test mode, but got '
|
||||||
|
f'{mode}')
|
||||||
|
cur_loop = self._get_cur_loop(runner, mode)
|
||||||
|
dataloader_len = len(cur_loop.dataloader)
|
||||||
|
|
||||||
|
custom_cfg_copy = self._parse_windows_size(runner, batch_idx)
|
||||||
|
# tag is used to write log information to different backends.
|
||||||
|
tag = self._collect_scalars(custom_cfg_copy, runner, mode)
|
||||||
|
# validation log string needs cur epoch/iteration and max
|
||||||
|
# epochs/iterations. test log string only needs length of test
|
||||||
|
# dataloader.
|
||||||
|
cur_iter = self._get_iter(runner, batch_idx)
|
||||||
|
if self.by_epoch:
|
||||||
|
if mode == 'val':
|
||||||
|
cur_epoch = self._get_epoch(runner, mode)
|
||||||
|
log_str = (f'Epoch({mode}) [{cur_epoch}][{dataloader_len}/'
|
||||||
|
f'{dataloader_len}] ')
|
||||||
|
else:
|
||||||
|
log_str = (
|
||||||
|
f'Epoch({mode}) [{dataloader_len}/{dataloader_len}] ')
|
||||||
|
|
||||||
|
else:
|
||||||
|
if mode == 'train':
|
||||||
|
log_str = (f'Iter({mode}) [{cur_iter}/'
|
||||||
|
f'{runner.train_loop.max_iters}] ')
|
||||||
|
else:
|
||||||
|
log_str = (
|
||||||
|
f'Iter({mode}) [{dataloader_len}/{dataloader_len}] ')
|
||||||
|
log_items = []
|
||||||
|
for name, val in tag.items():
|
||||||
|
if name in ('time', 'data_time'):
|
||||||
|
continue
|
||||||
|
if isinstance(val, float):
|
||||||
|
val = f'{val:.4f}'
|
||||||
|
log_items.append(f'{name}: {val}')
|
||||||
|
log_str += ' '.join(log_items)
|
||||||
|
return tag, log_str
|
||||||
|
|
||||||
|
def _collect_scalars(self, custom_cfg: List[dict], runner,
|
||||||
|
mode: str) -> dict:
|
||||||
|
"""Collect log information to compose a dict according to mode.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
custom_cfg (List[dict]): A copy of ``self.custom_cfg`` with int
|
||||||
|
``window_size``.
|
||||||
|
runner (Runner): The runner of the training process.
|
||||||
|
mode (str): 'train' or 'val', which means the prefix attached by
|
||||||
|
runner.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Statistical values of logs.
|
||||||
|
"""
|
||||||
|
tag = OrderedDict()
|
||||||
|
# history_scalars of train/val/test phase.
|
||||||
|
history_scalars = runner.message_hub.log_scalars
|
||||||
|
# corresponding mode history_scalars
|
||||||
|
mode_history_scalars = OrderedDict()
|
||||||
|
# extract log scalars and remove prefix to `mode_history_scalars`
|
||||||
|
# according to mode.
|
||||||
|
for prefix_key, log_buffer in history_scalars.items():
|
||||||
|
if prefix_key.startswith(mode):
|
||||||
|
key = prefix_key.split('/')[-1]
|
||||||
|
mode_history_scalars[key] = log_buffer
|
||||||
|
for key in mode_history_scalars:
|
||||||
|
# Update the latest learning rate and smoothed time logs.
|
||||||
|
if key.startswith('loss'):
|
||||||
|
tag[key] = mode_history_scalars[key].mean(self.window_size)
|
||||||
|
else:
|
||||||
|
# Default statistic method is current.
|
||||||
|
tag[key] = mode_history_scalars[key].current()
|
||||||
|
# Update custom keys.
|
||||||
|
for log_cfg in custom_cfg:
|
||||||
|
data_src = log_cfg.pop('data_src')
|
||||||
|
if 'log_name' in log_cfg:
|
||||||
|
log_name = log_cfg.pop('log_name')
|
||||||
|
else:
|
||||||
|
log_name = data_src
|
||||||
|
# log item in custom_cfg could only exist in train or val
|
||||||
|
# mode.
|
||||||
|
if data_src in mode_history_scalars:
|
||||||
|
tag[log_name] = mode_history_scalars[data_src].statistics(
|
||||||
|
**log_cfg)
|
||||||
|
return tag
|
||||||
|
|
||||||
|
def _check_custom_cfg(self) -> None:
|
||||||
|
"""Check the legality of ``self.custom_cfg``."""
|
||||||
|
|
||||||
|
def _check_window_size():
|
||||||
|
for log_cfg in self.custom_cfg:
|
||||||
|
if not self.by_epoch:
|
||||||
|
assert log_cfg['window_size'] != 'epoch', \
|
||||||
|
'window_size cannot be epoch if LoggerHook.by_epoch' \
|
||||||
|
' is False.'
|
||||||
|
|
||||||
|
def _check_repeated_log_name():
|
||||||
|
check_dict = dict()
|
||||||
|
# The `log_name` of the same data_src should not be repeated.
|
||||||
|
# If `log_name` is not specified, `data_src` will be overwritten.
|
||||||
|
# But only allowed to be overwritten once.
|
||||||
|
for log_cfg in self.custom_cfg:
|
||||||
|
assert 'data_src' in log_cfg
|
||||||
|
data_src = log_cfg['data_src']
|
||||||
|
log_name = log_cfg.get('log_name', data_src)
|
||||||
|
check_dict.setdefault(data_src,
|
||||||
|
dict(log_names=set(), log_counts=0))
|
||||||
|
check_dict[data_src]['log_names'].add(log_name)
|
||||||
|
check_dict[data_src]['log_counts'] += 1
|
||||||
|
assert (len(
|
||||||
|
check_dict[data_src]
|
||||||
|
['log_names']) == check_dict[data_src]['log_counts']), (
|
||||||
|
f'If you want to statistic {data_src} with multiple '
|
||||||
|
'statistics method, please check `log_name` is unique'
|
||||||
|
f'and {data_src} will not be overwritten twice. See '
|
||||||
|
f'more information in the docstring of `LogProcessor`')
|
||||||
|
|
||||||
|
_check_repeated_log_name()
|
||||||
|
_check_window_size()
|
||||||
|
|
||||||
|
def _parse_windows_size(self, runner, batch_idx: int) -> list:
|
||||||
|
"""Parse window_size defined in custom_cfg to int value.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
runner (Runner): The runner of the training process.
|
||||||
|
batch_idx (int): The iteration index of current dataloader.
|
||||||
|
"""
|
||||||
|
custom_cfg_copy = copy.deepcopy(self.custom_cfg)
|
||||||
|
for log_cfg in custom_cfg_copy:
|
||||||
|
window_size = log_cfg.get('window_size', None)
|
||||||
|
if window_size is None or isinstance(window_size, int):
|
||||||
|
continue
|
||||||
|
elif window_size == 'epoch':
|
||||||
|
log_cfg['window_size'] = batch_idx + 1
|
||||||
|
elif window_size == 'global':
|
||||||
|
log_cfg['window_size'] = runner.iter + 1
|
||||||
|
else:
|
||||||
|
raise TypeError(
|
||||||
|
'window_size should be int, epoch or global, but got '
|
||||||
|
f'invalid {window_size}')
|
||||||
|
return custom_cfg_copy
|
||||||
|
|
||||||
|
def _get_max_memory(self, runner) -> int:
|
||||||
|
"""Returns the maximum GPU memory occupied by tensors in megabytes (MB)
|
||||||
|
for a given device.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
runner (Runner): The runner of the training process.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The maximum GPU memory occupied by tensors in megabytes for a given
|
||||||
|
device.
|
||||||
|
"""
|
||||||
|
device = getattr(runner.model, 'output_device', None)
|
||||||
|
mem = torch.cuda.max_memory_allocated(device=device)
|
||||||
|
mem_mb = torch.tensor([int(mem) // (1024 * 1024)],
|
||||||
|
dtype=torch.int,
|
||||||
|
device=device)
|
||||||
|
torch.cuda.reset_peak_memory_stats()
|
||||||
|
return int(mem_mb.item())
|
||||||
|
|
||||||
|
def _get_iter(self, runner, batch_idx: int = None) -> int:
|
||||||
|
"""Get current training iteration step.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
runner (Runner): The runner of the training process.
|
||||||
|
batch_idx (int, optional): The interaction index of current
|
||||||
|
dataloader. Defaults to None.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: The current global iter or inner iter.
|
||||||
|
"""
|
||||||
|
if self.by_epoch and batch_idx:
|
||||||
|
current_iter = batch_idx + 1
|
||||||
|
else:
|
||||||
|
current_iter = runner.iter + 1
|
||||||
|
return current_iter
|
||||||
|
|
||||||
|
def _get_epoch(self, runner, mode: str) -> int:
|
||||||
|
"""Get current epoch according to mode.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
runner (Runner): The runner of the training/validation process.
|
||||||
|
mode (str): Current mode of runner, "train" or "val".
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: The current epoch.
|
||||||
|
"""
|
||||||
|
if mode == 'train':
|
||||||
|
epoch = runner.epoch + 1
|
||||||
|
elif mode == 'val':
|
||||||
|
# normal val mode
|
||||||
|
# runner.epoch += 1 has been done before validation
|
||||||
|
epoch = runner.epoch
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"runner mode should be 'train' or 'val', but got {mode}")
|
||||||
|
return epoch
|
||||||
|
|
||||||
|
def _get_cur_loop(self, runner, mode: str):
|
||||||
|
"""Get current loop according to mode.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
runner (Runner): The runner of the training/validation/testing
|
||||||
|
process.
|
||||||
|
mode (str): Current mode of runner, "train", "val" or test.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
BaseLoop: Current loop of runner.
|
||||||
|
"""
|
||||||
|
# returns type hint will occur circular import
|
||||||
|
if mode == 'train':
|
||||||
|
return runner.train_loop
|
||||||
|
elif mode == 'val':
|
||||||
|
return runner.val_loop
|
||||||
|
else:
|
||||||
|
return runner.test_loop
|
@ -32,15 +32,15 @@ class MMFormatter(logging.Formatter):
|
|||||||
info_prefix = self._get_prefix('INFO', color)
|
info_prefix = self._get_prefix('INFO', color)
|
||||||
debug_prefix = self._get_prefix('DEBUG', color)
|
debug_prefix = self._get_prefix('DEBUG', color)
|
||||||
# Config output format.
|
# Config output format.
|
||||||
self.err_format = f'%(asctime)s - %(name)s - {error_prefix} - ' \
|
self.err_format = (f'%(asctime)s - %(name)s - {error_prefix} - '
|
||||||
f'%(pathname)s - %(funcName)s - %(lineno)d - ' \
|
'%(pathname)s - %(funcName)s - %(lineno)d - '
|
||||||
'%(message)s'
|
'%(message)s')
|
||||||
self.warn_format = f'%(asctime)s - %(name)s - {warn_prefix} - %(' \
|
self.warn_format = (f'%(asctime)s - %(name)s - {warn_prefix} - %('
|
||||||
'message)s'
|
'message)s')
|
||||||
self.info_format = f'%(asctime)s - %(name)s - {info_prefix} - %(' \
|
self.info_format = (f'%(asctime)s - %(name)s - {info_prefix} - %('
|
||||||
'message)s'
|
'message)s')
|
||||||
self.debug_format = f'%(asctime)s - %(name)s - {debug_prefix} - %(' \
|
self.debug_format = (f'%(asctime)s - %(name)s - {debug_prefix} - %('
|
||||||
'message)s'
|
'message)s')
|
||||||
|
|
||||||
def _get_prefix(self, level: str, color: bool) -> str:
|
def _get_prefix(self, level: str, color: bool) -> str:
|
||||||
"""Get the prefix of the target log level.
|
"""Get the prefix of the target log level.
|
||||||
|
@ -1,14 +1,14 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from .lr_scheduler import (ConstantLR, CosineAnnealingLR, ExponentialLR,
|
from .lr_scheduler import (ConstantLR, CosineAnnealingLR, ExponentialLR,
|
||||||
LinearLR, MultiStepLR, StepLR)
|
LinearLR, MultiStepLR, PolyLR, StepLR)
|
||||||
from .momentum_scheduler import (ConstantMomentum, CosineAnnealingMomentum,
|
from .momentum_scheduler import (ConstantMomentum, CosineAnnealingMomentum,
|
||||||
ExponentialMomentum, LinearMomentum,
|
ExponentialMomentum, LinearMomentum,
|
||||||
MultiStepMomentum, StepMomentum)
|
MultiStepMomentum, PolyMomentum, StepMomentum)
|
||||||
from .param_scheduler import (ConstantParamScheduler,
|
from .param_scheduler import (ConstantParamScheduler,
|
||||||
CosineAnnealingParamScheduler,
|
CosineAnnealingParamScheduler,
|
||||||
ExponentialParamScheduler, LinearParamScheduler,
|
ExponentialParamScheduler, LinearParamScheduler,
|
||||||
MultiStepParamScheduler, StepParamScheduler,
|
MultiStepParamScheduler, PolyParamScheduler,
|
||||||
_ParamScheduler)
|
StepParamScheduler, _ParamScheduler)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'ConstantLR', 'CosineAnnealingLR', 'ExponentialLR', 'LinearLR',
|
'ConstantLR', 'CosineAnnealingLR', 'ExponentialLR', 'LinearLR',
|
||||||
@ -16,5 +16,6 @@ __all__ = [
|
|||||||
'ExponentialMomentum', 'LinearMomentum', 'MultiStepMomentum',
|
'ExponentialMomentum', 'LinearMomentum', 'MultiStepMomentum',
|
||||||
'StepMomentum', 'ConstantParamScheduler', 'CosineAnnealingParamScheduler',
|
'StepMomentum', 'ConstantParamScheduler', 'CosineAnnealingParamScheduler',
|
||||||
'ExponentialParamScheduler', 'LinearParamScheduler',
|
'ExponentialParamScheduler', 'LinearParamScheduler',
|
||||||
'MultiStepParamScheduler', 'StepParamScheduler', '_ParamScheduler'
|
'MultiStepParamScheduler', 'StepParamScheduler', '_ParamScheduler',
|
||||||
|
'PolyParamScheduler', 'PolyLR', 'PolyMomentum'
|
||||||
]
|
]
|
||||||
|
@ -7,7 +7,8 @@ from mmengine.registry import PARAM_SCHEDULERS
|
|||||||
from .param_scheduler import (INF, ConstantParamScheduler,
|
from .param_scheduler import (INF, ConstantParamScheduler,
|
||||||
CosineAnnealingParamScheduler,
|
CosineAnnealingParamScheduler,
|
||||||
ExponentialParamScheduler, LinearParamScheduler,
|
ExponentialParamScheduler, LinearParamScheduler,
|
||||||
MultiStepParamScheduler, StepParamScheduler)
|
MultiStepParamScheduler, PolyParamScheduler,
|
||||||
|
StepParamScheduler)
|
||||||
|
|
||||||
|
|
||||||
@PARAM_SCHEDULERS.register_module()
|
@PARAM_SCHEDULERS.register_module()
|
||||||
@ -294,3 +295,49 @@ class StepLR(StepParamScheduler):
|
|||||||
last_step=last_step,
|
last_step=last_step,
|
||||||
by_epoch=by_epoch,
|
by_epoch=by_epoch,
|
||||||
verbose=verbose)
|
verbose=verbose)
|
||||||
|
|
||||||
|
|
||||||
|
@PARAM_SCHEDULERS.register_module()
|
||||||
|
class PolyLR(PolyParamScheduler):
|
||||||
|
"""Decays the learning rate of each parameter group in a polynomial decay
|
||||||
|
scheme.
|
||||||
|
|
||||||
|
Notice that such decay can happen simultaneously with other changes to the
|
||||||
|
parameter value from outside this scheduler.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
optimizer (Optimizer): Wrapped optimizer.
|
||||||
|
eta_min (float): Minimum learning rate at the end of scheduling.
|
||||||
|
Defaults to 0.
|
||||||
|
power (float): The power of the polynomial. Defaults to 1.0.
|
||||||
|
begin (int): Step at which to start updating the parameters.
|
||||||
|
Defaults to 0.
|
||||||
|
end (int): Step at which to stop updating the parameters.
|
||||||
|
Defaults to INF.
|
||||||
|
last_step (int): The index of last step. Used for resume without
|
||||||
|
state dict. Defaults to -1.
|
||||||
|
by_epoch (bool): Whether the scheduled parameters are updated by
|
||||||
|
epochs. Defaults to True.
|
||||||
|
verbose (bool): Whether to print the value for each update.
|
||||||
|
Defaults to False.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
optimizer: torch.optim.Optimizer,
|
||||||
|
eta_min: float = 0,
|
||||||
|
power: float = 1,
|
||||||
|
begin: int = 0,
|
||||||
|
end: int = INF,
|
||||||
|
last_step: int = -1,
|
||||||
|
by_epoch: bool = True,
|
||||||
|
verbose: bool = False):
|
||||||
|
super().__init__(
|
||||||
|
optimizer,
|
||||||
|
param_name='lr',
|
||||||
|
eta_min=eta_min,
|
||||||
|
power=power,
|
||||||
|
begin=begin,
|
||||||
|
end=end,
|
||||||
|
last_step=last_step,
|
||||||
|
by_epoch=by_epoch,
|
||||||
|
verbose=verbose)
|
||||||
|
@ -7,7 +7,8 @@ from mmengine.registry import PARAM_SCHEDULERS
|
|||||||
from .param_scheduler import (INF, ConstantParamScheduler,
|
from .param_scheduler import (INF, ConstantParamScheduler,
|
||||||
CosineAnnealingParamScheduler,
|
CosineAnnealingParamScheduler,
|
||||||
ExponentialParamScheduler, LinearParamScheduler,
|
ExponentialParamScheduler, LinearParamScheduler,
|
||||||
MultiStepParamScheduler, StepParamScheduler)
|
MultiStepParamScheduler, PolyParamScheduler,
|
||||||
|
StepParamScheduler)
|
||||||
|
|
||||||
|
|
||||||
@PARAM_SCHEDULERS.register_module()
|
@PARAM_SCHEDULERS.register_module()
|
||||||
@ -294,3 +295,49 @@ class StepMomentum(StepParamScheduler):
|
|||||||
last_step=last_step,
|
last_step=last_step,
|
||||||
by_epoch=by_epoch,
|
by_epoch=by_epoch,
|
||||||
verbose=verbose)
|
verbose=verbose)
|
||||||
|
|
||||||
|
|
||||||
|
@PARAM_SCHEDULERS.register_module()
|
||||||
|
class PolyMomentum(PolyParamScheduler):
|
||||||
|
"""Decays the momentum of each parameter group in a polynomial decay
|
||||||
|
scheme.
|
||||||
|
|
||||||
|
Notice that such decay can happen simultaneously with other changes to the
|
||||||
|
parameter value from outside this scheduler.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
optimizer (Optimizer): Wrapped optimizer.
|
||||||
|
eta_min (float): Minimum momentum at the end of scheduling.
|
||||||
|
Defaults to 0.
|
||||||
|
power (float): The power of the polynomial. Defaults to 1.0.
|
||||||
|
begin (int): Step at which to start updating the parameters.
|
||||||
|
Defaults to 0.
|
||||||
|
end (int): Step at which to stop updating the parameters.
|
||||||
|
Defaults to INF.
|
||||||
|
last_step (int): The index of last step. Used for resume without
|
||||||
|
state dict. Defaults to -1.
|
||||||
|
by_epoch (bool): Whether the scheduled parameters are updated by
|
||||||
|
epochs. Defaults to True.
|
||||||
|
verbose (bool): Whether to print the value for each update.
|
||||||
|
Defaults to False.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
optimizer: torch.optim.Optimizer,
|
||||||
|
eta_min: float = 0,
|
||||||
|
power: float = 1,
|
||||||
|
begin: int = 0,
|
||||||
|
end: int = INF,
|
||||||
|
last_step: int = -1,
|
||||||
|
by_epoch: bool = True,
|
||||||
|
verbose: bool = False):
|
||||||
|
super().__init__(
|
||||||
|
optimizer,
|
||||||
|
param_name='momentum',
|
||||||
|
eta_min=eta_min,
|
||||||
|
power=power,
|
||||||
|
begin=begin,
|
||||||
|
end=end,
|
||||||
|
last_step=last_step,
|
||||||
|
by_epoch=by_epoch,
|
||||||
|
verbose=verbose)
|
||||||
|
@ -534,6 +534,7 @@ class LinearParamScheduler(_ParamScheduler):
|
|||||||
|
|
||||||
Notice that such decay can happen simultaneously with other changes to the
|
Notice that such decay can happen simultaneously with other changes to the
|
||||||
parameter value from outside this scheduler.
|
parameter value from outside this scheduler.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
optimizer (Optimizer): Wrapped optimizer.
|
optimizer (Optimizer): Wrapped optimizer.
|
||||||
start_factor (float): The number we multiply parameter value in the
|
start_factor (float): The number we multiply parameter value in the
|
||||||
@ -598,3 +599,64 @@ class LinearParamScheduler(_ParamScheduler):
|
|||||||
(self.end_factor - self.start_factor)))
|
(self.end_factor - self.start_factor)))
|
||||||
for group in self.optimizer.param_groups
|
for group in self.optimizer.param_groups
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@PARAM_SCHEDULERS.register_module()
|
||||||
|
class PolyParamScheduler(_ParamScheduler):
|
||||||
|
"""Decays the parameter value of each parameter group in a polynomial decay
|
||||||
|
scheme.
|
||||||
|
|
||||||
|
Notice that such decay can happen simultaneously with other changes to the
|
||||||
|
parameter value from outside this scheduler.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
optimizer (Optimizer): Wrapped optimizer.
|
||||||
|
eta_min (float): Minimum parameter value at the end of scheduling.
|
||||||
|
Defaults to 0.
|
||||||
|
power (float): The power of the polynomial. Defaults to 1.0.
|
||||||
|
begin (int): Step at which to start updating the parameters.
|
||||||
|
Defaults to 0.
|
||||||
|
end (int): Step at which to stop updating the parameters.
|
||||||
|
Defaults to INF.
|
||||||
|
last_step (int): The index of last step. Used for resume without
|
||||||
|
state dict. Defaults to -1.
|
||||||
|
by_epoch (bool): Whether the scheduled parameters are updated by
|
||||||
|
epochs. Defaults to True.
|
||||||
|
verbose (bool): Whether to print the value for each update.
|
||||||
|
Defaults to False.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
optimizer: Optimizer,
|
||||||
|
param_name: str,
|
||||||
|
eta_min: float = 0,
|
||||||
|
power: float = 1.0,
|
||||||
|
begin: int = 0,
|
||||||
|
end: int = INF,
|
||||||
|
last_step: int = -1,
|
||||||
|
by_epoch: bool = True,
|
||||||
|
verbose: bool = False):
|
||||||
|
|
||||||
|
self.eta_min = eta_min
|
||||||
|
self.power = power
|
||||||
|
self.total_iters = end - begin - 1
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
optimizer,
|
||||||
|
param_name=param_name,
|
||||||
|
begin=begin,
|
||||||
|
end=end,
|
||||||
|
last_step=last_step,
|
||||||
|
by_epoch=by_epoch,
|
||||||
|
verbose=verbose)
|
||||||
|
|
||||||
|
def _get_value(self):
|
||||||
|
|
||||||
|
if self.last_step == 0:
|
||||||
|
return [
|
||||||
|
group[self.param_name] for group in self.optimizer.param_groups
|
||||||
|
]
|
||||||
|
|
||||||
|
return [(group[self.param_name] - self.eta_min) *
|
||||||
|
(1 - 1 / (self.total_iters - self.last_step + 1))**self.power +
|
||||||
|
self.eta_min for group in self.optimizer.param_groups]
|
||||||
|
@ -25,7 +25,7 @@ from mmengine.dist import (broadcast, get_dist_info, init_dist, master_only,
|
|||||||
sync_random_seed)
|
sync_random_seed)
|
||||||
from mmengine.evaluator import Evaluator
|
from mmengine.evaluator import Evaluator
|
||||||
from mmengine.hooks import Hook
|
from mmengine.hooks import Hook
|
||||||
from mmengine.logging import MessageHub, MMLogger
|
from mmengine.logging import LogProcessor, MessageHub, MMLogger
|
||||||
from mmengine.model import is_model_wrapper
|
from mmengine.model import is_model_wrapper
|
||||||
from mmengine.optim import _ParamScheduler, build_optimizer
|
from mmengine.optim import _ParamScheduler, build_optimizer
|
||||||
from mmengine.registry import (DATA_SAMPLERS, DATASETS, HOOKS, LOOPS,
|
from mmengine.registry import (DATA_SAMPLERS, DATASETS, HOOKS, LOOPS,
|
||||||
@ -127,6 +127,8 @@ class Runner:
|
|||||||
non-distributed environment will be launched.
|
non-distributed environment will be launched.
|
||||||
env_cfg (dict): A dict used for setting environment. Defaults to
|
env_cfg (dict): A dict used for setting environment. Defaults to
|
||||||
dict(dist_cfg=dict(backend='nccl')).
|
dict(dist_cfg=dict(backend='nccl')).
|
||||||
|
log_processor (dict, optional): A processor to format logs. Defaults to
|
||||||
|
None.
|
||||||
log_level (int or str): The log level of MMLogger handlers.
|
log_level (int or str): The log level of MMLogger handlers.
|
||||||
Defaults to 'INFO'.
|
Defaults to 'INFO'.
|
||||||
visualizer (Visualizer or dict, optional): A Visualizer object or a
|
visualizer (Visualizer or dict, optional): A Visualizer object or a
|
||||||
@ -151,43 +153,44 @@ class Runner:
|
|||||||
Examples:
|
Examples:
|
||||||
>>> from mmengine import Runner
|
>>> from mmengine import Runner
|
||||||
>>> cfg = dict(
|
>>> cfg = dict(
|
||||||
model=dict(type='ToyModel'),
|
>>> model=dict(type='ToyModel'),
|
||||||
work_dir='path/of/work_dir',
|
>>> work_dir='path/of/work_dir',
|
||||||
train_dataloader=dict(
|
>>> train_dataloader=dict(
|
||||||
dataset=dict(type='ToyDataset'),
|
>>> dataset=dict(type='ToyDataset'),
|
||||||
sampler=dict(type='DefaultSampler', shuffle=True),
|
>>> sampler=dict(type='DefaultSampler', shuffle=True),
|
||||||
batch_size=1,
|
>>> batch_size=1,
|
||||||
num_workers=0),
|
>>> num_workers=0),
|
||||||
val_dataloader=dict(
|
>>> val_dataloader=dict(
|
||||||
dataset=dict(type='ToyDataset'),
|
>>> dataset=dict(type='ToyDataset'),
|
||||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
>>> sampler=dict(type='DefaultSampler', shuffle=False),
|
||||||
batch_size=1,
|
>>> batch_size=1,
|
||||||
num_workers=0),
|
>>> num_workers=0),
|
||||||
test_dataloader=dict(
|
>>> test_dataloader=dict(
|
||||||
dataset=dict(type='ToyDataset'),
|
>>> dataset=dict(type='ToyDataset'),
|
||||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
>>> sampler=dict(type='DefaultSampler', shuffle=False),
|
||||||
batch_size=1,
|
>>> batch_size=1,
|
||||||
num_workers=0),
|
>>> num_workers=0),
|
||||||
optimizer=dict(type='SGD', lr=0.01),
|
>>> optimizer=dict(type='SGD', lr=0.01),
|
||||||
param_scheduler=dict(type='MultiStepLR', milestones=[1, 2]),
|
>>> param_scheduler=dict(type='MultiStepLR', milestones=[1, 2]),
|
||||||
val_evaluator=dict(type='ToyEvaluator'),
|
>>> val_evaluator=dict(type='ToyEvaluator'),
|
||||||
test_evaluator=dict(type='ToyEvaluator'),
|
>>> test_evaluator=dict(type='ToyEvaluator'),
|
||||||
train_cfg=dict(by_epoch=True, max_epochs=3),
|
>>> train_cfg=dict(by_epoch=True, max_epochs=3),
|
||||||
val_cfg=dict(interval=1),
|
>>> val_cfg=dict(interval=1),
|
||||||
test_cfg=dict(),
|
>>> test_cfg=dict(),
|
||||||
custom_hooks=[],
|
>>> custom_hooks=[],
|
||||||
default_hooks=dict(
|
>>> default_hooks=dict(
|
||||||
timer=dict(type='IterTimerHook'),
|
>>> timer=dict(type='IterTimerHook'),
|
||||||
checkpoint=dict(type='CheckpointHook', interval=1),
|
>>> checkpoint=dict(type='CheckpointHook', interval=1),
|
||||||
logger=dict(type='LoggerHook'),
|
>>> logger=dict(type='LoggerHook'),
|
||||||
optimizer=dict(type='OptimizerHook', grad_clip=False),
|
>>> optimizer=dict(type='OptimizerHook', grad_clip=False),
|
||||||
param_scheduler=dict(type='ParamSchedulerHook')),
|
>>> param_scheduler=dict(type='ParamSchedulerHook')),
|
||||||
launcher='none',
|
>>> launcher='none',
|
||||||
env_cfg=dict(dist_cfg=dict(backend='nccl')),
|
>>> env_cfg=dict(dist_cfg=dict(backend='nccl')),
|
||||||
visualizer=dict(type='Visualizer',
|
>>> log_processor=dict(window_size=20),
|
||||||
vis_backends=[dict(type='LocalVisBackend',
|
>>> visualizer=dict(type='Visualizer',
|
||||||
save_dir='temp_dir')])
|
>>> vis_backends=[dict(type='LocalVisBackend',
|
||||||
)
|
>>> save_dir='temp_dir')])
|
||||||
|
>>> )
|
||||||
>>> runner = Runner.from_cfg(cfg)
|
>>> runner = Runner.from_cfg(cfg)
|
||||||
>>> runner.train()
|
>>> runner.train()
|
||||||
>>> runner.test()
|
>>> runner.test()
|
||||||
@ -217,6 +220,7 @@ class Runner:
|
|||||||
resume: bool = False,
|
resume: bool = False,
|
||||||
launcher: str = 'none',
|
launcher: str = 'none',
|
||||||
env_cfg: Dict = dict(dist_cfg=dict(backend='nccl')),
|
env_cfg: Dict = dict(dist_cfg=dict(backend='nccl')),
|
||||||
|
log_processor: Optional[Dict] = None,
|
||||||
log_level: str = 'INFO',
|
log_level: str = 'INFO',
|
||||||
visualizer: Optional[Union[Visualizer, Dict]] = None,
|
visualizer: Optional[Union[Visualizer, Dict]] = None,
|
||||||
default_scope: Optional[str] = None,
|
default_scope: Optional[str] = None,
|
||||||
@ -309,14 +313,16 @@ class Runner:
|
|||||||
self._experiment_name = f'{filename_no_ext}_{self._timestamp}'
|
self._experiment_name = f'{filename_no_ext}_{self._timestamp}'
|
||||||
else:
|
else:
|
||||||
self._experiment_name = self.timestamp
|
self._experiment_name = self.timestamp
|
||||||
|
|
||||||
# Used to reset registries location. See :meth:`Registry.build` for
|
# Used to reset registries location. See :meth:`Registry.build` for
|
||||||
# more details.
|
# more details.
|
||||||
self.default_scope = DefaultScope.get_instance(
|
self.default_scope = DefaultScope.get_instance(
|
||||||
self._experiment_name, scope_name=default_scope)
|
self._experiment_name, scope_name=default_scope)
|
||||||
|
# Build log processor to format message.
|
||||||
|
log_processor = dict() if log_processor is None else log_processor
|
||||||
|
self.log_processor = LogProcessor(**log_processor)
|
||||||
|
# Since `get_instance` could return any subclass of ManagerMixin. The
|
||||||
|
# corresponding attribute needs a type hint.
|
||||||
self.logger = self.build_logger(log_level=log_level)
|
self.logger = self.build_logger(log_level=log_level)
|
||||||
|
|
||||||
# Build `message_hub` for communication among components.
|
# Build `message_hub` for communication among components.
|
||||||
# `message_hub` can store log scalars (loss, learning rate) and
|
# `message_hub` can store log scalars (loss, learning rate) and
|
||||||
# runtime information (iter and epoch). Those components that do not
|
# runtime information (iter and epoch). Those components that do not
|
||||||
@ -387,6 +393,7 @@ class Runner:
|
|||||||
resume=cfg.get('resume', False),
|
resume=cfg.get('resume', False),
|
||||||
launcher=cfg.get('launcher', 'none'),
|
launcher=cfg.get('launcher', 'none'),
|
||||||
env_cfg=cfg.get('env_cfg'), # type: ignore
|
env_cfg=cfg.get('env_cfg'), # type: ignore
|
||||||
|
log_processor=cfg.get('log_processor'),
|
||||||
log_level=cfg.get('log_level', 'INFO'),
|
log_level=cfg.get('log_level', 'INFO'),
|
||||||
visualizer=cfg.get('visualizer'),
|
visualizer=cfg.get('visualizer'),
|
||||||
default_scope=cfg.get('default_scope'),
|
default_scope=cfg.get('default_scope'),
|
||||||
|
@ -1,29 +1,70 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from unittest.mock import Mock
|
from unittest import TestCase
|
||||||
|
from unittest.mock import MagicMock, Mock, patch
|
||||||
|
|
||||||
from mmengine.hooks import IterTimerHook
|
from mmengine.hooks import IterTimerHook
|
||||||
|
from mmengine.logging import MessageHub
|
||||||
|
|
||||||
|
|
||||||
class TestIterTimerHook:
|
def time_patch():
|
||||||
|
if not hasattr(time_patch, 'time'):
|
||||||
|
time_patch.time = 0
|
||||||
|
else:
|
||||||
|
time_patch.time += 1
|
||||||
|
return time_patch.time
|
||||||
|
|
||||||
|
|
||||||
|
class TestIterTimerHook(TestCase):
|
||||||
|
|
||||||
|
def setUp(self) -> None:
|
||||||
|
self.hook = IterTimerHook()
|
||||||
|
|
||||||
|
def test_init(self):
|
||||||
|
assert self.hook.time_sec_tot == 0
|
||||||
|
assert self.hook.start_iter == 0
|
||||||
|
|
||||||
|
def test_before_run(self):
|
||||||
|
runner = MagicMock()
|
||||||
|
runner.iter = 1
|
||||||
|
self.hook.before_run(runner)
|
||||||
|
assert self.hook.start_iter == 1
|
||||||
|
|
||||||
def test_before_epoch(self):
|
def test_before_epoch(self):
|
||||||
hook = IterTimerHook()
|
|
||||||
runner = Mock()
|
runner = Mock()
|
||||||
hook._before_epoch(runner)
|
self.hook._before_epoch(runner)
|
||||||
assert isinstance(hook.t, float)
|
assert isinstance(self.hook.t, float)
|
||||||
|
|
||||||
|
@patch('time.time', MagicMock(return_value=1))
|
||||||
def test_before_iter(self):
|
def test_before_iter(self):
|
||||||
hook = IterTimerHook()
|
runner = MagicMock()
|
||||||
runner = Mock()
|
|
||||||
runner.log_buffer = dict()
|
runner.log_buffer = dict()
|
||||||
hook._before_epoch(runner)
|
self.hook._before_epoch(runner)
|
||||||
hook._before_iter(runner, 0)
|
for mode in ('train', 'val', 'test'):
|
||||||
runner.message_hub.update_scalar.assert_called()
|
self.hook._before_iter(runner, batch_idx=1, mode=mode)
|
||||||
|
runner.message_hub.update_scalar.assert_called_with(
|
||||||
|
f'{mode}/data_time', 0)
|
||||||
|
|
||||||
|
@patch('time.time', time_patch)
|
||||||
def test_after_iter(self):
|
def test_after_iter(self):
|
||||||
hook = IterTimerHook()
|
runner = MagicMock()
|
||||||
runner = Mock()
|
|
||||||
runner.log_buffer = dict()
|
runner.log_buffer = dict()
|
||||||
hook._before_epoch(runner)
|
runner.log_processor.window_size = 10
|
||||||
hook._after_iter(runner, 0)
|
runner.train_loop.max_iters = 100
|
||||||
|
runner.iter = 0
|
||||||
|
runner.test_loop.dataloader = [0] * 20
|
||||||
|
runner.val_loop.dataloader = [0] * 20
|
||||||
|
self.hook._before_epoch(runner)
|
||||||
|
self.hook.before_run(runner)
|
||||||
|
self.hook._after_iter(runner, batch_idx=1)
|
||||||
runner.message_hub.update_scalar.assert_called()
|
runner.message_hub.update_scalar.assert_called()
|
||||||
|
runner.message_hub.get_log.assert_not_called()
|
||||||
|
runner.message_hub.update_info.assert_not_called()
|
||||||
|
runner.message_hub = MessageHub.get_instance('test_iter_timer_hook')
|
||||||
|
runner.iter = 9
|
||||||
|
# eta = (100 - 10) / 1
|
||||||
|
self.hook._after_iter(runner, batch_idx=89)
|
||||||
|
assert runner.message_hub.get_info('eta') == 90
|
||||||
|
self.hook._after_iter(runner, batch_idx=9, mode='val')
|
||||||
|
assert runner.message_hub.get_info('eta') == 10
|
||||||
|
self.hook._after_iter(runner, batch_idx=19, mode='test')
|
||||||
|
assert runner.message_hub.get_info('eta') == 0
|
||||||
|
@ -1,13 +1,8 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
import datetime
|
|
||||||
import logging
|
|
||||||
import os.path as osp
|
import os.path as osp
|
||||||
import sys
|
from unittest.mock import MagicMock
|
||||||
from collections import OrderedDict
|
|
||||||
from unittest.mock import MagicMock, patch
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
|
||||||
|
|
||||||
from mmengine.fileio.file_client import HardDiskBackend
|
from mmengine.fileio.file_client import HardDiskBackend
|
||||||
from mmengine.hooks import LoggerHook
|
from mmengine.hooks import LoggerHook
|
||||||
@ -17,11 +12,8 @@ class TestLoggerHook:
|
|||||||
|
|
||||||
def test_init(self):
|
def test_init(self):
|
||||||
logger_hook = LoggerHook(out_dir='tmp.txt')
|
logger_hook = LoggerHook(out_dir='tmp.txt')
|
||||||
assert logger_hook.by_epoch
|
|
||||||
assert logger_hook.interval == 10
|
assert logger_hook.interval == 10
|
||||||
assert not logger_hook.custom_keys
|
|
||||||
assert logger_hook.ignore_last
|
assert logger_hook.ignore_last
|
||||||
assert logger_hook.time_sec_tot == 0
|
|
||||||
assert logger_hook.interval_exp_name == 1000
|
assert logger_hook.interval_exp_name == 1000
|
||||||
assert logger_hook.out_suffix == ('.log.json', '.log', '.py')
|
assert logger_hook.out_suffix == ('.log.json', '.log', '.py')
|
||||||
assert logger_hook.keep_local
|
assert logger_hook.keep_local
|
||||||
@ -30,22 +22,7 @@ class TestLoggerHook:
|
|||||||
# out_dir should be None or string or tuple of string.
|
# out_dir should be None or string or tuple of string.
|
||||||
with pytest.raises(TypeError):
|
with pytest.raises(TypeError):
|
||||||
LoggerHook(out_dir=1)
|
LoggerHook(out_dir=1)
|
||||||
# time cannot be overwritten.
|
|
||||||
with pytest.raises(AssertionError):
|
|
||||||
LoggerHook(custom_keys=dict(time=dict(method='max')))
|
|
||||||
LoggerHook(
|
|
||||||
custom_keys=dict(time=[
|
|
||||||
dict(method='max', log_name='time_max'),
|
|
||||||
dict(method='min', log_name='time_min')
|
|
||||||
]))
|
|
||||||
# Epoch window_size cannot be used when `LoggerHook.by_epoch=False`
|
|
||||||
with pytest.raises(AssertionError):
|
|
||||||
LoggerHook(
|
|
||||||
by_epoch=False,
|
|
||||||
custom_keys=dict(
|
|
||||||
time=dict(
|
|
||||||
method='max', log_name='time_max',
|
|
||||||
window_size='epoch')))
|
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
LoggerHook(file_client_args=dict(enable_mc=True))
|
LoggerHook(file_client_args=dict(enable_mc=True))
|
||||||
|
|
||||||
@ -60,19 +37,22 @@ class TestLoggerHook:
|
|||||||
assert logger_hook.out_dir == osp.join('out_dir', 'work_dir')
|
assert logger_hook.out_dir == osp.join('out_dir', 'work_dir')
|
||||||
assert logger_hook.json_log_path == osp.join('work_dir',
|
assert logger_hook.json_log_path == osp.join('work_dir',
|
||||||
'timestamp.log.json')
|
'timestamp.log.json')
|
||||||
assert logger_hook.start_iter == runner.iter
|
|
||||||
|
|
||||||
def test_after_run(self, tmp_path):
|
def test_after_run(self, tmp_path):
|
||||||
|
# Test
|
||||||
out_dir = tmp_path / 'out_dir'
|
out_dir = tmp_path / 'out_dir'
|
||||||
out_dir.mkdir()
|
out_dir.mkdir()
|
||||||
work_dir = tmp_path / 'work_dir'
|
work_dir = tmp_path / 'work_dir'
|
||||||
work_dir.mkdir()
|
work_dir.mkdir()
|
||||||
work_dir_json = work_dir / 'tmp.log.json'
|
work_dir_json = work_dir / 'tmp.log.json'
|
||||||
json_f = open(work_dir_json, 'w')
|
|
||||||
json_f.close()
|
|
||||||
runner = MagicMock()
|
runner = MagicMock()
|
||||||
runner.work_dir = work_dir
|
runner.work_dir = work_dir
|
||||||
|
# Test without out_dir.
|
||||||
|
logger_hook = LoggerHook()
|
||||||
|
logger_hook.after_run(runner)
|
||||||
|
# Test with out_dir and make sure json file has been moved to out_dir.
|
||||||
|
json_f = open(work_dir_json, 'w')
|
||||||
|
json_f.close()
|
||||||
logger_hook = LoggerHook(out_dir=str(tmp_path), keep_local=False)
|
logger_hook = LoggerHook(out_dir=str(tmp_path), keep_local=False)
|
||||||
logger_hook.out_dir = str(out_dir)
|
logger_hook.out_dir = str(out_dir)
|
||||||
logger_hook.after_run(runner)
|
logger_hook.after_run(runner)
|
||||||
@ -83,276 +63,83 @@ class TestLoggerHook:
|
|||||||
def test_after_train_iter(self):
|
def test_after_train_iter(self):
|
||||||
# Test LoggerHook by iter.
|
# Test LoggerHook by iter.
|
||||||
runner = MagicMock()
|
runner = MagicMock()
|
||||||
runner.iter = 10
|
runner.log_processor.get_log_after_iter = MagicMock(
|
||||||
batch_idx = 5
|
return_value=(dict(), 'log_str'))
|
||||||
logger_hook = LoggerHook(by_epoch=False)
|
logger_hook = LoggerHook()
|
||||||
logger_hook._log_train = MagicMock()
|
logger_hook.after_train_iter(runner, batch_idx=5)
|
||||||
logger_hook.after_train_iter(runner, batch_idx=batch_idx)
|
|
||||||
# `cur_iter=10+1`, which cannot be exact division by
|
# `cur_iter=10+1`, which cannot be exact division by
|
||||||
# `logger_hook.interval`
|
# `logger_hook.interval`
|
||||||
logger_hook._log_train.assert_not_called()
|
runner.log_processor.get_log_after_iter.assert_not_called()
|
||||||
runner.iter = 9
|
logger_hook.after_train_iter(runner, batch_idx=9)
|
||||||
logger_hook.after_train_iter(runner, batch_idx=batch_idx)
|
runner.log_processor.get_log_after_iter.assert_called()
|
||||||
logger_hook._log_train.assert_called()
|
|
||||||
|
|
||||||
# Test LoggerHook by epoch.
|
# Test LoggerHook by epoch.
|
||||||
logger_hook = LoggerHook(by_epoch=True)
|
logger_hook = LoggerHook()
|
||||||
logger_hook._log_train = MagicMock()
|
runner = MagicMock()
|
||||||
# Only `runner.inner_iter` will work.
|
runner.log_processor.get_log_after_iter = MagicMock(
|
||||||
runner.iter = 9
|
return_value=(dict(), 'log_str'))
|
||||||
batch_idx = 10
|
# Only `batch_idx` will work.
|
||||||
logger_hook.after_train_iter(runner, batch_idx=batch_idx)
|
logger_hook.after_train_iter(runner, batch_idx=10)
|
||||||
logger_hook._log_train.assert_not_called()
|
runner.log_processor.get_log_after_iter.assert_not_called()
|
||||||
batch_idx = 9
|
logger_hook.after_train_iter(runner, batch_idx=9)
|
||||||
logger_hook.after_train_iter(runner, batch_idx=batch_idx)
|
runner.log_processor.get_log_after_iter.assert_called()
|
||||||
logger_hook._log_train.assert_called()
|
|
||||||
|
|
||||||
# Test end of the epoch.
|
# Test end of the epoch.
|
||||||
logger_hook = LoggerHook(by_epoch=True, ignore_last=False)
|
runner = MagicMock()
|
||||||
logger_hook._log_train = MagicMock()
|
runner.log_processor.get_log_after_iter = MagicMock(
|
||||||
runner.train_loop.dataloader = [0] * 5
|
return_value=(dict(), 'log_str'))
|
||||||
batch_idx = 4
|
logger_hook = LoggerHook(ignore_last=False)
|
||||||
logger_hook.after_train_iter(runner, batch_idx=batch_idx)
|
runner.train_dataloader = [0] * 5
|
||||||
logger_hook._log_train.assert_called()
|
logger_hook.after_train_iter(runner, batch_idx=4)
|
||||||
|
runner.log_processor.get_log_after_iter.assert_called()
|
||||||
|
|
||||||
# Test print exp_name
|
# Test print exp_name
|
||||||
|
runner = MagicMock()
|
||||||
|
runner.log_processor.get_log_after_iter = MagicMock(
|
||||||
|
return_value=(dict(), 'log_str'))
|
||||||
runner.meta = dict(exp_name='retinanet')
|
runner.meta = dict(exp_name='retinanet')
|
||||||
logger_hook = LoggerHook()
|
|
||||||
runner.logger = MagicMock()
|
runner.logger = MagicMock()
|
||||||
logger_hook._log_train = MagicMock()
|
logger_hook = LoggerHook()
|
||||||
logger_hook.after_train_iter(runner, batch_idx=batch_idx)
|
logger_hook.after_train_iter(runner, batch_idx=999)
|
||||||
runner.logger.info.assert_called_with(
|
runner.logger.info.assert_called()
|
||||||
f'Exp name: {runner.meta["exp_name"]}')
|
|
||||||
|
|
||||||
def test_after_val_epoch(self):
|
def test_after_val_epoch(self):
|
||||||
logger_hook = LoggerHook()
|
logger_hook = LoggerHook()
|
||||||
runner = MagicMock()
|
runner = MagicMock()
|
||||||
logger_hook._log_val = MagicMock()
|
runner.log_processor.get_log_after_epoch = MagicMock(
|
||||||
|
return_value=(dict(), 'string'))
|
||||||
logger_hook.after_val_epoch(runner)
|
logger_hook.after_val_epoch(runner)
|
||||||
logger_hook._log_val.assert_called()
|
runner.log_processor.get_log_after_epoch.assert_called()
|
||||||
|
runner.logger.info.assert_called()
|
||||||
|
runner.visualizer.add_scalars.assert_called()
|
||||||
|
|
||||||
@pytest.mark.parametrize('by_epoch', [True, False])
|
def test_after_test_epoch(self):
|
||||||
def test_log_train(self, by_epoch, capsys):
|
|
||||||
runner = self._setup_runner()
|
|
||||||
runner.meta = dict(exp_name='retinanet')
|
|
||||||
# Prepare LoggerHook
|
|
||||||
logger_hook = LoggerHook(by_epoch=by_epoch)
|
|
||||||
logger_hook._inner_iter = 1
|
|
||||||
logger_hook.writer = MagicMock()
|
|
||||||
logger_hook.time_sec_tot = 1000
|
|
||||||
logger_hook.start_iter = 0
|
|
||||||
logger_hook._get_max_memory = MagicMock(return_value='100')
|
|
||||||
logger_hook.json_log_path = 'tmp.json'
|
|
||||||
|
|
||||||
# Prepare training information.
|
|
||||||
train_infos = dict(
|
|
||||||
lr=0.1, momentum=0.9, time=1.0, data_time=1.0, loss_cls=1.0)
|
|
||||||
logger_hook._collect_info = MagicMock(return_value=train_infos)
|
|
||||||
logger_hook._log_train(runner)
|
|
||||||
# Verify that the correct variables have been written.
|
|
||||||
runner.visualizer.add_scalars.assert_called_with(
|
|
||||||
train_infos, step=11, file_path='tmp.json')
|
|
||||||
# Verify that the correct context have been logged.
|
|
||||||
out, _ = capsys.readouterr()
|
|
||||||
time_avg = logger_hook.time_sec_tot / (
|
|
||||||
runner.iter + 1 - logger_hook.start_iter)
|
|
||||||
eta_second = time_avg * (runner.train_loop.max_iters - runner.iter - 1)
|
|
||||||
eta_str = str(datetime.timedelta(seconds=int(eta_second)))
|
|
||||||
if by_epoch:
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
log_str = 'Epoch [2][2/5] ' \
|
|
||||||
f"lr: {train_infos['lr']:.3e} " \
|
|
||||||
f"momentum: {train_infos['momentum']:.3e}, " \
|
|
||||||
f'eta: {eta_str}, ' \
|
|
||||||
f"time: {train_infos['time']:.3f}, " \
|
|
||||||
f"data_time: {train_infos['data_time']:.3f}, " \
|
|
||||||
f'memory: 100, ' \
|
|
||||||
f"loss_cls: {train_infos['loss_cls']:.4f}\n"
|
|
||||||
else:
|
|
||||||
log_str = 'Epoch [2][2/5] ' \
|
|
||||||
f"lr: {train_infos['lr']:.3e} " \
|
|
||||||
f"momentum: {train_infos['momentum']:.3e}, " \
|
|
||||||
f'eta: {eta_str}, ' \
|
|
||||||
f"time: {train_infos['time']:.3f}, " \
|
|
||||||
f"data_time: {train_infos['data_time']:.3f}, " \
|
|
||||||
f"loss_cls: {train_infos['loss_cls']:.4f}\n"
|
|
||||||
assert out == log_str
|
|
||||||
else:
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
log_str = 'Iter [11/50] ' \
|
|
||||||
f"lr: {train_infos['lr']:.3e} " \
|
|
||||||
f"momentum: {train_infos['momentum']:.3e}, " \
|
|
||||||
f'eta: {eta_str}, ' \
|
|
||||||
f"time: {train_infos['time']:.3f}, " \
|
|
||||||
f"data_time: {train_infos['data_time']:.3f}, " \
|
|
||||||
f'memory: 100, ' \
|
|
||||||
f"loss_cls: {train_infos['loss_cls']:.4f}\n"
|
|
||||||
else:
|
|
||||||
log_str = 'Iter [11/50] ' \
|
|
||||||
f"lr: {train_infos['lr']:.3e} " \
|
|
||||||
f"momentum: {train_infos['momentum']:.3e}, " \
|
|
||||||
f'eta: {eta_str}, ' \
|
|
||||||
f"time: {train_infos['time']:.3f}, " \
|
|
||||||
f"data_time: {train_infos['data_time']:.3f}, " \
|
|
||||||
f"loss_cls: {train_infos['loss_cls']:.4f}\n"
|
|
||||||
assert out == log_str
|
|
||||||
|
|
||||||
@pytest.mark.parametrize('by_epoch', [True, False])
|
|
||||||
def test_log_val(self, by_epoch, capsys):
|
|
||||||
runner = self._setup_runner()
|
|
||||||
# Prepare LoggerHook.
|
|
||||||
logger_hook = LoggerHook(by_epoch=by_epoch)
|
|
||||||
logger_hook.json_log_path = 'tmp.json'
|
|
||||||
metric = dict(accuracy=0.9, data_time=1.0)
|
|
||||||
logger_hook._collect_info = MagicMock(return_value=metric)
|
|
||||||
logger_hook._log_val(runner)
|
|
||||||
# Verify that the correct context have been logged.
|
|
||||||
out, _ = capsys.readouterr()
|
|
||||||
runner.visualizer.add_scalars.assert_called_with(
|
|
||||||
metric, step=11, file_path='tmp.json')
|
|
||||||
if by_epoch:
|
|
||||||
assert out == 'Epoch(val) [1][5] accuracy: 0.9000, ' \
|
|
||||||
'data_time: 1.0000\n'
|
|
||||||
|
|
||||||
else:
|
|
||||||
assert out == 'Iter(val) [5] accuracy: 0.9000, ' \
|
|
||||||
'data_time: 1.0000\n'
|
|
||||||
|
|
||||||
def test_get_window_size(self):
|
|
||||||
runner = self._setup_runner()
|
|
||||||
logger_hook = LoggerHook()
|
|
||||||
logger_hook._inner_iter = 1
|
|
||||||
# Test get window size by name.
|
|
||||||
assert logger_hook._get_window_size(runner, 'epoch') == 2
|
|
||||||
assert logger_hook._get_window_size(runner, 'global') == 11
|
|
||||||
assert logger_hook._get_window_size(runner, 10) == 10
|
|
||||||
# Window size must equal to `logger_hook.interval`.
|
|
||||||
with pytest.raises(AssertionError):
|
|
||||||
logger_hook._get_window_size(runner, 20)
|
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
logger_hook._get_window_size(runner, 'unknwon')
|
|
||||||
|
|
||||||
def test_parse_custom_keys(self):
|
|
||||||
tag = OrderedDict()
|
|
||||||
runner = self._setup_runner()
|
|
||||||
log_buffers = OrderedDict(lr=MagicMock(), loss=MagicMock())
|
|
||||||
cfg_dict = dict(
|
|
||||||
lr=dict(method='min'),
|
|
||||||
loss=[
|
|
||||||
dict(method='min', window_size='global'),
|
|
||||||
dict(method='max', log_name='loss_max')
|
|
||||||
])
|
|
||||||
logger_hook = LoggerHook()
|
|
||||||
for log_key, log_cfg in cfg_dict.items():
|
|
||||||
logger_hook._parse_custom_keys(runner, log_key, log_cfg,
|
|
||||||
log_buffers, tag)
|
|
||||||
assert list(tag) == ['lr', 'loss', 'loss_max']
|
|
||||||
assert log_buffers['lr'].min.assert_called
|
|
||||||
assert log_buffers['loss'].min.assert_called
|
|
||||||
assert log_buffers['loss'].max.assert_called
|
|
||||||
assert log_buffers['loss'].mean.assert_called
|
|
||||||
# `log_name` Cannot be repeated.
|
|
||||||
with pytest.raises(KeyError):
|
|
||||||
cfg_dict = dict(loss=[
|
|
||||||
dict(method='min', window_size='global'),
|
|
||||||
dict(method='max', log_name='loss_max'),
|
|
||||||
dict(method='mean', log_name='loss_max')
|
|
||||||
])
|
|
||||||
logger_hook.custom_keys = cfg_dict
|
|
||||||
for log_key, log_cfg in cfg_dict.items():
|
|
||||||
logger_hook._parse_custom_keys(runner, log_key, log_cfg,
|
|
||||||
log_buffers, tag)
|
|
||||||
# `log_key` cannot be overwritten multiple times.
|
|
||||||
with pytest.raises(AssertionError):
|
|
||||||
cfg_dict = dict(loss=[
|
|
||||||
dict(method='min', window_size='global'),
|
|
||||||
dict(method='max'),
|
|
||||||
])
|
|
||||||
logger_hook.custom_keys = cfg_dict
|
|
||||||
for log_key, log_cfg in cfg_dict.items():
|
|
||||||
logger_hook._parse_custom_keys(runner, log_key, log_cfg,
|
|
||||||
log_buffers, tag)
|
|
||||||
|
|
||||||
def test_collect_info(self):
|
|
||||||
runner = self._setup_runner()
|
|
||||||
logger_hook = LoggerHook(
|
|
||||||
custom_keys=dict(time=dict(method='max', log_name='time_max')))
|
|
||||||
logger_hook._parse_custom_keys = MagicMock()
|
|
||||||
# Collect with prefix.
|
|
||||||
log_buffers = {
|
|
||||||
'train/time': MagicMock(),
|
|
||||||
'lr': MagicMock(),
|
|
||||||
'train/loss_cls': MagicMock(),
|
|
||||||
'val/metric': MagicMock()
|
|
||||||
}
|
|
||||||
runner.message_hub.log_scalars = log_buffers
|
|
||||||
tag = logger_hook._collect_info(runner, mode='train')
|
|
||||||
# Test parse custom_keys
|
|
||||||
logger_hook._parse_custom_keys.assert_called()
|
|
||||||
# Test training key in tag.
|
|
||||||
assert list(tag.keys()) == ['time', 'loss_cls']
|
|
||||||
# Test statistics lr with `current`, loss and time with 'mean'
|
|
||||||
log_buffers['train/time'].mean.assert_called()
|
|
||||||
log_buffers['train/loss_cls'].mean.assert_called()
|
|
||||||
log_buffers['train/loss_cls'].current.assert_not_called()
|
|
||||||
|
|
||||||
tag = logger_hook._collect_info(runner, mode='val')
|
|
||||||
assert list(tag.keys()) == ['metric']
|
|
||||||
log_buffers['val/metric'].current.assert_called()
|
|
||||||
|
|
||||||
@patch('torch.cuda.max_memory_allocated', MagicMock())
|
|
||||||
@patch('torch.cuda.reset_peak_memory_stats', MagicMock())
|
|
||||||
def test_get_max_memory(self):
|
|
||||||
logger_hook = LoggerHook()
|
logger_hook = LoggerHook()
|
||||||
runner = MagicMock()
|
runner = MagicMock()
|
||||||
runner.world_size = 1
|
runner.log_processor.get_log_after_epoch = MagicMock(
|
||||||
runner.model = torch.nn.Linear(1, 1)
|
return_value=(dict(), 'log_str'))
|
||||||
logger_hook._get_max_memory(runner)
|
logger_hook.after_test_epoch(runner)
|
||||||
torch.cuda.max_memory_allocated.assert_called()
|
runner.log_processor.get_log_after_epoch.assert_called()
|
||||||
torch.cuda.reset_peak_memory_stats.assert_called()
|
runner.logger.info.assert_called()
|
||||||
|
|
||||||
def test_get_iter(self):
|
def test_after_val_iter(self):
|
||||||
runner = self._setup_runner()
|
|
||||||
logger_hook = LoggerHook()
|
logger_hook = LoggerHook()
|
||||||
logger_hook._inner_iter = 1
|
|
||||||
# Get global iter when `inner_iter=False`
|
|
||||||
iter = logger_hook._get_iter(runner)
|
|
||||||
assert iter == 11
|
|
||||||
# Get inner iter
|
|
||||||
iter = logger_hook._get_iter(runner, inner_iter=True)
|
|
||||||
assert iter == 2
|
|
||||||
# Still get global iter when `logger_hook.by_epoch==False`
|
|
||||||
logger_hook.by_epoch = False
|
|
||||||
iter = logger_hook._get_iter(runner, inner_iter=True)
|
|
||||||
assert iter == 11
|
|
||||||
|
|
||||||
def test_get_epoch(self):
|
|
||||||
runner = self._setup_runner()
|
|
||||||
logger_hook = LoggerHook()
|
|
||||||
epoch = logger_hook._get_epoch(runner, 'train')
|
|
||||||
assert epoch == 2
|
|
||||||
epoch = logger_hook._get_epoch(runner, 'val')
|
|
||||||
assert epoch == 1
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
logger_hook._get_epoch(runner, 'test')
|
|
||||||
|
|
||||||
def _setup_runner(self):
|
|
||||||
runner = MagicMock()
|
runner = MagicMock()
|
||||||
runner.epoch = 1
|
runner.iter = 0
|
||||||
runner.train_loop.dataloader = [0] * 5
|
runner.log_processor.get_log_after_iter = MagicMock(
|
||||||
runner.val_loop.dataloader = [0] * 5
|
return_value=(dict(), 'log_str'))
|
||||||
runner.test_loop.dataloader = [0] * 5
|
logger_hook.after_val_iter(runner, 1)
|
||||||
runner.iter = 10
|
runner.log_processor.get_log_after_iter.assert_not_called()
|
||||||
runner.train_loop.max_iters = 50
|
logger_hook.after_val_iter(runner, 9)
|
||||||
logger = logging.getLogger()
|
runner.log_processor.get_log_after_iter.assert_called()
|
||||||
logger.setLevel(logging.INFO)
|
|
||||||
for handler in logger.handlers:
|
def test_after_test_iter(self):
|
||||||
if not isinstance(handler, logging.StreamHandler):
|
logger_hook = LoggerHook()
|
||||||
continue
|
runner = MagicMock()
|
||||||
else:
|
runner.iter = 0
|
||||||
logger.addHandler(logging.StreamHandler(stream=sys.stdout))
|
runner.log_processor.get_log_after_iter = MagicMock(
|
||||||
runner.logger = logger
|
return_value=(dict(), 'log_str'))
|
||||||
runner.message_hub = MagicMock()
|
logger_hook.after_test_iter(runner, 1)
|
||||||
runner.composed_wirter = MagicMock()
|
runner.log_processor.get_log_after_iter.assert_not_called()
|
||||||
return runner
|
logger_hook.after_test_iter(runner, 9)
|
||||||
|
runner.log_processor.get_log_after_iter.assert_called()
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from unittest.mock import Mock
|
from unittest.mock import MagicMock, Mock
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@ -45,7 +45,7 @@ class TestOptimizerHook:
|
|||||||
model = Model()
|
model = Model()
|
||||||
x = torch.rand(1, 1, 3, 3)
|
x = torch.rand(1, 1, 3, 3)
|
||||||
|
|
||||||
dummy_runner = Mock()
|
dummy_runner = MagicMock()
|
||||||
dummy_runner.optimizer.zero_grad = Mock(return_value=None)
|
dummy_runner.optimizer.zero_grad = Mock(return_value=None)
|
||||||
dummy_runner.optimizer.step = Mock(return_value=None)
|
dummy_runner.optimizer.step = Mock(return_value=None)
|
||||||
dummy_runner.model = model
|
dummy_runner.model = model
|
||||||
|
242
tests/test_logging/test_log_processor.py
Normal file
242
tests/test_logging/test_log_processor.py
Normal file
@ -0,0 +1,242 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
import copy
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from mmengine.logging import LogProcessor, MessageHub, MMLogger
|
||||||
|
|
||||||
|
|
||||||
|
class TestLogProcessor:
|
||||||
|
|
||||||
|
def test_init(self):
|
||||||
|
log_processor = LogProcessor(
|
||||||
|
window_size=10, by_epoch=True, custom_cfg=None)
|
||||||
|
assert log_processor.by_epoch
|
||||||
|
assert log_processor.window_size == 10
|
||||||
|
assert log_processor.custom_cfg == []
|
||||||
|
|
||||||
|
def test_check_custom_cfg(self):
|
||||||
|
# ``by_epoch==False`` and `window_size='epoch'` in log config will
|
||||||
|
# raise AssertionError.
|
||||||
|
custom_cfg = [dict(data_src='loss', window_size='epoch')]
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
LogProcessor(by_epoch=False, custom_cfg=custom_cfg)
|
||||||
|
# Duplicate log_name will raise AssertionError.
|
||||||
|
custom_cfg = [
|
||||||
|
dict(data_src='loss', log_name='loss_1'),
|
||||||
|
dict(data_src='loss', log_name='loss_1')
|
||||||
|
]
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
LogProcessor(custom_cfg=custom_cfg)
|
||||||
|
# Overwrite loss item twice will raise AssertionError.
|
||||||
|
custom_cfg = [dict(data_src='loss'), dict(data_src='loss')]
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
LogProcessor(custom_cfg=custom_cfg)
|
||||||
|
|
||||||
|
custom_cfg = [
|
||||||
|
dict(data_src='loss_cls', window_size=100, method_name='min'),
|
||||||
|
dict(data_src='loss', log_name='loss_min', method_name='max'),
|
||||||
|
dict(data_src='loss', log_name='loss_max', method_name='max')
|
||||||
|
]
|
||||||
|
LogProcessor(custom_cfg=custom_cfg)
|
||||||
|
|
||||||
|
def test_parse_windows_size(self):
|
||||||
|
log_processor = LogProcessor()
|
||||||
|
# Test parse 'epoch' window_size.
|
||||||
|
log_processor.custom_cfg = [
|
||||||
|
dict(data_src='loss_cls', window_size='epoch')
|
||||||
|
]
|
||||||
|
custom_cfg = log_processor._parse_windows_size(self.runner, 1)
|
||||||
|
assert custom_cfg[0]['window_size'] == 2
|
||||||
|
|
||||||
|
# Test parse 'global' window_size.
|
||||||
|
log_processor.custom_cfg = [
|
||||||
|
dict(data_src='loss_cls', window_size='global')
|
||||||
|
]
|
||||||
|
custom_cfg = log_processor._parse_windows_size(self.runner, 1)
|
||||||
|
assert custom_cfg[0]['window_size'] == 11
|
||||||
|
|
||||||
|
# Test parse int window_size
|
||||||
|
log_processor.custom_cfg = [dict(data_src='loss_cls', window_size=100)]
|
||||||
|
custom_cfg = log_processor._parse_windows_size(self.runner, 1)
|
||||||
|
assert custom_cfg[0]['window_size'] == 100
|
||||||
|
|
||||||
|
# Invalid type window_size will raise TypeError.
|
||||||
|
log_processor.custom_cfg = [dict(data_src='loss_cls', window_size=[])]
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
log_processor._parse_windows_size(custom_cfg, self.runner)
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('by_epoch,mode',
|
||||||
|
([True, 'train'], [False, 'train'], [True, 'val'],
|
||||||
|
[False, 'val'], [True, 'test'], [False, 'test']))
|
||||||
|
def test_get_log_after_iter(self, by_epoch, mode):
|
||||||
|
# Prepare LoggerHook
|
||||||
|
log_processor = LogProcessor(by_epoch=by_epoch)
|
||||||
|
log_processor._get_max_memory = MagicMock(return_value='100')
|
||||||
|
eta = 40
|
||||||
|
self.runner.message_hub.update_info('eta', eta)
|
||||||
|
# Prepare training information.
|
||||||
|
if mode == 'train':
|
||||||
|
train_logs = dict(lr=0.1, time=1.0, data_time=1.0, loss_cls=1.0)
|
||||||
|
else:
|
||||||
|
train_logs = dict(time=1.0, data_time=1.0, loss_cls=1.0)
|
||||||
|
log_processor._collect_scalars = MagicMock(return_value=train_logs)
|
||||||
|
tag, out = log_processor.get_log_after_iter(self.runner, 1, mode)
|
||||||
|
# Verify that the correct context have been logged.
|
||||||
|
cur_loop = log_processor._get_cur_loop(self.runner, mode)
|
||||||
|
if by_epoch:
|
||||||
|
if mode in ['train', 'val']:
|
||||||
|
cur_epoch = log_processor._get_epoch(self.runner, mode)
|
||||||
|
log_str = (f'Epoch({mode}) [{cur_epoch}][2/'
|
||||||
|
f'{len(cur_loop.dataloader)}] ')
|
||||||
|
else:
|
||||||
|
log_str = (f'Epoch({mode}) [2/{len(cur_loop.dataloader)}] ')
|
||||||
|
|
||||||
|
if mode == 'train':
|
||||||
|
log_str += f"lr: {train_logs['lr']:.3e} "
|
||||||
|
else:
|
||||||
|
log_str += ' '
|
||||||
|
|
||||||
|
log_str += (f'eta: 0:00:40 '
|
||||||
|
f"time: {train_logs['time']:.3f} "
|
||||||
|
f"data_time: {train_logs['data_time']:.3f} ")
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
log_str += 'memory: 100 '
|
||||||
|
if mode == 'train':
|
||||||
|
log_str += f"loss_cls: {train_logs['loss_cls']:.4f}"
|
||||||
|
assert out == log_str
|
||||||
|
else:
|
||||||
|
if mode == 'train':
|
||||||
|
max_iters = self.runner.train_loop.max_iters
|
||||||
|
log_str = f'Iter({mode}) [11/{max_iters}] '
|
||||||
|
else:
|
||||||
|
max_iters = len(cur_loop.dataloader)
|
||||||
|
log_str = f'Iter({mode}) [2/{max_iters}] '
|
||||||
|
|
||||||
|
if mode == 'train':
|
||||||
|
log_str += f"lr: {train_logs['lr']:.3e} "
|
||||||
|
else:
|
||||||
|
log_str += ' '
|
||||||
|
|
||||||
|
log_str += (f'eta: 0:00:40 '
|
||||||
|
f"time: {train_logs['time']:.3f} "
|
||||||
|
f"data_time: {train_logs['data_time']:.3f} ")
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
log_str += 'memory: 100 '
|
||||||
|
|
||||||
|
if mode == 'train':
|
||||||
|
log_str += f"loss_cls: {train_logs['loss_cls']:.4f}"
|
||||||
|
assert out == log_str
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
'by_epoch,mode',
|
||||||
|
([True, 'val'], [False, 'val'], [True, 'test'], [False, 'test']))
|
||||||
|
def test_log_val(self, by_epoch, mode):
|
||||||
|
# Prepare LoggerHook
|
||||||
|
log_processor = LogProcessor(by_epoch=by_epoch)
|
||||||
|
# Prepare validation information.
|
||||||
|
val_logs = dict(accuracy=0.9, data_time=1.0)
|
||||||
|
log_processor._collect_scalars = MagicMock(return_value=val_logs)
|
||||||
|
_, out = log_processor.get_log_after_epoch(self.runner, 2, mode)
|
||||||
|
if by_epoch:
|
||||||
|
if mode == 'test':
|
||||||
|
assert out == 'Epoch(test) [5/5] accuracy: 0.9000'
|
||||||
|
else:
|
||||||
|
assert out == 'Epoch(val) [1][10/10] accuracy: 0.9000'
|
||||||
|
else:
|
||||||
|
if mode == 'test':
|
||||||
|
assert out == 'Iter(test) [5/5] accuracy: 0.9000'
|
||||||
|
else:
|
||||||
|
assert out == 'Iter(val) [10/10] accuracy: 0.9000'
|
||||||
|
|
||||||
|
def test_collect_scalars(self):
|
||||||
|
custom_cfg = [
|
||||||
|
dict(data_src='time', method_name='mean', window_size=100),
|
||||||
|
dict(data_src='time', method_name='max', log_name='time_max')
|
||||||
|
]
|
||||||
|
logger_hook = LogProcessor(custom_cfg=custom_cfg)
|
||||||
|
# Collect with prefix.
|
||||||
|
log_scalars = {
|
||||||
|
'train/time': MagicMock(),
|
||||||
|
'lr': MagicMock(),
|
||||||
|
'train/loss_cls': MagicMock(),
|
||||||
|
'val/metric': MagicMock()
|
||||||
|
}
|
||||||
|
self.runner.message_hub._log_scalars = log_scalars
|
||||||
|
tag = logger_hook._collect_scalars(
|
||||||
|
copy.deepcopy(custom_cfg), self.runner, mode='train')
|
||||||
|
# Test training key in tag.
|
||||||
|
assert list(tag.keys()) == ['time', 'loss_cls', 'time_max']
|
||||||
|
# Test statistics lr with `current`, loss and time with 'mean'
|
||||||
|
log_scalars['train/time'].statistics.assert_called_with(
|
||||||
|
method_name='max')
|
||||||
|
log_scalars['train/loss_cls'].mean.assert_called()
|
||||||
|
|
||||||
|
tag = logger_hook._collect_scalars(
|
||||||
|
copy.deepcopy(custom_cfg), self.runner, mode='val')
|
||||||
|
assert list(tag.keys()) == ['metric']
|
||||||
|
log_scalars['val/metric'].current.assert_called()
|
||||||
|
|
||||||
|
@patch('torch.cuda.max_memory_allocated', MagicMock())
|
||||||
|
@patch('torch.cuda.reset_peak_memory_stats', MagicMock())
|
||||||
|
def test_get_max_memory(self):
|
||||||
|
logger_hook = LogProcessor()
|
||||||
|
runner = MagicMock()
|
||||||
|
runner.world_size = 1
|
||||||
|
runner.model = torch.nn.Linear(1, 1)
|
||||||
|
logger_hook._get_max_memory(runner)
|
||||||
|
torch.cuda.max_memory_allocated.assert_called()
|
||||||
|
torch.cuda.reset_peak_memory_stats.assert_called()
|
||||||
|
|
||||||
|
def test_get_iter(self):
|
||||||
|
log_processor = LogProcessor()
|
||||||
|
# Get global iter when `inner_iter=False`
|
||||||
|
iter = log_processor._get_iter(self.runner)
|
||||||
|
assert iter == 11
|
||||||
|
# Get inner iter
|
||||||
|
iter = log_processor._get_iter(self.runner, 1)
|
||||||
|
assert iter == 2
|
||||||
|
# Still get global iter when `logger_hook.by_epoch==False`
|
||||||
|
log_processor.by_epoch = False
|
||||||
|
iter = log_processor._get_iter(self.runner, 1)
|
||||||
|
assert iter == 11
|
||||||
|
|
||||||
|
def test_get_epoch(self):
|
||||||
|
log_processor = LogProcessor()
|
||||||
|
epoch = log_processor._get_epoch(self.runner, 'train')
|
||||||
|
assert epoch == 2
|
||||||
|
epoch = log_processor._get_epoch(self.runner, 'val')
|
||||||
|
assert epoch == 1
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
log_processor._get_epoch(self.runner, 'test')
|
||||||
|
|
||||||
|
def test_get_cur_loop(self):
|
||||||
|
log_processor = LogProcessor()
|
||||||
|
loop = log_processor._get_cur_loop(self.runner, 'train')
|
||||||
|
assert len(loop.dataloader) == 20
|
||||||
|
loop = log_processor._get_cur_loop(self.runner, 'val')
|
||||||
|
assert len(loop.dataloader) == 10
|
||||||
|
loop = log_processor._get_cur_loop(self.runner, 'test')
|
||||||
|
assert len(loop.dataloader) == 5
|
||||||
|
|
||||||
|
def setup(self):
|
||||||
|
runner = MagicMock()
|
||||||
|
runner.epoch = 1
|
||||||
|
runner.iter = 10
|
||||||
|
runner.train_loop.max_iters = 50
|
||||||
|
runner.train_loop.dataloader = [0] * 20
|
||||||
|
runner.val_loop.dataloader = [0] * 10
|
||||||
|
runner.test_loop.dataloader = [0] * 5
|
||||||
|
logger = MMLogger.get_instance('log_processor_test')
|
||||||
|
runner.logger = logger
|
||||||
|
message_hub = MessageHub.get_instance('log_processor_test')
|
||||||
|
for i in range(10):
|
||||||
|
message_hub.update_scalar('train/loss', 10 - i)
|
||||||
|
for i in range(10):
|
||||||
|
message_hub.update_scalar('val/acc', i * 0.1)
|
||||||
|
runner.message_hub = message_hub
|
||||||
|
self.runner = runner
|
@ -8,7 +8,7 @@ import torch.optim as optim
|
|||||||
|
|
||||||
from mmengine.optim.scheduler import (ConstantLR, CosineAnnealingLR,
|
from mmengine.optim.scheduler import (ConstantLR, CosineAnnealingLR,
|
||||||
ExponentialLR, LinearLR, MultiStepLR,
|
ExponentialLR, LinearLR, MultiStepLR,
|
||||||
StepLR, _ParamScheduler)
|
PolyLR, StepLR, _ParamScheduler)
|
||||||
from mmengine.testing import assert_allclose
|
from mmengine.testing import assert_allclose
|
||||||
|
|
||||||
|
|
||||||
@ -283,6 +283,21 @@ class TestLRScheduler(TestCase):
|
|||||||
scheduler = CosineAnnealingLR(self.optimizer, T_max=t, eta_min=eta_min)
|
scheduler = CosineAnnealingLR(self.optimizer, T_max=t, eta_min=eta_min)
|
||||||
self._test_scheduler_value(scheduler, targets, epochs)
|
self._test_scheduler_value(scheduler, targets, epochs)
|
||||||
|
|
||||||
|
def test_poly_scheduler(self):
|
||||||
|
epochs = 10
|
||||||
|
power = 0.9
|
||||||
|
min_lr = 0.001
|
||||||
|
iters = 4
|
||||||
|
single_targets = [
|
||||||
|
min_lr + (0.05 - min_lr) * (1 - i / iters)**power
|
||||||
|
for i in range(iters)
|
||||||
|
] + [min_lr] * (
|
||||||
|
epochs - iters)
|
||||||
|
targets = [single_targets, [x * epochs for x in single_targets]]
|
||||||
|
scheduler = PolyLR(
|
||||||
|
self.optimizer, power=power, eta_min=min_lr, end=iters + 1)
|
||||||
|
self._test_scheduler_value(scheduler, targets, epochs=10)
|
||||||
|
|
||||||
def _check_scheduler_state_dict(self, construct, construct2, epochs=10):
|
def _check_scheduler_state_dict(self, construct, construct2, epochs=10):
|
||||||
scheduler = construct()
|
scheduler = construct()
|
||||||
for _ in range(epochs):
|
for _ in range(epochs):
|
||||||
@ -331,6 +346,12 @@ class TestLRScheduler(TestCase):
|
|||||||
lambda: LinearLR(self.optimizer, start_factor=0, end_factor=0.3),
|
lambda: LinearLR(self.optimizer, start_factor=0, end_factor=0.3),
|
||||||
epochs=epochs)
|
epochs=epochs)
|
||||||
|
|
||||||
|
def test_poly_scheduler_state_dict(self):
|
||||||
|
self._check_scheduler_state_dict(
|
||||||
|
lambda: PolyLR(self.optimizer, power=0.5, eta_min=0.001),
|
||||||
|
lambda: PolyLR(self.optimizer, power=0.8, eta_min=0.002),
|
||||||
|
epochs=10)
|
||||||
|
|
||||||
def test_multi_scheduler_without_overlap_linear_multi_step(self):
|
def test_multi_scheduler_without_overlap_linear_multi_step(self):
|
||||||
# use Linear in the first 5 epochs and then use MultiStep
|
# use Linear in the first 5 epochs and then use MultiStep
|
||||||
epochs = 12
|
epochs = 12
|
||||||
|
@ -9,8 +9,8 @@ import torch.optim as optim
|
|||||||
from mmengine.optim.scheduler import (ConstantMomentum,
|
from mmengine.optim.scheduler import (ConstantMomentum,
|
||||||
CosineAnnealingMomentum,
|
CosineAnnealingMomentum,
|
||||||
ExponentialMomentum, LinearMomentum,
|
ExponentialMomentum, LinearMomentum,
|
||||||
MultiStepMomentum, StepMomentum,
|
MultiStepMomentum, PolyMomentum,
|
||||||
_ParamScheduler)
|
StepMomentum, _ParamScheduler)
|
||||||
from mmengine.testing import assert_allclose
|
from mmengine.testing import assert_allclose
|
||||||
|
|
||||||
|
|
||||||
@ -284,6 +284,21 @@ class TestMomentumScheduler(TestCase):
|
|||||||
self.optimizer, T_max=t, eta_min=eta_min)
|
self.optimizer, T_max=t, eta_min=eta_min)
|
||||||
self._test_scheduler_value(scheduler, targets, epochs)
|
self._test_scheduler_value(scheduler, targets, epochs)
|
||||||
|
|
||||||
|
def test_poly_scheduler(self):
|
||||||
|
epochs = 10
|
||||||
|
power = 0.9
|
||||||
|
min_lr = 0.001
|
||||||
|
iters = 4
|
||||||
|
single_targets = [
|
||||||
|
min_lr + (0.05 - min_lr) * (1 - i / iters)**power
|
||||||
|
for i in range(iters)
|
||||||
|
] + [min_lr] * (
|
||||||
|
epochs - iters)
|
||||||
|
targets = [single_targets, [x * epochs for x in single_targets]]
|
||||||
|
scheduler = PolyMomentum(
|
||||||
|
self.optimizer, power=power, eta_min=min_lr, end=iters + 1)
|
||||||
|
self._test_scheduler_value(scheduler, targets, epochs=10)
|
||||||
|
|
||||||
def _check_scheduler_state_dict(self, construct, construct2, epochs=10):
|
def _check_scheduler_state_dict(self, construct, construct2, epochs=10):
|
||||||
scheduler = construct()
|
scheduler = construct()
|
||||||
for _ in range(epochs):
|
for _ in range(epochs):
|
||||||
@ -333,6 +348,12 @@ class TestMomentumScheduler(TestCase):
|
|||||||
self.optimizer, start_factor=0, end_factor=0.3),
|
self.optimizer, start_factor=0, end_factor=0.3),
|
||||||
epochs=epochs)
|
epochs=epochs)
|
||||||
|
|
||||||
|
def test_poly_scheduler_state_dict(self):
|
||||||
|
self._check_scheduler_state_dict(
|
||||||
|
lambda: PolyMomentum(self.optimizer, power=0.5, eta_min=0.001),
|
||||||
|
lambda: PolyMomentum(self.optimizer, power=0.8, eta_min=0.002),
|
||||||
|
epochs=10)
|
||||||
|
|
||||||
def test_multi_scheduler_without_overlap_linear_multi_step(self):
|
def test_multi_scheduler_without_overlap_linear_multi_step(self):
|
||||||
# use Linear in the first 5 epochs and then use MultiStep
|
# use Linear in the first 5 epochs and then use MultiStep
|
||||||
epochs = 12
|
epochs = 12
|
||||||
|
@ -6,12 +6,15 @@ import torch
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
|
|
||||||
|
# yapf: disable
|
||||||
from mmengine.optim.scheduler import (ConstantParamScheduler,
|
from mmengine.optim.scheduler import (ConstantParamScheduler,
|
||||||
CosineAnnealingParamScheduler,
|
CosineAnnealingParamScheduler,
|
||||||
ExponentialParamScheduler,
|
ExponentialParamScheduler,
|
||||||
LinearParamScheduler,
|
LinearParamScheduler,
|
||||||
MultiStepParamScheduler,
|
MultiStepParamScheduler,
|
||||||
StepParamScheduler, _ParamScheduler)
|
PolyParamScheduler, StepParamScheduler,
|
||||||
|
_ParamScheduler)
|
||||||
|
# yapf: enable
|
||||||
from mmengine.testing import assert_allclose
|
from mmengine.testing import assert_allclose
|
||||||
|
|
||||||
|
|
||||||
@ -336,6 +339,25 @@ class TestParameterScheduler(TestCase):
|
|||||||
self.optimizer, param_name='lr', T_max=t, eta_min=eta_min)
|
self.optimizer, param_name='lr', T_max=t, eta_min=eta_min)
|
||||||
self._test_scheduler_value(scheduler, targets, epochs)
|
self._test_scheduler_value(scheduler, targets, epochs)
|
||||||
|
|
||||||
|
def test_poly_scheduler(self):
|
||||||
|
epochs = 10
|
||||||
|
power = 0.9
|
||||||
|
min_lr = 0.001
|
||||||
|
iters = 4
|
||||||
|
single_targets = [
|
||||||
|
min_lr + (0.05 - min_lr) * (1 - i / iters)**power
|
||||||
|
for i in range(iters)
|
||||||
|
] + [min_lr] * (
|
||||||
|
epochs - iters)
|
||||||
|
targets = [single_targets, [x * epochs for x in single_targets]]
|
||||||
|
scheduler = PolyParamScheduler(
|
||||||
|
self.optimizer,
|
||||||
|
param_name='lr',
|
||||||
|
power=power,
|
||||||
|
eta_min=min_lr,
|
||||||
|
end=iters + 1)
|
||||||
|
self._test_scheduler_value(scheduler, targets, epochs=10)
|
||||||
|
|
||||||
def _check_scheduler_state_dict(self, construct, construct2, epochs=10):
|
def _check_scheduler_state_dict(self, construct, construct2, epochs=10):
|
||||||
scheduler = construct()
|
scheduler = construct()
|
||||||
for _ in range(epochs):
|
for _ in range(epochs):
|
||||||
@ -402,6 +424,14 @@ class TestParameterScheduler(TestCase):
|
|||||||
end_factor=0.3),
|
end_factor=0.3),
|
||||||
epochs=epochs)
|
epochs=epochs)
|
||||||
|
|
||||||
|
def test_poly_scheduler_state_dict(self):
|
||||||
|
self._check_scheduler_state_dict(
|
||||||
|
lambda: PolyParamScheduler(
|
||||||
|
self.optimizer, param_name='lr', power=0.5, eta_min=0.001),
|
||||||
|
lambda: PolyParamScheduler(
|
||||||
|
self.optimizer, param_name='lr', power=0.8, eta_min=0.002),
|
||||||
|
epochs=10)
|
||||||
|
|
||||||
def test_multi_scheduler_without_overlap_linear_multi_step(self):
|
def test_multi_scheduler_without_overlap_linear_multi_step(self):
|
||||||
# use Linear in the first 5 epochs and then use MultiStep
|
# use Linear in the first 5 epochs and then use MultiStep
|
||||||
epochs = 12
|
epochs = 12
|
||||||
|
@ -222,7 +222,7 @@ class TestRunner(TestCase):
|
|||||||
self.iter_based_cfg.default_hooks = dict(
|
self.iter_based_cfg.default_hooks = dict(
|
||||||
timer=dict(type='IterTimerHook'),
|
timer=dict(type='IterTimerHook'),
|
||||||
checkpoint=dict(type='CheckpointHook', interval=1, by_epoch=False),
|
checkpoint=dict(type='CheckpointHook', interval=1, by_epoch=False),
|
||||||
logger=dict(type='LoggerHook', by_epoch=False),
|
logger=dict(type='LoggerHook'),
|
||||||
optimizer=dict(type='OptimizerHook', grad_clip=None),
|
optimizer=dict(type='OptimizerHook', grad_clip=None),
|
||||||
param_scheduler=dict(type='ParamSchedulerHook'))
|
param_scheduler=dict(type='ParamSchedulerHook'))
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user