mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
Add 4 retinal vessel segmentation benchmark (#315)
* add 4 retinal vessel segmentation configs of UNet * fix flip augmentation * add unet benchmark on 4 medical datasets * fix hrf bug
This commit is contained in:
parent
b9ba9f6ce7
commit
2e479d346b
59
configs/_base_/datasets/chase_db1.py
Normal file
59
configs/_base_/datasets/chase_db1.py
Normal file
@ -0,0 +1,59 @@
|
||||
# dataset settings
|
||||
dataset_type = 'ChaseDB1Dataset'
|
||||
data_root = 'data/CHASE_DB1'
|
||||
img_norm_cfg = dict(
|
||||
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
||||
img_scale = (960, 999)
|
||||
crop_size = (128, 128)
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='LoadAnnotations'),
|
||||
dict(type='Resize', img_scale=img_scale, ratio_range=(0.5, 2.0)),
|
||||
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
|
||||
dict(type='RandomFlip', prob=0.5),
|
||||
dict(type='PhotoMetricDistortion'),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
|
||||
dict(type='DefaultFormatBundle'),
|
||||
dict(type='Collect', keys=['img', 'gt_semantic_seg'])
|
||||
]
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='MultiScaleFlipAug',
|
||||
img_scale=img_scale,
|
||||
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0],
|
||||
flip=False,
|
||||
transforms=[
|
||||
dict(type='Resize', keep_ratio=True),
|
||||
dict(type='RandomFlip'),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='ImageToTensor', keys=['img']),
|
||||
dict(type='Collect', keys=['img'])
|
||||
])
|
||||
]
|
||||
|
||||
data = dict(
|
||||
samples_per_gpu=4,
|
||||
workers_per_gpu=4,
|
||||
train=dict(
|
||||
type='RepeatDataset',
|
||||
times=40000,
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
img_dir='images/training',
|
||||
ann_dir='annotations/training',
|
||||
pipeline=train_pipeline)),
|
||||
val=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
img_dir='images/validation',
|
||||
ann_dir='annotations/validation',
|
||||
pipeline=test_pipeline),
|
||||
test=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
img_dir='images/validation',
|
||||
ann_dir='annotations/validation',
|
||||
pipeline=test_pipeline))
|
59
configs/_base_/datasets/drive.py
Normal file
59
configs/_base_/datasets/drive.py
Normal file
@ -0,0 +1,59 @@
|
||||
# dataset settings
|
||||
dataset_type = 'DRIVEDataset'
|
||||
data_root = 'data/DRIVE'
|
||||
img_norm_cfg = dict(
|
||||
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
||||
img_scale = (584, 565)
|
||||
crop_size = (64, 64)
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='LoadAnnotations'),
|
||||
dict(type='Resize', img_scale=img_scale, ratio_range=(0.5, 2.0)),
|
||||
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
|
||||
dict(type='RandomFlip', prob=0.5),
|
||||
dict(type='PhotoMetricDistortion'),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
|
||||
dict(type='DefaultFormatBundle'),
|
||||
dict(type='Collect', keys=['img', 'gt_semantic_seg'])
|
||||
]
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='MultiScaleFlipAug',
|
||||
img_scale=img_scale,
|
||||
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0],
|
||||
flip=False,
|
||||
transforms=[
|
||||
dict(type='Resize', keep_ratio=True),
|
||||
dict(type='RandomFlip'),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='ImageToTensor', keys=['img']),
|
||||
dict(type='Collect', keys=['img'])
|
||||
])
|
||||
]
|
||||
|
||||
data = dict(
|
||||
samples_per_gpu=4,
|
||||
workers_per_gpu=4,
|
||||
train=dict(
|
||||
type='RepeatDataset',
|
||||
times=40000,
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
img_dir='images/training',
|
||||
ann_dir='annotations/training',
|
||||
pipeline=train_pipeline)),
|
||||
val=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
img_dir='images/validation',
|
||||
ann_dir='annotations/validation',
|
||||
pipeline=test_pipeline),
|
||||
test=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
img_dir='images/validation',
|
||||
ann_dir='annotations/validation',
|
||||
pipeline=test_pipeline))
|
59
configs/_base_/datasets/hrf.py
Normal file
59
configs/_base_/datasets/hrf.py
Normal file
@ -0,0 +1,59 @@
|
||||
# dataset settings
|
||||
dataset_type = 'HRFDataset'
|
||||
data_root = 'data/HRF'
|
||||
img_norm_cfg = dict(
|
||||
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
||||
img_scale = (2336, 3504)
|
||||
crop_size = (256, 256)
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='LoadAnnotations'),
|
||||
dict(type='Resize', img_scale=img_scale, ratio_range=(0.5, 2.0)),
|
||||
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
|
||||
dict(type='RandomFlip', prob=0.5),
|
||||
dict(type='PhotoMetricDistortion'),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
|
||||
dict(type='DefaultFormatBundle'),
|
||||
dict(type='Collect', keys=['img', 'gt_semantic_seg'])
|
||||
]
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='MultiScaleFlipAug',
|
||||
img_scale=img_scale,
|
||||
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0],
|
||||
flip=False,
|
||||
transforms=[
|
||||
dict(type='Resize', keep_ratio=True),
|
||||
dict(type='RandomFlip'),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='ImageToTensor', keys=['img']),
|
||||
dict(type='Collect', keys=['img'])
|
||||
])
|
||||
]
|
||||
|
||||
data = dict(
|
||||
samples_per_gpu=4,
|
||||
workers_per_gpu=4,
|
||||
train=dict(
|
||||
type='RepeatDataset',
|
||||
times=40000,
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
img_dir='images/training',
|
||||
ann_dir='annotations/training',
|
||||
pipeline=train_pipeline)),
|
||||
val=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
img_dir='images/validation',
|
||||
ann_dir='annotations/validation',
|
||||
pipeline=test_pipeline),
|
||||
test=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
img_dir='images/validation',
|
||||
ann_dir='annotations/validation',
|
||||
pipeline=test_pipeline))
|
59
configs/_base_/datasets/stare.py
Normal file
59
configs/_base_/datasets/stare.py
Normal file
@ -0,0 +1,59 @@
|
||||
# dataset settings
|
||||
dataset_type = 'STAREDataset'
|
||||
data_root = 'data/STARE'
|
||||
img_norm_cfg = dict(
|
||||
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
||||
img_scale = (605, 700)
|
||||
crop_size = (128, 128)
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='LoadAnnotations'),
|
||||
dict(type='Resize', img_scale=img_scale, ratio_range=(0.5, 2.0)),
|
||||
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
|
||||
dict(type='RandomFlip', prob=0.5),
|
||||
dict(type='PhotoMetricDistortion'),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
|
||||
dict(type='DefaultFormatBundle'),
|
||||
dict(type='Collect', keys=['img', 'gt_semantic_seg'])
|
||||
]
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='MultiScaleFlipAug',
|
||||
img_scale=img_scale,
|
||||
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0],
|
||||
flip=False,
|
||||
transforms=[
|
||||
dict(type='Resize', keep_ratio=True),
|
||||
dict(type='RandomFlip'),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='ImageToTensor', keys=['img']),
|
||||
dict(type='Collect', keys=['img'])
|
||||
])
|
||||
]
|
||||
|
||||
data = dict(
|
||||
samples_per_gpu=4,
|
||||
workers_per_gpu=4,
|
||||
train=dict(
|
||||
type='RepeatDataset',
|
||||
times=40000,
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
img_dir='images/training',
|
||||
ann_dir='annotations/training',
|
||||
pipeline=train_pipeline)),
|
||||
val=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
img_dir='images/validation',
|
||||
ann_dir='annotations/validation',
|
||||
pipeline=test_pipeline),
|
||||
test=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
img_dir='images/validation',
|
||||
ann_dir='annotations/validation',
|
||||
pipeline=test_pipeline))
|
51
configs/_base_/models/unet_s5-d16.py
Normal file
51
configs/_base_/models/unet_s5-d16.py
Normal file
@ -0,0 +1,51 @@
|
||||
# model settings
|
||||
norm_cfg = dict(type='SyncBN', requires_grad=True)
|
||||
model = dict(
|
||||
type='EncoderDecoder',
|
||||
pretrained=None,
|
||||
backbone=dict(
|
||||
type='UNet',
|
||||
in_channels=3,
|
||||
base_channels=64,
|
||||
num_stages=5,
|
||||
strides=(1, 1, 1, 1, 1),
|
||||
enc_num_convs=(2, 2, 2, 2, 2),
|
||||
dec_num_convs=(2, 2, 2, 2),
|
||||
downsamples=(True, True, True, True),
|
||||
enc_dilations=(1, 1, 1, 1, 1),
|
||||
dec_dilations=(1, 1, 1, 1),
|
||||
with_cp=False,
|
||||
conv_cfg=None,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=dict(type='ReLU'),
|
||||
upsample_cfg=dict(type='InterpConv'),
|
||||
norm_eval=False),
|
||||
decode_head=dict(
|
||||
type='FCNHead',
|
||||
in_channels=64,
|
||||
in_index=4,
|
||||
channels=64,
|
||||
num_convs=1,
|
||||
concat_input=False,
|
||||
dropout_ratio=0.1,
|
||||
num_classes=2,
|
||||
norm_cfg=norm_cfg,
|
||||
align_corners=False,
|
||||
loss_decode=dict(
|
||||
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
|
||||
auxiliary_head=dict(
|
||||
type='FCNHead',
|
||||
in_channels=128,
|
||||
in_index=3,
|
||||
channels=64,
|
||||
num_convs=1,
|
||||
concat_input=False,
|
||||
dropout_ratio=0.1,
|
||||
num_classes=2,
|
||||
norm_cfg=norm_cfg,
|
||||
align_corners=False,
|
||||
loss_decode=dict(
|
||||
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)))
|
||||
# model training and testing settings
|
||||
train_cfg = dict()
|
||||
test_cfg = dict(mode='slide', crop_size=256, stride=170)
|
23
configs/unet/README.md
Normal file
23
configs/unet/README.md
Normal file
@ -0,0 +1,23 @@
|
||||
# U-Net: Convolutional Networks for Biomedical Image Segmentation
|
||||
|
||||
## Introduction
|
||||
|
||||
```latex
|
||||
@inproceedings{ronneberger2015u,
|
||||
title={U-net: Convolutional networks for biomedical image segmentation},
|
||||
author={Ronneberger, Olaf and Fischer, Philipp and Brox, Thomas},
|
||||
booktitle={International Conference on Medical image computing and computer-assisted intervention},
|
||||
pages={234--241},
|
||||
year={2015},
|
||||
organization={Springer}
|
||||
}
|
||||
```
|
||||
|
||||
## Results and models
|
||||
|
||||
| Backbone | Head | Dataset | Image Size | Crop Size | Stride | Lr schd | Mem (GB) | Inf time (fps) | Dice | download |
|
||||
|--------|----------|----------|----------|-----------|--------:|----------|----------------|------:|--------------:|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| UNet-S5-D16 | FCN | DRIVE | 584x565 | 64x64 | 42x42 | 40000 | 0.680 | - | 78.67 | [model](https://download.openmmlab.com/mmsegmentation/v0.5/unet/unet_s5-d16_64x64_40k_drive/unet_s5-d16_64x64_40k_drive_20201223_191051-9cd163b8.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/unet/unet_s5-d16_64x64_40k_drive/unet_s5-d16_64x64_40k_drive-20201223_191051.log.json) |
|
||||
| UNet-S5-D16 | FCN | STARE | 605x700 | 128x128 | 85x85 | 40000 | 0.968 | - | 81.02 | [model](https://download.openmmlab.com/mmsegmentation/v0.5/unet/unet_s5-d16_128x128_40k_stare/unet_s5-d16_128x128_40k_stare_20201223_191051-e5439846.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/unet/unet_s5-d16_128x128_40k_stare/unet_s5-d16_128x128_40k_stare-20201223_191051.log.json) |
|
||||
| UNet-S5-D16 | FCN | CHASE_DB1 | 960x999 | 128x128 | 85x85 | 40000 | 0.968 | - | 80.24 | [model](https://download.openmmlab.com/mmsegmentation/v0.5/unet/unet_s5-d16_128x128_40k_chase_db1/unet_s5-d16_128x128_40k_chase_db1_20201223_191051-8b16ca0b.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/unet/unet_s5-d16_128x128_40k_chase_db1/unet_s5-d16_128x128_40k_chase_db1-20201223_191051.log.json) |
|
||||
| UNet-S5-D16 | FCN | HRF | 2336x3504 | 256x256 | 170x170 | 40000 | 2.525 | - | 79.45 | [model](https://download.openmmlab.com/mmsegmentation/v0.5/unet/unet_s5-d16_256x256_40k_hrf/unet_s5-d16_256x256_40k_hrf_20201223_173724-d89cf1ed.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/unet/unet_s5-d16_256x256_40k_hrf/unet_s5-d16_256x256_40k_hrf-20201223_173724.log.json) |
|
6
configs/unet/unet_s5-d16_128x128_40k_chase_db1.py
Normal file
6
configs/unet/unet_s5-d16_128x128_40k_chase_db1.py
Normal file
@ -0,0 +1,6 @@
|
||||
_base_ = [
|
||||
'../_base_/models/unet_s5-d16.py', '../_base_/datasets/chase_db1.py',
|
||||
'../_base_/default_runtime.py', '../_base_/schedules/schedule_40k.py'
|
||||
]
|
||||
test_cfg = dict(crop_size=(128, 128), stride=(85, 85))
|
||||
evaluation = dict(metric='mDice')
|
6
configs/unet/unet_s5-d16_128x128_40k_stare.py
Normal file
6
configs/unet/unet_s5-d16_128x128_40k_stare.py
Normal file
@ -0,0 +1,6 @@
|
||||
_base_ = [
|
||||
'../_base_/models/unet_s5-d16.py', '../_base_/datasets/stare.py',
|
||||
'../_base_/default_runtime.py', '../_base_/schedules/schedule_40k.py'
|
||||
]
|
||||
test_cfg = dict(crop_size=(128, 128), stride=(85, 85))
|
||||
evaluation = dict(metric='mDice')
|
6
configs/unet/unet_s5-d16_256x256_40k_hrf.py
Normal file
6
configs/unet/unet_s5-d16_256x256_40k_hrf.py
Normal file
@ -0,0 +1,6 @@
|
||||
_base_ = [
|
||||
'../_base_/models/unet_s5-d16.py', '../_base_/datasets/hrf.py',
|
||||
'../_base_/default_runtime.py', '../_base_/schedules/schedule_40k.py'
|
||||
]
|
||||
test_cfg = dict(crop_size=(256, 256), stride=(170, 170))
|
||||
evaluation = dict(metric='mDice')
|
6
configs/unet/unet_s5-d16_64x64_40k_drive.py
Normal file
6
configs/unet/unet_s5-d16_64x64_40k_drive.py
Normal file
@ -0,0 +1,6 @@
|
||||
_base_ = [
|
||||
'../_base_/models/unet_s5-d16.py', '../_base_/datasets/drive.py',
|
||||
'../_base_/default_runtime.py', '../_base_/schedules/schedule_40k.py'
|
||||
]
|
||||
test_cfg = dict(crop_size=(64, 64), stride=(42, 42))
|
||||
evaluation = dict(metric='mDice')
|
Loading…
x
Reference in New Issue
Block a user