diff --git a/README.md b/README.md
index d306afe..c312ef2 100644
--- a/README.md
+++ b/README.md
@@ -4,6 +4,7 @@ FastReID is a research platform that implements state-of-the-art re-identificati
 
 ## What's New
 
+- [Aug 2020] ONNX/TensorRT converter is supported.
 - [Jul 2020] Distributed training with multiple GPUs, it trains much faster.
 - [Jul 2020] `MAX_ITER` in config means `epoch`, it will auto scale to maximum iterations.
 - Includes more features such as circle loss, abundant visualization methods and evaluation metrics, SoTA results on conventional, cross-domain, partial and vehicle re-id, testing on multi-datasets simultaneously, etc.
diff --git a/configs/Base-Strongerbaseline.yml b/configs/Base-Strongerbaseline.yml
index 7cebdd0..81e6f03 100644
--- a/configs/Base-Strongerbaseline.yml
+++ b/configs/Base-Strongerbaseline.yml
@@ -9,7 +9,7 @@ MODEL:
   HEADS:
     NECK_FEAT: "after"
     POOL_LAYER: "gempool"
-    CLS_LAYER: "circle"
+    CLS_LAYER: "circleSoftmax"
     SCALE: 64
     MARGIN: 0.35
 
diff --git a/configs/Base-bagtricks.yml b/configs/Base-bagtricks.yml
index ee01d1b..1cc7fe1 100644
--- a/configs/Base-bagtricks.yml
+++ b/configs/Base-bagtricks.yml
@@ -4,7 +4,7 @@ MODEL:
   BACKBONE:
     NAME: "build_resnet_backbone"
     NORM: "BN"
-    DEPTH: 50
+    DEPTH: "50x"
     LAST_STRIDE: 1
     WITH_IBN: False
     PRETRAIN: True
diff --git a/configs/DukeMTMC/AGW_R101-ibn.yml b/configs/DukeMTMC/AGW_R101-ibn.yml
index 1b4766e..dcf8520 100644
--- a/configs/DukeMTMC/AGW_R101-ibn.yml
+++ b/configs/DukeMTMC/AGW_R101-ibn.yml
@@ -4,7 +4,6 @@ MODEL:
   BACKBONE:
     DEPTH: 101
     WITH_IBN: True
-    PRETRAIN_PATH: "/home/liaoxingyu2/lxy/.cache/torch/checkpoints/resnet101_ibn_a.pth.tar"
 
 DATASETS:
   NAMES: ("DukeMTMC",)
diff --git a/configs/DukeMTMC/AGW_R50-ibn.yml b/configs/DukeMTMC/AGW_R50-ibn.yml
index cd021a2..648660f 100644
--- a/configs/DukeMTMC/AGW_R50-ibn.yml
+++ b/configs/DukeMTMC/AGW_R50-ibn.yml
@@ -3,7 +3,6 @@ _BASE_: "../Base-AGW.yml"
 MODEL:
   BACKBONE:
     WITH_IBN: True
-    PRETRAIN_PATH: "/home/liaoxingyu2/lxy/.cache/torch/checkpoints/resnet50_ibn_a.pth.tar"
 
 DATASETS:
   NAMES: ("DukeMTMC",)
diff --git a/configs/DukeMTMC/bagtricks_R101-ibn.yml b/configs/DukeMTMC/bagtricks_R101-ibn.yml
index 6e0bc9a..ca3b4bc 100644
--- a/configs/DukeMTMC/bagtricks_R101-ibn.yml
+++ b/configs/DukeMTMC/bagtricks_R101-ibn.yml
@@ -4,7 +4,6 @@ MODEL:
   BACKBONE:
     DEPTH: 101
     WITH_IBN: True
-    PRETRAIN_PATH: "/home/liaoxingyu2/lxy/.cache/torch/checkpoints/resnet101_ibn_a.pth.tar"
 
 DATASETS:
   NAMES: ("DukeMTMC",)
diff --git a/configs/DukeMTMC/bagtricks_R50-ibn.yml b/configs/DukeMTMC/bagtricks_R50-ibn.yml
index 9c51ab9..cb46929 100644
--- a/configs/DukeMTMC/bagtricks_R50-ibn.yml
+++ b/configs/DukeMTMC/bagtricks_R50-ibn.yml
@@ -3,7 +3,6 @@ _BASE_: "../Base-bagtricks.yml"
 MODEL:
   BACKBONE:
     WITH_IBN: True
-    PRETRAIN_PATH: "/home/liaoxingyu2/lxy/.cache/torch/checkpoints/resnet50_ibn_a.pth.tar"
 
 DATASETS:
   NAMES: ("DukeMTMC",)
diff --git a/configs/DukeMTMC/mgn_R50-ibn.yml b/configs/DukeMTMC/mgn_R50-ibn.yml
index 2d8f6fb..ab6bf7c 100644
--- a/configs/DukeMTMC/mgn_R50-ibn.yml
+++ b/configs/DukeMTMC/mgn_R50-ibn.yml
@@ -3,7 +3,6 @@ _BASE_: "../Base-MGN.yml"
 MODEL:
   BACKBONE:
     WITH_IBN: True
-    PRETRAIN_PATH: "/home/liaoxingyu2/lxy/.cache/torch/checkpoints/resnet50_ibn_a.pth.tar"
 
 DATASETS:
   NAMES: ("DukeMTMC",)
diff --git a/configs/DukeMTMC/sbs_R101-ibn.yml b/configs/DukeMTMC/sbs_R101-ibn.yml
index 10d26b5..b22aac0 100644
--- a/configs/DukeMTMC/sbs_R101-ibn.yml
+++ b/configs/DukeMTMC/sbs_R101-ibn.yml
@@ -4,7 +4,6 @@ MODEL:
   BACKBONE:
     DEPTH: 101
     WITH_IBN: True
-    PRETRAIN_PATH: "/home/liaoxingyu2/lxy/.cache/torch/checkpoints/resnet101_ibn_a.pth.tar"
 
 DATASETS:
   NAMES: ("DukeMTMC",)
diff --git a/configs/DukeMTMC/sbs_R50-ibn.yml b/configs/DukeMTMC/sbs_R50-ibn.yml
index 8b6cd87..cbca66a 100644
--- a/configs/DukeMTMC/sbs_R50-ibn.yml
+++ b/configs/DukeMTMC/sbs_R50-ibn.yml
@@ -3,7 +3,6 @@ _BASE_: "../Base-Strongerbaseline.yml"
 MODEL:
   BACKBONE:
     WITH_IBN: True
-    PRETRAIN_PATH: "/home/liaoxingyu2/lxy/.cache/torch/checkpoints/resnet50_ibn_a.pth.tar"
 
 DATASETS:
   NAMES: ("DukeMTMC",)
diff --git a/configs/MSMT17/AGW_R101-ibn.yml b/configs/MSMT17/AGW_R101-ibn.yml
index 5283ba5..12f49a2 100644
--- a/configs/MSMT17/AGW_R101-ibn.yml
+++ b/configs/MSMT17/AGW_R101-ibn.yml
@@ -4,7 +4,6 @@ MODEL:
   BACKBONE:
     DEPTH: 101
     WITH_IBN: True
-    PRETRAIN_PATH: "/home/liaoxingyu2/lxy/.cache/torch/checkpoints/resnet101_ibn_a.pth.tar"
 
 DATASETS:
   NAMES: ("MSMT17",)
diff --git a/configs/MSMT17/AGW_R50-ibn.yml b/configs/MSMT17/AGW_R50-ibn.yml
index fb57808..6104ed6 100644
--- a/configs/MSMT17/AGW_R50-ibn.yml
+++ b/configs/MSMT17/AGW_R50-ibn.yml
@@ -3,7 +3,6 @@ _BASE_: "../Base-AGW.yml"
 MODEL:
   BACKBONE:
     WITH_IBN: True
-    PRETRAIN_PATH: "/home/liaoxingyu2/lxy/.cache/torch/checkpoints/resnet50_ibn_a.pth.tar"
 
 DATASETS:
   NAMES: ("MSMT17",)
diff --git a/configs/MSMT17/bagtricks_R101-ibn.yml b/configs/MSMT17/bagtricks_R101-ibn.yml
index ba86d24..aa21848 100644
--- a/configs/MSMT17/bagtricks_R101-ibn.yml
+++ b/configs/MSMT17/bagtricks_R101-ibn.yml
@@ -4,7 +4,6 @@ MODEL:
   BACKBONE:
     DEPTH: 101
     WITH_IBN: True
-    PRETRAIN_PATH: "/home/liaoxingyu2/lxy/.cache/torch/checkpoints/resnet101_ibn_a.pth.tar"
 
 DATASETS:
   NAMES: ("MSMT17",)
diff --git a/configs/MSMT17/bagtricks_R50-ibn.yml b/configs/MSMT17/bagtricks_R50-ibn.yml
index 563d48e..fac921e 100644
--- a/configs/MSMT17/bagtricks_R50-ibn.yml
+++ b/configs/MSMT17/bagtricks_R50-ibn.yml
@@ -3,7 +3,6 @@ _BASE_: "../Base-bagtricks.yml"
 MODEL:
   BACKBONE:
     WITH_IBN: True
-    PRETRAIN_PATH: "/home/liaoxingyu2/lxy/.cache/torch/checkpoints/resnet50_ibn_a.pth.tar"
 
 DATASETS:
   NAMES: ("MSMT17",)
diff --git a/configs/MSMT17/mgn_R50-ibn.yml b/configs/MSMT17/mgn_R50-ibn.yml
index 303371f..07f18dd 100644
--- a/configs/MSMT17/mgn_R50-ibn.yml
+++ b/configs/MSMT17/mgn_R50-ibn.yml
@@ -3,7 +3,6 @@ _BASE_: "../Base-MGN.yml"
 MODEL:
   BACKBONE:
     WITH_IBN: True
-    PRETRAIN_PATH: "/home/liaoxingyu2/lxy/.cache/torch/checkpoints/resnet50_ibn_a.pth.tar"
 
 DATASETS:
   NAMES: ("MSMT17",)
diff --git a/configs/MSMT17/sbs_R101-ibn.yml b/configs/MSMT17/sbs_R101-ibn.yml
index cebd7ff..85c7c1a 100644
--- a/configs/MSMT17/sbs_R101-ibn.yml
+++ b/configs/MSMT17/sbs_R101-ibn.yml
@@ -4,7 +4,6 @@ MODEL:
   BACKBONE:
     DEPTH: 101
     WITH_IBN: True
-    PRETRAIN_PATH: "/home/liaoxingyu2/lxy/.cache/torch/checkpoints/resnet101_ibn_a.pth.tar"
 
 DATASETS:
   NAMES: ("MSMT17",)
diff --git a/configs/MSMT17/sbs_R50-ibn.yml b/configs/MSMT17/sbs_R50-ibn.yml
index 64cb29d..d90ce4a 100644
--- a/configs/MSMT17/sbs_R50-ibn.yml
+++ b/configs/MSMT17/sbs_R50-ibn.yml
@@ -3,7 +3,6 @@ _BASE_: "../Base-Strongerbaseline.yml"
 MODEL:
   BACKBONE:
     WITH_IBN: True
-    PRETRAIN_PATH: "/home/liaoxingyu2/lxy/.cache/torch/checkpoints/resnet50_ibn_a.pth.tar"
 
 DATASETS:
   NAMES: ("MSMT17",)
diff --git a/configs/Market1501/AGW_R101-ibn.yml b/configs/Market1501/AGW_R101-ibn.yml
index afe2020..d86dbc8 100644
--- a/configs/Market1501/AGW_R101-ibn.yml
+++ b/configs/Market1501/AGW_R101-ibn.yml
@@ -4,7 +4,6 @@ MODEL:
   BACKBONE:
     DEPTH: 101
     WITH_IBN: True
-    PRETRAIN_PATH: "/home/liaoxingyu2/lxy/.cache/torch/checkpoints/resnet101_ibn_a.pth.tar"
 
 DATASETS:
   NAMES: ("Market1501",)
diff --git a/configs/Market1501/AGW_R50-ibn.yml b/configs/Market1501/AGW_R50-ibn.yml
index bee96c1..4ec8154 100644
--- a/configs/Market1501/AGW_R50-ibn.yml
+++ b/configs/Market1501/AGW_R50-ibn.yml
@@ -3,7 +3,6 @@ _BASE_: "../Base-AGW.yml"
 MODEL:
   BACKBONE:
     WITH_IBN: True
-    PRETRAIN_PATH: "/home/liaoxingyu2/lxy/.cache/torch/checkpoints/resnet50_ibn_a.pth.tar"
 
 DATASETS:
   NAMES: ("Market1501",)
diff --git a/configs/Market1501/bagtricks_R101-ibn.yml b/configs/Market1501/bagtricks_R101-ibn.yml
index 5103391..03b5132 100644
--- a/configs/Market1501/bagtricks_R101-ibn.yml
+++ b/configs/Market1501/bagtricks_R101-ibn.yml
@@ -2,9 +2,8 @@ _BASE_: "../Base-bagtricks.yml"
 
 MODEL:
   BACKBONE:
-    DEPTH: 101
+    DEPTH: "101"
     WITH_IBN: True
-    PRETRAIN_PATH: "/home/liaoxingyu2/lxy/.cache/torch/checkpoints/resnet101_ibn_a.pth.tar"
 
 DATASETS:
   NAMES: ("Market1501",)
diff --git a/configs/Market1501/bagtricks_R50-ibn.yml b/configs/Market1501/bagtricks_R50-ibn.yml
index d8f0cf6..0a5a032 100644
--- a/configs/Market1501/bagtricks_R50-ibn.yml
+++ b/configs/Market1501/bagtricks_R50-ibn.yml
@@ -3,7 +3,6 @@ _BASE_: "../Base-bagtricks.yml"
 MODEL:
   BACKBONE:
     WITH_IBN: True
-    PRETRAIN_PATH: "/home/liaoxingyu2/lxy/.cache/torch/checkpoints/resnet50_ibn_a.pth.tar"
 
 DATASETS:
   NAMES: ("Market1501",)
diff --git a/configs/Market1501/mgn_R50-ibn.yml b/configs/Market1501/mgn_R50-ibn.yml
index 058ca12..b2ca6f2 100644
--- a/configs/Market1501/mgn_R50-ibn.yml
+++ b/configs/Market1501/mgn_R50-ibn.yml
@@ -3,7 +3,6 @@ _BASE_: "../Base-MGN.yml"
 MODEL:
   BACKBONE:
     WITH_IBN: True
-    PRETRAIN_PATH: "/home/liaoxingyu2/lxy/.cache/torch/checkpoints/resnet50_ibn_a.pth.tar"
 
 DATASETS:
   NAMES: ("Market1501",)
diff --git a/configs/Market1501/sbs_R101-ibn.yml b/configs/Market1501/sbs_R101-ibn.yml
index 44fc73d..ff74053 100644
--- a/configs/Market1501/sbs_R101-ibn.yml
+++ b/configs/Market1501/sbs_R101-ibn.yml
@@ -4,7 +4,6 @@ MODEL:
   BACKBONE:
     DEPTH: 101
     WITH_IBN: True
-    PRETRAIN_PATH: "/home/liaoxingyu2/lxy/.cache/torch/checkpoints/resnet101_ibn_a.pth.tar"
 
 DATASETS:
   NAMES: ("Market1501",)
diff --git a/configs/Market1501/sbs_R50-ibn.yml b/configs/Market1501/sbs_R50-ibn.yml
index 4b556a0..7302591 100644
--- a/configs/Market1501/sbs_R50-ibn.yml
+++ b/configs/Market1501/sbs_R50-ibn.yml
@@ -3,7 +3,6 @@ _BASE_: "../Base-Strongerbaseline.yml"
 MODEL:
   BACKBONE:
     WITH_IBN: True
-    PRETRAIN_PATH: "/home/liaoxingyu2/lxy/.cache/torch/checkpoints/resnet50_ibn_a.pth.tar"
 
 DATASETS:
   NAMES: ("Market1501",)
diff --git a/configs/VERIWild/bagtricks_R50-ibn.yml b/configs/VERIWild/bagtricks_R50-ibn.yml
index a0375d1..20fc877 100644
--- a/configs/VERIWild/bagtricks_R50-ibn.yml
+++ b/configs/VERIWild/bagtricks_R50-ibn.yml
@@ -7,9 +7,7 @@ INPUT:
 MODEL:
   BACKBONE:
     WITH_IBN: True
-    PRETRAIN_PATH: '/export2/home/zjk/pretrain_models/resnet50_ibn_a.pth.tar'
   HEADS:
-    NUM_CLASSES: 30671
     POOL_LAYER: gempool
   LOSSES:
     TRI:
diff --git a/configs/VeRi/sbs_R50-ibn.yml b/configs/VeRi/sbs_R50-ibn.yml
index d64a10f..5324cde 100644
--- a/configs/VeRi/sbs_R50-ibn.yml
+++ b/configs/VeRi/sbs_R50-ibn.yml
@@ -7,10 +7,6 @@ INPUT:
 MODEL:
   BACKBONE:
     WITH_IBN: True
-    PRETRAIN_PATH: "/export2/home/zjk/pretrain_models/resnet50_ibn_a.pth.tar"
-
-  HEADS:
-    NUM_CLASSES: 575
 
 SOLVER:
   OPT: "SGD"
diff --git a/configs/VehicleID/bagtricks_R50-ibn.yml b/configs/VehicleID/bagtricks_R50-ibn.yml
index 6bd886c..0020c4f 100644
--- a/configs/VehicleID/bagtricks_R50-ibn.yml
+++ b/configs/VehicleID/bagtricks_R50-ibn.yml
@@ -7,9 +7,7 @@ INPUT:
 MODEL:
   BACKBONE:
     WITH_IBN: True
-    PRETRAIN_PATH: '/export2/home/zjk/pretrain_models/resnet50_ibn_a.pth.tar'
   HEADS:
-    NUM_CLASSES: 13164
     POOL_LAYER: gempool
   LOSSES:
     TRI:
diff --git a/demo/demo.py b/demo/demo.py
index bab0adb..349a796 100644
--- a/demo/demo.py
+++ b/demo/demo.py
@@ -29,6 +29,7 @@ cudnn.benchmark = True
 def setup_cfg(args):
     # load config from file and command-line arguments
     cfg = get_cfg()
+    # add_partialreid_config(cfg)
     cfg.merge_from_file(args.config_file)
     cfg.merge_from_list(args.opts)
     cfg.freeze()
diff --git a/demo/plot_roc_with_pickle.py b/demo/plot_roc_with_pickle.py
index 0ab51e3..c4e9aad 100644
--- a/demo/plot_roc_with_pickle.py
+++ b/demo/plot_roc_with_pickle.py
@@ -1,7 +1,7 @@
 # encoding: utf-8
 """
 @author:  xingyu liao
-@contact: liaoxingyu5@jd.com
+@contact: sherlockliao01@gmail.com
 """
 
 import matplotlib.pyplot as plt
