Merge branch 'limengzhang/fix_kwargs' into 'refactor_dev'

[Fix] Delete all **kwargs in Segmentor Forward function

See merge request openmmlab-enterprise/openmmlab-ce/mmsegmentation!52
pull/1801/head
zhengmiao 2022-06-22 08:24:13 +00:00
commit eef12a064b
5 changed files with 70 additions and 79 deletions

View File

@ -21,8 +21,7 @@ class BaseCascadeDecodeHead(BaseDecodeHead, metaclass=ABCMeta):
pass pass
def loss(self, inputs: List[Tensor], prev_output: Tensor, def loss(self, inputs: List[Tensor], prev_output: Tensor,
batch_data_samples: List[dict], train_cfg: ConfigType, batch_data_samples: List[dict], train_cfg: ConfigType) -> Tensor:
**kwargs) -> Tensor:
"""Forward function for training. """Forward function for training.
Args: Args:
@ -37,12 +36,12 @@ class BaseCascadeDecodeHead(BaseDecodeHead, metaclass=ABCMeta):
dict[str, Tensor]: a dictionary of loss components dict[str, Tensor]: a dictionary of loss components
""" """
seg_logits = self.forward(inputs, prev_output) seg_logits = self.forward(inputs, prev_output)
losses = self.loss_by_feat(seg_logits, batch_data_samples, **kwargs) losses = self.loss_by_feat(seg_logits, batch_data_samples)
return losses return losses
def predict(self, inputs: List[Tensor], prev_output: Tensor, def predict(self, inputs: List[Tensor], prev_output: Tensor,
batch_img_metas: List[dict], tese_cfg: ConfigType, **kwargs): batch_img_metas: List[dict], tese_cfg: ConfigType):
"""Forward function for testing. """Forward function for testing.
Args: Args:
@ -60,4 +59,4 @@ class BaseCascadeDecodeHead(BaseDecodeHead, metaclass=ABCMeta):
""" """
seg_logits = self.forward(inputs, prev_output) seg_logits = self.forward(inputs, prev_output)
return self.predict_by_feat(seg_logits, batch_img_metas, **kwargs) return self.predict_by_feat(seg_logits, batch_img_metas)

View File

@ -205,7 +205,7 @@ class BaseDecodeHead(BaseModule, metaclass=ABCMeta):
return inputs return inputs
@abstractmethod @abstractmethod
def forward(self, inputs, **kwargs): def forward(self, inputs):
"""Placeholder of forward function.""" """Placeholder of forward function."""
pass pass
@ -217,7 +217,7 @@ class BaseDecodeHead(BaseModule, metaclass=ABCMeta):
return output return output
def loss(self, inputs: Tuple[Tensor], batch_data_samples: SampleList, def loss(self, inputs: Tuple[Tensor], batch_data_samples: SampleList,
train_cfg: ConfigType, **kwargs) -> dict: train_cfg: ConfigType) -> dict:
"""Forward function for training. """Forward function for training.
Args: Args:
@ -230,12 +230,12 @@ class BaseDecodeHead(BaseModule, metaclass=ABCMeta):
Returns: Returns:
dict[str, Tensor]: a dictionary of loss components dict[str, Tensor]: a dictionary of loss components
""" """
seg_logits = self.forward(inputs, **kwargs) seg_logits = self.forward(inputs)
losses = self.loss_by_feat(seg_logits, batch_data_samples, **kwargs) losses = self.loss_by_feat(seg_logits, batch_data_samples)
return losses return losses
def predict(self, inputs: Tuple[Tensor], batch_img_metas: List[dict], def predict(self, inputs: Tuple[Tensor], batch_img_metas: List[dict],
test_cfg: ConfigType, **kwargs) -> List[Tensor]: test_cfg: ConfigType) -> List[Tensor]:
"""Forward function for prediction. """Forward function for prediction.
Args: Args:
@ -250,9 +250,9 @@ class BaseDecodeHead(BaseModule, metaclass=ABCMeta):
Returns: Returns:
List[Tensor]: Outputs segmentation logits map. List[Tensor]: Outputs segmentation logits map.
""" """
seg_logits = self.forward(inputs, **kwargs) seg_logits = self.forward(inputs)
return self.predict_by_feat(seg_logits, batch_img_metas, **kwargs) return self.predict_by_feat(seg_logits, batch_img_metas)
def _stack_batch_gt(self, batch_data_samples: SampleList) -> Tensor: def _stack_batch_gt(self, batch_data_samples: SampleList) -> Tensor:
gt_semantic_segs = [ gt_semantic_segs = [
@ -260,8 +260,8 @@ class BaseDecodeHead(BaseModule, metaclass=ABCMeta):
] ]
return torch.stack(gt_semantic_segs, dim=0) return torch.stack(gt_semantic_segs, dim=0)
def loss_by_feat(self, seg_logits: Tensor, batch_data_samples: SampleList, def loss_by_feat(self, seg_logits: Tensor,
**kwargs) -> dict: batch_data_samples: SampleList) -> dict:
"""Compute segmentation loss. """Compute segmentation loss.
Args: Args:
@ -309,8 +309,8 @@ class BaseDecodeHead(BaseModule, metaclass=ABCMeta):
seg_logits, seg_label, ignore_index=self.ignore_index) seg_logits, seg_label, ignore_index=self.ignore_index)
return loss return loss
def predict_by_feat(self, seg_logits: Tensor, batch_img_metas: List[dict], def predict_by_feat(self, seg_logits: Tensor,
**kwargs) -> List[Tensor]: batch_img_metas: List[dict]) -> List[Tensor]:
"""Transform a batch of output seg_logits to the input shape. """Transform a batch of output seg_logits to the input shape.
Args: Args:

