[Fix] Unloaded weights will not be initialized when using PretrainedIinit (#764)

* Separate init_cfgs to pretrained_cfg and other_cfgs

* Fix unit test

* update documentation

* Fix render of initialize.md

* Fix as comment

* rename initialize.md to weight_initialization.md

* add file

* fix ci

* rename weight_initialization.md to initialize.md

* Fix duplicated .md
This commit is contained in:
Mashiro 2023-01-09 18:46:30 +08:00 committed by GitHub
parent f10b5cefd9
commit 925ac870e2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 86 additions and 16 deletions

View File

@ -236,7 +236,7 @@ runner.train()
- [Config](https://mmengine.readthedocs.io/en/latest/tutorials/config.html)
- [BaseDataset](https://mmengine.readthedocs.io/en/latest/advanced_tutorials/basedataset.html)
- [Data Transform](https://mmengine.readthedocs.io/en/latest/tutorials/data_transform.html)
- [Initialization](https://mmengine.readthedocs.io/en/latest/advanced_tutorials/initialize.html)
- [Weight Initialization](https://mmengine.readthedocs.io/en/latest/advanced_tutorials/initialize.html)
- [Visualization](https://mmengine.readthedocs.io/en/latest/advanced_tutorials/visualization.html)
- [Abstract Data Element](https://mmengine.readthedocs.io/en/latest/advanced_tutorials/data_element.html)
- [Distribution Communication](https://mmengine.readthedocs.io/en/latest/advanced_tutorials/distributed.html)

View File

@ -247,9 +247,8 @@ runner.train()
- [配置](https://mmengine.readthedocs.io/zh_CN/latest/tutorials/config.html)
- [数据集基类](https://mmengine.readthedocs.io/zh_CN/latest/advanced_tutorials/basedataset.html)
- [抽象数据接口](https://mmengine.readthedocs.io/zh_CN/latest/advanced_tutorials/data_element.html)
- [可视化](https://mmengine.readthedocs.io/zh_CN/latest/advanced_tutorials/visualization.html)
- [数据变换](https://mmengine.readthedocs.io/zh_CN/latest/tutorials/data_transform.html)
- [初始化](https://mmengine.readthedocs.io/zh_CN/latest/advanced_tutorials/initialize.html)
- [权重初始化](https://mmengine.readthedocs.io/zh_CN/latest/advanced_tutorials/initialize.html)
- [可视化](https://mmengine.readthedocs.io/zh_CN/latest/advanced_tutorials/visualization.html)
- [抽象数据接口](https://mmengine.readthedocs.io/zh_CN/latest/advanced_tutorials/data_element.html)
- [分布式通信原语](https://mmengine.readthedocs.io/zh_CN/latest/advanced_tutorials/distributed.html)

View File

@ -1,4 +1,4 @@
# Initialization
# Weight initialization
Usually, we'll customize our module based on [nn.Module](https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module), which is implemented by Native PyTorch. Also, [torch.nn.init](https://pytorch.org/docs/stable/nn.init.html) could help us initialize the parameters of the model easily. To simplify the process of model construction and initialization, MMEngine designed the [BaseModule](mmengine.model.BaseModule) to help us define and initialize the model from config easily.
@ -107,6 +107,10 @@ toy_net.init_weights()
If `init_cfg` is a `dict`, `type` means a kind of initializer registered in `WEIGHT_INITIALIZERS`. The `Pretrained` means `PretrainedInit`, which could help us to load the target checkpoint.
All initializers have the same mapping relationship like `Pretrained` -> `PretrainedInit`, which strips the suffix `Init` of the class name. The `checkpoint` argument of `PretrainedInit` means the path of the checkpoint. It could be a local path or a URL.
```{note}
`PretrainedInit` has a higher priority than any other initializer. The loaded pretrained weights will overwrite the previous initialized weights.
```
### Commonly used initialization methods
Similarly, we could use the `Kaiming` initialization just like `Pretrained` initializer. For example, we could make `init_cfg=dict(type='Kaiming', layer='Conv2d')` to initialize all `Conv2d` module with `Kaiming` initialization.

View File

@ -1,4 +1,4 @@
# 初始化
# 权重初始化
基于 Pytorch 构建模型时,我们通常会选择 [nn.Module](https://pytorch.org/docs/stable/nn.html?highlight=nn%20module#module-torch.nn.modules) 作为模型的基类,搭配使用 Pytorch 的初始化模块 [torch.nn.init](https://pytorch.org/docs/stable/nn.init.html?highlight=kaiming#torch.nn.init.kaiming_normal_)完成模型的初始化。MMEngine 在此基础上抽象出基础模块BaseModule,让我们能够通过传参或配置文件来选择模型的初始化方式。此外,`MMEngine` 还提供了一系列模块初始化函数,让我们能够更加方便灵活地初始化模型参数。
@ -102,6 +102,10 @@ toy_net.init_weights()
`init_cfg` 是一个字典时,`type` 字段就表示一种初始化器,它需要被注册到 `WEIGHT_INITIALIZERS` [注册器](registry.md)。我们可以通过指定 `init_cfg=dict(type='Pretrained', checkpoint='path/to/ckpt')` 来加载预训练权重,其中 `Pretrained``PretrainedInit` 初始化器的缩写,这个映射名由 `WEIGHT_INITIALIZERS` 维护;`checkpoint``PretrainedInit` 的初始化参数,用于指定权重的加载路径,它可以是本地磁盘路径,也可以是 URL。
```{note}
在所有的初始化器中,`PretrainedInit` 拥有最高的优先级。`init_cfg` 中其他初始化器初始化的权重会被 `PretrainedInit` 加载的预训练权重覆盖。
```
### 常用的初始化方式
和使用 `PretrainedInit` 初始化器类似,如果我们想对卷积做 `Kaiming` 初始化,需要令 `init_cfg=dict(type='Kaiming', layer='Conv2d')`。这样模型初始化时,就会以 `Kaiming` 初始化的方式来初始化类型为 `Conv2d` 的模块。

View File

@ -5,7 +5,7 @@ import warnings
from abc import ABCMeta
from collections import defaultdict
from logging import FileHandler
from typing import Iterable, Optional
from typing import Iterable, List, Optional, Union
import torch.nn as nn
@ -26,11 +26,17 @@ class BaseModule(nn.Module, metaclass=ABCMeta):
- ``_params_init_info``: Used to track the parameter initialization
information. This attribute only exists during executing the
``init_weights``.
Note:
:obj:`PretrainedInit` has a higher priority than any other
initializer. The loaded pretrained weights will overwrite
the previous initialized weights.
Args:
init_cfg (dict, optional): Initialization config dict.
init_cfg (dict or List[dict], optional): Initialization config dict.
"""
def __init__(self, init_cfg=None):
def __init__(self, init_cfg: Union[dict, List[dict], None] = None):
"""Initialize BaseModule, inherited from `torch.nn.Module`"""
# NOTE init_cfg can be defined in different levels, but init_cfg
@ -100,14 +106,25 @@ class BaseModule(nn.Module, metaclass=ABCMeta):
f'initialize {module_name} with init_cfg {self.init_cfg}',
logger=logger_name,
level=logging.DEBUG)
initialize(self, self.init_cfg)
init_cfgs = self.init_cfg
if isinstance(self.init_cfg, dict):
# prevent the parameters of
# the pre-trained model
# from being overwritten by
# the `init_weights`
if self.init_cfg['type'] == 'Pretrained':
return
init_cfgs = [self.init_cfg]
# PretrainedInit has higher priority than any other init_cfg.
# Therefore we initialize `pretrained_cfg` last to overwrite
# the previous initialized weights.
# See details in https://github.com/open-mmlab/mmengine/issues/691 # noqa E501
other_cfgs = []
pretrained_cfg = []
for init_cfg in init_cfgs:
assert isinstance(init_cfg, dict)
if init_cfg['type'] == 'Pretrained':
pretrained_cfg.append(init_cfg)
else:
other_cfgs.append(init_cfg)
initialize(self, other_cfgs)
for m in self.children():
if hasattr(m, 'init_weights'):
@ -118,7 +135,8 @@ class BaseModule(nn.Module, metaclass=ABCMeta):
init_info=f'Initialized by '
f'user-defined `init_weights`'
f' in {m.__class__.__name__} ')
if self.init_cfg and pretrained_cfg:
initialize(self, pretrained_cfg)
self._is_init = True
else:
warnings.warn(f'init_weights of {self.__class__.__name__} has '

View File

@ -1,9 +1,13 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import logging
import os.path as osp
import tempfile
from unittest import TestCase
import torch
from torch import nn
from torch.nn.init import constant_
from mmengine.logging.logger import MMLogger
from mmengine.model import BaseModule, ModuleDict, ModuleList, Sequential
@ -90,6 +94,7 @@ class FooModel(BaseModule):
class TestBaseModule(TestCase):
def setUp(self) -> None:
self.temp_dir = tempfile.TemporaryDirectory()
self.BaseModule = BaseModule()
self.model_cfg = dict(
type='FooModel',
@ -110,6 +115,7 @@ class TestBaseModule(TestCase):
self.logger = MMLogger.get_instance(self._testMethodName)
def tearDown(self) -> None:
self.temp_dir.cleanup()
logging.shutdown()
MMLogger._instance_dict.clear()
return super().tearDown()
@ -177,6 +183,45 @@ class TestBaseModule(TestCase):
assert torch.equal(self.model.reg.bias,
torch.full(self.model.reg.bias.shape, 2.0))
# Test build model from Pretrained weights
class CustomLinear(BaseModule):
def __init__(self, init_cfg=None):
super().__init__(init_cfg)
self.linear = nn.Linear(1, 1)
def init_weights(self):
constant_(self.linear.weight, 1)
constant_(self.linear.bias, 2)
@FOOMODELS.register_module()
class PratrainedModel(FooModel):
def __init__(self,
component1=None,
component2=None,
component3=None,
component4=None,
init_cfg=None) -> None:
super().__init__(component1, component2, component3,
component4, init_cfg)
self.linear = CustomLinear()
checkpoint_path = osp.join(self.temp_dir.name, 'test.pth')
torch.save(self.model.state_dict(), checkpoint_path)
model_cfg = copy.deepcopy(self.model_cfg)
model_cfg['type'] = 'PratrainedModel'
model_cfg['init_cfg'] = dict(
type='Pretrained', checkpoint=checkpoint_path)
model = FOOMODELS.build(model_cfg)
ori_layer_weight = model.linear.linear.weight.clone()
ori_layer_bias = model.linear.linear.bias.clone()
model.init_weights()
self.assertTrue((ori_layer_weight != model.linear.linear.weight).any())
self.assertTrue((ori_layer_bias != model.linear.linear.bias).any())
def test_dump_init_info(self):
import os
import shutil