fix cgnet configs and encnet forward error
parent
95b4926e37
commit
55b0c3aef5
|
@ -22,7 +22,8 @@ train_cfg = dict(
|
|||
type='IterBasedTrainLoop', max_iters=total_iters, val_interval=4000)
|
||||
val_cfg = dict(type='ValLoop')
|
||||
test_cfg = dict(type='TestLoop')
|
||||
default_hooks = dict(checkpoint=dict(by_epoch=False, interval=4000))
|
||||
default_hooks = dict(
|
||||
checkpoint=dict(type='CheckpointHook', by_epoch=False, interval=4000))
|
||||
|
||||
crop_size = (512, 1024)
|
||||
data_preprocessor = dict(size=crop_size)
|
||||
|
|
|
@ -22,7 +22,8 @@ train_cfg = dict(
|
|||
type='IterBasedTrainLoop', max_iters=total_iters, val_interval=4000)
|
||||
val_cfg = dict(type='ValLoop')
|
||||
test_cfg = dict(type='TestLoop')
|
||||
default_hooks = dict(checkpoint=dict(by_epoch=False, interval=4000))
|
||||
default_hooks = dict(
|
||||
checkpoint=dict(type='CheckpointHook', by_epoch=False, interval=4000))
|
||||
|
||||
crop_size = (680, 680)
|
||||
data_preprocessor = dict(size=crop_size)
|
||||
|
|
|
@ -8,6 +8,7 @@ from mmcv.cnn import ConvModule, build_norm_layer
|
|||
from torch import Tensor
|
||||
|
||||
from mmseg.core.utils import SampleList
|
||||
from mmseg.core.utils.typing import ConfigType
|
||||
from mmseg.ops import Encoding, resize
|
||||
from mmseg.registry import MODELS
|
||||
from ..builder import build_loss
|
||||
|
@ -155,13 +156,13 @@ class EncHead(BaseDecodeHead):
|
|||
return output
|
||||
|
||||
def predict(self, inputs: Tuple[Tensor], batch_img_metas: List[dict],
|
||||
**kwargs):
|
||||
test_cfg: ConfigType):
|
||||
"""Forward function for testing, ignore se_loss."""
|
||||
if self.use_se_loss:
|
||||
seg_logits = self.forward(inputs)[0]
|
||||
else:
|
||||
seg_logits = self.forward(inputs)
|
||||
return self.predict_by_feat(seg_logits, batch_img_metas, **kwargs)
|
||||
return self.predict_by_feat(seg_logits, batch_img_metas)
|
||||
|
||||
@staticmethod
|
||||
def _convert_to_onehot_labels(seg_label, num_classes):
|
||||
|
@ -188,7 +189,8 @@ class EncHead(BaseDecodeHead):
|
|||
"""Compute segmentation and semantic encoding loss."""
|
||||
seg_logit, se_seg_logit = seg_logit
|
||||
loss = dict()
|
||||
loss.update(super(EncHead, self).losses(seg_logit, batch_data_samples))
|
||||
loss.update(
|
||||
super(EncHead, self).loss_by_feat(seg_logit, batch_data_samples))
|
||||
|
||||
seg_label = self._stack_batch_gt(batch_data_samples)
|
||||
se_loss = self.loss_se_decode(
|
||||
|
|
Loading…
Reference in New Issue