# Migration from MMSegmentation 0.x ## Introduction This guide describes the fundamental differences between MMSegmentation 0.x and MMSegmentation 1.x in terms of behaviors and the APIs, and how these all relate to your migration journey. ## New dependencies MMSegmentation 1.x depends on some new packages, you can prepare a new clean environment and install again according to the [installation tutorial](../get_started.md). Or install the below packages manually. 1. [MMEngine](https://github.com/open-mmlab/mmengine): MMEngine is the core the OpenMMLab 2.0 architecture, and we splited many compentents unrelated to computer vision from MMCV to MMEngine. 2. [MMCV](https://github.com/open-mmlab/mmcv): The computer vision package of OpenMMLab. This is not a new dependency, but you need to upgrade it to **2.0.0** version or above. 3. [MMClassification](https://github.com/open-mmlab/mmclassification)(Optional): The image classification toolbox and benchmark of OpenMMLab. This is not a new dependency, but you need to upgrade it to **1.0.0rc6** version. 4. [MMDetection](https://github.com/open-mmlab/mmdetection)(Optional): The object detection toolbox and benchmark of OpenMMLab. This is not a new dependency, but you need to upgrade it to **3.0.0** version or above. ## Train launch The main improvement of OpenMMLab 2.0 is releasing MMEngine which provides universal and powerful runner for unified interfaces to launch training jobs. Compared with MMSeg0.x, MMSeg1.x provides fewer command line arguments in `tools/train.py`
Function Original New
Loading pre-trained checkpoint --load_from=$CHECKPOINT --cfg-options load_from=$CHECKPOINT
Resuming Train from specific checkpoint --resume-from=$CHECKPOINT --resume=$CHECKPOINT
Resuming Train from the latest checkpoint --auto-resume --resume='auto'
Whether not to evaluate the checkpoint during training --no-validate --cfg-options val_cfg=None val_dataloader=None val_evaluator=None
Training device assignment --gpu-id=$DEVICE_ID -
Whether or not set different seeds for different ranks --diff-seed --cfg-options randomness.diff_rank_seed=True
Whether to set deterministic options for CUDNN backend --deterministic --cfg-options randomness.deterministic=True
## Test launch Similar to training launch, there are only common arguments in tools/test.py of MMSegmentation 1.x. Below is the difference in test scripts, please refer to [this documentation](../user_guides/4_train_test.md) for more details about test launch.
Function 0.x 1.x
Evaluation metrics --eval mIoU --cfg-options test_evaluator.type=IoUMetric
Whether to use test time augmentation --aug-test --tta
Whether save the output results without perform evaluation --format-only --cfg-options test_evaluator.format_only=True
## Configuration file ### Model settings No changes in `model.backbone`, `model.neck`, `model.decode_head` and `model.losses` fields. Add `model.data_preprocessor` field to configure the `DataPreProcessor`, including: - `mean` (Sequence, optional): The pixel mean of R, G, B channels. Defaults to None. - `std` (Sequence, optional): The pixel standard deviation of R, G, B channels. Defaults to None. - `size` (Sequence, optional): Fixed padding size. - `size_divisor` (int, optional): The divisor of padded size. - `seg_pad_val` (float, optional): Padding value of segmentation map. Default: 255. - `padding_mode` (str): Type of padding. Default: 'constant'. - constant: pads with a constant value, this value is specified with pad_val. - `bgr_to_rgb` (bool): whether to convert image from BGR to RGB.Defaults to False. - `rgb_to_bgr` (bool): whether to convert image from RGB to BGR. Defaults to False. **Note:** Please refer [models documentation](../advanced_guides/models.md) for more details. ### Dataset settings Changes in **data**: The original `data` field is split to `train_dataloader`, `val_dataloader` and `test_dataloader`. This allows us to configure them in fine-grained. For example, you can specify different sampler and batch size during training and test. The `samples_per_gpu` is renamed to `batch_size`. The `workers_per_gpu` is renamed to `num_workers`.
Original ```python data = dict( samples_per_gpu=4, workers_per_gpu=4, train=dict(...), val=dict(...), test=dict(...), ) ```
New ```python train_dataloader = dict( batch_size=4, num_workers=4, dataset=dict(...), sampler=dict(type='DefaultSampler', shuffle=True) # necessary ) val_dataloader = dict( batch_size=4, num_workers=4, dataset=dict(...), sampler=dict(type='DefaultSampler', shuffle=False) # necessary ) test_dataloader = val_dataloader ```
Changes in **pipeline** - The original formatting transforms **`ToTensor`**、**`ImageToTensor`**、**`Collect`** are combined as [`PackSegInputs`](mmseg.datasets.transforms.PackSegInputs) - We don't recommend to do **`Normalize`** and **Pad** in the dataset pipeline. Please remove it from pipelines and set it in the `data_preprocessor` field. - The original **`Resize`** in MMSeg 1.x has been changed to **`RandomResize`** and the input arguments `img_scale` is renamed to `scale`, and the default value of `keep_ratio` is modified to False. - The original `test_pipeline` combines single-scale test and multi-scale test together, in MMSeg 1.x we separate it into `test_pipeline` and `tta_pipeline`. **Note:** We move some work of data transforms to the data preprocessor, like normalization, see [the documentation](package.md) for more details. train_pipeline
Original ```python train_pipeline = [ dict(type='LoadImageFromFile'), dict(type='LoadAnnotations', reduce_zero_label=True), dict(type='Resize', img_scale=(2560, 640), ratio_range=(0.5, 2.0)), dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), dict(type='RandomFlip', prob=0.5), dict(type='PhotoMetricDistortion'), dict(type='Normalize', **img_norm_cfg), dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), dict(type='DefaultFormatBundle'), dict(type='Collect', keys=['img', 'gt_semantic_seg']), ] ```
New ```python train_pipeline = [ dict(type='LoadImageFromFile'), dict(type='LoadAnnotations', reduce_zero_label=True), dict( type='RandomResize', scale=(2560, 640), ratio_range=(0.5, 2.0), keep_ratio=True), dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), dict(type='RandomFlip', prob=0.5), dict(type='PhotoMetricDistortion'), dict(type='PackSegInputs') ] ```
test_pipeline
Original ```python test_pipeline = [ dict(type='LoadImageFromFile'), dict( type='MultiScaleFlipAug', img_scale=(2560, 640), # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], flip=False, transforms=[ dict(type='Resize', keep_ratio=True), dict(type='RandomFlip'), dict(type='Normalize', **img_norm_cfg), dict(type='ImageToTensor', keys=['img']), dict(type='Collect', keys=['img']), ]) ] ```
New ```python test_pipeline = [ dict(type='LoadImageFromFile'), dict(type='Resize', scale=(2560, 640), keep_ratio=True), dict(type='LoadAnnotations', reduce_zero_label=True), dict(type='PackSegInputs') ] img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75] tta_pipeline = [ dict(type='LoadImageFromFile', backend_args=None), dict( type='TestTimeAug', transforms=[ [ dict(type='Resize', scale_factor=r, keep_ratio=True) for r in img_ratios ], [ dict(type='RandomFlip', prob=0., direction='horizontal'), dict(type='RandomFlip', prob=1., direction='horizontal') ], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')] ]) ] ```
Changes in **`evaluation`**: - The **`evaluation`** field is split to `val_evaluator` and `test_evaluator`. And it won't support `interval` and `save_best` arguments. The `interval` is moved to `train_cfg.val_interval`, and the `save_best` is moved to `default_hooks.checkpoint.save_best`. `pre_eval` has been removed. - `'mIoU'` has been changed to `'IoUMetric'`.
Original ```python evaluation = dict(interval=2000, metric='mIoU', pre_eval=True) ```
New ```python val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU']) test_evaluator = val_evaluator ```
### Optimizer and Schedule settings Changes in **`optimizer`** and **`optimizer_config`**: - Now we use `optim_wrapper` field to specify all configuration about the optimization process. And the `optimizer` is a sub field of `optim_wrapper` now. - `paramwise_cfg` is also a sub field of `optim_wrapper`, instead of `optimizer`. - `optimizer_config` is removed now, and all configurations of it are moved to `optim_wrapper`. - `grad_clip` is renamed to `clip_grad`.
Original ```python optimizer = dict(type='AdamW', lr=0.0001, weight_decay=0.0005) optimizer_config = dict(grad_clip=dict(max_norm=1, norm_type=2)) ```
New ```python optim_wrapper = dict( type='OptimWrapper', optimizer=dict(type='AdamW', lr=0.0001, weight_decay=0.0005), clip_grad=dict(max_norm=1, norm_type=2)) ```
Changes in **`lr_config`**: - The `lr_config` field is removed and we use new `param_scheduler` to replace it. - The `warmup` related arguments are removed, since we use schedulers combination to implement this functionality. The new schedulers combination mechanism is very flexible, and you can use it to design many kinds of learning rate / momentum curves. See [the tutorial](TODO) for more details.
Original ```python lr_config = dict( policy='poly', warmup='linear', warmup_iters=1500, warmup_ratio=1e-6, power=1.0, min_lr=0.0, by_epoch=False) ```
New ```python param_scheduler = [ dict( type='LinearLR', start_factor=1e-6, by_epoch=False, begin=0, end=1500), dict( type='PolyLR', power=1.0, begin=1500, end=160000, eta_min=0.0, by_epoch=False, ) ] ```
Changes in **`runner`**: Most configuration in the original `runner` field is moved to `train_cfg`, `val_cfg` and `test_cfg`, which configure the loop in training, validation and test.
Original ```python runner = dict(type='IterBasedRunner', max_iters=20000) ```
New ```python # The `val_interval` is the original `evaluation.interval`. train_cfg = dict(type='IterBasedTrainLoop', max_iters=20000, val_interval=2000) val_cfg = dict(type='ValLoop') # Use the default validation loop. test_cfg = dict(type='TestLoop') # Use the default test loop. ```
In fact, in OpenMMLab 2.0, we introduced `Loop` to control the behaviors in training, validation and test. The functionalities of `Runner` are also changed. You can find more details of [runner tutorial](https://github.com/open-mmlab/mmengine/blob/main/docs/en/design/runner.md) in [MMEngine](https://github.com/open-mmlab/mmengine/). ### Runtime settings Changes in **`checkpoint_config`** and **`log_config`**: The `checkpoint_config` are moved to `default_hooks.checkpoint` and the `log_config` are moved to `default_hooks.logger`. And we move many hooks settings from the script code to the `default_hooks` field in the runtime configuration. ```python default_hooks = dict( # record the time of every iterations. timer=dict(type='IterTimerHook'), # print log every 50 iterations. logger=dict(type='LoggerHook', interval=50, log_metric_by_epoch=False), # enable the parameter scheduler. param_scheduler=dict(type='ParamSchedulerHook'), # save checkpoint every 2000 iterations. checkpoint=dict(type='CheckpointHook', by_epoch=False, interval=2000), # set sampler seed in distributed environment. sampler_seed=dict(type='DistSamplerSeedHook'), # validation results visualization. visualization=dict(type='SegVisualizationHook')) ``` In addition, we split the original logger to logger and visualizer. The logger is used to record information and the visualizer is used to show the logger in different backends, like terminal and TensorBoard.
Original ```python log_config = dict( interval=100, hooks=[ dict(type='TextLoggerHook'), dict(type='TensorboardLoggerHook'), ]) ```
New ```python default_hooks = dict( ... logger=dict(type='LoggerHook', interval=100), ) vis_backends = [dict(type='LocalVisBackend'), dict(type='TensorboardVisBackend')] visualizer = dict( type='SegLocalVisualizer', vis_backends=vis_backends, name='visualizer') ```
Changes in **`load_from`** and **`resume_from`**: - The `resume_from` is removed. And we use `resume` and `load_from` to replace it. - If `resume=True` and `load_from` is **not None**, resume training from the checkpoint in `load_from`. - If `resume=True` and `load_from` is **None**, try to resume from the latest checkpoint in the work directory. - If `resume=False` and `load_from` is **not None**, only load the checkpoint, not resume training. - If `resume=False` and `load_from` is **None**, do not load nor resume. Changes in **`dist_params`**: The `dist_params` field is a sub field of `env_cfg` now. And there are some new configurations in the `env_cfg`. ```python env_cfg = dict( # whether to enable cudnn benchmark cudnn_benchmark=False, # set multi process parameters mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), # set distributed parameters dist_cfg=dict(backend='nccl'), ) ``` Changes in **`workflow`**: `workflow` related functionalities are removed. New field **`visualizer`**: The visualizer is a new design in OpenMMLab 2.0 architecture. We use a visualizer instance in the runner to handle results & log visualization and save to different backends. See the [visualization tutorial](../user_guides/visualization.md) for more details. New field **`default_scope`**: The start point to search module for all registries. The `default_scope` in MMSegmentation is `mmseg`. See [the registry tutorial](https://github.com/open-mmlab/mmengine/blob/main/docs/en/advanced_tutorials/registry.md) for more details.