mmengine/tests/test_data/test_data_utils.py
Mashiro 8770c6c7fc
[Refactor] Refactor data flow to make the interface more natural (#468)
* [Refactor]: modify interface of Visualizer.add_datasample (#365)

* [Refactor] Refactor data flow: refine `data_preprocessor`. (#359)

* refine data_preprocessor

* remove unused BATCH_DATA alias

* Fix type hints

* rename move_data to cast_data

* [Refactor] Refactor data flow: collate data in `collate_fn` of `DataLoader`  (#323)

* acollate data in dataloader

* fix docstring

* refine comment

* fix as comment

* refactor default collate and psedo collate

* foramt test file

* fix docstring

* fix as comment

* rename elem to data_item

* minor fix

* fix as comment

* [Refactor] Refactor data flow: `data_batch` argument of `Evaluator.process is a `dict` (#360)

* refine evaluator and metric

* compatible with new default collate

* replace default collate with pseudo

* Handle data_batch in metric

* fix unit test

* fix unit test

* fix unit test

* minor refine

* make data_batch optional

make data_batch optional

* rename outputs to predictions

* fix ut

* rename predictions to outputs

* fix docstring

* fix docstring

* fix unit test

* make outputs and data_batch to kwargs

* fix unit test

* keep signature of metric

* fix ut

* rename pred_sample arguments to data_sample(Visualizer)

* fix loop and ut

* [refactor]: Refactor model dataflow (#398)

* [Refactor] Refactor data flow: refine `data_preprocessor`. (#359)

* refine data_preprocessor

* remove unused BATCH_DATA alias

* Fix type hints

* rename move_data to cast_data

* refactor model data flow

tmp_commt

tmp commit

* make val_cfg and test_cfg optional

* roll back runner

* pass test mmdet

* fix as comment

fix as comment

fix ci in DataPreprocessor

* fix ut

* fix ut

* fix rebase main

* [Fix]: Fix test val ddp (#462)

* [Fix] Fix docstring and type hint of data flow (#463)

* Fix docstring of data flow

* change signature of hook

* fix unit test

* resolve conflicts

* fix lint
2022-08-24 22:04:55 +08:00

156 lines
6.4 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
import numpy as np
import torch
from mmengine.dataset import default_collate, pseudo_collate
from mmengine.structures import BaseDataElement
from mmengine.utils import is_list_of
class TestDataUtils(TestCase):
def test_pseudo_collate(self):
# Test with list of dict tensor inputs.
input1 = torch.randn(1, 3, 5)
input2 = torch.randn(1, 3, 5)
label1 = torch.randn(1)
label2 = torch.randn(1)
data_batch = [
dict(inputs=input1, data_sample=label1),
dict(inputs=input2, data_sample=label2)
]
data_batch = pseudo_collate(data_batch)
self.assertTrue(torch.allclose(input1, data_batch['inputs'][0]))
self.assertTrue(torch.allclose(input2, data_batch['inputs'][1]))
self.assertTrue(torch.allclose(label1, data_batch['data_sample'][0]))
self.assertTrue(torch.allclose(label2, data_batch['data_sample'][1]))
# Test with list of dict, and each element contains `data_sample`
# inputs
data_sample1 = BaseDataElement(label=torch.tensor(1))
data_sample2 = BaseDataElement(label=torch.tensor(1))
data = [
dict(inputs=input1, data_sample=data_sample1),
dict(inputs=input2, data_sample=data_sample2),
]
data_batch = pseudo_collate(data)
batch_inputs, batch_data_sample = (data_batch['inputs'],
data_batch['data_sample'])
# check batch_inputs
self.assertTrue(is_list_of(batch_inputs, torch.Tensor))
self.assertIs(input1, batch_inputs[0])
self.assertIs(input2, batch_inputs[1])
# check data_sample
self.assertIs(batch_data_sample[0], data_sample1)
self.assertIs(batch_data_sample[1], data_sample2)
# Test with list of tuple, each tuple is a nested dict instance
data_batch = [(dict(
inputs=input1,
data_sample=data_sample1,
value=1,
name='1',
nested=dict(data_sample=data_sample1)),
dict(
inputs=input2,
data_sample=data_sample2,
value=2,
name='2',
nested=dict(data_sample=data_sample2))),
(dict(
inputs=input1,
data_sample=data_sample1,
value=1,
name='1',
nested=dict(data_sample=data_sample1)),
dict(
inputs=input2,
data_sample=data_sample2,
value=2,
name='2',
nested=dict(data_sample=data_sample2)))]
data_batch = pseudo_collate(data_batch)
batch_inputs_0 = data_batch[0]['inputs']
batch_inputs_1 = data_batch[1]['inputs']
batch_data_sample_0 = data_batch[0]['data_sample']
batch_data_sample_1 = data_batch[1]['data_sample']
batch_value_0 = data_batch[0]['value']
batch_value_1 = data_batch[1]['value']
batch_name_0 = data_batch[0]['name']
batch_name_1 = data_batch[1]['name']
batch_nested_0 = data_batch[0]['nested']
batch_nested_1 = data_batch[1]['nested']
self.assertTrue(is_list_of(batch_inputs_0, torch.Tensor))
self.assertTrue(is_list_of(batch_inputs_1, torch.Tensor))
self.assertIs(batch_inputs_0[0], input1)
self.assertIs(batch_inputs_0[1], input1)
self.assertIs(batch_inputs_1[0], input2)
self.assertIs(batch_inputs_1[1], input2)
self.assertIs(batch_data_sample_0[0], data_sample1)
self.assertIs(batch_data_sample_0[1], data_sample1)
self.assertIs(batch_data_sample_1[0], data_sample2)
self.assertIs(batch_data_sample_1[1], data_sample2)
self.assertEqual(batch_value_0, [1, 1])
self.assertEqual(batch_value_1, [2, 2])
self.assertEqual(batch_name_0, ['1', '1'])
self.assertEqual(batch_name_1, ['2', '2'])
self.assertIs(batch_nested_0['data_sample'][0], data_sample1)
self.assertIs(batch_nested_0['data_sample'][1], data_sample1)
self.assertIs(batch_nested_1['data_sample'][0], data_sample2)
self.assertIs(batch_nested_1['data_sample'][1], data_sample2)
def test_default_collate(self):
# `default_collate` has comment logic with `pseudo_collate`, therefore
# only test it cam stack batch tensor, convert int or float to tensor.
input1 = torch.randn(1, 3, 5)
input2 = torch.randn(1, 3, 5)
data_batch = [(
dict(inputs=input1, value=1, array=np.array(1)),
dict(inputs=input2, value=2, array=np.array(2)),
),
(
dict(inputs=input1, value=1, array=np.array(1)),
dict(inputs=input2, value=2, array=np.array(2)),
)]
data_batch = default_collate(data_batch)
batch_inputs_0 = data_batch[0]['inputs']
batch_inputs_1 = data_batch[1]['inputs']
batch_value_0 = data_batch[0]['value']
batch_value_1 = data_batch[1]['value']
batch_array_0 = data_batch[0]['array']
batch_array_1 = data_batch[1]['array']
self.assertEqual(tuple(batch_inputs_0.shape), (2, 1, 3, 5))
self.assertEqual(tuple(batch_inputs_1.shape), (2, 1, 3, 5))
self.assertTrue(
torch.allclose(batch_inputs_0, torch.stack([input1, input1])))
self.assertTrue(
torch.allclose(batch_inputs_1, torch.stack([input2, input2])))
self.assertTrue(
torch.allclose(batch_value_0,
torch.stack([torch.tensor(1),
torch.tensor(1)])))
self.assertTrue(
torch.allclose(batch_value_1,
torch.stack([torch.tensor(2),
torch.tensor(2)])))
self.assertTrue(
torch.allclose(batch_array_0,
torch.stack([torch.tensor(1),
torch.tensor(1)])))
self.assertTrue(
torch.allclose(batch_array_1,
torch.stack([torch.tensor(2),
torch.tensor(2)])))