mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Enhance] Support convert float begin&end in ParamScheduler. (#393)
This commit is contained in:
parent
4432e54c97
commit
84cd19aaa2
@ -281,9 +281,9 @@ class StepParamScheduler(_ParamScheduler):
|
||||
f'but got {epoch_length}.'
|
||||
by_epoch = False
|
||||
step_size = step_size * epoch_length
|
||||
begin = begin * epoch_length
|
||||
begin = int(begin * epoch_length)
|
||||
if end != INF:
|
||||
end = end * epoch_length
|
||||
end = int(end * epoch_length)
|
||||
return cls(
|
||||
*args,
|
||||
step_size=step_size,
|
||||
@ -367,9 +367,9 @@ class MultiStepParamScheduler(_ParamScheduler):
|
||||
f'but got {epoch_length}.'
|
||||
by_epoch = False
|
||||
milestones = [i * epoch_length for i in milestones]
|
||||
begin = begin * epoch_length
|
||||
begin = int(begin * epoch_length)
|
||||
if end != INF:
|
||||
end = end * epoch_length
|
||||
end = int(end * epoch_length)
|
||||
return cls(
|
||||
*args,
|
||||
milestones=milestones,
|
||||
@ -455,9 +455,9 @@ class ConstantParamScheduler(_ParamScheduler):
|
||||
f'`epoch_length` must be a positive integer, ' \
|
||||
f'but got {epoch_length}.'
|
||||
by_epoch = False
|
||||
begin = begin * epoch_length
|
||||
begin = int(begin * epoch_length)
|
||||
if end != INF:
|
||||
end = end * epoch_length
|
||||
end = int(end * epoch_length)
|
||||
return cls(*args, begin=begin, end=end, by_epoch=by_epoch, **kwargs)
|
||||
|
||||
def _get_value(self):
|
||||
@ -536,9 +536,9 @@ class ExponentialParamScheduler(_ParamScheduler):
|
||||
f'`epoch_length` must be a positive integer, ' \
|
||||
f'but got {epoch_length}.'
|
||||
by_epoch = False
|
||||
begin = begin * epoch_length
|
||||
begin = int(begin * epoch_length)
|
||||
if end != INF:
|
||||
end = end * epoch_length
|
||||
end = int(end * epoch_length)
|
||||
return cls(*args, begin=begin, end=end, by_epoch=by_epoch, **kwargs)
|
||||
|
||||
def _get_value(self):
|
||||
@ -642,9 +642,9 @@ class CosineAnnealingParamScheduler(_ParamScheduler):
|
||||
f'but got {epoch_length}.'
|
||||
by_epoch = False
|
||||
T_max = T_max * epoch_length
|
||||
begin = begin * epoch_length
|
||||
begin = int(begin * epoch_length)
|
||||
if end != INF:
|
||||
end = end * epoch_length
|
||||
end = int(end * epoch_length)
|
||||
return cls(
|
||||
*args,
|
||||
T_max=T_max,
|
||||
@ -747,9 +747,9 @@ class LinearParamScheduler(_ParamScheduler):
|
||||
f'`epoch_length` must be a positive integer, ' \
|
||||
f'but got {epoch_length}.'
|
||||
by_epoch = False
|
||||
begin = begin * epoch_length
|
||||
begin = int(begin * epoch_length)
|
||||
if end != INF:
|
||||
end = end * epoch_length
|
||||
end = int(end * epoch_length)
|
||||
return cls(*args, begin=begin, end=end, by_epoch=by_epoch, **kwargs)
|
||||
|
||||
def _get_value(self):
|
||||
@ -835,9 +835,9 @@ class PolyParamScheduler(_ParamScheduler):
|
||||
f'`epoch_length` must be a positive integer, ' \
|
||||
f'but got {epoch_length}.'
|
||||
by_epoch = False
|
||||
begin = begin * epoch_length
|
||||
begin = int(begin * epoch_length)
|
||||
if end != INF:
|
||||
end = end * epoch_length
|
||||
end = int(end * epoch_length)
|
||||
return cls(*args, begin=begin, end=end, by_epoch=by_epoch, **kwargs)
|
||||
|
||||
def _get_value(self):
|
||||
@ -1055,9 +1055,9 @@ class OneCycleParamScheduler(_ParamScheduler):
|
||||
f'`epoch_length` must be a positive integer, ' \
|
||||
f'but got {epoch_length}.'
|
||||
by_epoch = False
|
||||
begin = begin * epoch_length
|
||||
begin = int(begin * epoch_length)
|
||||
if end != INF:
|
||||
end = end * epoch_length
|
||||
end = int(end * epoch_length)
|
||||
if total_steps is not None:
|
||||
total_steps = total_steps * epoch_length
|
||||
return cls(
|
||||
|
Loading…
x
Reference in New Issue
Block a user