From fd5d06243fcdb1018f85fe4c8dda44dace8837b2 Mon Sep 17 00:00:00 2001 From: fanqiNO1 <75657629+fanqiNO1@users.noreply.github.com> Date: Mon, 20 Nov 2023 16:36:43 +0800 Subject: [PATCH] [Fix] Fix scale_lr in SingleDeviceStrategy (#1428) --- mmengine/_strategy/base.py | 3 ++- mmengine/_strategy/single_device.py | 11 +++++------ mmengine/runner/_flexible_runner.py | 3 ++- 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/mmengine/_strategy/base.py b/mmengine/_strategy/base.py index 7b91d30f..708d0dbe 100644 --- a/mmengine/_strategy/base.py +++ b/mmengine/_strategy/base.py @@ -707,7 +707,8 @@ class BaseStrategy(metaclass=ABCMeta): return any(_is_built(s) for s in schedulers) return isinstance(schedulers, _ParamScheduler) - if _is_built(self.param_schedulers): + if hasattr(self, 'param_schedulers') and _is_built( + self.param_schedulers): raise RuntimeError('`scale_lr` should be called before building ' 'ParamScheduler because ParamScheduler will ' 'store initial lr from optimizer wrappers') diff --git a/mmengine/_strategy/single_device.py b/mmengine/_strategy/single_device.py index 180ad79c..c7d8accd 100644 --- a/mmengine/_strategy/single_device.py +++ b/mmengine/_strategy/single_device.py @@ -63,12 +63,6 @@ class SingleDeviceStrategy(BaseStrategy): if optim_wrapper is not None: self.optim_wrapper = self.build_optim_wrapper(optim_wrapper, model) - - if param_scheduler is not None: - self.param_schedulers = self.build_param_scheduler( - param_scheduler, self.optim_wrapper) - - if optim_wrapper is not None: self._scale_lr() accumulative_counts = getattr(self.optim_wrapper, @@ -82,6 +76,11 @@ class SingleDeviceStrategy(BaseStrategy): self.optim_wrapper.initialize_count_status( # type: ignore self.model, 0, self.dispatch_kwargs['max_iters']) + + if param_scheduler is not None: + self.param_schedulers = self.build_param_scheduler( + param_scheduler, self.optim_wrapper) + self._prepared = True return self._prepared_components() diff --git a/mmengine/runner/_flexible_runner.py b/mmengine/runner/_flexible_runner.py index 8a771b05..6d727fb4 100644 --- a/mmengine/runner/_flexible_runner.py +++ b/mmengine/runner/_flexible_runner.py @@ -1176,7 +1176,8 @@ class FlexibleRunner: epoch_length=len(self.train_dataloader), max_epochs=self.max_epochs, max_iters=self.max_iters, - ) + train_micro_batch_size_per_gpu=_get_batch_size( + self.train_dataloader)) # type: ignore self.strategy.prepare( self.model,