[Docs] translate initialize docs (#687)
* translate initialize * Fix as comment * minor refine * Fix as comment * Apply suggestions from code review Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>pull/792/head
parent
b142774a42
commit
e4cf5fda10
|
@ -1,3 +1,325 @@
|
|||
# Initialization
|
||||
|
||||
Coming soon. Please refer to [chinese documentation](https://mmengine.readthedocs.io/zh_CN/latest/advanced_tutorials/initialize.html).
|
||||
Usually, we'll customize our module based on [nn.Module](https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module), which is implemented by Native PyTorch. Also, [torch.nn.init](https://pytorch.org/docs/stable/nn.init.html) could help us initialize the parameters of the model easily. To simplify the process of model construction and initialization, MMEngine designed the [BaseModule](mmengine.model.BaseModule) to help us define and initialize the model from config easily.
|
||||
|
||||
## Initialize the model from config
|
||||
|
||||
The core function of `BaseModule` is that it could help us to initialize the model from config. Subclasses inherited from `BaseModule` could define the `init_cfg` in the `__init__` function, and we can choose the method of initialization by configuring `init_cfg`.
|
||||
|
||||
Currently, we support the following initialization methods:
|
||||
|
||||
| Initializer | Registered name | Function |
|
||||
| :-------------------------------------------------------------------------------------------------------- | :-------------: | :--------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| [ConstantInit](../api/generated/mmengine.model.ConstantInit.html#mmengine.model.ConstantInit) | Constant | Initialize the weight and bias with a constant, commonly used for Convolution |
|
||||
| [XavierInit](../api/generated/mmengine.model.XavierInit.html#mmengine.model.XavierInit) | Xavier | Initialize the weight by `Xavier` initialization, and initialize the bias with a constant |
|
||||
| [NormalInit](../api/generated/mmengine.model.NormalInit.html#mmengine.model.NormalInit) | Normal | Initialize the weight by normal distribution, and initialize the bias with a constant |
|
||||
| [TruncNormalInit](../api/generated/mmengine.model.TruncNormalInit.html#mmengine.model.TruncNormalInit) | TruncNormal | Initialize the weight by truncated normal distribution, and initialize the bias with a constant,commonly used for Transformer |
|
||||
| [UniformInit](../api/generated/mmengine.model.UniformInit.html#mmengine.model.UniformInit) | Uniform | Initialize the weight by uniform distribution, and initialize the bias with a constant,commonly used for convolution |
|
||||
| [KaimingInit](../api/generated/mmengine.model.KaimingInit.html#mmengine.model.KaimingInit) | Kaiming | Initialize the weight by `Kaiming` initialization, and initialize the bias with a constant. Commonly used for convolution |
|
||||
| [Caffe2XavierInit](../api/generated/mmengine.model.Caffe2XavierInit.html#mmengine.model.Caffe2XavierInit) | Caffe2Xavier | `Xavier` initialization in Caffe2, and `Kaiming` initialization in PyTorh with `fan_in` and `normal` mode. Commonly used for convolution |
|
||||
| [PretrainedInit](../api/generated/mmengine.model.PretrainedInit.html#mmengine.model.PretrainedInit) | Pretrained | Initialize the model with the pretrained model |
|
||||
|
||||
### Initialize the model with pretrained model
|
||||
|
||||
Defining the `ToyNet` as below:
|
||||
|
||||
```python
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from mmengine.model import BaseModule
|
||||
|
||||
|
||||
class ToyNet(BaseModule):
|
||||
|
||||
def __init__(self, init_cfg=None):
|
||||
super().__init__(init_cfg)
|
||||
self.conv1 = nn.Linear(1, 1)
|
||||
|
||||
|
||||
# Save the checkpoint.
|
||||
toy_net = ToyNet()
|
||||
torch.save(toy_net.state_dict(), './pretrained.pth')
|
||||
pretrained = './pretrained.pth'
|
||||
|
||||
toy_net = ToyNet(init_cfg=dict(type='Pretrained', checkpoint=pretrained))
|
||||
```
|
||||
|
||||
and then we can configure the `init_cfg` to make it load the pretrained model by calling `initi_weights()` after its construction.
|
||||
|
||||
```python
|
||||
# Initialize the model with the saved checkpoint.
|
||||
toy_net.init_weights()
|
||||
```
|
||||
|
||||
```
|
||||
08/19 16:50:24 - mmengine - INFO - load model from: ./pretrained.pth
|
||||
08/19 16:50:24 - mmengine - INFO - local loads checkpoint from path: ./pretrained.pth
|
||||
```
|
||||
|
||||
If `init_cfg` is a `dict`, `type` means a kind of initializer registered in `WEIGHT_INITIALIZERS`. The `Pretrained` means `PretrainedInit`, which could help us to load the target checkpoint.
|
||||
All initializers have the same mapping relationship like `Pretrained` -> `PretrainedInit`, which strips the suffix `Init` of the class name. The `checkpoint` argument of `PretrainedInit` means the path of the checkpoint. It could be a local path or a URL.
|
||||
|
||||
### Commonly used initialization methods
|
||||
|
||||
Similarly, we could use the `Kaiming` initialization just like `Pretrained` initializer. For example, we could make `init_cfg=dict(type='Kaiming', layer='Conv2d')` to initialize all `Conv2d` module with `Kaiming` initialization.
|
||||
|
||||
Sometimes we need to initialize the model with different initialization methods for different modules. For example, we could initialize the `Conv2d` module with `Kaiming` initialization and initialize the `Linear` module with `Xavier` initialization. We could make `init_cfg=dict(type='Kaiming', layer='Conv2d')`:
|
||||
|
||||
```python
|
||||
import torch.nn as nn
|
||||
|
||||
from mmengine.model import BaseModule
|
||||
|
||||
|
||||
class ToyNet(BaseModule):
|
||||
|
||||
def __init__(self, init_cfg=None):
|
||||
super().__init__(init_cfg)
|
||||
self.linear = nn.Linear(1, 1)
|
||||
self.conv = nn.Conv2d(1, 1, 1)
|
||||
|
||||
|
||||
# Apply `Kaiming` initialization to `Conv2d` module and `Xavier` initialization to `Linear` module.
|
||||
toy_net = ToyNet(
|
||||
init_cfg=[
|
||||
dict(type='Kaiming', layer='Conv2d'),
|
||||
dict(type='Xavier', layer='Linear')
|
||||
], )
|
||||
toy_net.init_weights()
|
||||
```
|
||||
|
||||
```
|
||||
08/19 16:50:24 - mmengine - INFO -
|
||||
linear.weight - torch.Size([1, 1]):
|
||||
XavierInit: gain=1, distribution=normal, bias=0
|
||||
|
||||
08/19 16:50:24 - mmengine - INFO -
|
||||
linear.bias - torch.Size([1]):
|
||||
XavierInit: gain=1, distribution=normal, bias=0
|
||||
|
||||
08/19 16:50:24 - mmengine - INFO -
|
||||
conv.weight - torch.Size([1, 1, 1, 1]):
|
||||
KaimingInit: a=0, mode=fan_out, nonlinearity=relu, distribution =normal, bias=0
|
||||
|
||||
08/19 16:50:24 - mmengine - INFO -
|
||||
conv.bias - torch.Size([1]):
|
||||
KaimingInit: a=0, mode=fan_out, nonlinearity=relu, distribution =normal, bias=0
|
||||
```
|
||||
|
||||
`layer` could also be a list, each element of which means a type of applied module.
|
||||
|
||||
```python
|
||||
# Apply Kaiming initialization to `Conv2d` and `Linear` module.
|
||||
toy_net = ToyNet(init_cfg=[dict(type='Kaiming', layer=['Conv2d', 'Linear'])], )
|
||||
toy_net.init_weights()
|
||||
```
|
||||
|
||||
```
|
||||
08/19 16:50:24 - mmengine - INFO -
|
||||
linear.weight - torch.Size([1, 1]):
|
||||
KaimingInit: a=0, mode=fan_out, nonlinearity=relu, distribution =normal, bias=0
|
||||
|
||||
08/19 16:50:24 - mmengine - INFO -
|
||||
linear.bias - torch.Size([1]):
|
||||
KaimingInit: a=0, mode=fan_out, nonlinearity=relu, distribution =normal, bias=0
|
||||
|
||||
08/19 16:50:24 - mmengine - INFO -
|
||||
conv.weight - torch.Size([1, 1, 1, 1]):
|
||||
KaimingInit: a=0, mode=fan_out, nonlinearity=relu, distribution =normal, bias=0
|
||||
|
||||
08/19 16:50:24 - mmengine - INFO -
|
||||
conv.bias - torch.Size([1]):
|
||||
KaimingInit: a=0, mode=fan_out, nonlinearity=relu, distribution =normal, bias=0
|
||||
```
|
||||
|
||||
### More fine-grained initialization
|
||||
|
||||
Sometimes we need to initialize the same type of module with different types of initialization. For example, we've defined `conv1` and `conv2` submodules, and we want to initialize the `conv1` with `Kaiming` initialization and `conv2` with `Xavier` initialization. We could configure the init_cfg with `override`:
|
||||
|
||||
```python
|
||||
import torch.nn as nn
|
||||
|
||||
from mmengine.model import BaseModule
|
||||
|
||||
|
||||
class ToyNet(BaseModule):
|
||||
|
||||
def __init__(self, init_cfg=None):
|
||||
super().__init__(init_cfg)
|
||||
self.conv1 = nn.Conv2d(1, 1, 1)
|
||||
self.conv2 = nn.Conv2d(1, 1, 1)
|
||||
|
||||
|
||||
# Apllly `Kaiming` initialization to `conv1` and `Xavier` initialization to `conv2`.
|
||||
toy_net = ToyNet(
|
||||
init_cfg=[
|
||||
dict(
|
||||
type='Kaiming',
|
||||
layer=['Conv2d'],
|
||||
override=dict(name='conv2', type='Xavier')),
|
||||
], )
|
||||
toy_net.init_weights()
|
||||
```
|
||||
|
||||
```
|
||||
08/19 16:50:24 - mmengine - INFO -
|
||||
conv1.weight - torch.Size([1, 1, 1, 1]):
|
||||
KaimingInit: a=0, mode=fan_out, nonlinearity=relu, distribution =normal, bias=0
|
||||
|
||||
08/19 16:50:24 - mmengine - INFO -
|
||||
conv1.bias - torch.Size([1]):
|
||||
KaimingInit: a=0, mode=fan_out, nonlinearity=relu, distribution =normal, bias=0
|
||||
|
||||
08/19 16:50:24 - mmengine - INFO -
|
||||
conv2.weight - torch.Size([1, 1, 1, 1]):
|
||||
XavierInit: gain=1, distribution=normal, bias=0
|
||||
|
||||
08/19 16:50:24 - mmengine - INFO -
|
||||
conv2.bias - torch.Size([1]):
|
||||
KaimingInit: a=0, mode=fan_out, nonlinearity=relu, distribution =normal, bias=0
|
||||
```
|
||||
|
||||
`override` could be understood as an nested `init_cfg`, which could also be a `list` or `dict`, and we should also set "`type`" for it. The difference is that we must set `name` in `override` to specify the applied scope for submodule. As the example above, we set `name='conv2'` to specify that the `Xavier` initialization is applied to all submodules of `toy_net.conv2`.
|
||||
|
||||
### Customize the initialization method
|
||||
|
||||
Although the `init_cfg` could control the initialization method for different modules, we would have to register a new initialization method to `WEIGHT_INITIALIZERS` if we want to customize initialization process. It is not convenient right? Actually, we could also override the `init_weights` method to customize the initialization process.
|
||||
|
||||
Assuming we've defined the following modules:
|
||||
|
||||
- `ToyConv` inherit from `nn.Module`, implements `init_weights`which initialize `custom_weight`(`parameter` of `ToyConv`) with 1 and initialize `custom_bias` with 0
|
||||
|
||||
- `ToyNet` defines a `ToyConv` submodule.
|
||||
|
||||
`ToyNet.init_weights` will call `init_weights` of all submodules sequentially.
|
||||
|
||||
```python
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from mmengine.model import BaseModule
|
||||
|
||||
|
||||
class ToyConv(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.custom_weight = nn.Parameter(torch.empty(1, 1, 1, 1))
|
||||
self.custom_bias = nn.Parameter(torch.empty(1))
|
||||
|
||||
def init_weights(self):
|
||||
with torch.no_grad():
|
||||
self.custom_weight = self.custom_weight.fill_(1)
|
||||
self.custom_bias = self.custom_bias.fill_(0)
|
||||
|
||||
|
||||
class ToyNet(BaseModule):
|
||||
|
||||
def __init__(self, init_cfg=None):
|
||||
super().__init__(init_cfg)
|
||||
self.conv1 = nn.Conv2d(1, 1, 1)
|
||||
self.conv2 = nn.Conv2d(1, 1, 1)
|
||||
self.custom_conv = ToyConv()
|
||||
|
||||
|
||||
toy_net = ToyNet(
|
||||
init_cfg=[
|
||||
dict(
|
||||
type='Kaiming',
|
||||
layer=['Conv2d'],
|
||||
override=dict(name='conv2', type='Xavier'))
|
||||
])
|
||||
|
||||
toy_net.init_weights()
|
||||
```
|
||||
|
||||
```
|
||||
08/19 16:50:24 - mmengine - INFO -
|
||||
conv1.weight - torch.Size([1, 1, 1, 1]):
|
||||
KaimingInit: a=0, mode=fan_out, nonlinearity=relu, distribution =normal, bias=0
|
||||
|
||||
08/19 16:50:24 - mmengine - INFO -
|
||||
conv1.bias - torch.Size([1]):
|
||||
KaimingInit: a=0, mode=fan_out, nonlinearity=relu, distribution =normal, bias=0
|
||||
|
||||
08/19 16:50:24 - mmengine - INFO -
|
||||
conv2.weight - torch.Size([1, 1, 1, 1]):
|
||||
XavierInit: gain=1, distribution=normal, bias=0
|
||||
|
||||
08/19 16:50:24 - mmengine - INFO -
|
||||
conv2.bias - torch.Size([1]):
|
||||
KaimingInit: a=0, mode=fan_out, nonlinearity=relu, distribution =normal, bias=0
|
||||
|
||||
08/19 16:50:24 - mmengine - INFO -
|
||||
custom_conv.custom_weight - torch.Size([1, 1, 1, 1]):
|
||||
Initialized by user-defined `init_weights` in ToyConv
|
||||
|
||||
08/19 16:50:24 - mmengine - INFO -
|
||||
custom_conv.custom_bias - torch.Size([1]):
|
||||
Initialized by user-defined `init_weights` in ToyConv
|
||||
```
|
||||
|
||||
### Conclusion
|
||||
|
||||
**1. Configure `init_cfg` to initialize model**
|
||||
|
||||
- Commonly used for the initialization of `Conv2d`, `Linear` and other underlying module. All initialization methods should be managed by `WEIGHT_INITIALIZERS`
|
||||
- Dynamic initialization controlled by `init_cfg`
|
||||
|
||||
**2. Customize `init_weights`**
|
||||
|
||||
- Compared to configuring the `init_cfg`, implementing the `init_weights` is simpler and does not require registration. However, it is not as flexible as `init_cfg`, and it is not possible to initialize the module dynamically.
|
||||
|
||||
```{note}
|
||||
- The priorify of init_weights is higher than `init_cfg`
|
||||
- Runner will call `init_weights` in Runner.train()
|
||||
```
|
||||
|
||||
### Ininitailize module with function
|
||||
|
||||
As mentioned in prior \[section\](#Customize the initialization method), we could customize our initialization in `init_weights`. To make it more convenient to initialize modules, MMEngine provides a series of **module initialization functions** to initialize the whole module based on `torch.nn.init`. For example, we want to initialize the weights of the convolutional layer with normal distribution and initialize the bias of the convolutional layer with a constant. The implementation of `torch.nn.init` is as follows:
|
||||
|
||||
```python
|
||||
from torch.nn.init import normal_, constant_
|
||||
import torch.nn as nn
|
||||
|
||||
model = nn.Conv2d(1, 1, 1)
|
||||
normal_(model.weight, mean=0, std=0.01)
|
||||
constant_(model.bias, val=0)
|
||||
```
|
||||
|
||||
```
|
||||
Parameter containing:
|
||||
tensor([0.], requires_grad=True)
|
||||
```
|
||||
|
||||
The above process is actually a standard process for initializing a convolutional module with normal distribution, so MMEngine simplifies this by implementing a series of common **module** initialization functions. Compared with `torch.nn.init`, the module initialization functions could accept the convolution module directly:
|
||||
|
||||
```python
|
||||
from mmengine.model import normal_init
|
||||
|
||||
normal_init(model, mean=0, std=0.01, bias=0)
|
||||
```
|
||||
|
||||
Similarly, we could also use [Kaiming](http://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf) initialization and [Xavier](http://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf) initialization:
|
||||
|
||||
```python
|
||||
from mmengine.model import kaiming_init, xavier_init
|
||||
|
||||
kaiming_init(model)
|
||||
xavier_init(model)
|
||||
```
|
||||
|
||||
Currently, MMEngine provide the following initialization function:
|
||||
|
||||
| initialization function | function |
|
||||
| :----------------------------------------------------------------------------------------------------------------- | :--------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| [constant_init](../api/generated/mmengine.model.constant_init.html#mmengine.model.constant_init) | Initialize the weight and bias with a constant, commonly used for Convolution |
|
||||
| [xavier_init](../api/generated/mmengine.model.xavier_init.html#mmengine.model.xavier_init) | Initialize the weight by `Xavier` initialization, and initialize the bias with a constant |
|
||||
| [normal_init](../api/generated/mmengine.model.normal_init.html#mmengine.model.normal_init) | Initialize the weight by normal distribution, and initialize the bias with a constant |
|
||||
| [trunc_normal_init](../api/generated/mmengine.model.trunc_normal_init.html#mmengine.model.trunc_normal_init) | Initialize the weight by truncated normal distribution, and initialize the bias with a constant,commonly used for Transformer |
|
||||
| [uniform_init](../api/generated/mmengine.model.uniform_init.html#mmengine.model.uniform_init) | Initialize the weight by uniform distribution, and initialize the bias with a constant,commonly used for convolution |
|
||||
| [kaiming_init](../api/generated/mmengine.model.kaiming_init.html#mmengine.model.kaiming_init) | Initialize the weight by `Kaiming` initialization, and initialize the bias with a constant. Commonly used for convolution |
|
||||
| [caffe2_xavier_init](../api/generated/mmengine.model.caffe2_xavier_init.html#mmengine.model.caffe2_xavier_init) | `Xavier` initialization in Caffe2, and `Kaiming` initialization in PyTorh with `fan_in` and `normal` mode. Commonly used for convolution |
|
||||
| [bias_init_with_prob](../api/generated/mmengine.model.bias_init_with_prob.html#mmengine.model.bias_init_with_prob) | Initialize the bias with the probability |
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
# 初始化
|
||||
|
||||
基于 Pytorch 构建模型时,我们通常会选择 [nn.Module](https://pytorch.org/docs/stable/nn.html?highlight=nn%20module#module-torch.nn.modules) 作为模型的基类,搭配使用 Pytorch 的初始化模块 [torch.nn.init](https://pytorch.org/docs/stable/nn.init.html?highlight=kaiming#torch.nn.init.kaiming_normal_),完成模型的初始化。`MMEngine` 在此基础上抽象出基础模块(BaseModule),让我们能够通过传参或配置文件来选择模型的初始化方式。此外,`MMEngine` 还提供了一系列模块初始化函数,让我们能够更加方便灵活地初始化模型参数。
|
||||
基于 Pytorch 构建模型时,我们通常会选择 [nn.Module](https://pytorch.org/docs/stable/nn.html?highlight=nn%20module#module-torch.nn.modules) 作为模型的基类,搭配使用 Pytorch 的初始化模块 [torch.nn.init](https://pytorch.org/docs/stable/nn.init.html?highlight=kaiming#torch.nn.init.kaiming_normal_),完成模型的初始化。MMEngine 在此基础上抽象出基础模块(BaseModule),让我们能够通过传参或配置文件来选择模型的初始化方式。此外,`MMEngine` 还提供了一系列模块初始化函数,让我们能够更加方便灵活地初始化模型参数。
|
||||
|
||||
## 配置式初始化
|
||||
|
||||
|
@ -60,7 +60,7 @@ toy_net.init_weights()
|
|||
和使用 `PretrainedInit` 初始化器类似,如果我们想对卷积做 `Kaiming` 初始化,需要令 `init_cfg=dict(type='Kaiming', layer='Conv2d')`。这样模型初始化时,就会以 `Kaiming` 初始化的方式来初始化类型为 `Conv2d` 的模块。
|
||||
|
||||
有时候我们可能需要用不同的初始化方式去初始化不同类型的模块,例如对卷积使用 `Kaiming` 初始化,对线性层使用 `Xavier`
|
||||
初始化。此时我们可以使 `init_cfg` 成为一个列表,,其中的每一个元素都表示对某些层使用特定的初始化方式。
|
||||
初始化。此时我们可以使 `init_cfg` 成为一个列表,其中的每一个元素都表示对某些层使用特定的初始化方式。
|
||||
|
||||
```python
|
||||
import torch.nn as nn
|
||||
|
@ -129,7 +129,7 @@ conv.bias - torch.Size([1]):
|
|||
KaimingInit: a=0, mode=fan_out, nonlinearity=relu, distribution =normal, bias=0
|
||||
```
|
||||
|
||||
#### 更细粒度的初始化
|
||||
### 更细粒度的初始化
|
||||
|
||||
有时同一类型的不同模块有不同初始化方式,例如现在有 `conv1` 和 `conv2` 两个模块,他们的类型均为 `Conv2d`
|
||||
。我们需要对 conv1 进行 `Kaiming` 初始化,conv2 进行 `Xavier` 初始化,则可以通过配置 `override` 参数来满足这样的需求:
|
||||
|
@ -148,7 +148,7 @@ class ToyNet(BaseModule):
|
|||
self.conv2 = nn.Conv2d(1, 1, 1)
|
||||
|
||||
|
||||
# 对 conv1 做卷积初始化,对 从 conv2 做 Xavier 初始化
|
||||
# 对 conv1 做 Kaiming 初始化,对 从 conv2 做 Xavier 初始化
|
||||
toy_net = ToyNet(
|
||||
init_cfg=[
|
||||
dict(
|
||||
|
@ -190,7 +190,8 @@ KaimingInit: a=0, mode=fan_out, nonlinearity=relu, distribution =normal, bias=0
|
|||
假设我们定义了以下模块:
|
||||
|
||||
- 继承自 `nn.Module` 的 `ToyConv`,实现了 `init_weights` 方法,让 `custom_weight` 初始化为 1,`custom_bias` 初始化为 0
|
||||
- 继承自模块基类的模型 `ToyNet`,且含有 `ToyConv` 子模块。
|
||||
|
||||
- 继承自模块基类的模型 `ToyNet`,且含有 `ToyConv` 子模块
|
||||
|
||||
我们在调用 `ToyNet` 的 `init_weights` 方法时,会链式的调用的子模块 `ToyConv` 的 `init_weights` 方法,实现自定义模块的初始化。
|
||||
|
||||
|
|
Loading…
Reference in New Issue