diff --git a/demo/predictor.py b/demo/predictor.py
index a6b8512..b56f01c 100644
--- a/demo/predictor.py
+++ b/demo/predictor.py
@@ -1,7 +1,7 @@
 # encoding: utf-8
 """
 @author:  xingyu liao
-@contact: liaoxingyu5@jd.com
+@contact: sherlockliao01@gmail.com
 """
 
 import atexit
diff --git a/demo/visualize_result.py b/demo/visualize_result.py
index f3f315e..ea46381 100644
--- a/demo/visualize_result.py
+++ b/demo/visualize_result.py
@@ -1,7 +1,7 @@
 # encoding: utf-8
 """
 @author:  xingyu liao
-@contact: liaoxingyu5@jd.com
+@contact: sherlockliao01@gmail.com
 """
 
 import argparse
diff --git a/fastreid/config/defaults.py b/fastreid/config/defaults.py
index 86cbb78..fbaafd8 100644
--- a/fastreid/config/defaults.py
+++ b/fastreid/config/defaults.py
@@ -30,9 +30,7 @@ _C.MODEL.FREEZE_LAYERS = ['']
 _C.MODEL.BACKBONE = CN()
 
 _C.MODEL.BACKBONE.NAME = "build_resnet_backbone"
-_C.MODEL.BACKBONE.DEPTH = 50
-# RegNet volume
-_C.MODEL.BACKBONE.VOLUME = "800y"
+_C.MODEL.BACKBONE.DEPTH = "50x"
 _C.MODEL.BACKBONE.LAST_STRIDE = 1
 # Normalization method for the convolution layers.
 _C.MODEL.BACKBONE.NORM = "BN"
@@ -137,7 +135,13 @@ _C.INPUT.DO_PAD = True
 _C.INPUT.PADDING_MODE = 'constant'
 _C.INPUT.PADDING = 10
 # Random color jitter
-_C.INPUT.DO_CJ = False
+_C.INPUT.CJ = CN()
+_C.INPUT.CJ.ENABLED = False
+_C.INPUT.CJ.PROB = 0.8
+_C.INPUT.CJ.BRIGHTNESS = 0.15
+_C.INPUT.CJ.CONTRAST = 0.15
+_C.INPUT.CJ.SATURATION = 0.1
+_C.INPUT.CJ.HUE = 0.1
 # Auto augmentation
 _C.INPUT.DO_AUTOAUG = False
 # Augmix augmentation
diff --git a/fastreid/data/datasets/vehicleid.py b/fastreid/data/datasets/vehicleid.py
index 013ce1e..1bbc27e 100644
--- a/fastreid/data/datasets/vehicleid.py
+++ b/fastreid/data/datasets/vehicleid.py
@@ -1,124 +1,124 @@
-# encoding: utf-8
-"""
-@author:  Jinkai Zheng
-@contact: 1315673509@qq.com
-"""
-
-import os.path as osp
-import random
-
-from .bases import ImageDataset
-from ..datasets import DATASET_REGISTRY
-
-
-@DATASET_REGISTRY.register()
-class VehicleID(ImageDataset):
-    """VehicleID.
-
-    Reference:
-        Liu et al. Deep relative distance learning: Tell the difference between similar vehicles. CVPR 2016.
-
-    URL: `<https://pkuml.org/resources/pku-vehicleid.html>`_
-
-    Train dataset statistics:
-        - identities: 13164.
-        - images: 113346.
-    """
-    dataset_dir = "vehicleid"
-    dataset_name = "vehicleid"
-
-    def __init__(self, root='datasets', test_list='', **kwargs):
-        self.dataset_dir = osp.join(root, self.dataset_dir)
-
-        self.image_dir = osp.join(self.dataset_dir, 'image')
-        self.train_list = osp.join(self.dataset_dir, 'train_test_split/train_list.txt')
-        if test_list:
-            self.test_list = test_list
-        else:
-            self.test_list = osp.join(self.dataset_dir, 'train_test_split/test_list_13164.txt')
-
-        required_files = [
-            self.dataset_dir,
-            self.image_dir,
-            self.train_list,
-            self.test_list,
-        ]
-        self.check_before_run(required_files)
-
-        train = self.process_dir(self.train_list, is_train=True)
-        query, gallery = self.process_dir(self.test_list, is_train=False)
-
-        super(VehicleID, self).__init__(train, query, gallery, **kwargs)
-
-    def process_dir(self, list_file, is_train=True):
-        img_list_lines = open(list_file, 'r').readlines()
-
-        dataset = []
-        for idx, line in enumerate(img_list_lines):
-            line = line.strip()
-            vid = int(line.split(' ')[1])
-            imgid = line.split(' ')[0]
-            img_path = osp.join(self.image_dir, imgid + '.jpg')
-            if is_train:
-                vid = self.dataset_name + "_" + str(vid)
-            dataset.append((img_path, vid, int(imgid)))
-
-        if is_train: return dataset
-        else:
-            random.shuffle(dataset)
-            vid_container = set()
-            query = []
-            gallery = []
-            for sample in dataset:
-                if sample[1] not in vid_container:
-                    vid_container.add(sample[1])
-                    gallery.append(sample)
-                else:
-                    query.append(sample)
-
-            return query, gallery
-
-
-@DATASET_REGISTRY.register()
-class SmallVehicleID(VehicleID):
-    """VehicleID.
-    Small test dataset statistics:
-        - identities: 800.
-        - images: 6493.
-    """
-
-    def __init__(self, root='datasets', **kwargs):
-        self.dataset_dir = osp.join(root, self.dataset_dir)
-        self.test_list = osp.join(self.dataset_dir, 'train_test_split/test_list_800.txt')
-
-        super(SmallVehicleID, self).__init__(root, self.test_list, **kwargs)
-
-
-@DATASET_REGISTRY.register()
-class MediumVehicleID(VehicleID):
-    """VehicleID.
-    Medium test dataset statistics:
-        - identities: 1600.
-        - images: 13377.
-    """
-
-    def __init__(self, root='datasets', **kwargs):
-        self.dataset_dir = osp.join(root, self.dataset_dir)
-        self.test_list = osp.join(self.dataset_dir, 'train_test_split/test_list_1600.txt')
-
-        super(MediumVehicleID, self).__init__(root, self.test_list, **kwargs)
-
-
-@DATASET_REGISTRY.register()
-class LargeVehicleID(VehicleID):
-    """VehicleID.
-    Large test dataset statistics:
-        - identities: 2400.
-        - images: 19777.
-    """
-
-    def __init__(self, root='datasets', **kwargs):
-        self.dataset_dir = osp.join(root, self.dataset_dir)
-        self.test_list = osp.join(self.dataset_dir, 'train_test_split/test_list_2400.txt')
-
-        super(LargeVehicleID, self).__init__(root, self.test_list, **kwargs)
+# encoding: utf-8
+"""
+@author:  Jinkai Zheng
+@contact: 1315673509@qq.com
+"""
+
+import os.path as osp
+import random
+
+from .bases import ImageDataset
+from ..datasets import DATASET_REGISTRY
+
+
+@DATASET_REGISTRY.register()
+class VehicleID(ImageDataset):
+    """VehicleID.
+
+    Reference:
+        Liu et al. Deep relative distance learning: Tell the difference between similar vehicles. CVPR 2016.
+
+    URL: `<https://pkuml.org/resources/pku-vehicleid.html>`_
+
+    Train dataset statistics:
+        - identities: 13164.
+        - images: 113346.
+    """
+    dataset_dir = "vehicleid"
+    dataset_name = "vehicleid"
+
+    def __init__(self, root='datasets', test_list='', **kwargs):
+        self.dataset_dir = osp.join(root, self.dataset_dir)
+
+        self.image_dir = osp.join(self.dataset_dir, 'image')
+        self.train_list = osp.join(self.dataset_dir, 'train_test_split/train_list.txt')
+        if test_list:
+            self.test_list = test_list
+        else:
+            self.test_list = osp.join(self.dataset_dir, 'train_test_split/test_list_13164.txt')
+
+        required_files = [
+            self.dataset_dir,
+            self.image_dir,
+            self.train_list,
+            self.test_list,
+        ]
+        self.check_before_run(required_files)
+
+        train = self.process_dir(self.train_list, is_train=True)
+        query, gallery = self.process_dir(self.test_list, is_train=False)
+
+        super(VehicleID, self).__init__(train, query, gallery, **kwargs)
+
+    def process_dir(self, list_file, is_train=True):
+        img_list_lines = open(list_file, 'r').readlines()
+
+        dataset = []
+        for idx, line in enumerate(img_list_lines):
+            line = line.strip()
+            vid = int(line.split(' ')[1])
+            imgid = line.split(' ')[0]
+            img_path = osp.join(self.image_dir, imgid + '.jpg')
+            if is_train:
+                vid = self.dataset_name + "_" + str(vid)
+            dataset.append((img_path, vid, int(imgid)))
+
+        if is_train: return dataset
+        else:
+            random.shuffle(dataset)
+            vid_container = set()
+            query = []
+            gallery = []
+            for sample in dataset:
+                if sample[1] not in vid_container:
+                    vid_container.add(sample[1])
+                    gallery.append(sample)
+                else:
+                    query.append(sample)
+
+            return query, gallery
+
+
+@DATASET_REGISTRY.register()
+class SmallVehicleID(VehicleID):
+    """VehicleID.
+    Small test dataset statistics:
+        - identities: 800.
+        - images: 6493.
+    """
+
+    def __init__(self, root='datasets', **kwargs):
+        self.dataset_dir = osp.join(root, self.dataset_dir)
+        self.test_list = osp.join(self.dataset_dir, 'train_test_split/test_list_800.txt')
+
+        super(SmallVehicleID, self).__init__(root, self.test_list, **kwargs)
+
+
+@DATASET_REGISTRY.register()
+class MediumVehicleID(VehicleID):
+    """VehicleID.
+    Medium test dataset statistics:
+        - identities: 1600.
+        - images: 13377.
+    """
+
+    def __init__(self, root='datasets', **kwargs):
+        self.dataset_dir = osp.join(root, self.dataset_dir)
+        self.test_list = osp.join(self.dataset_dir, 'train_test_split/test_list_1600.txt')
+
+        super(MediumVehicleID, self).__init__(root, self.test_list, **kwargs)
+
+
+@DATASET_REGISTRY.register()
+class LargeVehicleID(VehicleID):
+    """VehicleID.
+    Large test dataset statistics:
+        - identities: 2400.
+        - images: 19777.
+    """
+
+    def __init__(self, root='datasets', **kwargs):
+        self.dataset_dir = osp.join(root, self.dataset_dir)
+        self.test_list = osp.join(self.dataset_dir, 'train_test_split/test_list_2400.txt')
+
+        super(LargeVehicleID, self).__init__(root, self.test_list, **kwargs)
diff --git a/fastreid/data/datasets/veri.py b/fastreid/data/datasets/veri.py
index 181d5c1..7e3b166 100644
--- a/fastreid/data/datasets/veri.py
+++ b/fastreid/data/datasets/veri.py
@@ -1,67 +1,67 @@
-# encoding: utf-8
-"""
-@author:  Jinkai Zheng
-@contact: 1315673509@qq.com
-"""
-
-import glob
-import os.path as osp
-import re
-
-from .bases import ImageDataset
-from ..datasets import DATASET_REGISTRY
-
-
-@DATASET_REGISTRY.register()
-class VeRi(ImageDataset):
-    """VeRi.
-
-    Reference:
-        Liu et al. A Deep Learning based Approach for Progressive Vehicle Re-Identification. ECCV 2016.
-
-    URL: `<https://vehiclereid.github.io/VeRi/>`_
-
-    Dataset statistics:
-        - identities: 775.
-        - images: 37778 (train) + 1678 (query) + 11579 (gallery).
-    """
-    dataset_dir = "veri"
-    dataset_name = "veri"
-
-    def __init__(self, root='datasets', **kwargs):
-        self.dataset_dir = osp.join(root, self.dataset_dir)
-
-        self.train_dir = osp.join(self.dataset_dir, 'image_train')
-        self.query_dir = osp.join(self.dataset_dir, 'image_query')
-        self.gallery_dir = osp.join(self.dataset_dir, 'image_test')
-
-        required_files = [
-            self.dataset_dir,
-            self.train_dir,
-            self.query_dir,
-            self.gallery_dir,
-        ]
-        self.check_before_run(required_files)
-
-        train = self.process_dir(self.train_dir)
-        query = self.process_dir(self.query_dir, is_train=False)
-        gallery = self.process_dir(self.gallery_dir, is_train=False)
-
-        super(VeRi, self).__init__(train, query, gallery, **kwargs)
-
-    def process_dir(self, dir_path, is_train=True):
-        img_paths = glob.glob(osp.join(dir_path, '*.jpg'))
-        pattern = re.compile(r'([\d]+)_c(\d\d\d)')
-
-        data = []
-        for img_path in img_paths:
-            pid, camid = map(int, pattern.search(img_path).groups())
-            if pid == -1: continue  # junk images are just ignored
-            assert 1 <= pid <= 776
-            assert 1 <= camid <= 20
-            camid -= 1  # index starts from 0
-            if is_train:
-                pid = self.dataset_name + "_" + str(pid)
-            data.append((img_path, pid, camid))
-
-        return data
+# encoding: utf-8
+"""
+@author:  Jinkai Zheng
+@contact: 1315673509@qq.com
+"""
+
+import glob
+import os.path as osp
+import re
+
+from .bases import ImageDataset
+from ..datasets import DATASET_REGISTRY
+
+
+@DATASET_REGISTRY.register()
+class VeRi(ImageDataset):
+    """VeRi.
+
+    Reference:
+        Liu et al. A Deep Learning based Approach for Progressive Vehicle Re-Identification. ECCV 2016.
+
+    URL: `<https://vehiclereid.github.io/VeRi/>`_
+
+    Dataset statistics:
+        - identities: 775.
+        - images: 37778 (train) + 1678 (query) + 11579 (gallery).
+    """
+    dataset_dir = "veri"
+    dataset_name = "veri"
+
+    def __init__(self, root='datasets', **kwargs):
+        self.dataset_dir = osp.join(root, self.dataset_dir)
+
+        self.train_dir = osp.join(self.dataset_dir, 'image_train')
+        self.query_dir = osp.join(self.dataset_dir, 'image_query')
+        self.gallery_dir = osp.join(self.dataset_dir, 'image_test')
+
+        required_files = [
+            self.dataset_dir,
+            self.train_dir,
+            self.query_dir,
+            self.gallery_dir,
+        ]
+        self.check_before_run(required_files)
+
+        train = self.process_dir(self.train_dir)
+        query = self.process_dir(self.query_dir, is_train=False)
+        gallery = self.process_dir(self.gallery_dir, is_train=False)
+
+        super(VeRi, self).__init__(train, query, gallery, **kwargs)
+
+    def process_dir(self, dir_path, is_train=True):
+        img_paths = glob.glob(osp.join(dir_path, '*.jpg'))
+        pattern = re.compile(r'([\d]+)_c(\d\d\d)')
+
+        data = []
+        for img_path in img_paths:
+            pid, camid = map(int, pattern.search(img_path).groups())
+            if pid == -1: continue  # junk images are just ignored
+            assert 1 <= pid <= 776
+            assert 1 <= camid <= 20
+            camid -= 1  # index starts from 0
+            if is_train:
+                pid = self.dataset_name + "_" + str(pid)
+            data.append((img_path, pid, camid))
+
+        return data
diff --git a/fastreid/data/datasets/veriwild.py b/fastreid/data/datasets/veriwild.py
index 9637d7b..1e0b1cb 100644
--- a/fastreid/data/datasets/veriwild.py
+++ b/fastreid/data/datasets/veriwild.py
@@ -1,138 +1,138 @@
-# encoding: utf-8
-"""
-@author:  Jinkai Zheng
-@contact: 1315673509@qq.com
-"""
-
-import os.path as osp
-
-from .bases import ImageDataset
-from ..datasets import DATASET_REGISTRY
-
-
-@DATASET_REGISTRY.register()
-class VeRiWild(ImageDataset):
-    """VeRi-Wild.
-
-    Reference:
-        Lou et al. A Large-Scale Dataset for Vehicle Re-Identification in the Wild. CVPR 2019.
-
-    URL: `<https://github.com/PKU-IMRE/VERI-Wild>`_
-
-    Train dataset statistics:
-        - identities: 30671.
-        - images: 277797.
-    """
-    dataset_dir = "VERI-Wild"
-    dataset_name = "veriwild"
-
-    def __init__(self, root='datasets', query_list='', gallery_list='', **kwargs):
-        self.dataset_dir = osp.join(root, self.dataset_dir)
-
-        self.image_dir = osp.join(self.dataset_dir, 'images')
-        self.train_list = osp.join(self.dataset_dir, 'train_test_split/train_list.txt')
-        self.vehicle_info = osp.join(self.dataset_dir, 'train_test_split/vehicle_info.txt')
-        if query_list and gallery_list:
-            self.query_list = query_list
-            self.gallery_list = gallery_list
-        else:
-            self.query_list = osp.join(self.dataset_dir, 'train_test_split/test_10000_query.txt')
-            self.gallery_list = osp.join(self.dataset_dir, 'train_test_split/test_10000.txt')
-
-        required_files = [
-            self.image_dir,
-            self.train_list,
-            self.query_list,
-            self.gallery_list,
-            self.vehicle_info,
-        ]
-        self.check_before_run(required_files)
-
-        self.imgid2vid, self.imgid2camid, self.imgid2imgpath = self.process_vehicle(self.vehicle_info)
-
-        train = self.process_dir(self.train_list)
-        query = self.process_dir(self.query_list, is_train=False)
-        gallery = self.process_dir(self.gallery_list, is_train=False)
-
-        super(VeRiWild, self).__init__(train, query, gallery, **kwargs)
-
-    def process_dir(self, img_list, is_train=True):
-        img_list_lines = open(img_list, 'r').readlines()
-
-        dataset = []
-        for idx, line in enumerate(img_list_lines):
-            line = line.strip()
-            vid = int(line.split('/')[0])
-            imgid = line.split('/')[1]
-            if is_train:
-                vid = self.dataset_name + "_" + str(vid)
-            dataset.append((self.imgid2imgpath[imgid], vid, int(self.imgid2camid[imgid])))
-
-        assert len(dataset) == len(img_list_lines)
-        return dataset
-
-    def process_vehicle(self, vehicle_info):
-        imgid2vid = {}
-        imgid2camid = {}
-        imgid2imgpath = {}
-        vehicle_info_lines = open(vehicle_info, 'r').readlines()
-
-        for idx, line in enumerate(vehicle_info_lines[1:]):
-            vid = line.strip().split('/')[0]
-            imgid = line.strip().split(';')[0].split('/')[1]
-            camid = line.strip().split(';')[1]
-            img_path = osp.join(self.image_dir, vid, imgid + '.jpg')
-            imgid2vid[imgid] = vid
-            imgid2camid[imgid] = camid
-            imgid2imgpath[imgid] = img_path
-
-        assert len(imgid2vid) == len(vehicle_info_lines) - 1
-        return imgid2vid, imgid2camid, imgid2imgpath
-
-
-@DATASET_REGISTRY.register()
-class SmallVeRiWild(VeRiWild):
-    """VeRi-Wild.
-    Small test dataset statistics:
-        - identities: 3000.
-        - images: 41861.
-    """
-
-    def __init__(self, root='datasets', **kwargs):
-        self.dataset_dir = osp.join(root, self.dataset_dir)
-        self.query_list = osp.join(self.dataset_dir, 'train_test_split/test_3000_query.txt')
-        self.gallery_list = osp.join(self.dataset_dir, 'train_test_split/test_3000.txt')
-
-        super(SmallVeRiWild, self).__init__(root, self.query_list, self.gallery_list, **kwargs)
-
-
-@DATASET_REGISTRY.register()
-class MediumVeRiWild(VeRiWild):
-    """VeRi-Wild.
-    Medium test dataset statistics:
-        - identities: 5000.
-        - images: 69389.
-    """
-
-    def __init__(self, root='datasets', **kwargs):
-        self.dataset_dir = osp.join(root, self.dataset_dir)
-        self.query_list = osp.join(self.dataset_dir, 'train_test_split/test_5000_query.txt')
-        self.gallery_list = osp.join(self.dataset_dir, 'train_test_split/test_5000.txt')
-
-        super(MediumVeRiWild, self).__init__(root, self.query_list, self.gallery_list, **kwargs)
-
-
-@DATASET_REGISTRY.register()
-class LargeVeRiWild(VeRiWild):
-    """VeRi-Wild.
-    Large test dataset statistics:
-        - identities: 10000.
-        - images: 138517.
-    """
-
-    def __init__(self, root='datasets', **kwargs):
-        self.dataset_dir = osp.join(root, self.dataset_dir)
-        self.query_list = osp.join(self.dataset_dir, 'train_test_split/test_10000_query.txt')
-        self.gallery_list = osp.join(self.dataset_dir, 'train_test_split/test_10000.txt')
-
-        super(LargeVeRiWild, self).__init__(root, self.query_list, self.gallery_list, **kwargs)
+# encoding: utf-8
+"""
+@author:  Jinkai Zheng
+@contact: 1315673509@qq.com
+"""
+
+import os.path as osp
+
+from .bases import ImageDataset
+from ..datasets import DATASET_REGISTRY
+
+
+@DATASET_REGISTRY.register()
+class VeRiWild(ImageDataset):
+    """VeRi-Wild.
+
+    Reference:
+        Lou et al. A Large-Scale Dataset for Vehicle Re-Identification in the Wild. CVPR 2019.
+
+    URL: `<https://github.com/PKU-IMRE/VERI-Wild>`_
+
+    Train dataset statistics:
+        - identities: 30671.
+        - images: 277797.
+    """
+    dataset_dir = "VERI-Wild"
+    dataset_name = "veriwild"
+
+    def __init__(self, root='datasets', query_list='', gallery_list='', **kwargs):
+        self.dataset_dir = osp.join(root, self.dataset_dir)
+
+        self.image_dir = osp.join(self.dataset_dir, 'images')
+        self.train_list = osp.join(self.dataset_dir, 'train_test_split/train_list.txt')
+        self.vehicle_info = osp.join(self.dataset_dir, 'train_test_split/vehicle_info.txt')
+        if query_list and gallery_list:
+            self.query_list = query_list
+            self.gallery_list = gallery_list
+        else:
+            self.query_list = osp.join(self.dataset_dir, 'train_test_split/test_10000_query.txt')
+            self.gallery_list = osp.join(self.dataset_dir, 'train_test_split/test_10000.txt')
+
+        required_files = [
+            self.image_dir,
+            self.train_list,
+            self.query_list,
+            self.gallery_list,
+            self.vehicle_info,
+        ]
+        self.check_before_run(required_files)
+
+        self.imgid2vid, self.imgid2camid, self.imgid2imgpath = self.process_vehicle(self.vehicle_info)
+
+        train = self.process_dir(self.train_list)
+        query = self.process_dir(self.query_list, is_train=False)
+        gallery = self.process_dir(self.gallery_list, is_train=False)
+
+        super(VeRiWild, self).__init__(train, query, gallery, **kwargs)
+
+    def process_dir(self, img_list, is_train=True):
+        img_list_lines = open(img_list, 'r').readlines()
+
+        dataset = []
+        for idx, line in enumerate(img_list_lines):
+            line = line.strip()
+            vid = int(line.split('/')[0])
+            imgid = line.split('/')[1]
+            if is_train:
+                vid = self.dataset_name + "_" + str(vid)
+            dataset.append((self.imgid2imgpath[imgid], vid, int(self.imgid2camid[imgid])))
+
+        assert len(dataset) == len(img_list_lines)
+        return dataset
+
+    def process_vehicle(self, vehicle_info):
+        imgid2vid = {}
+        imgid2camid = {}
+        imgid2imgpath = {}
+        vehicle_info_lines = open(vehicle_info, 'r').readlines()
+
+        for idx, line in enumerate(vehicle_info_lines[1:]):
+            vid = line.strip().split('/')[0]
+            imgid = line.strip().split(';')[0].split('/')[1]
+            camid = line.strip().split(';')[1]
+            img_path = osp.join(self.image_dir, vid, imgid + '.jpg')
+            imgid2vid[imgid] = vid
+            imgid2camid[imgid] = camid
+            imgid2imgpath[imgid] = img_path
+
+        assert len(imgid2vid) == len(vehicle_info_lines) - 1
+        return imgid2vid, imgid2camid, imgid2imgpath
+
+
+@DATASET_REGISTRY.register()
+class SmallVeRiWild(VeRiWild):
+    """VeRi-Wild.
+    Small test dataset statistics:
+        - identities: 3000.
+        - images: 41861.
+    """
+
+    def __init__(self, root='datasets', **kwargs):
+        self.dataset_dir = osp.join(root, self.dataset_dir)
+        self.query_list = osp.join(self.dataset_dir, 'train_test_split/test_3000_query.txt')
+        self.gallery_list = osp.join(self.dataset_dir, 'train_test_split/test_3000.txt')
+
+        super(SmallVeRiWild, self).__init__(root, self.query_list, self.gallery_list, **kwargs)
+
+
+@DATASET_REGISTRY.register()
+class MediumVeRiWild(VeRiWild):
+    """VeRi-Wild.
+    Medium test dataset statistics:
+        - identities: 5000.
+        - images: 69389.
+    """
+
+    def __init__(self, root='datasets', **kwargs):
+        self.dataset_dir = osp.join(root, self.dataset_dir)
+        self.query_list = osp.join(self.dataset_dir, 'train_test_split/test_5000_query.txt')
+        self.gallery_list = osp.join(self.dataset_dir, 'train_test_split/test_5000.txt')
+
+        super(MediumVeRiWild, self).__init__(root, self.query_list, self.gallery_list, **kwargs)
+
+
+@DATASET_REGISTRY.register()
+class LargeVeRiWild(VeRiWild):
+    """VeRi-Wild.
+    Large test dataset statistics:
+        - identities: 10000.
+        - images: 138517.
+    """
+
+    def __init__(self, root='datasets', **kwargs):
+        self.dataset_dir = osp.join(root, self.dataset_dir)
+        self.query_list = osp.join(self.dataset_dir, 'train_test_split/test_10000_query.txt')
+        self.gallery_list = osp.join(self.dataset_dir, 'train_test_split/test_10000.txt')
+
+        super(LargeVeRiWild, self).__init__(root, self.query_list, self.gallery_list, **kwargs)
diff --git a/fastreid/data/transforms/build.py b/fastreid/data/transforms/build.py
index 35e7a9b..743c8a2 100644
--- a/fastreid/data/transforms/build.py
+++ b/fastreid/data/transforms/build.py
@@ -33,7 +33,12 @@ def build_transforms(cfg, is_train=True):
         padding_mode = cfg.INPUT.PADDING_MODE
 
         # color jitter
