CodeCamp #1555[Feature] Support Mapillary Vistas Dataset (#2484)

## Support `Mapillary Vistas Dataset`

## Motivation

Support  **`Mapillary Vistas Dataset`**
Dataset Paper link : https://ieeexplore.ieee.org/document/9878466/
Download and more information view
https://www.mapillary.com/dataset/vistas
```
@InProceedings{Neuhold_2017_ICCV,
author = {Neuhold, Gerhard and Ollmann, Tobias and Rota Bulo, Samuel and Kontschieder, Peter},
title = {The Mapillary Vistas Dataset for Semantic Understanding of Street Scenes},
booktitle = {Proceedings of the IEEE International Conference on Computer Vision (ICCV)},
month = {Oct},
year = {2017}
}
```

## Modification

Add `Mapillary_dataset` in `mmsegmentation/projects`
Add `configs/_base_/mapillary_v1_2.py` and
`configs/_base_/mapillary_v2_0.py`
Add `configs/deeplabv3plus_r18-d8_4xb2-80k_mapillay-512x1024.py` to test
training and testing on Mapillary datasets
Add `docs/en/user_guides/2_dataset_prepare.md` , add Mapillary Vistas
Dataset Preparing and Structure.
Add `tools/dataset_converters/mapillary.py` to convert RGB labels to
Mask labels.

Co-authored-by: 谢昕辰 <xiexinch@outlook.com>
pull/2516/head
Tianlong Ai 2023-01-20 14:25:51 +08:00 committed by GitHub
parent f678a5c974
commit e394e2aa28
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 867 additions and 0 deletions

View File

@ -0,0 +1,85 @@
# Mapillary Vistas Dataset
Support **`Mapillary Vistas Dataset`**
## Description
Author: AI-Tianlong
This project implements **`Mapillary Vistas Dataset`**
### Dataset preparing
Preparing `Mapillary Vistas Dataset` dataset following [Mapillary Vistas Dataset Preparing Guide](https://github.com/open-mmlab/mmsegmentation/tree/dev-1.x/projects/mapillary_dataset/docs/en/user_guides/2_dataset_prepare.md)
```none
mmsegmentation
├── mmseg
├── tools
├── configs
├── data
│ ├── mapillary
│ │ ├── training
│ │ │ ├── images
│ │ │ ├── v1.2
| │ │ │ ├── instances
| │ │ │ ├── labels
| │ │ │ ├── labels_mask
| │   │   │ └── panoptic
│ │ │ ├── v2.0
| │ │ │ ├── instances
| │ │ │ ├── labels
| │ │ │ ├── labels_mask
| │ │ │ ├── panoptic
| │   │   │ └── polygons
│ │ ├── validation
│ │ │ ├── images
| │ │ │ ├── instances
| │ │ │ ├── labels
| │ │ │ ├── labels_mask
| │   │   │ └── panoptic
│ │ │ ├── v2.0
| │ │ │ ├── instances
| │ │ │ ├── labels
| │ │ │ ├── labels_mask
| │ │ │ ├── panoptic
| │   │   │ └── polygons
```
### Training commands with `deeplabv3plus_r101-d8_4xb2-240k_mapillay-512x1024.py`
```bash
# Dataset train commands
# at `mmsegmentation` folder
bash tools/dist_train.sh projects/mapillary_dataset/configs/deeplabv3plus_r101-d8_4xb2-240k_mapillay-512x1024.py 4
```
## Checklist
- [x] Milestone 1: PR-ready, and acceptable to be one of the `projects/`.
- [x] Finish the code
- [x] Basic docstrings & proper citation
- [ ] Test-time correctness
- [x] A full README
- [ ] Milestone 2: Indicates a successful model implementation.
- [ ] Training-time correctness
- [ ] Milestone 3: Good to be a part of our core package!
- [ ] Type hints and docstrings
- [ ] Unit tests
- [ ] Code polishing
- [ ] Metafile.yml
- [ ] Move your modules into the core package following the codebase's file hierarchy structure.
- [ ] Refactor your modules into the core package following the codebase's file hierarchy structure.

View File

@ -0,0 +1,69 @@
# dataset settings
dataset_type = 'MapillaryDataset_v1_2'
data_root = 'data/mapillary/'
crop_size = (512, 1024)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations'),
dict(
type='RandomResize',
scale=(2048, 1024),
ratio_range=(0.5, 2.0),
keep_ratio=True),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'),
dict(type='PackSegInputs')
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='Resize', scale=(2048, 1024), keep_ratio=True),
# add loading annotation after ``Resize`` because ground truth
# does not need to do resize data transform
dict(type='LoadAnnotations'),
dict(type='PackSegInputs')
]
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
tta_pipeline = [
dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),
dict(
type='TestTimeAug',
transforms=[
[
dict(type='Resize', scale_factor=r, keep_ratio=True)
for r in img_ratios
],
[
dict(type='RandomFlip', prob=0., direction='horizontal'),
dict(type='RandomFlip', prob=1., direction='horizontal')
], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
])
]
train_dataloader = dict(
batch_size=2,
num_workers=4,
persistent_workers=True,
sampler=dict(type='InfiniteSampler', shuffle=True),
dataset=dict(
type=dataset_type,
data_root=data_root,
data_prefix=dict(
img_path='training/images',
seg_map_path='training/v1.2/labels_mask'),
pipeline=train_pipeline))
val_dataloader = dict(
batch_size=1,
num_workers=4,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type,
data_root=data_root,
data_prefix=dict(
img_path='validation/images',
seg_map_path='validation/v1.2/labels_mask'),
pipeline=test_pipeline))
test_dataloader = val_dataloader
val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU'])
test_evaluator = val_evaluator

