diff --git a/configs/cgnet/cgnet_512x1024_60k_cityscapes.py b/configs/cgnet/cgnet_512x1024_60k_cityscapes.py index dcabaf01c..fc9ad7c1c 100644 --- a/configs/cgnet/cgnet_512x1024_60k_cityscapes.py +++ b/configs/cgnet/cgnet_512x1024_60k_cityscapes.py @@ -22,7 +22,8 @@ train_cfg = dict( type='IterBasedTrainLoop', max_iters=total_iters, val_interval=4000) val_cfg = dict(type='ValLoop') test_cfg = dict(type='TestLoop') -default_hooks = dict(checkpoint=dict(by_epoch=False, interval=4000)) +default_hooks = dict( + checkpoint=dict(type='CheckpointHook', by_epoch=False, interval=4000)) crop_size = (512, 1024) data_preprocessor = dict(size=crop_size) diff --git a/configs/cgnet/cgnet_680x680_60k_cityscapes.py b/configs/cgnet/cgnet_680x680_60k_cityscapes.py index 93457fe51..4854b481e 100644 --- a/configs/cgnet/cgnet_680x680_60k_cityscapes.py +++ b/configs/cgnet/cgnet_680x680_60k_cityscapes.py @@ -22,7 +22,8 @@ train_cfg = dict( type='IterBasedTrainLoop', max_iters=total_iters, val_interval=4000) val_cfg = dict(type='ValLoop') test_cfg = dict(type='TestLoop') -default_hooks = dict(checkpoint=dict(by_epoch=False, interval=4000)) +default_hooks = dict( + checkpoint=dict(type='CheckpointHook', by_epoch=False, interval=4000)) crop_size = (680, 680) data_preprocessor = dict(size=crop_size) diff --git a/mmseg/models/decode_heads/enc_head.py b/mmseg/models/decode_heads/enc_head.py index 0915d35fb..05eaa8582 100644 --- a/mmseg/models/decode_heads/enc_head.py +++ b/mmseg/models/decode_heads/enc_head.py @@ -8,6 +8,7 @@ from mmcv.cnn import ConvModule, build_norm_layer from torch import Tensor from mmseg.core.utils import SampleList +from mmseg.core.utils.typing import ConfigType from mmseg.ops import Encoding, resize from mmseg.registry import MODELS from ..builder import build_loss @@ -155,13 +156,13 @@ class EncHead(BaseDecodeHead): return output def predict(self, inputs: Tuple[Tensor], batch_img_metas: List[dict], - **kwargs): + test_cfg: ConfigType): """Forward function for testing, ignore se_loss.""" if self.use_se_loss: seg_logits = self.forward(inputs)[0] else: seg_logits = self.forward(inputs) - return self.predict_by_feat(seg_logits, batch_img_metas, **kwargs) + return self.predict_by_feat(seg_logits, batch_img_metas) @staticmethod def _convert_to_onehot_labels(seg_label, num_classes): @@ -188,7 +189,8 @@ class EncHead(BaseDecodeHead): """Compute segmentation and semantic encoding loss.""" seg_logit, se_seg_logit = seg_logit loss = dict() - loss.update(super(EncHead, self).losses(seg_logit, batch_data_samples)) + loss.update( + super(EncHead, self).loss_by_feat(seg_logit, batch_data_samples)) seg_label = self._stack_batch_gt(batch_data_samples) se_loss = self.loss_se_decode(