From 84cd19aaa2b1ba5c3b04e00aaa7456fba090260f Mon Sep 17 00:00:00 2001 From: RangiLyu Date: Mon, 1 Aug 2022 10:26:58 +0800 Subject: [PATCH] [Enhance] Support convert float begin&end in ParamScheduler. (#393) --- mmengine/optim/scheduler/param_scheduler.py | 32 ++++++++++----------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/mmengine/optim/scheduler/param_scheduler.py b/mmengine/optim/scheduler/param_scheduler.py index 2fea10a3..50ea25a1 100644 --- a/mmengine/optim/scheduler/param_scheduler.py +++ b/mmengine/optim/scheduler/param_scheduler.py @@ -281,9 +281,9 @@ class StepParamScheduler(_ParamScheduler): f'but got {epoch_length}.' by_epoch = False step_size = step_size * epoch_length - begin = begin * epoch_length + begin = int(begin * epoch_length) if end != INF: - end = end * epoch_length + end = int(end * epoch_length) return cls( *args, step_size=step_size, @@ -367,9 +367,9 @@ class MultiStepParamScheduler(_ParamScheduler): f'but got {epoch_length}.' by_epoch = False milestones = [i * epoch_length for i in milestones] - begin = begin * epoch_length + begin = int(begin * epoch_length) if end != INF: - end = end * epoch_length + end = int(end * epoch_length) return cls( *args, milestones=milestones, @@ -455,9 +455,9 @@ class ConstantParamScheduler(_ParamScheduler): f'`epoch_length` must be a positive integer, ' \ f'but got {epoch_length}.' by_epoch = False - begin = begin * epoch_length + begin = int(begin * epoch_length) if end != INF: - end = end * epoch_length + end = int(end * epoch_length) return cls(*args, begin=begin, end=end, by_epoch=by_epoch, **kwargs) def _get_value(self): @@ -536,9 +536,9 @@ class ExponentialParamScheduler(_ParamScheduler): f'`epoch_length` must be a positive integer, ' \ f'but got {epoch_length}.' by_epoch = False - begin = begin * epoch_length + begin = int(begin * epoch_length) if end != INF: - end = end * epoch_length + end = int(end * epoch_length) return cls(*args, begin=begin, end=end, by_epoch=by_epoch, **kwargs) def _get_value(self): @@ -642,9 +642,9 @@ class CosineAnnealingParamScheduler(_ParamScheduler): f'but got {epoch_length}.' by_epoch = False T_max = T_max * epoch_length - begin = begin * epoch_length + begin = int(begin * epoch_length) if end != INF: - end = end * epoch_length + end = int(end * epoch_length) return cls( *args, T_max=T_max, @@ -747,9 +747,9 @@ class LinearParamScheduler(_ParamScheduler): f'`epoch_length` must be a positive integer, ' \ f'but got {epoch_length}.' by_epoch = False - begin = begin * epoch_length + begin = int(begin * epoch_length) if end != INF: - end = end * epoch_length + end = int(end * epoch_length) return cls(*args, begin=begin, end=end, by_epoch=by_epoch, **kwargs) def _get_value(self): @@ -835,9 +835,9 @@ class PolyParamScheduler(_ParamScheduler): f'`epoch_length` must be a positive integer, ' \ f'but got {epoch_length}.' by_epoch = False - begin = begin * epoch_length + begin = int(begin * epoch_length) if end != INF: - end = end * epoch_length + end = int(end * epoch_length) return cls(*args, begin=begin, end=end, by_epoch=by_epoch, **kwargs) def _get_value(self): @@ -1055,9 +1055,9 @@ class OneCycleParamScheduler(_ParamScheduler): f'`epoch_length` must be a positive integer, ' \ f'but got {epoch_length}.' by_epoch = False - begin = begin * epoch_length + begin = int(begin * epoch_length) if end != INF: - end = end * epoch_length + end = int(end * epoch_length) if total_steps is not None: total_steps = total_steps * epoch_length return cls(