[Fix] Fix loss parse in val_step (#906)

* [Fix] Fix loss parse in val_step

* Add val_step unittest

* Add train_step unittest
This commit is contained in:
Julius Zhang 2021-09-26 09:17:40 +08:00 committed by GitHub
parent e171e806de
commit 29c82eaf13
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 30 additions and 3 deletions

View File

@ -145,15 +145,22 @@ class BaseSegmentor(BaseModule, metaclass=ABCMeta):
return outputs
def val_step(self, data_batch, **kwargs):
def val_step(self, data_batch, optimizer=None, **kwargs):
"""The iteration step during validation.
This method shares the same signature as :func:`train_step`, but used
during val epochs. Note that the evaluation after training epochs is
not implemented with this method, but an evaluation hook.
"""
output = self(**data_batch, **kwargs)
return output
losses = self(**data_batch)
loss, log_vars = self._parse_losses(losses)
outputs = dict(
loss=loss,
log_vars=log_vars,
num_samples=len(data_batch['img_metas']))
return outputs
@staticmethod
def _parse_losses(losses):

View File

@ -101,6 +101,26 @@ def _segmentor_forward_train_test(segmentor):
imgs, img_metas, gt_semantic_seg=gt_semantic_seg, return_loss=True)
assert isinstance(losses, dict)
# Test train_step
data_batch = dict(
img=imgs, img_metas=img_metas, gt_semantic_seg=gt_semantic_seg)
outputs = segmentor.train_step(data_batch, None)
assert isinstance(outputs, dict)
assert 'loss' in outputs
assert 'log_vars' in outputs
assert 'num_samples' in outputs
# Test val_step
with torch.no_grad():
segmentor.eval()
data_batch = dict(
img=imgs, img_metas=img_metas, gt_semantic_seg=gt_semantic_seg)
outputs = segmentor.val_step(data_batch, None)
assert isinstance(outputs, dict)
assert 'loss' in outputs
assert 'log_vars' in outputs
assert 'num_samples' in outputs
# Test forward simple test
with torch.no_grad():
segmentor.eval()