View File

@ -0,0 +1,69 @@
# dataset settings
dataset_type = 'MapillaryDataset_v2_0'
data_root = 'data/mapillary/'
crop_size = (512, 1024)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations'),
dict(
type='RandomResize',
scale=(2048, 1024),
ratio_range=(0.5, 2.0),
keep_ratio=True),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'),
dict(type='PackSegInputs')
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='Resize', scale=(2048, 1024), keep_ratio=True),
# add loading annotation after ``Resize`` because ground truth
# does not need to do resize data transform
dict(type='LoadAnnotations'),
dict(type='PackSegInputs')
]
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
tta_pipeline = [
dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),
dict(
type='TestTimeAug',
transforms=[
[
dict(type='Resize', scale_factor=r, keep_ratio=True)
for r in img_ratios
],
[
dict(type='RandomFlip', prob=0., direction='horizontal'),
dict(type='RandomFlip', prob=1., direction='horizontal')
], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
])
]
train_dataloader = dict(
batch_size=2,
num_workers=4,
persistent_workers=True,
sampler=dict(type='InfiniteSampler', shuffle=True),
dataset=dict(
type=dataset_type,
data_root=data_root,
data_prefix=dict(
img_path='training/images',
seg_map_path='training/v2.0/labels_mask'),
pipeline=train_pipeline))
val_dataloader = dict(
batch_size=1,
num_workers=4,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type,
data_root=data_root,
data_prefix=dict(
img_path='validation/images',
seg_map_path='validation/v2.0/labels_mask'),
pipeline=test_pipeline))
test_dataloader = val_dataloader
val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU'])
test_evaluator = val_evaluator

View File

@ -0,0 +1,103 @@
_base_ = ['./_base_/datasets/mapillary_v1_2.py'] # v 1.2 labels
# _base_ = ['./_base_/datasets/mapillary_v2_0.py'] # v2.0 labels
custom_imports = dict(imports=[
'projects.mapillary_dataset.mmseg.datasets.mapillary_v1_2',
'projects.mapillary_dataset.mmseg.datasets.mapillary_v2_0',
])
norm_cfg = dict(type='SyncBN', requires_grad=True)
data_preprocessor = dict(
type='SegDataPreProcessor',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
bgr_to_rgb=True,
pad_val=0,
seg_pad_val=255,
size=(512, 1024))
model = dict(
type='EncoderDecoder',
data_preprocessor=data_preprocessor,
pretrained=None,
backbone=dict(
type='ResNet',
depth=101,
num_stages=4,
out_indices=(0, 1, 2, 3),
dilations=(1, 1, 2, 4),
strides=(1, 2, 1, 1),
norm_cfg=norm_cfg,
norm_eval=False,
style='pytorch',
contract_dilation=True),
decode_head=dict(
type='DepthwiseSeparableASPPHead',
in_channels=2048,
in_index=3,
channels=512,
dilations=(1, 12, 24, 36),
c1_in_channels=256,
c1_channels=48,
dropout_ratio=0.1,
num_classes=66, # v1.2
# num_classes=124, # v2.0
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
auxiliary_head=dict(
type='FCNHead',
in_channels=1024,
in_index=2,
channels=256,
num_convs=1,
concat_input=False,
dropout_ratio=0.1,
num_classes=66, # v1.2
# num_classes=124, # v2.0
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
train_cfg=dict(),
test_cfg=dict(mode='whole'))
default_scope = 'mmseg'
env_cfg = dict(
cudnn_benchmark=True,
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
dist_cfg=dict(backend='nccl'))
vis_backends = [dict(type='LocalVisBackend')]
visualizer = dict(
type='SegLocalVisualizer',
vis_backends=[dict(type='LocalVisBackend')],
name='visualizer')
log_processor = dict(by_epoch=False)
log_level = 'INFO'
load_from = None
resume = False
tta_model = dict(type='SegTTAModel')
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)
optim_wrapper = dict(
type='OptimWrapper',
optimizer=dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001),
clip_grad=None)
param_scheduler = [
dict(
type='PolyLR',
eta_min=0.0001,
power=0.9,
begin=0,
end=240000,
by_epoch=False)
]
train_cfg = dict(
type='IterBasedTrainLoop', max_iters=240000, val_interval=24000)
val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')
default_hooks = dict(
timer=dict(type='IterTimerHook'),
logger=dict(type='LoggerHook', interval=50, log_metric_by_epoch=False),
param_scheduler=dict(type='ParamSchedulerHook'),
checkpoint=dict(type='CheckpointHook', by_epoch=False, interval=24000),
sampler_seed=dict(type='DistSamplerSeedHook'),
visualization=dict(type='SegVisualizationHook'))

