[Fix] Fix image dtype when enable_normalize=False. (#301)
* [Fix] Fix image dtype when enable_normalize=False. * update ut * move to collate * update utpull/302/head
parent
bcab813242
commit
819e10c24c
|
@ -53,7 +53,7 @@ class BaseDataPreprocessor(nn.Module):
|
|||
Tuple[List[torch.Tensor], Optional[list]]: Unstacked list of input
|
||||
tensor and list of labels at target device.
|
||||
"""
|
||||
inputs = [_data['inputs'].to(self._device) for _data in data]
|
||||
inputs = [_data['inputs'].to(self._device).float() for _data in data]
|
||||
batch_data_samples: List[BaseDataElement] = []
|
||||
# Model can get predictions without any data samples.
|
||||
for _data in data:
|
||||
|
|
|
@ -28,6 +28,7 @@ class TestBaseDataPreprocessor(TestCase):
|
|||
]
|
||||
|
||||
batch_inputs, batch_labels = base_data_preprocessor(data)
|
||||
self.assertTrue(torch.is_floating_point(batch_inputs))
|
||||
self.assertEqual(batch_inputs.shape, (2, 1, 3, 5))
|
||||
|
||||
assert_allclose(input1, batch_inputs[0])
|
||||
|
@ -38,14 +39,17 @@ class TestBaseDataPreprocessor(TestCase):
|
|||
if torch.cuda.is_available():
|
||||
base_data_preprocessor = base_data_preprocessor.cuda()
|
||||
batch_inputs, batch_labels = base_data_preprocessor(data)
|
||||
self.assertTrue(torch.is_floating_point(batch_inputs))
|
||||
self.assertEqual(batch_inputs.device.type, 'cuda')
|
||||
|
||||
base_data_preprocessor = base_data_preprocessor.cpu()
|
||||
batch_inputs, batch_labels = base_data_preprocessor(data)
|
||||
self.assertTrue(torch.is_floating_point(batch_inputs))
|
||||
self.assertEqual(batch_inputs.device.type, 'cpu')
|
||||
|
||||
base_data_preprocessor = base_data_preprocessor.to('cuda:0')
|
||||
batch_inputs, batch_labels = base_data_preprocessor(data)
|
||||
self.assertTrue(torch.is_floating_point(batch_inputs))
|
||||
self.assertEqual(batch_inputs.device.type, 'cuda')
|
||||
|
||||
|
||||
|
@ -122,6 +126,7 @@ class TestImageDataPreprocessor(TestBaseDataPreprocessor):
|
|||
|
||||
target_inputs = [target_inputs1, target_inputs2]
|
||||
inputs, data_samples = data_preprocessor(data, True)
|
||||
self.assertTrue(torch.is_floating_point(inputs))
|
||||
|
||||
target_data_samples = [data_sample1, data_sample2]
|
||||
for input_, data_sample, target_input, target_data_sample in zip(
|
||||
|
@ -142,6 +147,7 @@ class TestImageDataPreprocessor(TestBaseDataPreprocessor):
|
|||
|
||||
target_inputs = [target_inputs1, target_inputs2]
|
||||
inputs, data_samples = data_preprocessor(data, True)
|
||||
self.assertTrue(torch.is_floating_point(inputs))
|
||||
|
||||
target_data_samples = [data_sample1, data_sample2]
|
||||
for input_, data_sample, target_input, target_data_sample in zip(
|
||||
|
|
Loading…
Reference in New Issue