-        do_cj = cfg.INPUT.DO_CJ
+        do_cj = cfg.INPUT.CJ.ENABLED
+        cj_prob = cfg.INPUT.CJ.PROB
+        cj_brightness = cfg.INPUT.CJ.BRIGHTNESS
+        cj_contrast = cfg.INPUT.CJ.CONTRAST
+        cj_saturation = cfg.INPUT.CJ.SATURATION
+        cj_hue = cfg.INPUT.CJ.HUE
 
         # random erasing
         do_rea = cfg.INPUT.REA.ENABLED
@@ -52,7 +57,7 @@ def build_transforms(cfg, is_train=True):
             res.extend([T.Pad(padding, padding_mode=padding_mode),
                         T.RandomCrop(size_train)])
         if do_cj:
-            res.append(T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0))
+            T.RandomApply([T.ColorJitter(cj_brightness, cj_contrast, cj_saturation, cj_hue)], p=cj_prob)
         if do_augmix:
             res.append(AugMix())
         if do_rea:
diff --git a/fastreid/data/transforms/transforms.py b/fastreid/data/transforms/transforms.py
index b66a96f..0de5562 100644
--- a/fastreid/data/transforms/transforms.py
+++ b/fastreid/data/transforms/transforms.py
@@ -4,7 +4,7 @@
 @contact: sherlockliao01@gmail.com
 """
 
-__all__ = ['ToTensor', 'RandomErasing', 'RandomPatch', 'AugMix', ]
+__all__ = ['ToTensor', 'RandomErasing', 'RandomPatch', 'AugMix',]
 
 import math
 import random
@@ -202,110 +202,3 @@ class AugMix(object):
 
         mixed = (1 - m) * image + m * mix
         return mixed
-
-# class ColorJitter(object):
-#     """docstring for do_color"""
-#
-#     def __init__(self, probability=0.5):
-#         self.probability = probability
-#
-#     def do_brightness_shift(self, image, alpha=0.125):
-#         image = image.astype(np.float32)
-#         image = image + alpha * 255
-#         image = np.clip(image, 0, 255).astype(np.uint8)
-#         return image
-#
-#     def do_brightness_multiply(self, image, alpha=1):
-#         image = image.astype(np.float32)
-#         image = alpha * image
-#         image = np.clip(image, 0, 255).astype(np.uint8)
-#         return image
-#
-#     def do_contrast(self, image, alpha=1.0):
-#         image = image.astype(np.float32)
-#         gray = image * np.array([[[0.114, 0.587, 0.299]]])  # rgb to gray (YCbCr)
-#         gray = (3.0 * (1.0 - alpha) / gray.size) * np.sum(gray)
-#         image = alpha * image + gray
-#         image = np.clip(image, 0, 255).astype(np.uint8)
-#         return image
-#
-#     # https://www.pyimagesearch.com/2015/10/05/opencv-gamma-correction/
-#     def do_gamma(self, image, gamma=1.0):
-#         table = np.array([((i / 255.0) ** (1.0 / gamma)) * 255
-#                           for i in np.arange(0, 256)]).astype("uint8")
-#
-#         return cv2.LUT(image, table)  # apply gamma correction using the lookup table
-#
-#     def do_clahe(self, image, clip=2, grid=16):
-#         grid = int(grid)
-#
-#         lab = cv2.cvtColor(image, cv2.COLOR_BGR2LAB)
-#         gray, a, b = cv2.split(lab)
-#         gray = cv2.createCLAHE(clipLimit=clip, tileGridSize=(grid, grid)).apply(gray)
-#         lab = cv2.merge((gray, a, b))
-#         image = cv2.cvtColor(lab, cv2.COLOR_LAB2BGR)
-#
-#         return image
-#
-#     def __call__(self, image):
-#         if random.uniform(0, 1) > self.probability:
-#             return image
-#
-#         image = np.asarray(image, dtype=np.uint8).copy()
-#         index = random.randint(0, 4)
-#         if index == 0:
-#             image = self.do_brightness_shift(image, 0.1)
-#         elif index == 1:
-#             image = self.do_gamma(image, 1)
-#         elif index == 2:
-#             image = self.do_clahe(image)
-#         elif index == 3:
-#             image = self.do_brightness_multiply(image)
-#         elif index == 4:
-#             image = self.do_contrast(image)
-#         return image
-
-
-# class random_shift(object):
-#     """docstring for do_color"""
-#
-#     def __init__(self, probability=0.5):
-#         self.probability = probability
-#
-#     def __call__(self, image):
-#         if random.uniform(0, 1) > self.probability:
-#             return image
-#
-#         width, height, d = image.shape
-#         zero_image = np.zeros_like(image)
-#         w = random.randint(0, 20) - 10
-#         h = random.randint(0, 30) - 15
-#         zero_image[max(0, w): min(w + width, width), max(h, 0): min(h + height, height)] = \
-#             image[max(0, -w): min(-w + width, width), max(-h, 0): min(-h + height, height)]
-#         image = zero_image.copy()
-#         return image
-#
-#
-# class random_scale(object):
-#     """docstring for do_color"""
-#
-#     def __init__(self, probability=0.5):
-#         self.probability = probability
-#
-#     def __call__(self, image):
-#         if random.uniform(0, 1) > self.probability:
-#             return image
-#
-#         scale = random.random() * 0.1 + 0.9
-#         assert 0.9 <= scale <= 1
-#         width, height, d = image.shape
-#         zero_image = np.zeros_like(image)
-#         new_width = round(width * scale)
-#         new_height = round(height * scale)
-#         image = cv2.resize(image, (new_height, new_width))
-#         start_w = random.randint(0, width - new_width)
-#         start_h = random.randint(0, height - new_height)
-#         zero_image[start_w: start_w + new_width,
-#         start_h:start_h + new_height] = image
-#         image = zero_image.copy()
-#         return image
diff --git a/fastreid/engine/launch.py b/fastreid/engine/launch.py
index 328db59..a8f5c9b 100644
--- a/fastreid/engine/launch.py
+++ b/fastreid/engine/launch.py
@@ -1,7 +1,7 @@
 # encoding: utf-8
 """
 @author:  xingyu liao
-@contact: liaoxingyu5@jd.com
+@contact: sherlockliao01@gmail.com
 """
 
 # based on:
diff --git a/fastreid/evaluation/query_expansion.py b/fastreid/evaluation/query_expansion.py
index 873aed6..637c097 100644
--- a/fastreid/evaluation/query_expansion.py
+++ b/fastreid/evaluation/query_expansion.py
@@ -1,7 +1,7 @@
 # encoding: utf-8
 """
 @author:  xingyu liao
-@contact: liaoxingyu5@jd.com
+@contact: sherlockliao01@gmail.com
 """
 
 # based on
diff --git a/fastreid/evaluation/reid_evaluation.py b/fastreid/evaluation/reid_evaluation.py
index d83e399..a46913e 100644
--- a/fastreid/evaluation/reid_evaluation.py
+++ b/fastreid/evaluation/reid_evaluation.py
@@ -10,7 +10,6 @@ from collections import OrderedDict
 import numpy as np
 import torch
 import torch.nn.functional as F
-from tabulate import tabulate
 
 from .evaluator import DatasetEvaluator
 from .query_expansion import aqe
@@ -101,6 +100,6 @@ class ReidEvaluator(DatasetEvaluator):
         tprs = evaluate_roc(dist, query_pids, gallery_pids, query_camids, gallery_camids)
         fprs = [1e-4, 1e-3, 1e-2]
         for i in range(len(fprs)):
-            self._results["TPR@FPR={}".format(fprs[i])] = tprs[i]
+            self._results["TPR@FPR={:.0e}".format(fprs[i])] = tprs[i]
 
         return copy.deepcopy(self._results)
diff --git a/fastreid/layers/__init__.py b/fastreid/layers/__init__.py
index 606b3fc..dccaccd 100644
--- a/fastreid/layers/__init__.py
+++ b/fastreid/layers/__init__.py
@@ -6,9 +6,10 @@
 
 from .activation import *
 from .arc_softmax import ArcSoftmax
+from .circle_softmax import CircleSoftmax
+from .am_softmax import AMSoftmax
 from .batch_drop import BatchDrop
 from .batch_norm import *
-from .circle_softmax import CircleSoftmax
 from .context_block import ContextBlock
 from .frn import FRN, TLU
 from .non_local import Non_local
diff --git a/fastreid/layers/activation.py b/fastreid/layers/activation.py
index 3e0aea6..dafbf60 100644
--- a/fastreid/layers/activation.py
+++ b/fastreid/layers/activation.py
@@ -1,7 +1,7 @@
 # encoding: utf-8
 """
 @author:  xingyu liao
-@contact: liaoxingyu5@jd.com
+@contact: sherlockliao01@gmail.com
 """
 
 import math
diff --git a/fastreid/layers/am_softmax.py b/fastreid/layers/am_softmax.py
new file mode 100644
index 0000000..3f5a59e
--- /dev/null
+++ b/fastreid/layers/am_softmax.py
@@ -0,0 +1,43 @@
+# encoding: utf-8
+"""
+@author:  xingyu liao
+@contact: sherlockliao01@gmail.com
+"""
+
+import torch
+from torch import nn
+import torch.nn.functional as F
+from torch.nn import Parameter
+
+
+class AMSoftmax(nn.Module):
+    r"""Implement of large margin cosine distance:
+    Args:
+        in_feat: size of each input sample
+        num_classes: size of each output sample
+    """
+
+    def __init__(self, cfg, in_feat, num_classes):
+        super().__init__()
+        self.in_features = in_feat
+        self._num_classes = num_classes
+        self._s = cfg.MODEL.HEADS.SCALE
+        self._m = cfg.MODEL.HEADS.MARGIN
+        self.weight = Parameter(torch.Tensor(num_classes, in_feat))
+        nn.init.xavier_uniform_(self.weight)
+
+    def forward(self, features, targets):
+        # --------------------------- cos(theta) & phi(theta) ---------------------------
+        cosine = F.linear(F.normalize(features), F.normalize(self.weight))
+        phi = cosine - self._m
+        # --------------------------- convert label to one-hot ---------------------------
+        targets = F.one_hot(targets, num_classes=self._num_classes)
+        output = (targets * phi) + ((1.0 - targets) * cosine)
+        output *= self._s
+
+        return output
+
+    def extra_repr(self):
+        return 'in_features={}, num_classes={}, scale={}, margin={}'.format(
+            self.in_feat, self._num_classes, self._s, self._m
+        )
diff --git a/fastreid/layers/arc_softmax.py b/fastreid/layers/arc_softmax.py
index 32a70f7..455309c 100644
--- a/fastreid/layers/arc_softmax.py
+++ b/fastreid/layers/arc_softmax.py
@@ -26,6 +26,7 @@ class ArcSoftmax(nn.Module):
         self.mm = math.sin(math.pi - self._m) * self._m
 
         self.weight = Parameter(torch.Tensor(num_classes, in_feat))
+        nn.init.xavier_uniform_(self.weight)
         self.register_buffer('t', torch.zeros(1))
 
     def forward(self, features, targets):
diff --git a/fastreid/layers/circle_softmax.py b/fastreid/layers/circle_softmax.py
index 2d596b6..e4e392d 100644
--- a/fastreid/layers/circle_softmax.py
+++ b/fastreid/layers/circle_softmax.py
@@ -4,6 +4,8 @@
 @contact: sherlockliao01@gmail.com
 """
 
+import math
+
 import torch
 import torch.nn as nn
 import torch.nn.functional as F
@@ -19,11 +21,12 @@ class CircleSoftmax(nn.Module):
         self._m = cfg.MODEL.HEADS.MARGIN
 
         self.weight = Parameter(torch.Tensor(num_classes, in_feat))
