Ma Zerun 85b1eae7f1
Bump to v1.0.0rc0 (#1007)
* Update docs.

* Update requirements.

* Update config readme and docstring.

* Update CONTRIBUTING.md

* Update README

* Update requirements/mminstall.txt

Co-authored-by: Yifei Yang <2744335995@qq.com>

* Update MMEngine docs link and add to readthedocs requirement.

Co-authored-by: Yifei Yang <2744335995@qq.com>
2022-08-31 23:57:51 +08:00

74 lines
2.5 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# 添加新数据集
用户可以编写一个继承自 [`BasesDataset`](https://mmclassification.readthedocs.io/zh_CN/latest/_modules/mmcls/datasets/base_dataset.html#BaseDataset) 的新数据集类,并重载 `load_data_list(self)` 方法,类似 [CIFAR10](https://github.com/open-mmlab/mmclassification/blob/master/mmcls/datasets/cifar.py) 和 [ImageNet](https://github.com/open-mmlab/mmclassification/blob/master/mmcls/datasets/imagenet.py)。
通常,此方法返回一个包含所有样本的列表,其中的每个样本都是一个字典。字典中包含了必要的数据信息,例如 `img``gt_label`
假设我们将要实现一个 `Filelist` 数据集,该数据集将使用文件列表进行训练和测试。注释列表的格式如下:
```text
000001.jpg 0
000002.jpg 1
...
```
## 1. 创建数据集类
我们可以在 `mmcls/datasets/filelist.py` 中创建一个新的数据集类以加载数据。
```python
from mmcls.registry import DATASETS
from .base_dataset import BaseDataset
@DATASETS.register_module()
class Filelist(BaseDataset):
def load_data_list(self):
assert isinstance(self.ann_file, str)
data_list = []
with open(self.ann_file) as f:
samples = [x.strip().split(' ') for x in f.readlines()]
for filename, gt_label in samples:
img_path = add_prefix(filename, self.img_prefix)
info = {'img_path': img_path, 'gt_label': int(gt_label)}
data_list.append(info)
return data_list
```
## 2. 添加进 MMCls 库
将新的数据集类加入到 `mmcls/datasets/__init__.py` 中:
```python
from .base_dataset import BaseDataset
...
from .filelist import Filelist
__all__ = [
'BaseDataset', ... ,'Filelist'
]
```
### 3. 修改相关配置文件
然后在配置文件中,为了使用 `Filelist`,用户可以按以下方式修改配置
```python
train_dataloader = dict(
...
dataset=dict(
type='Filelist',
ann_file='image_list.txt',
pipeline=train_pipeline,
)
)
```
所有继承 [`BaseDataset`](https://github.com/open-mmlab/mmclassification/blob/master/mmcls/datasets/base_dataset.py) 的数据集类都具有**懒加载**以及**节省内存**的特性,可以参考相关文档 [mmengine.basedataset](https://github.com/open-mmlab/mmengine/blob/main/docs/zh_cn/tutorials/basedataset.md)。
```{note}
如果数据样本时获取的字典中,只包含了 'img_path' 不包含 'img' 则在 pipeline 中必须包含 'LoadImgFromFile'。
```