mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
[Fix] Fix loss parse in val_step (#906)
* [Fix] Fix loss parse in val_step * Add val_step unittest * Add train_step unittest
This commit is contained in:
parent
e171e806de
commit
29c82eaf13
@ -145,15 +145,22 @@ class BaseSegmentor(BaseModule, metaclass=ABCMeta):
|
|||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
def val_step(self, data_batch, **kwargs):
|
def val_step(self, data_batch, optimizer=None, **kwargs):
|
||||||
"""The iteration step during validation.
|
"""The iteration step during validation.
|
||||||
|
|
||||||
This method shares the same signature as :func:`train_step`, but used
|
This method shares the same signature as :func:`train_step`, but used
|
||||||
during val epochs. Note that the evaluation after training epochs is
|
during val epochs. Note that the evaluation after training epochs is
|
||||||
not implemented with this method, but an evaluation hook.
|
not implemented with this method, but an evaluation hook.
|
||||||
"""
|
"""
|
||||||
output = self(**data_batch, **kwargs)
|
losses = self(**data_batch)
|
||||||
return output
|
loss, log_vars = self._parse_losses(losses)
|
||||||
|
|
||||||
|
outputs = dict(
|
||||||
|
loss=loss,
|
||||||
|
log_vars=log_vars,
|
||||||
|
num_samples=len(data_batch['img_metas']))
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _parse_losses(losses):
|
def _parse_losses(losses):
|
||||||
|
@ -101,6 +101,26 @@ def _segmentor_forward_train_test(segmentor):
|
|||||||
imgs, img_metas, gt_semantic_seg=gt_semantic_seg, return_loss=True)
|
imgs, img_metas, gt_semantic_seg=gt_semantic_seg, return_loss=True)
|
||||||
assert isinstance(losses, dict)
|
assert isinstance(losses, dict)
|
||||||
|
|
||||||
|
# Test train_step
|
||||||
|
data_batch = dict(
|
||||||
|
img=imgs, img_metas=img_metas, gt_semantic_seg=gt_semantic_seg)
|
||||||
|
outputs = segmentor.train_step(data_batch, None)
|
||||||
|
assert isinstance(outputs, dict)
|
||||||
|
assert 'loss' in outputs
|
||||||
|
assert 'log_vars' in outputs
|
||||||
|
assert 'num_samples' in outputs
|
||||||
|
|
||||||
|
# Test val_step
|
||||||
|
with torch.no_grad():
|
||||||
|
segmentor.eval()
|
||||||
|
data_batch = dict(
|
||||||
|
img=imgs, img_metas=img_metas, gt_semantic_seg=gt_semantic_seg)
|
||||||
|
outputs = segmentor.val_step(data_batch, None)
|
||||||
|
assert isinstance(outputs, dict)
|
||||||
|
assert 'loss' in outputs
|
||||||
|
assert 'log_vars' in outputs
|
||||||
|
assert 'num_samples' in outputs
|
||||||
|
|
||||||
# Test forward simple test
|
# Test forward simple test
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
segmentor.eval()
|
segmentor.eval()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user