From a03774d6dbb24113be51e2f0aa2256586c5befe4 Mon Sep 17 00:00:00 2001 From: YuanLiuuuuuu <3463423099@qq.com> Date: Mon, 30 May 2022 18:16:02 +0800 Subject: [PATCH] [Refactor]: Grascale return uint8 type --- mmcv/transforms/processing.py | 2 ++ tests/test_transforms/test_transforms_processing.py | 8 ++++---- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/mmcv/transforms/processing.py b/mmcv/transforms/processing.py index e2b5f2e7c..97e384bb5 100644 --- a/mmcv/transforms/processing.py +++ b/mmcv/transforms/processing.py @@ -692,6 +692,7 @@ class RandomGrayscale(BaseTransform): normalized_weights = ( np.array(self.channel_weights) / sum(self.channel_weights)) img = (normalized_weights * img).sum(axis=2) + img = img.astype('uint8') if self.keep_channels: img = img[:, :, None] results['img'] = np.dstack( @@ -699,6 +700,7 @@ class RandomGrayscale(BaseTransform): else: results['img'] = img return results + img = img.astype('uint8') results['img'] = img return results diff --git a/tests/test_transforms/test_transforms_processing.py b/tests/test_transforms/test_transforms_processing.py index 5a64f5ac6..1f21bfc24 100644 --- a/tests/test_transforms/test_transforms_processing.py +++ b/tests/test_transforms/test_transforms_processing.py @@ -474,7 +474,7 @@ class TestRandomGrayscale: @classmethod def setup_class(cls): - cls.img = np.random.rand(10, 10, 3).astype(np.float32) + cls.img = (np.random.rand(10, 10, 3) * 255).astype(np.uint8) def test_repr(self): # test repr @@ -504,9 +504,9 @@ class TestRandomGrayscale: random_gray_scale_module = TRANSFORMS.build(transform) results['img'] = copy.deepcopy(self.img) img = random_gray_scale_module(results)['img'] - computed_gray = ( - self.img[:, :, 0] * 0.299 + self.img[:, :, 1] * 0.587 + - self.img[:, :, 2] * 0.114) + computed_gray = (self.img[:, :, 0] * 0.299 + + self.img[:, :, 1] * 0.587 + + self.img[:, :, 2] * 0.114).astype(np.uint8) for i in range(img.shape[2]): assert_array_almost_equal(img[:, :, i], computed_gray, decimal=4) assert img.shape == (10, 10, 3)