# Migration from MMSegmentation 0.x ## Introduction This guide describes the fundamental differences between MMSegmentation 0.x and MMSegmentation 1.x in terms of behaviors and the APIs, and how these all relate to your migration journey. ## New dependencies MMSegmentation 1.x depends on some new packages, you can prepare a new clean environment and install again according to the [installation tutorial](../get_started.md). Or install the below packages manually. 1. [MMEngine](https://github.com/open-mmlab/mmengine): MMEngine is the core the OpenMMLab 2.0 architecture, and we splited many compentents unrelated to computer vision from MMCV to MMEngine. 2. [MMCV](https://github.com/open-mmlab/mmcv): The computer vision package of OpenMMLab. This is not a new dependency, but you need to upgrade it to **2.0.0** version or above. 3. [MMClassification](https://github.com/open-mmlab/mmclassification)(Optional): The image classification toolbox and benchmark of OpenMMLab. This is not a new dependency, but you need to upgrade it to **1.0.0rc6** version. 4. [MMDetection](https://github.com/open-mmlab/mmdetection)(Optional): The object detection toolbox and benchmark of OpenMMLab. This is not a new dependency, but you need to upgrade it to **3.0.0** version or above. ## Train launch The main improvement of OpenMMLab 2.0 is releasing MMEngine which provides universal and powerful runner for unified interfaces to launch training jobs. Compared with MMSeg0.x, MMSeg1.x provides fewer command line arguments in `tools/train.py`
Function | Original | New |
Loading pre-trained checkpoint | --load_from=$CHECKPOINT | --cfg-options load_from=$CHECKPOINT |
Resuming Train from specific checkpoint | --resume-from=$CHECKPOINT | --resume=$CHECKPOINT |
Resuming Train from the latest checkpoint | --auto-resume | --resume='auto' |
Whether not to evaluate the checkpoint during training | --no-validate | --cfg-options val_cfg=None val_dataloader=None val_evaluator=None |
Training device assignment | --gpu-id=$DEVICE_ID | - |
Whether or not set different seeds for different ranks | --diff-seed | --cfg-options randomness.diff_rank_seed=True | Whether to set deterministic options for CUDNN backend | --deterministic | --cfg-options randomness.deterministic=True |
Function | 0.x | 1.x |
Evaluation metrics | --eval mIoU | --cfg-options test_evaluator.type=IoUMetric |
Whether to use test time augmentation | --aug-test | --tta |
Whether save the output results without perform evaluation | --format-only | --cfg-options test_evaluator.format_only=True |
Original | ```python data = dict( samples_per_gpu=4, workers_per_gpu=4, train=dict(...), val=dict(...), test=dict(...), ) ``` |
New | ```python train_dataloader = dict( batch_size=4, num_workers=4, dataset=dict(...), sampler=dict(type='DefaultSampler', shuffle=True) # necessary ) val_dataloader = dict( batch_size=4, num_workers=4, dataset=dict(...), sampler=dict(type='DefaultSampler', shuffle=False) # necessary ) test_dataloader = val_dataloader ``` |
Original | ```python train_pipeline = [ dict(type='LoadImageFromFile'), dict(type='LoadAnnotations', reduce_zero_label=True), dict(type='Resize', img_scale=(2560, 640), ratio_range=(0.5, 2.0)), dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), dict(type='RandomFlip', prob=0.5), dict(type='PhotoMetricDistortion'), dict(type='Normalize', **img_norm_cfg), dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), dict(type='DefaultFormatBundle'), dict(type='Collect', keys=['img', 'gt_semantic_seg']), ] ``` |
New | ```python train_pipeline = [ dict(type='LoadImageFromFile'), dict(type='LoadAnnotations', reduce_zero_label=True), dict( type='RandomResize', scale=(2560, 640), ratio_range=(0.5, 2.0), keep_ratio=True), dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), dict(type='RandomFlip', prob=0.5), dict(type='PhotoMetricDistortion'), dict(type='PackSegInputs') ] ``` |
Original | ```python test_pipeline = [ dict(type='LoadImageFromFile'), dict( type='MultiScaleFlipAug', img_scale=(2560, 640), # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], flip=False, transforms=[ dict(type='Resize', keep_ratio=True), dict(type='RandomFlip'), dict(type='Normalize', **img_norm_cfg), dict(type='ImageToTensor', keys=['img']), dict(type='Collect', keys=['img']), ]) ] ``` |
New | ```python test_pipeline = [ dict(type='LoadImageFromFile'), dict(type='Resize', scale=(2560, 640), keep_ratio=True), dict(type='LoadAnnotations', reduce_zero_label=True), dict(type='PackSegInputs') ] img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75] tta_pipeline = [ dict(type='LoadImageFromFile', backend_args=None), dict( type='TestTimeAug', transforms=[ [ dict(type='Resize', scale_factor=r, keep_ratio=True) for r in img_ratios ], [ dict(type='RandomFlip', prob=0., direction='horizontal'), dict(type='RandomFlip', prob=1., direction='horizontal') ], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')] ]) ] ``` |
Original | ```python evaluation = dict(interval=2000, metric='mIoU', pre_eval=True) ``` |
New | ```python val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU']) test_evaluator = val_evaluator ``` |
Original | ```python optimizer = dict(type='AdamW', lr=0.0001, weight_decay=0.0005) optimizer_config = dict(grad_clip=dict(max_norm=1, norm_type=2)) ``` |
New | ```python optim_wrapper = dict( type='OptimWrapper', optimizer=dict(type='AdamW', lr=0.0001, weight_decay=0.0005), clip_grad=dict(max_norm=1, norm_type=2)) ``` |
Original | ```python lr_config = dict( policy='poly', warmup='linear', warmup_iters=1500, warmup_ratio=1e-6, power=1.0, min_lr=0.0, by_epoch=False) ``` |
New | ```python param_scheduler = [ dict( type='LinearLR', start_factor=1e-6, by_epoch=False, begin=0, end=1500), dict( type='PolyLR', power=1.0, begin=1500, end=160000, eta_min=0.0, by_epoch=False, ) ] ``` |
Original | ```python runner = dict(type='IterBasedRunner', max_iters=20000) ``` |
New | ```python # The `val_interval` is the original `evaluation.interval`. train_cfg = dict(type='IterBasedTrainLoop', max_iters=20000, val_interval=2000) val_cfg = dict(type='ValLoop') # Use the default validation loop. test_cfg = dict(type='TestLoop') # Use the default test loop. ``` |
Original | ```python log_config = dict( interval=100, hooks=[ dict(type='TextLoggerHook'), dict(type='TensorboardLoggerHook'), ]) ``` |
New | ```python default_hooks = dict( ... logger=dict(type='LoggerHook', interval=100), ) vis_backends = [dict(type='LocalVisBackend'), dict(type='TensorboardVisBackend')] visualizer = dict( type='SegLocalVisualizer', vis_backends=vis_backends, name='visualizer') ``` |