mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Fix] Fix BaseDataPreprocessor.cast_data
cound not handle string data (#602)
* [Fix] Fix cound not handle string data * Minor refine * Refine type hint Refine type hint * fix as comment * Minor refine * Update mmengine/model/base_model/data_preprocessor.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
This commit is contained in:
parent
1bf5c0c12e
commit
d1dd240796
@ -11,7 +11,8 @@ from mmengine.structures import BaseDataElement
|
|||||||
from mmengine.utils import is_list_of
|
from mmengine.utils import is_list_of
|
||||||
from ..utils import stack_batch
|
from ..utils import stack_batch
|
||||||
|
|
||||||
CastData = Union[tuple, dict, BaseDataElement, torch.Tensor, list]
|
CastData = Union[tuple, dict, BaseDataElement, torch.Tensor, list, bytes, str,
|
||||||
|
None]
|
||||||
|
|
||||||
|
|
||||||
@MODELS.register_module()
|
@MODELS.register_module()
|
||||||
@ -48,17 +49,20 @@ class BaseDataPreprocessor(nn.Module):
|
|||||||
"""
|
"""
|
||||||
if isinstance(data, Mapping):
|
if isinstance(data, Mapping):
|
||||||
return {key: self.cast_data(data[key]) for key in data}
|
return {key: self.cast_data(data[key]) for key in data}
|
||||||
|
elif isinstance(data, (str, bytes)) or data is None:
|
||||||
|
return data
|
||||||
elif isinstance(data, tuple) and hasattr(data, '_fields'):
|
elif isinstance(data, tuple) and hasattr(data, '_fields'):
|
||||||
# namedtuple
|
# namedtuple
|
||||||
return type(data)(*(self.cast_data(sample) for sample in data)) # type: ignore # noqa: E501 # yapf:disable
|
return type(data)(*(self.cast_data(sample) for sample in data)) # type: ignore # noqa: E501 # yapf:disable
|
||||||
elif isinstance(data, Sequence):
|
elif isinstance(data, Sequence):
|
||||||
return [self.cast_data(sample) for sample in data]
|
return type(data)(self.cast_data(sample) for sample in data) # type: ignore # noqa: E501 # yapf:disable
|
||||||
elif isinstance(data, torch.Tensor):
|
elif isinstance(data, (torch.Tensor, BaseDataElement)):
|
||||||
return data.to(self.device, non_blocking=self._non_blocking)
|
|
||||||
elif isinstance(data, BaseDataElement):
|
|
||||||
return data.to(self.device, non_blocking=self._non_blocking)
|
return data.to(self.device, non_blocking=self._non_blocking)
|
||||||
else:
|
else:
|
||||||
return data
|
raise TypeError(
|
||||||
|
'`BaseDataPreprocessor.cast_data`: batch data must contain '
|
||||||
|
'tensors, numpy arrays, numbers, dicts or lists, but '
|
||||||
|
f'found {type(data)}')
|
||||||
|
|
||||||
def forward(self, data: dict, training: bool = False) -> Union[dict, list]:
|
def forward(self, data: dict, training: bool = False) -> Union[dict, list]:
|
||||||
"""Preprocesses the data into the model input format.
|
"""Preprocesses the data into the model input format.
|
||||||
|
@ -28,8 +28,8 @@ class TestBaseDataPreprocessor(TestCase):
|
|||||||
label1 = torch.randn(1)
|
label1 = torch.randn(1)
|
||||||
label2 = torch.randn(1)
|
label2 = torch.randn(1)
|
||||||
|
|
||||||
|
# Test with dict of batch inputs and batch data samples
|
||||||
data = dict(inputs=[input1, input2], data_sample=[label1, label2])
|
data = dict(inputs=[input1, input2], data_sample=[label1, label2])
|
||||||
|
|
||||||
output = base_data_preprocessor(data)
|
output = base_data_preprocessor(data)
|
||||||
batch_inputs, batch_labels = output['inputs'], output['data_sample']
|
batch_inputs, batch_labels = output['inputs'], output['data_sample']
|
||||||
self.assertTrue(torch.is_floating_point(batch_inputs[0]))
|
self.assertTrue(torch.is_floating_point(batch_inputs[0]))
|
||||||
@ -41,40 +41,54 @@ class TestBaseDataPreprocessor(TestCase):
|
|||||||
assert_allclose(label2, batch_labels[1])
|
assert_allclose(label2, batch_labels[1])
|
||||||
|
|
||||||
# Test with tuple of batch inputs and batch data samples
|
# Test with tuple of batch inputs and batch data samples
|
||||||
data = dict(
|
data = (torch.stack([input1, input2]), (label1, label2))
|
||||||
inputs=torch.stack([input1, input2]), data_sample=[label1, label2])
|
batch_inputs, batch_labels = base_data_preprocessor(data)
|
||||||
output = base_data_preprocessor(data)['inputs']
|
self.assertTrue(torch.is_floating_point(batch_inputs))
|
||||||
|
self.assertEqual(batch_inputs[0].shape, (1, 3, 5))
|
||||||
|
self.assertEqual(batch_inputs[1].shape, (1, 3, 5))
|
||||||
self.assertTrue(torch.is_floating_point(batch_inputs[0]))
|
self.assertTrue(torch.is_floating_point(batch_inputs[0]))
|
||||||
|
|
||||||
# Test cuda forward
|
# Test cuda forward
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
# Test with list of data samples.
|
# Test with list of data samples.
|
||||||
|
data = dict(inputs=[input1, input2], data_sample=[label1, label2])
|
||||||
base_data_preprocessor = base_data_preprocessor.cuda()
|
base_data_preprocessor = base_data_preprocessor.cuda()
|
||||||
output = base_data_preprocessor(data)
|
output = base_data_preprocessor(data)
|
||||||
batch_inputs, batch_labels = output['inputs'], output[
|
batch_inputs, batch_labels = output['inputs'], output[
|
||||||
'data_sample']
|
'data_sample']
|
||||||
self.assertTrue(torch.is_floating_point(batch_inputs))
|
self.assertTrue(torch.is_floating_point(batch_inputs[0]))
|
||||||
self.assertEqual(batch_inputs.device.type, 'cuda')
|
self.assertEqual(batch_inputs[0].device.type, 'cuda')
|
||||||
|
|
||||||
|
# Fallback to test with cpu.
|
||||||
base_data_preprocessor = base_data_preprocessor.cpu()
|
base_data_preprocessor = base_data_preprocessor.cpu()
|
||||||
output = base_data_preprocessor(data)
|
output = base_data_preprocessor(data)
|
||||||
batch_inputs, batch_labels = output['inputs'], output[
|
batch_inputs, batch_labels = output['inputs'], output[
|
||||||
'data_sample']
|
'data_sample']
|
||||||
self.assertTrue(torch.is_floating_point(batch_inputs))
|
self.assertTrue(torch.is_floating_point(batch_inputs[0]))
|
||||||
self.assertEqual(batch_inputs.device.type, 'cpu')
|
self.assertEqual(batch_inputs[0].device.type, 'cpu')
|
||||||
|
|
||||||
|
# Test `base_data_preprocessor` can be moved to cuda again.
|
||||||
base_data_preprocessor = base_data_preprocessor.to('cuda:0')
|
base_data_preprocessor = base_data_preprocessor.to('cuda:0')
|
||||||
output = base_data_preprocessor(data)
|
output = base_data_preprocessor(data)
|
||||||
batch_inputs, batch_labels = output['inputs'], output[
|
batch_inputs, batch_labels = output['inputs'], output[
|
||||||
'data_sample']
|
'data_sample']
|
||||||
self.assertTrue(torch.is_floating_point(batch_inputs))
|
self.assertTrue(torch.is_floating_point(batch_inputs[0]))
|
||||||
self.assertEqual(batch_inputs.device.type, 'cuda')
|
self.assertEqual(batch_inputs[0].device.type, 'cuda')
|
||||||
|
|
||||||
# device of `base_data_preprocessor` is cuda, output should be
|
# device of `base_data_preprocessor` is cuda, output should be
|
||||||
# cuda tensor.
|
# cuda tensor.
|
||||||
self.assertEqual(batch_inputs.device.type, 'cuda')
|
self.assertEqual(batch_inputs[0].device.type, 'cuda')
|
||||||
self.assertEqual(batch_labels[0].device.type, 'cuda')
|
self.assertEqual(batch_labels[0].device.type, 'cuda')
|
||||||
|
|
||||||
|
# Test forward with string value
|
||||||
|
data = dict(string='abc')
|
||||||
|
base_data_preprocessor(data)
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(TypeError,
|
||||||
|
'`BaseDataPreprocessor.cast_data`:'):
|
||||||
|
data = dict(string=object())
|
||||||
|
base_data_preprocessor(data)
|
||||||
|
|
||||||
|
|
||||||
class TestImgDataPreprocessor(TestBaseDataPreprocessor):
|
class TestImgDataPreprocessor(TestBaseDataPreprocessor):
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user