From f361bd52e973b80fa2095b243f5664cbd866fb50 Mon Sep 17 00:00:00 2001 From: Ezra-Yu <1105212286@qq.com> Date: Mon, 22 Nov 2021 12:06:47 +0800 Subject: [PATCH] [Fix] Fix a bug when using iter-based runner with 'val' workflow (#542) * add kwargs and default of optimizer in train_step and val_step * update docstring * update docstring * update optional annotation --- mmcls/models/classifiers/base.py | 26 +++++++++++++++++++++----- tests/test_models/test_classifiers.py | 10 ++++++++++ 2 files changed, 31 insertions(+), 5 deletions(-) diff --git a/mmcls/models/classifiers/base.py b/mmcls/models/classifiers/base.py index 5090245e..02391c71 100644 --- a/mmcls/models/classifiers/base.py +++ b/mmcls/models/classifiers/base.py @@ -118,7 +118,7 @@ class BaseClassifier(BaseModule, metaclass=ABCMeta): return loss, log_vars - def train_step(self, data, optimizer): + def train_step(self, data, optimizer=None, **kwargs): """The iteration step during training. This method defines an iteration step during training, except for the @@ -129,9 +129,9 @@ class BaseClassifier(BaseModule, metaclass=ABCMeta): Args: data (dict): The output of dataloader. - optimizer (:obj:`torch.optim.Optimizer` | dict): The optimizer of - runner is passed to ``train_step()``. This argument is unused - and reserved. + optimizer (:obj:`torch.optim.Optimizer` | dict, optional): The + optimizer of runner is passed to ``train_step()``. This + argument is unused and reserved. Returns: dict: Dict of outputs. The following fields are contained. @@ -151,12 +151,28 @@ class BaseClassifier(BaseModule, metaclass=ABCMeta): return outputs - def val_step(self, data, optimizer): + def val_step(self, data, optimizer=None, **kwargs): """The iteration step during validation. This method shares the same signature as :func:`train_step`, but used during val epochs. Note that the evaluation after training epochs is not implemented with this method, but an evaluation hook. + + Args: + data (dict): The output of dataloader. + optimizer (:obj:`torch.optim.Optimizer` | dict, optional): The + optimizer of runner is passed to ``train_step()``. This + argument is unused and reserved. + + Returns: + dict: Dict of outputs. The following fields are contained. + - loss (torch.Tensor): A tensor for back propagation, which \ + can be a weighted sum of multiple losses. + - log_vars (dict): Dict contains all the variables to be sent \ + to the logger. + - num_samples (int): Indicates the batch size (when the model \ + is DDP, it means the batch size on each GPU), which is \ + used for averaging the logs. """ losses = self(**data) loss, log_vars = self._parse_losses(losses) diff --git a/tests/test_models/test_classifiers.py b/tests/test_models/test_classifiers.py index 7b5df469..41cffbb6 100644 --- a/tests/test_models/test_classifiers.py +++ b/tests/test_models/test_classifiers.py @@ -44,11 +44,21 @@ def test_image_classifier(): assert outputs['loss'].item() > 0 assert outputs['num_samples'] == 16 + # test train_step without optimizer + outputs = model.train_step({'img': imgs, 'gt_label': label}) + assert outputs['loss'].item() > 0 + assert outputs['num_samples'] == 16 + # test val_step outputs = model.val_step({'img': imgs, 'gt_label': label}, None) assert outputs['loss'].item() > 0 assert outputs['num_samples'] == 16 + # test val_step without optimizer + outputs = model.val_step({'img': imgs, 'gt_label': label}) + assert outputs['loss'].item() > 0 + assert outputs['num_samples'] == 16 + # test forward losses = model(imgs, return_loss=True, gt_label=label) assert losses['loss'].item() > 0