2022-08-31 20:54:15 +08:00
# Add New Modules
2020-07-07 20:52:19 +08:00
## Develop new components
2022-09-29 16:42:19 +08:00
We can customize all the components introduced at [the model documentation ](./models.md ), such as **backbone** , **head** , **loss function** and **data preprocessor** .
2020-07-07 20:52:19 +08:00
### Add new backbones
2022-09-29 16:42:19 +08:00
Here we show how to develop a new backbone with an example of MobileNet.
2020-07-07 20:52:19 +08:00
1. Create a new file `mmseg/models/backbones/mobilenet.py` .
2022-09-29 16:42:19 +08:00
```python
import torch.nn as nn
2020-07-07 20:52:19 +08:00
2022-09-29 16:42:19 +08:00
from mmseg.registry import MODELS
2020-07-07 20:52:19 +08:00
2022-09-29 16:42:19 +08:00
@MODELS .register_module()
class MobileNet(nn.Module):
2020-07-07 20:52:19 +08:00
2022-09-29 16:42:19 +08:00
def __init__ (self, arg1, arg2):
pass
2020-07-07 20:52:19 +08:00
2022-09-29 16:42:19 +08:00
def forward(self, x): # should return a tuple
pass
2020-07-07 20:52:19 +08:00
2022-09-29 16:42:19 +08:00
def init_weights(self, pretrained=None):
pass
```
2020-07-07 20:52:19 +08:00
2. Import the module in `mmseg/models/backbones/__init__.py` .
2022-09-29 16:42:19 +08:00
```python
from .mobilenet import MobileNet
```
2020-07-07 20:52:19 +08:00
3. Use it in your config file.
2022-09-29 16:42:19 +08:00
```python
model = dict(
...
backbone=dict(
type='MobileNet',
arg1=xxx,
arg2=xxx),
...
```
2020-07-07 20:52:19 +08:00
### Add new heads
2023-03-31 16:26:30 +08:00
In MMSegmentation, we provide a [BaseDecodeHead ](https://github.com/open-mmlab/mmsegmentation/blob/main/mmseg/models/decode_heads/decode_head.py#L17 ) for developing all segmentation heads.
2020-07-07 20:52:19 +08:00
All newly implemented decode heads should be derived from it.
Here we show how to develop a new head with the example of [PSPNet ](https://arxiv.org/abs/1612.01105 ) as the following.
First, add a new decode head in `mmseg/models/decode_heads/psp_head.py` .
PSPNet implements a decode head for segmentation decode.
2022-09-29 16:42:19 +08:00
To implement a decode head, we need to implement three functions of the new module as the following.
2020-07-07 20:52:19 +08:00
```python
2022-08-31 20:54:15 +08:00
from mmseg.registry import MODELS
@MODELS .register_module()
2020-07-07 20:52:19 +08:00
class PSPHead(BaseDecodeHead):
def __init__ (self, pool_scales=(1, 2, 3, 6), **kwargs):
super(PSPHead, self).__init__(**kwargs)
def init_weights(self):
2022-09-29 16:42:19 +08:00
pass
2020-07-07 20:52:19 +08:00
def forward(self, inputs):
2022-09-29 16:42:19 +08:00
pass
2020-07-07 20:52:19 +08:00
```
2022-09-29 16:42:19 +08:00
Next, the users need to add the module in the `mmseg/models/decode_heads/__init__.py` , thus the corresponding registry could find and load them.
2020-07-07 20:52:19 +08:00
To config file of PSPNet is as the following
```python
norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict(
type='EncoderDecoder',
pretrained='pretrain_model/resnet50_v1c_trick-2cccc1ad.pth',
backbone=dict(
type='ResNetV1c',
depth=50,
num_stages=4,
out_indices=(0, 1, 2, 3),
dilations=(1, 1, 2, 4),
strides=(1, 2, 1, 1),
norm_cfg=norm_cfg,
norm_eval=False,
style='pytorch',
contract_dilation=True),
decode_head=dict(
type='PSPHead',
in_channels=2048,
in_index=3,
channels=512,
pool_scales=(1, 2, 3, 6),
2020-11-18 12:22:06 +08:00
dropout_ratio=0.1,
2020-07-07 20:52:19 +08:00
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)))
```
### Add new loss
Assume you want to add a new loss as `MyLoss` for segmentation decode.
2022-09-29 16:42:19 +08:00
To add a new loss function, the users need to implement it in `mmseg/models/losses/my_loss.py` .
The decorator `weighted_loss` enables the loss to be weighted for each element.
2020-07-07 20:52:19 +08:00
```python
import torch
import torch.nn as nn
2022-08-31 20:54:15 +08:00
from mmseg.registry import MODELS
2020-07-07 20:52:19 +08:00
from .utils import weighted_loss
@weighted_loss
def my_loss(pred, target):
assert pred.size() == target.size() and target.numel() > 0
loss = torch.abs(pred - target)
return loss
2022-09-29 16:42:19 +08:00
@MODELS .register_module()
2020-07-07 20:52:19 +08:00
class MyLoss(nn.Module):
def __init__ (self, reduction='mean', loss_weight=1.0):
super(MyLoss, self).__init__()
self.reduction = reduction
self.loss_weight = loss_weight
def forward(self,
pred,
target,
weight=None,
avg_factor=None,
reduction_override=None):
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
loss = self.loss_weight * my_loss(
pred, target, weight, reduction=reduction, avg_factor=avg_factor)
return loss
```
Then the users need to add it in the `mmseg/models/losses/__init__.py` .
```python
from .my_loss import MyLoss, my_loss
```
To use it, modify the `loss_xxx` field.
Then you need to modify the `loss_decode` field in the head.
`loss_weight` could be used to balance multiple losses.
```python
loss_decode=dict(type='MyLoss', loss_weight=1.0))
```
2022-09-29 16:42:19 +08:00
### Add new data preprocessor
2023-03-31 16:26:30 +08:00
In MMSegmentation 1.x versions, we use [SegDataPreProcessor ](https://github.com/open-mmlab/mmsegmentation/blob/main/mmseg/models/data_preprocessor.py#L13 ) to copy data to the target device and preprocess the data into the model input format as default. Here we show how to develop a new data preprocessor.
2022-09-29 16:42:19 +08:00
1. Create a new file `mmseg/models/my_datapreprocessor.py` .
```python
from mmengine.model import BaseDataPreprocessor
from mmseg.registry import MODELS
@MODELS .register_module()
class MyDataPreProcessor(BaseDataPreprocessor):
def __init__ (self, **kwargs):
super().__init__(**kwargs)
def forward(self, data: dict, training: bool=False) -> Dict[str, Any]:
# TODO Define the logic for data pre-processing in the forward method
pass
```
2. Import your data preprocessor in `mmseg/models/__init__.py`
```python
from .my_datapreprocessor import MyDataPreProcessor
```
3. Use it in your config file.
```python
model = dict(
data_preprocessor=dict(type='MyDataPreProcessor)
...
)
```
## Develop new segmentors
2023-03-06 18:03:12 +08:00
The segmentor is an algorithmic architecture in which users can customize their algorithms by adding customized components and defining the logic of algorithm execution. Please refer to [the model document ](./models.md ) for more details.
2022-09-29 16:42:19 +08:00
2023-03-31 16:26:30 +08:00
Since the [BaseSegmentor ](https://github.com/open-mmlab/mmsegmentation/blob/main/mmseg/models/segmentors/base.py#L15 ) in MMSegmentation unifies three modes for a forward process, to develop a new segmentor, users need to overwrite `loss` , `predict` and `_forward` methods corresponding to the `loss` , `predict` and `tensor` modes.
2022-09-29 16:42:19 +08:00
Here we show how to develop a new segmentor.
1. Create a new file `mmseg/models/segmentors/my_segmentor.py` .
```python
from typing import Dict, Optional, Union
import torch
from mmseg.registry import MODELS
from mmseg.models import BaseSegmentor
@MODELS .register_module()
class MySegmentor(BaseSegmentor):
def __init__ (self, **kwargs):
super().__init__(**kwargs)
# TODO users should build components of the network here
def loss(self, inputs: Tensor, data_samples: SampleList) -> dict:
"""Calculate losses from a batch of inputs and data samples."""
pass
def predict(self, inputs: Tensor, data_samples: OptSampleList=None) -> SampleList:
"""Predict results from a batch of inputs and data samples with post-
processing."""
pass
def _forward(self,
inputs: Tensor,
data_samples: OptSampleList = None) -> Tuple[List[Tensor]]:
"""Network forward process.
Usually includes backbone, neck and head forward without any post-
processing.
"""
pass
```
2. Import your segmentor in `mmseg/models/segmentors/__init__.py` .
```python
from .my_segmentor import MySegmentor
```
3. Use it in your config file.
```python
model = dict(
type='MySegmentor'
...
)
```