ZiAn-Su 7254f5330f
[Fix] Fix SegTTAModel with no attribute '_gt_sem_seg' error (#3152)
## Motivation

When using the - tta command for multi-scale prediction, and the test
set is not annotated, although format_only has been set true in
test_evaluator, but SegTTAModel class still threw error 'AttributeError:
'SegDataSample' object has no attribute '_gt_sem_seg''.

## Modification

The reason is SegTTAModel didn't determine if there were annotations in
the dataset, so I added the code to make the judgment and let the
program run normally on my computer.
2023-07-13 17:06:06 +08:00

48 lines
1.7 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
from typing import List
import torch
from mmengine.model import BaseTTAModel
from mmengine.structures import PixelData
from mmseg.registry import MODELS
from mmseg.utils import SampleList
@MODELS.register_module()
class SegTTAModel(BaseTTAModel):
def merge_preds(self, data_samples_list: List[SampleList]) -> SampleList:
"""Merge predictions of enhanced data to one prediction.
Args:
data_samples_list (List[SampleList]): List of predictions
of all enhanced data.
Returns:
SampleList: Merged prediction.
"""
predictions = []
for data_samples in data_samples_list:
seg_logits = data_samples[0].seg_logits.data
logits = torch.zeros(seg_logits.shape).to(seg_logits)
for data_sample in data_samples:
seg_logit = data_sample.seg_logits.data
if self.module.out_channels > 1:
logits += seg_logit.softmax(dim=0)
else:
logits += seg_logit.sigmoid()
logits /= len(data_samples)
if self.module.out_channels == 1:
seg_pred = (logits > self.module.decode_head.threshold
).to(logits).squeeze(1)
else:
seg_pred = logits.argmax(dim=0)
data_sample.set_data({'pred_sem_seg': PixelData(data=seg_pred)})
if hasattr(data_samples[0], 'gt_sem_seg'):
data_sample.set_data(
{'gt_sem_seg': data_samples[0].gt_sem_seg})
data_sample.set_metainfo({'img_path': data_samples[0].img_path})
predictions.append(data_sample)
return predictions