From 3e8594d2dc26c35a23ac5e6f70278647c66555a0 Mon Sep 17 00:00:00 2001 From: "xiexinchen.vendor" Date: Fri, 8 Jul 2022 10:34:03 +0000 Subject: [PATCH] [Refactor] Move tensor2list operation to EncoderDecoder --- mmseg/models/decode_heads/decode_head.py | 6 +++--- mmseg/models/segmentors/encoder_decoder.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/mmseg/models/decode_heads/decode_head.py b/mmseg/models/decode_heads/decode_head.py index 479aa1931..a797d61b9 100644 --- a/mmseg/models/decode_heads/decode_head.py +++ b/mmseg/models/decode_heads/decode_head.py @@ -310,7 +310,7 @@ class BaseDecodeHead(BaseModule, metaclass=ABCMeta): return loss def predict_by_feat(self, seg_logits: Tensor, - batch_img_metas: List[dict]) -> List[Tensor]: + batch_img_metas: List[dict]) -> Tensor: """Transform a batch of output seg_logits to the input shape. Args: @@ -319,7 +319,7 @@ class BaseDecodeHead(BaseModule, metaclass=ABCMeta): image size, scaling factor, etc. Returns: - List[Tensor]: Outputs segmentation logits map. + Tensor: Outputs segmentation logits map. """ seg_logits = resize( @@ -327,4 +327,4 @@ class BaseDecodeHead(BaseModule, metaclass=ABCMeta): size=batch_img_metas[0]['img_shape'], mode='bilinear', align_corners=self.align_corners) - return list(seg_logits) + return seg_logits diff --git a/mmseg/models/segmentors/encoder_decoder.py b/mmseg/models/segmentors/encoder_decoder.py index 135e29528..a87168569 100644 --- a/mmseg/models/segmentors/encoder_decoder.py +++ b/mmseg/models/segmentors/encoder_decoder.py @@ -125,10 +125,10 @@ class EncoderDecoder(BaseSegmentor): """Encode images with backbone and decode into a semantic segmentation map of the same size as input.""" x = self.extract_feat(batch_inputs) - seg_logits_list = self.decode_head.predict(x, batch_img_metas, - self.test_cfg) + seg_logits = self.decode_head.predict(x, batch_img_metas, + self.test_cfg) - return seg_logits_list + return list(seg_logits) def _decode_head_forward_train(self, batch_inputs: List[Tensor], batch_data_samples: SampleList) -> dict: