fix cgnet configs and encnet forward error

pull/1801/head
xiexinch 2022-07-07 14:18:17 +08:00
parent 95b4926e37
commit 55b0c3aef5
3 changed files with 9 additions and 5 deletions

View File

@ -22,7 +22,8 @@ train_cfg = dict(
type='IterBasedTrainLoop', max_iters=total_iters, val_interval=4000)
val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')
default_hooks = dict(checkpoint=dict(by_epoch=False, interval=4000))
default_hooks = dict(
checkpoint=dict(type='CheckpointHook', by_epoch=False, interval=4000))
crop_size = (512, 1024)
data_preprocessor = dict(size=crop_size)

View File

@ -22,7 +22,8 @@ train_cfg = dict(
type='IterBasedTrainLoop', max_iters=total_iters, val_interval=4000)
val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')
default_hooks = dict(checkpoint=dict(by_epoch=False, interval=4000))
default_hooks = dict(
checkpoint=dict(type='CheckpointHook', by_epoch=False, interval=4000))
crop_size = (680, 680)
data_preprocessor = dict(size=crop_size)

View File

@ -8,6 +8,7 @@ from mmcv.cnn import ConvModule, build_norm_layer
from torch import Tensor
from mmseg.core.utils import SampleList
from mmseg.core.utils.typing import ConfigType
from mmseg.ops import Encoding, resize
from mmseg.registry import MODELS
from ..builder import build_loss
@ -155,13 +156,13 @@ class EncHead(BaseDecodeHead):
return output
def predict(self, inputs: Tuple[Tensor], batch_img_metas: List[dict],
**kwargs):
test_cfg: ConfigType):
"""Forward function for testing, ignore se_loss."""
if self.use_se_loss:
seg_logits = self.forward(inputs)[0]
else:
seg_logits = self.forward(inputs)
return self.predict_by_feat(seg_logits, batch_img_metas, **kwargs)
return self.predict_by_feat(seg_logits, batch_img_metas)
@staticmethod
def _convert_to_onehot_labels(seg_label, num_classes):
@ -188,7 +189,8 @@ class EncHead(BaseDecodeHead):
"""Compute segmentation and semantic encoding loss."""
seg_logit, se_seg_logit = seg_logit
loss = dict()
loss.update(super(EncHead, self).losses(seg_logit, batch_data_samples))
loss.update(
super(EncHead, self).loss_by_feat(seg_logit, batch_data_samples))
seg_label = self._stack_batch_gt(batch_data_samples)
se_loss = self.loss_se_decode(