Created KITTI dataset for segmentation in autonomous driving scenario (#2730)
Note that this PR is a modified version of the withdrawn PR https://github.com/open-mmlab/mmsegmentation/pull/1748 ## Motivation In the last years, panoptic segmentation has become more into the focus in reseach. Weber et al. [[Link]](http://www.cvlibs.net/publications/Weber2021NEURIPSDATA.pdf) have published a quite nice dataset, which is in the same style like Cityscapes, but for KITTI sequences. Since Cityscapes and KITTI-STEP share the same classes and also a comparable domain (dashcam view), interesting investigations, e.g. about relations in the domain e.t.c. can be done. Note that KITTI-STEP provices panoptic segmentation annotations which are out of scope for mmsegmentation. ## Modification Mostly, I added the new dataset and dataset preparation file. To simplify the first usage of the new dataset, I also added configs for the dataset, segformer and deeplabv3plus. ## BC-breaking (Optional) No BC-breaking ## Use cases (Optional) Researchers want to test their new methods, e.g. for interpretable AI in the context of semantic segmentation. They want to show, that their method is reproducible on comparable datasets. Thus, they can compare Cityscapes and KITTI-STEP. --------- Co-authored-by: CSH <40987381+csatsurnh@users.noreply.github.com> Co-authored-by: csatsurnh <cshan1995@126.com> Co-authored-by: 谢昕辰 <xiexinch@outlook.com>master
parent
83e7cc24ea
commit
a85675c16f
|
@ -0,0 +1,97 @@
|
|||
# KITTI STEP Dataset
|
||||
|
||||
Support **`KITTI STEP Dataset`**
|
||||
|
||||
## Description
|
||||
|
||||
Author: TimoK93
|
||||
|
||||
This project implements **`KITTI STEP Dataset`**
|
||||
|
||||
### Dataset preparing
|
||||
|
||||
After registration, the data images could be download from [KITTI-STEP](http://www.cvlibs.net/datasets/kitti/eval_step.php)
|
||||
|
||||
You may need to follow the following structure for dataset preparation after downloading KITTI-STEP dataset.
|
||||
|
||||
```
|
||||
mmsegmentation
|
||||
├── mmseg
|
||||
├── tools
|
||||
├── configs
|
||||
├── data
|
||||
│ ├── kitti_step
|
||||
│ │ ├── testing
|
||||
│ │ ├── training
|
||||
│ │ ├── panoptic_maps
|
||||
```
|
||||
|
||||
Run the preparation script to generate label files and kitti subsets by executing
|
||||
|
||||
```shell
|
||||
python tools/convert_datasets/kitti_step.py /path/to/kitti_step
|
||||
```
|
||||
|
||||
After executing the script, your directory should look like
|
||||
|
||||
```
|
||||
mmsegmentation
|
||||
├── mmseg
|
||||
├── tools
|
||||
├── configs
|
||||
├── data
|
||||
│ ├── kitti_step
|
||||
│ │ ├── testing
|
||||
│ │ ├── training
|
||||
│ │ ├── panoptic_maps
|
||||
│ │ ├── training_openmmlab
|
||||
│ │ ├── panoptic_maps_openmmlab
|
||||
```
|
||||
|
||||
### Training commands
|
||||
|
||||
```bash
|
||||
# Dataset train commands
|
||||
# at `mmsegmentation` folder
|
||||
bash tools/dist_train.sh projects/kitti_step_dataset/configs/segformer/segformer_mit-b5_368x368_160k_kittistep.py 8
|
||||
```
|
||||
|
||||
### Testing commands
|
||||
|
||||
```bash
|
||||
mim test mmsegmentation projects/kitti_step_dataset/configs/segformer/segformer_mit-b5_368x368_160k_kittistep.py --work-dir work_dirs/segformer_mit-b5_368x368_160k_kittistep --checkpoint ${CHECKPOINT_PATH} --eval mIoU
|
||||
```
|
||||
|
||||
| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | config | model | log |
|
||||
| --------- | -------- | --------- | ------: | -------- | -------------- | ----: | ------------: | ---------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| Segformer | MIT-B5 | 368x368 | 160000 | - | - | 65.05 | - | [config](configs/segformer/segformer_mit-b5_368x368_160k_kittistep.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/segformer/segformer_mit-b5_368x368_160k_kittistep/segformer_mit-b5_368x368_160k_kittistep_20230506_103002-20797496.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/segformer/segformer_mit-b5_368x368_160k_kittistep/segformer_mit-b5_368x368_160k_kittistep_20230506_103002.log.json) |
|
||||
|
||||
## Checklist
|
||||
|
||||
- [x] Milestone 1: PR-ready, and acceptable to be one of the `projects/`.
|
||||
|
||||
- [x] Finish the code
|
||||
|
||||
- [x] Basic docstrings & proper citation
|
||||
|
||||
- [ ] Test-time correctness
|
||||
|
||||
- [x] A full README
|
||||
|
||||
- [ ] Milestone 2: Indicates a successful model implementation.
|
||||
|
||||
- [ ] Training-time correctness
|
||||
|
||||
- [ ] Milestone 3: Good to be a part of our core package!
|
||||
|
||||
- [ ] Type hints and docstrings
|
||||
|
||||
- [ ] Unit tests
|
||||
|
||||
- [ ] Code polishing
|
||||
|
||||
- [ ] Metafile.yml
|
||||
|
||||
- [ ] Move your modules into the core package following the codebase's file hierarchy structure.
|
||||
|
||||
- [ ] Refactor your modules into the core package following the codebase's file hierarchy structure.
|
|
@ -0,0 +1,54 @@
|
|||
# dataset settings
|
||||
dataset_type = 'KITTISTEPDataset'
|
||||
data_root = 'data/kitti_step/'
|
||||
img_norm_cfg = dict(
|
||||
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
||||
crop_size = (368, 368)
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='LoadAnnotations'),
|
||||
dict(type='Resize', img_scale=(1242, 375), ratio_range=(0.5, 2.0)),
|
||||
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
|
||||
dict(type='RandomFlip', prob=0.5),
|
||||
dict(type='PhotoMetricDistortion'),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
|
||||
dict(type='DefaultFormatBundle'),
|
||||
dict(type='Collect', keys=['img', 'gt_semantic_seg']),
|
||||
]
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='MultiScaleFlipAug',
|
||||
img_scale=(1242, 375),
|
||||
img_ratios=[1.0],
|
||||
flip=False,
|
||||
transforms=[
|
||||
dict(type='Resize', keep_ratio=True),
|
||||
dict(type='RandomFlip'),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='ImageToTensor', keys=['img']),
|
||||
dict(type='Collect', keys=['img']),
|
||||
])
|
||||
]
|
||||
data = dict(
|
||||
samples_per_gpu=2,
|
||||
workers_per_gpu=2,
|
||||
train=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
img_dir='training_openmmlab/image_02/train',
|
||||
ann_dir='panoptic_maps_openmmlab/train',
|
||||
pipeline=train_pipeline),
|
||||
val=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
img_dir='training_openmmlab/image_02/val',
|
||||
ann_dir='panoptic_maps_openmmlab/val',
|
||||
pipeline=test_pipeline),
|
||||
test=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
img_dir='training_openmmlab/image_02/val',
|
||||
ann_dir='panoptic_maps_openmmlab/val',
|
||||
pipeline=test_pipeline))
|
|
@ -0,0 +1,10 @@
|
|||
_base_ = [
|
||||
'../../../../configs/_base_/models/deeplabv3plus_r50-d8.py',
|
||||
'../_base_/datasets/kittistep.py',
|
||||
'../../../../configs/_base_/default_runtime.py',
|
||||
'../../../../configs/_base_/schedules/schedule_80k.py'
|
||||
]
|
||||
model = dict(
|
||||
decode_head=dict(align_corners=True),
|
||||
auxiliary_head=dict(align_corners=True),
|
||||
test_cfg=dict(mode='slide', crop_size=(769, 769), stride=(513, 513)))
|
|
@ -0,0 +1,38 @@
|
|||
_base_ = [
|
||||
'../../../../configs/_base_/models/segformer_mit-b0.py',
|
||||
'../_base_/datasets/kittistep.py',
|
||||
'../../../../configs/_base_/default_runtime.py',
|
||||
'../../../../configs/_base_/schedules/schedule_160k.py'
|
||||
]
|
||||
|
||||
checkpoint = 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/segformer/mit_b0_20220624-7e0fe6dd.pth' # noqa
|
||||
|
||||
model = dict(
|
||||
backbone=dict(init_cfg=dict(type='Pretrained', checkpoint=checkpoint)),
|
||||
test_cfg=dict(mode='slide', crop_size=(1024, 1024), stride=(768, 768)))
|
||||
|
||||
# optimizer
|
||||
optimizer = dict(
|
||||
_delete_=True,
|
||||
type='AdamW',
|
||||
lr=0.00006,
|
||||
betas=(0.9, 0.999),
|
||||
weight_decay=0.01,
|
||||
paramwise_cfg=dict(
|
||||
custom_keys={
|
||||
'pos_block': dict(decay_mult=0.),
|
||||
'norm': dict(decay_mult=0.),
|
||||
'head': dict(lr_mult=10.)
|
||||
}))
|
||||
|
||||
lr_config = dict(
|
||||
_delete_=True,
|
||||
policy='poly',
|
||||
warmup='linear',
|
||||
warmup_iters=1500,
|
||||
warmup_ratio=1e-6,
|
||||
power=1.0,
|
||||
min_lr=0.0,
|
||||
by_epoch=False)
|
||||
|
||||
data = dict(samples_per_gpu=2, workers_per_gpu=2)
|
|
@ -0,0 +1,9 @@
|
|||
_base_ = ['./segformer_mit-b0_368x368_160k_kittistep.py']
|
||||
|
||||
checkpoint = 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/segformer/mit_b5_20220624-658746d9.pth' # noqa
|
||||
model = dict(
|
||||
backbone=dict(
|
||||
init_cfg=dict(type='Pretrained', checkpoint=checkpoint),
|
||||
embed_dims=64,
|
||||
num_layers=[3, 6, 40, 3]),
|
||||
decode_head=dict(in_channels=[64, 128, 320, 512]))
|
|
@ -0,0 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .kitti_step import KITTISTEPDataset
|
||||
|
||||
__all__ = [
|
||||
'KITTISTEPDataset',
|
||||
]
|
|
@ -0,0 +1,15 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmseg.datasets.builder import DATASETS
|
||||
from mmseg.datasets.cityscapes import CityscapesDataset
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class KITTISTEPDataset(CityscapesDataset):
|
||||
"""KITTI-STEP dataset."""
|
||||
|
||||
def __init__(self,
|
||||
img_suffix='.png',
|
||||
seg_map_suffix='_labelTrainIds.png',
|
||||
**kwargs):
|
||||
super(KITTISTEPDataset, self).__init__(
|
||||
img_suffix=img_suffix, seg_map_suffix=seg_map_suffix, **kwargs)
|
|
@ -0,0 +1,77 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import argparse
|
||||
import os
|
||||
import os.path as osp
|
||||
import shutil
|
||||
|
||||
import cv2
|
||||
import mmcv
|
||||
|
||||
|
||||
def kitti_to_train_ids(input):
|
||||
src, gt_dir, new_gt_dir = input
|
||||
label_file = src.replace('.png',
|
||||
'_labelTrainIds.png').replace(gt_dir, new_gt_dir)
|
||||
img = cv2.imread(src)
|
||||
dirname = os.path.dirname(label_file)
|
||||
os.makedirs(dirname, exist_ok=True)
|
||||
sem_seg = img[:, :, 2]
|
||||
cv2.imwrite(label_file, sem_seg)
|
||||
|
||||
|
||||
def copy_file(input):
|
||||
src, dst = input
|
||||
if not osp.exists(dst):
|
||||
os.makedirs(osp.dirname(dst), exist_ok=True)
|
||||
shutil.copyfile(src, dst)
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Convert KITTI-STEP annotations to TrainIds')
|
||||
parser.add_argument('kitti_path', help='kitti step data path')
|
||||
parser.add_argument('--gt-dir', default='panoptic_maps', type=str)
|
||||
parser.add_argument('-o', '--out-dir', help='output path')
|
||||
parser.add_argument(
|
||||
'--nproc', default=1, type=int, help='number of process')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
kitti_path = args.kitti_path
|
||||
out_dir = args.out_dir if args.out_dir else kitti_path
|
||||
mmcv.mkdir_or_exist(out_dir)
|
||||
|
||||
gt_dir = osp.join(kitti_path, args.gt_dir)
|
||||
|
||||
ann_files = []
|
||||
for poly in mmcv.scandir(gt_dir, '.png', recursive=True):
|
||||
poly_file = osp.join(gt_dir, poly)
|
||||
ann_files.append([poly_file, args.gt_dir, args.gt_dir + '_openmmlab'])
|
||||
|
||||
if args.nproc > 1:
|
||||
mmcv.track_parallel_progress(kitti_to_train_ids, ann_files, args.nproc)
|
||||
else:
|
||||
mmcv.track_progress(kitti_to_train_ids, ann_files)
|
||||
|
||||
copy_files = []
|
||||
for f in mmcv.scandir(gt_dir, '.png', recursive=True):
|
||||
original_f = osp.join(gt_dir, f).replace(args.gt_dir + '/train',
|
||||
'training/image_02')
|
||||
new_f = osp.join(gt_dir, f).replace(args.gt_dir,
|
||||
'training_openmmlab/image_02')
|
||||
original_f = original_f.replace(args.gt_dir + '/val',
|
||||
'training/image_02')
|
||||
new_f = new_f.replace(args.gt_dir, 'training_openmmlab/image_02')
|
||||
copy_files.append([original_f, new_f])
|
||||
|
||||
if args.nproc > 1:
|
||||
mmcv.track_parallel_progress(copy_file, copy_files, args.nproc)
|
||||
else:
|
||||
mmcv.track_progress(copy_file, copy_files)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -4,3 +4,4 @@ myst-parser
|
|||
sphinx==4.0.2
|
||||
sphinx_copybutton
|
||||
sphinx_markdown_tables
|
||||
urllib3<2.0.0
|
||||
|
|
Loading…
Reference in New Issue