Merge branch 'limengzhang/fix_kwargs' into 'refactor_dev'

[Fix] Delete all **kwargs in Segmentor Forward function

See merge request openmmlab-enterprise/openmmlab-ce/mmsegmentation!52
pull/1801/head
zhengmiao 2022-06-22 08:24:13 +00:00
commit eef12a064b
5 changed files with 70 additions and 79 deletions

View File

@ -21,8 +21,7 @@ class BaseCascadeDecodeHead(BaseDecodeHead, metaclass=ABCMeta):
pass
def loss(self, inputs: List[Tensor], prev_output: Tensor,
batch_data_samples: List[dict], train_cfg: ConfigType,
**kwargs) -> Tensor:
batch_data_samples: List[dict], train_cfg: ConfigType) -> Tensor:
"""Forward function for training.
Args:
@ -37,12 +36,12 @@ class BaseCascadeDecodeHead(BaseDecodeHead, metaclass=ABCMeta):
dict[str, Tensor]: a dictionary of loss components
"""
seg_logits = self.forward(inputs, prev_output)
losses = self.loss_by_feat(seg_logits, batch_data_samples, **kwargs)
losses = self.loss_by_feat(seg_logits, batch_data_samples)
return losses
def predict(self, inputs: List[Tensor], prev_output: Tensor,
batch_img_metas: List[dict], tese_cfg: ConfigType, **kwargs):
batch_img_metas: List[dict], tese_cfg: ConfigType):
"""Forward function for testing.
Args:
@ -60,4 +59,4 @@ class BaseCascadeDecodeHead(BaseDecodeHead, metaclass=ABCMeta):
"""
seg_logits = self.forward(inputs, prev_output)
return self.predict_by_feat(seg_logits, batch_img_metas, **kwargs)
return self.predict_by_feat(seg_logits, batch_img_metas)

View File

@ -205,7 +205,7 @@ class BaseDecodeHead(BaseModule, metaclass=ABCMeta):
return inputs
@abstractmethod
def forward(self, inputs, **kwargs):
def forward(self, inputs):
"""Placeholder of forward function."""
pass
@ -217,7 +217,7 @@ class BaseDecodeHead(BaseModule, metaclass=ABCMeta):
return output
def loss(self, inputs: Tuple[Tensor], batch_data_samples: SampleList,
train_cfg: ConfigType, **kwargs) -> dict:
train_cfg: ConfigType) -> dict:
"""Forward function for training.
Args:
@ -230,12 +230,12 @@ class BaseDecodeHead(BaseModule, metaclass=ABCMeta):
Returns:
dict[str, Tensor]: a dictionary of loss components
"""
seg_logits = self.forward(inputs, **kwargs)
losses = self.loss_by_feat(seg_logits, batch_data_samples, **kwargs)
seg_logits = self.forward(inputs)
losses = self.loss_by_feat(seg_logits, batch_data_samples)
return losses
def predict(self, inputs: Tuple[Tensor], batch_img_metas: List[dict],
test_cfg: ConfigType, **kwargs) -> List[Tensor]:
test_cfg: ConfigType) -> List[Tensor]:
"""Forward function for prediction.
Args:
@ -250,9 +250,9 @@ class BaseDecodeHead(BaseModule, metaclass=ABCMeta):
Returns:
List[Tensor]: Outputs segmentation logits map.
"""
seg_logits = self.forward(inputs, **kwargs)
seg_logits = self.forward(inputs)
return self.predict_by_feat(seg_logits, batch_img_metas, **kwargs)
return self.predict_by_feat(seg_logits, batch_img_metas)
def _stack_batch_gt(self, batch_data_samples: SampleList) -> Tensor:
gt_semantic_segs = [
@ -260,8 +260,8 @@ class BaseDecodeHead(BaseModule, metaclass=ABCMeta):
]
return torch.stack(gt_semantic_segs, dim=0)
def loss_by_feat(self, seg_logits: Tensor, batch_data_samples: SampleList,
**kwargs) -> dict:
def loss_by_feat(self, seg_logits: Tensor,
batch_data_samples: SampleList) -> dict:
"""Compute segmentation loss.
Args:
@ -309,8 +309,8 @@ class BaseDecodeHead(BaseModule, metaclass=ABCMeta):
seg_logits, seg_label, ignore_index=self.ignore_index)
return loss
def predict_by_feat(self, seg_logits: Tensor, batch_img_metas: List[dict],
**kwargs) -> List[Tensor]:
def predict_by_feat(self, seg_logits: Tensor,
batch_img_metas: List[dict]) -> List[Tensor]:
"""Transform a batch of output seg_logits to the input shape.
Args:

