[Docs] Add docs for custom dataset wrapper (#222)
* add docs for custom dataset wrapper * Update basedataset.mdpull/231/head
parent
22d3b04575
commit
92b94e8e60
|
@ -219,7 +219,7 @@ class ToyVideoDataset(BaseDataset):
|
|||
|
||||
1. 将不满足规范的标注文件转换成满足规范的标注文件,再通过上述方式使用数据集基类。
|
||||
|
||||
2. 实现一个新的数据集类,继承自数据集基类,并且重载数据集基类的 `load_data_list(self, ann_file):` 函数,处理不满足规范的标注文件,并保证返回值为 `list[dict]`,其中每个 `dict` 代表一个数据样本。
|
||||
2. 实现一个新的数据集类,继承自数据集基类,并且重载数据集基类的 `load_data_list(self):` 函数,处理不满足规范的标注文件,并保证返回值为 `list[dict]`,其中每个 `dict` 代表一个数据样本。
|
||||
|
||||
## 数据集基类的其它特性
|
||||
|
||||
|
@ -391,3 +391,92 @@ toy_dataset_repeat = ClassBalancedDataset(dataset=toy_dataset, oversample_thr=1e
|
|||
```
|
||||
|
||||
上述例子将数据集的 `train` 部分以 `oversample_thr=1e-3` 重新采样,具体地,对于数据集中出现频率低于 `1e-3` 的类别,会重复采样该类别对应的样本,否则不重复采样,具体采样策略请参考 `ClassBalancedDataset` API 文档。
|
||||
|
||||
### 自定义数据集类包装
|
||||
|
||||
由于数据集基类实现了懒加载的功能,因此在自定义数据集类包装时,需要遵循一些规则,下面以一个例子的方式来展示如何自定义数据集类包装:
|
||||
|
||||
```python
|
||||
from mmengine.dataset import BaseDataset
|
||||
from mmengine.registry import DATASETS
|
||||
|
||||
|
||||
@DATASETS.register_module()
|
||||
class ExampleDatasetWrapper:
|
||||
|
||||
def __init__(self, dataset, lazy_init = False, ...):
|
||||
# 构建原数据集(self.dataset)
|
||||
if isinstance(dataset, dict):
|
||||
self.dataset = DATASETS.build(dataset)
|
||||
elif isinstance(dataset, BaseDataset):
|
||||
self.dataset = dataset
|
||||
else:
|
||||
raise TypeError(
|
||||
'elements in datasets sequence should be config or '
|
||||
f'`BaseDataset` instance, but got {type(dataset)}')
|
||||
# 记录原数据集的元信息
|
||||
self._metainfo = self.dataset.metainfo
|
||||
|
||||
'''
|
||||
1. 在这里实现一些代码,来记录用于包装数据集的一些超参。
|
||||
'''
|
||||
|
||||
self._fully_initialized = False
|
||||
if not lazy_init:
|
||||
self.full_init()
|
||||
|
||||
def full_init(self):
|
||||
if self._fully_initialized:
|
||||
return
|
||||
|
||||
# 将原数据集完全初始化
|
||||
self.dataset.full_init()
|
||||
|
||||
'''
|
||||
2. 在这里实现一些代码,来包装原数据集。
|
||||
'''
|
||||
|
||||
self._fully_initialized = True
|
||||
|
||||
@force_full_init
|
||||
def _get_ori_dataset_idx(self, idx: int):
|
||||
|
||||
'''
|
||||
3. 在这里实现一些代码,来将包装的索引 `idx` 映射到原数据集的索引 `ori_idx`。
|
||||
'''
|
||||
ori_idx = ...
|
||||
|
||||
return ori_idx
|
||||
|
||||
# 提供与 `self.dataset` 一样的对外接口。
|
||||
@force_full_init
|
||||
def get_data_info(self, idx):
|
||||
sample_idx = self._get_ori_dataset_idx(idx)
|
||||
return self.dataset.get_data_info(sample_idx)
|
||||
|
||||
# 提供与 `self.dataset` 一样的对外接口。
|
||||
def __getitem__(self, idx):
|
||||
if not self._fully_initialized:
|
||||
warnings.warn('Please call `full_init` method manually to '
|
||||
'accelerate the speed.')
|
||||
self.full_init()
|
||||
|
||||
sample_idx = self._get_ori_dataset_idx(idx)
|
||||
return self.dataset[sample_idx]
|
||||
|
||||
# 提供与 `self.dataset` 一样的对外接口。
|
||||
@force_full_init
|
||||
def __len__(self):
|
||||
|
||||
'''
|
||||
4. 在这里实现一些代码,来计算包装数据集之后的长度。
|
||||
'''
|
||||
len_wrapper = ...
|
||||
|
||||
return len_wrapper
|
||||
|
||||
# 提供与 `self.dataset` 一样的对外接口。
|
||||
@property
|
||||
def metainfo(self)
|
||||
return copy.deepcopy(self._metainfo)
|
||||
```
|
||||
|
|
Loading…
Reference in New Issue