CodeCamp #1555[Feature] Support Mapillary Vistas Dataset (#2484)
## Support `Mapillary Vistas Dataset` ## Motivation Support **`Mapillary Vistas Dataset`** Dataset Paper link : https://ieeexplore.ieee.org/document/9878466/ Download and more information view https://www.mapillary.com/dataset/vistas ``` @InProceedings{Neuhold_2017_ICCV, author = {Neuhold, Gerhard and Ollmann, Tobias and Rota Bulo, Samuel and Kontschieder, Peter}, title = {The Mapillary Vistas Dataset for Semantic Understanding of Street Scenes}, booktitle = {Proceedings of the IEEE International Conference on Computer Vision (ICCV)}, month = {Oct}, year = {2017} } ``` ## Modification Add `Mapillary_dataset` in `mmsegmentation/projects` Add `configs/_base_/mapillary_v1_2.py` and `configs/_base_/mapillary_v2_0.py` Add `configs/deeplabv3plus_r18-d8_4xb2-80k_mapillay-512x1024.py` to test training and testing on Mapillary datasets Add `docs/en/user_guides/2_dataset_prepare.md` , add Mapillary Vistas Dataset Preparing and Structure. Add `tools/dataset_converters/mapillary.py` to convert RGB labels to Mask labels. Co-authored-by: 谢昕辰 <xiexinch@outlook.com>pull/2516/head
parent
f678a5c974
commit
e394e2aa28
|
@ -0,0 +1,85 @@
|
||||||
|
# Mapillary Vistas Dataset
|
||||||
|
|
||||||
|
Support **`Mapillary Vistas Dataset`**
|
||||||
|
|
||||||
|
## Description
|
||||||
|
|
||||||
|
Author: AI-Tianlong
|
||||||
|
|
||||||
|
This project implements **`Mapillary Vistas Dataset`**
|
||||||
|
|
||||||
|
### Dataset preparing
|
||||||
|
|
||||||
|
Preparing `Mapillary Vistas Dataset` dataset following [Mapillary Vistas Dataset Preparing Guide](https://github.com/open-mmlab/mmsegmentation/tree/dev-1.x/projects/mapillary_dataset/docs/en/user_guides/2_dataset_prepare.md)
|
||||||
|
|
||||||
|
```none
|
||||||
|
mmsegmentation
|
||||||
|
├── mmseg
|
||||||
|
├── tools
|
||||||
|
├── configs
|
||||||
|
├── data
|
||||||
|
│ ├── mapillary
|
||||||
|
│ │ ├── training
|
||||||
|
│ │ │ ├── images
|
||||||
|
│ │ │ ├── v1.2
|
||||||
|
| │ │ │ ├── instances
|
||||||
|
| │ │ │ ├── labels
|
||||||
|
| │ │ │ ├── labels_mask
|
||||||
|
| │ │ │ └── panoptic
|
||||||
|
│ │ │ ├── v2.0
|
||||||
|
| │ │ │ ├── instances
|
||||||
|
| │ │ │ ├── labels
|
||||||
|
| │ │ │ ├── labels_mask
|
||||||
|
| │ │ │ ├── panoptic
|
||||||
|
| │ │ │ └── polygons
|
||||||
|
│ │ ├── validation
|
||||||
|
│ │ │ ├── images
|
||||||
|
| │ │ │ ├── instances
|
||||||
|
| │ │ │ ├── labels
|
||||||
|
| │ │ │ ├── labels_mask
|
||||||
|
| │ │ │ └── panoptic
|
||||||
|
│ │ │ ├── v2.0
|
||||||
|
| │ │ │ ├── instances
|
||||||
|
| │ │ │ ├── labels
|
||||||
|
| │ │ │ ├── labels_mask
|
||||||
|
| │ │ │ ├── panoptic
|
||||||
|
| │ │ │ └── polygons
|
||||||
|
```
|
||||||
|
|
||||||
|
### Training commands with `deeplabv3plus_r101-d8_4xb2-240k_mapillay-512x1024.py`
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Dataset train commands
|
||||||
|
# at `mmsegmentation` folder
|
||||||
|
bash tools/dist_train.sh projects/mapillary_dataset/configs/deeplabv3plus_r101-d8_4xb2-240k_mapillay-512x1024.py 4
|
||||||
|
```
|
||||||
|
|
||||||
|
## Checklist
|
||||||
|
|
||||||
|
- [x] Milestone 1: PR-ready, and acceptable to be one of the `projects/`.
|
||||||
|
|
||||||
|
- [x] Finish the code
|
||||||
|
|
||||||
|
- [x] Basic docstrings & proper citation
|
||||||
|
|
||||||
|
- [ ] Test-time correctness
|
||||||
|
|
||||||
|
- [x] A full README
|
||||||
|
|
||||||
|
- [ ] Milestone 2: Indicates a successful model implementation.
|
||||||
|
|
||||||
|
- [ ] Training-time correctness
|
||||||
|
|
||||||
|
- [ ] Milestone 3: Good to be a part of our core package!
|
||||||
|
|
||||||
|
- [ ] Type hints and docstrings
|
||||||
|
|
||||||
|
- [ ] Unit tests
|
||||||
|
|
||||||
|
- [ ] Code polishing
|
||||||
|
|
||||||
|
- [ ] Metafile.yml
|
||||||
|
|
||||||
|
- [ ] Move your modules into the core package following the codebase's file hierarchy structure.
|
||||||
|
|
||||||
|
- [ ] Refactor your modules into the core package following the codebase's file hierarchy structure.
|
|
@ -0,0 +1,69 @@
|
||||||
|
# dataset settings
|
||||||
|
dataset_type = 'MapillaryDataset_v1_2'
|
||||||
|
data_root = 'data/mapillary/'
|
||||||
|
crop_size = (512, 1024)
|
||||||
|
train_pipeline = [
|
||||||
|
dict(type='LoadImageFromFile'),
|
||||||
|
dict(type='LoadAnnotations'),
|
||||||
|
dict(
|
||||||
|
type='RandomResize',
|
||||||
|
scale=(2048, 1024),
|
||||||
|
ratio_range=(0.5, 2.0),
|
||||||
|
keep_ratio=True),
|
||||||
|
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
|
||||||
|
dict(type='RandomFlip', prob=0.5),
|
||||||
|
dict(type='PhotoMetricDistortion'),
|
||||||
|
dict(type='PackSegInputs')
|
||||||
|
]
|
||||||
|
test_pipeline = [
|
||||||
|
dict(type='LoadImageFromFile'),
|
||||||
|
dict(type='Resize', scale=(2048, 1024), keep_ratio=True),
|
||||||
|
# add loading annotation after ``Resize`` because ground truth
|
||||||
|
# does not need to do resize data transform
|
||||||
|
dict(type='LoadAnnotations'),
|
||||||
|
dict(type='PackSegInputs')
|
||||||
|
]
|
||||||
|
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
|
||||||
|
tta_pipeline = [
|
||||||
|
dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),
|
||||||
|
dict(
|
||||||
|
type='TestTimeAug',
|
||||||
|
transforms=[
|
||||||
|
[
|
||||||
|
dict(type='Resize', scale_factor=r, keep_ratio=True)
|
||||||
|
for r in img_ratios
|
||||||
|
],
|
||||||
|
[
|
||||||
|
dict(type='RandomFlip', prob=0., direction='horizontal'),
|
||||||
|
dict(type='RandomFlip', prob=1., direction='horizontal')
|
||||||
|
], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
|
||||||
|
])
|
||||||
|
]
|
||||||
|
train_dataloader = dict(
|
||||||
|
batch_size=2,
|
||||||
|
num_workers=4,
|
||||||
|
persistent_workers=True,
|
||||||
|
sampler=dict(type='InfiniteSampler', shuffle=True),
|
||||||
|
dataset=dict(
|
||||||
|
type=dataset_type,
|
||||||
|
data_root=data_root,
|
||||||
|
data_prefix=dict(
|
||||||
|
img_path='training/images',
|
||||||
|
seg_map_path='training/v1.2/labels_mask'),
|
||||||
|
pipeline=train_pipeline))
|
||||||
|
val_dataloader = dict(
|
||||||
|
batch_size=1,
|
||||||
|
num_workers=4,
|
||||||
|
persistent_workers=True,
|
||||||
|
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||||
|
dataset=dict(
|
||||||
|
type=dataset_type,
|
||||||
|
data_root=data_root,
|
||||||
|
data_prefix=dict(
|
||||||
|
img_path='validation/images',
|
||||||
|
seg_map_path='validation/v1.2/labels_mask'),
|
||||||
|
pipeline=test_pipeline))
|
||||||
|
test_dataloader = val_dataloader
|
||||||
|
|
||||||
|
val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU'])
|
||||||
|
test_evaluator = val_evaluator
|
|
@ -0,0 +1,69 @@
|
||||||
|
# dataset settings
|
||||||
|
dataset_type = 'MapillaryDataset_v2_0'
|
||||||
|
data_root = 'data/mapillary/'
|
||||||
|
crop_size = (512, 1024)
|
||||||
|
train_pipeline = [
|
||||||
|
dict(type='LoadImageFromFile'),
|
||||||
|
dict(type='LoadAnnotations'),
|
||||||
|
dict(
|
||||||
|
type='RandomResize',
|
||||||
|
scale=(2048, 1024),
|
||||||
|
ratio_range=(0.5, 2.0),
|
||||||
|
keep_ratio=True),
|
||||||
|
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
|
||||||
|
dict(type='RandomFlip', prob=0.5),
|
||||||
|
dict(type='PhotoMetricDistortion'),
|
||||||
|
dict(type='PackSegInputs')
|
||||||
|
]
|
||||||
|
test_pipeline = [
|
||||||
|
dict(type='LoadImageFromFile'),
|
||||||
|
dict(type='Resize', scale=(2048, 1024), keep_ratio=True),
|
||||||
|
# add loading annotation after ``Resize`` because ground truth
|
||||||
|
# does not need to do resize data transform
|
||||||
|
dict(type='LoadAnnotations'),
|
||||||
|
dict(type='PackSegInputs')
|
||||||
|
]
|
||||||
|
img_ratios = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75]
|
||||||
|
tta_pipeline = [
|
||||||
|
dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),
|
||||||
|
dict(
|
||||||
|
type='TestTimeAug',
|
||||||
|
transforms=[
|
||||||
|
[
|
||||||
|
dict(type='Resize', scale_factor=r, keep_ratio=True)
|
||||||
|
for r in img_ratios
|
||||||
|
],
|
||||||
|
[
|
||||||
|
dict(type='RandomFlip', prob=0., direction='horizontal'),
|
||||||
|
dict(type='RandomFlip', prob=1., direction='horizontal')
|
||||||
|
], [dict(type='LoadAnnotations')], [dict(type='PackSegInputs')]
|
||||||
|
])
|
||||||
|
]
|
||||||
|
train_dataloader = dict(
|
||||||
|
batch_size=2,
|
||||||
|
num_workers=4,
|
||||||
|
persistent_workers=True,
|
||||||
|
sampler=dict(type='InfiniteSampler', shuffle=True),
|
||||||
|
dataset=dict(
|
||||||
|
type=dataset_type,
|
||||||
|
data_root=data_root,
|
||||||
|
data_prefix=dict(
|
||||||
|
img_path='training/images',
|
||||||
|
seg_map_path='training/v2.0/labels_mask'),
|
||||||
|
pipeline=train_pipeline))
|
||||||
|
val_dataloader = dict(
|
||||||
|
batch_size=1,
|
||||||
|
num_workers=4,
|
||||||
|
persistent_workers=True,
|
||||||
|
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||||
|
dataset=dict(
|
||||||
|
type=dataset_type,
|
||||||
|
data_root=data_root,
|
||||||
|
data_prefix=dict(
|
||||||
|
img_path='validation/images',
|
||||||
|
seg_map_path='validation/v2.0/labels_mask'),
|
||||||
|
pipeline=test_pipeline))
|
||||||
|
test_dataloader = val_dataloader
|
||||||
|
|
||||||
|
val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU'])
|
||||||
|
test_evaluator = val_evaluator
|
|
@ -0,0 +1,103 @@
|
||||||
|
_base_ = ['./_base_/datasets/mapillary_v1_2.py'] # v 1.2 labels
|
||||||
|
# _base_ = ['./_base_/datasets/mapillary_v2_0.py'] # v2.0 labels
|
||||||
|
custom_imports = dict(imports=[
|
||||||
|
'projects.mapillary_dataset.mmseg.datasets.mapillary_v1_2',
|
||||||
|
'projects.mapillary_dataset.mmseg.datasets.mapillary_v2_0',
|
||||||
|
])
|
||||||
|
|
||||||
|
norm_cfg = dict(type='SyncBN', requires_grad=True)
|
||||||
|
data_preprocessor = dict(
|
||||||
|
type='SegDataPreProcessor',
|
||||||
|
mean=[123.675, 116.28, 103.53],
|
||||||
|
std=[58.395, 57.12, 57.375],
|
||||||
|
bgr_to_rgb=True,
|
||||||
|
pad_val=0,
|
||||||
|
seg_pad_val=255,
|
||||||
|
size=(512, 1024))
|
||||||
|
|
||||||
|
model = dict(
|
||||||
|
type='EncoderDecoder',
|
||||||
|
data_preprocessor=data_preprocessor,
|
||||||
|
pretrained=None,
|
||||||
|
backbone=dict(
|
||||||
|
type='ResNet',
|
||||||
|
depth=101,
|
||||||
|
num_stages=4,
|
||||||
|
out_indices=(0, 1, 2, 3),
|
||||||
|
dilations=(1, 1, 2, 4),
|
||||||
|
strides=(1, 2, 1, 1),
|
||||||
|
norm_cfg=norm_cfg,
|
||||||
|
norm_eval=False,
|
||||||
|
style='pytorch',
|
||||||
|
contract_dilation=True),
|
||||||
|
decode_head=dict(
|
||||||
|
type='DepthwiseSeparableASPPHead',
|
||||||
|
in_channels=2048,
|
||||||
|
in_index=3,
|
||||||
|
channels=512,
|
||||||
|
dilations=(1, 12, 24, 36),
|
||||||
|
c1_in_channels=256,
|
||||||
|
c1_channels=48,
|
||||||
|
dropout_ratio=0.1,
|
||||||
|
num_classes=66, # v1.2
|
||||||
|
# num_classes=124, # v2.0
|
||||||
|
norm_cfg=norm_cfg,
|
||||||
|
align_corners=False,
|
||||||
|
loss_decode=dict(
|
||||||
|
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
|
||||||
|
auxiliary_head=dict(
|
||||||
|
type='FCNHead',
|
||||||
|
in_channels=1024,
|
||||||
|
in_index=2,
|
||||||
|
channels=256,
|
||||||
|
num_convs=1,
|
||||||
|
concat_input=False,
|
||||||
|
dropout_ratio=0.1,
|
||||||
|
num_classes=66, # v1.2
|
||||||
|
# num_classes=124, # v2.0
|
||||||
|
norm_cfg=norm_cfg,
|
||||||
|
align_corners=False,
|
||||||
|
loss_decode=dict(
|
||||||
|
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
|
||||||
|
train_cfg=dict(),
|
||||||
|
test_cfg=dict(mode='whole'))
|
||||||
|
default_scope = 'mmseg'
|
||||||
|
env_cfg = dict(
|
||||||
|
cudnn_benchmark=True,
|
||||||
|
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
|
||||||
|
dist_cfg=dict(backend='nccl'))
|
||||||
|
vis_backends = [dict(type='LocalVisBackend')]
|
||||||
|
visualizer = dict(
|
||||||
|
type='SegLocalVisualizer',
|
||||||
|
vis_backends=[dict(type='LocalVisBackend')],
|
||||||
|
name='visualizer')
|
||||||
|
log_processor = dict(by_epoch=False)
|
||||||
|
log_level = 'INFO'
|
||||||
|
load_from = None
|
||||||
|
resume = False
|
||||||
|
tta_model = dict(type='SegTTAModel')
|
||||||
|
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)
|
||||||
|
optim_wrapper = dict(
|
||||||
|
type='OptimWrapper',
|
||||||
|
optimizer=dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001),
|
||||||
|
clip_grad=None)
|
||||||
|
param_scheduler = [
|
||||||
|
dict(
|
||||||
|
type='PolyLR',
|
||||||
|
eta_min=0.0001,
|
||||||
|
power=0.9,
|
||||||
|
begin=0,
|
||||||
|
end=240000,
|
||||||
|
by_epoch=False)
|
||||||
|
]
|
||||||
|
train_cfg = dict(
|
||||||
|
type='IterBasedTrainLoop', max_iters=240000, val_interval=24000)
|
||||||
|
val_cfg = dict(type='ValLoop')
|
||||||
|
test_cfg = dict(type='TestLoop')
|
||||||
|
default_hooks = dict(
|
||||||
|
timer=dict(type='IterTimerHook'),
|
||||||
|
logger=dict(type='LoggerHook', interval=50, log_metric_by_epoch=False),
|
||||||
|
param_scheduler=dict(type='ParamSchedulerHook'),
|
||||||
|
checkpoint=dict(type='CheckpointHook', by_epoch=False, interval=24000),
|
||||||
|
sampler_seed=dict(type='DistSamplerSeedHook'),
|
||||||
|
visualization=dict(type='SegVisualizationHook'))
|
|
@ -0,0 +1,117 @@
|
||||||
|
## Prepare datasets
|
||||||
|
|
||||||
|
It is recommended to symlink the dataset root to `$MMSEGMENTATION/data`.
|
||||||
|
If your folder structure is different, you may need to change the corresponding paths in config files.
|
||||||
|
|
||||||
|
```none
|
||||||
|
mmsegmentation
|
||||||
|
├── mmseg
|
||||||
|
├── tools
|
||||||
|
├── configs
|
||||||
|
├── data
|
||||||
|
│ ├── mapillary
|
||||||
|
│ │ ├── training
|
||||||
|
│ │ │ ├── images
|
||||||
|
│ │ │ ├── v1.2
|
||||||
|
| │ │ │ ├── instances
|
||||||
|
| │ │ │ ├── labels
|
||||||
|
| │ │ │ ├── labels_mask
|
||||||
|
| │ │ │ └── panoptic
|
||||||
|
│ │ │ ├── v2.0
|
||||||
|
| │ │ │ ├── instances
|
||||||
|
| │ │ │ ├── labels
|
||||||
|
| │ │ │ ├── labels_mask
|
||||||
|
| │ │ │ ├── panoptic
|
||||||
|
| │ │ │ └── polygons
|
||||||
|
│ │ ├── validation
|
||||||
|
│ │ │ ├── images
|
||||||
|
| │ │ │ ├── instances
|
||||||
|
| │ │ │ ├── labels
|
||||||
|
| │ │ │ ├── labels_mask
|
||||||
|
| │ │ │ └── panoptic
|
||||||
|
│ │ │ ├── v2.0
|
||||||
|
| │ │ │ ├── instances
|
||||||
|
| │ │ │ ├── labels
|
||||||
|
| │ │ │ ├── labels_mask
|
||||||
|
| │ │ │ ├── panoptic
|
||||||
|
| │ │ │ └── polygons
|
||||||
|
```
|
||||||
|
|
||||||
|
## Mapillary Vistas Datasets
|
||||||
|
|
||||||
|
- The dataset could be download [here](https://www.mapillary.com/dataset/vistas) after registration.
|
||||||
|
- Assumption you have put the dataset zip file in `mmsegmentation/data`
|
||||||
|
- Please run the following commands to unzip dataset.
|
||||||
|
```bash
|
||||||
|
cd data
|
||||||
|
mkdir mapillary
|
||||||
|
unzip -d mapillary An-ZjB1Zm61yAZG0ozTymz8I8NqI4x0MrYrh26dq7kPgfu8vf9ImrdaOAVOFYbJ2pNAgUnVGBmbue9lTgdBOb5BbKXIpFs0fpYWqACbrQDChAA2fdX0zS9PcHu7fY8c-FOvyBVxPNYNFQuM.zip
|
||||||
|
```
|
||||||
|
- After unzip, you will get Mapillary Vistas Dataset like this structure.
|
||||||
|
```none
|
||||||
|
├── data
|
||||||
|
│ ├── mapillary
|
||||||
|
│ │ ├── training
|
||||||
|
│ │ │ ├── images
|
||||||
|
│ │ │ ├── v1.2
|
||||||
|
| │ │ │ ├── instances
|
||||||
|
| │ │ │ ├── labels
|
||||||
|
| │ │ │ └── panoptic
|
||||||
|
│ │ │ ├── v2.0
|
||||||
|
| │ │ │ ├── instances
|
||||||
|
| │ │ │ ├── labels
|
||||||
|
| │ │ │ ├── panoptic
|
||||||
|
| │ │ │ └── polygons
|
||||||
|
│ │ ├── validation
|
||||||
|
│ │ │ ├── images
|
||||||
|
| │ │ │ ├── instances
|
||||||
|
| │ │ │ ├── labels
|
||||||
|
| │ │ │ └── panoptic
|
||||||
|
│ │ │ ├── v2.0
|
||||||
|
| │ │ │ ├── instances
|
||||||
|
| │ │ │ ├── labels
|
||||||
|
| │ │ │ ├── panoptic
|
||||||
|
| │ │ │ └── polygons
|
||||||
|
```
|
||||||
|
- run following commands to convert RGB labels to mask labels
|
||||||
|
```bash
|
||||||
|
# --nproc optional, default 1, whether use multi-progress
|
||||||
|
# --version optional, 'v1.2', 'v2.0','all', default 'all', choose convert which version labels
|
||||||
|
# run this command at 'mmsegmentation/projects/Mapillary_dataset' folder
|
||||||
|
cd mmsegmentation/projects/mapillary_dataset
|
||||||
|
python tools/dataset_converters/mapillary.py ../../data/mapillary --nproc 8 --version all
|
||||||
|
```
|
||||||
|
After then, you will get this structure
|
||||||
|
```none
|
||||||
|
mmsegmentation
|
||||||
|
├── mmseg
|
||||||
|
├── tools
|
||||||
|
├── configs
|
||||||
|
├── data
|
||||||
|
│ ├── mapillary
|
||||||
|
│ │ ├── training
|
||||||
|
│ │ │ ├── images
|
||||||
|
│ │ │ ├── v1.2
|
||||||
|
| │ │ │ ├── instances
|
||||||
|
| │ │ │ ├── labels
|
||||||
|
| │ │ │ ├── labels_mask
|
||||||
|
| │ │ │ └── panoptic
|
||||||
|
│ │ │ ├── v2.0
|
||||||
|
| │ │ │ ├── instances
|
||||||
|
| │ │ │ ├── labels
|
||||||
|
| │ │ │ ├── labels_mask
|
||||||
|
| │ │ │ ├── panoptic
|
||||||
|
| │ │ │ └── polygons
|
||||||
|
│ │ ├── validation
|
||||||
|
│ │ │ ├── images
|
||||||
|
| │ │ │ ├── instances
|
||||||
|
| │ │ │ ├── labels
|
||||||
|
| │ │ │ ├── labels_mask
|
||||||
|
| │ │ │ └── panoptic
|
||||||
|
│ │ │ ├── v2.0
|
||||||
|
| │ │ │ ├── instances
|
||||||
|
| │ │ │ ├── labels
|
||||||
|
| │ │ │ ├── labels_mask
|
||||||
|
| │ │ │ ├── panoptic
|
||||||
|
| │ │ │ └── polygons
|
||||||
|
```
|
|
@ -0,0 +1,65 @@
|
||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
from mmseg.datasets.basesegdataset import BaseSegDataset
|
||||||
|
from mmseg.registry import DATASETS
|
||||||
|
|
||||||
|
|
||||||
|
@DATASETS.register_module()
|
||||||
|
class MapillaryDataset_v1_2(BaseSegDataset):
|
||||||
|
"""Mapillary Vistas Dataset.
|
||||||
|
|
||||||
|
Dataset paper link:
|
||||||
|
http://ieeexplore.ieee.org/document/8237796/
|
||||||
|
|
||||||
|
v1.2 contain 66 object classes.
|
||||||
|
(37 instance-specific)
|
||||||
|
|
||||||
|
v2.0 contain 124 object classes.
|
||||||
|
(70 instance-specific, 46 stuff, 8 void or crowd).
|
||||||
|
|
||||||
|
The ``img_suffix`` is fixed to '.jpg' and ``seg_map_suffix`` is
|
||||||
|
fixed to '.png' for Mapillary Vistas Dataset.
|
||||||
|
"""
|
||||||
|
METAINFO = dict(
|
||||||
|
classes=('Bird', 'Ground Animal', 'Curb', 'Fence', 'Guard Rail',
|
||||||
|
'Barrier', 'Wall', 'Bike Lane', 'Crosswalk - Plain',
|
||||||
|
'Curb Cut', 'Parking', 'Pedestrian Area', 'Rail Track',
|
||||||
|
'Road', 'Service Lane', 'Sidewalk', 'Bridge', 'Building',
|
||||||
|
'Tunnel', 'Person', 'Bicyclist', 'Motorcyclist',
|
||||||
|
'Other Rider', 'Lane Marking - Crosswalk',
|
||||||
|
'Lane Marking - General', 'Mountain', 'Sand', 'Sky', 'Snow',
|
||||||
|
'Terrain', 'Vegetation', 'Water', 'Banner', 'Bench',
|
||||||
|
'Bike Rack', 'Billboard', 'Catch Basin', 'CCTV Camera',
|
||||||
|
'Fire Hydrant', 'Junction Box', 'Mailbox', 'Manhole',
|
||||||
|
'Phone Booth', 'Pothole', 'Street Light', 'Pole',
|
||||||
|
'Traffic Sign Frame', 'Utility Pole', 'Traffic Light',
|
||||||
|
'Traffic Sign (Back)', 'Traffic Sign (Front)', 'Trash Can',
|
||||||
|
'Bicycle', 'Boat', 'Bus', 'Car', 'Caravan', 'Motorcycle',
|
||||||
|
'On Rails', 'Other Vehicle', 'Trailer', 'Truck',
|
||||||
|
'Wheeled Slow', 'Car Mount', 'Ego Vehicle', 'Unlabeled'),
|
||||||
|
palette=[[165, 42, 42], [0, 192, 0], [196, 196, 196], [190, 153, 153],
|
||||||
|
[180, 165, 180], [90, 120, 150], [102, 102, 156],
|
||||||
|
[128, 64, 255], [140, 140, 200], [170, 170, 170],
|
||||||
|
[250, 170, 160], [96, 96, 96],
|
||||||
|
[230, 150, 140], [128, 64, 128], [110, 110, 110],
|
||||||
|
[244, 35, 232], [150, 100, 100], [70, 70, 70], [150, 120, 90],
|
||||||
|
[220, 20, 60], [255, 0, 0], [255, 0, 100], [255, 0, 200],
|
||||||
|
[200, 128, 128], [255, 255, 255], [64, 170,
|
||||||
|
64], [230, 160, 50],
|
||||||
|
[70, 130, 180], [190, 255, 255], [152, 251, 152],
|
||||||
|
[107, 142, 35], [0, 170, 30], [255, 255, 128], [250, 0, 30],
|
||||||
|
[100, 140, 180], [220, 220, 220], [220, 128, 128],
|
||||||
|
[222, 40, 40], [100, 170, 30], [40, 40, 40], [33, 33, 33],
|
||||||
|
[100, 128, 160], [142, 0, 0], [70, 100, 150], [210, 170, 100],
|
||||||
|
[153, 153, 153], [128, 128, 128], [0, 0, 80], [250, 170, 30],
|
||||||
|
[192, 192, 192], [220, 220, 0], [140, 140, 20], [119, 11, 32],
|
||||||
|
[150, 0, 255], [0, 60, 100], [0, 0, 142], [0, 0, 90],
|
||||||
|
[0, 0, 230], [0, 80, 100], [128, 64, 64], [0, 0, 110],
|
||||||
|
[0, 0, 70], [0, 0, 192], [32, 32, 32], [120, 10,
|
||||||
|
10], [0, 0, 0]])
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
img_suffix='.jpg',
|
||||||
|
seg_map_suffix='.png',
|
||||||
|
**kwargs) -> None:
|
||||||
|
super().__init__(
|
||||||
|
img_suffix=img_suffix, seg_map_suffix=seg_map_suffix, **kwargs)
|
|
@ -0,0 +1,114 @@
|
||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
from mmseg.datasets.basesegdataset import BaseSegDataset
|
||||||
|
from mmseg.registry import DATASETS
|
||||||
|
|
||||||
|
|
||||||
|
@DATASETS.register_module()
|
||||||
|
class MapillaryDataset_v2_0(BaseSegDataset):
|
||||||
|
"""Mapillary Vistas Dataset.
|
||||||
|
|
||||||
|
Dataset paper link:
|
||||||
|
http://ieeexplore.ieee.org/document/8237796/
|
||||||
|
|
||||||
|
v1.2 contain 66 object classes.
|
||||||
|
(37 instance-specific)
|
||||||
|
|
||||||
|
v2.0 contain 124 object classes.
|
||||||
|
(70 instance-specific, 46 stuff, 8 void or crowd).
|
||||||
|
|
||||||
|
The ``img_suffix`` is fixed to '.jpg' and ``seg_map_suffix`` is
|
||||||
|
fixed to '.png' for Mapillary Vistas Dataset.
|
||||||
|
"""
|
||||||
|
METAINFO = dict(
|
||||||
|
classes=(
|
||||||
|
'Bird', 'Ground Animal', 'Ambiguous Barrier', 'Concrete Block',
|
||||||
|
'Curb', 'Fence', 'Guard Rail', 'Barrier', 'Road Median',
|
||||||
|
'Road Side', 'Lane Separator', 'Temporary Barrier', 'Wall',
|
||||||
|
'Bike Lane', 'Crosswalk - Plain', 'Curb Cut', 'Driveway',
|
||||||
|
'Parking', 'Parking Aisle', 'Pedestrian Area', 'Rail Track',
|
||||||
|
'Road', 'Road Shoulder', 'Service Lane', 'Sidewalk',
|
||||||
|
'Traffic Island', 'Bridge', 'Building', 'Garage', 'Tunnel',
|
||||||
|
'Person', 'Person Group', 'Bicyclist', 'Motorcyclist',
|
||||||
|
'Other Rider', 'Lane Marking - Dashed Line',
|
||||||
|
'Lane Marking - Straight Line', 'Lane Marking - Zigzag Line',
|
||||||
|
'Lane Marking - Ambiguous', 'Lane Marking - Arrow (Left)',
|
||||||
|
'Lane Marking - Arrow (Other)', 'Lane Marking - Arrow (Right)',
|
||||||
|
'Lane Marking - Arrow (Split Left or Straight)',
|
||||||
|
'Lane Marking - Arrow (Split Right or Straight)',
|
||||||
|
'Lane Marking - Arrow (Straight)', 'Lane Marking - Crosswalk',
|
||||||
|
'Lane Marking - Give Way (Row)',
|
||||||
|
'Lane Marking - Give Way (Single)',
|
||||||
|
'Lane Marking - Hatched (Chevron)',
|
||||||
|
'Lane Marking - Hatched (Diagonal)', 'Lane Marking - Other',
|
||||||
|
'Lane Marking - Stop Line', 'Lane Marking - Symbol (Bicycle)',
|
||||||
|
'Lane Marking - Symbol (Other)', 'Lane Marking - Text',
|
||||||
|
'Lane Marking (only) - Dashed Line',
|
||||||
|
'Lane Marking (only) - Crosswalk', 'Lane Marking (only) - Other',
|
||||||
|
'Lane Marking (only) - Test', 'Mountain', 'Sand', 'Sky', 'Snow',
|
||||||
|
'Terrain', 'Vegetation', 'Water', 'Banner', 'Bench', 'Bike Rack',
|
||||||
|
'Catch Basin', 'CCTV Camera', 'Fire Hydrant', 'Junction Box',
|
||||||
|
'Mailbox', 'Manhole', 'Parking Meter', 'Phone Booth', 'Pothole',
|
||||||
|
'Signage - Advertisement', 'Signage - Ambiguous', 'Signage - Back',
|
||||||
|
'Signage - Information', 'Signage - Other', 'Signage - Store',
|
||||||
|
'Street Light', 'Pole', 'Pole Group', 'Traffic Sign Frame',
|
||||||
|
'Utility Pole', 'Traffic Cone', 'Traffic Light - General (Single)',
|
||||||
|
'Traffic Light - Pedestrians', 'Traffic Light - General (Upright)',
|
||||||
|
'Traffic Light - General (Horizontal)', 'Traffic Light - Cyclists',
|
||||||
|
'Traffic Light - Other', 'Traffic Sign - Ambiguous',
|
||||||
|
'Traffic Sign (Back)', 'Traffic Sign - Direction (Back)',
|
||||||
|
'Traffic Sign - Direction (Front)', 'Traffic Sign (Front)',
|
||||||
|
'Traffic Sign - Parking', 'Traffic Sign - Temporary (Back)',
|
||||||
|
'Traffic Sign - Temporary (Front)', 'Trash Can', 'Bicycle', 'Boat',
|
||||||
|
'Bus', 'Car', 'Caravan', 'Motorcycle', 'On Rails', 'Other Vehicle',
|
||||||
|
'Trailer', 'Truck', 'Vehicle Group', 'Wheeled Slow', 'Water Valve',
|
||||||
|
'Car Mount', 'Dynamic', 'Ego Vehicle', 'Ground', 'Static',
|
||||||
|
'Unlabeled'),
|
||||||
|
palette=[[165, 42, 42], [0, 192, 0], [250, 170, 31], [250, 170, 32],
|
||||||
|
[196, 196, 196], [190, 153, 153], [180, 165, 180],
|
||||||
|
[90, 120, 150], [250, 170, 33], [250, 170, 34],
|
||||||
|
[128, 128, 128], [250, 170, 35], [102, 102, 156],
|
||||||
|
[128, 64, 255], [140, 140, 200], [170, 170, 170],
|
||||||
|
[250, 170, 36], [250, 170, 160], [250, 170, 37], [96, 96, 96],
|
||||||
|
[230, 150, 140], [128, 64, 128], [110, 110, 110],
|
||||||
|
[110, 110, 110], [244, 35, 232], [128, 196,
|
||||||
|
128], [150, 100, 100],
|
||||||
|
[70, 70, 70], [150, 150, 150], [150, 120, 90], [220, 20, 60],
|
||||||
|
[220, 20, 60], [255, 0, 0], [255, 0, 100], [255, 0, 200],
|
||||||
|
[255, 255, 255], [255, 255, 255], [250, 170, 29],
|
||||||
|
[250, 170, 28], [250, 170, 26], [250, 170,
|
||||||
|
25], [250, 170, 24],
|
||||||
|
[250, 170, 22], [250, 170, 21], [250, 170,
|
||||||
|
20], [255, 255, 255],
|
||||||
|
[250, 170, 19], [250, 170, 18], [250, 170,
|
||||||
|
12], [250, 170, 11],
|
||||||
|
[255, 255, 255], [255, 255, 255], [250, 170, 16],
|
||||||
|
[250, 170, 15], [250, 170, 15], [255, 255, 255],
|
||||||
|
[255, 255, 255], [255, 255, 255], [255, 255, 255],
|
||||||
|
[64, 170, 64], [230, 160, 50],
|
||||||
|
[70, 130, 180], [190, 255, 255], [152, 251, 152],
|
||||||
|
[107, 142, 35], [0, 170, 30], [255, 255, 128], [250, 0, 30],
|
||||||
|
[100, 140, 180], [220, 128, 128], [222, 40,
|
||||||
|
40], [100, 170, 30],
|
||||||
|
[40, 40, 40], [33, 33, 33], [100, 128, 160], [20, 20, 255],
|
||||||
|
[142, 0, 0], [70, 100, 150], [250, 171, 30], [250, 172, 30],
|
||||||
|
[250, 173, 30], [250, 174, 30], [250, 175,
|
||||||
|
30], [250, 176, 30],
|
||||||
|
[210, 170, 100], [153, 153, 153], [153, 153, 153],
|
||||||
|
[128, 128, 128], [0, 0, 80], [210, 60, 60], [250, 170, 30],
|
||||||
|
[250, 170, 30], [250, 170, 30], [250, 170,
|
||||||
|
30], [250, 170, 30],
|
||||||
|
[250, 170, 30], [192, 192, 192], [192, 192, 192],
|
||||||
|
[192, 192, 192], [220, 220, 0], [220, 220, 0], [0, 0, 196],
|
||||||
|
[192, 192, 192], [220, 220, 0], [140, 140, 20], [119, 11, 32],
|
||||||
|
[150, 0, 255], [0, 60, 100], [0, 0, 142], [0, 0, 90],
|
||||||
|
[0, 0, 230], [0, 80, 100], [128, 64, 64], [0, 0, 110],
|
||||||
|
[0, 0, 70], [0, 0, 142], [0, 0, 192], [170, 170, 170],
|
||||||
|
[32, 32, 32], [111, 74, 0], [120, 10, 10], [81, 0, 81],
|
||||||
|
[111, 111, 0], [0, 0, 0]])
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
img_suffix='.jpg',
|
||||||
|
seg_map_suffix='.png',
|
||||||
|
**kwargs) -> None:
|
||||||
|
super().__init__(
|
||||||
|
img_suffix=img_suffix, seg_map_suffix=seg_map_suffix, **kwargs)
|
|
@ -0,0 +1,245 @@
|
||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
import argparse
|
||||||
|
import os.path as osp
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
import mmcv
|
||||||
|
import numpy as np
|
||||||
|
from mmengine.utils import (mkdir_or_exist, scandir, track_parallel_progress,
|
||||||
|
track_progress)
|
||||||
|
|
||||||
|
colormap_v1_2 = np.array([[165, 42, 42], [0, 192, 0], [196, 196, 196],
|
||||||
|
[190, 153, 153], [180, 165, 180], [90, 120, 150],
|
||||||
|
[102, 102, 156], [128, 64, 255], [140, 140, 200],
|
||||||
|
[170, 170, 170], [250, 170, 160], [96, 96, 96],
|
||||||
|
[230, 150, 140], [128, 64, 128], [110, 110, 110],
|
||||||
|
[244, 35, 232], [150, 100, 100], [70, 70, 70],
|
||||||
|
[150, 120, 90], [220, 20, 60], [255, 0, 0],
|
||||||
|
[255, 0, 100], [255, 0, 200], [200, 128, 128],
|
||||||
|
[255, 255, 255], [64, 170, 64], [230, 160, 50],
|
||||||
|
[70, 130, 180], [190, 255, 255], [152, 251, 152],
|
||||||
|
[107, 142, 35], [0, 170, 30], [255, 255, 128],
|
||||||
|
[250, 0, 30], [100, 140, 180], [220, 220, 220],
|
||||||
|
[220, 128, 128], [222, 40, 40], [100, 170, 30],
|
||||||
|
[40, 40, 40], [33, 33, 33], [100, 128, 160],
|
||||||
|
[142, 0, 0], [70, 100, 150], [210, 170, 100],
|
||||||
|
[153, 153, 153], [128, 128, 128], [0, 0, 80],
|
||||||
|
[250, 170, 30], [192, 192, 192], [220, 220, 0],
|
||||||
|
[140, 140, 20], [119, 11, 32], [150, 0, 255],
|
||||||
|
[0, 60, 100], [0, 0, 142], [0, 0, 90], [0, 0, 230],
|
||||||
|
[0, 80, 100], [128, 64, 64], [0, 0, 110], [0, 0, 70],
|
||||||
|
[0, 0, 192], [32, 32, 32], [120, 10, 10], [0, 0, 0]])
|
||||||
|
|
||||||
|
colormap_v2_0 = np.array([[165, 42, 42], [0, 192, 0], [250, 170, 31],
|
||||||
|
[250, 170, 32], [196, 196, 196], [190, 153, 153],
|
||||||
|
[180, 165, 180], [90, 120, 150], [250, 170, 33],
|
||||||
|
[250, 170, 34], [128, 128, 128], [250, 170, 35],
|
||||||
|
[102, 102, 156], [128, 64, 255], [140, 140, 200],
|
||||||
|
[170, 170, 170], [250, 170, 36], [250, 170, 160],
|
||||||
|
[250, 170, 37], [96, 96, 96], [230, 150, 140],
|
||||||
|
[128, 64, 128], [110, 110, 110], [110, 110, 110],
|
||||||
|
[244, 35, 232], [128, 196, 128], [150, 100, 100],
|
||||||
|
[70, 70, 70], [150, 150, 150], [150, 120, 90],
|
||||||
|
[220, 20, 60], [220, 20, 60], [255, 0, 0],
|
||||||
|
[255, 0, 100], [255, 0, 200], [255, 255, 255],
|
||||||
|
[255, 255, 255], [250, 170, 29], [250, 170, 28],
|
||||||
|
[250, 170, 26], [250, 170, 25], [250, 170, 24],
|
||||||
|
[250, 170, 22], [250, 170, 21], [250, 170, 20],
|
||||||
|
[255, 255, 255], [250, 170, 19], [250, 170, 18],
|
||||||
|
[250, 170, 12], [250, 170, 11], [255, 255, 255],
|
||||||
|
[255, 255, 255], [250, 170, 16], [250, 170, 15],
|
||||||
|
[250, 170, 15], [255, 255, 255], [255, 255, 255],
|
||||||
|
[255, 255, 255], [255, 255, 255], [64, 170, 64],
|
||||||
|
[230, 160, 50], [70, 130, 180], [190, 255, 255],
|
||||||
|
[152, 251, 152], [107, 142, 35], [0, 170, 30],
|
||||||
|
[255, 255, 128], [250, 0, 30], [100, 140, 180],
|
||||||
|
[220, 128, 128], [222, 40, 40], [100, 170, 30],
|
||||||
|
[40, 40, 40], [33, 33, 33], [100, 128, 160],
|
||||||
|
[20, 20, 255], [142, 0, 0], [70, 100, 150],
|
||||||
|
[250, 171, 30], [250, 172, 30], [250, 173, 30],
|
||||||
|
[250, 174, 30], [250, 175, 30], [250, 176, 30],
|
||||||
|
[210, 170, 100], [153, 153, 153], [153, 153, 153],
|
||||||
|
[128, 128, 128], [0, 0, 80], [210, 60, 60],
|
||||||
|
[250, 170, 30], [250, 170, 30], [250, 170, 30],
|
||||||
|
[250, 170, 30], [250, 170, 30], [250, 170, 30],
|
||||||
|
[192, 192, 192], [192, 192, 192], [192, 192, 192],
|
||||||
|
[220, 220, 0], [220, 220, 0], [0, 0, 196],
|
||||||
|
[192, 192, 192], [220, 220, 0], [140, 140, 20],
|
||||||
|
[119, 11, 32], [150, 0, 255], [0, 60, 100],
|
||||||
|
[0, 0, 142], [0, 0, 90], [0, 0, 230], [0, 80, 100],
|
||||||
|
[128, 64, 64], [0, 0, 110], [0, 0, 70], [0, 0, 142],
|
||||||
|
[0, 0, 192], [170, 170, 170], [32, 32, 32],
|
||||||
|
[111, 74, 0], [120, 10, 10], [81, 0, 81],
|
||||||
|
[111, 111, 0], [0, 0, 0]])
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description='Convert Mapillary dataset to mmsegmentation format')
|
||||||
|
parser.add_argument('dataset_path', help='Mapillary folder path')
|
||||||
|
parser.add_argument(
|
||||||
|
'--version',
|
||||||
|
default='all',
|
||||||
|
help="Mapillary labels version, 'v1.2','v2.0','all'")
|
||||||
|
parser.add_argument('-o', '--out_dir', help='output path')
|
||||||
|
parser.add_argument(
|
||||||
|
'--nproc', default=1, type=int, help='number of process')
|
||||||
|
args = parser.parse_args()
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
def mapillary_colormap2label(colormap: np.ndarray) -> list:
|
||||||
|
"""Create a `list` shaped (256^3, 1), convert each color palette to a
|
||||||
|
number, which can use to find the correct label value.
|
||||||
|
|
||||||
|
For example labels 0--Bird--[165, 42, 42]
|
||||||
|
(165*256 + 42) * 256 + 42 = 10824234 (This is list's index])
|
||||||
|
`colormap2label[10824234] = 0`
|
||||||
|
|
||||||
|
In converting, if a RGB pixel value is [165, 42, 42],
|
||||||
|
through colormap2label[10824234]-->can quickly find
|
||||||
|
this labels value is 0.
|
||||||
|
Through matrix multiply to compute a img is very fast.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
colormap (np.ndarray): Mapillary Vistas Dataset palette
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: values are mask labels,
|
||||||
|
indexes are palette's convert results.、
|
||||||
|
"""
|
||||||
|
colormap2label = np.zeros(256**3, dtype=np.longlong)
|
||||||
|
for i, colormap_ in enumerate(colormap):
|
||||||
|
colormap2label[(colormap_[0] * 256 + colormap_[1]) * 256 +
|
||||||
|
colormap_[2]] = i
|
||||||
|
return colormap2label
|
||||||
|
|
||||||
|
|
||||||
|
def mapillary_masklabel(rgb_label: np.ndarray,
|
||||||
|
colormap2label: list) -> np.ndarray:
|
||||||
|
"""Computing a img mask label through `colormap2label` get in
|
||||||
|
`mapillary_colormap2label(COLORMAP: np.ndarray)`
|
||||||
|
|
||||||
|
Args:
|
||||||
|
rgb_label (np.array): a RGB labels img.
|
||||||
|
colormap2label (list): get in mapillary_colormap2label(colormap)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
np.ndarray: mask labels array.
|
||||||
|
"""
|
||||||
|
colormap_ = rgb_label.astype('uint32')
|
||||||
|
idx = np.array((colormap_[:, :, 0] * 256 + colormap_[:, :, 1]) * 256 +
|
||||||
|
colormap_[:, :, 2]).astype('uint32')
|
||||||
|
return colormap2label[idx]
|
||||||
|
|
||||||
|
|
||||||
|
def RGB2Mask(rgb_label_path: str, colormap2label: list) -> None:
|
||||||
|
"""Mapillary Vistas Dataset provide 8-bit with color-palette class-specific
|
||||||
|
labels for semantic segmentation. However, semantic segmentation needs
|
||||||
|
single channel mask labels.
|
||||||
|
|
||||||
|
This code is about converting mapillary RGB labels
|
||||||
|
{traing,validation/v1.2,v2.0/labels} to mask labels
|
||||||
|
{{traing,validation/v1.2,v2.0/labels_mask}
|
||||||
|
|
||||||
|
Args:
|
||||||
|
rgb_label_path (str): image absolute path.
|
||||||
|
dataset_version (str): v1.2 or v2.0 to choose color_map .
|
||||||
|
"""
|
||||||
|
rgb_label = mmcv.imread(rgb_label_path, channel_order='rgb')
|
||||||
|
|
||||||
|
masks_label = mapillary_masklabel(rgb_label, colormap2label)
|
||||||
|
|
||||||
|
mmcv.imwrite(
|
||||||
|
masks_label.astype(np.uint8),
|
||||||
|
rgb_label_path.replace('labels', 'labels_mask'))
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
colormap2label_v1_2 = mapillary_colormap2label(colormap_v1_2)
|
||||||
|
colormap2label_v2_0 = mapillary_colormap2label(colormap_v2_0)
|
||||||
|
|
||||||
|
dataset_path = args.dataset_path
|
||||||
|
if args.out_dir is None:
|
||||||
|
out_dir = dataset_path
|
||||||
|
else:
|
||||||
|
out_dir = args.out_dir
|
||||||
|
|
||||||
|
RGB_labels_path = []
|
||||||
|
RGB_labels_v1_2_path = []
|
||||||
|
RGB_labels_v2_0_path = []
|
||||||
|
print('Scanning labels path....')
|
||||||
|
for label_path in scandir(dataset_path, suffix='.png', recursive=True):
|
||||||
|
if 'labels' in label_path:
|
||||||
|
rgb_label_path = osp.join(dataset_path, label_path)
|
||||||
|
RGB_labels_path.append(rgb_label_path)
|
||||||
|
if 'v1.2' in label_path:
|
||||||
|
RGB_labels_v1_2_path.append(rgb_label_path)
|
||||||
|
elif 'v2.0' in label_path:
|
||||||
|
RGB_labels_v2_0_path.append(rgb_label_path)
|
||||||
|
|
||||||
|
if args.version == 'all':
|
||||||
|
print(f'Totaly found {len(RGB_labels_path)} {args.version} RGB labels')
|
||||||
|
elif args.version == 'v1.2':
|
||||||
|
print(f'Found {len(RGB_labels_v1_2_path)} {args.version} RGB labels')
|
||||||
|
elif args.version == 'v2.0':
|
||||||
|
print(f'Found {len(RGB_labels_v2_0_path)} {args.version} RGB labels')
|
||||||
|
print('Making directories...')
|
||||||
|
mkdir_or_exist(osp.join(out_dir, 'training', 'v1.2', 'labels_mask'))
|
||||||
|
mkdir_or_exist(osp.join(out_dir, 'validation', 'v1.2', 'labels_mask'))
|
||||||
|
mkdir_or_exist(osp.join(out_dir, 'training', 'v2.0', 'labels_mask'))
|
||||||
|
mkdir_or_exist(osp.join(out_dir, 'validation', 'v2.0', 'labels_mask'))
|
||||||
|
print('Directories Have Made...')
|
||||||
|
|
||||||
|
if args.nproc > 1:
|
||||||
|
if args.version == 'all':
|
||||||
|
print('Converting v1.2 ....')
|
||||||
|
track_parallel_progress(
|
||||||
|
partial(RGB2Mask, colormap2label=colormap2label_v1_2),
|
||||||
|
RGB_labels_v1_2_path,
|
||||||
|
nproc=args.nproc)
|
||||||
|
print('Converting v2.0 ....')
|
||||||
|
track_parallel_progress(
|
||||||
|
partial(RGB2Mask, colormap2label=colormap2label_v2_0),
|
||||||
|
RGB_labels_v2_0_path,
|
||||||
|
nproc=args.nproc)
|
||||||
|
elif args.version == 'v1.2':
|
||||||
|
print('Converting v1.2 ....')
|
||||||
|
track_parallel_progress(
|
||||||
|
partial(RGB2Mask, colormap2label=colormap2label_v1_2),
|
||||||
|
RGB_labels_v1_2_path,
|
||||||
|
nproc=args.nproc)
|
||||||
|
elif args.version == 'v2.0':
|
||||||
|
print('Converting v2.0 ....')
|
||||||
|
track_parallel_progress(
|
||||||
|
partial(RGB2Mask, colormap2label=colormap2label_v2_0),
|
||||||
|
RGB_labels_v2_0_path,
|
||||||
|
nproc=args.nproc)
|
||||||
|
|
||||||
|
else:
|
||||||
|
if args.version == 'all':
|
||||||
|
print('Converting v1.2 ....')
|
||||||
|
track_progress(
|
||||||
|
partial(RGB2Mask, colormap2label=colormap2label_v1_2),
|
||||||
|
RGB_labels_v1_2_path)
|
||||||
|
print('Converting v2.0 ....')
|
||||||
|
track_progress(
|
||||||
|
partial(RGB2Mask, colormap2label=colormap2label_v2_0),
|
||||||
|
RGB_labels_v2_0_path)
|
||||||
|
elif args.version == 'v1.2':
|
||||||
|
print('Converting v1.2 ....')
|
||||||
|
track_progress(
|
||||||
|
partial(RGB2Mask, colormap2label=colormap2label_v1_2),
|
||||||
|
RGB_labels_v1_2_path)
|
||||||
|
elif args.version == 'v2.0':
|
||||||
|
print('Converting v2.0 ....')
|
||||||
|
track_progress(
|
||||||
|
partial(RGB2Mask, colormap2label=colormap2label_v2_0),
|
||||||
|
RGB_labels_v2_0_path)
|
||||||
|
|
||||||
|
print('Have convert Mapillary Vistas Datasets RGB labels to Mask labels!')
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
args = parse_args()
|
||||||
|
main()
|
Loading…
Reference in New Issue