# Copyright (c) OpenMMLab. All rights reserved. import random from functools import partial from unittest import TestCase import numpy as np import pytest import torch from mmengine.data import BaseDataElement, BaseDataSample class TestBaseDataSample(TestCase): def setup_data(self): metainfo = dict( img_id=random.randint(0, 100), img_shape=(random.randint(400, 600), random.randint(400, 600))) gt_instances = BaseDataElement( data=dict(bboxes=torch.rand((5, 4)), labels=torch.rand((5, )))) pred_instances = BaseDataElement( data=dict(bboxes=torch.rand((5, 4)), scores=torch.rand((5, )))) data = dict(gt_instances=gt_instances, pred_instances=pred_instances) return metainfo, data def is_equal(self, x, y): assert type(x) == type(y) if isinstance( x, (int, float, str, list, tuple, dict, set, BaseDataElement)): return x == y elif isinstance(x, (torch.Tensor, np.ndarray)): return (x == y).all() def check_key_value(self, instances, metainfo=None, data=None): # check the existence of keys in metainfo, data, and instances if metainfo: for k, v in metainfo.items(): assert k in instances assert k in instances.keys() assert k in instances.metainfo_keys() assert k not in instances.data_keys() assert self.is_equal(instances.get(k), v) assert self.is_equal(getattr(instances, k), v) if data: for k, v in data.items(): assert k in instances assert k in instances.keys() assert k not in instances.metainfo_keys() assert k in instances.data_keys() assert self.is_equal(instances.get(k), v) assert self.is_equal(getattr(instances, k), v) def check_data_device(self, instances, device): # assert instances.device == device for v in instances.data_values(): if isinstance(v, torch.Tensor): assert v.device == torch.device(device) elif isinstance(v, (BaseDataSample, BaseDataElement)): self.check_data_device(v, device) def check_data_dtype(self, instances, dtype): for v in instances.data_values(): if isinstance(v, (torch.Tensor, np.ndarray)): assert isinstance(v, dtype) if isinstance(v, (BaseDataSample, BaseDataElement)): self.check_data_dtype(v, dtype) def check_requires_grad(self, instances): for v in instances.data_values(): if isinstance(v, torch.Tensor): assert v.requires_grad is False if isinstance(v, (BaseDataSample, BaseDataElement)): self.check_requires_grad(v) def test_init(self): # initialization with no data and metainfo metainfo, data = self.setup_data() instances = BaseDataSample() for k in metainfo: assert k not in instances assert instances.get(k, None) is None for k in data: assert k not in instances assert instances.get(k, 'abc') == 'abc' # initialization with kwargs metainfo, data = self.setup_data() instances = BaseDataSample(metainfo=metainfo, data=data) self.check_key_value(instances, metainfo, data) # initialization with args metainfo, data = self.setup_data() instances = BaseDataSample(metainfo, data) self.check_key_value(instances, metainfo, data) # initialization with args metainfo, data = self.setup_data() instances = BaseDataSample(metainfo=metainfo) self.check_key_value(instances, metainfo) instances = BaseDataSample(data=data) self.check_key_value(instances, data=data) def test_new(self): metainfo, data = self.setup_data() instances = BaseDataSample(metainfo=metainfo, data=data) # test new() with no arguments new_instances = instances.new() assert type(new_instances) == type(instances) # After deepcopy, the address of new data'element will be same as # origin, but when change new data' element will not effect the origin # element and will have new address _, data = self.setup_data() new_instances.set_data(data) assert not self.is_equal(new_instances.gt_instances, instances.gt_instances) self.check_key_value(new_instances, metainfo, data) # test new() with arguments metainfo, data = self.setup_data() new_instances = instances.new(metainfo=metainfo, data=data) assert type(new_instances) == type(instances) assert id(new_instances.gt_instances) != id(instances.gt_instances) _, new_data = self.setup_data() new_instances.set_data(new_data) assert id(new_instances.gt_instances) != id(data['gt_instances']) self.check_key_value(new_instances, metainfo, new_data) metainfo, data = self.setup_data() new_instances = instances.new(metainfo=metainfo) def test_set_metainfo(self): metainfo, _ = self.setup_data() instances = BaseDataSample() instances.set_metainfo(metainfo) self.check_key_value(instances, metainfo=metainfo) # test setting existing keys and new keys new_metainfo, _ = self.setup_data() new_metainfo.update(other=123) instances.set_metainfo(new_metainfo) self.check_key_value(instances, metainfo=new_metainfo) # test have the same key in data # TODO _, data = self.setup_data() instances = BaseDataSample(data=data) _, data = self.setup_data() with self.assertRaises(AttributeError): instances.set_metainfo(data) with self.assertRaises(AssertionError): instances.set_metainfo(123) def test_set_data(self): metainfo, data = self.setup_data() instances = BaseDataSample() instances.gt_instances = data['gt_instances'] instances.pred_instances = data['pred_instances'] self.check_key_value(instances, data=data) # a.xx only set data rather than metainfo instances.img_shape = metainfo['img_shape'] instances.img_id = metainfo['img_id'] self.check_key_value(instances, data=metainfo) # test can not set metainfo with `.` metainfo, data = self.setup_data() instances = BaseDataSample(metainfo, data) with self.assertRaises(AttributeError): instances.img_shape = metainfo['img_shape'] # test set '_metainfo_fields' or '_data_fields' with self.assertRaises(AttributeError): instances._metainfo_fields = 1 with self.assertRaises(AttributeError): instances._data_fields = 1 with self.assertRaises(AssertionError): instances.set_data(123) def test_delete_modify(self): metainfo, data = self.setup_data() instances = BaseDataSample(metainfo, data) new_metainfo, new_data = self.setup_data() instances.gt_instances = new_data['gt_instances'] instances.pred_instances = new_data['pred_instances'] # a.xx only set data rather than metainfo instances.set_metainfo(new_metainfo) self.check_key_value(instances, new_metainfo, new_data) assert not self.is_equal(instances.gt_instances, data['gt_instances']) assert not self.is_equal(instances.pred_instances, data['pred_instances']) assert not self.is_equal(instances.img_id, metainfo['img_id']) assert not self.is_equal(instances.img_shape, metainfo['img_shape']) del instances.gt_instances del instances.img_id assert not self.is_equal( instances.pop('pred_instances', None), data['pred_instances']) with self.assertRaises(AttributeError): del instances.pred_instances assert 'gt_instances' not in instances assert 'pred_instances' not in instances assert 'img_id' not in instances assert instances.pop('gt_instances', None) is None # test pop not exist key without default with self.assertRaises(KeyError): instances.pop('gt_instances') assert instances.pop('pred_instances', 'abcdef') == 'abcdef' assert instances.pop('img_id', None) is None # test pop not exist key without default with self.assertRaises(KeyError): instances.pop('img_id') assert instances.pop('img_shape') == new_metainfo['img_shape'] # test del '_metainfo_fields' or '_data_fields' with self.assertRaises(AttributeError): del instances._metainfo_fields with self.assertRaises(AttributeError): del instances._data_fields @pytest.mark.skipif( not torch.cuda.is_available(), reason='GPU is required!') def test_cuda(self): metainfo, data = self.setup_data() instances = BaseDataSample(metainfo, data) cuda_instances = instances.cuda() self.check_data_device(cuda_instances, 'cuda:0') # here we further test to convert from cuda to cpu cpu_instances = cuda_instances.cpu() self.check_data_device(cpu_instances, 'cpu') del cuda_instances cuda_instances = instances.to('cuda:0') self.check_data_device(cuda_instances, 'cuda:0') _, data = self.setup_data() instances = BaseDataSample(metainfo=data) cuda_instances = instances.cuda() self.check_data_device(cuda_instances, 'cuda:0') # here we further test to convert from cuda to cpu cpu_instances = cuda_instances.cpu() self.check_data_device(cpu_instances, 'cpu') del cuda_instances cuda_instances = instances.to('cuda:0') self.check_data_device(cuda_instances, 'cuda:0') def test_cpu(self): metainfo, data = self.setup_data() instances = BaseDataSample(metainfo, data) self.check_data_device(instances, 'cpu') cpu_instances = instances.cpu() # assert cpu_instances.device == 'cpu' self.check_data_device(cpu_instances, 'cpu') _, data = self.setup_data() instances = BaseDataSample(metainfo=data) self.check_data_device(instances, 'cpu') cpu_instances = instances.cpu() # assert cpu_instances.device == 'cpu' self.check_data_device(cpu_instances, 'cpu') def test_numpy_tensor(self): metainfo, data = self.setup_data() data.update(bboxes=torch.rand((5, 4))) instances = BaseDataSample(metainfo, data) np_instances = instances.numpy() self.check_data_dtype(np_instances, np.ndarray) tensor_instances = np_instances.to_tensor() self.check_data_dtype(tensor_instances, torch.Tensor) _, data = self.setup_data() data.update(bboxes=torch.rand((5, 4))) instances = BaseDataSample(metainfo=data) np_instances = instances.numpy() self.check_data_dtype(np_instances, np.ndarray) tensor_instances = np_instances.to_tensor() self.check_data_dtype(tensor_instances, torch.Tensor) def test_detach(self): metainfo, data = self.setup_data() instances = BaseDataSample(metainfo, data) instances.detach() self.check_requires_grad(instances) _, data = self.setup_data() instances = BaseDataSample(metainfo=data) instances.detach() self.check_requires_grad(instances) def test_repr(self): metainfo = dict(img_shape=(800, 1196, 3)) gt_instances = BaseDataElement( metainfo=metainfo, data=dict(det_labels=torch.LongTensor([0, 1, 2, 3]))) data = dict(gt_instances=gt_instances) sample = BaseDataSample(metainfo=metainfo, data=data) address = hex(id(sample)) address_gt_instances = hex(id(sample.gt_instances)) assert repr(sample) == (f'\n' f'\n) at {address}>') sample = BaseDataSample(data=metainfo, metainfo=data) address = hex(id(sample)) address_gt_instances = hex(id(sample.gt_instances)) assert repr(sample) == (f'\n' f'\n DATA FIELDS \n' f'img_shape: (800, 1196, 3) \n' f'\n) at {address}>') metainfo = dict(bboxes=torch.rand((5, 4))) sample = BaseDataSample(metainfo=metainfo) address = hex(id(sample)) assert repr(sample) == (f'') sample = BaseDataSample(data=metainfo) address = hex(id(sample)) assert repr(sample) == (f'') def test_set_get_fields(self): metainfo, data = self.setup_data() instances = BaseDataSample(metainfo) for key, value in data.items(): instances.set_field(name=key, value=value, dtype=BaseDataElement) self.check_key_value(instances, data=data) # test type check _, data = self.setup_data() instances = BaseDataSample() for key, value in data.items(): with self.assertRaises(AssertionError): instances.set_field( name=key, value=value, dtype=BaseDataSample) def test_del_field(self): metainfo, data = self.setup_data() instances = BaseDataSample(metainfo) for key, value in data.items(): instances.set_field(value=value, name=key, dtype=BaseDataElement) instances.del_field('gt_instances') instances.del_field('pred_instances') # gt_instance has been deleted, instances does not have the gt_instance with self.assertRaises(AttributeError): instances.del_field('gt_instances') assert 'gt_instances' not in instances assert 'pred_instances' not in instances def test_inheritance(self): class DetDataSample(BaseDataSample): proposals = property( fget=partial(BaseDataSample.get_field, name='_proposals'), fset=partial( BaseDataSample.set_field, name='_proposals', dtype=BaseDataElement), fdel=partial(BaseDataSample.del_field, name='_proposals'), doc='Region proposals of an image') gt_instances = property( fget=partial(BaseDataSample.get_field, name='_gt_instances'), fset=partial( BaseDataSample.set_field, name='_gt_instances', dtype=BaseDataElement), fdel=partial(BaseDataSample.del_field, name='_gt_instances'), doc='Ground truth instances of an image') pred_instances = property( fget=partial(BaseDataSample.get_field, name='_pred_instances'), fset=partial( BaseDataSample.set_field, name='_pred_instances', dtype=BaseDataElement), fdel=partial(BaseDataSample.del_field, name='_pred_instances'), doc='Predicted instances of an image') det_sample = DetDataSample() # test set proposals = BaseDataElement(data=dict(bboxes=torch.rand((5, 4)))) det_sample.proposals = proposals assert 'proposals' in det_sample # test get assert det_sample.proposals == proposals # test delete del det_sample.proposals assert 'proposals' not in det_sample # test the data whether meet the requirements with self.assertRaises(AssertionError): det_sample.proposals = torch.rand((5, 4)) def test_values(self): # test_metainfo_values metainfo, data = self.setup_data() instances = BaseDataSample(metainfo, data) assert len(instances.metainfo_values()) == len(metainfo.values()) # test_values assert len( instances.values()) == len(metainfo.values()) + len(data.values()) # test_data_values assert len(instances.data_values()) == len(data.values()) def test_keys(self): # test_metainfo_keys metainfo, data = self.setup_data() instances = BaseDataSample(metainfo, data) assert len(instances.metainfo_keys()) == len(metainfo.keys()) # test_keys assert len(instances.keys()) == len(data.keys()) + len(metainfo.keys()) # test_data_keys assert len(instances.data_keys()) == len(data.keys()) def test_items(self): # test_metainfo_items metainfo, data = self.setup_data() instances = BaseDataSample(metainfo, data) assert len(dict(instances.metainfo_items())) == len( dict(metainfo.items())) # test_items assert len(dict(instances.items())) == len(dict( metainfo.items())) + len(dict(data.items())) # test_data_items assert len(dict(instances.data_items())) == len(dict(data.items()))