mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
[Bug Fix] Fix TTA resize scale (#334)
* fix tta bug * modify as suggested * fix test_tta bug
This commit is contained in:
parent
7c4e505e7d
commit
022b055a66
@ -104,7 +104,7 @@ class MultiScaleFlipAug(object):
|
|||||||
aug_data = []
|
aug_data = []
|
||||||
if self.img_scale is None and mmcv.is_list_of(self.img_ratios, float):
|
if self.img_scale is None and mmcv.is_list_of(self.img_ratios, float):
|
||||||
h, w = results['img'].shape[:2]
|
h, w = results['img'].shape[:2]
|
||||||
img_scale = [(int(h * ratio), int(w * ratio))
|
img_scale = [(int(w * ratio), int(h * ratio))
|
||||||
for ratio in self.img_ratios]
|
for ratio in self.img_ratios]
|
||||||
else:
|
else:
|
||||||
img_scale = self.img_scale
|
img_scale = self.img_scale
|
||||||
|
@ -156,8 +156,9 @@ class Resize(object):
|
|||||||
|
|
||||||
if self.ratio_range is not None:
|
if self.ratio_range is not None:
|
||||||
if self.img_scale is None:
|
if self.img_scale is None:
|
||||||
scale, scale_idx = self.random_sample_ratio(
|
h, w = results['img'].shape[:2]
|
||||||
results['img'].shape[:2], self.ratio_range)
|
scale, scale_idx = self.random_sample_ratio((w, h),
|
||||||
|
self.ratio_range)
|
||||||
else:
|
else:
|
||||||
scale, scale_idx = self.random_sample_ratio(
|
scale, scale_idx = self.random_sample_ratio(
|
||||||
self.img_scale[0], self.ratio_range)
|
self.img_scale[0], self.ratio_range)
|
||||||
|
@ -108,7 +108,7 @@ def test_multi_scale_flip_aug():
|
|||||||
)
|
)
|
||||||
tta_module = build_from_cfg(tta_transform, PIPELINES)
|
tta_module = build_from_cfg(tta_transform, PIPELINES)
|
||||||
tta_results = tta_module(results.copy())
|
tta_results = tta_module(results.copy())
|
||||||
assert tta_results['scale'] == [(144, 256), (288, 512), (576, 1024)]
|
assert tta_results['scale'] == [(256, 144), (512, 288), (1024, 576)]
|
||||||
assert tta_results['flip'] == [False, False, False]
|
assert tta_results['flip'] == [False, False, False]
|
||||||
|
|
||||||
tta_transform = dict(
|
tta_transform = dict(
|
||||||
@ -120,8 +120,8 @@ def test_multi_scale_flip_aug():
|
|||||||
)
|
)
|
||||||
tta_module = build_from_cfg(tta_transform, PIPELINES)
|
tta_module = build_from_cfg(tta_transform, PIPELINES)
|
||||||
tta_results = tta_module(results.copy())
|
tta_results = tta_module(results.copy())
|
||||||
assert tta_results['scale'] == [(144, 256), (144, 256), (288, 512),
|
assert tta_results['scale'] == [(256, 144), (256, 144), (512, 288),
|
||||||
(288, 512), (576, 1024), (576, 1024)]
|
(512, 288), (1024, 576), (1024, 576)]
|
||||||
assert tta_results['flip'] == [False, True, False, True, False, True]
|
assert tta_results['flip'] == [False, True, False, True, False, True]
|
||||||
|
|
||||||
tta_transform = dict(
|
tta_transform = dict(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user