From e26182e6ec862870c339d01f4eb849ab5b0c26c7 Mon Sep 17 00:00:00 2001
From: liaoxingyu <sherlockliao01@gmail.com>
Date: Fri, 22 Jan 2021 11:17:21 +0800
Subject: [PATCH] make lr warmup by iter

Summary: change warmup way by iter not by epoch, which will make it more flexible when training small epochs
---
 docs/GETTING_STARTED.md => GETTING_STARTED.md  |  0
 docs/INSTALL.md => INSTALL.md                  |  0
 docs/MODEL_ZOO.md => MODEL_ZOO.md              |  0
 README.md                                      |  4 ++--
 configs/Base-SBS.yml                           |  2 +-
 configs/Base-bagtricks.yml                     |  2 +-
 configs/VERIWild/bagtricks_R50-ibn.yml         |  2 +-
 configs/VeRi/sbs_R50-ibn.yml                   |  2 +-
 configs/VehicleID/bagtricks_R50-ibn.yml        |  2 +-
 fastreid/__init__.py                           |  2 +-
 fastreid/config/defaults.py                    |  2 +-
 fastreid/engine/defaults.py                    | 10 +++++-----
 fastreid/engine/hooks.py                       | 11 +++++++----
 fastreid/modeling/backbones/regnet/config.py   |  2 +-
 .../regnet/regnetx/RegNetX-1.6GF_dds_8gpu.yaml |  2 +-
 .../regnet/regnetx/RegNetX-12GF_dds_8gpu.yaml  |  2 +-
 .../regnet/regnetx/RegNetX-16GF_dds_8gpu.yaml  |  2 +-
 .../regnet/regnetx/RegNetX-200MF_dds_8gpu.yaml |  2 +-
 .../regnet/regnetx/RegNetX-3.2GF_dds_8gpu.yaml |  2 +-
 .../regnet/regnetx/RegNetX-32GF_dds_8gpu.yaml  |  2 +-
 .../regnet/regnetx/RegNetX-4.0GF_dds_8gpu.yaml |  2 +-
 .../regnet/regnetx/RegNetX-400MF_dds_8gpu.yaml |  2 +-
 .../regnet/regnetx/RegNetX-6.4GF_dds_8gpu.yaml |  2 +-
 .../regnet/regnetx/RegNetX-600MF_dds_8gpu.yaml |  2 +-
 .../regnet/regnetx/RegNetX-8.0GF_dds_8gpu.yaml |  2 +-
 .../regnet/regnetx/RegNetX-800MF_dds_8gpu.yaml |  2 +-
 .../regnet/regnety/RegNetY-1.6GF_dds_8gpu.yaml |  2 +-
 .../regnet/regnety/RegNetY-12GF_dds_8gpu.yaml  |  2 +-
 .../regnet/regnety/RegNetY-16GF_dds_8gpu.yaml  |  2 +-
 .../regnet/regnety/RegNetY-200MF_dds_8gpu.yaml |  1 -
 .../regnet/regnety/RegNetY-3.2GF_dds_8gpu.yaml |  2 +-
 .../regnet/regnety/RegNetY-32GF_dds_8gpu.yaml  |  2 +-
 .../regnet/regnety/RegNetY-4.0GF_dds_8gpu.yaml |  2 +-
 .../regnet/regnety/RegNetY-400MF_dds_8gpu.yaml |  2 +-
 .../regnet/regnety/RegNetY-6.4GF_dds_8gpu.yaml |  2 +-
 .../regnet/regnety/RegNetY-600MF_dds_8gpu.yaml |  2 +-
 .../regnet/regnety/RegNetY-8.0GF_dds_8gpu.yaml |  2 +-
 .../regnet/regnety/RegNetY-800MF_dds_8gpu.yaml |  2 +-
 fastreid/modeling/meta_arch/distiller.py       |  2 +-
 fastreid/solver/build.py                       | 16 ++++++++--------
 fastreid/solver/lr_scheduler.py                | 18 +++++++++---------
 fastreid/utils/checkpoint.py                   | 14 ++++++++++++++
 projects/FastAttr/configs/Base-attribute.yml   |  2 +-
 projects/FastCls/configs/base-cls.yaml         |  2 +-
 projects/FastFace/configs/face_base.yml        |  2 +-
 .../FastRetri/configs/base-image_retri.yml     |  2 +-
 projects/FastTune/configs/search_trial.yml     |  2 +-
 .../PartialReID/configs/partial_market.yml     |  2 +-
 docs/requirements => requirements.txt          |  2 +-
 tools/plain_train_net.py                       | 14 +++++++-------
 50 files changed, 91 insertions(+), 75 deletions(-)
 rename docs/GETTING_STARTED.md => GETTING_STARTED.md (100%)
 rename docs/INSTALL.md => INSTALL.md (100%)
 rename docs/MODEL_ZOO.md => MODEL_ZOO.md (100%)
 rename docs/requirements => requirements.txt (95%)

