From 16655448c25dfb3e9cbf35e6a1f94d1731bbbf06 Mon Sep 17 00:00:00 2001
From: liaoxingyu <sherlockliao01@gmail.com>
Date: Wed, 29 Jul 2020 17:43:39 +0800
Subject: [PATCH] onnx/trt support

Summary: change model pretrain mode and support onnx/TensorRT export
---
 README.md                                     |   1 +
 configs/Base-Strongerbaseline.yml             |   2 +-
 configs/Base-bagtricks.yml                    |   2 +-
 configs/DukeMTMC/AGW_R101-ibn.yml             |   1 -
 configs/DukeMTMC/AGW_R50-ibn.yml              |   1 -
 configs/DukeMTMC/bagtricks_R101-ibn.yml       |   1 -
 configs/DukeMTMC/bagtricks_R50-ibn.yml        |   1 -
 configs/DukeMTMC/mgn_R50-ibn.yml              |   1 -
 configs/DukeMTMC/sbs_R101-ibn.yml             |   1 -
 configs/DukeMTMC/sbs_R50-ibn.yml              |   1 -
 configs/MSMT17/AGW_R101-ibn.yml               |   1 -
 configs/MSMT17/AGW_R50-ibn.yml                |   1 -
 configs/MSMT17/bagtricks_R101-ibn.yml         |   1 -
 configs/MSMT17/bagtricks_R50-ibn.yml          |   1 -
 configs/MSMT17/mgn_R50-ibn.yml                |   1 -
 configs/MSMT17/sbs_R101-ibn.yml               |   1 -
 configs/MSMT17/sbs_R50-ibn.yml                |   1 -
 configs/Market1501/AGW_R101-ibn.yml           |   1 -
 configs/Market1501/AGW_R50-ibn.yml            |   1 -
 configs/Market1501/bagtricks_R101-ibn.yml     |   3 +-
 configs/Market1501/bagtricks_R50-ibn.yml      |   1 -
 configs/Market1501/mgn_R50-ibn.yml            |   1 -
 configs/Market1501/sbs_R101-ibn.yml           |   1 -
 configs/Market1501/sbs_R50-ibn.yml            |   1 -
 configs/VERIWild/bagtricks_R50-ibn.yml        |   2 -
 configs/VeRi/sbs_R50-ibn.yml                  |   4 -
 configs/VehicleID/bagtricks_R50-ibn.yml       |   2 -
 demo/demo.py                                  |   1 +
 demo/plot_roc_with_pickle.py                  |   2 +-
 demo/predictor.py                             |   2 +-
 demo/visualize_result.py                      |   2 +-
 fastreid/config/defaults.py                   |  12 +-
 fastreid/data/datasets/vehicleid.py           | 248 ++++++++--------
 fastreid/data/datasets/veri.py                | 134 ++++-----
 fastreid/data/datasets/veriwild.py            | 276 +++++++++---------
 fastreid/data/transforms/build.py             |   9 +-
 fastreid/data/transforms/transforms.py        | 109 +------
 fastreid/engine/launch.py                     |   2 +-
 fastreid/evaluation/query_expansion.py        |   2 +-
 fastreid/evaluation/reid_evaluation.py        |   3 +-
 fastreid/layers/__init__.py                   |   3 +-
 fastreid/layers/activation.py                 |   2 +-
 fastreid/layers/am_softmax.py                 |  43 +++
 fastreid/layers/arc_softmax.py                |   1 +
 fastreid/layers/circle_softmax.py             |   7 +-
 fastreid/layers/splat.py                      |   5 -
 fastreid/modeling/backbones/__init__.py       |   2 +-
 fastreid/modeling/backbones/osnet.py          |  96 +++---
 fastreid/modeling/backbones/regnet/regnet.py  |  76 ++++-
 fastreid/modeling/backbones/resnet.py         | 106 +++++--
 fastreid/modeling/backbones/resnext.py        | 158 +++++++---
 fastreid/modeling/heads/bnneck_head.py        |   3 +-
 fastreid/modeling/heads/linear_head.py        |   3 +-
 fastreid/modeling/heads/reduction_head.py     |   9 +-
 fastreid/modeling/losses/__init__.py          |   3 +-
 fastreid/modeling/losses/circle_loss.py       |  61 ++++
 fastreid/modeling/losses/cross_entroy_loss.py |   8 +-
 fastreid/modeling/losses/smooth_ap.py         | 241 +++++++++++++++
 .../{metric_loss.py => triplet_loss.py}       |  85 +-----
 fastreid/modeling/losses/utils.py             |  51 ++++
 fastreid/modeling/meta_arch/baseline.py       |  12 +-
 fastreid/solver/optim/adam.py                 |   2 +-
 fastreid/solver/optim/sgd.py                  |   2 +-
 fastreid/solver/optim/swa.py                  |   2 +-
 fastreid/utils/collect_env.py                 |   2 +-
 fastreid/utils/weight_init.py                 |   4 +-
 .../PartialReID/configs/partial_market.yml    |   3 +-
 tools/deploy/README.md                        | 124 +++++---
 tools/deploy/caffe_export.py                  |  11 +-
 tools/deploy/caffe_inference.py               |   4 +-
 tools/deploy/export2tf.py                     |  48 ---
 tools/deploy/onnx_export.py                   | 146 +++++++++
 tools/deploy/onnx_inference.py                |  85 ++++++
 tools/deploy/run_inference.sh                 |  10 -
 .../deploy/test_data/0022_c6s1_002976_01.jpg  | Bin 0 -> 2223 bytes
 .../deploy/test_data/0027_c2s2_091032_02.jpg  | Bin 0 -> 2137 bytes
 .../deploy/test_data/0032_c6s1_002851_01.jpg  | Bin 0 -> 2640 bytes
 .../deploy/test_data/0048_c1s1_005351_01.jpg  | Bin 0 -> 2149 bytes
 .../deploy/test_data/0065_c6s1_009501_02.jpg  | Bin 0 -> 1890 bytes
 tools/deploy/trt_export.py                    |  82 ++++++
 tools/deploy/trt_inference.py                 |  99 +++++++
 81 files changed, 1631 insertions(+), 806 deletions(-)
 create mode 100644 fastreid/layers/am_softmax.py
 create mode 100644 fastreid/modeling/losses/circle_loss.py
 create mode 100644 fastreid/modeling/losses/smooth_ap.py
 rename fastreid/modeling/losses/{metric_loss.py => triplet_loss.py} (63%)
 create mode 100644 fastreid/modeling/losses/utils.py
 delete mode 100644 tools/deploy/export2tf.py
 create mode 100644 tools/deploy/onnx_export.py
 create mode 100644 tools/deploy/onnx_inference.py
 delete mode 100644 tools/deploy/run_inference.sh
 create mode 100644 tools/deploy/test_data/0022_c6s1_002976_01.jpg
 create mode 100644 tools/deploy/test_data/0027_c2s2_091032_02.jpg
 create mode 100644 tools/deploy/test_data/0032_c6s1_002851_01.jpg
 create mode 100644 tools/deploy/test_data/0048_c1s1_005351_01.jpg
 create mode 100644 tools/deploy/test_data/0065_c6s1_009501_02.jpg
 create mode 100644 tools/deploy/trt_export.py
 create mode 100644 tools/deploy/trt_inference.py

diff --git a/README.md b/README.md
index d306afe..c312ef2 100644
--- a/README.md
+++ b/README.md
@@ -4,6 +4,7 @@ FastReID is a research platform that implements state-of-the-art re-identificati
 
 ## What's New
 
+- [Aug 2020] ONNX/TensorRT converter is supported.
 - [Jul 2020] Distributed training with multiple GPUs, it trains much faster.
 - [Jul 2020] `MAX_ITER` in config means `epoch`, it will auto scale to maximum iterations.
 - Includes more features such as circle loss, abundant visualization methods and evaluation metrics, SoTA results on conventional, cross-domain, partial and vehicle re-id, testing on multi-datasets simultaneously, etc.
diff --git a/configs/Base-Strongerbaseline.yml b/configs/Base-Strongerbaseline.yml
index 7cebdd0..81e6f03 100644
--- a/configs/Base-Strongerbaseline.yml
+++ b/configs/Base-Strongerbaseline.yml
@@ -9,7 +9,7 @@ MODEL:
   HEADS:
     NECK_FEAT: "after"
     POOL_LAYER: "gempool"
-    CLS_LAYER: "circle"
+    CLS_LAYER: "circleSoftmax"
     SCALE: 64
     MARGIN: 0.35
 
diff --git a/configs/Base-bagtricks.yml b/configs/Base-bagtricks.yml
index ee01d1b..1cc7fe1 100644
--- a/configs/Base-bagtricks.yml
+++ b/configs/Base-bagtricks.yml
@@ -4,7 +4,7 @@ MODEL:
   BACKBONE:
     NAME: "build_resnet_backbone"
     NORM: "BN"
-    DEPTH: 50
+    DEPTH: "50x"
     LAST_STRIDE: 1
     WITH_IBN: False
     PRETRAIN: True
diff --git a/configs/DukeMTMC/AGW_R101-ibn.yml b/configs/DukeMTMC/AGW_R101-ibn.yml
index 1b4766e..dcf8520 100644
--- a/configs/DukeMTMC/AGW_R101-ibn.yml
+++ b/configs/DukeMTMC/AGW_R101-ibn.yml
@@ -4,7 +4,6 @@ MODEL:
   BACKBONE:
     DEPTH: 101
     WITH_IBN: True
-    PRETRAIN_PATH: "/home/liaoxingyu2/lxy/.cache/torch/checkpoints/resnet101_ibn_a.pth.tar"
 
 DATASETS:
   NAMES: ("DukeMTMC",)
diff --git a/configs/DukeMTMC/AGW_R50-ibn.yml b/configs/DukeMTMC/AGW_R50-ibn.yml
index cd021a2..648660f 100644
--- a/configs/DukeMTMC/AGW_R50-ibn.yml
+++ b/configs/DukeMTMC/AGW_R50-ibn.yml
@@ -3,7 +3,6 @@ _BASE_: "../Base-AGW.yml"
 MODEL:
   BACKBONE:
     WITH_IBN: True
-    PRETRAIN_PATH: "/home/liaoxingyu2/lxy/.cache/torch/checkpoints/resnet50_ibn_a.pth.tar"
 
 DATASETS:
   NAMES: ("DukeMTMC",)
diff --git a/configs/DukeMTMC/bagtricks_R101-ibn.yml b/configs/DukeMTMC/bagtricks_R101-ibn.yml
index 6e0bc9a..ca3b4bc 100644
--- a/configs/DukeMTMC/bagtricks_R101-ibn.yml
+++ b/configs/DukeMTMC/bagtricks_R101-ibn.yml
@@ -4,7 +4,6 @@ MODEL:
   BACKBONE:
     DEPTH: 101
     WITH_IBN: True
-    PRETRAIN_PATH: "/home/liaoxingyu2/lxy/.cache/torch/checkpoints/resnet101_ibn_a.pth.tar"
 
 DATASETS:
   NAMES: ("DukeMTMC",)
diff --git a/configs/DukeMTMC/bagtricks_R50-ibn.yml b/configs/DukeMTMC/bagtricks_R50-ibn.yml
index 9c51ab9..cb46929 100644
--- a/configs/DukeMTMC/bagtricks_R50-ibn.yml
+++ b/configs/DukeMTMC/bagtricks_R50-ibn.yml
@@ -3,7 +3,6 @@ _BASE_: "../Base-bagtricks.yml"
 MODEL:
   BACKBONE:
     WITH_IBN: True
-    PRETRAIN_PATH: "/home/liaoxingyu2/lxy/.cache/torch/checkpoints/resnet50_ibn_a.pth.tar"
 
 DATASETS:
   NAMES: ("DukeMTMC",)
diff --git a/configs/DukeMTMC/mgn_R50-ibn.yml b/configs/DukeMTMC/mgn_R50-ibn.yml
index 2d8f6fb..ab6bf7c 100644
--- a/configs/DukeMTMC/mgn_R50-ibn.yml
+++ b/configs/DukeMTMC/mgn_R50-ibn.yml
@@ -3,7 +3,6 @@ _BASE_: "../Base-MGN.yml"
 MODEL:
   BACKBONE:
     WITH_IBN: True
-    PRETRAIN_PATH: "/home/liaoxingyu2/lxy/.cache/torch/checkpoints/resnet50_ibn_a.pth.tar"
 
 DATASETS:
   NAMES: ("DukeMTMC",)
diff --git a/configs/DukeMTMC/sbs_R101-ibn.yml b/configs/DukeMTMC/sbs_R101-ibn.yml
index 10d26b5..b22aac0 100644
--- a/configs/DukeMTMC/sbs_R101-ibn.yml
+++ b/configs/DukeMTMC/sbs_R101-ibn.yml
@@ -4,7 +4,6 @@ MODEL:
   BACKBONE:
     DEPTH: 101
     WITH_IBN: True
-    PRETRAIN_PATH: "/home/liaoxingyu2/lxy/.cache/torch/checkpoints/resnet101_ibn_a.pth.tar"
 
 DATASETS:
   NAMES: ("DukeMTMC",)
diff --git a/configs/DukeMTMC/sbs_R50-ibn.yml b/configs/DukeMTMC/sbs_R50-ibn.yml
index 8b6cd87..cbca66a 100644
--- a/configs/DukeMTMC/sbs_R50-ibn.yml
+++ b/configs/DukeMTMC/sbs_R50-ibn.yml
@@ -3,7 +3,6 @@ _BASE_: "../Base-Strongerbaseline.yml"
 MODEL:
   BACKBONE:
     WITH_IBN: True
-    PRETRAIN_PATH: "/home/liaoxingyu2/lxy/.cache/torch/checkpoints/resnet50_ibn_a.pth.tar"
 
 DATASETS:
   NAMES: ("DukeMTMC",)
diff --git a/configs/MSMT17/AGW_R101-ibn.yml b/configs/MSMT17/AGW_R101-ibn.yml
index 5283ba5..12f49a2 100644
--- a/configs/MSMT17/AGW_R101-ibn.yml
+++ b/configs/MSMT17/AGW_R101-ibn.yml
@@ -4,7 +4,6 @@ MODEL:
   BACKBONE:
     DEPTH: 101
     WITH_IBN: True
-    PRETRAIN_PATH: "/home/liaoxingyu2/lxy/.cache/torch/checkpoints/resnet101_ibn_a.pth.tar"
 
 DATASETS:
   NAMES: ("MSMT17",)
diff --git a/configs/MSMT17/AGW_R50-ibn.yml b/configs/MSMT17/AGW_R50-ibn.yml
index fb57808..6104ed6 100644
--- a/configs/MSMT17/AGW_R50-ibn.yml
+++ b/configs/MSMT17/AGW_R50-ibn.yml
@@ -3,7 +3,6 @@ _BASE_: "../Base-AGW.yml"
 MODEL:
   BACKBONE:
     WITH_IBN: True
-    PRETRAIN_PATH: "/home/liaoxingyu2/lxy/.cache/torch/checkpoints/resnet50_ibn_a.pth.tar"
 
 DATASETS:
   NAMES: ("MSMT17",)
diff --git a/configs/MSMT17/bagtricks_R101-ibn.yml b/configs/MSMT17/bagtricks_R101-ibn.yml
index ba86d24..aa21848 100644
--- a/configs/MSMT17/bagtricks_R101-ibn.yml
+++ b/configs/MSMT17/bagtricks_R101-ibn.yml
@@ -4,7 +4,6 @@ MODEL:
   BACKBONE:
     DEPTH: 101
     WITH_IBN: True
-    PRETRAIN_PATH: "/home/liaoxingyu2/lxy/.cache/torch/checkpoints/resnet101_ibn_a.pth.tar"
 
 DATASETS:
   NAMES: ("MSMT17",)
diff --git a/configs/MSMT17/bagtricks_R50-ibn.yml b/configs/MSMT17/bagtricks_R50-ibn.yml
index 563d48e..fac921e 100644
--- a/configs/MSMT17/bagtricks_R50-ibn.yml
+++ b/configs/MSMT17/bagtricks_R50-ibn.yml
@@ -3,7 +3,6 @@ _BASE_: "../Base-bagtricks.yml"
 MODEL:
   BACKBONE:
     WITH_IBN: True
-    PRETRAIN_PATH: "/home/liaoxingyu2/lxy/.cache/torch/checkpoints/resnet50_ibn_a.pth.tar"
 
 DATASETS:
   NAMES: ("MSMT17",)
diff --git a/configs/MSMT17/mgn_R50-ibn.yml b/configs/MSMT17/mgn_R50-ibn.yml
index 303371f..07f18dd 100644
--- a/configs/MSMT17/mgn_R50-ibn.yml
+++ b/configs/MSMT17/mgn_R50-ibn.yml
@@ -3,7 +3,6 @@ _BASE_: "../Base-MGN.yml"
 MODEL:
   BACKBONE:
     WITH_IBN: True
-    PRETRAIN_PATH: "/home/liaoxingyu2/lxy/.cache/torch/checkpoints/resnet50_ibn_a.pth.tar"
 
 DATASETS:
   NAMES: ("MSMT17",)
diff --git a/configs/MSMT17/sbs_R101-ibn.yml b/configs/MSMT17/sbs_R101-ibn.yml
index cebd7ff..85c7c1a 100644
--- a/configs/MSMT17/sbs_R101-ibn.yml
+++ b/configs/MSMT17/sbs_R101-ibn.yml
@@ -4,7 +4,6 @@ MODEL:
   BACKBONE:
     DEPTH: 101
     WITH_IBN: True
-    PRETRAIN_PATH: "/home/liaoxingyu2/lxy/.cache/torch/checkpoints/resnet101_ibn_a.pth.tar"
 
 DATASETS:
   NAMES: ("MSMT17",)
diff --git a/configs/MSMT17/sbs_R50-ibn.yml b/configs/MSMT17/sbs_R50-ibn.yml
index 64cb29d..d90ce4a 100644
--- a/configs/MSMT17/sbs_R50-ibn.yml
+++ b/configs/MSMT17/sbs_R50-ibn.yml
@@ -3,7 +3,6 @@ _BASE_: "../Base-Strongerbaseline.yml"
 MODEL:
   BACKBONE:
     WITH_IBN: True
-    PRETRAIN_PATH: "/home/liaoxingyu2/lxy/.cache/torch/checkpoints/resnet50_ibn_a.pth.tar"
 
 DATASETS:
   NAMES: ("MSMT17",)
diff --git a/configs/Market1501/AGW_R101-ibn.yml b/configs/Market1501/AGW_R101-ibn.yml
index afe2020..d86dbc8 100644
--- a/configs/Market1501/AGW_R101-ibn.yml
+++ b/configs/Market1501/AGW_R101-ibn.yml
@@ -4,7 +4,6 @@ MODEL:
   BACKBONE:
     DEPTH: 101
     WITH_IBN: True
-    PRETRAIN_PATH: "/home/liaoxingyu2/lxy/.cache/torch/checkpoints/resnet101_ibn_a.pth.tar"
 
 DATASETS:
   NAMES: ("Market1501",)
diff --git a/configs/Market1501/AGW_R50-ibn.yml b/configs/Market1501/AGW_R50-ibn.yml
index bee96c1..4ec8154 100644
--- a/configs/Market1501/AGW_R50-ibn.yml
+++ b/configs/Market1501/AGW_R50-ibn.yml
@@ -3,7 +3,6 @@ _BASE_: "../Base-AGW.yml"
 MODEL:
   BACKBONE:
     WITH_IBN: True
-    PRETRAIN_PATH: "/home/liaoxingyu2/lxy/.cache/torch/checkpoints/resnet50_ibn_a.pth.tar"
 
 DATASETS:
   NAMES: ("Market1501",)
diff --git a/configs/Market1501/bagtricks_R101-ibn.yml b/configs/Market1501/bagtricks_R101-ibn.yml
index 5103391..03b5132 100644
--- a/configs/Market1501/bagtricks_R101-ibn.yml
+++ b/configs/Market1501/bagtricks_R101-ibn.yml
@@ -2,9 +2,8 @@ _BASE_: "../Base-bagtricks.yml"
 
 MODEL:
   BACKBONE:
-    DEPTH: 101
+    DEPTH: "101"
     WITH_IBN: True
-    PRETRAIN_PATH: "/home/liaoxingyu2/lxy/.cache/torch/checkpoints/resnet101_ibn_a.pth.tar"
 
 DATASETS:
   NAMES: ("Market1501",)
diff --git a/configs/Market1501/bagtricks_R50-ibn.yml b/configs/Market1501/bagtricks_R50-ibn.yml
index d8f0cf6..0a5a032 100644
--- a/configs/Market1501/bagtricks_R50-ibn.yml
+++ b/configs/Market1501/bagtricks_R50-ibn.yml
@@ -3,7 +3,6 @@ _BASE_: "../Base-bagtricks.yml"
 MODEL:
   BACKBONE:
     WITH_IBN: True
-    PRETRAIN_PATH: "/home/liaoxingyu2/lxy/.cache/torch/checkpoints/resnet50_ibn_a.pth.tar"
 
 DATASETS:
   NAMES: ("Market1501",)
diff --git a/configs/Market1501/mgn_R50-ibn.yml b/configs/Market1501/mgn_R50-ibn.yml
index 058ca12..b2ca6f2 100644
--- a/configs/Market1501/mgn_R50-ibn.yml
+++ b/configs/Market1501/mgn_R50-ibn.yml
@@ -3,7 +3,6 @@ _BASE_: "../Base-MGN.yml"
 MODEL:
   BACKBONE:
     WITH_IBN: True
-    PRETRAIN_PATH: "/home/liaoxingyu2/lxy/.cache/torch/checkpoints/resnet50_ibn_a.pth.tar"
 
 DATASETS:
   NAMES: ("Market1501",)
diff --git a/configs/Market1501/sbs_R101-ibn.yml b/configs/Market1501/sbs_R101-ibn.yml
index 44fc73d..ff74053 100644
--- a/configs/Market1501/sbs_R101-ibn.yml
+++ b/configs/Market1501/sbs_R101-ibn.yml
@@ -4,7 +4,6 @@ MODEL:
   BACKBONE:
     DEPTH: 101
     WITH_IBN: True
-    PRETRAIN_PATH: "/home/liaoxingyu2/lxy/.cache/torch/checkpoints/resnet101_ibn_a.pth.tar"
 
 DATASETS:
   NAMES: ("Market1501",)
diff --git a/configs/Market1501/sbs_R50-ibn.yml b/configs/Market1501/sbs_R50-ibn.yml
index 4b556a0..7302591 100644
--- a/configs/Market1501/sbs_R50-ibn.yml
+++ b/configs/Market1501/sbs_R50-ibn.yml
@@ -3,7 +3,6 @@ _BASE_: "../Base-Strongerbaseline.yml"
 MODEL:
   BACKBONE:
     WITH_IBN: True
-    PRETRAIN_PATH: "/home/liaoxingyu2/lxy/.cache/torch/checkpoints/resnet50_ibn_a.pth.tar"
 
 DATASETS:
   NAMES: ("Market1501",)
diff --git a/configs/VERIWild/bagtricks_R50-ibn.yml b/configs/VERIWild/bagtricks_R50-ibn.yml
index a0375d1..20fc877 100644
--- a/configs/VERIWild/bagtricks_R50-ibn.yml
+++ b/configs/VERIWild/bagtricks_R50-ibn.yml
@@ -7,9 +7,7 @@ INPUT:
 MODEL:
   BACKBONE:
     WITH_IBN: True
-    PRETRAIN_PATH: '/export2/home/zjk/pretrain_models/resnet50_ibn_a.pth.tar'
   HEADS:
-    NUM_CLASSES: 30671
     POOL_LAYER: gempool
   LOSSES:
     TRI:
diff --git a/configs/VeRi/sbs_R50-ibn.yml b/configs/VeRi/sbs_R50-ibn.yml
index d64a10f..5324cde 100644
--- a/configs/VeRi/sbs_R50-ibn.yml
+++ b/configs/VeRi/sbs_R50-ibn.yml
@@ -7,10 +7,6 @@ INPUT:
 MODEL:
   BACKBONE:
     WITH_IBN: True