View File

@ -53,7 +53,7 @@ class BaseSegmentor(BaseModel, metaclass=ABCMeta):
@abstractmethod @abstractmethod
def encode_decode(self, batch_inputs: Tensor, def encode_decode(self, batch_inputs: Tensor,
batch_data_samples: SampleList, **kwargs): batch_data_samples: SampleList):
"""Placeholder for encode images with backbone and decode into a """Placeholder for encode images with backbone and decode into a
semantic segmentation map of the same size as input.""" semantic segmentation map of the same size as input."""
pass pass
@ -61,8 +61,7 @@ class BaseSegmentor(BaseModel, metaclass=ABCMeta):
def forward(self, def forward(self,
batch_inputs: Tensor, batch_inputs: Tensor,
batch_data_samples: OptSampleList = None, batch_data_samples: OptSampleList = None,
mode: str = 'tensor', mode: str = 'tensor') -> ForwardResults:
**kwargs) -> ForwardResults:
"""The unified entry for a forward process in both training and test. """The unified entry for a forward process in both training and test.
The method should accept three modes: "tensor", "predict" and "loss": The method should accept three modes: "tensor", "predict" and "loss":
@ -92,33 +91,33 @@ class BaseSegmentor(BaseModel, metaclass=ABCMeta):
- If ``mode="loss"``, return a dict of tensor. - If ``mode="loss"``, return a dict of tensor.
""" """
if mode == 'loss': if mode == 'loss':
return self.loss(batch_inputs, batch_data_samples, **kwargs) return self.loss(batch_inputs, batch_data_samples)
elif mode == 'predict': elif mode == 'predict':
return self.predict(batch_inputs, batch_data_samples, **kwargs) return self.predict(batch_inputs, batch_data_samples)
elif mode == 'tensor': elif mode == 'tensor':
return self._forward(batch_inputs, batch_data_samples, **kwargs) return self._forward(batch_inputs, batch_data_samples)
else: else:
raise RuntimeError(f'Invalid mode "{mode}". ' raise RuntimeError(f'Invalid mode "{mode}". '
'Only supports loss, predict and tensor mode') 'Only supports loss, predict and tensor mode')
@abstractmethod @abstractmethod
def loss(self, batch_inputs: Tensor, batch_data_samples: SampleList, def loss(self, batch_inputs: Tensor,
**kwargs) -> dict: batch_data_samples: SampleList) -> dict:
"""Calculate losses from a batch of inputs and data samples.""" """Calculate losses from a batch of inputs and data samples."""
pass pass
@abstractmethod @abstractmethod
def predict(self, batch_inputs: Tensor, batch_data_samples: SampleList, def predict(self, batch_inputs: Tensor,
**kwargs) -> SampleList: batch_data_samples: SampleList) -> SampleList:
"""Predict results from a batch of inputs and data samples with post- """Predict results from a batch of inputs and data samples with post-
processing.""" processing."""
pass pass
@abstractmethod @abstractmethod
def _forward(self, def _forward(
self,
batch_inputs: Tensor, batch_inputs: Tensor,
batch_data_samples: OptSampleList = None, batch_data_samples: OptSampleList = None) -> Tuple[List[Tensor]]:
**kwargs) -> Tuple[List[Tensor]]:
"""Network forward process. """Network forward process.
Usually includes backbone, neck and head forward without any post- Usually includes backbone, neck and head forward without any post-
@ -127,7 +126,7 @@ class BaseSegmentor(BaseModel, metaclass=ABCMeta):
pass pass
@abstractmethod @abstractmethod
def aug_test(self, batch_inputs, batch_img_metas, **kwargs): def aug_test(self, batch_inputs, batch_img_metas):
"""Placeholder for augmentation test.""" """Placeholder for augmentation test."""
pass pass

