[Enhance] Support convert float begin&end in ParamScheduler. (#393)

This commit is contained in:
RangiLyu 2022-08-01 10:26:58 +08:00 committed by GitHub
parent 4432e54c97
commit 84cd19aaa2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -281,9 +281,9 @@ class StepParamScheduler(_ParamScheduler):
f'but got {epoch_length}.'
by_epoch = False
step_size = step_size * epoch_length
begin = begin * epoch_length
begin = int(begin * epoch_length)
if end != INF:
end = end * epoch_length
end = int(end * epoch_length)
return cls(
*args,
step_size=step_size,
@ -367,9 +367,9 @@ class MultiStepParamScheduler(_ParamScheduler):
f'but got {epoch_length}.'
by_epoch = False
milestones = [i * epoch_length for i in milestones]
begin = begin * epoch_length
begin = int(begin * epoch_length)
if end != INF:
end = end * epoch_length
end = int(end * epoch_length)
return cls(
*args,
milestones=milestones,
@ -455,9 +455,9 @@ class ConstantParamScheduler(_ParamScheduler):
f'`epoch_length` must be a positive integer, ' \
f'but got {epoch_length}.'
by_epoch = False
begin = begin * epoch_length
begin = int(begin * epoch_length)
if end != INF:
end = end * epoch_length
end = int(end * epoch_length)
return cls(*args, begin=begin, end=end, by_epoch=by_epoch, **kwargs)
def _get_value(self):
@ -536,9 +536,9 @@ class ExponentialParamScheduler(_ParamScheduler):
f'`epoch_length` must be a positive integer, ' \
f'but got {epoch_length}.'
by_epoch = False
begin = begin * epoch_length
begin = int(begin * epoch_length)
if end != INF:
end = end * epoch_length
end = int(end * epoch_length)
return cls(*args, begin=begin, end=end, by_epoch=by_epoch, **kwargs)
def _get_value(self):
@ -642,9 +642,9 @@ class CosineAnnealingParamScheduler(_ParamScheduler):
f'but got {epoch_length}.'
by_epoch = False
T_max = T_max * epoch_length
begin = begin * epoch_length
begin = int(begin * epoch_length)
if end != INF:
end = end * epoch_length
end = int(end * epoch_length)
return cls(
*args,
T_max=T_max,
@ -747,9 +747,9 @@ class LinearParamScheduler(_ParamScheduler):
f'`epoch_length` must be a positive integer, ' \
f'but got {epoch_length}.'
by_epoch = False
begin = begin * epoch_length
begin = int(begin * epoch_length)
if end != INF:
end = end * epoch_length
end = int(end * epoch_length)
return cls(*args, begin=begin, end=end, by_epoch=by_epoch, **kwargs)
def _get_value(self):
@ -835,9 +835,9 @@ class PolyParamScheduler(_ParamScheduler):
f'`epoch_length` must be a positive integer, ' \
f'but got {epoch_length}.'
by_epoch = False
begin = begin * epoch_length
begin = int(begin * epoch_length)
if end != INF:
end = end * epoch_length
end = int(end * epoch_length)
return cls(*args, begin=begin, end=end, by_epoch=by_epoch, **kwargs)
def _get_value(self):
@ -1055,9 +1055,9 @@ class OneCycleParamScheduler(_ParamScheduler):
f'`epoch_length` must be a positive integer, ' \
f'but got {epoch_length}.'
by_epoch = False
begin = begin * epoch_length
begin = int(begin * epoch_length)
if end != INF:
end = end * epoch_length
end = int(end * epoch_length)
if total_steps is not None:
total_steps = total_steps * epoch_length
return cls(