diff --git a/fastreid/config/defaults.py b/fastreid/config/defaults.py
index 8e95828..06253e1 100644
--- a/fastreid/config/defaults.py
+++ b/fastreid/config/defaults.py
@@ -43,7 +43,8 @@ _C.MODEL.BACKBONE.PRETRAIN_PATH = ''
 # REID HEADS options
 # ---------------------------------------------------------------------------- #
 _C.MODEL.HEADS = CN()
-_C.MODEL.HEADS.NAME = "BNneckLinear"
+_C.MODEL.HEADS.NAME = "StandardHead"
+_C.MODEL.HEADS.POOL_LAYER = 'avgpool'
 _C.MODEL.HEADS.NUM_CLASSES = 751
 
 # ---------------------------------------------------------------------------- #
@@ -95,7 +96,7 @@ _C.INPUT.BRIGHTNESS = 0.4
 _C.INPUT.CONTRAST = 0.4
 # Random erasing
 _C.INPUT.RE = CN()
-_C.INPUT.RE.ENABLED = True
+_C.INPUT.RE.ENABLED = False
 _C.INPUT.RE.PROB = 0.5
 _C.INPUT.RE.MEAN = [0.596*255, 0.558*255, 0.497*255]
 # Cutout
@@ -103,7 +104,7 @@ _C.INPUT.CUTOUT = CN()
 _C.INPUT.CUTOUT.ENABLED = False
 _C.INPUT.CUTOUT.PROB = 0.5
 _C.INPUT.CUTOUT.SIZE = 64
-_C.INPUT.CUTOUT.MEAN = [0, 0, 0]
+_C.INPUT.CUTOUT.MEAN = [0.485*255, 0.456*255, 0.406*255]
 
 # -----------------------------------------------------------------------------
 # Dataset
@@ -129,7 +130,7 @@ _C.DATALOADER.NUM_WORKERS = 8
 # ---------------------------------------------------------------------------- #
 _C.SOLVER = CN()
 
-_C.SOLVER.OPT = "adam"
+_C.SOLVER.OPT = "Adam"
 
 _C.SOLVER.MAX_ITER = 40000
 
@@ -141,9 +142,15 @@ _C.SOLVER.MOMENTUM = 0.9
 _C.SOLVER.WEIGHT_DECAY = 0.0005
 _C.SOLVER.WEIGHT_DECAY_BIAS = 0.
 
+_C.SOLVER.SCHED = "warmup"
+# warmup config
 _C.SOLVER.GAMMA = 0.1
 _C.SOLVER.STEPS = (30, 55)
 
+# cosine annealing
+_C.SOLVER.DELAY_ITERS = 100
+_C.SOLVER.COS_ANNEAL_ITERS = 100
+
 _C.SOLVER.WARMUP_FACTOR = 0.1
 _C.SOLVER.WARMUP_ITERS = 10
 _C.SOLVER.WARMUP_METHOD = "linear"
diff --git a/fastreid/data/transforms/build.py b/fastreid/data/transforms/build.py
index 730c519..6ce1dac 100644
--- a/fastreid/data/transforms/build.py
+++ b/fastreid/data/transforms/build.py
@@ -25,6 +25,11 @@ def build_transforms(cfg, is_train=True):
         do_re = cfg.INPUT.RE.ENABLED
         re_prob = cfg.INPUT.RE.PROB
         re_mean = cfg.INPUT.RE.MEAN
+        # cutout
+        do_cutout = cfg.INPUT.CUTOUT.ENABLED
+        cutout_prob = cfg.INPUT.CUTOUT.PROB
+        cutout_size = cfg.INPUT.CUTOUT.SIZE
+        cutout_mean = cfg.INPUT.CUTOUT.MEAN
         res.append(T.Resize(size_train, interpolation=3))
         if do_flip:
             res.append(T.RandomHorizontalFlip(p=flip_prob))
@@ -33,9 +38,9 @@ def build_transforms(cfg, is_train=True):
                         T.RandomCrop(size_train)])
         if do_re:
             res.append(RandomErasing(probability=re_prob, mean=re_mean))
-        # if cfg.INPUT.CUTOUT.DO:
-        #     res.append(Cutout(probability=cfg.INPUT.CUTOUT.PROB, size=cfg.INPUT.CUTOUT.SIZE,
-        #                       mean=cfg.INPUT.CUTOUT.MEAN))
+        if do_cutout:
+            res.append(Cutout(probability=cutout_prob, size=cutout_size,
+                              mean=cutout_mean))
     else:
         size_test = cfg.INPUT.SIZE_TEST
         res.append(T.Resize(size_test, interpolation=3))
diff --git a/fastreid/data/transforms/transforms.py b/fastreid/data/transforms/transforms.py
index c413800..1dac971 100644
--- a/fastreid/data/transforms/transforms.py
+++ b/fastreid/data/transforms/transforms.py
@@ -93,14 +93,13 @@ class Cutout(object):
         self.size = size
 
     def __call__(self, img):
-        img = np.asarray(img, dtype=np.uint8).copy()
+        img = np.asarray(img, dtype=np.float32).copy()
         if random.uniform(0, 1) > self.probability:
             return img
 
         h = self.size
         w = self.size
         for attempt in range(100):
-            area = img.shape[0] * img.shape[1]
             if w < img.shape[1] and h < img.shape[0]:
                 x1 = random.randint(0, img.shape[0] - h)
                 y1 = random.randint(0, img.shape[1] - w)
diff --git a/fastreid/engine/defaults.py b/fastreid/engine/defaults.py
index df91998..d0bd29e 100644
--- a/fastreid/engine/defaults.py
+++ b/fastreid/engine/defaults.py
@@ -17,10 +17,11 @@ import numpy as np
 import torch
 from torch.nn import DataParallel
 
+from . import hooks
+from .train_loop import SimpleTrainer
 from ..data import build_reid_test_loader, build_reid_train_loader
 from ..evaluation import (DatasetEvaluator, ReidEvaluator,
                           inference_on_dataset, print_csv_format)
-from ..modeling.losses import build_criterion
 from ..modeling.meta_arch import build_model
 from ..solver import build_lr_scheduler, build_optimizer
 from ..utils import comm
@@ -28,8 +29,6 @@ from ..utils.checkpoint import Checkpointer
 from ..utils.events import CommonMetricPrinter, JSONWriter, TensorboardXWriter
 from ..utils.file_io import PathManager
 from ..utils.logger import setup_logger
-from . import hooks
-from .train_loop import SimpleTrainer
 
 __all__ = ["default_argument_parser", "default_setup", "DefaultPredictor", "DefaultTrainer"]
 
@@ -198,19 +197,18 @@ class DefaultTrainer(SimpleTrainer):
         Args:
             cfg (CfgNode):
         """
-        logger = logging.getLogger("fastreid."+__name__)
+        logger = logging.getLogger("fastreid." + __name__)
         if not logger.isEnabledFor(logging.INFO):  # setup_logger is not called for d2
             setup_logger()
         # Assume these objects must be constructed in this order.
         model = self.build_model(cfg)
         optimizer = self.build_optimizer(cfg, model)
         data_loader = self.build_train_loader(cfg)
-        criterion = self.build_criterion(cfg)
 
         # For training, wrap with DP. But don't need this for inference.
         model = DataParallel(model)
         model = model.cuda()
-        super().__init__(model, data_loader, optimizer, criterion)
+        super().__init__(model, data_loader, optimizer)
 
         self.scheduler = self.build_lr_scheduler(cfg, optimizer)
         # Assume no other objects need to be checkpointed.
@@ -338,10 +336,6 @@ class DefaultTrainer(SimpleTrainer):
         # logger.info("Model:\n{}".format(model))
         return model
 
-    @classmethod
-    def build_criterion(cls, cfg):
-        return build_criterion(cfg)
-
     @classmethod
     def build_optimizer(cls, cfg, model):
         """
diff --git a/fastreid/engine/train_loop.py b/fastreid/engine/train_loop.py
index 510b80d..cac96e5 100644
--- a/fastreid/engine/train_loop.py
+++ b/fastreid/engine/train_loop.py
@@ -160,7 +160,7 @@ class SimpleTrainer(TrainerBase):
     or write your own training loop.
     """
 
-    def __init__(self, model, data_loader, optimizer, criterion):
+    def __init__(self, model, data_loader, optimizer):
         """
         Args:
             model: a torch Module. Takes a data from data_loader and returns a
@@ -181,7 +181,6 @@ class SimpleTrainer(TrainerBase):
         self.model = model
         self.data_loader = data_loader
         self.optimizer = optimizer
-        self.criterion = criterion
 
     def run_step(self):
         """
@@ -199,7 +198,7 @@ class SimpleTrainer(TrainerBase):
         If your want to do something with the heads, you can wrap the model.
         """
         outputs = self.model(data)
-        loss_dict = self.criterion(*outputs)
+        loss_dict = self.model.module.losses(outputs)
         losses = sum(loss for loss in loss_dict.values())
         self._detect_anomaly(losses, loss_dict)
 
diff --git a/fastreid/evaluation/rank.py b/fastreid/evaluation/rank.py
index 63dc77d..c92a0c6 100644
--- a/fastreid/evaluation/rank.py
+++ b/fastreid/evaluation/rank.py
@@ -101,12 +101,10 @@ def eval_market1501(distmat, q_pids, g_pids, q_camids, g_camids, max_rank):
 
     if num_g < max_rank:
         max_rank = num_g
-        print(
-            'Note: number of gallery samples is quite small, got {}'.
-                format(num_g)
-        )
+        print('Note: number of gallery samples is quite small, got {}'.format(num_g))
 
     indices = np.argsort(distmat, axis=1)
+
     matches = (g_pids[indices] == q_pids[:, np.newaxis]).astype(np.int32)
 
     # compute cmc curve for each query
