[Fix] Fix SegTTAModel with no attribute '_gt_sem_seg' error (#3152)

## Motivation

When using the - tta command for multi-scale prediction, and the test
set is not annotated, although format_only has been set true in
test_evaluator, but SegTTAModel class still threw error 'AttributeError:
'SegDataSample' object has no attribute '_gt_sem_seg''.

## Modification

The reason is SegTTAModel didn't determine if there were annotations in
the dataset, so I added the code to make the judgment and let the
program run normally on my computer.
pull/3174/head^2
ZiAn-Su 2023-07-13 17:06:06 +08:00 committed by GitHub
parent 067a95e40b
commit 7254f5330f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 4 additions and 6 deletions

View File

@ -6,7 +6,6 @@ from mmengine.model import BaseTTAModel
from mmengine.structures import PixelData
from mmseg.registry import MODELS
from mmseg.structures import SegDataSample
from mmseg.utils import SampleList
@ -39,11 +38,10 @@ class SegTTAModel(BaseTTAModel):
).to(logits).squeeze(1)
else:
seg_pred = logits.argmax(dim=0)
data_sample = SegDataSample(
**{
'pred_sem_seg': PixelData(data=seg_pred),
'gt_sem_seg': data_samples[0].gt_sem_seg
})
data_sample.set_data({'pred_sem_seg': PixelData(data=seg_pred)})
if hasattr(data_samples[0], 'gt_sem_seg'):
data_sample.set_data(
{'gt_sem_seg': data_samples[0].gt_sem_seg})
data_sample.set_metainfo({'img_path': data_samples[0].img_path})
predictions.append(data_sample)
return predictions