fix cgnet configs and encnet forward error

pull/1801/head
xiexinch 2022-07-07 14:18:17 +08:00
parent 95b4926e37
commit 55b0c3aef5
3 changed files with 9 additions and 5 deletions

View File

@ -22,7 +22,8 @@ train_cfg = dict(
type='IterBasedTrainLoop', max_iters=total_iters, val_interval=4000) type='IterBasedTrainLoop', max_iters=total_iters, val_interval=4000)
val_cfg = dict(type='ValLoop') val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop') test_cfg = dict(type='TestLoop')
default_hooks = dict(checkpoint=dict(by_epoch=False, interval=4000)) default_hooks = dict(
checkpoint=dict(type='CheckpointHook', by_epoch=False, interval=4000))
crop_size = (512, 1024) crop_size = (512, 1024)
data_preprocessor = dict(size=crop_size) data_preprocessor = dict(size=crop_size)

View File

@ -22,7 +22,8 @@ train_cfg = dict(
type='IterBasedTrainLoop', max_iters=total_iters, val_interval=4000) type='IterBasedTrainLoop', max_iters=total_iters, val_interval=4000)
val_cfg = dict(type='ValLoop') val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop') test_cfg = dict(type='TestLoop')
default_hooks = dict(checkpoint=dict(by_epoch=False, interval=4000)) default_hooks = dict(
checkpoint=dict(type='CheckpointHook', by_epoch=False, interval=4000))
crop_size = (680, 680) crop_size = (680, 680)
data_preprocessor = dict(size=crop_size) data_preprocessor = dict(size=crop_size)

View File

@ -8,6 +8,7 @@ from mmcv.cnn import ConvModule, build_norm_layer
from torch import Tensor from torch import Tensor
from mmseg.core.utils import SampleList from mmseg.core.utils import SampleList
from mmseg.core.utils.typing import ConfigType
from mmseg.ops import Encoding, resize from mmseg.ops import Encoding, resize
from mmseg.registry import MODELS from mmseg.registry import MODELS
from ..builder import build_loss from ..builder import build_loss
@ -155,13 +156,13 @@ class EncHead(BaseDecodeHead):
return output return output
def predict(self, inputs: Tuple[Tensor], batch_img_metas: List[dict], def predict(self, inputs: Tuple[Tensor], batch_img_metas: List[dict],
**kwargs): test_cfg: ConfigType):
"""Forward function for testing, ignore se_loss.""" """Forward function for testing, ignore se_loss."""
if self.use_se_loss: if self.use_se_loss:
seg_logits = self.forward(inputs)[0] seg_logits = self.forward(inputs)[0]
else: else:
seg_logits = self.forward(inputs) seg_logits = self.forward(inputs)
return self.predict_by_feat(seg_logits, batch_img_metas, **kwargs) return self.predict_by_feat(seg_logits, batch_img_metas)
@staticmethod @staticmethod
def _convert_to_onehot_labels(seg_label, num_classes): def _convert_to_onehot_labels(seg_label, num_classes):
@ -188,7 +189,8 @@ class EncHead(BaseDecodeHead):
"""Compute segmentation and semantic encoding loss.""" """Compute segmentation and semantic encoding loss."""
seg_logit, se_seg_logit = seg_logit seg_logit, se_seg_logit = seg_logit
loss = dict() loss = dict()
loss.update(super(EncHead, self).losses(seg_logit, batch_data_samples)) loss.update(
super(EncHead, self).loss_by_feat(seg_logit, batch_data_samples))
seg_label = self._stack_batch_gt(batch_data_samples) seg_label = self._stack_batch_gt(batch_data_samples)
se_loss = self.loss_se_decode( se_loss = self.loss_se_decode(