[Fix] Delete all **kwargs in Segmentor Forward function
parent
9ed2e1cdd8
commit
c5ad7fb0b7
|
@ -21,8 +21,7 @@ class BaseCascadeDecodeHead(BaseDecodeHead, metaclass=ABCMeta):
|
|||
pass
|
||||
|
||||
def loss(self, inputs: List[Tensor], prev_output: Tensor,
|
||||
batch_data_samples: List[dict], train_cfg: ConfigType,
|
||||
**kwargs) -> Tensor:
|
||||
batch_data_samples: List[dict], train_cfg: ConfigType) -> Tensor:
|
||||
"""Forward function for training.
|
||||
|
||||
Args:
|
||||
|
@ -37,12 +36,12 @@ class BaseCascadeDecodeHead(BaseDecodeHead, metaclass=ABCMeta):
|
|||
dict[str, Tensor]: a dictionary of loss components
|
||||
"""
|
||||
seg_logits = self.forward(inputs, prev_output)
|
||||
losses = self.loss_by_feat(seg_logits, batch_data_samples, **kwargs)
|
||||
losses = self.loss_by_feat(seg_logits, batch_data_samples)
|
||||
|
||||
return losses
|
||||
|
||||
def predict(self, inputs: List[Tensor], prev_output: Tensor,
|
||||
batch_img_metas: List[dict], tese_cfg: ConfigType, **kwargs):
|
||||
batch_img_metas: List[dict], tese_cfg: ConfigType):
|
||||
"""Forward function for testing.
|
||||
|
||||
Args:
|
||||
|
@ -60,4 +59,4 @@ class BaseCascadeDecodeHead(BaseDecodeHead, metaclass=ABCMeta):
|
|||
"""
|
||||
seg_logits = self.forward(inputs, prev_output)
|
||||
|
||||
return self.predict_by_feat(seg_logits, batch_img_metas, **kwargs)
|
||||
return self.predict_by_feat(seg_logits, batch_img_metas)
|
||||
|
|
|
@ -205,7 +205,7 @@ class BaseDecodeHead(BaseModule, metaclass=ABCMeta):
|
|||
return inputs
|
||||
|
||||
@abstractmethod
|
||||
def forward(self, inputs, **kwargs):
|
||||
def forward(self, inputs):
|
||||
"""Placeholder of forward function."""
|
||||
pass
|
||||
|
||||
|
@ -217,7 +217,7 @@ class BaseDecodeHead(BaseModule, metaclass=ABCMeta):
|
|||
return output
|
||||
|
||||
def loss(self, inputs: Tuple[Tensor], batch_data_samples: SampleList,
|
||||
train_cfg: ConfigType, **kwargs) -> dict:
|
||||
train_cfg: ConfigType) -> dict:
|
||||
"""Forward function for training.
|
||||
|
||||
Args:
|
||||
|
@ -230,12 +230,12 @@ class BaseDecodeHead(BaseModule, metaclass=ABCMeta):
|
|||
Returns:
|
||||
dict[str, Tensor]: a dictionary of loss components
|
||||
"""
|
||||
seg_logits = self.forward(inputs, **kwargs)
|
||||
losses = self.loss_by_feat(seg_logits, batch_data_samples, **kwargs)
|
||||
seg_logits = self.forward(inputs)
|
||||
losses = self.loss_by_feat(seg_logits, batch_data_samples)
|
||||
return losses
|
||||
|
||||
def predict(self, inputs: Tuple[Tensor], batch_img_metas: List[dict],
|
||||
test_cfg: ConfigType, **kwargs) -> List[Tensor]:
|
||||
test_cfg: ConfigType) -> List[Tensor]:
|
||||
"""Forward function for prediction.
|
||||
|
||||
Args:
|
||||
|
@ -250,9 +250,9 @@ class BaseDecodeHead(BaseModule, metaclass=ABCMeta):
|
|||
Returns:
|
||||
List[Tensor]: Outputs segmentation logits map.
|
||||
"""
|
||||
seg_logits = self.forward(inputs, **kwargs)
|
||||
seg_logits = self.forward(inputs)
|
||||
|
||||
return self.predict_by_feat(seg_logits, batch_img_metas, **kwargs)
|
||||
return self.predict_by_feat(seg_logits, batch_img_metas)
|
||||
|
||||
def _stack_batch_gt(self, batch_data_samples: SampleList) -> Tensor:
|
||||
gt_semantic_segs = [
|
||||
|
@ -260,8 +260,8 @@ class BaseDecodeHead(BaseModule, metaclass=ABCMeta):
|
|||
]
|
||||
return torch.stack(gt_semantic_segs, dim=0)
|
||||
|
||||
def loss_by_feat(self, seg_logits: Tensor, batch_data_samples: SampleList,
|
||||
**kwargs) -> dict:
|
||||
def loss_by_feat(self, seg_logits: Tensor,
|
||||
batch_data_samples: SampleList) -> dict:
|
||||
"""Compute segmentation loss.
|
||||
|
||||
Args:
|
||||
|
@ -309,8 +309,8 @@ class BaseDecodeHead(BaseModule, metaclass=ABCMeta):
|
|||
seg_logits, seg_label, ignore_index=self.ignore_index)
|
||||
return loss
|
||||
|
||||
def predict_by_feat(self, seg_logits: Tensor, batch_img_metas: List[dict],
|
||||
**kwargs) -> List[Tensor]:
|
||||
def predict_by_feat(self, seg_logits: Tensor,
|
||||
batch_img_metas: List[dict]) -> List[Tensor]:
|
||||
"""Transform a batch of output seg_logits to the input shape.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -53,7 +53,7 @@ class BaseSegmentor(BaseModel, metaclass=ABCMeta):
|
|||
|
||||
@abstractmethod
|
||||
def encode_decode(self, batch_inputs: Tensor,
|
||||
batch_data_samples: SampleList, **kwargs):
|
||||
batch_data_samples: SampleList):
|
||||
"""Placeholder for encode images with backbone and decode into a
|
||||
semantic segmentation map of the same size as input."""
|
||||
pass
|
||||
|
@ -61,8 +61,7 @@ class BaseSegmentor(BaseModel, metaclass=ABCMeta):
|
|||
def forward(self,
|
||||
batch_inputs: Tensor,
|
||||
batch_data_samples: OptSampleList = None,
|
||||
mode: str = 'tensor',
|
||||
**kwargs) -> ForwardResults:
|
||||
mode: str = 'tensor') -> ForwardResults:
|
||||
"""The unified entry for a forward process in both training and test.
|
||||
|
||||
The method should accept three modes: "tensor", "predict" and "loss":
|
||||
|
@ -92,33 +91,33 @@ class BaseSegmentor(BaseModel, metaclass=ABCMeta):
|
|||
- If ``mode="loss"``, return a dict of tensor.
|
||||
"""
|
||||
if mode == 'loss':
|
||||
return self.loss(batch_inputs, batch_data_samples, **kwargs)
|
||||
return self.loss(batch_inputs, batch_data_samples)
|
||||
elif mode == 'predict':
|
||||
return self.predict(batch_inputs, batch_data_samples, **kwargs)
|
||||
return self.predict(batch_inputs, batch_data_samples)
|
||||
elif mode == 'tensor':
|
||||
return self._forward(batch_inputs, batch_data_samples, **kwargs)
|
||||
return self._forward(batch_inputs, batch_data_samples)
|
||||
else:
|
||||
raise RuntimeError(f'Invalid mode "{mode}". '
|
||||
'Only supports loss, predict and tensor mode')
|
||||
|
||||
@abstractmethod
|
||||
def loss(self, batch_inputs: Tensor, batch_data_samples: SampleList,
|
||||
**kwargs) -> dict:
|
||||
def loss(self, batch_inputs: Tensor,
|
||||
batch_data_samples: SampleList) -> dict:
|
||||
"""Calculate losses from a batch of inputs and data samples."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def predict(self, batch_inputs: Tensor, batch_data_samples: SampleList,
|
||||
**kwargs) -> SampleList:
|
||||
def predict(self, batch_inputs: Tensor,
|
||||
batch_data_samples: SampleList) -> SampleList:
|
||||
"""Predict results from a batch of inputs and data samples with post-
|
||||
processing."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _forward(self,
|
||||
batch_inputs: Tensor,
|
||||
batch_data_samples: OptSampleList = None,
|
||||
**kwargs) -> Tuple[List[Tensor]]:
|
||||
def _forward(
|
||||
self,
|
||||
batch_inputs: Tensor,
|
||||
batch_data_samples: OptSampleList = None) -> Tuple[List[Tensor]]:
|
||||
"""Network forward process.
|
||||
|
||||
Usually includes backbone, neck and head forward without any post-
|
||||
|
@ -127,7 +126,7 @@ class BaseSegmentor(BaseModel, metaclass=ABCMeta):
|
|||
pass
|
||||
|
||||
@abstractmethod
|
||||
def aug_test(self, batch_inputs, batch_img_metas, **kwargs):
|
||||
def aug_test(self, batch_inputs, batch_img_metas):
|
||||
"""Placeholder for augmentation test."""
|
||||
pass
|
||||
|
||||
|
|
|
@ -70,29 +70,28 @@ class CascadeEncoderDecoder(EncoderDecoder):
|
|||
self.align_corners = self.decode_head[-1].align_corners
|
||||
self.num_classes = self.decode_head[-1].num_classes
|
||||
|
||||
def encode_decode(self, batch_inputs: Tensor, batch_img_metas: List[dict],
|
||||
**kwargs) -> List[Tensor]:
|
||||
def encode_decode(self, batch_inputs: Tensor,
|
||||
batch_img_metas: List[dict]) -> List[Tensor]:
|
||||
"""Encode images with backbone and decode into a semantic segmentation
|
||||
map of the same size as input."""
|
||||
x = self.extract_feat(batch_inputs)
|
||||
out = self.decode_head[0].forward(x, **kwargs)
|
||||
out = self.decode_head[0].forward(x)
|
||||
for i in range(1, self.num_stages - 1):
|
||||
out = self.decode_head[i].forward(x, out, **kwargs)
|
||||
out = self.decode_head[i].forward(x, out)
|
||||
seg_logits_list = self.decode_head[-1].predict(x, out, batch_img_metas,
|
||||
self.test_cfg, **kwargs)
|
||||
self.test_cfg)
|
||||
|
||||
return seg_logits_list
|
||||
|
||||
def _decode_head_forward_train(self, batch_inputs: Tensor,
|
||||
batch_data_samples: SampleList,
|
||||
**kwargs) -> dict:
|
||||
batch_data_samples: SampleList) -> dict:
|
||||
"""Run forward function and calculate loss for decode head in
|
||||
training."""
|
||||
losses = dict()
|
||||
|
||||
loss_decode = self.decode_head[0].loss(batch_inputs,
|
||||
batch_data_samples,
|
||||
self.train_cfg, **kwargs)
|
||||
self.train_cfg)
|
||||
|
||||
losses.update(add_prefix(loss_decode, 'decode_0'))
|
||||
# get batch_img_metas
|
||||
|
@ -105,22 +104,20 @@ class CascadeEncoderDecoder(EncoderDecoder):
|
|||
for i in range(1, self.num_stages):
|
||||
# forward test again, maybe unnecessary for most methods.
|
||||
if i == 1:
|
||||
prev_outputs = self.decode_head[0].forward(
|
||||
batch_inputs, **kwargs)
|
||||
prev_outputs = self.decode_head[0].forward(batch_inputs)
|
||||
else:
|
||||
prev_outputs = self.decode_head[i - 1].forward(
|
||||
batch_inputs, prev_outputs, **kwargs)
|
||||
batch_inputs, prev_outputs)
|
||||
loss_decode = self.decode_head[i].loss(batch_inputs, prev_outputs,
|
||||
batch_data_samples,
|
||||
self.train_cfg, **kwargs)
|
||||
self.train_cfg)
|
||||
losses.update(add_prefix(loss_decode, f'decode_{i}'))
|
||||
|
||||
return losses
|
||||
|
||||
def _forward(self,
|
||||
batch_inputs: Tensor,
|
||||
data_samples: OptSampleList = None,
|
||||
**kwargs) -> Tensor:
|
||||
data_samples: OptSampleList = None) -> Tensor:
|
||||
"""Network forward process.
|
||||
|
||||
Args:
|
||||
|
@ -137,6 +134,6 @@ class CascadeEncoderDecoder(EncoderDecoder):
|
|||
out = self.decode_head[0].forward(x)
|
||||
for i in range(1, self.num_stages):
|
||||
# TODO support PointRend tensor mode
|
||||
out = self.decode_head[i].forward(x, out, **kwargs)
|
||||
out = self.decode_head[i].forward(x, out)
|
||||
|
||||
return out
|
||||
|
|
|
@ -120,49 +120,50 @@ class EncoderDecoder(BaseSegmentor):
|
|||
x = self.neck(x)
|
||||
return x
|
||||
|
||||
def encode_decode(self, batch_inputs: Tensor, batch_img_metas: List[dict],
|
||||
**kwargs) -> List[Tensor]:
|
||||
def encode_decode(self, batch_inputs: Tensor,
|
||||
batch_img_metas: List[dict]) -> List[Tensor]:
|
||||
"""Encode images with backbone and decode into a semantic segmentation
|
||||
map of the same size as input."""
|
||||
x = self.extract_feat(batch_inputs)
|
||||
seg_logits_list = self.decode_head.predict(x, batch_img_metas,
|
||||
self.test_cfg, **kwargs)
|
||||
self.test_cfg)
|
||||
|
||||
return seg_logits_list
|
||||
|
||||
def _decode_head_forward_train(self, batch_inputs: List[Tensor],
|
||||
batch_data_samples: SampleList,
|
||||
**kwargs) -> dict:
|
||||
batch_data_samples: SampleList) -> dict:
|
||||
"""Run forward function and calculate loss for decode head in
|
||||
training."""
|
||||
losses = dict()
|
||||
loss_decode = self.decode_head.loss(batch_inputs, batch_data_samples,
|
||||
self.train_cfg, **kwargs)
|
||||
self.train_cfg)
|
||||
|
||||
losses.update(add_prefix(loss_decode, 'decode'))
|
||||
return losses
|
||||
|
||||
def _auxiliary_head_forward_train(self, batch_inputs: List[Tensor],
|
||||
batch_data_samples: SampleList,
|
||||
**kwargs) -> dict:
|
||||
def _auxiliary_head_forward_train(
|
||||
self,
|
||||
batch_inputs: List[Tensor],
|
||||
batch_data_samples: SampleList,
|
||||
) -> dict:
|
||||
"""Run forward function and calculate loss for auxiliary head in
|
||||
training."""
|
||||
losses = dict()
|
||||
if isinstance(self.auxiliary_head, nn.ModuleList):
|
||||
for idx, aux_head in enumerate(self.auxiliary_head):
|
||||
loss_aux = aux_head.loss(batch_inputs, batch_data_samples,
|
||||
self.train_cfg, **kwargs)
|
||||
self.train_cfg)
|
||||
losses.update(add_prefix(loss_aux, f'aux_{idx}'))
|
||||
else:
|
||||
loss_aux = self.auxiliary_head.loss(batch_inputs,
|
||||
batch_data_samples,
|
||||
self.train_cfg, **kwargs)
|
||||
self.train_cfg)
|
||||
losses.update(add_prefix(loss_aux, 'aux'))
|
||||
|
||||
return losses
|
||||
|
||||
def loss(self, batch_inputs: Tensor, batch_data_samples: SampleList,
|
||||
**kwargs) -> dict:
|
||||
def loss(self, batch_inputs: Tensor,
|
||||
batch_data_samples: SampleList) -> dict:
|
||||
"""Calculate losses from a batch of inputs and data samples.
|
||||
|
||||
Args:
|
||||
|
@ -179,19 +180,18 @@ class EncoderDecoder(BaseSegmentor):
|
|||
|
||||
losses = dict()
|
||||
|
||||
loss_decode = self._decode_head_forward_train(x, batch_data_samples,
|
||||
**kwargs)
|
||||
loss_decode = self._decode_head_forward_train(x, batch_data_samples)
|
||||
losses.update(loss_decode)
|
||||
|
||||
if self.with_auxiliary_head:
|
||||
loss_aux = self._auxiliary_head_forward_train(
|
||||
x, batch_data_samples, **kwargs)
|
||||
x, batch_data_samples)
|
||||
losses.update(loss_aux)
|
||||
|
||||
return losses
|
||||
|
||||
def predict(self, batch_inputs: Tensor, batch_data_samples: SampleList,
|
||||
**kwargs) -> SampleList:
|
||||
def predict(self, batch_inputs: Tensor,
|
||||
batch_data_samples: SampleList) -> SampleList:
|
||||
"""Predict results from a batch of inputs and data samples with post-
|
||||
processing.
|
||||
|
||||
|
@ -213,15 +213,13 @@ class EncoderDecoder(BaseSegmentor):
|
|||
for data_sample in batch_data_samples:
|
||||
batch_img_metas.append(data_sample.metainfo)
|
||||
|
||||
seg_logit_list = self.inference(batch_inputs, batch_img_metas,
|
||||
**kwargs)
|
||||
seg_logit_list = self.inference(batch_inputs, batch_img_metas)
|
||||
|
||||
return self.postprocess_result(seg_logit_list, batch_img_metas)
|
||||
|
||||
def _forward(self,
|
||||
batch_inputs: Tensor,
|
||||
data_samples: OptSampleList = None,
|
||||
**kwargs) -> Tensor:
|
||||
data_samples: OptSampleList = None) -> Tensor:
|
||||
"""Network forward process.
|
||||
|
||||
Args:
|
||||
|
@ -234,10 +232,10 @@ class EncoderDecoder(BaseSegmentor):
|
|||
Tensor: Forward output of model without any post-processes.
|
||||
"""
|
||||
x = self.extract_feat(batch_inputs)
|
||||
return self.decode_head.forward(x, **kwargs)
|
||||
return self.decode_head.forward(x)
|
||||
|
||||
def slide_inference(self, batch_inputs: Tensor,
|
||||
batch_img_metas: List[dict], **kwargs) -> List[Tensor]:
|
||||
batch_img_metas: List[dict]) -> List[Tensor]:
|
||||
"""Inference by sliding-window with overlap.
|
||||
|
||||
If h_crop > h_img or w_crop > w_img, the small patch will be used to
|
||||
|
@ -279,8 +277,7 @@ class EncoderDecoder(BaseSegmentor):
|
|||
# the output of encode_decode is list of seg logits map
|
||||
# with shape [C, H, W]
|
||||
crop_seg_logit = torch.stack(
|
||||
self.encode_decode(crop_img, batch_img_metas, **kwargs),
|
||||
dim=0)
|
||||
self.encode_decode(crop_img, batch_img_metas), dim=0)
|
||||
preds += F.pad(crop_seg_logit,
|
||||
(int(x1), int(preds.shape[3] - x2), int(y1),
|
||||
int(preds.shape[2] - y2)))
|
||||
|
@ -292,7 +289,7 @@ class EncoderDecoder(BaseSegmentor):
|
|||
return seg_logits_list
|
||||
|
||||
def whole_inference(self, batch_inputs: Tensor,
|
||||
batch_img_metas: List[dict], **kwargs) -> List[Tensor]:
|
||||
batch_img_metas: List[dict]) -> List[Tensor]:
|
||||
"""Inference with full image.
|
||||
|
||||
Args:
|
||||
|
@ -309,13 +306,12 @@ class EncoderDecoder(BaseSegmentor):
|
|||
model of each input image.
|
||||
"""
|
||||
|
||||
seg_logits_list = self.encode_decode(batch_inputs, batch_img_metas,
|
||||
**kwargs)
|
||||
seg_logits_list = self.encode_decode(batch_inputs, batch_img_metas)
|
||||
|
||||
return seg_logits_list
|
||||
|
||||
def inference(self, batch_inputs: Tensor, batch_img_metas: List[dict],
|
||||
**kwargs) -> List[Tensor]:
|
||||
def inference(self, batch_inputs: Tensor,
|
||||
batch_img_metas: List[dict]) -> List[Tensor]:
|
||||
"""Inference with slide/whole style.
|
||||
|
||||
Args:
|
||||
|
@ -336,10 +332,10 @@ class EncoderDecoder(BaseSegmentor):
|
|||
assert all(_['ori_shape'] == ori_shape for _ in batch_img_metas)
|
||||
if self.test_cfg.mode == 'slide':
|
||||
seg_logit_list = self.slide_inference(batch_inputs,
|
||||
batch_img_metas, **kwargs)
|
||||
batch_img_metas)
|
||||
else:
|
||||
seg_logit_list = self.whole_inference(batch_inputs,
|
||||
batch_img_metas, **kwargs)
|
||||
batch_img_metas)
|
||||
|
||||
return seg_logit_list
|
||||
|
||||
|
|
Loading…
Reference in New Issue