mmengine/tests/test_structures/test_instance_data.py

224 lines
7.6 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import itertools
import random
from unittest import TestCase
import numpy as np
import pytest
import torch
from mmengine.structures import BaseDataElement, InstanceData
class TmpObject:
def __init__(self, tmp) -> None:
assert isinstance(tmp, list)
if len(tmp) > 0:
for t in tmp:
assert isinstance(t, list)
self.tmp = tmp
def __len__(self):
return len(self.tmp)
def __getitem__(self, item):
if isinstance(item, int):
if item >= len(self) or item < -len(self): # type:ignore
raise IndexError(f'Index {item} out of range!')
else:
# keep the dimension
item = slice(item, None, len(self))
return TmpObject(self.tmp[item])
@staticmethod
def cat(tmp_objs):
assert all(isinstance(results, TmpObject) for results in tmp_objs)
if len(tmp_objs) == 1:
return tmp_objs[0]
tmp_list = [tmp_obj.tmp for tmp_obj in tmp_objs]
tmp_list = list(itertools.chain(*tmp_list))
new_data = TmpObject(tmp_list)
return new_data
def __repr__(self):
return str(self.tmp)
class TmpObjectWithoutCat:
def __init__(self, tmp) -> None:
assert isinstance(tmp, list)
if len(tmp) > 0:
for t in tmp:
assert isinstance(t, list)
self.tmp = tmp
def __len__(self):
return len(self.tmp)
def __getitem__(self, item):
if isinstance(item, int):
if item >= len(self) or item < -len(self): # type:ignore
raise IndexError(f'Index {item} out of range!')
else:
# keep the dimension
item = slice(item, None, len(self))
return TmpObjectWithoutCat(self.tmp[item])
def __repr__(self):
return str(self.tmp)
class TestInstanceData(TestCase):
def setup_data(self):
metainfo = dict(
img_id=random.randint(0, 100),
img_shape=(random.randint(400, 600), random.randint(400, 600)))
instances_infos = [1] * 5
bboxes = torch.rand((5, 4))
labels = np.random.rand(5)
kps = [[1, 1], [2, 2], [3, 3], [4, 4], [5, 5]]
ids = (1, 2, 3, 4, 5)
name_ids = '12345'
polygons = TmpObject(np.arange(25).reshape((5, -1)).tolist())
instance_data = InstanceData(
metainfo=metainfo,
bboxes=bboxes,
labels=labels,
polygons=polygons,
kps=kps,
ids=ids,
name_ids=name_ids,
instances_infos=instances_infos)
return instance_data
def test_set_data(self):
instance_data = self.setup_data()
# test set '_metainfo_fields' or '_data_fields'
with self.assertRaises(AttributeError):
instance_data._metainfo_fields = 1
with self.assertRaises(AttributeError):
instance_data._data_fields = 1
# The data length in InstanceData must be the same
with self.assertRaises(AssertionError):
instance_data.keypoints = torch.rand((17, 2))
instance_data.keypoints = torch.rand((5, 2))
assert 'keypoints' in instance_data
def test_getitem(self):
instance_data = InstanceData()
# length must be greater than 0
with self.assertRaises(IndexError):
instance_data[1]
instance_data = self.setup_data()
assert len(instance_data) == 5
slice_instance_data = instance_data[:2]
assert len(slice_instance_data) == 2
slice_instance_data = instance_data[1]
assert len(slice_instance_data) == 1
# assert the index should in 0 ~ len(instance_data) -1
with pytest.raises(IndexError):
instance_data[5]
# isinstance(str, slice, int, torch.LongTensor, torch.BoolTensor)
item = torch.Tensor([1, 2, 3, 4]) # float
with pytest.raises(AssertionError):
instance_data[item]
# when input is a bool tensor, the shape of
# the input at index 0 should equal to
# the value length in instance_data_field
with pytest.raises(AssertionError):
instance_data[item.bool()]
# test LongTensor
long_tensor = torch.randint(5, (2, ))
long_index_instance_data = instance_data[long_tensor]
assert len(long_index_instance_data) == len(long_tensor)
# test BoolTensor
bool_tensor = torch.rand(5) > 0.5
bool_index_instance_data = instance_data[bool_tensor]
assert len(bool_index_instance_data) == bool_tensor.sum()
bool_tensor = torch.rand(5) > 1
empty_instance_data = instance_data[bool_tensor]
assert len(empty_instance_data) == bool_tensor.sum()
# test list index
list_index = [1, 2]
list_index_instance_data = instance_data[list_index]
assert len(list_index_instance_data) == len(list_index)
# test list bool
list_bool = [True, False, True, False, False]
list_bool_instance_data = instance_data[list_bool]
assert len(list_bool_instance_data) == 2
# test numpy
long_numpy = np.random.randint(5, size=2)
long_numpy_instance_data = instance_data[long_numpy]
assert len(long_numpy_instance_data) == len(long_numpy)
bool_numpy = np.random.rand(5) > 0.5
bool_numpy_instance_data = instance_data[bool_numpy]
assert len(bool_numpy_instance_data) == bool_numpy.sum()
# without cat
instance_data.polygons = TmpObjectWithoutCat(
np.arange(25).reshape((5, -1)).tolist())
bool_numpy = np.random.rand(5) > 0.5
with pytest.raises(
ValueError,
match=('The type of `polygons` is '
f'`{type(instance_data.polygons)}`, '
'which has no attribute of `cat`, so it does not '
f'support slice with `bool`')):
bool_numpy_instance_data = instance_data[bool_numpy]
def test_cat(self):
instance_data_1 = self.setup_data()
instance_data_2 = self.setup_data()
cat_instance_data = InstanceData.cat(
[instance_data_1, instance_data_2])
assert len(cat_instance_data) == 10
# All inputs must be InstanceData
instance_data_2 = BaseDataElement(
bboxes=torch.rand((5, 4)), labels=torch.rand((5, )))
with self.assertRaises(AssertionError):
InstanceData.cat([instance_data_1, instance_data_2])
# Input List length must be greater than 0
with self.assertRaises(AssertionError):
InstanceData.cat([])
instance_data_2 = instance_data_1.clone()
instance_data_2 = instance_data_2[torch.zeros(5) > 0.5]
cat_instance_data = InstanceData.cat(
[instance_data_1, instance_data_2])
cat_instance_data = InstanceData.cat([instance_data_1])
assert len(cat_instance_data) == 5
# test custom data cat
instance_data_1.polygons = TmpObjectWithoutCat(
np.arange(25).reshape((5, -1)).tolist())
instance_data_2 = instance_data_1.clone()
with pytest.raises(
ValueError,
match=('The type of `polygons` is '
f'`{type(instance_data_1.polygons)}` '
'which has no attribute of `cat`')):
cat_instance_data = InstanceData.cat(
[instance_data_1, instance_data_2])
def test_len(self):
instance_data = self.setup_data()
assert len(instance_data) == 5
instance_data = InstanceData()
assert len(instance_data) == 0