View File

@ -53,7 +53,7 @@ class BaseSegmentor(BaseModel, metaclass=ABCMeta):
@abstractmethod
def encode_decode(self, batch_inputs: Tensor,
batch_data_samples: SampleList, **kwargs):
batch_data_samples: SampleList):
"""Placeholder for encode images with backbone and decode into a
semantic segmentation map of the same size as input."""
pass
@ -61,8 +61,7 @@ class BaseSegmentor(BaseModel, metaclass=ABCMeta):
def forward(self,
batch_inputs: Tensor,
batch_data_samples: OptSampleList = None,
mode: str = 'tensor',
**kwargs) -> ForwardResults:
mode: str = 'tensor') -> ForwardResults:
"""The unified entry for a forward process in both training and test.
The method should accept three modes: "tensor", "predict" and "loss":
@ -92,33 +91,33 @@ class BaseSegmentor(BaseModel, metaclass=ABCMeta):
- If ``mode="loss"``, return a dict of tensor.
"""
if mode == 'loss':
return self.loss(batch_inputs, batch_data_samples, **kwargs)
return self.loss(batch_inputs, batch_data_samples)
elif mode == 'predict':
return self.predict(batch_inputs, batch_data_samples, **kwargs)
return self.predict(batch_inputs, batch_data_samples)
elif mode == 'tensor':
return self._forward(batch_inputs, batch_data_samples, **kwargs)
return self._forward(batch_inputs, batch_data_samples)
else:
raise RuntimeError(f'Invalid mode "{mode}". '
'Only supports loss, predict and tensor mode')
@abstractmethod
def loss(self, batch_inputs: Tensor, batch_data_samples: SampleList,
**kwargs) -> dict:
def loss(self, batch_inputs: Tensor,
batch_data_samples: SampleList) -> dict:
"""Calculate losses from a batch of inputs and data samples."""
pass
@abstractmethod
def predict(self, batch_inputs: Tensor, batch_data_samples: SampleList,
**kwargs) -> SampleList:
def predict(self, batch_inputs: Tensor,
batch_data_samples: SampleList) -> SampleList:
"""Predict results from a batch of inputs and data samples with post-
processing."""
pass
@abstractmethod
def _forward(self,
batch_inputs: Tensor,
batch_data_samples: OptSampleList = None,
**kwargs) -> Tuple[List[Tensor]]:
def _forward(
self,
batch_inputs: Tensor,
batch_data_samples: OptSampleList = None) -> Tuple[List[Tensor]]:
"""Network forward process.
Usually includes backbone, neck and head forward without any post-
@ -127,7 +126,7 @@ class BaseSegmentor(BaseModel, metaclass=ABCMeta):
pass
@abstractmethod
def aug_test(self, batch_inputs, batch_img_metas, **kwargs):
def aug_test(self, batch_inputs, batch_img_metas):
"""Placeholder for augmentation test."""
pass

View File

@ -70,29 +70,28 @@ class CascadeEncoderDecoder(EncoderDecoder):
self.align_corners = self.decode_head[-1].align_corners
self.num_classes = self.decode_head[-1].num_classes
def encode_decode(self, batch_inputs: Tensor, batch_img_metas: List[dict],
**kwargs) -> List[Tensor]:
def encode_decode(self, batch_inputs: Tensor,
batch_img_metas: List[dict]) -> List[Tensor]:
"""Encode images with backbone and decode into a semantic segmentation
map of the same size as input."""
x = self.extract_feat(batch_inputs)
out = self.decode_head[0].forward(x, **kwargs)
out = self.decode_head[0].forward(x)
for i in range(1, self.num_stages - 1):
out = self.decode_head[i].forward(x, out, **kwargs)
out = self.decode_head[i].forward(x, out)
seg_logits_list = self.decode_head[-1].predict(x, out, batch_img_metas,
self.test_cfg, **kwargs)
self.test_cfg)
return seg_logits_list
def _decode_head_forward_train(self, batch_inputs: Tensor,
batch_data_samples: SampleList,
**kwargs) -> dict:
batch_data_samples: SampleList) -> dict:
"""Run forward function and calculate loss for decode head in
training."""
losses = dict()
loss_decode = self.decode_head[0].loss(batch_inputs,
batch_data_samples,
self.train_cfg, **kwargs)
self.train_cfg)
losses.update(add_prefix(loss_decode, 'decode_0'))
# get batch_img_metas
@ -105,22 +104,20 @@ class CascadeEncoderDecoder(EncoderDecoder):
for i in range(1, self.num_stages):
# forward test again, maybe unnecessary for most methods.
if i == 1:
prev_outputs = self.decode_head[0].forward(
batch_inputs, **kwargs)
prev_outputs = self.decode_head[0].forward(batch_inputs)
else:
prev_outputs = self.decode_head[i - 1].forward(
batch_inputs, prev_outputs, **kwargs)
batch_inputs, prev_outputs)
loss_decode = self.decode_head[i].loss(batch_inputs, prev_outputs,
batch_data_samples,
self.train_cfg, **kwargs)
self.train_cfg)
losses.update(add_prefix(loss_decode, f'decode_{i}'))
return losses
def _forward(self,
batch_inputs: Tensor,
data_samples: OptSampleList = None,
**kwargs) -> Tensor:
data_samples: OptSampleList = None) -> Tensor:
"""Network forward process.
Args:
@ -137,6 +134,6 @@ class CascadeEncoderDecoder(EncoderDecoder):
out = self.decode_head[0].forward(x)
for i in range(1, self.num_stages):
# TODO support PointRend tensor mode
out = self.decode_head[i].forward(x, out, **kwargs)
out = self.decode_head[i].forward(x, out)
return out

