mmengine/tests/test_data/test_data_utils.py

156 lines
6.4 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
import numpy as np
import torch
from mmengine.dataset import default_collate, pseudo_collate
from mmengine.structures import BaseDataElement
from mmengine.utils import is_list_of
class TestDataUtils(TestCase):
def test_pseudo_collate(self):
# Test with list of dict tensor inputs.
input1 = torch.randn(1, 3, 5)
input2 = torch.randn(1, 3, 5)
label1 = torch.randn(1)
label2 = torch.randn(1)
data_batch = [
dict(inputs=input1, data_sample=label1),
dict(inputs=input2, data_sample=label2)
]
data_batch = pseudo_collate(data_batch)
self.assertTrue(torch.allclose(input1, data_batch['inputs'][0]))
self.assertTrue(torch.allclose(input2, data_batch['inputs'][1]))
self.assertTrue(torch.allclose(label1, data_batch['data_sample'][0]))
self.assertTrue(torch.allclose(label2, data_batch['data_sample'][1]))
# Test with list of dict, and each element contains `data_sample`
# inputs
data_sample1 = BaseDataElement(label=torch.tensor(1))
data_sample2 = BaseDataElement(label=torch.tensor(1))
data = [
dict(inputs=input1, data_sample=data_sample1),
dict(inputs=input2, data_sample=data_sample2),
]
data_batch = pseudo_collate(data)
batch_inputs, batch_data_sample = (data_batch['inputs'],
data_batch['data_sample'])
# check batch_inputs
self.assertTrue(is_list_of(batch_inputs, torch.Tensor))
self.assertIs(input1, batch_inputs[0])
self.assertIs(input2, batch_inputs[1])
# check data_sample
self.assertIs(batch_data_sample[0], data_sample1)
self.assertIs(batch_data_sample[1], data_sample2)
# Test with list of tuple, each tuple is a nested dict instance
data_batch = [(dict(
inputs=input1,
data_sample=data_sample1,
value=1,
name='1',
nested=dict(data_sample=data_sample1)),
dict(
inputs=input2,
data_sample=data_sample2,
value=2,
name='2',
nested=dict(data_sample=data_sample2))),
(dict(
inputs=input1,
data_sample=data_sample1,
value=1,
name='1',
nested=dict(data_sample=data_sample1)),
dict(
inputs=input2,
data_sample=data_sample2,
value=2,
name='2',
nested=dict(data_sample=data_sample2)))]
data_batch = pseudo_collate(data_batch)
batch_inputs_0 = data_batch[0]['inputs']
batch_inputs_1 = data_batch[1]['inputs']
batch_data_sample_0 = data_batch[0]['data_sample']
batch_data_sample_1 = data_batch[1]['data_sample']
batch_value_0 = data_batch[0]['value']
batch_value_1 = data_batch[1]['value']
batch_name_0 = data_batch[0]['name']
batch_name_1 = data_batch[1]['name']
batch_nested_0 = data_batch[0]['nested']
batch_nested_1 = data_batch[1]['nested']
self.assertTrue(is_list_of(batch_inputs_0, torch.Tensor))
self.assertTrue(is_list_of(batch_inputs_1, torch.Tensor))
self.assertIs(batch_inputs_0[0], input1)
self.assertIs(batch_inputs_0[1], input1)
self.assertIs(batch_inputs_1[0], input2)
self.assertIs(batch_inputs_1[1], input2)
self.assertIs(batch_data_sample_0[0], data_sample1)
self.assertIs(batch_data_sample_0[1], data_sample1)
self.assertIs(batch_data_sample_1[0], data_sample2)
self.assertIs(batch_data_sample_1[1], data_sample2)
self.assertEqual(batch_value_0, [1, 1])
self.assertEqual(batch_value_1, [2, 2])
self.assertEqual(batch_name_0, ['1', '1'])
self.assertEqual(batch_name_1, ['2', '2'])
self.assertIs(batch_nested_0['data_sample'][0], data_sample1)
self.assertIs(batch_nested_0['data_sample'][1], data_sample1)
self.assertIs(batch_nested_1['data_sample'][0], data_sample2)
self.assertIs(batch_nested_1['data_sample'][1], data_sample2)
def test_default_collate(self):
# `default_collate` has comment logic with `pseudo_collate`, therefore
# only test it cam stack batch tensor, convert int or float to tensor.
input1 = torch.randn(1, 3, 5)
input2 = torch.randn(1, 3, 5)
data_batch = [(
dict(inputs=input1, value=1, array=np.array(1)),
dict(inputs=input2, value=2, array=np.array(2)),
),
(
dict(inputs=input1, value=1, array=np.array(1)),
dict(inputs=input2, value=2, array=np.array(2)),
)]
data_batch = default_collate(data_batch)
batch_inputs_0 = data_batch[0]['inputs']
batch_inputs_1 = data_batch[1]['inputs']
batch_value_0 = data_batch[0]['value']
batch_value_1 = data_batch[1]['value']
batch_array_0 = data_batch[0]['array']
batch_array_1 = data_batch[1]['array']
self.assertEqual(tuple(batch_inputs_0.shape), (2, 1, 3, 5))
self.assertEqual(tuple(batch_inputs_1.shape), (2, 1, 3, 5))
self.assertTrue(
torch.allclose(batch_inputs_0, torch.stack([input1, input1])))
self.assertTrue(
torch.allclose(batch_inputs_1, torch.stack([input2, input2])))
self.assertTrue(
torch.allclose(batch_value_0,
torch.stack([torch.tensor(1),
torch.tensor(1)])))
self.assertTrue(
torch.allclose(batch_value_1,
torch.stack([torch.tensor(2),
torch.tensor(2)])))
self.assertTrue(
torch.allclose(batch_array_0,
torch.stack([torch.tensor(1),
torch.tensor(1)])))
self.assertTrue(
torch.allclose(batch_array_1,
torch.stack([torch.tensor(2),
torch.tensor(2)])))