[Fix] Fix a bug when using iter-based runner with 'val' workflow (#542)

* add kwargs and default of optimizer in train_step and val_step

* update docstring

* update docstring

* update optional annotation
pull/525/head
Ezra-Yu 2021-11-22 12:06:47 +08:00 committed by GitHub
parent 49cbfd776a
commit f361bd52e9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 31 additions and 5 deletions

View File

@ -118,7 +118,7 @@ class BaseClassifier(BaseModule, metaclass=ABCMeta):
return loss, log_vars
def train_step(self, data, optimizer):
def train_step(self, data, optimizer=None, **kwargs):
"""The iteration step during training.
This method defines an iteration step during training, except for the
@ -129,9 +129,9 @@ class BaseClassifier(BaseModule, metaclass=ABCMeta):
Args:
data (dict): The output of dataloader.
optimizer (:obj:`torch.optim.Optimizer` | dict): The optimizer of
runner is passed to ``train_step()``. This argument is unused
and reserved.
optimizer (:obj:`torch.optim.Optimizer` | dict, optional): The
optimizer of runner is passed to ``train_step()``. This
argument is unused and reserved.
Returns:
dict: Dict of outputs. The following fields are contained.
@ -151,12 +151,28 @@ class BaseClassifier(BaseModule, metaclass=ABCMeta):
return outputs
def val_step(self, data, optimizer):
def val_step(self, data, optimizer=None, **kwargs):
"""The iteration step during validation.
This method shares the same signature as :func:`train_step`, but used
during val epochs. Note that the evaluation after training epochs is
not implemented with this method, but an evaluation hook.
Args:
data (dict): The output of dataloader.
optimizer (:obj:`torch.optim.Optimizer` | dict, optional): The
optimizer of runner is passed to ``train_step()``. This
argument is unused and reserved.
Returns:
dict: Dict of outputs. The following fields are contained.
- loss (torch.Tensor): A tensor for back propagation, which \
can be a weighted sum of multiple losses.
- log_vars (dict): Dict contains all the variables to be sent \
to the logger.
- num_samples (int): Indicates the batch size (when the model \
is DDP, it means the batch size on each GPU), which is \
used for averaging the logs.
"""
losses = self(**data)
loss, log_vars = self._parse_losses(losses)

View File

@ -44,11 +44,21 @@ def test_image_classifier():
assert outputs['loss'].item() > 0
assert outputs['num_samples'] == 16
# test train_step without optimizer
outputs = model.train_step({'img': imgs, 'gt_label': label})
assert outputs['loss'].item() > 0
assert outputs['num_samples'] == 16
# test val_step
outputs = model.val_step({'img': imgs, 'gt_label': label}, None)
assert outputs['loss'].item() > 0
assert outputs['num_samples'] == 16
# test val_step without optimizer
outputs = model.val_step({'img': imgs, 'gt_label': label})
assert outputs['loss'].item() > 0
assert outputs['num_samples'] == 16
# test forward
losses = model(imgs, return_loss=True, gt_label=label)
assert losses['loss'].item() > 0