[Feature] Add GDAL backend and Support LEVIR-CD Dataset (#2903)
## Motivation For support with reading multiple remote sensing image formats, please refer to https://gdal.org/drivers/raster/index.html. Byte, UInt16, Int16, UInt32, Int32, Float32, Float64, CInt16, CInt32, CFloat32 and CFloat64 are supported for reading and writing. Support input of two images for change detection tasks, and support the LEVIR-CD dataset. ## Modification Add LoadSingleRSImageFromFile in 'mmseg/datasets/transforms/loading.py'. Load a single remote sensing image for object segmentation tasks. Add LoadMultipleRSImageFromFile in 'mmseg/datasets/transforms/loading.py'. Load two remote sensing images for change detection tasks. Add ConcatCDInput in 'mmseg/datasets/transforms/transforms.py'. Combine images that have been separately augmented for data enhancement. Add BaseCDDataset in 'mmseg/datasets/basesegdataset.py' Base class for datasets used in change detection tasks. --------- Co-authored-by: xiexinch <xiexinch@outlook.com>pull/2925/head
parent
77836e6231
commit
77591b9e7b
|
@ -0,0 +1,59 @@
|
|||
# dataset settings
|
||||
dataset_type = 'LEVIRCDDataset'
|
||||
data_root = r'data/LEVIRCD'
|
||||
|
||||
albu_train_transforms = [
|
||||
dict(type='RandomBrightnessContrast', p=0.2),
|
||||
dict(type='HorizontalFlip', p=0.5),
|
||||
dict(type='VerticalFlip', p=0.5)
|
||||
]
|
||||
|
||||
train_pipeline = [
|
||||
dict(type='LoadMultipleRSImageFromFile'),
|
||||
dict(type='LoadAnnotations'),
|
||||
dict(type='Albu', transforms=albu_train_transforms),
|
||||
dict(type='ConcatCDInput'),
|
||||
dict(type='PackSegInputs')
|
||||
]
|
||||
test_pipeline = [
|
||||
dict(type='LoadMultipleRSImageFromFile'),
|
||||
dict(type='LoadAnnotations'),
|
||||
dict(type='ConcatCDInput'),
|
||||
dict(type='PackSegInputs')
|
||||
]
|
||||
|
||||
tta_pipeline = [
|
||||
dict(type='LoadMultipleRSImageFromFile'),
|
||||
dict(
|
||||
type='TestTimeAug',
|
||||
transforms=[[dict(type='LoadAnnotations')],
|
||||
[dict(type='ConcatCDInput')],
|
||||
[dict(type='PackSegInputs')]])
|
||||
]
|
||||
train_dataloader = dict(
|
||||
batch_size=4,
|
||||
num_workers=4,
|
||||
persistent_workers=True,
|
||||
sampler=dict(type='InfiniteSampler', shuffle=True),
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
data_prefix=dict(
|
||||
img_path='train/A',
|
||||
img_path2='train/B',
|
||||
seg_map_path='train/label'),
|
||||
pipeline=train_pipeline))
|
||||
val_dataloader = dict(
|
||||
batch_size=1,
|
||||
num_workers=4,
|
||||
persistent_workers=True,
|
||||
sampler=dict(type='DefaultSampler', shuffle=False),
|
||||
dataset=dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
data_prefix=dict(
|
||||
img_path='test/A', img_path2='test/B', seg_map_path='test/label'),
|
||||
pipeline=test_pipeline))
|
||||
test_dataloader = val_dataloader
|
||||
val_evaluator = dict(type='IoUMetric', iou_metrics=['mIoU'])
|
||||
test_evaluator = val_evaluator
|
|
@ -0,0 +1,56 @@
|
|||
_base_ = [
|
||||
'../_base_/models/upernet_swin.py', '../_base_/datasets/levir_256x256.py',
|
||||
'../_base_/default_runtime.py', '../_base_/schedules/schedule_20k.py'
|
||||
]
|
||||
crop_size = (256, 256)
|
||||
norm_cfg = dict(type='BN', requires_grad=True)
|
||||
data_preprocessor = dict(
|
||||
size=crop_size,
|
||||
type='SegDataPreProcessor',
|
||||
mean=[123.675, 116.28, 103.53, 123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375, 58.395, 57.12, 57.375])
|
||||
|
||||
model = dict(
|
||||
data_preprocessor=data_preprocessor,
|
||||
backbone=dict(
|
||||
in_channels=6,
|
||||
embed_dims=96,
|
||||
depths=[2, 2, 6, 2],
|
||||
num_heads=[3, 6, 12, 24],
|
||||
window_size=7,
|
||||
use_abs_pos_embed=False,
|
||||
drop_path_rate=0.3,
|
||||
patch_norm=True),
|
||||
decode_head=dict(in_channels=[96, 192, 384, 768], num_classes=2),
|
||||
auxiliary_head=dict(in_channels=384, num_classes=2))
|
||||
|
||||
# AdamW optimizer, no weight decay for position embedding & layer norm
|
||||
# in backbone
|
||||
optim_wrapper = dict(
|
||||
_delete_=True,
|
||||
type='OptimWrapper',
|
||||
optimizer=dict(
|
||||
type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01),
|
||||
paramwise_cfg=dict(
|
||||
custom_keys={
|
||||
'absolute_pos_embed': dict(decay_mult=0.),
|
||||
'relative_position_bias_table': dict(decay_mult=0.),
|
||||
'norm': dict(decay_mult=0.)
|
||||
}))
|
||||
|
||||
param_scheduler = [
|
||||
dict(
|
||||
type='LinearLR', start_factor=1e-6, by_epoch=False, begin=0, end=1500),
|
||||
dict(
|
||||
type='PolyLR',
|
||||
eta_min=0.0,
|
||||
power=1.0,
|
||||
begin=1500,
|
||||
end=20000,
|
||||
by_epoch=False,
|
||||
)
|
||||
]
|
||||
|
||||
train_dataloader = dict(batch_size=4)
|
||||
val_dataloader = dict(batch_size=1)
|
||||
test_dataloader = val_dataloader
|
|
@ -195,6 +195,16 @@ Run it with
|
|||
docker run --gpus all --shm-size=8g -it -v {DATA_DIR}:/mmsegmentation/data mmsegmentation
|
||||
```
|
||||
|
||||
### Optional Dependencies
|
||||
|
||||
#### Install GDAL
|
||||
|
||||
[GDAL](https://gdal.org/) is a translator library for raster and vector geospatial data formats. Install GDAL to read complex formats and extremely large remote sensing images.
|
||||
|
||||
```shell
|
||||
conda install GDAL
|
||||
```
|
||||
|
||||
## Trouble shooting
|
||||
|
||||
If you have some issues during the installation, please first view the [FAQ](notes/faq.md) page.
|
||||
|
|
|
@ -620,3 +620,33 @@ It includes 400 images for training, 400 images for validation and 400 images fo
|
|||
|
||||
- You could set Datasets version with `MapillaryDataset_v1` and `MapillaryDataset_v2` in your configs.
|
||||
View the Mapillary Vistas Datasets config file here [V1.2](https://github.com/open-mmlab/mmsegmentation/blob/main/configs/_base_/datasets/mapillary_v1.py) and [V2.0](https://github.com/open-mmlab/mmsegmentation/blob/main/configs/_base_/datasets/mapillary_v2.py)
|
||||
|
||||
## LEVIR-CD
|
||||
|
||||
[LEVIR-CD](https://justchenhao.github.io/LEVIR/) Large-scale Remote Sensing Change Detection Dataset for Building.
|
||||
|
||||
Download the dataset from [here](https://justchenhao.github.io/LEVIR/).
|
||||
|
||||
The supplement version of the dataset can be requested on the [homepage](https://github.com/S2Looking/Dataset)
|
||||
|
||||
Please download the supplement version of the dataset, then unzip `LEVIR-CD+.zip`, the contents of original datasets include:
|
||||
|
||||
```none
|
||||
│ ├── LEVIR-CD+
|
||||
│ │ ├── train
|
||||
│ │ │ ├── A
|
||||
│ │ │ ├── B
|
||||
│ │ │ ├── label
|
||||
│ │ ├── test
|
||||
│ │ │ ├── A
|
||||
│ │ │ ├── B
|
||||
│ │ │ ├── label
|
||||
```
|
||||
|
||||
For LEVIR-CD dataset, please run the following command to crop images without overlap:
|
||||
|
||||
```shell
|
||||
python tools/dataset_converters/levircd.py --dataset-path /path/to/LEVIR-CD+ --out_dir /path/to/LEVIR-CD
|
||||
```
|
||||
|
||||
The size of cropped image is 256x256, which is consistent with the original paper.
|
||||
|
|
|
@ -194,6 +194,16 @@ docker build -t mmsegmentation docker/
|
|||
docker run --gpus all --shm-size=8g -it -v {DATA_DIR}:/mmsegmentation/data mmsegmentation
|
||||
```
|
||||
|
||||
### 可选依赖
|
||||
|
||||
#### 安装 GDAL
|
||||
|
||||
[GDAL](https://gdal.org/) 是一个用于栅格和矢量地理空间数据格式的转换库。安装 GDAL 可以读取复杂格式和极大的遥感图像。
|
||||
|
||||
```shell
|
||||
conda install GDAL
|
||||
```
|
||||
|
||||
## 问题解答
|
||||
|
||||
如果您在安装过程中遇到了其他问题,请第一时间查阅 [FAQ](notes/faq.md) 文件。如果没有找到答案,您也可以在 GitHub 上提出 [issue](https://github.com/open-mmlab/mmsegmentation/issues/new/choose)
|
||||
|
|
|
@ -616,3 +616,33 @@ python tools/convert_datasets/refuge.py --raw_data_root=/path/to/refuge/REFUGE2/
|
|||
|
||||
- 您可以在配置中使用 `MapillaryDataset_v1` 和 `Mapillary Dataset_v2` 设置数据集版本。
|
||||
在此处 [V1.2](https://github.com/open-mmlab/mmsegmentation/blob/main/configs/_base_/datasets/mapillary_v1.py) 和 [V2.0](https://github.com/open-mmlab/mmsegmentation/blob/main/configs/_base_/datasets/mapillary_v2.py) 查看 Mapillary Vistas 数据集配置文件
|
||||
|
||||
## LEVIR-CD
|
||||
|
||||
[LEVIR-CD](https://justchenhao.github.io/LEVIR/) 大规模遥感建筑变化检测数据集。
|
||||
|
||||
数据集可以在[主页](https://justchenhao.github.io/LEVIR/)上请求获得。
|
||||
|
||||
数据集的补充版本可以在[主页](https://github.com/S2Looking/Dataset)上请求获得。
|
||||
|
||||
请下载数据集的补充版本,然后解压 `LEVIR-CD+.zip`,数据集的内容包括:
|
||||
|
||||
```none
|
||||
│ ├── LEVIR-CD+
|
||||
│ │ ├── train
|
||||
│ │ │ ├── A
|
||||
│ │ │ ├── B
|
||||
│ │ │ ├── label
|
||||
│ │ ├── test
|
||||
│ │ │ ├── A
|
||||
│ │ │ ├── B
|
||||
│ │ │ ├── label
|
||||
```
|
||||
|
||||
对于 LEVIR-CD 数据集,请运行以下命令无重叠裁剪影像:
|
||||
|
||||
```shell
|
||||
python tools/dataset_converters/levircd.py --dataset-path /path/to/LEVIR-CD+ --out_dir /path/to/LEVIR-CD
|
||||
```
|
||||
|
||||
裁剪后的影像大小为256x256,与原论文保持一致。
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
# yapf: disable
|
||||
from .ade import ADE20KDataset
|
||||
from .basesegdataset import BaseSegDataset
|
||||
from .basesegdataset import BaseCDDataset, BaseSegDataset
|
||||
from .chase_db1 import ChaseDB1Dataset
|
||||
from .cityscapes import CityscapesDataset
|
||||
from .coco_stuff import COCOStuffDataset
|
||||
|
@ -12,6 +12,7 @@ from .drive import DRIVEDataset
|
|||
from .hrf import HRFDataset
|
||||
from .isaid import iSAIDDataset
|
||||
from .isprs import ISPRSDataset
|
||||
from .levir import LEVIRCDDataset
|
||||
from .lip import LIPDataset
|
||||
from .loveda import LoveDADataset
|
||||
from .mapillary import MapillaryDataset_v1, MapillaryDataset_v2
|
||||
|
@ -25,13 +26,15 @@ from .synapse import SynapseDataset
|
|||
from .transforms import (CLAHE, AdjustGamma, Albu, BioMedical3DPad,
|
||||
BioMedical3DRandomCrop, BioMedical3DRandomFlip,
|
||||
BioMedicalGaussianBlur, BioMedicalGaussianNoise,
|
||||
BioMedicalRandomGamma, GenerateEdge, LoadAnnotations,
|
||||
LoadBiomedicalAnnotation, LoadBiomedicalData,
|
||||
LoadBiomedicalImageFromFile, LoadImageFromNDArray,
|
||||
PackSegInputs, PhotoMetricDistortion, RandomCrop,
|
||||
RandomCutOut, RandomMosaic, RandomRotate,
|
||||
RandomRotFlip, Rerange, ResizeShortestEdge,
|
||||
ResizeToMultiple, RGB2Gray, SegRescale)
|
||||
BioMedicalRandomGamma, ConcatCDInput, GenerateEdge,
|
||||
LoadAnnotations, LoadBiomedicalAnnotation,
|
||||
LoadBiomedicalData, LoadBiomedicalImageFromFile,
|
||||
LoadImageFromNDArray, LoadMultipleRSImageFromFile,
|
||||
LoadSingleRSImageFromFile, PackSegInputs,
|
||||
PhotoMetricDistortion, RandomCrop, RandomCutOut,
|
||||
RandomMosaic, RandomRotate, RandomRotFlip, Rerange,
|
||||
ResizeShortestEdge, ResizeToMultiple, RGB2Gray,
|
||||
SegRescale)
|
||||
from .voc import PascalVOCDataset
|
||||
|
||||
# yapf: enable
|
||||
|
@ -51,5 +54,7 @@ __all__ = [
|
|||
'BioMedicalGaussianNoise', 'BioMedicalGaussianBlur',
|
||||
'BioMedicalRandomGamma', 'BioMedical3DPad', 'RandomRotFlip',
|
||||
'SynapseDataset', 'REFUGEDataset', 'MapillaryDataset_v1',
|
||||
'MapillaryDataset_v2', 'Albu'
|
||||
'MapillaryDataset_v2', 'Albu', 'LEVIRCDDataset',
|
||||
'LoadMultipleRSImageFromFile', 'LoadSingleRSImageFromFile',
|
||||
'ConcatCDInput', 'BaseCDDataset'
|
||||
]
|
||||
|
|
|
@ -266,3 +266,284 @@ class BaseSegDataset(BaseDataset):
|
|||
data_list.append(data_info)
|
||||
data_list = sorted(data_list, key=lambda x: x['img_path'])
|
||||
return data_list
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class BaseCDDataset(BaseDataset):
|
||||
"""Custom dataset for change detection. An example of file structure is as
|
||||
followed.
|
||||
|
||||
.. code-block:: none
|
||||
|
||||
├── data
|
||||
│ ├── my_dataset
|
||||
│ │ ├── img_dir
|
||||
│ │ │ ├── train
|
||||
│ │ │ │ ├── xxx{img_suffix}
|
||||
│ │ │ │ ├── yyy{img_suffix}
|
||||
│ │ │ │ ├── zzz{img_suffix}
|
||||
│ │ │ ├── val
|
||||
│ │ ├── img_dir2
|
||||
│ │ │ ├── train
|
||||
│ │ │ │ ├── xxx{img_suffix}
|
||||
│ │ │ │ ├── yyy{img_suffix}
|
||||
│ │ │ │ ├── zzz{img_suffix}
|
||||
│ │ │ ├── val
|
||||
│ │ ├── ann_dir
|
||||
│ │ │ ├── train
|
||||
│ │ │ │ ├── xxx{seg_map_suffix}
|
||||
│ │ │ │ ├── yyy{seg_map_suffix}
|
||||
│ │ │ │ ├── zzz{seg_map_suffix}
|
||||
│ │ │ ├── val
|
||||
|
||||
The image names in img_dir and img_dir2 should be consistent.
|
||||
The img/gt_semantic_seg pair of BaseSegDataset should be of the same
|
||||
except suffix. A valid img/gt_semantic_seg filename pair should be like
|
||||
``xxx{img_suffix}`` and ``xxx{seg_map_suffix}`` (extension is also included
|
||||
in the suffix). If split is given, then ``xxx`` is specified in txt file.
|
||||
Otherwise, all files in ``img_dir/``and ``ann_dir`` will be loaded.
|
||||
Please refer to ``docs/en/tutorials/new_dataset.md`` for more details.
|
||||
|
||||
|
||||
Args:
|
||||
ann_file (str): Annotation file path. Defaults to ''.
|
||||
metainfo (dict, optional): Meta information for dataset, such as
|
||||
specify classes to load. Defaults to None.
|
||||
data_root (str, optional): The root directory for ``data_prefix`` and
|
||||
``ann_file``. Defaults to None.
|
||||
data_prefix (dict, optional): Prefix for training data. Defaults to
|
||||
dict(img_path=None, img_path2=None, seg_map_path=None).
|
||||
img_suffix (str): Suffix of images. Default: '.jpg'
|
||||
img_suffix2 (str): Suffix of images. Default: '.jpg'
|
||||
seg_map_suffix (str): Suffix of segmentation maps. Default: '.png'
|
||||
filter_cfg (dict, optional): Config for filter data. Defaults to None.
|
||||
indices (int or Sequence[int], optional): Support using first few
|
||||
data in annotation file to facilitate training/testing on a smaller
|
||||
dataset. Defaults to None which means using all ``data_infos``.
|
||||
serialize_data (bool, optional): Whether to hold memory using
|
||||
serialized objects, when enabled, data loader workers can use
|
||||
shared RAM from master process instead of making a copy. Defaults
|
||||
to True.
|
||||
pipeline (list, optional): Processing pipeline. Defaults to [].
|
||||
test_mode (bool, optional): ``test_mode=True`` means in test phase.
|
||||
Defaults to False.
|
||||
lazy_init (bool, optional): Whether to load annotation during
|
||||
instantiation. In some cases, such as visualization, only the meta
|
||||
information of the dataset is needed, which is not necessary to
|
||||
load annotation file. ``Basedataset`` can skip load annotations to
|
||||
save time by set ``lazy_init=True``. Defaults to False.
|
||||
max_refetch (int, optional): If ``Basedataset.prepare_data`` get a
|
||||
None img. The maximum extra number of cycles to get a valid
|
||||
image. Defaults to 1000.
|
||||
ignore_index (int): The label index to be ignored. Default: 255
|
||||
reduce_zero_label (bool): Whether to mark label zero as ignored.
|
||||
Default to False.
|
||||
backend_args (dict, Optional): Arguments to instantiate a file backend.
|
||||
See https://mmengine.readthedocs.io/en/latest/api/fileio.htm
|
||||
for details. Defaults to None.
|
||||
Notes: mmcv>=2.0.0rc4, mmengine>=0.2.0 required.
|
||||
"""
|
||||
METAINFO: dict = dict()
|
||||
|
||||
def __init__(self,
|
||||
ann_file: str = '',
|
||||
img_suffix='.jpg',
|
||||
img_suffix2='.jpg',
|
||||
seg_map_suffix='.png',
|
||||
metainfo: Optional[dict] = None,
|
||||
data_root: Optional[str] = None,
|
||||
data_prefix: dict = dict(
|
||||
img_path='', img_path2='', seg_map_path=''),
|
||||
filter_cfg: Optional[dict] = None,
|
||||
indices: Optional[Union[int, Sequence[int]]] = None,
|
||||
serialize_data: bool = True,
|
||||
pipeline: List[Union[dict, Callable]] = [],
|
||||
test_mode: bool = False,
|
||||
lazy_init: bool = False,
|
||||
max_refetch: int = 1000,
|
||||
ignore_index: int = 255,
|
||||
reduce_zero_label: bool = False,
|
||||
backend_args: Optional[dict] = None) -> None:
|
||||
|
||||
self.img_suffix = img_suffix
|
||||
self.img_suffix2 = img_suffix2
|
||||
self.seg_map_suffix = seg_map_suffix
|
||||
self.ignore_index = ignore_index
|
||||
self.reduce_zero_label = reduce_zero_label
|
||||
self.backend_args = backend_args.copy() if backend_args else None
|
||||
|
||||
self.data_root = data_root
|
||||
self.data_prefix = copy.copy(data_prefix)
|
||||
self.ann_file = ann_file
|
||||
self.filter_cfg = copy.deepcopy(filter_cfg)
|
||||
self._indices = indices
|
||||
self.serialize_data = serialize_data
|
||||
self.test_mode = test_mode
|
||||
self.max_refetch = max_refetch
|
||||
self.data_list: List[dict] = []
|
||||
self.data_bytes: np.ndarray
|
||||
|
||||
# Set meta information.
|
||||
self._metainfo = self._load_metainfo(copy.deepcopy(metainfo))
|
||||
|
||||
# Get label map for custom classes
|
||||
new_classes = self._metainfo.get('classes', None)
|
||||
self.label_map = self.get_label_map(new_classes)
|
||||
self._metainfo.update(
|
||||
dict(
|
||||
label_map=self.label_map,
|
||||
reduce_zero_label=self.reduce_zero_label))
|
||||
|
||||
# Update palette based on label map or generate palette
|
||||
# if it is not defined
|
||||
updated_palette = self._update_palette()
|
||||
self._metainfo.update(dict(palette=updated_palette))
|
||||
|
||||
# Join paths.
|
||||
if self.data_root is not None:
|
||||
self._join_prefix()
|
||||
|
||||
# Build pipeline.
|
||||
self.pipeline = Compose(pipeline)
|
||||
# Full initialize the dataset.
|
||||
if not lazy_init:
|
||||
self.full_init()
|
||||
|
||||
if test_mode:
|
||||
assert self._metainfo.get('classes') is not None, \
|
||||
'dataset metainfo `classes` should be specified when testing'
|
||||
|
||||
@classmethod
|
||||
def get_label_map(cls,
|
||||
new_classes: Optional[Sequence] = None
|
||||
) -> Union[Dict, None]:
|
||||
"""Require label mapping.
|
||||
|
||||
The ``label_map`` is a dictionary, its keys are the old label ids and
|
||||
its values are the new label ids, and is used for changing pixel
|
||||
labels in load_annotations. If and only if old classes in cls.METAINFO
|
||||
is not equal to new classes in self._metainfo and nether of them is not
|
||||
None, `label_map` is not None.
|
||||
|
||||
Args:
|
||||
new_classes (list, tuple, optional): The new classes name from
|
||||
metainfo. Default to None.
|
||||
|
||||
|
||||
Returns:
|
||||
dict, optional: The mapping from old classes in cls.METAINFO to
|
||||
new classes in self._metainfo
|
||||
"""
|
||||
old_classes = cls.METAINFO.get('classes', None)
|
||||
if (new_classes is not None and old_classes is not None
|
||||
and list(new_classes) != list(old_classes)):
|
||||
|
||||
label_map = {}
|
||||
if not set(new_classes).issubset(cls.METAINFO['classes']):
|
||||
raise ValueError(
|
||||
f'new classes {new_classes} is not a '
|
||||
f'subset of classes {old_classes} in METAINFO.')
|
||||
for i, c in enumerate(old_classes):
|
||||
if c not in new_classes:
|
||||
label_map[i] = 255
|
||||
else:
|
||||
label_map[i] = new_classes.index(c)
|
||||
return label_map
|
||||
else:
|
||||
return None
|
||||
|
||||
def _update_palette(self) -> list:
|
||||
"""Update palette after loading metainfo.
|
||||
|
||||
If length of palette is equal to classes, just return the palette.
|
||||
If palette is not defined, it will randomly generate a palette.
|
||||
If classes is updated by customer, it will return the subset of
|
||||
palette.
|
||||
|
||||
Returns:
|
||||
Sequence: Palette for current dataset.
|
||||
"""
|
||||
palette = self._metainfo.get('palette', [])
|
||||
classes = self._metainfo.get('classes', [])
|
||||
# palette does match classes
|
||||
if len(palette) == len(classes):
|
||||
return palette
|
||||
|
||||
if len(palette) == 0:
|
||||
# Get random state before set seed, and restore
|
||||
# random state later.
|
||||
# It will prevent loss of randomness, as the palette
|
||||
# may be different in each iteration if not specified.
|
||||
# See: https://github.com/open-mmlab/mmdetection/issues/5844
|
||||
state = np.random.get_state()
|
||||
np.random.seed(42)
|
||||
# random palette
|
||||
new_palette = np.random.randint(
|
||||
0, 255, size=(len(classes), 3)).tolist()
|
||||
np.random.set_state(state)
|
||||
elif len(palette) >= len(classes) and self.label_map is not None:
|
||||
new_palette = []
|
||||
# return subset of palette
|
||||
for old_id, new_id in sorted(
|
||||
self.label_map.items(), key=lambda x: x[1]):
|
||||
if new_id != 255:
|
||||
new_palette.append(palette[old_id])
|
||||
new_palette = type(palette)(new_palette)
|
||||
else:
|
||||
raise ValueError('palette does not match classes '
|
||||
f'as metainfo is {self._metainfo}.')
|
||||
return new_palette
|
||||
|
||||
def load_data_list(self) -> List[dict]:
|
||||
"""Load annotation from directory or annotation file.
|
||||
|
||||
Returns:
|
||||
list[dict]: All data info of dataset.
|
||||
"""
|
||||
data_list = []
|
||||
img_dir = self.data_prefix.get('img_path', None)
|
||||
img_dir2 = self.data_prefix.get('img_path2', None)
|
||||
ann_dir = self.data_prefix.get('seg_map_path', None)
|
||||
if osp.isfile(self.ann_file):
|
||||
lines = mmengine.list_from_file(
|
||||
self.ann_file, backend_args=self.backend_args)
|
||||
for line in lines:
|
||||
img_name = line.strip()
|
||||
if '.' in osp.basename(img_name):
|
||||
img_name, img_ext = osp.splitext(img_name)
|
||||
self.img_suffix = img_ext
|
||||
self.img_suffix2 = img_ext
|
||||
data_info = dict(
|
||||
img_path=osp.join(img_dir, img_name + self.img_suffix),
|
||||
img_path2=osp.join(img_dir2, img_name + self.img_suffix2))
|
||||
|
||||
if ann_dir is not None:
|
||||
seg_map = img_name + self.seg_map_suffix
|
||||
data_info['seg_map_path'] = osp.join(ann_dir, seg_map)
|
||||
data_info['label_map'] = self.label_map
|
||||
data_info['reduce_zero_label'] = self.reduce_zero_label
|
||||
data_info['seg_fields'] = []
|
||||
data_list.append(data_info)
|
||||
else:
|
||||
for img in fileio.list_dir_or_file(
|
||||
dir_path=img_dir,
|
||||
list_dir=False,
|
||||
suffix=self.img_suffix,
|
||||
recursive=True,
|
||||
backend_args=self.backend_args):
|
||||
if '.' in osp.basename(img):
|
||||
img, img_ext = osp.splitext(img)
|
||||
self.img_suffix = img_ext
|
||||
self.img_suffix2 = img_ext
|
||||
data_info = dict(
|
||||
img_path=osp.join(img_dir, img + self.img_suffix),
|
||||
img_path2=osp.join(img_dir2, img + self.img_suffix2))
|
||||
if ann_dir is not None:
|
||||
seg_map = img + self.seg_map_suffix
|
||||
data_info['seg_map_path'] = osp.join(ann_dir, seg_map)
|
||||
data_info['label_map'] = self.label_map
|
||||
data_info['reduce_zero_label'] = self.reduce_zero_label
|
||||
data_info['seg_fields'] = []
|
||||
data_list.append(data_info)
|
||||
data_list = sorted(data_list, key=lambda x: x['img_path'])
|
||||
return data_list
|
||||
|
|
|
@ -0,0 +1,31 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
from mmseg.registry import DATASETS
|
||||
from .basesegdataset import BaseCDDataset
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class LEVIRCDDataset(BaseCDDataset):
|
||||
"""ISPRS dataset.
|
||||
|
||||
In segmentation map annotation for ISPRS, 0 is to ignore index.
|
||||
``reduce_zero_label`` should be set to True. The ``img_suffix`` and
|
||||
``seg_map_suffix`` are both fixed to '.png'.
|
||||
"""
|
||||
|
||||
METAINFO = dict(
|
||||
classes=('background', 'changed'),
|
||||
palette=[[0, 0, 0], [255, 255, 255]])
|
||||
|
||||
def __init__(self,
|
||||
img_suffix='.png',
|
||||
img_suffix2='.png',
|
||||
seg_map_suffix='.png',
|
||||
reduce_zero_label=False,
|
||||
**kwargs) -> None:
|
||||
super().__init__(
|
||||
img_suffix=img_suffix,
|
||||
img_suffix2=img_suffix2,
|
||||
seg_map_suffix=seg_map_suffix,
|
||||
reduce_zero_label=reduce_zero_label,
|
||||
**kwargs)
|
|
@ -2,12 +2,13 @@
|
|||
from .formatting import PackSegInputs
|
||||
from .loading import (LoadAnnotations, LoadBiomedicalAnnotation,
|
||||
LoadBiomedicalData, LoadBiomedicalImageFromFile,
|
||||
LoadImageFromNDArray)
|
||||
LoadImageFromNDArray, LoadMultipleRSImageFromFile,
|
||||
LoadSingleRSImageFromFile)
|
||||
# yapf: disable
|
||||
from .transforms import (CLAHE, AdjustGamma, Albu, BioMedical3DPad,
|
||||
BioMedical3DRandomCrop, BioMedical3DRandomFlip,
|
||||
BioMedicalGaussianBlur, BioMedicalGaussianNoise,
|
||||
BioMedicalRandomGamma, GenerateEdge,
|
||||
BioMedicalRandomGamma, ConcatCDInput, GenerateEdge,
|
||||
PhotoMetricDistortion, RandomCrop, RandomCutOut,
|
||||
RandomMosaic, RandomRotate, RandomRotFlip, Rerange,
|
||||
ResizeShortestEdge, ResizeToMultiple, RGB2Gray,
|
||||
|
@ -22,5 +23,6 @@ __all__ = [
|
|||
'LoadBiomedicalAnnotation', 'LoadBiomedicalData', 'GenerateEdge',
|
||||
'ResizeShortestEdge', 'BioMedicalGaussianNoise', 'BioMedicalGaussianBlur',
|
||||
'BioMedical3DRandomFlip', 'BioMedicalRandomGamma', 'BioMedical3DPad',
|
||||
'RandomRotFlip', 'Albu'
|
||||
'RandomRotFlip', 'Albu', 'LoadSingleRSImageFromFile', 'ConcatCDInput',
|
||||
'LoadMultipleRSImageFromFile'
|
||||
]
|
||||
|
|
|
@ -12,6 +12,11 @@ from mmcv.transforms import LoadImageFromFile
|
|||
from mmseg.registry import TRANSFORMS
|
||||
from mmseg.utils import datafrombytes
|
||||
|
||||
try:
|
||||
from osgeo import gdal
|
||||
except ImportError:
|
||||
gdal = None
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class LoadAnnotations(MMCV_LoadAnnotations):
|
||||
|
@ -493,3 +498,130 @@ class InferencerLoader(BaseTransform):
|
|||
if 'img' in inputs:
|
||||
return self.from_ndarray(inputs)
|
||||
return self.from_file(inputs)
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class LoadSingleRSImageFromFile(BaseTransform):
|
||||
"""Load a Remote Sensing mage from file.
|
||||
|
||||
Required Keys:
|
||||
|
||||
- img_path
|
||||
|
||||
Modified Keys:
|
||||
|
||||
- img
|
||||
- img_shape
|
||||
- ori_shape
|
||||
|
||||
Args:
|
||||
to_float32 (bool): Whether to convert the loaded image to a float32
|
||||
numpy array. If set to False, the loaded image is a float64 array.
|
||||
Defaults to True.
|
||||
"""
|
||||
|
||||
def __init__(self, to_float32: bool = True):
|
||||
self.to_float32 = to_float32
|
||||
|
||||
if gdal is None:
|
||||
raise RuntimeError('gdal is not installed')
|
||||
|
||||
def transform(self, results: Dict) -> Dict:
|
||||
"""Functions to load image.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict from :obj:``mmcv.BaseDataset``.
|
||||
|
||||
Returns:
|
||||
dict: The dict contains loaded image and meta information.
|
||||
"""
|
||||
|
||||
filename = results['img_path']
|
||||
ds = gdal.Open(filename)
|
||||
if ds is None:
|
||||
raise Exception(f'Unable to open file: {filename}')
|
||||
img = np.einsum('ijk->jki', ds.ReadAsArray())
|
||||
|
||||
if self.to_float32:
|
||||
img = img.astype(np.float32)
|
||||
|
||||
results['img'] = img
|
||||
results['img_shape'] = img.shape[:2]
|
||||
results['ori_shape'] = img.shape[:2]
|
||||
return results
|
||||
|
||||
def __repr__(self):
|
||||
repr_str = (f'{self.__class__.__name__}('
|
||||
f'to_float32={self.to_float32})')
|
||||
return repr_str
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class LoadMultipleRSImageFromFile(BaseTransform):
|
||||
"""Load two Remote Sensing mage from file.
|
||||
|
||||
Required Keys:
|
||||
|
||||
- img_path
|
||||
- img_path2
|
||||
|
||||
Modified Keys:
|
||||
|
||||
- img
|
||||
- img2
|
||||
- img_shape
|
||||
- ori_shape
|
||||
|
||||
Args:
|
||||
to_float32 (bool): Whether to convert the loaded image to a float32
|
||||
numpy array. If set to False, the loaded image is a float64 array.
|
||||
Defaults to True.
|
||||
"""
|
||||
|
||||
def __init__(self, to_float32: bool = True):
|
||||
if gdal is None:
|
||||
raise RuntimeError('gdal is not installed')
|
||||
self.to_float32 = to_float32
|
||||
|
||||
def transform(self, results: Dict) -> Dict:
|
||||
"""Functions to load image.
|
||||
|
||||
Args:
|
||||
results (dict): Result dict from :obj:``mmcv.BaseDataset``.
|
||||
|
||||
Returns:
|
||||
dict: The dict contains loaded image and meta information.
|
||||
"""
|
||||
|
||||
filename = results['img_path']
|
||||
filename2 = results['img_path2']
|
||||
|
||||
ds = gdal.Open(filename)
|
||||
ds2 = gdal.Open(filename2)
|
||||
|
||||
if ds is None:
|
||||
raise Exception(f'Unable to open file: {filename}')
|
||||
if ds2 is None:
|
||||
raise Exception(f'Unable to open file: {filename2}')
|
||||
|
||||
img = np.einsum('ijk->jki', ds.ReadAsArray())
|
||||
img2 = np.einsum('ijk->jki', ds2.ReadAsArray())
|
||||
|
||||
if self.to_float32:
|
||||
img = img.astype(np.float32)
|
||||
img2 = img2.astype(np.float32)
|
||||
|
||||
if img.shape != img2.shape:
|
||||
raise Exception(f'Image shapes do not match:'
|
||||
f' {img.shape} vs {img2.shape}')
|
||||
|
||||
results['img'] = img
|
||||
results['img2'] = img2
|
||||
results['img_shape'] = img.shape[:2]
|
||||
results['ori_shape'] = img.shape[:2]
|
||||
return results
|
||||
|
||||
def __repr__(self):
|
||||
repr_str = (f'{self.__class__.__name__}('
|
||||
f'to_float32={self.to_float32})')
|
||||
return repr_str
|
||||
|
|
|
@ -197,7 +197,7 @@ class CLAHE(BaseTransform):
|
|||
|
||||
def __repr__(self):
|
||||
repr_str = self.__class__.__name__
|
||||
repr_str += f'(clip_limit={self.clip_limit}, '\
|
||||
repr_str += f'(clip_limit={self.clip_limit}, ' \
|
||||
f'tile_grid_size={self.tile_grid_size})'
|
||||
return repr_str
|
||||
|
||||
|
@ -1162,8 +1162,8 @@ class RandomMosaic(BaseTransform):
|
|||
x1_c, y1_c, x2_c, y2_c = crop_coord
|
||||
|
||||
# crop and paste image
|
||||
mosaic_seg[y1_p:y2_p, x1_p:x2_p] = gt_seg_i[y1_c:y2_c,
|
||||
x1_c:x2_c]
|
||||
mosaic_seg[y1_p:y2_p, x1_p:x2_p] = \
|
||||
gt_seg_i[y1_c:y2_c, x1_c:x2_c]
|
||||
|
||||
results[key] = mosaic_seg
|
||||
|
||||
|
@ -1771,9 +1771,9 @@ class BioMedicalGaussianBlur(BaseTransform):
|
|||
repr_str += f'(prob={self.prob}, '
|
||||
repr_str += f'prob_per_channel={self.prob_per_channel}, '
|
||||
repr_str += f'sigma_range={self.sigma_range}, '
|
||||
repr_str += 'different_sigma_per_channel='\
|
||||
repr_str += 'different_sigma_per_channel=' \
|
||||
f'{self.different_sigma_per_channel}, '
|
||||
repr_str += 'different_sigma_per_axis='\
|
||||
repr_str += 'different_sigma_per_axis=' \
|
||||
f'{self.different_sigma_per_axis})'
|
||||
return repr_str
|
||||
|
||||
|
@ -2291,3 +2291,33 @@ class Albu(BaseTransform):
|
|||
def __repr__(self):
|
||||
repr_str = self.__class__.__name__ + f'(transforms={self.transforms})'
|
||||
return repr_str
|
||||
|
||||
|
||||
@TRANSFORMS.register_module()
|
||||
class ConcatCDInput(BaseTransform):
|
||||
"""Concat images for change detection.
|
||||
|
||||
Required Keys:
|
||||
|
||||
- img
|
||||
- img2
|
||||
|
||||
Args:
|
||||
input_keys (tuple): Input image keys for change detection.
|
||||
Default: ('img', 'img2').
|
||||
"""
|
||||
|
||||
def __init__(self, input_keys=('img', 'img2')):
|
||||
self.input_keys = input_keys
|
||||
|
||||
def transform(self, results: dict) -> dict:
|
||||
img = []
|
||||
for input_key in self.input_keys:
|
||||
img.append(results.pop(input_key))
|
||||
results['img'] = np.concatenate(img, axis=2)
|
||||
return results
|
||||
|
||||
def __repr__(self):
|
||||
repr_str = self.__class__.__name__
|
||||
repr_str += f'(input_keys={self.input_keys}, '
|
||||
return repr_str
|
||||
|
|
|
@ -92,7 +92,11 @@ def test_config_data_pipeline():
|
|||
del config_mod.train_pipeline[0]
|
||||
del config_mod.test_pipeline[0]
|
||||
# remove loading annotation in test pipeline
|
||||
del config_mod.test_pipeline[-2]
|
||||
load_anno_idx = -1
|
||||
for i in range(len(config_mod.test_pipeline)):
|
||||
if config_mod.test_pipeline[i].type == 'LoadAnnotations':
|
||||
load_anno_idx = i
|
||||
del config_mod.test_pipeline[load_anno_idx]
|
||||
|
||||
train_pipeline = Compose(config_mod.train_pipeline)
|
||||
test_pipeline = Compose(config_mod.test_pipeline)
|
||||
|
@ -110,22 +114,27 @@ def test_config_data_pipeline():
|
|||
ori_shape=img.shape,
|
||||
gt_seg_map=seg)
|
||||
results['seg_fields'] = ['gt_seg_map']
|
||||
|
||||
_check_concat_cd_input(config_mod, results)
|
||||
print(f'Test training data pipeline: \n{train_pipeline!r}')
|
||||
output_results = train_pipeline(results)
|
||||
assert output_results is not None
|
||||
|
||||
results = dict(
|
||||
filename='test_img.png',
|
||||
ori_filename='test_img.png',
|
||||
img=img,
|
||||
img_shape=img.shape,
|
||||
ori_shape=img.shape)
|
||||
_check_concat_cd_input(config_mod, results)
|
||||
print(f'Test testing data pipeline: \n{test_pipeline!r}')
|
||||
output_results = test_pipeline(results)
|
||||
assert output_results is not None
|
||||
|
||||
|
||||
def _check_concat_cd_input(config_mod: Config, results: dict):
|
||||
keys = []
|
||||
pipeline = config_mod.train_pipeline.copy()
|
||||
pipeline.extend(config_mod.test_pipeline)
|
||||
for t in pipeline:
|
||||
keys.append(t.type)
|
||||
if 'ConcatCDInput' in keys:
|
||||
results.update({'img2': results['img']})
|
||||
|
||||
|
||||
def _check_decode_head(decode_head_cfg, decode_head):
|
||||
if isinstance(decode_head_cfg, list):
|
||||
assert isinstance(decode_head, nn.ModuleList)
|
||||
|
|
|
@ -0,0 +1,99 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import argparse
|
||||
import glob
|
||||
import math
|
||||
import os
|
||||
import os.path as osp
|
||||
|
||||
import mmcv
|
||||
import numpy as np
|
||||
from mmengine.utils import ProgressBar
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Convert levir-cd dataset to mmsegmentation format')
|
||||
parser.add_argument('--dataset_path', help='potsdam folder path')
|
||||
parser.add_argument('-o', '--out_dir', help='output path')
|
||||
parser.add_argument(
|
||||
'--clip_size',
|
||||
type=int,
|
||||
help='clipped size of image after preparation',
|
||||
default=256)
|
||||
parser.add_argument(
|
||||
'--stride_size',
|
||||
type=int,
|
||||
help='stride of clipping original images',
|
||||
default=256)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
input_folder = args.dataset_path
|
||||
png_files = glob.glob(
|
||||
os.path.join(input_folder, '**/*.png'), recursive=True)
|
||||
output_folder = args.out_dir
|
||||
prog_bar = ProgressBar(len(png_files))
|
||||
for png_file in png_files:
|
||||
new_path = os.path.join(
|
||||
output_folder,
|
||||
os.path.relpath(os.path.dirname(png_file), input_folder))
|
||||
os.makedirs(os.path.dirname(new_path), exist_ok=True)
|
||||
label = False
|
||||
if 'label' in png_file:
|
||||
label = True
|
||||
clip_big_image(png_file, new_path, args, label)
|
||||
prog_bar.update()
|
||||
|
||||
|
||||
def clip_big_image(image_path, clip_save_dir, args, to_label=False):
|
||||
image = mmcv.imread(image_path)
|
||||
|
||||
h, w, c = image.shape
|
||||
clip_size = args.clip_size
|
||||
stride_size = args.stride_size
|
||||
|
||||
num_rows = math.ceil((h - clip_size) / stride_size) if math.ceil(
|
||||
(h - clip_size) /
|
||||
stride_size) * stride_size + clip_size >= h else math.ceil(
|
||||
(h - clip_size) / stride_size) + 1
|
||||
num_cols = math.ceil((w - clip_size) / stride_size) if math.ceil(
|
||||
(w - clip_size) /
|
||||
stride_size) * stride_size + clip_size >= w else math.ceil(
|
||||
(w - clip_size) / stride_size) + 1
|
||||
|
||||
x, y = np.meshgrid(np.arange(num_cols + 1), np.arange(num_rows + 1))
|
||||
xmin = x * clip_size
|
||||
ymin = y * clip_size
|
||||
|
||||
xmin = xmin.ravel()
|
||||
ymin = ymin.ravel()
|
||||
xmin_offset = np.where(xmin + clip_size > w, w - xmin - clip_size,
|
||||
np.zeros_like(xmin))
|
||||
ymin_offset = np.where(ymin + clip_size > h, h - ymin - clip_size,
|
||||
np.zeros_like(ymin))
|
||||
boxes = np.stack([
|
||||
xmin + xmin_offset, ymin + ymin_offset,
|
||||
np.minimum(xmin + clip_size, w),
|
||||
np.minimum(ymin + clip_size, h)
|
||||
],
|
||||
axis=1)
|
||||
|
||||
if to_label:
|
||||
image[image == 255] = 1
|
||||
image = image[:, :, 0]
|
||||
for box in boxes:
|
||||
start_x, start_y, end_x, end_y = box
|
||||
clipped_image = image[start_y:end_y, start_x:end_x] \
|
||||
if to_label else image[start_y:end_y, start_x:end_x, :]
|
||||
idx = osp.basename(image_path).split('.')[0]
|
||||
mmcv.imwrite(
|
||||
clipped_image.astype(np.uint8),
|
||||
osp.join(clip_save_dir,
|
||||
f'{idx}_{start_x}_{start_y}_{end_x}_{end_y}.png'))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
Loading…
Reference in New Issue