[Fix] Fix image dtype when enable_normalize=False. (#301)
* [Fix] Fix image dtype when enable_normalize=False. * update ut * move to collate * update utpull/302/head
parent
bcab813242
commit
819e10c24c
|
@ -53,7 +53,7 @@ class BaseDataPreprocessor(nn.Module):
|
||||||
Tuple[List[torch.Tensor], Optional[list]]: Unstacked list of input
|
Tuple[List[torch.Tensor], Optional[list]]: Unstacked list of input
|
||||||
tensor and list of labels at target device.
|
tensor and list of labels at target device.
|
||||||
"""
|
"""
|
||||||
inputs = [_data['inputs'].to(self._device) for _data in data]
|
inputs = [_data['inputs'].to(self._device).float() for _data in data]
|
||||||
batch_data_samples: List[BaseDataElement] = []
|
batch_data_samples: List[BaseDataElement] = []
|
||||||
# Model can get predictions without any data samples.
|
# Model can get predictions without any data samples.
|
||||||
for _data in data:
|
for _data in data:
|
||||||
|
|
|
@ -28,6 +28,7 @@ class TestBaseDataPreprocessor(TestCase):
|
||||||
]
|
]
|
||||||
|
|
||||||
batch_inputs, batch_labels = base_data_preprocessor(data)
|
batch_inputs, batch_labels = base_data_preprocessor(data)
|
||||||
|
self.assertTrue(torch.is_floating_point(batch_inputs))
|
||||||
self.assertEqual(batch_inputs.shape, (2, 1, 3, 5))
|
self.assertEqual(batch_inputs.shape, (2, 1, 3, 5))
|
||||||
|
|
||||||
assert_allclose(input1, batch_inputs[0])
|
assert_allclose(input1, batch_inputs[0])
|
||||||
|
@ -38,14 +39,17 @@ class TestBaseDataPreprocessor(TestCase):
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
base_data_preprocessor = base_data_preprocessor.cuda()
|
base_data_preprocessor = base_data_preprocessor.cuda()
|
||||||
batch_inputs, batch_labels = base_data_preprocessor(data)
|
batch_inputs, batch_labels = base_data_preprocessor(data)
|
||||||
|
self.assertTrue(torch.is_floating_point(batch_inputs))
|
||||||
self.assertEqual(batch_inputs.device.type, 'cuda')
|
self.assertEqual(batch_inputs.device.type, 'cuda')
|
||||||
|
|
||||||
base_data_preprocessor = base_data_preprocessor.cpu()
|
base_data_preprocessor = base_data_preprocessor.cpu()
|
||||||
batch_inputs, batch_labels = base_data_preprocessor(data)
|
batch_inputs, batch_labels = base_data_preprocessor(data)
|
||||||
|
self.assertTrue(torch.is_floating_point(batch_inputs))
|
||||||
self.assertEqual(batch_inputs.device.type, 'cpu')
|
self.assertEqual(batch_inputs.device.type, 'cpu')
|
||||||
|
|
||||||
base_data_preprocessor = base_data_preprocessor.to('cuda:0')
|
base_data_preprocessor = base_data_preprocessor.to('cuda:0')
|
||||||
batch_inputs, batch_labels = base_data_preprocessor(data)
|
batch_inputs, batch_labels = base_data_preprocessor(data)
|
||||||
|
self.assertTrue(torch.is_floating_point(batch_inputs))
|
||||||
self.assertEqual(batch_inputs.device.type, 'cuda')
|
self.assertEqual(batch_inputs.device.type, 'cuda')
|
||||||
|
|
||||||
|
|
||||||
|
@ -122,6 +126,7 @@ class TestImageDataPreprocessor(TestBaseDataPreprocessor):
|
||||||
|
|
||||||
target_inputs = [target_inputs1, target_inputs2]
|
target_inputs = [target_inputs1, target_inputs2]
|
||||||
inputs, data_samples = data_preprocessor(data, True)
|
inputs, data_samples = data_preprocessor(data, True)
|
||||||
|
self.assertTrue(torch.is_floating_point(inputs))
|
||||||
|
|
||||||
target_data_samples = [data_sample1, data_sample2]
|
target_data_samples = [data_sample1, data_sample2]
|
||||||
for input_, data_sample, target_input, target_data_sample in zip(
|
for input_, data_sample, target_input, target_data_sample in zip(
|
||||||
|
@ -142,6 +147,7 @@ class TestImageDataPreprocessor(TestBaseDataPreprocessor):
|
||||||
|
|
||||||
target_inputs = [target_inputs1, target_inputs2]
|
target_inputs = [target_inputs1, target_inputs2]
|
||||||
inputs, data_samples = data_preprocessor(data, True)
|
inputs, data_samples = data_preprocessor(data, True)
|
||||||
|
self.assertTrue(torch.is_floating_point(inputs))
|
||||||
|
|
||||||
target_data_samples = [data_sample1, data_sample2]
|
target_data_samples = [data_sample1, data_sample2]
|
||||||
for input_, data_sample, target_input, target_data_sample in zip(
|
for input_, data_sample, target_input, target_data_sample in zip(
|
||||||
|
|
Loading…
Reference in New Issue