-    PRETRAIN_PATH: "/export2/home/zjk/pretrain_models/resnet50_ibn_a.pth.tar"
-
-  HEADS:
-    NUM_CLASSES: 575
 
 SOLVER:
   OPT: "SGD"
diff --git a/configs/VehicleID/bagtricks_R50-ibn.yml b/configs/VehicleID/bagtricks_R50-ibn.yml
index 6bd886c..0020c4f 100644
--- a/configs/VehicleID/bagtricks_R50-ibn.yml
+++ b/configs/VehicleID/bagtricks_R50-ibn.yml
@@ -7,9 +7,7 @@ INPUT:
 MODEL:
   BACKBONE:
     WITH_IBN: True
-    PRETRAIN_PATH: '/export2/home/zjk/pretrain_models/resnet50_ibn_a.pth.tar'
   HEADS:
-    NUM_CLASSES: 13164
     POOL_LAYER: gempool
   LOSSES:
     TRI:
diff --git a/demo/demo.py b/demo/demo.py
index bab0adb..349a796 100644
--- a/demo/demo.py
+++ b/demo/demo.py
@@ -29,6 +29,7 @@ cudnn.benchmark = True
 def setup_cfg(args):
     # load config from file and command-line arguments
     cfg = get_cfg()
+    # add_partialreid_config(cfg)
     cfg.merge_from_file(args.config_file)
     cfg.merge_from_list(args.opts)
     cfg.freeze()
diff --git a/demo/plot_roc_with_pickle.py b/demo/plot_roc_with_pickle.py
index 0ab51e3..c4e9aad 100644
--- a/demo/plot_roc_with_pickle.py
+++ b/demo/plot_roc_with_pickle.py
@@ -1,7 +1,7 @@
 # encoding: utf-8
 """
 @author:  xingyu liao
-@contact: liaoxingyu5@jd.com
+@contact: sherlockliao01@gmail.com
 """
 
 import matplotlib.pyplot as plt
diff --git a/demo/predictor.py b/demo/predictor.py
index a6b8512..b56f01c 100644
--- a/demo/predictor.py
+++ b/demo/predictor.py
@@ -1,7 +1,7 @@
 # encoding: utf-8
 """
 @author:  xingyu liao
-@contact: liaoxingyu5@jd.com
+@contact: sherlockliao01@gmail.com
 """
 
 import atexit
diff --git a/demo/visualize_result.py b/demo/visualize_result.py
index f3f315e..ea46381 100644
--- a/demo/visualize_result.py
+++ b/demo/visualize_result.py
@@ -1,7 +1,7 @@
 # encoding: utf-8
 """
 @author:  xingyu liao
-@contact: liaoxingyu5@jd.com
+@contact: sherlockliao01@gmail.com
 """
 
 import argparse
diff --git a/fastreid/config/defaults.py b/fastreid/config/defaults.py
index 86cbb78..fbaafd8 100644
--- a/fastreid/config/defaults.py
+++ b/fastreid/config/defaults.py
@@ -30,9 +30,7 @@ _C.MODEL.FREEZE_LAYERS = ['']
 _C.MODEL.BACKBONE = CN()
 
 _C.MODEL.BACKBONE.NAME = "build_resnet_backbone"
-_C.MODEL.BACKBONE.DEPTH = 50
-# RegNet volume
-_C.MODEL.BACKBONE.VOLUME = "800y"
+_C.MODEL.BACKBONE.DEPTH = "50x"
 _C.MODEL.BACKBONE.LAST_STRIDE = 1
 # Normalization method for the convolution layers.
 _C.MODEL.BACKBONE.NORM = "BN"
@@ -137,7 +135,13 @@ _C.INPUT.DO_PAD = True
 _C.INPUT.PADDING_MODE = 'constant'
 _C.INPUT.PADDING = 10
 # Random color jitter
-_C.INPUT.DO_CJ = False
+_C.INPUT.CJ = CN()
+_C.INPUT.CJ.ENABLED = False
+_C.INPUT.CJ.PROB = 0.8
+_C.INPUT.CJ.BRIGHTNESS = 0.15
+_C.INPUT.CJ.CONTRAST = 0.15
+_C.INPUT.CJ.SATURATION = 0.1
+_C.INPUT.CJ.HUE = 0.1
 # Auto augmentation
 _C.INPUT.DO_AUTOAUG = False
 # Augmix augmentation
diff --git a/fastreid/data/datasets/vehicleid.py b/fastreid/data/datasets/vehicleid.py
index 013ce1e..1bbc27e 100644
--- a/fastreid/data/datasets/vehicleid.py
+++ b/fastreid/data/datasets/vehicleid.py
@@ -1,124 +1,124 @@
-# encoding: utf-8
-"""
-@author:  Jinkai Zheng
-@contact: 1315673509@qq.com
-"""
-
-import os.path as osp
-import random
-
-from .bases import ImageDataset
-from ..datasets import DATASET_REGISTRY
-
-
-@DATASET_REGISTRY.register()
-class VehicleID(ImageDataset):
-    """VehicleID.
-
-    Reference:
-        Liu et al. Deep relative distance learning: Tell the difference between similar vehicles. CVPR 2016.
-
-    URL: `<https://pkuml.org/resources/pku-vehicleid.html>`_
-
-    Train dataset statistics:
-        - identities: 13164.
-        - images: 113346.
-    """
-    dataset_dir = "vehicleid"
-    dataset_name = "vehicleid"
-
-    def __init__(self, root='datasets', test_list='', **kwargs):
-        self.dataset_dir = osp.join(root, self.dataset_dir)
-
-        self.image_dir = osp.join(self.dataset_dir, 'image')
-        self.train_list = osp.join(self.dataset_dir, 'train_test_split/train_list.txt')
-        if test_list:
-            self.test_list = test_list
-        else:
-            self.test_list = osp.join(self.dataset_dir, 'train_test_split/test_list_13164.txt')
-
-        required_files = [
-            self.dataset_dir,
-            self.image_dir,
-            self.train_list,
-            self.test_list,
-        ]
-        self.check_before_run(required_files)
-
-        train = self.process_dir(self.train_list, is_train=True)
-        query, gallery = self.process_dir(self.test_list, is_train=False)
-
-        super(VehicleID, self).__init__(train, query, gallery, **kwargs)
-
-    def process_dir(self, list_file, is_train=True):
-        img_list_lines = open(list_file, 'r').readlines()
-
-        dataset = []
-        for idx, line in enumerate(img_list_lines):
-            line = line.strip()
-            vid = int(line.split(' ')[1])
-            imgid = line.split(' ')[0]
-            img_path = osp.join(self.image_dir, imgid + '.jpg')
-            if is_train:
-                vid = self.dataset_name + "_" + str(vid)
-            dataset.append((img_path, vid, int(imgid)))
-
-        if is_train: return dataset
-        else:
-            random.shuffle(dataset)
-            vid_container = set()
-            query = []
-            gallery = []
-            for sample in dataset:
-                if sample[1] not in vid_container:
-                    vid_container.add(sample[1])
-                    gallery.append(sample)
-                else:
-                    query.append(sample)
-
-            return query, gallery
-
-
-@DATASET_REGISTRY.register()
-class SmallVehicleID(VehicleID):
-    """VehicleID.
-    Small test dataset statistics:
-        - identities: 800.
-        - images: 6493.
-    """
-
-    def __init__(self, root='datasets', **kwargs):
-        self.dataset_dir = osp.join(root, self.dataset_dir)
-        self.test_list = osp.join(self.dataset_dir, 'train_test_split/test_list_800.txt')
-
-        super(SmallVehicleID, self).__init__(root, self.test_list, **kwargs)
-
-
-@DATASET_REGISTRY.register()
-class MediumVehicleID(VehicleID):
-    """VehicleID.
-    Medium test dataset statistics:
-        - identities: 1600.
-        - images: 13377.
-    """
-
-    def __init__(self, root='datasets', **kwargs):
-        self.dataset_dir = osp.join(root, self.dataset_dir)
-        self.test_list = osp.join(self.dataset_dir, 'train_test_split/test_list_1600.txt')
-
-        super(MediumVehicleID, self).__init__(root, self.test_list, **kwargs)
-
-
-@DATASET_REGISTRY.register()
-class LargeVehicleID(VehicleID):
-    """VehicleID.
-    Large test dataset statistics:
-        - identities: 2400.
-        - images: 19777.
-    """
-
-    def __init__(self, root='datasets', **kwargs):
-        self.dataset_dir = osp.join(root, self.dataset_dir)
-        self.test_list = osp.join(self.dataset_dir, 'train_test_split/test_list_2400.txt')
-
-        super(LargeVehicleID, self).__init__(root, self.test_list, **kwargs)
+# encoding: utf-8
+"""
+@author:  Jinkai Zheng
+@contact: 1315673509@qq.com
+"""
+
+import os.path as osp
+import random
+
+from .bases import ImageDataset
+from ..datasets import DATASET_REGISTRY
+
+
+@DATASET_REGISTRY.register()
+class VehicleID(ImageDataset):
+    """VehicleID.
+
+    Reference:
+        Liu et al. Deep relative distance learning: Tell the difference between similar vehicles. CVPR 2016.
+
+    URL: `<https://pkuml.org/resources/pku-vehicleid.html>`_
+
+    Train dataset statistics:
+        - identities: 13164.
+        - images: 113346.
+    """
+    dataset_dir = "vehicleid"
+    dataset_name = "vehicleid"
+
+    def __init__(self, root='datasets', test_list='', **kwargs):
+        self.dataset_dir = osp.join(root, self.dataset_dir)
+
+        self.image_dir = osp.join(self.dataset_dir, 'image')
+        self.train_list = osp.join(self.dataset_dir, 'train_test_split/train_list.txt')
+        if test_list:
+            self.test_list = test_list
+        else:
+            self.test_list = osp.join(self.dataset_dir, 'train_test_split/test_list_13164.txt')
+
+        required_files = [
+            self.dataset_dir,
+            self.image_dir,
+            self.train_list,
+            self.test_list,
+        ]
+        self.check_before_run(required_files)
+
+        train = self.process_dir(self.train_list, is_train=True)
+        query, gallery = self.process_dir(self.test_list, is_train=False)
+
+        super(VehicleID, self).__init__(train, query, gallery, **kwargs)
+
+    def process_dir(self, list_file, is_train=True):
+        img_list_lines = open(list_file, 'r').readlines()
+
+        dataset = []
+        for idx, line in enumerate(img_list_lines):
+            line = line.strip()
+            vid = int(line.split(' ')[1])
+            imgid = line.split(' ')[0]
+            img_path = osp.join(self.image_dir, imgid + '.jpg')
+            if is_train:
+                vid = self.dataset_name + "_" + str(vid)
+            dataset.append((img_path, vid, int(imgid)))
+
+        if is_train: return dataset
+        else:
+            random.shuffle(dataset)
+            vid_container = set()
+            query = []
+            gallery = []
+            for sample in dataset:
+                if sample[1] not in vid_container:
+                    vid_container.add(sample[1])
+                    gallery.append(sample)
+                else:
+                    query.append(sample)
+
+            return query, gallery
+
+
+@DATASET_REGISTRY.register()
+class SmallVehicleID(VehicleID):
+    """VehicleID.
+    Small test dataset statistics:
+        - identities: 800.
+        - images: 6493.
+    """
+
+    def __init__(self, root='datasets', **kwargs):
+        self.dataset_dir = osp.join(root, self.dataset_dir)
+        self.test_list = osp.join(self.dataset_dir, 'train_test_split/test_list_800.txt')
+
+        super(SmallVehicleID, self).__init__(root, self.test_list, **kwargs)
+
+
+@DATASET_REGISTRY.register()
+class MediumVehicleID(VehicleID):
+    """VehicleID.
+    Medium test dataset statistics:
+        - identities: 1600.
+        - images: 13377.
+    """
+
+    def __init__(self, root='datasets', **kwargs):
+        self.dataset_dir = osp.join(root, self.dataset_dir)
+        self.test_list = osp.join(self.dataset_dir, 'train_test_split/test_list_1600.txt')
+
+        super(MediumVehicleID, self).__init__(root, self.test_list, **kwargs)
+
+
+@DATASET_REGISTRY.register()
+class LargeVehicleID(VehicleID):
+    """VehicleID.
+    Large test dataset statistics:
+        - identities: 2400.
+        - images: 19777.
+    """
+
+    def __init__(self, root='datasets', **kwargs):
+        self.dataset_dir = osp.join(root, self.dataset_dir)
+        self.test_list = osp.join(self.dataset_dir, 'train_test_split/test_list_2400.txt')
+
+        super(LargeVehicleID, self).__init__(root, self.test_list, **kwargs)
diff --git a/fastreid/data/datasets/veri.py b/fastreid/data/datasets/veri.py
index 181d5c1..7e3b166 100644
--- a/fastreid/data/datasets/veri.py
+++ b/fastreid/data/datasets/veri.py
@@ -1,67 +1,67 @@
-# encoding: utf-8
-"""
-@author:  Jinkai Zheng
-@contact: 1315673509@qq.com
-"""
-
-import glob
-import os.path as osp
-import re
-
-from .bases import ImageDataset
-from ..datasets import DATASET_REGISTRY
-
-
-@DATASET_REGISTRY.register()
-class VeRi(ImageDataset):
-    """VeRi.
-
-    Reference:
-        Liu et al. A Deep Learning based Approach for Progressive Vehicle Re-Identification. ECCV 2016.
-
-    URL: `<https://vehiclereid.github.io/VeRi/>`_
-
-    Dataset statistics:
-        - identities: 775.
-        - images: 37778 (train) + 1678 (query) + 11579 (gallery).
-    """
-    dataset_dir = "veri"
-    dataset_name = "veri"
-
-    def __init__(self, root='datasets', **kwargs):
-        self.dataset_dir = osp.join(root, self.dataset_dir)
-
-        self.train_dir = osp.join(self.dataset_dir, 'image_train')
-        self.query_dir = osp.join(self.dataset_dir, 'image_query')
-        self.gallery_dir = osp.join(self.dataset_dir, 'image_test')
-
-        required_files = [
-            self.dataset_dir,
-            self.train_dir,
-            self.query_dir,
-            self.gallery_dir,
-        ]
-        self.check_before_run(required_files)
-
-        train = self.process_dir(self.train_dir)
-        query = self.process_dir(self.query_dir, is_train=False)
-        gallery = self.process_dir(self.gallery_dir, is_train=False)
-
-        super(VeRi, self).__init__(train, query, gallery, **kwargs)
-
-    def process_dir(self, dir_path, is_train=True):
-        img_paths = glob.glob(osp.join(dir_path, '*.jpg'))
-        pattern = re.compile(r'([\d]+)_c(\d\d\d)')
-
-        data = []
-        for img_path in img_paths:
-            pid, camid = map(int, pattern.search(img_path).groups())
-            if pid == -1: continue  # junk images are just ignored
-            assert 1 <= pid <= 776
-            assert 1 <= camid <= 20
-            camid -= 1  # index starts from 0
-            if is_train:
-                pid = self.dataset_name + "_" + str(pid)
-            data.append((img_path, pid, camid))
-
-        return data
+# encoding: utf-8
+"""
+@author:  Jinkai Zheng
+@contact: 1315673509@qq.com
+"""
+
+import glob
+import os.path as osp
+import re
+
+from .bases import ImageDataset
+from ..datasets import DATASET_REGISTRY
+
+
+@DATASET_REGISTRY.register()
+class VeRi(ImageDataset):
+    """VeRi.
+
+    Reference:
+        Liu et al. A Deep Learning based Approach for Progressive Vehicle Re-Identification. ECCV 2016.
+
+    URL: `<https://vehiclereid.github.io/VeRi/>`_
+
+    Dataset statistics:
+        - identities: 775.
+        - images: 37778 (train) + 1678 (query) + 11579 (gallery).
+    """
+    dataset_dir = "veri"
+    dataset_name = "veri"
+
+    def __init__(self, root='datasets', **kwargs):
+        self.dataset_dir = osp.join(root, self.dataset_dir)
+
+        self.train_dir = osp.join(self.dataset_dir, 'image_train')
+        self.query_dir = osp.join(self.dataset_dir, 'image_query')
+        self.gallery_dir = osp.join(self.dataset_dir, 'image_test')
+
+        required_files = [
+            self.dataset_dir,
+            self.train_dir,
+            self.query_dir,
+            self.gallery_dir,
+        ]
+        self.check_before_run(required_files)
+
+        train = self.process_dir(self.train_dir)
+        query = self.process_dir(self.query_dir, is_train=False)
+        gallery = self.process_dir(self.gallery_dir, is_train=False)
+
+        super(VeRi, self).__init__(train, query, gallery, **kwargs)
+
+    def process_dir(self, dir_path, is_train=True):
+        img_paths = glob.glob(osp.join(dir_path, '*.jpg'))
+        pattern = re.compile(r'([\d]+)_c(\d\d\d)')
+
+        data = []
+        for img_path in img_paths:
+            pid, camid = map(int, pattern.search(img_path).groups())
+            if pid == -1: continue  # junk images are just ignored
+            assert 1 <= pid <= 776
+            assert 1 <= camid <= 20
+            camid -= 1  # index starts from 0
+            if is_train:
+                pid = self.dataset_name + "_" + str(pid)
+            data.append((img_path, pid, camid))
+
+        return data
diff --git a/fastreid/data/datasets/veriwild.py b/fastreid/data/datasets/veriwild.py
index 9637d7b..1e0b1cb 100644
--- a/fastreid/data/datasets/veriwild.py
+++ b/fastreid/data/datasets/veriwild.py
@@ -1,138 +1,138 @@
-# encoding: utf-8
-"""
-@author:  Jinkai Zheng
-@contact: 1315673509@qq.com
-"""
-
-import os.path as osp
-
-from .bases import ImageDataset
-from ..datasets import DATASET_REGISTRY
-
-
-@DATASET_REGISTRY.register()
-class VeRiWild(ImageDataset):
-    """VeRi-Wild.
-
-    Reference:
-        Lou et al. A Large-Scale Dataset for Vehicle Re-Identification in the Wild. CVPR 2019.
-
-    URL: `<https://github.com/PKU-IMRE/VERI-Wild>`_
-
-    Train dataset statistics:
-        - identities: 30671.
-        - images: 277797.
-    """
-    dataset_dir = "VERI-Wild"
-    dataset_name = "veriwild"
-
-    def __init__(self, root='datasets', query_list='', gallery_list='', **kwargs):
-        self.dataset_dir = osp.join(root, self.dataset_dir)
-
-        self.image_dir = osp.join(self.dataset_dir, 'images')
-        self.train_list = osp.join(self.dataset_dir, 'train_test_split/train_list.txt')
-        self.vehicle_info = osp.join(self.dataset_dir, 'train_test_split/vehicle_info.txt')
-        if query_list and gallery_list:
-            self.query_list = query_list
-            self.gallery_list = gallery_list
-        else:
-            self.query_list = osp.join(self.dataset_dir, 'train_test_split/test_10000_query.txt')
-            self.gallery_list = osp.join(self.dataset_dir, 'train_test_split/test_10000.txt')
-
-        required_files = [
-            self.image_dir,
-            self.train_list,
-            self.query_list,
-            self.gallery_list,
-            self.vehicle_info,
-        ]
-        self.check_before_run(required_files)
-
-        self.imgid2vid, self.imgid2camid, self.imgid2imgpath = self.process_vehicle(self.vehicle_info)
-
-        train = self.process_dir(self.train_list)
-        query = self.process_dir(self.query_list, is_train=False)
-        gallery = self.process_dir(self.gallery_list, is_train=False)
-
-        super(VeRiWild, self).__init__(train, query, gallery, **kwargs)
-
-    def process_dir(self, img_list, is_train=True):
-        img_list_lines = open(img_list, 'r').readlines()
-
-        dataset = []
-        for idx, line in enumerate(img_list_lines):
-            line = line.strip()
-            vid = int(line.split('/')[0])
-            imgid = line.split('/')[1]
-            if is_train:
-                vid = self.dataset_name + "_" + str(vid)
-            dataset.append((self.imgid2imgpath[imgid], vid, int(self.imgid2camid[imgid])))
-
-        assert len(dataset) == len(img_list_lines)
-        return dataset
-
-    def process_vehicle(self, vehicle_info):
-        imgid2vid = {}
-        imgid2camid = {}
-        imgid2imgpath = {}
-        vehicle_info_lines = open(vehicle_info, 'r').readlines()
-
-        for idx, line in enumerate(vehicle_info_lines[1:]):
-            vid = line.strip().split('/')[0]
-            imgid = line.strip().split(';')[0].split('/')[1]
-            camid = line.strip().split(';')[1]
-            img_path = osp.join(self.image_dir, vid, imgid + '.jpg')
-            imgid2vid[imgid] = vid
-            imgid2camid[imgid] = camid
-            imgid2imgpath[imgid] = img_path
-
-        assert len(imgid2vid) == len(vehicle_info_lines) - 1
-        return imgid2vid, imgid2camid, imgid2imgpath
-
-
-@DATASET_REGISTRY.register()
-class SmallVeRiWild(VeRiWild):
-    """VeRi-Wild.
-    Small test dataset statistics:
-        - identities: 3000.
-        - images: 41861.
-    """
-
-    def __init__(self, root='datasets', **kwargs):
-        self.dataset_dir = osp.join(root, self.dataset_dir)
-        self.query_list = osp.join(self.dataset_dir, 'train_test_split/test_3000_query.txt')
-        self.gallery_list = osp.join(self.dataset_dir, 'train_test_split/test_3000.txt')
-
-        super(SmallVeRiWild, self).__init__(root, self.query_list, self.gallery_list, **kwargs)
-
-
-@DATASET_REGISTRY.register()
-class MediumVeRiWild(VeRiWild):
-    """VeRi-Wild.
-    Medium test dataset statistics:
-        - identities: 5000.
-        - images: 69389.
-    """
-
-    def __init__(self, root='datasets', **kwargs):
-        self.dataset_dir = osp.join(root, self.dataset_dir)
-        self.query_list = osp.join(self.dataset_dir, 'train_test_split/test_5000_query.txt')
-        self.gallery_list = osp.join(self.dataset_dir, 'train_test_split/test_5000.txt')
-
-        super(MediumVeRiWild, self).__init__(root, self.query_list, self.gallery_list, **kwargs)
-
-
-@DATASET_REGISTRY.register()
-class LargeVeRiWild(VeRiWild):
-    """VeRi-Wild.
-    Large test dataset statistics:
-        - identities: 10000.
-        - images: 138517.
-    """
-
-    def __init__(self, root='datasets', **kwargs):
-        self.dataset_dir = osp.join(root, self.dataset_dir)
-        self.query_list = osp.join(self.dataset_dir, 'train_test_split/test_10000_query.txt')
-        self.gallery_list = osp.join(self.dataset_dir, 'train_test_split/test_10000.txt')
-
-        super(LargeVeRiWild, self).__init__(root, self.query_list, self.gallery_list, **kwargs)
+# encoding: utf-8
+"""
+@author:  Jinkai Zheng
+@contact: 1315673509@qq.com
+"""
+
+import os.path as osp
+
+from .bases import ImageDataset
+from ..datasets import DATASET_REGISTRY
+
+
+@DATASET_REGISTRY.register()
+class VeRiWild(ImageDataset):
+    """VeRi-Wild.
+
+    Reference:
+        Lou et al. A Large-Scale Dataset for Vehicle Re-Identification in the Wild. CVPR 2019.
+
+    URL: `<https://github.com/PKU-IMRE/VERI-Wild>`_
+
+    Train dataset statistics:
+        - identities: 30671.
+        - images: 277797.
+    """
+    dataset_dir = "VERI-Wild"
+    dataset_name = "veriwild"
+
+    def __init__(self, root='datasets', query_list='', gallery_list='', **kwargs):
+        self.dataset_dir = osp.join(root, self.dataset_dir)
+
+        self.image_dir = osp.join(self.dataset_dir, 'images')
+        self.train_list = osp.join(self.dataset_dir, 'train_test_split/train_list.txt')
+        self.vehicle_info = osp.join(self.dataset_dir, 'train_test_split/vehicle_info.txt')
+        if query_list and gallery_list:
+            self.query_list = query_list
+            self.gallery_list = gallery_list
+        else:
+            self.query_list = osp.join(self.dataset_dir, 'train_test_split/test_10000_query.txt')
+            self.gallery_list = osp.join(self.dataset_dir, 'train_test_split/test_10000.txt')
+
+        required_files = [
+            self.image_dir,
+            self.train_list,
+            self.query_list,
+            self.gallery_list,
+            self.vehicle_info,
+        ]
+        self.check_before_run(required_files)
+
+        self.imgid2vid, self.imgid2camid, self.imgid2imgpath = self.process_vehicle(self.vehicle_info)
+
+        train = self.process_dir(self.train_list)
+        query = self.process_dir(self.query_list, is_train=False)
+        gallery = self.process_dir(self.gallery_list, is_train=False)
+
+        super(VeRiWild, self).__init__(train, query, gallery, **kwargs)
+
+    def process_dir(self, img_list, is_train=True):
+        img_list_lines = open(img_list, 'r').readlines()
+
+        dataset = []
+        for idx, line in enumerate(img_list_lines):
+            line = line.strip()
+            vid = int(line.split('/')[0])
+            imgid = line.split('/')[1]
+            if is_train:
+                vid = self.dataset_name + "_" + str(vid)
+            dataset.append((self.imgid2imgpath[imgid], vid, int(self.imgid2camid[imgid])))
+
+        assert len(dataset) == len(img_list_lines)
+        return dataset
+
+    def process_vehicle(self, vehicle_info):
+        imgid2vid = {}
+        imgid2camid = {}
+        imgid2imgpath = {}
+        vehicle_info_lines = open(vehicle_info, 'r').readlines()
+
+        for idx, line in enumerate(vehicle_info_lines[1:]):
+            vid = line.strip().split('/')[0]
+            imgid = line.strip().split(';')[0].split('/')[1]
+            camid = line.strip().split(';')[1]
+            img_path = osp.join(self.image_dir, vid, imgid + '.jpg')
+            imgid2vid[imgid] = vid
+            imgid2camid[imgid] = camid
+            imgid2imgpath[imgid] = img_path
+
+        assert len(imgid2vid) == len(vehicle_info_lines) - 1
+        return imgid2vid, imgid2camid, imgid2imgpath
+
+
+@DATASET_REGISTRY.register()
+class SmallVeRiWild(VeRiWild):
+    """VeRi-Wild.
+    Small test dataset statistics:
+        - identities: 3000.
+        - images: 41861.
+    """
+
+    def __init__(self, root='datasets', **kwargs):
+        self.dataset_dir = osp.join(root, self.dataset_dir)
+        self.query_list = osp.join(self.dataset_dir, 'train_test_split/test_3000_query.txt')
+        self.gallery_list = osp.join(self.dataset_dir, 'train_test_split/test_3000.txt')
+
+        super(SmallVeRiWild, self).__init__(root, self.query_list, self.gallery_list, **kwargs)
+
+
+@DATASET_REGISTRY.register()
+class MediumVeRiWild(VeRiWild):
+    """VeRi-Wild.
+    Medium test dataset statistics:
+        - identities: 5000.
+        - images: 69389.
+    """
+
+    def __init__(self, root='datasets', **kwargs):
+        self.dataset_dir = osp.join(root, self.dataset_dir)
+        self.query_list = osp.join(self.dataset_dir, 'train_test_split/test_5000_query.txt')
+        self.gallery_list = osp.join(self.dataset_dir, 'train_test_split/test_5000.txt')
+
+        super(MediumVeRiWild, self).__init__(root, self.query_list, self.gallery_list, **kwargs)
+
+
+@DATASET_REGISTRY.register()
+class LargeVeRiWild(VeRiWild):
+    """VeRi-Wild.
+    Large test dataset statistics:
+        - identities: 10000.
+        - images: 138517.
+    """
+
+    def __init__(self, root='datasets', **kwargs):
+        self.dataset_dir = osp.join(root, self.dataset_dir)
+        self.query_list = osp.join(self.dataset_dir, 'train_test_split/test_10000_query.txt')
+        self.gallery_list = osp.join(self.dataset_dir, 'train_test_split/test_10000.txt')
+
+        super(LargeVeRiWild, self).__init__(root, self.query_list, self.gallery_list, **kwargs)
diff --git a/fastreid/data/transforms/build.py b/fastreid/data/transforms/build.py
index 35e7a9b..743c8a2 100644
--- a/fastreid/data/transforms/build.py
+++ b/fastreid/data/transforms/build.py
@@ -33,7 +33,12 @@ def build_transforms(cfg, is_train=True):
         padding_mode = cfg.INPUT.PADDING_MODE
 
         # color jitter
