mirror of https://github.com/JDAI-CV/fast-reid.git
fix lr scheduler warning when amp training
Skip lr scheduler when this iteration creates NaN gradientspull/504/head
parent
2d2279be6a
commit
8f8cbf9411
|
@ -15,7 +15,6 @@ import sys
|
|||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
|
||||
from fastreid.data import build_reid_test_loader, build_reid_train_loader
|
||||
|
@ -155,7 +154,6 @@ class DefaultPredictor:
|
|||
return predictions
|
||||
|
||||
|
||||
|
||||
class DefaultTrainer(TrainerBase):
|
||||
"""
|
||||
A trainer with default training logic. Compared to `SimpleTrainer`, it
|
||||
|
@ -488,5 +486,5 @@ class DefaultTrainer(TrainerBase):
|
|||
|
||||
|
||||
# Access basic attributes from the underlying trainer
|
||||
for _attr in ["model", "data_loader", "optimizer"]:
|
||||
setattr(DefaultTrainer, _attr, property(lambda self, x=_attr: getattr(self._trainer, x)))
|
||||
for _attr in ["model", "data_loader", "optimizer", "grad_scaler"]:
|
||||
setattr(DefaultTrainer, _attr, property(lambda self, x=_attr: getattr(self._trainer, x, None)))
|
||||
|
|
|
@ -226,6 +226,7 @@ class LRScheduler(HookBase):
|
|||
"""
|
||||
self._optimizer = optimizer
|
||||
self._scheduler = scheduler
|
||||
self._scale = 0
|
||||
|
||||
# NOTE: some heuristics on what LR to summarize
|
||||
# summarize the param group with most parameters
|
||||
|
@ -246,13 +247,18 @@ class LRScheduler(HookBase):
|
|||
self._best_param_group_id = i
|
||||
break
|
||||
|
||||
def before_step(self):
|
||||
if self.trainer.grad_scaler is not None:
|
||||
self._scale = self.trainer.grad_scaler.get_scale()
|
||||
|
||||
def after_step(self):
|
||||
lr = self._optimizer.param_groups[self._best_param_group_id]["lr"]
|
||||
self.trainer.storage.put_scalar("lr", lr, smoothing_hint=False)
|
||||
|
||||
next_iter = self.trainer.iter + 1
|
||||
if next_iter <= self.trainer.warmup_iters:
|
||||
self._scheduler["warmup_sched"].step()
|
||||
if self.trainer.grad_scaler is None or self._scale == self.trainer.grad_scaler.get_scale():
|
||||
self._scheduler["warmup_sched"].step()
|
||||
|
||||
def after_epoch(self):
|
||||
next_iter = self.trainer.iter + 1
|
||||
|
|
Loading…
Reference in New Issue