mmpretrain/docs_zh-CN/tutorials/new_dataset.md

141 lines
4.2 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

# 教程 2如何添加新数据集
## 通过重新组织数据来自定义数据集
### 将数据集重新组织为已有格式
最简单的方法是将数据集转换为现有的数据集格式 (ImageNet)。
为了训练,根据图片的类别,存放至不同子目录下。训练数据文件夹结构如下所示:
```
imagenet
├── ...
├── train
│ ├── n01440764
│ │ ├── n01440764_10026.JPEG
│ │ ├── n01440764_10027.JPEG
│ │ ├── ...
│ ├── ...
│ ├── n15075141
│ │ ├── n15075141_999.JPEG
│ │ ├── n15075141_9993.JPEG
│ │ ├── ...
```
为了验证,我们提供了一个注释列表。列表的每一行都包含一个文件名及其相应的真实标签。格式如下:
```
ILSVRC2012_val_00000001.JPEG 65
ILSVRC2012_val_00000002.JPEG 970
ILSVRC2012_val_00000003.JPEG 230
ILSVRC2012_val_00000004.JPEG 809
ILSVRC2012_val_00000005.JPEG 516
```
注:真实标签的值应该位于 `[0, 类别数目 - 1]` 之间
### 自定义数据集的示例
用户可以编写一个继承自 `BasesDataset` 的新数据集类,并重载 `load_annotations(self)` 方法,类似 [CIFAR10](https://github.com/open-mmlab/mmclassification/blob/master/mmcls/datasets/cifar.py) 和 [ImageNet](https://github.com/open-mmlab/mmclassification/blob/master/mmcls/datasets/imagenet.py)。
通常,此方法返回一个包含所有样本的列表,其中的每个样本都是一个字典。字典中包含了必要的数据信息,例如 `img``gt_label`
假设我们将要实现一个 `Filelist` 数据集,该数据集将使用文件列表进行训练和测试。注释列表的格式如下:
```
000001.jpg 0
000002.jpg 1
```
我们可以在 `mmcls/datasets/filelist.py` 中创建一个新的数据集类以加载数据。
```python
import mmcv
import numpy as np
from .builder import DATASETS
from .base_dataset import BaseDataset
@DATASETS.register_module()
class Filelist(BaseDataset):
def load_annotations(self):
assert isinstance(self.ann_file, str)
data_infos = []
with open(self.ann_file) as f:
samples = [x.strip().split(' ') for x in f.readlines()]
for filename, gt_label in samples:
info = {'img_prefix': self.data_prefix}
info['img_info'] = {'filename': filename}
info['gt_label'] = np.array(gt_label, dtype=np.int64)
data_infos.append(info)
return data_infos
```
将新的数据集类加入到 `mmcls/datasets/__init__.py` 中:
```python
from .base_dataset import BaseDataset
...
from .filelist import Filelist
__all__ = [
'BaseDataset', ... ,'Filelist'
]
```
然后在配置文件中,为了使用 `Filelist`,用户可以按以下方式修改配置
```python
train = dict(
type='Filelist',
ann_file = 'image_list.txt',
pipeline=train_pipeline
)
```
## 通过混合数据集来自定义数据集
MMClassification 还支持混合数据集以进行训练。目前支持合并和重复数据集。
### 重复数据集
我们使用 `RepeatDataset` 作为一个重复数据集的封装。举个例子,假设原始数据集是 `Dataset_A`,为了重复它,我们需要如下的配置文件:
```python
dataset_A_train = dict(
type='RepeatDataset',
times=N,
dataset=dict( # 这里是 Dataset_A 的原始配置
type='Dataset_A',
...
pipeline=train_pipeline
)
)
```
### 类别平衡数据集
我们使用 `ClassBalancedDataset` 作为根据类别频率对数据集进行重复采样的封装类。进行重复采样的数据集需要实现函数 `self.get_cat_ids(idx)` 以支持 `ClassBalancedDataset`
举个例子,按照 `oversample_thr=1e-3``Dataset_A` 进行重复采样,需要如下的配置文件:
```python
dataset_A_train = dict(
type='ClassBalancedDataset',
oversample_thr=1e-3,
dataset=dict( # 这里是 Dataset_A 的原始配置
type='Dataset_A',
...
pipeline=train_pipeline
)
)
```
更加具体的细节,请参考 [源代码](https://github.com/open-mmlab/mmclassification/tree/master/mmcls/datasets/dataset_wrappers.py)。