mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
* [Refactor]: modify interface of Visualizer.add_datasample (#365) * [Refactor] Refactor data flow: refine `data_preprocessor`. (#359) * refine data_preprocessor * remove unused BATCH_DATA alias * Fix type hints * rename move_data to cast_data * [Refactor] Refactor data flow: collate data in `collate_fn` of `DataLoader` (#323) * acollate data in dataloader * fix docstring * refine comment * fix as comment * refactor default collate and psedo collate * foramt test file * fix docstring * fix as comment * rename elem to data_item * minor fix * fix as comment * [Refactor] Refactor data flow: `data_batch` argument of `Evaluator.process is a `dict` (#360) * refine evaluator and metric * compatible with new default collate * replace default collate with pseudo * Handle data_batch in metric * fix unit test * fix unit test * fix unit test * minor refine * make data_batch optional make data_batch optional * rename outputs to predictions * fix ut * rename predictions to outputs * fix docstring * fix docstring * fix unit test * make outputs and data_batch to kwargs * fix unit test * keep signature of metric * fix ut * rename pred_sample arguments to data_sample(Visualizer) * fix loop and ut * [refactor]: Refactor model dataflow (#398) * [Refactor] Refactor data flow: refine `data_preprocessor`. (#359) * refine data_preprocessor * remove unused BATCH_DATA alias * Fix type hints * rename move_data to cast_data * refactor model data flow tmp_commt tmp commit * make val_cfg and test_cfg optional * roll back runner * pass test mmdet * fix as comment fix as comment fix ci in DataPreprocessor * fix ut * fix ut * fix rebase main * [Fix]: Fix test val ddp (#462) * [Fix] Fix docstring and type hint of data flow (#463) * Fix docstring of data flow * change signature of hook * fix unit test * resolve conflicts * fix lint
156 lines
6.4 KiB
Python
156 lines
6.4 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
from unittest import TestCase
|
|
|
|
import numpy as np
|
|
import torch
|
|
|
|
from mmengine.dataset import default_collate, pseudo_collate
|
|
from mmengine.structures import BaseDataElement
|
|
from mmengine.utils import is_list_of
|
|
|
|
|
|
class TestDataUtils(TestCase):
|
|
|
|
def test_pseudo_collate(self):
|
|
# Test with list of dict tensor inputs.
|
|
input1 = torch.randn(1, 3, 5)
|
|
input2 = torch.randn(1, 3, 5)
|
|
label1 = torch.randn(1)
|
|
label2 = torch.randn(1)
|
|
|
|
data_batch = [
|
|
dict(inputs=input1, data_sample=label1),
|
|
dict(inputs=input2, data_sample=label2)
|
|
]
|
|
data_batch = pseudo_collate(data_batch)
|
|
self.assertTrue(torch.allclose(input1, data_batch['inputs'][0]))
|
|
self.assertTrue(torch.allclose(input2, data_batch['inputs'][1]))
|
|
self.assertTrue(torch.allclose(label1, data_batch['data_sample'][0]))
|
|
self.assertTrue(torch.allclose(label2, data_batch['data_sample'][1]))
|
|
|
|
# Test with list of dict, and each element contains `data_sample`
|
|
# inputs
|
|
data_sample1 = BaseDataElement(label=torch.tensor(1))
|
|
data_sample2 = BaseDataElement(label=torch.tensor(1))
|
|
data = [
|
|
dict(inputs=input1, data_sample=data_sample1),
|
|
dict(inputs=input2, data_sample=data_sample2),
|
|
]
|
|
data_batch = pseudo_collate(data)
|
|
batch_inputs, batch_data_sample = (data_batch['inputs'],
|
|
data_batch['data_sample'])
|
|
# check batch_inputs
|
|
self.assertTrue(is_list_of(batch_inputs, torch.Tensor))
|
|
self.assertIs(input1, batch_inputs[0])
|
|
self.assertIs(input2, batch_inputs[1])
|
|
|
|
# check data_sample
|
|
self.assertIs(batch_data_sample[0], data_sample1)
|
|
self.assertIs(batch_data_sample[1], data_sample2)
|
|
|
|
# Test with list of tuple, each tuple is a nested dict instance
|
|
data_batch = [(dict(
|
|
inputs=input1,
|
|
data_sample=data_sample1,
|
|
value=1,
|
|
name='1',
|
|
nested=dict(data_sample=data_sample1)),
|
|
dict(
|
|
inputs=input2,
|
|
data_sample=data_sample2,
|
|
value=2,
|
|
name='2',
|
|
nested=dict(data_sample=data_sample2))),
|
|
(dict(
|
|
inputs=input1,
|
|
data_sample=data_sample1,
|
|
value=1,
|
|
name='1',
|
|
nested=dict(data_sample=data_sample1)),
|
|
dict(
|
|
inputs=input2,
|
|
data_sample=data_sample2,
|
|
value=2,
|
|
name='2',
|
|
nested=dict(data_sample=data_sample2)))]
|
|
data_batch = pseudo_collate(data_batch)
|
|
batch_inputs_0 = data_batch[0]['inputs']
|
|
batch_inputs_1 = data_batch[1]['inputs']
|
|
batch_data_sample_0 = data_batch[0]['data_sample']
|
|
batch_data_sample_1 = data_batch[1]['data_sample']
|
|
batch_value_0 = data_batch[0]['value']
|
|
batch_value_1 = data_batch[1]['value']
|
|
batch_name_0 = data_batch[0]['name']
|
|
batch_name_1 = data_batch[1]['name']
|
|
batch_nested_0 = data_batch[0]['nested']
|
|
batch_nested_1 = data_batch[1]['nested']
|
|
|
|
self.assertTrue(is_list_of(batch_inputs_0, torch.Tensor))
|
|
self.assertTrue(is_list_of(batch_inputs_1, torch.Tensor))
|
|
self.assertIs(batch_inputs_0[0], input1)
|
|
self.assertIs(batch_inputs_0[1], input1)
|
|
self.assertIs(batch_inputs_1[0], input2)
|
|
self.assertIs(batch_inputs_1[1], input2)
|
|
|
|
self.assertIs(batch_data_sample_0[0], data_sample1)
|
|
self.assertIs(batch_data_sample_0[1], data_sample1)
|
|
self.assertIs(batch_data_sample_1[0], data_sample2)
|
|
self.assertIs(batch_data_sample_1[1], data_sample2)
|
|
|
|
self.assertEqual(batch_value_0, [1, 1])
|
|
self.assertEqual(batch_value_1, [2, 2])
|
|
|
|
self.assertEqual(batch_name_0, ['1', '1'])
|
|
self.assertEqual(batch_name_1, ['2', '2'])
|
|
|
|
self.assertIs(batch_nested_0['data_sample'][0], data_sample1)
|
|
self.assertIs(batch_nested_0['data_sample'][1], data_sample1)
|
|
self.assertIs(batch_nested_1['data_sample'][0], data_sample2)
|
|
self.assertIs(batch_nested_1['data_sample'][1], data_sample2)
|
|
|
|
def test_default_collate(self):
|
|
# `default_collate` has comment logic with `pseudo_collate`, therefore
|
|
# only test it cam stack batch tensor, convert int or float to tensor.
|
|
input1 = torch.randn(1, 3, 5)
|
|
input2 = torch.randn(1, 3, 5)
|
|
data_batch = [(
|
|
dict(inputs=input1, value=1, array=np.array(1)),
|
|
dict(inputs=input2, value=2, array=np.array(2)),
|
|
),
|
|
(
|
|
dict(inputs=input1, value=1, array=np.array(1)),
|
|
dict(inputs=input2, value=2, array=np.array(2)),
|
|
)]
|
|
data_batch = default_collate(data_batch)
|
|
batch_inputs_0 = data_batch[0]['inputs']
|
|
batch_inputs_1 = data_batch[1]['inputs']
|
|
batch_value_0 = data_batch[0]['value']
|
|
batch_value_1 = data_batch[1]['value']
|
|
batch_array_0 = data_batch[0]['array']
|
|
batch_array_1 = data_batch[1]['array']
|
|
|
|
self.assertEqual(tuple(batch_inputs_0.shape), (2, 1, 3, 5))
|
|
self.assertEqual(tuple(batch_inputs_1.shape), (2, 1, 3, 5))
|
|
self.assertTrue(
|
|
torch.allclose(batch_inputs_0, torch.stack([input1, input1])))
|
|
self.assertTrue(
|
|
torch.allclose(batch_inputs_1, torch.stack([input2, input2])))
|
|
|
|
self.assertTrue(
|
|
torch.allclose(batch_value_0,
|
|
torch.stack([torch.tensor(1),
|
|
torch.tensor(1)])))
|
|
self.assertTrue(
|
|
torch.allclose(batch_value_1,
|
|
torch.stack([torch.tensor(2),
|
|
torch.tensor(2)])))
|
|
|
|
self.assertTrue(
|
|
torch.allclose(batch_array_0,
|
|
torch.stack([torch.tensor(1),
|
|
torch.tensor(1)])))
|
|
self.assertTrue(
|
|
torch.allclose(batch_array_1,
|
|
torch.stack([torch.tensor(2),
|
|
torch.tensor(2)])))
|