mirror of
https://github.com/open-mmlab/mmclassification.git
synced 2025-06-03 21:53:55 +08:00
add unittest for backends
This commit is contained in:
parent
a0c814fecd
commit
0009591aa3
@ -64,6 +64,7 @@ def test_resize():
|
||||
results['img2'] = copy.deepcopy(original_img)
|
||||
results['img_shape'] = original_img.shape
|
||||
results['ori_shape'] = original_img.shape
|
||||
results['img_fields'] = ['img', 'img2']
|
||||
return results
|
||||
|
||||
# test resize when size is int
|
||||
@ -101,6 +102,26 @@ def test_resize():
|
||||
assert np.equal(results['img'], results['img2']).all()
|
||||
assert results['img_shape'] == (img_height * 2, img_width * 2, 3)
|
||||
|
||||
# test resize with different backends
|
||||
transform_cv2 = dict(
|
||||
type='Resize',
|
||||
size=(224, 256),
|
||||
interpolation='bilinear',
|
||||
backend='cv2')
|
||||
transform_pil = dict(
|
||||
type='Resize',
|
||||
size=(224, 256),
|
||||
interpolation='bilinear',
|
||||
backend='pillow')
|
||||
resize_module_cv2 = build_from_cfg(transform_cv2, PIPELINES)
|
||||
resize_module_pil = build_from_cfg(transform_pil, PIPELINES)
|
||||
results = reset_results(results, original_img)
|
||||
results['img_fields'] = ['img']
|
||||
results_cv2 = resize_module_cv2(results)
|
||||
results['img_fields'] = ['img2']
|
||||
results_pil = resize_module_pil(results)
|
||||
assert np.allclose(results_cv2['img'], results_pil['img2'], atol=45)
|
||||
|
||||
# compare results with torchvision
|
||||
transform = dict(type='Resize', size=(224, 224), interpolation='area')
|
||||
resize_module = build_from_cfg(transform, PIPELINES)
|
||||
|
Loading…
x
Reference in New Issue
Block a user