make lr warmup by iter

Summary: change warmup way by iter not by epoch, which will make it more flexible when training small epochs
pull/389/head
liaoxingyu 2021-01-22 11:17:21 +08:00
parent fdaa4b1a84
commit e26182e6ec
50 changed files with 91 additions and 75 deletions

View File

@ -26,13 +26,13 @@ See [INSTALL.md](https://github.com/JDAI-CV/fast-reid/blob/master/docs/INSTALL.m
The designed architecture follows this guide [PyTorch-Project-Template](https://github.com/L1aoXingyu/PyTorch-Project-Template), you can check each folder's purpose by yourself.
See [GETTING_STARTED.md](https://github.com/JDAI-CV/fast-reid/blob/master/docs/GETTING_STARTED.md).
See [GETTING_STARTED.md](https://github.com/JDAI-CV/fast-reid/blob/master/GETTING_STARTED.md).
Learn more at out [documentation](). And see [projects/](https://github.com/JDAI-CV/fast-reid/tree/master/projects) for some projects that are build on top of fastreid.
## Model Zoo and Baselines
We provide a large set of baseline results and trained models available for download in the [Fastreid Model Zoo](https://github.com/JDAI-CV/fast-reid/blob/master/docs/MODEL_ZOO.md).
We provide a large set of baseline results and trained models available for download in the [Fastreid Model Zoo](https://github.com/JDAI-CV/fast-reid/blob/master/MODEL_ZOO.md).
## Deployment

View File

@ -50,7 +50,7 @@ SOLVER:
ETA_MIN_LR: 0.0000007
WARMUP_FACTOR: 0.1
WARMUP_EPOCHS: 10
WARMUP_ITERS: 2000
FREEZE_ITERS: 1000

View File

@ -60,7 +60,7 @@ SOLVER:
GAMMA: 0.1
WARMUP_FACTOR: 0.1
WARMUP_EPOCHS: 10
WARMUP_ITERS: 2000
CHECKPOINT_PERIOD: 30

View File

@ -22,7 +22,7 @@ SOLVER:
IMS_PER_BATCH: 128
MAX_ITER: 60
STEPS: [30, 50]
WARMUP_EPOCHS: 10
WARMUP_ITERS: 2000
CHECKPOINT_PERIOD: 20

View File

@ -16,7 +16,7 @@ SOLVER:
IMS_PER_BATCH: 64
MAX_ITER: 60
DELAY_ITERS: 30
WARMUP_EPOCHS: 10
WARMUP_ITERS: 2000
FREEZE_ITERS: 10
CHECKPOINT_PERIOD: 20

View File

@ -24,7 +24,7 @@ SOLVER:
IMS_PER_BATCH: 512
MAX_ITER: 60
STEPS: [30, 50]
WARMUP_EPOCHS: 10
WARMUP_ITERS: 2000
CHECKPOINT_PERIOD: 20

View File

@ -5,4 +5,4 @@
"""
__version__ = "0.2.0"
__version__ = "1.0.0"

View File

@ -238,7 +238,7 @@ _C.SOLVER.ETA_MIN_LR = 1e-7
# Warmup options
_C.SOLVER.WARMUP_FACTOR = 0.1
_C.SOLVER.WARMUP_EPOCHS = 10
_C.SOLVER.WARMUP_ITERS = 1000
_C.SOLVER.WARMUP_METHOD = "linear"
# Backbone freeze iters

View File

@ -233,7 +233,8 @@ class DefaultTrainer(TrainerBase):
model, data_loader, optimizer
)
self.scheduler = self.build_lr_scheduler(cfg, optimizer)
self.iters_per_epoch = len(data_loader.dataset) // cfg.SOLVER.IMS_PER_BATCH
self.scheduler = self.build_lr_scheduler(cfg, optimizer, self.iters_per_epoch)
# Assume no other objects need to be checkpointed.
# We can later make it checkpoint the stateful hooks
@ -246,12 +247,11 @@ class DefaultTrainer(TrainerBase):
**self.scheduler,
)
self.iters_per_epoch = len(data_loader.dataset) // cfg.SOLVER.IMS_PER_BATCH
self.start_epoch = 0
self.max_epoch = cfg.SOLVER.MAX_EPOCH
self.max_iter = self.max_epoch * self.iters_per_epoch
self.warmup_epochs = cfg.SOLVER.WARMUP_EPOCHS
self.warmup_iters = cfg.SOLVER.WARMUP_ITERS
self.delay_epochs = cfg.SOLVER.DELAY_EPOCHS
self.cfg = cfg
@ -409,12 +409,12 @@ class DefaultTrainer(TrainerBase):
return build_optimizer(cfg, model)
@classmethod
def build_lr_scheduler(cls, cfg, optimizer):
def build_lr_scheduler(cls, cfg, optimizer, iters_per_epoch):
"""
It now calls :func:`fastreid.solver.build_lr_scheduler`.
Overwrite it if you'd like a different scheduler.
"""
return build_lr_scheduler(cfg, optimizer)
return build_lr_scheduler(cfg, optimizer, iters_per_epoch)
@classmethod
def build_train_loader(cls, cfg):

View File

@ -250,11 +250,14 @@ class LRScheduler(HookBase):
lr = self._optimizer.param_groups[self._best_param_group_id]["lr"]
self.trainer.storage.put_scalar("lr", lr, smoothing_hint=False)
def after_epoch(self):
next_epoch = self.trainer.epoch + 1
if next_epoch <= self.trainer.warmup_epochs:
next_iter = self.trainer.iter + 1
if next_iter <= self.trainer.warmup_iters:
self._scheduler["warmup_sched"].step()
elif next_epoch >= self.trainer.delay_epochs:
def after_epoch(self):
next_iter = self.trainer.iter + 1
next_epoch = self.trainer.epoch + 1
if next_iter > self.trainer.warmup_iters and next_epoch >= self.trainer.delay_epochs:
self._scheduler["lr_sched"].step()

View File

@ -224,7 +224,7 @@ _C.OPTIM.WEIGHT_DECAY = 5e-4
_C.OPTIM.WARMUP_FACTOR = 0.1
# Gradually warm up the OPTIM.BASE_LR over this number of epochs
_C.OPTIM.WARMUP_EPOCHS = 0
_C.OPTIM.WARMUP_ITERS = 0
# ------------------------------------------------------------------------------------ #
# Training options

View File

@ -13,7 +13,7 @@ OPTIM:
MAX_EPOCH: 100
MOMENTUM: 0.9
WEIGHT_DECAY: 5e-5
WARMUP_EPOCHS: 5
WARMUP_ITERS: 5
TRAIN:
DATASET: imagenet
IM_SIZE: 224

View File

@ -13,7 +13,7 @@ OPTIM:
MAX_EPOCH: 100
MOMENTUM: 0.9
WEIGHT_DECAY: 5e-5
WARMUP_EPOCHS: 5
WARMUP_ITERS: 5
TRAIN:
DATASET: imagenet
IM_SIZE: 224

View File

@ -13,7 +13,7 @@ OPTIM:
MAX_EPOCH: 100
MOMENTUM: 0.9
WEIGHT_DECAY: 5e-5
WARMUP_EPOCHS: 5
WARMUP_ITERS: 5
TRAIN:
DATASET: imagenet
IM_SIZE: 224

View File

@ -13,7 +13,7 @@ OPTIM:
MAX_EPOCH: 100
MOMENTUM: 0.9
WEIGHT_DECAY: 5e-5
WARMUP_EPOCHS: 5
WARMUP_ITERS: 5
TRAIN:
DATASET: imagenet
IM_SIZE: 224

View File

@ -13,7 +13,7 @@ OPTIM:
MAX_EPOCH: 100
MOMENTUM: 0.9
WEIGHT_DECAY: 5e-5
WARMUP_EPOCHS: 5
WARMUP_ITERS: 5
TRAIN:
DATASET: imagenet
IM_SIZE: 224

View File

@ -13,7 +13,7 @@ OPTIM:
MAX_EPOCH: 100
MOMENTUM: 0.9
WEIGHT_DECAY: 5e-5
WARMUP_EPOCHS: 5
WARMUP_ITERS: 5
TRAIN:
DATASET: imagenet
IM_SIZE: 224

View File

@ -13,7 +13,7 @@ OPTIM:
MAX_EPOCH: 100
MOMENTUM: 0.9
WEIGHT_DECAY: 5e-5
WARMUP_EPOCHS: 5
WARMUP_ITERS: 5
TRAIN:
DATASET: imagenet
IM_SIZE: 224

View File

@ -13,7 +13,7 @@ OPTIM:
MAX_EPOCH: 100
MOMENTUM: 0.9
WEIGHT_DECAY: 5e-5
WARMUP_EPOCHS: 5
WARMUP_ITERS: 5
TRAIN:
DATASET: imagenet
IM_SIZE: 224

View File

@ -13,7 +13,7 @@ OPTIM:
MAX_EPOCH: 100
MOMENTUM: 0.9
WEIGHT_DECAY: 5e-5
WARMUP_EPOCHS: 5
WARMUP_ITERS: 5
TRAIN:
DATASET: imagenet
IM_SIZE: 224

View File

@ -13,7 +13,7 @@ OPTIM:
MAX_EPOCH: 100
MOMENTUM: 0.9
WEIGHT_DECAY: 5e-5
WARMUP_EPOCHS: 5
WARMUP_ITERS: 5
TRAIN:
DATASET: imagenet
IM_SIZE: 224

View File

@ -13,7 +13,7 @@ OPTIM:
MAX_EPOCH: 100
MOMENTUM: 0.9
WEIGHT_DECAY: 5e-5
WARMUP_EPOCHS: 5
WARMUP_ITERS: 5
TRAIN:
DATASET: imagenet
IM_SIZE: 224

View File

@ -13,7 +13,7 @@ OPTIM:
MAX_EPOCH: 100
MOMENTUM: 0.9
WEIGHT_DECAY: 5e-5
WARMUP_EPOCHS: 5
WARMUP_ITERS: 5
TRAIN:
DATASET: imagenet
IM_SIZE: 224

View File

@ -14,7 +14,7 @@ OPTIM:
MAX_EPOCH: 100
MOMENTUM: 0.9
WEIGHT_DECAY: 5e-5
WARMUP_EPOCHS: 5
WARMUP_ITERS: 5
TRAIN:
DATASET: imagenet
IM_SIZE: 224

View File

@ -14,7 +14,7 @@ OPTIM:
MAX_EPOCH: 100
MOMENTUM: 0.9
WEIGHT_DECAY: 5e-5
WARMUP_EPOCHS: 5
WARMUP_ITERS: 5
TRAIN:
DATASET: imagenet
IM_SIZE: 224

View File

@ -14,7 +14,7 @@ OPTIM:
MAX_EPOCH: 100
MOMENTUM: 0.9
WEIGHT_DECAY: 5e-5
WARMUP_EPOCHS: 5
WARMUP_ITERS: 5
TRAIN:
DATASET: imagenet
IM_SIZE: 224

View File

@ -14,7 +14,6 @@ OPTIM:
MAX_EPOCH: 100
MOMENTUM: 0.9
WEIGHT_DECAY: 5e-5
WARMUP_EPOCHS: 5
TRAIN:
DATASET: imagenet
IM_SIZE: 224

View File

@ -14,7 +14,7 @@ OPTIM:
MAX_EPOCH: 100
MOMENTUM: 0.9
WEIGHT_DECAY: 5e-5
WARMUP_EPOCHS: 5
WARMUP_ITERS: 5
TRAIN:
DATASET: imagenet
IM_SIZE: 224

View File

@ -14,7 +14,7 @@ OPTIM:
MAX_EPOCH: 100
MOMENTUM: 0.9
WEIGHT_DECAY: 5e-5
WARMUP_EPOCHS: 5
WARMUP_ITERS: 5
TRAIN:
DATASET: imagenet
IM_SIZE: 224

View File

@ -14,7 +14,7 @@ OPTIM:
MAX_EPOCH: 100
MOMENTUM: 0.9
WEIGHT_DECAY: 5e-5
WARMUP_EPOCHS: 5
WARMUP_ITERS: 5
TRAIN:
DATASET: imagenet
IM_SIZE: 224

View File

@ -14,7 +14,7 @@ OPTIM:
MAX_EPOCH: 100
MOMENTUM: 0.9
WEIGHT_DECAY: 5e-5
WARMUP_EPOCHS: 5
WARMUP_ITERS: 5
TRAIN:
DATASET: imagenet
IM_SIZE: 224

View File

@ -14,7 +14,7 @@ OPTIM:
MAX_EPOCH: 100
MOMENTUM: 0.9
WEIGHT_DECAY: 5e-5
WARMUP_EPOCHS: 5
WARMUP_ITERS: 5
TRAIN:
DATASET: imagenet
IM_SIZE: 224

View File

@ -14,7 +14,7 @@ OPTIM:
MAX_EPOCH: 100
MOMENTUM: 0.9
WEIGHT_DECAY: 5e-5
WARMUP_EPOCHS: 5
WARMUP_ITERS: 5
TRAIN:
DATASET: imagenet
IM_SIZE: 224

View File

@ -14,7 +14,7 @@ OPTIM:
MAX_EPOCH: 100
MOMENTUM: 0.9
WEIGHT_DECAY: 5e-5
WARMUP_EPOCHS: 5
WARMUP_ITERS: 5
TRAIN:
DATASET: imagenet
IM_SIZE: 224

View File

@ -14,7 +14,7 @@ OPTIM:
MAX_EPOCH: 100
MOMENTUM: 0.9
WEIGHT_DECAY: 5e-5
WARMUP_EPOCHS: 5
WARMUP_ITERS: 5
TRAIN:
DATASET: imagenet
IM_SIZE: 224

View File

@ -64,7 +64,7 @@ class Distiller(Baseline):
return super(Distiller, self).forward(batched_inputs)
def losses(self, s_outputs, t_outputs, gt_labels):
r"""
"""
Compute loss from modeling's outputs, the loss function input arguments
must be the same as the outputs of the model forwarding.
"""

View File

@ -4,6 +4,8 @@
@contact: sherlockliao01@gmail.com
"""
import math
from . import lr_scheduler
from . import optim
@ -34,11 +36,9 @@ def build_optimizer(cfg, model):
return opt_fns
def build_lr_scheduler(cfg, optimizer):
cfg = cfg.clone()
cfg.defrost()
cfg.SOLVER.MAX_EPOCH = cfg.SOLVER.MAX_EPOCH - max(
cfg.SOLVER.WARMUP_EPOCHS + 1, cfg.SOLVER.DELAY_EPOCHS)
def build_lr_scheduler(cfg, optimizer, iters_per_epoch):
max_epoch = cfg.SOLVER.MAX_EPOCH - max(
math.ceil(cfg.SOLVER.WARMUP_ITERS / iters_per_epoch), cfg.SOLVER.DELAY_EPOCHS)
scheduler_dict = {}
@ -52,7 +52,7 @@ def build_lr_scheduler(cfg, optimizer):
"CosineAnnealingLR": {
"optimizer": optimizer,
# cosine annealing lr scheduler options
"T_max": cfg.SOLVER.MAX_EPOCH,
"T_max": max_epoch,
"eta_min": cfg.SOLVER.ETA_MIN_LR,
},
@ -61,13 +61,13 @@ def build_lr_scheduler(cfg, optimizer):
scheduler_dict["lr_sched"] = getattr(lr_scheduler, cfg.SOLVER.SCHED)(
**scheduler_args[cfg.SOLVER.SCHED])
if cfg.SOLVER.WARMUP_EPOCHS > 0:
if cfg.SOLVER.WARMUP_ITERS > 0:
warmup_args = {
"optimizer": optimizer,
# warmup options
"warmup_factor": cfg.SOLVER.WARMUP_FACTOR,
"warmup_epochs": cfg.SOLVER.WARMUP_EPOCHS,
"warmup_iters": cfg.SOLVER.WARMUP_ITERS,
"warmup_method": cfg.SOLVER.WARMUP_METHOD,
}
scheduler_dict["warmup_sched"] = lr_scheduler.WarmupLR(**warmup_args)

View File

@ -15,18 +15,18 @@ class WarmupLR(torch.optim.lr_scheduler._LRScheduler):
self,
optimizer: torch.optim.Optimizer,
warmup_factor: float = 0.1,
warmup_epochs: int = 10,
warmup_iters: int = 1000,
warmup_method: str = "linear",
last_epoch: int = -1,
):
self.warmup_factor = warmup_factor
self.warmup_epochs = warmup_epochs
self.warmup_iters = warmup_iters
self.warmup_method = warmup_method
super().__init__(optimizer, last_epoch)
def get_lr(self) -> List[float]:
warmup_factor = _get_warmup_factor_at_epoch(
self.warmup_method, self.last_epoch, self.warmup_epochs, self.warmup_factor
self.warmup_method, self.last_epoch, self.warmup_iters, self.warmup_factor
)
return [
base_lr * warmup_factor for base_lr in self.base_lrs
@ -38,29 +38,29 @@ class WarmupLR(torch.optim.lr_scheduler._LRScheduler):
def _get_warmup_factor_at_epoch(
method: str, epoch: int, warmup_epochs: int, warmup_factor: float
method: str, iter: int, warmup_iters: int, warmup_factor: float
) -> float:
"""
Return the learning rate warmup factor at a specific iteration.
See https://arxiv.org/abs/1706.02677 for more details.
Args:
method (str): warmup method; either "constant" or "linear".
epoch (int): epoch at which to calculate the warmup factor.
warmup_epochs (int): the number of warmup epochs.
iter (int): iter at which to calculate the warmup factor.
warmup_iters (int): the number of warmup epochs.
warmup_factor (float): the base warmup factor (the meaning changes according
to the method used).
Returns:
float: the effective warmup factor at the given iteration.
"""
if epoch >= warmup_epochs:
if iter >= warmup_iters:
return 1.0
if method == "constant":
return warmup_factor
elif method == "linear":
alpha = epoch / warmup_epochs
alpha = iter / warmup_iters
return warmup_factor * (1 - alpha) + alpha
elif method == "exp":
return warmup_factor ** (1 - epoch / warmup_epochs)
return warmup_factor ** (1 - iter / warmup_iters)
else:
raise ValueError("Unknown warmup method: {}".format(method))

View File

@ -73,6 +73,7 @@ class Checkpointer(object):
def save(self, name: str, **kwargs: Dict[str, str]):
"""
Dump model and checkpointables to a file.
Args:
name (str): name of the file.
kwargs (dict): extra arbitrary data to save.
@ -98,6 +99,7 @@ class Checkpointer(object):
"""
Load from the given checkpoint. When path points to network file, this
function has to be called on all ranks.
Args:
path (str): path or url to the checkpoint. If empty, will not load
anything.
@ -176,6 +178,7 @@ class Checkpointer(object):
If `resume` is True, this method attempts to resume from the last
checkpoint, if exists. Otherwise, load checkpoint from the given path.
This is useful when restarting an interrupted training job.
Args:
path (str): path to the checkpoint.
resume (bool): if True, resume from the last checkpoint if it exists.
@ -191,6 +194,7 @@ class Checkpointer(object):
def tag_last_checkpoint(self, last_filename_basename: str):
"""
Tag the last checkpoint.
Args:
last_filename_basename (str): the basename of the last filename.
"""
@ -202,6 +206,7 @@ class Checkpointer(object):
"""
Load a checkpoint file. Can be overwritten by subclasses to support
different formats.
Args:
f (str): a locally mounted file path.
Returns:
@ -214,6 +219,7 @@ class Checkpointer(object):
def _load_model(self, checkpoint: Any):
"""
Load weights from a checkpoint.
Args:
checkpoint (Any): checkpoint contains the weights.
"""
@ -269,6 +275,7 @@ class Checkpointer(object):
def _convert_ndarray_to_tensor(self, state_dict: dict):
"""
In-place convert all numpy arrays in the state_dict to torch tensor.
Args:
state_dict (dict): a state-dict to be loaded to the model.
"""
@ -313,6 +320,7 @@ class PeriodicCheckpointer:
def step(self, epoch: int, **kwargs: Any):
"""
Perform the appropriate action at the given iteration.
Args:
epoch (int): the current epoch, ranged in [0, max_epoch-1].
kwargs (Any): extra data to save, same as in
@ -342,6 +350,7 @@ class PeriodicCheckpointer:
"""
Same argument as :meth:`Checkpointer.save`.
Use this method to manually save checkpoints outside the schedule.
Args:
name (str): file name.
kwargs (Any): extra data to save, same as in
@ -374,6 +383,7 @@ def get_missing_parameters_message(keys: List[str]) -> str:
"""
Get a logging-friendly message to report parameter names (keys) that are in
the model but not found in a checkpoint.
Args:
keys (list[str]): List of keys that were not found in the checkpoint.
Returns:
@ -391,6 +401,7 @@ def get_unexpected_parameters_message(keys: List[str]) -> str:
"""
Get a logging-friendly message to report parameter names (keys) that are in
the checkpoint but not found in the model.
Args:
keys (list[str]): List of keys that were not found in the model.
Returns:
@ -407,6 +418,7 @@ def get_unexpected_parameters_message(keys: List[str]) -> str:
def _strip_prefix_if_present(state_dict: Dict[str, Any], prefix: str) -> None:
"""
Strip the prefix in metadata, if any.
Args:
state_dict (OrderedDict): a state-dict to be loaded to the model.
prefix (str): prefix.
@ -441,6 +453,7 @@ def _group_checkpoint_keys(keys: List[str]) -> Dict[str, List[str]]:
"""
Group keys based on common prefixes. A prefix is the string up to the final
"." in each key.
Args:
keys (list[str]): list of parameter names, i.e. keys in the model
checkpoint dict.
@ -461,6 +474,7 @@ def _group_checkpoint_keys(keys: List[str]) -> Dict[str, List[str]]:
def _group_to_str(group: List[str]) -> str:
"""
Format a group of parameter name suffixes into a loggable string.
Args:
group (list[str]): list of parameter name suffixes.
Returns:

View File

@ -50,7 +50,7 @@ SOLVER:
STEPS: [ 15, 20, 25 ]
WARMUP_FACTOR: 0.1
WARMUP_EPOCHS: 0
WARMUP_ITERS: 1000
CHECKPOINT_PERIOD: 10

View File

@ -60,7 +60,7 @@ SOLVER:
ETA_MIN_LR: 0.00003
WARMUP_FACTOR: 0.1
WARMUP_EPOCHS: 10
WARMUP_ITERS: 2000
CHECKPOINT_PERIOD: 10

View File

@ -59,7 +59,7 @@ SOLVER:
IMS_PER_BATCH: 512
WARMUP_FACTOR: 0.1
WARMUP_EPOCHS: 1
WARMUP_ITERS: 5000
CHECKPOINT_PERIOD: 2

View File

@ -61,7 +61,7 @@ SOLVER:
ETA_MIN_LR: 0.00003
WARMUP_FACTOR: 0.1
WARMUP_EPOCHS: 10
WARMUP_ITERS: 1000
CHECKPOINT_PERIOD: 10

View File

@ -71,7 +71,7 @@ SOLVER:
FREEZE_ITERS: 500
WARMUP_FACTOR: 0.1
WARMUP_EPOCHS: 5
WARMUP_ITERS: 1000
CHECKPOINT_PERIOD: 100

View File

@ -57,7 +57,7 @@ SOLVER:
GAMMA: 0.1
WARMUP_FACTOR: 0.01
WARMUP_ITERS: 5
WARMUP_ITERS: 1000
CHECKPOINT_PERIOD: 10

View File

@ -17,4 +17,4 @@ termcolor
scikit-learn
tabulate
gdown
faiss-cpu
faiss-gpu

View File

@ -80,7 +80,8 @@ def do_train(cfg, model, resume=False):
optimizer_ckpt = dict(optimizer=optimizer)
scheduler = build_lr_scheduler(cfg, optimizer)
iters_per_epoch = len(data_loader.dataset) // cfg.SOLVER.IMS_PER_BATCH
scheduler = build_lr_scheduler(cfg, optimizer, iters_per_epoch)
checkpointer = Checkpointer(
model,
@ -90,8 +91,6 @@ def do_train(cfg, model, resume=False):
**scheduler
)
iters_per_epoch = len(data_loader.dataset) // cfg.SOLVER.IMS_PER_BATCH
start_epoch = (
checkpointer.resume_or_load(cfg.MODEL.WEIGHTS, resume=resume).get("epoch", -1) + 1
)
@ -99,7 +98,7 @@ def do_train(cfg, model, resume=False):
max_epoch = cfg.SOLVER.MAX_EPOCH
max_iter = max_epoch * iters_per_epoch
warmup_epochs = cfg.SOLVER.WARMUP_EPOCHS
warmup_iters = cfg.SOLVER.WARMUP_ITERS
delay_epochs = cfg.SOLVER.DELAY_EPOCHS
periodic_checkpointer = PeriodicCheckpointer(checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD, max_epoch)
@ -146,13 +145,14 @@ def do_train(cfg, model, resume=False):
iteration += 1
if iteration <= warmup_iters:
scheduler["warmup_sched"].step()
# Write metrics after each epoch
for writer in writers:
writer.write()
if (epoch + 1) <= warmup_epochs:
scheduler["warmup_sched"].step()
elif (epoch + 1) >= delay_epochs:
if iteration > warmup_iters and (epoch + 1) >= delay_epochs:
scheduler["lr_sched"].step()
if (