[Fix] Fix loss could be nan in optimizer wrapper (#345)
* fix optimizer wrapper counts * fix utpull/357/head
parent
96378fa748
commit
b2ee9f8b11
|
@ -115,11 +115,6 @@ class OptimWrapper:
|
|||
# lost when the `_max_counts` is not divisible by
|
||||
# `accumulative_counts`.
|
||||
self._max_counts = -1
|
||||
# If `_inner_count` is smaller than `_divisible_counts`, the loss
|
||||
# factor used for gradient accumulation should be the same as
|
||||
# `_accumulative_counts`. If `_max_counts` has not been initialized,
|
||||
# the loss factor will always be the same as `_accumulative_counts`.
|
||||
self._divisible_counts = -1
|
||||
# The `_remainder_iter` is used for calculating loss factor at the
|
||||
# last few iterations. If `_max_counts` has not been initialized,
|
||||
# the loss factor will always be the same as `_accumulative_counts`.
|
||||
|
@ -333,14 +328,8 @@ class OptimWrapper:
|
|||
self.logger.warning(
|
||||
'Gradient accumulative may slightly decrease '
|
||||
'performance because the model has BatchNorm layers.')
|
||||
residual_counts = max_counts - init_counts
|
||||
# The maximum number of training iteration that is divisible by
|
||||
# `_accumulative_counts`.
|
||||
self._divisible_counts = (
|
||||
residual_counts // self._accumulative_counts *
|
||||
self._accumulative_counts)
|
||||
# Remainder of `_max_counts` divided by `_accumulative_counts`
|
||||
self._remainder_counts = residual_counts - self._divisible_counts
|
||||
self._remainder_counts = self._max_counts % self._accumulative_counts
|
||||
|
||||
def should_update(self) -> bool:
|
||||
"""Decide whether the parameters should be updated at the current
|
||||
|
@ -396,7 +385,7 @@ class OptimWrapper:
|
|||
# be divisible by `self._accumulative_counts`, so the
|
||||
# `loss_scale` for the last few iterations needs to be
|
||||
# recalculated.
|
||||
if self._inner_count < self._divisible_counts:
|
||||
if self._inner_count < self._max_counts - self._remainder_counts:
|
||||
loss_factor = self._accumulative_counts
|
||||
else:
|
||||
loss_factor = self._remainder_counts
|
||||
|
|
|
@ -68,7 +68,6 @@ class TestOptimWrapper(MultiProcessTestCase):
|
|||
self.assertIs(optim_wrapper.message_hub, self.message_hub)
|
||||
self.assertEqual(optim_wrapper._inner_count, 0)
|
||||
self.assertEqual(optim_wrapper._max_counts, -1)
|
||||
self.assertEqual(optim_wrapper._divisible_counts, -1)
|
||||
self.assertEqual(optim_wrapper._remainder_counts, -1)
|
||||
|
||||
with self.assertRaisesRegex(AssertionError,
|
||||
|
@ -119,12 +118,21 @@ class TestOptimWrapper(MultiProcessTestCase):
|
|||
optim_wrapper.update_params(loss)
|
||||
optim_wrapper.step.assert_called()
|
||||
optim_wrapper.zero_grad.assert_called()
|
||||
self.assertEqual(optim_wrapper.scaled_loss, torch.tensor(1))
|
||||
self.assertEqual(optim_wrapper.scaled_loss, torch.tensor(1.))
|
||||
self._mock_method(optim_wrapper)
|
||||
|
||||
# optim_wrapper.step should not be called at iteration 97 98, and the
|
||||
# loss factor should be 3 at iteration 99.
|
||||
optim_wrapper.initialize_count_status(self.model, 96, 100)
|
||||
for _ in range(2):
|
||||
optim_wrapper.update_params(loss)
|
||||
optim_wrapper.step.assert_not_called()
|
||||
optim_wrapper.zero_grad.assert_not_called()
|
||||
self.assertEqual(optim_wrapper.scaled_loss, torch.tensor(1.) / 3)
|
||||
|
||||
def test_initialize_iter_status(self):
|
||||
optim_wrapper = OptimWrapper(self.optimizer, accumulative_counts=3)
|
||||
optim_wrapper.initialize_count_status(self.model, 0, 100)
|
||||
self.assertEqual(optim_wrapper._divisible_counts, 99)
|
||||
self.assertEqual(optim_wrapper._remainder_counts, 1)
|
||||
|
||||
# Indivisible cur_iter will output warning.
|
||||
|
|
Loading…
Reference in New Issue