mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
sched noise dup code remove
This commit is contained in:
parent
07379c6d5d
commit
cf57695938
@ -43,7 +43,7 @@ class PlateauLRScheduler(Scheduler):
|
|||||||
min_lr=lr_min
|
min_lr=lr_min
|
||||||
)
|
)
|
||||||
|
|
||||||
self.noise_range = noise_range_t
|
self.noise_range_t = noise_range_t
|
||||||
self.noise_pct = noise_pct
|
self.noise_pct = noise_pct
|
||||||
self.noise_type = noise_type
|
self.noise_type = noise_type
|
||||||
self.noise_std = noise_std
|
self.noise_std = noise_std
|
||||||
@ -82,25 +82,12 @@ class PlateauLRScheduler(Scheduler):
|
|||||||
|
|
||||||
self.lr_scheduler.step(metric, epoch) # step the base scheduler
|
self.lr_scheduler.step(metric, epoch) # step the base scheduler
|
||||||
|
|
||||||
if self.noise_range is not None:
|
if self._is_apply_noise(epoch):
|
||||||
if isinstance(self.noise_range, (list, tuple)):
|
self._apply_noise(epoch)
|
||||||
apply_noise = self.noise_range[0] <= epoch < self.noise_range[1]
|
|
||||||
else:
|
|
||||||
apply_noise = epoch >= self.noise_range
|
|
||||||
if apply_noise:
|
|
||||||
self._apply_noise(epoch)
|
|
||||||
|
|
||||||
def _apply_noise(self, epoch):
|
def _apply_noise(self, epoch):
|
||||||
g = torch.Generator()
|
noise = self._calculate_noise(epoch)
|
||||||
g.manual_seed(self.noise_seed + epoch)
|
|
||||||
if self.noise_type == 'normal':
|
|
||||||
while True:
|
|
||||||
# resample if noise out of percent limit, brute force but shouldn't spin much
|
|
||||||
noise = torch.randn(1, generator=g).item()
|
|
||||||
if abs(noise) < self.noise_pct:
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
noise = 2 * (torch.rand(1, generator=g).item() - 0.5) * self.noise_pct
|
|
||||||
|
|
||||||
# apply the noise on top of previous LR, cache the old value so we can restore for normal
|
# apply the noise on top of previous LR, cache the old value so we can restore for normal
|
||||||
# stepping of base scheduler
|
# stepping of base scheduler
|
||||||
|
@ -85,21 +85,29 @@ class Scheduler:
|
|||||||
param_group[self.param_group_field] = value
|
param_group[self.param_group_field] = value
|
||||||
|
|
||||||
def _add_noise(self, lrs, t):
|
def _add_noise(self, lrs, t):
|
||||||
|
if self._is_apply_noise(t):
|
||||||
|
noise = self._calculate_noise(t)
|
||||||
|
lrs = [v + v * noise for v in lrs]
|
||||||
|
return lrs
|
||||||
|
|
||||||
|
def _is_apply_noise(self, t) -> bool:
|
||||||
|
"""Return True if scheduler in noise range."""
|
||||||
if self.noise_range_t is not None:
|
if self.noise_range_t is not None:
|
||||||
if isinstance(self.noise_range_t, (list, tuple)):
|
if isinstance(self.noise_range_t, (list, tuple)):
|
||||||
apply_noise = self.noise_range_t[0] <= t < self.noise_range_t[1]
|
apply_noise = self.noise_range_t[0] <= t < self.noise_range_t[1]
|
||||||
else:
|
else:
|
||||||
apply_noise = t >= self.noise_range_t
|
apply_noise = t >= self.noise_range_t
|
||||||
if apply_noise:
|
return apply_noise
|
||||||
g = torch.Generator()
|
|
||||||
g.manual_seed(self.noise_seed + t)
|
def _calculate_noise(self, t) -> float:
|
||||||
if self.noise_type == 'normal':
|
g = torch.Generator()
|
||||||
while True:
|
g.manual_seed(self.noise_seed + t)
|
||||||
|
if self.noise_type == 'normal':
|
||||||
|
while True:
|
||||||
# resample if noise out of percent limit, brute force but shouldn't spin much
|
# resample if noise out of percent limit, brute force but shouldn't spin much
|
||||||
noise = torch.randn(1, generator=g).item()
|
noise = torch.randn(1, generator=g).item()
|
||||||
if abs(noise) < self.noise_pct:
|
if abs(noise) < self.noise_pct:
|
||||||
break
|
return noise
|
||||||
else:
|
else:
|
||||||
noise = 2 * (torch.rand(1, generator=g).item() - 0.5) * self.noise_pct
|
noise = 2 * (torch.rand(1, generator=g).item() - 0.5) * self.noise_pct
|
||||||
lrs = [v + v * noise for v in lrs]
|
return noise
|
||||||
return lrs
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user