-        do_cj = cfg.INPUT.DO_CJ
+        do_cj = cfg.INPUT.CJ.ENABLED
+        cj_prob = cfg.INPUT.CJ.PROB
+        cj_brightness = cfg.INPUT.CJ.BRIGHTNESS
+        cj_contrast = cfg.INPUT.CJ.CONTRAST
+        cj_saturation = cfg.INPUT.CJ.SATURATION
+        cj_hue = cfg.INPUT.CJ.HUE
 
         # random erasing
         do_rea = cfg.INPUT.REA.ENABLED
@@ -52,7 +57,7 @@ def build_transforms(cfg, is_train=True):
             res.extend([T.Pad(padding, padding_mode=padding_mode),
                         T.RandomCrop(size_train)])
         if do_cj:
-            res.append(T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0))
+            T.RandomApply([T.ColorJitter(cj_brightness, cj_contrast, cj_saturation, cj_hue)], p=cj_prob)
         if do_augmix:
             res.append(AugMix())
         if do_rea:
diff --git a/fastreid/data/transforms/transforms.py b/fastreid/data/transforms/transforms.py
index b66a96f..0de5562 100644
--- a/fastreid/data/transforms/transforms.py
+++ b/fastreid/data/transforms/transforms.py
@@ -4,7 +4,7 @@
 @contact: sherlockliao01@gmail.com
 """
 
-__all__ = ['ToTensor', 'RandomErasing', 'RandomPatch', 'AugMix', ]
+__all__ = ['ToTensor', 'RandomErasing', 'RandomPatch', 'AugMix',]
 
 import math
 import random
@@ -202,110 +202,3 @@ class AugMix(object):
 
         mixed = (1 - m) * image + m * mix
         return mixed
-
-# class ColorJitter(object):
-#     """docstring for do_color"""
-#
-#     def __init__(self, probability=0.5):
-#         self.probability = probability
-#
-#     def do_brightness_shift(self, image, alpha=0.125):
-#         image = image.astype(np.float32)
-#         image = image + alpha * 255
-#         image = np.clip(image, 0, 255).astype(np.uint8)
-#         return image
-#
-#     def do_brightness_multiply(self, image, alpha=1):
-#         image = image.astype(np.float32)
-#         image = alpha * image
-#         image = np.clip(image, 0, 255).astype(np.uint8)
-#         return image
-#
-#     def do_contrast(self, image, alpha=1.0):
-#         image = image.astype(np.float32)
-#         gray = image * np.array([[[0.114, 0.587, 0.299]]])  # rgb to gray (YCbCr)
-#         gray = (3.0 * (1.0 - alpha) / gray.size) * np.sum(gray)
-#         image = alpha * image + gray
-#         image = np.clip(image, 0, 255).astype(np.uint8)
-#         return image
-#
-#     # https://www.pyimagesearch.com/2015/10/05/opencv-gamma-correction/
-#     def do_gamma(self, image, gamma=1.0):
-#         table = np.array([((i / 255.0) ** (1.0 / gamma)) * 255
-#                           for i in np.arange(0, 256)]).astype("uint8")
-#
-#         return cv2.LUT(image, table)  # apply gamma correction using the lookup table
-#
-#     def do_clahe(self, image, clip=2, grid=16):
-#         grid = int(grid)
-#
-#         lab = cv2.cvtColor(image, cv2.COLOR_BGR2LAB)
-#         gray, a, b = cv2.split(lab)
-#         gray = cv2.createCLAHE(clipLimit=clip, tileGridSize=(grid, grid)).apply(gray)
-#         lab = cv2.merge((gray, a, b))
-#         image = cv2.cvtColor(lab, cv2.COLOR_LAB2BGR)
-#
-#         return image
-#
-#     def __call__(self, image):
-#         if random.uniform(0, 1) > self.probability:
-#             return image
-#
-#         image = np.asarray(image, dtype=np.uint8).copy()
-#         index = random.randint(0, 4)
-#         if index == 0:
-#             image = self.do_brightness_shift(image, 0.1)
-#         elif index == 1:
-#             image = self.do_gamma(image, 1)
-#         elif index == 2:
-#             image = self.do_clahe(image)
-#         elif index == 3:
-#             image = self.do_brightness_multiply(image)
-#         elif index == 4:
-#             image = self.do_contrast(image)
-#         return image
-
-
-# class random_shift(object):
-#     """docstring for do_color"""
-#
-#     def __init__(self, probability=0.5):
-#         self.probability = probability
-#
-#     def __call__(self, image):
-#         if random.uniform(0, 1) > self.probability:
-#             return image
-#
-#         width, height, d = image.shape
-#         zero_image = np.zeros_like(image)
-#         w = random.randint(0, 20) - 10
-#         h = random.randint(0, 30) - 15
-#         zero_image[max(0, w): min(w + width, width), max(h, 0): min(h + height, height)] = \
-#             image[max(0, -w): min(-w + width, width), max(-h, 0): min(-h + height, height)]
-#         image = zero_image.copy()
-#         return image
-#
-#
-# class random_scale(object):
-#     """docstring for do_color"""
-#
-#     def __init__(self, probability=0.5):
-#         self.probability = probability
-#
-#     def __call__(self, image):
-#         if random.uniform(0, 1) > self.probability:
-#             return image
-#
-#         scale = random.random() * 0.1 + 0.9
-#         assert 0.9 <= scale <= 1
-#         width, height, d = image.shape
-#         zero_image = np.zeros_like(image)
-#         new_width = round(width * scale)
-#         new_height = round(height * scale)
-#         image = cv2.resize(image, (new_height, new_width))
-#         start_w = random.randint(0, width - new_width)
-#         start_h = random.randint(0, height - new_height)
-#         zero_image[start_w: start_w + new_width,
-#         start_h:start_h + new_height] = image
-#         image = zero_image.copy()
-#         return image
diff --git a/fastreid/engine/launch.py b/fastreid/engine/launch.py
index 328db59..a8f5c9b 100644
--- a/fastreid/engine/launch.py
+++ b/fastreid/engine/launch.py
@@ -1,7 +1,7 @@
 # encoding: utf-8
 """
 @author:  xingyu liao
-@contact: liaoxingyu5@jd.com
+@contact: sherlockliao01@gmail.com
 """
 
 # based on:
diff --git a/fastreid/evaluation/query_expansion.py b/fastreid/evaluation/query_expansion.py
index 873aed6..637c097 100644
--- a/fastreid/evaluation/query_expansion.py
+++ b/fastreid/evaluation/query_expansion.py
@@ -1,7 +1,7 @@
 # encoding: utf-8
 """
 @author:  xingyu liao
-@contact: liaoxingyu5@jd.com
+@contact: sherlockliao01@gmail.com
 """
 
 # based on
diff --git a/fastreid/evaluation/reid_evaluation.py b/fastreid/evaluation/reid_evaluation.py
index d83e399..a46913e 100644
--- a/fastreid/evaluation/reid_evaluation.py
+++ b/fastreid/evaluation/reid_evaluation.py
@@ -10,7 +10,6 @@ from collections import OrderedDict
 import numpy as np
 import torch
 import torch.nn.functional as F
-from tabulate import tabulate
 
 from .evaluator import DatasetEvaluator
 from .query_expansion import aqe
@@ -101,6 +100,6 @@ class ReidEvaluator(DatasetEvaluator):
         tprs = evaluate_roc(dist, query_pids, gallery_pids, query_camids, gallery_camids)
         fprs = [1e-4, 1e-3, 1e-2]
         for i in range(len(fprs)):
-            self._results["TPR@FPR={}".format(fprs[i])] = tprs[i]
+            self._results["TPR@FPR={:.0e}".format(fprs[i])] = tprs[i]
 
         return copy.deepcopy(self._results)
diff --git a/fastreid/layers/__init__.py b/fastreid/layers/__init__.py
index 606b3fc..dccaccd 100644
--- a/fastreid/layers/__init__.py
+++ b/fastreid/layers/__init__.py
@@ -6,9 +6,10 @@
 
 from .activation import *
 from .arc_softmax import ArcSoftmax
+from .circle_softmax import CircleSoftmax
+from .am_softmax import AMSoftmax
 from .batch_drop import BatchDrop
 from .batch_norm import *
-from .circle_softmax import CircleSoftmax
 from .context_block import ContextBlock
 from .frn import FRN, TLU
 from .non_local import Non_local
diff --git a/fastreid/layers/activation.py b/fastreid/layers/activation.py
index 3e0aea6..dafbf60 100644
--- a/fastreid/layers/activation.py
+++ b/fastreid/layers/activation.py
@@ -1,7 +1,7 @@
 # encoding: utf-8
 """
 @author:  xingyu liao
-@contact: liaoxingyu5@jd.com
+@contact: sherlockliao01@gmail.com
 """
 
 import math
diff --git a/fastreid/layers/am_softmax.py b/fastreid/layers/am_softmax.py
new file mode 100644
index 0000000..3f5a59e
--- /dev/null
+++ b/fastreid/layers/am_softmax.py
@@ -0,0 +1,43 @@
+# encoding: utf-8
+"""
+@author:  xingyu liao
+@contact: sherlockliao01@gmail.com
+"""
+
+import torch
+from torch import nn
+import torch.nn.functional as F
+from torch.nn import Parameter
+
+
+class AMSoftmax(nn.Module):
+    r"""Implement of large margin cosine distance:
+    Args:
+        in_feat: size of each input sample
+        num_classes: size of each output sample
+    """
+
+    def __init__(self, cfg, in_feat, num_classes):
+        super().__init__()
+        self.in_features = in_feat
+        self._num_classes = num_classes
+        self._s = cfg.MODEL.HEADS.SCALE
+        self._m = cfg.MODEL.HEADS.MARGIN
+        self.weight = Parameter(torch.Tensor(num_classes, in_feat))
+        nn.init.xavier_uniform_(self.weight)
+
+    def forward(self, features, targets):
+        # --------------------------- cos(theta) & phi(theta) ---------------------------
+        cosine = F.linear(F.normalize(features), F.normalize(self.weight))
+        phi = cosine - self._m
+        # --------------------------- convert label to one-hot ---------------------------
+        targets = F.one_hot(targets, num_classes=self._num_classes)
+        output = (targets * phi) + ((1.0 - targets) * cosine)
+        output *= self._s
+
+        return output
+
+    def extra_repr(self):
+        return 'in_features={}, num_classes={}, scale={}, margin={}'.format(
+            self.in_feat, self._num_classes, self._s, self._m
+        )
diff --git a/fastreid/layers/arc_softmax.py b/fastreid/layers/arc_softmax.py
index 32a70f7..455309c 100644
--- a/fastreid/layers/arc_softmax.py
+++ b/fastreid/layers/arc_softmax.py
@@ -26,6 +26,7 @@ class ArcSoftmax(nn.Module):
         self.mm = math.sin(math.pi - self._m) * self._m
 
         self.weight = Parameter(torch.Tensor(num_classes, in_feat))
+        nn.init.xavier_uniform_(self.weight)
         self.register_buffer('t', torch.zeros(1))
 
     def forward(self, features, targets):
diff --git a/fastreid/layers/circle_softmax.py b/fastreid/layers/circle_softmax.py
index 2d596b6..e4e392d 100644
--- a/fastreid/layers/circle_softmax.py
+++ b/fastreid/layers/circle_softmax.py
@@ -4,6 +4,8 @@
 @contact: sherlockliao01@gmail.com
 """
 
+import math
+
 import torch
 import torch.nn as nn
 import torch.nn.functional as F
@@ -19,11 +21,12 @@ class CircleSoftmax(nn.Module):
         self._m = cfg.MODEL.HEADS.MARGIN
 
         self.weight = Parameter(torch.Tensor(num_classes, in_feat))
+        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
 
     def forward(self, features, targets):
         sim_mat = F.linear(F.normalize(features), F.normalize(self.weight))
-        alpha_p = F.relu(-sim_mat.detach() + 1 + self._m)
-        alpha_n = F.relu(sim_mat.detach() + self._m)
+        alpha_p = torch.clamp_min(-sim_mat.detach() + 1 + self._m, min=0.)
+        alpha_n = torch.clamp_min(sim_mat.detach() + self._m, min=0.)
         delta_p = 1 - self._m
         delta_n = self._m
 
diff --git a/fastreid/layers/splat.py b/fastreid/layers/splat.py
index 5559eab..b7451c5 100644
--- a/fastreid/layers/splat.py
+++ b/fastreid/layers/splat.py
@@ -1,9 +1,4 @@
 # encoding: utf-8
-"""
-@author:  xingyu liao
-@contact: liaoxingyu5@jd.com
-"""
-
 import torch
 import torch.nn.functional as F
 from torch import nn
diff --git a/fastreid/modeling/backbones/__init__.py b/fastreid/modeling/backbones/__init__.py
index f964933..d3ee3be 100644
--- a/fastreid/modeling/backbones/__init__.py
+++ b/fastreid/modeling/backbones/__init__.py
@@ -10,4 +10,4 @@ from .resnet import build_resnet_backbone
 from .osnet import build_osnet_backbone
 from .resnest import build_resnest_backbone
 from .resnext import build_resnext_backbone
-from .regnet import build_regnet_backbone
\ No newline at end of file
+from .regnet import build_regnet_backbone
diff --git a/fastreid/modeling/backbones/osnet.py b/fastreid/modeling/backbones/osnet.py
index d71615a..b9da1fa 100644
--- a/fastreid/modeling/backbones/osnet.py
+++ b/fastreid/modeling/backbones/osnet.py
@@ -1,7 +1,7 @@
 # encoding: utf-8
 """
 @author:  xingyu liao
-@contact: liaoxingyu5@jd.com
+@contact: sherlockliao01@gmail.com
 """
 
 # based on:
@@ -9,7 +9,8 @@
 
 import torch
 from torch import nn