+        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
 
     def forward(self, features, targets):
         sim_mat = F.linear(F.normalize(features), F.normalize(self.weight))
-        alpha_p = F.relu(-sim_mat.detach() + 1 + self._m)
-        alpha_n = F.relu(sim_mat.detach() + self._m)
+        alpha_p = torch.clamp_min(-sim_mat.detach() + 1 + self._m, min=0.)
+        alpha_n = torch.clamp_min(sim_mat.detach() + self._m, min=0.)
         delta_p = 1 - self._m
         delta_n = self._m
 
diff --git a/fastreid/layers/splat.py b/fastreid/layers/splat.py
index 5559eab..b7451c5 100644
--- a/fastreid/layers/splat.py
+++ b/fastreid/layers/splat.py
@@ -1,9 +1,4 @@
 # encoding: utf-8
-"""
-@author:  xingyu liao
-@contact: liaoxingyu5@jd.com
-"""
-
 import torch
 import torch.nn.functional as F
 from torch import nn
diff --git a/fastreid/modeling/backbones/__init__.py b/fastreid/modeling/backbones/__init__.py
index f964933..d3ee3be 100644
--- a/fastreid/modeling/backbones/__init__.py
+++ b/fastreid/modeling/backbones/__init__.py
@@ -10,4 +10,4 @@ from .resnet import build_resnet_backbone
 from .osnet import build_osnet_backbone
 from .resnest import build_resnest_backbone
 from .resnext import build_resnext_backbone
-from .regnet import build_regnet_backbone
\ No newline at end of file
+from .regnet import build_regnet_backbone
diff --git a/fastreid/modeling/backbones/osnet.py b/fastreid/modeling/backbones/osnet.py
index d71615a..b9da1fa 100644
--- a/fastreid/modeling/backbones/osnet.py
+++ b/fastreid/modeling/backbones/osnet.py
@@ -1,7 +1,7 @@
 # encoding: utf-8
 """
 @author:  xingyu liao
-@contact: liaoxingyu5@jd.com
+@contact: sherlockliao01@gmail.com
 """
 
 # based on:
@@ -9,7 +9,8 @@
 
 import torch
 from torch import nn
