mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
[Doc] Update add module doc (#2067)
* develop segmentor and remove custom optim * update segmentor example * add data preprocessor * refine intro
This commit is contained in:
parent
b3b7629d38
commit
5d1faeabf0
@ -1,82 +1,12 @@
|
|||||||
# Add New Modules
|
# Add New Modules
|
||||||
|
|
||||||
## Customize optimizer
|
|
||||||
|
|
||||||
Assume you want to add a optimizer named as `MyOptimizer`, which has arguments `a`, `b`, and `c`.
|
|
||||||
You need to first implement the new optimizer in a file, e.g., in `mmseg/engine/optimizers/my_optimizer.py`:
|
|
||||||
|
|
||||||
```python
|
|
||||||
from mmcv.runner import OPTIMIZERS
|
|
||||||
from torch.optim import Optimizer
|
|
||||||
|
|
||||||
|
|
||||||
@OPTIMIZERS.register_module
|
|
||||||
class MyOptimizer(Optimizer):
|
|
||||||
|
|
||||||
def __init__(self, a, b, c)
|
|
||||||
|
|
||||||
```
|
|
||||||
|
|
||||||
Then add this module in `mmseg/engine/optimizers/__init__.py` thus the registry will
|
|
||||||
find the new module and add it:
|
|
||||||
|
|
||||||
```python
|
|
||||||
from .my_optimizer import MyOptimizer
|
|
||||||
```
|
|
||||||
|
|
||||||
Then you can use `MyOptimizer` in `optimizer` field of config files.
|
|
||||||
In the configs, the optimizers are defined by the field `optimizer` like the following:
|
|
||||||
|
|
||||||
```python
|
|
||||||
optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)
|
|
||||||
```
|
|
||||||
|
|
||||||
To use your own optimizer, the field can be changed as
|
|
||||||
|
|
||||||
```python
|
|
||||||
optimizer = dict(type='MyOptimizer', a=a_value, b=b_value, c=c_value)
|
|
||||||
```
|
|
||||||
|
|
||||||
We already support to use all the optimizers implemented by PyTorch, and the only modification is to change the `optimizer` field of config files.
|
|
||||||
For example, if you want to use `ADAM`, though the performance will drop a lot, the modification could be as the following.
|
|
||||||
|
|
||||||
```python
|
|
||||||
optimizer = dict(type='Adam', lr=0.0003, weight_decay=0.0001)
|
|
||||||
```
|
|
||||||
|
|
||||||
The users can directly set arguments following the [API doc](https://pytorch.org/docs/stable/optim.html?highlight=optim#module-torch.optim) of PyTorch.
|
|
||||||
|
|
||||||
## Customize optimizer constructor
|
|
||||||
|
|
||||||
Some models may have some parameter-specific settings for optimization, e.g. weight decay for BatchNoarm layers.
|
|
||||||
The users can do those fine-grained parameter tuning through customizing optimizer constructor.
|
|
||||||
|
|
||||||
```python
|
|
||||||
from mmseg.registry import OPTIM_WRAPPER_CONSTRUCTORS
|
|
||||||
from .cocktail_optimizer import CocktailOptimizer
|
|
||||||
|
|
||||||
|
|
||||||
@OPTIM_WRAPPER_CONSTRUCTORS.register_module
|
|
||||||
class CocktailOptimizerConstructor(object):
|
|
||||||
|
|
||||||
def __init__(self, optim_wrapper_cfg, paramwise_cfg=None):
|
|
||||||
|
|
||||||
def __call__(self, model):
|
|
||||||
|
|
||||||
return my_optimizer
|
|
||||||
|
|
||||||
```
|
|
||||||
|
|
||||||
## Develop new components
|
## Develop new components
|
||||||
|
|
||||||
There are mainly 2 types of components in MMSegmentation.
|
We can customize all the components introduced at [the model documentation](./models.md), such as **backbone**, **head**, **loss function** and **data preprocessor**.
|
||||||
|
|
||||||
- backbone: usually stacks of convolutional network to extract feature maps, e.g., ResNet, HRNet.
|
|
||||||
- head: the component for semantic segmentation map decoding.
|
|
||||||
|
|
||||||
### Add new backbones
|
### Add new backbones
|
||||||
|
|
||||||
Here we show how to develop new components with an example of MobileNet.
|
Here we show how to develop a new backbone with an example of MobileNet.
|
||||||
|
|
||||||
1. Create a new file `mmseg/models/backbones/mobilenet.py`.
|
1. Create a new file `mmseg/models/backbones/mobilenet.py`.
|
||||||
|
|
||||||
@ -86,7 +16,7 @@ import torch.nn as nn
|
|||||||
from mmseg.registry import MODELS
|
from mmseg.registry import MODELS
|
||||||
|
|
||||||
|
|
||||||
@MODELS.register_module
|
@MODELS.register_module()
|
||||||
class MobileNet(nn.Module):
|
class MobileNet(nn.Module):
|
||||||
|
|
||||||
def __init__(self, arg1, arg2):
|
def __init__(self, arg1, arg2):
|
||||||
@ -119,13 +49,13 @@ model = dict(
|
|||||||
|
|
||||||
### Add new heads
|
### Add new heads
|
||||||
|
|
||||||
In MMSegmentation, we provide a base [BaseDecodeHead](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/mmseg/models/decode_heads/decode_head.py) for all segmentation head.
|
In MMSegmentation, we provide a [BaseDecodeHead](https://github.com/open-mmlab/mmsegmentation/blob/1.x/mmseg/models/decode_heads/decode_head.py#L17) for developing all segmentation heads.
|
||||||
All newly implemented decode heads should be derived from it.
|
All newly implemented decode heads should be derived from it.
|
||||||
Here we show how to develop a new head with the example of [PSPNet](https://arxiv.org/abs/1612.01105) as the following.
|
Here we show how to develop a new head with the example of [PSPNet](https://arxiv.org/abs/1612.01105) as the following.
|
||||||
|
|
||||||
First, add a new decode head in `mmseg/models/decode_heads/psp_head.py`.
|
First, add a new decode head in `mmseg/models/decode_heads/psp_head.py`.
|
||||||
PSPNet implements a decode head for segmentation decode.
|
PSPNet implements a decode head for segmentation decode.
|
||||||
To implement a decode head, basically we need to implement three functions of the new module as the following.
|
To implement a decode head, we need to implement three functions of the new module as the following.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from mmseg.registry import MODELS
|
from mmseg.registry import MODELS
|
||||||
@ -137,12 +67,13 @@ class PSPHead(BaseDecodeHead):
|
|||||||
super(PSPHead, self).__init__(**kwargs)
|
super(PSPHead, self).__init__(**kwargs)
|
||||||
|
|
||||||
def init_weights(self):
|
def init_weights(self):
|
||||||
|
pass
|
||||||
|
|
||||||
def forward(self, inputs):
|
def forward(self, inputs):
|
||||||
|
pass
|
||||||
```
|
```
|
||||||
|
|
||||||
Next, the users need to add the module in the `mmseg/models/decode_heads/__init__.py` thus the corresponding registry could find and load them.
|
Next, the users need to add the module in the `mmseg/models/decode_heads/__init__.py`, thus the corresponding registry could find and load them.
|
||||||
|
|
||||||
To config file of PSPNet is as the following
|
To config file of PSPNet is as the following
|
||||||
|
|
||||||
@ -180,8 +111,8 @@ model = dict(
|
|||||||
### Add new loss
|
### Add new loss
|
||||||
|
|
||||||
Assume you want to add a new loss as `MyLoss` for segmentation decode.
|
Assume you want to add a new loss as `MyLoss` for segmentation decode.
|
||||||
To add a new loss function, the users need implement it in `mmseg/models/losses/my_loss.py`.
|
To add a new loss function, the users need to implement it in `mmseg/models/losses/my_loss.py`.
|
||||||
The decorator `weighted_loss` enable the loss to be weighted for each element.
|
The decorator `weighted_loss` enables the loss to be weighted for each element.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
import torch
|
import torch
|
||||||
@ -196,7 +127,7 @@ def my_loss(pred, target):
|
|||||||
loss = torch.abs(pred - target)
|
loss = torch.abs(pred - target)
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
@LOSSES.register_module
|
@MODELS.register_module()
|
||||||
class MyLoss(nn.Module):
|
class MyLoss(nn.Module):
|
||||||
|
|
||||||
def __init__(self, reduction='mean', loss_weight=1.0):
|
def __init__(self, reduction='mean', loss_weight=1.0):
|
||||||
@ -232,3 +163,98 @@ Then you need to modify the `loss_decode` field in the head.
|
|||||||
```python
|
```python
|
||||||
loss_decode=dict(type='MyLoss', loss_weight=1.0))
|
loss_decode=dict(type='MyLoss', loss_weight=1.0))
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Add new data preprocessor
|
||||||
|
|
||||||
|
In MMSegmentation 1.x versions, we use [SegDataPreProcessor](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/mmseg/models/data_preprocessor.py#L13) to copy data to the target device and preprocess the data into the model input format as default. Here we show how to develop a new data preprocessor.
|
||||||
|
|
||||||
|
1. Create a new file `mmseg/models/my_datapreprocessor.py`.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from mmengine.model import BaseDataPreprocessor
|
||||||
|
|
||||||
|
from mmseg.registry import MODELS
|
||||||
|
|
||||||
|
@MODELS.register_module()
|
||||||
|
class MyDataPreProcessor(BaseDataPreprocessor):
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
def forward(self, data: dict, training: bool=False) -> Dict[str, Any]:
|
||||||
|
# TODO Define the logic for data pre-processing in the forward method
|
||||||
|
pass
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Import your data preprocessor in `mmseg/models/__init__.py`
|
||||||
|
|
||||||
|
```python
|
||||||
|
from .my_datapreprocessor import MyDataPreProcessor
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Use it in your config file.
|
||||||
|
|
||||||
|
```python
|
||||||
|
model = dict(
|
||||||
|
data_preprocessor=dict(type='MyDataPreProcessor)
|
||||||
|
...
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Develop new segmentors
|
||||||
|
|
||||||
|
The segmentor is an algorithmic architecture in which users can customize their algorithms by adding customized components and defining the logic of algorithm execution. Please refer to [the model document](https://github.com/open-mmlab/mmsegmentation/blob/1.x/docs/en/advanced_guides/models.md) for more details.
|
||||||
|
|
||||||
|
Since the [BaseSegmentor](https://github.com/open-mmlab/mmsegmentation/blob/1.x/mmseg/models/segmentors/base.py#L15) in MMSegmentation unifies three modes for a forward process, to develop a new segmentor, users need to overwrite `loss`, `predict` and `_forward` methods corresponding to the `loss`, `predict` and `tensor` modes.
|
||||||
|
|
||||||
|
Here we show how to develop a new segmentor.
|
||||||
|
|
||||||
|
1. Create a new file `mmseg/models/segmentors/my_segmentor.py`.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from typing import Dict, Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from mmseg.registry import MODELS
|
||||||
|
from mmseg.models import BaseSegmentor
|
||||||
|
|
||||||
|
@MODELS.register_module()
|
||||||
|
class MySegmentor(BaseSegmentor):
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
# TODO users should build components of the network here
|
||||||
|
|
||||||
|
def loss(self, inputs: Tensor, data_samples: SampleList) -> dict:
|
||||||
|
"""Calculate losses from a batch of inputs and data samples."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def predict(self, inputs: Tensor, data_samples: OptSampleList=None) -> SampleList:
|
||||||
|
"""Predict results from a batch of inputs and data samples with post-
|
||||||
|
processing."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _forward(self,
|
||||||
|
inputs: Tensor,
|
||||||
|
data_samples: OptSampleList = None) -> Tuple[List[Tensor]]:
|
||||||
|
"""Network forward process.
|
||||||
|
|
||||||
|
Usually includes backbone, neck and head forward without any post-
|
||||||
|
processing.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Import your segmentor in `mmseg/models/segmentors/__init__.py`.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from .my_segmentor import MySegmentor
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Use it in your config file.
|
||||||
|
|
||||||
|
```python
|
||||||
|
model = dict(
|
||||||
|
type='MySegmentor'
|
||||||
|
...
|
||||||
|
)
|
||||||
|
```
|
||||||
|
Loading…
x
Reference in New Issue
Block a user