mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Fix] Fix loss could be nan in optimizer wrapper (#345)
* fix optimizer wrapper counts * fix ut
This commit is contained in:
parent
96378fa748
commit
b2ee9f8b11
@ -115,11 +115,6 @@ class OptimWrapper:
|
|||||||
# lost when the `_max_counts` is not divisible by
|
# lost when the `_max_counts` is not divisible by
|
||||||
# `accumulative_counts`.
|
# `accumulative_counts`.
|
||||||
self._max_counts = -1
|
self._max_counts = -1
|
||||||
# If `_inner_count` is smaller than `_divisible_counts`, the loss
|
|
||||||
# factor used for gradient accumulation should be the same as
|
|
||||||
# `_accumulative_counts`. If `_max_counts` has not been initialized,
|
|
||||||
# the loss factor will always be the same as `_accumulative_counts`.
|
|
||||||
self._divisible_counts = -1
|
|
||||||
# The `_remainder_iter` is used for calculating loss factor at the
|
# The `_remainder_iter` is used for calculating loss factor at the
|
||||||
# last few iterations. If `_max_counts` has not been initialized,
|
# last few iterations. If `_max_counts` has not been initialized,
|
||||||
# the loss factor will always be the same as `_accumulative_counts`.
|
# the loss factor will always be the same as `_accumulative_counts`.
|
||||||
@ -333,14 +328,8 @@ class OptimWrapper:
|
|||||||
self.logger.warning(
|
self.logger.warning(
|
||||||
'Gradient accumulative may slightly decrease '
|
'Gradient accumulative may slightly decrease '
|
||||||
'performance because the model has BatchNorm layers.')
|
'performance because the model has BatchNorm layers.')
|
||||||
residual_counts = max_counts - init_counts
|
|
||||||
# The maximum number of training iteration that is divisible by
|
|
||||||
# `_accumulative_counts`.
|
|
||||||
self._divisible_counts = (
|
|
||||||
residual_counts // self._accumulative_counts *
|
|
||||||
self._accumulative_counts)
|
|
||||||
# Remainder of `_max_counts` divided by `_accumulative_counts`
|
# Remainder of `_max_counts` divided by `_accumulative_counts`
|
||||||
self._remainder_counts = residual_counts - self._divisible_counts
|
self._remainder_counts = self._max_counts % self._accumulative_counts
|
||||||
|
|
||||||
def should_update(self) -> bool:
|
def should_update(self) -> bool:
|
||||||
"""Decide whether the parameters should be updated at the current
|
"""Decide whether the parameters should be updated at the current
|
||||||
@ -396,7 +385,7 @@ class OptimWrapper:
|
|||||||
# be divisible by `self._accumulative_counts`, so the
|
# be divisible by `self._accumulative_counts`, so the
|
||||||
# `loss_scale` for the last few iterations needs to be
|
# `loss_scale` for the last few iterations needs to be
|
||||||
# recalculated.
|
# recalculated.
|
||||||
if self._inner_count < self._divisible_counts:
|
if self._inner_count < self._max_counts - self._remainder_counts:
|
||||||
loss_factor = self._accumulative_counts
|
loss_factor = self._accumulative_counts
|
||||||
else:
|
else:
|
||||||
loss_factor = self._remainder_counts
|
loss_factor = self._remainder_counts
|
||||||
|
@ -68,7 +68,6 @@ class TestOptimWrapper(MultiProcessTestCase):
|
|||||||
self.assertIs(optim_wrapper.message_hub, self.message_hub)
|
self.assertIs(optim_wrapper.message_hub, self.message_hub)
|
||||||
self.assertEqual(optim_wrapper._inner_count, 0)
|
self.assertEqual(optim_wrapper._inner_count, 0)
|
||||||
self.assertEqual(optim_wrapper._max_counts, -1)
|
self.assertEqual(optim_wrapper._max_counts, -1)
|
||||||
self.assertEqual(optim_wrapper._divisible_counts, -1)
|
|
||||||
self.assertEqual(optim_wrapper._remainder_counts, -1)
|
self.assertEqual(optim_wrapper._remainder_counts, -1)
|
||||||
|
|
||||||
with self.assertRaisesRegex(AssertionError,
|
with self.assertRaisesRegex(AssertionError,
|
||||||
@ -119,12 +118,21 @@ class TestOptimWrapper(MultiProcessTestCase):
|
|||||||
optim_wrapper.update_params(loss)
|
optim_wrapper.update_params(loss)
|
||||||
optim_wrapper.step.assert_called()
|
optim_wrapper.step.assert_called()
|
||||||
optim_wrapper.zero_grad.assert_called()
|
optim_wrapper.zero_grad.assert_called()
|
||||||
self.assertEqual(optim_wrapper.scaled_loss, torch.tensor(1))
|
self.assertEqual(optim_wrapper.scaled_loss, torch.tensor(1.))
|
||||||
|
self._mock_method(optim_wrapper)
|
||||||
|
|
||||||
|
# optim_wrapper.step should not be called at iteration 97 98, and the
|
||||||
|
# loss factor should be 3 at iteration 99.
|
||||||
|
optim_wrapper.initialize_count_status(self.model, 96, 100)
|
||||||
|
for _ in range(2):
|
||||||
|
optim_wrapper.update_params(loss)
|
||||||
|
optim_wrapper.step.assert_not_called()
|
||||||
|
optim_wrapper.zero_grad.assert_not_called()
|
||||||
|
self.assertEqual(optim_wrapper.scaled_loss, torch.tensor(1.) / 3)
|
||||||
|
|
||||||
def test_initialize_iter_status(self):
|
def test_initialize_iter_status(self):
|
||||||
optim_wrapper = OptimWrapper(self.optimizer, accumulative_counts=3)
|
optim_wrapper = OptimWrapper(self.optimizer, accumulative_counts=3)
|
||||||
optim_wrapper.initialize_count_status(self.model, 0, 100)
|
optim_wrapper.initialize_count_status(self.model, 0, 100)
|
||||||
self.assertEqual(optim_wrapper._divisible_counts, 99)
|
|
||||||
self.assertEqual(optim_wrapper._remainder_counts, 1)
|
self.assertEqual(optim_wrapper._remainder_counts, 1)
|
||||||
|
|
||||||
# Indivisible cur_iter will output warning.
|
# Indivisible cur_iter will output warning.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user