Merge branch 'xiexinchen/fix_enc_cgnet' into 'refactor_dev'

[Fix] Fix cgnet configs and encnet forward error

See merge request openmmlab-enterprise/openmmlab-ce/mmsegmentation!59
This commit is contained in:
zhengmiao 2022-07-07 11:03:13 +00:00
commit 320e41004d
3 changed files with 9 additions and 5 deletions

View File

@ -22,7 +22,8 @@ train_cfg = dict(
type='IterBasedTrainLoop', max_iters=total_iters, val_interval=4000) type='IterBasedTrainLoop', max_iters=total_iters, val_interval=4000)
val_cfg = dict(type='ValLoop') val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop') test_cfg = dict(type='TestLoop')
default_hooks = dict(checkpoint=dict(by_epoch=False, interval=4000)) default_hooks = dict(
checkpoint=dict(type='CheckpointHook', by_epoch=False, interval=4000))
crop_size = (512, 1024) crop_size = (512, 1024)
data_preprocessor = dict(size=crop_size) data_preprocessor = dict(size=crop_size)

View File

@ -22,7 +22,8 @@ train_cfg = dict(
type='IterBasedTrainLoop', max_iters=total_iters, val_interval=4000) type='IterBasedTrainLoop', max_iters=total_iters, val_interval=4000)
val_cfg = dict(type='ValLoop') val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop') test_cfg = dict(type='TestLoop')
default_hooks = dict(checkpoint=dict(by_epoch=False, interval=4000)) default_hooks = dict(
checkpoint=dict(type='CheckpointHook', by_epoch=False, interval=4000))
crop_size = (680, 680) crop_size = (680, 680)
data_preprocessor = dict(size=crop_size) data_preprocessor = dict(size=crop_size)

View File

@ -8,6 +8,7 @@ from mmcv.cnn import ConvModule, build_norm_layer
from torch import Tensor from torch import Tensor
from mmseg.core.utils import SampleList from mmseg.core.utils import SampleList
from mmseg.core.utils.typing import ConfigType
from mmseg.ops import Encoding, resize from mmseg.ops import Encoding, resize
from mmseg.registry import MODELS from mmseg.registry import MODELS
from ..builder import build_loss from ..builder import build_loss
@ -155,13 +156,13 @@ class EncHead(BaseDecodeHead):
return output return output
def predict(self, inputs: Tuple[Tensor], batch_img_metas: List[dict], def predict(self, inputs: Tuple[Tensor], batch_img_metas: List[dict],
**kwargs): test_cfg: ConfigType):
"""Forward function for testing, ignore se_loss.""" """Forward function for testing, ignore se_loss."""
if self.use_se_loss: if self.use_se_loss:
seg_logits = self.forward(inputs)[0] seg_logits = self.forward(inputs)[0]
else: else:
seg_logits = self.forward(inputs) seg_logits = self.forward(inputs)
return self.predict_by_feat(seg_logits, batch_img_metas, **kwargs) return self.predict_by_feat(seg_logits, batch_img_metas)
@staticmethod @staticmethod
def _convert_to_onehot_labels(seg_label, num_classes): def _convert_to_onehot_labels(seg_label, num_classes):
@ -188,7 +189,8 @@ class EncHead(BaseDecodeHead):
"""Compute segmentation and semantic encoding loss.""" """Compute segmentation and semantic encoding loss."""
seg_logit, se_seg_logit = seg_logit seg_logit, se_seg_logit = seg_logit
loss = dict() loss = dict()
loss.update(super(EncHead, self).losses(seg_logit, batch_data_samples)) loss.update(
super(EncHead, self).loss_by_feat(seg_logit, batch_data_samples))
seg_label = self._stack_batch_gt(batch_data_samples) seg_label = self._stack_batch_gt(batch_data_samples)
se_loss = self.loss_se_decode( se_loss = self.loss_se_decode(