[Fix] Fix a bug when using iter-based runner with 'val' workflow (#542)
* add kwargs and default of optimizer in train_step and val_step * update docstring * update docstring * update optional annotationpull/525/head
parent
49cbfd776a
commit
f361bd52e9
|
@ -118,7 +118,7 @@ class BaseClassifier(BaseModule, metaclass=ABCMeta):
|
|||
|
||||
return loss, log_vars
|
||||
|
||||
def train_step(self, data, optimizer):
|
||||
def train_step(self, data, optimizer=None, **kwargs):
|
||||
"""The iteration step during training.
|
||||
|
||||
This method defines an iteration step during training, except for the
|
||||
|
@ -129,9 +129,9 @@ class BaseClassifier(BaseModule, metaclass=ABCMeta):
|
|||
|
||||
Args:
|
||||
data (dict): The output of dataloader.
|
||||
optimizer (:obj:`torch.optim.Optimizer` | dict): The optimizer of
|
||||
runner is passed to ``train_step()``. This argument is unused
|
||||
and reserved.
|
||||
optimizer (:obj:`torch.optim.Optimizer` | dict, optional): The
|
||||
optimizer of runner is passed to ``train_step()``. This
|
||||
argument is unused and reserved.
|
||||
|
||||
Returns:
|
||||
dict: Dict of outputs. The following fields are contained.
|
||||
|
@ -151,12 +151,28 @@ class BaseClassifier(BaseModule, metaclass=ABCMeta):
|
|||
|
||||
return outputs
|
||||
|
||||
def val_step(self, data, optimizer):
|
||||
def val_step(self, data, optimizer=None, **kwargs):
|
||||
"""The iteration step during validation.
|
||||
|
||||
This method shares the same signature as :func:`train_step`, but used
|
||||
during val epochs. Note that the evaluation after training epochs is
|
||||
not implemented with this method, but an evaluation hook.
|
||||
|
||||
Args:
|
||||
data (dict): The output of dataloader.
|
||||
optimizer (:obj:`torch.optim.Optimizer` | dict, optional): The
|
||||
optimizer of runner is passed to ``train_step()``. This
|
||||
argument is unused and reserved.
|
||||
|
||||
Returns:
|
||||
dict: Dict of outputs. The following fields are contained.
|
||||
- loss (torch.Tensor): A tensor for back propagation, which \
|
||||
can be a weighted sum of multiple losses.
|
||||
- log_vars (dict): Dict contains all the variables to be sent \
|
||||
to the logger.
|
||||
- num_samples (int): Indicates the batch size (when the model \
|
||||
is DDP, it means the batch size on each GPU), which is \
|
||||
used for averaging the logs.
|
||||
"""
|
||||
losses = self(**data)
|
||||
loss, log_vars = self._parse_losses(losses)
|
||||
|
|
|
@ -44,11 +44,21 @@ def test_image_classifier():
|
|||
assert outputs['loss'].item() > 0
|
||||
assert outputs['num_samples'] == 16
|
||||
|
||||
# test train_step without optimizer
|
||||
outputs = model.train_step({'img': imgs, 'gt_label': label})
|
||||
assert outputs['loss'].item() > 0
|
||||
assert outputs['num_samples'] == 16
|
||||
|
||||
# test val_step
|
||||
outputs = model.val_step({'img': imgs, 'gt_label': label}, None)
|
||||
assert outputs['loss'].item() > 0
|
||||
assert outputs['num_samples'] == 16
|
||||
|
||||
# test val_step without optimizer
|
||||
outputs = model.val_step({'img': imgs, 'gt_label': label})
|
||||
assert outputs['loss'].item() > 0
|
||||
assert outputs['num_samples'] == 16
|
||||
|
||||
# test forward
|
||||
losses = model(imgs, return_loss=True, gt_label=label)
|
||||
assert losses['loss'].item() > 0
|
||||
|
|
Loading…
Reference in New Issue