[Refactor] Move tensor2list operation to EncoderDecoder

pull/1801/head
xiexinchen.vendor 2022-07-08 10:34:03 +00:00 committed by zhengmiao
parent a63f77d249
commit 3e8594d2dc
2 changed files with 6 additions and 6 deletions

View File

@ -310,7 +310,7 @@ class BaseDecodeHead(BaseModule, metaclass=ABCMeta):
return loss
def predict_by_feat(self, seg_logits: Tensor,
batch_img_metas: List[dict]) -> List[Tensor]:
batch_img_metas: List[dict]) -> Tensor:
"""Transform a batch of output seg_logits to the input shape.
Args:
@ -319,7 +319,7 @@ class BaseDecodeHead(BaseModule, metaclass=ABCMeta):
image size, scaling factor, etc.
Returns:
List[Tensor]: Outputs segmentation logits map.
Tensor: Outputs segmentation logits map.
"""
seg_logits = resize(
@ -327,4 +327,4 @@ class BaseDecodeHead(BaseModule, metaclass=ABCMeta):
size=batch_img_metas[0]['img_shape'],
mode='bilinear',
align_corners=self.align_corners)
return list(seg_logits)
return seg_logits

View File

@ -125,10 +125,10 @@ class EncoderDecoder(BaseSegmentor):
"""Encode images with backbone and decode into a semantic segmentation
map of the same size as input."""
x = self.extract_feat(batch_inputs)
seg_logits_list = self.decode_head.predict(x, batch_img_metas,
self.test_cfg)
seg_logits = self.decode_head.predict(x, batch_img_metas,
self.test_cfg)
return seg_logits_list
return list(seg_logits)
def _decode_head_forward_train(self, batch_inputs: List[Tensor],
batch_data_samples: SampleList) -> dict: