mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Merge pull request #94 from rwightman/lr_noise
Learning rate noise, MobileNetV3 weights, and activate MobileNetV3/EfficientNet weight init change
This commit is contained in:
commit
56e2ac3a6d
16
README.md
16
README.md
@ -2,6 +2,14 @@
|
|||||||
|
|
||||||
## What's New
|
## What's New
|
||||||
|
|
||||||
|
### Feb 29, 2020
|
||||||
|
* New MobileNet-V3 Large weights trained from stratch with this code to 75.77% top-1
|
||||||
|
* IMPORTANT CHANGE - default weight init changed for all MobilenetV3 / EfficientNet / related models
|
||||||
|
* overall results similar to a bit better training from scratch on a few smaller models tried
|
||||||
|
* performance early in training seems consistently improved but less difference by end
|
||||||
|
* set `fix_group_fanout=False` in `_init_weight_goog` fn if you need to reproducte past behaviour
|
||||||
|
* Experimental LR noise feature added applies a random perturbation to LR each epoch in specified range of training
|
||||||
|
|
||||||
### Feb 18, 2020
|
### Feb 18, 2020
|
||||||
* Big refactor of model layers and addition of several attention mechanisms. Several additions motivated by 'Compounding the Performance Improvements...' (https://arxiv.org/abs/2001.06268):
|
* Big refactor of model layers and addition of several attention mechanisms. Several additions motivated by 'Compounding the Performance Improvements...' (https://arxiv.org/abs/2001.06268):
|
||||||
* Move layer/module impl into `layers` subfolder/module of `models` and organize in a more granular fashion
|
* Move layer/module impl into `layers` subfolder/module of `models` and organize in a more granular fashion
|
||||||
@ -187,7 +195,8 @@ I've leveraged the training scripts in this repository to train a few of the mod
|
|||||||
| skresnet34 | 76.912 (23.088) | 93.322 (6.678) | 22.2M | bicubic | 224 |
|
| skresnet34 | 76.912 (23.088) | 93.322 (6.678) | 22.2M | bicubic | 224 |
|
||||||
| resnet26d | 76.68 (23.32) | 93.166 (6.834) | 16M | bicubic | 224 |
|
| resnet26d | 76.68 (23.32) | 93.166 (6.834) | 16M | bicubic | 224 |
|
||||||
| mixnet_s | 75.988 (24.012) | 92.794 (7.206) | 4.13M | bicubic | 224 |
|
| mixnet_s | 75.988 (24.012) | 92.794 (7.206) | 4.13M | bicubic | 224 |
|
||||||
| mobilenetv3_100 | 75.634 (24.366) | 92.708 (7.292) | 5.5M | bicubic | 224 |
|
| mobilenetv3_large_100 | 75.766 (24.234) | 92.542 (7.458) | 5.5M | bicubic | 224 |
|
||||||
|
| mobilenetv3_rw | 75.634 (24.366) | 92.708 (7.292) | 5.5M | bicubic | 224 |
|
||||||
| mnasnet_a1 | 75.448 (24.552) | 92.604 (7.396) | 3.89M | bicubic | 224 |
|
| mnasnet_a1 | 75.448 (24.552) | 92.604 (7.396) | 3.89M | bicubic | 224 |
|
||||||
| resnet26 | 75.292 (24.708) | 92.57 (7.43) | 16M | bicubic | 224 |
|
| resnet26 | 75.292 (24.708) | 92.57 (7.43) | 16M | bicubic | 224 |
|
||||||
| fbnetc_100 | 75.124 (24.876) | 92.386 (7.614) | 5.6M | bilinear | 224 |
|
| fbnetc_100 | 75.124 (24.876) | 92.386 (7.614) | 5.6M | bilinear | 224 |
|
||||||
@ -361,6 +370,11 @@ Trained by [Andrew Lavin](https://github.com/andravin) with 8 V100 cards. Model
|
|||||||
|
|
||||||
`./distributed_train.sh 8 /imagenet --model efficientnet_es -b 128 --sched step --epochs 450 --decay-epochs 2.4 --decay-rate .97 --opt rmsproptf --opt-eps .001 -j 8 --warmup-lr 1e-6 --weight-decay 1e-5 --drop 0.2 --drop-connect 0.2 --aa rand-m9-mstd0.5 --remode pixel --reprob 0.2 --amp --lr .064`
|
`./distributed_train.sh 8 /imagenet --model efficientnet_es -b 128 --sched step --epochs 450 --decay-epochs 2.4 --decay-rate .97 --opt rmsproptf --opt-eps .001 -j 8 --warmup-lr 1e-6 --weight-decay 1e-5 --drop 0.2 --drop-connect 0.2 --aa rand-m9-mstd0.5 --remode pixel --reprob 0.2 --amp --lr .064`
|
||||||
|
|
||||||
|
### MobileNetV3-Large-100 - 75.766 top-1, 92,542 top-5
|
||||||
|
|
||||||
|
`./distributed_train.sh 2 /imagenet/ --model mobilenetv3_large_100 -b 512 --sched step --epochs 600 --decay-epochs 2.4 --decay-rate .973 --opt rmsproptf --opt-eps .001 -j 7 --warmup-lr 1e-6 --weight-decay 1e-5 --drop 0.2 --drop-connect 0.2 --model-ema --model-ema-decay 0.9999 --aa rand-m9-mstd0.5 --remode pixel --reprob 0.2 --amp --lr .064 --lr-noise 0.42 0.9`
|
||||||
|
|
||||||
|
|
||||||
**TODO dig up some more**
|
**TODO dig up some more**
|
||||||
|
|
||||||
|
|
||||||
|
@ -93,7 +93,7 @@ model_list = [
|
|||||||
_entry('semnasnet_100', 'MnasNet-A1', '1807.11626'),
|
_entry('semnasnet_100', 'MnasNet-A1', '1807.11626'),
|
||||||
_entry('spnasnet_100', 'Single-Path NAS', '1904.02877',
|
_entry('spnasnet_100', 'Single-Path NAS', '1904.02877',
|
||||||
model_desc='Trained in PyTorch with SGD, cosine LR decay'),
|
model_desc='Trained in PyTorch with SGD, cosine LR decay'),
|
||||||
_entry('mobilenetv3_rw', 'MobileNet V3-Large 1.0', '1905.02244',
|
_entry('mobilenetv3_large_100', 'MobileNet V3-Large 1.0', '1905.02244',
|
||||||
model_desc='Trained in PyTorch with RMSProp, exponential LR decay, and hyper-params matching '
|
model_desc='Trained in PyTorch with RMSProp, exponential LR decay, and hyper-params matching '
|
||||||
'paper as closely as possible.'),
|
'paper as closely as possible.'),
|
||||||
|
|
||||||
|
@ -359,15 +359,13 @@ class EfficientNetBuilder:
|
|||||||
return stages
|
return stages
|
||||||
|
|
||||||
|
|
||||||
def _init_weight_goog(m, n='', fix_group_fanout=False):
|
def _init_weight_goog(m, n='', fix_group_fanout=True):
|
||||||
""" Weight initialization as per Tensorflow official implementations.
|
""" Weight initialization as per Tensorflow official implementations.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
m (nn.Module): module to init
|
m (nn.Module): module to init
|
||||||
n (str): module name
|
n (str): module name
|
||||||
fix_group_fanout (bool): enable correct fanout calculation w/ group convs
|
fix_group_fanout (bool): enable correct (matching Tensorflow TPU impl) fanout calculation w/ group convs
|
||||||
|
|
||||||
FIXME change fix_group_fanout to default to True if experiments show better training results
|
|
||||||
|
|
||||||
Handles layers in EfficientNet, EfficientNet-CondConv, MixNet, MnasNet, MobileNetV3, etc:
|
Handles layers in EfficientNet, EfficientNet-CondConv, MixNet, MnasNet, MobileNetV3, etc:
|
||||||
* https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_model.py
|
* https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_model.py
|
||||||
|
@ -31,7 +31,9 @@ def _cfg(url='', **kwargs):
|
|||||||
|
|
||||||
default_cfgs = {
|
default_cfgs = {
|
||||||
'mobilenetv3_large_075': _cfg(url=''),
|
'mobilenetv3_large_075': _cfg(url=''),
|
||||||
'mobilenetv3_large_100': _cfg(url=''),
|
'mobilenetv3_large_100': _cfg(
|
||||||
|
interpolation='bicubic',
|
||||||
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_large_100_ra-f55367f5.pth'),
|
||||||
'mobilenetv3_small_075': _cfg(url=''),
|
'mobilenetv3_small_075': _cfg(url=''),
|
||||||
'mobilenetv3_small_100': _cfg(url=''),
|
'mobilenetv3_small_100': _cfg(url=''),
|
||||||
'mobilenetv3_rw': _cfg(
|
'mobilenetv3_rw': _cfg(
|
||||||
|
@ -29,8 +29,15 @@ class CosineLRScheduler(Scheduler):
|
|||||||
warmup_prefix=False,
|
warmup_prefix=False,
|
||||||
cycle_limit=0,
|
cycle_limit=0,
|
||||||
t_in_epochs=True,
|
t_in_epochs=True,
|
||||||
|
noise_range_t=None,
|
||||||
|
noise_pct=0.67,
|
||||||
|
noise_std=1.0,
|
||||||
|
noise_seed=42,
|
||||||
initialize=True) -> None:
|
initialize=True) -> None:
|
||||||
super().__init__(optimizer, param_group_field="lr", initialize=initialize)
|
super().__init__(
|
||||||
|
optimizer, param_group_field="lr",
|
||||||
|
noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed,
|
||||||
|
initialize=initialize)
|
||||||
|
|
||||||
assert t_initial > 0
|
assert t_initial > 0
|
||||||
assert lr_min >= 0
|
assert lr_min >= 0
|
||||||
|
@ -8,33 +8,34 @@ class PlateauLRScheduler(Scheduler):
|
|||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
optimizer,
|
optimizer,
|
||||||
factor=0.1,
|
decay_rate=0.1,
|
||||||
patience=10,
|
patience_t=10,
|
||||||
verbose=False,
|
verbose=True,
|
||||||
threshold=1e-4,
|
threshold=1e-4,
|
||||||
cooldown_epochs=0,
|
cooldown_t=0,
|
||||||
warmup_updates=0,
|
warmup_t=0,
|
||||||
warmup_lr_init=0,
|
warmup_lr_init=0,
|
||||||
lr_min=0,
|
lr_min=0,
|
||||||
|
mode='min',
|
||||||
|
initialize=True,
|
||||||
):
|
):
|
||||||
super().__init__(optimizer, 'lr', initialize=False)
|
super().__init__(optimizer, 'lr', initialize=initialize)
|
||||||
|
|
||||||
self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
||||||
self.optimizer.optimizer,
|
self.optimizer,
|
||||||
patience=patience,
|
patience=patience_t,
|
||||||
factor=factor,
|
factor=decay_rate,
|
||||||
verbose=verbose,
|
verbose=verbose,
|
||||||
threshold=threshold,
|
threshold=threshold,
|
||||||
cooldown=cooldown_epochs,
|
cooldown=cooldown_t,
|
||||||
|
mode=mode,
|
||||||
min_lr=lr_min
|
min_lr=lr_min
|
||||||
)
|
)
|
||||||
|
|
||||||
self.warmup_updates = warmup_updates
|
self.warmup_t = warmup_t
|
||||||
self.warmup_lr_init = warmup_lr_init
|
self.warmup_lr_init = warmup_lr_init
|
||||||
|
if self.warmup_t:
|
||||||
if self.warmup_updates:
|
self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values]
|
||||||
self.warmup_active = warmup_updates > 0 # this state updates with num_updates
|
|
||||||
self.warmup_steps = [(v - warmup_lr_init) / self.warmup_updates for v in self.base_values]
|
|
||||||
super().update_groups(self.warmup_lr_init)
|
super().update_groups(self.warmup_lr_init)
|
||||||
else:
|
else:
|
||||||
self.warmup_steps = [1 for _ in self.base_values]
|
self.warmup_steps = [1 for _ in self.base_values]
|
||||||
@ -51,18 +52,9 @@ class PlateauLRScheduler(Scheduler):
|
|||||||
self.lr_scheduler.last_epoch = state_dict['last_epoch']
|
self.lr_scheduler.last_epoch = state_dict['last_epoch']
|
||||||
|
|
||||||
# override the base class step fn completely
|
# override the base class step fn completely
|
||||||
def step(self, epoch, val_loss=None):
|
def step(self, epoch, metric=None):
|
||||||
"""Update the learning rate at the end of the given epoch."""
|
if epoch <= self.warmup_t:
|
||||||
if val_loss is not None and not self.warmup_active:
|
lrs = [self.warmup_lr_init + epoch * s for s in self.warmup_steps]
|
||||||
self.lr_scheduler.step(val_loss, epoch)
|
super().update_groups(lrs)
|
||||||
else:
|
else:
|
||||||
self.lr_scheduler.last_epoch = epoch
|
self.lr_scheduler.step(metric, epoch)
|
||||||
|
|
||||||
def get_update_values(self, num_updates: int):
|
|
||||||
if num_updates < self.warmup_updates:
|
|
||||||
lrs = [self.warmup_lr_init + num_updates * s for s in self.warmup_steps]
|
|
||||||
else:
|
|
||||||
self.warmup_active = False # warmup cancelled by first update past warmup_update count
|
|
||||||
lrs = None # no change on update after warmup stage
|
|
||||||
return lrs
|
|
||||||
|
|
||||||
|
@ -25,6 +25,11 @@ class Scheduler:
|
|||||||
def __init__(self,
|
def __init__(self,
|
||||||
optimizer: torch.optim.Optimizer,
|
optimizer: torch.optim.Optimizer,
|
||||||
param_group_field: str,
|
param_group_field: str,
|
||||||
|
noise_range_t=None,
|
||||||
|
noise_type='normal',
|
||||||
|
noise_pct=0.67,
|
||||||
|
noise_std=1.0,
|
||||||
|
noise_seed=None,
|
||||||
initialize: bool = True) -> None:
|
initialize: bool = True) -> None:
|
||||||
self.optimizer = optimizer
|
self.optimizer = optimizer
|
||||||
self.param_group_field = param_group_field
|
self.param_group_field = param_group_field
|
||||||
@ -40,6 +45,11 @@ class Scheduler:
|
|||||||
raise KeyError(f"{self._initial_param_group_field} missing from param_groups[{i}]")
|
raise KeyError(f"{self._initial_param_group_field} missing from param_groups[{i}]")
|
||||||
self.base_values = [group[self._initial_param_group_field] for group in self.optimizer.param_groups]
|
self.base_values = [group[self._initial_param_group_field] for group in self.optimizer.param_groups]
|
||||||
self.metric = None # any point to having this for all?
|
self.metric = None # any point to having this for all?
|
||||||
|
self.noise_range_t = noise_range_t
|
||||||
|
self.noise_pct = noise_pct
|
||||||
|
self.noise_type = noise_type
|
||||||
|
self.noise_std = noise_std
|
||||||
|
self.noise_seed = noise_seed if noise_seed is not None else 42
|
||||||
self.update_groups(self.base_values)
|
self.update_groups(self.base_values)
|
||||||
|
|
||||||
def state_dict(self) -> Dict[str, Any]:
|
def state_dict(self) -> Dict[str, Any]:
|
||||||
@ -58,12 +68,14 @@ class Scheduler:
|
|||||||
self.metric = metric
|
self.metric = metric
|
||||||
values = self.get_epoch_values(epoch)
|
values = self.get_epoch_values(epoch)
|
||||||
if values is not None:
|
if values is not None:
|
||||||
|
values = self._add_noise(values, epoch)
|
||||||
self.update_groups(values)
|
self.update_groups(values)
|
||||||
|
|
||||||
def step_update(self, num_updates: int, metric: float = None):
|
def step_update(self, num_updates: int, metric: float = None):
|
||||||
self.metric = metric
|
self.metric = metric
|
||||||
values = self.get_update_values(num_updates)
|
values = self.get_update_values(num_updates)
|
||||||
if values is not None:
|
if values is not None:
|
||||||
|
values = self._add_noise(values, num_updates)
|
||||||
self.update_groups(values)
|
self.update_groups(values)
|
||||||
|
|
||||||
def update_groups(self, values):
|
def update_groups(self, values):
|
||||||
@ -71,3 +83,23 @@ class Scheduler:
|
|||||||
values = [values] * len(self.optimizer.param_groups)
|
values = [values] * len(self.optimizer.param_groups)
|
||||||
for param_group, value in zip(self.optimizer.param_groups, values):
|
for param_group, value in zip(self.optimizer.param_groups, values):
|
||||||
param_group[self.param_group_field] = value
|
param_group[self.param_group_field] = value
|
||||||
|
|
||||||
|
def _add_noise(self, lrs, t):
|
||||||
|
if self.noise_range_t is not None:
|
||||||
|
if isinstance(self.noise_range_t, (list, tuple)):
|
||||||
|
apply_noise = self.noise_range_t[0] <= t < self.noise_range_t[1]
|
||||||
|
else:
|
||||||
|
apply_noise = t >= self.noise_range_t
|
||||||
|
if apply_noise:
|
||||||
|
g = torch.Generator()
|
||||||
|
g.manual_seed(self.noise_seed + t)
|
||||||
|
if self.noise_type == 'normal':
|
||||||
|
while True:
|
||||||
|
# resample if noise out of percent limit, brute force but shouldn't spin much
|
||||||
|
noise = torch.randn(1, generator=g).item()
|
||||||
|
if abs(noise) < self.noise_pct:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
noise = 2 * (torch.rand(1, generator=g).item() - 0.5) * self.noise_pct
|
||||||
|
lrs = [v + v * noise for v in lrs]
|
||||||
|
return lrs
|
||||||
|
@ -1,10 +1,22 @@
|
|||||||
from .cosine_lr import CosineLRScheduler
|
from .cosine_lr import CosineLRScheduler
|
||||||
from .tanh_lr import TanhLRScheduler
|
from .tanh_lr import TanhLRScheduler
|
||||||
from .step_lr import StepLRScheduler
|
from .step_lr import StepLRScheduler
|
||||||
|
from .plateau_lr import PlateauLRScheduler
|
||||||
|
|
||||||
|
|
||||||
def create_scheduler(args, optimizer):
|
def create_scheduler(args, optimizer):
|
||||||
num_epochs = args.epochs
|
num_epochs = args.epochs
|
||||||
|
|
||||||
|
if args.lr_noise is not None:
|
||||||
|
if isinstance(args.lr_noise, (list, tuple)):
|
||||||
|
noise_range = [n * num_epochs for n in args.lr_noise]
|
||||||
|
if len(noise_range) == 1:
|
||||||
|
noise_range = noise_range[0]
|
||||||
|
else:
|
||||||
|
noise_range = args.lr_noise * num_epochs
|
||||||
|
else:
|
||||||
|
noise_range = None
|
||||||
|
|
||||||
lr_scheduler = None
|
lr_scheduler = None
|
||||||
#FIXME expose cycle parms of the scheduler config to arguments
|
#FIXME expose cycle parms of the scheduler config to arguments
|
||||||
if args.sched == 'cosine':
|
if args.sched == 'cosine':
|
||||||
@ -18,6 +30,10 @@ def create_scheduler(args, optimizer):
|
|||||||
warmup_t=args.warmup_epochs,
|
warmup_t=args.warmup_epochs,
|
||||||
cycle_limit=1,
|
cycle_limit=1,
|
||||||
t_in_epochs=True,
|
t_in_epochs=True,
|
||||||
|
noise_range_t=noise_range,
|
||||||
|
noise_pct=args.lr_noise_pct,
|
||||||
|
noise_std=args.lr_noise_std,
|
||||||
|
noise_seed=args.seed,
|
||||||
)
|
)
|
||||||
num_epochs = lr_scheduler.get_cycle_length() + args.cooldown_epochs
|
num_epochs = lr_scheduler.get_cycle_length() + args.cooldown_epochs
|
||||||
elif args.sched == 'tanh':
|
elif args.sched == 'tanh':
|
||||||
@ -30,6 +46,10 @@ def create_scheduler(args, optimizer):
|
|||||||
warmup_t=args.warmup_epochs,
|
warmup_t=args.warmup_epochs,
|
||||||
cycle_limit=1,
|
cycle_limit=1,
|
||||||
t_in_epochs=True,
|
t_in_epochs=True,
|
||||||
|
noise_range_t=noise_range,
|
||||||
|
noise_pct=args.lr_noise_pct,
|
||||||
|
noise_std=args.lr_noise_std,
|
||||||
|
noise_seed=args.seed,
|
||||||
)
|
)
|
||||||
num_epochs = lr_scheduler.get_cycle_length() + args.cooldown_epochs
|
num_epochs = lr_scheduler.get_cycle_length() + args.cooldown_epochs
|
||||||
elif args.sched == 'step':
|
elif args.sched == 'step':
|
||||||
@ -39,5 +59,20 @@ def create_scheduler(args, optimizer):
|
|||||||
decay_rate=args.decay_rate,
|
decay_rate=args.decay_rate,
|
||||||
warmup_lr_init=args.warmup_lr,
|
warmup_lr_init=args.warmup_lr,
|
||||||
warmup_t=args.warmup_epochs,
|
warmup_t=args.warmup_epochs,
|
||||||
|
noise_range_t=noise_range,
|
||||||
|
noise_pct=args.lr_noise_pct,
|
||||||
|
noise_std=args.lr_noise_std,
|
||||||
|
noise_seed=args.seed,
|
||||||
)
|
)
|
||||||
|
elif args.sched == 'plateau':
|
||||||
|
lr_scheduler = PlateauLRScheduler(
|
||||||
|
optimizer,
|
||||||
|
decay_rate=args.decay_rate,
|
||||||
|
patience_t=args.patience_epochs,
|
||||||
|
lr_min=args.min_lr,
|
||||||
|
warmup_lr_init=args.warmup_lr,
|
||||||
|
warmup_t=args.warmup_epochs,
|
||||||
|
cooldown_t=args.cooldown_epochs,
|
||||||
|
)
|
||||||
|
|
||||||
return lr_scheduler, num_epochs
|
return lr_scheduler, num_epochs
|
||||||
|
@ -10,13 +10,21 @@ class StepLRScheduler(Scheduler):
|
|||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
optimizer: torch.optim.Optimizer,
|
optimizer: torch.optim.Optimizer,
|
||||||
decay_t: int,
|
decay_t: float,
|
||||||
decay_rate: float = 1.,
|
decay_rate: float = 1.,
|
||||||
warmup_t=0,
|
warmup_t=0,
|
||||||
warmup_lr_init=0,
|
warmup_lr_init=0,
|
||||||
t_in_epochs=True,
|
t_in_epochs=True,
|
||||||
initialize=True) -> None:
|
noise_range_t=None,
|
||||||
super().__init__(optimizer, param_group_field="lr", initialize=initialize)
|
noise_pct=0.67,
|
||||||
|
noise_std=1.0,
|
||||||
|
noise_seed=42,
|
||||||
|
initialize=True,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(
|
||||||
|
optimizer, param_group_field="lr",
|
||||||
|
noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed,
|
||||||
|
initialize=initialize)
|
||||||
|
|
||||||
self.decay_t = decay_t
|
self.decay_t = decay_t
|
||||||
self.decay_rate = decay_rate
|
self.decay_rate = decay_rate
|
||||||
@ -33,8 +41,7 @@ class StepLRScheduler(Scheduler):
|
|||||||
if t < self.warmup_t:
|
if t < self.warmup_t:
|
||||||
lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
|
lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
|
||||||
else:
|
else:
|
||||||
lrs = [v * (self.decay_rate ** (t // self.decay_t))
|
lrs = [v * (self.decay_rate ** (t // self.decay_t)) for v in self.base_values]
|
||||||
for v in self.base_values]
|
|
||||||
return lrs
|
return lrs
|
||||||
|
|
||||||
def get_epoch_values(self, epoch: int):
|
def get_epoch_values(self, epoch: int):
|
||||||
|
@ -28,8 +28,15 @@ class TanhLRScheduler(Scheduler):
|
|||||||
warmup_prefix=False,
|
warmup_prefix=False,
|
||||||
cycle_limit=0,
|
cycle_limit=0,
|
||||||
t_in_epochs=True,
|
t_in_epochs=True,
|
||||||
|
noise_range_t=None,
|
||||||
|
noise_pct=0.67,
|
||||||
|
noise_std=1.0,
|
||||||
|
noise_seed=42,
|
||||||
initialize=True) -> None:
|
initialize=True) -> None:
|
||||||
super().__init__(optimizer, param_group_field="lr", initialize=initialize)
|
super().__init__(
|
||||||
|
optimizer, param_group_field="lr",
|
||||||
|
noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed,
|
||||||
|
initialize=initialize)
|
||||||
|
|
||||||
assert t_initial > 0
|
assert t_initial > 0
|
||||||
assert lr_min >= 0
|
assert lr_min >= 0
|
||||||
|
8
train.py
8
train.py
@ -105,6 +105,12 @@ parser.add_argument('--sched', default='step', type=str, metavar='SCHEDULER',
|
|||||||
help='LR scheduler (default: "step"')
|
help='LR scheduler (default: "step"')
|
||||||
parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
|
parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
|
||||||
help='learning rate (default: 0.01)')
|
help='learning rate (default: 0.01)')
|
||||||
|
parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',
|
||||||
|
help='learning rate noise on/off epoch percentages')
|
||||||
|
parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT',
|
||||||
|
help='learning rate noise limit percent (default: 0.67)')
|
||||||
|
parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',
|
||||||
|
help='learning rate noise std-dev (default: 1.0)')
|
||||||
parser.add_argument('--warmup-lr', type=float, default=0.0001, metavar='LR',
|
parser.add_argument('--warmup-lr', type=float, default=0.0001, metavar='LR',
|
||||||
help='warmup learning rate (default: 0.0001)')
|
help='warmup learning rate (default: 0.0001)')
|
||||||
parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR',
|
parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR',
|
||||||
@ -119,6 +125,8 @@ parser.add_argument('--warmup-epochs', type=int, default=3, metavar='N',
|
|||||||
help='epochs to warmup LR, if scheduler supports')
|
help='epochs to warmup LR, if scheduler supports')
|
||||||
parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N',
|
parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N',
|
||||||
help='epochs to cooldown LR at min_lr, after cyclic schedule ends')
|
help='epochs to cooldown LR at min_lr, after cyclic schedule ends')
|
||||||
|
parser.add_argument('--patience-epochs', type=int, default=10, metavar='N',
|
||||||
|
help='patience epochs for Plateau LR scheduler (default: 10')
|
||||||
parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',
|
parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',
|
||||||
help='LR decay rate (default: 0.1)')
|
help='LR decay rate (default: 0.1)')
|
||||||
# Augmentation parameters
|
# Augmentation parameters
|
||||||
|
Loading…
x
Reference in New Issue
Block a user