mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
[Doc] Add EN custmized runtime doc in dev-1.x (#2533)
## Motivation Translate Chinese version customized runtime doc into English https://github.com/open-mmlab/mmsegmentation/pull/2502.
This commit is contained in:
parent
432628b735
commit
b2577e0ba0
@ -1,30 +1,83 @@
|
||||
# Customize Runtime Settings
|
||||
|
||||
## Customize optimization settings
|
||||
## Customize hooks
|
||||
|
||||
### Customize optimizer supported by Pytorch
|
||||
### Step 1: Implement a new hook
|
||||
|
||||
We already support to use all the optimizers implemented by PyTorch, and the only modification is to change the `optimizer` field of config files.
|
||||
For example, if you want to use `ADAM` (note that the performance could drop a lot), the modification could be as the following.
|
||||
MMEngine has implemented commonly used [hooks](https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/hook.md) for training and test,
|
||||
When users have requirements for customization, they can follow examples below.
|
||||
For example, if some hyper-parameter of the model needs to be changed when model training, we can implement a new hook for it:
|
||||
|
||||
```python
|
||||
optimizer = dict(type='Adam', lr=0.0003, weight_decay=0.0001)
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Optional, Sequence
|
||||
|
||||
from mmengine.hooks import Hook
|
||||
from mmengine.model import is_model_wrapper
|
||||
|
||||
from mmseg.registry import HOOKS
|
||||
|
||||
|
||||
@HOOKS.register_module()
|
||||
class NewHook(Hook):
|
||||
"""Docstring for NewHook.
|
||||
"""
|
||||
|
||||
def __init__(self, a: int, b: int) -> None:
|
||||
self.a = a
|
||||
self.b = b
|
||||
|
||||
def before_train_iter(self,
|
||||
runner,
|
||||
batch_idx: int,
|
||||
data_batch: Optional[Sequence[dict]] = None) -> None:
|
||||
cur_iter = runner.iter
|
||||
# acquire this model when it is in a wrapper
|
||||
if is_model_wrapper(runner.model):
|
||||
model = runner.model.module
|
||||
model.hyper_parameter = self.a * cur_iter + self.b
|
||||
```
|
||||
|
||||
To modify the learning rate of the model, the users only need to modify the `lr` in the config of optimizer. The users can directly set arguments following the [API doc](https://pytorch.org/docs/stable/optim.html?highlight=optim#module-torch.optim) of PyTorch.
|
||||
### Step 2: Import a new hook
|
||||
|
||||
### Customize self-implemented optimizer
|
||||
The module which is defined above needs to be imported into main namespace first to ensure being registered.
|
||||
We assume `NewHook` is implemented in `mmseg/engine/hooks/new_hook.py`, there are two ways to import it:
|
||||
|
||||
#### 1. Define a new optimizer
|
||||
|
||||
A customized optimizer could be defined as following.
|
||||
|
||||
Assume you want to add a optimizer named `MyOptimizer`, which has arguments `a`, `b`, and `c`.
|
||||
You need to create a new directory named `mmseg/core/optimizer`.
|
||||
And then implement the new optimizer in a file, e.g., in `mmseg/core/optimizer/my_optimizer.py`:
|
||||
- Import it by modifying `mmseg/engine/hooks/__init__.py`.
|
||||
Modules should be imported in `mmseg/engine/hooks/__init__.py` thus these new modules can be found and added by registry.
|
||||
|
||||
```python
|
||||
from .registry import OPTIMIZERS
|
||||
from .new_hook import NewHook
|
||||
|
||||
__all__ = [..., NewHook]
|
||||
```
|
||||
|
||||
- Import it manually by `custom_imports` in config file.
|
||||
|
||||
```python
|
||||
custom_imports = dict(imports=['mmseg.engine.hooks.new_hook'], allow_failed_imports=False)
|
||||
```
|
||||
|
||||
### Step 3: Modify config file
|
||||
|
||||
Users can set and use customized hooks in training and test followed methods below.
|
||||
The execution priority of hooks at the same place of `Runner` can be referred [here](https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/hook.md#built-in-hooks),
|
||||
Default priority of customized hook is `NORMAL`.
|
||||
|
||||
```python
|
||||
custom_hooks = [
|
||||
dict(type='NewHook', a=a_value, b=b_value, priority='ABOVE_NORMAL')
|
||||
]
|
||||
```
|
||||
|
||||
## Customize optimizer
|
||||
|
||||
### Step 1: Implement a new optimizer
|
||||
|
||||
We recommend the customized optimizer implemented in `mmseg/engine/optimizers/my_optimizer.py`. Here is an example of a new optimizer `MyOptimizer` which has parameters `a`, `b` and `c`:
|
||||
|
||||
```python
|
||||
from mmseg.registry import OPTIMIZERS
|
||||
from torch.optim import Optimizer
|
||||
|
||||
|
||||
@ -32,214 +85,84 @@ from torch.optim import Optimizer
|
||||
class MyOptimizer(Optimizer):
|
||||
|
||||
def __init__(self, a, b, c)
|
||||
|
||||
```
|
||||
|
||||
#### 2. Add the optimizer to registry
|
||||
### Step 2: Import a new optimizer
|
||||
|
||||
To find the above module defined above, this module should be imported into the main namespace at first. There are two options to achieve it.
|
||||
The module which is defined above needs to be imported into main namespace first to ensure being registered.
|
||||
We assume `MyOptimizer` is implemented in `mmseg/engine/optimizers/my_optimizer.py`, there are two ways to import it:
|
||||
|
||||
- Modify `mmseg/core/optimizer/__init__.py` to import it.
|
||||
|
||||
The newly defined module should be imported in `mmseg/core/optimizer/__init__.py` so that the registry will
|
||||
find the new module and add it:
|
||||
- Import it by modifying `mmseg/engine/optimizers/__init__.py`.
|
||||
Modules should be imported in `mmseg/engine/optimizers/__init__.py` thus these new modules can be found and added by registry.
|
||||
|
||||
```python
|
||||
from .my_optimizer import MyOptimizer
|
||||
```
|
||||
|
||||
- Use `custom_imports` in the config to manually import it
|
||||
- Import it manually by `custom_imports` in config file.
|
||||
|
||||
```python
|
||||
custom_imports = dict(imports=['mmseg.core.optimizer.my_optimizer'], allow_failed_imports=False)
|
||||
custom_imports = dict(imports=['mmseg.engine.optimizers.my_optimizer'], allow_failed_imports=False)
|
||||
```
|
||||
|
||||
The module `mmseg.core.optimizer.my_optimizer` will be imported at the beginning of the program and the class `MyOptimizer` is then automatically registered.
|
||||
Note that only the package containing the class `MyOptimizer` should be imported.
|
||||
`mmseg.core.optimizer.my_optimizer.MyOptimizer` **cannot** be imported directly.
|
||||
### Step 3: Modify config file
|
||||
|
||||
Actually users can use a totally different file directory structure using this importing method, as long as the module root can be located in `PYTHONPATH`.
|
||||
|
||||
#### 3. Specify the optimizer in the config file
|
||||
|
||||
Then you can use `MyOptimizer` in `optimizer` field of config files.
|
||||
In the configs, the optimizers are defined by the field `optimizer` like the following:
|
||||
Then it needs to modify `optimizer` in `optim_wrapper` of config file, if users want to use customized `MyOptimizer`, it can be modified as:
|
||||
|
||||
```python
|
||||
optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)
|
||||
optim_wrapper = dict(type='OptimWrapper',
|
||||
optimizer=dict(type='MyOptimizer',
|
||||
a=a_value, b=b_value, c=c_value),
|
||||
clip_grad=None)
|
||||
```
|
||||
|
||||
To use your own optimizer, the field can be changed to
|
||||
## Customize optimizer constructor
|
||||
|
||||
### Step 1: Implement a new optimizer constructor
|
||||
|
||||
Optimizer constructor is used to create optimizer and optimizer wrapper for model training, which has powerful functions like specifying learning rate and weight decay for different model layers.
|
||||
Here is an example for a customized optimizer constructor.
|
||||
|
||||
```python
|
||||
optimizer = dict(type='MyOptimizer', a=a_value, b=b_value, c=c_value)
|
||||
```
|
||||
|
||||
### Customize optimizer constructor
|
||||
|
||||
Some models may have some parameter-specific settings for optimization, e.g. weight decay for BatchNorm layers.
|
||||
The users can do those fine-grained parameter tuning through customizing optimizer constructor.
|
||||
|
||||
```python
|
||||
from mmcv.utils import build_from_cfg
|
||||
|
||||
from mmcv.runner.optimizer import OPTIMIZER_BUILDERS, OPTIMIZERS
|
||||
from mmseg.utils import get_root_logger
|
||||
from .my_optimizer import MyOptimizer
|
||||
|
||||
|
||||
@OPTIMIZER_BUILDERS.register_module()
|
||||
class MyOptimizerConstructor(object):
|
||||
from mmengine.optim import DefaultOptimWrapperConstructor
|
||||
from mmseg.registry import OPTIM_WRAPPER_CONSTRUCTORS
|
||||
|
||||
@OPTIM_WRAPPER_CONSTRUCTORS.register_module()
|
||||
class LearningRateDecayOptimizerConstructor(DefaultOptimWrapperConstructor):
|
||||
def __init__(self, optim_wrapper_cfg, paramwise_cfg=None):
|
||||
|
||||
def __call__(self, model):
|
||||
|
||||
return my_optimizer
|
||||
|
||||
```
|
||||
|
||||
The default optimizer constructor is implemented [here](https://github.com/open-mmlab/mmcv/blob/9ecd6b0d5ff9d2172c49a182eaa669e9f27bb8e7/mmcv/runner/optimizer/default_constructor.py#L11), which could also serve as a template for new optimizer constructor.
|
||||
Default optimizer constructor is implemented [here](https://github.com/open-mmlab/mmengine/blob/main/mmengine/optim/optimizer/default_constructor.py#L19).
|
||||
It can also be used as base class of new optimizer constructor.
|
||||
|
||||
### Additional settings
|
||||
### Step 2: Import a new optimizer constructor
|
||||
|
||||
Tricks not implemented by the optimizer should be implemented through optimizer constructor (e.g., set parameter-wise learning rates) or hooks. We list some common settings that could stabilize the training or accelerate the training. Feel free to create PR, issue for more settings.
|
||||
The module which is defined above needs to be imported into main namespace first to ensure being registered.
|
||||
We assume `MyOptimizerConstructor` is implemented in `mmseg/engine/optimizers/my_optimizer_constructor.py`, there are two ways to import it:
|
||||
|
||||
- __Use gradient clip to stabilize training__:
|
||||
Some models need gradient clip to clip the gradients to stabilize the training process. An example is as below:
|
||||
|
||||
```python
|
||||
optimizer_config = dict(
|
||||
_delete_=True, grad_clip=dict(max_norm=35, norm_type=2))
|
||||
```
|
||||
|
||||
If your config inherits the base config which already sets the `optimizer_config`, you might need `_delete_=True` to override the unnecessary settings. See the [config documentation](https://mmsegmentation.readthedocs.io/en/latest/config.html) for more details.
|
||||
|
||||
- __Use momentum schedule to accelerate model convergence__:
|
||||
We support momentum scheduler to modify model's momentum according to learning rate, which could make the model converge in a faster way.
|
||||
Momentum scheduler is usually used with LR scheduler, for example, the following config is used in 3D detection to accelerate convergence.
|
||||
For more details, please refer to the implementation of [CyclicLrUpdater](https://github.com/open-mmlab/mmcv/blob/f48241a65aebfe07db122e9db320c31b685dc674/mmcv/runner/hooks/lr_updater.py#L327) and [CyclicMomentumUpdater](https://github.com/open-mmlab/mmcv/blob/f48241a65aebfe07db122e9db320c31b685dc674/mmcv/runner/hooks/momentum_updater.py#L130).
|
||||
|
||||
```python
|
||||
lr_config = dict(
|
||||
policy='cyclic',
|
||||
target_ratio=(10, 1e-4),
|
||||
cyclic_times=1,
|
||||
step_ratio_up=0.4,
|
||||
)
|
||||
momentum_config = dict(
|
||||
policy='cyclic',
|
||||
target_ratio=(0.85 / 0.95, 1),
|
||||
cyclic_times=1,
|
||||
step_ratio_up=0.4,
|
||||
)
|
||||
```
|
||||
|
||||
## Customize training schedules
|
||||
|
||||
By default we use step learning rate with 40k/80k schedule, this calls [`PolyLrUpdaterHook`](https://github.com/open-mmlab/mmcv/blob/826d3a7b68596c824fa1e2cb89b6ac274f52179c/mmcv/runner/hooks/lr_updater.py#L196) in MMCV.
|
||||
We support many other learning rate schedule [here](https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/hooks/lr_updater.py), such as `CosineAnnealing` and `Poly` schedule. Here are some examples
|
||||
|
||||
- Step schedule:
|
||||
|
||||
```python
|
||||
lr_config = dict(policy='step', step=[9, 10])
|
||||
```
|
||||
|
||||
- ConsineAnnealing schedule:
|
||||
|
||||
```python
|
||||
lr_config = dict(
|
||||
policy='CosineAnnealing',
|
||||
warmup='linear',
|
||||
warmup_iters=1000,
|
||||
warmup_ratio=1.0 / 10,
|
||||
min_lr_ratio=1e-5)
|
||||
```
|
||||
|
||||
## Customize workflow
|
||||
|
||||
Workflow is a list of (phase, epochs) to specify the running order and epochs.
|
||||
By default it is set to be
|
||||
- Import it by modifying `mmseg/engine/optimizers/__init__.py`.
|
||||
Modules should be imported in `mmseg/engine/optimizers/__init__.py` thus these new modules can be found and added by registry.
|
||||
|
||||
```python
|
||||
workflow = [('train', 1)]
|
||||
from .my_optimizer_constructor import MyOptimizerConstructor
|
||||
```
|
||||
|
||||
which means running 1 epoch for training.
|
||||
Sometimes user may want to check some metrics (e.g. loss, accuracy) about the model on the validate set.
|
||||
In such case, we can set the workflow as
|
||||
- Import it manually by `custom_imports` in config file.
|
||||
|
||||
```python
|
||||
[('train', 1), ('val', 1)]
|
||||
custom_imports = dict(imports=['mmseg.engine.optimizers.my_optimizer_constructor'], allow_failed_imports=False)
|
||||
```
|
||||
|
||||
so that 1 epoch for training and 1 epoch for validation will be run iteratively.
|
||||
### Step 3: Modify config file
|
||||
|
||||
:::{note}
|
||||
|
||||
1. The parameters of model will not be updated during val epoch.
|
||||
2. Keyword `total_epochs` in the config only controls the number of training epochs and will not affect the validation workflow.
|
||||
3. Workflows `[('train', 1), ('val', 1)]` and `[('train', 1)]` will not change the behavior of `EvalHook` because `EvalHook` is called by `after_train_epoch` and validation workflow only affect hooks that are called through `after_val_epoch`. Therefore, the only difference between `[('train', 1), ('val', 1)]` and `[('train', 1)]` is that the runner will calculate losses on validation set after each training epoch.
|
||||
|
||||
:::
|
||||
|
||||
## Customize hooks
|
||||
|
||||
### Use hooks implemented in MMCV
|
||||
|
||||
If the hook is already implemented in MMCV, you can directly modify the config to use the hook as below
|
||||
Then it needs to modify `constructor` in `optim_wrapper` of config file, if users want to use customized `MyOptimizerConstructor`, it can be modified as:
|
||||
|
||||
```python
|
||||
custom_hooks = [
|
||||
dict(type='MyHook', a=a_value, b=b_value, priority='NORMAL')
|
||||
]
|
||||
```
|
||||
|
||||
### Modify default runtime hooks
|
||||
|
||||
There are some common hooks that are not registered through `custom_hooks`, they are
|
||||
|
||||
- log_config
|
||||
- checkpoint_config
|
||||
- evaluation
|
||||
- lr_config
|
||||
- optimizer_config
|
||||
- momentum_config
|
||||
|
||||
In those hooks, only the logger hook has the `VERY_LOW` priority, others' priority are `NORMAL`.
|
||||
The above-mentioned tutorials already covers how to modify `optimizer_config`, `momentum_config`, and `lr_config`.
|
||||
Here we reveals how what we can do with `log_config`, `checkpoint_config`, and `evaluation`.
|
||||
|
||||
#### Checkpoint config
|
||||
|
||||
The MMCV runner will use `checkpoint_config` to initialize [`CheckpointHook`](https://github.com/open-mmlab/mmcv/blob/9ecd6b0d5ff9d2172c49a182eaa669e9f27bb8e7/mmcv/runner/hooks/checkpoint.py#L9).
|
||||
|
||||
```python
|
||||
checkpoint_config = dict(interval=1)
|
||||
```
|
||||
|
||||
The users could set `max_keep_ckpts` to only save only small number of checkpoints or decide whether to store state dict of optimizer by `save_optimizer`. More details of the arguments are [here](https://mmcv.readthedocs.io/en/latest/api.html#mmcv.runner.CheckpointHook)
|
||||
|
||||
#### Log config
|
||||
|
||||
The `log_config` wraps multiple logger hooks and enables to set intervals. Now MMCV supports `WandbLoggerHook`, `MlflowLoggerHook`, and `TensorboardLoggerHook`.
|
||||
The detail usages can be found in the [doc](https://mmcv.readthedocs.io/en/latest/api.html#mmcv.runner.LoggerHook).
|
||||
|
||||
```python
|
||||
log_config = dict(
|
||||
interval=50,
|
||||
hooks=[
|
||||
dict(type='TextLoggerHook'),
|
||||
dict(type='TensorboardLoggerHook')
|
||||
])
|
||||
```
|
||||
|
||||
#### Evaluation config
|
||||
|
||||
The config of `evaluation` will be used to initialize the [`EvalHook`](https://github.com/open-mmlab/mmsegmentation/blob/e3f6f655d69b777341aec2fe8829871cc0beadcb/mmseg/core/evaluation/eval_hooks.py#L7).
|
||||
Except the key `interval`, other arguments such as `metric` will be passed to the `dataset.evaluate()`
|
||||
|
||||
```python
|
||||
evaluation = dict(interval=1, metric='mIoU')
|
||||
optim_wrapper = dict(type='OptimWrapper',
|
||||
constructor='MyOptimizerConstructor',
|
||||
clip_grad=None)
|
||||
```
|
||||
|
Loading…
x
Reference in New Issue
Block a user