From b2ee9f8b11f243c3000679be88f29efade9d3ee7 Mon Sep 17 00:00:00 2001 From: Mashiro <57566630+HAOCHENYE@users.noreply.github.com> Date: Wed, 6 Jul 2022 16:42:49 +0800 Subject: [PATCH] [Fix] Fix loss could be nan in optimizer wrapper (#345) * fix optimizer wrapper counts * fix ut --- mmengine/optim/optimizer/optimizer_wrapper.py | 15 ++------------- .../test_optimizer/test_optimizer_wrapper.py | 14 +++++++++++--- 2 files changed, 13 insertions(+), 16 deletions(-) diff --git a/mmengine/optim/optimizer/optimizer_wrapper.py b/mmengine/optim/optimizer/optimizer_wrapper.py index b7fa4290..e502001b 100644 --- a/mmengine/optim/optimizer/optimizer_wrapper.py +++ b/mmengine/optim/optimizer/optimizer_wrapper.py @@ -115,11 +115,6 @@ class OptimWrapper: # lost when the `_max_counts` is not divisible by # `accumulative_counts`. self._max_counts = -1 - # If `_inner_count` is smaller than `_divisible_counts`, the loss - # factor used for gradient accumulation should be the same as - # `_accumulative_counts`. If `_max_counts` has not been initialized, - # the loss factor will always be the same as `_accumulative_counts`. - self._divisible_counts = -1 # The `_remainder_iter` is used for calculating loss factor at the # last few iterations. If `_max_counts` has not been initialized, # the loss factor will always be the same as `_accumulative_counts`. @@ -333,14 +328,8 @@ class OptimWrapper: self.logger.warning( 'Gradient accumulative may slightly decrease ' 'performance because the model has BatchNorm layers.') - residual_counts = max_counts - init_counts - # The maximum number of training iteration that is divisible by - # `_accumulative_counts`. - self._divisible_counts = ( - residual_counts // self._accumulative_counts * - self._accumulative_counts) # Remainder of `_max_counts` divided by `_accumulative_counts` - self._remainder_counts = residual_counts - self._divisible_counts + self._remainder_counts = self._max_counts % self._accumulative_counts def should_update(self) -> bool: """Decide whether the parameters should be updated at the current @@ -396,7 +385,7 @@ class OptimWrapper: # be divisible by `self._accumulative_counts`, so the # `loss_scale` for the last few iterations needs to be # recalculated. - if self._inner_count < self._divisible_counts: + if self._inner_count < self._max_counts - self._remainder_counts: loss_factor = self._accumulative_counts else: loss_factor = self._remainder_counts diff --git a/tests/test_optim/test_optimizer/test_optimizer_wrapper.py b/tests/test_optim/test_optimizer/test_optimizer_wrapper.py index 798b37eb..db5f4b41 100644 --- a/tests/test_optim/test_optimizer/test_optimizer_wrapper.py +++ b/tests/test_optim/test_optimizer/test_optimizer_wrapper.py @@ -68,7 +68,6 @@ class TestOptimWrapper(MultiProcessTestCase): self.assertIs(optim_wrapper.message_hub, self.message_hub) self.assertEqual(optim_wrapper._inner_count, 0) self.assertEqual(optim_wrapper._max_counts, -1) - self.assertEqual(optim_wrapper._divisible_counts, -1) self.assertEqual(optim_wrapper._remainder_counts, -1) with self.assertRaisesRegex(AssertionError, @@ -119,12 +118,21 @@ class TestOptimWrapper(MultiProcessTestCase): optim_wrapper.update_params(loss) optim_wrapper.step.assert_called() optim_wrapper.zero_grad.assert_called() - self.assertEqual(optim_wrapper.scaled_loss, torch.tensor(1)) + self.assertEqual(optim_wrapper.scaled_loss, torch.tensor(1.)) + self._mock_method(optim_wrapper) + + # optim_wrapper.step should not be called at iteration 97 98, and the + # loss factor should be 3 at iteration 99. + optim_wrapper.initialize_count_status(self.model, 96, 100) + for _ in range(2): + optim_wrapper.update_params(loss) + optim_wrapper.step.assert_not_called() + optim_wrapper.zero_grad.assert_not_called() + self.assertEqual(optim_wrapper.scaled_loss, torch.tensor(1.) / 3) def test_initialize_iter_status(self): optim_wrapper = OptimWrapper(self.optimizer, accumulative_counts=3) optim_wrapper.initialize_count_status(self.model, 0, 100) - self.assertEqual(optim_wrapper._divisible_counts, 99) self.assertEqual(optim_wrapper._remainder_counts, 1) # Indivisible cur_iter will output warning.