mirror of https://github.com/open-mmlab/mmcv.git
[Refactor]: Grascale return uint8 type
parent
5867a97a41
commit
a03774d6db
|
@ -692,6 +692,7 @@ class RandomGrayscale(BaseTransform):
|
|||
normalized_weights = (
|
||||
np.array(self.channel_weights) / sum(self.channel_weights))
|
||||
img = (normalized_weights * img).sum(axis=2)
|
||||
img = img.astype('uint8')
|
||||
if self.keep_channels:
|
||||
img = img[:, :, None]
|
||||
results['img'] = np.dstack(
|
||||
|
@ -699,6 +700,7 @@ class RandomGrayscale(BaseTransform):
|
|||
else:
|
||||
results['img'] = img
|
||||
return results
|
||||
img = img.astype('uint8')
|
||||
results['img'] = img
|
||||
return results
|
||||
|
||||
|
|
|
@ -474,7 +474,7 @@ class TestRandomGrayscale:
|
|||
|
||||
@classmethod
|
||||
def setup_class(cls):
|
||||
cls.img = np.random.rand(10, 10, 3).astype(np.float32)
|
||||
cls.img = (np.random.rand(10, 10, 3) * 255).astype(np.uint8)
|
||||
|
||||
def test_repr(self):
|
||||
# test repr
|
||||
|
@ -504,9 +504,9 @@ class TestRandomGrayscale:
|
|||
random_gray_scale_module = TRANSFORMS.build(transform)
|
||||
results['img'] = copy.deepcopy(self.img)
|
||||
img = random_gray_scale_module(results)['img']
|
||||
computed_gray = (
|
||||
self.img[:, :, 0] * 0.299 + self.img[:, :, 1] * 0.587 +
|
||||
self.img[:, :, 2] * 0.114)
|
||||
computed_gray = (self.img[:, :, 0] * 0.299 +
|
||||
self.img[:, :, 1] * 0.587 +
|
||||
self.img[:, :, 2] * 0.114).astype(np.uint8)
|
||||
for i in range(img.shape[2]):
|
||||
assert_array_almost_equal(img[:, :, i], computed_gray, decimal=4)
|
||||
assert img.shape == (10, 10, 3)
|
||||
|
|
Loading…
Reference in New Issue