[Enhancement] Speedup formatting by replacing np.transpose with torch.permute (#1719)

pull/1722/head^2
Tong Gao 2023-02-16 14:14:03 +08:00 committed by GitHub
parent f820470415
commit df0be646ea
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 55 additions and 6 deletions

View File

@ -90,8 +90,17 @@ class PackTextDetInputs(BaseTransform):
img = results['img']
if len(img.shape) < 3:
img = np.expand_dims(img, -1)
img = np.ascontiguousarray(img.transpose(2, 0, 1))
packed_results['inputs'] = to_tensor(img)
# A simple trick to speedup formatting by 3-5 times when
# OMP_NUM_THREADS != 1
# Refer to https://github.com/open-mmlab/mmdetection/pull/9533
# for more details
if img.flags.c_contiguous:
img = to_tensor(img)
img = img.permute(2, 0, 1).contiguous()
else:
img = np.ascontiguousarray(img.transpose(2, 0, 1))
img = to_tensor(img)
packed_results['inputs'] = img
data_sample = TextDetDataSample()
instance_data = InstanceData()
@ -174,8 +183,17 @@ class PackTextRecogInputs(BaseTransform):
img = results['img']
if len(img.shape) < 3:
img = np.expand_dims(img, -1)
img = np.ascontiguousarray(img.transpose(2, 0, 1))
packed_results['inputs'] = to_tensor(img)
# A simple trick to speedup formatting by 3-5 times when
# OMP_NUM_THREADS != 1
# Refer to https://github.com/open-mmlab/mmdetection/pull/9533
# for more details
if img.flags.c_contiguous:
img = to_tensor(img)
img = img.permute(2, 0, 1).contiguous()
else:
img = np.ascontiguousarray(img.transpose(2, 0, 1))
img = to_tensor(img)
packed_results['inputs'] = img
data_sample = TextRecogDataSample()
gt_text = LabelData()
@ -272,8 +290,17 @@ class PackKIEInputs(BaseTransform):
img = results['img']
if len(img.shape) < 3:
img = np.expand_dims(img, -1)
img = np.ascontiguousarray(img.transpose(2, 0, 1))
packed_results['inputs'] = to_tensor(img)
# A simple trick to speedup formatting by 3-5 times when
# OMP_NUM_THREADS != 1
# Refer to https://github.com/open-mmlab/mmdetection/pull/9533
# for more details
if img.flags.c_contiguous:
img = to_tensor(img)
img = img.permute(2, 0, 1).contiguous()
else:
img = np.ascontiguousarray(img.transpose(2, 0, 1))
img = to_tensor(img)
packed_results['inputs'] = img
else:
packed_results['inputs'] = torch.FloatTensor().reshape(0, 0, 0)

View File

@ -36,9 +36,17 @@ class TestPackTextDetInputs(TestCase):
transform = PackTextDetInputs()
results = transform(copy.deepcopy(datainfo))
self.assertIn('inputs', results)
self.assertEqual(results['inputs'].shape, torch.Size([1, 10, 10]))
self.assertTupleEqual(tuple(results['inputs'].shape), (1, 10, 10))
self.assertIn('data_samples', results)
# test non-contiugous img
nc_datainfo = copy.deepcopy(datainfo)
nc_datainfo['img'] = nc_datainfo['img'].transpose(1, 0)
results = transform(nc_datainfo)
self.assertIn('inputs', results)
self.assertEqual(results['inputs'].shape, torch.Size([1, 10, 10]))
data_sample = results['data_samples']
self.assertIn('bboxes', data_sample.gt_instances)
self.assertIsInstance(data_sample.gt_instances.bboxes, torch.Tensor)
@ -115,6 +123,13 @@ class TestPackTextRecogInputs(TestCase):
self.assertIn('valid_ratio', data_sample)
self.assertIn('pad_shape', data_sample)
# test non-contiugous img
nc_datainfo = copy.deepcopy(datainfo)
nc_datainfo['img'] = nc_datainfo['img'].transpose(1, 0)
results = transform(nc_datainfo)
self.assertIn('inputs', results)
self.assertEqual(results['inputs'].shape, torch.Size([1, 10, 10]))
transform = PackTextRecogInputs(meta_keys=('img_path', ))
results = transform(copy.deepcopy(datainfo))
self.assertIn('inputs', results)
@ -174,6 +189,13 @@ class TestPackKIEInputs(TestCase):
torch.int64)
self.assertIsInstance(data_sample.gt_instances.texts, list)
# test non-contiugous img
nc_datainfo = copy.deepcopy(datainfo)
nc_datainfo['img'] = nc_datainfo['img'].transpose(1, 0)
results = self.transform(nc_datainfo)
self.assertIn('inputs', results)
self.assertEqual(results['inputs'].shape, torch.Size([1, 10, 10]))
transform = PackKIEInputs(meta_keys=('img_path', ))
results = transform(copy.deepcopy(datainfo))
self.assertIn('inputs', results)