mmselfsup/docs/zh_cn/tutorials/3_new_module.md

205 lines
4.5 KiB
Markdown
Raw Normal View History

# 教程 3添加新的模块
2021-12-15 19:06:36 +08:00
- [教程 3添加新的模块](#教程-3-添加新的模块)
- [添加新的 backbone](#添加新的-backbone)
- [添加新的 Necks](#添加新的-Necks)
- [添加新的损失](#添加新的损失)
- [合并所有改动](#合并所有改动)
2021-12-15 19:06:36 +08:00
在自监督学习领域,每个模型可以被分为以下四个部分:
2021-12-15 19:06:36 +08:00
- backbone用于提取图像特征。
- projection head将 backbone 提取的特征映射到另一空间。
- loss用于模型优化的损失函数。
- memory bank可选一些方法例如 `odc` ),需要额外的 memory bank 用于存储图像特征。
2021-12-15 19:06:36 +08:00
## 添加新的 backbone
2021-12-15 19:06:36 +08:00
假设我们要创建一个自定义的 backbone `CustomizedBackbone`
2021-12-15 19:06:36 +08:00
1.创建新文件 `mmselfsup/models/backbones/customized_backbone.py` 并在其中实现 `CustomizedBackbone`
2021-12-15 19:06:36 +08:00
```python
2021-12-15 19:06:36 +08:00
import torch.nn as nn
from ..builder import BACKBONES
@BACKBONES.register_module()
class CustomizedBackbone(nn.Module):
def __init__(self, **kwargs):
## TODO
def forward(self, x):
## TODO
def init_weights(self, pretrained=None):
## TODO
def train(self, mode=True):
## TODO
```
2.在 `mmselfsup/models/backbones/__init__.py` 中导入自定义的 backbone。
2021-12-15 19:06:36 +08:00
```python
2021-12-15 19:06:36 +08:00
from .customized_backbone import CustomizedBackbone
__all__ = [
..., 'CustomizedBackbone'
]
```
3.在你的配置文件中使用它。
2021-12-15 19:06:36 +08:00
```python
2021-12-15 19:06:36 +08:00
model = dict(
...
backbone=dict(
type='CustomizedBackbone',
...),
...
)
```
## 添加新的 Necks
2021-12-15 19:06:36 +08:00
我们在 `mmselfsup/models/necks` 中包含了所有的 projection heads。假设我们要创建一个 `CustomizedProjHead`
2021-12-15 19:06:36 +08:00
1.创建一个新文件 `mmselfsup/models/necks/customized_proj_head.py` 并在其中实现 `CustomizedProjHead`
2021-12-15 19:06:36 +08:00
```python
2021-12-15 19:06:36 +08:00
import torch.nn as nn
from mmcv.runner import BaseModule
from ..builder import NECKS
@NECKS.register_module()
class CustomizedProjHead(BaseModule):
def __init__(self, *args, **kwargs):
super(CustomizedProjHead, self).__init__(init_cfg)
## TODO
def forward(self, x):
## TODO
```
你需要实现前向函数,该函数从 backbone 中获取特征,并输出映射后的特征。
2021-12-15 19:06:36 +08:00
2.在 `mmselfsup/models/necks/__init__` 中导入 `CustomizedProjHead`
2021-12-15 19:06:36 +08:00
```python
2021-12-15 19:06:36 +08:00
from .customized_proj_head import CustomizedProjHead
__all__ = [
...,
CustomizedProjHead,
...
]
```
3.在你的配置文件中使用它。
2021-12-15 19:06:36 +08:00
```python
2021-12-15 19:06:36 +08:00
model = dict(
...,
neck=dict(
type='CustomizedProjHead',
...),
...)
```
## 添加新的损失
2021-12-15 19:06:36 +08:00
为了增加一个新的损失函数,我们主要在损失模块中实现 `forward` 函数。
2021-12-15 19:06:36 +08:00
1.创建一个新的文件 `mmselfsup/models/heads/customized_head.py` 并在其中实现你自定义的 `CustomizedHead`
2021-12-15 19:06:36 +08:00
```python
2021-12-15 19:06:36 +08:00
import torch
import torch.nn as nn
from mmcv.runner import BaseModule
from ..builder import HEADS
@HEADS.register_module()
class CustomizedHead(BaseModule):
def __init__(self, *args, **kwargs):
super(CustomizedHead, self).__init__()
## TODO
def forward(self, *args, **kwargs):
## TODO
```
2.在 `mmselfsup/models/heads/__init__.py` 中导入该模块。
2021-12-15 19:06:36 +08:00
```python
2021-12-15 19:06:36 +08:00
from .customized_head import CustomizedHead
__all__ = [..., CustomizedHead, ...]
```
3.在你的配置文件中使用它。
2021-12-15 19:06:36 +08:00
```python
2021-12-15 19:06:36 +08:00
model = dict(
...,
head=dict(type='CustomizedHead')
)
```
## 合并所有改动
2021-12-15 19:06:36 +08:00
在创建了上述每个组件后,我们需要创建一个 `CustomizedAlgorithm` 来有逻辑的将他们组织到一起。 `CustomizedAlgorithm` 接收原始图像作为输入,并将损失输出给优化器。
2021-12-15 19:06:36 +08:00
1.创建一个新文件 `mmselfsup/models/algorithms/customized_algorithm.py` 并在其中实现 `CustomizedAlgorithm`
2021-12-15 19:06:36 +08:00
```python
2021-12-15 19:06:36 +08:00
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from ..builder import ALGORITHMS, build_backbone, build_head, build_neck
from ..utils import GatherLayer
from .base import BaseModel
@ALGORITHMS.register_module()
class CustomizedAlgorithm(BaseModel):
def __init__(self, backbone, neck=None, head=None, init_cfg=None):
super(SimCLR, self).__init__(init_cfg)
## TODO
def forward_train(self, img, **kwargs):
## TODO
```
2.在 `mmselfsup/models/algorithms/__init__.py` 中导入该模块。
2021-12-15 19:06:36 +08:00
```python
2021-12-15 19:06:36 +08:00
from .customized_algorithm import CustomizedAlgorithm
__all__ = [..., CustomizedAlgorithm, ...]
```
3.在你的配置文件中使用它。
2021-12-15 19:06:36 +08:00
```python
2021-12-15 19:06:36 +08:00
model = dict(
type='CustomizedAlgorightm',
backbone=...,
neck=...,
head=...)
```