5.6 KiB
自定义数据处理流程
数据流的设计
在新数据集教程中,我们知道数据集类使用 load_data_list
方法来初始化整个数据集,我们将每个样本的信息保存到一个 dict 中。
通常,为了节省内存,我们只加载 load_data_list
中的图片路径和标签,使用时加载完整的图片内容。此外,我们可能希望在训练时选择样本时进行一些随机数据扩充。几乎所有的数据加载、预处理和格式化操作都可以通过数据管道在 MMPretrain 中进行配置。
数据管道意味着在从数据集中索引样本时如何处理样本字典,它由一系列数据变换组成。每个数据变换都将一个字典作为输入,对其进行处理,并为下一个数据变换输出一个字典。
这是 ImageNet 上 ResNet-50 训练的数据管道示例。
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='RandomResizedCrop', scale=224),
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
dict(type='PackInputs'),
]
MMPretrain 中所有可用的数据变换都可以在 数据变换文档 中找到。
修改训练/测试管道
MMPretrain 中的数据管道非常灵活。您几乎可以从配置文件中控制数据预处理的每一步,但另一方面,面对如此多的选项,您可能会感到困惑。
这是图像分类任务的常见做法和指南。
读取
在数据管道的开始,我们通常需要从文件路径加载图像数据。
LoadImageFromFile
通常用于执行此任务。
train_pipeline = [
dict(type='LoadImageFromFile'),
...
]
如果您想从具有特殊格式或特殊位置的文件中加载数据,您可以 实施新的加载变换 并将其添加到数据管道的开头。
增强和其它处理
在训练过程中,我们通常需要做数据增强来避免过拟合。在测试过程中,我们还需要做一些数据处理,比如调整大小和裁剪。这些数据变换将放置在加载过程之后。
这是一个简单的数据扩充方案示例。它会将输入图像随机调整大小并裁剪到指定比例,并随机水平翻转图像。
train_pipeline = [
...
dict(type='RandomResizedCrop', scale=224),
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
...
]
这是 Swin-Transformer 训练中使用的大量数据增强配方示例。 为了与官方实施保持一致,它指定 pillow
作为调整大小后端,bicubic
作为调整大小算法。 此外,它添加了 RandAugment
和 RandomErasing
作为额外的数据增强方法。
此配置指定了数据扩充的每个细节,您只需将其复制到您自己的配置文件中即可应用 Swin-Transformer 的数据扩充。
bgr_mean = [103.53, 116.28, 123.675]
bgr_std = [57.375, 57.12, 58.395]
train_pipeline = [
...
dict(type='RandomResizedCrop', scale=224, backend='pillow', interpolation='bicubic'),
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
dict(
type='RandAugment',
policies='timm_increasing',
num_policies=2,
total_level=10,
magnitude_level=9,
magnitude_std=0.5,
hparams=dict(
pad_val=[round(x) for x in bgr_mean], interpolation='bicubic')),
dict(
type='RandomErasing',
erase_prob=0.25,
mode='rand',
min_area_ratio=0.02,
max_area_ratio=1 / 3,
fill_color=bgr_mean,
fill_std=bgr_std),
...
]
通常,数据管道中的数据增强部分仅处理图像方面的变换,而不处理图像归一化或混合/剪切混合等变换。 因为我们可以对 batch data 做 image normalization 和 mixup/cutmix 来加速。要配置图像归一化和 mixup/cutmix,请使用 [数据预处理器](mmpretrain.models.utils.data_preprocessor)。
格式化
格式化是从数据信息字典中收集训练数据,并将这些数据转换为模型友好的格式。
在大多数情况下,您可以简单地使用 PackInputs
,它将 NumPy 数组格式的图像转换为 PyTorch 张量,并将 ground truth 类别信息和其他元信息打包为 DataSample
。
train_pipeline = [
...
dict(type='PackInputs'),
]
添加新的数据变换
-
在任何文件中写入一个新的数据转换,例如
my_transform.py
,并将其放在文件夹mmpretrain/datasets/transforms/
中。 数据变换类需要继承mmcv.transforms.BaseTransform
类并覆盖以字典作为输入并返回字典的transform
方法。from mmcv.transforms import BaseTransform from mmpretrain.datasets import TRANSFORMS @TRANSFORMS.register_module() class MyTransform(BaseTransform): def transform(self, results): # Modify the data information dict `results`. return results
-
在
mmpretrain/datasets/transforms/__init__.py
中导入新的变换... from .my_transform import MyTransform __all__ = [ ..., 'MyTransform' ]
-
在配置文件中使用
train_pipeline = [ ... dict(type='MyTransform'), ... ]
数据管道可视化
数据流水线设计完成后,可以使用 可视化工具 查看效果。