[Refactor] Move tensor2list operation to EncoderDecoder
parent
a63f77d249
commit
3e8594d2dc
|
@ -310,7 +310,7 @@ class BaseDecodeHead(BaseModule, metaclass=ABCMeta):
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
def predict_by_feat(self, seg_logits: Tensor,
|
def predict_by_feat(self, seg_logits: Tensor,
|
||||||
batch_img_metas: List[dict]) -> List[Tensor]:
|
batch_img_metas: List[dict]) -> Tensor:
|
||||||
"""Transform a batch of output seg_logits to the input shape.
|
"""Transform a batch of output seg_logits to the input shape.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -319,7 +319,7 @@ class BaseDecodeHead(BaseModule, metaclass=ABCMeta):
|
||||||
image size, scaling factor, etc.
|
image size, scaling factor, etc.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[Tensor]: Outputs segmentation logits map.
|
Tensor: Outputs segmentation logits map.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
seg_logits = resize(
|
seg_logits = resize(
|
||||||
|
@ -327,4 +327,4 @@ class BaseDecodeHead(BaseModule, metaclass=ABCMeta):
|
||||||
size=batch_img_metas[0]['img_shape'],
|
size=batch_img_metas[0]['img_shape'],
|
||||||
mode='bilinear',
|
mode='bilinear',
|
||||||
align_corners=self.align_corners)
|
align_corners=self.align_corners)
|
||||||
return list(seg_logits)
|
return seg_logits
|
||||||
|
|
|
@ -125,10 +125,10 @@ class EncoderDecoder(BaseSegmentor):
|
||||||
"""Encode images with backbone and decode into a semantic segmentation
|
"""Encode images with backbone and decode into a semantic segmentation
|
||||||
map of the same size as input."""
|
map of the same size as input."""
|
||||||
x = self.extract_feat(batch_inputs)
|
x = self.extract_feat(batch_inputs)
|
||||||
seg_logits_list = self.decode_head.predict(x, batch_img_metas,
|
seg_logits = self.decode_head.predict(x, batch_img_metas,
|
||||||
self.test_cfg)
|
self.test_cfg)
|
||||||
|
|
||||||
return seg_logits_list
|
return list(seg_logits)
|
||||||
|
|
||||||
def _decode_head_forward_train(self, batch_inputs: List[Tensor],
|
def _decode_head_forward_train(self, batch_inputs: List[Tensor],
|
||||||
batch_data_samples: SampleList) -> dict:
|
batch_data_samples: SampleList) -> dict:
|
||||||
|
|
Loading…
Reference in New Issue