sched noise dup code remove
parent
07379c6d5d
commit
cf57695938
|
@ -43,7 +43,7 @@ class PlateauLRScheduler(Scheduler):
|
|||
min_lr=lr_min
|
||||
)
|
||||
|
||||
self.noise_range = noise_range_t
|
||||
self.noise_range_t = noise_range_t
|
||||
self.noise_pct = noise_pct
|
||||
self.noise_type = noise_type
|
||||
self.noise_std = noise_std
|
||||
|
@ -82,25 +82,12 @@ class PlateauLRScheduler(Scheduler):
|
|||
|
||||
self.lr_scheduler.step(metric, epoch) # step the base scheduler
|
||||
|
||||
if self.noise_range is not None:
|
||||
if isinstance(self.noise_range, (list, tuple)):
|
||||
apply_noise = self.noise_range[0] <= epoch < self.noise_range[1]
|
||||
else:
|
||||
apply_noise = epoch >= self.noise_range
|
||||
if apply_noise:
|
||||
self._apply_noise(epoch)
|
||||
if self._is_apply_noise(epoch):
|
||||
self._apply_noise(epoch)
|
||||
|
||||
|
||||
def _apply_noise(self, epoch):
|
||||
g = torch.Generator()
|
||||
g.manual_seed(self.noise_seed + epoch)
|
||||
if self.noise_type == 'normal':
|
||||
while True:
|
||||
# resample if noise out of percent limit, brute force but shouldn't spin much
|
||||
noise = torch.randn(1, generator=g).item()
|
||||
if abs(noise) < self.noise_pct:
|
||||
break
|
||||
else:
|
||||
noise = 2 * (torch.rand(1, generator=g).item() - 0.5) * self.noise_pct
|
||||
noise = self._calculate_noise(epoch)
|
||||
|
||||
# apply the noise on top of previous LR, cache the old value so we can restore for normal
|
||||
# stepping of base scheduler
|
||||
|
|
|
@ -85,21 +85,29 @@ class Scheduler:
|
|||
param_group[self.param_group_field] = value
|
||||
|
||||
def _add_noise(self, lrs, t):
|
||||
if self._is_apply_noise(t):
|
||||
noise = self._calculate_noise(t)
|
||||
lrs = [v + v * noise for v in lrs]
|
||||
return lrs
|
||||
|
||||
def _is_apply_noise(self, t) -> bool:
|
||||
"""Return True if scheduler in noise range."""
|
||||
if self.noise_range_t is not None:
|
||||
if isinstance(self.noise_range_t, (list, tuple)):
|
||||
apply_noise = self.noise_range_t[0] <= t < self.noise_range_t[1]
|
||||
else:
|
||||
apply_noise = t >= self.noise_range_t
|
||||
if apply_noise:
|
||||
g = torch.Generator()
|
||||
g.manual_seed(self.noise_seed + t)
|
||||
if self.noise_type == 'normal':
|
||||
while True:
|
||||
return apply_noise
|
||||
|
||||
def _calculate_noise(self, t) -> float:
|
||||
g = torch.Generator()
|
||||
g.manual_seed(self.noise_seed + t)
|
||||
if self.noise_type == 'normal':
|
||||
while True:
|
||||
# resample if noise out of percent limit, brute force but shouldn't spin much
|
||||
noise = torch.randn(1, generator=g).item()
|
||||
if abs(noise) < self.noise_pct:
|
||||
break
|
||||
else:
|
||||
noise = 2 * (torch.rand(1, generator=g).item() - 0.5) * self.noise_pct
|
||||
lrs = [v + v * noise for v in lrs]
|
||||
return lrs
|
||||
noise = torch.randn(1, generator=g).item()
|
||||
if abs(noise) < self.noise_pct:
|
||||
return noise
|
||||
else:
|
||||
noise = 2 * (torch.rand(1, generator=g).item() - 0.5) * self.noise_pct
|
||||
return noise
|
||||
|
|
Loading…
Reference in New Issue