mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Fix] Fix scale_lr in SingleDeviceStrategy (#1428)
This commit is contained in:
parent
5a90805b1e
commit
fd5d06243f
@ -707,7 +707,8 @@ class BaseStrategy(metaclass=ABCMeta):
|
||||
return any(_is_built(s) for s in schedulers)
|
||||
return isinstance(schedulers, _ParamScheduler)
|
||||
|
||||
if _is_built(self.param_schedulers):
|
||||
if hasattr(self, 'param_schedulers') and _is_built(
|
||||
self.param_schedulers):
|
||||
raise RuntimeError('`scale_lr` should be called before building '
|
||||
'ParamScheduler because ParamScheduler will '
|
||||
'store initial lr from optimizer wrappers')
|
||||
|
@ -63,12 +63,6 @@ class SingleDeviceStrategy(BaseStrategy):
|
||||
|
||||
if optim_wrapper is not None:
|
||||
self.optim_wrapper = self.build_optim_wrapper(optim_wrapper, model)
|
||||
|
||||
if param_scheduler is not None:
|
||||
self.param_schedulers = self.build_param_scheduler(
|
||||
param_scheduler, self.optim_wrapper)
|
||||
|
||||
if optim_wrapper is not None:
|
||||
self._scale_lr()
|
||||
|
||||
accumulative_counts = getattr(self.optim_wrapper,
|
||||
@ -82,6 +76,11 @@ class SingleDeviceStrategy(BaseStrategy):
|
||||
|
||||
self.optim_wrapper.initialize_count_status( # type: ignore
|
||||
self.model, 0, self.dispatch_kwargs['max_iters'])
|
||||
|
||||
if param_scheduler is not None:
|
||||
self.param_schedulers = self.build_param_scheduler(
|
||||
param_scheduler, self.optim_wrapper)
|
||||
|
||||
self._prepared = True
|
||||
return self._prepared_components()
|
||||
|
||||
|
@ -1176,7 +1176,8 @@ class FlexibleRunner:
|
||||
epoch_length=len(self.train_dataloader),
|
||||
max_epochs=self.max_epochs,
|
||||
max_iters=self.max_iters,
|
||||
)
|
||||
train_micro_batch_size_per_gpu=_get_batch_size(
|
||||
self.train_dataloader)) # type: ignore
|
||||
|
||||
self.strategy.prepare(
|
||||
self.model,
|
||||
|
Loading…
x
Reference in New Issue
Block a user