[Bug Fix] Fix TTA resize scale (#334)

* fix tta bug

* modify as suggested

* fix test_tta bug
This commit is contained in:
yamengxi 2021-01-08 01:58:34 +08:00 committed by GitHub
parent 7c4e505e7d
commit 022b055a66
3 changed files with 7 additions and 6 deletions

View File

@ -104,7 +104,7 @@ class MultiScaleFlipAug(object):
aug_data = []
if self.img_scale is None and mmcv.is_list_of(self.img_ratios, float):
h, w = results['img'].shape[:2]
img_scale = [(int(h * ratio), int(w * ratio))
img_scale = [(int(w * ratio), int(h * ratio))
for ratio in self.img_ratios]
else:
img_scale = self.img_scale

View File

@ -156,8 +156,9 @@ class Resize(object):
if self.ratio_range is not None:
if self.img_scale is None:
scale, scale_idx = self.random_sample_ratio(
results['img'].shape[:2], self.ratio_range)
h, w = results['img'].shape[:2]
scale, scale_idx = self.random_sample_ratio((w, h),
self.ratio_range)
else:
scale, scale_idx = self.random_sample_ratio(
self.img_scale[0], self.ratio_range)

View File

@ -108,7 +108,7 @@ def test_multi_scale_flip_aug():
)
tta_module = build_from_cfg(tta_transform, PIPELINES)
tta_results = tta_module(results.copy())
assert tta_results['scale'] == [(144, 256), (288, 512), (576, 1024)]
assert tta_results['scale'] == [(256, 144), (512, 288), (1024, 576)]
assert tta_results['flip'] == [False, False, False]
tta_transform = dict(
@ -120,8 +120,8 @@ def test_multi_scale_flip_aug():
)
tta_module = build_from_cfg(tta_transform, PIPELINES)
tta_results = tta_module(results.copy())
assert tta_results['scale'] == [(144, 256), (144, 256), (288, 512),
(288, 512), (576, 1024), (576, 1024)]
assert tta_results['scale'] == [(256, 144), (256, 144), (512, 288),
(512, 288), (1024, 576), (1024, 576)]
assert tta_results['flip'] == [False, True, False, True, False, True]
tta_transform = dict(