View File

@ -70,29 +70,28 @@ class CascadeEncoderDecoder(EncoderDecoder):
self.align_corners = self.decode_head[-1].align_corners self.align_corners = self.decode_head[-1].align_corners
self.num_classes = self.decode_head[-1].num_classes self.num_classes = self.decode_head[-1].num_classes
def encode_decode(self, batch_inputs: Tensor, batch_img_metas: List[dict], def encode_decode(self, batch_inputs: Tensor,
**kwargs) -> List[Tensor]: batch_img_metas: List[dict]) -> List[Tensor]:
"""Encode images with backbone and decode into a semantic segmentation """Encode images with backbone and decode into a semantic segmentation
map of the same size as input.""" map of the same size as input."""
x = self.extract_feat(batch_inputs) x = self.extract_feat(batch_inputs)
out = self.decode_head[0].forward(x, **kwargs) out = self.decode_head[0].forward(x)
for i in range(1, self.num_stages - 1): for i in range(1, self.num_stages - 1):
out = self.decode_head[i].forward(x, out, **kwargs) out = self.decode_head[i].forward(x, out)
seg_logits_list = self.decode_head[-1].predict(x, out, batch_img_metas, seg_logits_list = self.decode_head[-1].predict(x, out, batch_img_metas,
self.test_cfg, **kwargs) self.test_cfg)
return seg_logits_list return seg_logits_list
def _decode_head_forward_train(self, batch_inputs: Tensor, def _decode_head_forward_train(self, batch_inputs: Tensor,
batch_data_samples: SampleList, batch_data_samples: SampleList) -> dict:
**kwargs) -> dict:
"""Run forward function and calculate loss for decode head in """Run forward function and calculate loss for decode head in
training.""" training."""
losses = dict() losses = dict()
loss_decode = self.decode_head[0].loss(batch_inputs, loss_decode = self.decode_head[0].loss(batch_inputs,
batch_data_samples, batch_data_samples,
self.train_cfg, **kwargs) self.train_cfg)
losses.update(add_prefix(loss_decode, 'decode_0')) losses.update(add_prefix(loss_decode, 'decode_0'))
# get batch_img_metas # get batch_img_metas
@ -105,22 +104,20 @@ class CascadeEncoderDecoder(EncoderDecoder):
for i in range(1, self.num_stages): for i in range(1, self.num_stages):
# forward test again, maybe unnecessary for most methods. # forward test again, maybe unnecessary for most methods.
if i == 1: if i == 1:
prev_outputs = self.decode_head[0].forward( prev_outputs = self.decode_head[0].forward(batch_inputs)
batch_inputs, **kwargs)
else: else:
prev_outputs = self.decode_head[i - 1].forward( prev_outputs = self.decode_head[i - 1].forward(
batch_inputs, prev_outputs, **kwargs) batch_inputs, prev_outputs)
loss_decode = self.decode_head[i].loss(batch_inputs, prev_outputs, loss_decode = self.decode_head[i].loss(batch_inputs, prev_outputs,
batch_data_samples, batch_data_samples,
self.train_cfg, **kwargs) self.train_cfg)
losses.update(add_prefix(loss_decode, f'decode_{i}')) losses.update(add_prefix(loss_decode, f'decode_{i}'))
return losses return losses
def _forward(self, def _forward(self,
batch_inputs: Tensor, batch_inputs: Tensor,
data_samples: OptSampleList = None, data_samples: OptSampleList = None) -> Tensor:
**kwargs) -> Tensor:
"""Network forward process. """Network forward process.
Args: Args:
@ -137,6 +134,6 @@ class CascadeEncoderDecoder(EncoderDecoder):
out = self.decode_head[0].forward(x) out = self.decode_head[0].forward(x)
for i in range(1, self.num_stages): for i in range(1, self.num_stages):
# TODO support PointRend tensor mode # TODO support PointRend tensor mode
out = self.decode_head[i].forward(x, out, **kwargs) out = self.decode_head[i].forward(x, out)
return out return out