-from torch.nn import functional as F
+
+from fastreid.layers import get_norm
 from .build import BACKBONE_REGISTRY
 
 model_urls = {
@@ -37,6 +38,8 @@ class ConvLayer(nn.Module):
             in_channels,
             out_channels,
             kernel_size,
+            bn_norm,
+            num_splits,
             stride=1,
             padding=0,
             groups=1,
@@ -55,7 +58,7 @@ class ConvLayer(nn.Module):
         if IN:
             self.bn = nn.InstanceNorm2d(out_channels, affine=True)
         else:
-            self.bn = nn.BatchNorm2d(out_channels)
+            self.bn = get_norm(bn_norm, out_channels, num_splits)
         self.relu = nn.ReLU(inplace=True)
 
     def forward(self, x):
@@ -68,7 +71,7 @@ class ConvLayer(nn.Module):
 class Conv1x1(nn.Module):
     """1x1 convolution + bn + relu."""
 
-    def __init__(self, in_channels, out_channels, stride=1, groups=1):
+    def __init__(self, in_channels, out_channels, bn_norm, num_splits, stride=1, groups=1):
         super(Conv1x1, self).__init__()
         self.conv = nn.Conv2d(
             in_channels,
@@ -79,7 +82,7 @@ class Conv1x1(nn.Module):
             bias=False,
             groups=groups
         )
-        self.bn = nn.BatchNorm2d(out_channels)
+        self.bn = get_norm(bn_norm, out_channels, num_splits)
         self.relu = nn.ReLU(inplace=True)
 
     def forward(self, x):
@@ -92,12 +95,12 @@ class Conv1x1(nn.Module):
 class Conv1x1Linear(nn.Module):
     """1x1 convolution + bn (w/o non-linearity)."""
 
-    def __init__(self, in_channels, out_channels, stride=1):
+    def __init__(self, in_channels, out_channels, bn_norm, num_splits, stride=1):
         super(Conv1x1Linear, self).__init__()
         self.conv = nn.Conv2d(
             in_channels, out_channels, 1, stride=stride, padding=0, bias=False
         )
-        self.bn = nn.BatchNorm2d(out_channels)
+        self.bn = get_norm(bn_norm, out_channels, num_splits)
 
     def forward(self, x):
         x = self.conv(x)
@@ -108,7 +111,7 @@ class Conv1x1Linear(nn.Module):
 class Conv3x3(nn.Module):
     """3x3 convolution + bn + relu."""
 
-    def __init__(self, in_channels, out_channels, stride=1, groups=1):
+    def __init__(self, in_channels, out_channels, bn_norm, num_splits, stride=1, groups=1):
         super(Conv3x3, self).__init__()
         self.conv = nn.Conv2d(
             in_channels,
@@ -119,7 +122,7 @@ class Conv3x3(nn.Module):
             bias=False,
             groups=groups
         )
-        self.bn = nn.BatchNorm2d(out_channels)
+        self.bn = get_norm(bn_norm, out_channels, num_splits)
         self.relu = nn.ReLU(inplace=True)
 
     def forward(self, x):
@@ -134,7 +137,7 @@ class LightConv3x3(nn.Module):
     1x1 (linear) + dw 3x3 (nonlinear).
     """
 
-    def __init__(self, in_channels, out_channels):
+    def __init__(self, in_channels, out_channels, bn_norm, num_splits):
         super(LightConv3x3, self).__init__()
         self.conv1 = nn.Conv2d(
             in_channels, out_channels, 1, stride=1, padding=0, bias=False
@@ -148,7 +151,7 @@ class LightConv3x3(nn.Module):
             bias=False,
             groups=out_channels
         )
-        self.bn = nn.BatchNorm2d(out_channels)
+        self.bn = get_norm(bn_norm, out_channels, num_splits)
         self.relu = nn.ReLU(inplace=True)
 
     def forward(self, x):
@@ -197,9 +200,12 @@ class ChannelGate(nn.Module):
             bias=True,
             padding=0
         )
-        if gate_activation == 'sigmoid':  self.gate_activation = nn.Sigmoid()
-        elif gate_activation == 'relu':   self.gate_activation = nn.ReLU(inplace=True)
-        elif gate_activation == 'linear': self.gate_activation = nn.Identity()
+        if gate_activation == 'sigmoid':
+            self.gate_activation = nn.Sigmoid()
+        elif gate_activation == 'relu':
+            self.gate_activation = nn.ReLU(inplace=True)
+        elif gate_activation == 'linear':
+            self.gate_activation = nn.Identity()
         else:
             raise RuntimeError(
                 "Unknown gate activation: {}".format(gate_activation)
@@ -224,34 +230,36 @@ class OSBlock(nn.Module):
             self,
             in_channels,
             out_channels,
+            bn_norm,
+            num_splits,
             IN=False,
             bottleneck_reduction=4,
             **kwargs
     ):
         super(OSBlock, self).__init__()
         mid_channels = out_channels // bottleneck_reduction
-        self.conv1 = Conv1x1(in_channels, mid_channels)
-        self.conv2a = LightConv3x3(mid_channels, mid_channels)
+        self.conv1 = Conv1x1(in_channels, mid_channels, bn_norm, num_splits)
+        self.conv2a = LightConv3x3(mid_channels, mid_channels, bn_norm, num_splits)
         self.conv2b = nn.Sequential(
-            LightConv3x3(mid_channels, mid_channels),
-            LightConv3x3(mid_channels, mid_channels),
+            LightConv3x3(mid_channels, mid_channels, bn_norm, num_splits),
+            LightConv3x3(mid_channels, mid_channels, bn_norm, num_splits),
         )
         self.conv2c = nn.Sequential(
-            LightConv3x3(mid_channels, mid_channels),
-            LightConv3x3(mid_channels, mid_channels),
-            LightConv3x3(mid_channels, mid_channels),
+            LightConv3x3(mid_channels, mid_channels, bn_norm, num_splits),
+            LightConv3x3(mid_channels, mid_channels, bn_norm, num_splits),
+            LightConv3x3(mid_channels, mid_channels, bn_norm, num_splits),
         )
         self.conv2d = nn.Sequential(
-            LightConv3x3(mid_channels, mid_channels),
-            LightConv3x3(mid_channels, mid_channels),
-            LightConv3x3(mid_channels, mid_channels),
-            LightConv3x3(mid_channels, mid_channels),
+            LightConv3x3(mid_channels, mid_channels, bn_norm, num_splits),
+            LightConv3x3(mid_channels, mid_channels, bn_norm, num_splits),
+            LightConv3x3(mid_channels, mid_channels, bn_norm, num_splits),
+            LightConv3x3(mid_channels, mid_channels, bn_norm, num_splits),
         )
         self.gate = ChannelGate(mid_channels)
-        self.conv3 = Conv1x1Linear(mid_channels, out_channels)
+        self.conv3 = Conv1x1Linear(mid_channels, out_channels, bn_norm, num_splits)
         self.downsample = None
         if in_channels != out_channels:
-            self.downsample = Conv1x1Linear(in_channels, out_channels)
+            self.downsample = Conv1x1Linear(in_channels, out_channels, bn_norm, num_splits)
         self.IN = None
         if IN: self.IN = nn.InstanceNorm2d(out_channels, affine=True)
         self.relu = nn.ReLU(True)
@@ -290,6 +298,8 @@ class OSNet(nn.Module):
             blocks,
             layers,
             channels,
+            bn_norm,
+            num_splits,
             IN=False,
             **kwargs
     ):
@@ -299,13 +309,15 @@ class OSNet(nn.Module):
         assert num_blocks == len(channels) - 1
 
         # convolutional backbone
-        self.conv1 = ConvLayer(3, channels[0], 7, stride=2, padding=3, IN=IN)
+        self.conv1 = ConvLayer(3, channels[0], 7, bn_norm, num_splits, stride=2, padding=3, IN=IN)
         self.maxpool = nn.MaxPool2d(3, stride=2, padding=1)
         self.conv2 = self._make_layer(
             blocks[0],
             layers[0],
             channels[0],
             channels[1],
+            bn_norm,
+            num_splits,
             reduce_spatial_size=True,
             IN=IN
         )
@@ -314,6 +326,8 @@ class OSNet(nn.Module):
             layers[1],
             channels[1],
             channels[2],
+            bn_norm,
+            num_splits,
             reduce_spatial_size=True
         )
         self.conv4 = self._make_layer(
@@ -321,9 +335,11 @@ class OSNet(nn.Module):
             layers[2],
             channels[2],
             channels[3],
+            bn_norm,
+            num_splits,
             reduce_spatial_size=False
         )
-        self.conv5 = Conv1x1(channels[3], channels[3])
+        self.conv5 = Conv1x1(channels[3], channels[3], bn_norm, num_splits)
 
         self._init_params()
 
@@ -333,19 +349,21 @@ class OSNet(nn.Module):
             layer,
             in_channels,
             out_channels,
+            bn_norm,
+            num_splits,
             reduce_spatial_size,
             IN=False
     ):
         layers = []
 
-        layers.append(block(in_channels, out_channels, IN=IN))
+        layers.append(block(in_channels, out_channels, bn_norm, num_splits, IN=IN))
         for i in range(1, layer):
-            layers.append(block(out_channels, out_channels, IN=IN))
+            layers.append(block(out_channels, out_channels, bn_norm, num_splits, IN=IN))
 
         if reduce_spatial_size:
             layers.append(
                 nn.Sequential(
-                    Conv1x1(out_channels, out_channels),
+                    Conv1x1(out_channels, out_channels, bn_norm, num_splits),
                     nn.AvgPool2d(2, stride=2),
                 )
             )
@@ -477,11 +495,19 @@ def build_osnet_backbone(cfg):
     # fmt: off
     pretrain = cfg.MODEL.BACKBONE.PRETRAIN
     with_ibn = cfg.MODEL.BACKBONE.WITH_IBN
+    bn_norm = cfg.MODEL.BACKBONE.NORM
+    num_splits = cfg.MODEL.BACKBONE.NORM_SPLIT
+    depth = cfg.MODEL.BACKBONE.DEPTH
 
     num_blocks_per_stage = [2, 2, 2]
-    num_channels_per_stage = [64, 256, 384, 512]
-    model = OSNet([OSBlock, OSBlock, OSBlock], num_blocks_per_stage, num_channels_per_stage, with_ibn)
-    pretrain_key = 'osnet_ibn_x1_0' if with_ibn else 'osnet_x1_0'
+    num_channels_per_stage = {"x1_0": [64, 256, 384, 512], "x0_75": [48, 192, 288, 384], "x0_5": [32, 128, 192, 256],
+                              "x0_25": [16, 64, 96, 128]}[depth]
+    model = OSNet([OSBlock, OSBlock, OSBlock], num_blocks_per_stage, num_channels_per_stage,
+                  bn_norm, num_splits, IN=with_ibn)
+
     if pretrain:
+        if with_ibn: pretrain_key = "osnet_ibn_" + depth
+        else:        pretrain_key = "osnet_" + depth
+
         init_pretrained_weights(model, pretrain_key)
     return model
diff --git a/fastreid/modeling/backbones/regnet/regnet.py b/fastreid/modeling/backbones/regnet/regnet.py
index 782731b..5a55097 100644
--- a/fastreid/modeling/backbones/regnet/regnet.py
+++ b/fastreid/modeling/backbones/regnet/regnet.py
@@ -10,7 +10,18 @@ from ..build import BACKBONE_REGISTRY
 from .config import regnet_cfg
 
 logger = logging.getLogger(__name__)
-
+model_urls = {
+    '800x': 'https://dl.fbaipublicfiles.com/pycls/dds_baselines/160906036/RegNetX-800MF_dds_8gpu.pyth',
+    '800y': 'https://dl.fbaipublicfiles.com/pycls/dds_baselines/160906567/RegNetY-800MF_dds_8gpu.pyth',
+    '1600x': 'https://dl.fbaipublicfiles.com/pycls/dds_baselines/160990626/RegNetX-1.6GF_dds_8gpu.pyth',
+    '1600y': 'https://dl.fbaipublicfiles.com/pycls/dds_baselines/160906681/RegNetY-1.6GF_dds_8gpu.pyth',
+    '3200x': 'https://dl.fbaipublicfiles.com/pycls/dds_baselines/160906139/RegNetX-3.2GF_dds_8gpu.pyth',
+    '3200y': 'https://dl.fbaipublicfiles.com/pycls/dds_baselines/160906834/RegNetY-3.2GF_dds_8gpu.pyth',
+    '4000x': 'https://dl.fbaipublicfiles.com/pycls/dds_baselines/160906383/RegNetX-4.0GF_dds_8gpu.pyth',
+    '4000y': 'https://dl.fbaipublicfiles.com/pycls/dds_baselines/160906838/RegNetY-4.0GF_dds_8gpu.pyth',
+    '6400x': 'https://dl.fbaipublicfiles.com/pycls/dds_baselines/161116590/RegNetX-6.4GF_dds_8gpu.pyth',
+    '6400y': 'https://dl.fbaipublicfiles.com/pycls/dds_baselines/160907112/RegNetY-6.4GF_dds_8gpu.pyth',
+}
 
 def init_weights(m):
     """Performs ResNet-style weight initialization."""
@@ -464,14 +475,60 @@ class RegNet(AnyNet):
         super(RegNet, self).__init__(**kwargs)
 
 
+def init_pretrained_weights(key):
+    """Initializes model with pretrained weights.
+
+    Layers that don't match with pretrained layers in name or size are kept unchanged.
+    """
+    import os
+    import errno
+    import gdown
+
+    def _get_torch_home():
+        ENV_TORCH_HOME = 'TORCH_HOME'
+        ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME'
+        DEFAULT_CACHE_DIR = '~/.cache'
+        torch_home = os.path.expanduser(
+            os.getenv(
+                ENV_TORCH_HOME,
+                os.path.join(
+                    os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'torch'
+                )
+            )
+        )
+        return torch_home
+
+    torch_home = _get_torch_home()
+    model_dir = os.path.join(torch_home, 'checkpoints')
+    try:
+        os.makedirs(model_dir)
+    except OSError as e:
+        if e.errno == errno.EEXIST:
+            # Directory already exists, ignore.
+            pass
+        else:
+            # Unexpected OSError, re-raise.
+            raise
+
+    filename = model_urls[key].split('/')[-1]
+
+    cached_file = os.path.join(model_dir, filename)
+
+    if not os.path.exists(cached_file):
+        gdown.download(model_urls[key], cached_file, quiet=False)
+
+    logger.info(f"Loading pretrained model from {cached_file}")
+    state_dict = torch.load(cached_file)['model_state']
+
+    return state_dict
+
 @BACKBONE_REGISTRY.register()
 def build_regnet_backbone(cfg):
     # fmt: off
     pretrain = cfg.MODEL.BACKBONE.PRETRAIN
-    pretrain_path = cfg.MODEL.BACKBONE.PRETRAIN_PATH
     last_stride = cfg.MODEL.BACKBONE.LAST_STRIDE
     bn_norm = cfg.MODEL.BACKBONE.NORM
-    volume = cfg.MODEL.BACKBONE.VOLUME
+    depth = cfg.MODEL.BACKBONE.DEPTH
 
     cfg_files = {
         '800x': 'fastreid/modeling/backbones/regnet/regnetx/RegNetX-800MF_dds_8gpu.yaml',
@@ -480,21 +537,18 @@ def build_regnet_backbone(cfg):
         '1600y': 'fastreid/modeling/backbones/regnet/regnety/RegNetY-1.6GF_dds_8gpu.yaml',
         '3200x': 'fastreid/modeling/backbones/regnet/regnetx/RegNetX-3.2GF_dds_8gpu.yaml',
         '3200y': 'fastreid/modeling/backbones/regnet/regnety/RegNetY-3.2GF_dds_8gpu.yaml',
+        '4000x': 'fastreid/modeling/backbones/regnet/regnety/RegNetX-4.0GF_dds_8gpu.yaml',
+        '4000y': 'fastreid/modeling/backbones/regnet/regnety/RegNetY-4.0GF_dds_8gpu.yaml',
         '6400x': 'fastreid/modeling/backbones/regnet/regnetx/RegNetX-6.4GF_dds_8gpu.yaml',
         '6400y': 'fastreid/modeling/backbones/regnet/regnety/RegNetY-6.4GF_dds_8gpu.yaml',
-    }[volume]
+    }[depth]
 
     regnet_cfg.merge_from_file(cfg_files)
     model = RegNet(last_stride, bn_norm)
 
     if pretrain:
-        try:
-            state_dict = torch.load(pretrain_path, map_location=torch.device('cpu'))['model_state']
-        except FileNotFoundError as e:
-            logger.info(f'{pretrain_path} is not found! Please check this path.')
-            raise e
-
-        logger.info(f"Loading pretrained model from {pretrain_path}")
+        key = depth
+        state_dict = init_pretrained_weights(key)
 
         incompatible = model.load_state_dict(state_dict, strict=False)
         if incompatible.missing_keys:
diff --git a/fastreid/modeling/backbones/resnet.py b/fastreid/modeling/backbones/resnet.py
index 47d7846..1cae38f 100644
--- a/fastreid/modeling/backbones/resnet.py
+++ b/fastreid/modeling/backbones/resnet.py
@@ -9,7 +9,6 @@ import math
 
 import torch
 from torch import nn
-from torch.utils import model_zoo
 
 from fastreid.layers import (
     IBN,
@@ -22,11 +21,15 @@ from .build import BACKBONE_REGISTRY
 
 logger = logging.getLogger(__name__)
 model_urls = {
-    18: 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
-    34: 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
-    50: 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
-    101: 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
-    152: 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
+    '18x': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
+    '34x': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
+    '50x': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
+    '101x': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
+    'ibn_18x': 'https://github.com/XingangPan/IBN-Net/releases/download/v1.0/resnet18_ibn_a-2f571257.pth',
+    'ibn_34x': 'https://github.com/XingangPan/IBN-Net/releases/download/v1.0/resnet34_ibn_a-94bc1577.pth',
+    'ibn_50x': 'https://github.com/XingangPan/IBN-Net/releases/download/v1.0/resnet50_ibn_a-d9d0bb7b.pth',
+    'ibn_101x': 'https://github.com/XingangPan/IBN-Net/releases/download/v1.0/resnet101_ibn_a-59ea0ac6.pth',
+    'se_ibn_101x': 'https://github.com/XingangPan/IBN-Net/releases/download/v1.0/se_resnet101_ibn_a-fabed4e2.pth',
 }
 
 
@@ -37,7 +40,10 @@ class BasicBlock(nn.Module):
                  stride=1, downsample=None, reduction=16):
         super(BasicBlock, self).__init__()
         self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
-        self.bn1 = get_norm(bn_norm, planes, num_splits)
+        if with_ibn:
+            self.bn1 = IBN(planes, bn_norm, num_splits)
+        else:
+            self.bn1 = get_norm(bn_norm, planes, num_splits)
         self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
         self.bn2 = get_norm(bn_norm, planes, num_splits)
         self.relu = nn.ReLU(inplace=True)
@@ -146,8 +152,6 @@ class ResNet(nn.Module):
             )
 
         layers = []
-        if planes == 512:
-            with_ibn = False
         layers.append(block(self.inplanes, planes, bn_norm, num_splits, with_ibn, with_se, stride, downsample))
         self.inplanes = planes * block.expansion
         for i in range(1, blocks):
@@ -227,6 +231,54 @@ class ResNet(nn.Module):
                 nn.init.constant_(m.bias, 0)
 
 
+def init_pretrained_weights(key):
+    """Initializes model with pretrained weights.
+
+    Layers that don't match with pretrained layers in name or size are kept unchanged.
+    """
+    import os
+    import errno
+    import gdown
+
+    def _get_torch_home():
+        ENV_TORCH_HOME = 'TORCH_HOME'
+        ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME'
+        DEFAULT_CACHE_DIR = '~/.cache'
+        torch_home = os.path.expanduser(
+            os.getenv(
+                ENV_TORCH_HOME,
+                os.path.join(
+                    os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'torch'
+                )
+            )
+        )
+        return torch_home
+
+    torch_home = _get_torch_home()
+    model_dir = os.path.join(torch_home, 'checkpoints')
+    try:
+        os.makedirs(model_dir)
+    except OSError as e:
+        if e.errno == errno.EEXIST:
+            # Directory already exists, ignore.
+            pass
+        else:
+            # Unexpected OSError, re-raise.
+            raise
+
+    filename = model_urls[key].split('/')[-1]
+
+    cached_file = os.path.join(model_dir, filename)
+
+    if not os.path.exists(cached_file):
+        gdown.download(model_urls[key], cached_file, quiet=False)
+
+    logger.info(f"Loading pretrained model from {cached_file}")
+    state_dict = torch.load(cached_file)
+
+    return state_dict
+
+
 @BACKBONE_REGISTRY.register()
 def build_resnet_backbone(cfg):
     """
@@ -246,13 +298,15 @@ def build_resnet_backbone(cfg):
     with_nl = cfg.MODEL.BACKBONE.WITH_NL
     depth = cfg.MODEL.BACKBONE.DEPTH
 
-    num_blocks_per_stage = {34: [3, 4, 6, 3], 50: [3, 4, 6, 3], 101: [3, 4, 23, 3], 152: [3, 8, 36, 3], }[depth]
-    nl_layers_per_stage = {34: [0, 2, 3, 0], 50: [0, 2, 3, 0], 101: [0, 2, 9, 0]}[depth]
-    block = {34: BasicBlock, 50: Bottleneck, 101: Bottleneck}[depth]
+    num_blocks_per_stage = {'18x': [2, 2, 2, 2], '34x': [3, 4, 6, 3], '50x': [3, 4, 6, 3],
+                            '101x': [3, 4, 23, 3],}[depth]
+    nl_layers_per_stage = {'18x': [0, 0, 0, 0], '34x': [0, 0, 0, 0], '50x': [0, 2, 3, 0], '101x': [0, 2, 9, 0]}[depth]
+    block = {'18x': BasicBlock, '34x': BasicBlock, '50x': Bottleneck, '101x': Bottleneck}[depth]
     model = ResNet(last_stride, bn_norm, num_splits, with_ibn, with_se, with_nl, block,
                    num_blocks_per_stage, nl_layers_per_stage)
     if pretrain:
-        if not with_ibn:
+        # Load pretrain path if specifically
+        if pretrain_path:
             try:
                 state_dict = torch.load(pretrain_path, map_location=torch.device('cpu'))['model']
                 # Remove module.encoder in name
@@ -263,20 +317,19 @@ def build_resnet_backbone(cfg):
                         new_state_dict[new_k] = state_dict[k]
                 state_dict = new_state_dict
                 logger.info(f"Loading pretrained model from {pretrain_path}")
-            except FileNotFoundError or KeyError:
-                # original resnet
-                state_dict = model_zoo.load_url(model_urls[depth])
-                logger.info("Loading pretrained model from torchvision")
+            except FileNotFoundError as e:
+                logger.info(f'{pretrain_path} is not found! Please check this path.')
+                raise e
+            except KeyError as e:
+                logger.info("State dict keys error! Please check the state dict.")
+                raise e
         else:
-            state_dict = torch.load(pretrain_path, map_location=torch.device('cpu'))['state_dict']  # ibn-net
-            # Remove module in name
-            new_state_dict = {}
-            for k in state_dict:
-                new_k = '.'.join(k.split('.')[1:])
-                if new_k in model.state_dict() and (model.state_dict()[new_k].shape == state_dict[k].shape):
-                    new_state_dict[new_k] = state_dict[k]
-            state_dict = new_state_dict
-            logger.info(f"Loading pretrained model from {pretrain_path}")
+            key = depth
+            if with_ibn: key = 'ibn_' + key
+            if with_se:  key = 'se_' + key
+
+            state_dict = init_pretrained_weights(key)
+
         incompatible = model.load_state_dict(state_dict, strict=False)
         if incompatible.missing_keys:
             logger.info(
@@ -286,4 +339,5 @@ def build_resnet_backbone(cfg):
             logger.info(
                 get_unexpected_parameters_message(incompatible.unexpected_keys)
             )
+
     return model
diff --git a/fastreid/modeling/backbones/resnext.py b/fastreid/modeling/backbones/resnext.py
index 88ae3ac..a9b2519 100644
--- a/fastreid/modeling/backbones/resnext.py
+++ b/fastreid/modeling/backbones/resnext.py
@@ -1,21 +1,27 @@
 # encoding: utf-8
 """
 @author:  xingyu liao
-@contact: liaoxingyu5@jd.com
+@contact: sherlockliao01@gmail.com
 """
 
 # based on:
 # https://github.com/XingangPan/IBN-Net/blob/master/models/imagenet/resnext_ibn_a.py
 
-import math
 import logging
-import torch.nn as nn
-import torch.nn.functional as F
-from torch.nn import init
+import math
+
 import torch
-from ...layers import IBN
+import torch.nn as nn
+
+from fastreid.layers import IBN, get_norm
+from fastreid.utils.checkpoint import get_missing_parameters_message, get_unexpected_parameters_message
 from .build import BACKBONE_REGISTRY
 
+logger = logging.getLogger(__name__)
+model_urls = {
+    'ibn_101x': 'https://github.com/XingangPan/IBN-Net/releases/download/v1.0/resnext101_ibn_a-6ace051d.pth',
+}
+
 
 class Bottleneck(nn.Module):
     """
@@ -23,7 +29,8 @@ class Bottleneck(nn.Module):
     """
     expansion = 4
 
-    def __init__(self, inplanes, planes, with_ibn, baseWidth, cardinality, stride=1, downsample=None):
+    def __init__(self, inplanes, planes, bn_norm, num_splits, with_ibn, baseWidth, cardinality, stride=1,
+                 downsample=None):
         """ Constructor
         Args:
             inplanes: input channel dimensionality
@@ -38,13 +45,13 @@ class Bottleneck(nn.Module):
         C = cardinality
         self.conv1 = nn.Conv2d(inplanes, D * C, kernel_size=1, stride=1, padding=0, bias=False)
         if with_ibn:
-            self.bn1 = IBN(D * C)
+            self.bn1 = IBN(D * C, bn_norm, num_splits)
         else:
-            self.bn1 = nn.BatchNorm2d(D * C)
+            self.bn1 = get_norm(bn_norm, D * C, num_splits)
         self.conv2 = nn.Conv2d(D * C, D * C, kernel_size=3, stride=stride, padding=1, groups=C, bias=False)
-        self.bn2 = nn.BatchNorm2d(D * C)
+        self.bn2 = get_norm(bn_norm, D * C, num_splits)
         self.conv3 = nn.Conv2d(D * C, planes * 4, kernel_size=1, stride=1, padding=0, bias=False)
-        self.bn3 = nn.BatchNorm2d(planes * 4)
+        self.bn3 = get_norm(bn_norm, planes * 4, num_splits)
         self.relu = nn.ReLU(inplace=True)
 
         self.downsample = downsample
@@ -78,7 +85,7 @@ class ResNeXt(nn.Module):
     https://arxiv.org/pdf/1611.05431.pdf
     """
 
-    def __init__(self, last_stride, with_ibn, block, layers, baseWidth=4, cardinality=32):
+    def __init__(self, last_stride, bn_norm, num_splits, with_ibn, block, layers, baseWidth=4, cardinality=32):
         """ Constructor
         Args:
             baseWidth: baseWidth for ResNeXt.
@@ -94,17 +101,17 @@ class ResNeXt(nn.Module):
         self.output_size = 64
 
         self.conv1 = nn.Conv2d(3, 64, 7, 2, 3, bias=False)
-        self.bn1 = nn.BatchNorm2d(64)
+        self.bn1 = get_norm(bn_norm, 64, num_splits)
         self.relu = nn.ReLU(inplace=True)
         self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
-        self.layer1 = self._make_layer(block, 64, layers[0], with_ibn=with_ibn)
-        self.layer2 = self._make_layer(block, 128, layers[1], stride=2, with_ibn=with_ibn)
-        self.layer3 = self._make_layer(block, 256, layers[2], stride=2, with_ibn=with_ibn)
-        self.layer4 = self._make_layer(block, 512, layers[3], stride=last_stride, with_ibn=with_ibn)
+        self.layer1 = self._make_layer(block, 64, layers[0], 1, bn_norm, num_splits, with_ibn=with_ibn)
+        self.layer2 = self._make_layer(block, 128, layers[1], 2, bn_norm, num_splits, with_ibn=with_ibn)
+        self.layer3 = self._make_layer(block, 256, layers[2], 2, bn_norm, num_splits, with_ibn=with_ibn)
+        self.layer4 = self._make_layer(block, 512, layers[3], last_stride, bn_norm, num_splits, with_ibn=with_ibn)
 
         self.random_init()
 
-    def _make_layer(self, block, planes, blocks, stride=1, with_ibn=False):
+    def _make_layer(self, block, planes, blocks, stride=1, bn_norm='BN', num_splits=1, with_ibn=False):
         """ Stack n bottleneck modules where n is inferred from the depth of the network.
         Args:
             block: block type used to construct ResNext
@@ -118,16 +125,18 @@ class ResNeXt(nn.Module):
             downsample = nn.Sequential(
                 nn.Conv2d(self.inplanes, planes * block.expansion,
                           kernel_size=1, stride=stride, bias=False),
-                nn.BatchNorm2d(planes * block.expansion),
+                get_norm(bn_norm, planes * block.expansion, num_splits),
             )
 
         layers = []
         if planes == 512:
             with_ibn = False
-        layers.append(block(self.inplanes, planes, with_ibn, self.baseWidth, self.cardinality, stride, downsample))
+        layers.append(block(self.inplanes, planes, bn_norm, num_splits, with_ibn,
+                            self.baseWidth, self.cardinality, stride, downsample))
         self.inplanes = planes * block.expansion
         for i in range(1, blocks):
-            layers.append(block(self.inplanes, planes, with_ibn, self.baseWidth, self.cardinality, 1, None))
+            layers.append(
+                block(self.inplanes, planes, bn_norm, num_splits, with_ibn, self.baseWidth, self.cardinality, 1, None))
 
         return nn.Sequential(*layers)
 
@@ -157,6 +166,53 @@ class ResNeXt(nn.Module):
                 m.bias.data.zero_()
 
 
+def init_pretrained_weights(key):
+    """Initializes model with pretrained weights.
+
+    Layers that don't match with pretrained layers in name or size are kept unchanged.
+    """
+    import os
+    import errno
+    import gdown
+
+    def _get_torch_home():
+        ENV_TORCH_HOME = 'TORCH_HOME'
+        ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME'
+        DEFAULT_CACHE_DIR = '~/.cache'
+        torch_home = os.path.expanduser(
+            os.getenv(
+                ENV_TORCH_HOME,
+                os.path.join(
+                    os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'torch'
+                )
+            )
+        )
+        return torch_home
+
+    torch_home = _get_torch_home()
+    model_dir = os.path.join(torch_home, 'checkpoints')
+    try:
+        os.makedirs(model_dir)
+    except OSError as e:
+        if e.errno == errno.EEXIST:
+            # Directory already exists, ignore.
+            pass
+        else:
+            # Unexpected OSError, re-raise.
+            raise
+
+    filename = model_urls[key].split('/')[-1]
+
+    cached_file = os.path.join(model_dir, filename)
+
+    if not os.path.exists(cached_file):
+        gdown.download(model_urls[key], cached_file, quiet=False)
+
+    logger.info(f"Loading pretrained model from {cached_file}")
+    state_dict = torch.load(cached_file)
+
+    return state_dict
+
 @BACKBONE_REGISTRY.register()
 def build_resnext_backbone(cfg):
     """
@@ -169,30 +225,48 @@ def build_resnext_backbone(cfg):
     pretrain = cfg.MODEL.BACKBONE.PRETRAIN
     pretrain_path = cfg.MODEL.BACKBONE.PRETRAIN_PATH
     last_stride = cfg.MODEL.BACKBONE.LAST_STRIDE
+    bn_norm = cfg.MODEL.BACKBONE.NORM
+    num_splits = cfg.MODEL.BACKBONE.NORM_SPLIT
     with_ibn = cfg.MODEL.BACKBONE.WITH_IBN
-    with_se = cfg.MODEL.BACKBONE.WITH_SE
     with_nl = cfg.MODEL.BACKBONE.WITH_NL
     depth = cfg.MODEL.BACKBONE.DEPTH
 
-    num_blocks_per_stage = {50: [3, 4, 6, 3], 101: [3, 4, 23, 3], 152: [3, 8, 36, 3], }[depth]
-    nl_layers_per_stage = {50: [0, 2, 3, 0], 101: [0, 2, 3, 0]}[depth]
-    model = ResNeXt(last_stride, with_ibn, Bottleneck, num_blocks_per_stage)
+    num_blocks_per_stage = {'50x': [3, 4, 6, 3], '101x': [3, 4, 23, 3], '152x': [3, 8, 36, 3], }[depth]
+    nl_layers_per_stage = {'50x': [0, 2, 3, 0], '101x': [0, 2, 3, 0]}[depth]
+    model = ResNeXt(last_stride, bn_norm, num_splits, with_ibn, Bottleneck, num_blocks_per_stage)
+
     if pretrain:
-        # if not with_ibn:
-        # original resnet
-        # state_dict = model_zoo.load_url(model_urls[depth])
-        # else:
-        # ibn resnet
-        state_dict = torch.load(pretrain_path)['state_dict']
-        # remove module in name
-        new_state_dict = {}
-        for k in state_dict:
-            new_k = '.'.join(k.split('.')[1:])
-            if new_k in model.state_dict() and (model.state_dict()[new_k].shape == state_dict[k].shape):
-                new_state_dict[new_k] = state_dict[k]
-        state_dict = new_state_dict
-        res = model.load_state_dict(state_dict, strict=False)
-        logger = logging.getLogger(__name__)
-        logger.info('missing keys is {}'.format(res.missing_keys))
-        logger.info('unexpected keys is {}'.format(res.unexpected_keys))
+        if pretrain_path:
+            try:
+                state_dict = torch.load(pretrain_path, map_location=torch.device('cpu'))['model']
+                # Remove module.encoder in name
+                new_state_dict = {}
+                for k in state_dict:
+                    new_k = '.'.join(k.split('.')[2:])
+                    if new_k in model.state_dict() and (model.state_dict()[new_k].shape == state_dict[k].shape):
+                        new_state_dict[new_k] = state_dict[k]
+                state_dict = new_state_dict
+                logger.info(f"Loading pretrained model from {pretrain_path}")
+            except FileNotFoundError as e:
+                logger.info(f'{pretrain_path} is not found! Please check this path.')
+                raise e
+            except KeyError as e:
+                logger.info("State dict keys error! Please check the state dict.")
+                raise e
+        else:
+            key = depth
+            if with_ibn: key = 'ibn_' + key
+
+            state_dict = init_pretrained_weights(key)
+
+        incompatible = model.load_state_dict(state_dict, strict=False)
+        if incompatible.missing_keys:
+            logger.info(
+                get_missing_parameters_message(incompatible.missing_keys)
+            )
+        if incompatible.unexpected_keys:
+            logger.info(
+                get_unexpected_parameters_message(incompatible.unexpected_keys)
+            )
+
     return model
diff --git a/fastreid/modeling/heads/bnneck_head.py b/fastreid/modeling/heads/bnneck_head.py
index 2d410bc..5b46c6a 100644
--- a/fastreid/modeling/heads/bnneck_head.py
+++ b/fastreid/modeling/heads/bnneck_head.py
@@ -24,9 +24,10 @@ class BNneckHead(nn.Module):
         if cls_type == 'linear':          self.classifier = nn.Linear(in_feat, num_classes, bias=False)
         elif cls_type == 'arcSoftmax':    self.classifier = ArcSoftmax(cfg, in_feat, num_classes)
         elif cls_type == 'circleSoftmax': self.classifier = CircleSoftmax(cfg, in_feat, num_classes)
+        elif cls_type == 'amSoftmax':     self.classifier = AMSoftmax(cfg, in_feat, num_classes)
         else:
             raise KeyError(f"{cls_type} is invalid, please choose from "
-                           f"'linear', 'arcSoftmax' and 'circleSoftmax'.")
+                           f"'linear', 'arcSoftmax', 'amSoftmax' and 'circleSoftmax'.")
 
         self.classifier.apply(weights_init_classifier)
 
diff --git a/fastreid/modeling/heads/linear_head.py b/fastreid/modeling/heads/linear_head.py
index 2641407..b22bd7f 100644
--- a/fastreid/modeling/heads/linear_head.py
+++ b/fastreid/modeling/heads/linear_head.py
@@ -20,9 +20,10 @@ class LinearHead(nn.Module):
         if cls_type == 'linear':          self.classifier = nn.Linear(in_feat, num_classes, bias=False)
         elif cls_type == 'arcSoftmax':    self.classifier = ArcSoftmax(cfg, in_feat, num_classes)
         elif cls_type == 'circleSoftmax': self.classifier = CircleSoftmax(cfg, in_feat, num_classes)
+        elif cls_type == 'amSoftmax':     self.classifier = AMSoftmax(cfg, in_feat, num_classes)
         else:
             raise KeyError(f"{cls_type} is invalid, please choose from "
-                           f"'linear', 'arcSoftmax' and 'circleSoftmax'.")
+                           f"'linear', 'arcSoftmax', 'amSoftmax' and 'circleSoftmax'.")
 
         self.classifier.apply(weights_init_classifier)
 
diff --git a/fastreid/modeling/heads/reduction_head.py b/fastreid/modeling/heads/reduction_head.py
index af6ea40..7d55b3f 100644
--- a/fastreid/modeling/heads/reduction_head.py
+++ b/fastreid/modeling/heads/reduction_head.py
@@ -32,12 +32,13 @@ class ReductionHead(nn.Module):
 
         # identity classification layer
         cls_type = cfg.MODEL.HEADS.CLS_LAYER
-        if cls_type == 'linear':          self.classifier = nn.Linear(in_feat, num_classes, bias=False)
-        elif cls_type == 'arcSoftmax':    self.classifier = ArcSoftmax(cfg, in_feat, num_classes)
-        elif cls_type == 'circleSoftmax': self.classifier = CircleSoftmax(cfg, in_feat, num_classes)
+        if cls_type == 'linear':          self.classifier = nn.Linear(reduction_dim, num_classes, bias=False)
+        elif cls_type == 'arcSoftmax':    self.classifier = ArcSoftmax(cfg, reduction_dim, num_classes)
+        elif cls_type == 'circleSoftmax': self.classifier = CircleSoftmax(cfg, reduction_dim, num_classes)
+        elif cls_type == 'amSoftmax':     self.classifier = AMSoftmax(cfg, reduction_dim, num_classes)
         else:
             raise KeyError(f"{cls_type} is invalid, please choose from "
-                           f"'linear', 'arcSoftmax' and 'circleSoftmax'.")
+                           f"'linear', 'arcSoftmax', 'amSoftmax' and 'circleSoftmax'.")
 
         self.classifier.apply(weights_init_classifier)
 
diff --git a/fastreid/modeling/losses/__init__.py b/fastreid/modeling/losses/__init__.py
index 977e5c3..7411aff 100644
--- a/fastreid/modeling/losses/__init__.py
+++ b/fastreid/modeling/losses/__init__.py
@@ -6,4 +6,5 @@
 
 from .cross_entroy_loss import CrossEntropyLoss
 from .focal_loss import FocalLoss
-from .metric_loss import TripletLoss, CircleLoss
+from .triplet_loss import TripletLoss
+from .circle_loss import CircleLoss
diff --git a/fastreid/modeling/losses/circle_loss.py b/fastreid/modeling/losses/circle_loss.py
new file mode 100644
index 0000000..b945659
--- /dev/null
+++ b/fastreid/modeling/losses/circle_loss.py
@@ -0,0 +1,61 @@
+# encoding: utf-8
+"""
+@author:  xingyu liao
+@contact: sherlockliao01@gmail.com
+"""
+
+import torch
+from torch import nn
+import torch.nn.functional as F
+
+from fastreid.utils import comm
+from .utils import concat_all_gather
+
+
+class CircleLoss(object):
+    def __init__(self, cfg):
+        self._scale = cfg.MODEL.LOSSES.CIRCLE.SCALE
+
+        self._m = cfg.MODEL.LOSSES.CIRCLE.MARGIN
+        self._s = cfg.MODEL.LOSSES.CIRCLE.ALPHA
+
+    def __call__(self, embedding, targets):
+        embedding = nn.functional.normalize(embedding, dim=1)
+
+        if comm.get_world_size() > 1:
+            all_embedding = concat_all_gather(embedding)
+            all_targets = concat_all_gather(targets)
+        else:
+            all_embedding = embedding
+            all_targets = targets
+
+        dist_mat = torch.matmul(embedding, all_embedding.t())
+
+        N, M = dist_mat.size()
+        is_pos = targets.view(N, 1).expand(N, M).eq(all_targets.view(M, 1).expand(M, N).t()).float()
+
+        # Compute the mask which ignores the relevance score of the query to itself
+        if M > N:
+            identity_indx = torch.eye(N, N, device=is_pos.device)
+            remain_indx = torch.zeros(N, M - N, device=is_pos.device)
+            identity_indx = torch.cat((identity_indx, remain_indx), dim=1)
+            is_pos = is_pos - identity_indx
+        else:
+            is_pos = is_pos - torch.eye(N, N, device=is_pos.device)
+
+        is_neg = targets.view(N, 1).expand(N, M).ne(all_targets.view(M, 1).expand(M, N).t())
+
+        s_p = dist_mat * is_pos
+        s_n = dist_mat * is_neg
+
+        alpha_p = torch.clamp_min(-s_p.detach() + 1 + self._m, min=0.)
+        alpha_n = torch.clamp_min(s_n.detach() + self._m, min=0.)
+        delta_p = 1 - self._m
+        delta_n = self._m
+
+        logit_p = - self._s * alpha_p * (s_p - delta_p)
+        logit_n = self._s * alpha_n * (s_n - delta_n)
+
+        loss = nn.functional.softplus(torch.logsumexp(logit_p, dim=1) + torch.logsumexp(logit_n, dim=1)).mean()
+
+        return loss * self._scale
diff --git a/fastreid/modeling/losses/cross_entroy_loss.py b/fastreid/modeling/losses/cross_entroy_loss.py
index 8a09843..c86e47e 100644
--- a/fastreid/modeling/losses/cross_entroy_loss.py
+++ b/fastreid/modeling/losses/cross_entroy_loss.py
@@ -58,5 +58,11 @@ class CrossEntropyLoss(object):
             targets *= smooth_param / (self._num_classes - 1)
             targets.scatter_(1, gt_classes.data.unsqueeze(1), (1 - smooth_param))
 
-        loss = (-targets * log_probs).mean(0).sum()
+        loss = (-targets * log_probs).sum(dim=1)
+
+        with torch.no_grad():
+            non_zero_cnt = max(loss.nonzero().size(0), 1)
+
+        loss = loss.sum() / non_zero_cnt
+
         return loss * self._scale
diff --git a/fastreid/modeling/losses/smooth_ap.py b/fastreid/modeling/losses/smooth_ap.py
new file mode 100644
index 0000000..6305ca7
--- /dev/null
+++ b/fastreid/modeling/losses/smooth_ap.py
@@ -0,0 +1,241 @@
+# encoding: utf-8
+"""
+@author:  xingyu liao
+@contact: sherlockliao01@gmail.com
+"""
+
+# based on:
+# https://github.com/Andrew-Brown1/Smooth_AP/blob/master/src/Smooth_AP_loss.py
+
+import torch
+import torch.nn.functional as F
+
+from fastreid.utils import comm
+from fastreid.modeling.losses.utils import concat_all_gather
+
+
+def sigmoid(tensor, temp=1.0):
+    """ temperature controlled sigmoid
+    takes as input a torch tensor (tensor) and passes it through a sigmoid, controlled by temperature: temp
+    """
+    exponent = -tensor / temp
+    # clamp the input tensor for stability
+    exponent = torch.clamp(exponent, min=-50, max=50)
+    y = 1.0 / (1.0 + torch.exp(exponent))
+    return y
+
+
+class SmoothAP(object):
+    r"""PyTorch implementation of the Smooth-AP loss.
+    implementation of the Smooth-AP loss. Takes as input the mini-batch of CNN-produced feature embeddings and returns
+    the value of the Smooth-AP loss. The mini-batch must be formed of a defined number of classes. Each class must
+    have the same number of instances represented in the mini-batch and must be ordered sequentially by class.
+    e.g. the labels for a mini-batch with batch size 9, and 3 represented classes (A,B,C) must look like:
+        labels = ( A, A, A, B, B, B, C, C, C)
+    (the order of the classes however does not matter)
+    For each instance in the mini-batch, the loss computes the Smooth-AP when it is used as the query and the rest of the
+    mini-batch is used as the retrieval set. The positive set is formed of the other instances in the batch from the
+    same class. The loss returns the average Smooth-AP across all instances in the mini-batch.
+    Args:
+        anneal : float
+            the temperature of the sigmoid that is used to smooth the ranking function. A low value of the temperature
+            results in a steep sigmoid, that tightly approximates the heaviside step function in the ranking function.
+        batch_size : int
+            the batch size being used during training.
+        num_id : int
+            the number of different classes that are represented in the batch.
+        feat_dims : int
+            the dimension of the input feature embeddings
+    Shape:
+        - Input (preds): (batch_size, feat_dims) (must be a cuda torch float tensor)
+        - Output: scalar
+    Examples::
+        >>> loss = SmoothAP(0.01, 60, 6, 256)
+        >>> input = torch.randn(60, 256, requires_grad=True).cuda()
+        >>> output = loss(input)
+        >>> output.backward()
+    """
+
+    def __init__(self, cfg):
+        r"""
+        Parameters
+        ----------
+        cfg: (cfgNode)
+
+        anneal : float
+            the temperature of the sigmoid that is used to smooth the ranking function
+        batch_size : int
+            the batch size being used
+        num_id : int
+            the number of different classes that are represented in the batch
+        feat_dims : int
+            the dimension of the input feature embeddings
+        """
+
+        self.anneal = 0.01
+        self.num_id = cfg.SOLVER.IMS_PER_BATCH // cfg.DATALOADER.NUM_INSTANCE
+        # self.num_id = 6
+
+    def __call__(self, embedding, targets):
+        """Forward pass for all input predictions: preds - (batch_size x feat_dims) """
+
+        # ------ differentiable ranking of all retrieval set ------
+        embedding = F.normalize(embedding, dim=1)
+
+        feat_dim = embedding.size(1)
+
+        # For distributed training, gather all features from different process.
+        if comm.get_world_size() > 1:
+            all_embedding = concat_all_gather(embedding)
+            all_targets = concat_all_gather(targets)
+        else:
+            all_embedding = embedding
+            all_targets = targets
+
+        sim_dist = torch.matmul(embedding, all_embedding.t())
+        N, M = sim_dist.size()
+
+        # Compute the mask which ignores the relevance score of the query to itself
+        mask_indx = 1.0 - torch.eye(M, device=sim_dist.device)
+        mask_indx = mask_indx.unsqueeze(dim=0).repeat(N, 1, 1)  # (N, M, M)
+
+        # sim_dist -> N, 1, M -> N, M, N
+        sim_dist_repeat = sim_dist.unsqueeze(dim=1).repeat(1, M, 1)  # (N, M, M)
+        # sim_dist_repeat_t = sim_dist.t().unsqueeze(dim=1).repeat(1, N, 1)  # (N, N, M)
+
+        # Compute the difference matrix
+        sim_diff = sim_dist_repeat - sim_dist_repeat.permute(0, 2, 1)  # (N, M, M)
+
+        # Pass through the sigmoid
+        sim_sg = sigmoid(sim_diff, temp=self.anneal) * mask_indx
+
+        # Compute all the rankings
+        sim_all_rk = torch.sum(sim_sg, dim=-1) + 1  # (N, N)
+
+        pos_mask = targets.view(N, 1).expand(N, M).eq(all_targets.view(M, 1).expand(M, N).t()).float()  # (N, M)
+
+        pos_mask_repeat = pos_mask.unsqueeze(1).repeat(1, M, 1)  # (N, M, M)
+
+        # Compute positive rankings
+        pos_sim_sg = sim_sg * pos_mask_repeat
+        sim_pos_rk = torch.sum(pos_sim_sg, dim=-1) + 1  # (N, N)
+
+        # sum the values of the Smooth-AP for all instances in the mini-batch
+        ap = 0
+        group = N // self.num_id
+        for ind in range(self.num_id):
+            pos_divide = torch.sum(
+                sim_pos_rk[(ind * group):((ind + 1) * group), (ind * group):((ind + 1) * group)] / (sim_all_rk[(ind * group):((ind + 1) * group), (ind * group):((ind + 1) * group)]))
+            ap += pos_divide / torch.sum(pos_mask[ind*group]) / N
+        return 1 - ap
+
+
+class SmoothAP_old(torch.nn.Module):
+    """PyTorch implementation of the Smooth-AP loss.
+    implementation of the Smooth-AP loss. Takes as input the mini-batch of CNN-produced feature embeddings and returns
+    the value of the Smooth-AP loss. The mini-batch must be formed of a defined number of classes. Each class must
+    have the same number of instances represented in the mini-batch and must be ordered sequentially by class.
+    e.g. the labels for a mini-batch with batch size 9, and 3 represented classes (A,B,C) must look like:
+        labels = ( A, A, A, B, B, B, C, C, C)
+    (the order of the classes however does not matter)
+    For each instance in the mini-batch, the loss computes the Smooth-AP when it is used as the query and the rest of the
+    mini-batch is used as the retrieval set. The positive set is formed of the other instances in the batch from the
+    same class. The loss returns the average Smooth-AP across all instances in the mini-batch.
+    Args:
+        anneal : float
+            the temperature of the sigmoid that is used to smooth the ranking function. A low value of the temperature
+            results in a steep sigmoid, that tightly approximates the heaviside step function in the ranking function.
+        batch_size : int
+            the batch size being used during training.
+        num_id : int
+            the number of different classes that are represented in the batch.
+        feat_dims : int
+            the dimension of the input feature embeddings
+    Shape:
+        - Input (preds): (batch_size, feat_dims) (must be a cuda torch float tensor)
+        - Output: scalar
+    Examples::
+        >>> loss = SmoothAP(0.01, 60, 6, 256)
+        >>> input = torch.randn(60, 256, requires_grad=True).cuda()
+        >>> output = loss(input)
+        >>> output.backward()
+    """
+
+    def __init__(self, anneal, batch_size, num_id, feat_dims):
+        """
+        Parameters
+        ----------
+        anneal : float
+            the temperature of the sigmoid that is used to smooth the ranking function
+        batch_size : int
+            the batch size being used
+        num_id : int
+            the number of different classes that are represented in the batch
+        feat_dims : int
+            the dimension of the input feature embeddings
+        """
+        super().__init__()
+
+        assert(batch_size%num_id==0)
+
+        self.anneal = anneal
+        self.batch_size = batch_size
+        self.num_id = num_id
+        self.feat_dims = feat_dims
+
+    def forward(self, preds):
+        """Forward pass for all input predictions: preds - (batch_size x feat_dims) """
+
+        preds = F.normalize(preds, dim=1)
+        # ------ differentiable ranking of all retrieval set ------
+        # compute the mask which ignores the relevance score of the query to itself
+        mask = 1.0 - torch.eye(self.batch_size)
+        mask = mask.unsqueeze(dim=0).repeat(self.batch_size, 1, 1)
+        # compute the relevance scores via cosine similarity of the CNN-produced embedding vectors
+        sim_all = torch.mm(preds, preds.t())
+        sim_all_repeat = sim_all.unsqueeze(dim=1).repeat(1, self.batch_size, 1)
+        # compute the difference matrix
+        sim_diff = sim_all_repeat - sim_all_repeat.permute(0, 2, 1)
+        # pass through the sigmoid
+        sim_sg = sigmoid(sim_diff, temp=self.anneal) * mask
+        # compute the rankings
+        sim_all_rk = torch.sum(sim_sg, dim=-1) + 1
+
+        # ------ differentiable ranking of only positive set in retrieval set ------
+        # compute the mask which only gives non-zero weights to the positive set
+        xs = preds.view(self.num_id, int(self.batch_size / self.num_id), self.feat_dims)
+        pos_mask = 1.0 - torch.eye(int(self.batch_size / self.num_id))
+        pos_mask = pos_mask.unsqueeze(dim=0).unsqueeze(dim=0).repeat(self.num_id, int(self.batch_size / self.num_id), 1, 1)
+        # compute the relevance scores
+        sim_pos = torch.bmm(xs, xs.permute(0, 2, 1))
+        sim_pos_repeat = sim_pos.unsqueeze(dim=2).repeat(1, 1, int(self.batch_size / self.num_id), 1)
+        # compute the difference matrix
+        sim_pos_diff = sim_pos_repeat - sim_pos_repeat.permute(0, 1, 3, 2)
+        # pass through the sigmoid
+        sim_pos_sg = sigmoid(sim_pos_diff, temp=self.anneal) * pos_mask
+        # compute the rankings of the positive set
+        sim_pos_rk = torch.sum(sim_pos_sg, dim=-1) + 1
+
+        # sum the values of the Smooth-AP for all instances in the mini-batch
+        ap = torch.zeros(1)
+        group = int(self.batch_size / self.num_id)
+        for ind in range(self.num_id):
+            pos_divide = torch.sum(sim_pos_rk[ind] / (sim_all_rk[(ind * group):((ind + 1) * group), (ind * group):((ind + 1) * group)]))
+            ap = ap + ((pos_divide / group) / self.batch_size)
+
+        return 1-ap
+
+if __name__ == '__main__':
+    loss1 = SmoothAP(0.01)
+    loss2 = SmoothAP_old(0.01, 60, 6, 256)
+
+    inputs = torch.randn(60, 256, requires_grad=True)
+    targets = []
+    for i in range(6):
+        targets.extend([i]*10)
+    targets = torch.LongTensor(targets)
+
+    output1 = loss1(inputs, targets)
+    output2 = loss2(inputs)
+
+    print(torch.sum(output1 - output2))
diff --git a/fastreid/modeling/losses/metric_loss.py b/fastreid/modeling/losses/triplet_loss.py
similarity index 63%
rename from fastreid/modeling/losses/metric_loss.py
rename to fastreid/modeling/losses/triplet_loss.py
index cca9edb..5e1aa74 100644
--- a/fastreid/modeling/losses/metric_loss.py
+++ b/fastreid/modeling/losses/triplet_loss.py
@@ -8,51 +8,7 @@ import torch
 import torch.nn.functional as F
 
 from fastreid.utils import comm
-
-
-# utils
-@torch.no_grad()
-def concat_all_gather(tensor):
-    """
-    Performs all_gather operation on the provided tensors.
-    *** Warning ***: torch.distributed.all_gather has no gradient.
-    """
-    tensors_gather = [torch.ones_like(tensor)
-                      for _ in range(torch.distributed.get_world_size())]
-    torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
-
-    output = torch.cat(tensors_gather, dim=0)
-    return output
-
-
-def normalize(x, axis=-1):
-    """Normalizing to unit length along the specified dimension.
-    Args:
-      x: pytorch Variable
-    Returns:
-      x: pytorch Variable, same shape as input
-    """
-    x = 1. * x / (torch.norm(x, 2, axis, keepdim=True).expand_as(x) + 1e-12)
-    return x
-
-
-def euclidean_dist(x, y):
-    m, n = x.size(0), y.size(0)
-    xx = torch.pow(x, 2).sum(1, keepdim=True).expand(m, n)
-    yy = torch.pow(y, 2).sum(1, keepdim=True).expand(n, m).t()
-    dist = xx + yy
-    dist.addmm_(1, -2, x, y.t())
-    dist = dist.clamp(min=1e-12).sqrt()  # for numerical stability
-    return dist
-
-
-def cosine_dist(x, y):
-    bs1, bs2 = x.size(0), y.size(0)
-    frac_up = torch.matmul(x, y.transpose(0, 1))
-    frac_down = (torch.sqrt(torch.sum(torch.pow(x, 2), 1))).view(bs1, 1).repeat(1, bs2) * \
-                (torch.sqrt(torch.sum(torch.pow(y, 2), 1))).view(1, bs2).repeat(bs1, 1)
-    cosine = frac_up / frac_down
-    return 1 - cosine
+from .utils import concat_all_gather, euclidean_dist, normalize
 
 
 def softmax_weights(dist, mask):
@@ -174,42 +130,3 @@ class TripletLoss(object):
             if loss == float('Inf'): loss = F.margin_ranking_loss(dist_an, dist_ap, y, margin=0.3)
 
         return loss * self._scale
-
-
-class CircleLoss(object):
-    def __init__(self, cfg):
-        self._scale = cfg.MODEL.LOSSES.CIRCLE.SCALE
-
-        self.m = cfg.MODEL.LOSSES.CIRCLE.MARGIN
-        self.s = cfg.MODEL.LOSSES.CIRCLE.ALPHA
-
-    def __call__(self, embedding, targets):
-        embedding = F.normalize(embedding, dim=1)
-
-        if comm.get_world_size() > 1:
-            all_embedding = concat_all_gather(embedding)
-            all_targets = concat_all_gather(targets)
-        else:
-            all_embedding = embedding
-            all_targets = targets
-
-        dist_mat = torch.matmul(embedding, all_embedding.t())
-
-        N, M = dist_mat.size()
-        is_pos = targets.view(N, 1).expand(N, M).eq(all_targets.view(M, 1).expand(M, N).t())
-        is_neg = targets.view(N, 1).expand(N, M).ne(all_targets.view(M, 1).expand(M, N).t())
-
-        s_p = dist_mat[is_pos].contiguous().view(N, -1)
-        s_n = dist_mat[is_neg].contiguous().view(N, -1)
-
-        alpha_p = F.relu(-s_p.detach() + 1 + self.m)
-        alpha_n = F.relu(s_n.detach() + self.m)
-        delta_p = 1 - self.m
-        delta_n = self.m
-
-        logit_p = - self.s * alpha_p * (s_p - delta_p)
-        logit_n = self.s * alpha_n * (s_n - delta_n)
-
-        loss = F.softplus(torch.logsumexp(logit_p, dim=1) + torch.logsumexp(logit_n, dim=1)).mean()
-
-        return loss * self._scale
diff --git a/fastreid/modeling/losses/utils.py b/fastreid/modeling/losses/utils.py
new file mode 100644
index 0000000..a0ff648
--- /dev/null
+++ b/fastreid/modeling/losses/utils.py
@@ -0,0 +1,51 @@
+# encoding: utf-8
+"""
+@author:  xingyu liao
+@contact: sherlockliao01@gmail.com
+"""
+
+import torch
+
+
+@torch.no_grad()
+def concat_all_gather(tensor):
+    """
+    Performs all_gather operation on the provided tensors.
+    *** Warning ***: torch.distributed.all_gather has no gradient.
+    """
+    tensors_gather = [torch.ones_like(tensor)
+                      for _ in range(torch.distributed.get_world_size())]
+    torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
+
+    output = torch.cat(tensors_gather, dim=0)
+    return output
+
+
+def normalize(x, axis=-1):
+    """Normalizing to unit length along the specified dimension.
+    Args:
+      x: pytorch Variable
+    Returns:
+      x: pytorch Variable, same shape as input
+    """
+    x = 1. * x / (torch.norm(x, 2, axis, keepdim=True).expand_as(x) + 1e-12)
+    return x
+
+
+def euclidean_dist(x, y):
+    m, n = x.size(0), y.size(0)
+    xx = torch.pow(x, 2).sum(1, keepdim=True).expand(m, n)
+    yy = torch.pow(y, 2).sum(1, keepdim=True).expand(n, m).t()
+    dist = xx + yy
+    dist.addmm_(1, -2, x, y.t())
+    dist = dist.clamp(min=1e-12).sqrt()  # for numerical stability
+    return dist
+
+
+def cosine_dist(x, y):
+    bs1, bs2 = x.size(0), y.size(0)
+    frac_up = torch.matmul(x, y.transpose(0, 1))
+    frac_down = (torch.sqrt(torch.sum(torch.pow(x, 2), 1))).view(bs1, 1).repeat(1, bs2) * \
+                (torch.sqrt(torch.sum(torch.pow(y, 2), 1))).view(1, bs2).repeat(bs1, 1)
+    cosine = frac_up / frac_down
+    return 1 - cosine
diff --git a/fastreid/modeling/meta_arch/baseline.py b/fastreid/modeling/meta_arch/baseline.py
index aa59460..543cb0c 100644
--- a/fastreid/modeling/meta_arch/baseline.py
+++ b/fastreid/modeling/meta_arch/baseline.py
@@ -28,7 +28,8 @@ class Baseline(nn.Module):
 
         # head
         pool_type = cfg.MODEL.HEADS.POOL_LAYER
-        if pool_type == 'avgpool':      pool_layer = FastGlobalAvgPool2d()
+        if pool_type == 'fastavgpool':  pool_layer = FastGlobalAvgPool2d()
+        elif pool_type == 'avgpool':    pool_layer = nn.AdaptiveAvgPool2d(1)
         elif pool_type == 'maxpool':    pool_layer = nn.AdaptiveMaxPool2d(1)
         elif pool_type == 'gempool':    pool_layer = GeneralizedMeanPoolingP()
         elif pool_type == "avgmaxpool": pool_layer = AdaptiveAvgMaxPool2d()
@@ -66,8 +67,10 @@ class Baseline(nn.Module):
         """
         Normalize and batch the input images.
         """
-        images = batched_inputs["images"].to(self.device)
-        # images = batched_inputs
+        if isinstance(batched_inputs, dict):
+            images = batched_inputs["images"].to(self.device)
+        elif isinstance(batched_inputs, torch.Tensor):
+            images = batched_inputs.to(self.device)
         images.sub_(self.pixel_mean).div_(self.pixel_std)
         return images
 
@@ -89,4 +92,7 @@ class Baseline(nn.Module):
         if "TripletLoss" in loss_names:
             loss_dict['loss_triplet'] = TripletLoss(self._cfg)(pred_features, gt_labels)
 
+        if "CircleLoss" in loss_names:
+            loss_dict['loss_circle'] = CircleLoss(self._cfg)(pred_features, gt_labels)
+
         return loss_dict
diff --git a/fastreid/solver/optim/adam.py b/fastreid/solver/optim/adam.py
index fd051f1..a9b83e3 100644
--- a/fastreid/solver/optim/adam.py
+++ b/fastreid/solver/optim/adam.py
@@ -1,7 +1,7 @@
 # encoding: utf-8
 """
 @author:  xingyu liao
-@contact: liaoxingyu5@jd.com
+@contact: sherlockliao01@gmail.com
 """
 
 import torch
diff --git a/fastreid/solver/optim/sgd.py b/fastreid/solver/optim/sgd.py
index fbff412..9e31bc0 100644
--- a/fastreid/solver/optim/sgd.py
+++ b/fastreid/solver/optim/sgd.py
@@ -1,7 +1,7 @@
 # encoding: utf-8
 """
 @author:  xingyu liao
-@contact: liaoxingyu5@jd.com
+@contact: sherlockliao01@gmail.com
 """
 
 
diff --git a/fastreid/solver/optim/swa.py b/fastreid/solver/optim/swa.py
index 889239a..1d45e02 100644
--- a/fastreid/solver/optim/swa.py
+++ b/fastreid/solver/optim/swa.py
@@ -1,7 +1,7 @@
 # encoding: utf-8
 """
 @author:  xingyu liao
-@contact: liaoxingyu5@jd.com
+@contact: sherlockliao01@gmail.com
 """
 # based on:
 # https://github.com/pytorch/contrib/blob/master/torchcontrib/optim/swa.py
diff --git a/fastreid/utils/collect_env.py b/fastreid/utils/collect_env.py
index 6a7ec32..5affc33 100644
--- a/fastreid/utils/collect_env.py
+++ b/fastreid/utils/collect_env.py
@@ -1,7 +1,7 @@
 # encoding: utf-8
 """
 @author:  xingyu liao
-@contact: liaoxingyu5@jd.com
+@contact: sherlockliao01@gmail.com
 """
 
 # based on
diff --git a/fastreid/utils/weight_init.py b/fastreid/utils/weight_init.py
index 989c1df..0bd22e3 100644
--- a/fastreid/utils/weight_init.py
+++ b/fastreid/utils/weight_init.py
@@ -1,7 +1,7 @@
 # encoding: utf-8
 """
 @author:  xingyu liao
-@contact: liaoxingyu5@jd.com
+@contact: sherlockliao01@gmail.com
 """
 
 import math
@@ -35,5 +35,3 @@ def weights_init_classifier(m):
         nn.init.normal_(m.weight, std=0.001)
         if m.bias is not None:
             nn.init.constant_(m.bias, 0.0)
-    elif classname.find("Arcface") != -1 or classname.find("Circle") != -1:
-        nn.init.kaiming_uniform_(m.weight, a=math.sqrt(5))
diff --git a/projects/PartialReID/configs/partial_market.yml b/projects/PartialReID/configs/partial_market.yml
index 5f16d6f..76b9b0e 100644
--- a/projects/PartialReID/configs/partial_market.yml
+++ b/projects/PartialReID/configs/partial_market.yml
@@ -3,11 +3,10 @@ MODEL:
 
   BACKBONE:
     NAME: "build_resnet_backbone"
-    DEPTH: 50
+    DEPTH: "50x"
     NORM: "BN"
     LAST_STRIDE: 1
     WITH_IBN: True
-    PRETRAIN_PATH: "/export/home/lxy/.cache/torch/checkpoints/resnet50_ibn_a.pth.tar"
 
   HEADS:
     NAME: "DSRHead"
diff --git a/tools/deploy/README.md b/tools/deploy/README.md
index 96167f5..2a340b0 100644
--- a/tools/deploy/README.md
+++ b/tools/deploy/README.md
@@ -1,38 +1,29 @@
-# Deployment
+# Model Deployment
 
 This directory contains:
 
-1. A script that converts a fastreid model to Caffe format.
+1. The scripts that convert a fastreid model to Caffe/ONNX/TRT format.
 
-2. An exmpale that loads a R50 baseline model in Caffe and run inference.
+2. The exmpales that load a R50 baseline model in Caffe/ONNX/TRT and run inference.
 
 ## Tutorial
 
+### Caffe Convert
+
+<details>
+<summary>step-to-step pipeline for caffe convert</summary>
+
 This is a tiny example for converting fastreid-baseline in `meta_arch` to Caffe model, if you want to convert more complex architecture, you need to customize more things.
 
-1. Change `preprocess_image` in `fastreid/modeling/meta_arch/baseline.py` as below
-
-    ```python
-    def preprocess_image(self, batched_inputs):
-        """
-        Normalize and batch the input images.
-        """
-        # images = [x["images"] for x in batched_inputs]
-        # images = batched_inputs["images"]
-        images = batched_inputs
-        images.sub_(self.pixel_mean).div_(self.pixel_std)
-        return images
-    ```
-
-2. Run `caffe_export.py` to get the converted Caffe model,
+1. Run `caffe_export.py` to get the converted Caffe model,
 
     ```bash
-    python caffe_export.py --config-file "/export/home/lxy/fast-reid/logs/market1501/bagtricks_R50/config.yaml" --name "baseline_R50" --output "logs/caffe_model" --opts MODEL.WEIGHTS "/export/home/lxy/fast-reid/logs/market1501/bagtricks_R50/model_final.pth"
+    python caffe_export.py --config-file root-path/market1501/bagtricks_R50/config.yml --name "baseline_R50" --output outputs/caffe_model --opts MODEL.WEIGHTS root-path/logs/market1501/bagtricks_R50/model_final.pth
     ```
 
-    then you can check the Caffe model and prototxt in `logs/caffe_model`.
+    then you can check the Caffe model and prototxt in `outputs/caffe_model`.
 
-3. Change `prototxt` following next three steps:
+2. Change `prototxt` following next three steps:
 
    1) Edit `max_pooling` in `baseline_R50.prototxt` like this
 
@@ -67,7 +58,7 @@ This is a tiny example for converting fastreid-baseline in `meta_arch` to Caffe
         }
         ```
 
-    3) Change the last layer `top` name to `output`
+   3) Change the last layer `top` name to `output`
 
         ```prototxt
         layer {
@@ -81,22 +72,89 @@ This is a tiny example for converting fastreid-baseline in `meta_arch` to Caffe
         }
         ```
 
-4. (optional) You can open [Netscope](https://ethereon.github.io/netscope/quickstart.html), then enter you network `prototxt` to visualize the network.
+3. (optional) You can open [Netscope](https://ethereon.github.io/netscope/quickstart.html), then enter you network `prototxt` to visualize the network.
 
-5. Run `caffe_inference.py` to save Caffe model features with input images
+4. Run `caffe_inference.py` to save Caffe model features with input images
 
    ```bash
-    python caffe_inference.py --model-def "logs/caffe_model/baseline_R50.prototxt" \
-    --model-weights "logs/caffe_model/baseline_R50.caffemodel" \
-    --input \
-    '/export/home/DATA/Market-1501-v15.09.15/bounding_box_test/1182_c5s3_015240_04.jpg' \
-    '/export/home/DATA/Market-1501-v15.09.15/bounding_box_test/1182_c6s3_038217_01.jpg' \
-    '/export/home/DATA/Market-1501-v15.09.15/bounding_box_test/1183_c5s3_006943_05.jpg' \
-    --output "caffe_R34_output"
+    python caffe_inference.py --model-def outputs/caffe_model/baseline_R50.prototxt \
+    --model-weights outputs/caffe_model/baseline_R50.caffemodel \
+    --input test_data/*.jpg --output caffe_output
    ```
 
-6. Run `demo/demo.py` to get fastreid model features with the same input images, then compute the cosine similarity of difference model features to verify if you convert Caffe model successfully.
+6. Run `demo/demo.py` to get fastreid model features with the same input images, then verify that Caffe and PyTorch are computing the same value for the network.
+
+    ```python
+    np.testing.assert_allclose(torch_out, ort_out, rtol=1e-3, atol=1e-6)
+    ```
+
+</details>
+
+### ONNX Convert
+
+<details>
+<summary>step-to-step pipeline for onnx convert</summary>
+
+This is a tiny example for converting fastreid-baseline in `meta_arch` to ONNX model. ONNX supports most operators in pytorch as far as I know and if some operators are not supported by ONNX, you need to customize these.
+
+1. Run `onnx_export.py` to get the converted ONNX model,
+
+    ```bash
+    python onnx_export.py --config-file root-path/bagtricks_R50/config.yml --name "baseline_R50" --output outputs/onnx_model --opts MODEL.WEIGHTS root-path/logs/market1501/bagtricks_R50/model_final.pth
+    ```
+
+    then you can check the ONNX model in `outputs/onnx_model`.
+
+2. (optional) You can use [Netron](https://github.com/lutzroeder/netron) to visualize the network.
+
+3. Run `onnx_inference.py` to save ONNX model features with input images
+
+   ```bash
+    python onnx_inference.py --model-path outputs/onnx_model/baseline_R50.onnx \
+    --input test_data/*.jpg --output onnx_output
+   ```
+
+4. Run `demo/demo.py` to get fastreid model features with the same input images, then verify that ONNX Runtime and PyTorch are computing the same value for the network.
+
+    ```python
+    np.testing.assert_allclose(torch_out, ort_out, rtol=1e-3, atol=1e-6)
+    ```
+
+</details>
+
+### TensorRT Convert
+
+<details>
+<summary>step-to-step pipeline for trt convert</summary>
+
+This is a tiny example for converting fastreid-baseline in `meta_arch` to TRT model. We use [tiny-tensorrt](https://github.com/zerollzeng/tiny-tensorrt) which is a simple and easy-to-use nvidia TensorRt warpper, to get the model converted to tensorRT.
+
+First you need to convert the pytorch model to ONNX format following [ONNX Convert](https://github.com/JDAI-CV/fast-reid#fastreid), and you need to remember your `output` name. Then you can convert ONNX model to TensorRT following instructions below.
+
+1. Run command line below to get the converted TRT model from ONNX model,
+
+    ```bash
+
+    python trt_export.py --name "baseline_R50" --output outputs/trt_model --onnx-model outputs/onnx_model/baseline.onnx --heighi 256 --width 128
+    ```
+
+    then you can check the TRT model in `outputs/trt_model`.
+
+2. Run `trt_inference.py` to save TRT model features with input images
+
+   ```bash
+    python onnx_inference.py --model-path outputs/trt_model/baseline.engine \
+    --input test_data/*.jpg --output trt_output --output-name trt_model_outputname
+   ```
+
+3. Run `demo/demo.py` to get fastreid model features with the same input images, then verify that TensorRT and PyTorch are computing the same value for the network.
+
+    ```python
+    np.testing.assert_allclose(torch_out, ort_out, rtol=1e-3, atol=1e-6)
+    ```
+
+</details>
 
 ## Acknowledgements
 
-Thank to [CPFLAME](https://github.com/CPFLAME), [gcong18](https://github.com/gcong18), [YuxiangJohn](https://github.com/YuxiangJohn) and [wiggin66](https://github.com/wiggin66) at JDAI Model Acceleration Group for help in PyTorch to Caffe model converting.
+Thank to [CPFLAME](https://github.com/CPFLAME), [gcong18](https://github.com/gcong18), [YuxiangJohn](https://github.com/YuxiangJohn) and [wiggin66](https://github.com/wiggin66) at JDAI Model Acceleration Group for help in PyTorch model converting.
diff --git a/tools/deploy/caffe_export.py b/tools/deploy/caffe_export.py
index ec54af5..db89612 100644
--- a/tools/deploy/caffe_export.py
+++ b/tools/deploy/caffe_export.py
@@ -1,7 +1,7 @@
 # encoding: utf-8
 """
 @author:  xingyu liao
-@contact: liaoxingyu5@jd.com
+@contact: sherlockliao01@gmail.com
 """
 
 import argparse
@@ -15,6 +15,9 @@ from fastreid.config import get_cfg
 from fastreid.modeling.meta_arch import build_model
 from fastreid.utils.file_io import PathManager
 from fastreid.utils.checkpoint import Checkpointer
+from fastreid.utils.logger import setup_logger
+
+logger = setup_logger(name='caffe_export')
 
 
 def setup_cfg(args):
@@ -64,10 +67,12 @@ if __name__ == '__main__':
     model = build_model(cfg)
     Checkpointer(model).load(cfg.MODEL.WEIGHTS)
     model.eval()
-    print(model)
+    logger.info(model)
 
-    inputs = torch.randn(1, 3, cfg.INPUT.SIZE_TEST[0], cfg.INPUT.SIZE_TEST[1]).cuda()
+    inputs = torch.randn(1, 3, cfg.INPUT.SIZE_TEST[0], cfg.INPUT.SIZE_TEST[1])
     PathManager.mkdirs(args.output)
     pytorch_to_caffe.trans_net(model, inputs, args.name)
     pytorch_to_caffe.save_prototxt(f"{args.output}/{args.name}.prototxt")
     pytorch_to_caffe.save_caffemodel(f"{args.output}/{args.name}.caffemodel")
+
+    logger.info(f"Export caffe model in {args.output} sucessfully!")
diff --git a/tools/deploy/caffe_inference.py b/tools/deploy/caffe_inference.py
index 6b92f82..2956816 100644
--- a/tools/deploy/caffe_inference.py
+++ b/tools/deploy/caffe_inference.py
@@ -1,7 +1,7 @@
 # encoding: utf-8
 """
 @author:  xingyu liao
-@contact: liaoxingyu5@jd.com
+@contact: sherlockliao01@gmail.com
 """
 
 import caffe
@@ -43,7 +43,7 @@ def get_parser():
     parser.add_argument(
         "--height",
         type=int,
-        default=384,
+        default=256,
         help="height of image"
     )
     parser.add_argument(
diff --git a/tools/deploy/export2tf.py b/tools/deploy/export2tf.py
deleted file mode 100644
index 593c7cb..0000000
--- a/tools/deploy/export2tf.py
+++ /dev/null
@@ -1,48 +0,0 @@
-# encoding: utf-8
-"""
-@author:  sherlock
-@contact: sherlockliao01@gmail.com
-"""
-
-import sys
-
-import torch
-sys.path.append('../..')
-from fastreid.config import get_cfg
-from fastreid.engine import default_argument_parser, default_setup
-from fastreid.modeling.meta_arch import build_model
-from fastreid.export.tensorflow_export import export_tf_reid_model
-from fastreid.export.tf_modeling import TfMetaArch
-
-
-def setup(args):
-    """
-    Create configs and perform basic setups.
-    """
-    cfg = get_cfg()
-    # cfg.merge_from_file(args.config_file)
-    cfg.merge_from_list(args.opts)
-    cfg.freeze()
-    default_setup(cfg, args)
-    return cfg
-
-
-if __name__ == "__main__":
-    args = default_argument_parser().parse_args()
-    print("Command Line Args:", args)
-    cfg = setup(args)
-    cfg.defrost()
-    cfg.MODEL.BACKBONE.NAME = "build_resnet_backbone"
-    cfg.MODEL.BACKBONE.DEPTH = 50
-    cfg.MODEL.BACKBONE.LAST_STRIDE = 1
-    # If use IBN block in backbone
-    cfg.MODEL.BACKBONE.WITH_IBN = False
-    cfg.MODEL.BACKBONE.PRETRAIN = False
-
-    from torchvision.models import resnet50
-    # model = TfMetaArch(cfg)
-    model = resnet50(pretrained=False)
-    # model.load_params_wo_fc(torch.load('logs/bjstation/res50_baseline_v0.4/ckpts/model_epoch80.pth'))
-    model.eval()
-    dummy_inputs = torch.randn(1, 3, 256, 128)
-    export_tf_reid_model(model, dummy_inputs, 'reid_tf.pb')
diff --git a/tools/deploy/onnx_export.py b/tools/deploy/onnx_export.py
new file mode 100644
index 0000000..449db3a
--- /dev/null
+++ b/tools/deploy/onnx_export.py
@@ -0,0 +1,146 @@
+# encoding: utf-8
+"""
+@author:  xingyu liao
+@contact: sherlockliao01@gmail.com
+"""
+
+import argparse
+import io
+import sys
+
+import onnx
+import torch
+from onnxsim import simplify
+from torch.onnx import OperatorExportTypes
+
+sys.path.append('../../')
+
+from fastreid.config import get_cfg
+from fastreid.modeling.meta_arch import build_model
+from fastreid.utils.file_io import PathManager
+from fastreid.utils.checkpoint import Checkpointer
+from fastreid.utils.logger import setup_logger
+
+logger = setup_logger(name='onnx_export')
+
+
+def setup_cfg(args):
+    cfg = get_cfg()
+    cfg.merge_from_file(args.config_file)
+    cfg.merge_from_list(args.opts)
+    cfg.freeze()
+    return cfg
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(description="Convert Pytorch to ONNX model")
+
+    parser.add_argument(
+        "--config-file",
+        metavar="FILE",
+        help="path to config file",
+    )
+    parser.add_argument(
+        "--name",
+        default="baseline",
+        help="name for converted model"
+    )
+    parser.add_argument(
+        "--output",
+        default='onnx_model',
+        help='path to save converted onnx model'
+    )
+    parser.add_argument(
+        "--opts",
+        help="Modify config options using the command-line 'KEY VALUE' pairs",
+        default=[],
+        nargs=argparse.REMAINDER,
+    )
+    return parser
+
+
+def remove_initializer_from_input(model):
+    if model.ir_version < 4:
+        print(
+            'Model with ir_version below 4 requires to include initilizer in graph input'
+        )
+        return
+
+    inputs = model.graph.input
+    name_to_input = {}
+    for input in inputs:
+        name_to_input[input.name] = input
+
+    for initializer in model.graph.initializer:
+        if initializer.name in name_to_input:
+            inputs.remove(name_to_input[initializer.name])
+
+    return model
+
+
+def export_onnx_model(model, inputs):
+    """
+    Trace and export a model to onnx format.
+    Args:
+        model (nn.Module):
+        inputs (torch.Tensor): the model will be called by `model(*inputs)`
+    Returns:
+        an onnx model
+    """
+    assert isinstance(model, torch.nn.Module)
+
+    # make sure all modules are in eval mode, onnx may change the training state
+    # of the module if the states are not consistent
+    def _check_eval(module):
+        assert not module.training
+
+    model.apply(_check_eval)
+
+    # Export the model to ONNX
+    with torch.no_grad():
+        with io.BytesIO() as f:
+            torch.onnx.export(
+                model,
+                inputs,
+                f,
+                operator_export_type=OperatorExportTypes.ONNX_ATEN_FALLBACK,
+                # verbose=True,  # NOTE: uncomment this for debugging
+                # export_params=True,
+            )
+            onnx_model = onnx.load_from_string(f.getvalue())
+
+    # Apply ONNX's Optimization
+    all_passes = onnx.optimizer.get_available_passes()
+    passes = ["extract_constant_to_initializer", "eliminate_unused_initializer", "fuse_bn_into_conv"]
+    assert all(p in all_passes for p in passes)
+    onnx_model = onnx.optimizer.optimize(onnx_model, passes)
+    return onnx_model
+
+
+if __name__ == '__main__':
+    args = get_parser().parse_args()
+    cfg = setup_cfg(args)
+
+    cfg.defrost()
+    cfg.MODEL.BACKBONE.PRETRAIN = False
+    if cfg.MODEL.HEADS.POOL_LAYER == 'fastavgpool':
+        cfg.MODEL.HEADS.POOL_LAYER = 'avgpool'
+    model = build_model(cfg)
+    Checkpointer(model).load(cfg.MODEL.WEIGHTS)
+    model.eval()
+    logger.info(model)
+
+    inputs = torch.randn(1, 3, cfg.INPUT.SIZE_TEST[0], cfg.INPUT.SIZE_TEST[1])
+    onnx_model = export_onnx_model(model, inputs)
+
+    model_simp, check = simplify(onnx_model)
+
+    model_simp = remove_initializer_from_input(model_simp)
+
+    assert check, "Simplified ONNX model could not be validated"
+
+    PathManager.mkdirs(args.output)
+
+    onnx.save_model(model_simp, f"{args.output}/{args.name}.onnx")
+
+    logger.info(f"Export onnx model in {args.output} successfully!")
diff --git a/tools/deploy/onnx_inference.py b/tools/deploy/onnx_inference.py
new file mode 100644
index 0000000..2e29f62
--- /dev/null
+++ b/tools/deploy/onnx_inference.py
@@ -0,0 +1,85 @@
+# encoding: utf-8
+"""
+@author:  xingyu liao
+@contact: sherlockliao01@gmail.com
+"""
+
+import argparse
+import glob
+import os
+
+import cv2
+import numpy as np
+import onnxruntime
+import tqdm
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(description="onnx model inference")
+
+    parser.add_argument(
+        "--model-path",
+        default="onnx_model/baseline.onnx",
+        help="onnx model path"
+    )
+    parser.add_argument(
+        "--input",
+        nargs="+",
+        help="A list of space separated input images; "
+             "or a single glob pattern such as 'directory/*.jpg'",
+    )
+    parser.add_argument(
+        "--output",
+        default='onnx_output',
+        help='path to save converted caffe model'
+    )
+    parser.add_argument(
+        "--height",
+        type=int,
+        default=256,
+        help="height of image"
+    )
+    parser.add_argument(
+        "--width",
+        type=int,
+        default=128,
+        help="width of image"
+    )
+    return parser
+
+
+def preprocess(image_path, image_height, image_width):
+    original_image = cv2.imread(image_path)
+    # the model expects RGB inputs
+    original_image = original_image[:, :, ::-1]
+
+    # Apply pre-processing to image.
+    img = cv2.resize(original_image, (image_width, image_height), interpolation=cv2.INTER_CUBIC)
+    img = img.astype("float32").transpose(2, 0, 1)[np.newaxis]  # (1, 3, h, w)
+    return img
+
+
+def normalize(nparray, order=2, axis=-1):
+    """Normalize a N-D numpy array along the specified axis."""
+    norm = np.linalg.norm(nparray, ord=order, axis=axis, keepdims=True)
+    return nparray / (norm + np.finfo(np.float32).eps)
+
+
+if __name__ == "__main__":
+    args = get_parser().parse_args()
+
+    ort_sess = onnxruntime.InferenceSession(args.model_path)
+
+    input_name = ort_sess.get_inputs()[0].name
+
+    if not os.path.exists(args.output): os.makedirs(args.output)
+
+    if args.input:
+        if os.path.isdir(args.input[0]):
+            args.input = glob.glob(os.path.expanduser(args.input[0]))
+            assert args.input, "The input path(s) was not found"
+        for path in tqdm.tqdm(args.input):
+            image = preprocess(path, args.height, args.width)
+            feat = ort_sess.run(None, {input_name: image})[0]
+            feat = normalize(feat, axis=1)
+            np.save(os.path.join(args.output, path.replace('.jpg', '.npy').split('/')[-1]), feat)
diff --git a/tools/deploy/run_inference.sh b/tools/deploy/run_inference.sh
deleted file mode 100644
index 2bc54ab..0000000
--- a/tools/deploy/run_inference.sh
+++ /dev/null
@@ -1,10 +0,0 @@
-
-python caffe_inference.py --model-def "logs/caffe_R34/baseline_R34.prototxt" \
---model-weights "logs/caffe_R34/baseline_R34.caffemodel" \
---height 256 --width 128 \
---input \
-'/export/home/DATA/Market-1501-v15.09.15/bounding_box_test/1182_c5s3_015240_04.jpg' \
-'/export/home/DATA/Market-1501-v15.09.15/bounding_box_test/1182_c6s3_038217_01.jpg' \
-'/export/home/DATA/Market-1501-v15.09.15/bounding_box_test/1183_c5s3_006943_05.jpg' \
-'/export/home/DATA/DukeMTMC-reID/bounding_box_train/0728_c4_f0161265.jpg' \
---output "caffe_R34_output"
diff --git a/tools/deploy/test_data/0022_c6s1_002976_01.jpg b/tools/deploy/test_data/0022_c6s1_002976_01.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..e15dab72f2170b561c9049e4139aa9770a9b50dc
GIT binary patch
literal 2223
zcmbW!XHb)C76<V6B_RP4KnNmD=_NEt2we=lM2d85NRxnw+yH`*D2YgwA|Snq6a|e+
z6ZBqG1dJd>AyNWL2~|ZYQdW^%ymx1JW<Txj|IG7zdCr_Uzj-(#oJn90!OGSO06_o%
zcMIT*0p<V?3<igB^T6S7US1wPqzFF}fk282qXb3vOGqEsFCir*qoAoQBd0DeC8eT=
zR@cJda5!mY17m&cAx&)@_SX=QmzNibK#KA6i(zG@WU>EqIL`qg9)Jp%KtMTwO9+Gr
zft(INdUsE5@UH;>49Eq6a>L*}ynKk=hNe9L7YKoHK_T4SQ0Q*=rQLS`D#R@;i!p_v
zoP6MNBoQnvvz$lXtgdaZ^V?+wZQrmfynLej#Ka{Om6TP`syH28J$(a1a|=r=Ya3gF
zi>sTvhv%{5e*R|yh=D<8!y_(4MqP}KiMtw~keHO5l9hcUCpRy@pzu~jWmWa<J2kcS
z4UJ8Ynp++}dC}g{`I7Oft9xLO`EF=<WRx{E&7PT^n_pP`xU#zTZGB^N>mT2Lxj+E&
zH*5F)8}=VAp<Ndj6bgaDf4M-esNDk*f^y4ZV8W(Oa32y%4ol+^G0QBkYvYyIc3$4=
z8}^n@Q~@`oxblnkm+ZfTUHQLc|APJ7#R8BJaQE{dLI56EYpZz=_V#dq&E-M=XDzv+
z)-45ZH~`|Xr3X5lS>zxn@(r<9(HVdYHO7x;n%$k-!!vmx$I6F#H8rhWp6Ov4U7^uq
ztD9eKBM3!pI#yYXyU)HT%1mcny>t6qMd?j}ri$`Y4QbA1F!#>q+g5nGTd3e!smoId
zF?^#E{t_OE>%wDcW2Q|RKDi5Mw6(v$jokVD;&NCcKf2+B(Fp@wQf!1Z#dV{W($E=b
zTGbo{S-UIx$mIh*GbmK2r;_cj;yJ_tXxoA{TivFrsE{i8rQC(|(@f=>&85@kfrUyX
zQ$L^&2XA`bO0yz2kn6q#)_EL)5Vjnr=||%fQ#?CjTE`|cws^W8*ey@Aqsl#T!V?UV
z{G(M$J#d&B&?>F}>B!OPQns7o{&XuxSn@u8_2{M5VWImJ^`V(Z&qVXyc&Trp5tR7L
z-oI%xK6$;iRCG}Dc~YM{-xbpw*Aya1V&Z<!N8K)H4*MCVW;g^Lji;o)eEKZMFO%|E
zx4xy#T^wnQ@seIS`2695a0$__x$`Rh$3kAE9XZ3#!&lH`VB0?E=E&&C*<9Odxt9QD
z`>k@F%x}S=4(}$JI%@q5j#6(bF*z#X%%+P`*DfUIGr*7!-d+7?KWt$zJIo;$lB#bK
ztt|3Z(Y%LhkmUkdE9`Tv^|`__9VEex_lrILP^@ZovUH)D9<dOps)PwVnG<~2-9c4G
zFk0|a?iYvy^8*2K?zvzvj_fb2CGK)^zB92@F!6hQZ^Ml&QCo;WE9qzO>qcWMczp`F
zV$No9$Xz3&I>rfYsX<6RfxXb$;FXX`1<k0zO<RO(#KTn-`8A7p*`p;FyYPK?0<SL~
zPh&pX6e~2}B$N@p{vk^wRypD#SA?Zzz-^`K?D~hU<Dc2ZHU{oq*l<>fWc2MwtqQ7&
zP~wknD<fi@^Lb#Jo@JPutt=8yDWEr9Y>#g*@lp_1(&;nS<69=52|PEIIx|H|`L5R4
ztw}K;pYxQ>(M|Bl95%-DEgN}f2U`uWL-XlPl!s<bXgcCr9q~vmzu}oIwCs4TtL{`^
z%uvTugLk<6sKuAeUL~0SDYmppaM4cJX`5R5Mrp_-mU<ICAGG8K%@oe*dwTGim+I;o
z2N2$5R6>dCRT*r25|rlO&=`7){e8ID2pgBkzGG198-;q%pAmjyG*Y&-g#(D|G>Px@
zHgDVOj<L_*&zk;Z-JEHZG{ds~<aK5aqmS2G@6<=7mL9GyQ}wmmKzvP>h)s%^UG*+y
zxrTmV$n+h)N;;-?H|UBJCfR~IAC;x;6&tv!Yhg*_0R92@5}ug9?%HB_x~8Bp5+}YL
z>eal&u69&p)?vFac(mA9N(*cDW1VyAh<@Lc{mw*j0zYCVH9SSWMH?LB5ioKmmq)+d
z_Z8I{Xe&}w@Y+vVIbGHGQ04JIoUw3Ah2o!TWQ?dreta84it@6GqSal4ces&<ku2IE
zTZ~GE(+1aT6dKC{#8*-p8)d)P`xvpxbYQk0LHfa^8td))?6ti|?1NQ8HcsYbu8mJW
zfDa1#Y#E35hfNa^tqSWuj&}7QznEv^P11Q$Ov<^NwdMQH@ynOhOKo@Z+U*##PfXyE
z-l}rz=n93LNd1uK5xwTYVtK9>LC~{d+N|o9G`gmL@(AYiw9EA%7LIg;tdNb*+w-w4
zO6c3(cl*kbqZT$UzSl2(l*}PUSgInjU=9F1xK(1`VBsdA^~ceqfWVVJ?9IjAa1#9k
z9f2SdLlPyEp4Z+ONWbM2rI99QM|l7#+g#Xbf4n4Oy#rM(Sl+1x)%+)xQ6>Wys3xL$
zCK-0a$lEo|=yA=Gn%=JDx}7n;0-WTgjPLtPwsX~nO=G%9ne6w`Tsc)wNvmB{vNvC^
z*X*KS(u(E1-0#=MlA^k)um~}D2BBwwiC{Lm8O^@>v^5-56BOs<7F@tW57X4`czqU3
z?>+mB36MAusvFcVSq##aP7z`}`&d~EdP)p;Ru2~e=|E!C-G5=Qbkk8?PRVe|F5D|x
zzR|I+%>&-jd7g1`QjPG+22wLK6#lygF1>QX!qAd2>!f=h><FwlKtHx;!iod5$#MXy
TYle&u4RzVX{Uog?mNWV%al_>}

literal 0
HcmV?d00001

diff --git a/tools/deploy/test_data/0027_c2s2_091032_02.jpg b/tools/deploy/test_data/0027_c2s2_091032_02.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..8cf495998221cfa7a9772b16ee04baf06cda1fff
GIT binary patch
literal 2137
zcmbW!dpOg59|!Q?Y;4YLHi<EZm9%nYM)x$_n$sv*35`l2=NLJpgvzNL8cjJQ-OPDq
zg>2E`Mw7eDX{wQ&vKjSAa!BRYeLvUpT-Wp0^E~hG_5FPQ_<pYM^?hC6wbyH-zy>FK
ziah`V0RU)y0c%458Gy;jLFHs&P$*Ph9;Tp#R8~?{RN914Q$=cPVsx}M(P&+)v7xU1
zb^|ne>n@}1J52}#0>+SN`3s(8Y)ZiYyaXgKFR!Giq@k>=f!9Op;s58XJpte_APM*t
z4AKW=;2<y@wAKz_*54@$`YXUc1CjwlWaXeRc?HGwfyNDh3<wOCfq-RYA&~Xyg!OX(
z0+&VTnb^pwQIA9Q=}3Gs<2KA-59cZBK<^yZ^hD$Zc?I>28k;mXZ!z3zw2fe9zUvpF
z1=-fl-ocULbkOaPyNBmtny;V#$$-G1;HcBlF|l!H;xArGNli=7xSX4JjhTP_MnU0k
zW#z1j%Bt#`d-od}n;tZCdB3-{cRb^FKJV)59}o_{di_Q;Ha;==eo8EvmVTQ5yzpgl
zX?f-APZtOP|IJ#jf5ZO61z&f`Kp<cU^rs6X6T7~_aEPp)i5$X)3O!C&)5j;nkb4-n
zIZx#cOb^VVPDJ)9sACCZn?L=e{U!VFU>E)`*}q`_c8LHbFlhbrz;M7C5EN1Kz4qTH
z^sXE+pZzdBSpQMqSu3Vw#MSib9+i-U1|0dhJ9LI>oi=tlaKJIDCxGpU%^~JG7Z@vS
zjqw$r^Y&3hhWfug$vG<Qt;6{7t~S5XFToQvYBt6vbhB_2^x5K)wu=<U{x^$Ti=KHi
znK>Mihx6D_dLyEjf((}8RfeayK?@K6IAP1dkFzDQ$4&};vX;)xB9kJ{)Cy;vuQ-%m
z(TWxai@y&a%&LQNMUR=J>8BJOKmHH7z*Pm`o9jI|^f}jTkX<RdH8LDKbT#wL!n7@o
znbMbs9J+98Pq|xDy<KJCr-x2GX^Mt7O4Z#BEcvL;;b6S;`Eif8CC<KDtuyfvQ}Nef
z<}b7CMWV&Im%0e}k~e+j>wvZYCA0xju;hZ%v2-mXY?|pLR!_7HJG{Wv5K6O~%NfzL
zddl1z4+f@kEiwD@=d}tpB5a>ttL`9^SieD{QIOajC62X|*=<3lO#!KU%6G0*lQ@!)
zn@x`ohm~I@4?Pa2yN8|16~}%Jo$I=xoGR-XKl*A1&0UBV8nJ3HHSbz@xF;&!^@b8v
z)UzX*$I4C*m%694_P;oJ7-Og*cbk{<0p77?=O27jYpz)qBPVk2Mn=7LoZDizz-l}W
z<wYfOf#{>#I-2Jvh*CRjTXn4xtIU5UJ6JCK#zb`m?UA+}1^*$>tZcg{t21;mB8ZdJ
zZXiZ%k;QJ?)TNer(PK1ox`!r>b+7R9WeWUs8kv+T_TEZDennVj0ppSVsW=hN&QCKG
zF+6Pev2(>nO1mT3?6<NZNfi<%HoHFIcrxeEs<BO${d<0^Pwi(bYT3>7#*VhGZ<AAZ
z#3VQPmUfvs-Iu3#C=4)N-d6VA*&A7tP?G^sXokfdU@Q4^Uv?QpwT#ylwbGn%S19b`
zXFapcot8Bgzf_$oo989GfBHS7!NQXDyH}3E3uCv@;`~q#%Y@7j^b6(19W+}P<juBE
zRV~pL+V;Zx7Y3!2Z^2>btvMkvGngej-|{<dH-FCQUeGJ0OAaK!XbH-qkMwap`lkF2
z=-GT$?Pj^CYKTX;-41Ne#s*$Xc5I4g9j>)=7I#-93HacWel?!?O+rDrl{qHq6r5sq
z8IKcvSNaG1z7T};)q|Wy1~%`6BfG)Bsd&`m7%PqwHks|^Y=q@yhGo3>#s+A-%@+4~
zphjyAEQQ0DOH(q91cSUafJ5>@l&V(MN!=oQx}&D>q<ujIi)-Q#yp_UEl!IGy{xML?
zE-9q`h2{`_h4XIit}VAAbtb%!Jzt#ZaecX;%<FHoWw3I+4qrEU;FF-*8r4|s6q3(3
zJeo5%II=rS(JQ&j)WvCbB0?)f@U1xduFbM930^c4ourtmSLhSzxbW5eim!n{(?G4o
zurAc~2l^m0^2{{t{MIDeaxFE(zQ_M<4%=09?%vt)Yc?gf=9XMityKt1uvqUdPx|W!
zT*cXtF#eLy-SGUHX=<IO&Mjozk@_88T1hjvswz4n_Z7#~@Jfo*Ys&>e78hWfRil0I
z%6rN2nC|kF^k@jZBnULNeUmnWSGyrmLDHfYPgnzxHe6?i2vQSxWiaM(VCa`s3Gn1F
z1w(#G&5p@<V6^Ssh_)T!P(k9eJTiTJRh-=ROzYHGzpl@qx6gAdHM?#sk#;$-)9PfT
z{0wbG{r+lV#=<hJD+7|XfG(t?a)Rdju-ODxbN##1dtbki22-%^j3N%9XI8~+Vy?XX
z@_2aEegt?}XiZRgm}s%TA*;M*GP&P!pDyvSM%*beRd~DQp6isqNYZM*yRiF*ptpsH
zXp|fwQ>D7lCHg8y`#dv*ng|b8dfw`jW*uq%gg@r_N^^U!SN#D2@#FdFZ$Y@i>hiSU
zM?1NX`rr4Dh&sY1r_nLehgNCqW6R-{Wk*J&hyb{5hE*jyl5aO~l$>t5U?zE_{y@_u
tHUA&&^Cr3bPn<opg&b=U!N|T;xW9<B3*%IA-$e4da^>UhDF$rq&7ay?)0F@K

literal 0
HcmV?d00001

diff --git a/tools/deploy/test_data/0032_c6s1_002851_01.jpg b/tools/deploy/test_data/0032_c6s1_002851_01.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..141ed15f4eb0e96db521c66e5c4a2d268a175493
GIT binary patch
literal 2640
zcmbW!c{CLK8VB$hGmM?FWf|*?5uuE+Z-wk4#yUjSF!p`Rma-*7cukfjWEXFC^=4;O
z6cfTr8KgpEU*qb1@45G!``5kq`<&nJpWkzy^L)<lbn5gIfZa&XP!9kE0sz3X0-VkO
zbO7{pbPzgPdI$u<z(CK)%*DdY#Kg?U!O6xY$S*7;$Pa^wAY>&(#H8Uc*m)&M=}U5G
zG+J0fMMD{-E-Q~l{Tc#fU|?WoV&-LG;YEqUL{a~9PP+h5dO$qjG6*OJpn(EGP~d4F
zK=^D=THs#+{4+or5SW$@LeId+bk@+!4xj;oKr~<wEiD*))*XBH4gf=GIYi~I&~aYB
z2@wnALM7md^l<IQZf=YBRD`@sP$B~(&pBQ`esQG4c}Xd>f})bLimHz8RXu$JLnBL!
zm9>qnoxQ7@yNBm3FYn-x(6I1`$f%?{$#?Iiq^4!(Jj~6@|E+-Vq_nL3X+>pK_3zJ{
znp;}i+B<rB`}zk4he*R?<Kzz$lT#n3KQB-gmzKY*eEqhuxwXBsySIOE_{#+Zfc|Ek
z-G9UW!v#Ha(SX4qFyxmDNE3c`Ku|EPs2m-~mFtk3ft+Hf1bQxQJh8Ex0WNPr<#q{r
z&&Y#7e-_{PMf*$k-@y|9FWJ9f|8`9Sm_fj^&jUdLS^%;WuYLPb_H9<Fh+`|a!DY?|
zUBcgQ#JG-)YPt`{>cy4Qu%L@CmN`{37VARF3>Tt-zVJTBGK{#*@Q4BTLa?`zlIs2b
zw;$tFX0u_=4{$bi9(nH^^%1IB(#qYu5cQ~u@HDh>%A!&RlTc=Gu^NA>0?+-M0;Rwo
zw*w3y-1EokubMit*GD$b8&#*m!@C9|;0ntYLV=E%=^k|HTzTKpgGFC=YALVEjn&?(
z`WED6k9d5tKrX2sKi9i!60bLQeo!|Dw8{O&m#*UHqJOE3h@-t+e8C`5n4&L<(3<AC
ze+rOw!k+@ZkNvEEv(u7wTjJPDrmi`?W~z>71es@JOXh`G1b%hWUMc1r3&EldYJI1D
z<BjF@m{bPUi<73CRs_xH8yombaZTetkj_!M8C|j0l<V_vP$$~{V3oVCX(tP#fHR&P
z9TL`dB(dFuS(_~6pJF?^qx>24Gi27eS4Cm%hGNW@zV(V^23?TKT1IEI^KA}1qW-|X
zh#j9cW_EVVZkL#1?37JwF<T@Big#Y58uD3a*&n*|C}>GXn~k*<naZeJQKDlNh?&uj
z3MM&u=2!gJPa=eiFfJ1dz8Yq{s*(OXQCV8)jC3VYuY%2;j`rg!pW!yxdNtUrQWu{m
z;QTf=o8a_j9oxr^H$>xZ3MQ}a=AT3IeG1B~6JUm7Oq<$Qg;>THV#D7!_jN-<N$U|`
z_X*qPT^oeOV+og*Ft(WWnR-pk@HXh!$7p1mRQ)I{%iTOXy-PhaAd-uW)!mVpY^vNT
z;$QB$rrv)ba8x&S68a`aZDP4?d>!rcWB+&@>oK<^pKt(uFx+vG-KXJ}ePG)|e|+Xt
zsi0#`_l``orI1KlXG_j7!7_gn>jgeaD=@w4?aS|XU8A$V7@n^49zl?>?#Nk8P+Qm7
zShG52DLn<at)wHm3A?KYldSr|Lk>rtTgVqmM@|^ysrvf#VQ6pmT3eye!L!FPYITFp
zz6Lt>&?H`>gc+c<jf}ND&k3?T>2i*m%#AUi7A3DYvK+~hJRkU6!uh$c1a8;2m<Eb(
zaCUcGVb&5kUqIWrW&Og~m8blPIx*aseO^Gad+(=IIS*bgYN|C7t!OO;uHhKy3~S)k
z{LnQMk!btE%PVN};#5h93>nx-T-+7QH{VbvG@SxGjFn4c+{;M5uZ%t)J9{+0cy8q}
zMRj}j2Zi8Wh6xdd2Q4X_h(_G?nTNji%c!Gm_{gh+&mm7<^Z_NS0NfT2KWLzqQEpyj
zmb&D@a_@Q~P7if8V8mhM3(hL{?pt<f7KZx=VV+u&>P|g0^=HgCZiYS{oxMGHFrl0W
zu}tMkva{JtsVgD<R#Qdm$R?IzoSUP(1t0fK{_vEv=2!00D=Ppl8EL`xNux>5MMIvv
z(rZJ<7vibw6_L@z%3`mlp)Q>dKV@kc)Co2)Xa8YYvJhlg|H=zl8?;hT^=TiOZ<1!`
za$mN7wlRghkF{6J=*tPx*3>=R6;8Ag?IOW8rPi942GU&HGTFtFSgi_S36prtaWq9n
zpzpR|t1&h@#ZUB+mTHD`hEVW!<sFLzhslC3hcT7>7L+Av@7en}%2KZWL}RuAWk|hn
z(GtIQExfceTCP8+DP(qxwPw#jm{{qai?FB-Vgq|rueYo`BJJTA^Uy~Q8`z%sbIz#q
zhAkYqc@@O{urX{L+e~f&T>tfA@ikNXh!>L!HV+SXYphj$rggZk_a#FwBd?HZ9)S7!
zYJsa<E*%qt0&?oQ?{q#Nya`|*P5!emvw8V}4?49I4tZSP**`rzOtt62kmkB>=caX<
z#L3dv*o$iLjT-v|l=K)>v#H&iX_a^xHf>!z-BKY~?RrbCusGpru#z02W9fYcP`I#L
zn|g^=3WK>fG9;R2XCM$@W-^KW94qU;rresbrk1h|nyA4}6U`MD!-wj7dv{+8R&@k%
z3nzc!qahfEh?Aa|M$YhhQpS~^4mfPV+@iP3TCzWka41TUGsc%PE$CCqZY&v=$tQ1M
z6!*lY4JfuVU>=$5d84Jv1X<l-bC^d*3Dj#c(CKQZ!lTQXZBYD)!i{;;=t^7q;iQ!d
z5#q0_3d=pzuJpIui;IcG<vM7h0#6i@{XHA}V2$4gqgJ0rQm76SReqf=Jx$3yneAW6
zguZ5J#gt8?*48`P(RcW_Mn<3Jl0*{20PK;<pqHT&0n_YRTzO&Uba3C{3rZ%f#XiKB
z@HrJRea4D!raq<>oXFAryq^6LO)`V<-cY#)gkx{lwUG{iy}4@Web+b7?#OiR6yS<v
zTf$75#H_};yo;)AbNwl)m&+`1qHQgqHD`@GQLe4Og&{1hFU33PJp#u!E@rE3nw3Y*
z;Y(NERtz{dv`Xm(*&rsf$~X`Sd2q2kn`e4dy~?fU_D#|WHlo2<v?e(%N!HPYHhp!9
zGN@fSTosHHwVo2PGsHbm_Iv3aog+LMUuaW*tUcDh(OR?y8(<xmRdMnkX#4(($(DrU
zH59>FmRsrTNsjwtT<x0jIZUV~wUtC$9=Mj#c#|A{%M^8K?YXihc}>eOyN~v+0p-T0
zjOAwSky7L`@JB0mS<y!Z+eGn%yNLEYlT$#{#gk?xUoUy?)-aiQsM`r9pVQ%)+$^?z
XU6s>rF7(Y{xC$ik9Nyyk>BoNnwWG;Q

literal 0
HcmV?d00001

diff --git a/tools/deploy/test_data/0048_c1s1_005351_01.jpg b/tools/deploy/test_data/0048_c1s1_005351_01.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..593caeed98671d64fa1fa961602a9399bdb7823e
GIT binary patch
literal 2149
zcmbW!dpy(s9tZI6?82DK#`>AI#Kv4EjS*5vnT39tYeGhda+}+P3P}@{d-$mp=Eu(^
zW!R!0xhy(L%Uv85m1|A>bP-3t-+7$#c$~k^Ij{HQ^ZDofe!L%_=i{?EvN;7P9JZm^
z03Z+mfVK+Q90#lcDHu!=29=VOl!U{jq!CK82pJg!2Bj#kgi}%7fm6X^cM^6HcdC={
zSgh6_ZPIQsg+fs!n(i|(FyCcJG5E0r1c$>BG6-c^S!Dw?teV09oXs`>DFs{rEWsdk
zKmrK@BSD*;fa=yep`bql{4*d4Fa!#dl!8mkYz^=f00|HnECB&Sp%BQ{bmCSYfFPkL
zHL?{<(UBpk9-(BAlyzGQf1th{O@IE8VCc(ChD&ezNg1P}p-I%z)}a^~?=dkovp#sp
zhH6VY?Bwj?>gImT!|$YjK;WsM;K-=xm{?Za*_4Z^m(nh?)3bAO^YROB6c*hnEi30%
z{90Lc?|uWXv8kCac+}C^)h&G7(<^#0FgP?kGCKC^^_%ILx9{H1&WV>lt*rjB_W8@#
zA1)98{)@GB{{{OG7jnxb0fB%al0RG^iP)_JMna%!WEjfIQIZj%sBVxXrF0<cc6~b>
zZ%F@$_GLbo-bSFj(h&ck{VDtJV9EcN>|d~dyT$+n7_{|yU?gAxJOkr#mB#CBDIU78
ztesz?3+2Bx*3i2uW@|e_OQSXcJ7bRY<%^X>4yM-HMQrqB!hD-QgU!-=S+v3votYdQ
z41L?9{N?rd_tU+r`tn|xuB;lpnG(C`Od3b}{J5Y9_h|Lmkl>c=+^|ZKndj8KYb|?+
z^^C8?hjjIoiRB7Ustdy89iV*o30wPwK2^<RerRU52?1nJis&CqIbWKQyq5a=e7=t6
z0Oh*;Qr)%YHck#x^x<iq%arkN8w2TgD)O~fF<rdeJ$sn%ZmEamCOR-a>)SPn19&ri
z#IF=<l4(z!duC0;a=walL_6yHR7%b&r91NM`Hptdm0eYn2f$gVC;V%Tr<)pi_g-+^
z^>(IfwFxIMM#b@!dd<tZOF8o{3_%dlEw34yPQssC+Aj0J553qz$0HPaWy537Bcd+x
z<4eqDePy%Na&Z@i9rrCCZfC-D$^>~0C$tM=Rbb&}!c<-e%Fy(ANaz&KKb5^6F~n~e
z?9hKwIy#MItOc|Nt=1V6FPT^SZ|y^~Zq`)RaO=Zw1rp_gq!!Pij++IU<&8Dsnx92@
z$)<{V&G~z-$JIR_S*cb{Q|T^daQgg7X^Uq@fut0d1Zv!NZxHIS-EIWOM}@npt@*V4
zrN&$Phm57EX_q>S%u<ZewNsq3aMBTX&@Ykc{m2`RqtBj70x?|8YG-~+RKU@5et$P|
zvQMq8#hUt3_ct7oTb_tK%Do%4J*Yh7_6bfXzmlm#J|oIq5TxTOB23<hTIRcpd_7p~
z_Ch}tV!iE?+lj8R&g|BS<|{!2M3TGUYU%l^iA3(_L-;p<VYoJGu#n(E8WOQzKZy1>
z#k!8o(=Q@}$^KLc9~}nah^;#&V&aBMr`YRDa(l>|jr|XKq#@ao{Ng(AF-t#M!Z*X&
zwE}UC{GwE?eT1cQotv+&FsiuKx7S5_!N+{#V@&NFjtOZ*+|yi07{SeXE<}^|X18We
zI2}#<nOzoupjvqcRRxQJ*Nz92_+{+4_y?Dj;+0LeEHHf!Wgg5AS)c_UUcBM3eCn(-
za@1?~j%d<F_NwxZ)q|e{>7#$^M~>@xZUU)1tr{aR;!9_X+KSOjw3hIRZ{)nk$h*+#
zx1NeUdN*O(LhDP3g4B(l3)Y+l27fblG{(f5yxVu^IA3s;I*$L<aox1mgQr(NKXtch
zt&Dm0wS&>`V@thO63KYBq)}I6#Z20AA>TwoppQM#Q0_~G0`GOxJadJu4oc$#FSfHC
zJ?k~%JgB611W`b7I=L@K!>=>9_-xTloQOMY^f<)Fr&=!3yZ_-f3Cgiha)ZwzdU_L(
zpH^TEJ--5Vx=u5xi+q@>kFp&)s<5IpK+DaVjy3Bphdh%8shXI~1wRnR@B5{IZkT-G
z){S9LaU0Qq*TN*PqN5Jyj68w!GOg9#Q_FoSbxW3oG@Hj__R>sq%p^_NYPHT<fI39O
zljofA%8^Ht`a+b;>)W;Yd+SGD=iq|Au{aqqkE`aa$~gy54l`sr{8GF#u}&$=-srC?
z)9aX%*C+)A5FOBk#MAfH2dO>k8@M1c_e=?rU~VTPAE@-++<b!};S7k|s{(YHaxeA2
zar3nvndmH8&VHi%1s1S6>=X4n727)Y@?rI^Si~@gwb4b`J;`?N$v_h|o2hNDPT6PU
z-5N?XKOAd49)IeJzPQoE@!01$f^(*wAr{zodeB@WY=2UjQij&eu)@&?F68_WN^!gl
zp&=cuW`nsr?^IiiQH{=xk#QS#7ZBnW!<xQN1~aNV_pd})^wXOcoHWV3T#HE~#2Wdy
z<%`~F8_+9id0^%Nnw$sP(KgH_$&KwL)lO%XA!c}s^bS|LUslKMH=2A`qNu5bRxS*)
z7(0`dPn;>}AF_p$5>|H%sKHA;yH3}P)68;qtSjga&G`FPXBgjwB<c#GE!-K5B7WAp
zq#{^>#y~JXpwuro?53>eC4|clh`txz7xpF|a{X!Bu^H~Vw!=jItCPjc9Qn=Be*j8<
By;=YO

literal 0
HcmV?d00001

diff --git a/tools/deploy/test_data/0065_c6s1_009501_02.jpg b/tools/deploy/test_data/0065_c6s1_009501_02.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..e7fb874372399effdb5e55afa8287a01109732ff
GIT binary patch
literal 1890
zcmbW!dpOkj9tZH>+{}ziYYf5}mn4@VG6plnPQ+L<wbo_EpiYvz&=`as>y|Js<(iPn
zxWp8>OKFo7O2Qa&%P>Q%wQiNlWwpD{InQ(cI_JDU&-eS!_w#(7&+GZ_jO<JS(iE~S
z832I*0JK}c&KPhEK!}OK#b5|H9F9aH#3f~=Bqbyy6=l%h$to+UswgXAFlrk5T59Tg
z`!N`8QyslShIl+)Rf}j&z?tbA;c?%FfRIR}q=ck`l#~K)A7&r!f6mSm0EGZzfuj(R
zIsit2ASlqzGeC8BPZ;RW0DlJrhCpFra0F6ZVz;4D8UTYJ5HJ)1gF&IY-O;=602Bq2
z*=J}WhPFQgS7*xN;xca|_FLAr%Q^HfYZ!Ti#3RM!_bMnV9njR$*1_V9O-u>I!^f<w
z$u>XOQXHLVbZ3T(>se1PZy#Sje^zK%_=SkbsDw+2my?oH*jd>*xp~*}uNRbXO3TVC
z?o?LQ)i*RYHMcxyecJJ?^ZARe?w)}`-VlFy<n5?ndPXRkoqIpOu(G<g{^8@Njn7}c
zxj+EqFV^n;7wkV=s9hHr3WY%7-&`Q@h1~-|L1Ft0#bhk(;b)j=bzB@m)-v;UZ98(m
zk;Af_M@YZ8yarxyVC5U_PuYJ5i~ql5|APJ7H3~>VK)as@K>;LSmBYW4@5#QovGAy3
z7B9o@G&QDvWM$}ZJG{%#(o73014T$dY7^GwWHyiI1lU{4KH(AiRBs$CE)3(IzSmY2
zZ|Cp=pw-k1Ra@cwR+3n4UBR?>Y~3<M{5eZR=2e!nh9ArE8a@RRX*n@6D!K$l*Xf|h
z_3E!DhbOgFv~!FoF{zEr5GuAsjm1x<UWX1?yM%sZG~XZnEQt2YoFD8n$PGpx0c9C(
z`Hv4f^w@^=$<>$qdh_jq^r*<r(W-PB%REt3nz#By<ehT$ifa;7d~uVCiAuclC^38{
z3&xBUuD4gEEurJd_Z{Cee2SdCrp(P=%1v~3m}YLO6^J+~#mYT4<`?Oc>C=UQrnY$l
zr{2yfoOWeb2`F^ygF#-`=%LA*>~R`>Y@YV0AOAZuOD-eIY0`Tszp6Y-Gf2=p;!(tH
z%Y_3?*h4dBgSq+L3Q({cp(=mMZ-XsA*x<{FOI|xXyjQyGWk_^D!-`Y+iSNf_LY);A
z%y^F2+rRjBa<r6T*5!;Czh6<V!bG2!(J)3Dg5%yY+h~T8QFaJO+D65FsIV`s?xdJ?
z1_s*(ntYhKY_dwa#HUgd(zegaEnLg9rv|y6F02GT-F13%FbL#FbNWI@-OWf!3MX_4
z>sOn4(%g7TwLY?YO>d2H7A?BZoT6^{zczuA8C!(1V#1<Y7$aB00NCR{zT{a_^2z{}
zxphO8GJ}pQS`FZ?NZef9Fq(Xze+XNd{nNXMLO<ej-O6QJ>pzH?tB7;ijAm`<cQ5Q+
zRTTMp=k(GXF+}qbHkQkkJNjDz6GOTX(HF%%kp1H{V?2>m=5=B^ow%pJWowWBRt<Y`
zv#orB7j+fJ)OIOO2;v)*YC^|m^W2RX1<5jcS)@stns-^-%7LO0nn#j3l@^e?g`cm_
zE5zXhk>&yANTbW4qZf$E9zx?w%vYoc<dgMkcQE|1|InaeZXK1?uE_-heNmH!wW&Ui
zEMh<;4lAYkY{A)5`iPJ0*IV?7&r|%G*tv{%|8&Y(QJ+#`7oH3~l~`ya@0Lm?4^6Cz
zr&>1^Y%6MnZ%v*j)KFK>g<thW_v#ZYL7nF(!Q_VxnI_x6Gux|e!mGn8bBym9>)qX)
zC7QQXJDjj2KaB-Cqu^JjR`p~r4qsRK&<2Y>?%#K7E_z5lvcl9=_m0qq8mSR@&hp0i
za$7G|kdB*kUXf~`zf(W9)~jZw{<Yvv=mQQcRi@{VO<8!cvPFfj=IA&GKi;dh(D8=P
z^5<g_ycK#QjGrFqPMR6e&~3>$_0WANI4KjcPAJk5-h5Fow@H8Z-rm(}B&g;JWzlRu
z;{I|ey5|>cd~|w9x1VRQq8qDlBRitbEPUoNT*J92aLSol)2q+rQdwHZ1xMHKW_t%j
zVY^<JRxkbJRVv|}&T;&?=V|kUl!*qqL)t;nL&HT4%yxq)afs4^yC~qk5%sV82uG~!
z1}eJ+_oT*1XVUFL2)!%{Ifb$I89%R;tv7Z(xi4y7WBGMctB0k&kd5)w6r|AyHr%Wv
zO~wG5^Uteeo*wZ)r)WV2U*R=%ReZuqdINM)9*2;UX=lqHe=peN4wNp75^B}`;y(GO
zMfkzLJj@`C8dH8Cy3F2lw-DKhj+>lZulX@pQZF^(@HWlmkveCouJ+3iOZ?T&+dlwk
CKuZ<?

literal 0
HcmV?d00001

diff --git a/tools/deploy/trt_export.py b/tools/deploy/trt_export.py
new file mode 100644
index 0000000..edc9f1e
--- /dev/null
+++ b/tools/deploy/trt_export.py
@@ -0,0 +1,82 @@
+# encoding: utf-8
+"""
+@author:  xingyu liao
+@contact: sherlockliao01@gmail.com
+"""
+
+import argparse
+import os
+import numpy as np
+import sys
+
+sys.path.append('../../')
+sys.path.append("/export/home/lxy/runtimelib-tensorrt-tiny/build")
+
+import pytrt
+from fastreid.utils.logger import setup_logger
+from fastreid.utils.file_io import PathManager
+
+
+logger = setup_logger(name='trt_export')
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(description="Convert ONNX to TRT model")
+
+    parser.add_argument(
+        "--name",
+        default="baseline",
+        help="name for converted model"
+    )
+    parser.add_argument(
+        "--output",
+        default='outputs/trt_model',
+        help='path to save converted trt model'
+    )
+    parser.add_argument(
+        "--onnx-model",
+        default='outputs/onnx_model/baseline.onnx',
+        help='path to onnx model'
+    )
+    parser.add_argument(
+        "--height",
+        type=int,
+        default=256,
+        help="height of image"
+    )
+    parser.add_argument(
+        "--width",
+        type=int,
+        default=128,
+        help="width of image"
+    )
+    return parser
+
+
+def export_trt_model(onnxModel, engineFile, input_numpy_array):
+    r"""
+    Export a model to trt format.
+    """
+
+    trt = pytrt.Trt()
+
+    customOutput = []
+    maxBatchSize = 1
+    calibratorData = []
+    mode = 2
+    trt.CreateEngine(onnxModel, engineFile, customOutput, maxBatchSize, mode, calibratorData)
+    trt.DoInference(input_numpy_array)  # slightly different from c++
+    return 0
+
+
+if __name__ == '__main__':
+    args = get_parser().parse_args()
+
+    inputs = np.zeros(shape=(32, args.height, args.width, 3))
+    onnxModel = args.onnx_model
+    engineFile = os.path.join(args.output, args.name+'.engine')
+
+    PathManager.mkdirs(args.output)
+    export_trt_model(onnxModel, engineFile, inputs)
+
+    logger.info(f"Export trt model in {args.output} successfully!")
diff --git a/tools/deploy/trt_inference.py b/tools/deploy/trt_inference.py
new file mode 100644
index 0000000..0775364
--- /dev/null
+++ b/tools/deploy/trt_inference.py
@@ -0,0 +1,99 @@
+# encoding: utf-8
+"""
+@author:  xingyu liao
+@contact: sherlockliao01@gmail.com
+"""
+import argparse
+import glob
+import os
+import sys
+
+import cv2
+import numpy as np
+# import tqdm
+
+sys.path.append("/export/home/lxy/runtimelib-tensorrt-tiny/build")
+
+import pytrt
+
+
+def get_parser():
+    parser = argparse.ArgumentParser(description="trt model inference")
+
+    parser.add_argument(
+        "--model-path",
+        default="outputs/trt_model/baseline.engine",
+        help="trt model path"
+    )
+    parser.add_argument(
+        "--input",
+        nargs="+",
+        help="A list of space separated input images; "
+             "or a single glob pattern such as 'directory/*.jpg'",
+    )
+    parser.add_argument(
+        "--output",
+        default="trt_output",
+        help="path to save trt model inference results"
+    )
+    parser.add_argument(
+        "--output-name",
+        help="tensorRT model output name"
+    )
+    parser.add_argument(
+        "--height",
+        type=int,
+        default=256,
+        help="height of image"
+    )
+    parser.add_argument(
+        "--width",
+        type=int,
+        default=128,
+        help="width of image"
+    )
+    return parser
+
+
+def preprocess(image_path, image_height, image_width):
+    original_image = cv2.imread(image_path)
+    # the model expects RGB inputs
+    original_image = original_image[:, :, ::-1]
+
+    # Apply pre-processing to image.
+    img = cv2.resize(original_image, (image_width, image_height), interpolation=cv2.INTER_CUBIC)
+    img = img.astype("float32").transpose(2, 0, 1)[np.newaxis]  # (1, 3, h, w)
+    return img
+
+
+def normalize(nparray, order=2, axis=-1):
+    """Normalize a N-D numpy array along the specified axis."""
+    norm = np.linalg.norm(nparray, ord=order, axis=axis, keepdims=True)
+    return nparray / (norm + np.finfo(np.float32).eps)
+
+
+if __name__ == "__main__":
+    args = get_parser().parse_args()
+
+    trt = pytrt.Trt()
+
+    onnxModel = ""
+    engineFile = args.model_path
+    customOutput = []
+    maxBatchSize = 1
+    calibratorData = []
+    mode = 2
+    trt.CreateEngine(onnxModel, engineFile, customOutput, maxBatchSize, mode, calibratorData)
+
+    if not os.path.exists(args.output): os.makedirs(args.output)
+
+    if args.input:
+        if os.path.isdir(args.input[0]):
+            args.input = glob.glob(os.path.expanduser(args.input[0]))
+            assert args.input, "The input path(s) was not found"
+        for path in args.input:
+            input_numpy_array = preprocess(path, args.height, args.width)
+            trt.DoInference(input_numpy_array)
+            feat = trt.GetOutput(args.output_name)
+            feat = normalize(feat, axis=1)
+            np.save(os.path.join(args.output, path.replace('.jpg', '.npy').split('/')[-1]), feat)