[Refactor]: Grascale return uint8 type

pull/2133/head
YuanLiuuuuuu 2022-05-30 18:16:02 +08:00 committed by zhouzaida
parent 5867a97a41
commit a03774d6db
2 changed files with 6 additions and 4 deletions

View File

@ -692,6 +692,7 @@ class RandomGrayscale(BaseTransform):
normalized_weights = (
np.array(self.channel_weights) / sum(self.channel_weights))
img = (normalized_weights * img).sum(axis=2)
img = img.astype('uint8')
if self.keep_channels:
img = img[:, :, None]
results['img'] = np.dstack(
@ -699,6 +700,7 @@ class RandomGrayscale(BaseTransform):
else:
results['img'] = img
return results
img = img.astype('uint8')
results['img'] = img
return results

View File

@ -474,7 +474,7 @@ class TestRandomGrayscale:
@classmethod
def setup_class(cls):
cls.img = np.random.rand(10, 10, 3).astype(np.float32)
cls.img = (np.random.rand(10, 10, 3) * 255).astype(np.uint8)
def test_repr(self):
# test repr
@ -504,9 +504,9 @@ class TestRandomGrayscale:
random_gray_scale_module = TRANSFORMS.build(transform)
results['img'] = copy.deepcopy(self.img)
img = random_gray_scale_module(results)['img']
computed_gray = (
self.img[:, :, 0] * 0.299 + self.img[:, :, 1] * 0.587 +
self.img[:, :, 2] * 0.114)
computed_gray = (self.img[:, :, 0] * 0.299 +
self.img[:, :, 1] * 0.587 +
self.img[:, :, 2] * 0.114).astype(np.uint8)
for i in range(img.shape[2]):
assert_array_almost_equal(img[:, :, i], computed_gray, decimal=4)
assert img.shape == (10, 10, 3)