mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Feature] Add InstanceData (#160)
* [Enhancement] refactor base data elment * fix comment * fix comment * fix pop not existing key without error * add instance_data * update * refine code * add refer Co-authored-by: liukuikun <641417025@qq.com>
This commit is contained in:
parent
1927bc7726
commit
dc594e75bf
@ -1,9 +1,10 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from .base_data_element import BaseDataElement
|
from .base_data_element import BaseDataElement
|
||||||
|
from .instance_data import InstanceData
|
||||||
from .sampler import DefaultSampler, InfiniteSampler
|
from .sampler import DefaultSampler, InfiniteSampler
|
||||||
from .utils import pseudo_collate, worker_init_fn
|
from .utils import pseudo_collate, worker_init_fn
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'BaseDataElement', 'DefaultSampler', 'InfiniteSampler', 'worker_init_fn',
|
'BaseDataElement', 'DefaultSampler', 'InfiniteSampler', 'worker_init_fn',
|
||||||
'pseudo_collate'
|
'pseudo_collate', 'InstanceData'
|
||||||
]
|
]
|
||||||
|
206
mmengine/data/instance_data.py
Normal file
206
mmengine/data/instance_data.py
Normal file
@ -0,0 +1,206 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
import itertools
|
||||||
|
from typing import List, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from .base_data_element import BaseDataElement
|
||||||
|
|
||||||
|
|
||||||
|
# Modified from
|
||||||
|
# https://github.com/open-mmlab/mmdetection/blob/master/mmdet/core/data_structures/instance_data.py # noqa
|
||||||
|
class InstanceData(BaseDataElement):
|
||||||
|
"""Data structure for instance-level annnotations or predictions.
|
||||||
|
|
||||||
|
Subclass of :class:`BaseDataElement`. All value in `data_fields`
|
||||||
|
should have the same length. This design refer to
|
||||||
|
https://github.com/facebookresearch/detectron2/blob/master/detectron2/structures/instances.py # noqa E501
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> from mmengine.data import InstanceData
|
||||||
|
>>> import numpy as np
|
||||||
|
>>> img_meta = dict(img_shape=(800, 1196, 3), pad_shape=(800, 1216, 3))
|
||||||
|
>>> instance_data = InstanceData(metainfo=img_meta)
|
||||||
|
>>> 'img_shape' in instance_data
|
||||||
|
True
|
||||||
|
>>> instance_data.det_labels = torch.LongTensor([2, 3])
|
||||||
|
>>> instance_data["det_scores"] = torch.Tensor([0.8, 0.7])
|
||||||
|
>>> instance_data.bboxes = torch.rand((2, 4))
|
||||||
|
>>> len(instance_data)
|
||||||
|
4
|
||||||
|
>>> print(instance_data)
|
||||||
|
<InstanceData(
|
||||||
|
|
||||||
|
META INFORMATION
|
||||||
|
pad_shape: (800, 1196, 3)
|
||||||
|
img_shape: (800, 1216, 3)
|
||||||
|
|
||||||
|
DATA FIELDS
|
||||||
|
det_labels: tensor([2, 3])
|
||||||
|
det_scores: tensor([0.8, 0.7000])
|
||||||
|
bboxes: tensor([[0.4997, 0.7707, 0.0595, 0.4188],
|
||||||
|
[0.8101, 0.3105, 0.5123, 0.6263]])
|
||||||
|
) at 0x7fb492de6280>
|
||||||
|
>>> sorted_results = instance_data[instance_data.det_scores.sort().indices]
|
||||||
|
>>> sorted_results.det_scores
|
||||||
|
tensor([0.7000, 0.8000])
|
||||||
|
>>> print(instance_data[instance_data.det_scores > 0.75])
|
||||||
|
<InstanceData(
|
||||||
|
|
||||||
|
META INFORMATION
|
||||||
|
pad_shape: (800, 1216, 3)
|
||||||
|
img_shape: (800, 1196, 3)
|
||||||
|
|
||||||
|
DATA FIELDS
|
||||||
|
det_labels: tensor([0])
|
||||||
|
bboxes: tensor([[0.4997, 0.7707, 0.0595, 0.4188]])
|
||||||
|
det_scores: tensor([0.8000])
|
||||||
|
) at 0x7fb5cf6e2790>
|
||||||
|
>>> instance_data[instance_data.det_scores > 0.75].det_labels
|
||||||
|
tensor([0])
|
||||||
|
>>> instance_data[instance_data.det_scores > 0.75].det_scores
|
||||||
|
tensor([0.8000])
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __setattr__(self, name: str, value: Union[torch.Tensor, np.ndarray,
|
||||||
|
list]):
|
||||||
|
if name in ('_metainfo_fields', '_data_fields'):
|
||||||
|
if not hasattr(self, name):
|
||||||
|
super().__setattr__(name, value)
|
||||||
|
else:
|
||||||
|
raise AttributeError(
|
||||||
|
f'{name} has been used as a '
|
||||||
|
f'private attribute, which is immutable. ')
|
||||||
|
|
||||||
|
else:
|
||||||
|
assert isinstance(value, (torch.Tensor, np.ndarray, list)), \
|
||||||
|
f'Can set {type(value)}, only support' \
|
||||||
|
f' {(torch.Tensor, np.ndarray, list)}'
|
||||||
|
|
||||||
|
if len(self) > 0:
|
||||||
|
assert len(value) == len(self), f'the length of ' \
|
||||||
|
f'values {len(value)} is ' \
|
||||||
|
f'not consistent with' \
|
||||||
|
f' the length of this ' \
|
||||||
|
f':obj:`InstanceData` ' \
|
||||||
|
f'{len(self)} '
|
||||||
|
super().__setattr__(name, value)
|
||||||
|
|
||||||
|
def __getitem__(
|
||||||
|
self, item: Union[str, slice, int, torch.LongTensor, torch.BoolTensor]
|
||||||
|
) -> 'InstanceData':
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
item (str, obj:`slice`,
|
||||||
|
obj`torch.LongTensor`, obj:`torch.BoolTensor`):
|
||||||
|
get the corresponding values according to item.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
obj:`InstanceData`: Corresponding values.
|
||||||
|
"""
|
||||||
|
assert len(self) > 0, ' This is a empty instance'
|
||||||
|
|
||||||
|
assert isinstance(
|
||||||
|
item, (str, slice, int, torch.LongTensor, torch.BoolTensor))
|
||||||
|
|
||||||
|
if isinstance(item, str):
|
||||||
|
return getattr(self, item)
|
||||||
|
|
||||||
|
if type(item) == int:
|
||||||
|
if item >= len(self) or item < -len(self): # type:ignore
|
||||||
|
raise IndexError(f'Index {item} out of range!')
|
||||||
|
else:
|
||||||
|
# keep the dimension
|
||||||
|
item = slice(item, None, len(self))
|
||||||
|
|
||||||
|
new_data = self.new(data={})
|
||||||
|
if isinstance(item, torch.Tensor):
|
||||||
|
assert item.dim() == 1, 'Only support to get the' \
|
||||||
|
' values along the first dimension.'
|
||||||
|
if isinstance(item, torch.BoolTensor):
|
||||||
|
assert len(item) == len(self), f'The shape of the' \
|
||||||
|
f' input(BoolTensor)) ' \
|
||||||
|
f'{len(item)} ' \
|
||||||
|
f' does not match the shape ' \
|
||||||
|
f'of the indexed tensor ' \
|
||||||
|
f'in results_filed ' \
|
||||||
|
f'{len(self)} at ' \
|
||||||
|
f'first dimension. '
|
||||||
|
|
||||||
|
for k, v in self.items():
|
||||||
|
if isinstance(v, torch.Tensor):
|
||||||
|
new_data[k] = v[item]
|
||||||
|
elif isinstance(v, np.ndarray):
|
||||||
|
new_data[k] = v[item.cpu().numpy()]
|
||||||
|
elif isinstance(v, list):
|
||||||
|
r_list = []
|
||||||
|
# convert to indexes from boolTensor
|
||||||
|
if isinstance(item, torch.BoolTensor):
|
||||||
|
indexes = torch.nonzero(item).view(-1)
|
||||||
|
else:
|
||||||
|
indexes = item
|
||||||
|
for index in indexes:
|
||||||
|
r_list.append(v[index])
|
||||||
|
new_data[k] = r_list
|
||||||
|
else:
|
||||||
|
# item is a slice
|
||||||
|
for k, v in self.items():
|
||||||
|
new_data[k] = v[item]
|
||||||
|
return new_data # type:ignore
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def cat(instances_list: List['InstanceData']) -> 'InstanceData':
|
||||||
|
"""Concat the instances of all :obj:`InstanceData` in the list.
|
||||||
|
|
||||||
|
Note: To ensure that cat returns as expected, make sure that
|
||||||
|
all elements in the list must have exactly the same keys.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
instances_list (list[:obj:`InstanceData`]): A list
|
||||||
|
of :obj:`InstanceData`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
obj:`InstanceData`
|
||||||
|
"""
|
||||||
|
assert all(
|
||||||
|
isinstance(results, InstanceData) for results in instances_list)
|
||||||
|
assert len(instances_list) > 0
|
||||||
|
if len(instances_list) == 1:
|
||||||
|
return instances_list[0]
|
||||||
|
|
||||||
|
# metainfo and data_fields must be exactly the
|
||||||
|
# same for each element to avoid exceptions.
|
||||||
|
field_keys_list = [
|
||||||
|
instances.all_keys() for instances in instances_list
|
||||||
|
]
|
||||||
|
assert len(set([len(field_keys) for field_keys in field_keys_list])) \
|
||||||
|
== 1 and len(set(itertools.chain(*field_keys_list))) \
|
||||||
|
== len(field_keys_list[0]), 'There are different keys in ' \
|
||||||
|
'`instances_list`, which may ' \
|
||||||
|
'cause the cat operation ' \
|
||||||
|
'to fail. Please make sure all ' \
|
||||||
|
'elements in `instances_list` ' \
|
||||||
|
'have the exact same key '
|
||||||
|
|
||||||
|
new_data = instances_list[0].new(data={})
|
||||||
|
for k in instances_list[0].keys():
|
||||||
|
values = [results[k] for results in instances_list]
|
||||||
|
v0 = values[0]
|
||||||
|
if isinstance(v0, torch.Tensor):
|
||||||
|
values = torch.cat(values, dim=0)
|
||||||
|
elif isinstance(v0, np.ndarray):
|
||||||
|
values = np.concatenate(values, axis=0)
|
||||||
|
elif isinstance(v0, list):
|
||||||
|
values = list(itertools.chain(*values))
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f'Can not concat the {k} which is a {type(v0)}')
|
||||||
|
new_data[k] = values
|
||||||
|
return new_data # type:ignore
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
if len(self._data_fields) > 0:
|
||||||
|
return len(self.values()[0])
|
||||||
|
else:
|
||||||
|
return 0
|
105
tests/test_data/test_instance_data.py
Normal file
105
tests/test_data/test_instance_data.py
Normal file
@ -0,0 +1,105 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
import random
|
||||||
|
from unittest import TestCase
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from mmengine.data import BaseDataElement, InstanceData
|
||||||
|
|
||||||
|
|
||||||
|
class TestInstanceData(TestCase):
|
||||||
|
|
||||||
|
def setup_data(self):
|
||||||
|
metainfo = dict(
|
||||||
|
img_id=random.randint(0, 100),
|
||||||
|
img_shape=(random.randint(400, 600), random.randint(400, 600)))
|
||||||
|
instances_infos = [1] * 5
|
||||||
|
bboxes = torch.rand((5, 4))
|
||||||
|
labels = np.random.rand(5)
|
||||||
|
instance_data = InstanceData(
|
||||||
|
metainfo=metainfo,
|
||||||
|
bboxes=bboxes,
|
||||||
|
labels=labels,
|
||||||
|
instances_infos=instances_infos)
|
||||||
|
return instance_data
|
||||||
|
|
||||||
|
def test_set_data(self):
|
||||||
|
instance_data = self.setup_data()
|
||||||
|
|
||||||
|
# test set '_metainfo_fields' or '_data_fields'
|
||||||
|
with self.assertRaises(AttributeError):
|
||||||
|
instance_data._metainfo_fields = 1
|
||||||
|
with self.assertRaises(AttributeError):
|
||||||
|
instance_data._data_fields = 1
|
||||||
|
|
||||||
|
# value only supports (torch.Tensor, np.ndarray, list)
|
||||||
|
with self.assertRaises(AssertionError):
|
||||||
|
instance_data.v = 'value'
|
||||||
|
|
||||||
|
# The data length in InstanceData must be the same
|
||||||
|
with self.assertRaises(AssertionError):
|
||||||
|
instance_data.keypoints = torch.rand((17, 2))
|
||||||
|
|
||||||
|
instance_data.keypoints = torch.rand((5, 2))
|
||||||
|
assert 'keypoints' in instance_data
|
||||||
|
|
||||||
|
def test_getitem(self):
|
||||||
|
instance_data = InstanceData()
|
||||||
|
# length must be greater than 0
|
||||||
|
with self.assertRaises(AssertionError):
|
||||||
|
instance_data[1]
|
||||||
|
|
||||||
|
instance_data = self.setup_data()
|
||||||
|
assert len(instance_data) == 5
|
||||||
|
slice_instance_data = instance_data[:2]
|
||||||
|
assert len(slice_instance_data) == 2
|
||||||
|
|
||||||
|
# assert the index should in 0 ~ len(instance_data) -1
|
||||||
|
with pytest.raises(IndexError):
|
||||||
|
instance_data[5]
|
||||||
|
|
||||||
|
# isinstance(str, slice, int, torch.LongTensor, torch.BoolTensor)
|
||||||
|
item = torch.Tensor([1, 2, 3, 4]) # float
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
instance_data[item]
|
||||||
|
|
||||||
|
# when input is a bool tensor, The shape of
|
||||||
|
# the input at index 0 should equal to
|
||||||
|
# the value length in instance_data_field
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
instance_data[item.bool()]
|
||||||
|
|
||||||
|
# test Longtensor
|
||||||
|
long_tensor = torch.randint(5, (2, ))
|
||||||
|
long_index_instance_data = instance_data[long_tensor]
|
||||||
|
assert len(long_index_instance_data) == len(long_tensor)
|
||||||
|
|
||||||
|
# test bool tensor
|
||||||
|
bool_tensor = torch.rand(5) > 0.5
|
||||||
|
bool_index_instance_data = instance_data[bool_tensor]
|
||||||
|
assert len(bool_index_instance_data) == bool_tensor.sum()
|
||||||
|
|
||||||
|
def test_cat(self):
|
||||||
|
instance_data_1 = self.setup_data()
|
||||||
|
instance_data_2 = self.setup_data()
|
||||||
|
cat_instance_data = InstanceData.cat(
|
||||||
|
[instance_data_1, instance_data_2])
|
||||||
|
assert len(cat_instance_data) == 10
|
||||||
|
|
||||||
|
# All inputs must be InstanceData
|
||||||
|
instance_data_2 = BaseDataElement(
|
||||||
|
bboxes=torch.rand((5, 4)), labels=torch.rand((5, )))
|
||||||
|
with self.assertRaises(AssertionError):
|
||||||
|
InstanceData.cat([instance_data_1, instance_data_2])
|
||||||
|
|
||||||
|
# Input List length must be greater than 0
|
||||||
|
with self.assertRaises(AssertionError):
|
||||||
|
InstanceData.cat([])
|
||||||
|
|
||||||
|
def test_len(self):
|
||||||
|
instance_data = self.setup_data()
|
||||||
|
assert len(instance_data) == 5
|
||||||
|
instance_data = InstanceData()
|
||||||
|
assert len(instance_data) == 0
|
Loading…
x
Reference in New Issue
Block a user