[Docs] Simplify hook docs (#428)
* Move the design of hook to design/hook.md * add relative links in docs * update docstring of hooks * refine checkpointhook docs * refine * fix comments * refine * add logging.md link in hook.md * resolve comments * fix typopull/459/head
parent
5e17dd7cd4
commit
6c607bd26f
|
@ -0,0 +1,204 @@
|
|||
# 钩子的设计
|
||||
|
||||
钩子编程是一种编程模式,是指在程序的一个或者多个位置设置位点(挂载点),当程序运行至某个位点时,会自动调用运行时注册到位点的所有方法。钩子编程可以提高程序的灵活性和拓展性,用户将自定义的方法注册到位点便可被调用而无需修改程序中的代码。
|
||||
|
||||
## 钩子示例
|
||||
|
||||
下面是钩子的简单示例。
|
||||
|
||||
```python
|
||||
pre_hooks = [(print, 'hello')]
|
||||
post_hooks = [(print, 'goodbye')]
|
||||
|
||||
def main():
|
||||
for func, arg in pre_hooks:
|
||||
func(arg)
|
||||
print('do something here')
|
||||
for func, arg in post_hooks:
|
||||
func(arg)
|
||||
|
||||
main()
|
||||
```
|
||||
|
||||
下面是程序的输出:
|
||||
|
||||
```
|
||||
hello
|
||||
do something here
|
||||
goodbye
|
||||
```
|
||||
|
||||
可以看到,`main` 函数在两个位置调用钩子中的函数而无需做任何改动。
|
||||
|
||||
在 PyTorch 中,钩子的应用也随处可见,例如神经网络模块(nn.Module)中的钩子可以获得模块的前向输入输出以及反向的输入输出。以 [`register_forward_hook`](https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_forward_hook) 方法为例,该方法往模块注册一个前向钩子,钩子可以获得模块的前向输入和输出。
|
||||
|
||||
下面是 `register_forward_hook` 用法的简单示例:
|
||||
|
||||
```python
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
def forward_hook_fn(
|
||||
module, # 被注册钩子的对象
|
||||
input, # module 前向计算的输入
|
||||
output, # module 前向计算的输出
|
||||
):
|
||||
print(f'"forward_hook_fn" is invoked by {module.name}')
|
||||
print('weight:', module.weight.data)
|
||||
print('bias:', module.bias.data)
|
||||
print('input:', input)
|
||||
print('output:', output)
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.fc = nn.Linear(3, 1)
|
||||
|
||||
def forward(self, x):
|
||||
y = self.fc(x)
|
||||
return y
|
||||
|
||||
model = Model()
|
||||
# 将 forward_hook_fn 注册到 model 每个子模块
|
||||
for module in model.children():
|
||||
module.register_forward_hook(forward_hook_fn)
|
||||
|
||||
x = torch.Tensor([[0.0, 1.0, 2.0]])
|
||||
y = model(x)
|
||||
```
|
||||
|
||||
下面是程序的输出:
|
||||
|
||||
```python
|
||||
"forward_hook_fn" is invoked by Linear(in_features=3, out_features=1, bias=True)
|
||||
weight: tensor([[-0.4077, 0.0119, -0.3606]])
|
||||
bias: tensor([-0.2943])
|
||||
input: (tensor([[0., 1., 2.]]),)
|
||||
output: tensor([[-1.0036]], grad_fn=<AddmmBackward>)
|
||||
```
|
||||
|
||||
可以看到注册到 Linear 模块的 `forward_hook_fn` 钩子被调用,在该钩子中打印了 Linear 模块的权重、偏置、模块的输入以及输出。更多关于 PyTorch 钩子的用法可以阅读 [nn.Module](https://pytorch.org/docs/stable/generated/torch.nn.Module.htm)。
|
||||
|
||||
## MMEngine 中钩子的设计
|
||||
|
||||
在介绍 MMEngine 中钩子的设计之前,先简单介绍使用 PyTorch 实现模型训练的基本步骤(示例代码来自 [PyTorch Tutorials](https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html#sphx-glr-beginner-blitz-cifar10-tutorial-py)):
|
||||
|
||||
```python
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
import torchvision.transforms as transforms
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
|
||||
class CustomDataset(Dataset):
|
||||
pass
|
||||
|
||||
class Net(nn.Module):
|
||||
pass
|
||||
|
||||
def main():
|
||||
transform = transforms.ToTensor()
|
||||
train_dataset = CustomDataset(transform=transform, ...)
|
||||
val_dataset = CustomDataset(transform=transform, ...)
|
||||
test_dataset = CustomDataset(transform=transform, ...)
|
||||
train_dataloader = DataLoader(train_dataset, ...)
|
||||
val_dataloader = DataLoader(val_dataset, ...)
|
||||
test_dataloader = DataLoader(test_dataset, ...)
|
||||
|
||||
net = Net()
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
|
||||
|
||||
for i in range(max_epochs):
|
||||
for inputs, labels in train_dataloader:
|
||||
optimizer.zero_grad()
|
||||
outputs = net(inputs)
|
||||
loss = criterion(outputs, labels)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
with torch.no_grad():
|
||||
for inputs, labels in val_dataloader:
|
||||
outputs = net(inputs)
|
||||
loss = criterion(outputs, labels)
|
||||
|
||||
with torch.no_grad():
|
||||
for inputs, labels in test_dataloader:
|
||||
outputs = net(inputs)
|
||||
accuracy = ...
|
||||
```
|
||||
|
||||
上面的伪代码是训练模型的基本步骤。如果要在上面的代码中加入定制化的逻辑,我们需要不断修改和拓展 `main` 函数。为了提高 `main` 函数的灵活性和拓展性,我们可以在 `main` 方法中插入位点,并在对应位点实现调用 hook 的抽象逻辑。此时只需在这些位点插入 hook 来实现定制化逻辑,即可添加定制化功能,例如加载模型权重、更新模型参数等。
|
||||
|
||||
```python
|
||||
def main():
|
||||
...
|
||||
call_hooks('before_run', hooks) # 任务开始前执行的逻辑
|
||||
call_hooks('after_load_checkpoint', hooks) # 加载权重后执行的逻辑
|
||||
call_hooks('before_train', hooks) # 训练开始前执行的逻辑
|
||||
for i in range(max_epochs):
|
||||
call_hooks('before_train_epoch', hooks) # 遍历训练数据集前执行的逻辑
|
||||
for inputs, labels in train_dataloader:
|
||||
call_hooks('before_train_iter', hooks) # 模型前向计算前执行的逻辑
|
||||
outputs = net(inputs)
|
||||
loss = criterion(outputs, labels)
|
||||
call_hooks('after_train_iter', hooks) # 模型前向计算后执行的逻辑
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
call_hooks('after_train_epoch', hooks) # 遍历完训练数据集后执行的逻辑
|
||||
|
||||
call_hooks('before_val_epoch', hooks) # 遍历验证数据集前执行的逻辑
|
||||
with torch.no_grad():
|
||||
for inputs, labels in val_dataloader:
|
||||
call_hooks('before_val_iter', hooks) # 模型前向计算前执行
|
||||
outputs = net(inputs)
|
||||
loss = criterion(outputs, labels)
|
||||
call_hooks('after_val_iter', hooks) # 模型前向计算后执行
|
||||
call_hooks('after_val_epoch', hooks) # 遍历完验证数据集前执行
|
||||
|
||||
call_hooks('before_save_checkpoint', hooks) # 保存权重前执行的逻辑
|
||||
call_hooks('after_train', hooks) # 训练结束后执行的逻辑
|
||||
|
||||
call_hooks('before_test_epoch', hooks) # 遍历测试数据集前执行的逻辑
|
||||
with torch.no_grad():
|
||||
for inputs, labels in test_dataloader:
|
||||
call_hooks('before_test_iter', hooks) # 模型前向计算后执行的逻辑
|
||||
outputs = net(inputs)
|
||||
accuracy = ...
|
||||
call_hooks('after_test_iter', hooks) # 遍历完成测试数据集后执行的逻辑
|
||||
call_hooks('after_test_epoch', hooks) # 遍历完测试数据集后执行
|
||||
|
||||
call_hooks('after_run', hooks) # 任务结束后执行的逻辑
|
||||
```
|
||||
|
||||
在 MMEngine 中,我们将训练过程抽象成执行器(Runner),执行器除了完成环境的初始化,另一个功能是在特定的位点调用钩子完成定制化逻辑。更多关于执行器的介绍请阅读[执行器文档](../tutorials/runner.md)。
|
||||
|
||||
为了方便管理,MMEngine 将位点定义为方法并集成到[钩子基类(Hook)](mmengine.hooks.Hook)中,我们只需继承钩子基类并根据需求在特定位点实现定制化逻辑,再将钩子注册到执行器中,便可自动调用钩子中相应位点的方法。
|
||||
|
||||
钩子中一共有 22 个位点:
|
||||
|
||||
- before_run
|
||||
- after_run
|
||||
- before_train
|
||||
- after_train
|
||||
- before_train_epoch
|
||||
- after_train_epoch
|
||||
- before_train_iter
|
||||
- after_train_iter
|
||||
- before_val
|
||||
- after_val
|
||||
- before_test_epoch
|
||||
- after_test_epoch
|
||||
- before_val_iter
|
||||
- after_val_iter
|
||||
- before_test
|
||||
- after_test
|
||||
- before_test_epoch
|
||||
- after_test_epoch
|
||||
- before_test_iter
|
||||
- after_test_iter
|
||||
- before_save_checkpoint
|
||||
- after_load_checkpoint
|
||||
|
||||
你可能还想阅读[钩子的用法](../tutorials/hook.md)或者[钩子的 API 文档](mmengine.hooks)。
|
|
@ -31,6 +31,12 @@
|
|||
|
||||
examples/speed_up_training.md
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
:caption: 架构设计
|
||||
|
||||
design/hook.md
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 2
|
||||
:caption: API 文档
|
||||
|
|
|
@ -2,208 +2,9 @@
|
|||
|
||||
钩子编程是一种编程模式,是指在程序的一个或者多个位置设置位点(挂载点),当程序运行至某个位点时,会自动调用运行时注册到位点的所有方法。钩子编程可以提高程序的灵活性和拓展性,用户将自定义的方法注册到位点便可被调用而无需修改程序中的代码。
|
||||
|
||||
## 钩子示例
|
||||
|
||||
下面是钩子的简单示例。
|
||||
|
||||
```python
|
||||
pre_hooks = [(print, 'hello')]
|
||||
post_hooks = [(print, 'goodbye')]
|
||||
|
||||
def main():
|
||||
for func, arg in pre_hooks:
|
||||
func(arg)
|
||||
print('do something here')
|
||||
for func, arg in post_hooks:
|
||||
func(arg)
|
||||
|
||||
main()
|
||||
```
|
||||
|
||||
下面是程序的输出:
|
||||
|
||||
```
|
||||
hello
|
||||
do something here
|
||||
goodbye
|
||||
```
|
||||
|
||||
可以看到,`main` 函数在两个位置调用钩子中的函数而无需做任何改动。
|
||||
|
||||
在 PyTorch 中,钩子的应用也随处可见,例如神经网络模块(nn.Module)中的钩子可以获得模块的前向输入输出以及反向的输入输出。以 [`register_forward_hook`](https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_forward_hook) 方法为例,该方法往模块注册一个前向钩子,钩子可以获得模块的前向输入和输出。
|
||||
|
||||
下面是 `register_forward_hook` 用法的简单示例:
|
||||
|
||||
```python
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
def forward_hook_fn(
|
||||
module, # 被注册钩子的对象
|
||||
input, # module 前向计算的输入
|
||||
output, # module 前向计算的输出
|
||||
):
|
||||
print(f'"forward_hook_fn" is invoked by {module.name}')
|
||||
print('weight:', module.weight.data)
|
||||
print('bias:', module.bias.data)
|
||||
print('input:', input)
|
||||
print('output:', output)
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.fc = nn.Linear(3, 1)
|
||||
|
||||
def forward(self, x):
|
||||
y = self.fc(x)
|
||||
return y
|
||||
|
||||
model = Model()
|
||||
# 将 forward_hook_fn 注册到 model 每个子模块
|
||||
for module in model.children():
|
||||
module.register_forward_hook(forward_hook_fn)
|
||||
|
||||
x = torch.Tensor([[0.0, 1.0, 2.0]])
|
||||
y = model(x)
|
||||
```
|
||||
|
||||
下面是程序的输出:
|
||||
|
||||
```python
|
||||
"forward_hook_fn" is invoked by Linear(in_features=3, out_features=1, bias=True)
|
||||
weight: tensor([[-0.4077, 0.0119, -0.3606]])
|
||||
bias: tensor([-0.2943])
|
||||
input: (tensor([[0., 1., 2.]]),)
|
||||
output: tensor([[-1.0036]], grad_fn=<AddmmBackward>)
|
||||
```
|
||||
|
||||
可以看到注册到 Linear 模块的 `forward_hook_fn` 钩子被调用,在该钩子中打印了 Linear 模块的权重、偏置、模块的输入以及输出。更多关于 PyTorch 钩子的用法可以阅读 [nn.Module](https://pytorch.org/docs/stable/generated/torch.nn.Module.htm)。
|
||||
|
||||
## MMEngine 中钩子的设计
|
||||
|
||||
在介绍 MMEngine 中钩子的设计之前,先简单介绍使用 PyTorch 实现模型训练的基本步骤(示例代码来自 [PyTorch Tutorials](https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html#sphx-glr-beginner-blitz-cifar10-tutorial-py)):
|
||||
|
||||
```python
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
import torchvision.transforms as transforms
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
|
||||
class CustomDataset(Dataset):
|
||||
pass
|
||||
|
||||
class Net(nn.Module):
|
||||
pass
|
||||
|
||||
def main():
|
||||
transform = transforms.ToTensor()
|
||||
train_dataset = CustomDataset(transform=transform, ...)
|
||||
val_dataset = CustomDataset(transform=transform, ...)
|
||||
test_dataset = CustomDataset(transform=transform, ...)
|
||||
train_dataloader = DataLoader(train_dataset, ...)
|
||||
val_dataloader = DataLoader(val_dataset, ...)
|
||||
test_dataloader = DataLoader(test_dataset, ...)
|
||||
|
||||
net = Net()
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
|
||||
|
||||
for i in range(max_epochs):
|
||||
for inputs, labels in train_dataloader:
|
||||
optimizer.zero_grad()
|
||||
outputs = net(inputs)
|
||||
loss = criterion(outputs, labels)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
with torch.no_grad():
|
||||
for inputs, labels in val_dataloader:
|
||||
outputs = net(inputs)
|
||||
loss = criterion(outputs, labels)
|
||||
|
||||
with torch.no_grad():
|
||||
for inputs, labels in test_dataloader:
|
||||
outputs = net(inputs)
|
||||
accuracy = ...
|
||||
```
|
||||
|
||||
上面的伪代码是训练模型的基本步骤。如果要在上面的代码中加入定制化的逻辑,我们需要不断修改和拓展 `main` 函数。为了提高 `main` 函数的灵活性和拓展性,我们可以在 `main` 方法中插入位点,并在对应位点实现调用 hook 的抽象逻辑。此时只需在这些位点插入 hook 来实现定制化逻辑,即可添加定制化功能,例如加载模型权重、更新模型参数等。
|
||||
|
||||
```python
|
||||
def main():
|
||||
...
|
||||
call_hooks('before_run', hooks) # 任务开始前执行的逻辑
|
||||
call_hooks('after_load_checkpoint', hooks) # 加载权重后执行的逻辑
|
||||
call_hooks('before_train', hooks) # 训练开始前执行的逻辑
|
||||
for i in range(max_epochs):
|
||||
call_hooks('before_train_epoch', hooks) # 遍历训练数据集前执行的逻辑
|
||||
for inputs, labels in train_dataloader:
|
||||
call_hooks('before_train_iter', hooks) # 模型前向计算前执行的逻辑
|
||||
outputs = net(inputs)
|
||||
loss = criterion(outputs, labels)
|
||||
call_hooks('after_train_iter', hooks) # 模型前向计算后执行的逻辑
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
call_hooks('after_train_epoch', hooks) # 遍历完训练数据集后执行的逻辑
|
||||
|
||||
call_hooks('before_val_epoch', hooks) # 遍历验证数据集前执行的逻辑
|
||||
with torch.no_grad():
|
||||
for inputs, labels in val_dataloader:
|
||||
call_hooks('before_val_iter', hooks) # 模型前向计算前执行
|
||||
outputs = net(inputs)
|
||||
loss = criterion(outputs, labels)
|
||||
call_hooks('after_val_iter', hooks) # 模型前向计算后执行
|
||||
call_hooks('after_val_epoch', hooks) # 遍历完验证数据集前执行
|
||||
|
||||
call_hooks('before_save_checkpoint', hooks) # 保存权重前执行的逻辑
|
||||
call_hooks('after_train', hooks) # 训练结束后执行的逻辑
|
||||
|
||||
call_hooks('before_test_epoch', hooks) # 遍历测试数据集前执行的逻辑
|
||||
with torch.no_grad():
|
||||
for inputs, labels in test_dataloader:
|
||||
call_hooks('before_test_iter', hooks) # 模型前向计算后执行的逻辑
|
||||
outputs = net(inputs)
|
||||
accuracy = ...
|
||||
call_hooks('after_test_iter', hooks) # 遍历完成测试数据集后执行的逻辑
|
||||
call_hooks('after_test_epoch', hooks) # 遍历完测试数据集后执行
|
||||
|
||||
call_hooks('after_run', hooks) # 任务结束后执行的逻辑
|
||||
```
|
||||
|
||||
在 MMEngine 中,我们将训练过程抽象成执行器(Runner),执行器除了完成环境的初始化,另一个功能是在特定的位点调用钩子完成定制化逻辑。更多关于执行器的介绍请阅读[执行器文档](https://mmengine.readthedocs.io/zh_CN/latest/tutorials/runner.html)。
|
||||
|
||||
为了方便管理,MMEngine 将位点定义为方法并集成到[钩子基类(Hook)](https://mmengine.readthedocs.io/zh/latest/api.html#hook)中,我们只需继承钩子基类并根据需求在特定位点实现定制化逻辑,再将钩子注册到执行器中,便可自动调用钩子中相应位点的方法。
|
||||
|
||||
钩子中一共有 22 个位点:
|
||||
|
||||
- before_run
|
||||
- after_run
|
||||
- before_train
|
||||
- after_train
|
||||
- before_train_epoch
|
||||
- after_train_epoch
|
||||
- before_train_iter
|
||||
- after_train_iter
|
||||
- before_val
|
||||
- after_val
|
||||
- before_test_epoch
|
||||
- after_test_epoch
|
||||
- before_val_iter
|
||||
- after_val_iter
|
||||
- before_test
|
||||
- after_test
|
||||
- before_test_epoch
|
||||
- after_test_epoch
|
||||
- before_test_iter
|
||||
- after_test_iter
|
||||
- before_save_checkpoint
|
||||
- after_load_checkpoint
|
||||
|
||||
## 内置钩子
|
||||
|
||||
MMEngine 提供了很多内置的钩子,将钩子分为两类,分别是默认钩子以及自定义钩子。
|
||||
MMEngine 提供了很多内置的钩子,将钩子分为两类,分别是默认钩子以及自定义钩子,前者表示会默认往执行器注册,后者表示需要用户自己注册。
|
||||
|
||||
每个钩子都有对应的优先级,在同一位点,钩子的优先级越高,越早被执行器调用,如果优先级一样,被调用的顺序和钩子注册的顺序一致。优先级列表如下:
|
||||
|
||||
|
@ -244,7 +45,7 @@ MMEngine 提供了很多内置的钩子,将钩子分为两类,分别是默
|
|||
两种钩子在执行器中的设置不同,默认钩子的配置传给执行器的 `default_hooks` 参数,自定义钩子的配置传给 `custom_hooks` 参数,如下所示:
|
||||
|
||||
```python
|
||||
from mmengine import Runner
|
||||
from mmengine.runner import Runner
|
||||
|
||||
default_hooks = dict(
|
||||
runtime_info=dict(type='RuntimeInfoHook'),
|
||||
|
@ -260,131 +61,139 @@ custom_hooks = [
|
|||
]
|
||||
|
||||
runner = Runner(default_hooks=default_hooks, custom_hooks=custom_hooks, ...)
|
||||
runner.run()
|
||||
runner.train()
|
||||
```
|
||||
|
||||
下面逐一介绍 MMEngine 中内置钩子的用法。
|
||||
|
||||
### CheckpointHook
|
||||
|
||||
`CheckpointHook` 按照给定间隔保存模型的权重,如果是分布式多卡训练,则只有主(master)进程会保存权重。
|
||||
[CheckpointHook](mmengine.hooks.CheckpointHook) 按照给定间隔保存模型的权重,如果是分布式多卡训练,则只有主(master)进程会保存权重。`CheckpointHook` 的主要功能如下:
|
||||
|
||||
假设我们一共训练 21 个 epoch 并希望每隔 5 个 epoch 保存一次权重,下面的配置即可帮我们实现该需求。
|
||||
- 按照间隔保存权重,支持按 epoch 数或者 iteration 数保存权重
|
||||
- 保存最新的多个权重
|
||||
- 保存最优权重
|
||||
- 指定保存权重的路径
|
||||
|
||||
如需了解其他功能,请阅读[CheckpointHook API 文档](mmengine.hooks.CheckpointHook)。
|
||||
|
||||
下面介绍上面提到的 4 个功能。
|
||||
|
||||
- 按照间隔保存权重,支持按 epoch 数或者 iteration 数保存权重
|
||||
|
||||
假设我们一共训练 20 个 epoch 并希望每隔 5 个 epoch 保存一次权重,下面的配置即可帮我们实现该需求。
|
||||
|
||||
```python
|
||||
from mmengine import HOOKS
|
||||
|
||||
# by_epoch 的默认值为 True
|
||||
checkpoint_config = dict(type='CheckpointHook', internal=5, by_epoch=True)
|
||||
HOOKS.build(checkpoint_config)
|
||||
default_hooks = dict(checkpoint=dict(type='CheckpointHook', interval=5, by_epoch=True))
|
||||
```
|
||||
|
||||
上面的配置会保存第 5, 10, 15, 20 个 epoch 的权重。但是不会保存最后一个 epoch(第 21 个 epoch)的权重,因为 `interval=5` 表示每隔 5 个 epoch 才保存一次权重,而第 21 个 epoch 还没有隔 5 个 epoch,不过可以通过设置 `save_last=True` 保存最后一个 epoch 的权重。
|
||||
如果想以迭代次数作为保存间隔,则可以将 `by_epoch` 设为 False,`interval=5` 则表示每迭代 5 次保存一次权重。
|
||||
|
||||
```python
|
||||
checkpoint_config = dict(type='CheckpointHook', internal=5, by_epoch=True, save_last=True)
|
||||
default_hooks = dict(checkpoint=dict(type='CheckpointHook', interval=5, by_epoch=False))
|
||||
```
|
||||
|
||||
如果想以迭代次数作为保存间隔,则可以将 `by_epoch` 设为 False,`internal=5` 则表示每迭代 5 次保存一次权重。
|
||||
|
||||
```python
|
||||
checkpoint_config = dict(type='CheckpointHook', internal=5, by_epoch=False)
|
||||
```
|
||||
|
||||
权重默认保存在工作目录(work_dir),但可以通过设置 `out_dir` 改变保存路径。
|
||||
|
||||
```python
|
||||
checkpoint_config = dict(type='CheckpointHook', internal=5, out_dir='/path/of/directory')
|
||||
```
|
||||
- 保存最新的多个权重
|
||||
|
||||
如果只想保存一定数量的权重,可以通过设置 `max_keep_ckpts` 参数实现最多保存 `max_keep_ckpts` 个权重,当保存的权重数超过 `max_keep_ckpts` 时,前面的权重会被删除。
|
||||
|
||||
```python
|
||||
checkpoint_config = dict(type='CheckpointHook', internal=5, max_keep_ckpts=2)
|
||||
default_hooks = dict(checkpoint=dict(type='CheckpointHook', interval=5, max_keep_ckpts=2))
|
||||
```
|
||||
|
||||
上述例子表示,假如一共训练 20 个 epoch,那么会在第 5, 10, 15, 20 个 epoch 保存模型,但是在第 15 个 epoch 的时候会删除第 5 个 epoch 保存的权重,在第 20 个 epoch 的时候会删除第 10 个 epoch 的权重,最终只有第 15 和第 20 个 epoch 的权重才会被保存。
|
||||
|
||||
- 保存最优权重
|
||||
|
||||
如果想要保存训练过程验证集的最优权重,可以设置 `save_best` 参数,如果设置为 `'auto'`,则会根据验证集的第一个评价指标(验证集返回的评价指标是一个有序字典)判断当前权重是否最优。
|
||||
|
||||
```python
|
||||
default_hooks = dict(checkpoint=dict(type='CheckpointHook', save_best='auto'))
|
||||
```
|
||||
|
||||
也可以直接指定 `save_best` 的值为评价指标,例如在分类任务中,可以指定为 `save_best='top-1'`,则会根据 `'top-1'` 的值判断当前权重是否最优。
|
||||
|
||||
除了 `save_best` 参数,和保存最优权重相关的参数还有 `rule`,`greater_keys` 和 `less_keys`,这三者用来判断 `save_bes` 的值是越大越好还是越小越好。例如指定了 `save_best='top-1'`,可以指定 `rule='greater'`,则表示该值越大表示权重越好。
|
||||
|
||||
- 指定保存权重的路径
|
||||
|
||||
权重默认保存在工作目录(work_dir),但可以通过设置 `out_dir` 改变保存路径。
|
||||
|
||||
```python
|
||||
default_hooks = dict(checkpoint=dict(type='CheckpointHook', interval=5, out_dir='/path/of/directory'))
|
||||
```
|
||||
|
||||
### LoggerHook
|
||||
|
||||
`LoggerHook` 负责收集日志并把日志输出到终端或者输出到文件、TensorBoard 等后端。
|
||||
[LoggerHook](mmengine.hooks.LoggerHook) 负责收集日志并把日志输出到终端或者输出到文件、TensorBoard 等后端。
|
||||
|
||||
如果我们希望每迭代 20 次就输出(或保存)一次日志,我们可以设置 interval 参数,配置如下:
|
||||
如果我们希望每迭代 20 次就输出(或保存)一次日志,我们可以设置 `interval` 参数,配置如下:
|
||||
|
||||
```python
|
||||
config = dict(type='LoggerHook', interval=20)
|
||||
default_hooks = dict(logger=dict(type='LoggerHook', interval=20))
|
||||
```
|
||||
|
||||
如果我们希望训练结束后把指定后缀的文件转存到其他路径,例如 Ceph。我们可以设置 out_dir、out_suffix 和 keep_loal 三个参数。第一个参数表示将文件转存到指定的路径;第二个参数表示需要转存以哪些后缀结尾的文件,默认是 .json、.log、.py 和 yaml;第三个参数表示当我们把文件转存到其他路径后是否删除被转存的文件。
|
||||
|
||||
```python
|
||||
config = dict(type='LoggerHook', out_dir='s3://save_log/', out_suffix=('.json', '.py'), keep_local=True)
|
||||
```
|
||||
如果你对日志的管理感兴趣,可以阅读[记录日志(logging)](logging.md)。
|
||||
|
||||
### ParamSchedulerHook
|
||||
|
||||
`ParamSchedulerHook` 遍历执行器的所有优化器参数调整策略(Parameter Scheduler)并逐个调用 step 方法更新优化器的参数。如需了解优化器参数调整策略的用法请阅读[文档](https://mmengine.readthedocs.io/zh_CN/latest/tutorials/param_scheduler.html)。
|
||||
|
||||
```python
|
||||
from mmengine import Runner
|
||||
|
||||
scheduler = dict(type='MultiStepLR', by_epoch=True, milestones=[8, 11], gamma=0.1)
|
||||
|
||||
default_hooks = dict(scheduler_hook=dict(type='ParamSchedulerHook'))
|
||||
runner = Runner(scheduler=scheduler, default_hooks=default_hooks, ...)
|
||||
runner.run()
|
||||
```
|
||||
[ParamSchedulerHook](mmengine.hooks.ParamSchedulerHook) 遍历执行器的所有优化器参数调整策略(Parameter Scheduler)并逐个调用 step 方法更新优化器的参数。如需了解优化器参数调整策略的用法请阅读[文档](../tutorials/param_scheduler.md)。`ParamSchedulerHook` 默认注册到执行器并且没有可配置的参数,所以无需对其做任何配置。
|
||||
|
||||
### IterTimerHook
|
||||
|
||||
`IterTimerHook` 用于记录加载数据的时间以及迭代一次耗费的时间。
|
||||
|
||||
```python
|
||||
config = dict(type='IterTimerHook')
|
||||
```
|
||||
[IterTimerHook](mmengine.hooks.IterTimerHook) 用于记录加载数据的时间以及迭代一次耗费的时间。`IterTimerHook` 默认注册到执行器并且没有可配置的参数,所以无需对其做任何配置。
|
||||
|
||||
### DistSamplerSeedHook
|
||||
|
||||
`DistSamplerSeedHook` 在分布式训练时调用 Sampler 的 step 方法以确保 shuffle 参数生效。
|
||||
[DistSamplerSeedHook](mmengine.hooks.DistSamplerSeedHook) 在分布式训练时调用 Sampler 的 step 方法以确保 shuffle 参数生效。`DistSamplerSeedHook` 默认注册到执行器并且没有可配置的参数,所以无需对其做任何配置。
|
||||
|
||||
```python
|
||||
config = dict(type='DistSamplerSeedHook')
|
||||
```
|
||||
### RuntimeInfoHook
|
||||
|
||||
[RuntimeInfoHook](mmengine.hooks.RuntimeInfoHook) 会在执行器的不同钩子位点将当前的运行时信息(如 epoch、iter、max_epochs、max_iters、lr、metrics等)更新至 message hub 中,
|
||||
以便其他无法访问执行器的模块能够获取到这些信息。`RuntimeInfoHook` 默认注册到执行器并且没有可配置的参数,所以无需对其做任何配置。
|
||||
|
||||
### EMAHook
|
||||
|
||||
`EMAHook` 在训练过程中对模型执行指数滑动平均操作,目的是提高模型的鲁棒性。注意:指数滑动平均生成的模型只用于验证和测试,不影响训练。
|
||||
[EMAHook](mmengine.hooks.EMAHook) 在训练过程中对模型执行指数滑动平均操作,目的是提高模型的鲁棒性。注意:指数滑动平均生成的模型只用于验证和测试,不影响训练。
|
||||
|
||||
```python
|
||||
config = dict(type='EMAHook')
|
||||
custom_hooks = [dict(type='EMAHook')]
|
||||
runner = Runner(custom_hooks=custom_hooks, ...)
|
||||
runner.train()
|
||||
```
|
||||
|
||||
`EMAHook` 默认使用 `ExponentialMovingAverage`,可选值还有 `StochasticWeightAverage` 和 `MomentumAnnealingEMA`。可以通过设置 `ema_type` 使用其他的平均策略。
|
||||
|
||||
```python
|
||||
custom_hooks = [dict(type='EMAHook', ema_type='StochasticWeightAverage')]
|
||||
```
|
||||
|
||||
更多用法请阅读[EMAHook API 文档](mmengine.hooks.EMAHook)。
|
||||
|
||||
### EmptyCacheHook
|
||||
|
||||
`EmptyCacheHook` 调用 `torch.cuda.empty_cache()` 释放未被使用的显存。`EmptyCacheHook` 会在 3 个位点调用 `torch.cuda.empty_cache()`,分别是 `before_epoch`, `after_iter` 以及 `after_epoch`,用户可以通过参数控制是否调用。
|
||||
[EmptyCacheHook](mmengine.hooks.EmptyCacheHook) 调用 `torch.cuda.empty_cache()` 释放未被使用的显存。
|
||||
可以通过设置 `before_epoch`, `after_iter` 以及 `after_epoch` 参数控制释显存的时机,第一个参数表示在每个 epoch 开始之前,第二参数表示在每次迭代之后,第三个参数表示在每个 epoch 之后。
|
||||
|
||||
```python
|
||||
config = dict(type='EmptyCacheHook', before_epoch=False, after_epoch=True, after_iter=False)
|
||||
# 每一个 epoch 结束都会执行释放操作
|
||||
custom_hooks = [dict(type='EmptyCacheHook', after_epoch=True)]
|
||||
runner = Runner(custom_hooks=custom_hooks, ...)
|
||||
runner.train()
|
||||
```
|
||||
|
||||
### SyncBuffersHook
|
||||
|
||||
`SyncBuffersHook` 在分布式训练每一轮(epoch)结束时同步模型的 buffer,例如 BN 层的 `running_mean` 以及 `running_var`。
|
||||
[SyncBuffersHook](mmengine.hooks.SyncBuffersHook) 在分布式训练每一轮(epoch)结束时同步模型的 buffer,例如 BN 层的 `running_mean` 以及 `running_var`。
|
||||
|
||||
```python
|
||||
config = dict(type='SyncBuffersHook')
|
||||
custom_hooks = [dict(type='SyncBuffersHook')]
|
||||
runner = Runner(custom_hooks=custom_hooks, ...)
|
||||
runner.train()
|
||||
```
|
||||
|
||||
### RuntimeInfoHook
|
||||
|
||||
`RuntimeInfoHook` 会在执行器的不同钩子位点将当前的运行时信息(如 epoch、iter、max_epochs、max_iters、lr、metrics等)更新至 message hub 中,
|
||||
以便其他无法访问执行器的模块能够获取到这些信息。
|
||||
|
||||
```python
|
||||
config = dict(type='RuntimeInfoHook')
|
||||
```
|
||||
|
||||
## 添加自定义钩子
|
||||
## 自定义钩子
|
||||
|
||||
如果 MMEngine 提供的默认钩子不能满足需求,用户可以自定义钩子,只需继承钩子基类并重写相应的位点方法。
|
||||
|
||||
|
@ -393,7 +202,7 @@ config = dict(type='RuntimeInfoHook')
|
|||
```python
|
||||
import torch
|
||||
|
||||
from mmengine import HOOKS
|
||||
from mmengine.registry import HOOKS
|
||||
from mmengine.hooks import Hook
|
||||
|
||||
|
||||
|
@ -425,20 +234,20 @@ class CheckInvalidLossHook(Hook):
|
|||
Defaults to None.
|
||||
"""
|
||||
if self.every_n_train_iters(runner, self.interval):
|
||||
assert torch.isfinite(runner.outputs['loss']),\
|
||||
assert torch.isfinite(outputs['loss']),\
|
||||
runner.logger.info('loss become infinite or NaN!')
|
||||
```
|
||||
|
||||
我们只需将钩子的配置传给执行器的 custom_hooks 的参数,执行器初始化的时候会注册钩子,
|
||||
我们只需将钩子的配置传给执行器的 `custom_hooks` 的参数,执行器初始化的时候会注册钩子,
|
||||
|
||||
```python
|
||||
from mmengine import Runner
|
||||
from mmengine.runner import Runner
|
||||
|
||||
custom_hooks = dict(
|
||||
dict(type='CheckInvalidLossHook', interval=50)
|
||||
)
|
||||
runner = Runner(custom_hooks=custom_hooks, ...) # 实例化执行器,主要完成环境的初始化以及各种模块的构建
|
||||
runner.run() # 执行器开始训练
|
||||
runner.train() # 执行器开始训练
|
||||
```
|
||||
|
||||
便会在每次模型前向计算后检查损失值。
|
||||
|
@ -450,3 +259,14 @@ custom_hooks = dict(
|
|||
dict(type='CheckInvalidLossHook', interval=50, priority='ABOVE_NORMAL')
|
||||
)
|
||||
```
|
||||
|
||||
也可以在定义类是给定优先级
|
||||
|
||||
```python
|
||||
@HOOKS.register_module()
|
||||
class CheckInvalidLossHook(Hook):
|
||||
|
||||
priority = 'ABOVE_NORMAL'
|
||||
```
|
||||
|
||||
你可能还想阅读[钩子的设计](../design/hook.md)或者[钩子的 API 文档](mmengine.hooks)。
|
||||
|
|
|
@ -87,8 +87,7 @@ class Accuracy(BaseMetric):
|
|||
|
||||
default_prefix = 'ACC' # 设置 default_prefix
|
||||
|
||||
def process(self, data_batch: data_batch: Sequence[dict],
|
||||
predictions: Sequence[dict]):
|
||||
def process(self, data_batch: Sequence[dict], predictions: Sequence[dict]):
|
||||
"""Process one batch of data and predictions. The processed
|
||||
Results should be stored in `self.results`, which will be used
|
||||
to computed the metrics when all batches have been processed.
|
||||
|
@ -129,5 +128,4 @@ class Accuracy(BaseMetric):
|
|||
|
||||
# 返回评测指标结果
|
||||
return {'accuracy': acc}
|
||||
|
||||
```
|
||||
|
|
|
@ -25,14 +25,14 @@ class CheckpointHook(Hook):
|
|||
indicates epochs, otherwise it indicates iterations.
|
||||
Defaults to -1, which means "never".
|
||||
by_epoch (bool): Saving checkpoints by epoch or by iteration.
|
||||
Default: True.
|
||||
Defaults to True.
|
||||
save_optimizer (bool): Whether to save optimizer state_dict in the
|
||||
checkpoint. It is usually used for resuming experiments.
|
||||
Defaults to True.
|
||||
save_param_scheduler (bool): Whether to save param_scheduler state_dict
|
||||
in the checkpoint. It is usually used for resuming experiments.
|
||||
Defaults to True.
|
||||
out_dir (str, optional | Path): The root directory to save checkpoints.
|
||||
out_dir (str, Path, Optional): The root directory to save checkpoints.
|
||||
If not specified, ``runner.work_dir`` will be used by default. If
|
||||
specified, the ``out_dir`` will be the concatenation of ``out_dir``
|
||||
and the last level directory of ``runner.work_dir``. For example,
|
||||
|
|
|
@ -59,7 +59,11 @@ class EMAHook(Hook):
|
|||
self.enabled_by_epoch = self.begin_epoch > 0
|
||||
|
||||
def before_run(self, runner) -> None:
|
||||
"""Create an ema copy of the model."""
|
||||
"""Create an ema copy of the model.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
"""
|
||||
model = runner.model
|
||||
if is_model_wrapper(model):
|
||||
model = model.module
|
||||
|
@ -81,39 +85,71 @@ class EMAHook(Hook):
|
|||
batch_idx: int,
|
||||
data_batch: DATA_BATCH = None,
|
||||
outputs: Optional[dict] = None) -> None:
|
||||
"""Update ema parameter."""
|
||||
"""Update ema parameter.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
batch_idx (int): The index of the current batch in the train loop.
|
||||
data_batch (Sequence[dict], optional): Data from dataloader.
|
||||
Defaults to None.
|
||||
outputs (dict, optional): Outputs from model. Defaults to None.
|
||||
"""
|
||||
if self._ema_started(runner):
|
||||
self.ema_model.update_parameters(self.src_model)
|
||||
|
||||
def before_val_epoch(self, runner) -> None:
|
||||
"""We load parameter values from ema model to source model before
|
||||
validation."""
|
||||
validation.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
"""
|
||||
if self._ema_started(runner):
|
||||
self._swap_ema_parameters()
|
||||
|
||||
def after_val_epoch(self,
|
||||
runner,
|
||||
metrics: Optional[Dict[str, float]] = None) -> None:
|
||||
"""We recover source model's parameter from ema model after
|
||||
validation."""
|
||||
"""We recover source model's parameter from ema model after validation.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the validation process.
|
||||
metrics (Dict[str, float], optional): Evaluation results of all
|
||||
metrics on validation dataset. The keys are the names of the
|
||||
metrics, and the values are corresponding results.
|
||||
"""
|
||||
if self._ema_started(runner):
|
||||
self._swap_ema_parameters()
|
||||
|
||||
def before_test_epoch(self, runner) -> None:
|
||||
"""We load parameter values from ema model to source model before
|
||||
test."""
|
||||
"""We load parameter values from ema model to source model before test.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
"""
|
||||
if self._ema_started(runner):
|
||||
self._swap_ema_parameters()
|
||||
|
||||
def after_test_epoch(self,
|
||||
runner,
|
||||
metrics: Optional[Dict[str, float]] = None) -> None:
|
||||
"""We recover source model's parameter from ema model after test."""
|
||||
"""We recover source model's parameter from ema model after test.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the testing process.
|
||||
metrics (Dict[str, float], optional): Evaluation results of all
|
||||
metrics on test dataset. The keys are the names of the
|
||||
metrics, and the values are corresponding results.
|
||||
"""
|
||||
if self._ema_started(runner):
|
||||
self._swap_ema_parameters()
|
||||
|
||||
def before_save_checkpoint(self, runner, checkpoint: dict) -> None:
|
||||
"""Save ema parameters to checkpoint."""
|
||||
"""Save ema parameters to checkpoint.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the testing process.
|
||||
"""
|
||||
if self._ema_started(runner):
|
||||
checkpoint['ema_state_dict'] = self.ema_model.state_dict()
|
||||
# Save ema parameters to the source model's state dict so that we
|
||||
|
@ -124,7 +160,11 @@ class EMAHook(Hook):
|
|||
self._swap_ema_state_dict(checkpoint)
|
||||
|
||||
def after_load_checkpoint(self, runner, checkpoint: dict) -> None:
|
||||
"""Resume ema parameters from checkpoint."""
|
||||
"""Resume ema parameters from checkpoint.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the testing process.
|
||||
"""
|
||||
if self._ema_started(runner):
|
||||
if 'ema_state_dict' in checkpoint:
|
||||
# The original model parameters are actually saved in ema
|
||||
|
|
|
@ -39,10 +39,10 @@ class LoggerHook(Hook):
|
|||
out_dir (str or Path, optional): The root directory to save
|
||||
checkpoints. If not specified, ``runner.work_dir`` will be used
|
||||
by default. If specified, the ``out_dir`` will be the concatenation
|
||||
of ``out_dir`` and the last level directory of
|
||||
``runner.work_dir``. For example, if the input ``our_dir`` is
|
||||
``./tmp`` and ``runner.work_dir`` is ``./work_dir/cur_exp``,
|
||||
then the log will be saved in ``./tmp/cur_exp``. Defaults to None.
|
||||
of ``out_dir`` and the last level directory of ``runner.work_dir``.
|
||||
For example, if the input ``our_dir`` is ``./tmp`` and
|
||||
``runner.work_dir`` is ``./work_dir/cur_exp``, then the log will be
|
||||
saved in ``./tmp/cur_exp``. Defaults to None.
|
||||
out_suffix (Tuple[str] or str): Those files in ``runner._log_dir``
|
||||
ending with ``out_suffix`` will be copied to ``out_dir``. Defaults
|
||||
to ('json', '.log', '.py').
|
||||
|
|
|
@ -18,7 +18,7 @@ class NaiveVisualizationHook(Hook):
|
|||
"""Show or Write the predicted results during the process of testing.
|
||||
|
||||
Args:
|
||||
interval (int): Visualization interval. Default: 1.
|
||||
interval (int): Visualization interval. Defaults to 1.
|
||||
draw_gt (bool): Whether to draw the ground truth. Default to True.
|
||||
draw_pred (bool): Whether to draw the predicted result.
|
||||
Default to True.
|
||||
|
|
|
@ -3,6 +3,7 @@ from typing import Dict, Optional, Sequence
|
|||
|
||||
from ..registry import HOOKS
|
||||
from ..utils import get_git_hash
|
||||
from ..version import __version__
|
||||
from .hook import Hook
|
||||
|
||||
DATA_BATCH = Optional[Sequence[dict]]
|
||||
|
@ -20,16 +21,24 @@ class RuntimeInfoHook(Hook):
|
|||
priority = 'VERY_HIGH'
|
||||
|
||||
def before_run(self, runner) -> None:
|
||||
import mmengine
|
||||
"""Update metainfo.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
"""
|
||||
metainfo = dict(
|
||||
cfg=runner.cfg.pretty_text,
|
||||
seed=runner.seed,
|
||||
experiment_name=runner.experiment_name,
|
||||
mmengine_version=mmengine.__version__ + get_git_hash())
|
||||
mmengine_version=__version__ + get_git_hash())
|
||||
runner.message_hub.update_info_dict(metainfo)
|
||||
|
||||
def before_train(self, runner) -> None:
|
||||
"""Update resumed training state."""
|
||||
"""Update resumed training state.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
"""
|
||||
runner.message_hub.update_info('epoch', runner.epoch)
|
||||
runner.message_hub.update_info('iter', runner.iter)
|
||||
runner.message_hub.update_info('max_epochs', runner.max_epochs)
|
||||
|
@ -39,7 +48,11 @@ class RuntimeInfoHook(Hook):
|
|||
'dataset_meta', runner.train_dataloader.dataset.metainfo)
|
||||
|
||||
def before_train_epoch(self, runner) -> None:
|
||||
"""Update current epoch information before every epoch."""
|
||||
"""Update current epoch information before every epoch.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
"""
|
||||
runner.message_hub.update_info('epoch', runner.epoch)
|
||||
|
||||
def before_train_iter(self,
|
||||
|
@ -47,7 +60,14 @@ class RuntimeInfoHook(Hook):
|
|||
batch_idx: int,
|
||||
data_batch: DATA_BATCH = None) -> None:
|
||||
"""Update current iter and learning rate information before every
|
||||
iteration."""
|
||||
iteration.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
batch_idx (int): The index of the current batch in the train loop.
|
||||
data_batch (Sequence[dict], optional): Data from dataloader.
|
||||
Defaults to None.
|
||||
"""
|
||||
runner.message_hub.update_info('iter', runner.iter)
|
||||
lr_dict = runner.optim_wrapper.get_lr()
|
||||
assert isinstance(lr_dict, dict), (
|
||||
|
@ -65,7 +85,15 @@ class RuntimeInfoHook(Hook):
|
|||
batch_idx: int,
|
||||
data_batch: DATA_BATCH = None,
|
||||
outputs: Optional[dict] = None) -> None:
|
||||
"""Update ``log_vars`` in model outputs every iteration."""
|
||||
"""Update ``log_vars`` in model outputs every iteration.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
batch_idx (int): The index of the current batch in the train loop.
|
||||
data_batch (Sequence[dict], optional): Data from dataloader.
|
||||
Defaults to None.
|
||||
outputs (dict, optional): Outputs from model. Defaults to None.
|
||||
"""
|
||||
if outputs is not None:
|
||||
for key, value in outputs.items():
|
||||
runner.message_hub.update_scalar(f'train/{key}', value)
|
||||
|
|
|
@ -14,7 +14,7 @@ class DistSamplerSeedHook(Hook):
|
|||
|
||||
priority = 'NORMAL'
|
||||
|
||||
def before_train_epoch(self, runner, mode: str = 'train') -> None:
|
||||
def before_train_epoch(self, runner) -> None:
|
||||
"""Set the seed for sampler and batch_sampler.
|
||||
|
||||
Args:
|
||||
|
|
Loading…
Reference in New Issue