mirror of https://github.com/alibaba/EasyCV.git
37 lines
1.1 KiB
Python
37 lines
1.1 KiB
Python
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
import unittest
|
|
|
|
import numpy as np
|
|
|
|
from easycv.core.bbox import bbox_util
|
|
|
|
|
|
class BboxUtilTest(unittest.TestCase):
|
|
|
|
def test_batched_cxcywh2xyxy_with_shape(self):
|
|
# normal
|
|
normalized_cxcywh = np.array([[0.4, 0.6, 0.1, 0.2],
|
|
[0.5, 0.4, 0.2, 0.3]])
|
|
h, w = 500, 600
|
|
xyxy = bbox_util.batched_cxcywh2xyxy_with_shape(
|
|
normalized_cxcywh, shape=[h, w])
|
|
|
|
target = np.array([[0.35 * w, 0.5 * h, 0.45 * w, 0.7 * h],
|
|
[0.4 * w, 0.25 * h, 0.6 * w, 0.55 * h]])
|
|
|
|
self.assertEqual(xyxy.all(), target.all())
|
|
|
|
# out of bounds
|
|
cxcywh_out = np.array([[0.4, 0.6, 0.1, 0.9], [0.8, 0.4, 0.8, 0.3]])
|
|
xyxy_out = bbox_util.batched_cxcywh2xyxy_with_shape(
|
|
cxcywh_out, shape=[h, w])
|
|
|
|
target_out = np.array([[0.35 * w, 0.15 * h, 0.45 * w, 1.0 * h],
|
|
[0.4 * w, 0.25 * h, 1.0 * w, 0.55 * h]])
|
|
|
|
self.assertEqual(xyxy_out.all(), target_out.all())
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|