View File

@ -120,49 +120,50 @@ class EncoderDecoder(BaseSegmentor):
x = self.neck(x) x = self.neck(x)
return x return x
def encode_decode(self, batch_inputs: Tensor, batch_img_metas: List[dict], def encode_decode(self, batch_inputs: Tensor,
**kwargs) -> List[Tensor]: batch_img_metas: List[dict]) -> List[Tensor]:
"""Encode images with backbone and decode into a semantic segmentation """Encode images with backbone and decode into a semantic segmentation
map of the same size as input.""" map of the same size as input."""
x = self.extract_feat(batch_inputs) x = self.extract_feat(batch_inputs)
seg_logits_list = self.decode_head.predict(x, batch_img_metas, seg_logits_list = self.decode_head.predict(x, batch_img_metas,
self.test_cfg, **kwargs) self.test_cfg)
return seg_logits_list return seg_logits_list
def _decode_head_forward_train(self, batch_inputs: List[Tensor], def _decode_head_forward_train(self, batch_inputs: List[Tensor],
batch_data_samples: SampleList, batch_data_samples: SampleList) -> dict:
**kwargs) -> dict:
"""Run forward function and calculate loss for decode head in """Run forward function and calculate loss for decode head in
training.""" training."""
losses = dict() losses = dict()
loss_decode = self.decode_head.loss(batch_inputs, batch_data_samples, loss_decode = self.decode_head.loss(batch_inputs, batch_data_samples,
self.train_cfg, **kwargs) self.train_cfg)
losses.update(add_prefix(loss_decode, 'decode')) losses.update(add_prefix(loss_decode, 'decode'))
return losses return losses
def _auxiliary_head_forward_train(self, batch_inputs: List[Tensor], def _auxiliary_head_forward_train(
self,
batch_inputs: List[Tensor],
batch_data_samples: SampleList, batch_data_samples: SampleList,
**kwargs) -> dict: ) -> dict:
"""Run forward function and calculate loss for auxiliary head in """Run forward function and calculate loss for auxiliary head in
training.""" training."""
losses = dict() losses = dict()
if isinstance(self.auxiliary_head, nn.ModuleList): if isinstance(self.auxiliary_head, nn.ModuleList):
for idx, aux_head in enumerate(self.auxiliary_head): for idx, aux_head in enumerate(self.auxiliary_head):
loss_aux = aux_head.loss(batch_inputs, batch_data_samples, loss_aux = aux_head.loss(batch_inputs, batch_data_samples,
self.train_cfg, **kwargs) self.train_cfg)
losses.update(add_prefix(loss_aux, f'aux_{idx}')) losses.update(add_prefix(loss_aux, f'aux_{idx}'))
else: else:
loss_aux = self.auxiliary_head.loss(batch_inputs, loss_aux = self.auxiliary_head.loss(batch_inputs,
batch_data_samples, batch_data_samples,
self.train_cfg, **kwargs) self.train_cfg)
losses.update(add_prefix(loss_aux, 'aux')) losses.update(add_prefix(loss_aux, 'aux'))
return losses return losses
def loss(self, batch_inputs: Tensor, batch_data_samples: SampleList, def loss(self, batch_inputs: Tensor,
**kwargs) -> dict: batch_data_samples: SampleList) -> dict:
"""Calculate losses from a batch of inputs and data samples. """Calculate losses from a batch of inputs and data samples.
Args: Args:
@ -179,19 +180,18 @@ class EncoderDecoder(BaseSegmentor):
losses = dict() losses = dict()
loss_decode = self._decode_head_forward_train(x, batch_data_samples, loss_decode = self._decode_head_forward_train(x, batch_data_samples)
**kwargs)
losses.update(loss_decode) losses.update(loss_decode)
if self.with_auxiliary_head: if self.with_auxiliary_head:
loss_aux = self._auxiliary_head_forward_train( loss_aux = self._auxiliary_head_forward_train(
x, batch_data_samples, **kwargs) x, batch_data_samples)
losses.update(loss_aux) losses.update(loss_aux)
return losses return losses
def predict(self, batch_inputs: Tensor, batch_data_samples: SampleList, def predict(self, batch_inputs: Tensor,
**kwargs) -> SampleList: batch_data_samples: SampleList) -> SampleList:
"""Predict results from a batch of inputs and data samples with post- """Predict results from a batch of inputs and data samples with post-
processing. processing.
@ -213,15 +213,13 @@ class EncoderDecoder(BaseSegmentor):
for data_sample in batch_data_samples: for data_sample in batch_data_samples:
batch_img_metas.append(data_sample.metainfo) batch_img_metas.append(data_sample.metainfo)
seg_logit_list = self.inference(batch_inputs, batch_img_metas, seg_logit_list = self.inference(batch_inputs, batch_img_metas)
**kwargs)
return self.postprocess_result(seg_logit_list, batch_img_metas) return self.postprocess_result(seg_logit_list, batch_img_metas)
def _forward(self, def _forward(self,
batch_inputs: Tensor, batch_inputs: Tensor,
data_samples: OptSampleList = None, data_samples: OptSampleList = None) -> Tensor:
**kwargs) -> Tensor:
"""Network forward process. """Network forward process.
Args: Args:
@ -234,10 +232,10 @@ class EncoderDecoder(BaseSegmentor):
Tensor: Forward output of model without any post-processes. Tensor: Forward output of model without any post-processes.
""" """
x = self.extract_feat(batch_inputs) x = self.extract_feat(batch_inputs)
return self.decode_head.forward(x, **kwargs) return self.decode_head.forward(x)
def slide_inference(self, batch_inputs: Tensor, def slide_inference(self, batch_inputs: Tensor,
batch_img_metas: List[dict], **kwargs) -> List[Tensor]: batch_img_metas: List[dict]) -> List[Tensor]:
"""Inference by sliding-window with overlap. """Inference by sliding-window with overlap.
If h_crop > h_img or w_crop > w_img, the small patch will be used to If h_crop > h_img or w_crop > w_img, the small patch will be used to
@ -279,8 +277,7 @@ class EncoderDecoder(BaseSegmentor):
# the output of encode_decode is list of seg logits map # the output of encode_decode is list of seg logits map
# with shape [C, H, W] # with shape [C, H, W]
crop_seg_logit = torch.stack( crop_seg_logit = torch.stack(
self.encode_decode(crop_img, batch_img_metas, **kwargs), self.encode_decode(crop_img, batch_img_metas), dim=0)
dim=0)
preds += F.pad(crop_seg_logit, preds += F.pad(crop_seg_logit,
(int(x1), int(preds.shape[3] - x2), int(y1), (int(x1), int(preds.shape[3] - x2), int(y1),
int(preds.shape[2] - y2))) int(preds.shape[2] - y2)))
@ -292,7 +289,7 @@ class EncoderDecoder(BaseSegmentor):
return seg_logits_list return seg_logits_list
def whole_inference(self, batch_inputs: Tensor, def whole_inference(self, batch_inputs: Tensor,
batch_img_metas: List[dict], **kwargs) -> List[Tensor]: batch_img_metas: List[dict]) -> List[Tensor]:
"""Inference with full image. """Inference with full image.
Args: Args:
@ -309,13 +306,12 @@ class EncoderDecoder(BaseSegmentor):
model of each input image. model of each input image.
""" """
seg_logits_list = self.encode_decode(batch_inputs, batch_img_metas, seg_logits_list = self.encode_decode(batch_inputs, batch_img_metas)
**kwargs)
return seg_logits_list return seg_logits_list
def inference(self, batch_inputs: Tensor, batch_img_metas: List[dict], def inference(self, batch_inputs: Tensor,
**kwargs) -> List[Tensor]: batch_img_metas: List[dict]) -> List[Tensor]:
"""Inference with slide/whole style. """Inference with slide/whole style.
Args: Args:
@ -336,10 +332,10 @@ class EncoderDecoder(BaseSegmentor):
assert all(_['ori_shape'] == ori_shape for _ in batch_img_metas) assert all(_['ori_shape'] == ori_shape for _ in batch_img_metas)
if self.test_cfg.mode == 'slide': if self.test_cfg.mode == 'slide':
seg_logit_list = self.slide_inference(batch_inputs, seg_logit_list = self.slide_inference(batch_inputs,
batch_img_metas, **kwargs) batch_img_metas)
else: else:
seg_logit_list = self.whole_inference(batch_inputs, seg_logit_list = self.whole_inference(batch_inputs,
batch_img_metas, **kwargs) batch_img_metas)
return seg_logit_list return seg_logit_list