# Copyright (c) OpenMMLab. All rights reserved. from unittest import TestCase import numpy as np import torch from mmengine.dataset import default_collate, pseudo_collate from mmengine.structures import BaseDataElement from mmengine.utils import is_list_of class TestDataUtils(TestCase): def test_pseudo_collate(self): # Test with list of dict tensor inputs. input1 = torch.randn(1, 3, 5) input2 = torch.randn(1, 3, 5) label1 = torch.randn(1) label2 = torch.randn(1) data_batch = [ dict(inputs=input1, data_sample=label1), dict(inputs=input2, data_sample=label2) ] data_batch = pseudo_collate(data_batch) self.assertTrue(torch.allclose(input1, data_batch['inputs'][0])) self.assertTrue(torch.allclose(input2, data_batch['inputs'][1])) self.assertTrue(torch.allclose(label1, data_batch['data_sample'][0])) self.assertTrue(torch.allclose(label2, data_batch['data_sample'][1])) # Test with list of dict, and each element contains `data_sample` # inputs data_sample1 = BaseDataElement(label=torch.tensor(1)) data_sample2 = BaseDataElement(label=torch.tensor(1)) data = [ dict(inputs=input1, data_sample=data_sample1), dict(inputs=input2, data_sample=data_sample2), ] data_batch = pseudo_collate(data) batch_inputs, batch_data_sample = (data_batch['inputs'], data_batch['data_sample']) # check batch_inputs self.assertTrue(is_list_of(batch_inputs, torch.Tensor)) self.assertIs(input1, batch_inputs[0]) self.assertIs(input2, batch_inputs[1]) # check data_sample self.assertIs(batch_data_sample[0], data_sample1) self.assertIs(batch_data_sample[1], data_sample2) # Test with list of tuple, each tuple is a nested dict instance data_batch = [(dict( inputs=input1, data_sample=data_sample1, value=1, name='1', nested=dict(data_sample=data_sample1)), dict( inputs=input2, data_sample=data_sample2, value=2, name='2', nested=dict(data_sample=data_sample2))), (dict( inputs=input1, data_sample=data_sample1, value=1, name='1', nested=dict(data_sample=data_sample1)), dict( inputs=input2, data_sample=data_sample2, value=2, name='2', nested=dict(data_sample=data_sample2)))] data_batch = pseudo_collate(data_batch) batch_inputs_0 = data_batch[0]['inputs'] batch_inputs_1 = data_batch[1]['inputs'] batch_data_sample_0 = data_batch[0]['data_sample'] batch_data_sample_1 = data_batch[1]['data_sample'] batch_value_0 = data_batch[0]['value'] batch_value_1 = data_batch[1]['value'] batch_name_0 = data_batch[0]['name'] batch_name_1 = data_batch[1]['name'] batch_nested_0 = data_batch[0]['nested'] batch_nested_1 = data_batch[1]['nested'] self.assertTrue(is_list_of(batch_inputs_0, torch.Tensor)) self.assertTrue(is_list_of(batch_inputs_1, torch.Tensor)) self.assertIs(batch_inputs_0[0], input1) self.assertIs(batch_inputs_0[1], input1) self.assertIs(batch_inputs_1[0], input2) self.assertIs(batch_inputs_1[1], input2) self.assertIs(batch_data_sample_0[0], data_sample1) self.assertIs(batch_data_sample_0[1], data_sample1) self.assertIs(batch_data_sample_1[0], data_sample2) self.assertIs(batch_data_sample_1[1], data_sample2) self.assertEqual(batch_value_0, [1, 1]) self.assertEqual(batch_value_1, [2, 2]) self.assertEqual(batch_name_0, ['1', '1']) self.assertEqual(batch_name_1, ['2', '2']) self.assertIs(batch_nested_0['data_sample'][0], data_sample1) self.assertIs(batch_nested_0['data_sample'][1], data_sample1) self.assertIs(batch_nested_1['data_sample'][0], data_sample2) self.assertIs(batch_nested_1['data_sample'][1], data_sample2) def test_default_collate(self): # `default_collate` has comment logic with `pseudo_collate`, therefore # only test it cam stack batch tensor, convert int or float to tensor. input1 = torch.randn(1, 3, 5) input2 = torch.randn(1, 3, 5) data_batch = [( dict(inputs=input1, value=1, array=np.array(1)), dict(inputs=input2, value=2, array=np.array(2)), ), ( dict(inputs=input1, value=1, array=np.array(1)), dict(inputs=input2, value=2, array=np.array(2)), )] data_batch = default_collate(data_batch) batch_inputs_0 = data_batch[0]['inputs'] batch_inputs_1 = data_batch[1]['inputs'] batch_value_0 = data_batch[0]['value'] batch_value_1 = data_batch[1]['value'] batch_array_0 = data_batch[0]['array'] batch_array_1 = data_batch[1]['array'] self.assertEqual(tuple(batch_inputs_0.shape), (2, 1, 3, 5)) self.assertEqual(tuple(batch_inputs_1.shape), (2, 1, 3, 5)) self.assertTrue( torch.allclose(batch_inputs_0, torch.stack([input1, input1]))) self.assertTrue( torch.allclose(batch_inputs_1, torch.stack([input2, input2]))) self.assertTrue( torch.allclose(batch_value_0, torch.stack([torch.tensor(1), torch.tensor(1)]))) self.assertTrue( torch.allclose(batch_value_1, torch.stack([torch.tensor(2), torch.tensor(2)]))) self.assertTrue( torch.allclose(batch_array_0, torch.stack([torch.tensor(1), torch.tensor(1)]))) self.assertTrue( torch.allclose(batch_array_1, torch.stack([torch.tensor(2), torch.tensor(2)])))