mirror of https://github.com/open-mmlab/mmocr.git
[Enhancement] Speedup formatting by replacing np.transpose with torch.permute (#1719)
parent
f820470415
commit
df0be646ea
|
@ -90,8 +90,17 @@ class PackTextDetInputs(BaseTransform):
|
|||
img = results['img']
|
||||
if len(img.shape) < 3:
|
||||
img = np.expand_dims(img, -1)
|
||||
img = np.ascontiguousarray(img.transpose(2, 0, 1))
|
||||
packed_results['inputs'] = to_tensor(img)
|
||||
# A simple trick to speedup formatting by 3-5 times when
|
||||
# OMP_NUM_THREADS != 1
|
||||
# Refer to https://github.com/open-mmlab/mmdetection/pull/9533
|
||||
# for more details
|
||||
if img.flags.c_contiguous:
|
||||
img = to_tensor(img)
|
||||
img = img.permute(2, 0, 1).contiguous()
|
||||
else:
|
||||
img = np.ascontiguousarray(img.transpose(2, 0, 1))
|
||||
img = to_tensor(img)
|
||||
packed_results['inputs'] = img
|
||||
|
||||
data_sample = TextDetDataSample()
|
||||
instance_data = InstanceData()
|
||||
|
@ -174,8 +183,17 @@ class PackTextRecogInputs(BaseTransform):
|
|||
img = results['img']
|
||||
if len(img.shape) < 3:
|
||||
img = np.expand_dims(img, -1)
|
||||
img = np.ascontiguousarray(img.transpose(2, 0, 1))
|
||||
packed_results['inputs'] = to_tensor(img)
|
||||
# A simple trick to speedup formatting by 3-5 times when
|
||||
# OMP_NUM_THREADS != 1
|
||||
# Refer to https://github.com/open-mmlab/mmdetection/pull/9533
|
||||
# for more details
|
||||
if img.flags.c_contiguous:
|
||||
img = to_tensor(img)
|
||||
img = img.permute(2, 0, 1).contiguous()
|
||||
else:
|
||||
img = np.ascontiguousarray(img.transpose(2, 0, 1))
|
||||
img = to_tensor(img)
|
||||
packed_results['inputs'] = img
|
||||
|
||||
data_sample = TextRecogDataSample()
|
||||
gt_text = LabelData()
|
||||
|
@ -272,8 +290,17 @@ class PackKIEInputs(BaseTransform):
|
|||
img = results['img']
|
||||
if len(img.shape) < 3:
|
||||
img = np.expand_dims(img, -1)
|
||||
img = np.ascontiguousarray(img.transpose(2, 0, 1))
|
||||
packed_results['inputs'] = to_tensor(img)
|
||||
# A simple trick to speedup formatting by 3-5 times when
|
||||
# OMP_NUM_THREADS != 1
|
||||
# Refer to https://github.com/open-mmlab/mmdetection/pull/9533
|
||||
# for more details
|
||||
if img.flags.c_contiguous:
|
||||
img = to_tensor(img)
|
||||
img = img.permute(2, 0, 1).contiguous()
|
||||
else:
|
||||
img = np.ascontiguousarray(img.transpose(2, 0, 1))
|
||||
img = to_tensor(img)
|
||||
packed_results['inputs'] = img
|
||||
else:
|
||||
packed_results['inputs'] = torch.FloatTensor().reshape(0, 0, 0)
|
||||
|
||||
|
|
|
@ -36,9 +36,17 @@ class TestPackTextDetInputs(TestCase):
|
|||
transform = PackTextDetInputs()
|
||||
results = transform(copy.deepcopy(datainfo))
|
||||
self.assertIn('inputs', results)
|
||||
self.assertEqual(results['inputs'].shape, torch.Size([1, 10, 10]))
|
||||
self.assertTupleEqual(tuple(results['inputs'].shape), (1, 10, 10))
|
||||
self.assertIn('data_samples', results)
|
||||
|
||||
# test non-contiugous img
|
||||
nc_datainfo = copy.deepcopy(datainfo)
|
||||
nc_datainfo['img'] = nc_datainfo['img'].transpose(1, 0)
|
||||
results = transform(nc_datainfo)
|
||||
self.assertIn('inputs', results)
|
||||
self.assertEqual(results['inputs'].shape, torch.Size([1, 10, 10]))
|
||||
|
||||
data_sample = results['data_samples']
|
||||
self.assertIn('bboxes', data_sample.gt_instances)
|
||||
self.assertIsInstance(data_sample.gt_instances.bboxes, torch.Tensor)
|
||||
|
@ -115,6 +123,13 @@ class TestPackTextRecogInputs(TestCase):
|
|||
self.assertIn('valid_ratio', data_sample)
|
||||
self.assertIn('pad_shape', data_sample)
|
||||
|
||||
# test non-contiugous img
|
||||
nc_datainfo = copy.deepcopy(datainfo)
|
||||
nc_datainfo['img'] = nc_datainfo['img'].transpose(1, 0)
|
||||
results = transform(nc_datainfo)
|
||||
self.assertIn('inputs', results)
|
||||
self.assertEqual(results['inputs'].shape, torch.Size([1, 10, 10]))
|
||||
|
||||
transform = PackTextRecogInputs(meta_keys=('img_path', ))
|
||||
results = transform(copy.deepcopy(datainfo))
|
||||
self.assertIn('inputs', results)
|
||||
|
@ -174,6 +189,13 @@ class TestPackKIEInputs(TestCase):
|
|||
torch.int64)
|
||||
self.assertIsInstance(data_sample.gt_instances.texts, list)
|
||||
|
||||
# test non-contiugous img
|
||||
nc_datainfo = copy.deepcopy(datainfo)
|
||||
nc_datainfo['img'] = nc_datainfo['img'].transpose(1, 0)
|
||||
results = self.transform(nc_datainfo)
|
||||
self.assertIn('inputs', results)
|
||||
self.assertEqual(results['inputs'].shape, torch.Size([1, 10, 10]))
|
||||
|
||||
transform = PackKIEInputs(meta_keys=('img_path', ))
|
||||
results = transform(copy.deepcopy(datainfo))
|
||||
self.assertIn('inputs', results)
|
||||
|
|
Loading…
Reference in New Issue