6.8 KiB
6.8 KiB
数据变换类的迁移
简介
在 TorchVision 的数据变换类接口约定中,数据变换类需要实现 __call__
方法,而在 OpenMMLab 1.0 的接口约定中,进一步要求
__call__
方法的输出应当是一个字典,在各种数据变换中对这个字典进行增删查改。在 OpenMMLab 2.0 中,为了提升后续的可扩展性,我们将原先的 __call__
方法迁移为 transform
方法,并要求数据变换类应当继承
mmcv.transforms.BaseTransfrom
。具体如何实现一个数据变换类,可以参见文档。
由于在此次更新中,我们将部分共用的数据变换类统一迁移至 MMCV 中,因此本文的将会以 MMClassification v0.23.2、MMDetection v2.25.1 和 MMCV v2.0.0rc0 为例,对比这些数据变换类在新旧版本中功能、用法和实现上的差异。
功能差异
MMClassification (旧) | MMDetection (旧) | MMCV (新) | |
---|---|---|---|
LoadImageFromFile |
从 'img_prefix' 和 'img_info.filename' 字段组合获得文件路径并读取 | 从 'img_prefix' 和 'img_info.filename' 字段组合获得文件路径并读取,支持指定通道顺序 | 从 'img_path' 获得文件路径并读取,支持指定加载失败不报错,支持指定解码后端 |
LoadAnnotations |
无 | 支持读取 bbox,label,mask(包括多边形样式),seg map,转换 bbox 坐标系 | 支持读取 bbox,label,mask(不包括多边形样式),seg map |
Pad |
填充 "img_fields" 中所有字段,不支持指定填充至整数倍 | 填充 "img_fields" 中所有字段,支持指定填充至整数倍 | 填充 "img" 字段,支持指定填充至整数倍 |
CenterCrop |
裁切 "img_fields" 中所有字段,支持以 EfficientNet 方式进行裁切 | 无 | 裁切 "img" 字段的图像,"gt_bboxes" 字段的 bbox,"gt_seg_map" 字段的分割图,"gt_keypoints" 字段的关键点,支持自动填充裁切边缘 |
Normalize |
图像归一化 | 无差异 | 无差异,但 MMEngine 推荐在数据预处理器中进行归一化 |
Resize |
缩放 "img_fields" 中所有字段,允许指定根据某边长等比例缩放 | 功能由 Resize 实现。需要 ratio_range 为 None,img_scale 仅指定一个尺寸,且 multiscale_mode 为 "value" 。 |
缩放 "img" 字段的图像,"gt_bboxes" 字段的 bbox,"gt_seg_map" 字段的分割图,"gt_keypoints" 字段的关键点,支持指定缩放比例,支持等比例缩放图像至指定尺寸内 |
RandomResize |
无 | 功能由 Resize 实现。需要 ratio_range 为 None,img_scale 指定两个尺寸,且 multiscale_mode 为 "range",或 ratio_range 不为 None。
Resize( img_sacle=[(640, 480), (960, 720)], mode="range", ) |
缩放功能同 Resize ,支持从指定尺寸范围或指定比例范围随机采样缩放尺寸。
RandomResize(scale=[(640, 480), (960, 720)]) |
RandomChoiceResize |
无 | 功能由 Resize 实现。需要 ratio_range 为 None,img_scale 指定多个尺寸,且 multiscale_mode 为 "value"。
Resize( img_sacle=[(640, 480), (960, 720)], mode="value", ) |
缩放功能同 Resize ,支持从若干指定尺寸中随机选择缩放尺寸。
RandomChoiceResize(scales=[(640, 480), (960, 720)]) |
RandomGrayscale |
灰度化 "img_fields" 中所有字段,灰度化后保持通道数。 | 无 | 灰度化 "img" 字段,支持指定灰度化权重,支持指定是否在灰度化后保持通道数(默认不保持)。 |
RandomFlip |
翻转 "img_fields" 中所有字段,支持指定水平或垂直翻转。 | 翻转 "img_fields", "bbox_fields", "mask_fields", "seg_fields" 中所有字段,支持指定水平、垂直或对角翻转,支持指定各类翻转概率。 | 翻转 "img", "gt_bboxes", "gt_seg_map", "gt_keypoints" 字段,支持指定水平、垂直或对角翻转,支持指定各类翻转概率。 |
MultiScaleFlipAug |
无 | 用于测试时增强 | TODO |
ToTensor |
将指定字段转换为 torch.Tensor |
无差异 | 无差异 |
ImageToTensor |
将指定字段转换为 torch.Tensor ,并调整通道顺序至 CHW。 |
无差异 | 无差异 |
实现差异
以 RandomFlip
为例,MMCV 的 RandomFlip 相比旧版 MMDetection 的 RandomFlip,需要继承 BaseTransfrom
,将功能实现放在 transforms
方法,并将生成随机结果的部分放在单独的方法中,用 cache_randomness
包装。有关随机方法的包装相关功能,参见相关文档。
- MMDetection (旧)
class RandomFlip:
def __call__(self, results):
"""调用时进行随机翻转"""
...
# 随机选择翻转方向
cur_dir = np.random.choice(direction_list, p=flip_ratio_list)
...
return results
- MMCV
class RandomFlip(BaseTransfrom):
def transform(self, results):
"""调用时进行随机翻转"""
...
cur_dir = self._random_direction()
...
return results
@cache_randomness
def _random_direction(self):
"""随机选择翻转方向"""
...
return np.random.choice(direction_list, p=flip_ratio_list)