mmclassification/mmpretrain/models/utils/data_preprocessor.py

196 lines
8.1 KiB
Python
Raw Normal View History

2022-06-09 21:48:12 +08:00
# Copyright (c) OpenMMLab. All rights reserved.
import math
2022-06-09 21:48:12 +08:00
from numbers import Number
from typing import Optional, Sequence
2022-06-09 21:48:12 +08:00
import torch
import torch.nn.functional as F
2022-06-09 21:48:12 +08:00
from mmengine.model import BaseDataPreprocessor, stack_batch
from mmpretrain.registry import MODELS
from mmpretrain.structures import (DataSample, MultiTaskDataSample,
batch_label_to_onehot, cat_batch_labels,
tensor_split)
from .batch_augments import RandomBatchAugment
2022-06-09 21:48:12 +08:00
@MODELS.register_module()
class ClsDataPreprocessor(BaseDataPreprocessor):
"""Image pre-processor for classification tasks.
Comparing with the :class:`mmengine.model.ImgDataPreprocessor`,
2022-06-09 21:48:12 +08:00
1. It won't do normalization if ``mean`` is not specified.
2. It does normalization and color space conversion after stacking batch.
3. It supports batch augmentations like mixup and cutmix.
It provides the data pre-processing as follows
- Collate and move data to the target device.
- Pad inputs to the maximum size of current batch with defined
``pad_value``. The padding size can be divisible by a defined
``pad_size_divisor``
- Stack inputs to batch_inputs.
- Convert inputs from bgr to rgb if the shape of input is (3, H, W).
- Normalize image with defined std and mean.
- Do batch augmentations like Mixup and Cutmix during training.
Args:
mean (Sequence[Number], optional): The pixel mean of R, G, B channels.
Defaults to None.
std (Sequence[Number], optional): The pixel standard deviation of
R, G, B channels. Defaults to None.
pad_size_divisor (int): The size of padded image should be
divisible by ``pad_size_divisor``. Defaults to 1.
pad_value (Number): The padded pixel value. Defaults to 0.
to_rgb (bool): whether to convert image from BGR to RGB.
Defaults to False.
to_onehot (bool): Whether to generate one-hot format gt-labels and set
to data samples. Defaults to False.
num_classes (int, optional): The number of classes. Defaults to None.
batch_augments (dict, optional): The batch augmentations settings,
including "augments" and "probs". For more details, see
:class:`mmpretrain.models.RandomBatchAugment`.
2022-06-09 21:48:12 +08:00
"""
def __init__(self,
mean: Sequence[Number] = None,
std: Sequence[Number] = None,
pad_size_divisor: int = 1,
pad_value: Number = 0,
to_rgb: bool = False,
to_onehot: bool = False,
num_classes: Optional[int] = None,
batch_augments: Optional[dict] = None):
2022-06-09 21:48:12 +08:00
super().__init__()
self.pad_size_divisor = pad_size_divisor
self.pad_value = pad_value
self.to_rgb = to_rgb
self.to_onehot = to_onehot
self.num_classes = num_classes
2022-06-09 21:48:12 +08:00
if mean is not None:
assert std is not None, 'To enable the normalization in ' \
'preprocessing, please specify both `mean` and `std`.'
# Enable the normalization in preprocessing.
self._enable_normalize = True
self.register_buffer('mean',
torch.tensor(mean).view(-1, 1, 1), False)
self.register_buffer('std',
torch.tensor(std).view(-1, 1, 1), False)
else:
self._enable_normalize = False
if batch_augments:
self.batch_augments = RandomBatchAugment(**batch_augments)
if not self.to_onehot:
from mmengine.logging import MMLogger
MMLogger.get_current_instance().info(
'Because batch augmentations are enabled, the data '
'preprocessor automatically enables the `to_onehot` '
'option to generate one-hot format labels.')
self.to_onehot = True
else:
self.batch_augments = None
2022-06-09 21:48:12 +08:00
def forward(self, data: dict, training: bool = False) -> dict:
"""Perform normalization, padding, bgr2rgb conversion and batch
augmentation based on ``BaseDataPreprocessor``.
2022-06-09 21:48:12 +08:00
Args:
data (dict): data sampled from dataloader.
2022-06-09 21:48:12 +08:00
training (bool): Whether to enable training time augmentation.
Returns:
dict: Data in the same format as the model input.
2022-06-09 21:48:12 +08:00
"""
inputs = self.cast_data(data['inputs'])
if isinstance(inputs, torch.Tensor):
# The branch if use `default_collate` as the collate_fn in the
# dataloader.
# ------ To RGB ------
if self.to_rgb and inputs.size(1) == 3:
inputs = inputs.flip(1)
# -- Normalization ---
inputs = inputs.float()
if self._enable_normalize:
inputs = (inputs - self.mean) / self.std
# ------ Padding -----
if self.pad_size_divisor > 1:
h, w = inputs.shape[-2:]
target_h = math.ceil(
h / self.pad_size_divisor) * self.pad_size_divisor
target_w = math.ceil(
w / self.pad_size_divisor) * self.pad_size_divisor
pad_h = target_h - h
pad_w = target_w - w
inputs = F.pad(inputs, (0, pad_w, 0, pad_h), 'constant',
self.pad_value)
else:
# The branch if use `pseudo_collate` as the collate_fn in the
# dataloader.
2022-06-09 21:48:12 +08:00
processed_inputs = []
for input_ in inputs:
# ------ To RGB ------
if self.to_rgb and input_.size(0) == 3:
input_ = input_.flip(0)
2022-06-09 21:48:12 +08:00
# -- Normalization ---
input_ = input_.float()
if self._enable_normalize:
input_ = (input_ - self.mean) / self.std
2022-06-09 21:48:12 +08:00
processed_inputs.append(input_)
# Combine padding and stack
inputs = stack_batch(processed_inputs, self.pad_size_divisor,
self.pad_value)
2022-06-09 21:48:12 +08:00
data_samples = data.get('data_samples', None)
[Feature] Support Multi-task. (#1229) * unit test for multi_task_head * [Feature] MultiTaskHead (#628, #481) * [Fix] lint for multi_task_head * [Feature] Add `MultiTaskDataset` to support multi-task training. * Update MultiTaskClsHead * Update docs * [CI] Add test mim CI. (#879) * [Fix] Remove duplicated wide-resnet metafile. * [Feature] Support MPS device. (#894) * [Feature] Support MPS device. * Add `auto_select_device` * Add unit tests * [Fix] Fix Albu crash bug. (#918) * Fix albu BUG: using albu will cause the label from array(x) to array([x]) and crash the trainning * Fix common * Using copy incase potential bug in multi-label tasks * Improve coding * Improve code logic * Add unit test * Fix typo * Fix yapf * Bump version to 0.23.2. (#937) * [Improve] Use `forward_dummy` to calculate FLOPS. (#953) * Update README * [Docs] Fix typo for wrong reference. (#1036) * [Doc] Fix typo in tutorial 2 (#1043) * [Docs] Fix a typo in ImageClassifier (#1050) * add mask to loss * add another pipeline * adpat the pipeline if there is no mask * switch mask and task * first version of multi data smaple * fix problem with attribut by getattr * rm img_label suffix, fix 'LabelData' object has no attribute 'gt_label' * training without evaluation * first version work * add others metrics * delete evaluation from dataset * fix linter * fix linter * multi metrics * first version of test * change evaluate metric * Update tests/test_models/test_heads.py Co-authored-by: Colle <piercus@users.noreply.github.com> * Update tests/test_models/test_heads.py Co-authored-by: Colle <piercus@users.noreply.github.com> * add tests * add test for multidatasample * create a generic test * create a generic test * create a generic test * change multi data sample * correct test * test * add new test * add test for dataset * correct test * correct test * correct test * correct test * fix : #5 * run yapf * fix linter * fix linter * fix linter * fix isort * fix isort * fix docformmater * fix docformmater * fix linter * fix linter * fix data sample * Update mmcls/structures/multi_task_data_sample.py Co-authored-by: Colle <piercus@users.noreply.github.com> * Update mmcls/structures/multi_task_data_sample.py Co-authored-by: Colle <piercus@users.noreply.github.com> * Update mmcls/structures/multi_task_data_sample.py Co-authored-by: Colle <piercus@users.noreply.github.com> * Update mmcls/structures/multi_task_data_sample.py Co-authored-by: Colle <piercus@users.noreply.github.com> * Update mmcls/structures/multi_task_data_sample.py Co-authored-by: Colle <piercus@users.noreply.github.com> * Update mmcls/structures/multi_task_data_sample.py Co-authored-by: Colle <piercus@users.noreply.github.com> * Update tests/test_structures/test_datasample.py Co-authored-by: Colle <piercus@users.noreply.github.com> * Update mmcls/structures/multi_task_data_sample.py Co-authored-by: Colle <piercus@users.noreply.github.com> * Update tests/test_structures/test_datasample.py Co-authored-by: Colle <piercus@users.noreply.github.com> * Update tests/test_structures/test_datasample.py Co-authored-by: Colle <piercus@users.noreply.github.com> * update data sample * update head * update head * update multi data sample * fix linter * fix linter * fix linter * fix linter * fix linter * fix linter * update head * fix problem we don't set pred or gt * fix problem we don't set pred or gt * fix problem we don't set pred or gt * fix linter * fix : #2 * fix : linter * update multi head * fix linter * fix linter * update data sample * update data sample * fix ; linter * update test * test pipeline * update pipeline * update test * update dataset * update dataset * fix linter * fix linter * update formatting * add test for multi-task-eval * update formatting * fix linter * update test * update * add test * update metrics * update metrics * add doc for functions * fix linter * training for multitask 1.x * fix linter * run flake8 * run linter * update test * add mask in evaluation * update metric doc * update metric doc * Update mmcls/evaluation/metrics/multi_task.py Co-authored-by: Colle <piercus@users.noreply.github.com> * Update mmcls/evaluation/metrics/multi_task.py Co-authored-by: Colle <piercus@users.noreply.github.com> * Update mmcls/evaluation/metrics/multi_task.py Co-authored-by: Colle <piercus@users.noreply.github.com> * Update mmcls/evaluation/metrics/multi_task.py Co-authored-by: Colle <piercus@users.noreply.github.com> * Update mmcls/evaluation/metrics/multi_task.py Co-authored-by: Colle <piercus@users.noreply.github.com> * Update mmcls/evaluation/metrics/multi_task.py Co-authored-by: Colle <piercus@users.noreply.github.com> * update metric doc * update metric doc * Fix cannot import name MultiTaskDataSample * fix test_datasets * fix test_datasets * fix linter * add an example of multitask * change name of configs dataset * Refactor the multi-task support * correct test and metric * add test to multidatasample * add test to multidatasample * correct test * correct metrics and clshead * Update mmcls/models/heads/cls_head.py Co-authored-by: Colle <piercus@users.noreply.github.com> * update cls_head.py documentation * lint * lint * fix: lint * fix linter * add eval mask * fix documentation * fix: single_label.py back to 1.x * Update mmcls/models/heads/multi_task_head.py Co-authored-by: Ma Zerun <mzr1996@163.com> * Remove multi-task configs. Co-authored-by: mzr1996 <mzr1996@163.com> Co-authored-by: HinGwenWoong <peterhuang0323@qq.com> Co-authored-by: Ming-Hsuan-Tu <alec.tu@acer.com> Co-authored-by: Lei Lei <18294546+Crescent-Saturn@users.noreply.github.com> Co-authored-by: WRH <12756472+wangruohui@users.noreply.github.com> Co-authored-by: marouaneamz <maroineamil99@gmail.com> Co-authored-by: marouane amzil <53240092+marouaneamz@users.noreply.github.com>
2022-12-30 10:36:00 +08:00
sample_item = data_samples[0] if data_samples is not None else None
if isinstance(sample_item, DataSample):
batch_label = None
batch_score = None
if 'gt_label' in sample_item:
gt_labels = [sample.gt_label for sample in data_samples]
batch_label, label_indices = cat_batch_labels(gt_labels)
batch_label = batch_label.to(self.device)
if 'gt_score' in sample_item:
gt_scores = [sample.gt_score for sample in data_samples]
batch_score = torch.stack(gt_scores).to(self.device)
elif self.to_onehot:
assert batch_label is not None, \
'Cannot generate onehot format labels because no labels.'
num_classes = self.num_classes or sample_item.get(
'num_classes')
assert num_classes is not None, \
'Cannot generate one-hot format labels because not set ' \
'`num_classes` in `data_preprocessor`.'
batch_score = batch_label_to_onehot(
batch_label, label_indices, num_classes).to(self.device)
# ----- Batch Augmentations ----
if training and self.batch_augments is not None:
inputs, batch_score = self.batch_augments(inputs, batch_score)
# ----- scatter labels and scores to data samples ---
if batch_label is not None:
for sample, label in zip(
data_samples, tensor_split(batch_label,
label_indices)):
sample.set_gt_label(label)
if batch_score is not None:
for sample, score in zip(data_samples, batch_score):
sample.set_gt_score(score)
[Feature] Support Multi-task. (#1229) * unit test for multi_task_head * [Feature] MultiTaskHead (#628, #481) * [Fix] lint for multi_task_head * [Feature] Add `MultiTaskDataset` to support multi-task training. * Update MultiTaskClsHead * Update docs * [CI] Add test mim CI. (#879) * [Fix] Remove duplicated wide-resnet metafile. * [Feature] Support MPS device. (#894) * [Feature] Support MPS device. * Add `auto_select_device` * Add unit tests * [Fix] Fix Albu crash bug. (#918) * Fix albu BUG: using albu will cause the label from array(x) to array([x]) and crash the trainning * Fix common * Using copy incase potential bug in multi-label tasks * Improve coding * Improve code logic * Add unit test * Fix typo * Fix yapf * Bump version to 0.23.2. (#937) * [Improve] Use `forward_dummy` to calculate FLOPS. (#953) * Update README * [Docs] Fix typo for wrong reference. (#1036) * [Doc] Fix typo in tutorial 2 (#1043) * [Docs] Fix a typo in ImageClassifier (#1050) * add mask to loss * add another pipeline * adpat the pipeline if there is no mask * switch mask and task * first version of multi data smaple * fix problem with attribut by getattr * rm img_label suffix, fix 'LabelData' object has no attribute 'gt_label' * training without evaluation * first version work * add others metrics * delete evaluation from dataset * fix linter * fix linter * multi metrics * first version of test * change evaluate metric * Update tests/test_models/test_heads.py Co-authored-by: Colle <piercus@users.noreply.github.com> * Update tests/test_models/test_heads.py Co-authored-by: Colle <piercus@users.noreply.github.com> * add tests * add test for multidatasample * create a generic test * create a generic test * create a generic test * change multi data sample * correct test * test * add new test * add test for dataset * correct test * correct test * correct test * correct test * fix : #5 * run yapf * fix linter * fix linter * fix linter * fix isort * fix isort * fix docformmater * fix docformmater * fix linter * fix linter * fix data sample * Update mmcls/structures/multi_task_data_sample.py Co-authored-by: Colle <piercus@users.noreply.github.com> * Update mmcls/structures/multi_task_data_sample.py Co-authored-by: Colle <piercus@users.noreply.github.com> * Update mmcls/structures/multi_task_data_sample.py Co-authored-by: Colle <piercus@users.noreply.github.com> * Update mmcls/structures/multi_task_data_sample.py Co-authored-by: Colle <piercus@users.noreply.github.com> * Update mmcls/structures/multi_task_data_sample.py Co-authored-by: Colle <piercus@users.noreply.github.com> * Update mmcls/structures/multi_task_data_sample.py Co-authored-by: Colle <piercus@users.noreply.github.com> * Update tests/test_structures/test_datasample.py Co-authored-by: Colle <piercus@users.noreply.github.com> * Update mmcls/structures/multi_task_data_sample.py Co-authored-by: Colle <piercus@users.noreply.github.com> * Update tests/test_structures/test_datasample.py Co-authored-by: Colle <piercus@users.noreply.github.com> * Update tests/test_structures/test_datasample.py Co-authored-by: Colle <piercus@users.noreply.github.com> * update data sample * update head * update head * update multi data sample * fix linter * fix linter * fix linter * fix linter * fix linter * fix linter * update head * fix problem we don't set pred or gt * fix problem we don't set pred or gt * fix problem we don't set pred or gt * fix linter * fix : #2 * fix : linter * update multi head * fix linter * fix linter * update data sample * update data sample * fix ; linter * update test * test pipeline * update pipeline * update test * update dataset * update dataset * fix linter * fix linter * update formatting * add test for multi-task-eval * update formatting * fix linter * update test * update * add test * update metrics * update metrics * add doc for functions * fix linter * training for multitask 1.x * fix linter * run flake8 * run linter * update test * add mask in evaluation * update metric doc * update metric doc * Update mmcls/evaluation/metrics/multi_task.py Co-authored-by: Colle <piercus@users.noreply.github.com> * Update mmcls/evaluation/metrics/multi_task.py Co-authored-by: Colle <piercus@users.noreply.github.com> * Update mmcls/evaluation/metrics/multi_task.py Co-authored-by: Colle <piercus@users.noreply.github.com> * Update mmcls/evaluation/metrics/multi_task.py Co-authored-by: Colle <piercus@users.noreply.github.com> * Update mmcls/evaluation/metrics/multi_task.py Co-authored-by: Colle <piercus@users.noreply.github.com> * Update mmcls/evaluation/metrics/multi_task.py Co-authored-by: Colle <piercus@users.noreply.github.com> * update metric doc * update metric doc * Fix cannot import name MultiTaskDataSample * fix test_datasets * fix test_datasets * fix linter * add an example of multitask * change name of configs dataset * Refactor the multi-task support * correct test and metric * add test to multidatasample * add test to multidatasample * correct test * correct metrics and clshead * Update mmcls/models/heads/cls_head.py Co-authored-by: Colle <piercus@users.noreply.github.com> * update cls_head.py documentation * lint * lint * fix: lint * fix linter * add eval mask * fix documentation * fix: single_label.py back to 1.x * Update mmcls/models/heads/multi_task_head.py Co-authored-by: Ma Zerun <mzr1996@163.com> * Remove multi-task configs. Co-authored-by: mzr1996 <mzr1996@163.com> Co-authored-by: HinGwenWoong <peterhuang0323@qq.com> Co-authored-by: Ming-Hsuan-Tu <alec.tu@acer.com> Co-authored-by: Lei Lei <18294546+Crescent-Saturn@users.noreply.github.com> Co-authored-by: WRH <12756472+wangruohui@users.noreply.github.com> Co-authored-by: marouaneamz <maroineamil99@gmail.com> Co-authored-by: marouane amzil <53240092+marouaneamz@users.noreply.github.com>
2022-12-30 10:36:00 +08:00
elif isinstance(sample_item, MultiTaskDataSample):
data_samples = self.cast_data(data_samples)
return {'inputs': inputs, 'data_samples': data_samples}