fix lr scheduler warning when amp training

Skip lr scheduler when this iteration creates NaN gradients
pull/504/head
liaoxingyu 2021-06-02 16:35:46 +08:00
parent 2d2279be6a
commit 8f8cbf9411
2 changed files with 9 additions and 5 deletions

View File

@ -15,7 +15,6 @@ import sys
from collections import OrderedDict
import torch
import torch.nn.functional as F
from torch.nn.parallel import DistributedDataParallel
from fastreid.data import build_reid_test_loader, build_reid_train_loader
@ -155,7 +154,6 @@ class DefaultPredictor:
return predictions
class DefaultTrainer(TrainerBase):
"""
A trainer with default training logic. Compared to `SimpleTrainer`, it
@ -488,5 +486,5 @@ class DefaultTrainer(TrainerBase):
# Access basic attributes from the underlying trainer
for _attr in ["model", "data_loader", "optimizer"]:
setattr(DefaultTrainer, _attr, property(lambda self, x=_attr: getattr(self._trainer, x)))
for _attr in ["model", "data_loader", "optimizer", "grad_scaler"]:
setattr(DefaultTrainer, _attr, property(lambda self, x=_attr: getattr(self._trainer, x, None)))

View File

@ -226,6 +226,7 @@ class LRScheduler(HookBase):
"""
self._optimizer = optimizer
self._scheduler = scheduler
self._scale = 0
# NOTE: some heuristics on what LR to summarize
# summarize the param group with most parameters
@ -246,13 +247,18 @@ class LRScheduler(HookBase):
self._best_param_group_id = i
break
def before_step(self):
if self.trainer.grad_scaler is not None:
self._scale = self.trainer.grad_scaler.get_scale()
def after_step(self):
lr = self._optimizer.param_groups[self._best_param_group_id]["lr"]
self.trainer.storage.put_scalar("lr", lr, smoothing_hint=False)
next_iter = self.trainer.iter + 1
if next_iter <= self.trainer.warmup_iters:
self._scheduler["warmup_sched"].step()
if self.trainer.grad_scaler is None or self._scale == self.trainer.grad_scaler.get_scale():
self._scheduler["warmup_sched"].step()
def after_epoch(self):
next_iter = self.trainer.iter + 1