[Fix] Fix loss could be nan in optimizer wrapper (#345)

* fix optimizer wrapper counts

* fix ut
pull/357/head
Mashiro 2022-07-06 16:42:49 +08:00 committed by GitHub
parent 96378fa748
commit b2ee9f8b11
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 13 additions and 16 deletions

View File

@ -115,11 +115,6 @@ class OptimWrapper:
# lost when the `_max_counts` is not divisible by
# `accumulative_counts`.
self._max_counts = -1
# If `_inner_count` is smaller than `_divisible_counts`, the loss
# factor used for gradient accumulation should be the same as
# `_accumulative_counts`. If `_max_counts` has not been initialized,
# the loss factor will always be the same as `_accumulative_counts`.
self._divisible_counts = -1
# The `_remainder_iter` is used for calculating loss factor at the
# last few iterations. If `_max_counts` has not been initialized,
# the loss factor will always be the same as `_accumulative_counts`.
@ -333,14 +328,8 @@ class OptimWrapper:
self.logger.warning(
'Gradient accumulative may slightly decrease '
'performance because the model has BatchNorm layers.')
residual_counts = max_counts - init_counts
# The maximum number of training iteration that is divisible by
# `_accumulative_counts`.
self._divisible_counts = (
residual_counts // self._accumulative_counts *
self._accumulative_counts)
# Remainder of `_max_counts` divided by `_accumulative_counts`
self._remainder_counts = residual_counts - self._divisible_counts
self._remainder_counts = self._max_counts % self._accumulative_counts
def should_update(self) -> bool:
"""Decide whether the parameters should be updated at the current
@ -396,7 +385,7 @@ class OptimWrapper:
# be divisible by `self._accumulative_counts`, so the
# `loss_scale` for the last few iterations needs to be
# recalculated.
if self._inner_count < self._divisible_counts:
if self._inner_count < self._max_counts - self._remainder_counts:
loss_factor = self._accumulative_counts
else:
loss_factor = self._remainder_counts

View File

@ -68,7 +68,6 @@ class TestOptimWrapper(MultiProcessTestCase):
self.assertIs(optim_wrapper.message_hub, self.message_hub)
self.assertEqual(optim_wrapper._inner_count, 0)
self.assertEqual(optim_wrapper._max_counts, -1)
self.assertEqual(optim_wrapper._divisible_counts, -1)
self.assertEqual(optim_wrapper._remainder_counts, -1)
with self.assertRaisesRegex(AssertionError,
@ -119,12 +118,21 @@ class TestOptimWrapper(MultiProcessTestCase):
optim_wrapper.update_params(loss)
optim_wrapper.step.assert_called()
optim_wrapper.zero_grad.assert_called()
self.assertEqual(optim_wrapper.scaled_loss, torch.tensor(1))
self.assertEqual(optim_wrapper.scaled_loss, torch.tensor(1.))
self._mock_method(optim_wrapper)
# optim_wrapper.step should not be called at iteration 97 98, and the
# loss factor should be 3 at iteration 99.
optim_wrapper.initialize_count_status(self.model, 96, 100)
for _ in range(2):
optim_wrapper.update_params(loss)
optim_wrapper.step.assert_not_called()
optim_wrapper.zero_grad.assert_not_called()
self.assertEqual(optim_wrapper.scaled_loss, torch.tensor(1.) / 3)
def test_initialize_iter_status(self):
optim_wrapper = OptimWrapper(self.optimizer, accumulative_counts=3)
optim_wrapper.initialize_count_status(self.model, 0, 100)
self.assertEqual(optim_wrapper._divisible_counts, 99)
self.assertEqual(optim_wrapper._remainder_counts, 1)
# Indivisible cur_iter will output warning.