mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Merge pull request #94 from rwightman/lr_noise
Learning rate noise, MobileNetV3 weights, and activate MobileNetV3/EfficientNet weight init change
This commit is contained in:
commit
56e2ac3a6d
16
README.md
16
README.md
@ -2,6 +2,14 @@
|
||||
|
||||
## What's New
|
||||
|
||||
### Feb 29, 2020
|
||||
* New MobileNet-V3 Large weights trained from stratch with this code to 75.77% top-1
|
||||
* IMPORTANT CHANGE - default weight init changed for all MobilenetV3 / EfficientNet / related models
|
||||
* overall results similar to a bit better training from scratch on a few smaller models tried
|
||||
* performance early in training seems consistently improved but less difference by end
|
||||
* set `fix_group_fanout=False` in `_init_weight_goog` fn if you need to reproducte past behaviour
|
||||
* Experimental LR noise feature added applies a random perturbation to LR each epoch in specified range of training
|
||||
|
||||
### Feb 18, 2020
|
||||
* Big refactor of model layers and addition of several attention mechanisms. Several additions motivated by 'Compounding the Performance Improvements...' (https://arxiv.org/abs/2001.06268):
|
||||
* Move layer/module impl into `layers` subfolder/module of `models` and organize in a more granular fashion
|
||||
@ -187,7 +195,8 @@ I've leveraged the training scripts in this repository to train a few of the mod
|
||||
| skresnet34 | 76.912 (23.088) | 93.322 (6.678) | 22.2M | bicubic | 224 |
|
||||
| resnet26d | 76.68 (23.32) | 93.166 (6.834) | 16M | bicubic | 224 |
|
||||
| mixnet_s | 75.988 (24.012) | 92.794 (7.206) | 4.13M | bicubic | 224 |
|
||||
| mobilenetv3_100 | 75.634 (24.366) | 92.708 (7.292) | 5.5M | bicubic | 224 |
|
||||
| mobilenetv3_large_100 | 75.766 (24.234) | 92.542 (7.458) | 5.5M | bicubic | 224 |
|
||||
| mobilenetv3_rw | 75.634 (24.366) | 92.708 (7.292) | 5.5M | bicubic | 224 |
|
||||
| mnasnet_a1 | 75.448 (24.552) | 92.604 (7.396) | 3.89M | bicubic | 224 |
|
||||
| resnet26 | 75.292 (24.708) | 92.57 (7.43) | 16M | bicubic | 224 |
|
||||
| fbnetc_100 | 75.124 (24.876) | 92.386 (7.614) | 5.6M | bilinear | 224 |
|
||||
@ -361,6 +370,11 @@ Trained by [Andrew Lavin](https://github.com/andravin) with 8 V100 cards. Model
|
||||
|
||||
`./distributed_train.sh 8 /imagenet --model efficientnet_es -b 128 --sched step --epochs 450 --decay-epochs 2.4 --decay-rate .97 --opt rmsproptf --opt-eps .001 -j 8 --warmup-lr 1e-6 --weight-decay 1e-5 --drop 0.2 --drop-connect 0.2 --aa rand-m9-mstd0.5 --remode pixel --reprob 0.2 --amp --lr .064`
|
||||
|
||||
### MobileNetV3-Large-100 - 75.766 top-1, 92,542 top-5
|
||||
|
||||
`./distributed_train.sh 2 /imagenet/ --model mobilenetv3_large_100 -b 512 --sched step --epochs 600 --decay-epochs 2.4 --decay-rate .973 --opt rmsproptf --opt-eps .001 -j 7 --warmup-lr 1e-6 --weight-decay 1e-5 --drop 0.2 --drop-connect 0.2 --model-ema --model-ema-decay 0.9999 --aa rand-m9-mstd0.5 --remode pixel --reprob 0.2 --amp --lr .064 --lr-noise 0.42 0.9`
|
||||
|
||||
|
||||
**TODO dig up some more**
|
||||
|
||||
|
||||
|
@ -93,7 +93,7 @@ model_list = [
|
||||
_entry('semnasnet_100', 'MnasNet-A1', '1807.11626'),
|
||||
_entry('spnasnet_100', 'Single-Path NAS', '1904.02877',
|
||||
model_desc='Trained in PyTorch with SGD, cosine LR decay'),
|
||||
_entry('mobilenetv3_rw', 'MobileNet V3-Large 1.0', '1905.02244',
|
||||
_entry('mobilenetv3_large_100', 'MobileNet V3-Large 1.0', '1905.02244',
|
||||
model_desc='Trained in PyTorch with RMSProp, exponential LR decay, and hyper-params matching '
|
||||
'paper as closely as possible.'),
|
||||
|
||||
|
@ -359,15 +359,13 @@ class EfficientNetBuilder:
|
||||
return stages
|
||||
|
||||
|
||||
def _init_weight_goog(m, n='', fix_group_fanout=False):
|
||||
def _init_weight_goog(m, n='', fix_group_fanout=True):
|
||||
""" Weight initialization as per Tensorflow official implementations.
|
||||
|
||||
Args:
|
||||
m (nn.Module): module to init
|
||||
n (str): module name
|
||||
fix_group_fanout (bool): enable correct fanout calculation w/ group convs
|
||||
|
||||
FIXME change fix_group_fanout to default to True if experiments show better training results
|
||||
fix_group_fanout (bool): enable correct (matching Tensorflow TPU impl) fanout calculation w/ group convs
|
||||
|
||||
Handles layers in EfficientNet, EfficientNet-CondConv, MixNet, MnasNet, MobileNetV3, etc:
|
||||
* https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_model.py
|
||||
|
@ -31,7 +31,9 @@ def _cfg(url='', **kwargs):
|
||||
|
||||
default_cfgs = {
|
||||
'mobilenetv3_large_075': _cfg(url=''),
|
||||
'mobilenetv3_large_100': _cfg(url=''),
|
||||
'mobilenetv3_large_100': _cfg(
|
||||
interpolation='bicubic',
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_large_100_ra-f55367f5.pth'),
|
||||
'mobilenetv3_small_075': _cfg(url=''),
|
||||
'mobilenetv3_small_100': _cfg(url=''),
|
||||
'mobilenetv3_rw': _cfg(
|
||||
|
@ -29,8 +29,15 @@ class CosineLRScheduler(Scheduler):
|
||||
warmup_prefix=False,
|
||||
cycle_limit=0,
|
||||
t_in_epochs=True,
|
||||
noise_range_t=None,
|
||||
noise_pct=0.67,
|
||||
noise_std=1.0,
|
||||
noise_seed=42,
|
||||
initialize=True) -> None:
|
||||
super().__init__(optimizer, param_group_field="lr", initialize=initialize)
|
||||
super().__init__(
|
||||
optimizer, param_group_field="lr",
|
||||
noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed,
|
||||
initialize=initialize)
|
||||
|
||||
assert t_initial > 0
|
||||
assert lr_min >= 0
|
||||
|
@ -8,33 +8,34 @@ class PlateauLRScheduler(Scheduler):
|
||||
|
||||
def __init__(self,
|
||||
optimizer,
|
||||
factor=0.1,
|
||||
patience=10,
|
||||
verbose=False,
|
||||
decay_rate=0.1,
|
||||
patience_t=10,
|
||||
verbose=True,
|
||||
threshold=1e-4,
|
||||
cooldown_epochs=0,
|
||||
warmup_updates=0,
|
||||
cooldown_t=0,
|
||||
warmup_t=0,
|
||||
warmup_lr_init=0,
|
||||
lr_min=0,
|
||||
mode='min',
|
||||
initialize=True,
|
||||
):
|
||||
super().__init__(optimizer, 'lr', initialize=False)
|
||||
super().__init__(optimizer, 'lr', initialize=initialize)
|
||||
|
||||
self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
||||
self.optimizer.optimizer,
|
||||
patience=patience,
|
||||
factor=factor,
|
||||
self.optimizer,
|
||||
patience=patience_t,
|
||||
factor=decay_rate,
|
||||
verbose=verbose,
|
||||
threshold=threshold,
|
||||
cooldown=cooldown_epochs,
|
||||
cooldown=cooldown_t,
|
||||
mode=mode,
|
||||
min_lr=lr_min
|
||||
)
|
||||
|
||||
self.warmup_updates = warmup_updates
|
||||
self.warmup_t = warmup_t
|
||||
self.warmup_lr_init = warmup_lr_init
|
||||
|
||||
if self.warmup_updates:
|
||||
self.warmup_active = warmup_updates > 0 # this state updates with num_updates
|
||||
self.warmup_steps = [(v - warmup_lr_init) / self.warmup_updates for v in self.base_values]
|
||||
if self.warmup_t:
|
||||
self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values]
|
||||
super().update_groups(self.warmup_lr_init)
|
||||
else:
|
||||
self.warmup_steps = [1 for _ in self.base_values]
|
||||
@ -51,18 +52,9 @@ class PlateauLRScheduler(Scheduler):
|
||||
self.lr_scheduler.last_epoch = state_dict['last_epoch']
|
||||
|
||||
# override the base class step fn completely
|
||||
def step(self, epoch, val_loss=None):
|
||||
"""Update the learning rate at the end of the given epoch."""
|
||||
if val_loss is not None and not self.warmup_active:
|
||||
self.lr_scheduler.step(val_loss, epoch)
|
||||
def step(self, epoch, metric=None):
|
||||
if epoch <= self.warmup_t:
|
||||
lrs = [self.warmup_lr_init + epoch * s for s in self.warmup_steps]
|
||||
super().update_groups(lrs)
|
||||
else:
|
||||
self.lr_scheduler.last_epoch = epoch
|
||||
|
||||
def get_update_values(self, num_updates: int):
|
||||
if num_updates < self.warmup_updates:
|
||||
lrs = [self.warmup_lr_init + num_updates * s for s in self.warmup_steps]
|
||||
else:
|
||||
self.warmup_active = False # warmup cancelled by first update past warmup_update count
|
||||
lrs = None # no change on update after warmup stage
|
||||
return lrs
|
||||
|
||||
self.lr_scheduler.step(metric, epoch)
|
||||
|
@ -25,6 +25,11 @@ class Scheduler:
|
||||
def __init__(self,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
param_group_field: str,
|
||||
noise_range_t=None,
|
||||
noise_type='normal',
|
||||
noise_pct=0.67,
|
||||
noise_std=1.0,
|
||||
noise_seed=None,
|
||||
initialize: bool = True) -> None:
|
||||
self.optimizer = optimizer
|
||||
self.param_group_field = param_group_field
|
||||
@ -40,6 +45,11 @@ class Scheduler:
|
||||
raise KeyError(f"{self._initial_param_group_field} missing from param_groups[{i}]")
|
||||
self.base_values = [group[self._initial_param_group_field] for group in self.optimizer.param_groups]
|
||||
self.metric = None # any point to having this for all?
|
||||
self.noise_range_t = noise_range_t
|
||||
self.noise_pct = noise_pct
|
||||
self.noise_type = noise_type
|
||||
self.noise_std = noise_std
|
||||
self.noise_seed = noise_seed if noise_seed is not None else 42
|
||||
self.update_groups(self.base_values)
|
||||
|
||||
def state_dict(self) -> Dict[str, Any]:
|
||||
@ -58,12 +68,14 @@ class Scheduler:
|
||||
self.metric = metric
|
||||
values = self.get_epoch_values(epoch)
|
||||
if values is not None:
|
||||
values = self._add_noise(values, epoch)
|
||||
self.update_groups(values)
|
||||
|
||||
def step_update(self, num_updates: int, metric: float = None):
|
||||
self.metric = metric
|
||||
values = self.get_update_values(num_updates)
|
||||
if values is not None:
|
||||
values = self._add_noise(values, num_updates)
|
||||
self.update_groups(values)
|
||||
|
||||
def update_groups(self, values):
|
||||
@ -71,3 +83,23 @@ class Scheduler:
|
||||
values = [values] * len(self.optimizer.param_groups)
|
||||
for param_group, value in zip(self.optimizer.param_groups, values):
|
||||
param_group[self.param_group_field] = value
|
||||
|
||||
def _add_noise(self, lrs, t):
|
||||
if self.noise_range_t is not None:
|
||||
if isinstance(self.noise_range_t, (list, tuple)):
|
||||
apply_noise = self.noise_range_t[0] <= t < self.noise_range_t[1]
|
||||
else:
|
||||
apply_noise = t >= self.noise_range_t
|
||||
if apply_noise:
|
||||
g = torch.Generator()
|
||||
g.manual_seed(self.noise_seed + t)
|
||||
if self.noise_type == 'normal':
|
||||
while True:
|
||||
# resample if noise out of percent limit, brute force but shouldn't spin much
|
||||
noise = torch.randn(1, generator=g).item()
|
||||
if abs(noise) < self.noise_pct:
|
||||
break
|
||||
else:
|
||||
noise = 2 * (torch.rand(1, generator=g).item() - 0.5) * self.noise_pct
|
||||
lrs = [v + v * noise for v in lrs]
|
||||
return lrs
|
||||
|
@ -1,10 +1,22 @@
|
||||
from .cosine_lr import CosineLRScheduler
|
||||
from .tanh_lr import TanhLRScheduler
|
||||
from .step_lr import StepLRScheduler
|
||||
from .plateau_lr import PlateauLRScheduler
|
||||
|
||||
|
||||
def create_scheduler(args, optimizer):
|
||||
num_epochs = args.epochs
|
||||
|
||||
if args.lr_noise is not None:
|
||||
if isinstance(args.lr_noise, (list, tuple)):
|
||||
noise_range = [n * num_epochs for n in args.lr_noise]
|
||||
if len(noise_range) == 1:
|
||||
noise_range = noise_range[0]
|
||||
else:
|
||||
noise_range = args.lr_noise * num_epochs
|
||||
else:
|
||||
noise_range = None
|
||||
|
||||
lr_scheduler = None
|
||||
#FIXME expose cycle parms of the scheduler config to arguments
|
||||
if args.sched == 'cosine':
|
||||
@ -18,6 +30,10 @@ def create_scheduler(args, optimizer):
|
||||
warmup_t=args.warmup_epochs,
|
||||
cycle_limit=1,
|
||||
t_in_epochs=True,
|
||||
noise_range_t=noise_range,
|
||||
noise_pct=args.lr_noise_pct,
|
||||
noise_std=args.lr_noise_std,
|
||||
noise_seed=args.seed,
|
||||
)
|
||||
num_epochs = lr_scheduler.get_cycle_length() + args.cooldown_epochs
|
||||
elif args.sched == 'tanh':
|
||||
@ -30,6 +46,10 @@ def create_scheduler(args, optimizer):
|
||||
warmup_t=args.warmup_epochs,
|
||||
cycle_limit=1,
|
||||
t_in_epochs=True,
|
||||
noise_range_t=noise_range,
|
||||
noise_pct=args.lr_noise_pct,
|
||||
noise_std=args.lr_noise_std,
|
||||
noise_seed=args.seed,
|
||||
)
|
||||
num_epochs = lr_scheduler.get_cycle_length() + args.cooldown_epochs
|
||||
elif args.sched == 'step':
|
||||
@ -39,5 +59,20 @@ def create_scheduler(args, optimizer):
|
||||
decay_rate=args.decay_rate,
|
||||
warmup_lr_init=args.warmup_lr,
|
||||
warmup_t=args.warmup_epochs,
|
||||
noise_range_t=noise_range,
|
||||
noise_pct=args.lr_noise_pct,
|
||||
noise_std=args.lr_noise_std,
|
||||
noise_seed=args.seed,
|
||||
)
|
||||
elif args.sched == 'plateau':
|
||||
lr_scheduler = PlateauLRScheduler(
|
||||
optimizer,
|
||||
decay_rate=args.decay_rate,
|
||||
patience_t=args.patience_epochs,
|
||||
lr_min=args.min_lr,
|
||||
warmup_lr_init=args.warmup_lr,
|
||||
warmup_t=args.warmup_epochs,
|
||||
cooldown_t=args.cooldown_epochs,
|
||||
)
|
||||
|
||||
return lr_scheduler, num_epochs
|
||||
|
@ -10,13 +10,21 @@ class StepLRScheduler(Scheduler):
|
||||
|
||||
def __init__(self,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
decay_t: int,
|
||||
decay_t: float,
|
||||
decay_rate: float = 1.,
|
||||
warmup_t=0,
|
||||
warmup_lr_init=0,
|
||||
t_in_epochs=True,
|
||||
initialize=True) -> None:
|
||||
super().__init__(optimizer, param_group_field="lr", initialize=initialize)
|
||||
noise_range_t=None,
|
||||
noise_pct=0.67,
|
||||
noise_std=1.0,
|
||||
noise_seed=42,
|
||||
initialize=True,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
optimizer, param_group_field="lr",
|
||||
noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed,
|
||||
initialize=initialize)
|
||||
|
||||
self.decay_t = decay_t
|
||||
self.decay_rate = decay_rate
|
||||
@ -33,8 +41,7 @@ class StepLRScheduler(Scheduler):
|
||||
if t < self.warmup_t:
|
||||
lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
|
||||
else:
|
||||
lrs = [v * (self.decay_rate ** (t // self.decay_t))
|
||||
for v in self.base_values]
|
||||
lrs = [v * (self.decay_rate ** (t // self.decay_t)) for v in self.base_values]
|
||||
return lrs
|
||||
|
||||
def get_epoch_values(self, epoch: int):
|
||||
|
@ -28,8 +28,15 @@ class TanhLRScheduler(Scheduler):
|
||||
warmup_prefix=False,
|
||||
cycle_limit=0,
|
||||
t_in_epochs=True,
|
||||
noise_range_t=None,
|
||||
noise_pct=0.67,
|
||||
noise_std=1.0,
|
||||
noise_seed=42,
|
||||
initialize=True) -> None:
|
||||
super().__init__(optimizer, param_group_field="lr", initialize=initialize)
|
||||
super().__init__(
|
||||
optimizer, param_group_field="lr",
|
||||
noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed,
|
||||
initialize=initialize)
|
||||
|
||||
assert t_initial > 0
|
||||
assert lr_min >= 0
|
||||
|
8
train.py
8
train.py
@ -105,6 +105,12 @@ parser.add_argument('--sched', default='step', type=str, metavar='SCHEDULER',
|
||||
help='LR scheduler (default: "step"')
|
||||
parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
|
||||
help='learning rate (default: 0.01)')
|
||||
parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',
|
||||
help='learning rate noise on/off epoch percentages')
|
||||
parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT',
|
||||
help='learning rate noise limit percent (default: 0.67)')
|
||||
parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',
|
||||
help='learning rate noise std-dev (default: 1.0)')
|
||||
parser.add_argument('--warmup-lr', type=float, default=0.0001, metavar='LR',
|
||||
help='warmup learning rate (default: 0.0001)')
|
||||
parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR',
|
||||
@ -119,6 +125,8 @@ parser.add_argument('--warmup-epochs', type=int, default=3, metavar='N',
|
||||
help='epochs to warmup LR, if scheduler supports')
|
||||
parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N',
|
||||
help='epochs to cooldown LR at min_lr, after cyclic schedule ends')
|
||||
parser.add_argument('--patience-epochs', type=int, default=10, metavar='N',
|
||||
help='patience epochs for Plateau LR scheduler (default: 10')
|
||||
parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',
|
||||
help='LR decay rate (default: 0.1)')
|
||||
# Augmentation parameters
|
||||
|
Loading…
x
Reference in New Issue
Block a user