[Doc]: refactor docs for basedataset (#318)

pull/321/head
Tao Gong 2022-06-21 14:58:10 +08:00 committed by GitHub
parent 44538e56c5
commit 45f5859b50
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 17 additions and 17 deletions

View File

@ -112,7 +112,7 @@ data
```python
import os.path as osp
from mmengine.data import BaseDataset
from mmengine.dataset import BaseDataset
class ToyDataset(BaseDataset):
@ -125,10 +125,10 @@ class ToyDataset(BaseDataset):
# }
def parse_data_info(self, raw_data_info):
data_info = raw_data_info
img_prefix = self.data_prefix.get('img', None)
img_prefix = self.data_prefix.get('img_path', None)
if img_prefix is not None:
data_info['img_path'] = osp.join(
img_prefix, data_info['img_path')
img_prefix, data_info['img_path'])
return data_info
```
@ -146,7 +146,7 @@ pipeline = [
toy_dataset = ToyDataset(
data_root='data/',
data_prefix=dict(img='train/'),
data_prefix=dict(img_path='train/'),
ann_file='annotations/train.json',
pipeline=pipeline)
```
@ -188,7 +188,7 @@ len(toy_dataset)
在上面的例子中,标注文件的每个原始数据只包含一个训练/测试样本(通常是图像领域)。如果每个原始数据包含若干个训练/测试样本(通常是视频领域),则只需保证 `parse_data_info()` 的返回值为 `list[dict]` 即可:
```python
from mmengine.data import BaseDataset
from mmengine.dataset import BaseDataset
class ToyVideoDataset(BaseDataset):
@ -238,7 +238,7 @@ pipeline = [
toy_dataset = ToyDataset(
data_root='data/',
data_prefix=dict(img='train/'),
data_prefix=dict(img_path='train/'),
ann_file='annotations/train.json',
pipeline=pipeline,
# 在这里传入 lazy_init 变量
@ -279,7 +279,7 @@ pipeline = [
toy_dataset = ToyDataset(
data_root='data/',
data_prefix=dict(img='train/'),
data_prefix=dict(img_path='train/'),
ann_file='annotations/train.json',
pipeline=pipeline,
# 在这里传入 serialize_data 变量
@ -297,7 +297,7 @@ toy_dataset = ToyDataset(
MMEngine 提供了 `ConcatDataset` 包装来拼接多个数据集,使用方法如下:
```python
from mmengine.data import ConcatDataset
from mmengine.dataset import ConcatDataset
pipeline = [
dict(type='xxx', ...),
@ -307,13 +307,13 @@ pipeline = [
toy_dataset_1 = ToyDataset(
data_root='data/',
data_prefix=dict(img='train/'),
data_prefix=dict(img_path='train/'),
ann_file='annotations/train.json',
pipeline=pipeline)
toy_dataset_2 = ToyDataset(
data_root='data/',
data_prefix=dict(img='val/'),
data_prefix=dict(img_path='val/'),
ann_file='annotations/val.json',
pipeline=pipeline)
@ -328,7 +328,7 @@ toy_dataset_12 = ConcatDataset(datasets=[toy_dataset_1, toy_dataset_2])
MMEngine 提供了 `RepeatDataset` 包装来重复采样某个数据集若干次,使用方法如下:
```python
from mmengine.data import RepeatDataset
from mmengine.dataset import RepeatDataset
pipeline = [
dict(type='xxx', ...),
@ -338,7 +338,7 @@ pipeline = [
toy_dataset = ToyDataset(
data_root='data/',
data_prefix=dict(img='train/'),
data_prefix=dict(img_path='train/'),
ann_file='annotations/train.json',
pipeline=pipeline)
@ -357,16 +357,16 @@ MMEngine 提供了 `ClassBalancedDataset` 包装,来基于数据集中类别
`ClassBalancedDataset` 包装假设了被包装的数据集类支持 `get_cat_ids(idx)` 方法,`get_cat_ids(idx)` 方法返回一个列表,该列表包含了 `idx` 指定的 `data_info` 包含的样本类别,使用方法如下:
```python
from mmengine.data import BaseDataset, ClassBalancedDataset
from mmengine.dataset import BaseDataset, ClassBalancedDataset
class ToyDataset(BaseDataset):
def parse_data_info(self, raw_data_info):
data_info = raw_data_info
img_prefix = self.data_prefix.get('img', None)
img_prefix = self.data_prefix.get('img_path', None)
if img_prefix is not None:
data_info['img_path'] = osp.join(
img_prefix, data_info['img_path')
img_prefix, data_info['img_path'])
return data_info
# 必须支持的方法,需要返回样本的类别
@ -382,7 +382,7 @@ pipeline = [
toy_dataset = ToyDataset(
data_root='data/',
data_prefix=dict(img='train/'),
data_prefix=dict(img_path='train/'),
ann_file='annotations/train.json',
pipeline=pipeline)
@ -404,7 +404,7 @@ from mmengine.registry import DATASETS
@DATASETS.register_module()
class ExampleDatasetWrapper:
def __init__(self, dataset, lazy_init = False, ...):
def __init__(self, dataset, lazy_init=False, ...):
# 构建原数据集self.dataset
if isinstance(dataset, dict):
self.dataset = DATASETS.build(dataset)