diff --git a/docs/GETTING_STARTED.md b/GETTING_STARTED.md
similarity index 100%
rename from docs/GETTING_STARTED.md
rename to GETTING_STARTED.md
diff --git a/docs/INSTALL.md b/INSTALL.md
similarity index 100%
rename from docs/INSTALL.md
rename to INSTALL.md
diff --git a/docs/MODEL_ZOO.md b/MODEL_ZOO.md
similarity index 100%
rename from docs/MODEL_ZOO.md
rename to MODEL_ZOO.md
diff --git a/README.md b/README.md
index 0203e52..e8e64cf 100644
--- a/README.md
+++ b/README.md
@@ -26,13 +26,13 @@ See [INSTALL.md](https://github.com/JDAI-CV/fast-reid/blob/master/docs/INSTALL.m
 
 The designed architecture follows this guide [PyTorch-Project-Template](https://github.com/L1aoXingyu/PyTorch-Project-Template), you can check each folder's purpose by yourself.
 
-See [GETTING_STARTED.md](https://github.com/JDAI-CV/fast-reid/blob/master/docs/GETTING_STARTED.md).
+See [GETTING_STARTED.md](https://github.com/JDAI-CV/fast-reid/blob/master/GETTING_STARTED.md).
 
 Learn more at out [documentation](). And see [projects/](https://github.com/JDAI-CV/fast-reid/tree/master/projects) for some projects that are build on top of fastreid.
 
 ## Model Zoo and Baselines
 
-We provide a large set of baseline results and trained models available for download in the [Fastreid Model Zoo](https://github.com/JDAI-CV/fast-reid/blob/master/docs/MODEL_ZOO.md).
+We provide a large set of baseline results and trained models available for download in the [Fastreid Model Zoo](https://github.com/JDAI-CV/fast-reid/blob/master/MODEL_ZOO.md).
 
 ## Deployment
 
diff --git a/configs/Base-SBS.yml b/configs/Base-SBS.yml
index 04d1f9b..5085a59 100644
--- a/configs/Base-SBS.yml
+++ b/configs/Base-SBS.yml
@@ -50,7 +50,7 @@ SOLVER:
   ETA_MIN_LR: 0.0000007
 
   WARMUP_FACTOR: 0.1
-  WARMUP_EPOCHS: 10
+  WARMUP_ITERS: 2000
 
   FREEZE_ITERS: 1000
 
diff --git a/configs/Base-bagtricks.yml b/configs/Base-bagtricks.yml
index a3da846..204346f 100644
--- a/configs/Base-bagtricks.yml
+++ b/configs/Base-bagtricks.yml
@@ -60,7 +60,7 @@ SOLVER:
   GAMMA: 0.1
 
   WARMUP_FACTOR: 0.1
-  WARMUP_EPOCHS: 10
+  WARMUP_ITERS: 2000
 
   CHECKPOINT_PERIOD: 30
 
diff --git a/configs/VERIWild/bagtricks_R50-ibn.yml b/configs/VERIWild/bagtricks_R50-ibn.yml
index 77a52c2..de6da4a 100644
--- a/configs/VERIWild/bagtricks_R50-ibn.yml
+++ b/configs/VERIWild/bagtricks_R50-ibn.yml
@@ -22,7 +22,7 @@ SOLVER:
   IMS_PER_BATCH: 128
   MAX_ITER: 60
   STEPS: [30, 50]
-  WARMUP_EPOCHS: 10
+  WARMUP_ITERS: 2000
 
   CHECKPOINT_PERIOD: 20
 
diff --git a/configs/VeRi/sbs_R50-ibn.yml b/configs/VeRi/sbs_R50-ibn.yml
index 724d460..686b1fc 100644
--- a/configs/VeRi/sbs_R50-ibn.yml
+++ b/configs/VeRi/sbs_R50-ibn.yml
@@ -16,7 +16,7 @@ SOLVER:
   IMS_PER_BATCH: 64
   MAX_ITER: 60
   DELAY_ITERS: 30
-  WARMUP_EPOCHS: 10
+  WARMUP_ITERS: 2000
   FREEZE_ITERS: 10
 
   CHECKPOINT_PERIOD: 20
diff --git a/configs/VehicleID/bagtricks_R50-ibn.yml b/configs/VehicleID/bagtricks_R50-ibn.yml
index 9ce8456..4b1c439 100644
--- a/configs/VehicleID/bagtricks_R50-ibn.yml
+++ b/configs/VehicleID/bagtricks_R50-ibn.yml
@@ -24,7 +24,7 @@ SOLVER:
   IMS_PER_BATCH: 512
   MAX_ITER: 60
   STEPS: [30, 50]
-  WARMUP_EPOCHS: 10
+  WARMUP_ITERS: 2000
 
   CHECKPOINT_PERIOD: 20
 
diff --git a/fastreid/__init__.py b/fastreid/__init__.py
index 94eab6e..0efd06c 100644
--- a/fastreid/__init__.py
+++ b/fastreid/__init__.py
@@ -5,4 +5,4 @@
 """
 
 
-__version__ = "0.2.0"
+__version__ = "1.0.0"
diff --git a/fastreid/config/defaults.py b/fastreid/config/defaults.py
index 3ac8251..dbc2170 100644
--- a/fastreid/config/defaults.py
+++ b/fastreid/config/defaults.py
@@ -238,7 +238,7 @@ _C.SOLVER.ETA_MIN_LR = 1e-7
 
 # Warmup options
 _C.SOLVER.WARMUP_FACTOR = 0.1
-_C.SOLVER.WARMUP_EPOCHS = 10
+_C.SOLVER.WARMUP_ITERS = 1000
 _C.SOLVER.WARMUP_METHOD = "linear"
 
 # Backbone freeze iters
diff --git a/fastreid/engine/defaults.py b/fastreid/engine/defaults.py
index 3d52bfa..28bb560 100644
--- a/fastreid/engine/defaults.py
+++ b/fastreid/engine/defaults.py
@@ -233,7 +233,8 @@ class DefaultTrainer(TrainerBase):
             model, data_loader, optimizer
         )
 
-        self.scheduler = self.build_lr_scheduler(cfg, optimizer)
+        self.iters_per_epoch = len(data_loader.dataset) // cfg.SOLVER.IMS_PER_BATCH
+        self.scheduler = self.build_lr_scheduler(cfg, optimizer, self.iters_per_epoch)
 
         # Assume no other objects need to be checkpointed.
         # We can later make it checkpoint the stateful hooks
@@ -246,12 +247,11 @@ class DefaultTrainer(TrainerBase):
             **self.scheduler,
         )
 
-        self.iters_per_epoch = len(data_loader.dataset) // cfg.SOLVER.IMS_PER_BATCH
 
         self.start_epoch = 0
         self.max_epoch = cfg.SOLVER.MAX_EPOCH
         self.max_iter = self.max_epoch * self.iters_per_epoch
-        self.warmup_epochs = cfg.SOLVER.WARMUP_EPOCHS
+        self.warmup_iters = cfg.SOLVER.WARMUP_ITERS
         self.delay_epochs = cfg.SOLVER.DELAY_EPOCHS
         self.cfg = cfg
 
@@ -409,12 +409,12 @@ class DefaultTrainer(TrainerBase):
         return build_optimizer(cfg, model)
 
     @classmethod
-    def build_lr_scheduler(cls, cfg, optimizer):
+    def build_lr_scheduler(cls, cfg, optimizer, iters_per_epoch):
         """
         It now calls :func:`fastreid.solver.build_lr_scheduler`.
         Overwrite it if you'd like a different scheduler.
         """
-        return build_lr_scheduler(cfg, optimizer)
+        return build_lr_scheduler(cfg, optimizer, iters_per_epoch)
 
     @classmethod
     def build_train_loader(cls, cfg):
diff --git a/fastreid/engine/hooks.py b/fastreid/engine/hooks.py
index 91fbaf9..f7111a3 100644
--- a/fastreid/engine/hooks.py
+++ b/fastreid/engine/hooks.py
@@ -250,11 +250,14 @@ class LRScheduler(HookBase):
         lr = self._optimizer.param_groups[self._best_param_group_id]["lr"]
         self.trainer.storage.put_scalar("lr", lr, smoothing_hint=False)
 
-    def after_epoch(self):
-        next_epoch = self.trainer.epoch + 1
-        if next_epoch <= self.trainer.warmup_epochs:
+        next_iter = self.trainer.iter + 1
+        if next_iter <= self.trainer.warmup_iters:
             self._scheduler["warmup_sched"].step()
-        elif next_epoch >= self.trainer.delay_epochs:
+
+    def after_epoch(self):
+        next_iter = self.trainer.iter + 1
+        next_epoch = self.trainer.epoch + 1
+        if next_iter > self.trainer.warmup_iters and next_epoch >= self.trainer.delay_epochs:
             self._scheduler["lr_sched"].step()
 
 
diff --git a/fastreid/modeling/backbones/regnet/config.py b/fastreid/modeling/backbones/regnet/config.py
index 4496480..cb967dc 100644
--- a/fastreid/modeling/backbones/regnet/config.py
+++ b/fastreid/modeling/backbones/regnet/config.py
@@ -224,7 +224,7 @@ _C.OPTIM.WEIGHT_DECAY = 5e-4
 _C.OPTIM.WARMUP_FACTOR = 0.1
 
 # Gradually warm up the OPTIM.BASE_LR over this number of epochs
-_C.OPTIM.WARMUP_EPOCHS = 0
+_C.OPTIM.WARMUP_ITERS = 0
 
 # ------------------------------------------------------------------------------------ #
 # Training options
diff --git a/fastreid/modeling/backbones/regnet/regnetx/RegNetX-1.6GF_dds_8gpu.yaml b/fastreid/modeling/backbones/regnet/regnetx/RegNetX-1.6GF_dds_8gpu.yaml
index c8133d7..7127fd1 100644
--- a/fastreid/modeling/backbones/regnet/regnetx/RegNetX-1.6GF_dds_8gpu.yaml
+++ b/fastreid/modeling/backbones/regnet/regnetx/RegNetX-1.6GF_dds_8gpu.yaml
@@ -13,7 +13,7 @@ OPTIM:
   MAX_EPOCH: 100
   MOMENTUM: 0.9
   WEIGHT_DECAY: 5e-5
-  WARMUP_EPOCHS: 5
+  WARMUP_ITERS: 5
 TRAIN:
   DATASET: imagenet
   IM_SIZE: 224
diff --git a/fastreid/modeling/backbones/regnet/regnetx/RegNetX-12GF_dds_8gpu.yaml b/fastreid/modeling/backbones/regnet/regnetx/RegNetX-12GF_dds_8gpu.yaml
index a3edf42..63d4ef9 100644
--- a/fastreid/modeling/backbones/regnet/regnetx/RegNetX-12GF_dds_8gpu.yaml
+++ b/fastreid/modeling/backbones/regnet/regnetx/RegNetX-12GF_dds_8gpu.yaml
@@ -13,7 +13,7 @@ OPTIM:
   MAX_EPOCH: 100
   MOMENTUM: 0.9
   WEIGHT_DECAY: 5e-5
-  WARMUP_EPOCHS: 5
+  WARMUP_ITERS: 5
 TRAIN:
   DATASET: imagenet
   IM_SIZE: 224
diff --git a/fastreid/modeling/backbones/regnet/regnetx/RegNetX-16GF_dds_8gpu.yaml b/fastreid/modeling/backbones/regnet/regnetx/RegNetX-16GF_dds_8gpu.yaml
index 0f94f9b..768763b 100644
--- a/fastreid/modeling/backbones/regnet/regnetx/RegNetX-16GF_dds_8gpu.yaml
+++ b/fastreid/modeling/backbones/regnet/regnetx/RegNetX-16GF_dds_8gpu.yaml
@@ -13,7 +13,7 @@ OPTIM:
   MAX_EPOCH: 100
   MOMENTUM: 0.9
   WEIGHT_DECAY: 5e-5
-  WARMUP_EPOCHS: 5
+  WARMUP_ITERS: 5
 TRAIN:
   DATASET: imagenet
   IM_SIZE: 224
diff --git a/fastreid/modeling/backbones/regnet/regnetx/RegNetX-200MF_dds_8gpu.yaml b/fastreid/modeling/backbones/regnet/regnetx/RegNetX-200MF_dds_8gpu.yaml
index c0b13b6..e41c967 100644
--- a/fastreid/modeling/backbones/regnet/regnetx/RegNetX-200MF_dds_8gpu.yaml
+++ b/fastreid/modeling/backbones/regnet/regnetx/RegNetX-200MF_dds_8gpu.yaml
@@ -13,7 +13,7 @@ OPTIM:
   MAX_EPOCH: 100
   MOMENTUM: 0.9
   WEIGHT_DECAY: 5e-5
-  WARMUP_EPOCHS: 5
+  WARMUP_ITERS: 5
 TRAIN:
   DATASET: imagenet
   IM_SIZE: 224
diff --git a/fastreid/modeling/backbones/regnet/regnetx/RegNetX-3.2GF_dds_8gpu.yaml b/fastreid/modeling/backbones/regnet/regnetx/RegNetX-3.2GF_dds_8gpu.yaml
index 594d533..f7af0d7 100644
--- a/fastreid/modeling/backbones/regnet/regnetx/RegNetX-3.2GF_dds_8gpu.yaml
+++ b/fastreid/modeling/backbones/regnet/regnetx/RegNetX-3.2GF_dds_8gpu.yaml
@@ -13,7 +13,7 @@ OPTIM:
   MAX_EPOCH: 100
   MOMENTUM: 0.9
   WEIGHT_DECAY: 5e-5
-  WARMUP_EPOCHS: 5
+  WARMUP_ITERS: 5
 TRAIN:
   DATASET: imagenet
   IM_SIZE: 224
diff --git a/fastreid/modeling/backbones/regnet/regnetx/RegNetX-32GF_dds_8gpu.yaml b/fastreid/modeling/backbones/regnet/regnetx/RegNetX-32GF_dds_8gpu.yaml
index c1d34b8..9dfb78b 100644
--- a/fastreid/modeling/backbones/regnet/regnetx/RegNetX-32GF_dds_8gpu.yaml
+++ b/fastreid/modeling/backbones/regnet/regnetx/RegNetX-32GF_dds_8gpu.yaml
@@ -13,7 +13,7 @@ OPTIM:
   MAX_EPOCH: 100
   MOMENTUM: 0.9
   WEIGHT_DECAY: 5e-5
-  WARMUP_EPOCHS: 5
+  WARMUP_ITERS: 5
 TRAIN:
   DATASET: imagenet
   IM_SIZE: 224
diff --git a/fastreid/modeling/backbones/regnet/regnetx/RegNetX-4.0GF_dds_8gpu.yaml b/fastreid/modeling/backbones/regnet/regnetx/RegNetX-4.0GF_dds_8gpu.yaml
index bd95453..d884ed6 100644
--- a/fastreid/modeling/backbones/regnet/regnetx/RegNetX-4.0GF_dds_8gpu.yaml
+++ b/fastreid/modeling/backbones/regnet/regnetx/RegNetX-4.0GF_dds_8gpu.yaml
@@ -13,7 +13,7 @@ OPTIM:
   MAX_EPOCH: 100
   MOMENTUM: 0.9
   WEIGHT_DECAY: 5e-5
-  WARMUP_EPOCHS: 5
+  WARMUP_ITERS: 5
 TRAIN:
   DATASET: imagenet
   IM_SIZE: 224
diff --git a/fastreid/modeling/backbones/regnet/regnetx/RegNetX-400MF_dds_8gpu.yaml b/fastreid/modeling/backbones/regnet/regnetx/RegNetX-400MF_dds_8gpu.yaml
index 7b887ad..1c14a16 100644
--- a/fastreid/modeling/backbones/regnet/regnetx/RegNetX-400MF_dds_8gpu.yaml
+++ b/fastreid/modeling/backbones/regnet/regnetx/RegNetX-400MF_dds_8gpu.yaml
@@ -13,7 +13,7 @@ OPTIM:
   MAX_EPOCH: 100
   MOMENTUM: 0.9
   WEIGHT_DECAY: 5e-5
-  WARMUP_EPOCHS: 5
+  WARMUP_ITERS: 5
 TRAIN:
   DATASET: imagenet
   IM_SIZE: 224
diff --git a/fastreid/modeling/backbones/regnet/regnetx/RegNetX-6.4GF_dds_8gpu.yaml b/fastreid/modeling/backbones/regnet/regnetx/RegNetX-6.4GF_dds_8gpu.yaml
index f256e64..48aeb83 100644
--- a/fastreid/modeling/backbones/regnet/regnetx/RegNetX-6.4GF_dds_8gpu.yaml
+++ b/fastreid/modeling/backbones/regnet/regnetx/RegNetX-6.4GF_dds_8gpu.yaml
@@ -13,7 +13,7 @@ OPTIM:
   MAX_EPOCH: 100
   MOMENTUM: 0.9
   WEIGHT_DECAY: 5e-5
-  WARMUP_EPOCHS: 5
+  WARMUP_ITERS: 5
 TRAIN:
   DATASET: imagenet
   IM_SIZE: 224
diff --git a/fastreid/modeling/backbones/regnet/regnetx/RegNetX-600MF_dds_8gpu.yaml b/fastreid/modeling/backbones/regnet/regnetx/RegNetX-600MF_dds_8gpu.yaml
index aca28aa..8ddf668 100644
--- a/fastreid/modeling/backbones/regnet/regnetx/RegNetX-600MF_dds_8gpu.yaml
+++ b/fastreid/modeling/backbones/regnet/regnetx/RegNetX-600MF_dds_8gpu.yaml
@@ -13,7 +13,7 @@ OPTIM:
   MAX_EPOCH: 100
   MOMENTUM: 0.9
   WEIGHT_DECAY: 5e-5
-  WARMUP_EPOCHS: 5
+  WARMUP_ITERS: 5
 TRAIN:
   DATASET: imagenet
   IM_SIZE: 224
diff --git a/fastreid/modeling/backbones/regnet/regnetx/RegNetX-8.0GF_dds_8gpu.yaml b/fastreid/modeling/backbones/regnet/regnetx/RegNetX-8.0GF_dds_8gpu.yaml
index a4141d6..b75beb0 100644
--- a/fastreid/modeling/backbones/regnet/regnetx/RegNetX-8.0GF_dds_8gpu.yaml
+++ b/fastreid/modeling/backbones/regnet/regnetx/RegNetX-8.0GF_dds_8gpu.yaml
@@ -13,7 +13,7 @@ OPTIM:
   MAX_EPOCH: 100
   MOMENTUM: 0.9
   WEIGHT_DECAY: 5e-5
-  WARMUP_EPOCHS: 5
+  WARMUP_ITERS: 5
 TRAIN:
   DATASET: imagenet
   IM_SIZE: 224
diff --git a/fastreid/modeling/backbones/regnet/regnetx/RegNetX-800MF_dds_8gpu.yaml b/fastreid/modeling/backbones/regnet/regnetx/RegNetX-800MF_dds_8gpu.yaml
index 8d2f6ae..8e76db7 100644
--- a/fastreid/modeling/backbones/regnet/regnetx/RegNetX-800MF_dds_8gpu.yaml
+++ b/fastreid/modeling/backbones/regnet/regnetx/RegNetX-800MF_dds_8gpu.yaml
@@ -13,7 +13,7 @@ OPTIM:
   MAX_EPOCH: 100
   MOMENTUM: 0.9
   WEIGHT_DECAY: 5e-5
-  WARMUP_EPOCHS: 5
+  WARMUP_ITERS: 5
 TRAIN:
   DATASET: imagenet
   IM_SIZE: 224
diff --git a/fastreid/modeling/backbones/regnet/regnety/RegNetY-1.6GF_dds_8gpu.yaml b/fastreid/modeling/backbones/regnet/regnety/RegNetY-1.6GF_dds_8gpu.yaml
index 2dc9f37..25c76e9 100644
--- a/fastreid/modeling/backbones/regnet/regnety/RegNetY-1.6GF_dds_8gpu.yaml
+++ b/fastreid/modeling/backbones/regnet/regnety/RegNetY-1.6GF_dds_8gpu.yaml
@@ -14,7 +14,7 @@ OPTIM:
   MAX_EPOCH: 100
   MOMENTUM: 0.9
   WEIGHT_DECAY: 5e-5
-  WARMUP_EPOCHS: 5
+  WARMUP_ITERS: 5
 TRAIN:
   DATASET: imagenet
   IM_SIZE: 224
diff --git a/fastreid/modeling/backbones/regnet/regnety/RegNetY-12GF_dds_8gpu.yaml b/fastreid/modeling/backbones/regnet/regnety/RegNetY-12GF_dds_8gpu.yaml
index 6d27d5d..5dea853 100644
--- a/fastreid/modeling/backbones/regnet/regnety/RegNetY-12GF_dds_8gpu.yaml
+++ b/fastreid/modeling/backbones/regnet/regnety/RegNetY-12GF_dds_8gpu.yaml
@@ -14,7 +14,7 @@ OPTIM:
   MAX_EPOCH: 100
   MOMENTUM: 0.9
   WEIGHT_DECAY: 5e-5
-  WARMUP_EPOCHS: 5
+  WARMUP_ITERS: 5
 TRAIN:
   DATASET: imagenet
   IM_SIZE: 224
diff --git a/fastreid/modeling/backbones/regnet/regnety/RegNetY-16GF_dds_8gpu.yaml b/fastreid/modeling/backbones/regnet/regnety/RegNetY-16GF_dds_8gpu.yaml
index 605d215..9d77d4c 100644
--- a/fastreid/modeling/backbones/regnet/regnety/RegNetY-16GF_dds_8gpu.yaml
+++ b/fastreid/modeling/backbones/regnet/regnety/RegNetY-16GF_dds_8gpu.yaml
@@ -14,7 +14,7 @@ OPTIM:
   MAX_EPOCH: 100
   MOMENTUM: 0.9
   WEIGHT_DECAY: 5e-5
-  WARMUP_EPOCHS: 5
+  WARMUP_ITERS: 5
 TRAIN:
   DATASET: imagenet
   IM_SIZE: 224
diff --git a/fastreid/modeling/backbones/regnet/regnety/RegNetY-200MF_dds_8gpu.yaml b/fastreid/modeling/backbones/regnet/regnety/RegNetY-200MF_dds_8gpu.yaml
index 300cc43..cfe006b 100644
--- a/fastreid/modeling/backbones/regnet/regnety/RegNetY-200MF_dds_8gpu.yaml
+++ b/fastreid/modeling/backbones/regnet/regnety/RegNetY-200MF_dds_8gpu.yaml
@@ -14,7 +14,6 @@ OPTIM:
   MAX_EPOCH: 100
   MOMENTUM: 0.9
   WEIGHT_DECAY: 5e-5
-  WARMUP_EPOCHS: 5
 TRAIN:
   DATASET: imagenet
   IM_SIZE: 224
diff --git a/fastreid/modeling/backbones/regnet/regnety/RegNetY-3.2GF_dds_8gpu.yaml b/fastreid/modeling/backbones/regnet/regnety/RegNetY-3.2GF_dds_8gpu.yaml
index 95f05ba..70f5f37 100644
--- a/fastreid/modeling/backbones/regnet/regnety/RegNetY-3.2GF_dds_8gpu.yaml
+++ b/fastreid/modeling/backbones/regnet/regnety/RegNetY-3.2GF_dds_8gpu.yaml
@@ -14,7 +14,7 @@ OPTIM:
   MAX_EPOCH: 100
   MOMENTUM: 0.9
   WEIGHT_DECAY: 5e-5
-  WARMUP_EPOCHS: 5
+  WARMUP_ITERS: 5
 TRAIN:
   DATASET: imagenet
   IM_SIZE: 224
diff --git a/fastreid/modeling/backbones/regnet/regnety/RegNetY-32GF_dds_8gpu.yaml b/fastreid/modeling/backbones/regnet/regnety/RegNetY-32GF_dds_8gpu.yaml
index 753d7a5..89e83e1 100644
--- a/fastreid/modeling/backbones/regnet/regnety/RegNetY-32GF_dds_8gpu.yaml
+++ b/fastreid/modeling/backbones/regnet/regnety/RegNetY-32GF_dds_8gpu.yaml
@@ -14,7 +14,7 @@ OPTIM:
   MAX_EPOCH: 100
   MOMENTUM: 0.9
   WEIGHT_DECAY: 5e-5
-  WARMUP_EPOCHS: 5
+  WARMUP_ITERS: 5
 TRAIN:
   DATASET: imagenet
   IM_SIZE: 224
diff --git a/fastreid/modeling/backbones/regnet/regnety/RegNetY-4.0GF_dds_8gpu.yaml b/fastreid/modeling/backbones/regnet/regnety/RegNetY-4.0GF_dds_8gpu.yaml
index 27895a9..4a1afab 100644
--- a/fastreid/modeling/backbones/regnet/regnety/RegNetY-4.0GF_dds_8gpu.yaml
+++ b/fastreid/modeling/backbones/regnet/regnety/RegNetY-4.0GF_dds_8gpu.yaml
@@ -14,7 +14,7 @@ OPTIM:
   MAX_EPOCH: 100
   MOMENTUM: 0.9
   WEIGHT_DECAY: 5e-5
-  WARMUP_EPOCHS: 5
+  WARMUP_ITERS: 5
 TRAIN:
   DATASET: imagenet
   IM_SIZE: 224
diff --git a/fastreid/modeling/backbones/regnet/regnety/RegNetY-400MF_dds_8gpu.yaml b/fastreid/modeling/backbones/regnet/regnety/RegNetY-400MF_dds_8gpu.yaml
index 1b1c31b..0cca887 100644
--- a/fastreid/modeling/backbones/regnet/regnety/RegNetY-400MF_dds_8gpu.yaml
+++ b/fastreid/modeling/backbones/regnet/regnety/RegNetY-400MF_dds_8gpu.yaml
@@ -14,7 +14,7 @@ OPTIM:
   MAX_EPOCH: 100
   MOMENTUM: 0.9
   WEIGHT_DECAY: 5e-5
-  WARMUP_EPOCHS: 5
+  WARMUP_ITERS: 5
 TRAIN:
   DATASET: imagenet
   IM_SIZE: 224
diff --git a/fastreid/modeling/backbones/regnet/regnety/RegNetY-6.4GF_dds_8gpu.yaml b/fastreid/modeling/backbones/regnet/regnety/RegNetY-6.4GF_dds_8gpu.yaml
index 74535c2..a93c4ae 100644
--- a/fastreid/modeling/backbones/regnet/regnety/RegNetY-6.4GF_dds_8gpu.yaml
+++ b/fastreid/modeling/backbones/regnet/regnety/RegNetY-6.4GF_dds_8gpu.yaml
@@ -14,7 +14,7 @@ OPTIM:
   MAX_EPOCH: 100
   MOMENTUM: 0.9
   WEIGHT_DECAY: 5e-5
-  WARMUP_EPOCHS: 5
+  WARMUP_ITERS: 5
 TRAIN:
   DATASET: imagenet
   IM_SIZE: 224
diff --git a/fastreid/modeling/backbones/regnet/regnety/RegNetY-600MF_dds_8gpu.yaml b/fastreid/modeling/backbones/regnet/regnety/RegNetY-600MF_dds_8gpu.yaml
index 661e1a9..adc490d 100644
--- a/fastreid/modeling/backbones/regnet/regnety/RegNetY-600MF_dds_8gpu.yaml
+++ b/fastreid/modeling/backbones/regnet/regnety/RegNetY-600MF_dds_8gpu.yaml
@@ -14,7 +14,7 @@ OPTIM:
   MAX_EPOCH: 100
   MOMENTUM: 0.9
   WEIGHT_DECAY: 5e-5
-  WARMUP_EPOCHS: 5
+  WARMUP_ITERS: 5
 TRAIN:
   DATASET: imagenet
   IM_SIZE: 224
diff --git a/fastreid/modeling/backbones/regnet/regnety/RegNetY-8.0GF_dds_8gpu.yaml b/fastreid/modeling/backbones/regnet/regnety/RegNetY-8.0GF_dds_8gpu.yaml
index 792147a..78377ed 100644
--- a/fastreid/modeling/backbones/regnet/regnety/RegNetY-8.0GF_dds_8gpu.yaml
+++ b/fastreid/modeling/backbones/regnet/regnety/RegNetY-8.0GF_dds_8gpu.yaml
@@ -14,7 +14,7 @@ OPTIM:
   MAX_EPOCH: 100
   MOMENTUM: 0.9
   WEIGHT_DECAY: 5e-5
-  WARMUP_EPOCHS: 5
+  WARMUP_ITERS: 5
 TRAIN:
   DATASET: imagenet
   IM_SIZE: 224
diff --git a/fastreid/modeling/backbones/regnet/regnety/RegNetY-800MF_dds_8gpu.yaml b/fastreid/modeling/backbones/regnet/regnety/RegNetY-800MF_dds_8gpu.yaml
index 6e52823..c6db355 100644
--- a/fastreid/modeling/backbones/regnet/regnety/RegNetY-800MF_dds_8gpu.yaml
+++ b/fastreid/modeling/backbones/regnet/regnety/RegNetY-800MF_dds_8gpu.yaml
@@ -14,7 +14,7 @@ OPTIM:
   MAX_EPOCH: 100
   MOMENTUM: 0.9
   WEIGHT_DECAY: 5e-5
-  WARMUP_EPOCHS: 5
+  WARMUP_ITERS: 5
 TRAIN:
   DATASET: imagenet
   IM_SIZE: 224
diff --git a/fastreid/modeling/meta_arch/distiller.py b/fastreid/modeling/meta_arch/distiller.py
index 1ee3008..5405721 100644
--- a/fastreid/modeling/meta_arch/distiller.py
+++ b/fastreid/modeling/meta_arch/distiller.py
@@ -64,7 +64,7 @@ class Distiller(Baseline):
             return super(Distiller, self).forward(batched_inputs)
 
     def losses(self, s_outputs, t_outputs, gt_labels):
-        r"""
+        """
         Compute loss from modeling's outputs, the loss function input arguments
         must be the same as the outputs of the model forwarding.
         """
diff --git a/fastreid/solver/build.py b/fastreid/solver/build.py
index 524a72e..806aa7e 100644
--- a/fastreid/solver/build.py
+++ b/fastreid/solver/build.py
@@ -4,6 +4,8 @@
 @contact: sherlockliao01@gmail.com
 """
 
+import math
+
 from . import lr_scheduler
 from . import optim
 
@@ -34,11 +36,9 @@ def build_optimizer(cfg, model):
     return opt_fns
 
 
-def build_lr_scheduler(cfg, optimizer):
-    cfg = cfg.clone()
-    cfg.defrost()
-    cfg.SOLVER.MAX_EPOCH = cfg.SOLVER.MAX_EPOCH - max(
-        cfg.SOLVER.WARMUP_EPOCHS + 1, cfg.SOLVER.DELAY_EPOCHS)
+def build_lr_scheduler(cfg, optimizer, iters_per_epoch):
+    max_epoch = cfg.SOLVER.MAX_EPOCH - max(
+        math.ceil(cfg.SOLVER.WARMUP_ITERS / iters_per_epoch), cfg.SOLVER.DELAY_EPOCHS)
 
     scheduler_dict = {}
 
@@ -52,7 +52,7 @@ def build_lr_scheduler(cfg, optimizer):
         "CosineAnnealingLR": {
             "optimizer": optimizer,
             # cosine annealing lr scheduler options
-            "T_max": cfg.SOLVER.MAX_EPOCH,
+            "T_max": max_epoch,
             "eta_min": cfg.SOLVER.ETA_MIN_LR,
         },
 
@@ -61,13 +61,13 @@ def build_lr_scheduler(cfg, optimizer):
     scheduler_dict["lr_sched"] = getattr(lr_scheduler, cfg.SOLVER.SCHED)(
         **scheduler_args[cfg.SOLVER.SCHED])
 
-    if cfg.SOLVER.WARMUP_EPOCHS > 0:
+    if cfg.SOLVER.WARMUP_ITERS > 0:
         warmup_args = {
             "optimizer": optimizer,
 
             # warmup options
             "warmup_factor": cfg.SOLVER.WARMUP_FACTOR,
-            "warmup_epochs": cfg.SOLVER.WARMUP_EPOCHS,
+            "warmup_iters": cfg.SOLVER.WARMUP_ITERS,
             "warmup_method": cfg.SOLVER.WARMUP_METHOD,
         }
         scheduler_dict["warmup_sched"] = lr_scheduler.WarmupLR(**warmup_args)
diff --git a/fastreid/solver/lr_scheduler.py b/fastreid/solver/lr_scheduler.py
index 711c1b5..17ce2b9 100644
--- a/fastreid/solver/lr_scheduler.py
+++ b/fastreid/solver/lr_scheduler.py
@@ -15,18 +15,18 @@ class WarmupLR(torch.optim.lr_scheduler._LRScheduler):
             self,
             optimizer: torch.optim.Optimizer,
             warmup_factor: float = 0.1,
-            warmup_epochs: int = 10,
+            warmup_iters: int = 1000,
             warmup_method: str = "linear",
             last_epoch: int = -1,
     ):
         self.warmup_factor = warmup_factor
-        self.warmup_epochs = warmup_epochs
+        self.warmup_iters = warmup_iters
         self.warmup_method = warmup_method
         super().__init__(optimizer, last_epoch)
 
     def get_lr(self) -> List[float]:
         warmup_factor = _get_warmup_factor_at_epoch(
-            self.warmup_method, self.last_epoch, self.warmup_epochs, self.warmup_factor
+            self.warmup_method, self.last_epoch, self.warmup_iters, self.warmup_factor
         )
         return [
             base_lr * warmup_factor for base_lr in self.base_lrs
@@ -38,29 +38,29 @@ class WarmupLR(torch.optim.lr_scheduler._LRScheduler):
 
 
 def _get_warmup_factor_at_epoch(
-        method: str, epoch: int, warmup_epochs: int, warmup_factor: float
+        method: str, iter: int, warmup_iters: int, warmup_factor: float
 ) -> float:
     """
     Return the learning rate warmup factor at a specific iteration.
     See https://arxiv.org/abs/1706.02677 for more details.
     Args:
         method (str): warmup method; either "constant" or "linear".
-        epoch (int): epoch at which to calculate the warmup factor.
-        warmup_epochs (int): the number of warmup epochs.
+        iter (int): iter at which to calculate the warmup factor.
+        warmup_iters (int): the number of warmup epochs.
         warmup_factor (float): the base warmup factor (the meaning changes according
             to the method used).
     Returns:
         float: the effective warmup factor at the given iteration.
     """
-    if epoch >= warmup_epochs:
+    if iter >= warmup_iters:
         return 1.0
 
     if method == "constant":
         return warmup_factor
     elif method == "linear":
-        alpha = epoch / warmup_epochs
+        alpha = iter / warmup_iters
         return warmup_factor * (1 - alpha) + alpha
     elif method == "exp":
-        return warmup_factor ** (1 - epoch / warmup_epochs)
+        return warmup_factor ** (1 - iter / warmup_iters)
     else:
         raise ValueError("Unknown warmup method: {}".format(method))
diff --git a/fastreid/utils/checkpoint.py b/fastreid/utils/checkpoint.py
index 1f142d0..b31c4c0 100644
--- a/fastreid/utils/checkpoint.py
+++ b/fastreid/utils/checkpoint.py
@@ -73,6 +73,7 @@ class Checkpointer(object):
     def save(self, name: str, **kwargs: Dict[str, str]):
         """
         Dump model and checkpointables to a file.
+
         Args:
             name (str): name of the file.
             kwargs (dict): extra arbitrary data to save.
@@ -98,6 +99,7 @@ class Checkpointer(object):
         """
         Load from the given checkpoint. When path points to network file, this
         function has to be called on all ranks.
+
         Args:
             path (str): path or url to the checkpoint. If empty, will not load
                 anything.
@@ -176,6 +178,7 @@ class Checkpointer(object):
         If `resume` is True, this method attempts to resume from the last
         checkpoint, if exists. Otherwise, load checkpoint from the given path.
         This is useful when restarting an interrupted training job.
+
         Args:
             path (str): path to the checkpoint.
             resume (bool): if True, resume from the last checkpoint if it exists.
@@ -191,6 +194,7 @@ class Checkpointer(object):
     def tag_last_checkpoint(self, last_filename_basename: str):
         """
         Tag the last checkpoint.
+
         Args:
             last_filename_basename (str): the basename of the last filename.
         """
@@ -202,6 +206,7 @@ class Checkpointer(object):
         """
         Load a checkpoint file. Can be overwritten by subclasses to support
         different formats.
+
         Args:
             f (str): a locally mounted file path.
         Returns:
@@ -214,6 +219,7 @@ class Checkpointer(object):
     def _load_model(self, checkpoint: Any):
         """
         Load weights from a checkpoint.
+
         Args:
             checkpoint (Any): checkpoint contains the weights.
         """
@@ -269,6 +275,7 @@ class Checkpointer(object):
     def _convert_ndarray_to_tensor(self, state_dict: dict):
         """
         In-place convert all numpy arrays in the state_dict to torch tensor.
+
         Args:
             state_dict (dict): a state-dict to be loaded to the model.
         """
@@ -313,6 +320,7 @@ class PeriodicCheckpointer:
     def step(self, epoch: int, **kwargs: Any):
         """
         Perform the appropriate action at the given iteration.
+
         Args:
             epoch (int): the current epoch, ranged in [0, max_epoch-1].
             kwargs (Any): extra data to save, same as in
@@ -342,6 +350,7 @@ class PeriodicCheckpointer:
         """
         Same argument as :meth:`Checkpointer.save`.
         Use this method to manually save checkpoints outside the schedule.
+
         Args:
             name (str): file name.
             kwargs (Any): extra data to save, same as in
@@ -374,6 +383,7 @@ def get_missing_parameters_message(keys: List[str]) -> str:
     """
     Get a logging-friendly message to report parameter names (keys) that are in
     the model but not found in a checkpoint.
+
     Args:
         keys (list[str]): List of keys that were not found in the checkpoint.
     Returns:
@@ -391,6 +401,7 @@ def get_unexpected_parameters_message(keys: List[str]) -> str:
     """
     Get a logging-friendly message to report parameter names (keys) that are in
     the checkpoint but not found in the model.
+
     Args:
         keys (list[str]): List of keys that were not found in the model.
     Returns:
@@ -407,6 +418,7 @@ def get_unexpected_parameters_message(keys: List[str]) -> str:
 def _strip_prefix_if_present(state_dict: Dict[str, Any], prefix: str) -> None:
     """
     Strip the prefix in metadata, if any.
+
     Args:
         state_dict (OrderedDict): a state-dict to be loaded to the model.
         prefix (str): prefix.
@@ -441,6 +453,7 @@ def _group_checkpoint_keys(keys: List[str]) -> Dict[str, List[str]]:
     """
     Group keys based on common prefixes. A prefix is the string up to the final
     "." in each key.
+
     Args:
         keys (list[str]): list of parameter names, i.e. keys in the model
             checkpoint dict.
@@ -461,6 +474,7 @@ def _group_checkpoint_keys(keys: List[str]) -> Dict[str, List[str]]:
 def _group_to_str(group: List[str]) -> str:
     """
     Format a group of parameter name suffixes into a loggable string.
+
     Args:
         group (list[str]): list of parameter name suffixes.
     Returns:
diff --git a/projects/FastAttr/configs/Base-attribute.yml b/projects/FastAttr/configs/Base-attribute.yml
index 33eb452..f9ee16e 100644
--- a/projects/FastAttr/configs/Base-attribute.yml
+++ b/projects/FastAttr/configs/Base-attribute.yml
@@ -50,7 +50,7 @@ SOLVER:
   STEPS: [ 15, 20, 25 ]
 
   WARMUP_FACTOR: 0.1
-  WARMUP_EPOCHS: 0
+  WARMUP_ITERS: 1000
 
   CHECKPOINT_PERIOD: 10
 
diff --git a/projects/FastCls/configs/base-cls.yaml b/projects/FastCls/configs/base-cls.yaml
index d02747a..c355703 100644
--- a/projects/FastCls/configs/base-cls.yaml
+++ b/projects/FastCls/configs/base-cls.yaml
@@ -60,7 +60,7 @@ SOLVER:
   ETA_MIN_LR: 0.00003
 
   WARMUP_FACTOR: 0.1
-  WARMUP_EPOCHS: 10
+  WARMUP_ITERS: 2000
 
   CHECKPOINT_PERIOD: 10
 
diff --git a/projects/FastFace/configs/face_base.yml b/projects/FastFace/configs/face_base.yml
index c6e94d5..8a5f5f8 100644
--- a/projects/FastFace/configs/face_base.yml
+++ b/projects/FastFace/configs/face_base.yml
@@ -59,7 +59,7 @@ SOLVER:
   IMS_PER_BATCH: 512
 
   WARMUP_FACTOR: 0.1
-  WARMUP_EPOCHS: 1
+  WARMUP_ITERS: 5000
 
   CHECKPOINT_PERIOD: 2
 
diff --git a/projects/FastRetri/configs/base-image_retri.yml b/projects/FastRetri/configs/base-image_retri.yml
index 85df78d..d1bfa87 100644
--- a/projects/FastRetri/configs/base-image_retri.yml
+++ b/projects/FastRetri/configs/base-image_retri.yml
@@ -61,7 +61,7 @@ SOLVER:
   ETA_MIN_LR: 0.00003
 
   WARMUP_FACTOR: 0.1
-  WARMUP_EPOCHS: 10
+  WARMUP_ITERS: 1000
 
   CHECKPOINT_PERIOD: 10
 
diff --git a/projects/FastTune/configs/search_trial.yml b/projects/FastTune/configs/search_trial.yml
index 1220dfe..18ae173 100644
--- a/projects/FastTune/configs/search_trial.yml
+++ b/projects/FastTune/configs/search_trial.yml
@@ -71,7 +71,7 @@ SOLVER:
   FREEZE_ITERS: 500
 
   WARMUP_FACTOR: 0.1
-  WARMUP_EPOCHS: 5
+  WARMUP_ITERS: 1000
 
   CHECKPOINT_PERIOD: 100
 
diff --git a/projects/PartialReID/configs/partial_market.yml b/projects/PartialReID/configs/partial_market.yml
index b78706e..d43a4e5 100644
--- a/projects/PartialReID/configs/partial_market.yml
+++ b/projects/PartialReID/configs/partial_market.yml
@@ -57,7 +57,7 @@ SOLVER:
   GAMMA: 0.1
 
   WARMUP_FACTOR: 0.01
-  WARMUP_ITERS: 5
+  WARMUP_ITERS: 1000
 
   CHECKPOINT_PERIOD: 10
 
diff --git a/docs/requirements b/requirements.txt
similarity index 95%
rename from docs/requirements
rename to requirements.txt
index f9f6594..a53f22c 100644
--- a/docs/requirements
+++ b/requirements.txt
@@ -17,4 +17,4 @@ termcolor
 scikit-learn
 tabulate
 gdown
-faiss-cpu
\ No newline at end of file
+faiss-gpu
\ No newline at end of file
diff --git a/tools/plain_train_net.py b/tools/plain_train_net.py
index 2d62360..53ce9c3 100644
--- a/tools/plain_train_net.py
+++ b/tools/plain_train_net.py
@@ -80,7 +80,8 @@ def do_train(cfg, model, resume=False):
 
     optimizer_ckpt = dict(optimizer=optimizer)
 
-    scheduler = build_lr_scheduler(cfg, optimizer)
+    iters_per_epoch = len(data_loader.dataset) // cfg.SOLVER.IMS_PER_BATCH
+    scheduler = build_lr_scheduler(cfg, optimizer, iters_per_epoch)
 
     checkpointer = Checkpointer(
         model,
@@ -90,8 +91,6 @@ def do_train(cfg, model, resume=False):
         **scheduler
     )
 
-    iters_per_epoch = len(data_loader.dataset) // cfg.SOLVER.IMS_PER_BATCH
-
     start_epoch = (
             checkpointer.resume_or_load(cfg.MODEL.WEIGHTS, resume=resume).get("epoch", -1) + 1
     )
@@ -99,7 +98,7 @@ def do_train(cfg, model, resume=False):
 
     max_epoch = cfg.SOLVER.MAX_EPOCH
     max_iter = max_epoch * iters_per_epoch
-    warmup_epochs = cfg.SOLVER.WARMUP_EPOCHS
+    warmup_iters = cfg.SOLVER.WARMUP_ITERS
     delay_epochs = cfg.SOLVER.DELAY_EPOCHS
 
     periodic_checkpointer = PeriodicCheckpointer(checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD, max_epoch)
@@ -146,13 +145,14 @@ def do_train(cfg, model, resume=False):
 
                 iteration += 1
 
+                if iteration <= warmup_iters:
+                    scheduler["warmup_sched"].step()
+
             # Write metrics after each epoch
             for writer in writers:
                 writer.write()
 
-            if (epoch + 1) <= warmup_epochs:
-                scheduler["warmup_sched"].step()
-            elif (epoch + 1) >= delay_epochs:
+            if iteration > warmup_iters and (epoch + 1) >= delay_epochs:
                 scheduler["lr_sched"].step()
 
             if (