[Fix] Fix `BaseDataPreprocessor.cast_data` cound not handle string data (#602)
* [Fix] Fix cound not handle string data * Minor refine * Refine type hint Refine type hint * fix as comment * Minor refine * Update mmengine/model/base_model/data_preprocessor.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>pull/650/head
parent
1bf5c0c12e
commit
d1dd240796
|
@ -11,7 +11,8 @@ from mmengine.structures import BaseDataElement
|
|||
from mmengine.utils import is_list_of
|
||||
from ..utils import stack_batch
|
||||
|
||||
CastData = Union[tuple, dict, BaseDataElement, torch.Tensor, list]
|
||||
CastData = Union[tuple, dict, BaseDataElement, torch.Tensor, list, bytes, str,
|
||||
None]
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
|
@ -48,17 +49,20 @@ class BaseDataPreprocessor(nn.Module):
|
|||
"""
|
||||
if isinstance(data, Mapping):
|
||||
return {key: self.cast_data(data[key]) for key in data}
|
||||
elif isinstance(data, (str, bytes)) or data is None:
|
||||
return data
|
||||
elif isinstance(data, tuple) and hasattr(data, '_fields'):
|
||||
# namedtuple
|
||||
return type(data)(*(self.cast_data(sample)for sample in data)) # type: ignore # noqa: E501 # yapf:disable
|
||||
return type(data)(*(self.cast_data(sample) for sample in data)) # type: ignore # noqa: E501 # yapf:disable
|
||||
elif isinstance(data, Sequence):
|
||||
return [self.cast_data(sample) for sample in data]
|
||||
elif isinstance(data, torch.Tensor):
|
||||
return data.to(self.device, non_blocking=self._non_blocking)
|
||||
elif isinstance(data, BaseDataElement):
|
||||
return type(data)(self.cast_data(sample) for sample in data) # type: ignore # noqa: E501 # yapf:disable
|
||||
elif isinstance(data, (torch.Tensor, BaseDataElement)):
|
||||
return data.to(self.device, non_blocking=self._non_blocking)
|
||||
else:
|
||||
return data
|
||||
raise TypeError(
|
||||
'`BaseDataPreprocessor.cast_data`: batch data must contain '
|
||||
'tensors, numpy arrays, numbers, dicts or lists, but '
|
||||
f'found {type(data)}')
|
||||
|
||||
def forward(self, data: dict, training: bool = False) -> Union[dict, list]:
|
||||
"""Preprocesses the data into the model input format.
|
||||
|
|
|
@ -28,8 +28,8 @@ class TestBaseDataPreprocessor(TestCase):
|
|||
label1 = torch.randn(1)
|
||||
label2 = torch.randn(1)
|
||||
|
||||
# Test with dict of batch inputs and batch data samples
|
||||
data = dict(inputs=[input1, input2], data_sample=[label1, label2])
|
||||
|
||||
output = base_data_preprocessor(data)
|
||||
batch_inputs, batch_labels = output['inputs'], output['data_sample']
|
||||
self.assertTrue(torch.is_floating_point(batch_inputs[0]))
|
||||
|
@ -41,40 +41,54 @@ class TestBaseDataPreprocessor(TestCase):
|
|||
assert_allclose(label2, batch_labels[1])
|
||||
|
||||
# Test with tuple of batch inputs and batch data samples
|
||||
data = dict(
|
||||
inputs=torch.stack([input1, input2]), data_sample=[label1, label2])
|
||||
output = base_data_preprocessor(data)['inputs']
|
||||
data = (torch.stack([input1, input2]), (label1, label2))
|
||||
batch_inputs, batch_labels = base_data_preprocessor(data)
|
||||
self.assertTrue(torch.is_floating_point(batch_inputs))
|
||||
self.assertEqual(batch_inputs[0].shape, (1, 3, 5))
|
||||
self.assertEqual(batch_inputs[1].shape, (1, 3, 5))
|
||||
self.assertTrue(torch.is_floating_point(batch_inputs[0]))
|
||||
|
||||
# Test cuda forward
|
||||
if torch.cuda.is_available():
|
||||
# Test with list of data samples.
|
||||
data = dict(inputs=[input1, input2], data_sample=[label1, label2])
|
||||
base_data_preprocessor = base_data_preprocessor.cuda()
|
||||
output = base_data_preprocessor(data)
|
||||
batch_inputs, batch_labels = output['inputs'], output[
|
||||
'data_sample']
|
||||
self.assertTrue(torch.is_floating_point(batch_inputs))
|
||||
self.assertEqual(batch_inputs.device.type, 'cuda')
|
||||
self.assertTrue(torch.is_floating_point(batch_inputs[0]))
|
||||
self.assertEqual(batch_inputs[0].device.type, 'cuda')
|
||||
|
||||
# Fallback to test with cpu.
|
||||
base_data_preprocessor = base_data_preprocessor.cpu()
|
||||
output = base_data_preprocessor(data)
|
||||
batch_inputs, batch_labels = output['inputs'], output[
|
||||
'data_sample']
|
||||
self.assertTrue(torch.is_floating_point(batch_inputs))
|
||||
self.assertEqual(batch_inputs.device.type, 'cpu')
|
||||
self.assertTrue(torch.is_floating_point(batch_inputs[0]))
|
||||
self.assertEqual(batch_inputs[0].device.type, 'cpu')
|
||||
|
||||
# Test `base_data_preprocessor` can be moved to cuda again.
|
||||
base_data_preprocessor = base_data_preprocessor.to('cuda:0')
|
||||
output = base_data_preprocessor(data)
|
||||
batch_inputs, batch_labels = output['inputs'], output[
|
||||
'data_sample']
|
||||
self.assertTrue(torch.is_floating_point(batch_inputs))
|
||||
self.assertEqual(batch_inputs.device.type, 'cuda')
|
||||
self.assertTrue(torch.is_floating_point(batch_inputs[0]))
|
||||
self.assertEqual(batch_inputs[0].device.type, 'cuda')
|
||||
|
||||
# device of `base_data_preprocessor` is cuda, output should be
|
||||
# cuda tensor.
|
||||
self.assertEqual(batch_inputs.device.type, 'cuda')
|
||||
self.assertEqual(batch_inputs[0].device.type, 'cuda')
|
||||
self.assertEqual(batch_labels[0].device.type, 'cuda')
|
||||
|
||||
# Test forward with string value
|
||||
data = dict(string='abc')
|
||||
base_data_preprocessor(data)
|
||||
|
||||
with self.assertRaisesRegex(TypeError,
|
||||
'`BaseDataPreprocessor.cast_data`:'):
|
||||
data = dict(string=object())
|
||||
base_data_preprocessor(data)
|
||||
|
||||
|
||||
class TestImgDataPreprocessor(TestBaseDataPreprocessor):
|
||||
|
||||
|
|
Loading…
Reference in New Issue