-from torch.nn import functional as F
+
+from fastreid.layers import get_norm
 from .build import BACKBONE_REGISTRY
 
 model_urls = {
@@ -37,6 +38,8 @@ class ConvLayer(nn.Module):
             in_channels,
             out_channels,
             kernel_size,
+            bn_norm,
+            num_splits,
             stride=1,
             padding=0,
             groups=1,
@@ -55,7 +58,7 @@ class ConvLayer(nn.Module):
         if IN:
             self.bn = nn.InstanceNorm2d(out_channels, affine=True)
         else:
-            self.bn = nn.BatchNorm2d(out_channels)
+            self.bn = get_norm(bn_norm, out_channels, num_splits)
         self.relu = nn.ReLU(inplace=True)
 
     def forward(self, x):
@@ -68,7 +71,7 @@ class ConvLayer(nn.Module):
 class Conv1x1(nn.Module):
     """1x1 convolution + bn + relu."""
 
-    def __init__(self, in_channels, out_channels, stride=1, groups=1):
+    def __init__(self, in_channels, out_channels, bn_norm, num_splits, stride=1, groups=1):
         super(Conv1x1, self).__init__()
         self.conv = nn.Conv2d(
             in_channels,
@@ -79,7 +82,7 @@ class Conv1x1(nn.Module):
             bias=False,
             groups=groups
         )
-        self.bn = nn.BatchNorm2d(out_channels)
+        self.bn = get_norm(bn_norm, out_channels, num_splits)
         self.relu = nn.ReLU(inplace=True)
 
     def forward(self, x):
@@ -92,12 +95,12 @@ class Conv1x1(nn.Module):
 class Conv1x1Linear(nn.Module):
     """1x1 convolution + bn (w/o non-linearity)."""
 
-    def __init__(self, in_channels, out_channels, stride=1):
+    def __init__(self, in_channels, out_channels, bn_norm, num_splits, stride=1):
         super(Conv1x1Linear, self).__init__()
         self.conv = nn.Conv2d(
             in_channels, out_channels, 1, stride=stride, padding=0, bias=False
         )
-        self.bn = nn.BatchNorm2d(out_channels)
+        self.bn = get_norm(bn_norm, out_channels, num_splits)
 
     def forward(self, x):
         x = self.conv(x)
@@ -108,7 +111,7 @@ class Conv1x1Linear(nn.Module):
 class Conv3x3(nn.Module):
     """3x3 convolution + bn + relu."""
 
-    def __init__(self, in_channels, out_channels, stride=1, groups=1):
+    def __init__(self, in_channels, out_channels, bn_norm, num_splits, stride=1, groups=1):
         super(Conv3x3, self).__init__()
         self.conv = nn.Conv2d(
             in_channels,
@@ -119,7 +122,7 @@ class Conv3x3(nn.Module):
             bias=False,
             groups=groups
         )
-        self.bn = nn.BatchNorm2d(out_channels)
+        self.bn = get_norm(bn_norm, out_channels, num_splits)
         self.relu = nn.ReLU(inplace=True)
 
     def forward(self, x):
@@ -134,7 +137,7 @@ class LightConv3x3(nn.Module):
     1x1 (linear) + dw 3x3 (nonlinear).
     """
 
-    def __init__(self, in_channels, out_channels):
+    def __init__(self, in_channels, out_channels, bn_norm, num_splits):
         super(LightConv3x3, self).__init__()
         self.conv1 = nn.Conv2d(
             in_channels, out_channels, 1, stride=1, padding=0, bias=False
@@ -148,7 +151,7 @@ class LightConv3x3(nn.Module):
             bias=False,
             groups=out_channels
         )
-        self.bn = nn.BatchNorm2d(out_channels)
+        self.bn = get_norm(bn_norm, out_channels, num_splits)
         self.relu = nn.ReLU(inplace=True)
 
     def forward(self, x):
@@ -197,9 +200,12 @@ class ChannelGate(nn.Module):
             bias=True,
             padding=0
         )
-        if gate_activation == 'sigmoid':  self.gate_activation = nn.Sigmoid()
-        elif gate_activation == 'relu':   self.gate_activation = nn.ReLU(inplace=True)
-        elif gate_activation == 'linear': self.gate_activation = nn.Identity()
+        if gate_activation == 'sigmoid':
+            self.gate_activation = nn.Sigmoid()
+        elif gate_activation == 'relu':
+            self.gate_activation = nn.ReLU(inplace=True)
+        elif gate_activation == 'linear':
+            self.gate_activation = nn.Identity()
         else:
             raise RuntimeError(
                 "Unknown gate activation: {}".format(gate_activation)
@@ -224,34 +230,36 @@ class OSBlock(nn.Module):
             self,
             in_channels,
             out_channels,
+            bn_norm,
+            num_splits,
             IN=False,
             bottleneck_reduction=4,
             **kwargs
     ):
         super(OSBlock, self).__init__()
         mid_channels = out_channels // bottleneck_reduction
-        self.conv1 = Conv1x1(in_channels, mid_channels)
-        self.conv2a = LightConv3x3(mid_channels, mid_channels)
+        self.conv1 = Conv1x1(in_channels, mid_channels, bn_norm, num_splits)
+        self.conv2a = LightConv3x3(mid_channels, mid_channels, bn_norm, num_splits)
         self.conv2b = nn.Sequential(
-            LightConv3x3(mid_channels, mid_channels),
-            LightConv3x3(mid_channels, mid_channels),
+            LightConv3x3(mid_channels, mid_channels, bn_norm, num_splits),
+            LightConv3x3(mid_channels, mid_channels, bn_norm, num_splits),
         )
         self.conv2c = nn.Sequential(
-            LightConv3x3(mid_channels, mid_channels),
-            LightConv3x3(mid_channels, mid_channels),
-            LightConv3x3(mid_channels, mid_channels),
+            LightConv3x3(mid_channels, mid_channels, bn_norm, num_splits),
+            LightConv3x3(mid_channels, mid_channels, bn_norm, num_splits),
+            LightConv3x3(mid_channels, mid_channels, bn_norm, num_splits),
         )
         self.conv2d = nn.Sequential(
-            LightConv3x3(mid_channels, mid_channels),
-            LightConv3x3(mid_channels, mid_channels),
-            LightConv3x3(mid_channels, mid_channels),
-            LightConv3x3(mid_channels, mid_channels),
+            LightConv3x3(mid_channels, mid_channels, bn_norm, num_splits),
+            LightConv3x3(mid_channels, mid_channels, bn_norm, num_splits),
+            LightConv3x3(mid_channels, mid_channels, bn_norm, num_splits),
+            LightConv3x3(mid_channels, mid_channels, bn_norm, num_splits),
         )
         self.gate = ChannelGate(mid_channels)
-        self.conv3 = Conv1x1Linear(mid_channels, out_channels)
+        self.conv3 = Conv1x1Linear(mid_channels, out_channels, bn_norm, num_splits)
         self.downsample = None
         if in_channels != out_channels:
-            self.downsample = Conv1x1Linear(in_channels, out_channels)
+            self.downsample = Conv1x1Linear(in_channels, out_channels, bn_norm, num_splits)
         self.IN = None
         if IN: self.IN = nn.InstanceNorm2d(out_channels, affine=True)
         self.relu = nn.ReLU(True)
@@ -290,6 +298,8 @@ class OSNet(nn.Module):
             blocks,
             layers,
             channels,
+            bn_norm,
+            num_splits,
             IN=False,
             **kwargs
     ):
@@ -299,13 +309,15 @@ class OSNet(nn.Module):
         assert num_blocks == len(channels) - 1
 
         # convolutional backbone
-        self.conv1 = ConvLayer(3, channels[0], 7, stride=2, padding=3, IN=IN)
+        self.conv1 = ConvLayer(3, channels[0], 7, bn_norm, num_splits, stride=2, padding=3, IN=IN)
         self.maxpool = nn.MaxPool2d(3, stride=2, padding=1)
         self.conv2 = self._make_layer(
             blocks[0],
             layers[0],
             channels[0],
             channels[1],
+            bn_norm,
+            num_splits,
             reduce_spatial_size=True,
             IN=IN
         )
@@ -314,6 +326,8 @@ class OSNet(nn.Module):
             layers[1],
             channels[1],
             channels[2],
+            bn_norm,
+            num_splits,
             reduce_spatial_size=True
         )
         self.conv4 = self._make_layer(
@@ -321,9 +335,11 @@ class OSNet(nn.Module):
             layers[2],
             channels[2],
             channels[3],
+            bn_norm,
+            num_splits,
             reduce_spatial_size=False
         )
-        self.conv5 = Conv1x1(channels[3], channels[3])
+        self.conv5 = Conv1x1(channels[3], channels[3], bn_norm, num_splits)
 
         self._init_params()
 
@@ -333,19 +349,21 @@ class OSNet(nn.Module):
             layer,
             in_channels,
             out_channels,
+            bn_norm,
+            num_splits,
             reduce_spatial_size,
             IN=False
     ):
         layers = []
 
-        layers.append(block(in_channels, out_channels, IN=IN))
+        layers.append(block(in_channels, out_channels, bn_norm, num_splits, IN=IN))
         for i in range(1, layer):
-            layers.append(block(out_channels, out_channels, IN=IN))
+            layers.append(block(out_channels, out_channels, bn_norm, num_splits, IN=IN))
 
         if reduce_spatial_size:
             layers.append(
                 nn.Sequential(
-                    Conv1x1(out_channels, out_channels),
+                    Conv1x1(out_channels, out_channels, bn_norm, num_splits),
                     nn.AvgPool2d(2, stride=2),
                 )
             )
@@ -477,11 +495,19 @@ def build_osnet_backbone(cfg):
     # fmt: off
     pretrain = cfg.MODEL.BACKBONE.PRETRAIN
     with_ibn = cfg.MODEL.BACKBONE.WITH_IBN
+    bn_norm = cfg.MODEL.BACKBONE.NORM
+    num_splits = cfg.MODEL.BACKBONE.NORM_SPLIT
+    depth = cfg.MODEL.BACKBONE.DEPTH
 
     num_blocks_per_stage = [2, 2, 2]
-    num_channels_per_stage = [64, 256, 384, 512]
-    model = OSNet([OSBlock, OSBlock, OSBlock], num_blocks_per_stage, num_channels_per_stage, with_ibn)
-    pretrain_key = 'osnet_ibn_x1_0' if with_ibn else 'osnet_x1_0'
+    num_channels_per_stage = {"x1_0": [64, 256, 384, 512], "x0_75": [48, 192, 288, 384], "x0_5": [32, 128, 192, 256],
+                              "x0_25": [16, 64, 96, 128]}[depth]
+    model = OSNet([OSBlock, OSBlock, OSBlock], num_blocks_per_stage, num_channels_per_stage,
+                  bn_norm, num_splits, IN=with_ibn)
+
     if pretrain:
+        if with_ibn: pretrain_key = "osnet_ibn_" + depth
+        else:        pretrain_key = "osnet_" + depth
+
         init_pretrained_weights(model, pretrain_key)
     return model
diff --git a/fastreid/modeling/backbones/regnet/regnet.py b/fastreid/modeling/backbones/regnet/regnet.py
index 782731b..5a55097 100644
--- a/fastreid/modeling/backbones/regnet/regnet.py
+++ b/fastreid/modeling/backbones/regnet/regnet.py
@@ -10,7 +10,18 @@ from ..build import BACKBONE_REGISTRY
 from .config import regnet_cfg
 
 logger = logging.getLogger(__name__)
-
+model_urls = {
+    '800x': 'https://dl.fbaipublicfiles.com/pycls/dds_baselines/160906036/RegNetX-800MF_dds_8gpu.pyth',
+    '800y': 'https://dl.fbaipublicfiles.com/pycls/dds_baselines/160906567/RegNetY-800MF_dds_8gpu.pyth',
+    '1600x': 'https://dl.fbaipublicfiles.com/pycls/dds_baselines/160990626/RegNetX-1.6GF_dds_8gpu.pyth',
+    '1600y': 'https://dl.fbaipublicfiles.com/pycls/dds_baselines/160906681/RegNetY-1.6GF_dds_8gpu.pyth',
+    '3200x': 'https://dl.fbaipublicfiles.com/pycls/dds_baselines/160906139/RegNetX-3.2GF_dds_8gpu.pyth',
+    '3200y': 'https://dl.fbaipublicfiles.com/pycls/dds_baselines/160906834/RegNetY-3.2GF_dds_8gpu.pyth',
+    '4000x': 'https://dl.fbaipublicfiles.com/pycls/dds_baselines/160906383/RegNetX-4.0GF_dds_8gpu.pyth',
+    '4000y': 'https://dl.fbaipublicfiles.com/pycls/dds_baselines/160906838/RegNetY-4.0GF_dds_8gpu.pyth',
+    '6400x': 'https://dl.fbaipublicfiles.com/pycls/dds_baselines/161116590/RegNetX-6.4GF_dds_8gpu.pyth',
+    '6400y': 'https://dl.fbaipublicfiles.com/pycls/dds_baselines/160907112/RegNetY-6.4GF_dds_8gpu.pyth',
+}
 
 def init_weights(m):
     """Performs ResNet-style weight initialization."""
@@ -464,14 +475,60 @@ class RegNet(AnyNet):
         super(RegNet, self).__init__(**kwargs)
 
 
+def init_pretrained_weights(key):
+    """Initializes model with pretrained weights.
+
+    Layers that don't match with pretrained layers in name or size are kept unchanged.
+    """
+    import os
+    import errno
+    import gdown
+
+    def _get_torch_home():
+        ENV_TORCH_HOME = 'TORCH_HOME'
+        ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME'
+        DEFAULT_CACHE_DIR = '~/.cache'
+        torch_home = os.path.expanduser(
+            os.getenv(
+                ENV_TORCH_HOME,
+                os.path.join(
+                    os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'torch'
+                )
+            )
+        )
+        return torch_home
+
+    torch_home = _get_torch_home()
+    model_dir = os.path.join(torch_home, 'checkpoints')
+    try:
+        os.makedirs(model_dir)
+    except OSError as e:
+        if e.errno == errno.EEXIST:
+            # Directory already exists, ignore.
+            pass
+        else:
+            # Unexpected OSError, re-raise.
+            raise
+
+    filename = model_urls[key].split('/')[-1]
+
+    cached_file = os.path.join(model_dir, filename)
+
+    if not os.path.exists(cached_file):
+        gdown.download(model_urls[key], cached_file, quiet=False)
+
+    logger.info(f"Loading pretrained model from {cached_file}")
+    state_dict = torch.load(cached_file)['model_state']
+
+    return state_dict
+
 @BACKBONE_REGISTRY.register()
 def build_regnet_backbone(cfg):
     # fmt: off
     pretrain = cfg.MODEL.BACKBONE.PRETRAIN
-    pretrain_path = cfg.MODEL.BACKBONE.PRETRAIN_PATH
     last_stride = cfg.MODEL.BACKBONE.LAST_STRIDE
     bn_norm = cfg.MODEL.BACKBONE.NORM
-    volume = cfg.MODEL.BACKBONE.VOLUME
+    depth = cfg.MODEL.BACKBONE.DEPTH
 
     cfg_files = {
         '800x': 'fastreid/modeling/backbones/regnet/regnetx/RegNetX-800MF_dds_8gpu.yaml',
@@ -480,21 +537,18 @@ def build_regnet_backbone(cfg):
         '1600y': 'fastreid/modeling/backbones/regnet/regnety/RegNetY-1.6GF_dds_8gpu.yaml',
         '3200x': 'fastreid/modeling/backbones/regnet/regnetx/RegNetX-3.2GF_dds_8gpu.yaml',
         '3200y': 'fastreid/modeling/backbones/regnet/regnety/RegNetY-3.2GF_dds_8gpu.yaml',
+        '4000x': 'fastreid/modeling/backbones/regnet/regnety/RegNetX-4.0GF_dds_8gpu.yaml',
+        '4000y': 'fastreid/modeling/backbones/regnet/regnety/RegNetY-4.0GF_dds_8gpu.yaml',
         '6400x': 'fastreid/modeling/backbones/regnet/regnetx/RegNetX-6.4GF_dds_8gpu.yaml',
         '6400y': 'fastreid/modeling/backbones/regnet/regnety/RegNetY-6.4GF_dds_8gpu.yaml',
-    }[volume]
+    }[depth]
 
     regnet_cfg.merge_from_file(cfg_files)
     model = RegNet(last_stride, bn_norm)
 
     if pretrain:
-        try:
-            state_dict = torch.load(pretrain_path, map_location=torch.device('cpu'))['model_state']
-        except FileNotFoundError as e:
-            logger.info(f'{pretrain_path} is not found! Please check this path.')
-            raise e
-
-        logger.info(f"Loading pretrained model from {pretrain_path}")
+        key = depth
+        state_dict = init_pretrained_weights(key)
 
         incompatible = model.load_state_dict(state_dict, strict=False)
         if incompatible.missing_keys:
diff --git a/fastreid/modeling/backbones/resnet.py b/fastreid/modeling/backbones/resnet.py
index 47d7846..1cae38f 100644
--- a/fastreid/modeling/backbones/resnet.py
+++ b/fastreid/modeling/backbones/resnet.py
@@ -9,7 +9,6 @@ import math
 
 import torch
 from torch import nn
-from torch.utils import model_zoo
 
 from fastreid.layers import (
     IBN,
@@ -22,11 +21,15 @@ from .build import BACKBONE_REGISTRY
 
 logger = logging.getLogger(__name__)
 model_urls = {
-    18: 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
-    34: 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
-    50: 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
-    101: 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
-    152: 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
+    '18x': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
+    '34x': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
+    '50x': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
+    '101x': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
+    'ibn_18x': 'https://github.com/XingangPan/IBN-Net/releases/download/v1.0/resnet18_ibn_a-2f571257.pth',
+    'ibn_34x': 'https://github.com/XingangPan/IBN-Net/releases/download/v1.0/resnet34_ibn_a-94bc1577.pth',
+    'ibn_50x': 'https://github.com/XingangPan/IBN-Net/releases/download/v1.0/resnet50_ibn_a-d9d0bb7b.pth',
+    'ibn_101x': 'https://github.com/XingangPan/IBN-Net/releases/download/v1.0/resnet101_ibn_a-59ea0ac6.pth',
+    'se_ibn_101x': 'https://github.com/XingangPan/IBN-Net/releases/download/v1.0/se_resnet101_ibn_a-fabed4e2.pth',
 }
 
 
@@ -37,7 +40,10 @@ class BasicBlock(nn.Module):
                  stride=1, downsample=None, reduction=16):
         super(BasicBlock, self).__init__()
         self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
-        self.bn1 = get_norm(bn_norm, planes, num_splits)
+        if with_ibn:
+            self.bn1 = IBN(planes, bn_norm, num_splits)
+        else:
+            self.bn1 = get_norm(bn_norm, planes, num_splits)
         self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
         self.bn2 = get_norm(bn_norm, planes, num_splits)
         self.relu = nn.ReLU(inplace=True)
@@ -146,8 +152,6 @@ class ResNet(nn.Module):
             )
 
         layers = []
-        if planes == 512:
-            with_ibn = False
         layers.append(block(self.inplanes, planes, bn_norm, num_splits, with_ibn, with_se, stride, downsample))
         self.inplanes = planes * block.expansion
         for i in range(1, blocks):
@@ -227,6 +231,54 @@ class ResNet(nn.Module):
                 nn.init.constant_(m.bias, 0)
 
 
+def init_pretrained_weights(key):
+    """Initializes model with pretrained weights.
+
+    Layers that don't match with pretrained layers in name or size are kept unchanged.
+    """
+    import os
+    import errno
+    import gdown
+
+    def _get_torch_home():
+        ENV_TORCH_HOME = 'TORCH_HOME'
+        ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME'
+        DEFAULT_CACHE_DIR = '~/.cache'
+        torch_home = os.path.expanduser(
+            os.getenv(
+                ENV_TORCH_HOME,
+                os.path.join(
+                    os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'torch'
+                )
+            )
+        )
+        return torch_home
+
+    torch_home = _get_torch_home()
+    model_dir = os.path.join(torch_home, 'checkpoints')
+    try:
+        os.makedirs(model_dir)
+    except OSError as e:
+        if e.errno == errno.EEXIST:
+            # Directory already exists, ignore.
+            pass
+        else:
+            # Unexpected OSError, re-raise.
+            raise
+
+    filename = model_urls[key].split('/')[-1]
+
+    cached_file = os.path.join(model_dir, filename)
+
+    if not os.path.exists(cached_file):
+        gdown.download(model_urls[key], cached_file, quiet=False)
+
+    logger.info(f"Loading pretrained model from {cached_file}")
+    state_dict = torch.load(cached_file)
+
+    return state_dict
+
+
 @BACKBONE_REGISTRY.register()
 def build_resnet_backbone(cfg):
     """
@@ -246,13 +298,15 @@ def build_resnet_backbone(cfg):
     with_nl = cfg.MODEL.BACKBONE.WITH_NL
     depth = cfg.MODEL.BACKBONE.DEPTH
 
-    num_blocks_per_stage = {34: [3, 4, 6, 3], 50: [3, 4, 6, 3], 101: [3, 4, 23, 3], 152: [3, 8, 36, 3], }[depth]
-    nl_layers_per_stage = {34: [0, 2, 3, 0], 50: [0, 2, 3, 0], 101: [0, 2, 9, 0]}[depth]
-    block = {34: BasicBlock, 50: Bottleneck, 101: Bottleneck}[depth]
+    num_blocks_per_stage = {'18x': [2, 2, 2, 2], '34x': [3, 4, 6, 3], '50x': [3, 4, 6, 3],
+                            '101x': [3, 4, 23, 3],}[depth]
+    nl_layers_per_stage = {'18x': [0, 0, 0, 0], '34x': [0, 0, 0, 0], '50x': [0, 2, 3, 0], '101x': [0, 2, 9, 0]}[depth]
+    block = {'18x': BasicBlock, '34x': BasicBlock, '50x': Bottleneck, '101x': Bottleneck}[depth]
     model = ResNet(last_stride, bn_norm, num_splits, with_ibn, with_se, with_nl, block,
                    num_blocks_per_stage, nl_layers_per_stage)
     if pretrain:
-        if not with_ibn:
+        # Load pretrain path if specifically
+        if pretrain_path:
             try:
                 state_dict = torch.load(pretrain_path, map_location=torch.device('cpu'))['model']
                 # Remove module.encoder in name
@@ -263,20 +317,19 @@ def build_resnet_backbone(cfg):
                         new_state_dict[new_k] = state_dict[k]
                 state_dict = new_state_dict
                 logger.info(f"Loading pretrained model from {pretrain_path}")
-            except FileNotFoundError or KeyError:
-                # original resnet
-                state_dict = model_zoo.load_url(model_urls[depth])
-                logger.info("Loading pretrained model from torchvision")
+            except FileNotFoundError as e:
+                logger.info(f'{pretrain_path} is not found! Please check this path.')
+                raise e
+            except KeyError as e:
+                logger.info("State dict keys error! Please check the state dict.")
+                raise e
         else:
-            state_dict = torch.load(pretrain_path, map_location=torch.device('cpu'))['state_dict']  # ibn-net
-            # Remove module in name
-            new_state_dict = {}
-            for k in state_dict:
-                new_k = '.'.join(k.split('.')[1:])
-                if new_k in model.state_dict() and (model.state_dict()[new_k].shape == state_dict[k].shape):
-                    new_state_dict[new_k] = state_dict[k]
-            state_dict = new_state_dict
-            logger.info(f"Loading pretrained model from {pretrain_path}")
+            key = depth
+            if with_ibn: key = 'ibn_' + key
+            if with_se:  key = 'se_' + key
+
+            state_dict = init_pretrained_weights(key)
+
         incompatible = model.load_state_dict(state_dict, strict=False)
         if incompatible.missing_keys:
             logger.info(
@@ -286,4 +339,5 @@ def build_resnet_backbone(cfg):
             logger.info(
                 get_unexpected_parameters_message(incompatible.unexpected_keys)
             )
+
     return model
diff --git a/fastreid/modeling/backbones/resnext.py b/fastreid/modeling/backbones/resnext.py
index 88ae3ac..a9b2519 100644
--- a/fastreid/modeling/backbones/resnext.py
+++ b/fastreid/modeling/backbones/resnext.py
@@ -1,21 +1,27 @@
 # encoding: utf-8
 """
 @author:  xingyu liao
-@contact: liaoxingyu5@jd.com
+@contact: sherlockliao01@gmail.com
 """
 
 # based on:
 # https://github.com/XingangPan/IBN-Net/blob/master/models/imagenet/resnext_ibn_a.py
 
-import math
 import logging
-import torch.nn as nn
-import torch.nn.functional as F
-from torch.nn import init
+import math
+
 import torch
-from ...layers import IBN
+import torch.nn as nn
+
+from fastreid.layers import IBN, get_norm
+from fastreid.utils.checkpoint import get_missing_parameters_message, get_unexpected_parameters_message
 from .build import BACKBONE_REGISTRY
 
+logger = logging.getLogger(__name__)
+model_urls = {
+    'ibn_101x': 'https://github.com/XingangPan/IBN-Net/releases/download/v1.0/resnext101_ibn_a-6ace051d.pth',
+}
+
 
 class Bottleneck(nn.Module):
     """
@@ -23,7 +29,8 @@ class Bottleneck(nn.Module):
     """
     expansion = 4
 
-    def __init__(self, inplanes, planes, with_ibn, baseWidth, cardinality, stride=1, downsample=None):
+    def __init__(self, inplanes, planes, bn_norm, num_splits, with_ibn, baseWidth, cardinality, stride=1,
+                 downsample=None):
         """ Constructor
         Args:
             inplanes: input channel dimensionality
@@ -38,13 +45,13 @@ class Bottleneck(nn.Module):
         C = cardinality
         self.conv1 = nn.Conv2d(inplanes, D * C, kernel_size=1, stride=1, padding=0, bias=False)
         if with_ibn:
-            self.bn1 = IBN(D * C)
+            self.bn1 = IBN(D * C, bn_norm, num_splits)
         else:
-            self.bn1 = nn.BatchNorm2d(D * C)
+            self.bn1 = get_norm(bn_norm, D * C, num_splits)
         self.conv2 = nn.Conv2d(D * C, D * C, kernel_size=3, stride=stride, padding=1, groups=C, bias=False)
-        self.bn2 = nn.BatchNorm2d(D * C)
+        self.bn2 = get_norm(bn_norm, D * C, num_splits)
         self.conv3 = nn.Conv2d(D * C, planes * 4, kernel_size=1, stride=1, padding=0, bias=False)
-        self.bn3 = nn.BatchNorm2d(planes * 4)
+        self.bn3 = get_norm(bn_norm, planes * 4, num_splits)
         self.relu = nn.ReLU(inplace=True)
 
         self.downsample = downsample
@@ -78,7 +85,7 @@ class ResNeXt(nn.Module):
     https://arxiv.org/pdf/1611.05431.pdf
     """
 
-    def __init__(self, last_stride, with_ibn, block, layers, baseWidth=4, cardinality=32):
+    def __init__(self, last_stride, bn_norm, num_splits, with_ibn, block, layers, baseWidth=4, cardinality=32):
         """ Constructor
         Args:
             baseWidth: baseWidth for ResNeXt.
@@ -94,17 +101,17 @@ class ResNeXt(nn.Module):
         self.output_size = 64
 
         self.conv1 = nn.Conv2d(3, 64, 7, 2, 3, bias=False)
-        self.bn1 = nn.BatchNorm2d(64)
+        self.bn1 = get_norm(bn_norm, 64, num_splits)
         self.relu = nn.ReLU(inplace=True)
         self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
-        self.layer1 = self._make_layer(block, 64, layers[0], with_ibn=with_ibn)
-        self.layer2 = self._make_layer(block, 128, layers[1], stride=2, with_ibn=with_ibn)
-        self.layer3 = self._make_layer(block, 256, layers[2], stride=2, with_ibn=with_ibn)
-        self.layer4 = self._make_layer(block, 512, layers[3], stride=last_stride, with_ibn=with_ibn)
+        self.layer1 = self._make_layer(block, 64, layers[0], 1, bn_norm, num_splits, with_ibn=with_ibn)
+        self.layer2 = self._make_layer(block, 128, layers[1], 2, bn_norm, num_splits, with_ibn=with_ibn)
+        self.layer3 = self._make_layer(block, 256, layers[2], 2, bn_norm, num_splits, with_ibn=with_ibn)
+        self.layer4 = self._make_layer(block, 512, layers[3], last_stride, bn_norm, num_splits, with_ibn=with_ibn)
 
         self.random_init()
 
-    def _make_layer(self, block, planes, blocks, stride=1, with_ibn=False):
+    def _make_layer(self, block, planes, blocks, stride=1, bn_norm='BN', num_splits=1, with_ibn=False):
         """ Stack n bottleneck modules where n is inferred from the depth of the network.
         Args:
             block: block type used to construct ResNext
@@ -118,16 +125,18 @@ class ResNeXt(nn.Module):
             downsample = nn.Sequential(
                 nn.Conv2d(self.inplanes, planes * block.expansion,
                           kernel_size=1, stride=stride, bias=False),
-                nn.BatchNorm2d(planes * block.expansion),
+                get_norm(bn_norm, planes * block.expansion, num_splits),
             )
 
         layers = []
         if planes == 512:
             with_ibn = False
-        layers.append(block(self.inplanes, planes, with_ibn, self.baseWidth, self.cardinality, stride, downsample))
+        layers.append(block(self.inplanes, planes, bn_norm, num_splits, with_ibn,
+                            self.baseWidth, self.cardinality, stride, downsample))
         self.inplanes = planes * block.expansion
         for i in range(1, blocks):
-            layers.append(block(self.inplanes, planes, with_ibn, self.baseWidth, self.cardinality, 1, None))
+            layers.append(
+                block(self.inplanes, planes, bn_norm, num_splits, with_ibn, self.baseWidth, self.cardinality, 1, None))
 
         return nn.Sequential(*layers)
 
@@ -157,6 +166,53 @@ class ResNeXt(nn.Module):
                 m.bias.data.zero_()
 
 
+def init_pretrained_weights(key):
+    """Initializes model with pretrained weights.
+
+    Layers that don't match with pretrained layers in name or size are kept unchanged.
+    """
+    import os
+    import errno
+    import gdown
+
+    def _get_torch_home():
+        ENV_TORCH_HOME = 'TORCH_HOME'
+        ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME'
+        DEFAULT_CACHE_DIR = '~/.cache'
+        torch_home = os.path.expanduser(
+            os.getenv(
+                ENV_TORCH_HOME,
+                os.path.join(
+                    os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'torch'
+                )
+            )
+        )
+        return torch_home
+
+    torch_home = _get_torch_home()
+    model_dir = os.path.join(torch_home, 'checkpoints')
+    try:
+        os.makedirs(model_dir)
+    except OSError as e:
+        if e.errno == errno.EEXIST:
+            # Directory already exists, ignore.
+            pass
+        else:
+            # Unexpected OSError, re-raise.
+            raise
+
+    filename = model_urls[key].split('/')[-1]
+
+    cached_file = os.path.join(model_dir, filename)
+
+    if not os.path.exists(cached_file):
+        gdown.download(model_urls[key], cached_file, quiet=False)
+
+    logger.info(f"Loading pretrained model from {cached_file}")
+    state_dict = torch.load(cached_file)
+
+    return state_dict
+
 @BACKBONE_REGISTRY.register()
 def build_resnext_backbone(cfg):
     """
@@ -169,30 +225,48 @@ def build_resnext_backbone(cfg):
     pretrain = cfg.MODEL.BACKBONE.PRETRAIN
     pretrain_path = cfg.MODEL.BACKBONE.PRETRAIN_PATH
     last_stride = cfg.MODEL.BACKBONE.LAST_STRIDE
+    bn_norm = cfg.MODEL.BACKBONE.NORM
+    num_splits = cfg.MODEL.BACKBONE.NORM_SPLIT
     with_ibn = cfg.MODEL.BACKBONE.WITH_IBN
-    with_se = cfg.MODEL.BACKBONE.WITH_SE
     with_nl = cfg.MODEL.BACKBONE.WITH_NL
     depth = cfg.MODEL.BACKBONE.DEPTH
 
-    num_blocks_per_stage = {50: [3, 4, 6, 3], 101: [3, 4, 23, 3], 152: [3, 8, 36, 3], }[depth]
-    nl_layers_per_stage = {50: [0, 2, 3, 0], 101: [0, 2, 3, 0]}[depth]
-    model = ResNeXt(last_stride, with_ibn, Bottleneck, num_blocks_per_stage)
+    num_blocks_per_stage = {'50x': [3, 4, 6, 3], '101x': [3, 4, 23, 3], '152x': [3, 8, 36, 3], }[depth]
+    nl_layers_per_stage = {'50x': [0, 2, 3, 0], '101x': [0, 2, 3, 0]}[depth]
+    model = ResNeXt(last_stride, bn_norm, num_splits, with_ibn, Bottleneck, num_blocks_per_stage)
+
     if pretrain:
-        # if not with_ibn:
-        # original resnet
-        # state_dict = model_zoo.load_url(model_urls[depth])
-        # else:
-        # ibn resnet
-        state_dict = torch.load(pretrain_path)['state_dict']
-        # remove module in name
-        new_state_dict = {}
-        for k in state_dict:
-            new_k = '.'.join(k.split('.')[1:])
-            if new_k in model.state_dict() and (model.state_dict()[new_k].shape == state_dict[k].shape):
-                new_state_dict[new_k] = state_dict[k]
-        state_dict = new_state_dict
-        res = model.load_state_dict(state_dict, strict=False)
-        logger = logging.getLogger(__name__)
-        logger.info('missing keys is {}'.format(res.missing_keys))
-        logger.info('unexpected keys is {}'.format(res.unexpected_keys))
+        if pretrain_path:
+            try:
+                state_dict = torch.load(pretrain_path, map_location=torch.device('cpu'))['model']
+                # Remove module.encoder in name
+                new_state_dict = {}
+                for k in state_dict:
+                    new_k = '.'.join(k.split('.')[2:])
+                    if new_k in model.state_dict() and (model.state_dict()[new_k].shape == state_dict[k].shape):
+                        new_state_dict[new_k] = state_dict[k]
+                state_dict = new_state_dict
+                logger.info(f"Loading pretrained model from {pretrain_path}")
+            except FileNotFoundError as e:
+                logger.info(f'{pretrain_path} is not found! Please check this path.')
+                raise e
+            except KeyError as e:
+                logger.info("State dict keys error! Please check the state dict.")
+                raise e
+        else:
+            key = depth
+            if with_ibn: key = 'ibn_' + key
+
+            state_dict = init_pretrained_weights(key)
+
+        incompatible = model.load_state_dict(state_dict, strict=False)
+        if incompatible.missing_keys:
+            logger.info(
+                get_missing_parameters_message(incompatible.missing_keys)
+            )
+        if incompatible.unexpected_keys:
+            logger.info(
+                get_unexpected_parameters_message(incompatible.unexpected_keys)
+            )
+
     return model
diff --git a/fastreid/modeling/heads/bnneck_head.py b/fastreid/modeling/heads/bnneck_head.py
index 2d410bc..5b46c6a 100644
--- a/fastreid/modeling/heads/bnneck_head.py
+++ b/fastreid/modeling/heads/bnneck_head.py
@@ -24,9 +24,10 @@ class BNneckHead(nn.Module):
         if cls_type == 'linear':          self.classifier = nn.Linear(in_feat, num_classes, bias=False)
         elif cls_type == 'arcSoftmax':    self.classifier = ArcSoftmax(cfg, in_feat, num_classes)
         elif cls_type == 'circleSoftmax': self.classifier = CircleSoftmax(cfg, in_feat, num_classes)
+        elif cls_type == 'amSoftmax':     self.classifier = AMSoftmax(cfg, in_feat, num_classes)
         else:
             raise KeyError(f"{cls_type} is invalid, please choose from "
-                           f"'linear', 'arcSoftmax' and 'circleSoftmax'.")
+                           f"'linear', 'arcSoftmax', 'amSoftmax' and 'circleSoftmax'.")
 
         self.classifier.apply(weights_init_classifier)
 
diff --git a/fastreid/modeling/heads/linear_head.py b/fastreid/modeling/heads/linear_head.py
index 2641407..b22bd7f 100644
--- a/fastreid/modeling/heads/linear_head.py
+++ b/fastreid/modeling/heads/linear_head.py
@@ -20,9 +20,10 @@ class LinearHead(nn.Module):
         if cls_type == 'linear':          self.classifier = nn.Linear(in_feat, num_classes, bias=False)
         elif cls_type == 'arcSoftmax':    self.classifier = ArcSoftmax(cfg, in_feat, num_classes)
         elif cls_type == 'circleSoftmax': self.classifier = CircleSoftmax(cfg, in_feat, num_classes)
+        elif cls_type == 'amSoftmax':     self.classifier = AMSoftmax(cfg, in_feat, num_classes)
         else:
             raise KeyError(f"{cls_type} is invalid, please choose from "
-                           f"'linear', 'arcSoftmax' and 'circleSoftmax'.")
+                           f"'linear', 'arcSoftmax', 'amSoftmax' and 'circleSoftmax'.")
 
         self.classifier.apply(weights_init_classifier)
 
diff --git a/fastreid/modeling/heads/reduction_head.py b/fastreid/modeling/heads/reduction_head.py
index af6ea40..7d55b3f 100644
--- a/fastreid/modeling/heads/reduction_head.py
+++ b/fastreid/modeling/heads/reduction_head.py
@@ -32,12 +32,13 @@ class ReductionHead(nn.Module):
 
         # identity classification layer
         cls_type = cfg.MODEL.HEADS.CLS_LAYER
-        if cls_type == 'linear':          self.classifier = nn.Linear(in_feat, num_classes, bias=False)
-        elif cls_type == 'arcSoftmax':    self.classifier = ArcSoftmax(cfg, in_feat, num_classes)
-        elif cls_type == 'circleSoftmax': self.classifier = CircleSoftmax(cfg, in_feat, num_classes)
+        if cls_type == 'linear':          self.classifier = nn.Linear(reduction_dim, num_classes, bias=False)
+        elif cls_type == 'arcSoftmax':    self.classifier = ArcSoftmax(cfg, reduction_dim, num_classes)
+        elif cls_type == 'circleSoftmax': self.classifier = CircleSoftmax(cfg, reduction_dim, num_classes)
+        elif cls_type == 'amSoftmax':     self.classifier = AMSoftmax(cfg, reduction_dim, num_classes)
         else:
             raise KeyError(f"{cls_type} is invalid, please choose from "
-                           f"'linear', 'arcSoftmax' and 'circleSoftmax'.")
+                           f"'linear', 'arcSoftmax', 'amSoftmax' and 'circleSoftmax'.")
 
         self.classifier.apply(weights_init_classifier)
 
diff --git a/fastreid/modeling/losses/__init__.py b/fastreid/modeling/losses/__init__.py
index 977e5c3..7411aff 100644
--- a/fastreid/modeling/losses/__init__.py
+++ b/fastreid/modeling/losses/__init__.py
@@ -6,4 +6,5 @@
 
 from .cross_entroy_loss import CrossEntropyLoss
 from .focal_loss import FocalLoss
-from .metric_loss import TripletLoss, CircleLoss
+from .triplet_loss import TripletLoss
+from .circle_loss import CircleLoss
diff --git a/fastreid/modeling/losses/circle_loss.py b/fastreid/modeling/losses/circle_loss.py
new file mode 100644
index 0000000..b945659
--- /dev/null
+++ b/fastreid/modeling/losses/circle_loss.py
@@ -0,0 +1,61 @@
+# encoding: utf-8
+"""
+@author:  xingyu liao
+@contact: sherlockliao01@gmail.com
+"""
+
+import torch
+from torch import nn
+import torch.nn.functional as F
+
+from fastreid.utils import comm
+from .utils import concat_all_gather
+
+
+class CircleLoss(object):
+    def __init__(self, cfg):
+        self._scale = cfg.MODEL.LOSSES.CIRCLE.SCALE
+
+        self._m = cfg.MODEL.LOSSES.CIRCLE.MARGIN
+        self._s = cfg.MODEL.LOSSES.CIRCLE.ALPHA
+
+    def __call__(self, embedding, targets):
+        embedding = nn.functional.normalize(embedding, dim=1)
+
+        if comm.get_world_size() > 1:
+            all_embedding = concat_all_gather(embedding)
+            all_targets = concat_all_gather(targets)
+        else:
+            all_embedding = embedding
+            all_targets = targets
+
+        dist_mat = torch.matmul(embedding, all_embedding.t())
+
+        N, M = dist_mat.size()
+        is_pos = targets.view(N, 1).expand(N, M).eq(all_targets.view(M, 1).expand(M, N).t()).float()
+
+        # Compute the mask which ignores the relevance score of the query to itself
+        if M > N:
+            identity_indx = torch.eye(N, N, device=is_pos.device)
+            remain_indx = torch.zeros(N, M - N, device=is_pos.device)
+            identity_indx = torch.cat((identity_indx, remain_indx), dim=1)
+            is_pos = is_pos - identity_indx
+        else:
+            is_pos = is_pos - torch.eye(N, N, device=is_pos.device)
+
+        is_neg = targets.view(N, 1).expand(N, M).ne(all_targets.view(M, 1).expand(M, N).t())
+
+        s_p = dist_mat * is_pos
+        s_n = dist_mat * is_neg
+
+        alpha_p = torch.clamp_min(-s_p.detach() + 1 + self._m, min=0.)
+        alpha_n = torch.clamp_min(s_n.detach() + self._m, min=0.)
+        delta_p = 1 - self._m
+        delta_n = self._m
+
+        logit_p = - self._s * alpha_p * (s_p - delta_p)
+        logit_n = self._s * alpha_n * (s_n - delta_n)
+
+        loss = nn.functional.softplus(torch.logsumexp(logit_p, dim=1) + torch.logsumexp(logit_n, dim=1)).mean()
+
+        return loss * self._scale
diff --git a/fastreid/modeling/losses/cross_entroy_loss.py b/fastreid/modeling/losses/cross_entroy_loss.py
index 8a09843..c86e47e 100644
--- a/fastreid/modeling/losses/cross_entroy_loss.py
+++ b/fastreid/modeling/losses/cross_entroy_loss.py
@@ -58,5 +58,11 @@ class CrossEntropyLoss(object):
             targets *= smooth_param / (self._num_classes - 1)
             targets.scatter_(1, gt_classes.data.unsqueeze(1), (1 - smooth_param))
 
-        loss = (-targets * log_probs).mean(0).sum()
+        loss = (-targets * log_probs).sum(dim=1)
+
+        with torch.no_grad():
+            non_zero_cnt = max(loss.nonzero().size(0), 1)
+
+        loss = loss.sum() / non_zero_cnt
+
         return loss * self._scale
diff --git a/fastreid/modeling/losses/smooth_ap.py b/fastreid/modeling/losses/smooth_ap.py
new file mode 100644
index 0000000..6305ca7
--- /dev/null
+++ b/fastreid/modeling/losses/smooth_ap.py
@@ -0,0 +1,241 @@
+# encoding: utf-8
+"""
+@author:  xingyu liao
+@contact: sherlockliao01@gmail.com
+"""
+
+# based on:
+# https://github.com/Andrew-Brown1/Smooth_AP/blob/master/src/Smooth_AP_loss.py
+
+import torch
+import torch.nn.functional as F
+
+from fastreid.utils import comm
+from fastreid.modeling.losses.utils import concat_all_gather
+
+
+def sigmoid(tensor, temp=1.0):
+    """ temperature controlled sigmoid
+    takes as input a torch tensor (tensor) and passes it through a sigmoid, controlled by temperature: temp
+    """
+    exponent = -tensor / temp
+    # clamp the input tensor for stability
+    exponent = torch.clamp(exponent, min=-50, max=50)
+    y = 1.0 / (1.0 + torch.exp(exponent))
+    return y
+
+
+class SmoothAP(object):
+    r"""PyTorch implementation of the Smooth-AP loss.
+    implementation of the Smooth-AP loss. Takes as input the mini-batch of CNN-produced feature embeddings and returns
+    the value of the Smooth-AP loss. The mini-batch must be formed of a defined number of classes. Each class must
+    have the same number of instances represented in the mini-batch and must be ordered sequentially by class.
+    e.g. the labels for a mini-batch with batch size 9, and 3 represented classes (A,B,C) must look like:
+        labels = ( A, A, A, B, B, B, C, C, C)
+    (the order of the classes however does not matter)
+    For each instance in the mini-batch, the loss computes the Smooth-AP when it is used as the query and the rest of the
+    mini-batch is used as the retrieval set. The positive set is formed of the other instances in the batch from the
+    same class. The loss returns the average Smooth-AP across all instances in the mini-batch.
+    Args:
+        anneal : float
+            the temperature of the sigmoid that is used to smooth the ranking function. A low value of the temperature
+            results in a steep sigmoid, that tightly approximates the heaviside step function in the ranking function.
+        batch_size : int
+            the batch size being used during training.
+        num_id : int
+            the number of different classes that are represented in the batch.
+        feat_dims : int
+            the dimension of the input feature embeddings
+    Shape:
+        - Input (preds): (batch_size, feat_dims) (must be a cuda torch float tensor)
+        - Output: scalar
+    Examples::
+        >>> loss = SmoothAP(0.01, 60, 6, 256)
+        >>> input = torch.randn(60, 256, requires_grad=True).cuda()
+        >>> output = loss(input)
+        >>> output.backward()
+    """
+
+    def __init__(self, cfg):
+        r"""
+        Parameters
+        ----------
+        cfg: (cfgNode)
+
+        anneal : float
+            the temperature of the sigmoid that is used to smooth the ranking function
+        batch_size : int
+            the batch size being used
+        num_id : int
+            the number of different classes that are represented in the batch
+        feat_dims : int
+            the dimension of the input feature embeddings
+        """
+
+        self.anneal = 0.01
+        self.num_id = cfg.SOLVER.IMS_PER_BATCH // cfg.DATALOADER.NUM_INSTANCE
+        # self.num_id = 6
+
+    def __call__(self, embedding, targets):
+        """Forward pass for all input predictions: preds - (batch_size x feat_dims) """
+
+        # ------ differentiable ranking of all retrieval set ------
+        embedding = F.normalize(embedding, dim=1)
+
+        feat_dim = embedding.size(1)
+
+        # For distributed training, gather all features from different process.
+        if comm.get_world_size() > 1:
+            all_embedding = concat_all_gather(embedding)
+            all_targets = concat_all_gather(targets)
+        else:
+            all_embedding = embedding
+            all_targets = targets
+
+        sim_dist = torch.matmul(embedding, all_embedding.t())
+        N, M = sim_dist.size()
+
+        # Compute the mask which ignores the relevance score of the query to itself
+        mask_indx = 1.0 - torch.eye(M, device=sim_dist.device)
+        mask_indx = mask_indx.unsqueeze(dim=0).repeat(N, 1, 1)  # (N, M, M)
+
+        # sim_dist -> N, 1, M -> N, M, N
+        sim_dist_repeat = sim_dist.unsqueeze(dim=1).repeat(1, M, 1)  # (N, M, M)
+        # sim_dist_repeat_t = sim_dist.t().unsqueeze(dim=1).repeat(1, N, 1)  # (N, N, M)
+
+        # Compute the difference matrix
+        sim_diff = sim_dist_repeat - sim_dist_repeat.permute(0, 2, 1)  # (N, M, M)
+
+        # Pass through the sigmoid
+        sim_sg = sigmoid(sim_diff, temp=self.anneal) * mask_indx
+
+        # Compute all the rankings
+        sim_all_rk = torch.sum(sim_sg, dim=-1) + 1  # (N, N)
+
+        pos_mask = targets.view(N, 1).expand(N, M).eq(all_targets.view(M, 1).expand(M, N).t()).float()  # (N, M)
+
+        pos_mask_repeat = pos_mask.unsqueeze(1).repeat(1, M, 1)  # (N, M, M)
+
+        # Compute positive rankings
+        pos_sim_sg = sim_sg * pos_mask_repeat
+        sim_pos_rk = torch.sum(pos_sim_sg, dim=-1) + 1  # (N, N)
+
+        # sum the values of the Smooth-AP for all instances in the mini-batch
+        ap = 0
+        group = N // self.num_id
+        for ind in range(self.num_id):
+            pos_divide = torch.sum(
+                sim_pos_rk[(ind * group):((ind + 1) * group), (ind * group):((ind + 1) * group)] / (sim_all_rk[(ind * group):((ind + 1) * group), (ind * group):((ind + 1) * group)]))
+            ap += pos_divide / torch.sum(pos_mask[ind*group]) / N
+        return 1 - ap
+
+
+class SmoothAP_old(torch.nn.Module):
+    """PyTorch implementation of the Smooth-AP loss.
+    implementation of the Smooth-AP loss. Takes as input the mini-batch of CNN-produced feature embeddings and returns
+    the value of the Smooth-AP loss. The mini-batch must be formed of a defined number of classes. Each class must
+    have the same number of instances represented in the mini-batch and must be ordered sequentially by class.
+    e.g. the labels for a mini-batch with batch size 9, and 3 represented classes (A,B,C) must look like:
+        labels = ( A, A, A, B, B, B, C, C, C)
+    (the order of the classes however does not matter)
+    For each instance in the mini-batch, the loss computes the Smooth-AP when it is used as the query and the rest of the
+    mini-batch is used as the retrieval set. The positive set is formed of the other instances in the batch from the
+    same class. The loss returns the average Smooth-AP across all instances in the mini-batch.
+    Args:
+        anneal : float
+            the temperature of the sigmoid that is used to smooth the ranking function. A low value of the temperature
+            results in a steep sigmoid, that tightly approximates the heaviside step function in the ranking function.
+        batch_size : int
+            the batch size being used during training.
+        num_id : int
+            the number of different classes that are represented in the batch.
+        feat_dims : int
+            the dimension of the input feature embeddings
+    Shape:
+        - Input (preds): (batch_size, feat_dims) (must be a cuda torch float tensor)
+        - Output: scalar
+    Examples::
+        >>> loss = SmoothAP(0.01, 60, 6, 256)
+        >>> input = torch.randn(60, 256, requires_grad=True).cuda()
+        >>> output = loss(input)
+        >>> output.backward()
+    """
+
+    def __init__(self, anneal, batch_size, num_id, feat_dims):
+        """
+        Parameters
+        ----------
+        anneal : float
+            the temperature of the sigmoid that is used to smooth the ranking function
+        batch_size : int
+            the batch size being used
+        num_id : int
+            the number of different classes that are represented in the batch
+        feat_dims : int
+            the dimension of the input feature embeddings
+        """
+        super().__init__()
+
+        assert(batch_size%num_id==0)
+
+        self.anneal = anneal
+        self.batch_size = batch_size
+        self.num_id = num_id
+        self.feat_dims = feat_dims
+
+    def forward(self, preds):
+        """Forward pass for all input predictions: preds - (batch_size x feat_dims) """
+
+        preds = F.normalize(preds, dim=1)
+        # ------ differentiable ranking of all retrieval set ------
+        # compute the mask which ignores the relevance score of the query to itself
+        mask = 1.0 - torch.eye(self.batch_size)
+        mask = mask.unsqueeze(dim=0).repeat(self.batch_size, 1, 1)
+        # compute the relevance scores via cosine similarity of the CNN-produced embedding vectors
+        sim_all = torch.mm(preds, preds.t())
+        sim_all_repeat = sim_all.unsqueeze(dim=1).repeat(1, self.batch_size, 1)
+        # compute the difference matrix
+        sim_diff = sim_all_repeat - sim_all_repeat.permute(0, 2, 1)
+        # pass through the sigmoid
+        sim_sg = sigmoid(sim_diff, temp=self.anneal) * mask
+        # compute the rankings
+        sim_all_rk = torch.sum(sim_sg, dim=-1) + 1
+
+        # ------ differentiable ranking of only positive set in retrieval set ------
+        # compute the mask which only gives non-zero weights to the positive set
+        xs = preds.view(self.num_id, int(self.batch_size / self.num_id), self.feat_dims)
+        pos_mask = 1.0 - torch.eye(int(self.batch_size / self.num_id))
+        pos_mask = pos_mask.unsqueeze(dim=0).unsqueeze(dim=0).repeat(self.num_id, int(self.batch_size / self.num_id), 1, 1)
+        # compute the relevance scores
+        sim_pos = torch.bmm(xs, xs.permute(0, 2, 1))
+        sim_pos_repeat = sim_pos.unsqueeze(dim=2).repeat(1, 1, int(self.batch_size / self.num_id), 1)
+        # compute the difference matrix
+        sim_pos_diff = sim_pos_repeat - sim_pos_repeat.permute(0, 1, 3, 2)
+        # pass through the sigmoid
+        sim_pos_sg = sigmoid(sim_pos_diff, temp=self.anneal) * pos_mask
+        # compute the rankings of the positive set
+        sim_pos_rk = torch.sum(sim_pos_sg, dim=-1) + 1
+
+        # sum the values of the Smooth-AP for all instances in the mini-batch
+        ap = torch.zeros(1)
+        group = int(self.batch_size / self.num_id)
+        for ind in range(self.num_id):
+            pos_divide = torch.sum(sim_pos_rk[ind] / (sim_all_rk[(ind * group):((ind + 1) * group), (ind * group):((ind + 1) * group)]))
+            ap = ap + ((pos_divide / group) / self.batch_size)
+
+        return 1-ap
+
+if __name__ == '__main__':
+    loss1 = SmoothAP(0.01)
+    loss2 = SmoothAP_old(0.01, 60, 6, 256)
+
+    inputs = torch.randn(60, 256, requires_grad=True)
+    targets = []
+    for i in range(6):
+        targets.extend([i]*10)
+    targets = torch.LongTensor(targets)
+
+    output1 = loss1(inputs, targets)
+    output2 = loss2(inputs)
+
+    print(torch.sum(output1 - output2))
diff --git a/fastreid/modeling/losses/metric_loss.py b/fastreid/modeling/losses/triplet_loss.py
similarity index 63%
rename from fastreid/modeling/losses/metric_loss.py
rename to fastreid/modeling/losses/triplet_loss.py
index cca9edb..5e1aa74 100644
--- a/fastreid/modeling/losses/metric_loss.py
+++ b/fastreid/modeling/losses/triplet_loss.py
@@ -8,51 +8,7 @@ import torch
 import torch.nn.functional as F
 
 from fastreid.utils import comm
-
-
-# utils
-@torch.no_grad()
-def concat_all_gather(tensor):
-    """
-    Performs all_gather operation on the provided tensors.
-    *** Warning ***: torch.distributed.all_gather has no gradient.
-    """
-    tensors_gather = [torch.ones_like(tensor)
-                      for _ in range(torch.distributed.get_world_size())]
-    torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
-
-    output = torch.cat(tensors_gather, dim=0)
-    return output
-
-
-def normalize(x, axis=-1):
-    """Normalizing to unit length along the specified dimension.
-    Args:
-      x: pytorch Variable
-    Returns:
-      x: pytorch Variable, same shape as input
-    """
-    x = 1. * x / (torch.norm(x, 2, axis, keepdim=True).expand_as(x) + 1e-12)
-    return x
-
-
-def euclidean_dist(x, y):
-    m, n = x.size(0), y.size(0)
-    xx = torch.pow(x, 2).sum(1, keepdim=True).expand(m, n)
-    yy = torch.pow(y, 2).sum(1, keepdim=True).expand(n, m).t()
-    dist = xx + yy
-    dist.addmm_(1, -2, x, y.t())
-    dist = dist.clamp(min=1e-12).sqrt()  # for numerical stability
-    return dist
-
-
-def cosine_dist(x, y):
-    bs1, bs2 = x.size(0), y.size(0)
-    frac_up = torch.matmul(x, y.transpose(0, 1))
-    frac_down = (torch.sqrt(torch.sum(torch.pow(x, 2), 1))).view(bs1, 1).repeat(1, bs2) * \
-                (torch.sqrt(torch.sum(torch.pow(y, 2), 1))).view(1, bs2).repeat(bs1, 1)
-    cosine = frac_up / frac_down
-    return 1 - cosine
+from .utils import concat_all_gather, euclidean_dist, normalize
 
 
 def softmax_weights(dist, mask):
@@ -174,42 +130,3 @@ class TripletLoss(object):
             if loss == float('Inf'): loss = F.margin_ranking_loss(dist_an, dist_ap, y, margin=0.3)
 
         return loss * self._scale
-
-
-class CircleLoss(object):
-    def __init__(self, cfg):
-        self._scale = cfg.MODEL.LOSSES.CIRCLE.SCALE
-
-        self.m = cfg.MODEL.LOSSES.CIRCLE.MARGIN
-        self.s = cfg.MODEL.LOSSES.CIRCLE.ALPHA
-
-    def __call__(self, embedding, targets):
-        embedding = F.normalize(embedding, dim=1)
-
-        if comm.get_world_size() > 1:
-            all_embedding = concat_all_gather(embedding)
-            all_targets = concat_all_gather(targets)
-        else:
-            all_embedding = embedding
-            all_targets = targets
-
-        dist_mat = torch.matmul(embedding, all_embedding.t())
-
-        N, M = dist_mat.size()
-        is_pos = targets.view(N, 1).expand(N, M).eq(all_targets.view(M, 1).expand(M, N).t())
-        is_neg = targets.view(N, 1).expand(N, M).ne(all_targets.view(M, 1).expand(M, N).t())
-
-        s_p = dist_mat[is_pos].contiguous().view(N, -1)
-        s_n = dist_mat[is_neg].contiguous().view(N, -1)
-
-        alpha_p = F.relu(-s_p.detach() + 1 + self.m)
-        alpha_n = F.relu(s_n.detach() + self.m)
-        delta_p = 1 - self.m
-        delta_n = self.m
-
-        logit_p = - self.s * alpha_p * (s_p - delta_p)
-        logit_n = self.s * alpha_n * (s_n - delta_n)
-
-        loss = F.softplus(torch.logsumexp(logit_p, dim=1) + torch.logsumexp(logit_n, dim=1)).mean()
-
-        return loss * self._scale
diff --git a/fastreid/modeling/losses/utils.py b/fastreid/modeling/losses/utils.py
new file mode 100644
index 0000000..a0ff648
--- /dev/null
+++ b/fastreid/modeling/losses/utils.py
@@ -0,0 +1,51 @@
+# encoding: utf-8
+"""
+@author:  xingyu liao
+@contact: sherlockliao01@gmail.com
+"""
+
+import torch
+
+
+@torch.no_grad()
+def concat_all_gather(tensor):
+    """
+    Performs all_gather operation on the provided tensors.
+    *** Warning ***: torch.distributed.all_gather has no gradient.
+    """
+    tensors_gather = [torch.ones_like(tensor)
+                      for _ in range(torch.distributed.get_world_size())]
+    torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
+
+    output = torch.cat(tensors_gather, dim=0)
+    return output
+
+
+def normalize(x, axis=-1):
+    """Normalizing to unit length along the specified dimension.
+    Args:
+      x: pytorch Variable
+    Returns:
+      x: pytorch Variable, same shape as input
+    """
+    x = 1. * x / (torch.norm(x, 2, axis, keepdim=True).expand_as(x) + 1e-12)
+    return x
+
+
+def euclidean_dist(x, y):
+    m, n = x.size(0), y.size(0)
+    xx = torch.pow(x, 2).sum(1, keepdim=True).expand(m, n)
+    yy = torch.pow(y, 2).sum(1, keepdim=True).expand(n, m).t()
+    dist = xx + yy
+    dist.addmm_(1, -2, x, y.t())
+    dist = dist.clamp(min=1e-12).sqrt()  # for numerical stability
+    return dist
+
+
+def cosine_dist(x, y):
+    bs1, bs2 = x.size(0), y.size(0)
+    frac_up = torch.matmul(x, y.transpose(0, 1))
+    frac_down = (torch.sqrt(torch.sum(torch.pow(x, 2), 1))).view(bs1, 1).repeat(1, bs2) * \
+                (torch.sqrt(torch.sum(torch.pow(y, 2), 1))).view(1, bs2).repeat(bs1, 1)
+    cosine = frac_up / frac_down
+    return 1 - cosine
diff --git a/fastreid/modeling/meta_arch/baseline.py b/fastreid/modeling/meta_arch/baseline.py
index aa59460..543cb0c 100644
--- a/fastreid/modeling/meta_arch/baseline.py
+++ b/fastreid/modeling/meta_arch/baseline.py
@@ -28,7 +28,8 @@ class Baseline(nn.Module):
 
         # head
         pool_type = cfg.MODEL.HEADS.POOL_LAYER
-        if pool_type == 'avgpool':      pool_layer = FastGlobalAvgPool2d()
+        if pool_type == 'fastavgpool':  pool_layer = FastGlobalAvgPool2d()
+        elif pool_type == 'avgpool':    pool_layer = nn.AdaptiveAvgPool2d(1)
         elif pool_type == 'maxpool':    pool_layer = nn.AdaptiveMaxPool2d(1)
         elif pool_type == 'gempool':    pool_layer = GeneralizedMeanPoolingP()
         elif pool_type == "avgmaxpool": pool_layer = AdaptiveAvgMaxPool2d()
@@ -66,8 +67,10 @@ class Baseline(nn.Module):
         """
         Normalize and batch the input images.
         """
-        images = batched_inputs["images"].to(self.device)
-        # images = batched_inputs
+        if isinstance(batched_inputs, dict):
+            images = batched_inputs["images"].to(self.device)
+        elif isinstance(batched_inputs, torch.Tensor):
+            images = batched_inputs.to(self.device)
         images.sub_(self.pixel_mean).div_(self.pixel_std)
         return images
 
@@ -89,4 +92,7 @@ class Baseline(nn.Module):
         if "TripletLoss" in loss_names:
             loss_dict['loss_triplet'] = TripletLoss(self._cfg)(pred_features, gt_labels)
 
+        if "CircleLoss" in loss_names:
+            loss_dict['loss_circle'] = CircleLoss(self._cfg)(pred_features, gt_labels)
+
         return loss_dict
diff --git a/fastreid/solver/optim/adam.py b/fastreid/solver/optim/adam.py
index fd051f1..a9b83e3 100644
--- a/fastreid/solver/optim/adam.py
+++ b/fastreid/solver/optim/adam.py
@@ -1,7 +1,7 @@
 # encoding: utf-8
 """
 @author:  xingyu liao
-@contact: liaoxingyu5@jd.com
+@contact: sherlockliao01@gmail.com
 """
 
 import torch
diff --git a/fastreid/solver/optim/sgd.py b/fastreid/solver/optim/sgd.py
index fbff412..9e31bc0 100644
--- a/fastreid/solver/optim/sgd.py
+++ b/fastreid/solver/optim/sgd.py
@@ -1,7 +1,7 @@
 # encoding: utf-8
 """
 @author:  xingyu liao
-@contact: liaoxingyu5@jd.com
+@contact: sherlockliao01@gmail.com
 """
 
 
diff --git a/fastreid/solver/optim/swa.py b/fastreid/solver/optim/swa.py
index 889239a..1d45e02 100644
--- a/fastreid/solver/optim/swa.py
+++ b/fastreid/solver/optim/swa.py
@@ -1,7 +1,7 @@
 # encoding: utf-8
 """
 @author:  xingyu liao
-@contact: liaoxingyu5@jd.com
+@contact: sherlockliao01@gmail.com
 """
 # based on:
 # https://github.com/pytorch/contrib/blob/master/torchcontrib/optim/swa.py
diff --git a/fastreid/utils/collect_env.py b/fastreid/utils/collect_env.py
index 6a7ec32..5affc33 100644
--- a/fastreid/utils/collect_env.py
+++ b/fastreid/utils/collect_env.py
@@ -1,7 +1,7 @@
 # encoding: utf-8
 """
 @author:  xingyu liao
-@contact: liaoxingyu5@jd.com
+@contact: sherlockliao01@gmail.com
 """
 
 # based on
diff --git a/fastreid/utils/weight_init.py b/fastreid/utils/weight_init.py
index 989c1df..0bd22e3 100644
--- a/fastreid/utils/weight_init.py
+++ b/fastreid/utils/weight_init.py
@@ -1,7 +1,7 @@
 # encoding: utf-8
 """
 @author:  xingyu liao
-@contact: liaoxingyu5@jd.com
+@contact: sherlockliao01@gmail.com
 """
 
 import math
@@ -35,5 +35,3 @@ def weights_init_classifier(m):
         nn.init.normal_(m.weight, std=0.001)
         if m.bias is not None:
             nn.init.constant_(m.bias, 0.0)
-    elif classname.find("Arcface") != -1 or classname.find("Circle") != -1:
-        nn.init.kaiming_uniform_(m.weight, a=math.sqrt(5))
diff --git a/projects/PartialReID/configs/partial_market.yml b/projects/PartialReID/configs/partial_market.yml
index 5f16d6f..76b9b0e 100644
--- a/projects/PartialReID/configs/partial_market.yml
+++ b/projects/PartialReID/configs/partial_market.yml
@@ -3,11 +3,10 @@ MODEL:
 
   BACKBONE:
     NAME: "build_resnet_backbone"
-    DEPTH: 50
+    DEPTH: "50x"
     NORM: "BN"
     LAST_STRIDE: 1
     WITH_IBN: True
-    PRETRAIN_PATH: "/export/home/lxy/.cache/torch/checkpoints/resnet50_ibn_a.pth.tar"
 
   HEADS:
     NAME: "DSRHead"
diff --git a/tools/deploy/README.md b/tools/deploy/README.md
index 96167f5..2a340b0 100644
--- a/tools/deploy/README.md
+++ b/tools/deploy/README.md
@@ -1,38 +1,29 @@
-# Deployment
+# Model Deployment
 
 This directory contains:
 
-1. A script that converts a fastreid model to Caffe format.
+1. The scripts that convert a fastreid model to Caffe/ONNX/TRT format.
 
-2. An exmpale that loads a R50 baseline model in Caffe and run inference.
+2. The exmpales that load a R50 baseline model in Caffe/ONNX/TRT and run inference.
 
 ## Tutorial
 
+### Caffe Convert
+
+<details>
+<summary>step-to-step pipeline for caffe convert</summary>
+
 This is a tiny example for converting fastreid-baseline in `meta_arch` to Caffe model, if you want to convert more complex architecture, you need to customize more things.
 
-1. Change `preprocess_image` in `fastreid/modeling/meta_arch/baseline.py` as below
-
-    ```python
-    def preprocess_image(self, batched_inputs):
-        """
-        Normalize and batch the input images.
-        """
-        # images = [x["images"] for x in batched_inputs]
-        # images = batched_inputs["images"]
-        images = batched_inputs
-        images.sub_(self.pixel_mean).div_(self.pixel_std)
-        return images
-    ```
-
-2. Run `caffe_export.py` to get the converted Caffe model,
+1. Run `caffe_export.py` to get the converted Caffe model,
 
     ```bash
-    python caffe_export.py --config-file "/export/home/lxy/fast-reid/logs/market1501/bagtricks_R50/config.yaml" --name "baseline_R50" --output "logs/caffe_model" --opts MODEL.WEIGHTS "/export/home/lxy/fast-reid/logs/market1501/bagtricks_R50/model_final.pth"
+    python caffe_export.py --config-file root-path/market1501/bagtricks_R50/config.yml --name "baseline_R50" --output outputs/caffe_model --opts MODEL.WEIGHTS root-path/logs/market1501/bagtricks_R50/model_final.pth
     ```
 
-    then you can check the Caffe model and prototxt in `logs/caffe_model`.
+    then you can check the Caffe model and prototxt in `outputs/caffe_model`.
 
-3. Change `prototxt` following next three steps:
+2. Change `prototxt` following next three steps:
 
    1) Edit `max_pooling` in `baseline_R50.prototxt` like this
 
@@ -67,7 +58,7 @@ This is a tiny example for converting fastreid-baseline in `meta_arch` to Caffe
         }
         ```
 
-    3) Change the last layer `top` name to `output`
+   3) Change the last layer `top` name to `output`
 
         ```prototxt
         layer {
@@ -81,22 +72,89 @@ This is a tiny example for converting fastreid-baseline in `meta_arch` to Caffe
         }
         ```
 
-4. (optional) You can open [Netscope](https://ethereon.github.io/netscope/quickstart.html), then enter you network `prototxt` to visualize the network.
+3. (optional) You can open [Netscope](https://ethereon.github.io/netscope/quickstart.html), then enter you network `prototxt` to visualize the network.
 
-5. Run `caffe_inference.py` to save Caffe model features with input images
+4. Run `caffe_inference.py` to save Caffe model features with input images
 
    ```bash
-    python caffe_inference.py --model-def "logs/caffe_model/baseline_R50.prototxt" \
-    --model-weights "logs/caffe_model/baseline_R50.caffemodel" \
-    --input \
-    '/export/home/DATA/Market-1501-v15.09.15/bounding_box_test/1182_c5s3_015240_04.jpg' \
-    '/export/home/DATA/Market-1501-v15.09.15/bounding_box_test/1182_c6s3_038217_01.jpg' \
-    '/export/home/DATA/Market-1501-v15.09.15/bounding_box_test/1183_c5s3_006943_05.jpg' \
-    --output "caffe_R34_output"
+    python caffe_inference.py --model-def outputs/caffe_model/baseline_R50.prototxt \
+    --model-weights outputs/caffe_model/baseline_R50.caffemodel \
+    --input test_data/*.jpg --output caffe_output
    ```
 
-6. Run `demo/demo.py` to get fastreid model features with the same input images, then compute the cosine similarity of difference model features to verify if you convert Caffe model successfully.
+6. Run `demo/demo.py` to get fastreid model features with the same input images, then verify that Caffe and PyTorch are computing the same value for the network.
+
+    ```python
+    np.testing.assert_allclose(torch_out, ort_out, rtol=1e-3, atol=1e-6)
+    ```
+
+</details>
+
+### ONNX Convert
+
+<details>
+<summary>step-to-step pipeline for onnx convert</summary>
+
+This is a tiny example for converting fastreid-baseline in `meta_arch` to ONNX model. ONNX supports most operators in pytorch as far as I know and if some operators are not supported by ONNX, you need to customize these.
+
+1. Run `onnx_export.py` to get the converted ONNX model,
+
+    ```bash
+    python onnx_export.py --config-file root-path/bagtricks_R50/config.yml --name "baseline_R50" --output outputs/onnx_model --opts MODEL.WEIGHTS root-path/logs/market1501/bagtricks_R50/model_final.pth
+    ```
+
+    then you can check the ONNX model in `outputs/onnx_model`.
+
+2. (optional) You can use [Netron](https://github.com/lutzroeder/netron) to visualize the network.
+
+3. Run `onnx_inference.py` to save ONNX model features with input images
+
+   ```bash
+    python onnx_inference.py --model-path outputs/onnx_model/baseline_R50.onnx \
+    --input test_data/*.jpg --output onnx_output
+   ```
+
+4. Run `demo/demo.py` to get fastreid model features with the same input images, then verify that ONNX Runtime and PyTorch are computing the same value for the network.
+
+    ```python
+    np.testing.assert_allclose(torch_out, ort_out, rtol=1e-3, atol=1e-6)
+    ```
+
+</details>
+
+### TensorRT Convert
+
+<details>
+<summary>step-to-step pipeline for trt convert</summary>
+
+This is a tiny example for converting fastreid-baseline in `meta_arch` to TRT model. We use [tiny-tensorrt](https://github.com/zerollzeng/tiny-tensorrt) which is a simple and easy-to-use nvidia TensorRt warpper, to get the model converted to tensorRT.
+
+First you need to convert the pytorch model to ONNX format following [ONNX Convert](https://github.com/JDAI-CV/fast-reid#fastreid), and you need to remember your `output` name. Then you can convert ONNX model to TensorRT following instructions below.
+
+1. Run command line below to get the converted TRT model from ONNX model,
+
+    ```bash
+
+    python trt_export.py --name "baseline_R50" --output outputs/trt_model --onnx-model outputs/onnx_model/baseline.onnx --heighi 256 --width 128
+    ```
+
+    then you can check the TRT model in `outputs/trt_model`.
+
+2. Run `trt_inference.py` to save TRT model features with input images
+
+   ```bash
+    python onnx_inference.py --model-path outputs/trt_model/baseline.engine \
+    --input test_data/*.jpg --output trt_output --output-name trt_model_outputname
+   ```
+
+3. Run `demo/demo.py` to get fastreid model features with the same input images, then verify that TensorRT and PyTorch are computing the same value for the network.
+
+    ```python
+    np.testing.assert_allclose(torch_out, ort_out, rtol=1e-3, atol=1e-6)
+    ```
+
+</details>
 
 ## Acknowledgements
 
-Thank to [CPFLAME](https://github.com/CPFLAME), [gcong18](https://github.com/gcong18), [YuxiangJohn](https://github.com/YuxiangJohn) and [wiggin66](https://github.com/wiggin66) at JDAI Model Acceleration Group for help in PyTorch to Caffe model converting.
+Thank to [CPFLAME](https://github.com/CPFLAME), [gcong18](https://github.com/gcong18), [YuxiangJohn](https://github.com/YuxiangJohn) and [wiggin66](https://github.com/wiggin66) at JDAI Model Acceleration Group for help in PyTorch model converting.
diff --git a/tools/deploy/caffe_export.py b/tools/deploy/caffe_export.py
index ec54af5..db89612 100644
--- a/tools/deploy/caffe_export.py
+++ b/tools/deploy/caffe_export.py
@@ -1,7 +1,7 @@
 # encoding: utf-8
 """
 @author:  xingyu liao
-@contact: liaoxingyu5@jd.com
+@contact: sherlockliao01@gmail.com
 """
 
 import argparse
@@ -15,6 +15,9 @@ from fastreid.config import get_cfg
 from fastreid.modeling.meta_arch import build_model
 from fastreid.utils.file_io import PathManager
 from fastreid.utils.checkpoint import Checkpointer
+from fastreid.utils.logger import setup_logger
+
+logger = setup_logger(name='caffe_export')
 
 
 def setup_cfg(args):
@@ -64,10 +67,12 @@ if __name__ == '__main__':
     model = build_model(cfg)
     Checkpointer(model).load(cfg.MODEL.WEIGHTS)
     model.eval()
-    print(model)
+    logger.info(model)
 
-    inputs = torch.randn(1, 3, cfg.INPUT.SIZE_TEST[0], cfg.INPUT.SIZE_TEST[1]).cuda()
+    inputs = torch.randn(1, 3, cfg.INPUT.SIZE_TEST[0], cfg.INPUT.SIZE_TEST[1])
     PathManager.mkdirs(args.output)
     pytorch_to_caffe.trans_net(model, inputs, args.name)
     pytorch_to_caffe.save_prototxt(f"{args.output}/{args.name}.prototxt")
     pytorch_to_caffe.save_caffemodel(f"{args.output}/{args.name}.caffemodel")
+
+    logger.info(f"Export caffe model in {args.output} sucessfully!")
diff --git a/tools/deploy/caffe_inference.py b/tools/deploy/caffe_inference.py
index 6b92f82..2956816 100644
--- a/tools/deploy/caffe_inference.py
+++ b/tools/deploy/caffe_inference.py
@@ -1,7 +1,7 @@
 # encoding: utf-8
 """
 @author:  xingyu liao
-@contact: liaoxingyu5@jd.com
+@contact: sherlockliao01@gmail.com
 """
 
 import caffe
@@ -43,7 +43,7 @@ def get_parser():
     parser.add_argument(
         "--height",
         type=int,
-        default=384,
+        default=256,
         help="height of image"
     )
     parser.add_argument(
diff --git a/tools/deploy/export2tf.py b/tools/deploy/export2tf.py
deleted file mode 100644
index 593c7cb..0000000
--- a/tools/deploy/export2tf.py
+++ /dev/null
@@ -1,48 +0,0 @@
-# encoding: utf-8
-"""
-@author:  sherlock
-@contact: sherlockliao01@gmail.com
-"""
-
-import sys
-
-import torch
-sys.path.append('../..')
-from fastreid.config import get_cfg
-from fastreid.engine import default_argument_parser, default_setup
-from fastreid.modeling.meta_arch import build_model
-from fastreid.export.tensorflow_export import export_tf_reid_model
-from fastreid.export.tf_modeling import TfMetaArch
-
-
-def setup(args):
-    """
-    Create configs and perform basic setups.
-    """
-    cfg = get_cfg()
-    # cfg.merge_from_file(args.config_file)
-    cfg.merge_from_list(args.opts)
-    cfg.freeze()
-    default_setup(cfg, args)
-    return cfg
-
-
-if __name__ == "__main__":
-    args = default_argument_parser().parse_args()
-    print("Command Line Args:", args)
-    cfg = setup(args)
-    cfg.defrost()
-    cfg.MODEL.BACKBONE.NAME = "build_resnet_backbone"
-    cfg.MODEL.BACKBONE.DEPTH = 50
-    cfg.MODEL.BACKBONE.LAST_STRIDE = 1
-    # If use IBN block in backbone
-    cfg.MODEL.BACKBONE.WITH_IBN = False
-    cfg.MODEL.BACKBONE.PRETRAIN = False
-
-    from torchvision.models import resnet50
-    # model = TfMetaArch(cfg)
-    model = resnet50(pretrained=False)
-    # model.load_params_wo_fc(torch.load('logs/bjstation/res50_baseline_v0.4/ckpts/model_epoch80.pth'))
-    model.eval()
-    dummy_inputs = torch.randn(1, 3, 256, 128)
-    export_tf_reid_model(model, dummy_inputs, 'reid_tf.pb')
diff --git a/tools/deploy/onnx_export.py b/tools/deploy/onnx_export.py
new file mode 100644
index 0000000..449db3a
--- /dev/null
+++ b/tools/deploy/onnx_export.py
@@ -0,0 +1,146 @@
+# encoding: utf-8
+"""
+@author:  xingyu liao
+@contact: sherlockliao01@gmail.com
+"""
+
+import argparse
+import io
+import sys
+
+import onnx
+import torch
+from onnxsim import simplify
+from torch.onnx import OperatorExportTypes
+
+sys.path.append('../../')
+
+from fastreid.config import get_cfg
+from fastreid.modeling.meta_arch import build_model
+from fastreid.utils.file_io import PathManager
+from fastreid.utils.checkpoint import Checkpointer
+from fastreid.utils.logger import setup_logger
+
+logger = setup_logger(name='onnx_export')
+
+
+def setup_cfg(args):
+    cfg = get_cfg()
+    cfg.merge_from_file(args.config_file)
+    cfg.merge_from_list(args.opts)
+    cfg.freeze()
+    return cfg
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(description="Convert Pytorch to ONNX model")
+
+    parser.add_argument(
+        "--config-file",
+        metavar="FILE",
+        help="path to config file",
+    )
+    parser.add_argument(
+        "--name",
+        default="baseline",
+        help="name for converted model"
+    )
+    parser.add_argument(
+        "--output",
+        default='onnx_model',
+        help='path to save converted onnx model'
+    )
+    parser.add_argument(
+        "--opts",
+        help="Modify config options using the command-line 'KEY VALUE' pairs",
+        default=[],
+        nargs=argparse.REMAINDER,
+    )
+    return parser
+
+
+def remove_initializer_from_input(model):
+    if model.ir_version < 4:
+        print(
+            'Model with ir_version below 4 requires to include initilizer in graph input'
+        )
+        return
+
+    inputs = model.graph.input
+    name_to_input = {}
+    for input in inputs:
+        name_to_input[input.name] = input
+
+    for initializer in model.graph.initializer:
+        if initializer.name in name_to_input:
+            inputs.remove(name_to_input[initializer.name])
+
+    return model
+
+
+def export_onnx_model(model, inputs):
+    """
+    Trace and export a model to onnx format.
+    Args:
+        model (nn.Module):
+        inputs (torch.Tensor): the model will be called by `model(*inputs)`
+    Returns:
+        an onnx model
+    """
+    assert isinstance(model, torch.nn.Module)
+
+    # make sure all modules are in eval mode, onnx may change the training state
+    # of the module if the states are not consistent
+    def _check_eval(module):
+        assert not module.training
+
+    model.apply(_check_eval)
+
+    # Export the model to ONNX
+    with torch.no_grad():
+        with io.BytesIO() as f:
+            torch.onnx.export(
+                model,
+                inputs,
+                f,
+                operator_export_type=OperatorExportTypes.ONNX_ATEN_FALLBACK,
+                # verbose=True,  # NOTE: uncomment this for debugging
+                # export_params=True,
+            )
+            onnx_model = onnx.load_from_string(f.getvalue())
+
+    # Apply ONNX's Optimization
+    all_passes = onnx.optimizer.get_available_passes()
+    passes = ["extract_constant_to_initializer", "eliminate_unused_initializer", "fuse_bn_into_conv"]
+    assert all(p in all_passes for p in passes)
+    onnx_model = onnx.optimizer.optimize(onnx_model, passes)
+    return onnx_model
+
+
+if __name__ == '__main__':
+    args = get_parser().parse_args()
+    cfg = setup_cfg(args)
+
+    cfg.defrost()
+    cfg.MODEL.BACKBONE.PRETRAIN = False
+    if cfg.MODEL.HEADS.POOL_LAYER == 'fastavgpool':
+        cfg.MODEL.HEADS.POOL_LAYER = 'avgpool'
+    model = build_model(cfg)
+    Checkpointer(model).load(cfg.MODEL.WEIGHTS)
+    model.eval()
+    logger.info(model)
+
+    inputs = torch.randn(1, 3, cfg.INPUT.SIZE_TEST[0], cfg.INPUT.SIZE_TEST[1])
+    onnx_model = export_onnx_model(model, inputs)
+
+    model_simp, check = simplify(onnx_model)
+
+    model_simp = remove_initializer_from_input(model_simp)
+
+    assert check, "Simplified ONNX model could not be validated"
+
+    PathManager.mkdirs(args.output)
+
+    onnx.save_model(model_simp, f"{args.output}/{args.name}.onnx")
+
+    logger.info(f"Export onnx model in {args.output} successfully!")
diff --git a/tools/deploy/onnx_inference.py b/tools/deploy/onnx_inference.py
new file mode 100644
index 0000000..2e29f62
--- /dev/null
+++ b/tools/deploy/onnx_inference.py
@@ -0,0 +1,85 @@
+# encoding: utf-8
+"""
+@author:  xingyu liao
+@contact: sherlockliao01@gmail.com
+"""
+
+import argparse
+import glob
+import os
+
+import cv2
+import numpy as np
+import onnxruntime
+import tqdm
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(description="onnx model inference")
+
+    parser.add_argument(
+        "--model-path",
+        default="onnx_model/baseline.onnx",
+        help="onnx model path"
+    )
+    parser.add_argument(
+        "--input",
+        nargs="+",
+        help="A list of space separated input images; "
+             "or a single glob pattern such as 'directory/*.jpg'",
+    )
+    parser.add_argument(
+        "--output",
+        default='onnx_output',
+        help='path to save converted caffe model'
+    )
+    parser.add_argument(
+        "--height",
+        type=int,
+        default=256,
+        help="height of image"
+    )
+    parser.add_argument(
+        "--width",
+        type=int,
+        default=128,
+        help="width of image"
+    )
+    return parser
+
+
+def preprocess(image_path, image_height, image_width):
+    original_image = cv2.imread(image_path)
+    # the model expects RGB inputs
+    original_image = original_image[:, :, ::-1]
+
+    # Apply pre-processing to image.
+    img = cv2.resize(original_image, (image_width, image_height), interpolation=cv2.INTER_CUBIC)
+    img = img.astype("float32").transpose(2, 0, 1)[np.newaxis]  # (1, 3, h, w)
+    return img
+
+
+def normalize(nparray, order=2, axis=-1):
+    """Normalize a N-D numpy array along the specified axis."""
+    norm = np.linalg.norm(nparray, ord=order, axis=axis, keepdims=True)
+    return nparray / (norm + np.finfo(np.float32).eps)
+
+
+if __name__ == "__main__":
+    args = get_parser().parse_args()
+
+    ort_sess = onnxruntime.InferenceSession(args.model_path)
+
+    input_name = ort_sess.get_inputs()[0].name
+
+    if not os.path.exists(args.output): os.makedirs(args.output)
+
+    if args.input:
+        if os.path.isdir(args.input[0]):
+            args.input = glob.glob(os.path.expanduser(args.input[0]))
+            assert args.input, "The input path(s) was not found"
+        for path in tqdm.tqdm(args.input):
+            image = preprocess(path, args.height, args.width)
+            feat = ort_sess.run(None, {input_name: image})[0]
+            feat = normalize(feat, axis=1)
+            np.save(os.path.join(args.output, path.replace('.jpg', '.npy').split('/')[-1]), feat)
diff --git a/tools/deploy/run_inference.sh b/tools/deploy/run_inference.sh
deleted file mode 100644
index 2bc54ab..0000000
--- a/tools/deploy/run_inference.sh
+++ /dev/null
@@ -1,10 +0,0 @@
-
-python caffe_inference.py --model-def "logs/caffe_R34/baseline_R34.prototxt" \
---model-weights "logs/caffe_R34/baseline_R34.caffemodel" \
---height 256 --width 128 \
---input \
-'/export/home/DATA/Market-1501-v15.09.15/bounding_box_test/1182_c5s3_015240_04.jpg' \
-'/export/home/DATA/Market-1501-v15.09.15/bounding_box_test/1182_c6s3_038217_01.jpg' \
-'/export/home/DATA/Market-1501-v15.09.15/bounding_box_test/1183_c5s3_006943_05.jpg' \
-'/export/home/DATA/DukeMTMC-reID/bounding_box_train/0728_c4_f0161265.jpg' \
---output "caffe_R34_output"
diff --git a/tools/deploy/test_data/0022_c6s1_002976_01.jpg b/tools/deploy/test_data/0022_c6s1_002976_01.jpg
new file mode 100644
index 0000000..e15dab7
Binary files /dev/null and b/tools/deploy/test_data/0022_c6s1_002976_01.jpg differ
diff --git a/tools/deploy/test_data/0027_c2s2_091032_02.jpg b/tools/deploy/test_data/0027_c2s2_091032_02.jpg
new file mode 100644
index 0000000..8cf4959
Binary files /dev/null and b/tools/deploy/test_data/0027_c2s2_091032_02.jpg differ
diff --git a/tools/deploy/test_data/0032_c6s1_002851_01.jpg b/tools/deploy/test_data/0032_c6s1_002851_01.jpg
new file mode 100644
index 0000000..141ed15
Binary files /dev/null and b/tools/deploy/test_data/0032_c6s1_002851_01.jpg differ
diff --git a/tools/deploy/test_data/0048_c1s1_005351_01.jpg b/tools/deploy/test_data/0048_c1s1_005351_01.jpg
new file mode 100644
index 0000000..593caee
Binary files /dev/null and b/tools/deploy/test_data/0048_c1s1_005351_01.jpg differ
diff --git a/tools/deploy/test_data/0065_c6s1_009501_02.jpg b/tools/deploy/test_data/0065_c6s1_009501_02.jpg
new file mode 100644
index 0000000..e7fb874
Binary files /dev/null and b/tools/deploy/test_data/0065_c6s1_009501_02.jpg differ
diff --git a/tools/deploy/trt_export.py b/tools/deploy/trt_export.py
new file mode 100644
index 0000000..edc9f1e
--- /dev/null
+++ b/tools/deploy/trt_export.py
@@ -0,0 +1,82 @@
+# encoding: utf-8
+"""
+@author:  xingyu liao
+@contact: sherlockliao01@gmail.com
+"""
+
+import argparse
+import os
+import numpy as np
+import sys
+
+sys.path.append('../../')
+sys.path.append("/export/home/lxy/runtimelib-tensorrt-tiny/build")
+
+import pytrt
+from fastreid.utils.logger import setup_logger
+from fastreid.utils.file_io import PathManager
+
+
+logger = setup_logger(name='trt_export')
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(description="Convert ONNX to TRT model")
+
+    parser.add_argument(
+        "--name",
+        default="baseline",
+        help="name for converted model"
+    )
+    parser.add_argument(
+        "--output",
+        default='outputs/trt_model',
+        help='path to save converted trt model'
+    )
+    parser.add_argument(
+        "--onnx-model",
+        default='outputs/onnx_model/baseline.onnx',
+        help='path to onnx model'
+    )
+    parser.add_argument(
+        "--height",
+        type=int,
+        default=256,
+        help="height of image"
+    )
+    parser.add_argument(
+        "--width",
+        type=int,
+        default=128,
+        help="width of image"
+    )
+    return parser
+
+
+def export_trt_model(onnxModel, engineFile, input_numpy_array):
+    r"""
+    Export a model to trt format.
+    """
+
+    trt = pytrt.Trt()
+
+    customOutput = []
+    maxBatchSize = 1
+    calibratorData = []
+    mode = 2
+    trt.CreateEngine(onnxModel, engineFile, customOutput, maxBatchSize, mode, calibratorData)
+    trt.DoInference(input_numpy_array)  # slightly different from c++
+    return 0
+
+
+if __name__ == '__main__':
+    args = get_parser().parse_args()
+
+    inputs = np.zeros(shape=(32, args.height, args.width, 3))
+    onnxModel = args.onnx_model
+    engineFile = os.path.join(args.output, args.name+'.engine')
+
+    PathManager.mkdirs(args.output)
+    export_trt_model(onnxModel, engineFile, inputs)
+
+    logger.info(f"Export trt model in {args.output} successfully!")
diff --git a/tools/deploy/trt_inference.py b/tools/deploy/trt_inference.py
new file mode 100644
index 0000000..0775364
--- /dev/null
+++ b/tools/deploy/trt_inference.py
@@ -0,0 +1,99 @@
+# encoding: utf-8
+"""
+@author:  xingyu liao
+@contact: sherlockliao01@gmail.com
+"""
+import argparse
+import glob
+import os
+import sys
+
+import cv2
+import numpy as np
+# import tqdm
+
+sys.path.append("/export/home/lxy/runtimelib-tensorrt-tiny/build")
+
+import pytrt
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(description="trt model inference")
+
+    parser.add_argument(
+        "--model-path",
+        default="outputs/trt_model/baseline.engine",
+        help="trt model path"
+    )
+    parser.add_argument(
+        "--input",
+        nargs="+",
+        help="A list of space separated input images; "
+             "or a single glob pattern such as 'directory/*.jpg'",
+    )
+    parser.add_argument(
+        "--output",
+        default="trt_output",
+        help="path to save trt model inference results"
+    )
+    parser.add_argument(
+        "--output-name",
+        help="tensorRT model output name"
+    )
+    parser.add_argument(
+        "--height",
+        type=int,
+        default=256,
+        help="height of image"
+    )
+    parser.add_argument(
+        "--width",
+        type=int,
+        default=128,
+        help="width of image"
+    )
+    return parser
+
+
+def preprocess(image_path, image_height, image_width):
+    original_image = cv2.imread(image_path)
+    # the model expects RGB inputs
+    original_image = original_image[:, :, ::-1]
+
+    # Apply pre-processing to image.
+    img = cv2.resize(original_image, (image_width, image_height), interpolation=cv2.INTER_CUBIC)
+    img = img.astype("float32").transpose(2, 0, 1)[np.newaxis]  # (1, 3, h, w)
+    return img
+
+
+def normalize(nparray, order=2, axis=-1):
+    """Normalize a N-D numpy array along the specified axis."""
+    norm = np.linalg.norm(nparray, ord=order, axis=axis, keepdims=True)
+    return nparray / (norm + np.finfo(np.float32).eps)
+
+
+if __name__ == "__main__":
+    args = get_parser().parse_args()
+
+    trt = pytrt.Trt()
+
+    onnxModel = ""
+    engineFile = args.model_path
+    customOutput = []
+    maxBatchSize = 1
+    calibratorData = []
+    mode = 2
+    trt.CreateEngine(onnxModel, engineFile, customOutput, maxBatchSize, mode, calibratorData)
+
+    if not os.path.exists(args.output): os.makedirs(args.output)
+
+    if args.input:
+        if os.path.isdir(args.input[0]):
+            args.input = glob.glob(os.path.expanduser(args.input[0]))
+            assert args.input, "The input path(s) was not found"
+        for path in args.input:
+            input_numpy_array = preprocess(path, args.height, args.width)
+            trt.DoInference(input_numpy_array)
+            feat = trt.GetOutput(args.output_name)
+            feat = normalize(feat, axis=1)
+            np.save(os.path.join(args.output, path.replace('.jpg', '.npy').split('/')[-1]), feat)