mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
Update docs of data element (#180)
* refine docs of data element * update * resolve comments
This commit is contained in:
parent
66e528830b
commit
7367df7ea7
@ -12,7 +12,7 @@
|
||||
tutorials/hook.md
|
||||
tutorials/optimizer.md
|
||||
tutorials/param_scheduler.md
|
||||
tutorials/abstract_data_interface.md
|
||||
tutorials/data_element.md
|
||||
tutorials/basedataset.md
|
||||
tutorials/evaluator.md
|
||||
tutorials/distributed.md
|
||||
|
@ -30,15 +30,16 @@ for img, data_sample in dataloader:
|
||||
|
||||
## 设计
|
||||
|
||||
一个算法库中的数据可以被归类成具有不同性质的数据元素。一个训练样本(如一张图片)的所有数据元素构成了一个训练样本的完整数据,称为样本数据。相应地,MMEngine 为数据元素和样本数据分别定义了一种封装。
|
||||
一个算法库中的数据可以被归类成具有不同性质的数据元素 (data element)。一张图片的检测框标注、模型在这张图片上预测出的检测框、以及一张图片的所有标注信息(包含检测框、语义分割图等)都可以被抽象成数据元素。因此,MMEngine 定义了数据元素的基类 `BaseDataElement` 和它所提供的基本的增/删/改/查等基本功能。基于 MMEngine 的算法库可以定义由 `BaseDataElement` 派生而来的数据元素封装,作为该库的组件之间的抽象数据接口。
|
||||
|
||||
1. 数据元素的封装: 数据元素指的是某一算法任务上的预测数据或标注,例如检测框,实例掩码,语义分割掩码等。因为标注数据和预测数据往往具有相似的性质(例如模型的预测框和标注框具有相同的性质),MMEngine 使用相同的抽象数据接口来封装预测数据和标注数据,并推荐使用命名来区分他们,如使用 `gt_instances` 和 `pred_instances` 来区分标注和预测的实例数据。另外,我们将数据元素区分为实例级别,像素级别,和标签级别。这些类型各有自己的特点,因此,MMEngine 定义了数据元素的基类 `BaseDataElement`,并由此派生出了 3 类数据结构来封装不同类型的标注数据或者模型的预测结果:`InstanceData`, `PixelData`, 和 `LabelData`。这些接口将被用于模型内各个模块之间的数据传递。
|
||||
一种典型数据元素是某一算法任务上的预测数据或标注:例如检测框,实例掩码,语义分割掩码,和图像标签等。这些数据元素可以进一步区分为实例级别,像素级别,和标签级别。这些类型各有自己的特点。因此,MMEngine 基于 `BaseDataElement` 派生出了 3 类数据结构来封装不同类型的标注数据或者模型的预测结果:`InstanceData`, `PixelData`, 和 `LabelData`。这些接口可以被用于模型内各个模块之间的数据传递。因为标注数据和预测数据往往具有相似的性质(例如模型的预测框和标注框具有相同的性质),MMEngine 使用相同的抽象数据接口来封装预测数据和标注数据,并推荐使用命名来区分他们,如使用 `gt_instances` 和 `pred_instances` 来区分标注和预测的实例数据。
|
||||
|
||||
2. 样本数据的封装:一个训练样本(例如一张图片)的所有标注和预测构成了一个样本数据。一般情况下,一张图片可以同时有多种类型的标注和/或预测(例如,同时拥有像素级别的语义分割标注和实例级别的检测框标注)。因此,MMEngine 定义了 `BaseDataSample`作为样本数据封装的基类。也就是说,**`BaseDataSample` 的属性会是各种类型的数据元素**,OpenMMLab 算法库将基于 `BaseDataSample` 实现自己的抽象数据接口,来封装一个算法库中单个样本的所有相关数据,作为 dataset,model,visualizer,和 evaluator 组件之间的数据接口。
|
||||
算法库中另一种常见的数据元素是一个训练样本(例如一张图片)的所有标注和预测构成的数据元素。一般情况下,一张图片可以同时有多种类型的标注和/或预测(例如,同时拥有像素级别的语义分割标注和实例级别的检测框标注)。 一个训练样本(例如一张图片)的所有标注和预测经常在 dataset,model,visualizer,和 evaluator 组件之间被传递。为了简化组件之间的接口,我们可以将他们当作一个大的数据元素并对他们进行封装,这类数据元素在 OpenMMLab 算法库中一般被称为 `XXDataSample`。
|
||||
因此,类似于 `nn.Module` 的派生类内部可以拥有类型为 `nn.Module` 的属性,`BaseDataElement` 也允许封装 `BaseDataElement` 作为它的属性。这样的类一般在算法库中封装一个样本的全体数据, 并且**它的属性一般会是各种类型的数据元素**。例如,MMDetection 由 `BaseDataElement` 派生出了 `DetDataSample` 来封装该算法库中一个样本的标注与预测的全部数据元素,`DetDataSample` 的属性一般是 `InstanceData`。
|
||||
|
||||
两种类型的封装和他们的继承关系如下图所示
|
||||
他们的关系如下图所示
|
||||
|
||||

|
||||

|
||||
|
||||
为了保证抽象数据接口内数据的完整性,抽象数据接口内部有两种数据,除了被封装的数据(data)本身,还有一种是数据的元信息(metainfo),例如图片大小和 ID 等。
|
||||
两种类型的抽象数据接口都可以作为 Python 类去使用和操作他们的属性。同时,因为他们封装的数据大多是 Tensor,他们也提供了类似 Tensor 的基础操作。
|
||||
@ -48,17 +49,19 @@ for img, data_sample in dataloader:
|
||||
### BaseDataElement
|
||||
|
||||
MMEngine 为数据元素的封装提供了一个基类 `BaseDataElement`。
|
||||
基于 `BaseDataElement`,MMEngine 还实现了 `InstanceData`, `PixelData`, `LabelData` 和 `GeneralData` 四个典型的子类,封装了实例级别,像素级别,标签级别和其他普通的数据元素,并针对他们的数据特性支持了一些额外的功能。
|
||||
基于 `BaseDataElement`,MMEngine 还实现了 `InstanceData`, `PixelData`, `LabelData` 三个典型的子类,封装了实例级别,像素级别,标签级别的数据元素,并针对他们的数据特性支持了一些额外的功能。
|
||||
|
||||
1. `InstanceData`:封装检测框、框对应的标签和实例掩码、甚至关键点等实例级别数据,`InstanceData` 假定它封装的数据具有相同的长度 N,N 代表实例的个数,并基于此假定对数据进行校验、支持对实例进行索引和拼接。
|
||||
2. `PixelData`:封装逐像素级别的数据,如语义分割图和深度图等。`PixelData` 假定它封装的数据有相同的长度和宽度,第一和第二维为图片的长宽,第三维为通道数。`PixelData` 基于此假定对数据进行校验、支持对实例进行空间维度的索引和各维度的拼接。
|
||||
3. `LabelData`:封装标签数据,如场景分类标签等。
|
||||
4. `GeneralData`:`BaseDataElement` 的等价类。虽然 `BaseDataElement` 可以作为独立的模块被使用,但是我们不推荐用户直接使用基类。因此,MMEngine 额外实现了 `GeneralData` 。`GeneralData` 保持了和 `InstanceData`, `PixelData`, 以及 `LabelData` 一致的命名习惯和继承层次。它拥有和 `BaseDataElement` 完全一样的功能和接口,对数据元素没有任何假定,仅支持最基本的增删改查功能。我们推荐用户在实际应用过程中使用 `GeneralData` 而非 `BaseDataElement` 来保持使用的一致性,在开发过程中继承 `BaseDataElement` 来保持继承层次的统一。在下文中,为了阐明数据元素封装的基本用法,我们还是使用 `BaseDataElement` 来进行描述和用例展示。
|
||||
|
||||
`BaseDataElement` 中存在两种类型的数据,一种是 `data` 类型,如标注框、框的标签、和实例掩码等;另一种是 `metainfo` 类型,包含数据的元信息以确保数据的完整性,如 `img_shape`, `img_id` 等数据所在图片的一些基本信息,方便可视化等情况下对数据进行恢复和使用。用户在创建 `BaseDataElement` 的过程中需要对这两类属性的数据进行显式地区分和声明。
|
||||
|
||||
#### 1. 数据元素的创建
|
||||
|
||||
`BaseDataElement` 的 data 参数可以直接通过 `key=value` 的方式自由添加,metainfo 的字段需要显式通过关键字 `metainfo` 指定。
|
||||
`BaseDataElement` 支持 `from_dict` 接口,支持从 dict 构建 `BaseDataElement`。
|
||||
|
||||
```python
|
||||
# 可以声明一个空的 object
|
||||
gt_instances = BaseDataElement()
|
||||
@ -69,40 +72,47 @@ img_id = 0 # 图像的 ID
|
||||
H = 800 # 图像的高度
|
||||
W = 1333 # 图像的宽度
|
||||
|
||||
# 显式声明 BaseDataElement 的参数 metainfo 和 data
|
||||
gt_instances = BaseDataElement(
|
||||
metainfo=dict(img_id=img_id, img_shape=(H, W)),
|
||||
data=dict(bboxes=bboxes, scores=scores))
|
||||
# 直接设置 BaseDataElement 的 data 参数
|
||||
gt_instances = BaseDataElement(bboxes=bboxes, scores=scores)
|
||||
|
||||
# 不显式声明的时候,传入字典将设置 BaseDataElement 的参数 metainfo
|
||||
gt_instances = BaseDataElement(dict(img_id=img_id, img_shape=(H, W)))
|
||||
# 显式声明来设置 BaseDataElement 的参数 metainfo
|
||||
gt_instances = BaseDataElement(
|
||||
bboxes=bboxes,
|
||||
scores=scores,
|
||||
metainfo=dict(img_id=img_id, img_shape=(H, W)))
|
||||
|
||||
# 通过 from_dict,传入字典将设置 BaseDataElement 的参数 data
|
||||
BaseDataElement.from_dict(dict(bboxes=bboxes, scores=scores))
|
||||
BaseDataElement.from_dict(
|
||||
dict(bboxes=bboxes, scores=scores),
|
||||
metainfo=dict(img_id=img_id, img_shape=(H, W)))
|
||||
```
|
||||
|
||||
#### 2. `new` 函数
|
||||
#### 2. `new` 与 `clone` 函数
|
||||
|
||||
用户可以使用 `new()` 函数通过已有的数据接口创建一个具有相同状态和数据的抽象数据接口。用户可以在创建新 `BaseDataElement` 时设置 metainfo 和 data,使得新的 BaseDataElement 有相同的状态但是不同的数据。
|
||||
也可以直接使用 `new()` 来获得一份深拷贝。
|
||||
也可以直接使用 `clone()` 来获得一份深拷贝,`clone()` 函数的行为与 PyTorch 中 Tensor 的 `clone()` 参数保持一致。
|
||||
|
||||
```python
|
||||
gt_instances = BaseDataElement()
|
||||
|
||||
# 可以在创建新 `BaseDataElement` 时设置 metainfo 和 data,使得新的 BaseDataElement 有不同的数据但是数据在相同的 device 上
|
||||
gt_instances1 = gt_instance.new(
|
||||
bboxes=torch.rand((5, 4)),
|
||||
scores=torch.rand((5,)),
|
||||
metainfo=dict(img_id=1, img_shape=(640, 640)),
|
||||
data=dict(bboxes=torch.rand((5, 4)), scores=torch.rand((5,)))
|
||||
)
|
||||
|
||||
# 也可以声明一个新的 object,新的 object 会拥有和 gt_instance 相同的 data 和 metainfo 内容
|
||||
gt_instances2 = gt_instances1.new()
|
||||
# 也可以通过 `clone` 构建一个新的 object,新的 object 会拥有和 gt_instance 相同的 data 和 metainfo 内容以及状态。
|
||||
gt_instances2 = gt_instances1.clone()
|
||||
```
|
||||
|
||||
#### 3. 属性的增加与查询
|
||||
|
||||
用户可以像增加类属性那样增加 `BaseDataElement` 的属性,此时数据会被**当作 data 类型**增加到 `BaseDataElement` 中。
|
||||
如果需要增加 metainfo 属性,用户应当使用 `set_metainfo`。
|
||||
用户可以通过 `metainfo_keys`,`metainfo_values`,和`metainfo_items` 来访问只存在于 metainfo 中的键值,
|
||||
也可以通过 `data_keys`,`data_values`,和 `data_items` 来访问只存在于 data 中的键值。
|
||||
用户还能通过 `keys`,`values`, `items` 来访问 `BaseDataElement` 的所有的属性并且不区分他们的类型。
|
||||
用户可以可以通过 `keys`,`values`,和 `items` 来访问只存在于 data 中的键值,也可以通过 `metainfo_keys`,`metainfo_values`,和`metainfo_items` 来访问只存在于 metainfo 中的键值。
|
||||
用户还能通过 `all_keys`,`all_values`, `all_items` 来访问 `BaseDataElement` 的所有的属性并且不区分他们的类型。
|
||||
|
||||
**注意:**
|
||||
|
||||
@ -117,19 +127,19 @@ assert 'img_shape' in gt_instances.metainfo_keys()
|
||||
# 'img_shape' 是 gt_instances 的属性
|
||||
assert 'img_shape' in gt_instances
|
||||
# img_shape 不是 gt_instances 的 data 字段
|
||||
assert 'img_shape' not in gt_instances.data_keys()
|
||||
# 通过 keys 来访问所有属性
|
||||
assert 'img_shape' in gt_instances.keys()
|
||||
assert 'img_shape' not in gt_instances.all_keys()
|
||||
# 通过 all_keys 来访问所有属性
|
||||
assert 'img_shape' in gt_instances.all_keys()
|
||||
# 访问类属性一样访问 'img_shape'
|
||||
print(gt_instances.img_shape)
|
||||
|
||||
# 直接设置 gt_instance 的 scores 属性,默认该数据属于 data
|
||||
gt_instances.scores = torch.rand((5,))
|
||||
assert 'scores' in gt_instances.data_keys()
|
||||
assert 'scores' in gt_instances.items()
|
||||
# 'scores' 是 gt_instances 的属性
|
||||
assert 'scores' in gt_instances
|
||||
# 通过 keys 来访问所有属性
|
||||
assert 'scores' in gt_instances.keys()
|
||||
# 通过 all_keys 来访问所有属性
|
||||
assert 'scores' in gt_instances.all_keys()
|
||||
# scores 不是 gt_instances 的 metainfo 字段
|
||||
assert 'scores' not in gt_instances.metainfo_keys()
|
||||
# 访问类属性一样访问 'scores'
|
||||
@ -137,23 +147,23 @@ print(gt_instances.scores)
|
||||
|
||||
# 设置 gt_instances 的 data 字段 bboxes
|
||||
gt_instances.bboxes = torch.rand((5, 4))
|
||||
assert 'bboxes' in gt_instances.data_keys()
|
||||
assert 'bboxes' in gt_instances.items()
|
||||
# 'bboxes' 是 gt_instances 的属性
|
||||
assert 'bboxes' in gt_instances
|
||||
# 通过 keys 来访问所有属性
|
||||
assert 'bboxes' in gt_instances.keys()
|
||||
# 通过 all_keys 来访问所有属性
|
||||
assert 'bboxes' in gt_instances.all_keys()
|
||||
# bboxes 不是 gt_instances 的 metainfo 字段
|
||||
assert 'bboxes' not in gt_instances.metainfo_keys()
|
||||
# 访问类属性一样访问 'bboxes'
|
||||
print(gt_instances.bboxes)
|
||||
|
||||
for k, v in gt_instances.items():
|
||||
for k, v in gt_instances.all_items():
|
||||
print(f'{k}: {v}') # 包含 img_shapes, img_id, bboxes,scores
|
||||
|
||||
for k, v in gt_instances.metainfo_items():
|
||||
print(f'{k}: {v}') # 包含 img_shapes, img_id
|
||||
|
||||
for k, v in gt_instances.data_items():
|
||||
for k, v in gt_instances.items():
|
||||
print(f'{k}: {v}') # 包含 bboxes,scores
|
||||
```
|
||||
|
||||
@ -164,8 +174,9 @@ for k, v in gt_instances.data_items():
|
||||
|
||||
```python
|
||||
gt_instances = BaseDataElement(
|
||||
bboxes=torch.rand((6, 4)), scores=torch.rand((6,)),
|
||||
metainfo=dict(img_id=0, img_shape=(640, 640)),
|
||||
data=dict(bboxes=torch.rand((6, 4)), scores=torch.rand((6,))))
|
||||
)
|
||||
|
||||
# 对类的属性进行修改
|
||||
gt_instances.img_shape = (1280, 1280)
|
||||
@ -216,7 +227,7 @@ np_instances = cpu_instances.numpy()
|
||||
|
||||
#### 6. 属性的展示
|
||||
|
||||
`BaseDataElement` 还实现了 `__nice__` 和 `__repr__`,因此,用户可以直接通过 `print` 函数看到其中的所有数据信息。
|
||||
`BaseDataElement` 还实现了 `__repr__`,因此,用户可以直接通过 `print` 函数看到其中的所有数据信息。
|
||||
同时,为了便捷开发者 debug,`BaseDataElement` 中的属性都会添加进 `__dict__` 中,方便用户在 IDE 界面可以直观看到 `BaseDataElement` 中的内容。
|
||||
一个完整的属性展示如下
|
||||
|
||||
@ -236,74 +247,64 @@ shape of det_scores: torch.Size([4])
|
||||
) at 0x7f84acd10f90>
|
||||
```
|
||||
|
||||
### BaseDataSample
|
||||
### DataSample
|
||||
|
||||
MMEngine 为样本数据的封装提供了一个基类 `BaseDataSample`,OpenMMLab 的每个算法库都应该继承 `BaseDataSample` 实现自己的样本数据封装,并规约和校验该算法库中的常见字段。算法库自己实现的样本数据封装会作为该算法库内 dataset,visualizer,evaluator,model 组件之间的数据接口进行流通。
|
||||
`BaseDataSample` 虽然可以作为一个模块被单独使用,但是我们不推荐 `BaseDataSample` 这种用法。
|
||||
|
||||
`BaseDataSample` 内部依然区分 metainfo 和 data,并且支持像类一样对其属性进行设置和调整,为了保证用户体验的一致性,`BaseDataSample` 的外部接口用法和 `BaseDataElement` 保持一致。
|
||||
|
||||
同时,由于 `BaseDataSample` 作为基类一般不会直接使用,为了方便下游算法库快速定义其子类,并对子类的属性进行规约和校验。
|
||||
`BaseDataSample` 额外提供了一套内部接口 `get_field`, `del_field` 和 `set_field` 来便利它的子类快捷地定义和规约 data 属性的增删改查。
|
||||
`set_field` 不会被当作外部接口直接使用,而是被用来定义属性(property) 的 `setter` 并提供基本的类型校验。
|
||||
|
||||
一个简单粗略的实现和用例如下。
|
||||
基于 `BaseDataElement`,下游算法库可以定义 `DetDataSample`,并且定义 3 个 property:proposals,gt_instances,pred_instances,并约束他们的类型。
|
||||
|
||||
```python
|
||||
from abc import ABC
|
||||
from functools import partial
|
||||
class DetDataSample(BaseDataElement):
|
||||
|
||||
@property
|
||||
def proposals(self):
|
||||
return self._proposals
|
||||
|
||||
class BaseDataSample(ABC):
|
||||
@proposals.setter
|
||||
def proposals(self, value):
|
||||
self.set_field(value, '_proposals', dtype=InstanceData)
|
||||
|
||||
def __init__(self, metainfo=dict(), data=dict()):
|
||||
self._data_fields = set()
|
||||
self._metainfo_fields = set()
|
||||
@proposals.deleter
|
||||
def proposals(self):
|
||||
del self._proposals
|
||||
|
||||
# 其他功能实现
|
||||
...
|
||||
@property
|
||||
def gt_instances(self):
|
||||
"""Ground truth instances of an image"""
|
||||
return self._gt_instances
|
||||
|
||||
def get_field(self, name):
|
||||
return getattr(self, name)
|
||||
@gt_instances.setter
|
||||
def gt_instances(self, value):
|
||||
self.set_field(value, '_gt_instances', dtype=InstanceData)
|
||||
|
||||
def set_field(self, val, name, dtype):
|
||||
assert isinstance(val, dtype)
|
||||
super().__setattr__(name, val)
|
||||
self._data_fields.add(name)
|
||||
@gt_instances.deleter
|
||||
def gt_instances(self):
|
||||
del self._gt_instances
|
||||
|
||||
def del_field(self, name):
|
||||
super().__delattr__(name)
|
||||
self._data_fields.remove(name)
|
||||
@property
|
||||
def pred_instances(self):
|
||||
"""Predicted instances of an image"""
|
||||
return self._pred_instances
|
||||
|
||||
```
|
||||
@pred_instances.setter
|
||||
def pred_instances(self, value):
|
||||
self.set_field(value, '_pred_instances', dtype=InstanceData)
|
||||
|
||||
基于 `BaseDataSample`,下游算法库可以定义 `DetDataSample`,并且使用 `BaseDataSample` 中的接口,快速定义 3 个 property:proposals,gt_instances,pred_instances,并约束他们的类型。
|
||||
@pred_instances.deleter
|
||||
def pred_instances(self):
|
||||
del self._pred_instances
|
||||
|
||||
```python
|
||||
class DetDataSample(BaseDataSample):
|
||||
@property
|
||||
def proposals(self):
|
||||
"""Region proposals"""
|
||||
return self._proposals
|
||||
|
||||
proposals = property(
|
||||
# 定义了 get 方法,通过 name '_proposals' 来访问实际维护的变量
|
||||
fget=partial(BaseDataSample.get_field, name='_proposals'),
|
||||
# 定义了 set 方法,将实际维护的变量设置为 '_proposals',并在设置的时候检查类型是否是 dtype 定义的类型 InstanceData
|
||||
fset=partial(BaseDataSample.set_field, name='_proposals', dtype=InstanceData),
|
||||
fdel=partial(BaseDataSample.del_field, name='_proposals'),
|
||||
doc='Region proposals of an image'
|
||||
)
|
||||
@proposals.setter
|
||||
def proposals(self, value):
|
||||
self.set_field(value, '_proposals', dtype=InstanceData)
|
||||
|
||||
gt_instances = property(
|
||||
fget=partial(BaseDataSample.get_field, name='_gt_instances'),
|
||||
fset=partial(BaseDataSample.set_field, name='_gt_instances', dtype=InstanceData),
|
||||
fdel=partial(BaseDataSample.del_field, name='_gt_instances'),
|
||||
doc='Ground truth instances of an image'
|
||||
)
|
||||
@proposals.deleter
|
||||
def proposalss(self):
|
||||
del self._proposals
|
||||
|
||||
pred_instances = property(
|
||||
fget=partial(BaseDataSample.get_field, name='_pred_instances'),
|
||||
fset=partial(BaseDataSample.set_field, name='_pred_instances', dtype=InstanceData),
|
||||
fdel=partial(BaseDataSample.del_field, name='_pred_instances'),
|
||||
doc='Predicted instances of an image'
|
||||
)
|
||||
```
|
||||
|
||||
`DetDataSample` 的用法如下所示,在数据类型不符合要求的时候(例如用 `torch.Tensor` 而非 `InstanceData` 定义 proposals 时) ,`DetDataSample` 就会报错。
|
@ -44,4 +44,4 @@ PyTorch 提供了一套基础的通信原语用于多进程之间张量的通信
|
||||
- [all_reduce_dict](https://mmengine.readthedocs.io/zh/latest//api.html#mmengine.dist.all_reduce_dict):对 dict 中的内容进行 all_reduce 操作,基于 broadcast 和 all_reduce 接口实现
|
||||
- [all_gather_object](https://mmengine.readthedocs.io/zh/latest//api.html#mmengine.dist.all_gather_object):基于 all_gather 实现对任意可以 Python 序列化对象的 all_tather 操作
|
||||
- [gather_object](https://mmengine.readthedocs.io/zh/latest//api.html#mmengine.dist.gather_object):将 group 里每个 rank 的 data gather 到一个目标 rank,且支持多种方式
|
||||
- [collect_results](https://mmengine.readthedocs.io/zh/latest//api.html#mmengine.dist.collect_results):支持基于 CPU 或者 GPU 对不同进程间的列表数据进行收集·
|
||||
- [collect_results](https://mmengine.readthedocs.io/zh/latest//api.html#mmengine.dist.collect_results):支持基于 CPU 或者 GPU 对不同进程间的列表数据进行收集
|
||||
|
Loading…
x
Reference in New Issue
Block a user