diff --git a/fastreid/evaluation/rank_cylib/rank_cy.pyx b/fastreid/evaluation/rank_cylib/rank_cy.pyx
index bf568d6..dd34a12 100644
--- a/fastreid/evaluation/rank_cylib/rank_cy.pyx
+++ b/fastreid/evaluation/rank_cylib/rank_cy.pyx
@@ -163,6 +163,7 @@ cpdef eval_market1501_cy(float[:,:] distmat, long[:] q_pids, long[:]g_pids,
 
         float[:,:] all_cmc = np.zeros((num_q, max_rank), dtype=np.float32)
         float[:] all_AP = np.zeros(num_q, dtype=np.float32)
+        float[:] all_INP = np.zeros(num_q, dtype=np.float32)
         float num_valid_q = 0. # number of valid query
 
         long q_idx, q_pid, q_camid, g_idx
@@ -171,6 +172,8 @@ cpdef eval_market1501_cy(float[:,:] distmat, long[:] q_pids, long[:]g_pids,
 
         float[:] raw_cmc = np.zeros(num_g, dtype=np.float32) # binary vector, positions with value 1 are correct matches
         float[:] cmc = np.zeros(num_g, dtype=np.float32)
+        long max_pos_idx = 0
+        float inp
         long num_g_real, rank_idx
         unsigned long meet_condition
 
@@ -183,16 +186,17 @@ cpdef eval_market1501_cy(float[:,:] distmat, long[:] q_pids, long[:]g_pids,
         q_pid = q_pids[q_idx]
         q_camid = q_camids[q_idx]
 
-        # remove gallery samples that have the same pid and camid with query
         for g_idx in range(num_g):
             order[g_idx] = indices[q_idx, g_idx]
         num_g_real = 0
         meet_condition = 0
 
+        # remove gallery samples that have the same pid and camid with query
         for g_idx in range(num_g):
             if (g_pids[order[g_idx]] != q_pid) or (g_camids[order[g_idx]] != q_camid):
                 raw_cmc[num_g_real] = matches[q_idx][g_idx]
                 num_g_real += 1
+                # this condition is true if query appear in gallery
                 if matches[q_idx][g_idx] > 1e-31:
                     meet_condition = 1
 
@@ -202,6 +206,15 @@ cpdef eval_market1501_cy(float[:,:] distmat, long[:] q_pids, long[:]g_pids,
 
         # compute cmc
         function_cumsum(raw_cmc, cmc, num_g_real)
+        # compute mean inverse negative penalty
+        # reference : https://github.com/mangye16/ReID-Survey/blob/master/utils/reid_metric.py
+        max_pos_idx = 0
+        for g_idx in range(num_g_real):
+            if (raw_cmc[g_idx] == 1) and (g_idx > max_pos_idx):
+                max_pos_idx = g_idx
+        inp = cmc[max_pos_idx] / (max_pos_idx + 1.0)
+        all_INP[q_idx] = inp
+
         for g_idx in range(num_g_real):
             if cmc[g_idx] > 1:
                 cmc[g_idx] = 1
@@ -230,11 +243,14 @@ cpdef eval_market1501_cy(float[:,:] distmat, long[:] q_pids, long[:]g_pids,
         avg_cmc[rank_idx] /= num_valid_q
 
     cdef float mAP = 0
+    cdef float mINP = 0
     for q_idx in range(num_q):
         mAP += all_AP[q_idx]
+        mINP += all_INP[q_idx]
     mAP /= num_valid_q
+    mINP /= num_valid_q
 
-    return np.asarray(avg_cmc).astype(np.float32), mAP
+    return np.asarray(avg_cmc).astype(np.float32), mAP, mINP
 
 
 # Compute the cumulative sum
diff --git a/fastreid/evaluation/rank_cylib/test_cython.py b/fastreid/evaluation/rank_cylib/test_cython.py
index e7f46a6..b6a1213 100644
--- a/fastreid/evaluation/rank_cylib/test_cython.py
+++ b/fastreid/evaluation/rank_cylib/test_cython.py
@@ -33,36 +33,37 @@ q_camids = np.random.randint(0, 5, size=num_q)
 g_camids = np.random.randint(0, 5, size=num_g)
 '''
 
-print('=> Using market1501\'s metric')
-pytime = timeit.timeit(
-    'evaluation.evaluate_rank(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_cython=False)',
-    setup=setup,
-    number=20
-)
-cytime = timeit.timeit(
-    'evaluation.evaluate_rank(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_cython=True)',
-    setup=setup,
-    number=20
-)
-print('Python time: {} s'.format(pytime))
-print('Cython time: {} s'.format(cytime))
-print('Cython is {} times faster than python\n'.format(pytime / cytime))
+# print('=> Using market1501\'s metric')
+# pytime = timeit.timeit(
+#     'evaluation.evaluate_rank(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_cython=False)',
+#     setup=setup,
+#     number=20
+# )
+# cytime = timeit.timeit(
+#     'evaluation.evaluate_rank(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_cython=True)',
+#     setup=setup,
+#     number=20
+# )
+# print('Python time: {} s'.format(pytime))
+# print('Cython time: {} s'.format(cytime))
+# print('Cython is {} times faster than python\n'.format(pytime / cytime))
+#
+# print('=> Using cuhk03\'s metric')
+# pytime = timeit.timeit(
+#     'evaluation.evaluate_rank(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_metric_cuhk03=True, use_cython=False)',
+#     setup=setup,
+#     number=20
+# )
+# cytime = timeit.timeit(
+#     'evaluation.evaluate_rank(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_metric_cuhk03=True, use_cython=True)',
+#     setup=setup,
+#     number=20
+# )
+# print('Python time: {} s'.format(pytime))
+# print('Cython time: {} s'.format(cytime))
+# print('Cython is {} times faster than python\n'.format(pytime / cytime))
 
-print('=> Using cuhk03\'s metric')
-pytime = timeit.timeit(
-    'evaluation.evaluate_rank(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_metric_cuhk03=True, use_cython=False)',
-    setup=setup,
-    number=20
-)
-cytime = timeit.timeit(
-    'evaluation.evaluate_rank(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_metric_cuhk03=True, use_cython=True)',
-    setup=setup,
-    number=20
-)
-print('Python time: {} s'.format(pytime))
-print('Cython time: {} s'.format(cytime))
-print('Cython is {} times faster than python\n'.format(pytime / cytime))
-"""
+from fastreid.evaluation import evaluate_rank
 print("=> Check precision")
 num_q = 30
 num_g = 300
@@ -72,8 +73,7 @@ q_pids = np.random.randint(0, num_q, size=num_q)
 g_pids = np.random.randint(0, num_g, size=num_g)
 q_camids = np.random.randint(0, 5, size=num_q)
 g_camids = np.random.randint(0, 5, size=num_g)
-cmc, mAP = evaluate(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_cython=False)
-print("Python:\nmAP = {} \ncmc = {}\n".format(mAP, cmc))
-cmc, mAP = evaluate(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_cython=True)
-print("Cython:\nmAP = {} \ncmc = {}\n".format(mAP, cmc))
-"""
+cmc, mAP, mINP = evaluate_rank(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_cython=False)
+print("Python:\nmAP = {} \ncmc = {}\nmINP = {}".format(mAP, cmc, mINP))
+cmc, mAP, mINP = evaluate_rank(distmat, q_pids, g_pids, q_camids, g_camids, max_rank, use_cython=True)
+print("Cython:\nmAP = {} \ncmc = {}\nmINP = {}".format(mAP, cmc, mINP))
diff --git a/fastreid/evaluation/reid_evaluation.py b/fastreid/evaluation/reid_evaluation.py
index 4b709b1..db8afa9 100644
--- a/fastreid/evaluation/reid_evaluation.py
+++ b/fastreid/evaluation/reid_evaluation.py
@@ -48,10 +48,10 @@ class ReidEvaluator(DatasetEvaluator):
         self._results = OrderedDict()
 
         cos_dist = torch.mm(query_features, gallery_features.t()).numpy()
-        cmc, mAP = evaluate_rank(1 - cos_dist, query_pids, gallery_pids, query_camids, gallery_camids)
+        cmc, mAP, mINP = evaluate_rank(1 - cos_dist, query_pids, gallery_pids, query_camids, gallery_camids)
         for r in [1, 5, 10]:
             self._results['Rank-{}'.format(r)] = cmc[r - 1]
         self._results['mAP'] = mAP
-        self._results['mINP'] = 0
+        self._results['mINP'] = mINP
 
         return copy.deepcopy(self._results)
diff --git a/fastreid/export/tensorflow_export.py b/fastreid/export/tensorflow_export.py
index 0edfadb..13d6ea7 100644
--- a/fastreid/export/tensorflow_export.py
+++ b/fastreid/export/tensorflow_export.py
@@ -8,7 +8,6 @@ import warnings
 
 warnings.filterwarnings('ignore')  # Ignore all the warning messages in this tutorial
 from onnx_tf.backend import prepare
-from onnx import optimizer
 
 import tensorflow as tf
 from PIL import Image
@@ -19,16 +18,14 @@ import numpy as np
 import torch
 from torch.backends import cudnn
 import io
-import sys
-
-sys.path.insert(0, './')
-
-from modeling import Baseline
 
 cudnn.benchmark = True
 
 
 def _export_via_onnx(model, inputs):
+    from ipdb import set_trace;
+    set_trace()
+
     def _check_val(module):
         assert not module.training
 
@@ -58,10 +55,10 @@ def _export_via_onnx(model, inputs):
     # )
 
     # Apply ONNX's Optimization
-    all_passes = optimizer.get_available_passes()
-    passes = ["fuse_bn_into_conv"]
-    assert all(p in all_passes for p in passes)
-    onnx_model = optimizer.optimize(onnx_model, passes)
+    # all_passes = optimizer.get_available_passes()
+    # passes = ["fuse_bn_into_conv"]
+    # assert all(p in all_passes for p in passes)
+    # onnx_model = optimizer.optimize(onnx_model, passes)
 
     # Convert ONNX Model to Tensorflow Model
     tf_rep = prepare(onnx_model, strict=False)  # Import the ONNX model to Tensorflow
@@ -158,154 +155,155 @@ def export_tf_reid_model(model: torch.nn.Module, tensor_inputs: torch.Tensor, gr
     print("Checking if tf.pb is right")
     _check_pytorch_tf_model(model, graph_save_path)
 
+# if __name__ == '__main__':
+# args = default_argument_parser().parse_args()
+# print("Command Line Args:", args)
+# cfg = setup(args)
+# cfg = cfg.defrost()
+# cfg.MODEL.BACKBONE.NAME = "build_resnet_backbone"
+# cfg.MODEL.BACKBONE.DEPTH = 50
+# cfg.MODEL.BACKBONE.LAST_STRIDE = 1
+# # If use IBN block in backbone
+# cfg.MODEL.BACKBONE.WITH_IBN = True
+#
+# model = build_model(cfg)
+# # model.load_params_wo_fc(torch.load('logs/bjstation/res50_baseline_v0.4/ckpts/model_epoch80.pth'))
+# model.cuda()
+# model.eval()
+# dummy_inputs = torch.randn(1, 3, 256, 128)
+# export_tf_reid_model(model, dummy_inputs, 'reid_tf.pb')
 
-if __name__ == '__main__':
-    model = Baseline('resnet50',
-                     num_classes=0,
-                     last_stride=1,
-                     with_ibn=False,
-                     with_se=False,
-                     gcb=None,
-                     stage_with_gcb=[False, False, False, False],
-                     pretrain=False,
-                     model_path='')
-    model.load_params_wo_fc(torch.load('logs/bjstation/res50_baseline_v0.4/ckpts/model_epoch80.pth'))
-    # model.cuda()
-    model.eval()
-    dummy_inputs = torch.randn(1, 3, 384, 128)
-    export_tf_reid_model(model, dummy_inputs, 'reid_tf.pb')
+# inputs = torch.rand(1, 3, 384, 128).cuda()
+#
+# _export_via_onnx(model, inputs)
+# onnx_model = onnx.load("reid_test.onnx")
+# onnx.checker.check_model(onnx_model)
+#
+# from PIL import Image
+# import torchvision.transforms as transforms
+#
+# img = Image.open("demo_imgs/dog.jpg")
+#
+# resize = transforms.Resize([384, 128])
+# img = resize(img)
+#
+# to_tensor = transforms.ToTensor()
+# img = to_tensor(img)
+# img.unsqueeze_(0)
+# img = img.cuda()
+#
+# with torch.no_grad():
+#     torch_out = model(img)
+#
+# ort_session = onnxruntime.InferenceSession("reid_test.onnx")
+#
+# # compute ONNX Runtime output prediction
+# ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(img)}
+# ort_outs = ort_session.run(None, ort_inputs)
+# img_out_y = ort_outs[0]
+#
+#
+# # compare ONNX Runtime and PyTorch results
+# np.testing.assert_allclose(to_numpy(torch_out), ort_outs[0], rtol=1e-03, atol=1e-05)
+#
+# print("Exported model has been tested with ONNXRuntime, and the result looks good!")
 
-    # inputs = torch.rand(1, 3, 384, 128).cuda()
-    #
-    # _export_via_onnx(model, inputs)
-    # onnx_model = onnx.load("reid_test.onnx")
-    # onnx.checker.check_model(onnx_model)
-    #
-    # from PIL import Image
-    # import torchvision.transforms as transforms
-    #
-    # img = Image.open("demo_imgs/dog.jpg")
-    #
-    # resize = transforms.Resize([384, 128])
-    # img = resize(img)
-    #
-    # to_tensor = transforms.ToTensor()
-    # img = to_tensor(img)
-    # img.unsqueeze_(0)
-    # img = img.cuda()
-    #
-    # with torch.no_grad():
-    #     torch_out = model(img)
-    #
-    # ort_session = onnxruntime.InferenceSession("reid_test.onnx")
-    #
-    # # compute ONNX Runtime output prediction
-    # ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(img)}
-    # ort_outs = ort_session.run(None, ort_inputs)
-    # img_out_y = ort_outs[0]
-    #
-    #
-    # # compare ONNX Runtime and PyTorch results
-    # np.testing.assert_allclose(to_numpy(torch_out), ort_outs[0], rtol=1e-03, atol=1e-05)
-    #
-    # print("Exported model has been tested with ONNXRuntime, and the result looks good!")
+# img = Image.open("demo_imgs/dog.jpg")
+#
+# resize = transforms.Resize([384, 128])
+# img = resize(img)
+#
+# to_tensor = transforms.ToTensor()
+# img = to_tensor(img)
+# img.unsqueeze_(0)
+# img = torch.cat([img.clone(), img.clone()], dim=0)
 
-    # img = Image.open("demo_imgs/dog.jpg")
-    #
-    # resize = transforms.Resize([384, 128])
-    # img = resize(img)
-    #
-    # to_tensor = transforms.ToTensor()
-    # img = to_tensor(img)
-    # img.unsqueeze_(0)
-    # img = torch.cat([img.clone(), img.clone()], dim=0)
+# ort_session = onnxruntime.InferenceSession("reid_test.onnx")
 
-    # ort_session = onnxruntime.InferenceSession("reid_test.onnx")
+# # compute ONNX Runtime output prediction
+# ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(img)}
+# ort_outs = ort_session.run(None, ort_inputs)
 
-    # # compute ONNX Runtime output prediction
-    # ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(img)}
-    # ort_outs = ort_session.run(None, ort_inputs)
+# model = onnx.load('reid_test.onnx')  # Load the ONNX file
+# tf_rep = prepare(model, strict=False) # Import the ONNX model to Tensorflow
+# print(tf_rep.inputs)  # Input nodes to the model
+# print('-----')
+# print(tf_rep.outputs)  # Output nodes from the model
+# print('-----')
+# # print(tf_rep.tensor_dict)  # All nodes in the model
 
-    # model = onnx.load('reid_test.onnx')  # Load the ONNX file
-    # tf_rep = prepare(model, strict=False) # Import the ONNX model to Tensorflow
-    # print(tf_rep.inputs)  # Input nodes to the model
-    # print('-----')
-    # print(tf_rep.outputs)  # Output nodes from the model
-    # print('-----')
-    # # print(tf_rep.tensor_dict)  # All nodes in the model
+# install onnx-tensorflow from github,and tf_rep = prepare(onnx_model, strict=False)
+# Reference https://github.com/onnx/onnx-tensorflow/issues/167
+# tf_rep = prepare(onnx_model) # whthout strict=False leads to KeyError: 'pyfunc_0'
 
-    # install onnx-tensorflow from github,and tf_rep = prepare(onnx_model, strict=False)
-    # Reference https://github.com/onnx/onnx-tensorflow/issues/167
-    # tf_rep = prepare(onnx_model) # whthout strict=False leads to KeyError: 'pyfunc_0'
+# # debug, here using the same input to check onnx and tf.
+# # output_onnx_tf = tf_rep.run(to_numpy(img))
+# # print('output_onnx_tf = {}'.format(output_onnx_tf))
+# # onnx --> tf.graph.pb
+# tf_pb_path = 'reid_tf_graph.pb'
+# tf_rep.export_graph(tf_pb_path)
 
-    # # debug, here using the same input to check onnx and tf.
-    # # output_onnx_tf = tf_rep.run(to_numpy(img))
-    # # print('output_onnx_tf = {}'.format(output_onnx_tf))
-    # # onnx --> tf.graph.pb
-    # tf_pb_path = 'reid_tf_graph.pb'
-    # tf_rep.export_graph(tf_pb_path)
+# # step 3, check if tf.pb is right.
+# with tf.Graph().as_default():
+#     graph_def = tf.GraphDef()
+#     with open(tf_pb_path, "rb") as f:
+#         graph_def.ParseFromString(f.read())
+#         tf.import_graph_def(graph_def, name="")
+#     with tf.Session() as sess:
+#         # init = tf.initialize_all_variables()
+#         init = tf.global_variables_initializer()
+#         # sess.run(init)
 
-    # # step 3, check if tf.pb is right.
-    # with tf.Graph().as_default():
-    #     graph_def = tf.GraphDef()
-    #     with open(tf_pb_path, "rb") as f:
-    #         graph_def.ParseFromString(f.read())
-    #         tf.import_graph_def(graph_def, name="")
-    #     with tf.Session() as sess:
-    #         # init = tf.initialize_all_variables()
-    #         init = tf.global_variables_initializer()
-    #         # sess.run(init)
+#         # print all ops, check input/output tensor name.
+#         # uncomment it if you donnot know io tensor names.
+#         '''
+#         print('-------------ops---------------------')
+#         op = sess.graph.get_operations()
+#         for m in op:
+#             try:
+#                 # if 'input' in m.values()[0].name:
+#                 #     print(m.values())
+#                 if m.values()[0].shape.as_list()[1] == 2048: #and (len(m.values()[0].shape.as_list()) == 4):
+#                     print(m.values())
+#             except:
+#                 pass
+#         print('-------------ops done.---------------------')
+#         '''
+#         input_x = sess.graph.get_tensor_by_name('input.1:0')  # input
+#         outputs = sess.graph.get_tensor_by_name('502:0')  # 5
+#         output_tf_pb = sess.run(outputs, feed_dict={input_x: to_numpy(img)})
+#         print('output_tf_pb = {}'.format(output_tf_pb))
+# np.testing.assert_allclose(ort_outs[0], output_tf_pb, rtol=1e-03, atol=1e-05)
 
-    #         # print all ops, check input/output tensor name.
-    #         # uncomment it if you donnot know io tensor names.
-    #         '''
-    #         print('-------------ops---------------------')
-    #         op = sess.graph.get_operations()
-    #         for m in op:
-    #             try:
-    #                 # if 'input' in m.values()[0].name:
-    #                 #     print(m.values())
-    #                 if m.values()[0].shape.as_list()[1] == 2048: #and (len(m.values()[0].shape.as_list()) == 4):
-    #                     print(m.values())
-    #             except:
-    #                 pass
-    #         print('-------------ops done.---------------------')
-    #         '''
-    #         input_x = sess.graph.get_tensor_by_name('input.1:0')  # input
-    #         outputs = sess.graph.get_tensor_by_name('502:0')  # 5
-    #         output_tf_pb = sess.run(outputs, feed_dict={input_x: to_numpy(img)})
-    #         print('output_tf_pb = {}'.format(output_tf_pb))
-    # np.testing.assert_allclose(ort_outs[0], output_tf_pb, rtol=1e-03, atol=1e-05)
-
-    # with tf.Graph().as_default():
-    #     graph_def = tf.GraphDef()
-    #     with open(tf_pb_path, "rb") as f:
-    #         graph_def.ParseFromString(f.read())
-    #         tf.import_graph_def(graph_def, name="")
-    #     with tf.Session() as sess:
-    #         # init = tf.initialize_all_variables()
-    #         init = tf.global_variables_initializer()
-    #         # sess.run(init)
-    #
-    #         # print all ops, check input/output tensor name.
-    #         # uncomment it if you donnot know io tensor names.
-    #         '''
-    #         print('-------------ops---------------------')
-    #         op = sess.graph.get_operations()
-    #         for m in op:
-    #             try:
-    #                 # if 'input' in m.values()[0].name:
-    #                 #     print(m.values())
-    #                 if m.values()[0].shape.as_list()[1] == 2048: #and (len(m.values()[0].shape.as_list()) == 4):
-    #                     print(m.values())
-    #             except:
-    #                 pass
-    #         print('-------------ops done.---------------------')
-    #         '''
-    #         input_x = sess.graph.get_tensor_by_name('input.1:0')  # input
-    #         outputs = sess.graph.get_tensor_by_name('502:0')  # 5
-    #         output_tf_pb = sess.run(outputs, feed_dict={input_x: to_numpy(img)})
-    #         from ipdb import set_trace;
-    #
-    #         set_trace()
-    #         print('output_tf_pb = {}'.format(output_tf_pb))
+# with tf.Graph().as_default():
+#     graph_def = tf.GraphDef()
+#     with open(tf_pb_path, "rb") as f:
+#         graph_def.ParseFromString(f.read())
+#         tf.import_graph_def(graph_def, name="")
+#     with tf.Session() as sess:
+#         # init = tf.initialize_all_variables()
+#         init = tf.global_variables_initializer()
+#         # sess.run(init)
+#
+#         # print all ops, check input/output tensor name.
+#         # uncomment it if you donnot know io tensor names.
+#         '''
+#         print('-------------ops---------------------')
+#         op = sess.graph.get_operations()
+#         for m in op:
+#             try:
+#                 # if 'input' in m.values()[0].name:
+#                 #     print(m.values())
+#                 if m.values()[0].shape.as_list()[1] == 2048: #and (len(m.values()[0].shape.as_list()) == 4):
+#                     print(m.values())
+#             except:
+#                 pass
+#         print('-------------ops done.---------------------')
+#         '''
+#         input_x = sess.graph.get_tensor_by_name('input.1:0')  # input
+#         outputs = sess.graph.get_tensor_by_name('502:0')  # 5
+#         output_tf_pb = sess.run(outputs, feed_dict={input_x: to_numpy(img)})
+#         from ipdb import set_trace;
+#
+#         set_trace()
+#         print('output_tf_pb = {}'.format(output_tf_pb))
diff --git a/fastreid/export/tf_modeling.py b/fastreid/export/tf_modeling.py
new file mode 100644
index 0000000..5ba4cef
--- /dev/null
+++ b/fastreid/export/tf_modeling.py
@@ -0,0 +1,20 @@
+# encoding: utf-8
+"""
+@author:  l1aoxingyu
+@contact: sherlockliao01@gmail.com
+"""
+from torch import nn
+from ..modeling.backbones import build_backbone
+from ..modeling.heads import build_reid_heads
+
+
+class TfMetaArch(nn.Module):
+    def __init__(self, cfg):
+        super().__init__()
+        self.backbone = build_backbone(cfg)
+        self.heads = build_reid_heads(cfg)
+
+    def forward(self, x):
+        global_feat = self.backbone(x)
+        pred_features = self.heads(global_feat)
+        return pred_features
diff --git a/fastreid/layers/__init__.py b/fastreid/layers/__init__.py
index ca72f52..2060af0 100644
--- a/fastreid/layers/__init__.py
+++ b/fastreid/layers/__init__.py
@@ -5,18 +5,15 @@
 """
 from torch import nn
 
-from .context_block import ContextBlock
 from .batch_drop import BatchDrop
+from .attention import *
 from .batch_norm import bn_no_bias
-from .pooling import GeM
+from .context_block import ContextBlock
 from .frn import FRN, TLU
+from .mish import Mish
+from .gem_pool import GeneralizedMeanPoolingP
 
 
-class Lambda(nn.Module):
-    "Create a layer that simply calls `func` with `x`"
-    def __init__(self, func):
-        super().__init__()
-        self.func=func
-
-    def forward(self, x):
-        return self.func(x)
\ No newline at end of file
+class Flatten(nn.Module):
+    def forward(self, input):
+        return input.view(input.size(0), -1)
diff --git a/fastreid/layers/attention.py b/fastreid/layers/attention.py
new file mode 100644
index 0000000..41a45d4
--- /dev/null
+++ b/fastreid/layers/attention.py
@@ -0,0 +1,177 @@
+# encoding: utf-8
+"""
+@author:  CASIA IVA
+@contact: jliu@nlpr.ia.ac.cn
+"""
+
+import torch
+from torch.nn import Module, Conv2d, Parameter, Softmax
+import torch.nn as nn
+
+__all__ = ['PAM_Module', 'CAM_Module', 'DANetHead',]
+
+
+class DANetHead(nn.Module):
+    def __init__(self,
+                 in_channels: int,
+                 out_channels: int,
+                 norm_layer: nn.Module,
+                 module_class: type,
+                 dim_collapsion: int=2):
+        super(DANetHead, self).__init__()
+
+        inter_channels = in_channels // dim_collapsion
+
+        self.conv5c = nn.Sequential(
+            nn.Conv2d(
+                in_channels,
+                inter_channels,
+                3,
+                padding=1,
+                bias=False
+            ),
+            norm_layer(inter_channels),
+            nn.ReLU()
+        )
+
+        self.attention_module = module_class(inter_channels)
+        self.conv52 = nn.Sequential(
+            nn.Conv2d(
+                inter_channels,
+                inter_channels,
+                3,
+                padding=1,
+                bias=False
+            ),
+            norm_layer(inter_channels),
+            nn.ReLU()
+        )
+
+        self.conv7 = nn.Sequential(
+            nn.Dropout2d(0.1, False),
+            nn.Conv2d(inter_channels, out_channels, 1)
+        )
+
+    def forward(self, x):
+
+        feat2 = self.conv5c(x)
+        sc_feat = self.attention_module(feat2)
+        sc_conv = self.conv52(sc_feat)
+        sc_output = self.conv7(sc_conv)
+
+        return sc_output
+
+
+class PAM_Module(nn.Module):
+    """ Position attention module"""
+    # Ref from SAGAN
+
+    def __init__(self, in_dim):
+        super(PAM_Module, self).__init__()
+        self.channel_in = in_dim
+
+        self.query_conv = Conv2d(
+            in_channels=in_dim,
+            out_channels=in_dim // 8,
+            kernel_size=1
+        )
+        self.key_conv = Conv2d(
+            in_channels=in_dim,
+            out_channels=in_dim // 8,
+            kernel_size=1
+        )
+        self.value_conv = Conv2d(
+            in_channels=in_dim,
+            out_channels=in_dim,
+            kernel_size=1
+        )
+        self.gamma = Parameter(torch.zeros(1))
+
+        self.softmax = Softmax(dim=-1)
+
+    def forward(self, x):
+        """
+            inputs :
+                x : input feature maps( B X C X H X W)
+            returns :
+                out : attention value + input feature
+                attention: B X (HxW) X (HxW)
+        """
+        m_batchsize, C, height, width = x.size()
+        proj_query = self.query_conv(x).view(m_batchsize, -1, width * height).permute(0, 2, 1)
+        proj_key = self.key_conv(x).view(m_batchsize, -1, width * height)
+        energy = torch.bmm(proj_query, proj_key)
+        attention = self.softmax(energy)
+        proj_value = self.value_conv(x).view(m_batchsize, -1, width * height)
+
+        out = torch.bmm(
+            proj_value,
+            attention.permute(0, 2, 1)
+        )
+        attention_mask = out.view(m_batchsize, C, height, width)
+
+        out = self.gamma * attention_mask + x
+        return out
+
+
+class CAM_Module(nn.Module):
+    """ Channel attention module"""
+
+    def __init__(self, in_dim):
+        super().__init__()
+        self.channel_in = in_dim
+
+        self.gamma = Parameter(torch.zeros(1))
+        self.softmax = Softmax(dim=-1)
+
+    def forward(self, x):
+        """
+            inputs :
+                x : input feature maps( B X C X H X W)
+            returns :
+                out : attention value + input feature
+                attention: B X C X C
+        """
+        m_batchsize, C, height, width = x.size()
+        proj_query = x.view(m_batchsize, C, -1)
+        proj_key = x.view(m_batchsize, C, -1).permute(0, 2, 1)
+        energy = torch.bmm(proj_query, proj_key)
+        max_energy_0 = torch.max(energy, -1, keepdim=True)[0].expand_as(energy)
+        energy_new = max_energy_0 - energy
+        attention = self.softmax(energy_new)
+        proj_value = x.view(m_batchsize, C, -1)
+
+        out = torch.bmm(attention, proj_value)
+        out = out.view(m_batchsize, C, height, width)
+
+        gamma = self.gamma.to(out.device)
+        out = gamma * out + x
+        return out
+
+
+# def get_attention_module_instance(
+#     name: 'cam | pam | identity',
+#     dim: int,
+#     *,
+#     out_dim=None,
+#     use_head: bool=False,
+#     dim_collapsion=2  # Used iff `used_head` set to True
+# ):
+#
+#     name = name.lower()
+#     assert name in ('cam', 'pam', 'identity')
+#
+#     module_class = name_module_class_mapping[name]
+#
+#     if out_dim is None:
+#         out_dim = dim
+#
+#     if use_head:
+#         return DANetHead(
+#             dim, out_dim,
+#             nn.BatchNorm2d,
+#             module_class,
+#             dim_collapsion=dim_collapsion
+#         )
+#     else:
+#         return module_class(dim)
\ No newline at end of file
diff --git a/fastreid/layers/batch_drop.py b/fastreid/layers/batch_drop.py
index b4250ec..5c25697 100644
--- a/fastreid/layers/batch_drop.py
+++ b/fastreid/layers/batch_drop.py
@@ -5,15 +5,17 @@
 """
 
 import random
+
 from torch import nn
 
 
 class BatchDrop(nn.Module):
-    """Copy from https://github.com/daizuozhuo/batch-dropblock-network/blob/master/models/networks.py
+    """ref: https://github.com/daizuozhuo/batch-dropblock-network/blob/master/models/networks.py
     batch drop mask
     """
+
     def __init__(self, h_ratio, w_ratio):
-        super().__init__()
+        super(BatchDrop, self).__init__()
         self.h_ratio = h_ratio
         self.w_ratio = w_ratio
 
@@ -22,9 +24,9 @@ class BatchDrop(nn.Module):
             h, w = x.size()[-2:]
             rh = round(self.h_ratio * h)
             rw = round(self.w_ratio * w)
-            sx = random.randint(0, h-rh)
-            sy = random.randint(0, w-rw)
+            sx = random.randint(0, h - rh)
+            sy = random.randint(0, w - rw)
             mask = x.new_ones(x.size())
-            mask[:, :, sx:sx+rh, sy:sy+rw] = 0
+            mask[:, :, sx:sx + rh, sy:sy + rw] = 0
             x = x * mask
-        return x
\ No newline at end of file
+        return x
diff --git a/fastreid/layers/batch_norm.py b/fastreid/layers/batch_norm.py
index 2b08bbd..621d0b3 100644
--- a/fastreid/layers/batch_norm.py
+++ b/fastreid/layers/batch_norm.py
@@ -10,4 +10,4 @@ from torch import nn
 def bn_no_bias(in_features):
     bn_layer = nn.BatchNorm1d(in_features)
     bn_layer.bias.requires_grad_(False)
-    return bn_layer
\ No newline at end of file
+    return bn_layer
diff --git a/fastreid/layers/gem_pool.py b/fastreid/layers/gem_pool.py
new file mode 100644
index 0000000..ca94b32
--- /dev/null
+++ b/fastreid/layers/gem_pool.py
@@ -0,0 +1,49 @@
+# encoding: utf-8
+"""
+@author:  l1aoxingyu
+@contact: sherlockliao01@gmail.com
+"""
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+
+class GeneralizedMeanPooling(nn.Module):
+    r"""Applies a 2D power-average adaptive pooling over an input signal composed of several input planes.
+    The function computed is: :math:`f(X) = pow(sum(pow(X, p)), 1/p)`
+        - At p = infinity, one gets Max Pooling
+        - At p = 1, one gets Average Pooling
+    The output is of size H x W, for any input size.
+    The number of output features is equal to the number of input planes.
+    Args:
+        output_size: the target output size of the image of the form H x W.
+                     Can be a tuple (H, W) or a single H for a square image H x H
+                     H and W can be either a ``int``, or ``None`` which means the size will
+                     be the same as that of the input.
+    """
+
+    def __init__(self, norm, output_size=1, eps=1e-6):
+        super(GeneralizedMeanPooling, self).__init__()
+        assert norm > 0
+        self.p = float(norm)
+        self.output_size = output_size
+        self.eps = eps
+
+    def forward(self, x):
+        x = x.clamp(min=self.eps).pow(self.p)
+        return torch.nn.functional.adaptive_avg_pool2d(x, self.output_size).pow(1. / self.p)
+
+    def __repr__(self):
+        return self.__class__.__name__ + '(' \
+               + str(self.p) + ', ' \
+               + 'output_size=' + str(self.output_size) + ')'
+
+
+class GeneralizedMeanPoolingP(GeneralizedMeanPooling):
+    """ Same, but norm is trainable
+    """
+
+    def __init__(self, norm=3, output_size=1, eps=1e-6):
+        super(GeneralizedMeanPoolingP, self).__init__(norm, output_size, eps)
+        self.p = nn.Parameter(torch.ones(1) * norm)
diff --git a/fastreid/layers/mish.py b/fastreid/layers/mish.py
new file mode 100644
index 0000000..3fb7ce1
--- /dev/null
+++ b/fastreid/layers/mish.py
@@ -0,0 +1,22 @@
+####
+# CODE TAKEN FROM https://github.com/lessw2020/mish
+# ORIGINAL PAPER https://arxiv.org/abs/1908.08681v1
+####
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F  #(uncomment if needed,but you likely already have it)
+
+#Mish - "Mish: A Self Regularized Non-Monotonic Neural Activation Function"
+#https://arxiv.org/abs/1908.08681v1
+#implemented for PyTorch / FastAI by lessw2020 
+#github: https://github.com/lessw2020/mish
+
+
+class Mish(nn.Module):
+    def __init__(self):
+        super().__init__()
+
+    def forward(self, x):
+        #inlining this saves 1 second per epoch (V100 GPU) vs having a temp x and then returning x(!)
+        return x *( torch.tanh(F.softplus(x)))
\ No newline at end of file
diff --git a/fastreid/layers/pooling.py b/fastreid/layers/pooling.py
deleted file mode 100644
index b27b156..0000000
--- a/fastreid/layers/pooling.py
+++ /dev/null
@@ -1,22 +0,0 @@
-# encoding: utf-8
-"""
-@author:  liaoxingyu
-@contact: sherlockliao01@gmail.com
-"""
-
-import torch
-from torch import nn
-from torch.nn.parameter import Parameter
-import torch.nn.functional as F
-
-__all__ = ['GeM',]
-
-
-class GeM(nn.Module):
-    def __init__(self, p=3, eps=1e-6):
-        super().__init__()
-        self.p = Parameter(torch.ones(1)*p)
-        self.eps = eps
-
-    def forward(self, x):
-        return F.avg_pool2d(x.clamp(min=self.eps).pow(self.p), (x.size(-2), x.size(-1))).pow(1./self.p)
\ No newline at end of file
diff --git a/fastreid/modeling/__init__.py b/fastreid/modeling/__init__.py
index ba0e0ea..71a2e7d 100644
--- a/fastreid/modeling/__init__.py
+++ b/fastreid/modeling/__init__.py
@@ -4,4 +4,4 @@
 @contact: sherlockliao01@gmail.com
 """
 
-
+from .meta_arch import build_model
diff --git a/fastreid/modeling/backbones/resnet.py b/fastreid/modeling/backbones/resnet.py
index bd9fcef..22b60a2 100644
--- a/fastreid/modeling/backbones/resnet.py
+++ b/fastreid/modeling/backbones/resnet.py
@@ -45,10 +45,28 @@ class IBN(nn.Module):
         return out
 
 
+class SELayer(nn.Module):
+    def __init__(self, channel, reduction=16):
+        super(SELayer, self).__init__()
+        self.avg_pool = nn.AdaptiveAvgPool2d(1)
+        self.fc = nn.Sequential(
+            nn.Linear(channel, int(channel / reduction), bias=False),
+            nn.ReLU(inplace=True),
+            nn.Linear(int(channel / reduction), channel, bias=False),
+            nn.Sigmoid()
+        )
+
+    def forward(self, x):
+        b, c, _, _ = x.size()
+        y = self.avg_pool(x).view(b, c)
+        y = self.fc(y).view(b, c, 1, 1)
+        return x * y.expand_as(x)
+
+
 class Bottleneck(nn.Module):
     expansion = 4
 
-    def __init__(self, inplanes, planes, with_ibn=False, stride=1, downsample=None):
+    def __init__(self, inplanes, planes, with_ibn=False, with_se=False, stride=1, downsample=None, reduction=16):
         super(Bottleneck, self).__init__()
         self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
         if with_ibn:
@@ -61,6 +79,10 @@ class Bottleneck(nn.Module):
         self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
         self.bn3 = nn.BatchNorm2d(planes * 4)
         self.relu = nn.ReLU(inplace=True)
+        if with_se:
+            self.se = SELayer(planes * 4, reduction)
+        else:
+            self.se = nn.Identity()
         self.downsample = downsample
         self.stride = stride
 
@@ -77,6 +99,7 @@ class Bottleneck(nn.Module):
 
         out = self.conv3(out)
         out = self.bn3(out)
+        out = self.se(out)
 
         if self.downsample is not None:
             residual = self.downsample(x)
@@ -97,14 +120,14 @@ class ResNet(nn.Module):
         self.bn1 = nn.BatchNorm2d(64)
         self.relu = nn.ReLU(inplace=True)
         self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
-        self.layer1 = self._make_layer(block, scale, layers[0], with_ibn=with_ibn)
-        self.layer2 = self._make_layer(block, scale * 2, layers[1], stride=2, with_ibn=with_ibn)
-        self.layer3 = self._make_layer(block, scale * 4, layers[2], stride=2, with_ibn=with_ibn)
-        self.layer4 = self._make_layer(block, scale * 8, layers[3], stride=last_stride)
+        self.layer1 = self._make_layer(block, scale, layers[0], with_ibn=with_ibn, with_se=with_se)
+        self.layer2 = self._make_layer(block, scale * 2, layers[1], stride=2, with_ibn=with_ibn, with_se=with_se)
+        self.layer3 = self._make_layer(block, scale * 4, layers[2], stride=2, with_ibn=with_ibn, with_se=with_se)
+        self.layer4 = self._make_layer(block, scale * 8, layers[3], stride=last_stride, with_se=with_se)
 
         self.random_init()
 
-    def _make_layer(self, block, planes, blocks, stride=1, with_ibn=False):
+    def _make_layer(self, block, planes, blocks, stride=1, with_ibn=False, with_se=False):
         downsample = None
         if stride != 1 or self.inplanes != planes * block.expansion:
             downsample = nn.Sequential(
@@ -116,10 +139,10 @@ class ResNet(nn.Module):
         layers = []
         if planes == 512:
             with_ibn = False
-        layers.append(block(self.inplanes, planes, with_ibn, stride, downsample))
+        layers.append(block(self.inplanes, planes, with_ibn, with_se, stride, downsample))
         self.inplanes = planes * block.expansion
         for i in range(1, blocks):
-            layers.append(block(self.inplanes, planes, with_ibn))
+            layers.append(block(self.inplanes, planes, with_ibn, with_se))
 
         return nn.Sequential(*layers)
 
@@ -168,20 +191,14 @@ def build_resnet_backbone(cfg):
         if not with_ibn:
             # original resnet
             state_dict = model_zoo.load_url(model_urls[depth])
-            # remove fully-connected-layers
-            state_dict.pop('fc.weight')
-            state_dict.pop('fc.bias')
         else:
             # ibn resnet
             state_dict = torch.load(pretrain_path)['state_dict']
-            # remove fully-connected-layers
-            state_dict.pop('module.fc.weight')
-            state_dict.pop('module.fc.bias')
             # remove module in name
             new_state_dict = {}
             for k in state_dict:
                 new_k = '.'.join(k.split('.')[1:])
-                if model.state_dict()[new_k].shape == state_dict[k].shape:
+                if new_k in model.state_dict() and (model.state_dict()[new_k].shape == state_dict[k].shape):
                     new_state_dict[new_k] = state_dict[k]
             state_dict = new_state_dict
         res = model.load_state_dict(state_dict, strict=False)
@@ -189,3 +206,5 @@ def build_resnet_backbone(cfg):
         logger.info('missing keys is {}'.format(res.missing_keys))
         logger.info('unexpected keys is {}'.format(res.unexpected_keys))
     return model
+
+
diff --git a/fastreid/modeling/heads/__init__.py b/fastreid/modeling/heads/__init__.py
index b857907..680470d 100644
--- a/fastreid/modeling/heads/__init__.py
+++ b/fastreid/modeling/heads/__init__.py
@@ -7,5 +7,6 @@
 from .build import REID_HEADS_REGISTRY, build_reid_heads
 
 # import all the meta_arch, so they will be registered
-from .bn_linear import BNneckLinear
-from .arcface import ArcFace
+from .linear_head import LinearHead
+from .bnneck_head import BNneckHead
+from .arcface import ArcfaceHead
diff --git a/fastreid/modeling/heads/arcface.py b/fastreid/modeling/heads/arcface.py
index 55d3281..ba525aa 100644
--- a/fastreid/modeling/heads/arcface.py
+++ b/fastreid/modeling/heads/arcface.py
@@ -17,42 +17,30 @@ from ...layers import bn_no_bias
 
 
 @REID_HEADS_REGISTRY.register()
-class ArcFace(nn.Module):
-    def __init__(self, cfg):
+class ArcfaceHead(nn.Module):
+    def __init__(self, cfg, in_feat):
         super().__init__()
-        self._in_features = 2048
         self._num_classes = cfg.MODEL.HEADS.NUM_CLASSES
         self._s = 30.0
         self._m = 0.50
 
-        self.gap = nn.AdaptiveAvgPool2d(1)
-        self.bnneck = bn_no_bias(self._in_features)
-        self.bnneck.apply(weights_init_kaiming)
-
         self.cos_m = math.cos(self._m)
         self.sin_m = math.sin(self._m)
 
         self.th = math.cos(math.pi - self._m)
         self.mm = math.sin(math.pi - self._m) * self._m
 
-        self.weight = Parameter(torch.Tensor(self._num_classes, self._in_features))
+        self.weight = Parameter(torch.Tensor(self._num_classes, in_feat))
         self.reset_parameters()
 
     def reset_parameters(self):
         nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
 
-    def forward(self, features, targets=None):
+    def forward(self, features, targets):
         """
         See :class:`ReIDHeads.forward`.
         """
-        global_features = self.gap(features)
-        global_features = global_features.view(global_features.shape[0], -1)
-        bn_features = self.bnneck(global_features)
-
-        if not self.training:
-            return F.normalize(bn_features)
-
-        cosine = F.linear(F.normalize(bn_features), F.normalize(self.weight))
+        cosine = F.linear(F.normalize(features), F.normalize(self.weight))
         sine = torch.sqrt((1.0 - torch.pow(cosine, 2)).clamp(0, 1))
         phi = cosine * self.cos_m - sine * self.sin_m
         phi = torch.where(cosine > self.th, phi, cosine - self.mm)
@@ -64,5 +52,4 @@ class ArcFace(nn.Module):
         pred_class_logits = (one_hot * phi) + (
                     (1.0 - one_hot) * cosine)  # you can use torch.where if your torch.__version__ is 0.4
         pred_class_logits *= self._s
-
-        return pred_class_logits, global_features, targets,
+        return pred_class_logits
diff --git a/fastreid/modeling/heads/bn_linear.py b/fastreid/modeling/heads/bn_linear.py
deleted file mode 100644
index 414d7df..0000000
--- a/fastreid/modeling/heads/bn_linear.py
+++ /dev/null
@@ -1,41 +0,0 @@
-# encoding: utf-8
-"""
-@author:  liaoxingyu
-@contact: sherlockliao01@gmail.com
-"""
-
-from torch import nn
-import torch.nn.functional as F
-
-from .build import REID_HEADS_REGISTRY
-from ..model_utils import weights_init_classifier, weights_init_kaiming
-from ...layers import bn_no_bias
-
-
-@REID_HEADS_REGISTRY.register()
-class BNneckLinear(nn.Module):
-
-    def __init__(self, cfg):
-        super().__init__()
-        self._num_classes = cfg.MODEL.HEADS.NUM_CLASSES
-
-        self.gap = nn.AdaptiveAvgPool2d(1)
-        self.bnneck = bn_no_bias(2048)
-        self.bnneck.apply(weights_init_kaiming)
-
-        self.classifier = nn.Linear(2048, self._num_classes, bias=False)
-        self.classifier.apply(weights_init_classifier)
-
-    def forward(self, features, targets=None):
-        """
-        See :class:`ReIDHeads.forward`.
-        """
-        global_features = self.gap(features)
-        global_features = global_features.view(global_features.shape[0], -1)
-        bn_features = self.bnneck(global_features)
-
-        if not self.training:
-            return F.normalize(bn_features)
-
-        pred_class_logits = self.classifier(bn_features)
-        return pred_class_logits, global_features, targets
diff --git a/fastreid/modeling/heads/bnneck_head.py b/fastreid/modeling/heads/bnneck_head.py
new file mode 100644
index 0000000..fe5dd62
--- /dev/null
+++ b/fastreid/modeling/heads/bnneck_head.py
@@ -0,0 +1,46 @@
+# encoding: utf-8
+"""
+@author:  liaoxingyu
+@contact: sherlockliao01@gmail.com
+"""
+
+from torch import nn
+
+from .build import REID_HEADS_REGISTRY
+from .linear_head import LinearHead
+from ..model_utils import weights_init_classifier, weights_init_kaiming
+from ...layers import bn_no_bias, Flatten
+
+
+@REID_HEADS_REGISTRY.register()
+class BNneckHead(nn.Module):
+
+    def __init__(self, cfg, in_feat, pool_layer=nn.AdaptiveAvgPool2d(1)):
+        super().__init__()
+        self._num_classes = cfg.MODEL.HEADS.NUM_CLASSES
+
+        self.pool_layer = nn.Sequential(
+            pool_layer,
+            Flatten()
+        )
+        self.bnneck = bn_no_bias(in_feat)
+        self.bnneck.apply(weights_init_kaiming)
+
+        self.classifier = nn.Linear(in_feat, self._num_classes, bias=False)
+        self.classifier.apply(weights_init_classifier)
+
+    def forward(self, features, targets=None):
+        """
+        See :class:`ReIDHeads.forward`.
+        """
+        global_feat = self.pool_layer(features)
+        bn_feat = self.bnneck(global_feat)
+        if not self.training:
+            return bn_feat
+        # training
+        pred_class_logits = self.classifier(bn_feat)
+        return pred_class_logits, global_feat
+
+    @classmethod
+    def losses(cls, cfg, pred_class_logits, global_features, gt_classes, prefix='') -> dict:
+        return LinearHead.losses(cfg, pred_class_logits, global_features, gt_classes, prefix)
diff --git a/fastreid/modeling/heads/build.py b/fastreid/modeling/heads/build.py
index 139c938..ddbb90b 100644
--- a/fastreid/modeling/heads/build.py
+++ b/fastreid/modeling/heads/build.py
@@ -16,9 +16,9 @@ The call is expected to return an :class:`ROIHeads`.
 """
 
 
-def build_reid_heads(cfg):
+def build_reid_heads(cfg, in_feat, pool_layer):
     """
     Build REIDHeads defined by `cfg.MODEL.REID_HEADS.NAME`.
     """
     head = cfg.MODEL.HEADS.NAME
-    return REID_HEADS_REGISTRY.get(head)(cfg)
+    return REID_HEADS_REGISTRY.get(head)(cfg, in_feat, pool_layer)
diff --git a/fastreid/modeling/heads/linear_head.py b/fastreid/modeling/heads/linear_head.py
new file mode 100644
index 0000000..4ba044a
--- /dev/null
+++ b/fastreid/modeling/heads/linear_head.py
@@ -0,0 +1,55 @@
+# encoding: utf-8
+"""
+@author:  liaoxingyu
+@contact: sherlockliao01@gmail.com
+"""
+
+from torch import nn
+
+from .build import REID_HEADS_REGISTRY
+from ..losses import CrossEntropyLoss, TripletLoss
+from ..model_utils import weights_init_classifier, weights_init_kaiming
+from ...layers import bn_no_bias, Flatten
+
+
+@REID_HEADS_REGISTRY.register()
+class LinearHead(nn.Module):
+
+    def __init__(self, cfg, in_feat, pool_layer=nn.AdaptiveAvgPool2d(1)):
+        super().__init__()
+        self._num_classes = cfg.MODEL.HEADS.NUM_CLASSES
+
+        self.pool_layer = nn.Sequential(
+            pool_layer,
+            Flatten()
+        )
+
+        self.classifier = nn.Linear(in_feat, self._num_classes, bias=False)
+        self.classifier.apply(weights_init_classifier)
+
+    def forward(self, features, targets=None):
+        """
+        See :class:`ReIDHeads.forward`.
+        """
+        global_feat = self.pool_layer(features)
+        if not self.training:
+            return global_feat
+        # training
+        pred_class_logits = self.classifier(global_feat)
+        return pred_class_logits, global_feat
+
+    @classmethod
+    def losses(cls, cfg, pred_class_logits, global_features, gt_classes, prefix='') -> dict:
+        loss_dict = {}
+        if "CrossEntropyLoss" in cfg.MODEL.LOSSES.NAME and pred_class_logits is not None:
+            loss = CrossEntropyLoss(cfg)(pred_class_logits, gt_classes)
+            loss_dict.update(loss)
+        if "TripletLoss" in cfg.MODEL.LOSSES.NAME and global_features is not None:
+            loss = TripletLoss(cfg)(global_features, gt_classes)
+            loss_dict.update(loss)
+        # rename
+        name_loss_dict = {}
+        for name in loss_dict.keys():
+            name_loss_dict[prefix + name] = loss_dict[name]
+        del loss_dict
+        return name_loss_dict
diff --git a/fastreid/modeling/losses/__init__.py b/fastreid/modeling/losses/__init__.py
index a0223aa..5f7e8b4 100644
--- a/fastreid/modeling/losses/__init__.py
+++ b/fastreid/modeling/losses/__init__.py
@@ -4,7 +4,5 @@
 @contact: sherlockliao01@gmail.com
 """
 
-from .build import build_criterion, LOSS_REGISTRY
-
 from .cross_entroy_loss import CrossEntropyLoss
 from .margin_loss import TripletLoss
diff --git a/fastreid/modeling/losses/build.py b/fastreid/modeling/losses/build.py
index f8124b0..b8a88d6 100644
--- a/fastreid/modeling/losses/build.py
+++ b/fastreid/modeling/losses/build.py
@@ -23,9 +23,9 @@ def build_criterion(cfg):
 
     loss_names = cfg.MODEL.LOSSES.NAME
     loss_funcs = [LOSS_REGISTRY.get(loss_name)(cfg) for loss_name in loss_names]
-    loss_dict = {}
 
     def criterion(*args):
+        loss_dict = {}
         for loss_func in loss_funcs:
             loss = loss_func(*args)
             loss_dict.update(loss)
diff --git a/fastreid/modeling/losses/cross_entroy_loss.py b/fastreid/modeling/losses/cross_entroy_loss.py
index 3a3d5ca..6b1bfa6 100644
--- a/fastreid/modeling/losses/cross_entroy_loss.py
+++ b/fastreid/modeling/losses/cross_entroy_loss.py
@@ -5,13 +5,10 @@
 """
 import torch
 import torch.nn.functional as F
-from torch import nn
 
-from .build import LOSS_REGISTRY
 from ...utils.events import get_event_storage
 
 
-@LOSS_REGISTRY.register()
 class CrossEntropyLoss(object):
     """
     A class that stores information and compute losses about outputs of a Baseline head.
@@ -43,7 +40,7 @@ class CrossEntropyLoss(object):
         storage = get_event_storage()
         storage.put_scalar("cls_accuracy", ret[0])
 
-    def __call__(self, pred_class_logits, pred_features, gt_classes):
+    def __call__(self, pred_class_logits, gt_classes):
         """
         Compute the softmax cross entropy loss for box classification.
         Returns:
@@ -59,5 +56,5 @@ class CrossEntropyLoss(object):
         else:
             loss = F.cross_entropy(pred_class_logits, gt_classes, reduction="mean")
         return {
-            "loss_cls": loss*self._scale,
+            "loss_cls": loss * self._scale,
         }
diff --git a/fastreid/modeling/losses/margin_loss.py b/fastreid/modeling/losses/margin_loss.py
index 3328782..b692caa 100644
--- a/fastreid/modeling/losses/margin_loss.py
+++ b/fastreid/modeling/losses/margin_loss.py
@@ -7,8 +7,6 @@
 import torch
 from torch import nn
 
-from .build import LOSS_REGISTRY
-
 
 def normalize(x, axis=-1):
     """Normalizing to unit length along the specified dimension.
@@ -102,7 +100,6 @@ def hard_example_mining(dist_mat, labels, return_inds=False):
     return dist_ap, dist_an
 
 
-@LOSS_REGISTRY.register()
 class TripletLoss(object):
     """Modified from Tong Xiao's open-reid (https://github.com/Cysu/open-reid).
     Related Triplet Loss theory can be found in paper 'In Defense of the Triplet
@@ -118,7 +115,7 @@ class TripletLoss(object):
         else:
             self.ranking_loss = nn.SoftMarginLoss()
 
-    def __call__(self, pred_class_logits, global_features, targets):
+    def __call__(self, global_features, targets):
         if self._normalize_feature:
             global_features = normalize(global_features, axis=-1)
 
diff --git a/fastreid/modeling/meta_arch/__init__.py b/fastreid/modeling/meta_arch/__init__.py
index 63c5d08..d27e25f 100644
--- a/fastreid/modeling/meta_arch/__init__.py
+++ b/fastreid/modeling/meta_arch/__init__.py
@@ -9,3 +9,8 @@ from .build import META_ARCH_REGISTRY, build_model
 
 # import all the meta_arch, so they will be registered
 from .baseline import Baseline
+from .bdb_network import BDB_net
+from .mf_network import MF_net
+from .abd_network import ABD_net
+from .mid_network import MidNetwork
+from .mgn import MGN
diff --git a/fastreid/modeling/meta_arch/abd_network.py b/fastreid/modeling/meta_arch/abd_network.py
new file mode 100644
index 0000000..921f43e
--- /dev/null
+++ b/fastreid/modeling/meta_arch/abd_network.py
@@ -0,0 +1,178 @@
+# encoding: utf-8
+"""
+@author:  l1aoxingyu
+@contact: sherlockliao01@gmail.com
+"""
+
+import copy
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from .build import META_ARCH_REGISTRY
+from ..backbones import build_backbone
+from ..heads import build_reid_heads, BNneckHead
+from ..model_utils import weights_init_kaiming
+from ...layers import CAM_Module, PAM_Module, DANetHead, Flatten, bn_no_bias
+
+
+@META_ARCH_REGISTRY.register()
+class ABD_net(nn.Module):
+    def __init__(self, cfg):
+        super().__init__()
+        self._cfg = cfg
+        # backbone
+        backbone = build_backbone(cfg)
+        self.backbone1 = nn.Sequential(
+            backbone.conv1,
+            backbone.bn1,
+            backbone.relu,
+            backbone.maxpool,
+            backbone.layer1,
+        )
+        self.shallow_cam = CAM_Module(256)
+        self.backbone2 = nn.Sequential(
+            backbone.layer2,
+            backbone.layer3
+        )
+
+        # global branch
+        self.global_res4 = copy.deepcopy(backbone.layer4)
+        self.global_branch = nn.Sequential(
+            nn.AdaptiveAvgPool2d(1),
+            Flatten(),
+            # reduce
+            nn.Linear(2048, 1024, bias=False),
+            nn.BatchNorm1d(1024),
+            nn.ReLU(True),
+        )
+        self.global_branch.apply(weights_init_kaiming)
+
+        self.global_head = build_reid_heads(cfg, 1024, nn.Identity())
+
+        # attention branch
+        self.att_res4 = copy.deepcopy(backbone.layer4)
+        # reduce
+        self.att_reduce = nn.Sequential(
+            nn.Conv2d(2048, 1024, kernel_size=1, bias=False),
+            nn.BatchNorm2d(1024),
+            nn.ReLU(True),
+        )
+        self.att_reduce.apply(weights_init_kaiming)
+
+        self.abd_branch = ABDBranch(1024)
+        self.abd_branch.apply(weights_init_kaiming)
+
+        self.att_head = build_reid_heads(cfg, 1024, nn.Identity())
+
+    def forward(self, inputs):
+        images = inputs["images"]
+        targets = inputs["targets"]
+
+        if not self.training:
+            pred_feat = self.inference(images)
+            return pred_feat, targets, inputs["camid"]
+
+        feat = self.backbone1(images)
+        feat = self.shallow_cam(feat)
+        feat = self.backbone2(feat)
+
+        # global branch
+        global_feat = self.global_res4(feat)
+        global_feat = self.global_branch(global_feat)
+        global_logits, global_feat = self.global_head(global_feat, targets)
+
+        # attention branch
+        att_feat = self.att_res4(feat)
+        att_feat = self.att_reduce(att_feat)
+        att_feat = self.abd_branch(att_feat)
+        att_logits, att_feat = self.att_bnneck(att_feat, targets)
+
+        return global_logits, global_feat, att_logits, att_feat, targets
+
+    def losses(self, outputs):
+        loss_dict = {}
+        loss_dict.update(self.global_head.losses(self._cfg, outputs[0], outputs[1], outputs[-1], 'global_'))
+        loss_dict.update(self.att_head.losses(self._cfg, outputs[2], outputs[3], outputs[-1], 'att_'))
+        return loss_dict
+
+    def inference(self, images):
+        assert not self.training
+        feat = self.backbone1(images)
+        feat = self.shallow_cam(feat)
+        feat = self.backbone2(feat)
+
+        # global branch
+        global_feat = self.global_res4(feat)
+        global_feat = self.global_branch(global_feat)
+        global_pred_feat = self.global_head(global_feat)
+
+        # attention branch
+        att_feat = self.att_res4(feat)
+        att_feat = self.att_reduce(att_feat)
+        att_feat = self.abd_branch(att_feat)
+        att_pred_feat = self.att_head(att_feat)
+
+        pred_feat = torch.cat([global_pred_feat, att_pred_feat], dim=1)
+        return F.normalize(pred_feat)
+
+
+class ABDBranch(nn.Module):
+
+    def __init__(self, input_dim):
+        super().__init__()
+        self.input_dim = input_dim
+        self.output_dim = 1024
+        self.part_num = 2
+        self.avg_pool = nn.Sequential(
+            nn.AdaptiveAvgPool2d(1),
+            Flatten())
+
+        self._init_attention_modules()
+
+    def _init_attention_modules(self):
+        self.before_module = DANetHead(self.output_dim, self.output_dim, nn.BatchNorm2d, nn.Identity)
+
+        self.cam_module = DANetHead(self.output_dim, self.output_dim, nn.BatchNorm2d, CAM_Module)
+
+        self.pam_module = DANetHead(self.output_dim, self.output_dim, nn.BatchNorm2d, PAM_Module)
+
+        self.sum_conv = nn.Sequential(
+            nn.Dropout2d(0.1, False),
+            nn.Conv2d(self.output_dim, self.output_dim, kernel_size=1)
+        )
+
+    def forward(self, x):
+        assert x.size(2) % self.part_num == 0, \
+            "Height {} is not a multiplication of {}. Aborted.".format(x.size(2), self.part_num)
+
+        before_x = self.before_module(x)
+        cam_x = self.cam_module(x)
+        pam_x = self.pam_module(x)
+        sum_x = before_x + cam_x + pam_x
+        att_feat = self.sum_conv(sum_x)
+        avg_feat = self.avg_pool(att_feat)
+        return avg_feat
+        # margin = x.size(2) // self.part_num
+        # for p in range(self.part_num):
+        #     x_sliced = x[:, :, margin * p:margin * (p + 1), :]
+        #
+        #     to_sum = []
+        #     # module_name: str
+        #     for module_name in self.dan_module_names:
+        #         x_out = getattr(self, module_name)(x_sliced)
+        #         to_sum.append(x_out)
+        #         fmap[module_name.partition('_')[0]].append(x_out)
+        #
+        #     fmap_after = self.sum_conv(sum(to_sum))
+        #     fmap['after'].append(fmap_after)
+        #
+        #     v = self.avgpool(fmap_after)
+        #     v = v.view(v.size(0), -1)
+        #     triplet.append(v)
+        #     predict.append(v)
+        #     v = self.classifiers[p](v)
+        #     xent.append(v)
+        #
+        # return predict, xent, triplet, fmap
diff --git a/fastreid/modeling/meta_arch/baseline.py b/fastreid/modeling/meta_arch/baseline.py
index 5a29ba8..881ec86 100644
--- a/fastreid/modeling/meta_arch/baseline.py
+++ b/fastreid/modeling/meta_arch/baseline.py
@@ -5,32 +5,51 @@
 """
 
 from torch import nn
+import torch.nn.functional as F
 
 from .build import META_ARCH_REGISTRY
 from ..backbones import build_backbone
 from ..heads import build_reid_heads
+from ...layers import GeneralizedMeanPoolingP
 
 
 @META_ARCH_REGISTRY.register()
 class Baseline(nn.Module):
     def __init__(self, cfg):
         super().__init__()
+        self._cfg = cfg
+        # backbone
         self.backbone = build_backbone(cfg)
-        self.heads = build_reid_heads(cfg)
+
+        # head
+        if cfg.MODEL.HEADS.POOL_LAYER == 'avgpool':
+            pool_layer = nn.AdaptiveAvgPool2d(1)
+        elif cfg.MODEL.HEADS.POOL_LAYER == 'maxpool':
+            pool_layer = nn.AdaptiveMaxPool2d(1)
+        elif cfg.MODEL.HEADS.POOL_LAYER == 'gempool':
+            pool_layer = GeneralizedMeanPoolingP()
+        else:
+            pool_layer = nn.Identity()
+        self.heads = build_reid_heads(cfg, 2048, pool_layer)
 
     def forward(self, inputs):
-        if not self.training:
-            return self.inference(inputs)
-
         images = inputs["images"]
         targets = inputs["targets"]
-        global_feat = self.backbone(images)  # (bs, 2048, 16, 8)
-        outputs = self.heads(global_feat, targets)
-        return outputs
 
-    def inference(self, inputs):
+        if not self.training:
+            pred_feat = self.inference(images)
+            return pred_feat, targets, inputs["camid"]
+
+        # training
+        features = self.backbone(images)  # (bs, 2048, 16, 8)
+        logits, global_feat = self.heads(features, targets)
+        return logits, global_feat, targets
+
+    def inference(self, images):
         assert not self.training
-        images = inputs["images"]
-        global_feat = self.backbone(images)
-        pred_features = self.heads(global_feat)
-        return pred_features, inputs["targets"], inputs["camid"]
+        features = self.backbone(images)  # (bs, 2048, 16, 8)
+        pred_feat = self.heads(features)
+        return F.normalize(pred_feat)
+
+    def losses(self, outputs):
+        return self.heads.losses(self._cfg, *outputs)
diff --git a/fastreid/modeling/meta_arch/bdb_network.py b/fastreid/modeling/meta_arch/bdb_network.py
new file mode 100644
index 0000000..b2ca4b8
--- /dev/null
+++ b/fastreid/modeling/meta_arch/bdb_network.py
@@ -0,0 +1,99 @@
+# encoding: utf-8
+"""
+@author:  liaoxingyu
+@contact: sherlockliao01@gmail.com
+"""
+
+import torch
+from torch import nn
+import torch.nn.functional as F
+
+from .build import META_ARCH_REGISTRY
+from ..backbones import build_backbone
+from ..backbones.resnet import Bottleneck
+from ..heads import build_reid_heads, BNneckHead
+from ..model_utils import weights_init_kaiming
+from ...layers import BatchDrop, bn_no_bias, Flatten, GeneralizedMeanPoolingP
+
+
+@META_ARCH_REGISTRY.register()
+class BDB_net(nn.Module):
+    def __init__(self, cfg):
+        super().__init__()
+        self._cfg = cfg
+        self.backbone = build_backbone(cfg)
+
+        # global branch
+        if cfg.MODEL.HEADS.POOL_LAYER == 'avgpool':
+            pool_layer = nn.AdaptiveAvgPool2d(1)
+        elif cfg.MODEL.HEADS.POOL_LAYER == 'maxpool':
+            pool_layer = nn.AdaptiveMaxPool2d(1)
+        elif cfg.MODEL.HEADS.POOL_LAYER == 'gempool':
+            pool_layer = GeneralizedMeanPoolingP()
+        else:
+            pool_layer = nn.Identity()
+
+        self.global_branch = nn.Sequential(
+            pool_layer,
+            Flatten(),
+            nn.Linear(2048, 512, bias=False),
+            nn.BatchNorm1d(512),
+            nn.ReLU(True),
+        )
+        self.global_head = build_reid_heads(cfg, 512, nn.Identity())
+
+        # part brach
+        self.part_branch = nn.Sequential(
+            Bottleneck(2048, 512),
+            BatchDrop(0.33, 1),
+            nn.AdaptiveMaxPool2d(1),
+            Flatten(),
+            nn.Linear(2048, 1024, bias=False),
+            nn.BatchNorm1d(1024),
+            nn.ReLU(True),
+        )
+        self.part_head = build_reid_heads(cfg, 1024, nn.Identity())
+
+        # initialize
+        self.global_branch.apply(weights_init_kaiming)
+        self.part_branch.apply(weights_init_kaiming)
+
+    def forward(self, inputs):
+        images = inputs["images"]
+        targets = inputs["targets"]
+
+        if not self.training:
+            pred_feat = self.inference(images)
+            return pred_feat, targets, inputs["camid"]
+
+        # training
+        features = self.backbone(images)
+        # global branch
+        global_feat = self.global_branch(features)
+        global_logits, global_feat = self.global_head(global_feat, targets)
+
+        # part branch
+        part_feat = self.part_branch(features)
+        part_logits, part_feat = self.part_head(part_feat, targets)
+
+        return global_logits, global_feat, part_logits, part_feat, targets
+
+    def inference(self, images):
+        assert not self.training
+        features = self.backbone(images)
+        # global branch
+        global_feat = self.global_branch(features)
+        global_bn_feat = self.global_head(global_feat)
+
+        # part branch
+        part_feat = self.part_branch(features)
+        part_bn_feat = self.part_head(part_feat)
+
+        pred_feat = torch.cat([global_bn_feat, part_bn_feat], dim=1)
+        return F.normalize(pred_feat)
+
+    def losses(self, outputs):
+        loss_dict = {}
+        loss_dict.update(self.global_head.losses(self._cfg, outputs[0], outputs[1], outputs[-1], 'global_'))
+        loss_dict.update(self.part_head.losses(self._cfg, outputs[2], outputs[3], outputs[-1], 'part_'))
+        return loss_dict
diff --git a/fastreid/modeling/meta_arch/bdnet.py b/fastreid/modeling/meta_arch/bdnet.py
deleted file mode 100644
index c5f3aaf..0000000
--- a/fastreid/modeling/meta_arch/bdnet.py
+++ /dev/null
@@ -1,124 +0,0 @@
-# encoding: utf-8
-"""
-@author:  liaoxingyu
-@contact: sherlockliao01@gmail.com
-"""
-
-import torch
-from torch import nn
-import torch.nn.functional as F
-
-from fastreid.modeling.backbones import *
-from fastreid.modeling.backbones.resnet import Bottleneck
-from fastreid.modeling.model_utils import *
-from fastreid.modeling.heads import *
-from fastreid.layers import BatchDrop
-
-
-class BDNet(nn.Module):
-    def __init__(self, 
-                 backbone, 
-                 num_classes, 
-                 last_stride, 
-                 with_ibn, 
-                 gcb, 
-                 stage_with_gcb, 
-                 pretrain=True, 
-                 model_path=''):
-        super().__init__()
-        self.num_classes = num_classes
-        if 'resnet' in backbone:
-            self.base = ResNet.from_name(backbone, last_stride, with_ibn, gcb, stage_with_gcb)
-            self.base.load_pretrain(model_path)
-            self.in_planes = 2048
-        elif 'osnet' in backbone:
-            if with_ibn:
-                self.base = osnet_ibn_x1_0(pretrained=pretrain)
-            else:
-                self.base = osnet_x1_0(pretrained=pretrain)
-            self.in_planes = 512
-        else:
-            print(f'not support {backbone} backbone')
-
-        # global branch
-        self.global_reduction = nn.Sequential(
-            nn.Conv2d(self.in_planes, 512, 1),
-            nn.BatchNorm2d(512),
-            nn.ReLU(True)
-        )
-
-        self.gap = nn.AdaptiveAvgPool2d(1)
-        self.global_bn = bn2d_no_bias(512)
-        self.global_classifier = nn.Linear(512, self.num_classes, bias=False)
-
-        # mask brach
-        self.part = Bottleneck(2048, 512)
-        self.batch_drop = BatchDrop(1.0, 0.33)
-        self.part_pool = nn.AdaptiveMaxPool2d(1)
-
-        self.part_reduction = nn.Sequential(
-            nn.Conv2d(self.in_planes, 1024, 1),
-            nn.BatchNorm2d(1024),
-            nn.ReLU(True)
-        )
-        self.part_bn = bn2d_no_bias(1024)
-        self.part_classifier = nn.Linear(1024, self.num_classes, bias=False)
-
-        # initialize 
-        self.part.apply(weights_init_kaiming)
-        self.global_reduction.apply(weights_init_kaiming)
-        self.part_reduction.apply(weights_init_kaiming)
-        self.global_classifier.apply(weights_init_classifier)
-        self.part_classifier.apply(weights_init_classifier)
-
-    def forward(self, x, label=None):
-        # feature extractor
-        feat = self.base(x)
-
-        # global branch
-        g_feat = self.global_reduction(feat)
-        g_feat = self.gap(g_feat)  # (bs, 512, 1, 1)
-        g_bn_feat = self.global_bn(g_feat)  # (bs, 512, 1, 1)
-        g_bn_feat = g_bn_feat.view(-1, g_bn_feat.shape[1])  # (bs, 512)
-
-        # mask branch
-        p_feat = self.part(feat)
-        p_feat = self.batch_drop(p_feat)
-        p_feat = self.part_pool(p_feat)  # (bs, 512, 1, 1)
-        p_feat = self.part_reduction(p_feat)
-        p_bn_feat = self.part_bn(p_feat)
-        p_bn_feat = p_bn_feat.view(-1, p_bn_feat.shape[1])  # (bs, 512)
-
-        if self.training:
-            global_cls = self.global_classifier(g_bn_feat)
-            part_cls = self.part_classifier(p_bn_feat)
-            return global_cls, part_cls, g_feat.view(-1, g_feat.shape[1]), p_feat.view(-1, p_feat.shape[1])
-
-        return torch.cat([g_bn_feat, p_bn_feat], dim=1)
-
-    def load_params_wo_fc(self, state_dict):
-        state_dict.pop('global_classifier.weight')
-        state_dict.pop('part_classifier.weight')
-
-        res = self.load_state_dict(state_dict, strict=False)
-        print(f'missing keys {res.missing_keys}')
-        # assert str(res.missing_keys) == str(['classifier.weight',]), 'issue loading pretrained weights'
-
-    def unfreeze_all_layers(self,):
-        self.train()
-        for p in self.parameters():
-            p.requires_grad = True
-
-    def unfreeze_specific_layer(self, names):
-        if isinstance(names, str):
-            names = [names]
-        
-        for name, module in self.named_children():
-            if name in names:
-                module.train()
-                for p in module.parameters():
-                    p.requires_grad = True
-            else:
-                module.eval()
-                for p in module.parameters():
-                    p.requires_grad = False
diff --git a/fastreid/modeling/meta_arch/mf_network.py b/fastreid/modeling/meta_arch/mf_network.py
new file mode 100644
index 0000000..6989dad
--- /dev/null
+++ b/fastreid/modeling/meta_arch/mf_network.py
@@ -0,0 +1,139 @@
+# encoding: utf-8
+"""
+@author:  liaoxingyu
+@contact: sherlockliao01@gmail.com
+"""
+
+import torch
+from torch import nn
+import torch.nn.functional as F
+
+from .build import META_ARCH_REGISTRY
+from ..model_utils import weights_init_kaiming
+from ..backbones import build_backbone
+from ..heads import build_reid_heads, BNneckHead
+from ...layers import Flatten, bn_no_bias
+
+
+@META_ARCH_REGISTRY.register()
+class MF_net(nn.Module):
+    def __init__(self, cfg):
+        super().__init__()
+        self._cfg = cfg
+        # backbone
+        backbone = build_backbone(cfg)
+        self.backbone = nn.Sequential(
+            backbone.conv1,
+            backbone.bn1,
+            backbone.relu,
+            backbone.maxpool,
+            backbone.layer1,
+            backbone.layer2,
+            backbone.layer3
+        )
+        # body
+        self.res4 = backbone.layer4
+        self.avgpool = nn.AdaptiveAvgPool2d(1)
+        self.maxpool = nn.AdaptiveMaxPool2d(1)
+        self.avgpool_2 = nn.AdaptiveAvgPool2d((2, 2))
+        self.maxpool_2 = nn.AdaptiveMaxPool2d((2, 2))
+        # branch 1
+        self.branch_1 = nn.Sequential(
+            Flatten(),
+            nn.BatchNorm1d(2048),
+            nn.LeakyReLU(0.1, True),
+            nn.Linear(2048, 512, bias=False),
+        )
+        self.branch_1.apply(weights_init_kaiming)
+        self.head1 = build_reid_heads(cfg, 512, nn.Identity())
+
+        # branch 2
+        self.branch_2 = nn.Sequential(
+            Flatten(),
+            nn.BatchNorm1d(8192),
+            nn.LeakyReLU(0.1, True),
+            nn.Linear(8192, 512, bias=False),
+        )
+        self.branch_2.apply(weights_init_kaiming)
+        self.head2 = build_reid_heads(cfg, 512, nn.Identity())
+        # branch 3
+        self.branch_3 = nn.Sequential(
+            Flatten(),
+            nn.BatchNorm1d(1024),
+            nn.LeakyReLU(0.1, True),
+            nn.Linear(1024, 512, bias=False),
+        )
+        self.branch_3.apply(weights_init_kaiming)
+        self.head3 = build_reid_heads(cfg, 512, nn.Identity())
+
+    def forward(self, inputs):
+        images = inputs["images"]
+        targets = inputs["targets"]
+
+        if not self.training:
+            pred_feat = self.inference(images)
+            return pred_feat, targets, inputs["camid"]
+
+        mid_feat = self.backbone(images)
+        feat = self.res4(mid_feat)
+
+        # branch 1
+        avg_feat1 = self.avgpool(feat)
+        max_feat1 = self.maxpool(feat)
+        feat1 = avg_feat1 + max_feat1
+        feat1 = self.branch_1(feat1)
+        logits_1, feat1 = self.head1(feat1, targets)
+        # branch 2
+        avg_feat2 = self.avgpool_2(feat)
+        max_feat2 = self.maxpool_2(feat)
+        feat2 = avg_feat2 + max_feat2
+        feat2 = self.branch_2(feat2)
+        logits_2, feat2 = self.head2(feat2, targets)
+        # branch 3
+        avg_feat3 = self.avgpool(mid_feat)
+        max_feat3 = self.maxpool(mid_feat)
+        feat3 = avg_feat3 + max_feat3
+        feat3 = self.branch_3(feat3)
+        logits_3, feat3 = self.head3(feat3, targets)
+
+        return logits_1, logits_2, logits_3, \
+               Flatten()(avg_feat1), Flatten()(avg_feat2), Flatten()(avg_feat3),\
+               Flatten()(max_feat1), Flatten()(max_feat2), Flatten()(max_feat3), targets
+
+    def inference(self, images):
+        assert not self.training
+
+        mid_feat = self.backbone(images)
+        feat = self.res4(mid_feat)
+
+        # branch 1
+        avg_feat1 = self.avgpool(feat)
+        max_feat1 = self.maxpool(feat)
+        feat1 = avg_feat1 + max_feat1
+        feat1 = self.branch_1(feat1)
+        pred_feat1 = self.head1(feat1)
+        # branch 2
+        avg_feat2 = self.avgpool_2(feat)
+        max_feat2 = self.maxpool_2(feat)
+        feat2 = avg_feat2 + max_feat2
+        feat2 = self.branch_2(feat2)
+        pred_feat2 = self.head2(feat2)
+        # branch 3
+        avg_feat3 = self.avgpool(mid_feat)
+        max_feat3 = self.maxpool(mid_feat)
+        feat3 = avg_feat3 + max_feat3
+        feat3 = self.branch_3(feat3)
+        pred_feat3 = self.head3(feat3)
+
+        pred_feat = torch.cat([pred_feat1, pred_feat2, pred_feat3], dim=1)
+        return F.normalize(pred_feat)
+
+    def losses(self, outputs):
+        loss_dict = {}
+        loss_dict.update(self.head1.losses(self._cfg, outputs[0], outputs[3], outputs[-1], 'b1_'))
+        loss_dict.update(self.head2.losses(self._cfg, outputs[1], outputs[4], outputs[-1], 'b2_'))
+        loss_dict.update(self.head3.losses(self._cfg, outputs[2], outputs[5], outputs[-1], 'b3_'))
+        loss_dict.update(self.head1.losses(self._cfg, None, outputs[6], outputs[-1], 'mp1_'))
+        loss_dict.update(self.head2.losses(self._cfg, None, outputs[7], outputs[-1], 'mp2_'))
+        loss_dict.update(self.head3.losses(self._cfg, None, outputs[8], outputs[-1], 'mp3_'))
+        return loss_dict
diff --git a/fastreid/modeling/meta_arch/mgn.py b/fastreid/modeling/meta_arch/mgn.py
index 502c411..5dd916c 100644
--- a/fastreid/modeling/meta_arch/mgn.py
+++ b/fastreid/modeling/meta_arch/mgn.py
@@ -6,148 +6,202 @@
 import copy
 
 import torch
+import torch.nn.functional as F
 from torch import nn
 
-from fastreid.modeling.backbones import ResNet, Bottleneck
-from fastreid.modeling.model_utils import *
+from .build import META_ARCH_REGISTRY
+from ..backbones import build_backbone
+from ..backbones.resnet import Bottleneck
+from ..heads import build_reid_heads
+from ..model_utils import weights_init_kaiming
+from ...layers import GeneralizedMeanPoolingP, Flatten
 
 
+@META_ARCH_REGISTRY.register()
 class MGN(nn.Module):
-    in_planes = 2048
-    feats = 256
-
-    def __init__(self,
-                 backbone,
-                 num_classes,
-                 last_stride,
-                 with_ibn,
-                 gcb,
-                 stage_with_gcb,
-                 pretrain=True,
-                 model_path=''):
+    def __init__(self, cfg):
         super().__init__()
-        try:
-            base_module = ResNet.from_name(backbone, last_stride, with_ibn, gcb, stage_with_gcb)
-        except:
-            print(f'not support {backbone} backbone')
-
-        if pretrain:
-            base_module.load_pretrain(model_path)
-
-        self.num_classes = num_classes
-
+        self._cfg = cfg
+        # backbone
+        backbone = build_backbone(cfg)
         self.backbone = nn.Sequential(
-            base_module.conv1,
-            base_module.bn1,
-            base_module.relu,
-            base_module.maxpool,
-            base_module.layer1,
-            base_module.layer2,
-            base_module.layer3[0]
+            backbone.conv1,
+            backbone.bn1,
+            backbone.relu,
+            backbone.maxpool,
+            backbone.layer1,
+            backbone.layer2,
+            backbone.layer3[0]
         )
-        
-        res_conv4 = nn.Sequential(*base_module.layer3[1:])
-        
-        res_g_conv5 = base_module.layer4
-        
+        res_conv4 = nn.Sequential(*backbone.layer3[1:])
+        res_g_conv5 = backbone.layer4
+
         res_p_conv5 = nn.Sequential(
-            Bottleneck(1024, 512, downsample=nn.Sequential(nn.Conv2d(1024, 2048, 1, bias=False),
-                                                           nn.BatchNorm2d(2048))),
+            Bottleneck(1024, 512, downsample=nn.Sequential(
+                nn.Conv2d(1024, 2048, 1, bias=False), nn.BatchNorm2d(2048))),
             Bottleneck(2048, 512),
-            Bottleneck(2048, 512)
-        )
-        res_p_conv5.load_state_dict(base_module.layer4.state_dict())
+            Bottleneck(2048, 512))
+        res_p_conv5.load_state_dict(backbone.layer4.state_dict())
 
-        self.p1 = nn.Sequential(copy.deepcopy(res_conv4), copy.deepcopy(res_g_conv5))
-        self.p2 = nn.Sequential(copy.deepcopy(res_conv4), copy.deepcopy(res_p_conv5))
-        self.p3 = nn.Sequential(copy.deepcopy(res_conv4), copy.deepcopy(res_p_conv5))
-
-        self.avgpool = nn.AdaptiveAvgPool2d(1)
-        self.maxpool_zp2 = nn.MaxPool2d((12, 9))
-        self.maxpool_zp3 = nn.MaxPool2d((8, 9))
-
-        self.reduction = nn.Conv2d(2048, self.feats, 1, bias=False)
-        self.bn_neck = BN_no_bias(self.feats)
-        # self.bn_neck_2048_0 = BN_no_bias(self.feats)
-        # self.bn_neck_2048_1 = BN_no_bias(self.feats)
-        # self.bn_neck_2048_2 = BN_no_bias(self.feats)
-        # self.bn_neck_256_1_0 = BN_no_bias(self.feats)
-        # self.bn_neck_256_1_1 = BN_no_bias(self.feats)
-        # self.bn_neck_256_2_0 = BN_no_bias(self.feats)
-        # self.bn_neck_256_2_1 = BN_no_bias(self.feats)
-        # self.bn_neck_256_2_2 = BN_no_bias(self.feats)
-
-        self.fc_id_2048_0 = nn.Linear(self.feats, self.num_classes, bias=False)
-        self.fc_id_2048_1 = nn.Linear(self.feats, self.num_classes, bias=False)
-        self.fc_id_2048_2 = nn.Linear(self.feats, self.num_classes, bias=False)
-
-        self.fc_id_256_1_0 = nn.Linear(self.feats, self.num_classes, bias=False)
-        self.fc_id_256_1_1 = nn.Linear(self.feats, self.num_classes, bias=False)
-        self.fc_id_256_2_0 = nn.Linear(self.feats, self.num_classes, bias=False)
-        self.fc_id_256_2_1 = nn.Linear(self.feats, self.num_classes, bias=False)
-        self.fc_id_256_2_2 = nn.Linear(self.feats, self.num_classes, bias=False)
-
-        self.fc_id_2048_0.apply(weights_init_classifier)
-        self.fc_id_2048_1.apply(weights_init_classifier)
-        self.fc_id_2048_2.apply(weights_init_classifier)
-        self.fc_id_256_1_0.apply(weights_init_classifier)
-        self.fc_id_256_1_1.apply(weights_init_classifier)
-        self.fc_id_256_2_0.apply(weights_init_classifier)
-        self.fc_id_256_2_1.apply(weights_init_classifier)
-        self.fc_id_256_2_2.apply(weights_init_classifier)
-
-    def forward(self, x, label=None):
-        global_feat = self.backbone(x)
-
-        p1 = self.p1(global_feat)  # (bs, 2048, 18, 9)
-        p2 = self.p2(global_feat)  # (bs, 2048, 18, 9)
-        p3 = self.p3(global_feat)  # (bs, 2048, 18, 9)
-
-        zg_p1 = self.avgpool(p1)  # (bs, 2048, 1, 1)
-        zg_p2 = self.avgpool(p2)  # (bs, 2048, 1, 1)
-        zg_p3 = self.avgpool(p3)  # (bs, 2048, 1, 1)
-
-        zp2 = self.maxpool_zp2(p2)
-        z0_p2 = zp2[:, :, 0:1, :]
-        z1_p2 = zp2[:, :, 1:2, :]
-
-        zp3 = self.maxpool_zp3(p3)
-        z0_p3 = zp3[:, :, 0:1, :]
-        z1_p3 = zp3[:, :, 1:2, :]
-        z2_p3 = zp3[:, :, 2:3, :]
-
-        g_p1 = zg_p1.squeeze(3).squeeze(2)  # (bs, 2048)
-        fg_p1 = self.reduction(zg_p1).squeeze(3).squeeze(2)
-        bn_fg_p1 = self.bn_neck(fg_p1)
-        g_p2 = zg_p2.squeeze(3).squeeze(2)
-        fg_p2 = self.reduction(zg_p2).squeeze(3).squeeze(2)  # (bs, 256)
-        bn_fg_p2 = self.bn_neck(fg_p2)
-        g_p3 = zg_p3.squeeze(3).squeeze(2)
-        fg_p3 = self.reduction(zg_p3).squeeze(3).squeeze(2)
-        bn_fg_p3 = self.bn_neck(fg_p3)
-
-        f0_p2 = self.bn_neck(self.reduction(z0_p2).squeeze(3).squeeze(2))
-        f1_p2 = self.bn_neck(self.reduction(z1_p2).squeeze(3).squeeze(2))
-        f0_p3 = self.bn_neck(self.reduction(z0_p3).squeeze(3).squeeze(2))
-        f1_p3 = self.bn_neck(self.reduction(z1_p3).squeeze(3).squeeze(2))
-        f2_p3 = self.bn_neck(self.reduction(z2_p3).squeeze(3).squeeze(2))
-
-        if self.training:
-            l_p1 = self.fc_id_2048_0(bn_fg_p1)
-            l_p2 = self.fc_id_2048_1(bn_fg_p2)
-            l_p3 = self.fc_id_2048_2(bn_fg_p3)
-
-            l0_p2 = self.fc_id_256_1_0(f0_p2)
-            l1_p2 = self.fc_id_256_1_1(f1_p2)
-            l0_p3 = self.fc_id_256_2_0(f0_p3)
-            l1_p3 = self.fc_id_256_2_1(f1_p3)
-            l2_p3 = self.fc_id_256_2_2(f2_p3)
-            return g_p1, g_p2, g_p3, l_p1, l_p2, l_p3, l0_p2, l1_p2, l0_p3, l1_p3, l2_p3
-            # return g_p2, l_p2, l0_p2, l1_p2, l0_p3, l1_p3, l2_p3
+        if cfg.MODEL.HEADS.POOL_LAYER == 'avgpool':
+            pool_layer = nn.AdaptiveAvgPool2d(1)
+        elif cfg.MODEL.HEADS.POOL_LAYER == 'maxpool':
+            pool_layer = nn.AdaptiveMaxPool2d(1)
+        elif cfg.MODEL.HEADS.POOL_LAYER == 'gempool':
+            pool_layer = GeneralizedMeanPoolingP()
         else:
-            return torch.cat([bn_fg_p1, bn_fg_p2, bn_fg_p3, f0_p2, f1_p2, f0_p3, f1_p3, f2_p3], dim=1)
+            pool_layer = nn.Identity()
 
-    def load_params_wo_fc(self, state_dict):
-        # state_dict.pop('classifier.weight')
-        res = self.load_state_dict(state_dict, strict=False)
-        assert str(res.missing_keys) == str(['classifier.weight',]), 'issue loading pretrained weights'
+        # branch1
+        self.b1 = nn.Sequential(
+            copy.deepcopy(res_conv4), copy.deepcopy(res_g_conv5)
+        )
+        self.b1_pool = self._build_pool_reduce(pool_layer)
+        self.b1_head = build_reid_heads(cfg, 256, nn.Identity())
+
+        # branch2
+        self.b2 = nn.Sequential(
+            copy.deepcopy(res_conv4), copy.deepcopy(res_p_conv5)
+        )
+        self.b2_pool = self._build_pool_reduce(pool_layer)
+        self.b2_head = build_reid_heads(cfg, 256, nn.Identity())
+
+        self.b21_pool = self._build_pool_reduce(pool_layer)
+        self.b21_head = build_reid_heads(cfg, 256, nn.Identity())
+
+        self.b22_pool = self._build_pool_reduce(pool_layer)
+        self.b22_head = build_reid_heads(cfg, 256, nn.Identity())
+
+        # branch3
+        self.b3 = nn.Sequential(
+            copy.deepcopy(res_conv4), copy.deepcopy(res_p_conv5)
+        )
+        self.b3_pool = self._build_pool_reduce(pool_layer)
+        self.b3_head = build_reid_heads(cfg, 256, nn.Identity())
+
+        self.b31_pool = self._build_pool_reduce(pool_layer)
+        self.b31_head = build_reid_heads(cfg, 256, nn.Identity())
+
+        self.b32_pool = self._build_pool_reduce(pool_layer)
+        self.b32_head = build_reid_heads(cfg, 256, nn.Identity())
+
+        self.b33_pool = self._build_pool_reduce(pool_layer)
+        self.b33_head = build_reid_heads(cfg, 256, nn.Identity())
+
+    def _build_pool_reduce(self, pool_layer, input_dim=2048, reduce_dim=256):
+        pool_reduce = nn.Sequential(
+            pool_layer,
+            nn.Conv2d(input_dim, reduce_dim, 1, bias=False),
+            nn.BatchNorm2d(reduce_dim),
+            nn.ReLU(True),
+            Flatten()
+        )
+        pool_reduce.apply(weights_init_kaiming)
+        return pool_reduce
+
+    def forward(self, inputs):
+        images = inputs["images"]
+        targets = inputs["targets"]
+
+        if not self.training:
+            pred_feat = self.inference(images)
+            return pred_feat, targets, inputs["camid"]
+
+        features = self.backbone(images)  # (bs, 2048, 16, 8)
+
+        # branch1
+        b1_feat = self.b1(features)
+        b1_pool_feat = self.b1_pool(b1_feat)
+        b1_logits, b1_pool_feat = self.b1_head(b1_pool_feat, targets)
+
+        # branch2
+        b2_feat = self.b2(features)
+        # global
+        b2_pool_feat = self.b2_pool(b2_feat)
+        b2_logits, b2_pool_feat = self.b2_head(b2_pool_feat, targets)
+
+        b21_feat, b22_feat = torch.chunk(b2_feat, 2, dim=2)
+        # part1
+        b21_pool_feat = self.b21_pool(b21_feat)
+        b21_logits, b21_pool_feat = self.b21_head(b21_pool_feat, targets)
+        # part2
+        b22_pool_feat = self.b22_pool(b22_feat)
+        b22_logits, b22_pool_feat = self.b22_head(b22_pool_feat, targets)
+
+        # branch3
+        b3_feat = self.b3(features)
+        # global
+        b3_pool_feat = self.b3_pool(b3_feat)
+        b3_logits, b3_pool_feat = self.b3_head(b3_pool_feat, targets)
+
+        b31_feat, b32_feat, b33_feat = torch.chunk(b3_feat, 3, dim=2)
+        # part1
+        b31_pool_feat = self.b31_pool(b31_feat)
+        b31_logits, b31_pool_feat = self.b31_head(b31_pool_feat, targets)
+        # part2
+        b32_pool_feat = self.b32_pool(b32_feat)
+        b32_logits, b32_pool_feat = self.b32_head(b32_pool_feat, targets)
+        # part3
+        b33_pool_feat = self.b33_pool(b33_feat)
+        b33_logits, b33_pool_feat = self.b33_head(b33_pool_feat, targets)
+
+        return (b1_logits, b2_logits, b3_logits, b21_logits, b22_logits, b31_logits, b32_logits, b33_logits), \
+               (b1_pool_feat, b2_pool_feat, b3_pool_feat), \
+               targets
+
+    def inference(self, images):
+        assert not self.training
+        features = self.backbone(images)  # (bs, 2048, 16, 8)
+
+        # branch1
+        b1_feat = self.b1(features)
+        b1_pool_feat = self.b1_pool(b1_feat)
+        b1_pool_feat = self.b1_head(b1_pool_feat)
+
+        # branch2
+        b2_feat = self.b2(features)
+        # global
+        b2_pool_feat = self.b2_pool(b2_feat)
+        b2_pool_feat = self.b2_head(b2_pool_feat)
+
+        b21_feat, b22_feat = torch.chunk(b2_feat, 2, dim=2)
+        # part1
+        b21_pool_feat = self.b21_pool(b21_feat)
+        b21_pool_feat = self.b21_head(b21_pool_feat)
+        # part2
+        b22_pool_feat = self.b22_pool(b22_feat)
+        b22_pool_feat = self.b22_head(b22_pool_feat)
+
+        # branch3
+        b3_feat = self.b3(features)
+        # global
+        b3_pool_feat = self.b3_pool(b3_feat)
+        b3_pool_feat = self.b3_head(b3_pool_feat)
+
+        b31_feat, b32_feat, b33_feat = torch.chunk(b3_feat, 3, dim=2)
+        # part1
+        b31_pool_feat = self.b31_pool(b31_feat)
+        b31_pool_feat = self.b31_head(b31_pool_feat)
+        # part2
+        b32_pool_feat = self.b32_pool(b32_feat)
+        b32_pool_feat = self.b32_head(b32_pool_feat)
+        # part3
+        b33_pool_feat = self.b33_pool(b33_feat)
+        b33_pool_feat = self.b33_head(b33_pool_feat)
+
+        pred_feat = torch.cat([b1_pool_feat, b2_pool_feat, b3_pool_feat, b21_pool_feat,
+                               b22_pool_feat, b31_pool_feat, b32_pool_feat, b33_pool_feat], dim=1)
+
+        return F.normalize(pred_feat)
+
+    def losses(self, outputs):
+        loss_dict = {}
+        loss_dict.update(self.b1_head.losses(self._cfg, outputs[0][0], outputs[1][0], outputs[2], 'b1_'))
+        loss_dict.update(self.b2_head.losses(self._cfg, outputs[0][1], outputs[1][1], outputs[2], 'b2_'))
+        loss_dict.update(self.b3_head.losses(self._cfg, outputs[0][2], outputs[1][2], outputs[2], 'b3_'))
+        loss_dict.update(self.b2_head.losses(self._cfg, outputs[0][3], None, outputs[2], 'b21_'))
+        loss_dict.update(self.b2_head.losses(self._cfg, outputs[0][4], None, outputs[2], 'b22_'))
+        loss_dict.update(self.b3_head.losses(self._cfg, outputs[0][5], None, outputs[2], 'b31_'))
+        loss_dict.update(self.b3_head.losses(self._cfg, outputs[0][6], None, outputs[2], 'b32_'))
+        loss_dict.update(self.b3_head.losses(self._cfg, outputs[0][7], None, outputs[2], 'b33_'))
+        return loss_dict
diff --git a/fastreid/modeling/meta_arch/mid_network.py b/fastreid/modeling/meta_arch/mid_network.py
new file mode 100644
index 0000000..ebe03b3
--- /dev/null
+++ b/fastreid/modeling/meta_arch/mid_network.py
@@ -0,0 +1,99 @@
+# encoding: utf-8
+"""
+@author:  l1aoxingyu
+@contact: sherlockliao01@gmail.com
+"""
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from .build import META_ARCH_REGISTRY
+from ..backbones import build_backbone
+from ..heads import build_reid_heads
+from ..model_utils import weights_init_kaiming
+from ...layers import Flatten, bn_no_bias
+
+
+@META_ARCH_REGISTRY.register()
+class MidNetwork(nn.Module):
+    """Residual network + mid-level features.
+
+    Reference:
+        Yu et al. The Devil is in the Middle: Exploiting Mid-level Representations for
+        Cross-Domain Instance Matching. arXiv:1711.08106.
+    Public keys:
+        - ``resnet50mid``: ResNet50 + mid-level feature fusion.
+    """
+
+    def __init__(self, cfg):
+        super().__init__()
+        self._cfg = cfg
+        # backbone
+        backbone = build_backbone(cfg)
+        self.backbone = nn.Sequential(
+            backbone.conv1,
+            backbone.bn1,
+            backbone.relu,
+            backbone.maxpool,
+            backbone.layer1,
+            backbone.layer2,
+            backbone.layer3
+        )
+        # body
+        self.res4 = backbone.layer4
+        self.avg_pool = nn.Sequential(
+            nn.AdaptiveAvgPool2d(1),
+            Flatten(),
+        )
+        self.fusion = nn.Sequential(
+            nn.Linear(4096, 1024, bias=False),
+            nn.BatchNorm1d(1024),
+            nn.ReLU(True)
+        )
+        self.fusion.apply(weights_init_kaiming)
+
+        # head
+        self.head = build_reid_heads(cfg, 3072, nn.Identity())
+
+    def forward(self, inputs):
+        images = inputs['images']
+        targets = inputs['targets']
+
+        if not self.training:
+            pred_feat = self.inference(images)
+            return pred_feat, targets, inputs['camid']
+
+        feat = self.backbone(images)
+        feat_4a = self.res4[0](feat)
+        feat_4b = self.res4[1](feat_4a)
+        feat_4c = self.res4[2](feat_4b)
+
+        feat_4a = self.avg_pool(feat_4a)
+        feat_4b = self.avg_pool(feat_4b)
+        feat_4c = self.avg_pool(feat_4c)
+        feat_4ab = torch.cat([feat_4a, feat_4b], dim=1)
+        feat_4ab = self.fusion(feat_4ab)
+        feat = torch.cat([feat_4ab, feat_4c], 1)
+
+        logist, feat = self.head(feat, targets)
+        return logist, feat, targets
+
+    def losses(self, outputs):
+        return self.head.losses(self._cfg, outputs[0], outputs[1], outputs[2])
+
+    def inference(self, images):
+        assert not self.training
+        feat = self.backbone(images)
+        feat_4a = self.res4[0](feat)
+        feat_4b = self.res4[1](feat_4a)
+        feat_4c = self.res4[2](feat_4b)
+
+        feat_4a = self.avg_pool(feat_4a)
+        feat_4b = self.avg_pool(feat_4b)
+        feat_4c = self.avg_pool(feat_4c)
+        feat_4ab = torch.cat([feat_4a, feat_4b], dim=1)
+        feat_4ab = self.fusion(feat_4ab)
+        feat = torch.cat([feat_4ab, feat_4c], 1)
+        pred_feat = self.head(feat)
+        return F.normalize(pred_feat)
diff --git a/fastreid/modeling/model_utils.py b/fastreid/modeling/model_utils.py
index 2405a04..6e68791 100644
--- a/fastreid/modeling/model_utils.py
+++ b/fastreid/modeling/model_utils.py
@@ -11,16 +11,16 @@ __all__ = ['weights_init_classifier', 'weights_init_kaiming', ]
 def weights_init_kaiming(m):
     classname = m.__class__.__name__
     if classname.find('Linear') != -1:
-        nn.init.kaiming_normal_(m.weight, a=0, mode='fan_out')
+        nn.init.normal_(m.weight, 0, 0.01)
         if m.bias is not None:
             nn.init.constant_(m.bias, 0.0)
     elif classname.find('Conv') != -1:
-        nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in')
+        nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
         if m.bias is not None:
             nn.init.constant_(m.bias, 0.0)
     elif classname.find('BatchNorm') != -1:
         if m.affine:
-            nn.init.constant_(m.weight, 1.0)
+            nn.init.normal_(m.weight, 1.0, 0.02)
             nn.init.constant_(m.bias, 0.0)
 
 
diff --git a/fastreid/solver/build.py b/fastreid/solver/build.py
index 66cd22f..a696f53 100644
--- a/fastreid/solver/build.py
+++ b/fastreid/solver/build.py
@@ -4,9 +4,8 @@
 @contact: sherlockliao01@gmail.com
 """
 
-
-import torch
-from .lr_scheduler import WarmupMultiStepLR
+from . import lr_scheduler
+from . import optim
 
 
 def build_optimizer(cfg, model):
@@ -16,29 +15,39 @@ def build_optimizer(cfg, model):
             continue
         lr = cfg.SOLVER.BASE_LR
         weight_decay = cfg.SOLVER.WEIGHT_DECAY
-        # if "base" in key:
-        #     lr = cfg.SOLVER.BASE_LR * 0.1
+        # if "heads" in key:
+        #     lr = cfg.SOLVER.BASE_LR * 10
         if "bias" in key:
-            lr = cfg.SOLVER.BASE_LR * cfg.SOLVER.BIAS_LR_FACTOR
+            lr = lr * cfg.SOLVER.BIAS_LR_FACTOR
             weight_decay = cfg.SOLVER.WEIGHT_DECAY_BIAS
         params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}]
-    if cfg.SOLVER.OPT == 'sgd':
-        opt_fns = torch.optim.SGD(params, momentum=cfg.SOLVER.MOMENTUM)
-    elif cfg.SOLVER.OPT == 'adam':
-        opt_fns = torch.optim.Adam(params)
-    elif cfg.SOLVER.OPT == 'adamw':
-        opt_fns = torch.optim.AdamW(params)
+    solver_opt = cfg.SOLVER.OPT
+    if hasattr(optim, solver_opt):
+        if solver_opt == "SGD":
+            opt_fns = getattr(optim, solver_opt)(params, momentum=cfg.SOLVER.MOMENTUM)
+        else:
+            opt_fns = getattr(optim, solver_opt)(params)
     else:
-        raise NameError(f'optimizer {cfg.SOLVER.OPT} not support')
+        raise NameError("optimizer {} not support".format(cfg.SOLVER.OPT))
     return opt_fns
 
 
 def build_lr_scheduler(cfg, optimizer):
-    return WarmupMultiStepLR(
-        optimizer,
-        cfg.SOLVER.STEPS,
-        cfg.SOLVER.GAMMA,
-        warmup_factor=cfg.SOLVER.WARMUP_FACTOR,
-        warmup_iters=cfg.SOLVER.WARMUP_ITERS,
-        warmup_method=cfg.SOLVER.WARMUP_METHOD
-    )
+    if cfg.SOLVER.SCHED == "warmup":
+        return lr_scheduler.WarmupMultiStepLR(
+            optimizer,
+            cfg.SOLVER.STEPS,
+            cfg.SOLVER.GAMMA,
+            warmup_factor=cfg.SOLVER.WARMUP_FACTOR,
+            warmup_iters=cfg.SOLVER.WARMUP_ITERS,
+            warmup_method=cfg.SOLVER.WARMUP_METHOD
+        )
+    elif cfg.SOLVER.SCHED == "delay":
+        return lr_scheduler.DelayedCosineAnnealingLR(
+            optimizer,
+            cfg.SOLVER.DELAY_ITERS,
+            cfg.SOLVER.COS_ANNEAL_ITERS,
+            warmup_factor=cfg.SOLVER.WARMUP_FACTOR,
+            warmup_iters=cfg.SOLVER.WARMUP_ITERS,
+            warmup_method=cfg.SOLVER.WARMUP_METHOD
+        )
diff --git a/fastreid/solver/lr_scheduler.py b/fastreid/solver/lr_scheduler.py
index e918555..7e66275 100644
--- a/fastreid/solver/lr_scheduler.py
+++ b/fastreid/solver/lr_scheduler.py
@@ -8,9 +8,12 @@ from bisect import bisect_right
 from typing import List
 
 import torch
+from torch.optim.lr_scheduler import _LRScheduler, CosineAnnealingLR
+
+__all__ = ["WarmupMultiStepLR", "DelayerScheduler"]
 
 
-class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler):
+class WarmupMultiStepLR(_LRScheduler):
     def __init__(
             self,
             optimizer: torch.optim.Optimizer,
@@ -72,3 +75,48 @@ def _get_warmup_factor_at_iter(
     else:
         raise ValueError("Unknown warmup method: {}".format(method))
 
+
+class DelayerScheduler(_LRScheduler):
+    """ Starts with a flat lr schedule until it reaches N epochs the applies a scheduler
+    Args:
+        optimizer (Optimizer): Wrapped optimizer.
+        delay_epochs: number of epochs to keep the initial lr until starting aplying the scheduler
+        after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau)
+    """
+
+    def __init__(self, optimizer, delay_epochs, after_scheduler, warmup_factor, warmup_iters, warmup_method):
+        self.delay_epochs = delay_epochs
+        self.after_scheduler = after_scheduler
+        self.finished = False
+        self.warmup_factor = warmup_factor
+        self.warmup_iters = warmup_iters
+        self.warmup_method = warmup_method
+        super().__init__(optimizer)
+
+    def get_lr(self):
+
+        if self.last_epoch >= self.delay_epochs:
+            if not self.finished:
+                self.after_scheduler.base_lrs = self.base_lrs
+                self.finished = True
+            return self.after_scheduler.get_lr()
+
+        warmup_factor = _get_warmup_factor_at_iter(
+            self.warmup_method, self.last_epoch, self.warmup_iters, self.warmup_factor
+        )
+        return [base_lr * warmup_factor for base_lr in self.base_lrs]
+
+    def step(self, epoch=None):
+        if self.finished:
+            if epoch is None:
+                self.after_scheduler.step(None)
+            else:
+                self.after_scheduler.step(epoch - self.delay_epochs)
+        else:
+            return super(DelayerScheduler, self).step(epoch)
+
+
+def DelayedCosineAnnealingLR(optimizer, delay_epochs, cosine_annealing_epochs, warmup_factor,
+                             warmup_iters, warmup_method):
+    base_scheduler = CosineAnnealingLR(optimizer, cosine_annealing_epochs, eta_min=0)
+    return DelayerScheduler(optimizer, delay_epochs, base_scheduler, warmup_factor, warmup_iters, warmup_method)
diff --git a/fastreid/solver/optim/__init__.py b/fastreid/solver/optim/__init__.py
new file mode 100644
index 0000000..2ba0948
--- /dev/null
+++ b/fastreid/solver/optim/__init__.py
@@ -0,0 +1,9 @@
+from .lamb import Lamb
+from .lookahead import Lookahead, LookaheadAdam
+from .novograd import Novograd
+from .over9000 import Over9000, RangerLars
+from .radam import RAdam, PlainRAdam, AdamW
+from .ralamb import Ralamb
+from .ranger import Ranger
+
+from torch.optim import *
\ No newline at end of file
diff --git a/fastreid/solver/optim/lamb.py b/fastreid/solver/optim/lamb.py
new file mode 100644
index 0000000..a8d3d26
--- /dev/null
+++ b/fastreid/solver/optim/lamb.py
@@ -0,0 +1,126 @@
+####
+# CODE TAKEN FROM https://github.com/mgrankin/over9000
+####
+
+import collections
+import math
+
+import torch
+from torch.optim.optimizer import Optimizer
+
+try: 
+    from tensorboardX import SummaryWriter
+
+    def log_lamb_rs(optimizer: Optimizer, event_writer: SummaryWriter, token_count: int):
+        """Log a histogram of trust ratio scalars in across layers."""
+        results = collections.defaultdict(list)
+        for group in optimizer.param_groups:
+            for p in group['params']:
+                state = optimizer.state[p]
+                for i in ('weight_norm', 'adam_norm', 'trust_ratio'):
+                    if i in state:
+                        results[i].append(state[i])
+
+        for k, v in results.items():
+            event_writer.add_histogram(f'lamb/{k}', torch.tensor(v), token_count)
+except ModuleNotFoundError as e: 
+    print("To use this log_lamb_rs, please run 'pip install tensorboardx'. Also you must have Tensorboard running to see results")
+
+class Lamb(Optimizer):
+    r"""Implements Lamb algorithm.
+    It has been proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_.
+    Arguments:
+        params (iterable): iterable of parameters to optimize or dicts defining
+            parameter groups
+        lr (float, optional): learning rate (default: 1e-3)
+        betas (Tuple[float, float], optional): coefficients used for computing
+            running averages of gradient and its square (default: (0.9, 0.999))
+        eps (float, optional): term added to the denominator to improve
+            numerical stability (default: 1e-8)
+        weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
+        adam (bool, optional): always use trust ratio = 1, which turns this into
+            Adam. Useful for comparison purposes.
+    .. _Large Batch Optimization for Deep Learning: Training BERT in 76 minutes:
+        https://arxiv.org/abs/1904.00962
+    """
+
+    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6,
+                 weight_decay=0, adam=False):
+        if not 0.0 <= lr:
+            raise ValueError("Invalid learning rate: {}".format(lr))
+        if not 0.0 <= eps:
+            raise ValueError("Invalid epsilon value: {}".format(eps))
+        if not 0.0 <= betas[0] < 1.0:
+            raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
+        if not 0.0 <= betas[1] < 1.0:
+            raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
+        defaults = dict(lr=lr, betas=betas, eps=eps,
+                        weight_decay=weight_decay)
+        self.adam = adam
+        super(Lamb, self).__init__(params, defaults)
+
+    def step(self, closure=None):
+        """Performs a single optimization step.
+        Arguments:
+            closure (callable, optional): A closure that reevaluates the model
+                and returns the loss.
+        """
+        loss = None
+        if closure is not None:
+            loss = closure()
+
+        for group in self.param_groups:
+            for p in group['params']:
+                if p.grad is None:
+                    continue
+                grad = p.grad.data
+                if grad.is_sparse:
+                    raise RuntimeError('Lamb does not support sparse gradients, consider SparseAdam instad.')
+
+                state = self.state[p]
+
+                # State initialization
+                if len(state) == 0:
+                    state['step'] = 0
+                    # Exponential moving average of gradient values
+                    state['exp_avg'] = torch.zeros_like(p.data)
+                    # Exponential moving average of squared gradient values
+                    state['exp_avg_sq'] = torch.zeros_like(p.data)
+
+                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
+                beta1, beta2 = group['betas']
+
+                state['step'] += 1
+
+                # Decay the first and second moment running average coefficient
+                # m_t
+                exp_avg.mul_(beta1).add_(1 - beta1, grad)
+                # v_t
+                exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
+
+                # Paper v3 does not use debiasing.
+                # bias_correction1 = 1 - beta1 ** state['step']
+                # bias_correction2 = 1 - beta2 ** state['step']
+                # Apply bias to lr to avoid broadcast.
+                step_size = group['lr'] # * math.sqrt(bias_correction2) / bias_correction1
+
+                weight_norm = p.data.pow(2).sum().sqrt().clamp(0, 10)
+
+                adam_step = exp_avg / exp_avg_sq.sqrt().add(group['eps'])
+                if group['weight_decay'] != 0:
+                    adam_step.add_(group['weight_decay'], p.data)
+
+                adam_norm = adam_step.pow(2).sum().sqrt()
+                if weight_norm == 0 or adam_norm == 0:
+                    trust_ratio = 1
+                else:
+                    trust_ratio = weight_norm / adam_norm
+                state['weight_norm'] = weight_norm
+                state['adam_norm'] = adam_norm
+                state['trust_ratio'] = trust_ratio
+                if self.adam:
+                    trust_ratio = 1
+
+                p.data.add_(-step_size * trust_ratio, adam_step)
+
+        return loss
\ No newline at end of file
diff --git a/fastreid/solver/optim/lookahead.py b/fastreid/solver/optim/lookahead.py
new file mode 100644
index 0000000..9cbbd97
--- /dev/null
+++ b/fastreid/solver/optim/lookahead.py
@@ -0,0 +1,104 @@
+####
+# CODE TAKEN FROM https://github.com/lonePatient/lookahead_pytorch
+# Original paper: https://arxiv.org/abs/1907.08610
+####
+# Lookahead implementation from https://github.com/rwightman/pytorch-image-models/blob/master/timm/optim/lookahead.py
+
+""" Lookahead Optimizer Wrapper.
+Implementation modified from: https://github.com/alphadl/lookahead.pytorch
+Paper: `Lookahead Optimizer: k steps forward, 1 step back` - https://arxiv.org/abs/1907.08610
+"""
+from collections import defaultdict
+
+import torch
+from torch.optim import Adam
+from torch.optim.optimizer import Optimizer
+
+
+class Lookahead(Optimizer):
+    def __init__(self, base_optimizer, alpha=0.5, k=6):
+        if not 0.0 <= alpha <= 1.0:
+            raise ValueError(f'Invalid slow update rate: {alpha}')
+        if not 1 <= k:
+            raise ValueError(f'Invalid lookahead steps: {k}')
+        defaults = dict(lookahead_alpha=alpha, lookahead_k=k, lookahead_step=0)
+        self.base_optimizer = base_optimizer
+        self.param_groups = self.base_optimizer.param_groups
+        self.defaults = base_optimizer.defaults
+        self.defaults.update(defaults)
+        self.state = defaultdict(dict)
+        # manually add our defaults to the param groups
+        for name, default in defaults.items():
+            for group in self.param_groups:
+                group.setdefault(name, default)
+
+    def update_slow(self, group):
+        for fast_p in group["params"]:
+            if fast_p.grad is None:
+                continue
+            param_state = self.state[fast_p]
+            if 'slow_buffer' not in param_state:
+                param_state['slow_buffer'] = torch.empty_like(fast_p.data)
+                param_state['slow_buffer'].copy_(fast_p.data)
+            slow = param_state['slow_buffer']
+            slow.add_(group['lookahead_alpha'], fast_p.data - slow)
+            fast_p.data.copy_(slow)
+
+    def sync_lookahead(self):
+        for group in self.param_groups:
+            self.update_slow(group)
+
+    def step(self, closure=None):
+        # print(self.k)
+        # assert id(self.param_groups) == id(self.base_optimizer.param_groups)
+        loss = self.base_optimizer.step(closure)
+        for group in self.param_groups:
+            group['lookahead_step'] += 1
+            if group['lookahead_step'] % group['lookahead_k'] == 0:
+                self.update_slow(group)
+        return loss
+
+    def state_dict(self):
+        fast_state_dict = self.base_optimizer.state_dict()
+        slow_state = {
+            (id(k) if isinstance(k, torch.Tensor) else k): v
+            for k, v in self.state.items()
+        }
+        fast_state = fast_state_dict['state']
+        param_groups = fast_state_dict['param_groups']
+        return {
+            'state': fast_state,
+            'slow_state': slow_state,
+            'param_groups': param_groups,
+        }
+
+    def load_state_dict(self, state_dict):
+        fast_state_dict = {
+            'state': state_dict['state'],
+            'param_groups': state_dict['param_groups'],
+        }
+        self.base_optimizer.load_state_dict(fast_state_dict)
+
+        # We want to restore the slow state, but share param_groups reference
+        # with base_optimizer. This is a bit redundant but least code
+        slow_state_new = False
+        if 'slow_state' not in state_dict:
+            print('Loading state_dict from optimizer without Lookahead applied.')
+            state_dict['slow_state'] = defaultdict(dict)
+            slow_state_new = True
+        slow_state_dict = {
+            'state': state_dict['slow_state'],
+            'param_groups': state_dict['param_groups'],  # this is pointless but saves code
+        }
+        super(Lookahead, self).load_state_dict(slow_state_dict)
+        self.param_groups = self.base_optimizer.param_groups  # make both ref same container
+        if slow_state_new:
+            # reapply defaults to catch missing lookahead specific ones
+            for name, default in self.defaults.items():
+                for group in self.param_groups:
+                    group.setdefault(name, default)
+
+
+def LookaheadAdam(params, alpha=0.5, k=6, *args, **kwargs):
+    adam = Adam(params, *args, **kwargs)
+    return Lookahead(adam, alpha, k)
diff --git a/fastreid/solver/optim/novograd.py b/fastreid/solver/optim/novograd.py
new file mode 100644
index 0000000..e14e232
--- /dev/null
+++ b/fastreid/solver/optim/novograd.py
@@ -0,0 +1,229 @@
+####
+# CODE TAKEN FROM https://github.com/mgrankin/over9000
+####
+
+# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+from torch.optim.optimizer import Optimizer
+import math
+
+
+class AdamW(Optimizer):
+    """Implements AdamW algorithm.
+  
+    It has been proposed in `Adam: A Method for Stochastic Optimization`_.
+  
+    Arguments:
+        params (iterable): iterable of parameters to optimize or dicts defining
+            parameter groups
+        lr (float, optional): learning rate (default: 1e-3)
+        betas (Tuple[float, float], optional): coefficients used for computing
+            running averages of gradient and its square (default: (0.9, 0.999))
+        eps (float, optional): term added to the denominator to improve
+            numerical stability (default: 1e-8)
+        weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
+        amsgrad (boolean, optional): whether to use the AMSGrad variant of this
+            algorithm from the paper `On the Convergence of Adam and Beyond`_
+  
+        Adam: A Method for Stochastic Optimization:
+        https://arxiv.org/abs/1412.6980
+        On the Convergence of Adam and Beyond:
+        https://openreview.net/forum?id=ryQu7f-RZ
+    """
+
+    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
+                 weight_decay=0, amsgrad=False):
+        if not 0.0 <= lr:
+            raise ValueError("Invalid learning rate: {}".format(lr))
+        if not 0.0 <= eps:
+            raise ValueError("Invalid epsilon value: {}".format(eps))
+        if not 0.0 <= betas[0] < 1.0:
+            raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
+        if not 0.0 <= betas[1] < 1.0:
+            raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
+        defaults = dict(lr=lr, betas=betas, eps=eps,
+                        weight_decay=weight_decay, amsgrad=amsgrad)
+        super(AdamW, self).__init__(params, defaults)
+
+    def __setstate__(self, state):
+        super(AdamW, self).__setstate__(state)
+        for group in self.param_groups:
+            group.setdefault('amsgrad', False)
+
+    def step(self, closure=None):
+        """Performs a single optimization step.
+  
+        Arguments:
+            closure (callable, optional): A closure that reevaluates the model
+                and returns the loss.
+        """
+        loss = None
+        if closure is not None:
+            loss = closure()
+
+        for group in self.param_groups:
+            for p in group['params']:
+                if p.grad is None:
+                    continue
+                grad = p.grad.data
+                if grad.is_sparse:
+                    raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
+                amsgrad = group['amsgrad']
+
+                state = self.state[p]
+
+                # State initialization
+                if len(state) == 0:
+                    state['step'] = 0
+                    # Exponential moving average of gradient values
+                    state['exp_avg'] = torch.zeros_like(p.data)
+                    # Exponential moving average of squared gradient values
+                    state['exp_avg_sq'] = torch.zeros_like(p.data)
+                    if amsgrad:
+                        # Maintains max of all exp. moving avg. of sq. grad. values
+                        state['max_exp_avg_sq'] = torch.zeros_like(p.data)
+
+                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
+                if amsgrad:
+                    max_exp_avg_sq = state['max_exp_avg_sq']
+                beta1, beta2 = group['betas']
+
+                state['step'] += 1
+                # Decay the first and second moment running average coefficient
+                exp_avg.mul_(beta1).add_(1 - beta1, grad)
+                exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
+                if amsgrad:
+                    # Maintains the maximum of all 2nd moment running avg. till now
+                    torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
+                    # Use the max. for normalizing running avg. of gradient
+                    denom = max_exp_avg_sq.sqrt().add_(group['eps'])
+                else:
+                    denom = exp_avg_sq.sqrt().add_(group['eps'])
+
+                bias_correction1 = 1 - beta1 ** state['step']
+                bias_correction2 = 1 - beta2 ** state['step']
+                step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1
+                p.data.add_(-step_size, torch.mul(p.data, group['weight_decay']).addcdiv_(1, exp_avg, denom))
+
+        return loss
+
+
+class Novograd(Optimizer):
+    """
+    Implements Novograd algorithm.
+
+    Args:
+        params (iterable): iterable of parameters to optimize or dicts defining
+            parameter groups
+        lr (float, optional): learning rate (default: 1e-3)
+        betas (Tuple[float, float], optional): coefficients used for computing
+            running averages of gradient and its square (default: (0.95, 0))
+        eps (float, optional): term added to the denominator to improve
+            numerical stability (default: 1e-8)
+        weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
+        grad_averaging: gradient averaging
+        amsgrad (boolean, optional): whether to use the AMSGrad variant of this
+            algorithm from the paper `On the Convergence of Adam and Beyond`_
+            (default: False)
+    """
+
+    def __init__(self, params, lr=1e-3, betas=(0.95, 0), eps=1e-8,
+                 weight_decay=0, grad_averaging=False, amsgrad=False):
+        if not 0.0 <= lr:
+            raise ValueError("Invalid learning rate: {}".format(lr))
+        if not 0.0 <= eps:
+            raise ValueError("Invalid epsilon value: {}".format(eps))
+        if not 0.0 <= betas[0] < 1.0:
+            raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
+        if not 0.0 <= betas[1] < 1.0:
+            raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
+        defaults = dict(lr=lr, betas=betas, eps=eps,
+                        weight_decay=weight_decay,
+                        grad_averaging=grad_averaging,
+                        amsgrad=amsgrad)
+
+        super(Novograd, self).__init__(params, defaults)
+
+    def __setstate__(self, state):
+        super(Novograd, self).__setstate__(state)
+        for group in self.param_groups:
+            group.setdefault('amsgrad', False)
+
+    def step(self, closure=None):
+        """Performs a single optimization step.
+
+        Arguments:
+            closure (callable, optional): A closure that reevaluates the model
+            and returns the loss.
+        """
+        loss = None
+        if closure is not None:
+            loss = closure()
+
+        for group in self.param_groups:
+            for p in group['params']:
+                if p.grad is None:
+                    continue
+                grad = p.grad.data
+                if grad.is_sparse:
+                    raise RuntimeError('Sparse gradients are not supported.')
+                amsgrad = group['amsgrad']
+
+                state = self.state[p]
+
+                # State initialization
+                if len(state) == 0:
+                    state['step'] = 0
+                    # Exponential moving average of gradient values
+                    state['exp_avg'] = torch.zeros_like(p.data)
+                    # Exponential moving average of squared gradient values
+                    state['exp_avg_sq'] = torch.zeros([]).to(state['exp_avg'].device)
+                    if amsgrad:
+                        # Maintains max of all exp. moving avg. of sq. grad. values
+                        state['max_exp_avg_sq'] = torch.zeros([]).to(state['exp_avg'].device)
+
+                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
+                if amsgrad:
+                    max_exp_avg_sq = state['max_exp_avg_sq']
+                beta1, beta2 = group['betas']
+
+                state['step'] += 1
+
+                norm = torch.sum(torch.pow(grad, 2))
+
+                if exp_avg_sq == 0:
+                    exp_avg_sq.copy_(norm)
+                else:
+                    exp_avg_sq.mul_(beta2).add_(1 - beta2, norm)
+
+                if amsgrad:
+                    # Maintains the maximum of all 2nd moment running avg. till now
+                    torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
+                    # Use the max. for normalizing running avg. of gradient
+                    denom = max_exp_avg_sq.sqrt().add_(group['eps'])
+                else:
+                    denom = exp_avg_sq.sqrt().add_(group['eps'])
+
+                grad.div_(denom)
+                if group['weight_decay'] != 0:
+                    grad.add_(group['weight_decay'], p.data)
+                if group['grad_averaging']:
+                    grad.mul_(1 - beta1)
+                exp_avg.mul_(beta1).add_(grad)
+
+                p.data.add_(-group['lr'], exp_avg)
+
+        return loss
diff --git a/fastreid/solver/optim/over9000.py b/fastreid/solver/optim/over9000.py
new file mode 100644
index 0000000..a6b3aed
--- /dev/null
+++ b/fastreid/solver/optim/over9000.py
@@ -0,0 +1,19 @@
+####
+# CODE TAKEN FROM https://github.com/mgrankin/over9000
+####
+
+from .lookahead import Lookahead
+from .ralamb import Ralamb
+
+
+# RAdam + LARS + LookAHead
+
+# Lookahead implementation from https://github.com/lonePatient/lookahead_pytorch/blob/master/optimizer.py
+# RAdam + LARS implementation from https://gist.github.com/redknightlois/c4023d393eb8f92bb44b2ab582d7ec20
+
+def Over9000(params, alpha=0.5, k=6, *args, **kwargs):
+    ralamb = Ralamb(params, *args, **kwargs)
+    return Lookahead(ralamb, alpha, k)
+
+
+RangerLars = Over9000
diff --git a/fastreid/solver/optim/radam.py b/fastreid/solver/optim/radam.py
new file mode 100644
index 0000000..476009b
--- /dev/null
+++ b/fastreid/solver/optim/radam.py
@@ -0,0 +1,255 @@
+####
+# CODE TAKEN FROM https://github.com/LiyuanLucasLiu/RAdam
+# Paper: https://arxiv.org/abs/1908.03265
+####
+
+import math
+
+import torch
+from torch.optim.optimizer import Optimizer
+
+
+class RAdam(Optimizer):
+    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, degenerated_to_sgd=True):
+        if not 0.0 <= lr:
+            raise ValueError("Invalid learning rate: {}".format(lr))
+        if not 0.0 <= eps:
+            raise ValueError("Invalid epsilon value: {}".format(eps))
+        if not 0.0 <= betas[0] < 1.0:
+            raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
+        if not 0.0 <= betas[1] < 1.0:
+            raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
+
+        self.degenerated_to_sgd = degenerated_to_sgd
+        if isinstance(params, (list, tuple)) and len(params) > 0 and isinstance(params[0], dict):
+            for param in params:
+                if 'betas' in param and (param['betas'][0] != betas[0] or param['betas'][1] != betas[1]):
+                    param['buffer'] = [[None, None, None] for _ in range(10)]
+        defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay,
+                        buffer=[[None, None, None] for _ in range(10)])
+        super(RAdam, self).__init__(params, defaults)
+
+    def __setstate__(self, state):
+        super(RAdam, self).__setstate__(state)
+
+    def step(self, closure=None):
+
+        loss = None
+        if closure is not None:
+            loss = closure()
+
+        for group in self.param_groups:
+
+            for p in group['params']:
+                if p.grad is None:
+                    continue
+                grad = p.grad.data.float()
+                if grad.is_sparse:
+                    raise RuntimeError('RAdam does not support sparse gradients')
+
+                p_data_fp32 = p.data.float()
+
+                state = self.state[p]
+
+                if len(state) == 0:
+                    state['step'] = 0
+                    state['exp_avg'] = torch.zeros_like(p_data_fp32)
+                    state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
+                else:
+                    state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
+                    state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
+
+                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
+                beta1, beta2 = group['betas']
+
+                exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
+                exp_avg.mul_(beta1).add_(1 - beta1, grad)
+
+                state['step'] += 1
+                buffered = group['buffer'][int(state['step'] % 10)]
+                if state['step'] == buffered[0]:
+                    N_sma, step_size = buffered[1], buffered[2]
+                else:
+                    buffered[0] = state['step']
+                    beta2_t = beta2 ** state['step']
+                    N_sma_max = 2 / (1 - beta2) - 1
+                    N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
+                    buffered[1] = N_sma
+
+                    # more conservative since it's an approximated value
+                    if N_sma >= 5:
+                        step_size = math.sqrt(
+                            (1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (
+                                        N_sma_max - 2)) / (1 - beta1 ** state['step'])
+                    elif self.degenerated_to_sgd:
+                        step_size = 1.0 / (1 - beta1 ** state['step'])
+                    else:
+                        step_size = -1
+                    buffered[2] = step_size
+
+                # more conservative since it's an approximated value
+                if N_sma >= 5:
+                    if group['weight_decay'] != 0:
+                        p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
+                    denom = exp_avg_sq.sqrt().add_(group['eps'])
+                    p_data_fp32.addcdiv_(-step_size * group['lr'], exp_avg, denom)
+                    p.data.copy_(p_data_fp32)
+                elif step_size > 0:
+                    if group['weight_decay'] != 0:
+                        p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
+                    p_data_fp32.add_(-step_size * group['lr'], exp_avg)
+                    p.data.copy_(p_data_fp32)
+
+        return loss
+
+
+class PlainRAdam(Optimizer):
+
+    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, degenerated_to_sgd=True):
+        if not 0.0 <= lr:
+            raise ValueError("Invalid learning rate: {}".format(lr))
+        if not 0.0 <= eps:
+            raise ValueError("Invalid epsilon value: {}".format(eps))
+        if not 0.0 <= betas[0] < 1.0:
+            raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
+        if not 0.0 <= betas[1] < 1.0:
+            raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
+
+        self.degenerated_to_sgd = degenerated_to_sgd
+        defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
+
+        super(PlainRAdam, self).__init__(params, defaults)
+
+    def __setstate__(self, state):
+        super(PlainRAdam, self).__setstate__(state)
+
+    def step(self, closure=None):
+
+        loss = None
+        if closure is not None:
+            loss = closure()
+
+        for group in self.param_groups:
+
+            for p in group['params']:
+                if p.grad is None:
+                    continue
+                grad = p.grad.data.float()
+                if grad.is_sparse:
+                    raise RuntimeError('RAdam does not support sparse gradients')
+
+                p_data_fp32 = p.data.float()
+
+                state = self.state[p]
+
+                if len(state) == 0:
+                    state['step'] = 0
+                    state['exp_avg'] = torch.zeros_like(p_data_fp32)
+                    state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
+                else:
+                    state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
+                    state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
+
+                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
+                beta1, beta2 = group['betas']
+
+                exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
+                exp_avg.mul_(beta1).add_(1 - beta1, grad)
+
+                state['step'] += 1
+                beta2_t = beta2 ** state['step']
+                N_sma_max = 2 / (1 - beta2) - 1
+                N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
+
+                # more conservative since it's an approximated value
+                if N_sma >= 5:
+                    if group['weight_decay'] != 0:
+                        p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
+                    step_size = group['lr'] * math.sqrt(
+                        (1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (
+                                    N_sma_max - 2)) / (1 - beta1 ** state['step'])
+                    denom = exp_avg_sq.sqrt().add_(group['eps'])
+                    p_data_fp32.addcdiv_(-step_size, exp_avg, denom)
+                    p.data.copy_(p_data_fp32)
+                elif self.degenerated_to_sgd:
+                    if group['weight_decay'] != 0:
+                        p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
+                    step_size = group['lr'] / (1 - beta1 ** state['step'])
+                    p_data_fp32.add_(-step_size, exp_avg)
+                    p.data.copy_(p_data_fp32)
+
+        return loss
+
+
+class AdamW(Optimizer):
+
+    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, warmup=0):
+        if not 0.0 <= lr:
+            raise ValueError("Invalid learning rate: {}".format(lr))
+        if not 0.0 <= eps:
+            raise ValueError("Invalid epsilon value: {}".format(eps))
+        if not 0.0 <= betas[0] < 1.0:
+            raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
+        if not 0.0 <= betas[1] < 1.0:
+            raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
+
+        defaults = dict(lr=lr, betas=betas, eps=eps,
+                        weight_decay=weight_decay, warmup=warmup)
+        super(AdamW, self).__init__(params, defaults)
+
+    def __setstate__(self, state):
+        super(AdamW, self).__setstate__(state)
+
+    def step(self, closure=None):
+        loss = None
+        if closure is not None:
+            loss = closure()
+
+        for group in self.param_groups:
+
+            for p in group['params']:
+                if p.grad is None:
+                    continue
+                grad = p.grad.data.float()
+                if grad.is_sparse:
+                    raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
+
+                p_data_fp32 = p.data.float()
+
+                state = self.state[p]
+
+                if len(state) == 0:
+                    state['step'] = 0
+                    state['exp_avg'] = torch.zeros_like(p_data_fp32)
+                    state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
+                else:
+                    state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
+                    state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
+
+                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
+                beta1, beta2 = group['betas']
+
+                state['step'] += 1
+
+                exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
+                exp_avg.mul_(beta1).add_(1 - beta1, grad)
+
+                denom = exp_avg_sq.sqrt().add_(group['eps'])
+                bias_correction1 = 1 - beta1 ** state['step']
+                bias_correction2 = 1 - beta2 ** state['step']
+
+                if group['warmup'] > state['step']:
+                    scheduled_lr = 1e-8 + state['step'] * group['lr'] / group['warmup']
+                else:
+                    scheduled_lr = group['lr']
+
+                step_size = scheduled_lr * math.sqrt(bias_correction2) / bias_correction1
+
+                if group['weight_decay'] != 0:
+                    p_data_fp32.add_(-group['weight_decay'] * scheduled_lr, p_data_fp32)
+
+                p_data_fp32.addcdiv_(-step_size, exp_avg, denom)
+
+                p.data.copy_(p_data_fp32)
+
+        return loss
diff --git a/fastreid/solver/optim/ralamb.py b/fastreid/solver/optim/ralamb.py
new file mode 100644
index 0000000..79c09cc
--- /dev/null
+++ b/fastreid/solver/optim/ralamb.py
@@ -0,0 +1,103 @@
+####
+# CODE TAKEN FROM https://github.com/mgrankin/over9000
+####
+
+import torch, math
+from torch.optim.optimizer import Optimizer
+
+# RAdam + LARS
+class Ralamb(Optimizer):
+
+    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0):
+        defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
+        self.buffer = [[None, None, None] for ind in range(10)]
+        super(Ralamb, self).__init__(params, defaults)
+
+    def __setstate__(self, state):
+        super(Ralamb, self).__setstate__(state)
+
+    def step(self, closure=None):
+
+        loss = None
+        if closure is not None:
+            loss = closure()
+
+        for group in self.param_groups:
+
+            for p in group['params']:
+                if p.grad is None:
+                    continue
+                grad = p.grad.data.float()
+                if grad.is_sparse:
+                    raise RuntimeError('Ralamb does not support sparse gradients')
+
+                p_data_fp32 = p.data.float()
+
+                state = self.state[p]
+
+                if len(state) == 0:
+                    state['step'] = 0
+                    state['exp_avg'] = torch.zeros_like(p_data_fp32)
+                    state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
+                else:
+                    state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
+                    state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
+
+                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
+                beta1, beta2 = group['betas']
+
+                # Decay the first and second moment running average coefficient
+                # m_t
+                exp_avg.mul_(beta1).add_(1 - beta1, grad)
+                # v_t
+                exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
+
+                state['step'] += 1
+                buffered = self.buffer[int(state['step'] % 10)]
+
+                if state['step'] == buffered[0]:
+                    N_sma, radam_step_size = buffered[1], buffered[2]
+                else:
+                    buffered[0] = state['step']
+                    beta2_t = beta2 ** state['step']
+                    N_sma_max = 2 / (1 - beta2) - 1
+                    N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
+                    buffered[1] = N_sma
+
+                    # more conservative since it's an approximated value
+                    if N_sma >= 5:
+                        radam_step_size = math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step'])
+                    else:
+                        radam_step_size = 1.0 / (1 - beta1 ** state['step'])
+                    buffered[2] = radam_step_size
+
+                if group['weight_decay'] != 0:
+                    p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
+
+                # more conservative since it's an approximated value
+                radam_step = p_data_fp32.clone()
+                if N_sma >= 5:
+                    denom = exp_avg_sq.sqrt().add_(group['eps'])
+                    radam_step.addcdiv_(-radam_step_size * group['lr'], exp_avg, denom)
+                else:
+                    radam_step.add_(-radam_step_size * group['lr'], exp_avg)
+
+                radam_norm = radam_step.pow(2).sum().sqrt()
+                weight_norm = p.data.pow(2).sum().sqrt().clamp(0, 10)
+                if weight_norm == 0 or radam_norm == 0:
+                    trust_ratio = 1
+                else:
+                    trust_ratio = weight_norm / radam_norm
+
+                state['weight_norm'] = weight_norm
+                state['adam_norm'] = radam_norm
+                state['trust_ratio'] = trust_ratio
+
+                if N_sma >= 5:
+                    p_data_fp32.addcdiv_(-radam_step_size * group['lr'] * trust_ratio, exp_avg, denom)
+                else:
+                    p_data_fp32.add_(-radam_step_size * group['lr'] * trust_ratio, exp_avg)
+
+                p.data.copy_(p_data_fp32)
+
+        return loss
\ No newline at end of file
diff --git a/fastreid/solver/optim/ranger.py b/fastreid/solver/optim/ranger.py
new file mode 100644
index 0000000..e6fd9f1
--- /dev/null
+++ b/fastreid/solver/optim/ranger.py
@@ -0,0 +1,14 @@
+####
+# CODE TAKEN FROM https://github.com/lessw2020/Ranger-Deep-Learning-Optimizer
+# Blog post: https://medium.com/@lessw/new-deep-learning-optimizer-ranger-synergistic-combination-of-radam-lookahead-for-the-best-of-2dc83f79a48d
+####
+
+import math
+import torch
+from .lookahead import Lookahead
+from .radam import RAdam
+
+
+def Ranger(params, alpha=0.5, k=6, betas=(.95, 0.999), *args, **kwargs):
+    radam = RAdam(params, betas=betas, *args, **kwargs)
+    return Lookahead(radam, alpha, k)
diff --git a/projects/AGWBaseline/agwbaseline/gem_pool.py b/projects/AGWBaseline/agwbaseline/gem_pool.py
index 2a11d7b..54c863e 100644
--- a/projects/AGWBaseline/agwbaseline/gem_pool.py
+++ b/projects/AGWBaseline/agwbaseline/gem_pool.py
@@ -73,7 +73,7 @@ class GeM_BN_Linear(nn.Module):
         bn_features = self.bnneck(global_features)
 
         if not self.training:
-            return F.normalize(bn_features),
+            return F.normalize(bn_features)
 
         pred_class_logits = self.classifier(bn_features)
         return pred_class_logits, global_features, targets,
diff --git a/projects/StrongBaseline/README.md b/projects/StrongBaseline/README.md
index 7265784..dd82ec0 100644
--- a/projects/StrongBaseline/README.md
+++ b/projects/StrongBaseline/README.md
@@ -19,24 +19,21 @@ CUDA_VISIBLE_DEVICES=0,1,2,3 python train_net.py --config-file='configs/baseline
 
 ### Market1501 dataset
 
-| Method | Pretrained | Rank@1 | mAP |
-| :---: | :---: | :---: |:---: |
-| BagTricks | ImageNet | 93.3% | 85.2% |
-| BagTricks + Ibn-a | ImageNet | 94.9% | 87.1% |
-| BagTricks + Ibn-a + softMargin | ImageNet | 94.8% | 87.7% |
+| Method | Pretrained | Rank@1 | mAP | mINP |
+| :---: | :---: | :---: |:---: | :---: |
+| BagTricks | ImageNet | 93.6% | 85.1% | 58.1% |
+| BagTricks + Ibn-a | ImageNet | 94.8% | 87.3% | 63.5% |
 
 ### DukeMTMC dataset
 
-| Method | Pretrained | Rank@1 | mAP |
-| :---: | :---: | :---: |:---: |
-| BagTricks | ImageNet | 86.6% | 77.3% |
-| BagTricks + Ibn-a | ImageNet | 88.8% | 78.6% |
-| BagTricks + Ibn-a + softMargin | ImageNet | 89.1% | 78.9% |
+| Method | Pretrained | Rank@1 | mAP | mINP |
+| :---: | :---: | :---: |:---: | :---: |
+| BagTricks | ImageNet | 86.1% | 75.9% | 38.7% |
+| BagTricks + Ibn-a | ImageNet | 89.0% | 78.8% | 43.6% |
 
 ### MSMT17 dataset
 
-| Method | Pretrained | Rank@1 | mAP |
-| :---: | :---: | :---: |:---: |
-| BagTricks | ImageNet | 72.0% | 48.6% |
-| BagTricks + Ibn-a | ImageNet | 77.7% | 54.6% |
-| BagTricks + Ibn-a + softMargin | ImageNet | 77.3% | 55.7% |
+| Method | Pretrained | Rank@1 | mAP | mINP |
+| :---: | :---: | :---: |:---: | :---: |
+| BagTricks | ImageNet | 70.4%  | 47.5% | 9.6% |
+| BagTricks + Ibn-a | ImageNet | 76.9% | 55.0% | 13.5% |
diff --git a/projects/StrongBaseline/configs/Base-Strongbaseline.yml b/projects/StrongBaseline/configs/Base-Strongbaseline.yml
index de333c4..5500de8 100644
--- a/projects/StrongBaseline/configs/Base-Strongbaseline.yml
+++ b/projects/StrongBaseline/configs/Base-Strongbaseline.yml
@@ -9,15 +9,14 @@ MODEL:
     PRETRAIN: True
 
   HEADS:
-    NAME: "BNneckLinear"
-    NUM_CLASSES: 702
+    NAME: "BNneckHead"
 
   LOSSES:
     NAME: ("CrossEntropyLoss", "TripletLoss")
     SMOOTH_ON: True
     SCALE_CE: 1.0
 
-    MARGIN: 0.0
+    MARGIN: 0.3
     SCALE_TRI: 1.0
 
 DATASETS:
@@ -42,17 +41,18 @@ DATALOADER:
   NUM_WORKERS: 16
 
 SOLVER:
-  OPT: "adam"
+  OPT: "Adam"
   MAX_ITER: 18000
   BASE_LR: 0.00035
+  BIAS_LR_FACTOR: 2
   WEIGHT_DECAY: 0.0005
-  WEIGHT_DECAY_BIAS: 0.0005
+  WEIGHT_DECAY_BIAS: 0.0
   IMS_PER_BATCH: 64
 
   STEPS: [8000, 14000]
   GAMMA: 0.1
 
-  WARMUP_FACTOR: 0.1
+  WARMUP_FACTOR: 0.01
   WARMUP_ITERS: 2000
 
   LOG_PERIOD: 200
@@ -64,4 +64,4 @@ TEST:
 
 CUDNN_BENCHMARK: True
 
-OUTPUT_DIR: "logs/fastreid_dukemtmc/ibn_softmax_softtriplet"
+OUTPUT_DIR: "logs/dukemtmc/softmax"
diff --git a/projects/StrongBaseline/configs/Base-Strongbaseline_ibn.yml b/projects/StrongBaseline/configs/Base-Strongbaseline_ibn.yml
index 8f4d3c2..874935a 100644
--- a/projects/StrongBaseline/configs/Base-Strongbaseline_ibn.yml
+++ b/projects/StrongBaseline/configs/Base-Strongbaseline_ibn.yml
@@ -3,5 +3,5 @@ _BASE_: "Base-Strongbaseline.yml"
 MODEL:
   BACKBONE:
     WITH_IBN: True
-    PRETRAIN_PATH: "/home/liaoxingyu2/lxy/.cache/torch/checkpoints/resnet50_ibn_a.pth.tar"
+    PRETRAIN_PATH: "/export/home/lxy/.cache/torch/checkpoints/resnet50_ibn_a.pth.tar"
 
diff --git a/projects/StrongBaseline/configs/baseline_dukemtmc.yml b/projects/StrongBaseline/configs/baseline_dukemtmc.yml
index def7d1b..8946ceb 100644
--- a/projects/StrongBaseline/configs/baseline_dukemtmc.yml
+++ b/projects/StrongBaseline/configs/baseline_dukemtmc.yml
@@ -1,11 +1,43 @@
 _BASE_: "Base-Strongbaseline.yml"
 
 MODEL:
+  META_ARCHITECTURE: "MGN_v2"
   HEADS:
+    POOL_LAYER: "maxpool"
+    NAME: "StandardHead"
     NUM_CLASSES: 702
 
+  LOSSES:
+    NAME: ("CrossEntropyLoss", "TripletLoss")
+    SMOOTH_ON: True
+    SCALE_CE: 0.1
+
+    MARGIN: 0.3
+    SCALE_TRI: 0.167
+
+INPUT:
+  RE:
+    ENABLED: True
+    PROB: 0.5
+  CUTOUT:
+    ENABLED: False
+
+SOLVER:
+  MAX_ITER: 9000
+  BASE_LR: 0.00035
+  BIAS_LR_FACTOR: 2
+  WEIGHT_DECAY: 0.0005
+  WEIGHT_DECAY_BIAS: 0.0
+  IMS_PER_BATCH: 256
+
+  STEPS: [4000, 7000]
+  GAMMA: 0.1
+
+  WARMUP_FACTOR: 0.01
+  WARMUP_ITERS: 1000
+
 DATASETS:
   NAMES: ("DukeMTMC",)
   TESTS: ("DukeMTMC",)
 
-OUTPUT_DIR: "logs/fastreid_dukemtmc/softmax_softmargin"
+OUTPUT_DIR: "logs/dukemtmc/mgn_v2"
diff --git a/projects/StrongBaseline/configs/baseline_ibn_dukemtmc.yml b/projects/StrongBaseline/configs/baseline_ibn_dukemtmc.yml
index ba1ef13..630a699 100644
--- a/projects/StrongBaseline/configs/baseline_ibn_dukemtmc.yml
+++ b/projects/StrongBaseline/configs/baseline_ibn_dukemtmc.yml
@@ -8,4 +8,4 @@ DATASETS:
   NAMES: ("DukeMTMC",)
   TESTS: ("DukeMTMC",)
 
-OUTPUT_DIR: "logs/fastreid_dukemtmc/ibn_softmax_softmargin"
+OUTPUT_DIR: "logs/dukemtmc/ibn_bagtricks"
diff --git a/projects/StrongBaseline/configs/baseline_ibn_market1501.yml b/projects/StrongBaseline/configs/baseline_ibn_market1501.yml
index 87395d9..809bff7 100644
--- a/projects/StrongBaseline/configs/baseline_ibn_market1501.yml
+++ b/projects/StrongBaseline/configs/baseline_ibn_market1501.yml
@@ -8,4 +8,4 @@ DATASETS:
   NAMES: ("Market1501",)
   TESTS: ("Market1501",)
 
-OUTPUT_DIR: "logs/fastreid_market1501/ibn_softmax_softmargin"
+OUTPUT_DIR: "logs/market1501/ibn_bagtricks"
diff --git a/projects/StrongBaseline/configs/baseline_ibn_msmt17.yml b/projects/StrongBaseline/configs/baseline_ibn_msmt17.yml
index 2b2f10a..cb403ab 100644
--- a/projects/StrongBaseline/configs/baseline_ibn_msmt17.yml
+++ b/projects/StrongBaseline/configs/baseline_ibn_msmt17.yml
@@ -10,10 +10,8 @@ DATASETS:
 
 SOLVER:
   MAX_ITER: 45000
-
   STEPS: [20000, 35000]
-
-  WARMUP_ITERS: 5000
+  WARMUP_ITERS: 2000
 
   LOG_PERIOD: 500
   CHECKPOINT_PERIOD: 15000
@@ -21,4 +19,4 @@ SOLVER:
 TEST:
   EVAL_PERIOD: 15000
 
-OUTPUT_DIR: "logs/fastreid_msmt17/ibn_softmax_softmargin"
+OUTPUT_DIR: "logs/msmt17/ibn_bagtricks"
diff --git a/projects/StrongBaseline/configs/baseline_market1501.yml b/projects/StrongBaseline/configs/baseline_market1501.yml
index 38d3c46..325afd6 100644
--- a/projects/StrongBaseline/configs/baseline_market1501.yml
+++ b/projects/StrongBaseline/configs/baseline_market1501.yml
@@ -1,25 +1,12 @@
 _BASE_: "Base-Strongbaseline.yml"
 
 MODEL:
-  BACKBONE:
-    PRETRAIN: True
-
   HEADS:
-    NAME: "BNneckLinear"
     NUM_CLASSES: 751
 
-  LOSSES:
-    NAME: ("CrossEntropyLoss", "TripletLoss")
-    SMOOTH_ON: True
-    SCALE_CE: 1.0
-
-    MARGIN: 0.0
-    SCALE_TRI: 1.0
-
-
 DATASETS:
   NAMES: ("Market1501",)
   TESTS: ("Market1501",)
 
 
-OUTPUT_DIR: "logs/market1501/test"
+OUTPUT_DIR: "logs/market1501/bagtricks"
diff --git a/projects/StrongBaseline/configs/baseline_msmt17.yml b/projects/StrongBaseline/configs/baseline_msmt17.yml
index 760d2f6..b5003ab 100644
--- a/projects/StrongBaseline/configs/baseline_msmt17.yml
+++ b/projects/StrongBaseline/configs/baseline_msmt17.yml
@@ -10,10 +10,8 @@ DATASETS:
 
 SOLVER:
   MAX_ITER: 45000
-
   STEPS: [20000, 35000]
-
-  WARMUP_ITERS: 5000
+  WARMUP_ITERS: 2000
 
   LOG_PERIOD: 500
   CHECKPOINT_PERIOD: 15000
@@ -21,4 +19,4 @@ SOLVER:
 TEST:
   EVAL_PERIOD: 15000
 
-OUTPUT_DIR: "logs/fastreid_msmt17/softmax_softmargin"
+OUTPUT_DIR: "logs/msmt17/bagtricks"
diff --git a/projects/StrongBaseline/non_linear_head.py b/projects/StrongBaseline/non_linear_head.py
deleted file mode 100644
index 4dc2ee9..0000000
--- a/projects/StrongBaseline/non_linear_head.py
+++ /dev/null
@@ -1,78 +0,0 @@
-# encoding: utf-8
-"""
-@author:  l1aoxingyu
-@contact: sherlockliao01@gmail.com
-"""
-
-import math
-
-import torch
-import torch.nn.functional as F
-from torch import nn
-from torch.nn import Parameter
-
-from fastreid.modeling.heads import REID_HEADS_REGISTRY
-from fastreid.modeling.model_utils import weights_init_classifier, weights_init_kaiming
-
-
-@REID_HEADS_REGISTRY.register()
-class NonLinear(nn.Module):
-    def __init__(self, cfg):
-        super().__init__()
-        self._num_classes = cfg.MODEL.HEADS.NUM_CLASSES
-        self.gap = nn.AdaptiveAvgPool2d(1)
-
-        self.fc1 = nn.Linear(2048, 1024, bias=False)
-        self.bn1 = nn.BatchNorm1d(1024)
-        # self.bn1.bias.requires_grad_(False)
-        self.relu = nn.ReLU(True)
-        self.fc2 = nn.Linear(1024, 512, bias=False)
-        self.bn2 = nn.BatchNorm1d(512)
-        self.bn2.bias.requires_grad_(False)
-
-        self._m = 0.50
-        self._s = 30.0
-        self._in_features = 512
-        self.cos_m = math.cos(self._m)
-        self.sin_m = math.sin(self._m)
-
-        self.th = math.cos(math.pi - self._m)
-        self.mm = math.sin(math.pi - self._m) * self._m
-
-        self.weight = Parameter(torch.Tensor(self._num_classes, self._in_features))
-
-        self.init_parameters()
-
-    def init_parameters(self):
-        self.fc1.apply(weights_init_kaiming)
-        self.bn1.apply(weights_init_kaiming)
-        self.fc2.apply(weights_init_kaiming)
-        self.bn2.apply(weights_init_kaiming)
-        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
-
-    def forward(self, features, targets=None):
-        global_features = self.gap(features)
-        global_features = global_features.view(global_features.shape[0], -1)
-
-        if not self.training:
-            return F.normalize(global_features)
-
-        fc_features = self.fc1(global_features)
-        fc_features = self.bn1(fc_features)
-        fc_features = self.relu(fc_features)
-        fc_features = self.fc2(fc_features)
-        fc_features = self.bn2(fc_features)
-
-        cosine = F.linear(F.normalize(fc_features), F.normalize(self.weight))
-        sine = torch.sqrt((1.0 - torch.pow(cosine, 2)).clamp(0, 1))
-        phi = cosine * self.cos_m - sine * self.sin_m
-        phi = torch.where(cosine > self.th, phi, cosine - self.mm)
-        # --------------------------- convert label to one-hot ---------------------------
-        # one_hot = torch.zeros(cosine.size(), requires_grad=True, device='cuda')
-        one_hot = torch.zeros(cosine.size(), device='cuda')
-        one_hot.scatter_(1, targets.view(-1, 1).long(), 1)
-        # -------------torch.where(out_i = {x_i if condition_i else y_i) -------------
-        pred_class_logits = (one_hot * phi) + (
-                    (1.0 - one_hot) * cosine)  # you can use torch.where if your torch.__version__ is 0.4
-        pred_class_logits *= self._s
-        return pred_class_logits, global_features, targets
diff --git a/projects/StrongBaseline/train_net.py b/projects/StrongBaseline/train_net.py
index 3d83408..05fc9ca 100644
--- a/projects/StrongBaseline/train_net.py
+++ b/projects/StrongBaseline/train_net.py
@@ -4,14 +4,25 @@
 @contact: sherlockliao01@gmail.com
 """
 
+import os
 import sys
 
+from torch import nn
+
 sys.path.append('../..')
 from fastreid.config import get_cfg
 from fastreid.engine import DefaultTrainer, default_argument_parser, default_setup
 from fastreid.utils.checkpoint import Checkpointer
+from fastreid.evaluation import ReidEvaluator
+from reduce_head import ReduceHead
 
-from non_linear_head import NonLinear
+
+class Trainer(DefaultTrainer):
+    @classmethod
+    def build_evaluator(cls, cfg, num_query, output_folder=None):
+        if output_folder is None:
+            output_folder = os.path.join(cfg.OUTPUT_DIR, "inference")
+        return ReidEvaluator(cfg, num_query)
 
 
 def setup(args):
@@ -30,19 +41,18 @@ def main(args):
     cfg = setup(args)
 
     if args.eval_only:
-        model = DefaultTrainer.build_model(cfg)
+        cfg.defrost()
+        cfg.MODEL.BACKBONE.PRETRAIN = False
+        model = Trainer.build_model(cfg)
+        model = nn.DataParallel(model)
+        model = model.cuda()
         Checkpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(
             cfg.MODEL.WEIGHTS, resume=args.resume
         )
-        res = DefaultTrainer.test(cfg, model)
+        res = Trainer.test(cfg, model)
         return res
 
-    trainer = DefaultTrainer(cfg)
-    # moco pretrain
-    # import torch
-    # state_dict = torch.load('logs/model_0109999.pth')['model_ema']
-    # ret = trainer.model.module.load_state_dict(state_dict, strict=False)
-    #
+    trainer = Trainer(cfg)
     trainer.resume_or_load(resume=args.resume)
     return trainer.train()
 
diff --git a/tests/layer_test.py b/tests/layer_test.py
deleted file mode 100644
index 37790cb..0000000
--- a/tests/layer_test.py
+++ /dev/null
@@ -1,10 +0,0 @@
-# encoding: utf-8
-"""
-@author:  liaoxingyu
-@contact: sherlockliao01@gmail.com
-"""
-
-from ops import BatchCrop
-
-
-net = BatchCrop()
\ No newline at end of file
diff --git a/tests/model_test.py b/tests/model_test.py
index c832b82..fa13e87 100644
--- a/tests/model_test.py
+++ b/tests/model_test.py
@@ -1,34 +1,38 @@
-# encoding: utf-8
-"""
-@author:  liaoxingyu
-@contact: sherlockliao01@gmail.com
-"""
+import unittest
 
 import torch
 
 import sys
-sys.path.append(".")
-from config import cfg
-from modeling import build_model
-from modeling.bdnet import BDNet
-
-cfg.MODEL.BACKBONE = 'resnet50'
-cfg.MODEL.WITH_IBN = False
-# cfg.MODEL.PRETRAIN_PATH = '/export/home/lxy/.cache/torch/checkpoints/resnet50_ibn_a.pth.tar'
-
-net = BDNet('resnet50', 100, 1, False, None, cfg.MODEL.STAGE_WITH_GCB, False)
-y = net(torch.randn(2, 3, 256, 128))
-print(3)
-# net = MGN_P('resnet50', 100, 1, False, None, cfg.MODEL.STAGE_WITH_GCB, cfg.MODEL.PRETRAIN, cfg.MODEL.PRETRAIN_PATH)
-# net = MGN('resnet50', 100, 2, False,None, cfg.MODEL.STAGE_WITH_GCB, cfg.MODEL.PRETRAIN, cfg.MODEL.PRETRAIN_PATH)
-# net.eval()
-# net = net.cuda()
-# x = torch.randn(10, 3, 256, 128)
-# y = net(x)
-# net = osnet_x1_0(False)
-# net(torch.randn(1, 3, 256, 128))
-# from ipdb import set_trace; set_trace()
-# label = torch.ones(10).long().cuda()
-# y = net(x, label)
+sys.path.append('.')
+from fastreid.config import cfg
+from fastreid.modeling.backbones import build_resnet_backbone
+from fastreid.modeling.backbones.resnet_ibn_a import se_resnet101_ibn_a
+from torch import nn
 
 
+class MyTestCase(unittest.TestCase):
+    def test_se_resnet101(self):
+        cfg.MODEL.BACKBONE.NAME = 'resnet101'
+        cfg.MODEL.BACKBONE.DEPTH = 101
+        cfg.MODEL.BACKBONE.WITH_IBN = True
+        cfg.MODEL.BACKBONE.WITH_SE = True
+        cfg.MODEL.BACKBONE.PRETRAIN_PATH = '/export/home/lxy/.cache/torch/checkpoints/se_resnet101_ibn_a.pth.tar'
+
+        net1 = build_resnet_backbone(cfg)
+        net1.cuda()
+        net2 = nn.DataParallel(se_resnet101_ibn_a())
+        res = net2.load_state_dict(torch.load(cfg.MODEL.BACKBONE.PRETRAIN_PATH)['state_dict'], strict=False)
+        net2.cuda()
+        x = torch.randn(10, 3, 256, 128).cuda()
+        y1 = net1(x)
+        y2 = net2(x)
+        assert y1.sum() == y2.sum(), 'train mode problem'
+        net1.eval()
+        net2.eval()
+        y1 = net1(x)
+        y2 = net2(x)
+        assert y1.sum() == y2.sum(), 'eval mode problem'
+
+
+if __name__ == '__main__':
+    unittest.main()