mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Update schedulers
This commit is contained in:
parent
b5255960d9
commit
b1a5a71151
@ -21,8 +21,10 @@ class CosineLRScheduler(Scheduler):
|
||||
t_mul: float = 1.,
|
||||
lr_min: float = 0.,
|
||||
decay_rate: float = 1.,
|
||||
warmup_updates=0,
|
||||
warmup_t=0,
|
||||
warmup_lr_init=0,
|
||||
warmup_prefix=False,
|
||||
t_in_epochs=True,
|
||||
initialize=True) -> None:
|
||||
super().__init__(optimizer, param_group_field="lr", initialize=initialize)
|
||||
|
||||
@ -35,32 +37,31 @@ class CosineLRScheduler(Scheduler):
|
||||
self.t_mul = t_mul
|
||||
self.lr_min = lr_min
|
||||
self.decay_rate = decay_rate
|
||||
self.warmup_updates = warmup_updates
|
||||
self.warmup_t = warmup_t
|
||||
self.warmup_lr_init = warmup_lr_init
|
||||
if self.warmup_updates:
|
||||
self.warmup_steps = [(v - warmup_lr_init) / self.warmup_updates for v in self.base_values]
|
||||
self.warmup_prefix = warmup_prefix
|
||||
self.t_in_epochs = t_in_epochs
|
||||
if self.warmup_t:
|
||||
self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values]
|
||||
super().update_groups(self.warmup_lr_init)
|
||||
else:
|
||||
self.warmup_steps = [1 for _ in self.base_values]
|
||||
if self.warmup_lr_init:
|
||||
super().update_groups(self.warmup_lr_init)
|
||||
|
||||
def get_epoch_values(self, epoch: int):
|
||||
# this scheduler doesn't update on epoch
|
||||
return None
|
||||
|
||||
def get_update_values(self, num_updates: int):
|
||||
if num_updates < self.warmup_updates:
|
||||
lrs = [self.warmup_lr_init + num_updates * s for s in self.warmup_steps]
|
||||
def _get_lr(self, t):
|
||||
if t < self.warmup_t:
|
||||
lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
|
||||
else:
|
||||
curr_updates = num_updates - self.warmup_updates
|
||||
if self.warmup_prefix:
|
||||
t = t - self.warmup_t
|
||||
|
||||
if self.t_mul != 1:
|
||||
i = math.floor(math.log(1 - curr_updates / self.t_initial * (1 - self.t_mul), self.t_mul))
|
||||
i = math.floor(math.log(1 - t / self.t_initial * (1 - self.t_mul), self.t_mul))
|
||||
t_i = self.t_mul ** i * self.t_initial
|
||||
t_curr = curr_updates - (1 - self.t_mul ** i) / (1 - self.t_mul) * self.t_initial
|
||||
t_curr = t - (1 - self.t_mul ** i) / (1 - self.t_mul) * self.t_initial
|
||||
else:
|
||||
i = curr_updates // self.t_initial
|
||||
i = t // self.t_initial
|
||||
t_i = self.t_initial
|
||||
t_curr = curr_updates - (self.t_initial * i)
|
||||
t_curr = t - (self.t_initial * i)
|
||||
|
||||
gamma = self.decay_rate ** i
|
||||
lr_min = self.lr_min * gamma
|
||||
@ -70,3 +71,15 @@ class CosineLRScheduler(Scheduler):
|
||||
lr_min + 0.5 * (lr_max - lr_min) * (1 + math.cos(math.pi * t_curr / t_i)) for lr_max in lr_max_values
|
||||
]
|
||||
return lrs
|
||||
|
||||
def get_epoch_values(self, epoch: int):
|
||||
if self.t_in_epochs:
|
||||
return self._get_lr(epoch)
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_update_values(self, num_updates: int):
|
||||
if not self.t_in_epochs:
|
||||
return self._get_lr(num_updates)
|
||||
else:
|
||||
return None
|
||||
|
@ -56,7 +56,7 @@ class Scheduler:
|
||||
|
||||
def step(self, epoch: int, metric: float = None) -> None:
|
||||
self.metric = metric
|
||||
values = self.get_epoch_values(epoch)
|
||||
values = self.get_epoch_values(epoch + 1) # +1 to calculate for next epoch
|
||||
if values is not None:
|
||||
self.update_groups(values)
|
||||
|
||||
|
@ -10,39 +10,41 @@ class StepLRScheduler(Scheduler):
|
||||
|
||||
def __init__(self,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
decay_epochs: int,
|
||||
decay_t: int,
|
||||
decay_rate: float = 1.,
|
||||
warmup_updates=0,
|
||||
warmup_t=0,
|
||||
warmup_lr_init=0,
|
||||
t_in_epochs=True,
|
||||
initialize=True) -> None:
|
||||
super().__init__(optimizer, param_group_field="lr", initialize=initialize)
|
||||
|
||||
self.decay_epochs = decay_epochs
|
||||
self.decay_t = decay_t
|
||||
self.decay_rate = decay_rate
|
||||
self.warmup_updates = warmup_updates
|
||||
self.warmup_t = warmup_t
|
||||
self.warmup_lr_init = warmup_lr_init
|
||||
|
||||
if self.warmup_updates:
|
||||
self.warmup_active = warmup_updates > 0 # this state updates with num_updates
|
||||
self.warmup_steps = [(v - warmup_lr_init) / self.warmup_updates for v in self.base_values]
|
||||
self.t_in_epochs = t_in_epochs
|
||||
if self.warmup_t:
|
||||
self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values]
|
||||
super().update_groups(self.warmup_lr_init)
|
||||
else:
|
||||
self.warmup_steps = [1 for _ in self.base_values]
|
||||
|
||||
def get_epoch_values(self, epoch: int):
|
||||
if not self.warmup_active:
|
||||
lrs = [v * (self.decay_rate ** ((epoch + 1) // self.decay_epochs))
|
||||
for v in self.base_values]
|
||||
def _get_lr(self, t):
|
||||
if t < self.warmup_t:
|
||||
lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
|
||||
else:
|
||||
lrs = None # no epoch updates while warming up
|
||||
lrs = [v * (self.decay_rate ** (t // self.decay_t))
|
||||
for v in self.base_values]
|
||||
return lrs
|
||||
|
||||
def get_epoch_values(self, epoch: int):
|
||||
if self.t_in_epochs:
|
||||
return self._get_lr(epoch)
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_update_values(self, num_updates: int):
|
||||
if num_updates < self.warmup_updates:
|
||||
lrs = [self.warmup_lr_init + num_updates * s for s in self.warmup_steps]
|
||||
if not self.t_in_epochs:
|
||||
return self._get_lr(num_updates)
|
||||
else:
|
||||
self.warmup_active = False # warmup cancelled by first update past warmup_update count
|
||||
lrs = None # no change on update afte warmup stage
|
||||
return lrs
|
||||
|
||||
|
||||
return None
|
||||
|
@ -27,7 +27,7 @@ class TanhLRScheduler(Scheduler):
|
||||
warmup_lr_init=0,
|
||||
warmup_prefix=False,
|
||||
cycle_limit=0,
|
||||
t_in_epochs=False,
|
||||
t_in_epochs=True,
|
||||
initialize=True) -> None:
|
||||
super().__init__(optimizer, param_group_field="lr", initialize=initialize)
|
||||
|
||||
|
22
train.py
22
train.py
@ -162,7 +162,7 @@ def main():
|
||||
if args.opt.lower() == 'sgd':
|
||||
optimizer = optim.SGD(
|
||||
model.parameters(), lr=args.lr,
|
||||
momentum=args.momentum, weight_decay=args.weight_decay, nesterov=False)
|
||||
momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True)
|
||||
elif args.opt.lower() == 'adam':
|
||||
optimizer = optim.Adam(
|
||||
model.parameters(), lr=args.lr, weight_decay=args.weight_decay, eps=args.opt_eps)
|
||||
@ -183,32 +183,32 @@ def main():
|
||||
if optimizer_state is not None:
|
||||
optimizer.load_state_dict(optimizer_state)
|
||||
|
||||
updates_per_epoch = len(loader_train)
|
||||
if args.sched == 'cosine':
|
||||
lr_scheduler = scheduler.CosineLRScheduler(
|
||||
optimizer,
|
||||
t_initial=100 * updates_per_epoch,
|
||||
t_initial=130,
|
||||
t_mul=1.0,
|
||||
lr_min=0,
|
||||
decay_rate=0.5,
|
||||
decay_rate=args.decay_rate,
|
||||
warmup_lr_init=1e-4,
|
||||
warmup_updates=1 * updates_per_epoch
|
||||
warmup_t=3,
|
||||
t_in_epochs=True,
|
||||
)
|
||||
elif args.sched == 'tanh':
|
||||
lr_scheduler = scheduler.TanhLRScheduler(
|
||||
optimizer,
|
||||
t_initial=80 * updates_per_epoch,
|
||||
t_initial=130,
|
||||
t_mul=1.0,
|
||||
lr_min=1e-5,
|
||||
decay_rate=0.5,
|
||||
lr_min=1e-6,
|
||||
warmup_lr_init=.001,
|
||||
warmup_t=5 * updates_per_epoch,
|
||||
cycle_limit=1
|
||||
warmup_t=3,
|
||||
cycle_limit=1,
|
||||
t_in_epochs=True,
|
||||
)
|
||||
else:
|
||||
lr_scheduler = scheduler.StepLRScheduler(
|
||||
optimizer,
|
||||
decay_epochs=args.decay_epochs,
|
||||
decay_t=args.decay_epochs,
|
||||
decay_rate=args.decay_rate,
|
||||
)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user