View File

@ -120,49 +120,50 @@ class EncoderDecoder(BaseSegmentor):
x = self.neck(x)
return x
def encode_decode(self, batch_inputs: Tensor, batch_img_metas: List[dict],
**kwargs) -> List[Tensor]:
def encode_decode(self, batch_inputs: Tensor,
batch_img_metas: List[dict]) -> List[Tensor]:
"""Encode images with backbone and decode into a semantic segmentation
map of the same size as input."""
x = self.extract_feat(batch_inputs)
seg_logits_list = self.decode_head.predict(x, batch_img_metas,
self.test_cfg, **kwargs)
self.test_cfg)
return seg_logits_list
def _decode_head_forward_train(self, batch_inputs: List[Tensor],
batch_data_samples: SampleList,
**kwargs) -> dict:
batch_data_samples: SampleList) -> dict:
"""Run forward function and calculate loss for decode head in
training."""
losses = dict()
loss_decode = self.decode_head.loss(batch_inputs, batch_data_samples,
self.train_cfg, **kwargs)
self.train_cfg)
losses.update(add_prefix(loss_decode, 'decode'))
return losses
def _auxiliary_head_forward_train(self, batch_inputs: List[Tensor],
batch_data_samples: SampleList,
**kwargs) -> dict:
def _auxiliary_head_forward_train(
self,
batch_inputs: List[Tensor],
batch_data_samples: SampleList,
) -> dict:
"""Run forward function and calculate loss for auxiliary head in
training."""
losses = dict()
if isinstance(self.auxiliary_head, nn.ModuleList):
for idx, aux_head in enumerate(self.auxiliary_head):
loss_aux = aux_head.loss(batch_inputs, batch_data_samples,
self.train_cfg, **kwargs)
self.train_cfg)
losses.update(add_prefix(loss_aux, f'aux_{idx}'))
else:
loss_aux = self.auxiliary_head.loss(batch_inputs,
batch_data_samples,
self.train_cfg, **kwargs)
self.train_cfg)
losses.update(add_prefix(loss_aux, 'aux'))
return losses
def loss(self, batch_inputs: Tensor, batch_data_samples: SampleList,
**kwargs) -> dict:
def loss(self, batch_inputs: Tensor,
batch_data_samples: SampleList) -> dict:
"""Calculate losses from a batch of inputs and data samples.
Args:
@ -179,19 +180,18 @@ class EncoderDecoder(BaseSegmentor):
losses = dict()
loss_decode = self._decode_head_forward_train(x, batch_data_samples,
**kwargs)
loss_decode = self._decode_head_forward_train(x, batch_data_samples)
losses.update(loss_decode)
if self.with_auxiliary_head:
loss_aux = self._auxiliary_head_forward_train(
x, batch_data_samples, **kwargs)
x, batch_data_samples)
losses.update(loss_aux)
return losses
def predict(self, batch_inputs: Tensor, batch_data_samples: SampleList,
**kwargs) -> SampleList:
def predict(self, batch_inputs: Tensor,
batch_data_samples: SampleList) -> SampleList:
"""Predict results from a batch of inputs and data samples with post-
processing.
@ -213,15 +213,13 @@ class EncoderDecoder(BaseSegmentor):
for data_sample in batch_data_samples:
batch_img_metas.append(data_sample.metainfo)
seg_logit_list = self.inference(batch_inputs, batch_img_metas,
**kwargs)
seg_logit_list = self.inference(batch_inputs, batch_img_metas)
return self.postprocess_result(seg_logit_list, batch_img_metas)
def _forward(self,
batch_inputs: Tensor,
data_samples: OptSampleList = None,
**kwargs) -> Tensor:
data_samples: OptSampleList = None) -> Tensor:
"""Network forward process.
Args:
@ -234,10 +232,10 @@ class EncoderDecoder(BaseSegmentor):
Tensor: Forward output of model without any post-processes.
"""
x = self.extract_feat(batch_inputs)
return self.decode_head.forward(x, **kwargs)
return self.decode_head.forward(x)
def slide_inference(self, batch_inputs: Tensor,
batch_img_metas: List[dict], **kwargs) -> List[Tensor]:
batch_img_metas: List[dict]) -> List[Tensor]:
"""Inference by sliding-window with overlap.
If h_crop > h_img or w_crop > w_img, the small patch will be used to
@ -279,8 +277,7 @@ class EncoderDecoder(BaseSegmentor):
# the output of encode_decode is list of seg logits map
# with shape [C, H, W]
crop_seg_logit = torch.stack(
self.encode_decode(crop_img, batch_img_metas, **kwargs),
dim=0)
self.encode_decode(crop_img, batch_img_metas), dim=0)
preds += F.pad(crop_seg_logit,
(int(x1), int(preds.shape[3] - x2), int(y1),
int(preds.shape[2] - y2)))
@ -292,7 +289,7 @@ class EncoderDecoder(BaseSegmentor):
return seg_logits_list
def whole_inference(self, batch_inputs: Tensor,
batch_img_metas: List[dict], **kwargs) -> List[Tensor]:
batch_img_metas: List[dict]) -> List[Tensor]:
"""Inference with full image.
Args:
@ -309,13 +306,12 @@ class EncoderDecoder(BaseSegmentor):
model of each input image.
"""
seg_logits_list = self.encode_decode(batch_inputs, batch_img_metas,
**kwargs)
seg_logits_list = self.encode_decode(batch_inputs, batch_img_metas)
return seg_logits_list
def inference(self, batch_inputs: Tensor, batch_img_metas: List[dict],
**kwargs) -> List[Tensor]:
def inference(self, batch_inputs: Tensor,
batch_img_metas: List[dict]) -> List[Tensor]:
"""Inference with slide/whole style.
Args:
@ -336,10 +332,10 @@ class EncoderDecoder(BaseSegmentor):
assert all(_['ori_shape'] == ori_shape for _ in batch_img_metas)
if self.test_cfg.mode == 'slide':
seg_logit_list = self.slide_inference(batch_inputs,
batch_img_metas, **kwargs)
batch_img_metas)
else:
seg_logit_list = self.whole_inference(batch_inputs,
batch_img_metas, **kwargs)
batch_img_metas)
return seg_logit_list