View File

@ -0,0 +1,117 @@
## Prepare datasets
It is recommended to symlink the dataset root to `$MMSEGMENTATION/data`.
If your folder structure is different, you may need to change the corresponding paths in config files.
```none
mmsegmentation
├── mmseg
├── tools
├── configs
├── data
│ ├── mapillary
│ │ ├── training
│ │ │ ├── images
│ │ │ ├── v1.2
| │ │ │ ├── instances
| │ │ │ ├── labels
| │ │ │ ├── labels_mask
| │   │   │ └── panoptic
│ │ │ ├── v2.0
| │ │ │ ├── instances
| │ │ │ ├── labels
| │ │ │ ├── labels_mask
| │ │ │ ├── panoptic
| │   │   │ └── polygons
│ │ ├── validation
│ │ │ ├── images
| │ │ │ ├── instances
| │ │ │ ├── labels
| │ │ │ ├── labels_mask
| │   │   │ └── panoptic
│ │ │ ├── v2.0
| │ │ │ ├── instances
| │ │ │ ├── labels
| │ │ │ ├── labels_mask
| │ │ │ ├── panoptic
| │   │   │ └── polygons
```
## Mapillary Vistas Datasets
- The dataset could be download [here](https://www.mapillary.com/dataset/vistas) after registration.
- Assumption you have put the dataset zip file in `mmsegmentation/data`
- Please run the following commands to unzip dataset.
```bash
cd data
mkdir mapillary
unzip -d mapillary An-ZjB1Zm61yAZG0ozTymz8I8NqI4x0MrYrh26dq7kPgfu8vf9ImrdaOAVOFYbJ2pNAgUnVGBmbue9lTgdBOb5BbKXIpFs0fpYWqACbrQDChAA2fdX0zS9PcHu7fY8c-FOvyBVxPNYNFQuM.zip
```
- After unzip, you will get Mapillary Vistas Dataset like this structure.
```none
├── data
│ ├── mapillary
│ │ ├── training
│ │ │ ├── images
│ │ │ ├── v1.2
| │ │ │ ├── instances
| │ │ │ ├── labels
| │   │   │ └── panoptic
│ │ │ ├── v2.0
| │ │ │ ├── instances
| │ │ │ ├── labels
| │ │ │ ├── panoptic
| │   │   │ └── polygons
│ │ ├── validation
│ │ │ ├── images
| │ │ │ ├── instances
| │ │ │ ├── labels
| │   │   │ └── panoptic
│ │ │ ├── v2.0
| │ │ │ ├── instances
| │ │ │ ├── labels
| │ │ │ ├── panoptic
| │   │   │ └── polygons
```
- run following commands to convert RGB labels to mask labels
```bash
# --nproc optional, default 1, whether use multi-progress
# --version optional, 'v1.2', 'v2.0','all', default 'all', choose convert which version labels
# run this command at 'mmsegmentation/projects/Mapillary_dataset' folder
cd mmsegmentation/projects/mapillary_dataset
python tools/dataset_converters/mapillary.py ../../data/mapillary --nproc 8 --version all
```
After then, you will get this structure
```none
mmsegmentation
├── mmseg
├── tools
├── configs
├── data
│ ├── mapillary
│ │ ├── training
│ │ │ ├── images
│ │ │ ├── v1.2
| │ │ │ ├── instances
| │ │ │ ├── labels
| │ │ │ ├── labels_mask
| │   │   │ └── panoptic
│ │ │ ├── v2.0
| │ │ │ ├── instances
| │ │ │ ├── labels
| │ │ │ ├── labels_mask
| │ │ │ ├── panoptic
| │   │   │ └── polygons
│ │ ├── validation
│ │ │ ├── images
| │ │ │ ├── instances
| │ │ │ ├── labels
| │ │ │ ├── labels_mask
| │   │   │ └── panoptic
│ │ │ ├── v2.0
| │ │ │ ├── instances
| │ │ │ ├── labels
| │ │ │ ├── labels_mask
| │ │ │ ├── panoptic
| │   │   │ └── polygons
```

View File

@ -0,0 +1,65 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmseg.datasets.basesegdataset import BaseSegDataset
from mmseg.registry import DATASETS
@DATASETS.register_module()
class MapillaryDataset_v1_2(BaseSegDataset):
"""Mapillary Vistas Dataset.
Dataset paper link:
http://ieeexplore.ieee.org/document/8237796/
v1.2 contain 66 object classes.
(37 instance-specific)
v2.0 contain 124 object classes.
(70 instance-specific, 46 stuff, 8 void or crowd).
The ``img_suffix`` is fixed to '.jpg' and ``seg_map_suffix`` is
fixed to '.png' for Mapillary Vistas Dataset.
"""
METAINFO = dict(
classes=('Bird', 'Ground Animal', 'Curb', 'Fence', 'Guard Rail',
'Barrier', 'Wall', 'Bike Lane', 'Crosswalk - Plain',
'Curb Cut', 'Parking', 'Pedestrian Area', 'Rail Track',
'Road', 'Service Lane', 'Sidewalk', 'Bridge', 'Building',
'Tunnel', 'Person', 'Bicyclist', 'Motorcyclist',
'Other Rider', 'Lane Marking - Crosswalk',
'Lane Marking - General', 'Mountain', 'Sand', 'Sky', 'Snow',
'Terrain', 'Vegetation', 'Water', 'Banner', 'Bench',
'Bike Rack', 'Billboard', 'Catch Basin', 'CCTV Camera',
'Fire Hydrant', 'Junction Box', 'Mailbox', 'Manhole',
'Phone Booth', 'Pothole', 'Street Light', 'Pole',
'Traffic Sign Frame', 'Utility Pole', 'Traffic Light',
'Traffic Sign (Back)', 'Traffic Sign (Front)', 'Trash Can',
'Bicycle', 'Boat', 'Bus', 'Car', 'Caravan', 'Motorcycle',
'On Rails', 'Other Vehicle', 'Trailer', 'Truck',
'Wheeled Slow', 'Car Mount', 'Ego Vehicle', 'Unlabeled'),
palette=[[165, 42, 42], [0, 192, 0], [196, 196, 196], [190, 153, 153],
[180, 165, 180], [90, 120, 150], [102, 102, 156],
[128, 64, 255], [140, 140, 200], [170, 170, 170],
[250, 170, 160], [96, 96, 96],
[230, 150, 140], [128, 64, 128], [110, 110, 110],
[244, 35, 232], [150, 100, 100], [70, 70, 70], [150, 120, 90],
[220, 20, 60], [255, 0, 0], [255, 0, 100], [255, 0, 200],
[200, 128, 128], [255, 255, 255], [64, 170,
64], [230, 160, 50],
[70, 130, 180], [190, 255, 255], [152, 251, 152],
[107, 142, 35], [0, 170, 30], [255, 255, 128], [250, 0, 30],
[100, 140, 180], [220, 220, 220], [220, 128, 128],
[222, 40, 40], [100, 170, 30], [40, 40, 40], [33, 33, 33],
[100, 128, 160], [142, 0, 0], [70, 100, 150], [210, 170, 100],
[153, 153, 153], [128, 128, 128], [0, 0, 80], [250, 170, 30],
[192, 192, 192], [220, 220, 0], [140, 140, 20], [119, 11, 32],
[150, 0, 255], [0, 60, 100], [0, 0, 142], [0, 0, 90],
[0, 0, 230], [0, 80, 100], [128, 64, 64], [0, 0, 110],
[0, 0, 70], [0, 0, 192], [32, 32, 32], [120, 10,
10], [0, 0, 0]])
def __init__(self,
img_suffix='.jpg',
seg_map_suffix='.png',
**kwargs) -> None:
super().__init__(
img_suffix=img_suffix, seg_map_suffix=seg_map_suffix, **kwargs)

View File

@ -0,0 +1,114 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmseg.datasets.basesegdataset import BaseSegDataset
from mmseg.registry import DATASETS
@DATASETS.register_module()
class MapillaryDataset_v2_0(BaseSegDataset):
"""Mapillary Vistas Dataset.
Dataset paper link:
http://ieeexplore.ieee.org/document/8237796/
v1.2 contain 66 object classes.
(37 instance-specific)
v2.0 contain 124 object classes.
(70 instance-specific, 46 stuff, 8 void or crowd).
The ``img_suffix`` is fixed to '.jpg' and ``seg_map_suffix`` is
fixed to '.png' for Mapillary Vistas Dataset.
"""
METAINFO = dict(
classes=(
'Bird', 'Ground Animal', 'Ambiguous Barrier', 'Concrete Block',
'Curb', 'Fence', 'Guard Rail', 'Barrier', 'Road Median',
'Road Side', 'Lane Separator', 'Temporary Barrier', 'Wall',
'Bike Lane', 'Crosswalk - Plain', 'Curb Cut', 'Driveway',
'Parking', 'Parking Aisle', 'Pedestrian Area', 'Rail Track',
'Road', 'Road Shoulder', 'Service Lane', 'Sidewalk',
'Traffic Island', 'Bridge', 'Building', 'Garage', 'Tunnel',
'Person', 'Person Group', 'Bicyclist', 'Motorcyclist',
'Other Rider', 'Lane Marking - Dashed Line',
'Lane Marking - Straight Line', 'Lane Marking - Zigzag Line',
'Lane Marking - Ambiguous', 'Lane Marking - Arrow (Left)',
'Lane Marking - Arrow (Other)', 'Lane Marking - Arrow (Right)',
'Lane Marking - Arrow (Split Left or Straight)',
'Lane Marking - Arrow (Split Right or Straight)',
'Lane Marking - Arrow (Straight)', 'Lane Marking - Crosswalk',
'Lane Marking - Give Way (Row)',
'Lane Marking - Give Way (Single)',
'Lane Marking - Hatched (Chevron)',
'Lane Marking - Hatched (Diagonal)', 'Lane Marking - Other',
'Lane Marking - Stop Line', 'Lane Marking - Symbol (Bicycle)',
'Lane Marking - Symbol (Other)', 'Lane Marking - Text',
'Lane Marking (only) - Dashed Line',
'Lane Marking (only) - Crosswalk', 'Lane Marking (only) - Other',
'Lane Marking (only) - Test', 'Mountain', 'Sand', 'Sky', 'Snow',
'Terrain', 'Vegetation', 'Water', 'Banner', 'Bench', 'Bike Rack',
'Catch Basin', 'CCTV Camera', 'Fire Hydrant', 'Junction Box',
'Mailbox', 'Manhole', 'Parking Meter', 'Phone Booth', 'Pothole',
'Signage - Advertisement', 'Signage - Ambiguous', 'Signage - Back',
'Signage - Information', 'Signage - Other', 'Signage - Store',
'Street Light', 'Pole', 'Pole Group', 'Traffic Sign Frame',
'Utility Pole', 'Traffic Cone', 'Traffic Light - General (Single)',
'Traffic Light - Pedestrians', 'Traffic Light - General (Upright)',
'Traffic Light - General (Horizontal)', 'Traffic Light - Cyclists',
'Traffic Light - Other', 'Traffic Sign - Ambiguous',
'Traffic Sign (Back)', 'Traffic Sign - Direction (Back)',
'Traffic Sign - Direction (Front)', 'Traffic Sign (Front)',
'Traffic Sign - Parking', 'Traffic Sign - Temporary (Back)',
'Traffic Sign - Temporary (Front)', 'Trash Can', 'Bicycle', 'Boat',
'Bus', 'Car', 'Caravan', 'Motorcycle', 'On Rails', 'Other Vehicle',
'Trailer', 'Truck', 'Vehicle Group', 'Wheeled Slow', 'Water Valve',
'Car Mount', 'Dynamic', 'Ego Vehicle', 'Ground', 'Static',
'Unlabeled'),
palette=[[165, 42, 42], [0, 192, 0], [250, 170, 31], [250, 170, 32],
[196, 196, 196], [190, 153, 153], [180, 165, 180],
[90, 120, 150], [250, 170, 33], [250, 170, 34],
[128, 128, 128], [250, 170, 35], [102, 102, 156],
[128, 64, 255], [140, 140, 200], [170, 170, 170],
[250, 170, 36], [250, 170, 160], [250, 170, 37], [96, 96, 96],
[230, 150, 140], [128, 64, 128], [110, 110, 110],
[110, 110, 110], [244, 35, 232], [128, 196,
128], [150, 100, 100],
[70, 70, 70], [150, 150, 150], [150, 120, 90], [220, 20, 60],
[220, 20, 60], [255, 0, 0], [255, 0, 100], [255, 0, 200],
[255, 255, 255], [255, 255, 255], [250, 170, 29],
[250, 170, 28], [250, 170, 26], [250, 170,
25], [250, 170, 24],
[250, 170, 22], [250, 170, 21], [250, 170,
20], [255, 255, 255],
[250, 170, 19], [250, 170, 18], [250, 170,
12], [250, 170, 11],
[255, 255, 255], [255, 255, 255], [250, 170, 16],
[250, 170, 15], [250, 170, 15], [255, 255, 255],
[255, 255, 255], [255, 255, 255], [255, 255, 255],
[64, 170, 64], [230, 160, 50],
[70, 130, 180], [190, 255, 255], [152, 251, 152],
[107, 142, 35], [0, 170, 30], [255, 255, 128], [250, 0, 30],
[100, 140, 180], [220, 128, 128], [222, 40,
40], [100, 170, 30],
[40, 40, 40], [33, 33, 33], [100, 128, 160], [20, 20, 255],
[142, 0, 0], [70, 100, 150], [250, 171, 30], [250, 172, 30],
[250, 173, 30], [250, 174, 30], [250, 175,
30], [250, 176, 30],
[210, 170, 100], [153, 153, 153], [153, 153, 153],
[128, 128, 128], [0, 0, 80], [210, 60, 60], [250, 170, 30],
[250, 170, 30], [250, 170, 30], [250, 170,
30], [250, 170, 30],
[250, 170, 30], [192, 192, 192], [192, 192, 192],
[192, 192, 192], [220, 220, 0], [220, 220, 0], [0, 0, 196],
[192, 192, 192], [220, 220, 0], [140, 140, 20], [119, 11, 32],
[150, 0, 255], [0, 60, 100], [0, 0, 142], [0, 0, 90],
[0, 0, 230], [0, 80, 100], [128, 64, 64], [0, 0, 110],
[0, 0, 70], [0, 0, 142], [0, 0, 192], [170, 170, 170],
[32, 32, 32], [111, 74, 0], [120, 10, 10], [81, 0, 81],
[111, 111, 0], [0, 0, 0]])
def __init__(self,
img_suffix='.jpg',
seg_map_suffix='.png',
**kwargs) -> None:
super().__init__(
img_suffix=img_suffix, seg_map_suffix=seg_map_suffix, **kwargs)

View File

@ -0,0 +1,245 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os.path as osp
from functools import partial
import mmcv
import numpy as np
from mmengine.utils import (mkdir_or_exist, scandir, track_parallel_progress,
track_progress)
colormap_v1_2 = np.array([[165, 42, 42], [0, 192, 0], [196, 196, 196],
[190, 153, 153], [180, 165, 180], [90, 120, 150],
[102, 102, 156], [128, 64, 255], [140, 140, 200],
[170, 170, 170], [250, 170, 160], [96, 96, 96],
[230, 150, 140], [128, 64, 128], [110, 110, 110],
[244, 35, 232], [150, 100, 100], [70, 70, 70],
[150, 120, 90], [220, 20, 60], [255, 0, 0],
[255, 0, 100], [255, 0, 200], [200, 128, 128],
[255, 255, 255], [64, 170, 64], [230, 160, 50],
[70, 130, 180], [190, 255, 255], [152, 251, 152],
[107, 142, 35], [0, 170, 30], [255, 255, 128],
[250, 0, 30], [100, 140, 180], [220, 220, 220],
[220, 128, 128], [222, 40, 40], [100, 170, 30],
[40, 40, 40], [33, 33, 33], [100, 128, 160],
[142, 0, 0], [70, 100, 150], [210, 170, 100],
[153, 153, 153], [128, 128, 128], [0, 0, 80],
[250, 170, 30], [192, 192, 192], [220, 220, 0],
[140, 140, 20], [119, 11, 32], [150, 0, 255],
[0, 60, 100], [0, 0, 142], [0, 0, 90], [0, 0, 230],
[0, 80, 100], [128, 64, 64], [0, 0, 110], [0, 0, 70],
[0, 0, 192], [32, 32, 32], [120, 10, 10], [0, 0, 0]])
colormap_v2_0 = np.array([[165, 42, 42], [0, 192, 0], [250, 170, 31],
[250, 170, 32], [196, 196, 196], [190, 153, 153],
[180, 165, 180], [90, 120, 150], [250, 170, 33],
[250, 170, 34], [128, 128, 128], [250, 170, 35],
[102, 102, 156], [128, 64, 255], [140, 140, 200],
[170, 170, 170], [250, 170, 36], [250, 170, 160],
[250, 170, 37], [96, 96, 96], [230, 150, 140],
[128, 64, 128], [110, 110, 110], [110, 110, 110],
[244, 35, 232], [128, 196, 128], [150, 100, 100],
[70, 70, 70], [150, 150, 150], [150, 120, 90],
[220, 20, 60], [220, 20, 60], [255, 0, 0],
[255, 0, 100], [255, 0, 200], [255, 255, 255],
[255, 255, 255], [250, 170, 29], [250, 170, 28],
[250, 170, 26], [250, 170, 25], [250, 170, 24],
[250, 170, 22], [250, 170, 21], [250, 170, 20],
[255, 255, 255], [250, 170, 19], [250, 170, 18],
[250, 170, 12], [250, 170, 11], [255, 255, 255],
[255, 255, 255], [250, 170, 16], [250, 170, 15],
[250, 170, 15], [255, 255, 255], [255, 255, 255],
[255, 255, 255], [255, 255, 255], [64, 170, 64],
[230, 160, 50], [70, 130, 180], [190, 255, 255],
[152, 251, 152], [107, 142, 35], [0, 170, 30],
[255, 255, 128], [250, 0, 30], [100, 140, 180],
[220, 128, 128], [222, 40, 40], [100, 170, 30],
[40, 40, 40], [33, 33, 33], [100, 128, 160],
[20, 20, 255], [142, 0, 0], [70, 100, 150],
[250, 171, 30], [250, 172, 30], [250, 173, 30],
[250, 174, 30], [250, 175, 30], [250, 176, 30],
[210, 170, 100], [153, 153, 153], [153, 153, 153],
[128, 128, 128], [0, 0, 80], [210, 60, 60],
[250, 170, 30], [250, 170, 30], [250, 170, 30],
[250, 170, 30], [250, 170, 30], [250, 170, 30],
[192, 192, 192], [192, 192, 192], [192, 192, 192],
[220, 220, 0], [220, 220, 0], [0, 0, 196],
[192, 192, 192], [220, 220, 0], [140, 140, 20],
[119, 11, 32], [150, 0, 255], [0, 60, 100],
[0, 0, 142], [0, 0, 90], [0, 0, 230], [0, 80, 100],
[128, 64, 64], [0, 0, 110], [0, 0, 70], [0, 0, 142],
[0, 0, 192], [170, 170, 170], [32, 32, 32],
[111, 74, 0], [120, 10, 10], [81, 0, 81],
[111, 111, 0], [0, 0, 0]])
def parse_args():
parser = argparse.ArgumentParser(
description='Convert Mapillary dataset to mmsegmentation format')
parser.add_argument('dataset_path', help='Mapillary folder path')
parser.add_argument(
'--version',
default='all',
help="Mapillary labels version, 'v1.2','v2.0','all'")
parser.add_argument('-o', '--out_dir', help='output path')
parser.add_argument(
'--nproc', default=1, type=int, help='number of process')
args = parser.parse_args()
return args
def mapillary_colormap2label(colormap: np.ndarray) -> list:
"""Create a `list` shaped (256^3, 1), convert each color palette to a
number, which can use to find the correct label value.
For example labels 0--Bird--[165, 42, 42]
(165*256 + 42) * 256 + 42 = 10824234 (This is list's index])
`colormap2label[10824234] = 0`
In converting, if a RGB pixel value is [165, 42, 42],
through colormap2label[10824234]-->can quickly find
this labels value is 0.
Through matrix multiply to compute a img is very fast.
Args:
colormap (np.ndarray): Mapillary Vistas Dataset palette
Returns:
list: values are mask labels,
indexes are palette's convert results.、
"""
colormap2label = np.zeros(256**3, dtype=np.longlong)
for i, colormap_ in enumerate(colormap):
colormap2label[(colormap_[0] * 256 + colormap_[1]) * 256 +
colormap_[2]] = i
return colormap2label
def mapillary_masklabel(rgb_label: np.ndarray,
colormap2label: list) -> np.ndarray:
"""Computing a img mask label through `colormap2label` get in
`mapillary_colormap2label(COLORMAP: np.ndarray)`
Args:
rgb_label (np.array): a RGB labels img.
colormap2label (list): get in mapillary_colormap2label(colormap)
Returns:
np.ndarray: mask labels array.
"""
colormap_ = rgb_label.astype('uint32')
idx = np.array((colormap_[:, :, 0] * 256 + colormap_[:, :, 1]) * 256 +
colormap_[:, :, 2]).astype('uint32')
return colormap2label[idx]
def RGB2Mask(rgb_label_path: str, colormap2label: list) -> None:
"""Mapillary Vistas Dataset provide 8-bit with color-palette class-specific
labels for semantic segmentation. However, semantic segmentation needs
single channel mask labels.
This code is about converting mapillary RGB labels
{traing,validation/v1.2,v2.0/labels} to mask labels
{{traing,validation/v1.2,v2.0/labels_mask}
Args:
rgb_label_path (str): image absolute path.
dataset_version (str): v1.2 or v2.0 to choose color_map .
"""
rgb_label = mmcv.imread(rgb_label_path, channel_order='rgb')
masks_label = mapillary_masklabel(rgb_label, colormap2label)
mmcv.imwrite(
masks_label.astype(np.uint8),
rgb_label_path.replace('labels', 'labels_mask'))
def main():
colormap2label_v1_2 = mapillary_colormap2label(colormap_v1_2)
colormap2label_v2_0 = mapillary_colormap2label(colormap_v2_0)
dataset_path = args.dataset_path
if args.out_dir is None:
out_dir = dataset_path
else:
out_dir = args.out_dir
RGB_labels_path = []
RGB_labels_v1_2_path = []
RGB_labels_v2_0_path = []
print('Scanning labels path....')
for label_path in scandir(dataset_path, suffix='.png', recursive=True):
if 'labels' in label_path:
rgb_label_path = osp.join(dataset_path, label_path)
RGB_labels_path.append(rgb_label_path)
if 'v1.2' in label_path:
RGB_labels_v1_2_path.append(rgb_label_path)
elif 'v2.0' in label_path:
RGB_labels_v2_0_path.append(rgb_label_path)
if args.version == 'all':
print(f'Totaly found {len(RGB_labels_path)} {args.version} RGB labels')
elif args.version == 'v1.2':
print(f'Found {len(RGB_labels_v1_2_path)} {args.version} RGB labels')
elif args.version == 'v2.0':
print(f'Found {len(RGB_labels_v2_0_path)} {args.version} RGB labels')
print('Making directories...')
mkdir_or_exist(osp.join(out_dir, 'training', 'v1.2', 'labels_mask'))
mkdir_or_exist(osp.join(out_dir, 'validation', 'v1.2', 'labels_mask'))
mkdir_or_exist(osp.join(out_dir, 'training', 'v2.0', 'labels_mask'))
mkdir_or_exist(osp.join(out_dir, 'validation', 'v2.0', 'labels_mask'))
print('Directories Have Made...')
if args.nproc > 1:
if args.version == 'all':
print('Converting v1.2 ....')
track_parallel_progress(
partial(RGB2Mask, colormap2label=colormap2label_v1_2),
RGB_labels_v1_2_path,
nproc=args.nproc)
print('Converting v2.0 ....')
track_parallel_progress(
partial(RGB2Mask, colormap2label=colormap2label_v2_0),
RGB_labels_v2_0_path,
nproc=args.nproc)
elif args.version == 'v1.2':
print('Converting v1.2 ....')
track_parallel_progress(
partial(RGB2Mask, colormap2label=colormap2label_v1_2),
RGB_labels_v1_2_path,
nproc=args.nproc)
elif args.version == 'v2.0':
print('Converting v2.0 ....')
track_parallel_progress(
partial(RGB2Mask, colormap2label=colormap2label_v2_0),
RGB_labels_v2_0_path,
nproc=args.nproc)
else:
if args.version == 'all':
print('Converting v1.2 ....')
track_progress(
partial(RGB2Mask, colormap2label=colormap2label_v1_2),
RGB_labels_v1_2_path)
print('Converting v2.0 ....')
track_progress(
partial(RGB2Mask, colormap2label=colormap2label_v2_0),
RGB_labels_v2_0_path)
elif args.version == 'v1.2':
print('Converting v1.2 ....')
track_progress(
partial(RGB2Mask, colormap2label=colormap2label_v1_2),
RGB_labels_v1_2_path)
elif args.version == 'v2.0':
print('Converting v2.0 ....')
track_progress(
partial(RGB2Mask, colormap2label=colormap2label_v2_0),
RGB_labels_v2_0_path)
print('Have convert Mapillary Vistas Datasets RGB labels to Mask labels!')
if __name__ == '__main__':
args = parse_args()
main()