mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Fix] Unloaded weights will not be initialized when using PretrainedIinit (#764)
* Separate init_cfgs to pretrained_cfg and other_cfgs * Fix unit test * update documentation * Fix render of initialize.md * Fix as comment * rename initialize.md to weight_initialization.md * add file * fix ci * rename weight_initialization.md to initialize.md * Fix duplicated .md
This commit is contained in:
parent
f10b5cefd9
commit
925ac870e2
@ -236,7 +236,7 @@ runner.train()
|
||||
- [Config](https://mmengine.readthedocs.io/en/latest/tutorials/config.html)
|
||||
- [BaseDataset](https://mmengine.readthedocs.io/en/latest/advanced_tutorials/basedataset.html)
|
||||
- [Data Transform](https://mmengine.readthedocs.io/en/latest/tutorials/data_transform.html)
|
||||
- [Initialization](https://mmengine.readthedocs.io/en/latest/advanced_tutorials/initialize.html)
|
||||
- [Weight Initialization](https://mmengine.readthedocs.io/en/latest/advanced_tutorials/initialize.html)
|
||||
- [Visualization](https://mmengine.readthedocs.io/en/latest/advanced_tutorials/visualization.html)
|
||||
- [Abstract Data Element](https://mmengine.readthedocs.io/en/latest/advanced_tutorials/data_element.html)
|
||||
- [Distribution Communication](https://mmengine.readthedocs.io/en/latest/advanced_tutorials/distributed.html)
|
||||
|
@ -247,9 +247,8 @@ runner.train()
|
||||
- [配置](https://mmengine.readthedocs.io/zh_CN/latest/tutorials/config.html)
|
||||
- [数据集基类](https://mmengine.readthedocs.io/zh_CN/latest/advanced_tutorials/basedataset.html)
|
||||
- [抽象数据接口](https://mmengine.readthedocs.io/zh_CN/latest/advanced_tutorials/data_element.html)
|
||||
- [可视化](https://mmengine.readthedocs.io/zh_CN/latest/advanced_tutorials/visualization.html)
|
||||
- [数据变换](https://mmengine.readthedocs.io/zh_CN/latest/tutorials/data_transform.html)
|
||||
- [初始化](https://mmengine.readthedocs.io/zh_CN/latest/advanced_tutorials/initialize.html)
|
||||
- [权重初始化](https://mmengine.readthedocs.io/zh_CN/latest/advanced_tutorials/initialize.html)
|
||||
- [可视化](https://mmengine.readthedocs.io/zh_CN/latest/advanced_tutorials/visualization.html)
|
||||
- [抽象数据接口](https://mmengine.readthedocs.io/zh_CN/latest/advanced_tutorials/data_element.html)
|
||||
- [分布式通信原语](https://mmengine.readthedocs.io/zh_CN/latest/advanced_tutorials/distributed.html)
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Initialization
|
||||
# Weight initialization
|
||||
|
||||
Usually, we'll customize our module based on [nn.Module](https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module), which is implemented by Native PyTorch. Also, [torch.nn.init](https://pytorch.org/docs/stable/nn.init.html) could help us initialize the parameters of the model easily. To simplify the process of model construction and initialization, MMEngine designed the [BaseModule](mmengine.model.BaseModule) to help us define and initialize the model from config easily.
|
||||
|
||||
@ -107,6 +107,10 @@ toy_net.init_weights()
|
||||
If `init_cfg` is a `dict`, `type` means a kind of initializer registered in `WEIGHT_INITIALIZERS`. The `Pretrained` means `PretrainedInit`, which could help us to load the target checkpoint.
|
||||
All initializers have the same mapping relationship like `Pretrained` -> `PretrainedInit`, which strips the suffix `Init` of the class name. The `checkpoint` argument of `PretrainedInit` means the path of the checkpoint. It could be a local path or a URL.
|
||||
|
||||
```{note}
|
||||
`PretrainedInit` has a higher priority than any other initializer. The loaded pretrained weights will overwrite the previous initialized weights.
|
||||
```
|
||||
|
||||
### Commonly used initialization methods
|
||||
|
||||
Similarly, we could use the `Kaiming` initialization just like `Pretrained` initializer. For example, we could make `init_cfg=dict(type='Kaiming', layer='Conv2d')` to initialize all `Conv2d` module with `Kaiming` initialization.
|
||||
|
@ -1,4 +1,4 @@
|
||||
# 初始化
|
||||
# 权重初始化
|
||||
|
||||
基于 Pytorch 构建模型时,我们通常会选择 [nn.Module](https://pytorch.org/docs/stable/nn.html?highlight=nn%20module#module-torch.nn.modules) 作为模型的基类,搭配使用 Pytorch 的初始化模块 [torch.nn.init](https://pytorch.org/docs/stable/nn.init.html?highlight=kaiming#torch.nn.init.kaiming_normal_),完成模型的初始化。MMEngine 在此基础上抽象出基础模块(BaseModule),让我们能够通过传参或配置文件来选择模型的初始化方式。此外,`MMEngine` 还提供了一系列模块初始化函数,让我们能够更加方便灵活地初始化模型参数。
|
||||
|
||||
@ -102,6 +102,10 @@ toy_net.init_weights()
|
||||
|
||||
当 `init_cfg` 是一个字典时,`type` 字段就表示一种初始化器,它需要被注册到 `WEIGHT_INITIALIZERS` [注册器](registry.md)。我们可以通过指定 `init_cfg=dict(type='Pretrained', checkpoint='path/to/ckpt')` 来加载预训练权重,其中 `Pretrained` 为 `PretrainedInit` 初始化器的缩写,这个映射名由 `WEIGHT_INITIALIZERS` 维护;`checkpoint` 是 `PretrainedInit` 的初始化参数,用于指定权重的加载路径,它可以是本地磁盘路径,也可以是 URL。
|
||||
|
||||
```{note}
|
||||
在所有的初始化器中,`PretrainedInit` 拥有最高的优先级。`init_cfg` 中其他初始化器初始化的权重会被 `PretrainedInit` 加载的预训练权重覆盖。
|
||||
```
|
||||
|
||||
### 常用的初始化方式
|
||||
|
||||
和使用 `PretrainedInit` 初始化器类似,如果我们想对卷积做 `Kaiming` 初始化,需要令 `init_cfg=dict(type='Kaiming', layer='Conv2d')`。这样模型初始化时,就会以 `Kaiming` 初始化的方式来初始化类型为 `Conv2d` 的模块。
|
||||
|
@ -5,7 +5,7 @@ import warnings
|
||||
from abc import ABCMeta
|
||||
from collections import defaultdict
|
||||
from logging import FileHandler
|
||||
from typing import Iterable, Optional
|
||||
from typing import Iterable, List, Optional, Union
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
@ -26,11 +26,17 @@ class BaseModule(nn.Module, metaclass=ABCMeta):
|
||||
- ``_params_init_info``: Used to track the parameter initialization
|
||||
information. This attribute only exists during executing the
|
||||
``init_weights``.
|
||||
|
||||
Note:
|
||||
:obj:`PretrainedInit` has a higher priority than any other
|
||||
initializer. The loaded pretrained weights will overwrite
|
||||
the previous initialized weights.
|
||||
|
||||
Args:
|
||||
init_cfg (dict, optional): Initialization config dict.
|
||||
init_cfg (dict or List[dict], optional): Initialization config dict.
|
||||
"""
|
||||
|
||||
def __init__(self, init_cfg=None):
|
||||
def __init__(self, init_cfg: Union[dict, List[dict], None] = None):
|
||||
"""Initialize BaseModule, inherited from `torch.nn.Module`"""
|
||||
|
||||
# NOTE init_cfg can be defined in different levels, but init_cfg
|
||||
@ -100,14 +106,25 @@ class BaseModule(nn.Module, metaclass=ABCMeta):
|
||||
f'initialize {module_name} with init_cfg {self.init_cfg}',
|
||||
logger=logger_name,
|
||||
level=logging.DEBUG)
|
||||
initialize(self, self.init_cfg)
|
||||
|
||||
init_cfgs = self.init_cfg
|
||||
if isinstance(self.init_cfg, dict):
|
||||
# prevent the parameters of
|
||||
# the pre-trained model
|
||||
# from being overwritten by
|
||||
# the `init_weights`
|
||||
if self.init_cfg['type'] == 'Pretrained':
|
||||
return
|
||||
init_cfgs = [self.init_cfg]
|
||||
|
||||
# PretrainedInit has higher priority than any other init_cfg.
|
||||
# Therefore we initialize `pretrained_cfg` last to overwrite
|
||||
# the previous initialized weights.
|
||||
# See details in https://github.com/open-mmlab/mmengine/issues/691 # noqa E501
|
||||
other_cfgs = []
|
||||
pretrained_cfg = []
|
||||
for init_cfg in init_cfgs:
|
||||
assert isinstance(init_cfg, dict)
|
||||
if init_cfg['type'] == 'Pretrained':
|
||||
pretrained_cfg.append(init_cfg)
|
||||
else:
|
||||
other_cfgs.append(init_cfg)
|
||||
|
||||
initialize(self, other_cfgs)
|
||||
|
||||
for m in self.children():
|
||||
if hasattr(m, 'init_weights'):
|
||||
@ -118,7 +135,8 @@ class BaseModule(nn.Module, metaclass=ABCMeta):
|
||||
init_info=f'Initialized by '
|
||||
f'user-defined `init_weights`'
|
||||
f' in {m.__class__.__name__} ')
|
||||
|
||||
if self.init_cfg and pretrained_cfg:
|
||||
initialize(self, pretrained_cfg)
|
||||
self._is_init = True
|
||||
else:
|
||||
warnings.warn(f'init_weights of {self.__class__.__name__} has '
|
||||
|
@ -1,9 +1,13 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
import logging
|
||||
import os.path as osp
|
||||
import tempfile
|
||||
from unittest import TestCase
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn.init import constant_
|
||||
|
||||
from mmengine.logging.logger import MMLogger
|
||||
from mmengine.model import BaseModule, ModuleDict, ModuleList, Sequential
|
||||
@ -90,6 +94,7 @@ class FooModel(BaseModule):
|
||||
class TestBaseModule(TestCase):
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.temp_dir = tempfile.TemporaryDirectory()
|
||||
self.BaseModule = BaseModule()
|
||||
self.model_cfg = dict(
|
||||
type='FooModel',
|
||||
@ -110,6 +115,7 @@ class TestBaseModule(TestCase):
|
||||
self.logger = MMLogger.get_instance(self._testMethodName)
|
||||
|
||||
def tearDown(self) -> None:
|
||||
self.temp_dir.cleanup()
|
||||
logging.shutdown()
|
||||
MMLogger._instance_dict.clear()
|
||||
return super().tearDown()
|
||||
@ -177,6 +183,45 @@ class TestBaseModule(TestCase):
|
||||
assert torch.equal(self.model.reg.bias,
|
||||
torch.full(self.model.reg.bias.shape, 2.0))
|
||||
|
||||
# Test build model from Pretrained weights
|
||||
|
||||
class CustomLinear(BaseModule):
|
||||
|
||||
def __init__(self, init_cfg=None):
|
||||
super().__init__(init_cfg)
|
||||
self.linear = nn.Linear(1, 1)
|
||||
|
||||
def init_weights(self):
|
||||
constant_(self.linear.weight, 1)
|
||||
constant_(self.linear.bias, 2)
|
||||
|
||||
@FOOMODELS.register_module()
|
||||
class PratrainedModel(FooModel):
|
||||
|
||||
def __init__(self,
|
||||
component1=None,
|
||||
component2=None,
|
||||
component3=None,
|
||||
component4=None,
|
||||
init_cfg=None) -> None:
|
||||
super().__init__(component1, component2, component3,
|
||||
component4, init_cfg)
|
||||
self.linear = CustomLinear()
|
||||
|
||||
checkpoint_path = osp.join(self.temp_dir.name, 'test.pth')
|
||||
torch.save(self.model.state_dict(), checkpoint_path)
|
||||
model_cfg = copy.deepcopy(self.model_cfg)
|
||||
model_cfg['type'] = 'PratrainedModel'
|
||||
model_cfg['init_cfg'] = dict(
|
||||
type='Pretrained', checkpoint=checkpoint_path)
|
||||
model = FOOMODELS.build(model_cfg)
|
||||
ori_layer_weight = model.linear.linear.weight.clone()
|
||||
ori_layer_bias = model.linear.linear.bias.clone()
|
||||
model.init_weights()
|
||||
|
||||
self.assertTrue((ori_layer_weight != model.linear.linear.weight).any())
|
||||
self.assertTrue((ori_layer_bias != model.linear.linear.bias).any())
|
||||
|
||||
def test_dump_init_info(self):
|
||||
import os
|
||||
import shutil
|
||||
|
Loading…
x
Reference in New Issue
Block a user