[Fix] Fix loss parse in val_step (#906)

* [Fix] Fix loss parse in val_step

* Add val_step unittest

* Add train_step unittest
This commit is contained in:
Julius Zhang 2021-09-26 09:17:40 +08:00 committed by GitHub
parent e171e806de
commit 29c82eaf13
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 30 additions and 3 deletions

View File

@ -145,15 +145,22 @@ class BaseSegmentor(BaseModule, metaclass=ABCMeta):
return outputs return outputs
def val_step(self, data_batch, **kwargs): def val_step(self, data_batch, optimizer=None, **kwargs):
"""The iteration step during validation. """The iteration step during validation.
This method shares the same signature as :func:`train_step`, but used This method shares the same signature as :func:`train_step`, but used
during val epochs. Note that the evaluation after training epochs is during val epochs. Note that the evaluation after training epochs is
not implemented with this method, but an evaluation hook. not implemented with this method, but an evaluation hook.
""" """
output = self(**data_batch, **kwargs) losses = self(**data_batch)
return output loss, log_vars = self._parse_losses(losses)
outputs = dict(
loss=loss,
log_vars=log_vars,
num_samples=len(data_batch['img_metas']))
return outputs
@staticmethod @staticmethod
def _parse_losses(losses): def _parse_losses(losses):

View File

@ -101,6 +101,26 @@ def _segmentor_forward_train_test(segmentor):
imgs, img_metas, gt_semantic_seg=gt_semantic_seg, return_loss=True) imgs, img_metas, gt_semantic_seg=gt_semantic_seg, return_loss=True)
assert isinstance(losses, dict) assert isinstance(losses, dict)
# Test train_step
data_batch = dict(
img=imgs, img_metas=img_metas, gt_semantic_seg=gt_semantic_seg)
outputs = segmentor.train_step(data_batch, None)
assert isinstance(outputs, dict)
assert 'loss' in outputs
assert 'log_vars' in outputs
assert 'num_samples' in outputs
# Test val_step
with torch.no_grad():
segmentor.eval()
data_batch = dict(
img=imgs, img_metas=img_metas, gt_semantic_seg=gt_semantic_seg)
outputs = segmentor.val_step(data_batch, None)
assert isinstance(outputs, dict)
assert 'loss' in outputs
assert 'log_vars' in outputs
assert 'num_samples' in outputs
# Test forward simple test # Test forward simple test
with torch.no_grad(): with torch.no_grad():
segmentor.eval() segmentor.eval()