mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Docs] Add docs for custom dataset wrapper (#222)
* add docs for custom dataset wrapper * Update basedataset.md
This commit is contained in:
parent
22d3b04575
commit
92b94e8e60
@ -219,7 +219,7 @@ class ToyVideoDataset(BaseDataset):
|
|||||||
|
|
||||||
1. 将不满足规范的标注文件转换成满足规范的标注文件,再通过上述方式使用数据集基类。
|
1. 将不满足规范的标注文件转换成满足规范的标注文件,再通过上述方式使用数据集基类。
|
||||||
|
|
||||||
2. 实现一个新的数据集类,继承自数据集基类,并且重载数据集基类的 `load_data_list(self, ann_file):` 函数,处理不满足规范的标注文件,并保证返回值为 `list[dict]`,其中每个 `dict` 代表一个数据样本。
|
2. 实现一个新的数据集类,继承自数据集基类,并且重载数据集基类的 `load_data_list(self):` 函数,处理不满足规范的标注文件,并保证返回值为 `list[dict]`,其中每个 `dict` 代表一个数据样本。
|
||||||
|
|
||||||
## 数据集基类的其它特性
|
## 数据集基类的其它特性
|
||||||
|
|
||||||
@ -391,3 +391,92 @@ toy_dataset_repeat = ClassBalancedDataset(dataset=toy_dataset, oversample_thr=1e
|
|||||||
```
|
```
|
||||||
|
|
||||||
上述例子将数据集的 `train` 部分以 `oversample_thr=1e-3` 重新采样,具体地,对于数据集中出现频率低于 `1e-3` 的类别,会重复采样该类别对应的样本,否则不重复采样,具体采样策略请参考 `ClassBalancedDataset` API 文档。
|
上述例子将数据集的 `train` 部分以 `oversample_thr=1e-3` 重新采样,具体地,对于数据集中出现频率低于 `1e-3` 的类别,会重复采样该类别对应的样本,否则不重复采样,具体采样策略请参考 `ClassBalancedDataset` API 文档。
|
||||||
|
|
||||||
|
### 自定义数据集类包装
|
||||||
|
|
||||||
|
由于数据集基类实现了懒加载的功能,因此在自定义数据集类包装时,需要遵循一些规则,下面以一个例子的方式来展示如何自定义数据集类包装:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from mmengine.dataset import BaseDataset
|
||||||
|
from mmengine.registry import DATASETS
|
||||||
|
|
||||||
|
|
||||||
|
@DATASETS.register_module()
|
||||||
|
class ExampleDatasetWrapper:
|
||||||
|
|
||||||
|
def __init__(self, dataset, lazy_init = False, ...):
|
||||||
|
# 构建原数据集(self.dataset)
|
||||||
|
if isinstance(dataset, dict):
|
||||||
|
self.dataset = DATASETS.build(dataset)
|
||||||
|
elif isinstance(dataset, BaseDataset):
|
||||||
|
self.dataset = dataset
|
||||||
|
else:
|
||||||
|
raise TypeError(
|
||||||
|
'elements in datasets sequence should be config or '
|
||||||
|
f'`BaseDataset` instance, but got {type(dataset)}')
|
||||||
|
# 记录原数据集的元信息
|
||||||
|
self._metainfo = self.dataset.metainfo
|
||||||
|
|
||||||
|
'''
|
||||||
|
1. 在这里实现一些代码,来记录用于包装数据集的一些超参。
|
||||||
|
'''
|
||||||
|
|
||||||
|
self._fully_initialized = False
|
||||||
|
if not lazy_init:
|
||||||
|
self.full_init()
|
||||||
|
|
||||||
|
def full_init(self):
|
||||||
|
if self._fully_initialized:
|
||||||
|
return
|
||||||
|
|
||||||
|
# 将原数据集完全初始化
|
||||||
|
self.dataset.full_init()
|
||||||
|
|
||||||
|
'''
|
||||||
|
2. 在这里实现一些代码,来包装原数据集。
|
||||||
|
'''
|
||||||
|
|
||||||
|
self._fully_initialized = True
|
||||||
|
|
||||||
|
@force_full_init
|
||||||
|
def _get_ori_dataset_idx(self, idx: int):
|
||||||
|
|
||||||
|
'''
|
||||||
|
3. 在这里实现一些代码,来将包装的索引 `idx` 映射到原数据集的索引 `ori_idx`。
|
||||||
|
'''
|
||||||
|
ori_idx = ...
|
||||||
|
|
||||||
|
return ori_idx
|
||||||
|
|
||||||
|
# 提供与 `self.dataset` 一样的对外接口。
|
||||||
|
@force_full_init
|
||||||
|
def get_data_info(self, idx):
|
||||||
|
sample_idx = self._get_ori_dataset_idx(idx)
|
||||||
|
return self.dataset.get_data_info(sample_idx)
|
||||||
|
|
||||||
|
# 提供与 `self.dataset` 一样的对外接口。
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
if not self._fully_initialized:
|
||||||
|
warnings.warn('Please call `full_init` method manually to '
|
||||||
|
'accelerate the speed.')
|
||||||
|
self.full_init()
|
||||||
|
|
||||||
|
sample_idx = self._get_ori_dataset_idx(idx)
|
||||||
|
return self.dataset[sample_idx]
|
||||||
|
|
||||||
|
# 提供与 `self.dataset` 一样的对外接口。
|
||||||
|
@force_full_init
|
||||||
|
def __len__(self):
|
||||||
|
|
||||||
|
'''
|
||||||
|
4. 在这里实现一些代码,来计算包装数据集之后的长度。
|
||||||
|
'''
|
||||||
|
len_wrapper = ...
|
||||||
|
|
||||||
|
return len_wrapper
|
||||||
|
|
||||||
|
# 提供与 `self.dataset` 一样的对外接口。
|
||||||
|
@property
|
||||||
|
def metainfo(self)
|
||||||
|
return copy.deepcopy(self._metainfo)
|
||||||
|
```
|
||||||
|
Loading…
x
Reference in New Issue
Block a user