[Refactor] Move tensor2list operation to EncoderDecoder
parent
a63f77d249
commit
3e8594d2dc
|
@ -310,7 +310,7 @@ class BaseDecodeHead(BaseModule, metaclass=ABCMeta):
|
|||
return loss
|
||||
|
||||
def predict_by_feat(self, seg_logits: Tensor,
|
||||
batch_img_metas: List[dict]) -> List[Tensor]:
|
||||
batch_img_metas: List[dict]) -> Tensor:
|
||||
"""Transform a batch of output seg_logits to the input shape.
|
||||
|
||||
Args:
|
||||
|
@ -319,7 +319,7 @@ class BaseDecodeHead(BaseModule, metaclass=ABCMeta):
|
|||
image size, scaling factor, etc.
|
||||
|
||||
Returns:
|
||||
List[Tensor]: Outputs segmentation logits map.
|
||||
Tensor: Outputs segmentation logits map.
|
||||
"""
|
||||
|
||||
seg_logits = resize(
|
||||
|
@ -327,4 +327,4 @@ class BaseDecodeHead(BaseModule, metaclass=ABCMeta):
|
|||
size=batch_img_metas[0]['img_shape'],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
return list(seg_logits)
|
||||
return seg_logits
|
||||
|
|
|
@ -125,10 +125,10 @@ class EncoderDecoder(BaseSegmentor):
|
|||
"""Encode images with backbone and decode into a semantic segmentation
|
||||
map of the same size as input."""
|
||||
x = self.extract_feat(batch_inputs)
|
||||
seg_logits_list = self.decode_head.predict(x, batch_img_metas,
|
||||
self.test_cfg)
|
||||
seg_logits = self.decode_head.predict(x, batch_img_metas,
|
||||
self.test_cfg)
|
||||
|
||||
return seg_logits_list
|
||||
return list(seg_logits)
|
||||
|
||||
def _decode_head_forward_train(self, batch_inputs: List[Tensor],
|
||||
batch_data_samples: SampleList) -> dict:
|
||||
|
|
Loading…
Reference in New Issue