mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
[Datasets] Add Mapillary Vistas Datasets to MMSeg Core Package. (#2576)
## [Datasets] Add Mapillary Vistas Datasets to MMSeg Core Package . ## Motivation Add Mapillary Vistas Datasets to core package. Old PR #2484 ## Modification - Add Mapillary Vistas Datasets to core package. - Delete `tools/datasets_convert/mapillary.py` , dataset does't need converting. - Add `schedule_240k.py` config. - Add configs files. ```none deeplabv3plus_r101-d8_4xb2-240k_mapillay_v1-512x1024.py deeplabv3plus_r101-d8_4xb2-240k_mapillay_v2-512x1024.py maskformer_swin-s_4xb2-240k_mapillary_v1-512x1024.py maskformer_swin-s_4xb2-240k_mapillary_v2-512x1024.py maskformer_r101-d8_4xb2-240k_mapillary_v1-512x1024.py maskformer_r101-d8_4xb2-240k_mapillary_v2-512x1024.py pspnet_r101-d8_4xb2-240k_mapillay_v1-512x1024.py pspnet_r101-d8_4xb2-240k_mapillay_v2-512x1024.py ``` - Synchronized changes to `projects/mapillary_datasets` --------- Co-authored-by: Miao Zheng <76149310+MeowZheng@users.noreply.github.com> Co-authored-by: xiexinch <xiexinch@outlook.com>
This commit is contained in:
parent
447a398c24
commit
8c89ff3dd1
@ -181,6 +181,7 @@ Results and models are available in the [model zoo](docs/en/model_zoo.md).
|
||||
- [x] [Potsdam](https://github.com/open-mmlab/mmsegmentation/blob/1.x/docs/en/user_guides/2_dataset_prepare.md#isprs-potsdam)
|
||||
- [x] [Vaihingen](https://github.com/open-mmlab/mmsegmentation/blob/1.x/docs/en/user_guides/2_dataset_prepare.md#isprs-vaihingen)
|
||||
- [x] [iSAID](https://github.com/open-mmlab/mmsegmentation/blob/1.x/docs/en/user_guides/2_dataset_prepare.md#isaid)
|
||||
- [x] [Mapillary Vistas](https://github.com/open-mmlab/mmsegmentation/blob/1.x/docs/en/user_guides/2_dataset_prepare.md#mapillary-vistas-datasets)
|
||||
|
||||
</details>
|
||||
|
||||
|
@ -162,6 +162,7 @@ MMSegmentation 是一个基于 PyTorch 的语义分割开源工具箱。它是 O
|
||||
- [x] [Potsdam](https://github.com/open-mmlab/mmsegmentation/blob/1.x/docs/zh_cn/dataset_prepare.md#isprs-potsdam)
|
||||
- [x] [Vaihingen](https://github.com/open-mmlab/mmsegmentation/blob/1.x/docs/zh_cn/dataset_prepare.md#isprs-vaihingen)
|
||||
- [x] [iSAID](https://github.com/open-mmlab/mmsegmentation/blob/1.x/docs/zh_cn/dataset_prepare.md#isaid)
|
||||
- [x] [Mapillary Vistas](https://github.com/open-mmlab/mmsegmentation/blob/1.x/docs/en/user_guides/2_dataset_prepare.md#mapillary-vistas-datasets)
|
||||
|
||||
</details>
|
||||
|
||||
|
@ -1,5 +1,5 @@
|
||||
# dataset settings
|
||||
dataset_type = 'MapillaryDataset_v1_2'
|
||||
dataset_type = 'MapillaryDataset_v1'
|
||||
data_root = 'data/mapillary/'
|
||||
crop_size = (512, 1024)
|
||||
train_pipeline = [
|
||||
@ -48,8 +48,7 @@ train_dataloader = dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
data_prefix=dict(
|
||||
img_path='training/images',
|
||||
seg_map_path='training/v1.2/labels_mask'),
|
||||
img_path='training/images', seg_map_path='training/v1.2/labels'),
|
||||
pipeline=train_pipeline))
|
||||
val_dataloader = dict(
|
||||
batch_size=1,
|
||||
@ -61,7 +60,7 @@ val_dataloader = dict(
|
||||
data_root=data_root,
|
||||
data_prefix=dict(
|
||||
img_path='validation/images',
|
||||
seg_map_path='validation/v1.2/labels_mask'),
|
||||
seg_map_path='validation/v1.2/labels'),
|
||||
pipeline=test_pipeline))
|
||||
test_dataloader = val_dataloader
|
||||
|
37
configs/_base_/datasets/mapillary_v1_65.py
Normal file
37
configs/_base_/datasets/mapillary_v1_65.py
Normal file
@ -0,0 +1,37 @@
|
||||
# dataset settings
|
||||
_base_ = './mapillary_v1.py'
|
||||
metainfo = dict(
|
||||
classes=('Bird', 'Ground Animal', 'Curb', 'Fence', 'Guard Rail', 'Barrier',
|
||||
'Wall', 'Bike Lane', 'Crosswalk - Plain', 'Curb Cut', 'Parking',
|
||||
'Pedestrian Area', 'Rail Track', 'Road', 'Service Lane',
|
||||
'Sidewalk', 'Bridge', 'Building', 'Tunnel', 'Person', 'Bicyclist',
|
||||
'Motorcyclist', 'Other Rider', 'Lane Marking - Crosswalk',
|
||||
'Lane Marking - General', 'Mountain', 'Sand', 'Sky', 'Snow',
|
||||
'Terrain', 'Vegetation', 'Water', 'Banner', 'Bench', 'Bike Rack',
|
||||
'Billboard', 'Catch Basin', 'CCTV Camera', 'Fire Hydrant',
|
||||
'Junction Box', 'Mailbox', 'Manhole', 'Phone Booth', 'Pothole',
|
||||
'Street Light', 'Pole', 'Traffic Sign Frame', 'Utility Pole',
|
||||
'Traffic Light', 'Traffic Sign (Back)', 'Traffic Sign (Front)',
|
||||
'Trash Can', 'Bicycle', 'Boat', 'Bus', 'Car', 'Caravan',
|
||||
'Motorcycle', 'On Rails', 'Other Vehicle', 'Trailer', 'Truck',
|
||||
'Wheeled Slow', 'Car Mount', 'Ego Vehicle'),
|
||||
palette=[[165, 42, 42], [0, 192, 0], [196, 196, 196], [190, 153, 153],
|
||||
[180, 165, 180], [90, 120, 150], [102, 102, 156], [128, 64, 255],
|
||||
[140, 140, 200], [170, 170, 170], [250, 170, 160], [96, 96, 96],
|
||||
[230, 150, 140], [128, 64, 128], [110, 110, 110], [244, 35, 232],
|
||||
[150, 100, 100], [70, 70, 70], [150, 120, 90], [220, 20, 60],
|
||||
[255, 0, 0], [255, 0, 100], [255, 0, 200], [200, 128, 128],
|
||||
[255, 255, 255], [64, 170, 64], [230, 160, 50], [70, 130, 180],
|
||||
[190, 255, 255], [152, 251, 152], [107, 142, 35], [0, 170, 30],
|
||||
[255, 255, 128], [250, 0, 30], [100, 140, 180], [220, 220, 220],
|
||||
[220, 128, 128], [222, 40, 40], [100, 170, 30], [40, 40, 40],
|
||||
[33, 33, 33], [100, 128, 160], [142, 0, 0], [70, 100, 150],
|
||||
[210, 170, 100], [153, 153, 153], [128, 128, 128], [0, 0, 80],
|
||||
[250, 170, 30], [192, 192, 192], [220, 220, 0], [140, 140, 20],
|
||||
[119, 11, 32], [150, 0, 255], [0, 60, 100], [0, 0, 142],
|
||||
[0, 0, 90], [0, 0, 230], [0, 80, 100], [128, 64, 64], [0, 0, 110],
|
||||
[0, 0, 70], [0, 0, 192], [32, 32, 32], [120, 10, 10]])
|
||||
|
||||
train_dataloader = dict(dataset=dict(metainfo=metainfo))
|
||||
val_dataloader = dict(dataset=dict(metainfo=metainfo))
|
||||
test_dataloader = val_dataloader
|
@ -1,5 +1,5 @@
|
||||
# dataset settings
|
||||
dataset_type = 'MapillaryDataset_v2_0'
|
||||
dataset_type = 'MapillaryDataset_v2'
|
||||
data_root = 'data/mapillary/'
|
||||
crop_size = (512, 1024)
|
||||
train_pipeline = [
|
||||
@ -48,8 +48,7 @@ train_dataloader = dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
data_prefix=dict(
|
||||
img_path='training/images',
|
||||
seg_map_path='training/v2.0/labels_mask'),
|
||||
img_path='training/images', seg_map_path='training/v2.0/labels'),
|
||||
pipeline=train_pipeline))
|
||||
val_dataloader = dict(
|
||||
batch_size=1,
|
||||
@ -61,7 +60,7 @@ val_dataloader = dict(
|
||||
data_root=data_root,
|
||||
data_prefix=dict(
|
||||
img_path='validation/images',
|
||||
seg_map_path='validation/v2.0/labels_mask'),
|
||||
seg_map_path='validation/v2.0/labels'),
|
||||
pipeline=test_pipeline))
|
||||
test_dataloader = val_dataloader
|
||||
|
25
configs/_base_/schedules/schedule_240k.py
Normal file
25
configs/_base_/schedules/schedule_240k.py
Normal file
@ -0,0 +1,25 @@
|
||||
# optimizer
|
||||
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
|
||||
optim_wrapper = dict(type='OptimWrapper', optimizer=optimizer, clip_grad=None)
|
||||
# learning policy
|
||||
param_scheduler = [
|
||||
dict(
|
||||
type='PolyLR',
|
||||
eta_min=1e-4,
|
||||
power=0.9,
|
||||
begin=0,
|
||||
end=240000,
|
||||
by_epoch=False)
|
||||
]
|
||||
# training schedule for 240k
|
||||
train_cfg = dict(
|
||||
type='IterBasedTrainLoop', max_iters=240000, val_interval=24000)
|
||||
val_cfg = dict(type='ValLoop')
|
||||
test_cfg = dict(type='TestLoop')
|
||||
default_hooks = dict(
|
||||
timer=dict(type='IterTimerHook'),
|
||||
logger=dict(type='LoggerHook', interval=50, log_metric_by_epoch=False),
|
||||
param_scheduler=dict(type='ParamSchedulerHook'),
|
||||
checkpoint=dict(type='CheckpointHook', by_epoch=False, interval=24000),
|
||||
sampler_seed=dict(type='DistSamplerSeedHook'),
|
||||
visualization=dict(type='SegVisualizationHook'))
|
@ -124,6 +124,12 @@ Spatial pyramid pooling module or encode-decoder structure are used in deep neur
|
||||
| DeepLabV3+ | R-18-D8 | 896x896 | 80000 | 6.19 | 24.81 | 61.35 | 62.61 | [config](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/configs/deeplabv3plus/deeplabv3plus_r18-d8_4xb4-80k_isaid-896x896.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/deeplabv3plus/deeplabv3plus_r18-d8_4x4_896x896_80k_isaid/deeplabv3plus_r18-d8_4x4_896x896_80k_isaid_20220110_180526-7059991d.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/deeplabv3plus/deeplabv3plus_r18-d8_4x4_896x896_80k_isaid/deeplabv3plus_r18-d8_4x4_896x896_80k_isaid_20220110_180526.log.json) |
|
||||
| DeepLabV3+ | R-50-D8 | 896x896 | 80000 | 21.45 | 8.42 | 67.06 | 68.02 | [config](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/configs/deeplabv3plus/deeplabv3plus_r50-d8_4xb4-80k_isaid-896x896.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/deeplabv3plus/deeplabv3plus_r50-d8_4x4_896x896_80k_isaid/deeplabv3plus_r50-d8_4x4_896x896_80k_isaid_20220110_180526-598be439.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/deeplabv3plus/deeplabv3plus_r50-d8_4x4_896x896_80k_isaid/deeplabv3plus_r50-d8_4x4_896x896_80k_isaid_20220110_180526.log.json) |
|
||||
|
||||
### Mapillary Vistas v1.2
|
||||
|
||||
| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | Device | mIoU | mIoU(ms+flip) | config | download |
|
||||
| ---------- | -------- | --------- | ------: | -------- | -------------- | ------ | ----: | ------------: | ---------------------------------------------------------------------------------------------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| DeepLabV3+ | R-50-D8 | 1280x1280 | 300000 | 24.04 | 17.92 | A100 | 47.35 | - | [config](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/configs/deeplabv3plus/deeplabv3plus_r50-d8_4xb2-300k_mapillay_v1_65-1280x1280.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/deeplabv3plus/deeplabv3plus_r50-d8_4xb2-300k_mapillay_v1_65-1280x1280/deeplabv3plus_r50-d8_4xb2-300k_mapillay_v1_65-1280x1280_20230301_110504-655f8e43.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/deeplabv3plus/deeplabv3plus_r50-d8_4xb2-300k_mapillay_v1_65-1280x1280/deeplabv3plus_r50-d8_4xb2-300k_mapillay_v1_65-1280x1280_20230301_110504.json) |
|
||||
|
||||
Note:
|
||||
|
||||
- `D-8`/`D-16` here corresponding to the output stride 8/16 setting for DeepLab series.
|
||||
|
@ -11,6 +11,7 @@ Collections:
|
||||
- Potsdam
|
||||
- Vaihingen
|
||||
- iSAID
|
||||
- Mapillary Vistas v1.2
|
||||
Paper:
|
||||
URL: https://arxiv.org/abs/1802.02611
|
||||
Title: Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation
|
||||
@ -848,3 +849,24 @@ Models:
|
||||
mIoU(ms+flip): 68.02
|
||||
Config: configs/deeplabv3plus/deeplabv3plus_r50-d8_4xb4-80k_isaid-896x896.py
|
||||
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/deeplabv3plus/deeplabv3plus_r50-d8_4x4_896x896_80k_isaid/deeplabv3plus_r50-d8_4x4_896x896_80k_isaid_20220110_180526-598be439.pth
|
||||
- Name: deeplabv3plus_r50-d8_4xb2-300k_mapillay_v1_65-1280x1280
|
||||
In Collection: DeepLabV3+
|
||||
Metadata:
|
||||
backbone: R-50-D8
|
||||
crop size: (1280,1280)
|
||||
lr schd: 300000
|
||||
inference time (ms/im):
|
||||
- value: 55.8
|
||||
hardware: V100
|
||||
backend: PyTorch
|
||||
batch size: 1
|
||||
mode: FP32
|
||||
resolution: (1280,1280)
|
||||
Training Memory (GB): 24.04
|
||||
Results:
|
||||
- Task: Semantic Segmentation
|
||||
Dataset: Mapillary Vistas v1.2
|
||||
Metrics:
|
||||
mIoU: 47.35
|
||||
Config: configs/deeplabv3plus/deeplabv3plus_r50-d8_4xb2-300k_mapillay_v1_65-1280x1280.py
|
||||
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/deeplabv3plus/deeplabv3plus_r50-d8_4xb2-300k_mapillay_v1_65-1280x1280/deeplabv3plus_r50-d8_4xb2-300k_mapillay_v1_65-1280x1280_20230301_110504-655f8e43.pth
|
||||
|
@ -0,0 +1,58 @@
|
||||
_base_ = [
|
||||
'../_base_/models/deeplabv3plus_r50-d8.py',
|
||||
'../_base_/datasets/mapillary_v1_65.py',
|
||||
'../_base_/default_runtime.py',
|
||||
]
|
||||
|
||||
crop_size = (1280, 1280)
|
||||
data_preprocessor = dict(size=crop_size)
|
||||
model = dict(
|
||||
data_preprocessor=data_preprocessor,
|
||||
pretrained='open-mmlab://resnet50_v1c',
|
||||
backbone=dict(depth=50),
|
||||
decode_head=dict(num_classes=65),
|
||||
auxiliary_head=dict(num_classes=65))
|
||||
|
||||
iters = 300000
|
||||
# optimizer
|
||||
optimizer = dict(
|
||||
type='AdamW', lr=0.0001, betas=(0.9, 0.999), weight_decay=0.0001)
|
||||
# optimizer
|
||||
optim_wrapper = dict(
|
||||
type='OptimWrapper',
|
||||
optimizer=optimizer,
|
||||
clip_grad=dict(max_norm=0.01, norm_type=2),
|
||||
paramwise_cfg=dict(
|
||||
custom_keys={'backbone': dict(lr_mult=0.1, decay_mult=1.0)}))
|
||||
param_scheduler = [
|
||||
dict(
|
||||
type='PolyLR',
|
||||
eta_min=0,
|
||||
power=0.9,
|
||||
begin=0,
|
||||
end=iters,
|
||||
by_epoch=False)
|
||||
]
|
||||
|
||||
# training schedule for 300k
|
||||
train_cfg = dict(
|
||||
type='IterBasedTrainLoop', max_iters=iters, val_interval=iters // 10)
|
||||
val_cfg = dict(type='ValLoop')
|
||||
test_cfg = dict(type='TestLoop')
|
||||
|
||||
default_hooks = dict(
|
||||
timer=dict(type='IterTimerHook'),
|
||||
logger=dict(type='LoggerHook', interval=50, log_metric_by_epoch=False),
|
||||
param_scheduler=dict(type='ParamSchedulerHook'),
|
||||
checkpoint=dict(
|
||||
type='CheckpointHook', by_epoch=False, interval=iters // 10),
|
||||
sampler_seed=dict(type='DistSamplerSeedHook'),
|
||||
visualization=dict(type='SegVisualizationHook'))
|
||||
|
||||
train_dataloader = dict(batch_size=2)
|
||||
|
||||
# Default setting for scaling LR automatically
|
||||
# - `enable` means enable scaling LR automatically
|
||||
# or not by default.
|
||||
# - `base_batch_size` = (4 GPUs) x (2 samples per GPU).
|
||||
auto_scale_lr = dict(enable=False, base_batch_size=8)
|
@ -154,6 +154,29 @@ mmsegmentation
|
||||
│ │ │ ├── training
|
||||
│ │ │ ├── validation
|
||||
│ │ │ ├── test
|
||||
│ ├── mapillary
|
||||
│ │ ├── training
|
||||
│ │ │ ├── images
|
||||
│ │ │ ├── v1.2
|
||||
| │ │ │ ├── instances
|
||||
| │ │ │ ├── labels
|
||||
| │ │ │ └── panoptic
|
||||
│ │ │ ├── v2.0
|
||||
| │ │ │ ├── instances
|
||||
| │ │ │ ├── labels
|
||||
| │ │ │ ├── panoptic
|
||||
| │ │ │ └── polygons
|
||||
│ │ ├── validation
|
||||
│ │ │ ├── images
|
||||
| │ │ ├── v1.2
|
||||
| │ │ │ ├── instances
|
||||
| │ │ │ ├── labels
|
||||
| │ │ │ └── panoptic
|
||||
│ │ │ ├── v2.0
|
||||
| │ │ │ ├── instances
|
||||
| │ │ │ ├── labels
|
||||
| │ │ │ ├── panoptic
|
||||
| │ │ │ └── polygons
|
||||
```
|
||||
|
||||
### Cityscapes
|
||||
@ -551,3 +574,54 @@ The script will make directory structure below:
|
||||
```
|
||||
|
||||
It includes 400 images for training, 400 images for validation and 400 images for testing which is the same as REFUGE 2018 dataset.
|
||||
|
||||
## Mapillary Vistas Datasets
|
||||
|
||||
- The dataset could be download [here](https://www.mapillary.com/dataset/vistas) after registration.
|
||||
|
||||
- Mapillary Vistas Dataset use 8-bit with color-palette to store labels. No conversion operation is required.
|
||||
|
||||
- Assumption you have put the dataset zip file in `mmsegmentation/data/mapillary`
|
||||
|
||||
- Please run the following commands to unzip dataset.
|
||||
|
||||
```bash
|
||||
cd data/mapillary
|
||||
unzip An-ZjB1Zm61yAZG0ozTymz8I8NqI4x0MrYrh26dq7kPgfu8vf9ImrdaOAVOFYbJ2pNAgUnVGBmbue9lTgdBOb5BbKXIpFs0fpYWqACbrQDChAA2fdX0zS9PcHu7fY8c-FOvyBVxPNYNFQuM.zip
|
||||
```
|
||||
|
||||
- After unzip, you will get Mapillary Vistas Dataset like this structure. Semantic segmentation mask labels in `labels` folder.
|
||||
|
||||
```none
|
||||
mmsegmentation
|
||||
├── mmseg
|
||||
├── tools
|
||||
├── configs
|
||||
├── data
|
||||
│ ├── mapillary
|
||||
│ │ ├── training
|
||||
│ │ │ ├── images
|
||||
│ │ │ ├── v1.2
|
||||
| │ │ │ ├── instances
|
||||
| │ │ │ ├── labels
|
||||
| │ │ │ └── panoptic
|
||||
│ │ │ ├── v2.0
|
||||
| │ │ │ ├── instances
|
||||
| │ │ │ ├── labels
|
||||
| │ │ │ ├── panoptic
|
||||
| │ │ │ └── polygons
|
||||
│ │ ├── validation
|
||||
│ │ │ ├── images
|
||||
| │ │ ├── v1.2
|
||||
| │ │ │ ├── instances
|
||||
| │ │ │ ├── labels
|
||||
| │ │ │ └── panoptic
|
||||
│ │ │ ├── v2.0
|
||||
| │ │ │ ├── instances
|
||||
| │ │ │ ├── labels
|
||||
| │ │ │ ├── panoptic
|
||||
| │ │ │ └── polygons
|
||||
```
|
||||
|
||||
- You could set Datasets version with `MapillaryDataset_v1` and `MapillaryDataset_v2` in your configs.
|
||||
View the Mapillary Vistas Datasets config file here [V1.2](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/configs/_base_/datasets/mapillary_v1.py) and [V2.0](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/configs/_base_/datasets/mapillary_v2.py)
|
||||
|
@ -14,6 +14,7 @@ from .isaid import iSAIDDataset
|
||||
from .isprs import ISPRSDataset
|
||||
from .lip import LIPDataset
|
||||
from .loveda import LoveDADataset
|
||||
from .mapillary import MapillaryDataset_v1, MapillaryDataset_v2
|
||||
from .night_driving import NightDrivingDataset
|
||||
from .pascal_context import PascalContextDataset, PascalContextDataset59
|
||||
from .potsdam import PotsdamDataset
|
||||
@ -49,5 +50,6 @@ __all__ = [
|
||||
'DecathlonDataset', 'LIPDataset', 'ResizeShortestEdge',
|
||||
'BioMedicalGaussianNoise', 'BioMedicalGaussianBlur',
|
||||
'BioMedicalRandomGamma', 'BioMedical3DPad', 'RandomRotFlip',
|
||||
'SynapseDataset', 'REFUGEDataset'
|
||||
'SynapseDataset', 'REFUGEDataset', 'MapillaryDataset_v1',
|
||||
'MapillaryDataset_v2'
|
||||
]
|
||||
|
@ -1,10 +1,72 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmseg.datasets.basesegdataset import BaseSegDataset
|
||||
from mmseg.registry import DATASETS
|
||||
from .basesegdataset import BaseSegDataset
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class MapillaryDataset_v2_0(BaseSegDataset):
|
||||
class MapillaryDataset_v1(BaseSegDataset):
|
||||
"""Mapillary Vistas Dataset.
|
||||
|
||||
Dataset paper link:
|
||||
http://ieeexplore.ieee.org/document/8237796/
|
||||
|
||||
v1.2 contain 66 object classes.
|
||||
(37 instance-specific)
|
||||
|
||||
v2.0 contain 124 object classes.
|
||||
(70 instance-specific, 46 stuff, 8 void or crowd).
|
||||
|
||||
The ``img_suffix`` is fixed to '.jpg' and ``seg_map_suffix`` is
|
||||
fixed to '.png' for Mapillary Vistas Dataset.
|
||||
"""
|
||||
METAINFO = dict(
|
||||
classes=('Bird', 'Ground Animal', 'Curb', 'Fence', 'Guard Rail',
|
||||
'Barrier', 'Wall', 'Bike Lane', 'Crosswalk - Plain',
|
||||
'Curb Cut', 'Parking', 'Pedestrian Area', 'Rail Track',
|
||||
'Road', 'Service Lane', 'Sidewalk', 'Bridge', 'Building',
|
||||
'Tunnel', 'Person', 'Bicyclist', 'Motorcyclist',
|
||||
'Other Rider', 'Lane Marking - Crosswalk',
|
||||
'Lane Marking - General', 'Mountain', 'Sand', 'Sky', 'Snow',
|
||||
'Terrain', 'Vegetation', 'Water', 'Banner', 'Bench',
|
||||
'Bike Rack', 'Billboard', 'Catch Basin', 'CCTV Camera',
|
||||
'Fire Hydrant', 'Junction Box', 'Mailbox', 'Manhole',
|
||||
'Phone Booth', 'Pothole', 'Street Light', 'Pole',
|
||||
'Traffic Sign Frame', 'Utility Pole', 'Traffic Light',
|
||||
'Traffic Sign (Back)', 'Traffic Sign (Front)', 'Trash Can',
|
||||
'Bicycle', 'Boat', 'Bus', 'Car', 'Caravan', 'Motorcycle',
|
||||
'On Rails', 'Other Vehicle', 'Trailer', 'Truck',
|
||||
'Wheeled Slow', 'Car Mount', 'Ego Vehicle', 'Unlabeled'),
|
||||
palette=[[165, 42, 42], [0, 192, 0], [196, 196, 196], [190, 153, 153],
|
||||
[180, 165, 180], [90, 120, 150], [102, 102, 156],
|
||||
[128, 64, 255], [140, 140, 200], [170, 170, 170],
|
||||
[250, 170, 160], [96, 96, 96],
|
||||
[230, 150, 140], [128, 64, 128], [110, 110, 110],
|
||||
[244, 35, 232], [150, 100, 100], [70, 70, 70], [150, 120, 90],
|
||||
[220, 20, 60], [255, 0, 0], [255, 0, 100], [255, 0, 200],
|
||||
[200, 128, 128], [255, 255, 255], [64, 170,
|
||||
64], [230, 160, 50],
|
||||
[70, 130, 180], [190, 255, 255], [152, 251, 152],
|
||||
[107, 142, 35], [0, 170, 30], [255, 255, 128], [250, 0, 30],
|
||||
[100, 140, 180], [220, 220, 220], [220, 128, 128],
|
||||
[222, 40, 40], [100, 170, 30], [40, 40, 40], [33, 33, 33],
|
||||
[100, 128, 160], [142, 0, 0], [70, 100, 150], [210, 170, 100],
|
||||
[153, 153, 153], [128, 128, 128], [0, 0, 80], [250, 170, 30],
|
||||
[192, 192, 192], [220, 220, 0], [140, 140, 20], [119, 11, 32],
|
||||
[150, 0, 255], [0, 60, 100], [0, 0, 142], [0, 0, 90],
|
||||
[0, 0, 230], [0, 80, 100], [128, 64, 64], [0, 0, 110],
|
||||
[0, 0, 70], [0, 0, 192], [32, 32, 32], [120, 10,
|
||||
10], [0, 0, 0]])
|
||||
|
||||
def __init__(self,
|
||||
img_suffix='.jpg',
|
||||
seg_map_suffix='.png',
|
||||
**kwargs) -> None:
|
||||
super().__init__(
|
||||
img_suffix=img_suffix, seg_map_suffix=seg_map_suffix, **kwargs)
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class MapillaryDataset_v2(BaseSegDataset):
|
||||
"""Mapillary Vistas Dataset.
|
||||
|
||||
Dataset paper link:
|
@ -126,6 +126,126 @@ def stare_classes():
|
||||
return ['background', 'vessel']
|
||||
|
||||
|
||||
def mapillary_v1_classes():
|
||||
"""mapillary_v1 class names for external use."""
|
||||
return [
|
||||
'Bird', 'Ground Animal', 'Curb', 'Fence', 'Guard Rail', 'Barrier',
|
||||
'Wall', 'Bike Lane', 'Crosswalk - Plain', 'Curb Cut', 'Parking',
|
||||
'Pedestrian Area', 'Rail Track', 'Road', 'Service Lane', 'Sidewalk',
|
||||
'Bridge', 'Building', 'Tunnel', 'Person', 'Bicyclist', 'Motorcyclist',
|
||||
'Other Rider', 'Lane Marking - Crosswalk', 'Lane Marking - General',
|
||||
'Mountain', 'Sand', 'Sky', 'Snow', 'Terrain', 'Vegetation', 'Water',
|
||||
'Banner', 'Bench', 'Bike Rack', 'Billboard', 'Catch Basin',
|
||||
'CCTV Camera', 'Fire Hydrant', 'Junction Box', 'Mailbox', 'Manhole',
|
||||
'Phone Booth', 'Pothole', 'Street Light', 'Pole', 'Traffic Sign Frame',
|
||||
'Utility Pole', 'Traffic Light', 'Traffic Sign (Back)',
|
||||
'Traffic Sign (Front)', 'Trash Can', 'Bicycle', 'Boat', 'Bus', 'Car',
|
||||
'Caravan', 'Motorcycle', 'On Rails', 'Other Vehicle', 'Trailer',
|
||||
'Truck', 'Wheeled Slow', 'Car Mount', 'Ego Vehicle', 'Unlabeled'
|
||||
]
|
||||
|
||||
|
||||
def mapillary_v1_palette():
|
||||
"""mapillary_v1_ palette for external use."""
|
||||
return [[165, 42, 42], [0, 192, 0], [196, 196, 196], [190, 153, 153],
|
||||
[180, 165, 180], [90, 120, 150], [102, 102, 156], [128, 64, 255],
|
||||
[140, 140, 200], [170, 170, 170], [250, 170, 160], [96, 96, 96],
|
||||
[230, 150, 140], [128, 64, 128], [110, 110, 110], [244, 35, 232],
|
||||
[150, 100, 100], [70, 70, 70], [150, 120, 90], [220, 20, 60],
|
||||
[255, 0, 0], [255, 0, 100], [255, 0, 200], [200, 128, 128],
|
||||
[255, 255, 255], [64, 170, 64], [230, 160, 50], [70, 130, 180],
|
||||
[190, 255, 255], [152, 251, 152], [107, 142, 35], [0, 170, 30],
|
||||
[255, 255, 128], [250, 0, 30], [100, 140, 180], [220, 220, 220],
|
||||
[220, 128, 128], [222, 40, 40], [100, 170, 30], [40, 40, 40],
|
||||
[33, 33, 33], [100, 128, 160], [142, 0, 0], [70, 100, 150],
|
||||
[210, 170, 100], [153, 153, 153], [128, 128, 128], [0, 0, 80],
|
||||
[250, 170, 30], [192, 192, 192], [220, 220, 0], [140, 140, 20],
|
||||
[119, 11, 32], [150, 0, 255], [0, 60, 100], [0, 0, 142],
|
||||
[0, 0, 90], [0, 0, 230], [0, 80, 100], [128, 64, 64], [0, 0, 110],
|
||||
[0, 0, 70], [0, 0, 192], [32, 32, 32], [120, 10, 10], [0, 0, 0]]
|
||||
|
||||
|
||||
def mapillary_v2_classes():
|
||||
"""mapillary_v2 class names for external use."""
|
||||
return [
|
||||
'Bird', 'Ground Animal', 'Ambiguous Barrier', 'Concrete Block', 'Curb',
|
||||
'Fence', 'Guard Rail', 'Barrier', 'Road Median', 'Road Side',
|
||||
'Lane Separator', 'Temporary Barrier', 'Wall', 'Bike Lane',
|
||||
'Crosswalk - Plain', 'Curb Cut', 'Driveway', 'Parking',
|
||||
'Parking Aisle', 'Pedestrian Area', 'Rail Track', 'Road',
|
||||
'Road Shoulder', 'Service Lane', 'Sidewalk', 'Traffic Island',
|
||||
'Bridge', 'Building', 'Garage', 'Tunnel', 'Person', 'Person Group',
|
||||
'Bicyclist', 'Motorcyclist', 'Other Rider',
|
||||
'Lane Marking - Dashed Line', 'Lane Marking - Straight Line',
|
||||
'Lane Marking - Zigzag Line', 'Lane Marking - Ambiguous',
|
||||
'Lane Marking - Arrow (Left)', 'Lane Marking - Arrow (Other)',
|
||||
'Lane Marking - Arrow (Right)',
|
||||
'Lane Marking - Arrow (Split Left or Straight)',
|
||||
'Lane Marking - Arrow (Split Right or Straight)',
|
||||
'Lane Marking - Arrow (Straight)', 'Lane Marking - Crosswalk',
|
||||
'Lane Marking - Give Way (Row)', 'Lane Marking - Give Way (Single)',
|
||||
'Lane Marking - Hatched (Chevron)',
|
||||
'Lane Marking - Hatched (Diagonal)', 'Lane Marking - Other',
|
||||
'Lane Marking - Stop Line', 'Lane Marking - Symbol (Bicycle)',
|
||||
'Lane Marking - Symbol (Other)', 'Lane Marking - Text',
|
||||
'Lane Marking (only) - Dashed Line', 'Lane Marking (only) - Crosswalk',
|
||||
'Lane Marking (only) - Other', 'Lane Marking (only) - Test',
|
||||
'Mountain', 'Sand', 'Sky', 'Snow', 'Terrain', 'Vegetation', 'Water',
|
||||
'Banner', 'Bench', 'Bike Rack', 'Catch Basin', 'CCTV Camera',
|
||||
'Fire Hydrant', 'Junction Box', 'Mailbox', 'Manhole', 'Parking Meter',
|
||||
'Phone Booth', 'Pothole', 'Signage - Advertisement',
|
||||
'Signage - Ambiguous', 'Signage - Back', 'Signage - Information',
|
||||
'Signage - Other', 'Signage - Store', 'Street Light', 'Pole',
|
||||
'Pole Group', 'Traffic Sign Frame', 'Utility Pole', 'Traffic Cone',
|
||||
'Traffic Light - General (Single)', 'Traffic Light - Pedestrians',
|
||||
'Traffic Light - General (Upright)',
|
||||
'Traffic Light - General (Horizontal)', 'Traffic Light - Cyclists',
|
||||
'Traffic Light - Other', 'Traffic Sign - Ambiguous',
|
||||
'Traffic Sign (Back)', 'Traffic Sign - Direction (Back)',
|
||||
'Traffic Sign - Direction (Front)', 'Traffic Sign (Front)',
|
||||
'Traffic Sign - Parking', 'Traffic Sign - Temporary (Back)',
|
||||
'Traffic Sign - Temporary (Front)', 'Trash Can', 'Bicycle', 'Boat',
|
||||
'Bus', 'Car', 'Caravan', 'Motorcycle', 'On Rails', 'Other Vehicle',
|
||||
'Trailer', 'Truck', 'Vehicle Group', 'Wheeled Slow', 'Water Valve',
|
||||
'Car Mount', 'Dynamic', 'Ego Vehicle', 'Ground', 'Static', 'Unlabeled'
|
||||
]
|
||||
|
||||
|
||||
def mapillary_v2_palette():
|
||||
"""mapillary_v2_ palette for external use."""
|
||||
return [[165, 42, 42], [0, 192, 0], [250, 170, 31], [250, 170, 32],
|
||||
[196, 196, 196], [190, 153, 153], [180, 165, 180], [90, 120, 150],
|
||||
[250, 170, 33], [250, 170, 34], [128, 128, 128], [250, 170, 35],
|
||||
[102, 102, 156], [128, 64, 255], [140, 140, 200], [170, 170, 170],
|
||||
[250, 170, 36], [250, 170, 160], [250, 170, 37], [96, 96, 96],
|
||||
[230, 150, 140], [128, 64, 128], [110, 110, 110], [110, 110, 110],
|
||||
[244, 35, 232], [128, 196, 128], [150, 100, 100], [70, 70, 70],
|
||||
[150, 150, 150], [150, 120, 90], [220, 20, 60], [220, 20, 60],
|
||||
[255, 0, 0], [255, 0, 100], [255, 0, 200], [255, 255, 255],
|
||||
[255, 255, 255], [250, 170, 29], [250, 170, 28], [250, 170, 26],
|
||||
[250, 170, 25], [250, 170, 24], [250, 170, 22], [250, 170, 21],
|
||||
[250, 170, 20], [255, 255, 255], [250, 170, 19], [250, 170, 18],
|
||||
[250, 170, 12], [250, 170, 11], [255, 255, 255], [255, 255, 255],
|
||||
[250, 170, 16], [250, 170, 15], [250, 170, 15], [255, 255, 255],
|
||||
[255, 255, 255], [255, 255, 255], [255, 255, 255], [64, 170, 64],
|
||||
[230, 160, 50], [70, 130, 180], [190, 255, 255], [152, 251, 152],
|
||||
[107, 142, 35], [0, 170, 30], [255, 255, 128], [250, 0, 30],
|
||||
[100, 140, 180], [220, 128, 128], [222, 40, 40], [100, 170, 30],
|
||||
[40, 40, 40], [33, 33, 33], [100, 128, 160], [20, 20, 255],
|
||||
[142, 0, 0], [70, 100, 150], [250, 171, 30], [250, 172, 30],
|
||||
[250, 173, 30], [250, 174, 30], [250, 175, 30], [250, 176, 30],
|
||||
[210, 170, 100], [153, 153, 153], [153, 153, 153], [128, 128, 128],
|
||||
[0, 0, 80], [210, 60, 60], [250, 170, 30], [250, 170, 30],
|
||||
[250, 170, 30], [250, 170, 30], [250, 170, 30], [250, 170, 30],
|
||||
[192, 192, 192], [192, 192, 192], [192, 192, 192], [220, 220, 0],
|
||||
[220, 220, 0], [0, 0, 196], [192, 192, 192], [220, 220, 0],
|
||||
[140, 140, 20], [119, 11, 32], [150, 0, 255], [0, 60, 100],
|
||||
[0, 0, 142], [0, 0, 90], [0, 0, 230], [0, 80, 100], [128, 64, 64],
|
||||
[0, 0, 110], [0, 0, 70], [0, 0, 142], [0, 0, 192], [170, 170, 170],
|
||||
[32, 32, 32], [111, 74, 0], [120, 10, 10], [81, 0, 81],
|
||||
[111, 111, 0], [0, 0, 0]]
|
||||
|
||||
|
||||
def cityscapes_palette():
|
||||
"""Cityscapes palette for external use."""
|
||||
return [[128, 64, 128], [244, 35, 232], [70, 70, 70], [102, 102, 156],
|
||||
@ -313,7 +433,9 @@ dataset_aliases = {
|
||||
],
|
||||
'isaid': ['isaid', 'iSAID'],
|
||||
'stare': ['stare', 'STARE'],
|
||||
'lip': ['LIP', 'lip']
|
||||
'lip': ['LIP', 'lip'],
|
||||
'mapillary_v1': ['mapillary_v1'],
|
||||
'mapillary_v2': ['mapillary_v2']
|
||||
}
|
||||
|
||||
|
||||
|
@ -34,6 +34,7 @@ Preparing `Mapillary Vistas Dataset` dataset following [Mapillary Vistas Dataset
|
||||
| │ │ │ └── polygons
|
||||
│ │ ├── validation
|
||||
│ │ │ ├── images
|
||||
│ │ │ ├── v1.2
|
||||
| │ │ │ ├── instances
|
||||
| │ │ │ ├── labels
|
||||
| │ │ │ ├── labels_mask
|
||||
@ -46,12 +47,12 @@ Preparing `Mapillary Vistas Dataset` dataset following [Mapillary Vistas Dataset
|
||||
| │ │ │ └── polygons
|
||||
```
|
||||
|
||||
### Training commands with `deeplabv3plus_r101-d8_4xb2-240k_mapillay-512x1024.py`
|
||||
### Training commands
|
||||
|
||||
```bash
|
||||
# Dataset train commands
|
||||
# at `mmsegmentation` folder
|
||||
bash tools/dist_train.sh projects/mapillary_dataset/configs/deeplabv3plus_r101-d8_4xb2-240k_mapillay-512x1024.py 4
|
||||
bash tools/dist_train.sh projects/mapillary_dataset/configs/deeplabv3plus_r101-d8_4xb2-240k_mapillay_v1-512x1024.py 4
|
||||
```
|
||||
|
||||
## Checklist
|
||||
@ -66,20 +67,20 @@ bash tools/dist_train.sh projects/mapillary_dataset/configs/deeplabv3plus_r101-d
|
||||
|
||||
- [x] A full README
|
||||
|
||||
- [ ] Milestone 2: Indicates a successful model implementation.
|
||||
- [x] Milestone 2: Indicates a successful model implementation.
|
||||
|
||||
- [ ] Training-time correctness
|
||||
- [x] Training-time correctness
|
||||
|
||||
- [ ] Milestone 3: Good to be a part of our core package!
|
||||
- [x] Milestone 3: Good to be a part of our core package!
|
||||
|
||||
- [ ] Type hints and docstrings
|
||||
- [x] Type hints and docstrings
|
||||
|
||||
- [ ] Unit tests
|
||||
- [x] Unit tests
|
||||
|
||||
- [ ] Code polishing
|
||||
- [x] Code polishing
|
||||
|
||||
- [ ] Metafile.yml
|
||||
- [x] Metafile.yml
|
||||
|
||||
- [ ] Move your modules into the core package following the codebase's file hierarchy structure.
|
||||
- [x] Move your modules into the core package following the codebase's file hierarchy structure.
|
||||
|
||||
- [ ] Refactor your modules into the core package following the codebase's file hierarchy structure.
|
||||
- [x] Refactor your modules into the core package following the codebase's file hierarchy structure.
|
||||
|
@ -0,0 +1,68 @@
|
||||
# dataset settings
|
||||
dataset_type = 'MapillaryDataset_v1'
|
||||
data_root = 'data/mapillary/'
|
||||
crop_size = (512, 1024)
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='LoadAnnotations'),
|
||||
dict(
|
||||
type='RandomResize',
|
||||
scale=(2048, 1024),
|
||||
ratio_range=(0.5, 2.0),
|
||||
keep_ratio=True),
|
||||
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
|
||||
dict(type='RandomFlip', prob=0.5),
|
||||
dict(type='PhotoMetricDistortion'),
|
||||
dict(type='PackSegInputs')
|
||||
]
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='Resize', scale=(2048, 1024), keep_ratio=True),
|
||||
# add loading annotation after ``Resize`` because ground truth
|
||||
# does not need to do resize data transform
|
||||
dict(type='LoadAnnotations'),
|
||||
dict(type='PackSegInputs')
|
||||
]
|
||||
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
|
||||
tta_pipeline = [
|
||||
dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),
|
||||
dict(
|
||||
type='TestTimeAug',
|
||||
transforms=[
|
||||
[
|
||||
dict(type='Resize', scale_factor=r, keep_ratio=True)
|
||||
for r in img_ratios
|
||||
],
|
||||
[
|
||||
dict(type='RandomFlip', prob=0., direction='horizontal'),
|
||||
dict(type='RandomFlip', prob=1., direction='horizontal')
|
||||
], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
|
||||
])
|
||||
]
|
||||
train_dataloader = dict(
|
||||
batch_size=2,
|
||||
num_workers=4,
|
||||
persistent_workers=True,
|
||||
sampler=dict(type='InfiniteSampler', shuffle=True),
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
data_prefix=dict(
|
||||
img_path='training/images', seg_map_path='training/v1.2/labels'),
|
||||
pipeline=train_pipeline))
|
||||
val_dataloader = dict(
|
||||
batch_size=1,
|
||||
num_workers=4,
|
||||
persistent_workers=True,
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
data_prefix=dict(
|
||||
img_path='validation/images',
|
||||
seg_map_path='validation/v1.2/labels'),
|
||||
pipeline=test_pipeline))
|
||||
test_dataloader = val_dataloader
|
||||
|
||||
val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU'])
|
||||
test_evaluator = val_evaluator
|
@ -0,0 +1,37 @@
|
||||
# dataset settings
|
||||
_base_ = './mapillary_v1.py'
|
||||
metainfo = dict(
|
||||
classes=('Bird', 'Ground Animal', 'Curb', 'Fence', 'Guard Rail', 'Barrier',
|
||||
'Wall', 'Bike Lane', 'Crosswalk - Plain', 'Curb Cut', 'Parking',
|
||||
'Pedestrian Area', 'Rail Track', 'Road', 'Service Lane',
|
||||
'Sidewalk', 'Bridge', 'Building', 'Tunnel', 'Person', 'Bicyclist',
|
||||
'Motorcyclist', 'Other Rider', 'Lane Marking - Crosswalk',
|
||||
'Lane Marking - General', 'Mountain', 'Sand', 'Sky', 'Snow',
|
||||
'Terrain', 'Vegetation', 'Water', 'Banner', 'Bench', 'Bike Rack',
|
||||
'Billboard', 'Catch Basin', 'CCTV Camera', 'Fire Hydrant',
|
||||
'Junction Box', 'Mailbox', 'Manhole', 'Phone Booth', 'Pothole',
|
||||
'Street Light', 'Pole', 'Traffic Sign Frame', 'Utility Pole',
|
||||
'Traffic Light', 'Traffic Sign (Back)', 'Traffic Sign (Front)',
|
||||
'Trash Can', 'Bicycle', 'Boat', 'Bus', 'Car', 'Caravan',
|
||||
'Motorcycle', 'On Rails', 'Other Vehicle', 'Trailer', 'Truck',
|
||||
'Wheeled Slow', 'Car Mount', 'Ego Vehicle'),
|
||||
palette=[[165, 42, 42], [0, 192, 0], [196, 196, 196], [190, 153, 153],
|
||||
[180, 165, 180], [90, 120, 150], [102, 102, 156], [128, 64, 255],
|
||||
[140, 140, 200], [170, 170, 170], [250, 170, 160], [96, 96, 96],
|
||||
[230, 150, 140], [128, 64, 128], [110, 110, 110], [244, 35, 232],
|
||||
[150, 100, 100], [70, 70, 70], [150, 120, 90], [220, 20, 60],
|
||||
[255, 0, 0], [255, 0, 100], [255, 0, 200], [200, 128, 128],
|
||||
[255, 255, 255], [64, 170, 64], [230, 160, 50], [70, 130, 180],
|
||||
[190, 255, 255], [152, 251, 152], [107, 142, 35], [0, 170, 30],
|
||||
[255, 255, 128], [250, 0, 30], [100, 140, 180], [220, 220, 220],
|
||||
[220, 128, 128], [222, 40, 40], [100, 170, 30], [40, 40, 40],
|
||||
[33, 33, 33], [100, 128, 160], [142, 0, 0], [70, 100, 150],
|
||||
[210, 170, 100], [153, 153, 153], [128, 128, 128], [0, 0, 80],
|
||||
[250, 170, 30], [192, 192, 192], [220, 220, 0], [140, 140, 20],
|
||||
[119, 11, 32], [150, 0, 255], [0, 60, 100], [0, 0, 142],
|
||||
[0, 0, 90], [0, 0, 230], [0, 80, 100], [128, 64, 64], [0, 0, 110],
|
||||
[0, 0, 70], [0, 0, 192], [32, 32, 32], [120, 10, 10]])
|
||||
|
||||
train_dataloader = dict(dataset=dict(metainfo=metainfo))
|
||||
val_dataloader = dict(dataset=dict(metainfo=metainfo))
|
||||
test_dataloader = val_dataloader
|
@ -0,0 +1,68 @@
|
||||
# dataset settings
|
||||
dataset_type = 'MapillaryDataset_v2'
|
||||
data_root = 'data/mapillary/'
|
||||
crop_size = (512, 1024)
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='LoadAnnotations'),
|
||||
dict(
|
||||
type='RandomResize',
|
||||
scale=(2048, 1024),
|
||||
ratio_range=(0.5, 2.0),
|
||||
keep_ratio=True),
|
||||
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
|
||||
dict(type='RandomFlip', prob=0.5),
|
||||
dict(type='PhotoMetricDistortion'),
|
||||
dict(type='PackSegInputs')
|
||||
]
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='Resize', scale=(2048, 1024), keep_ratio=True),
|
||||
# add loading annotation after ``Resize`` because ground truth
|
||||
# does not need to do resize data transform
|
||||
dict(type='LoadAnnotations'),
|
||||
dict(type='PackSegInputs')
|
||||
]
|
||||
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
|
||||
tta_pipeline = [
|
||||
dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),
|
||||
dict(
|
||||
type='TestTimeAug',
|
||||
transforms=[
|
||||
[
|
||||
dict(type='Resize', scale_factor=r, keep_ratio=True)
|
||||
for r in img_ratios
|
||||
],
|
||||
[
|
||||
dict(type='RandomFlip', prob=0., direction='horizontal'),
|
||||
dict(type='RandomFlip', prob=1., direction='horizontal')
|
||||
], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
|
||||
])
|
||||
]
|
||||
train_dataloader = dict(
|
||||
batch_size=2,
|
||||
num_workers=4,
|
||||
persistent_workers=True,
|
||||
sampler=dict(type='InfiniteSampler', shuffle=True),
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
data_prefix=dict(
|
||||
img_path='training/images', seg_map_path='training/v2.0/labels'),
|
||||
pipeline=train_pipeline))
|
||||
val_dataloader = dict(
|
||||
batch_size=1,
|
||||
num_workers=4,
|
||||
persistent_workers=True,
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
data_prefix=dict(
|
||||
img_path='validation/images',
|
||||
seg_map_path='validation/v2.0/labels'),
|
||||
pipeline=test_pipeline))
|
||||
test_dataloader = val_dataloader
|
||||
|
||||
val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU'])
|
||||
test_evaluator = val_evaluator
|
@ -1,103 +0,0 @@
|
||||
_base_ = ['./_base_/datasets/mapillary_v1_2.py'] # v 1.2 labels
|
||||
# _base_ = ['./_base_/datasets/mapillary_v2_0.py'] # v2.0 labels
|
||||
custom_imports = dict(imports=[
|
||||
'projects.mapillary_dataset.mmseg.datasets.mapillary_v1_2',
|
||||
'projects.mapillary_dataset.mmseg.datasets.mapillary_v2_0',
|
||||
])
|
||||
|
||||
norm_cfg = dict(type='SyncBN', requires_grad=True)
|
||||
data_preprocessor = dict(
|
||||
type='SegDataPreProcessor',
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
bgr_to_rgb=True,
|
||||
pad_val=0,
|
||||
seg_pad_val=255,
|
||||
size=(512, 1024))
|
||||
|
||||
model = dict(
|
||||
type='EncoderDecoder',
|
||||
data_preprocessor=data_preprocessor,
|
||||
pretrained=None,
|
||||
backbone=dict(
|
||||
type='ResNet',
|
||||
depth=101,
|
||||
num_stages=4,
|
||||
out_indices=(0, 1, 2, 3),
|
||||
dilations=(1, 1, 2, 4),
|
||||
strides=(1, 2, 1, 1),
|
||||
norm_cfg=norm_cfg,
|
||||
norm_eval=False,
|
||||
style='pytorch',
|
||||
contract_dilation=True),
|
||||
decode_head=dict(
|
||||
type='DepthwiseSeparableASPPHead',
|
||||
in_channels=2048,
|
||||
in_index=3,
|
||||
channels=512,
|
||||
dilations=(1, 12, 24, 36),
|
||||
c1_in_channels=256,
|
||||
c1_channels=48,
|
||||
dropout_ratio=0.1,
|
||||
num_classes=66, # v1.2
|
||||
# num_classes=124, # v2.0
|
||||
norm_cfg=norm_cfg,
|
||||
align_corners=False,
|
||||
loss_decode=dict(
|
||||
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
|
||||
auxiliary_head=dict(
|
||||
type='FCNHead',
|
||||
in_channels=1024,
|
||||
in_index=2,
|
||||
channels=256,
|
||||
num_convs=1,
|
||||
concat_input=False,
|
||||
dropout_ratio=0.1,
|
||||
num_classes=66, # v1.2
|
||||
# num_classes=124, # v2.0
|
||||
norm_cfg=norm_cfg,
|
||||
align_corners=False,
|
||||
loss_decode=dict(
|
||||
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
|
||||
train_cfg=dict(),
|
||||
test_cfg=dict(mode='whole'))
|
||||
default_scope = 'mmseg'
|
||||
env_cfg = dict(
|
||||
cudnn_benchmark=True,
|
||||
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
|
||||
dist_cfg=dict(backend='nccl'))
|
||||
vis_backends = [dict(type='LocalVisBackend')]
|
||||
visualizer = dict(
|
||||
type='SegLocalVisualizer',
|
||||
vis_backends=[dict(type='LocalVisBackend')],
|
||||
name='visualizer')
|
||||
log_processor = dict(by_epoch=False)
|
||||
log_level = 'INFO'
|
||||
load_from = None
|
||||
resume = False
|
||||
tta_model = dict(type='SegTTAModel')
|
||||
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)
|
||||
optim_wrapper = dict(
|
||||
type='OptimWrapper',
|
||||
optimizer=dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001),
|
||||
clip_grad=None)
|
||||
param_scheduler = [
|
||||
dict(
|
||||
type='PolyLR',
|
||||
eta_min=0.0001,
|
||||
power=0.9,
|
||||
begin=0,
|
||||
end=240000,
|
||||
by_epoch=False)
|
||||
]
|
||||
train_cfg = dict(
|
||||
type='IterBasedTrainLoop', max_iters=240000, val_interval=24000)
|
||||
val_cfg = dict(type='ValLoop')
|
||||
test_cfg = dict(type='TestLoop')
|
||||
default_hooks = dict(
|
||||
timer=dict(type='IterTimerHook'),
|
||||
logger=dict(type='LoggerHook', interval=50, log_metric_by_epoch=False),
|
||||
param_scheduler=dict(type='ParamSchedulerHook'),
|
||||
checkpoint=dict(type='CheckpointHook', by_epoch=False, interval=24000),
|
||||
sampler_seed=dict(type='DistSamplerSeedHook'),
|
||||
visualization=dict(type='SegVisualizationHook'))
|
@ -0,0 +1,17 @@
|
||||
_base_ = [
|
||||
'../../../configs/_base_/models/deeplabv3plus_r50-d8.py',
|
||||
'./_base_/datasets/mapillary_v1.py',
|
||||
'../../../configs/_base_/default_runtime.py',
|
||||
'../../../configs/_base_/schedules/schedule_240k.py'
|
||||
]
|
||||
custom_imports = dict(
|
||||
imports=['projects.mapillary_dataset.mmseg.datasets.mapillary'])
|
||||
|
||||
crop_size = (512, 1024)
|
||||
data_preprocessor = dict(size=crop_size)
|
||||
model = dict(
|
||||
data_preprocessor=data_preprocessor,
|
||||
pretrained='open-mmlab://resnet101_v1c',
|
||||
backbone=dict(depth=101),
|
||||
decode_head=dict(num_classes=66),
|
||||
auxiliary_head=dict(num_classes=66))
|
@ -0,0 +1,16 @@
|
||||
_base_ = [
|
||||
'../../../configs/_base_/models/deeplabv3plus_r50-d8.py',
|
||||
'./_base_/datasets/mapillary_v2.py',
|
||||
'../../../configs/_base_/default_runtime.py',
|
||||
'../../../configs/_base_/schedules/schedule_240k.py'
|
||||
]
|
||||
custom_imports = dict(
|
||||
imports=['projects.mapillary_dataset.mmseg.datasets.mapillary'])
|
||||
crop_size = (512, 1024)
|
||||
data_preprocessor = dict(size=crop_size)
|
||||
model = dict(
|
||||
data_preprocessor=data_preprocessor,
|
||||
pretrained='open-mmlab://resnet101_v1c',
|
||||
backbone=dict(depth=101),
|
||||
decode_head=dict(num_classes=124),
|
||||
auxiliary_head=dict(num_classes=124))
|
@ -0,0 +1,16 @@
|
||||
_base_ = [
|
||||
'../../../configs/_base_/models/pspnet_r50-d8.py',
|
||||
'./_base_/datasets/mapillary_v1.py',
|
||||
'../../../configs/_base_/default_runtime.py',
|
||||
'../../../configs/_base_/schedules/schedule_240k.py'
|
||||
]
|
||||
custom_imports = dict(
|
||||
imports=['projects.mapillary_dataset.mmseg.datasets.mapillary'])
|
||||
crop_size = (512, 1024)
|
||||
data_preprocessor = dict(size=crop_size)
|
||||
model = dict(
|
||||
data_preprocessor=data_preprocessor,
|
||||
pretrained='open-mmlab://resnet101_v1c',
|
||||
backbone=dict(depth=101),
|
||||
decode_head=dict(num_classes=66),
|
||||
auxiliary_head=dict(num_classes=66))
|
@ -0,0 +1,16 @@
|
||||
_base_ = [
|
||||
'../../../configs/_base_/models/pspnet_r50-d8.py',
|
||||
'./_base_/datasets/mapillary_v2.py',
|
||||
'../../../configs/_base_/default_runtime.py',
|
||||
'../../../configs/_base_/schedules/schedule_240k.py'
|
||||
]
|
||||
custom_imports = dict(
|
||||
imports=['projects.mapillary_dataset.mmseg.datasets.mapillary'])
|
||||
crop_size = (512, 1024)
|
||||
data_preprocessor = dict(size=crop_size)
|
||||
model = dict(
|
||||
data_preprocessor=data_preprocessor,
|
||||
pretrained='open-mmlab://resnet101_v1c',
|
||||
backbone=dict(depth=101),
|
||||
decode_head=dict(num_classes=124),
|
||||
auxiliary_head=dict(num_classes=124))
|
@ -1,87 +1,20 @@
|
||||
## Prepare datasets
|
||||
|
||||
It is recommended to symlink the dataset root to `$MMSEGMENTATION/data`.
|
||||
If your folder structure is different, you may need to change the corresponding paths in config files.
|
||||
|
||||
```none
|
||||
mmsegmentation
|
||||
├── mmseg
|
||||
├── tools
|
||||
├── configs
|
||||
├── data
|
||||
│ ├── mapillary
|
||||
│ │ ├── training
|
||||
│ │ │ ├── images
|
||||
│ │ │ ├── v1.2
|
||||
| │ │ │ ├── instances
|
||||
| │ │ │ ├── labels
|
||||
| │ │ │ ├── labels_mask
|
||||
| │ │ │ └── panoptic
|
||||
│ │ │ ├── v2.0
|
||||
| │ │ │ ├── instances
|
||||
| │ │ │ ├── labels
|
||||
| │ │ │ ├── labels_mask
|
||||
| │ │ │ ├── panoptic
|
||||
| │ │ │ └── polygons
|
||||
│ │ ├── validation
|
||||
│ │ │ ├── images
|
||||
| │ │ │ ├── instances
|
||||
| │ │ │ ├── labels
|
||||
| │ │ │ ├── labels_mask
|
||||
| │ │ │ └── panoptic
|
||||
│ │ │ ├── v2.0
|
||||
| │ │ │ ├── instances
|
||||
| │ │ │ ├── labels
|
||||
| │ │ │ ├── labels_mask
|
||||
| │ │ │ ├── panoptic
|
||||
| │ │ │ └── polygons
|
||||
```
|
||||
|
||||
## Mapillary Vistas Datasets
|
||||
|
||||
- The dataset could be download [here](https://www.mapillary.com/dataset/vistas) after registration.
|
||||
- Assumption you have put the dataset zip file in `mmsegmentation/data`
|
||||
|
||||
- Mapillary Vistas Dataset use 8-bit with color-palette to store labels. No conversion operation is required.
|
||||
|
||||
- Assumption you have put the dataset zip file in `mmsegmentation/data/mapillary`
|
||||
|
||||
- Please run the following commands to unzip dataset.
|
||||
|
||||
```bash
|
||||
cd data
|
||||
mkdir mapillary
|
||||
unzip -d mapillary An-ZjB1Zm61yAZG0ozTymz8I8NqI4x0MrYrh26dq7kPgfu8vf9ImrdaOAVOFYbJ2pNAgUnVGBmbue9lTgdBOb5BbKXIpFs0fpYWqACbrQDChAA2fdX0zS9PcHu7fY8c-FOvyBVxPNYNFQuM.zip
|
||||
cd data/mapillary
|
||||
unzip An-ZjB1Zm61yAZG0ozTymz8I8NqI4x0MrYrh26dq7kPgfu8vf9ImrdaOAVOFYbJ2pNAgUnVGBmbue9lTgdBOb5BbKXIpFs0fpYWqACbrQDChAA2fdX0zS9PcHu7fY8c-FOvyBVxPNYNFQuM.zip
|
||||
```
|
||||
- After unzip, you will get Mapillary Vistas Dataset like this structure.
|
||||
```none
|
||||
├── data
|
||||
│ ├── mapillary
|
||||
│ │ ├── training
|
||||
│ │ │ ├── images
|
||||
│ │ │ ├── v1.2
|
||||
| │ │ │ ├── instances
|
||||
| │ │ │ ├── labels
|
||||
| │ │ │ └── panoptic
|
||||
│ │ │ ├── v2.0
|
||||
| │ │ │ ├── instances
|
||||
| │ │ │ ├── labels
|
||||
| │ │ │ ├── panoptic
|
||||
| │ │ │ └── polygons
|
||||
│ │ ├── validation
|
||||
│ │ │ ├── images
|
||||
| │ │ │ ├── instances
|
||||
| │ │ │ ├── labels
|
||||
| │ │ │ └── panoptic
|
||||
│ │ │ ├── v2.0
|
||||
| │ │ │ ├── instances
|
||||
| │ │ │ ├── labels
|
||||
| │ │ │ ├── panoptic
|
||||
| │ │ │ └── polygons
|
||||
```
|
||||
- run following commands to convert RGB labels to mask labels
|
||||
```bash
|
||||
# --nproc optional, default 1, whether use multi-progress
|
||||
# --version optional, 'v1.2', 'v2.0','all', default 'all', choose convert which version labels
|
||||
# run this command at 'mmsegmentation/projects/Mapillary_dataset' folder
|
||||
cd mmsegmentation/projects/mapillary_dataset
|
||||
python tools/dataset_converters/mapillary.py ../../data/mapillary --nproc 8 --version all
|
||||
```
|
||||
After then, you will get this structure
|
||||
|
||||
- After unzip, you will get Mapillary Vistas Dataset like this structure. Semantic segmentation mask labels in `labels` folder.
|
||||
|
||||
```none
|
||||
mmsegmentation
|
||||
├── mmseg
|
||||
@ -94,24 +27,229 @@ mmsegmentation
|
||||
│ │ │ ├── v1.2
|
||||
| │ │ │ ├── instances
|
||||
| │ │ │ ├── labels
|
||||
| │ │ │ ├── labels_mask
|
||||
| │ │ │ └── panoptic
|
||||
│ │ │ ├── v2.0
|
||||
| │ │ │ ├── instances
|
||||
| │ │ │ ├── labels
|
||||
| │ │ │ ├── labels_mask
|
||||
| │ │ │ ├── panoptic
|
||||
| │ │ │ └── polygons
|
||||
│ │ ├── validation
|
||||
│ │ │ ├── images
|
||||
| │ │ ├── v1.2
|
||||
| │ │ │ ├── instances
|
||||
| │ │ │ ├── labels
|
||||
| │ │ │ ├── labels_mask
|
||||
| │ │ │ └── panoptic
|
||||
│ │ │ ├── v2.0
|
||||
| │ │ │ ├── instances
|
||||
| │ │ │ ├── labels
|
||||
| │ │ │ ├── labels_mask
|
||||
| │ │ │ ├── panoptic
|
||||
| │ │ │ └── polygons
|
||||
```
|
||||
|
||||
- You could set Datasets version with `MapillaryDataset_v1` and `MapillaryDataset_v2` in your configs.
|
||||
View the Mapillary Vistas Datasets config file here [V1.2](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/configs/_base_/datasets/mapillary_v1.py) and [V2.0](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/configs/_base_/datasets/mapillary_v2.py)
|
||||
|
||||
- **View datasets labels index and palette**
|
||||
|
||||
- **Mapillary Vistas Datasets labels information**
|
||||
**v1.2 information**
|
||||
|
||||
```none
|
||||
There are 66 labels classes in v1.2
|
||||
0--Bird--[165, 42, 42],
|
||||
1--Ground Animal--[0, 192, 0],
|
||||
2--Curb--[196, 196, 196],
|
||||
3--Fence--[190, 153, 153],
|
||||
4--Guard Rail--[180, 165, 180],
|
||||
5--Barrier--[90, 120, 150],
|
||||
6--Wall--[102, 102, 156],
|
||||
7--Bike Lane--[128, 64, 255],
|
||||
8--Crosswalk - Plain--[140, 140, 200],
|
||||
9--Curb Cut--[170, 170, 170],
|
||||
10--Parking--[250, 170, 160],
|
||||
11--Pedestrian Area--[96, 96, 96],
|
||||
12--Rail Track--[230, 150, 140],
|
||||
13--Road--[128, 64, 128],
|
||||
14--Service Lane--[110, 110, 110],
|
||||
15--Sidewalk--[244, 35, 232],
|
||||
16--Bridge--[150, 100, 100],
|
||||
17--Building--[70, 70, 70],
|
||||
18--Tunnel--[150, 120, 90],
|
||||
19--Person--[220, 20, 60],
|
||||
20--Bicyclist--[255, 0, 0],
|
||||
21--Motorcyclist--[255, 0, 100],
|
||||
22--Other Rider--[255, 0, 200],
|
||||
23--Lane Marking - Crosswalk--[200, 128, 128],
|
||||
24--Lane Marking - General--[255, 255, 255],
|
||||
25--Mountain--[64, 170, 64],
|
||||
26--Sand--[230, 160, 50],
|
||||
27--Sky--[70, 130, 180],
|
||||
28--Snow--[190, 255, 255],
|
||||
29--Terrain--[152, 251, 152],
|
||||
30--Vegetation--[107, 142, 35],
|
||||
31--Water--[0, 170, 30],
|
||||
32--Banner--[255, 255, 128],
|
||||
33--Bench--[250, 0, 30],
|
||||
34--Bike Rack--[100, 140, 180],
|
||||
35--Billboard--[220, 220, 220],
|
||||
36--Catch Basin--[220, 128, 128],
|
||||
37--CCTV Camera--[222, 40, 40],
|
||||
38--Fire Hydrant--[100, 170, 30],
|
||||
39--Junction Box--[40, 40, 40],
|
||||
40--Mailbox--[33, 33, 33],
|
||||
41--Manhole--[100, 128, 160],
|
||||
42--Phone Booth--[142, 0, 0],
|
||||
43--Pothole--[70, 100, 150],
|
||||
44--Street Light--[210, 170, 100],
|
||||
45--Pole--[153, 153, 153],
|
||||
46--Traffic Sign Frame--[128, 128, 128],
|
||||
47--Utility Pole--[0, 0, 80],
|
||||
48--Traffic Light--[250, 170, 30],
|
||||
49--Traffic Sign (Back)--[192, 192, 192],
|
||||
50--Traffic Sign (Front)--[220, 220, 0],
|
||||
51--Trash Can--[140, 140, 20],
|
||||
52--Bicycle--[119, 11, 32],
|
||||
53--Boat--[150, 0, 255],
|
||||
54--Bus--[0, 60, 100],
|
||||
55--Car--[0, 0, 142],
|
||||
56--Caravan--[0, 0, 90],
|
||||
57--Motorcycle--[0, 0, 230],
|
||||
58--On Rails--[0, 80, 100],
|
||||
59--Other Vehicle--[128, 64, 64],
|
||||
60--Trailer--[0, 0, 110],
|
||||
61--Truck--[0, 0, 70],
|
||||
62--Wheeled Slow--[0, 0, 192],
|
||||
63--Car Mount--[32, 32, 32],
|
||||
64--Ego Vehicle--[120, 10, 10],
|
||||
65--Unlabeled--[0, 0, 0]
|
||||
```
|
||||
|
||||
**v2.0 information**
|
||||
|
||||
```none
|
||||
There are 124 labels classes in v2.0
|
||||
0--Bird--[165, 42, 42],
|
||||
1--Ground Animal--[0, 192, 0],
|
||||
2--Ambiguous Barrier--[250, 170, 31],
|
||||
3--Concrete Block--[250, 170, 32],
|
||||
4--Curb--[196, 196, 196],
|
||||
5--Fence--[190, 153, 153],
|
||||
6--Guard Rail--[180, 165, 180],
|
||||
7--Barrier--[90, 120, 150],
|
||||
8--Road Median--[250, 170, 33],
|
||||
9--Road Side--[250, 170, 34],
|
||||
10--Lane Separator--[128, 128, 128],
|
||||
11--Temporary Barrier--[250, 170, 35],
|
||||
12--Wall--[102, 102, 156],
|
||||
13--Bike Lane--[128, 64, 255],
|
||||
14--Crosswalk - Plain--[140, 140, 200],
|
||||
15--Curb Cut--[170, 170, 170],
|
||||
16--Driveway--[250, 170, 36],
|
||||
17--Parking--[250, 170, 160],
|
||||
18--Parking Aisle--[250, 170, 37],
|
||||
19--Pedestrian Area--[96, 96, 96],
|
||||
20--Rail Track--[230, 150, 140],
|
||||
21--Road--[128, 64, 128],
|
||||
22--Road Shoulder--[110, 110, 110],
|
||||
23--Service Lane--[110, 110, 110],
|
||||
24--Sidewalk--[244, 35, 232],
|
||||
25--Traffic Island--[128, 196, 128],
|
||||
26--Bridge--[150, 100, 100],
|
||||
27--Building--[70, 70, 70],
|
||||
28--Garage--[150, 150, 150],
|
||||
29--Tunnel--[150, 120, 90],
|
||||
30--Person--[220, 20, 60],
|
||||
31--Person Group--[220, 20, 60],
|
||||
32--Bicyclist--[255, 0, 0],
|
||||
33--Motorcyclist--[255, 0, 100],
|
||||
34--Other Rider--[255, 0, 200],
|
||||
35--Lane Marking - Dashed Line--[255, 255, 255],
|
||||
36--Lane Marking - Straight Line--[255, 255, 255],
|
||||
37--Lane Marking - Zigzag Line--[250, 170, 29],
|
||||
38--Lane Marking - Ambiguous--[250, 170, 28],
|
||||
39--Lane Marking - Arrow (Left)--[250, 170, 26],
|
||||
40--Lane Marking - Arrow (Other)--[250, 170, 25],
|
||||
41--Lane Marking - Arrow (Right)--[250, 170, 24],
|
||||
42--Lane Marking - Arrow (Split Left or Straight)--[250, 170, 22],
|
||||
43--Lane Marking - Arrow (Split Right or Straight)--[250, 170, 21],
|
||||
44--Lane Marking - Arrow (Straight)--[250, 170, 20],
|
||||
45--Lane Marking - Crosswalk--[255, 255, 255],
|
||||
46--Lane Marking - Give Way (Row)--[250, 170, 19],
|
||||
47--Lane Marking - Give Way (Single)--[250, 170, 18],
|
||||
48--Lane Marking - Hatched (Chevron)--[250, 170, 12],
|
||||
49--Lane Marking - Hatched (Diagonal)--[250, 170, 11],
|
||||
50--Lane Marking - Other--[255, 255, 255],
|
||||
51--Lane Marking - Stop Line--[255, 255, 255],
|
||||
52--Lane Marking - Symbol (Bicycle)--[250, 170, 16],
|
||||
53--Lane Marking - Symbol (Other)--[250, 170, 15],
|
||||
54--Lane Marking - Text--[250, 170, 15],
|
||||
55--Lane Marking (only) - Dashed Line--[255, 255, 255],
|
||||
56--Lane Marking (only) - Crosswalk--[255, 255, 255],
|
||||
57--Lane Marking (only) - Other--[255, 255, 255],
|
||||
58--Lane Marking (only) - Test--[255, 255, 255],
|
||||
59--Mountain--[64, 170, 64],
|
||||
60--Sand--[230, 160, 50],
|
||||
61--Sky--[70, 130, 180],
|
||||
62--Snow--[190, 255, 255],
|
||||
63--Terrain--[152, 251, 152],
|
||||
64--Vegetation--[107, 142, 35],
|
||||
65--Water--[0, 170, 30],
|
||||
66--Banner--[255, 255, 128],
|
||||
67--Bench--[250, 0, 30],
|
||||
68--Bike Rack--[100, 140, 180],
|
||||
69--Catch Basin--[220, 128, 128],
|
||||
70--CCTV Camera--[222, 40, 40],
|
||||
71--Fire Hydrant--[100, 170, 30],
|
||||
72--Junction Box--[40, 40, 40],
|
||||
73--Mailbox--[33, 33, 33],
|
||||
74--Manhole--[100, 128, 160],
|
||||
75--Parking Meter--[20, 20, 255],
|
||||
76--Phone Booth--[142, 0, 0],
|
||||
77--Pothole--[70, 100, 150],
|
||||
78--Signage - Advertisement--[250, 171, 30],
|
||||
79--Signage - Ambiguous--[250, 172, 30],
|
||||
80--Signage - Back--[250, 173, 30],
|
||||
81--Signage - Information--[250, 174, 30],
|
||||
82--Signage - Other--[250, 175, 30],
|
||||
83--Signage - Store--[250, 176, 30],
|
||||
84--Street Light--[210, 170, 100],
|
||||
85--Pole--[153, 153, 153],
|
||||
86--Pole Group--[153, 153, 153],
|
||||
87--Traffic Sign Frame--[128, 128, 128],
|
||||
88--Utility Pole--[0, 0, 80],
|
||||
89--Traffic Cone--[210, 60, 60],
|
||||
90--Traffic Light - General (Single)--[250, 170, 30],
|
||||
91--Traffic Light - Pedestrians--[250, 170, 30],
|
||||
92--Traffic Light - General (Upright)--[250, 170, 30],
|
||||
93--Traffic Light - General (Horizontal)--[250, 170, 30],
|
||||
94--Traffic Light - Cyclists--[250, 170, 30],
|
||||
95--Traffic Light - Other--[250, 170, 30],
|
||||
96--Traffic Sign - Ambiguous--[192, 192, 192],
|
||||
97--Traffic Sign (Back)--[192, 192, 192],
|
||||
98--Traffic Sign - Direction (Back)--[192, 192, 192],
|
||||
99--Traffic Sign - Direction (Front)--[220, 220, 0],
|
||||
100--Traffic Sign (Front)--[220, 220, 0],
|
||||
101--Traffic Sign - Parking--[0, 0, 196],
|
||||
102--Traffic Sign - Temporary (Back)--[192, 192, 192],
|
||||
103--Traffic Sign - Temporary (Front)--[220, 220, 0],
|
||||
104--Trash Can--[140, 140, 20],
|
||||
105--Bicycle--[119, 11, 32],
|
||||
106--Boat--[150, 0, 255],
|
||||
107--Bus--[0, 60, 100],
|
||||
108--Car--[0, 0, 142],
|
||||
109--Caravan--[0, 0, 90],
|
||||
110--Motorcycle--[0, 0, 230],
|
||||
111--On Rails--[0, 80, 100],
|
||||
112--Other Vehicle--[128, 64, 64],
|
||||
113--Trailer--[0, 0, 110],
|
||||
114--Truck--[0, 0, 70],
|
||||
115--Vehicle Group--[0, 0, 142],
|
||||
116--Wheeled Slow--[0, 0, 192],
|
||||
117--Water Valve--[170, 170, 170],
|
||||
118--Car Mount--[32, 32, 32],
|
||||
119--Dynamic--[111, 74, 0],
|
||||
120--Ego Vehicle--[120, 10, 10],
|
||||
121--Ground--[81, 0, 81],
|
||||
122--Static--[111, 111, 0],
|
||||
123--Unlabeled--[0, 0, 0]
|
||||
```
|
||||
|
177
projects/mapillary_dataset/mmseg/datasets/mapillary.py
Normal file
177
projects/mapillary_dataset/mmseg/datasets/mapillary.py
Normal file
@ -0,0 +1,177 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmseg.datasets.basesegdataset import BaseSegDataset
|
||||
|
||||
# from mmseg.registry import DATASETS
|
||||
|
||||
|
||||
# @DATASETS.register_module()
|
||||
class MapillaryDataset_v1(BaseSegDataset):
|
||||
"""Mapillary Vistas Dataset.
|
||||
|
||||
Dataset paper link:
|
||||
http://ieeexplore.ieee.org/document/8237796/
|
||||
|
||||
v1.2 contain 66 object classes.
|
||||
(37 instance-specific)
|
||||
|
||||
v2.0 contain 124 object classes.
|
||||
(70 instance-specific, 46 stuff, 8 void or crowd).
|
||||
|
||||
The ``img_suffix`` is fixed to '.jpg' and ``seg_map_suffix`` is
|
||||
fixed to '.png' for Mapillary Vistas Dataset.
|
||||
"""
|
||||
METAINFO = dict(
|
||||
classes=('Bird', 'Ground Animal', 'Curb', 'Fence', 'Guard Rail',
|
||||
'Barrier', 'Wall', 'Bike Lane', 'Crosswalk - Plain',
|
||||
'Curb Cut', 'Parking', 'Pedestrian Area', 'Rail Track',
|
||||
'Road', 'Service Lane', 'Sidewalk', 'Bridge', 'Building',
|
||||
'Tunnel', 'Person', 'Bicyclist', 'Motorcyclist',
|
||||
'Other Rider', 'Lane Marking - Crosswalk',
|
||||
'Lane Marking - General', 'Mountain', 'Sand', 'Sky', 'Snow',
|
||||
'Terrain', 'Vegetation', 'Water', 'Banner', 'Bench',
|
||||
'Bike Rack', 'Billboard', 'Catch Basin', 'CCTV Camera',
|
||||
'Fire Hydrant', 'Junction Box', 'Mailbox', 'Manhole',
|
||||
'Phone Booth', 'Pothole', 'Street Light', 'Pole',
|
||||
'Traffic Sign Frame', 'Utility Pole', 'Traffic Light',
|
||||
'Traffic Sign (Back)', 'Traffic Sign (Front)', 'Trash Can',
|
||||
'Bicycle', 'Boat', 'Bus', 'Car', 'Caravan', 'Motorcycle',
|
||||
'On Rails', 'Other Vehicle', 'Trailer', 'Truck',
|
||||
'Wheeled Slow', 'Car Mount', 'Ego Vehicle', 'Unlabeled'),
|
||||
palette=[[165, 42, 42], [0, 192, 0], [196, 196, 196], [190, 153, 153],
|
||||
[180, 165, 180], [90, 120, 150], [102, 102, 156],
|
||||
[128, 64, 255], [140, 140, 200], [170, 170, 170],
|
||||
[250, 170, 160], [96, 96, 96],
|
||||
[230, 150, 140], [128, 64, 128], [110, 110, 110],
|
||||
[244, 35, 232], [150, 100, 100], [70, 70, 70], [150, 120, 90],
|
||||
[220, 20, 60], [255, 0, 0], [255, 0, 100], [255, 0, 200],
|
||||
[200, 128, 128], [255, 255, 255], [64, 170,
|
||||
64], [230, 160, 50],
|
||||
[70, 130, 180], [190, 255, 255], [152, 251, 152],
|
||||
[107, 142, 35], [0, 170, 30], [255, 255, 128], [250, 0, 30],
|
||||
[100, 140, 180], [220, 220, 220], [220, 128, 128],
|
||||
[222, 40, 40], [100, 170, 30], [40, 40, 40], [33, 33, 33],
|
||||
[100, 128, 160], [142, 0, 0], [70, 100, 150], [210, 170, 100],
|
||||
[153, 153, 153], [128, 128, 128], [0, 0, 80], [250, 170, 30],
|
||||
[192, 192, 192], [220, 220, 0], [140, 140, 20], [119, 11, 32],
|
||||
[150, 0, 255], [0, 60, 100], [0, 0, 142], [0, 0, 90],
|
||||
[0, 0, 230], [0, 80, 100], [128, 64, 64], [0, 0, 110],
|
||||
[0, 0, 70], [0, 0, 192], [32, 32, 32], [120, 10,
|
||||
10], [0, 0, 0]])
|
||||
|
||||
def __init__(self,
|
||||
img_suffix='.jpg',
|
||||
seg_map_suffix='.png',
|
||||
**kwargs) -> None:
|
||||
super().__init__(
|
||||
img_suffix=img_suffix, seg_map_suffix=seg_map_suffix, **kwargs)
|
||||
|
||||
|
||||
# @DATASETS.register_module()
|
||||
class MapillaryDataset_v2(BaseSegDataset):
|
||||
"""Mapillary Vistas Dataset.
|
||||
|
||||
Dataset paper link:
|
||||
http://ieeexplore.ieee.org/document/8237796/
|
||||
|
||||
v1.2 contain 66 object classes.
|
||||
(37 instance-specific)
|
||||
|
||||
v2.0 contain 124 object classes.
|
||||
(70 instance-specific, 46 stuff, 8 void or crowd).
|
||||
|
||||
The ``img_suffix`` is fixed to '.jpg' and ``seg_map_suffix`` is
|
||||
fixed to '.png' for Mapillary Vistas Dataset.
|
||||
"""
|
||||
METAINFO = dict(
|
||||
classes=(
|
||||
'Bird', 'Ground Animal', 'Ambiguous Barrier', 'Concrete Block',
|
||||
'Curb', 'Fence', 'Guard Rail', 'Barrier', 'Road Median',
|
||||
'Road Side', 'Lane Separator', 'Temporary Barrier', 'Wall',
|
||||
'Bike Lane', 'Crosswalk - Plain', 'Curb Cut', 'Driveway',
|
||||
'Parking', 'Parking Aisle', 'Pedestrian Area', 'Rail Track',
|
||||
'Road', 'Road Shoulder', 'Service Lane', 'Sidewalk',
|
||||
'Traffic Island', 'Bridge', 'Building', 'Garage', 'Tunnel',
|
||||
'Person', 'Person Group', 'Bicyclist', 'Motorcyclist',
|
||||
'Other Rider', 'Lane Marking - Dashed Line',
|
||||
'Lane Marking - Straight Line', 'Lane Marking - Zigzag Line',
|
||||
'Lane Marking - Ambiguous', 'Lane Marking - Arrow (Left)',
|
||||
'Lane Marking - Arrow (Other)', 'Lane Marking - Arrow (Right)',
|
||||
'Lane Marking - Arrow (Split Left or Straight)',
|
||||
'Lane Marking - Arrow (Split Right or Straight)',
|
||||
'Lane Marking - Arrow (Straight)', 'Lane Marking - Crosswalk',
|
||||
'Lane Marking - Give Way (Row)',
|
||||
'Lane Marking - Give Way (Single)',
|
||||
'Lane Marking - Hatched (Chevron)',
|
||||
'Lane Marking - Hatched (Diagonal)', 'Lane Marking - Other',
|
||||
'Lane Marking - Stop Line', 'Lane Marking - Symbol (Bicycle)',
|
||||
'Lane Marking - Symbol (Other)', 'Lane Marking - Text',
|
||||
'Lane Marking (only) - Dashed Line',
|
||||
'Lane Marking (only) - Crosswalk', 'Lane Marking (only) - Other',
|
||||
'Lane Marking (only) - Test', 'Mountain', 'Sand', 'Sky', 'Snow',
|
||||
'Terrain', 'Vegetation', 'Water', 'Banner', 'Bench', 'Bike Rack',
|
||||
'Catch Basin', 'CCTV Camera', 'Fire Hydrant', 'Junction Box',
|
||||
'Mailbox', 'Manhole', 'Parking Meter', 'Phone Booth', 'Pothole',
|
||||
'Signage - Advertisement', 'Signage - Ambiguous', 'Signage - Back',
|
||||
'Signage - Information', 'Signage - Other', 'Signage - Store',
|
||||
'Street Light', 'Pole', 'Pole Group', 'Traffic Sign Frame',
|
||||
'Utility Pole', 'Traffic Cone', 'Traffic Light - General (Single)',
|
||||
'Traffic Light - Pedestrians', 'Traffic Light - General (Upright)',
|
||||
'Traffic Light - General (Horizontal)', 'Traffic Light - Cyclists',
|
||||
'Traffic Light - Other', 'Traffic Sign - Ambiguous',
|
||||
'Traffic Sign (Back)', 'Traffic Sign - Direction (Back)',
|
||||
'Traffic Sign - Direction (Front)', 'Traffic Sign (Front)',
|
||||
'Traffic Sign - Parking', 'Traffic Sign - Temporary (Back)',
|
||||
'Traffic Sign - Temporary (Front)', 'Trash Can', 'Bicycle', 'Boat',
|
||||
'Bus', 'Car', 'Caravan', 'Motorcycle', 'On Rails', 'Other Vehicle',
|
||||
'Trailer', 'Truck', 'Vehicle Group', 'Wheeled Slow', 'Water Valve',
|
||||
'Car Mount', 'Dynamic', 'Ego Vehicle', 'Ground', 'Static',
|
||||
'Unlabeled'),
|
||||
palette=[[165, 42, 42], [0, 192, 0], [250, 170, 31], [250, 170, 32],
|
||||
[196, 196, 196], [190, 153, 153], [180, 165, 180],
|
||||
[90, 120, 150], [250, 170, 33], [250, 170, 34],
|
||||
[128, 128, 128], [250, 170, 35], [102, 102, 156],
|
||||
[128, 64, 255], [140, 140, 200], [170, 170, 170],
|
||||
[250, 170, 36], [250, 170, 160], [250, 170, 37], [96, 96, 96],
|
||||
[230, 150, 140], [128, 64, 128], [110, 110, 110],
|
||||
[110, 110, 110], [244, 35, 232], [128, 196,
|
||||
128], [150, 100, 100],
|
||||
[70, 70, 70], [150, 150, 150], [150, 120, 90], [220, 20, 60],
|
||||
[220, 20, 60], [255, 0, 0], [255, 0, 100], [255, 0, 200],
|
||||
[255, 255, 255], [255, 255, 255], [250, 170, 29],
|
||||
[250, 170, 28], [250, 170, 26], [250, 170,
|
||||
25], [250, 170, 24],
|
||||
[250, 170, 22], [250, 170, 21], [250, 170,
|
||||
20], [255, 255, 255],
|
||||
[250, 170, 19], [250, 170, 18], [250, 170,
|
||||
12], [250, 170, 11],
|
||||
[255, 255, 255], [255, 255, 255], [250, 170, 16],
|
||||
[250, 170, 15], [250, 170, 15], [255, 255, 255],
|
||||
[255, 255, 255], [255, 255, 255], [255, 255, 255],
|
||||
[64, 170, 64], [230, 160, 50],
|
||||
[70, 130, 180], [190, 255, 255], [152, 251, 152],
|
||||
[107, 142, 35], [0, 170, 30], [255, 255, 128], [250, 0, 30],
|
||||
[100, 140, 180], [220, 128, 128], [222, 40,
|
||||
40], [100, 170, 30],
|
||||
[40, 40, 40], [33, 33, 33], [100, 128, 160], [20, 20, 255],
|
||||
[142, 0, 0], [70, 100, 150], [250, 171, 30], [250, 172, 30],
|
||||
[250, 173, 30], [250, 174, 30], [250, 175,
|
||||
30], [250, 176, 30],
|
||||
[210, 170, 100], [153, 153, 153], [153, 153, 153],
|
||||
[128, 128, 128], [0, 0, 80], [210, 60, 60], [250, 170, 30],
|
||||
[250, 170, 30], [250, 170, 30], [250, 170,
|
||||
30], [250, 170, 30],
|
||||
[250, 170, 30], [192, 192, 192], [192, 192, 192],
|
||||
[192, 192, 192], [220, 220, 0], [220, 220, 0], [0, 0, 196],
|
||||
[192, 192, 192], [220, 220, 0], [140, 140, 20], [119, 11, 32],
|
||||
[150, 0, 255], [0, 60, 100], [0, 0, 142], [0, 0, 90],
|
||||
[0, 0, 230], [0, 80, 100], [128, 64, 64], [0, 0, 110],
|
||||
[0, 0, 70], [0, 0, 142], [0, 0, 192], [170, 170, 170],
|
||||
[32, 32, 32], [111, 74, 0], [120, 10, 10], [81, 0, 81],
|
||||
[111, 111, 0], [0, 0, 0]])
|
||||
|
||||
def __init__(self,
|
||||
img_suffix='.jpg',
|
||||
seg_map_suffix='.png',
|
||||
**kwargs) -> None:
|
||||
super().__init__(
|
||||
img_suffix=img_suffix, seg_map_suffix=seg_map_suffix, **kwargs)
|
@ -1,65 +0,0 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from mmseg.datasets.basesegdataset import BaseSegDataset
|
||||
from mmseg.registry import DATASETS
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class MapillaryDataset_v1_2(BaseSegDataset):
|
||||
"""Mapillary Vistas Dataset.
|
||||
|
||||
Dataset paper link:
|
||||
http://ieeexplore.ieee.org/document/8237796/
|
||||
|
||||
v1.2 contain 66 object classes.
|
||||
(37 instance-specific)
|
||||
|
||||
v2.0 contain 124 object classes.
|
||||
(70 instance-specific, 46 stuff, 8 void or crowd).
|
||||
|
||||
The ``img_suffix`` is fixed to '.jpg' and ``seg_map_suffix`` is
|
||||
fixed to '.png' for Mapillary Vistas Dataset.
|
||||
"""
|
||||
METAINFO = dict(
|
||||
classes=('Bird', 'Ground Animal', 'Curb', 'Fence', 'Guard Rail',
|
||||
'Barrier', 'Wall', 'Bike Lane', 'Crosswalk - Plain',
|
||||
'Curb Cut', 'Parking', 'Pedestrian Area', 'Rail Track',
|
||||
'Road', 'Service Lane', 'Sidewalk', 'Bridge', 'Building',
|
||||
'Tunnel', 'Person', 'Bicyclist', 'Motorcyclist',
|
||||
'Other Rider', 'Lane Marking - Crosswalk',
|
||||
'Lane Marking - General', 'Mountain', 'Sand', 'Sky', 'Snow',
|
||||
'Terrain', 'Vegetation', 'Water', 'Banner', 'Bench',
|
||||
'Bike Rack', 'Billboard', 'Catch Basin', 'CCTV Camera',
|
||||
'Fire Hydrant', 'Junction Box', 'Mailbox', 'Manhole',
|
||||
'Phone Booth', 'Pothole', 'Street Light', 'Pole',
|
||||
'Traffic Sign Frame', 'Utility Pole', 'Traffic Light',
|
||||
'Traffic Sign (Back)', 'Traffic Sign (Front)', 'Trash Can',
|
||||
'Bicycle', 'Boat', 'Bus', 'Car', 'Caravan', 'Motorcycle',
|
||||
'On Rails', 'Other Vehicle', 'Trailer', 'Truck',
|
||||
'Wheeled Slow', 'Car Mount', 'Ego Vehicle', 'Unlabeled'),
|
||||
palette=[[165, 42, 42], [0, 192, 0], [196, 196, 196], [190, 153, 153],
|
||||
[180, 165, 180], [90, 120, 150], [102, 102, 156],
|
||||
[128, 64, 255], [140, 140, 200], [170, 170, 170],
|
||||
[250, 170, 160], [96, 96, 96],
|
||||
[230, 150, 140], [128, 64, 128], [110, 110, 110],
|
||||
[244, 35, 232], [150, 100, 100], [70, 70, 70], [150, 120, 90],
|
||||
[220, 20, 60], [255, 0, 0], [255, 0, 100], [255, 0, 200],
|
||||
[200, 128, 128], [255, 255, 255], [64, 170,
|
||||
64], [230, 160, 50],
|
||||
[70, 130, 180], [190, 255, 255], [152, 251, 152],
|
||||
[107, 142, 35], [0, 170, 30], [255, 255, 128], [250, 0, 30],
|
||||
[100, 140, 180], [220, 220, 220], [220, 128, 128],
|
||||
[222, 40, 40], [100, 170, 30], [40, 40, 40], [33, 33, 33],
|
||||
[100, 128, 160], [142, 0, 0], [70, 100, 150], [210, 170, 100],
|
||||
[153, 153, 153], [128, 128, 128], [0, 0, 80], [250, 170, 30],
|
||||
[192, 192, 192], [220, 220, 0], [140, 140, 20], [119, 11, 32],
|
||||
[150, 0, 255], [0, 60, 100], [0, 0, 142], [0, 0, 90],
|
||||
[0, 0, 230], [0, 80, 100], [128, 64, 64], [0, 0, 110],
|
||||
[0, 0, 70], [0, 0, 192], [32, 32, 32], [120, 10,
|
||||
10], [0, 0, 0]])
|
||||
|
||||
def __init__(self,
|
||||
img_suffix='.jpg',
|
||||
seg_map_suffix='.png',
|
||||
**kwargs) -> None:
|
||||
super().__init__(
|
||||
img_suffix=img_suffix, seg_map_suffix=seg_map_suffix, **kwargs)
|
@ -1,245 +0,0 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import argparse
|
||||
import os.path as osp
|
||||
from functools import partial
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
from mmengine.utils import (mkdir_or_exist, scandir, track_parallel_progress,
|
||||
track_progress)
|
||||
|
||||
colormap_v1_2 = np.array([[165, 42, 42], [0, 192, 0], [196, 196, 196],
|
||||
[190, 153, 153], [180, 165, 180], [90, 120, 150],
|
||||
[102, 102, 156], [128, 64, 255], [140, 140, 200],
|
||||
[170, 170, 170], [250, 170, 160], [96, 96, 96],
|
||||
[230, 150, 140], [128, 64, 128], [110, 110, 110],
|
||||
[244, 35, 232], [150, 100, 100], [70, 70, 70],
|
||||
[150, 120, 90], [220, 20, 60], [255, 0, 0],
|
||||
[255, 0, 100], [255, 0, 200], [200, 128, 128],
|
||||
[255, 255, 255], [64, 170, 64], [230, 160, 50],
|
||||
[70, 130, 180], [190, 255, 255], [152, 251, 152],
|
||||
[107, 142, 35], [0, 170, 30], [255, 255, 128],
|
||||
[250, 0, 30], [100, 140, 180], [220, 220, 220],
|
||||
[220, 128, 128], [222, 40, 40], [100, 170, 30],
|
||||
[40, 40, 40], [33, 33, 33], [100, 128, 160],
|
||||
[142, 0, 0], [70, 100, 150], [210, 170, 100],
|
||||
[153, 153, 153], [128, 128, 128], [0, 0, 80],
|
||||
[250, 170, 30], [192, 192, 192], [220, 220, 0],
|
||||
[140, 140, 20], [119, 11, 32], [150, 0, 255],
|
||||
[0, 60, 100], [0, 0, 142], [0, 0, 90], [0, 0, 230],
|
||||
[0, 80, 100], [128, 64, 64], [0, 0, 110], [0, 0, 70],
|
||||
[0, 0, 192], [32, 32, 32], [120, 10, 10], [0, 0, 0]])
|
||||
|
||||
colormap_v2_0 = np.array([[165, 42, 42], [0, 192, 0], [250, 170, 31],
|
||||
[250, 170, 32], [196, 196, 196], [190, 153, 153],
|
||||
[180, 165, 180], [90, 120, 150], [250, 170, 33],
|
||||
[250, 170, 34], [128, 128, 128], [250, 170, 35],
|
||||
[102, 102, 156], [128, 64, 255], [140, 140, 200],
|
||||
[170, 170, 170], [250, 170, 36], [250, 170, 160],
|
||||
[250, 170, 37], [96, 96, 96], [230, 150, 140],
|
||||
[128, 64, 128], [110, 110, 110], [110, 110, 110],
|
||||
[244, 35, 232], [128, 196, 128], [150, 100, 100],
|
||||
[70, 70, 70], [150, 150, 150], [150, 120, 90],
|
||||
[220, 20, 60], [220, 20, 60], [255, 0, 0],
|
||||
[255, 0, 100], [255, 0, 200], [255, 255, 255],
|
||||
[255, 255, 255], [250, 170, 29], [250, 170, 28],
|
||||
[250, 170, 26], [250, 170, 25], [250, 170, 24],
|
||||
[250, 170, 22], [250, 170, 21], [250, 170, 20],
|
||||
[255, 255, 255], [250, 170, 19], [250, 170, 18],
|
||||
[250, 170, 12], [250, 170, 11], [255, 255, 255],
|
||||
[255, 255, 255], [250, 170, 16], [250, 170, 15],
|
||||
[250, 170, 15], [255, 255, 255], [255, 255, 255],
|
||||
[255, 255, 255], [255, 255, 255], [64, 170, 64],
|
||||
[230, 160, 50], [70, 130, 180], [190, 255, 255],
|
||||
[152, 251, 152], [107, 142, 35], [0, 170, 30],
|
||||
[255, 255, 128], [250, 0, 30], [100, 140, 180],
|
||||
[220, 128, 128], [222, 40, 40], [100, 170, 30],
|
||||
[40, 40, 40], [33, 33, 33], [100, 128, 160],
|
||||
[20, 20, 255], [142, 0, 0], [70, 100, 150],
|
||||
[250, 171, 30], [250, 172, 30], [250, 173, 30],
|
||||
[250, 174, 30], [250, 175, 30], [250, 176, 30],
|
||||
[210, 170, 100], [153, 153, 153], [153, 153, 153],
|
||||
[128, 128, 128], [0, 0, 80], [210, 60, 60],
|
||||
[250, 170, 30], [250, 170, 30], [250, 170, 30],
|
||||
[250, 170, 30], [250, 170, 30], [250, 170, 30],
|
||||
[192, 192, 192], [192, 192, 192], [192, 192, 192],
|
||||
[220, 220, 0], [220, 220, 0], [0, 0, 196],
|
||||
[192, 192, 192], [220, 220, 0], [140, 140, 20],
|
||||
[119, 11, 32], [150, 0, 255], [0, 60, 100],
|
||||
[0, 0, 142], [0, 0, 90], [0, 0, 230], [0, 80, 100],
|
||||
[128, 64, 64], [0, 0, 110], [0, 0, 70], [0, 0, 142],
|
||||
[0, 0, 192], [170, 170, 170], [32, 32, 32],
|
||||
[111, 74, 0], [120, 10, 10], [81, 0, 81],
|
||||
[111, 111, 0], [0, 0, 0]])
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Convert Mapillary dataset to mmsegmentation format')
|
||||
parser.add_argument('dataset_path', help='Mapillary folder path')
|
||||
parser.add_argument(
|
||||
'--version',
|
||||
default='all',
|
||||
help="Mapillary labels version, 'v1.2','v2.0','all'")
|
||||
parser.add_argument('-o', '--out_dir', help='output path')
|
||||
parser.add_argument(
|
||||
'--nproc', default=1, type=int, help='number of process')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def mapillary_colormap2label(colormap: np.ndarray) -> list:
|
||||
"""Create a `list` shaped (256^3, 1), convert each color palette to a
|
||||
number, which can use to find the correct label value.
|
||||
|
||||
For example labels 0--Bird--[165, 42, 42]
|
||||
(165*256 + 42) * 256 + 42 = 10824234 (This is list's index])
|
||||
`colormap2label[10824234] = 0`
|
||||
|
||||
In converting, if a RGB pixel value is [165, 42, 42],
|
||||
through colormap2label[10824234]-->can quickly find
|
||||
this labels value is 0.
|
||||
Through matrix multiply to compute a img is very fast.
|
||||
|
||||
Args:
|
||||
colormap (np.ndarray): Mapillary Vistas Dataset palette
|
||||
|
||||
Returns:
|
||||
list: values are mask labels,
|
||||
indices are palette's convert results.
|
||||
"""
|
||||
colormap2label = np.zeros(256**3, dtype=np.longlong)
|
||||
for i, colormap_ in enumerate(colormap):
|
||||
colormap2label[(colormap_[0] * 256 + colormap_[1]) * 256 +
|
||||
colormap_[2]] = i
|
||||
return colormap2label
|
||||
|
||||
|
||||
def mapillary_masklabel(rgb_label: np.ndarray,
|
||||
colormap2label: list) -> np.ndarray:
|
||||
"""Computing a img mask label through `colormap2label` get in
|
||||
`mapillary_colormap2label(COLORMAP: np.ndarray)`
|
||||
|
||||
Args:
|
||||
rgb_label (np.array): a RGB labels img.
|
||||
colormap2label (list): get in mapillary_colormap2label(colormap)
|
||||
|
||||
Returns:
|
||||
np.ndarray: mask labels array.
|
||||
"""
|
||||
colormap_ = rgb_label.astype('uint32')
|
||||
idx = np.array((colormap_[:, :, 0] * 256 + colormap_[:, :, 1]) * 256 +
|
||||
colormap_[:, :, 2]).astype('uint32')
|
||||
return colormap2label[idx]
|
||||
|
||||
|
||||
def RGB2Mask(rgb_label_path: str, colormap2label: list) -> None:
|
||||
"""Mapillary Vistas Dataset provide 8-bit with color-palette class-specific
|
||||
labels for semantic segmentation. However, semantic segmentation needs
|
||||
single channel mask labels.
|
||||
|
||||
This code is about converting mapillary RGB labels
|
||||
{traing,validation/v1.2,v2.0/labels} to mask labels
|
||||
{{traing,validation/v1.2,v2.0/labels_mask}
|
||||
|
||||
Args:
|
||||
rgb_label_path (str): image absolute path.
|
||||
dataset_version (str): v1.2 or v2.0 to choose color_map .
|
||||
"""
|
||||
rgb_label = mmcv.imread(rgb_label_path, channel_order='rgb')
|
||||
|
||||
masks_label = mapillary_masklabel(rgb_label, colormap2label)
|
||||
|
||||
mmcv.imwrite(
|
||||
masks_label.astype(np.uint8),
|
||||
rgb_label_path.replace('labels', 'labels_mask'))
|
||||
|
||||
|
||||
def main():
|
||||
colormap2label_v1_2 = mapillary_colormap2label(colormap_v1_2)
|
||||
colormap2label_v2_0 = mapillary_colormap2label(colormap_v2_0)
|
||||
|
||||
dataset_path = args.dataset_path
|
||||
if args.out_dir is None:
|
||||
out_dir = dataset_path
|
||||
else:
|
||||
out_dir = args.out_dir
|
||||
|
||||
RGB_labels_path = []
|
||||
RGB_labels_v1_2_path = []
|
||||
RGB_labels_v2_0_path = []
|
||||
print('Scanning labels path....')
|
||||
for label_path in scandir(dataset_path, suffix='.png', recursive=True):
|
||||
if 'labels' in label_path:
|
||||
rgb_label_path = osp.join(dataset_path, label_path)
|
||||
RGB_labels_path.append(rgb_label_path)
|
||||
if 'v1.2' in label_path:
|
||||
RGB_labels_v1_2_path.append(rgb_label_path)
|
||||
elif 'v2.0' in label_path:
|
||||
RGB_labels_v2_0_path.append(rgb_label_path)
|
||||
|
||||
if args.version == 'all':
|
||||
print(f'Totaly found {len(RGB_labels_path)} {args.version} RGB labels')
|
||||
elif args.version == 'v1.2':
|
||||
print(f'Found {len(RGB_labels_v1_2_path)} {args.version} RGB labels')
|
||||
elif args.version == 'v2.0':
|
||||
print(f'Found {len(RGB_labels_v2_0_path)} {args.version} RGB labels')
|
||||
print('Making directories...')
|
||||
mkdir_or_exist(osp.join(out_dir, 'training', 'v1.2', 'labels_mask'))
|
||||
mkdir_or_exist(osp.join(out_dir, 'validation', 'v1.2', 'labels_mask'))
|
||||
mkdir_or_exist(osp.join(out_dir, 'training', 'v2.0', 'labels_mask'))
|
||||
mkdir_or_exist(osp.join(out_dir, 'validation', 'v2.0', 'labels_mask'))
|
||||
print('Directories Have Made...')
|
||||
|
||||
if args.nproc > 1:
|
||||
if args.version == 'all':
|
||||
print('Converting v1.2 ....')
|
||||
track_parallel_progress(
|
||||
partial(RGB2Mask, colormap2label=colormap2label_v1_2),
|
||||
RGB_labels_v1_2_path,
|
||||
nproc=args.nproc)
|
||||
print('Converting v2.0 ....')
|
||||
track_parallel_progress(
|
||||
partial(RGB2Mask, colormap2label=colormap2label_v2_0),
|
||||
RGB_labels_v2_0_path,
|
||||
nproc=args.nproc)
|
||||
elif args.version == 'v1.2':
|
||||
print('Converting v1.2 ....')
|
||||
track_parallel_progress(
|
||||
partial(RGB2Mask, colormap2label=colormap2label_v1_2),
|
||||
RGB_labels_v1_2_path,
|
||||
nproc=args.nproc)
|
||||
elif args.version == 'v2.0':
|
||||
print('Converting v2.0 ....')
|
||||
track_parallel_progress(
|
||||
partial(RGB2Mask, colormap2label=colormap2label_v2_0),
|
||||
RGB_labels_v2_0_path,
|
||||
nproc=args.nproc)
|
||||
|
||||
else:
|
||||
if args.version == 'all':
|
||||
print('Converting v1.2 ....')
|
||||
track_progress(
|
||||
partial(RGB2Mask, colormap2label=colormap2label_v1_2),
|
||||
RGB_labels_v1_2_path)
|
||||
print('Converting v2.0 ....')
|
||||
track_progress(
|
||||
partial(RGB2Mask, colormap2label=colormap2label_v2_0),
|
||||
RGB_labels_v2_0_path)
|
||||
elif args.version == 'v1.2':
|
||||
print('Converting v1.2 ....')
|
||||
track_progress(
|
||||
partial(RGB2Mask, colormap2label=colormap2label_v1_2),
|
||||
RGB_labels_v1_2_path)
|
||||
elif args.version == 'v2.0':
|
||||
print('Converting v2.0 ....')
|
||||
track_progress(
|
||||
partial(RGB2Mask, colormap2label=colormap2label_v2_0),
|
||||
RGB_labels_v2_0_path)
|
||||
|
||||
print('Have convert Mapillary Vistas Datasets RGB labels to Mask labels!')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_args()
|
||||
main()
|
Binary file not shown.
After Width: | Height: | Size: 1.2 MiB |
Binary file not shown.
After Width: | Height: | Size: 74 KiB |
Binary file not shown.
After Width: | Height: | Size: 74 KiB |
@ -7,7 +7,8 @@ import pytest
|
||||
|
||||
from mmseg.datasets import (ADE20KDataset, BaseSegDataset, CityscapesDataset,
|
||||
COCOStuffDataset, DecathlonDataset, ISPRSDataset,
|
||||
LIPDataset, LoveDADataset, PascalVOCDataset,
|
||||
LIPDataset, LoveDADataset, MapillaryDataset_v1,
|
||||
MapillaryDataset_v2, PascalVOCDataset,
|
||||
PotsdamDataset, REFUGEDataset, SynapseDataset,
|
||||
iSAIDDataset)
|
||||
from mmseg.registry import DATASETS
|
||||
@ -27,6 +28,10 @@ def test_classes():
|
||||
assert list(PotsdamDataset.METAINFO['classes']) == get_classes('potsdam')
|
||||
assert list(ISPRSDataset.METAINFO['classes']) == get_classes('vaihingen')
|
||||
assert list(iSAIDDataset.METAINFO['classes']) == get_classes('isaid')
|
||||
assert list(
|
||||
MapillaryDataset_v1.METAINFO['classes']) == get_classes('mapillary_v1')
|
||||
assert list(
|
||||
MapillaryDataset_v2.METAINFO['classes']) == get_classes('mapillary_v2')
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
get_classes('unsupported')
|
||||
@ -80,6 +85,10 @@ def test_palette():
|
||||
assert PotsdamDataset.METAINFO['palette'] == get_palette('potsdam')
|
||||
assert COCOStuffDataset.METAINFO['palette'] == get_palette('cocostuff')
|
||||
assert iSAIDDataset.METAINFO['palette'] == get_palette('isaid')
|
||||
assert list(
|
||||
MapillaryDataset_v1.METAINFO['palette']) == get_palette('mapillary_v1')
|
||||
assert list(
|
||||
MapillaryDataset_v2.METAINFO['palette']) == get_palette('mapillary_v2')
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
get_palette('unsupported')
|
||||
@ -304,6 +313,19 @@ def test_lip():
|
||||
assert len(test_dataset) == 1
|
||||
|
||||
|
||||
def test_mapillary():
|
||||
test_dataset = MapillaryDataset_v1(
|
||||
pipeline=[],
|
||||
data_prefix=dict(
|
||||
img_path=osp.join(
|
||||
osp.dirname(__file__),
|
||||
'../data/pseudo_mapillary_dataset/images'),
|
||||
seg_map_path=osp.join(
|
||||
osp.dirname(__file__),
|
||||
'../data/pseudo_mapillary_dataset/v1.2')))
|
||||
assert len(test_dataset) == 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize('dataset, classes', [
|
||||
('ADE20KDataset', ('wall', 'building')),
|
||||
('CityscapesDataset', ('road', 'sidewalk')),
|
||||
|
Loading…
x
Reference in New Issue
Block a user