mmengine/tests/test_data/test_data_sample.py

299 lines
12 KiB
Python
Raw Normal View History

# Copyright (c) OpenMMLab. All rights reserved.
import random
from functools import partial
from unittest import TestCase
import numpy as np
import pytest
import torch
from mmengine.data import BaseDataElement, BaseDataSample
class TestBaseDataSample(TestCase):
def setup_data(self):
metainfo = dict(
img_id=random.randint(0, 100),
img_shape=(random.randint(400, 600), random.randint(400, 600)))
gt_instances = BaseDataElement(
data=dict(bboxes=torch.rand((5, 4)), labels=torch.rand((5, ))))
pred_instances = BaseDataElement(
data=dict(bboxes=torch.rand((5, 4)), scores=torch.rand((5, ))))
data = dict(gt_instances=gt_instances, pred_instances=pred_instances)
return metainfo, data
def check_key_value(self, instances, metainfo=None, data=None):
# check the existence of keys in metainfo, data, and instances
if metainfo:
for k, v in metainfo.items():
assert k in instances
assert k in instances.keys()
assert k in instances.metainfo_keys()
assert k not in instances.data_keys()
assert instances.get(k) == v
assert getattr(instances, k) == v
if data:
for k, v in data.items():
assert k in instances
assert k in instances.keys()
assert k not in instances.metainfo_keys()
assert k in instances.data_keys()
assert instances.get(k) == v
assert getattr(instances, k) == v
def check_data_device(self, instances, device):
assert instances.device == device
for v in instances.data_values():
if isinstance(v, torch.Tensor):
assert v.device == device
elif isinstance(v, BaseDataElement):
self.check_data_device(v, device)
def check_data_dtype(self, instances, dtype):
for v in instances.data_values():
if isinstance(v, (torch.Tensor, np.ndarray)):
assert isinstance(v, dtype)
if isinstance(v, BaseDataElement):
self.check_data_dtype(v, dtype)
def test_init(self):
# initialization with no data and metainfo
metainfo, data = self.setup_data()
instances = BaseDataSample()
for k in metainfo:
assert k not in instances
assert instances.get(k, None) is None
for k in data:
assert k not in instances
assert instances.get(k, 'abc') == 'abc'
# initialization with kwargs
metainfo, data = self.setup_data()
instances = BaseDataSample(metainfo=metainfo, data=data)
self.check_key_value(instances, metainfo, data)
# initialization with args
metainfo, data = self.setup_data()
instances = BaseDataSample(metainfo, data)
self.check_key_value(instances, metainfo, data)
# initialization with args
metainfo, data = self.setup_data()
instances = BaseDataSample(metainfo=metainfo)
self.check_key_value(instances, metainfo)
instances = BaseDataSample(data=data)
self.check_key_value(instances, data=data)
def test_new(self):
metainfo, data = self.setup_data()
instances = BaseDataSample(metainfo=metainfo, data=data)
# test new() with no arguments
new_instances = instances.new()
assert type(new_instances) == type(instances)
assert id(new_instances.data) != id(instances.data)
assert id(new_instances.bboxes) != id(data)
self.check_key_value(new_instances, metainfo, data)
# test new() with arguments
metainfo, data = self.setup_data()
new_instances = instances.new(metainfo=metainfo, data=data)
assert type(new_instances) == type(instances)
assert id(new_instances.data) != id(instances.data)
assert id(new_instances.data) != id(data)
self.check_key_value(new_instances, metainfo, data)
def test_set_metainfo(self):
metainfo, _ = self.setup_data()
instances = BaseDataSample()
instances.set_metainfo(metainfo)
self.check_key_value(instances, metainfo=metainfo)
# test setting existing keys and new keys
new_metainfo, _ = self.setup_data()
new_metainfo.update(other=123)
instances.set_metainfo(new_metainfo)
self.check_key_value(instances, metainfo=new_metainfo)
def test_set_data(self):
metainfo, data = self.setup_data()
instances = BaseDataSample()
instances.gt_instances = data['gt_instances']
instances.pred_instances = data['pred_instances']
self.check_key_value(instances, data=data)
# a.xx only set data rather than metainfo
instances.img_shape = metainfo['img_shape']
instances.img_id = metainfo['img_id']
self.check_key_value(instances, data=metainfo)
def test_delete_modify(self):
metainfo, data = self.setup_data()
instances = BaseDataSample(metainfo, data)
new_metainfo, new_data = self.setup_data()
instances.gt_instances = new_data['gt_instances']
instances.pred_instances = new_data['pred_instances']
# a.xx only set data rather than metainfo
instances.set_metainfo(new_metainfo)
self.check_key_value(instances, new_metainfo, new_data)
assert instances.gt_instances != data['gt_instances']
assert instances.pred_instances != data['pred_instances']
assert instances.img_id != metainfo['img_id']
assert instances.img_shape != metainfo['img_shape']
del instances.gt_instances
assert instances.pop('pred_instances',
None) == new_data['pred_instances']
# pred_instances has been deleted,
# instances does not have the pred_instances
with self.assertRaises(AttributeError):
del instances.pred_instances
assert 'gt_instances' not in instances
assert 'pred_instances' not in instances
assert instances.pop('gt_instances', None) is None
assert instances.pop('pred_instances', 'abcdef') == 'abcdef'
@pytest.mark.skipif(
not torch.cuda.is_available(), reason='GPU is required!')
def test_cuda(self):
metainfo, data = self.setup_data()
instances = BaseDataSample(metainfo, data)
cuda_instances = instances.cuda()
self.check_data_device(instances, 'cuda:0')
# here we further test to convert from cuda to cpu
cpu_instances = cuda_instances.cpu()
self.check_data_device(cpu_instances, 'cpu')
del cuda_instances
cuda_instances = instances.to('cuda:0')
self.check_data_device(cuda_instances, 'cuda:0')
def test_cpu(self):
metainfo, data = self.setup_data()
instances = BaseDataSample(metainfo, data)
self.check_data_device(instances, 'cpu')
cpu_instances = instances.cpu()
assert cpu_instances.device == 'cpu'
assert cpu_instances.bboxes.device == 'cpu'
assert cpu_instances.scores.device == 'cpu'
def test_numpy_tensor(self):
metainfo, data = self.setup_data()
instances = BaseDataSample(metainfo, data)
np_instances = instances.numpy()
self.check_data_dtype(np_instances, np.ndarray)
tensor_instances = instances.to_tensor()
self.check_data_dtype(tensor_instances, torch.Tensor)
def test_repr(self):
metainfo = dict(img_shape=(800, 1196, 3), pad_shape=(800, 1216, 3))
gt_instances = BaseDataElement(
data=dict(
det_labels=torch.LongTensor([0, 1, 2, 3],
det_scores=torch.Tensor(
[0.01, 0.1, 0.2, 0.3]))))
data = dict(gt_instances=gt_instances)
instances = BaseDataSample(metainfo=metainfo, data=data)
assert repr(instances) == ('<BaseDataSample(\n'
' META INFORMATION\n'
'img_shape: (800, 1196, 3)\n'
'pad_shape: (800, 1216, 3)\n'
' DATA FIELDS\n'
'\tgt_instances: <BaseDataElement(\n'
'\t META INFORMATION\n'
'\timg_shape: (800, 1196, 3)\n'
'\tpad_shape: (800, 1216, 3)\n'
'\t DATA FIELDS\n'
'\tshape of det_labels: torch.Size([4])\n'
'\tshape of det_scores: torch.Size([4])\n'
'\t) at 0x7f84acd10f90>'
') at 0x7f84acd10f90>')
def test_set_get_fields(self):
metainfo, data = self.setup_data()
instances = BaseDataSample(metainfo)
for key, value in data.items():
instances._set_field(value, key, BaseDataElement)
self.check_key_value(instances, data=data)
# test type check
_, data = self.setup_data()
instances = BaseDataSample()
for key, value in data.items():
with self.assertRaises(AssertionError):
instances._set_field(value, key, BaseDataSample)
def test_del_field(self):
metainfo, data = self.setup_data()
instances = BaseDataSample(metainfo)
for key, value in data.items():
instances._set_field(value, key, BaseDataElement)
instances._del_field('gt_instances')
instances._del_field('pred_instances')
# gt_instance has been deleted, instances does not have the gt_instance
with self.assertRaises(AttributeError):
instances._del_field('gt_instances')
assert 'gt_instances' not in instances
assert 'pred_instances' not in instances
def test_inheritance(self):
class DetDataSample(BaseDataSample):
proposals = property(
fget=partial(BaseDataSample._get_field, name='_proposals'),
fset=partial(
BaseDataSample._set_field,
name='_proposals',
dtype=BaseDataElement),
fdel=partial(BaseDataSample._del_field, name='_proposals'),
doc='Region proposals of an image')
gt_instances = property(
fget=partial(BaseDataSample._get_field, name='_gt_instances'),
fset=partial(
BaseDataSample._set_field,
name='_gt_instances',
dtype=BaseDataElement),
fdel=partial(BaseDataSample._del_field, name='_gt_instances'),
doc='Ground truth instances of an image')
pred_instances = property(
fget=partial(
BaseDataSample._get_field, name='_pred_instances'),
fset=partial(
BaseDataSample._set_field,
name='_pred_instances',
dtype=BaseDataElement),
fdel=partial(
BaseDataSample._del_field, name='_pred_instances'),
doc='Predicted instances of an image')
det_sample = DetDataSample()
# test set
proposals = BaseDataElement(data=dict(bboxes=torch.rand((5, 4))))
det_sample.proposals = proposals
assert 'proposals' in det_sample
# test get
assert det_sample.proposals == proposals
# test delete
del det_sample.proposals
assert 'proposals' not in det_sample
# test the data whether meet the requirements
with self.assertRaises(AssertionError):
det_sample.proposals = torch.rand((5, 4))