diff --git a/mmseg/models/segmentors/base.py b/mmseg/models/segmentors/base.py index 906c6fe56..944da0f2e 100644 --- a/mmseg/models/segmentors/base.py +++ b/mmseg/models/segmentors/base.py @@ -145,15 +145,22 @@ class BaseSegmentor(BaseModule, metaclass=ABCMeta): return outputs - def val_step(self, data_batch, **kwargs): + def val_step(self, data_batch, optimizer=None, **kwargs): """The iteration step during validation. This method shares the same signature as :func:`train_step`, but used during val epochs. Note that the evaluation after training epochs is not implemented with this method, but an evaluation hook. """ - output = self(**data_batch, **kwargs) - return output + losses = self(**data_batch) + loss, log_vars = self._parse_losses(losses) + + outputs = dict( + loss=loss, + log_vars=log_vars, + num_samples=len(data_batch['img_metas'])) + + return outputs @staticmethod def _parse_losses(losses): diff --git a/tests/test_models/test_segmentors/utils.py b/tests/test_models/test_segmentors/utils.py index 0f51a4b1f..1826dbf85 100644 --- a/tests/test_models/test_segmentors/utils.py +++ b/tests/test_models/test_segmentors/utils.py @@ -101,6 +101,26 @@ def _segmentor_forward_train_test(segmentor): imgs, img_metas, gt_semantic_seg=gt_semantic_seg, return_loss=True) assert isinstance(losses, dict) + # Test train_step + data_batch = dict( + img=imgs, img_metas=img_metas, gt_semantic_seg=gt_semantic_seg) + outputs = segmentor.train_step(data_batch, None) + assert isinstance(outputs, dict) + assert 'loss' in outputs + assert 'log_vars' in outputs + assert 'num_samples' in outputs + + # Test val_step + with torch.no_grad(): + segmentor.eval() + data_batch = dict( + img=imgs, img_metas=img_metas, gt_semantic_seg=gt_semantic_seg) + outputs = segmentor.val_step(data_batch, None) + assert isinstance(outputs, dict) + assert 'loss' in outputs + assert 'log_vars' in outputs + assert 'num_samples' in outputs + # Test forward simple test with torch.no_grad(): segmentor.eval()