From df0be646ea4a15866f60678ca5f170b90876f098 Mon Sep 17 00:00:00 2001
From: Tong Gao <gaotongxiao@gmail.com>
Date: Thu, 16 Feb 2023 14:14:03 +0800
Subject: [PATCH] [Enhancement] Speedup formatting by replacing np.transpose
 with torch.permute (#1719)

---
 mmocr/datasets/transforms/formatting.py       | 39 ++++++++++++++++---
 .../test_transforms/test_formatting.py        | 22 +++++++++++
 2 files changed, 55 insertions(+), 6 deletions(-)

diff --git a/mmocr/datasets/transforms/formatting.py b/mmocr/datasets/transforms/formatting.py
index 1649850e..b9b71437 100644
--- a/mmocr/datasets/transforms/formatting.py
+++ b/mmocr/datasets/transforms/formatting.py
@@ -90,8 +90,17 @@ class PackTextDetInputs(BaseTransform):
             img = results['img']
             if len(img.shape) < 3:
                 img = np.expand_dims(img, -1)
-            img = np.ascontiguousarray(img.transpose(2, 0, 1))
-            packed_results['inputs'] = to_tensor(img)
+            # A simple trick to speedup formatting by 3-5 times when
+            # OMP_NUM_THREADS != 1
+            # Refer to https://github.com/open-mmlab/mmdetection/pull/9533
+            # for more details
+            if img.flags.c_contiguous:
+                img = to_tensor(img)
+                img = img.permute(2, 0, 1).contiguous()
+            else:
+                img = np.ascontiguousarray(img.transpose(2, 0, 1))
+                img = to_tensor(img)
+            packed_results['inputs'] = img
 
         data_sample = TextDetDataSample()
         instance_data = InstanceData()
@@ -174,8 +183,17 @@ class PackTextRecogInputs(BaseTransform):
             img = results['img']
             if len(img.shape) < 3:
                 img = np.expand_dims(img, -1)
-            img = np.ascontiguousarray(img.transpose(2, 0, 1))
-            packed_results['inputs'] = to_tensor(img)
+            # A simple trick to speedup formatting by 3-5 times when
+            # OMP_NUM_THREADS != 1
+            # Refer to https://github.com/open-mmlab/mmdetection/pull/9533
+            # for more details
+            if img.flags.c_contiguous:
+                img = to_tensor(img)
+                img = img.permute(2, 0, 1).contiguous()
+            else:
+                img = np.ascontiguousarray(img.transpose(2, 0, 1))
+                img = to_tensor(img)
+            packed_results['inputs'] = img
 
         data_sample = TextRecogDataSample()
         gt_text = LabelData()
@@ -272,8 +290,17 @@ class PackKIEInputs(BaseTransform):
             img = results['img']
             if len(img.shape) < 3:
                 img = np.expand_dims(img, -1)
-            img = np.ascontiguousarray(img.transpose(2, 0, 1))
-            packed_results['inputs'] = to_tensor(img)
+            # A simple trick to speedup formatting by 3-5 times when
+            # OMP_NUM_THREADS != 1
+            # Refer to https://github.com/open-mmlab/mmdetection/pull/9533
+            # for more details
+            if img.flags.c_contiguous:
+                img = to_tensor(img)
+                img = img.permute(2, 0, 1).contiguous()
+            else:
+                img = np.ascontiguousarray(img.transpose(2, 0, 1))
+                img = to_tensor(img)
+            packed_results['inputs'] = img
         else:
             packed_results['inputs'] = torch.FloatTensor().reshape(0, 0, 0)
 
diff --git a/tests/test_datasets/test_transforms/test_formatting.py b/tests/test_datasets/test_transforms/test_formatting.py
index a29eecf1..21e9d10f 100644
--- a/tests/test_datasets/test_transforms/test_formatting.py
+++ b/tests/test_datasets/test_transforms/test_formatting.py
@@ -36,9 +36,17 @@ class TestPackTextDetInputs(TestCase):
         transform = PackTextDetInputs()
         results = transform(copy.deepcopy(datainfo))
         self.assertIn('inputs', results)
+        self.assertEqual(results['inputs'].shape, torch.Size([1, 10, 10]))
         self.assertTupleEqual(tuple(results['inputs'].shape), (1, 10, 10))
         self.assertIn('data_samples', results)
 
+        # test non-contiugous img
+        nc_datainfo = copy.deepcopy(datainfo)
+        nc_datainfo['img'] = nc_datainfo['img'].transpose(1, 0)
+        results = transform(nc_datainfo)
+        self.assertIn('inputs', results)
+        self.assertEqual(results['inputs'].shape, torch.Size([1, 10, 10]))
+
         data_sample = results['data_samples']
         self.assertIn('bboxes', data_sample.gt_instances)
         self.assertIsInstance(data_sample.gt_instances.bboxes, torch.Tensor)
@@ -115,6 +123,13 @@ class TestPackTextRecogInputs(TestCase):
         self.assertIn('valid_ratio', data_sample)
         self.assertIn('pad_shape', data_sample)
 
+        # test non-contiugous img
+        nc_datainfo = copy.deepcopy(datainfo)
+        nc_datainfo['img'] = nc_datainfo['img'].transpose(1, 0)
+        results = transform(nc_datainfo)
+        self.assertIn('inputs', results)
+        self.assertEqual(results['inputs'].shape, torch.Size([1, 10, 10]))
+
         transform = PackTextRecogInputs(meta_keys=('img_path', ))
         results = transform(copy.deepcopy(datainfo))
         self.assertIn('inputs', results)
@@ -174,6 +189,13 @@ class TestPackKIEInputs(TestCase):
                          torch.int64)
         self.assertIsInstance(data_sample.gt_instances.texts, list)
 
+        # test non-contiugous img
+        nc_datainfo = copy.deepcopy(datainfo)
+        nc_datainfo['img'] = nc_datainfo['img'].transpose(1, 0)
+        results = self.transform(nc_datainfo)
+        self.assertIn('inputs', results)
+        self.assertEqual(results['inputs'].shape, torch.Size([1, 10, 10]))
+
         transform = PackKIEInputs(meta_keys=('img_path', ))
         results = transform(copy.deepcopy(datainfo))
         self.assertIn('inputs', results)