[Datasets] Add Mapillary Vistas Datasets to MMSeg Core Package. (#2576)

## [Datasets] Add Mapillary Vistas Datasets to MMSeg Core Package .
## Motivation
Add Mapillary Vistas Datasets to core package.
Old PR #2484 

## Modification
- Add Mapillary Vistas Datasets to core package.
- Delete `tools/datasets_convert/mapillary.py` , dataset does't need
converting.
- Add `schedule_240k.py`  config.
- Add configs files.  
  ```none
  deeplabv3plus_r101-d8_4xb2-240k_mapillay_v1-512x1024.py
  deeplabv3plus_r101-d8_4xb2-240k_mapillay_v2-512x1024.py
  maskformer_swin-s_4xb2-240k_mapillary_v1-512x1024.py
  maskformer_swin-s_4xb2-240k_mapillary_v2-512x1024.py
  maskformer_r101-d8_4xb2-240k_mapillary_v1-512x1024.py
  maskformer_r101-d8_4xb2-240k_mapillary_v2-512x1024.py
  pspnet_r101-d8_4xb2-240k_mapillay_v1-512x1024.py
  pspnet_r101-d8_4xb2-240k_mapillay_v2-512x1024.py
  ```
- Synchronized changes to `projects/mapillary_datasets`

---------

Co-authored-by: Miao Zheng <76149310+MeowZheng@users.noreply.github.com>
Co-authored-by: xiexinch <xiexinch@outlook.com>
This commit is contained in:
Tianlong Ai 2023-03-15 14:44:38 +08:00 committed by GitHub
parent 447a398c24
commit 8c89ff3dd1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
30 changed files with 1090 additions and 519 deletions

View File

@ -181,6 +181,7 @@ Results and models are available in the [model zoo](docs/en/model_zoo.md).
- [x] [Potsdam](https://github.com/open-mmlab/mmsegmentation/blob/1.x/docs/en/user_guides/2_dataset_prepare.md#isprs-potsdam)
- [x] [Vaihingen](https://github.com/open-mmlab/mmsegmentation/blob/1.x/docs/en/user_guides/2_dataset_prepare.md#isprs-vaihingen)
- [x] [iSAID](https://github.com/open-mmlab/mmsegmentation/blob/1.x/docs/en/user_guides/2_dataset_prepare.md#isaid)
- [x] [Mapillary Vistas](https://github.com/open-mmlab/mmsegmentation/blob/1.x/docs/en/user_guides/2_dataset_prepare.md#mapillary-vistas-datasets)
</details>

View File

@ -162,6 +162,7 @@ MMSegmentation 是一个基于 PyTorch 的语义分割开源工具箱。它是 O
- [x] [Potsdam](https://github.com/open-mmlab/mmsegmentation/blob/1.x/docs/zh_cn/dataset_prepare.md#isprs-potsdam)
- [x] [Vaihingen](https://github.com/open-mmlab/mmsegmentation/blob/1.x/docs/zh_cn/dataset_prepare.md#isprs-vaihingen)
- [x] [iSAID](https://github.com/open-mmlab/mmsegmentation/blob/1.x/docs/zh_cn/dataset_prepare.md#isaid)
- [x] [Mapillary Vistas](https://github.com/open-mmlab/mmsegmentation/blob/1.x/docs/en/user_guides/2_dataset_prepare.md#mapillary-vistas-datasets)
</details>

View File

@ -1,5 +1,5 @@
# dataset settings
dataset_type = 'MapillaryDataset_v1_2'
dataset_type = 'MapillaryDataset_v1'
data_root = 'data/mapillary/'
crop_size = (512, 1024)
train_pipeline = [
@ -48,8 +48,7 @@ train_dataloader = dict(
type=dataset_type,
data_root=data_root,
data_prefix=dict(
img_path='training/images',
seg_map_path='training/v1.2/labels_mask'),
img_path='training/images', seg_map_path='training/v1.2/labels'),
pipeline=train_pipeline))
val_dataloader = dict(
batch_size=1,
@ -61,7 +60,7 @@ val_dataloader = dict(
data_root=data_root,
data_prefix=dict(
img_path='validation/images',
seg_map_path='validation/v1.2/labels_mask'),
seg_map_path='validation/v1.2/labels'),
pipeline=test_pipeline))
test_dataloader = val_dataloader

View File

@ -0,0 +1,37 @@
# dataset settings
_base_ = './mapillary_v1.py'
metainfo = dict(
classes=('Bird', 'Ground Animal', 'Curb', 'Fence', 'Guard Rail', 'Barrier',
'Wall', 'Bike Lane', 'Crosswalk - Plain', 'Curb Cut', 'Parking',
'Pedestrian Area', 'Rail Track', 'Road', 'Service Lane',
'Sidewalk', 'Bridge', 'Building', 'Tunnel', 'Person', 'Bicyclist',
'Motorcyclist', 'Other Rider', 'Lane Marking - Crosswalk',
'Lane Marking - General', 'Mountain', 'Sand', 'Sky', 'Snow',
'Terrain', 'Vegetation', 'Water', 'Banner', 'Bench', 'Bike Rack',
'Billboard', 'Catch Basin', 'CCTV Camera', 'Fire Hydrant',
'Junction Box', 'Mailbox', 'Manhole', 'Phone Booth', 'Pothole',
'Street Light', 'Pole', 'Traffic Sign Frame', 'Utility Pole',
'Traffic Light', 'Traffic Sign (Back)', 'Traffic Sign (Front)',
'Trash Can', 'Bicycle', 'Boat', 'Bus', 'Car', 'Caravan',
'Motorcycle', 'On Rails', 'Other Vehicle', 'Trailer', 'Truck',
'Wheeled Slow', 'Car Mount', 'Ego Vehicle'),
palette=[[165, 42, 42], [0, 192, 0], [196, 196, 196], [190, 153, 153],
[180, 165, 180], [90, 120, 150], [102, 102, 156], [128, 64, 255],
[140, 140, 200], [170, 170, 170], [250, 170, 160], [96, 96, 96],
[230, 150, 140], [128, 64, 128], [110, 110, 110], [244, 35, 232],
[150, 100, 100], [70, 70, 70], [150, 120, 90], [220, 20, 60],
[255, 0, 0], [255, 0, 100], [255, 0, 200], [200, 128, 128],
[255, 255, 255], [64, 170, 64], [230, 160, 50], [70, 130, 180],
[190, 255, 255], [152, 251, 152], [107, 142, 35], [0, 170, 30],
[255, 255, 128], [250, 0, 30], [100, 140, 180], [220, 220, 220],
[220, 128, 128], [222, 40, 40], [100, 170, 30], [40, 40, 40],
[33, 33, 33], [100, 128, 160], [142, 0, 0], [70, 100, 150],
[210, 170, 100], [153, 153, 153], [128, 128, 128], [0, 0, 80],
[250, 170, 30], [192, 192, 192], [220, 220, 0], [140, 140, 20],
[119, 11, 32], [150, 0, 255], [0, 60, 100], [0, 0, 142],
[0, 0, 90], [0, 0, 230], [0, 80, 100], [128, 64, 64], [0, 0, 110],
[0, 0, 70], [0, 0, 192], [32, 32, 32], [120, 10, 10]])
train_dataloader = dict(dataset=dict(metainfo=metainfo))
val_dataloader = dict(dataset=dict(metainfo=metainfo))
test_dataloader = val_dataloader

View File

@ -1,5 +1,5 @@
# dataset settings
dataset_type = 'MapillaryDataset_v2_0'
dataset_type = 'MapillaryDataset_v2'
data_root = 'data/mapillary/'
crop_size = (512, 1024)
train_pipeline = [
@ -48,8 +48,7 @@ train_dataloader = dict(
type=dataset_type,
data_root=data_root,
data_prefix=dict(
img_path='training/images',
seg_map_path='training/v2.0/labels_mask'),
img_path='training/images', seg_map_path='training/v2.0/labels'),
pipeline=train_pipeline))
val_dataloader = dict(
batch_size=1,
@ -61,7 +60,7 @@ val_dataloader = dict(
data_root=data_root,
data_prefix=dict(
img_path='validation/images',
seg_map_path='validation/v2.0/labels_mask'),
seg_map_path='validation/v2.0/labels'),
pipeline=test_pipeline))
test_dataloader = val_dataloader

View File

@ -0,0 +1,25 @@
# optimizer
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
optim_wrapper = dict(type='OptimWrapper', optimizer=optimizer, clip_grad=None)
# learning policy
param_scheduler = [
dict(
type='PolyLR',
eta_min=1e-4,
power=0.9,
begin=0,
end=240000,
by_epoch=False)
]
# training schedule for 240k
train_cfg = dict(
type='IterBasedTrainLoop', max_iters=240000, val_interval=24000)
val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')
default_hooks = dict(
timer=dict(type='IterTimerHook'),
logger=dict(type='LoggerHook', interval=50, log_metric_by_epoch=False),
param_scheduler=dict(type='ParamSchedulerHook'),
checkpoint=dict(type='CheckpointHook', by_epoch=False, interval=24000),
sampler_seed=dict(type='DistSamplerSeedHook'),
visualization=dict(type='SegVisualizationHook'))

View File

@ -124,6 +124,12 @@ Spatial pyramid pooling module or encode-decoder structure are used in deep neur
| DeepLabV3+ | R-18-D8 | 896x896 | 80000 | 6.19 | 24.81 | 61.35 | 62.61 | [config](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/configs/deeplabv3plus/deeplabv3plus_r18-d8_4xb4-80k_isaid-896x896.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/deeplabv3plus/deeplabv3plus_r18-d8_4x4_896x896_80k_isaid/deeplabv3plus_r18-d8_4x4_896x896_80k_isaid_20220110_180526-7059991d.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/deeplabv3plus/deeplabv3plus_r18-d8_4x4_896x896_80k_isaid/deeplabv3plus_r18-d8_4x4_896x896_80k_isaid_20220110_180526.log.json) |
| DeepLabV3+ | R-50-D8 | 896x896 | 80000 | 21.45 | 8.42 | 67.06 | 68.02 | [config](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/configs/deeplabv3plus/deeplabv3plus_r50-d8_4xb4-80k_isaid-896x896.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/deeplabv3plus/deeplabv3plus_r50-d8_4x4_896x896_80k_isaid/deeplabv3plus_r50-d8_4x4_896x896_80k_isaid_20220110_180526-598be439.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/deeplabv3plus/deeplabv3plus_r50-d8_4x4_896x896_80k_isaid/deeplabv3plus_r50-d8_4x4_896x896_80k_isaid_20220110_180526.log.json) |
### Mapillary Vistas v1.2
| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | Device | mIoU | mIoU(ms+flip) | config | download |
| ---------- | -------- | --------- | ------: | -------- | -------------- | ------ | ----: | ------------: | ---------------------------------------------------------------------------------------------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| DeepLabV3+ | R-50-D8 | 1280x1280 | 300000 | 24.04 | 17.92 | A100 | 47.35 | - | [config](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/configs/deeplabv3plus/deeplabv3plus_r50-d8_4xb2-300k_mapillay_v1_65-1280x1280.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/deeplabv3plus/deeplabv3plus_r50-d8_4xb2-300k_mapillay_v1_65-1280x1280/deeplabv3plus_r50-d8_4xb2-300k_mapillay_v1_65-1280x1280_20230301_110504-655f8e43.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/deeplabv3plus/deeplabv3plus_r50-d8_4xb2-300k_mapillay_v1_65-1280x1280/deeplabv3plus_r50-d8_4xb2-300k_mapillay_v1_65-1280x1280_20230301_110504.json) |
Note:
- `D-8`/`D-16` here corresponding to the output stride 8/16 setting for DeepLab series.

View File

@ -11,6 +11,7 @@ Collections:
- Potsdam
- Vaihingen
- iSAID
- Mapillary Vistas v1.2
Paper:
URL: https://arxiv.org/abs/1802.02611
Title: Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation
@ -848,3 +849,24 @@ Models:
mIoU(ms+flip): 68.02
Config: configs/deeplabv3plus/deeplabv3plus_r50-d8_4xb4-80k_isaid-896x896.py
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/deeplabv3plus/deeplabv3plus_r50-d8_4x4_896x896_80k_isaid/deeplabv3plus_r50-d8_4x4_896x896_80k_isaid_20220110_180526-598be439.pth
- Name: deeplabv3plus_r50-d8_4xb2-300k_mapillay_v1_65-1280x1280
In Collection: DeepLabV3+
Metadata:
backbone: R-50-D8
crop size: (1280,1280)
lr schd: 300000
inference time (ms/im):
- value: 55.8
hardware: V100
backend: PyTorch
batch size: 1
mode: FP32
resolution: (1280,1280)
Training Memory (GB): 24.04
Results:
- Task: Semantic Segmentation
Dataset: Mapillary Vistas v1.2
Metrics:
mIoU: 47.35
Config: configs/deeplabv3plus/deeplabv3plus_r50-d8_4xb2-300k_mapillay_v1_65-1280x1280.py
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/deeplabv3plus/deeplabv3plus_r50-d8_4xb2-300k_mapillay_v1_65-1280x1280/deeplabv3plus_r50-d8_4xb2-300k_mapillay_v1_65-1280x1280_20230301_110504-655f8e43.pth

View File

@ -0,0 +1,58 @@
_base_ = [
'../_base_/models/deeplabv3plus_r50-d8.py',
'../_base_/datasets/mapillary_v1_65.py',
'../_base_/default_runtime.py',
]
crop_size = (1280, 1280)
data_preprocessor = dict(size=crop_size)
model = dict(
data_preprocessor=data_preprocessor,
pretrained='open-mmlab://resnet50_v1c',
backbone=dict(depth=50),
decode_head=dict(num_classes=65),
auxiliary_head=dict(num_classes=65))
iters = 300000
# optimizer
optimizer = dict(
type='AdamW', lr=0.0001, betas=(0.9, 0.999), weight_decay=0.0001)
# optimizer
optim_wrapper = dict(
type='OptimWrapper',
optimizer=optimizer,
clip_grad=dict(max_norm=0.01, norm_type=2),
paramwise_cfg=dict(
custom_keys={'backbone': dict(lr_mult=0.1, decay_mult=1.0)}))
param_scheduler = [
dict(
type='PolyLR',
eta_min=0,
power=0.9,
begin=0,
end=iters,
by_epoch=False)
]
# training schedule for 300k
train_cfg = dict(
type='IterBasedTrainLoop', max_iters=iters, val_interval=iters // 10)
val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')
default_hooks = dict(
timer=dict(type='IterTimerHook'),
logger=dict(type='LoggerHook', interval=50, log_metric_by_epoch=False),
param_scheduler=dict(type='ParamSchedulerHook'),
checkpoint=dict(
type='CheckpointHook', by_epoch=False, interval=iters // 10),
sampler_seed=dict(type='DistSamplerSeedHook'),
visualization=dict(type='SegVisualizationHook'))
train_dataloader = dict(batch_size=2)
# Default setting for scaling LR automatically
# - `enable` means enable scaling LR automatically
# or not by default.
# - `base_batch_size` = (4 GPUs) x (2 samples per GPU).
auto_scale_lr = dict(enable=False, base_batch_size=8)

View File

@ -154,6 +154,29 @@ mmsegmentation
│ │ │ ├── training
│ │ │ ├── validation
│ │ │ ├── test
│ ├── mapillary
│ │ ├── training
│ │ │ ├── images
│ │ │ ├── v1.2
| │ │ │ ├── instances
| │ │ │ ├── labels
| │   │   │ └── panoptic
│ │ │ ├── v2.0
| │ │ │ ├── instances
| │ │ │ ├── labels
| │ │ │ ├── panoptic
| │   │   │ └── polygons
│ │ ├── validation
│ │ │ ├── images
| │ │ ├── v1.2
| │ │ │ ├── instances
| │ │ │ ├── labels
| │   │   │ └── panoptic
│ │ │ ├── v2.0
| │ │ │ ├── instances
| │ │ │ ├── labels
| │ │ │ ├── panoptic
| │   │   │ └── polygons
```
### Cityscapes
@ -551,3 +574,54 @@ The script will make directory structure below:
```
It includes 400 images for training, 400 images for validation and 400 images for testing which is the same as REFUGE 2018 dataset.
## Mapillary Vistas Datasets
- The dataset could be download [here](https://www.mapillary.com/dataset/vistas) after registration.
- Mapillary Vistas Dataset use 8-bit with color-palette to store labels. No conversion operation is required.
- Assumption you have put the dataset zip file in `mmsegmentation/data/mapillary`
- Please run the following commands to unzip dataset.
```bash
cd data/mapillary
unzip An-ZjB1Zm61yAZG0ozTymz8I8NqI4x0MrYrh26dq7kPgfu8vf9ImrdaOAVOFYbJ2pNAgUnVGBmbue9lTgdBOb5BbKXIpFs0fpYWqACbrQDChAA2fdX0zS9PcHu7fY8c-FOvyBVxPNYNFQuM.zip
```
- After unzip, you will get Mapillary Vistas Dataset like this structure. Semantic segmentation mask labels in `labels` folder.
```none
mmsegmentation
├── mmseg
├── tools
├── configs
├── data
│ ├── mapillary
│ │ ├── training
│ │ │ ├── images
│ │ │ ├── v1.2
| │ │ │ ├── instances
| │ │ │ ├── labels
| │   │   │ └── panoptic
│ │ │ ├── v2.0
| │ │ │ ├── instances
| │ │ │ ├── labels
| │ │ │ ├── panoptic
| │   │   │ └── polygons
│ │ ├── validation
│ │ │ ├── images
| │ │ ├── v1.2
| │ │ │ ├── instances
| │ │ │ ├── labels
| │   │   │ └── panoptic
│ │ │ ├── v2.0
| │ │ │ ├── instances
| │ │ │ ├── labels
| │ │ │ ├── panoptic
| │   │   │ └── polygons
```
- You could set Datasets version with `MapillaryDataset_v1` and `MapillaryDataset_v2` in your configs.
View the Mapillary Vistas Datasets config file here [V1.2](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/configs/_base_/datasets/mapillary_v1.py) and [V2.0](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/configs/_base_/datasets/mapillary_v2.py)

View File

@ -14,6 +14,7 @@ from .isaid import iSAIDDataset
from .isprs import ISPRSDataset
from .lip import LIPDataset
from .loveda import LoveDADataset
from .mapillary import MapillaryDataset_v1, MapillaryDataset_v2
from .night_driving import NightDrivingDataset
from .pascal_context import PascalContextDataset, PascalContextDataset59
from .potsdam import PotsdamDataset
@ -49,5 +50,6 @@ __all__ = [
'DecathlonDataset', 'LIPDataset', 'ResizeShortestEdge',
'BioMedicalGaussianNoise', 'BioMedicalGaussianBlur',
'BioMedicalRandomGamma', 'BioMedical3DPad', 'RandomRotFlip',
'SynapseDataset', 'REFUGEDataset'
'SynapseDataset', 'REFUGEDataset', 'MapillaryDataset_v1',
'MapillaryDataset_v2'
]

View File

@ -1,10 +1,72 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmseg.datasets.basesegdataset import BaseSegDataset
from mmseg.registry import DATASETS
from .basesegdataset import BaseSegDataset
@DATASETS.register_module()
class MapillaryDataset_v2_0(BaseSegDataset):
class MapillaryDataset_v1(BaseSegDataset):
"""Mapillary Vistas Dataset.
Dataset paper link:
http://ieeexplore.ieee.org/document/8237796/
v1.2 contain 66 object classes.
(37 instance-specific)
v2.0 contain 124 object classes.
(70 instance-specific, 46 stuff, 8 void or crowd).
The ``img_suffix`` is fixed to '.jpg' and ``seg_map_suffix`` is
fixed to '.png' for Mapillary Vistas Dataset.
"""
METAINFO = dict(
classes=('Bird', 'Ground Animal', 'Curb', 'Fence', 'Guard Rail',
'Barrier', 'Wall', 'Bike Lane', 'Crosswalk - Plain',
'Curb Cut', 'Parking', 'Pedestrian Area', 'Rail Track',
'Road', 'Service Lane', 'Sidewalk', 'Bridge', 'Building',
'Tunnel', 'Person', 'Bicyclist', 'Motorcyclist',
'Other Rider', 'Lane Marking - Crosswalk',
'Lane Marking - General', 'Mountain', 'Sand', 'Sky', 'Snow',
'Terrain', 'Vegetation', 'Water', 'Banner', 'Bench',
'Bike Rack', 'Billboard', 'Catch Basin', 'CCTV Camera',
'Fire Hydrant', 'Junction Box', 'Mailbox', 'Manhole',
'Phone Booth', 'Pothole', 'Street Light', 'Pole',
'Traffic Sign Frame', 'Utility Pole', 'Traffic Light',
'Traffic Sign (Back)', 'Traffic Sign (Front)', 'Trash Can',
'Bicycle', 'Boat', 'Bus', 'Car', 'Caravan', 'Motorcycle',
'On Rails', 'Other Vehicle', 'Trailer', 'Truck',
'Wheeled Slow', 'Car Mount', 'Ego Vehicle', 'Unlabeled'),
palette=[[165, 42, 42], [0, 192, 0], [196, 196, 196], [190, 153, 153],
[180, 165, 180], [90, 120, 150], [102, 102, 156],
[128, 64, 255], [140, 140, 200], [170, 170, 170],
[250, 170, 160], [96, 96, 96],
[230, 150, 140], [128, 64, 128], [110, 110, 110],
[244, 35, 232], [150, 100, 100], [70, 70, 70], [150, 120, 90],
[220, 20, 60], [255, 0, 0], [255, 0, 100], [255, 0, 200],
[200, 128, 128], [255, 255, 255], [64, 170,
64], [230, 160, 50],
[70, 130, 180], [190, 255, 255], [152, 251, 152],
[107, 142, 35], [0, 170, 30], [255, 255, 128], [250, 0, 30],
[100, 140, 180], [220, 220, 220], [220, 128, 128],
[222, 40, 40], [100, 170, 30], [40, 40, 40], [33, 33, 33],
[100, 128, 160], [142, 0, 0], [70, 100, 150], [210, 170, 100],
[153, 153, 153], [128, 128, 128], [0, 0, 80], [250, 170, 30],
[192, 192, 192], [220, 220, 0], [140, 140, 20], [119, 11, 32],
[150, 0, 255], [0, 60, 100], [0, 0, 142], [0, 0, 90],
[0, 0, 230], [0, 80, 100], [128, 64, 64], [0, 0, 110],
[0, 0, 70], [0, 0, 192], [32, 32, 32], [120, 10,
10], [0, 0, 0]])
def __init__(self,
img_suffix='.jpg',
seg_map_suffix='.png',
**kwargs) -> None:
super().__init__(
img_suffix=img_suffix, seg_map_suffix=seg_map_suffix, **kwargs)
@DATASETS.register_module()
class MapillaryDataset_v2(BaseSegDataset):
"""Mapillary Vistas Dataset.
Dataset paper link:

View File

@ -126,6 +126,126 @@ def stare_classes():
return ['background', 'vessel']
def mapillary_v1_classes():
"""mapillary_v1 class names for external use."""
return [
'Bird', 'Ground Animal', 'Curb', 'Fence', 'Guard Rail', 'Barrier',
'Wall', 'Bike Lane', 'Crosswalk - Plain', 'Curb Cut', 'Parking',
'Pedestrian Area', 'Rail Track', 'Road', 'Service Lane', 'Sidewalk',
'Bridge', 'Building', 'Tunnel', 'Person', 'Bicyclist', 'Motorcyclist',
'Other Rider', 'Lane Marking - Crosswalk', 'Lane Marking - General',
'Mountain', 'Sand', 'Sky', 'Snow', 'Terrain', 'Vegetation', 'Water',
'Banner', 'Bench', 'Bike Rack', 'Billboard', 'Catch Basin',
'CCTV Camera', 'Fire Hydrant', 'Junction Box', 'Mailbox', 'Manhole',
'Phone Booth', 'Pothole', 'Street Light', 'Pole', 'Traffic Sign Frame',
'Utility Pole', 'Traffic Light', 'Traffic Sign (Back)',
'Traffic Sign (Front)', 'Trash Can', 'Bicycle', 'Boat', 'Bus', 'Car',
'Caravan', 'Motorcycle', 'On Rails', 'Other Vehicle', 'Trailer',
'Truck', 'Wheeled Slow', 'Car Mount', 'Ego Vehicle', 'Unlabeled'
]
def mapillary_v1_palette():
"""mapillary_v1_ palette for external use."""
return [[165, 42, 42], [0, 192, 0], [196, 196, 196], [190, 153, 153],
[180, 165, 180], [90, 120, 150], [102, 102, 156], [128, 64, 255],
[140, 140, 200], [170, 170, 170], [250, 170, 160], [96, 96, 96],
[230, 150, 140], [128, 64, 128], [110, 110, 110], [244, 35, 232],
[150, 100, 100], [70, 70, 70], [150, 120, 90], [220, 20, 60],
[255, 0, 0], [255, 0, 100], [255, 0, 200], [200, 128, 128],
[255, 255, 255], [64, 170, 64], [230, 160, 50], [70, 130, 180],
[190, 255, 255], [152, 251, 152], [107, 142, 35], [0, 170, 30],
[255, 255, 128], [250, 0, 30], [100, 140, 180], [220, 220, 220],
[220, 128, 128], [222, 40, 40], [100, 170, 30], [40, 40, 40],
[33, 33, 33], [100, 128, 160], [142, 0, 0], [70, 100, 150],
[210, 170, 100], [153, 153, 153], [128, 128, 128], [0, 0, 80],
[250, 170, 30], [192, 192, 192], [220, 220, 0], [140, 140, 20],
[119, 11, 32], [150, 0, 255], [0, 60, 100], [0, 0, 142],
[0, 0, 90], [0, 0, 230], [0, 80, 100], [128, 64, 64], [0, 0, 110],
[0, 0, 70], [0, 0, 192], [32, 32, 32], [120, 10, 10], [0, 0, 0]]
def mapillary_v2_classes():
"""mapillary_v2 class names for external use."""
return [
'Bird', 'Ground Animal', 'Ambiguous Barrier', 'Concrete Block', 'Curb',
'Fence', 'Guard Rail', 'Barrier', 'Road Median', 'Road Side',
'Lane Separator', 'Temporary Barrier', 'Wall', 'Bike Lane',
'Crosswalk - Plain', 'Curb Cut', 'Driveway', 'Parking',
'Parking Aisle', 'Pedestrian Area', 'Rail Track', 'Road',
'Road Shoulder', 'Service Lane', 'Sidewalk', 'Traffic Island',
'Bridge', 'Building', 'Garage', 'Tunnel', 'Person', 'Person Group',
'Bicyclist', 'Motorcyclist', 'Other Rider',
'Lane Marking - Dashed Line', 'Lane Marking - Straight Line',
'Lane Marking - Zigzag Line', 'Lane Marking - Ambiguous',
'Lane Marking - Arrow (Left)', 'Lane Marking - Arrow (Other)',
'Lane Marking - Arrow (Right)',
'Lane Marking - Arrow (Split Left or Straight)',
'Lane Marking - Arrow (Split Right or Straight)',
'Lane Marking - Arrow (Straight)', 'Lane Marking - Crosswalk',
'Lane Marking - Give Way (Row)', 'Lane Marking - Give Way (Single)',
'Lane Marking - Hatched (Chevron)',
'Lane Marking - Hatched (Diagonal)', 'Lane Marking - Other',
'Lane Marking - Stop Line', 'Lane Marking - Symbol (Bicycle)',
'Lane Marking - Symbol (Other)', 'Lane Marking - Text',
'Lane Marking (only) - Dashed Line', 'Lane Marking (only) - Crosswalk',
'Lane Marking (only) - Other', 'Lane Marking (only) - Test',
'Mountain', 'Sand', 'Sky', 'Snow', 'Terrain', 'Vegetation', 'Water',
'Banner', 'Bench', 'Bike Rack', 'Catch Basin', 'CCTV Camera',
'Fire Hydrant', 'Junction Box', 'Mailbox', 'Manhole', 'Parking Meter',
'Phone Booth', 'Pothole', 'Signage - Advertisement',
'Signage - Ambiguous', 'Signage - Back', 'Signage - Information',
'Signage - Other', 'Signage - Store', 'Street Light', 'Pole',
'Pole Group', 'Traffic Sign Frame', 'Utility Pole', 'Traffic Cone',
'Traffic Light - General (Single)', 'Traffic Light - Pedestrians',
'Traffic Light - General (Upright)',
'Traffic Light - General (Horizontal)', 'Traffic Light - Cyclists',
'Traffic Light - Other', 'Traffic Sign - Ambiguous',
'Traffic Sign (Back)', 'Traffic Sign - Direction (Back)',
'Traffic Sign - Direction (Front)', 'Traffic Sign (Front)',
'Traffic Sign - Parking', 'Traffic Sign - Temporary (Back)',
'Traffic Sign - Temporary (Front)', 'Trash Can', 'Bicycle', 'Boat',
'Bus', 'Car', 'Caravan', 'Motorcycle', 'On Rails', 'Other Vehicle',
'Trailer', 'Truck', 'Vehicle Group', 'Wheeled Slow', 'Water Valve',
'Car Mount', 'Dynamic', 'Ego Vehicle', 'Ground', 'Static', 'Unlabeled'
]
def mapillary_v2_palette():
"""mapillary_v2_ palette for external use."""
return [[165, 42, 42], [0, 192, 0], [250, 170, 31], [250, 170, 32],
[196, 196, 196], [190, 153, 153], [180, 165, 180], [90, 120, 150],
[250, 170, 33], [250, 170, 34], [128, 128, 128], [250, 170, 35],
[102, 102, 156], [128, 64, 255], [140, 140, 200], [170, 170, 170],
[250, 170, 36], [250, 170, 160], [250, 170, 37], [96, 96, 96],
[230, 150, 140], [128, 64, 128], [110, 110, 110], [110, 110, 110],
[244, 35, 232], [128, 196, 128], [150, 100, 100], [70, 70, 70],
[150, 150, 150], [150, 120, 90], [220, 20, 60], [220, 20, 60],
[255, 0, 0], [255, 0, 100], [255, 0, 200], [255, 255, 255],
[255, 255, 255], [250, 170, 29], [250, 170, 28], [250, 170, 26],
[250, 170, 25], [250, 170, 24], [250, 170, 22], [250, 170, 21],
[250, 170, 20], [255, 255, 255], [250, 170, 19], [250, 170, 18],
[250, 170, 12], [250, 170, 11], [255, 255, 255], [255, 255, 255],
[250, 170, 16], [250, 170, 15], [250, 170, 15], [255, 255, 255],
[255, 255, 255], [255, 255, 255], [255, 255, 255], [64, 170, 64],
[230, 160, 50], [70, 130, 180], [190, 255, 255], [152, 251, 152],
[107, 142, 35], [0, 170, 30], [255, 255, 128], [250, 0, 30],
[100, 140, 180], [220, 128, 128], [222, 40, 40], [100, 170, 30],
[40, 40, 40], [33, 33, 33], [100, 128, 160], [20, 20, 255],
[142, 0, 0], [70, 100, 150], [250, 171, 30], [250, 172, 30],
[250, 173, 30], [250, 174, 30], [250, 175, 30], [250, 176, 30],
[210, 170, 100], [153, 153, 153], [153, 153, 153], [128, 128, 128],
[0, 0, 80], [210, 60, 60], [250, 170, 30], [250, 170, 30],
[250, 170, 30], [250, 170, 30], [250, 170, 30], [250, 170, 30],
[192, 192, 192], [192, 192, 192], [192, 192, 192], [220, 220, 0],
[220, 220, 0], [0, 0, 196], [192, 192, 192], [220, 220, 0],
[140, 140, 20], [119, 11, 32], [150, 0, 255], [0, 60, 100],
[0, 0, 142], [0, 0, 90], [0, 0, 230], [0, 80, 100], [128, 64, 64],
[0, 0, 110], [0, 0, 70], [0, 0, 142], [0, 0, 192], [170, 170, 170],
[32, 32, 32], [111, 74, 0], [120, 10, 10], [81, 0, 81],
[111, 111, 0], [0, 0, 0]]
def cityscapes_palette():
"""Cityscapes palette for external use."""
return [[128, 64, 128], [244, 35, 232], [70, 70, 70], [102, 102, 156],
@ -313,7 +433,9 @@ dataset_aliases = {
],
'isaid': ['isaid', 'iSAID'],
'stare': ['stare', 'STARE'],
'lip': ['LIP', 'lip']
'lip': ['LIP', 'lip'],
'mapillary_v1': ['mapillary_v1'],
'mapillary_v2': ['mapillary_v2']
}

View File

@ -34,6 +34,7 @@ Preparing `Mapillary Vistas Dataset` dataset following [Mapillary Vistas Dataset
| │   │   │ └── polygons
│ │ ├── validation
│ │ │ ├── images
│ │ │ ├── v1.2
| │ │ │ ├── instances
| │ │ │ ├── labels
| │ │ │ ├── labels_mask
@ -46,12 +47,12 @@ Preparing `Mapillary Vistas Dataset` dataset following [Mapillary Vistas Dataset
| │   │   │ └── polygons
```
### Training commands with `deeplabv3plus_r101-d8_4xb2-240k_mapillay-512x1024.py`
### Training commands
```bash
# Dataset train commands
# at `mmsegmentation` folder
bash tools/dist_train.sh projects/mapillary_dataset/configs/deeplabv3plus_r101-d8_4xb2-240k_mapillay-512x1024.py 4
bash tools/dist_train.sh projects/mapillary_dataset/configs/deeplabv3plus_r101-d8_4xb2-240k_mapillay_v1-512x1024.py 4
```
## Checklist
@ -66,20 +67,20 @@ bash tools/dist_train.sh projects/mapillary_dataset/configs/deeplabv3plus_r101-d
- [x] A full README
- [ ] Milestone 2: Indicates a successful model implementation.
- [x] Milestone 2: Indicates a successful model implementation.
- [ ] Training-time correctness
- [x] Training-time correctness
- [ ] Milestone 3: Good to be a part of our core package!
- [x] Milestone 3: Good to be a part of our core package!
- [ ] Type hints and docstrings
- [x] Type hints and docstrings
- [ ] Unit tests
- [x] Unit tests
- [ ] Code polishing
- [x] Code polishing
- [ ] Metafile.yml
- [x] Metafile.yml
- [ ] Move your modules into the core package following the codebase's file hierarchy structure.
- [x] Move your modules into the core package following the codebase's file hierarchy structure.
- [ ] Refactor your modules into the core package following the codebase's file hierarchy structure.
- [x] Refactor your modules into the core package following the codebase's file hierarchy structure.

View File

@ -0,0 +1,68 @@
# dataset settings
dataset_type = 'MapillaryDataset_v1'
data_root = 'data/mapillary/'
crop_size = (512, 1024)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations'),
dict(
type='RandomResize',
scale=(2048, 1024),
ratio_range=(0.5, 2.0),
keep_ratio=True),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'),
dict(type='PackSegInputs')
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='Resize', scale=(2048, 1024), keep_ratio=True),
# add loading annotation after ``Resize`` because ground truth
# does not need to do resize data transform
dict(type='LoadAnnotations'),
dict(type='PackSegInputs')
]
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
tta_pipeline = [
dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),
dict(
type='TestTimeAug',
transforms=[
[
dict(type='Resize', scale_factor=r, keep_ratio=True)
for r in img_ratios
],
[
dict(type='RandomFlip', prob=0., direction='horizontal'),
dict(type='RandomFlip', prob=1., direction='horizontal')
], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
])
]
train_dataloader = dict(
batch_size=2,
num_workers=4,
persistent_workers=True,
sampler=dict(type='InfiniteSampler', shuffle=True),
dataset=dict(
type=dataset_type,
data_root=data_root,
data_prefix=dict(
img_path='training/images', seg_map_path='training/v1.2/labels'),
pipeline=train_pipeline))
val_dataloader = dict(
batch_size=1,
num_workers=4,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type,
data_root=data_root,
data_prefix=dict(
img_path='validation/images',
seg_map_path='validation/v1.2/labels'),
pipeline=test_pipeline))
test_dataloader = val_dataloader
val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU'])
test_evaluator = val_evaluator

View File

@ -0,0 +1,37 @@
# dataset settings
_base_ = './mapillary_v1.py'
metainfo = dict(
classes=('Bird', 'Ground Animal', 'Curb', 'Fence', 'Guard Rail', 'Barrier',
'Wall', 'Bike Lane', 'Crosswalk - Plain', 'Curb Cut', 'Parking',
'Pedestrian Area', 'Rail Track', 'Road', 'Service Lane',
'Sidewalk', 'Bridge', 'Building', 'Tunnel', 'Person', 'Bicyclist',
'Motorcyclist', 'Other Rider', 'Lane Marking - Crosswalk',
'Lane Marking - General', 'Mountain', 'Sand', 'Sky', 'Snow',
'Terrain', 'Vegetation', 'Water', 'Banner', 'Bench', 'Bike Rack',
'Billboard', 'Catch Basin', 'CCTV Camera', 'Fire Hydrant',
'Junction Box', 'Mailbox', 'Manhole', 'Phone Booth', 'Pothole',
'Street Light', 'Pole', 'Traffic Sign Frame', 'Utility Pole',
'Traffic Light', 'Traffic Sign (Back)', 'Traffic Sign (Front)',
'Trash Can', 'Bicycle', 'Boat', 'Bus', 'Car', 'Caravan',
'Motorcycle', 'On Rails', 'Other Vehicle', 'Trailer', 'Truck',
'Wheeled Slow', 'Car Mount', 'Ego Vehicle'),
palette=[[165, 42, 42], [0, 192, 0], [196, 196, 196], [190, 153, 153],
[180, 165, 180], [90, 120, 150], [102, 102, 156], [128, 64, 255],
[140, 140, 200], [170, 170, 170], [250, 170, 160], [96, 96, 96],
[230, 150, 140], [128, 64, 128], [110, 110, 110], [244, 35, 232],
[150, 100, 100], [70, 70, 70], [150, 120, 90], [220, 20, 60],
[255, 0, 0], [255, 0, 100], [255, 0, 200], [200, 128, 128],
[255, 255, 255], [64, 170, 64], [230, 160, 50], [70, 130, 180],
[190, 255, 255], [152, 251, 152], [107, 142, 35], [0, 170, 30],
[255, 255, 128], [250, 0, 30], [100, 140, 180], [220, 220, 220],
[220, 128, 128], [222, 40, 40], [100, 170, 30], [40, 40, 40],
[33, 33, 33], [100, 128, 160], [142, 0, 0], [70, 100, 150],
[210, 170, 100], [153, 153, 153], [128, 128, 128], [0, 0, 80],
[250, 170, 30], [192, 192, 192], [220, 220, 0], [140, 140, 20],
[119, 11, 32], [150, 0, 255], [0, 60, 100], [0, 0, 142],
[0, 0, 90], [0, 0, 230], [0, 80, 100], [128, 64, 64], [0, 0, 110],
[0, 0, 70], [0, 0, 192], [32, 32, 32], [120, 10, 10]])
train_dataloader = dict(dataset=dict(metainfo=metainfo))
val_dataloader = dict(dataset=dict(metainfo=metainfo))
test_dataloader = val_dataloader

View File

@ -0,0 +1,68 @@
# dataset settings
dataset_type = 'MapillaryDataset_v2'
data_root = 'data/mapillary/'
crop_size = (512, 1024)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations'),
dict(
type='RandomResize',
scale=(2048, 1024),
ratio_range=(0.5, 2.0),
keep_ratio=True),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'),
dict(type='PackSegInputs')
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='Resize', scale=(2048, 1024), keep_ratio=True),
# add loading annotation after ``Resize`` because ground truth
# does not need to do resize data transform
dict(type='LoadAnnotations'),
dict(type='PackSegInputs')
]
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
tta_pipeline = [
dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),
dict(
type='TestTimeAug',
transforms=[
[
dict(type='Resize', scale_factor=r, keep_ratio=True)
for r in img_ratios
],
[
dict(type='RandomFlip', prob=0., direction='horizontal'),
dict(type='RandomFlip', prob=1., direction='horizontal')
], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
])
]
train_dataloader = dict(
batch_size=2,
num_workers=4,
persistent_workers=True,
sampler=dict(type='InfiniteSampler', shuffle=True),
dataset=dict(
type=dataset_type,
data_root=data_root,
data_prefix=dict(
img_path='training/images', seg_map_path='training/v2.0/labels'),
pipeline=train_pipeline))
val_dataloader = dict(
batch_size=1,
num_workers=4,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type,
data_root=data_root,
data_prefix=dict(
img_path='validation/images',
seg_map_path='validation/v2.0/labels'),
pipeline=test_pipeline))
test_dataloader = val_dataloader
val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU'])
test_evaluator = val_evaluator

View File

@ -1,103 +0,0 @@
_base_ = ['./_base_/datasets/mapillary_v1_2.py'] # v 1.2 labels
# _base_ = ['./_base_/datasets/mapillary_v2_0.py'] # v2.0 labels
custom_imports = dict(imports=[
'projects.mapillary_dataset.mmseg.datasets.mapillary_v1_2',
'projects.mapillary_dataset.mmseg.datasets.mapillary_v2_0',
])
norm_cfg = dict(type='SyncBN', requires_grad=True)
data_preprocessor = dict(
type='SegDataPreProcessor',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
bgr_to_rgb=True,
pad_val=0,
seg_pad_val=255,
size=(512, 1024))
model = dict(
type='EncoderDecoder',
data_preprocessor=data_preprocessor,
pretrained=None,
backbone=dict(
type='ResNet',
depth=101,
num_stages=4,
out_indices=(0, 1, 2, 3),
dilations=(1, 1, 2, 4),
strides=(1, 2, 1, 1),
norm_cfg=norm_cfg,
norm_eval=False,
style='pytorch',
contract_dilation=True),
decode_head=dict(
type='DepthwiseSeparableASPPHead',
in_channels=2048,
in_index=3,
channels=512,
dilations=(1, 12, 24, 36),
c1_in_channels=256,
c1_channels=48,
dropout_ratio=0.1,
num_classes=66, # v1.2
# num_classes=124, # v2.0
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
auxiliary_head=dict(
type='FCNHead',
in_channels=1024,
in_index=2,
channels=256,
num_convs=1,
concat_input=False,
dropout_ratio=0.1,
num_classes=66, # v1.2
# num_classes=124, # v2.0
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
train_cfg=dict(),
test_cfg=dict(mode='whole'))
default_scope = 'mmseg'
env_cfg = dict(
cudnn_benchmark=True,
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
dist_cfg=dict(backend='nccl'))
vis_backends = [dict(type='LocalVisBackend')]
visualizer = dict(
type='SegLocalVisualizer',
vis_backends=[dict(type='LocalVisBackend')],
name='visualizer')
log_processor = dict(by_epoch=False)
log_level = 'INFO'
load_from = None
resume = False
tta_model = dict(type='SegTTAModel')
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)
optim_wrapper = dict(
type='OptimWrapper',
optimizer=dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001),
clip_grad=None)
param_scheduler = [
dict(
type='PolyLR',
eta_min=0.0001,
power=0.9,
begin=0,
end=240000,
by_epoch=False)
]
train_cfg = dict(
type='IterBasedTrainLoop', max_iters=240000, val_interval=24000)
val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')
default_hooks = dict(
timer=dict(type='IterTimerHook'),
logger=dict(type='LoggerHook', interval=50, log_metric_by_epoch=False),
param_scheduler=dict(type='ParamSchedulerHook'),
checkpoint=dict(type='CheckpointHook', by_epoch=False, interval=24000),
sampler_seed=dict(type='DistSamplerSeedHook'),
visualization=dict(type='SegVisualizationHook'))

View File

@ -0,0 +1,17 @@
_base_ = [
'../../../configs/_base_/models/deeplabv3plus_r50-d8.py',
'./_base_/datasets/mapillary_v1.py',
'../../../configs/_base_/default_runtime.py',
'../../../configs/_base_/schedules/schedule_240k.py'
]
custom_imports = dict(
imports=['projects.mapillary_dataset.mmseg.datasets.mapillary'])
crop_size = (512, 1024)
data_preprocessor = dict(size=crop_size)
model = dict(
data_preprocessor=data_preprocessor,
pretrained='open-mmlab://resnet101_v1c',
backbone=dict(depth=101),
decode_head=dict(num_classes=66),
auxiliary_head=dict(num_classes=66))

View File

@ -0,0 +1,16 @@
_base_ = [
'../../../configs/_base_/models/deeplabv3plus_r50-d8.py',
'./_base_/datasets/mapillary_v2.py',
'../../../configs/_base_/default_runtime.py',
'../../../configs/_base_/schedules/schedule_240k.py'
]
custom_imports = dict(
imports=['projects.mapillary_dataset.mmseg.datasets.mapillary'])
crop_size = (512, 1024)
data_preprocessor = dict(size=crop_size)
model = dict(
data_preprocessor=data_preprocessor,
pretrained='open-mmlab://resnet101_v1c',
backbone=dict(depth=101),
decode_head=dict(num_classes=124),
auxiliary_head=dict(num_classes=124))

View File

@ -0,0 +1,16 @@
_base_ = [
'../../../configs/_base_/models/pspnet_r50-d8.py',
'./_base_/datasets/mapillary_v1.py',
'../../../configs/_base_/default_runtime.py',
'../../../configs/_base_/schedules/schedule_240k.py'
]
custom_imports = dict(
imports=['projects.mapillary_dataset.mmseg.datasets.mapillary'])
crop_size = (512, 1024)
data_preprocessor = dict(size=crop_size)
model = dict(
data_preprocessor=data_preprocessor,
pretrained='open-mmlab://resnet101_v1c',
backbone=dict(depth=101),
decode_head=dict(num_classes=66),
auxiliary_head=dict(num_classes=66))

View File

@ -0,0 +1,16 @@
_base_ = [
'../../../configs/_base_/models/pspnet_r50-d8.py',
'./_base_/datasets/mapillary_v2.py',
'../../../configs/_base_/default_runtime.py',
'../../../configs/_base_/schedules/schedule_240k.py'
]
custom_imports = dict(
imports=['projects.mapillary_dataset.mmseg.datasets.mapillary'])
crop_size = (512, 1024)
data_preprocessor = dict(size=crop_size)
model = dict(
data_preprocessor=data_preprocessor,
pretrained='open-mmlab://resnet101_v1c',
backbone=dict(depth=101),
decode_head=dict(num_classes=124),
auxiliary_head=dict(num_classes=124))

View File

@ -1,87 +1,20 @@
## Prepare datasets
It is recommended to symlink the dataset root to `$MMSEGMENTATION/data`.
If your folder structure is different, you may need to change the corresponding paths in config files.
```none
mmsegmentation
├── mmseg
├── tools
├── configs
├── data
│ ├── mapillary
│ │ ├── training
│ │ │ ├── images
│ │ │ ├── v1.2
| │ │ │ ├── instances
| │ │ │ ├── labels
| │ │ │ ├── labels_mask
| │   │   │ └── panoptic
│ │ │ ├── v2.0
| │ │ │ ├── instances
| │ │ │ ├── labels
| │ │ │ ├── labels_mask
| │ │ │ ├── panoptic
| │   │   │ └── polygons
│ │ ├── validation
│ │ │ ├── images
| │ │ │ ├── instances
| │ │ │ ├── labels
| │ │ │ ├── labels_mask
| │   │   │ └── panoptic
│ │ │ ├── v2.0
| │ │ │ ├── instances
| │ │ │ ├── labels
| │ │ │ ├── labels_mask
| │ │ │ ├── panoptic
| │   │   │ └── polygons
```
## Mapillary Vistas Datasets
- The dataset could be download [here](https://www.mapillary.com/dataset/vistas) after registration.
- Assumption you have put the dataset zip file in `mmsegmentation/data`
- Mapillary Vistas Dataset use 8-bit with color-palette to store labels. No conversion operation is required.
- Assumption you have put the dataset zip file in `mmsegmentation/data/mapillary`
- Please run the following commands to unzip dataset.
```bash
cd data
mkdir mapillary
unzip -d mapillary An-ZjB1Zm61yAZG0ozTymz8I8NqI4x0MrYrh26dq7kPgfu8vf9ImrdaOAVOFYbJ2pNAgUnVGBmbue9lTgdBOb5BbKXIpFs0fpYWqACbrQDChAA2fdX0zS9PcHu7fY8c-FOvyBVxPNYNFQuM.zip
cd data/mapillary
unzip An-ZjB1Zm61yAZG0ozTymz8I8NqI4x0MrYrh26dq7kPgfu8vf9ImrdaOAVOFYbJ2pNAgUnVGBmbue9lTgdBOb5BbKXIpFs0fpYWqACbrQDChAA2fdX0zS9PcHu7fY8c-FOvyBVxPNYNFQuM.zip
```
- After unzip, you will get Mapillary Vistas Dataset like this structure.
```none
├── data
│ ├── mapillary
│ │ ├── training
│ │ │ ├── images
│ │ │ ├── v1.2
| │ │ │ ├── instances
| │ │ │ ├── labels
| │   │   │ └── panoptic
│ │ │ ├── v2.0
| │ │ │ ├── instances
| │ │ │ ├── labels
| │ │ │ ├── panoptic
| │   │   │ └── polygons
│ │ ├── validation
│ │ │ ├── images
| │ │ │ ├── instances
| │ │ │ ├── labels
| │   │   │ └── panoptic
│ │ │ ├── v2.0
| │ │ │ ├── instances
| │ │ │ ├── labels
| │ │ │ ├── panoptic
| │   │   │ └── polygons
```
- run following commands to convert RGB labels to mask labels
```bash
# --nproc optional, default 1, whether use multi-progress
# --version optional, 'v1.2', 'v2.0','all', default 'all', choose convert which version labels
# run this command at 'mmsegmentation/projects/Mapillary_dataset' folder
cd mmsegmentation/projects/mapillary_dataset
python tools/dataset_converters/mapillary.py ../../data/mapillary --nproc 8 --version all
```
After then, you will get this structure
- After unzip, you will get Mapillary Vistas Dataset like this structure. Semantic segmentation mask labels in `labels` folder.
```none
mmsegmentation
├── mmseg
@ -94,24 +27,229 @@ mmsegmentation
│ │ │ ├── v1.2
| │ │ │ ├── instances
| │ │ │ ├── labels
| │ │ │ ├── labels_mask
| │   │   │ └── panoptic
│ │ │ ├── v2.0
| │ │ │ ├── instances
| │ │ │ ├── labels
| │ │ │ ├── labels_mask
| │ │ │ ├── panoptic
| │   │   │ └── polygons
│ │ ├── validation
│ │ │ ├── images
| │ │ ├── v1.2
| │ │ │ ├── instances
| │ │ │ ├── labels
| │ │ │ ├── labels_mask
| │   │   │ └── panoptic
│ │ │ ├── v2.0
| │ │ │ ├── instances
| │ │ │ ├── labels
| │ │ │ ├── labels_mask
| │ │ │ ├── panoptic
| │   │   │ └── polygons
```
- You could set Datasets version with `MapillaryDataset_v1` and `MapillaryDataset_v2` in your configs.
View the Mapillary Vistas Datasets config file here [V1.2](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/configs/_base_/datasets/mapillary_v1.py) and [V2.0](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/configs/_base_/datasets/mapillary_v2.py)
- **View datasets labels index and palette**
- **Mapillary Vistas Datasets labels information**
**v1.2 information**
```none
There are 66 labels classes in v1.2
0--Bird--[165, 42, 42],
1--Ground Animal--[0, 192, 0],
2--Curb--[196, 196, 196],
3--Fence--[190, 153, 153],
4--Guard Rail--[180, 165, 180],
5--Barrier--[90, 120, 150],
6--Wall--[102, 102, 156],
7--Bike Lane--[128, 64, 255],
8--Crosswalk - Plain--[140, 140, 200],
9--Curb Cut--[170, 170, 170],
10--Parking--[250, 170, 160],
11--Pedestrian Area--[96, 96, 96],
12--Rail Track--[230, 150, 140],
13--Road--[128, 64, 128],
14--Service Lane--[110, 110, 110],
15--Sidewalk--[244, 35, 232],
16--Bridge--[150, 100, 100],
17--Building--[70, 70, 70],
18--Tunnel--[150, 120, 90],
19--Person--[220, 20, 60],
20--Bicyclist--[255, 0, 0],
21--Motorcyclist--[255, 0, 100],
22--Other Rider--[255, 0, 200],
23--Lane Marking - Crosswalk--[200, 128, 128],
24--Lane Marking - General--[255, 255, 255],
25--Mountain--[64, 170, 64],
26--Sand--[230, 160, 50],
27--Sky--[70, 130, 180],
28--Snow--[190, 255, 255],
29--Terrain--[152, 251, 152],
30--Vegetation--[107, 142, 35],
31--Water--[0, 170, 30],
32--Banner--[255, 255, 128],
33--Bench--[250, 0, 30],
34--Bike Rack--[100, 140, 180],
35--Billboard--[220, 220, 220],
36--Catch Basin--[220, 128, 128],
37--CCTV Camera--[222, 40, 40],
38--Fire Hydrant--[100, 170, 30],
39--Junction Box--[40, 40, 40],
40--Mailbox--[33, 33, 33],
41--Manhole--[100, 128, 160],
42--Phone Booth--[142, 0, 0],
43--Pothole--[70, 100, 150],
44--Street Light--[210, 170, 100],
45--Pole--[153, 153, 153],
46--Traffic Sign Frame--[128, 128, 128],
47--Utility Pole--[0, 0, 80],
48--Traffic Light--[250, 170, 30],
49--Traffic Sign (Back)--[192, 192, 192],
50--Traffic Sign (Front)--[220, 220, 0],
51--Trash Can--[140, 140, 20],
52--Bicycle--[119, 11, 32],
53--Boat--[150, 0, 255],
54--Bus--[0, 60, 100],
55--Car--[0, 0, 142],
56--Caravan--[0, 0, 90],
57--Motorcycle--[0, 0, 230],
58--On Rails--[0, 80, 100],
59--Other Vehicle--[128, 64, 64],
60--Trailer--[0, 0, 110],
61--Truck--[0, 0, 70],
62--Wheeled Slow--[0, 0, 192],
63--Car Mount--[32, 32, 32],
64--Ego Vehicle--[120, 10, 10],
65--Unlabeled--[0, 0, 0]
```
**v2.0 information**
```none
There are 124 labels classes in v2.0
0--Bird--[165, 42, 42],
1--Ground Animal--[0, 192, 0],
2--Ambiguous Barrier--[250, 170, 31],
3--Concrete Block--[250, 170, 32],
4--Curb--[196, 196, 196],
5--Fence--[190, 153, 153],
6--Guard Rail--[180, 165, 180],
7--Barrier--[90, 120, 150],
8--Road Median--[250, 170, 33],
9--Road Side--[250, 170, 34],
10--Lane Separator--[128, 128, 128],
11--Temporary Barrier--[250, 170, 35],
12--Wall--[102, 102, 156],
13--Bike Lane--[128, 64, 255],
14--Crosswalk - Plain--[140, 140, 200],
15--Curb Cut--[170, 170, 170],
16--Driveway--[250, 170, 36],
17--Parking--[250, 170, 160],
18--Parking Aisle--[250, 170, 37],
19--Pedestrian Area--[96, 96, 96],
20--Rail Track--[230, 150, 140],
21--Road--[128, 64, 128],
22--Road Shoulder--[110, 110, 110],
23--Service Lane--[110, 110, 110],
24--Sidewalk--[244, 35, 232],
25--Traffic Island--[128, 196, 128],
26--Bridge--[150, 100, 100],
27--Building--[70, 70, 70],
28--Garage--[150, 150, 150],
29--Tunnel--[150, 120, 90],
30--Person--[220, 20, 60],
31--Person Group--[220, 20, 60],
32--Bicyclist--[255, 0, 0],
33--Motorcyclist--[255, 0, 100],
34--Other Rider--[255, 0, 200],
35--Lane Marking - Dashed Line--[255, 255, 255],
36--Lane Marking - Straight Line--[255, 255, 255],
37--Lane Marking - Zigzag Line--[250, 170, 29],
38--Lane Marking - Ambiguous--[250, 170, 28],
39--Lane Marking - Arrow (Left)--[250, 170, 26],
40--Lane Marking - Arrow (Other)--[250, 170, 25],
41--Lane Marking - Arrow (Right)--[250, 170, 24],
42--Lane Marking - Arrow (Split Left or Straight)--[250, 170, 22],
43--Lane Marking - Arrow (Split Right or Straight)--[250, 170, 21],
44--Lane Marking - Arrow (Straight)--[250, 170, 20],
45--Lane Marking - Crosswalk--[255, 255, 255],
46--Lane Marking - Give Way (Row)--[250, 170, 19],
47--Lane Marking - Give Way (Single)--[250, 170, 18],
48--Lane Marking - Hatched (Chevron)--[250, 170, 12],
49--Lane Marking - Hatched (Diagonal)--[250, 170, 11],
50--Lane Marking - Other--[255, 255, 255],
51--Lane Marking - Stop Line--[255, 255, 255],
52--Lane Marking - Symbol (Bicycle)--[250, 170, 16],
53--Lane Marking - Symbol (Other)--[250, 170, 15],
54--Lane Marking - Text--[250, 170, 15],
55--Lane Marking (only) - Dashed Line--[255, 255, 255],
56--Lane Marking (only) - Crosswalk--[255, 255, 255],
57--Lane Marking (only) - Other--[255, 255, 255],
58--Lane Marking (only) - Test--[255, 255, 255],
59--Mountain--[64, 170, 64],
60--Sand--[230, 160, 50],
61--Sky--[70, 130, 180],
62--Snow--[190, 255, 255],
63--Terrain--[152, 251, 152],
64--Vegetation--[107, 142, 35],
65--Water--[0, 170, 30],
66--Banner--[255, 255, 128],
67--Bench--[250, 0, 30],
68--Bike Rack--[100, 140, 180],
69--Catch Basin--[220, 128, 128],
70--CCTV Camera--[222, 40, 40],
71--Fire Hydrant--[100, 170, 30],
72--Junction Box--[40, 40, 40],
73--Mailbox--[33, 33, 33],
74--Manhole--[100, 128, 160],
75--Parking Meter--[20, 20, 255],
76--Phone Booth--[142, 0, 0],
77--Pothole--[70, 100, 150],
78--Signage - Advertisement--[250, 171, 30],
79--Signage - Ambiguous--[250, 172, 30],
80--Signage - Back--[250, 173, 30],
81--Signage - Information--[250, 174, 30],
82--Signage - Other--[250, 175, 30],
83--Signage - Store--[250, 176, 30],
84--Street Light--[210, 170, 100],
85--Pole--[153, 153, 153],
86--Pole Group--[153, 153, 153],
87--Traffic Sign Frame--[128, 128, 128],
88--Utility Pole--[0, 0, 80],
89--Traffic Cone--[210, 60, 60],
90--Traffic Light - General (Single)--[250, 170, 30],
91--Traffic Light - Pedestrians--[250, 170, 30],
92--Traffic Light - General (Upright)--[250, 170, 30],
93--Traffic Light - General (Horizontal)--[250, 170, 30],
94--Traffic Light - Cyclists--[250, 170, 30],
95--Traffic Light - Other--[250, 170, 30],
96--Traffic Sign - Ambiguous--[192, 192, 192],
97--Traffic Sign (Back)--[192, 192, 192],
98--Traffic Sign - Direction (Back)--[192, 192, 192],
99--Traffic Sign - Direction (Front)--[220, 220, 0],
100--Traffic Sign (Front)--[220, 220, 0],
101--Traffic Sign - Parking--[0, 0, 196],
102--Traffic Sign - Temporary (Back)--[192, 192, 192],
103--Traffic Sign - Temporary (Front)--[220, 220, 0],
104--Trash Can--[140, 140, 20],
105--Bicycle--[119, 11, 32],
106--Boat--[150, 0, 255],
107--Bus--[0, 60, 100],
108--Car--[0, 0, 142],
109--Caravan--[0, 0, 90],
110--Motorcycle--[0, 0, 230],
111--On Rails--[0, 80, 100],
112--Other Vehicle--[128, 64, 64],
113--Trailer--[0, 0, 110],
114--Truck--[0, 0, 70],
115--Vehicle Group--[0, 0, 142],
116--Wheeled Slow--[0, 0, 192],
117--Water Valve--[170, 170, 170],
118--Car Mount--[32, 32, 32],
119--Dynamic--[111, 74, 0],
120--Ego Vehicle--[120, 10, 10],
121--Ground--[81, 0, 81],
122--Static--[111, 111, 0],
123--Unlabeled--[0, 0, 0]
```

View File

@ -0,0 +1,177 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmseg.datasets.basesegdataset import BaseSegDataset
# from mmseg.registry import DATASETS
# @DATASETS.register_module()
class MapillaryDataset_v1(BaseSegDataset):
"""Mapillary Vistas Dataset.
Dataset paper link:
http://ieeexplore.ieee.org/document/8237796/
v1.2 contain 66 object classes.
(37 instance-specific)
v2.0 contain 124 object classes.
(70 instance-specific, 46 stuff, 8 void or crowd).
The ``img_suffix`` is fixed to '.jpg' and ``seg_map_suffix`` is
fixed to '.png' for Mapillary Vistas Dataset.
"""
METAINFO = dict(
classes=('Bird', 'Ground Animal', 'Curb', 'Fence', 'Guard Rail',
'Barrier', 'Wall', 'Bike Lane', 'Crosswalk - Plain',
'Curb Cut', 'Parking', 'Pedestrian Area', 'Rail Track',
'Road', 'Service Lane', 'Sidewalk', 'Bridge', 'Building',
'Tunnel', 'Person', 'Bicyclist', 'Motorcyclist',
'Other Rider', 'Lane Marking - Crosswalk',
'Lane Marking - General', 'Mountain', 'Sand', 'Sky', 'Snow',
'Terrain', 'Vegetation', 'Water', 'Banner', 'Bench',
'Bike Rack', 'Billboard', 'Catch Basin', 'CCTV Camera',
'Fire Hydrant', 'Junction Box', 'Mailbox', 'Manhole',
'Phone Booth', 'Pothole', 'Street Light', 'Pole',
'Traffic Sign Frame', 'Utility Pole', 'Traffic Light',
'Traffic Sign (Back)', 'Traffic Sign (Front)', 'Trash Can',
'Bicycle', 'Boat', 'Bus', 'Car', 'Caravan', 'Motorcycle',
'On Rails', 'Other Vehicle', 'Trailer', 'Truck',
'Wheeled Slow', 'Car Mount', 'Ego Vehicle', 'Unlabeled'),
palette=[[165, 42, 42], [0, 192, 0], [196, 196, 196], [190, 153, 153],
[180, 165, 180], [90, 120, 150], [102, 102, 156],
[128, 64, 255], [140, 140, 200], [170, 170, 170],
[250, 170, 160], [96, 96, 96],
[230, 150, 140], [128, 64, 128], [110, 110, 110],
[244, 35, 232], [150, 100, 100], [70, 70, 70], [150, 120, 90],
[220, 20, 60], [255, 0, 0], [255, 0, 100], [255, 0, 200],
[200, 128, 128], [255, 255, 255], [64, 170,
64], [230, 160, 50],
[70, 130, 180], [190, 255, 255], [152, 251, 152],
[107, 142, 35], [0, 170, 30], [255, 255, 128], [250, 0, 30],
[100, 140, 180], [220, 220, 220], [220, 128, 128],
[222, 40, 40], [100, 170, 30], [40, 40, 40], [33, 33, 33],
[100, 128, 160], [142, 0, 0], [70, 100, 150], [210, 170, 100],
[153, 153, 153], [128, 128, 128], [0, 0, 80], [250, 170, 30],
[192, 192, 192], [220, 220, 0], [140, 140, 20], [119, 11, 32],
[150, 0, 255], [0, 60, 100], [0, 0, 142], [0, 0, 90],
[0, 0, 230], [0, 80, 100], [128, 64, 64], [0, 0, 110],
[0, 0, 70], [0, 0, 192], [32, 32, 32], [120, 10,
10], [0, 0, 0]])
def __init__(self,
img_suffix='.jpg',
seg_map_suffix='.png',
**kwargs) -> None:
super().__init__(
img_suffix=img_suffix, seg_map_suffix=seg_map_suffix, **kwargs)
# @DATASETS.register_module()
class MapillaryDataset_v2(BaseSegDataset):
"""Mapillary Vistas Dataset.
Dataset paper link:
http://ieeexplore.ieee.org/document/8237796/
v1.2 contain 66 object classes.
(37 instance-specific)
v2.0 contain 124 object classes.
(70 instance-specific, 46 stuff, 8 void or crowd).
The ``img_suffix`` is fixed to '.jpg' and ``seg_map_suffix`` is
fixed to '.png' for Mapillary Vistas Dataset.
"""
METAINFO = dict(
classes=(
'Bird', 'Ground Animal', 'Ambiguous Barrier', 'Concrete Block',
'Curb', 'Fence', 'Guard Rail', 'Barrier', 'Road Median',
'Road Side', 'Lane Separator', 'Temporary Barrier', 'Wall',
'Bike Lane', 'Crosswalk - Plain', 'Curb Cut', 'Driveway',
'Parking', 'Parking Aisle', 'Pedestrian Area', 'Rail Track',
'Road', 'Road Shoulder', 'Service Lane', 'Sidewalk',
'Traffic Island', 'Bridge', 'Building', 'Garage', 'Tunnel',
'Person', 'Person Group', 'Bicyclist', 'Motorcyclist',
'Other Rider', 'Lane Marking - Dashed Line',
'Lane Marking - Straight Line', 'Lane Marking - Zigzag Line',
'Lane Marking - Ambiguous', 'Lane Marking - Arrow (Left)',
'Lane Marking - Arrow (Other)', 'Lane Marking - Arrow (Right)',
'Lane Marking - Arrow (Split Left or Straight)',
'Lane Marking - Arrow (Split Right or Straight)',
'Lane Marking - Arrow (Straight)', 'Lane Marking - Crosswalk',
'Lane Marking - Give Way (Row)',
'Lane Marking - Give Way (Single)',
'Lane Marking - Hatched (Chevron)',
'Lane Marking - Hatched (Diagonal)', 'Lane Marking - Other',
'Lane Marking - Stop Line', 'Lane Marking - Symbol (Bicycle)',
'Lane Marking - Symbol (Other)', 'Lane Marking - Text',
'Lane Marking (only) - Dashed Line',
'Lane Marking (only) - Crosswalk', 'Lane Marking (only) - Other',
'Lane Marking (only) - Test', 'Mountain', 'Sand', 'Sky', 'Snow',
'Terrain', 'Vegetation', 'Water', 'Banner', 'Bench', 'Bike Rack',
'Catch Basin', 'CCTV Camera', 'Fire Hydrant', 'Junction Box',
'Mailbox', 'Manhole', 'Parking Meter', 'Phone Booth', 'Pothole',
'Signage - Advertisement', 'Signage - Ambiguous', 'Signage - Back',
'Signage - Information', 'Signage - Other', 'Signage - Store',
'Street Light', 'Pole', 'Pole Group', 'Traffic Sign Frame',
'Utility Pole', 'Traffic Cone', 'Traffic Light - General (Single)',
'Traffic Light - Pedestrians', 'Traffic Light - General (Upright)',
'Traffic Light - General (Horizontal)', 'Traffic Light - Cyclists',
'Traffic Light - Other', 'Traffic Sign - Ambiguous',
'Traffic Sign (Back)', 'Traffic Sign - Direction (Back)',
'Traffic Sign - Direction (Front)', 'Traffic Sign (Front)',
'Traffic Sign - Parking', 'Traffic Sign - Temporary (Back)',
'Traffic Sign - Temporary (Front)', 'Trash Can', 'Bicycle', 'Boat',
'Bus', 'Car', 'Caravan', 'Motorcycle', 'On Rails', 'Other Vehicle',
'Trailer', 'Truck', 'Vehicle Group', 'Wheeled Slow', 'Water Valve',
'Car Mount', 'Dynamic', 'Ego Vehicle', 'Ground', 'Static',
'Unlabeled'),
palette=[[165, 42, 42], [0, 192, 0], [250, 170, 31], [250, 170, 32],
[196, 196, 196], [190, 153, 153], [180, 165, 180],
[90, 120, 150], [250, 170, 33], [250, 170, 34],
[128, 128, 128], [250, 170, 35], [102, 102, 156],
[128, 64, 255], [140, 140, 200], [170, 170, 170],
[250, 170, 36], [250, 170, 160], [250, 170, 37], [96, 96, 96],
[230, 150, 140], [128, 64, 128], [110, 110, 110],
[110, 110, 110], [244, 35, 232], [128, 196,
128], [150, 100, 100],
[70, 70, 70], [150, 150, 150], [150, 120, 90], [220, 20, 60],
[220, 20, 60], [255, 0, 0], [255, 0, 100], [255, 0, 200],
[255, 255, 255], [255, 255, 255], [250, 170, 29],
[250, 170, 28], [250, 170, 26], [250, 170,
25], [250, 170, 24],
[250, 170, 22], [250, 170, 21], [250, 170,
20], [255, 255, 255],
[250, 170, 19], [250, 170, 18], [250, 170,
12], [250, 170, 11],
[255, 255, 255], [255, 255, 255], [250, 170, 16],
[250, 170, 15], [250, 170, 15], [255, 255, 255],
[255, 255, 255], [255, 255, 255], [255, 255, 255],
[64, 170, 64], [230, 160, 50],
[70, 130, 180], [190, 255, 255], [152, 251, 152],
[107, 142, 35], [0, 170, 30], [255, 255, 128], [250, 0, 30],
[100, 140, 180], [220, 128, 128], [222, 40,
40], [100, 170, 30],
[40, 40, 40], [33, 33, 33], [100, 128, 160], [20, 20, 255],
[142, 0, 0], [70, 100, 150], [250, 171, 30], [250, 172, 30],
[250, 173, 30], [250, 174, 30], [250, 175,
30], [250, 176, 30],
[210, 170, 100], [153, 153, 153], [153, 153, 153],
[128, 128, 128], [0, 0, 80], [210, 60, 60], [250, 170, 30],
[250, 170, 30], [250, 170, 30], [250, 170,
30], [250, 170, 30],
[250, 170, 30], [192, 192, 192], [192, 192, 192],
[192, 192, 192], [220, 220, 0], [220, 220, 0], [0, 0, 196],
[192, 192, 192], [220, 220, 0], [140, 140, 20], [119, 11, 32],
[150, 0, 255], [0, 60, 100], [0, 0, 142], [0, 0, 90],
[0, 0, 230], [0, 80, 100], [128, 64, 64], [0, 0, 110],
[0, 0, 70], [0, 0, 142], [0, 0, 192], [170, 170, 170],
[32, 32, 32], [111, 74, 0], [120, 10, 10], [81, 0, 81],
[111, 111, 0], [0, 0, 0]])
def __init__(self,
img_suffix='.jpg',
seg_map_suffix='.png',
**kwargs) -> None:
super().__init__(
img_suffix=img_suffix, seg_map_suffix=seg_map_suffix, **kwargs)

View File

@ -1,65 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmseg.datasets.basesegdataset import BaseSegDataset
from mmseg.registry import DATASETS
@DATASETS.register_module()
class MapillaryDataset_v1_2(BaseSegDataset):
"""Mapillary Vistas Dataset.
Dataset paper link:
http://ieeexplore.ieee.org/document/8237796/
v1.2 contain 66 object classes.
(37 instance-specific)
v2.0 contain 124 object classes.
(70 instance-specific, 46 stuff, 8 void or crowd).
The ``img_suffix`` is fixed to '.jpg' and ``seg_map_suffix`` is
fixed to '.png' for Mapillary Vistas Dataset.
"""
METAINFO = dict(
classes=('Bird', 'Ground Animal', 'Curb', 'Fence', 'Guard Rail',
'Barrier', 'Wall', 'Bike Lane', 'Crosswalk - Plain',
'Curb Cut', 'Parking', 'Pedestrian Area', 'Rail Track',
'Road', 'Service Lane', 'Sidewalk', 'Bridge', 'Building',
'Tunnel', 'Person', 'Bicyclist', 'Motorcyclist',
'Other Rider', 'Lane Marking - Crosswalk',
'Lane Marking - General', 'Mountain', 'Sand', 'Sky', 'Snow',
'Terrain', 'Vegetation', 'Water', 'Banner', 'Bench',
'Bike Rack', 'Billboard', 'Catch Basin', 'CCTV Camera',
'Fire Hydrant', 'Junction Box', 'Mailbox', 'Manhole',
'Phone Booth', 'Pothole', 'Street Light', 'Pole',
'Traffic Sign Frame', 'Utility Pole', 'Traffic Light',
'Traffic Sign (Back)', 'Traffic Sign (Front)', 'Trash Can',
'Bicycle', 'Boat', 'Bus', 'Car', 'Caravan', 'Motorcycle',
'On Rails', 'Other Vehicle', 'Trailer', 'Truck',
'Wheeled Slow', 'Car Mount', 'Ego Vehicle', 'Unlabeled'),
palette=[[165, 42, 42], [0, 192, 0], [196, 196, 196], [190, 153, 153],
[180, 165, 180], [90, 120, 150], [102, 102, 156],
[128, 64, 255], [140, 140, 200], [170, 170, 170],
[250, 170, 160], [96, 96, 96],
[230, 150, 140], [128, 64, 128], [110, 110, 110],
[244, 35, 232], [150, 100, 100], [70, 70, 70], [150, 120, 90],
[220, 20, 60], [255, 0, 0], [255, 0, 100], [255, 0, 200],
[200, 128, 128], [255, 255, 255], [64, 170,
64], [230, 160, 50],
[70, 130, 180], [190, 255, 255], [152, 251, 152],
[107, 142, 35], [0, 170, 30], [255, 255, 128], [250, 0, 30],
[100, 140, 180], [220, 220, 220], [220, 128, 128],
[222, 40, 40], [100, 170, 30], [40, 40, 40], [33, 33, 33],
[100, 128, 160], [142, 0, 0], [70, 100, 150], [210, 170, 100],
[153, 153, 153], [128, 128, 128], [0, 0, 80], [250, 170, 30],
[192, 192, 192], [220, 220, 0], [140, 140, 20], [119, 11, 32],
[150, 0, 255], [0, 60, 100], [0, 0, 142], [0, 0, 90],
[0, 0, 230], [0, 80, 100], [128, 64, 64], [0, 0, 110],
[0, 0, 70], [0, 0, 192], [32, 32, 32], [120, 10,
10], [0, 0, 0]])
def __init__(self,
img_suffix='.jpg',
seg_map_suffix='.png',
**kwargs) -> None:
super().__init__(
img_suffix=img_suffix, seg_map_suffix=seg_map_suffix, **kwargs)

View File

@ -1,245 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os.path as osp
from functools import partial
import mmcv
import numpy as np
from mmengine.utils import (mkdir_or_exist, scandir, track_parallel_progress,
track_progress)
colormap_v1_2 = np.array([[165, 42, 42], [0, 192, 0], [196, 196, 196],
[190, 153, 153], [180, 165, 180], [90, 120, 150],
[102, 102, 156], [128, 64, 255], [140, 140, 200],
[170, 170, 170], [250, 170, 160], [96, 96, 96],
[230, 150, 140], [128, 64, 128], [110, 110, 110],
[244, 35, 232], [150, 100, 100], [70, 70, 70],
[150, 120, 90], [220, 20, 60], [255, 0, 0],
[255, 0, 100], [255, 0, 200], [200, 128, 128],
[255, 255, 255], [64, 170, 64], [230, 160, 50],
[70, 130, 180], [190, 255, 255], [152, 251, 152],
[107, 142, 35], [0, 170, 30], [255, 255, 128],
[250, 0, 30], [100, 140, 180], [220, 220, 220],
[220, 128, 128], [222, 40, 40], [100, 170, 30],
[40, 40, 40], [33, 33, 33], [100, 128, 160],
[142, 0, 0], [70, 100, 150], [210, 170, 100],
[153, 153, 153], [128, 128, 128], [0, 0, 80],
[250, 170, 30], [192, 192, 192], [220, 220, 0],
[140, 140, 20], [119, 11, 32], [150, 0, 255],
[0, 60, 100], [0, 0, 142], [0, 0, 90], [0, 0, 230],
[0, 80, 100], [128, 64, 64], [0, 0, 110], [0, 0, 70],
[0, 0, 192], [32, 32, 32], [120, 10, 10], [0, 0, 0]])
colormap_v2_0 = np.array([[165, 42, 42], [0, 192, 0], [250, 170, 31],
[250, 170, 32], [196, 196, 196], [190, 153, 153],
[180, 165, 180], [90, 120, 150], [250, 170, 33],
[250, 170, 34], [128, 128, 128], [250, 170, 35],
[102, 102, 156], [128, 64, 255], [140, 140, 200],
[170, 170, 170], [250, 170, 36], [250, 170, 160],
[250, 170, 37], [96, 96, 96], [230, 150, 140],
[128, 64, 128], [110, 110, 110], [110, 110, 110],
[244, 35, 232], [128, 196, 128], [150, 100, 100],
[70, 70, 70], [150, 150, 150], [150, 120, 90],
[220, 20, 60], [220, 20, 60], [255, 0, 0],
[255, 0, 100], [255, 0, 200], [255, 255, 255],
[255, 255, 255], [250, 170, 29], [250, 170, 28],
[250, 170, 26], [250, 170, 25], [250, 170, 24],
[250, 170, 22], [250, 170, 21], [250, 170, 20],
[255, 255, 255], [250, 170, 19], [250, 170, 18],
[250, 170, 12], [250, 170, 11], [255, 255, 255],
[255, 255, 255], [250, 170, 16], [250, 170, 15],
[250, 170, 15], [255, 255, 255], [255, 255, 255],
[255, 255, 255], [255, 255, 255], [64, 170, 64],
[230, 160, 50], [70, 130, 180], [190, 255, 255],
[152, 251, 152], [107, 142, 35], [0, 170, 30],
[255, 255, 128], [250, 0, 30], [100, 140, 180],
[220, 128, 128], [222, 40, 40], [100, 170, 30],
[40, 40, 40], [33, 33, 33], [100, 128, 160],
[20, 20, 255], [142, 0, 0], [70, 100, 150],
[250, 171, 30], [250, 172, 30], [250, 173, 30],
[250, 174, 30], [250, 175, 30], [250, 176, 30],
[210, 170, 100], [153, 153, 153], [153, 153, 153],
[128, 128, 128], [0, 0, 80], [210, 60, 60],
[250, 170, 30], [250, 170, 30], [250, 170, 30],
[250, 170, 30], [250, 170, 30], [250, 170, 30],
[192, 192, 192], [192, 192, 192], [192, 192, 192],
[220, 220, 0], [220, 220, 0], [0, 0, 196],
[192, 192, 192], [220, 220, 0], [140, 140, 20],
[119, 11, 32], [150, 0, 255], [0, 60, 100],
[0, 0, 142], [0, 0, 90], [0, 0, 230], [0, 80, 100],
[128, 64, 64], [0, 0, 110], [0, 0, 70], [0, 0, 142],
[0, 0, 192], [170, 170, 170], [32, 32, 32],
[111, 74, 0], [120, 10, 10], [81, 0, 81],
[111, 111, 0], [0, 0, 0]])
def parse_args():
parser = argparse.ArgumentParser(
description='Convert Mapillary dataset to mmsegmentation format')
parser.add_argument('dataset_path', help='Mapillary folder path')
parser.add_argument(
'--version',
default='all',
help="Mapillary labels version, 'v1.2','v2.0','all'")
parser.add_argument('-o', '--out_dir', help='output path')
parser.add_argument(
'--nproc', default=1, type=int, help='number of process')
args = parser.parse_args()
return args
def mapillary_colormap2label(colormap: np.ndarray) -> list:
"""Create a `list` shaped (256^3, 1), convert each color palette to a
number, which can use to find the correct label value.
For example labels 0--Bird--[165, 42, 42]
(165*256 + 42) * 256 + 42 = 10824234 (This is list's index])
`colormap2label[10824234] = 0`
In converting, if a RGB pixel value is [165, 42, 42],
through colormap2label[10824234]-->can quickly find
this labels value is 0.
Through matrix multiply to compute a img is very fast.
Args:
colormap (np.ndarray): Mapillary Vistas Dataset palette
Returns:
list: values are mask labels,
indices are palette's convert results.
"""
colormap2label = np.zeros(256**3, dtype=np.longlong)
for i, colormap_ in enumerate(colormap):
colormap2label[(colormap_[0] * 256 + colormap_[1]) * 256 +
colormap_[2]] = i
return colormap2label
def mapillary_masklabel(rgb_label: np.ndarray,
colormap2label: list) -> np.ndarray:
"""Computing a img mask label through `colormap2label` get in
`mapillary_colormap2label(COLORMAP: np.ndarray)`
Args:
rgb_label (np.array): a RGB labels img.
colormap2label (list): get in mapillary_colormap2label(colormap)
Returns:
np.ndarray: mask labels array.
"""
colormap_ = rgb_label.astype('uint32')
idx = np.array((colormap_[:, :, 0] * 256 + colormap_[:, :, 1]) * 256 +
colormap_[:, :, 2]).astype('uint32')
return colormap2label[idx]
def RGB2Mask(rgb_label_path: str, colormap2label: list) -> None:
"""Mapillary Vistas Dataset provide 8-bit with color-palette class-specific
labels for semantic segmentation. However, semantic segmentation needs
single channel mask labels.
This code is about converting mapillary RGB labels
{traing,validation/v1.2,v2.0/labels} to mask labels
{{traing,validation/v1.2,v2.0/labels_mask}
Args:
rgb_label_path (str): image absolute path.
dataset_version (str): v1.2 or v2.0 to choose color_map .
"""
rgb_label = mmcv.imread(rgb_label_path, channel_order='rgb')
masks_label = mapillary_masklabel(rgb_label, colormap2label)
mmcv.imwrite(
masks_label.astype(np.uint8),
rgb_label_path.replace('labels', 'labels_mask'))
def main():
colormap2label_v1_2 = mapillary_colormap2label(colormap_v1_2)
colormap2label_v2_0 = mapillary_colormap2label(colormap_v2_0)
dataset_path = args.dataset_path
if args.out_dir is None:
out_dir = dataset_path
else:
out_dir = args.out_dir
RGB_labels_path = []
RGB_labels_v1_2_path = []
RGB_labels_v2_0_path = []
print('Scanning labels path....')
for label_path in scandir(dataset_path, suffix='.png', recursive=True):
if 'labels' in label_path:
rgb_label_path = osp.join(dataset_path, label_path)
RGB_labels_path.append(rgb_label_path)
if 'v1.2' in label_path:
RGB_labels_v1_2_path.append(rgb_label_path)
elif 'v2.0' in label_path:
RGB_labels_v2_0_path.append(rgb_label_path)
if args.version == 'all':
print(f'Totaly found {len(RGB_labels_path)} {args.version} RGB labels')
elif args.version == 'v1.2':
print(f'Found {len(RGB_labels_v1_2_path)} {args.version} RGB labels')
elif args.version == 'v2.0':
print(f'Found {len(RGB_labels_v2_0_path)} {args.version} RGB labels')
print('Making directories...')
mkdir_or_exist(osp.join(out_dir, 'training', 'v1.2', 'labels_mask'))
mkdir_or_exist(osp.join(out_dir, 'validation', 'v1.2', 'labels_mask'))
mkdir_or_exist(osp.join(out_dir, 'training', 'v2.0', 'labels_mask'))
mkdir_or_exist(osp.join(out_dir, 'validation', 'v2.0', 'labels_mask'))
print('Directories Have Made...')
if args.nproc > 1:
if args.version == 'all':
print('Converting v1.2 ....')
track_parallel_progress(
partial(RGB2Mask, colormap2label=colormap2label_v1_2),
RGB_labels_v1_2_path,
nproc=args.nproc)
print('Converting v2.0 ....')
track_parallel_progress(
partial(RGB2Mask, colormap2label=colormap2label_v2_0),
RGB_labels_v2_0_path,
nproc=args.nproc)
elif args.version == 'v1.2':
print('Converting v1.2 ....')
track_parallel_progress(
partial(RGB2Mask, colormap2label=colormap2label_v1_2),
RGB_labels_v1_2_path,
nproc=args.nproc)
elif args.version == 'v2.0':
print('Converting v2.0 ....')
track_parallel_progress(
partial(RGB2Mask, colormap2label=colormap2label_v2_0),
RGB_labels_v2_0_path,
nproc=args.nproc)
else:
if args.version == 'all':
print('Converting v1.2 ....')
track_progress(
partial(RGB2Mask, colormap2label=colormap2label_v1_2),
RGB_labels_v1_2_path)
print('Converting v2.0 ....')
track_progress(
partial(RGB2Mask, colormap2label=colormap2label_v2_0),
RGB_labels_v2_0_path)
elif args.version == 'v1.2':
print('Converting v1.2 ....')
track_progress(
partial(RGB2Mask, colormap2label=colormap2label_v1_2),
RGB_labels_v1_2_path)
elif args.version == 'v2.0':
print('Converting v2.0 ....')
track_progress(
partial(RGB2Mask, colormap2label=colormap2label_v2_0),
RGB_labels_v2_0_path)
print('Have convert Mapillary Vistas Datasets RGB labels to Mask labels!')
if __name__ == '__main__':
args = parse_args()
main()

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.2 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 74 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 74 KiB

View File

@ -7,7 +7,8 @@ import pytest
from mmseg.datasets import (ADE20KDataset, BaseSegDataset, CityscapesDataset,
COCOStuffDataset, DecathlonDataset, ISPRSDataset,
LIPDataset, LoveDADataset, PascalVOCDataset,
LIPDataset, LoveDADataset, MapillaryDataset_v1,
MapillaryDataset_v2, PascalVOCDataset,
PotsdamDataset, REFUGEDataset, SynapseDataset,
iSAIDDataset)
from mmseg.registry import DATASETS
@ -27,6 +28,10 @@ def test_classes():
assert list(PotsdamDataset.METAINFO['classes']) == get_classes('potsdam')
assert list(ISPRSDataset.METAINFO['classes']) == get_classes('vaihingen')
assert list(iSAIDDataset.METAINFO['classes']) == get_classes('isaid')
assert list(
MapillaryDataset_v1.METAINFO['classes']) == get_classes('mapillary_v1')
assert list(
MapillaryDataset_v2.METAINFO['classes']) == get_classes('mapillary_v2')
with pytest.raises(ValueError):
get_classes('unsupported')
@ -80,6 +85,10 @@ def test_palette():
assert PotsdamDataset.METAINFO['palette'] == get_palette('potsdam')
assert COCOStuffDataset.METAINFO['palette'] == get_palette('cocostuff')
assert iSAIDDataset.METAINFO['palette'] == get_palette('isaid')
assert list(
MapillaryDataset_v1.METAINFO['palette']) == get_palette('mapillary_v1')
assert list(
MapillaryDataset_v2.METAINFO['palette']) == get_palette('mapillary_v2')
with pytest.raises(ValueError):
get_palette('unsupported')
@ -304,6 +313,19 @@ def test_lip():
assert len(test_dataset) == 1
def test_mapillary():
test_dataset = MapillaryDataset_v1(
pipeline=[],
data_prefix=dict(
img_path=osp.join(
osp.dirname(__file__),
'../data/pseudo_mapillary_dataset/images'),
seg_map_path=osp.join(
osp.dirname(__file__),
'../data/pseudo_mapillary_dataset/v1.2')))
assert len(test_dataset) == 1
@pytest.mark.parametrize('dataset, classes', [
('ADE20KDataset', ('wall', 'building')),
('CityscapesDataset', ('road', 'sidewalk')),