[Fix] Fix image dtype when enable_normalize=False. (#301)

* [Fix] Fix image dtype when enable_normalize=False.

* update ut

* move to collate

* update ut
pull/302/head
RangiLyu 2022-06-13 21:21:19 +08:00 committed by GitHub
parent bcab813242
commit 819e10c24c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 7 additions and 1 deletions

View File

@ -53,7 +53,7 @@ class BaseDataPreprocessor(nn.Module):
Tuple[List[torch.Tensor], Optional[list]]: Unstacked list of input
tensor and list of labels at target device.
"""
inputs = [_data['inputs'].to(self._device) for _data in data]
inputs = [_data['inputs'].to(self._device).float() for _data in data]
batch_data_samples: List[BaseDataElement] = []
# Model can get predictions without any data samples.
for _data in data:

View File

@ -28,6 +28,7 @@ class TestBaseDataPreprocessor(TestCase):
]
batch_inputs, batch_labels = base_data_preprocessor(data)
self.assertTrue(torch.is_floating_point(batch_inputs))
self.assertEqual(batch_inputs.shape, (2, 1, 3, 5))
assert_allclose(input1, batch_inputs[0])
@ -38,14 +39,17 @@ class TestBaseDataPreprocessor(TestCase):
if torch.cuda.is_available():
base_data_preprocessor = base_data_preprocessor.cuda()
batch_inputs, batch_labels = base_data_preprocessor(data)
self.assertTrue(torch.is_floating_point(batch_inputs))
self.assertEqual(batch_inputs.device.type, 'cuda')
base_data_preprocessor = base_data_preprocessor.cpu()
batch_inputs, batch_labels = base_data_preprocessor(data)
self.assertTrue(torch.is_floating_point(batch_inputs))
self.assertEqual(batch_inputs.device.type, 'cpu')
base_data_preprocessor = base_data_preprocessor.to('cuda:0')
batch_inputs, batch_labels = base_data_preprocessor(data)
self.assertTrue(torch.is_floating_point(batch_inputs))
self.assertEqual(batch_inputs.device.type, 'cuda')
@ -122,6 +126,7 @@ class TestImageDataPreprocessor(TestBaseDataPreprocessor):
target_inputs = [target_inputs1, target_inputs2]
inputs, data_samples = data_preprocessor(data, True)
self.assertTrue(torch.is_floating_point(inputs))
target_data_samples = [data_sample1, data_sample2]
for input_, data_sample, target_input, target_data_sample in zip(
@ -142,6 +147,7 @@ class TestImageDataPreprocessor(TestBaseDataPreprocessor):
target_inputs = [target_inputs1, target_inputs2]
inputs, data_samples = data_preprocessor(data, True)
self.assertTrue(torch.is_floating_point(inputs))
target_data_samples = [data_sample1, data_sample2]
for input_, data_sample, target_input, target_data_sample in zip(