mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Doc]: refactor docs for basedataset (#318)
This commit is contained in:
parent
44538e56c5
commit
45f5859b50
@ -112,7 +112,7 @@ data
|
|||||||
```python
|
```python
|
||||||
import os.path as osp
|
import os.path as osp
|
||||||
|
|
||||||
from mmengine.data import BaseDataset
|
from mmengine.dataset import BaseDataset
|
||||||
|
|
||||||
|
|
||||||
class ToyDataset(BaseDataset):
|
class ToyDataset(BaseDataset):
|
||||||
@ -125,10 +125,10 @@ class ToyDataset(BaseDataset):
|
|||||||
# }
|
# }
|
||||||
def parse_data_info(self, raw_data_info):
|
def parse_data_info(self, raw_data_info):
|
||||||
data_info = raw_data_info
|
data_info = raw_data_info
|
||||||
img_prefix = self.data_prefix.get('img', None)
|
img_prefix = self.data_prefix.get('img_path', None)
|
||||||
if img_prefix is not None:
|
if img_prefix is not None:
|
||||||
data_info['img_path'] = osp.join(
|
data_info['img_path'] = osp.join(
|
||||||
img_prefix, data_info['img_path')
|
img_prefix, data_info['img_path'])
|
||||||
return data_info
|
return data_info
|
||||||
|
|
||||||
```
|
```
|
||||||
@ -146,7 +146,7 @@ pipeline = [
|
|||||||
|
|
||||||
toy_dataset = ToyDataset(
|
toy_dataset = ToyDataset(
|
||||||
data_root='data/',
|
data_root='data/',
|
||||||
data_prefix=dict(img='train/'),
|
data_prefix=dict(img_path='train/'),
|
||||||
ann_file='annotations/train.json',
|
ann_file='annotations/train.json',
|
||||||
pipeline=pipeline)
|
pipeline=pipeline)
|
||||||
```
|
```
|
||||||
@ -188,7 +188,7 @@ len(toy_dataset)
|
|||||||
在上面的例子中,标注文件的每个原始数据只包含一个训练/测试样本(通常是图像领域)。如果每个原始数据包含若干个训练/测试样本(通常是视频领域),则只需保证 `parse_data_info()` 的返回值为 `list[dict]` 即可:
|
在上面的例子中,标注文件的每个原始数据只包含一个训练/测试样本(通常是图像领域)。如果每个原始数据包含若干个训练/测试样本(通常是视频领域),则只需保证 `parse_data_info()` 的返回值为 `list[dict]` 即可:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from mmengine.data import BaseDataset
|
from mmengine.dataset import BaseDataset
|
||||||
|
|
||||||
|
|
||||||
class ToyVideoDataset(BaseDataset):
|
class ToyVideoDataset(BaseDataset):
|
||||||
@ -238,7 +238,7 @@ pipeline = [
|
|||||||
|
|
||||||
toy_dataset = ToyDataset(
|
toy_dataset = ToyDataset(
|
||||||
data_root='data/',
|
data_root='data/',
|
||||||
data_prefix=dict(img='train/'),
|
data_prefix=dict(img_path='train/'),
|
||||||
ann_file='annotations/train.json',
|
ann_file='annotations/train.json',
|
||||||
pipeline=pipeline,
|
pipeline=pipeline,
|
||||||
# 在这里传入 lazy_init 变量
|
# 在这里传入 lazy_init 变量
|
||||||
@ -279,7 +279,7 @@ pipeline = [
|
|||||||
|
|
||||||
toy_dataset = ToyDataset(
|
toy_dataset = ToyDataset(
|
||||||
data_root='data/',
|
data_root='data/',
|
||||||
data_prefix=dict(img='train/'),
|
data_prefix=dict(img_path='train/'),
|
||||||
ann_file='annotations/train.json',
|
ann_file='annotations/train.json',
|
||||||
pipeline=pipeline,
|
pipeline=pipeline,
|
||||||
# 在这里传入 serialize_data 变量
|
# 在这里传入 serialize_data 变量
|
||||||
@ -297,7 +297,7 @@ toy_dataset = ToyDataset(
|
|||||||
MMEngine 提供了 `ConcatDataset` 包装来拼接多个数据集,使用方法如下:
|
MMEngine 提供了 `ConcatDataset` 包装来拼接多个数据集,使用方法如下:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from mmengine.data import ConcatDataset
|
from mmengine.dataset import ConcatDataset
|
||||||
|
|
||||||
pipeline = [
|
pipeline = [
|
||||||
dict(type='xxx', ...),
|
dict(type='xxx', ...),
|
||||||
@ -307,13 +307,13 @@ pipeline = [
|
|||||||
|
|
||||||
toy_dataset_1 = ToyDataset(
|
toy_dataset_1 = ToyDataset(
|
||||||
data_root='data/',
|
data_root='data/',
|
||||||
data_prefix=dict(img='train/'),
|
data_prefix=dict(img_path='train/'),
|
||||||
ann_file='annotations/train.json',
|
ann_file='annotations/train.json',
|
||||||
pipeline=pipeline)
|
pipeline=pipeline)
|
||||||
|
|
||||||
toy_dataset_2 = ToyDataset(
|
toy_dataset_2 = ToyDataset(
|
||||||
data_root='data/',
|
data_root='data/',
|
||||||
data_prefix=dict(img='val/'),
|
data_prefix=dict(img_path='val/'),
|
||||||
ann_file='annotations/val.json',
|
ann_file='annotations/val.json',
|
||||||
pipeline=pipeline)
|
pipeline=pipeline)
|
||||||
|
|
||||||
@ -328,7 +328,7 @@ toy_dataset_12 = ConcatDataset(datasets=[toy_dataset_1, toy_dataset_2])
|
|||||||
MMEngine 提供了 `RepeatDataset` 包装来重复采样某个数据集若干次,使用方法如下:
|
MMEngine 提供了 `RepeatDataset` 包装来重复采样某个数据集若干次,使用方法如下:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from mmengine.data import RepeatDataset
|
from mmengine.dataset import RepeatDataset
|
||||||
|
|
||||||
pipeline = [
|
pipeline = [
|
||||||
dict(type='xxx', ...),
|
dict(type='xxx', ...),
|
||||||
@ -338,7 +338,7 @@ pipeline = [
|
|||||||
|
|
||||||
toy_dataset = ToyDataset(
|
toy_dataset = ToyDataset(
|
||||||
data_root='data/',
|
data_root='data/',
|
||||||
data_prefix=dict(img='train/'),
|
data_prefix=dict(img_path='train/'),
|
||||||
ann_file='annotations/train.json',
|
ann_file='annotations/train.json',
|
||||||
pipeline=pipeline)
|
pipeline=pipeline)
|
||||||
|
|
||||||
@ -357,16 +357,16 @@ MMEngine 提供了 `ClassBalancedDataset` 包装,来基于数据集中类别
|
|||||||
`ClassBalancedDataset` 包装假设了被包装的数据集类支持 `get_cat_ids(idx)` 方法,`get_cat_ids(idx)` 方法返回一个列表,该列表包含了 `idx` 指定的 `data_info` 包含的样本类别,使用方法如下:
|
`ClassBalancedDataset` 包装假设了被包装的数据集类支持 `get_cat_ids(idx)` 方法,`get_cat_ids(idx)` 方法返回一个列表,该列表包含了 `idx` 指定的 `data_info` 包含的样本类别,使用方法如下:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from mmengine.data import BaseDataset, ClassBalancedDataset
|
from mmengine.dataset import BaseDataset, ClassBalancedDataset
|
||||||
|
|
||||||
class ToyDataset(BaseDataset):
|
class ToyDataset(BaseDataset):
|
||||||
|
|
||||||
def parse_data_info(self, raw_data_info):
|
def parse_data_info(self, raw_data_info):
|
||||||
data_info = raw_data_info
|
data_info = raw_data_info
|
||||||
img_prefix = self.data_prefix.get('img', None)
|
img_prefix = self.data_prefix.get('img_path', None)
|
||||||
if img_prefix is not None:
|
if img_prefix is not None:
|
||||||
data_info['img_path'] = osp.join(
|
data_info['img_path'] = osp.join(
|
||||||
img_prefix, data_info['img_path')
|
img_prefix, data_info['img_path'])
|
||||||
return data_info
|
return data_info
|
||||||
|
|
||||||
# 必须支持的方法,需要返回样本的类别
|
# 必须支持的方法,需要返回样本的类别
|
||||||
@ -382,7 +382,7 @@ pipeline = [
|
|||||||
|
|
||||||
toy_dataset = ToyDataset(
|
toy_dataset = ToyDataset(
|
||||||
data_root='data/',
|
data_root='data/',
|
||||||
data_prefix=dict(img='train/'),
|
data_prefix=dict(img_path='train/'),
|
||||||
ann_file='annotations/train.json',
|
ann_file='annotations/train.json',
|
||||||
pipeline=pipeline)
|
pipeline=pipeline)
|
||||||
|
|
||||||
@ -404,7 +404,7 @@ from mmengine.registry import DATASETS
|
|||||||
@DATASETS.register_module()
|
@DATASETS.register_module()
|
||||||
class ExampleDatasetWrapper:
|
class ExampleDatasetWrapper:
|
||||||
|
|
||||||
def __init__(self, dataset, lazy_init = False, ...):
|
def __init__(self, dataset, lazy_init=False, ...):
|
||||||
# 构建原数据集(self.dataset)
|
# 构建原数据集(self.dataset)
|
||||||
if isinstance(dataset, dict):
|
if isinstance(dataset, dict):
|
||||||
self.dataset = DATASETS.build(dataset)
|
self.dataset = DATASETS.build(dataset)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user