fix cgnet configs and encnet forward error
parent
95b4926e37
commit
55b0c3aef5
|
@ -22,7 +22,8 @@ train_cfg = dict(
|
||||||
type='IterBasedTrainLoop', max_iters=total_iters, val_interval=4000)
|
type='IterBasedTrainLoop', max_iters=total_iters, val_interval=4000)
|
||||||
val_cfg = dict(type='ValLoop')
|
val_cfg = dict(type='ValLoop')
|
||||||
test_cfg = dict(type='TestLoop')
|
test_cfg = dict(type='TestLoop')
|
||||||
default_hooks = dict(checkpoint=dict(by_epoch=False, interval=4000))
|
default_hooks = dict(
|
||||||
|
checkpoint=dict(type='CheckpointHook', by_epoch=False, interval=4000))
|
||||||
|
|
||||||
crop_size = (512, 1024)
|
crop_size = (512, 1024)
|
||||||
data_preprocessor = dict(size=crop_size)
|
data_preprocessor = dict(size=crop_size)
|
||||||
|
|
|
@ -22,7 +22,8 @@ train_cfg = dict(
|
||||||
type='IterBasedTrainLoop', max_iters=total_iters, val_interval=4000)
|
type='IterBasedTrainLoop', max_iters=total_iters, val_interval=4000)
|
||||||
val_cfg = dict(type='ValLoop')
|
val_cfg = dict(type='ValLoop')
|
||||||
test_cfg = dict(type='TestLoop')
|
test_cfg = dict(type='TestLoop')
|
||||||
default_hooks = dict(checkpoint=dict(by_epoch=False, interval=4000))
|
default_hooks = dict(
|
||||||
|
checkpoint=dict(type='CheckpointHook', by_epoch=False, interval=4000))
|
||||||
|
|
||||||
crop_size = (680, 680)
|
crop_size = (680, 680)
|
||||||
data_preprocessor = dict(size=crop_size)
|
data_preprocessor = dict(size=crop_size)
|
||||||
|
|
|
@ -8,6 +8,7 @@ from mmcv.cnn import ConvModule, build_norm_layer
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
from mmseg.core.utils import SampleList
|
from mmseg.core.utils import SampleList
|
||||||
|
from mmseg.core.utils.typing import ConfigType
|
||||||
from mmseg.ops import Encoding, resize
|
from mmseg.ops import Encoding, resize
|
||||||
from mmseg.registry import MODELS
|
from mmseg.registry import MODELS
|
||||||
from ..builder import build_loss
|
from ..builder import build_loss
|
||||||
|
@ -155,13 +156,13 @@ class EncHead(BaseDecodeHead):
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def predict(self, inputs: Tuple[Tensor], batch_img_metas: List[dict],
|
def predict(self, inputs: Tuple[Tensor], batch_img_metas: List[dict],
|
||||||
**kwargs):
|
test_cfg: ConfigType):
|
||||||
"""Forward function for testing, ignore se_loss."""
|
"""Forward function for testing, ignore se_loss."""
|
||||||
if self.use_se_loss:
|
if self.use_se_loss:
|
||||||
seg_logits = self.forward(inputs)[0]
|
seg_logits = self.forward(inputs)[0]
|
||||||
else:
|
else:
|
||||||
seg_logits = self.forward(inputs)
|
seg_logits = self.forward(inputs)
|
||||||
return self.predict_by_feat(seg_logits, batch_img_metas, **kwargs)
|
return self.predict_by_feat(seg_logits, batch_img_metas)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _convert_to_onehot_labels(seg_label, num_classes):
|
def _convert_to_onehot_labels(seg_label, num_classes):
|
||||||
|
@ -188,7 +189,8 @@ class EncHead(BaseDecodeHead):
|
||||||
"""Compute segmentation and semantic encoding loss."""
|
"""Compute segmentation and semantic encoding loss."""
|
||||||
seg_logit, se_seg_logit = seg_logit
|
seg_logit, se_seg_logit = seg_logit
|
||||||
loss = dict()
|
loss = dict()
|
||||||
loss.update(super(EncHead, self).losses(seg_logit, batch_data_samples))
|
loss.update(
|
||||||
|
super(EncHead, self).loss_by_feat(seg_logit, batch_data_samples))
|
||||||
|
|
||||||
seg_label = self._stack_batch_gt(batch_data_samples)
|
seg_label = self._stack_batch_gt(batch_data_samples)
|
||||||
se_loss = self.loss_se_decode(
|
se_loss = self.loss_se_decode(
|
||||||
|
|
Loading…
Reference in New Issue