From 12957f66aaf0228871b1f7a8d364bf400dad7ddd Mon Sep 17 00:00:00 2001
From: L1aoXingyu <sherlockliao01@gmail.com>
Date: Tue, 18 Feb 2020 21:01:23 +0800
Subject: [PATCH] Change architecture: 1. delete redundant preprocess 2. add
 data prefetcher to accelerate data loading 3. fix minor bug of triplet
 sampler when only one image for one id

---
 .gitignore                                    |   3 +-
 README.md                                     |  14 +-
 fastreid/config/defaults.py                   |   4 +-
 fastreid/data/build.py                        |  64 +++++---
 fastreid/data/common.py                       |  39 ++++-
 fastreid/data/samplers/__init__.py            |   2 +-
 .../{training_sampler.py => data_sampler.py}  |  27 ++++
 fastreid/data/samplers/triplet_sampler.py     |   2 +-
 fastreid/data/transforms/build.py             |   7 +-
 fastreid/data/transforms/functional.py        | 109 ++++++--------
 fastreid/data/transforms/transforms.py        |  33 +++-
 fastreid/engine/defaults.py                   |  61 +-------
 fastreid/engine/hooks.py                      | 141 +++++++++---------
 fastreid/engine/train_loop.py                 |   9 +-
 fastreid/evaluation/evaluator.py              |  15 +-
 fastreid/evaluation/reid_evaluation.py        |  37 +----
 fastreid/modeling/backbones/resnet.py         |   3 +-
 fastreid/modeling/heads/arcface.py            |   2 +-
 fastreid/modeling/heads/bn_linear.py          |   4 +-
 fastreid/modeling/meta_arch/baseline.py       |  37 ++---
 fastreid/utils/precision_bn.py                |  34 ++---
 projects/AGWBaseline/configs/Base-AGW.yml     |   4 +-
 .../configs/Base-Strongbaseline.yml           |   4 +-
 .../configs/baseline_market1501.yml           |  16 +-
 projects/StrongBaseline/non_linear_head.py    |  78 ++++++++++
 projects/StrongBaseline/train_net.py          |   7 +
 26 files changed, 429 insertions(+), 327 deletions(-)
 rename fastreid/data/samplers/{training_sampler.py => data_sampler.py} (70%)
 create mode 100644 projects/StrongBaseline/non_linear_head.py

diff --git a/.gitignore b/.gitignore
index e567fd7..41d06a5 100644
--- a/.gitignore
+++ b/.gitignore
@@ -2,6 +2,7 @@
 __pycache__
 .DS_Store
 .vscode
-csrc/eval_cylib/*.so
+*.so
 logs/
 .ipynb_checkpoints
+logs
\ No newline at end of file
diff --git a/README.md b/README.md
index 0695041..08f871a 100644
--- a/README.md
+++ b/README.md
@@ -3,6 +3,7 @@
 FastReID is a research platform that implements state-of-the-art re-identification algorithms. 
 
 ## Quick Start
+
 The designed architecture follows this guide [PyTorch-Project-Template](https://github.com/L1aoXingyu/PyTorch-Project-Template), you can check each folder's purpose by yourself.
 
 1. `cd` to folder where you want to download this repo
@@ -13,25 +14,30 @@ The designed architecture follows this guide [PyTorch-Project-Template](https://
     - tensorboard
     - [yacs](https://github.com/rbgirshick/yacs)
 4. Prepare dataset
-    Create a directory to store reid datasets under this repo via
+    Create a directory to store reid datasets under projects, for example
+
     ```bash
-    cd fast-reid
+    cd fast-reid/projects/StrongBaseline
     mkdir datasets
     ```
+
     1. Download dataset to `datasets/` from [baidu pan](https://pan.baidu.com/s/1ntIi2Op) or [google driver](https://drive.google.com/file/d/0B8-rUzbwVRk0c054eEozWG9COHM/view)
     2. Extract dataset. The dataset structure would like:
+
     ```bash
     datasets
         Market-1501-v15.09.15
             bounding_box_test/
             bounding_box_train/
     ```
+
 5. Prepare pretrained model.
     If you use origin ResNet, you do not need to do anything. But if you want to use ResNet_ibn, you need to download pretrain model in [here](https://drive.google.com/open?id=1thS2B8UOSBi_cJX6zRy6YYRwz_nVFI_S). And then you can put it in `~/.cache/torch/checkpoints` or anywhere you like.
-    
-    Then you should set the pretrain model path in `configs/softmax_triplet.yml`.
+
+    Then you should set the pretrain model path in `configs/baseline_market1501.yml`.
 
 6. compile with cython to accelerate evalution
+
     ```bash
     cd fastreid/evaluation/rank_cylib; make all
     ```
diff --git a/fastreid/config/defaults.py b/fastreid/config/defaults.py
index fe2f158..8e95828 100644
--- a/fastreid/config/defaults.py
+++ b/fastreid/config/defaults.py
@@ -95,12 +95,12 @@ _C.INPUT.BRIGHTNESS = 0.4
 _C.INPUT.CONTRAST = 0.4
 # Random erasing
 _C.INPUT.RE = CN()
-_C.INPUT.RE.DO = True
+_C.INPUT.RE.ENABLED = True
 _C.INPUT.RE.PROB = 0.5
 _C.INPUT.RE.MEAN = [0.596*255, 0.558*255, 0.497*255]
 # Cutout
 _C.INPUT.CUTOUT = CN()
-_C.INPUT.CUTOUT.DO = False
+_C.INPUT.CUTOUT.ENABLED = False
 _C.INPUT.CUTOUT.PROB = 0.5
 _C.INPUT.CUTOUT.SIZE = 64
 _C.INPUT.CUTOUT.MEAN = [0, 0, 0]
diff --git a/fastreid/data/build.py b/fastreid/data/build.py
index a828054..68b0291 100644
--- a/fastreid/data/build.py
+++ b/fastreid/data/build.py
@@ -6,10 +6,11 @@
 import logging
 
 import torch
+from torch._six import container_abcs, string_classes, int_classes
 from torch.utils.data import DataLoader
 
 from . import samplers
-from .common import ReidDataset
+from .common import CommDataset, data_prefetcher
 from .datasets import DATASET_REGISTRY
 from .transforms import build_transforms
 
@@ -18,13 +19,13 @@ def build_reid_train_loader(cfg):
     train_transforms = build_transforms(cfg, is_train=True)
 
     logger = logging.getLogger(__name__)
-    train_img_items = list()
+    train_items = list()
     for d in cfg.DATASETS.NAMES:
         logger.info('prepare training set {}'.format(d))
         dataset = DATASET_REGISTRY.get(d)()
-        train_img_items.extend(dataset.train)
+        train_items.extend(dataset.train)
 
-    train_set = ReidDataset(train_img_items, train_transforms, relabel=True)
+    train_set = CommDataset(train_items, train_transforms, relabel=True)
 
     num_workers = cfg.DATALOADER.NUM_WORKERS
     batch_size = cfg.SOLVER.IMS_PER_BATCH
@@ -40,37 +41,31 @@ def build_reid_train_loader(cfg):
         train_set,
         num_workers=num_workers,
         batch_sampler=batch_sampler,
-        collate_fn=trivial_batch_collator,
+        collate_fn=fast_batch_collator,
     )
-    return train_loader
+    return data_prefetcher(cfg, train_loader)
 
 
 def build_reid_test_loader(cfg, dataset_name):
-    # tng_tfms = build_transforms(cfg, is_train=True)
     test_transforms = build_transforms(cfg, is_train=False)
 
     logger = logging.getLogger(__name__)
     logger.info('prepare test set {}'.format(dataset_name))
     dataset = DATASET_REGISTRY.get(dataset_name)()
-    query_names, gallery_names = dataset.query, dataset.gallery
-    test_img_items = query_names + gallery_names
+    test_items = dataset.query + dataset.gallery
+
+    test_set = CommDataset(test_items, test_transforms, relabel=False)
 
     num_workers = cfg.DATALOADER.NUM_WORKERS
     batch_size = cfg.TEST.IMS_PER_BATCH
-    # train_img_items = list()
-    # for d in cfg.DATASETS.NAMES:
-    #     dataset = init_dataset(d)
-    #     train_img_items.extend(dataset.train)
-
-    # tng_set = ImageDataset(train_img_items, tng_tfms, relabel=True)
-
-    # tng_set = ReidDataset(query_names + gallery_names, tng_tfms, False)
-    # tng_dataloader = DataLoader(tng_set, cfg.SOLVER.IMS_PER_BATCH, shuffle=True,
-    #                             num_workers=num_workers, collate_fn=fast_collate_fn, pin_memory=True, drop_last=True)
-    test_set = ReidDataset(test_img_items, test_transforms, relabel=False)
-    test_loader = DataLoader(test_set, batch_size, num_workers=num_workers,
-                             collate_fn=trivial_batch_collator, pin_memory=True)
-    return test_loader, len(query_names)
+    data_sampler = samplers.InferenceSampler(len(test_set))
+    batch_sampler = torch.utils.data.BatchSampler(data_sampler, batch_size, False)
+    test_loader = DataLoader(
+        test_set,
+        batch_sampler=batch_sampler,
+        num_workers=num_workers,
+        collate_fn=fast_batch_collator, pin_memory=True)
+    return data_prefetcher(cfg, test_loader), len(dataset.query)
 
 
 def trivial_batch_collator(batch):
@@ -78,3 +73,26 @@ def trivial_batch_collator(batch):
     A batch collator that does nothing.
     """
     return batch
+
+
+def fast_batch_collator(batched_inputs):
+    """
+    A simple batch collator for most common reid tasks
+    """
+
+    elem = batched_inputs[0]
+    if isinstance(elem, torch.Tensor):
+        out = torch.zeros((len(batched_inputs), *elem.size()), dtype=elem.dtype)
+        for i, tensor in enumerate(batched_inputs):
+            out[i] += tensor
+        return out
+
+    elif isinstance(elem, container_abcs.Mapping):
+        return {key: fast_batch_collator([d[key] for d in batched_inputs]) for key in elem}
+
+    elif isinstance(elem, float):
+        return torch.tensor(batched_inputs, dtype=torch.float64)
+    elif isinstance(elem, int_classes):
+        return torch.tensor(batched_inputs)
+    elif isinstance(elem, string_classes):
+        return batched_inputs
diff --git a/fastreid/data/common.py b/fastreid/data/common.py
index 8bd2eba..5bd6f9b 100644
--- a/fastreid/data/common.py
+++ b/fastreid/data/common.py
@@ -4,16 +4,17 @@
 @contact: sherlockliao01@gmail.com
 """
 
+import torch
 from torch.utils.data import Dataset
 
 from .data_utils import read_image
 
 
-class ReidDataset(Dataset):
+class CommDataset(Dataset):
     """Image Person ReID Dataset"""
 
     def __init__(self, img_items, transform=None, relabel=True):
-        self.tfms = transform
+        self.transform = transform
         self.relabel = relabel
 
         self.pid2label = None
@@ -35,8 +36,10 @@ class ReidDataset(Dataset):
     def __getitem__(self, index):
         img_path, pid, camid = self.img_items[index]
         img = read_image(img_path)
-        if self.tfms is not None:   img = self.tfms(img)
-        if self.relabel:            pid = self.pid2label[pid]
+        if self.transform is not None:
+            img = self.transform(img)
+        if self.relabel:
+            pid = self.pid2label[pid]
         return {
             'images': img,
             'targets': pid,
@@ -50,3 +53,31 @@ class ReidDataset(Dataset):
         else:
             prefix = file_path.split('/')[1]
         return prefix + '_' + str(pid)
+
+
+class data_prefetcher():
+    def __init__(self, cfg, loader):
+        self.loader = loader
+        self.loader_iter = iter(loader)
+
+        # normalize
+        assert len(cfg.MODEL.PIXEL_MEAN) == len(cfg.MODEL.PIXEL_STD)
+        num_channels = len(cfg.MODEL.PIXEL_MEAN)
+        self.mean = torch.tensor(cfg.MODEL.PIXEL_MEAN).view(1, num_channels, 1, 1)
+        self.std = torch.tensor(cfg.MODEL.PIXEL_STD).view(1, num_channels, 1, 1)
+
+        self.preload()
+
+    def preload(self):
+        try:
+            self.next_inputs = next(self.loader_iter)
+        except StopIteration:
+            self.next_inputs = None
+            return
+
+        self.next_inputs["images"].sub_(self.mean).div_(self.std)
+
+    def next(self):
+        inputs = self.next_inputs
+        self.preload()
+        return inputs
diff --git a/fastreid/data/samplers/__init__.py b/fastreid/data/samplers/__init__.py
index 3bdce19..556b70d 100644
--- a/fastreid/data/samplers/__init__.py
+++ b/fastreid/data/samplers/__init__.py
@@ -5,4 +5,4 @@
 """
 
 from .triplet_sampler import RandomIdentitySampler
-from .training_sampler import TrainingSampler
+from .data_sampler import TrainingSampler, InferenceSampler
diff --git a/fastreid/data/samplers/training_sampler.py b/fastreid/data/samplers/data_sampler.py
similarity index 70%
rename from fastreid/data/samplers/training_sampler.py
rename to fastreid/data/samplers/data_sampler.py
index dd0f21a..463a192 100644
--- a/fastreid/data/samplers/training_sampler.py
+++ b/fastreid/data/samplers/data_sampler.py
@@ -47,3 +47,30 @@ class TrainingSampler(Sampler):
                 yield from np.random.permutation(self._size)
             else:
                 yield from np.arange(self._size)
+
+
+class InferenceSampler(Sampler):
+    """
+    Produce indices for inference.
+    Inference needs to run on the __exact__ set of samples,
+    therefore when the total number of samples is not divisible by the number of workers,
+    this sampler produces different number of samples on different workers.
+    """
+
+    def __init__(self, size: int):
+        """
+        Args:
+            size (int): the total number of data of the underlying dataset to sample from
+        """
+        self._size = size
+        assert size > 0
+
+        begin = 0
+        end = self._size
+        self._local_indices = range(begin, end)
+
+    def __iter__(self):
+        yield from self._local_indices
+
+    def __len__(self):
+        return len(self._local_indices)
\ No newline at end of file
diff --git a/fastreid/data/samplers/triplet_sampler.py b/fastreid/data/samplers/triplet_sampler.py
index 882e6dd..9962775 100644
--- a/fastreid/data/samplers/triplet_sampler.py
+++ b/fastreid/data/samplers/triplet_sampler.py
@@ -63,7 +63,7 @@ class RandomIdentitySampler(Sampler):
                 select_indexes = No_index(index, i)
                 if not select_indexes:
                     # only one image for this identity
-                    ind_indexes = [i] * (self.num_instances - 1)
+                    ind_indexes = [0] * (self.num_instances - 1)
                 elif len(select_indexes) >= self.num_instances:
                     ind_indexes = np.random.choice(select_indexes, size=self.num_instances - 1, replace=False)
                 else:
diff --git a/fastreid/data/transforms/build.py b/fastreid/data/transforms/build.py
index 748d71f..730c519 100644
--- a/fastreid/data/transforms/build.py
+++ b/fastreid/data/transforms/build.py
@@ -22,10 +22,10 @@ def build_transforms(cfg, is_train=True):
         padding = cfg.INPUT.PADDING
         padding_mode = cfg.INPUT.PADDING_MODE
         # random erasing
-        do_re = cfg.INPUT.RE.DO
+        do_re = cfg.INPUT.RE.ENABLED
         re_prob = cfg.INPUT.RE.PROB
         re_mean = cfg.INPUT.RE.MEAN
-        res.append(T.Resize(size_train))
+        res.append(T.Resize(size_train, interpolation=3))
         if do_flip:
             res.append(T.RandomHorizontalFlip(p=flip_prob))
         if do_pad:
@@ -38,5 +38,6 @@ def build_transforms(cfg, is_train=True):
         #                       mean=cfg.INPUT.CUTOUT.MEAN))
     else:
         size_test = cfg.INPUT.SIZE_TEST
-        res.append(T.Resize(size_test))
+        res.append(T.Resize(size_test, interpolation=3))
+    res.append(ToTensor())
     return T.Compose(res)
diff --git a/fastreid/data/transforms/functional.py b/fastreid/data/transforms/functional.py
index 3becae9..4849e6a 100644
--- a/fastreid/data/transforms/functional.py
+++ b/fastreid/data/transforms/functional.py
@@ -3,69 +3,58 @@
 @author:  liaoxingyu
 @contact: sherlockliao01@gmail.com
 """
-import random
-from PIL import Image
 
-__all__ = ['swap']
+import numpy as np
+import torch
 
 
-def swap(img, crop):
-    def crop_image(image, cropnum):
-        width, high = image.size
-        crop_x = [int((width / cropnum[0]) * i) for i in range(cropnum[0] + 1)]
-        crop_y = [int((high / cropnum[1]) * i) for i in range(cropnum[1] + 1)]
-        im_list = []
-        for j in range(len(crop_y) - 1):
-            for i in range(len(crop_x) - 1):
-                im_list.append(image.crop((crop_x[i], crop_y[j], min(crop_x[i + 1], width), min(crop_y[j + 1], high))))
-        return im_list
+def to_tensor(pic):
+    """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
 
-    widthcut, highcut = img.size
-    img = img.crop((10, 10, widthcut - 10, highcut - 10))
-    images = crop_image(img, crop)
-    pro = 5
-    if pro >= 5:
-        tmpx = []
-        tmpy = []
-        count_x = 0
-        count_y = 0
-        k = 1
-        RAN = 2
-        for i in range(crop[1] * crop[0]):
-            tmpx.append(images[i])
-            count_x += 1
-            if len(tmpx) >= k:
-                tmp = tmpx[count_x - RAN:count_x]
-                random.shuffle(tmp)
-                tmpx[count_x - RAN:count_x] = tmp
-            if count_x == crop[0]:
-                tmpy.append(tmpx)
-                count_x = 0
-                count_y += 1
-                tmpx = []
-            if len(tmpy) >= k:
-                tmp2 = tmpy[count_y - RAN:count_y]
-                random.shuffle(tmp2)
-                tmpy[count_y - RAN:count_y] = tmp2
-        random_im = []
-        for line in tmpy:
-            random_im.extend(line)
+    See ``ToTensor`` for more details.
 
-        # random.shuffle(images)
-        width, high = img.size
-        iw = int(width / crop[0])
-        ih = int(high / crop[1])
-        toImage = Image.new('RGB', (iw * crop[0], ih * crop[1]))
-        x = 0
-        y = 0
-        for i in random_im:
-            i = i.resize((iw, ih), Image.ANTIALIAS)
-            toImage.paste(i, (x * iw, y * ih))
-            x += 1
-            if x == crop[0]:
-                x = 0
-                y += 1
+    Args:
+        pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
+
+    Returns:
+        Tensor: Converted image.
+    """
+    if isinstance(pic, np.ndarray):
+        assert len(pic.shape) in (2, 3)
+        # handle numpy array
+        if pic.ndim == 2:
+            pic = pic[:, :, None]
+
+        img = torch.from_numpy(pic.transpose((2, 0, 1)))
+        # backward compatibility
+        if isinstance(img, torch.ByteTensor):
+            return img.float()
+        else:
+            return img
+
+    # handle PIL Image
+    if pic.mode == 'I':
+        img = torch.from_numpy(np.array(pic, np.int32, copy=False))
+    elif pic.mode == 'I;16':
+        img = torch.from_numpy(np.array(pic, np.int16, copy=False))
+    elif pic.mode == 'F':
+        img = torch.from_numpy(np.array(pic, np.float32, copy=False))
+    elif pic.mode == '1':
+        img = 255 * torch.from_numpy(np.array(pic, np.uint8, copy=False))
     else:
-        toImage = img
-    toImage = toImage.resize((widthcut, highcut))
-    return toImage
+        img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
+    # PIL image mode: L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK
+    if pic.mode == 'YCbCr':
+        nchannel = 3
+    elif pic.mode == 'I;16':
+        nchannel = 1
+    else:
+        nchannel = len(pic.mode)
+    img = img.view(pic.size[1], pic.size[0], nchannel)
+    # put it from HWC to CHW format
+    # yikes, this transpose takes 80% of the loading time/CPU
+    img = img.transpose(0, 1).transpose(0, 2).contiguous()
+    if isinstance(img, torch.ByteTensor):
+        return img.float()
+    else:
+        return img
diff --git a/fastreid/data/transforms/transforms.py b/fastreid/data/transforms/transforms.py
index 3b40e87..c413800 100644
--- a/fastreid/data/transforms/transforms.py
+++ b/fastreid/data/transforms/transforms.py
@@ -4,16 +4,41 @@
 @contact: sherlockliao01@gmail.com
 """
 
-__all__ = ['RandomErasing', 'Cutout', 'random_angle_rotate', 'do_color', 'random_shift', 'random_scale']
+__all__ = ['ToTensor', 'RandomErasing', 'Cutout', 'random_angle_rotate',
+           'do_color', 'random_shift', 'random_scale']
 
 import math
 import random
-from PIL import Image
-import cv2
 
+import cv2
 import numpy as np
 
-from .functional import *
+from .functional import to_tensor
+
+
+class ToTensor(object):
+    """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
+
+    Converts a PIL Image or numpy.ndarray (H x W x C) in the range
+    [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]
+    if the PIL Image belongs to one of the modes (L, LA, P, I, F, RGB, YCbCr, RGBA, CMYK, 1)
+    or if the numpy.ndarray has dtype = np.uint8
+
+    In the other cases, tensors are returned without scaling.
+    """
+
+    def __call__(self, pic):
+        """
+        Args:
+            pic (PIL Image or numpy.ndarray): Image to be converted to tensor.
+
+        Returns:
+            Tensor: Converted image.
+        """
+        return to_tensor(pic)
+
+    def __repr__(self):
+        return self.__class__.__name__ + '()'
 
 
 class RandomErasing(object):
diff --git a/fastreid/engine/defaults.py b/fastreid/engine/defaults.py
index f03e76b..16833e9 100644
--- a/fastreid/engine/defaults.py
+++ b/fastreid/engine/defaults.py
@@ -18,18 +18,9 @@ import torch
 # from fvcore.nn.precise_bn import get_bn_modules
 from torch.nn import DataParallel
 
-from . import hooks
-from .train_loop import SimpleTrainer
-from ..data import (
-    build_reid_test_loader,
-    build_reid_train_loader,
-)
-from ..evaluation import (
-    DatasetEvaluator,
-    inference_on_dataset,
-    print_csv_format,
-    ReidEvaluator,
-)
+from ..data import build_reid_test_loader, build_reid_train_loader
+from ..evaluation import (DatasetEvaluator, ReidEvaluator,
+                          inference_on_dataset, print_csv_format)
 from ..modeling.losses import build_criterion
 from ..modeling.meta_arch import build_model
 from ..solver import build_lr_scheduler, build_optimizer
@@ -38,6 +29,8 @@ from ..utils.checkpoint import Checkpointer
 from ..utils.events import CommonMetricPrinter, JSONWriter, TensorboardXWriter
 from ..utils.file_io import PathManager
 from ..utils.logger import setup_logger
+from . import hooks
+from .train_loop import SimpleTrainer
 
 __all__ = ["default_argument_parser", "default_setup", "DefaultPredictor", "DefaultTrainer"]
 
@@ -147,13 +140,6 @@ class DefaultPredictor:
         checkpointer = Checkpointer(self.model)
         checkpointer.load(cfg.MODEL.WEIGHTS)
 
-        # self.transform_gen = T.Resize(
-        #     [cfg.INPUT.MIN_SIZE_TEST, cfg.INPUT.MIN_SIZE_TEST], cfg.INPUT.MAX_SIZE_TEST
-        # )
-
-        self.input_format = cfg.INPUT.FORMAT
-        assert self.input_format in ["RGB", "BGR"], self.input_format
-
     def __call__(self, original_image):
         """
         Args:
@@ -213,20 +199,19 @@ class DefaultTrainer(SimpleTrainer):
         Args:
             cfg (CfgNode):
         """
-        logger = logging.getLogger("fastreid")
+        logger = logging.getLogger("fastreid."+__name__)
         if not logger.isEnabledFor(logging.INFO):  # setup_logger is not called for d2
             setup_logger()
         # Assume these objects must be constructed in this order.
         model = self.build_model(cfg)
         optimizer = self.build_optimizer(cfg, model)
         data_loader = self.build_train_loader(cfg)
-        preprocess_inputs = self.build_preprocess_inputs(cfg)
         criterion = self.build_criterion(cfg)
 
         # For training, wrap with DP. But don't need this for inference.
         model = DataParallel(model)
         model = model.cuda()
-        super().__init__(model, data_loader, optimizer, preprocess_inputs, criterion)
+        super().__init__(model, data_loader, optimizer, criterion)
 
         self.scheduler = self.build_lr_scheduler(cfg, optimizer)
         # Assume no other objects need to be checkpointed.
@@ -341,38 +326,6 @@ class DefaultTrainer(SimpleTrainer):
         #     verify_results(self.cfg, self._last_eval_results)
         #     return self._last_eval_results
 
-    @classmethod
-    def build_preprocess_inputs(cls, cfg):
-        assert len(cfg.MODEL.PIXEL_MEAN) == len(cfg.MODEL.PIXEL_STD)
-        num_channels = len(cfg.MODEL.PIXEL_MEAN)
-        pixel_mean = torch.tensor(cfg.MODEL.PIXEL_MEAN).view(1, num_channels, 1, 1)
-        pixel_std = torch.tensor(cfg.MODEL.PIXEL_STD).view(1, num_channels, 1, 1)
-        normalizer = lambda x: (x - pixel_mean) / pixel_std
-
-        def preprocess_inputs(batched_inputs):
-            # images
-            images = [x["images"] for x in batched_inputs]
-            is_ndarray = isinstance(images[0], np.ndarray)
-            if not is_ndarray:
-                w = images[0].size[0]
-                h = images[0].size[1]
-            else:
-                w = images[0].shape[1]
-                h = images[0].shape[0]
-            tensor = torch.zeros((len(images), 3, h, w), dtype=torch.float32)
-            for i, image in enumerate(images):
-                if not is_ndarray:
-                    image = np.asarray(image, dtype=np.float32)
-                numpy_array = np.rollaxis(image, 2)
-                tensor[i] += torch.from_numpy(numpy_array)
-
-            # labels
-            labels = torch.tensor([x["targets"] for x in batched_inputs]).long()
-
-            return normalizer(tensor), labels
-
-        return preprocess_inputs
-
     @classmethod
     def build_model(cls, cfg):
         """
diff --git a/fastreid/engine/hooks.py b/fastreid/engine/hooks.py
index 51f3ffe..0100116 100644
--- a/fastreid/engine/hooks.py
+++ b/fastreid/engine/hooks.py
@@ -11,11 +11,12 @@ from collections import Counter
 
 import torch
 
+from ..evaluation.testing import flatten_results_dict
 from ..utils import comm
 from ..utils.checkpoint import PeriodicCheckpointer as _PeriodicCheckpointer
 from ..utils.events import EventStorage, EventWriter
-from ..evaluation.testing import flatten_results_dict
 from ..utils.file_io import PathManager
+from ..utils.precision_bn import update_bn_stats, get_bn_modules
 from ..utils.timer import Timer
 from .train_loop import HookBase
 
@@ -27,7 +28,7 @@ __all__ = [
     "LRScheduler",
     "AutogradProfiler",
     "EvalHook",
-    # "PreciseBN",
+    "PreciseBN",
 ]
 
 """
@@ -344,72 +345,70 @@ class EvalHook(HookBase):
         # therefore we clean it to avoid circular reference in the end
         del self._func
 
-# class PreciseBN(HookBase):
-#     """
-#     The standard implementation of BatchNorm uses EMA in inference, which is
-#     sometimes suboptimal.
-#     This class computes the true average of statistics rather than the moving average,
-#     and put true averages to every BN layer in the given model.
-#     It is executed every ``period`` iterations and after the last iteration.
-#     """
-#
-#     def __init__(self, period, model, data_loader, num_iter):
-#         """
-#         Args:
-#             period (int): the period this hook is run, or 0 to not run during training.
-#                 The hook will always run in the end of training.
-#             model (nn.Module): a module whose all BN layers in training mode will be
-#                 updated by precise BN.
-#                 Note that user is responsible for ensuring the BN layers to be
-#                 updated are in training mode when this hook is triggered.
-#             data_loader (iterable): it will produce data to be run by `model(data)`.
-#             num_iter (int): number of iterations used to compute the precise
-#                 statistics.
-#         """
-#         self._logger = logging.getLogger(__name__)
-#         if len(get_bn_modules(model)) == 0:
-#             self._logger.info(
-#                 "PreciseBN is disabled because model does not contain BN layers in training mode."
-#             )
-#             self._disabled = True
-#             return
-#
-#         self._model = model
-#         self._data_loader = data_loader
-#         self._num_iter = num_iter
-#         self._period = period
-#         self._disabled = False
-#
-#         self._data_iter = None
-#
-#     def after_step(self):
-#         next_iter = self.trainer.iter + 1
-#         is_final = next_iter == self.trainer.max_iter
-#         if is_final or (self._period > 0 and next_iter % self._period == 0):
-#             self.update_stats()
-#
-#     def update_stats(self):
-#         """
-#         Update the model with precise statistics. Users can manually call this method.
-#         """
-#         if self._disabled:
-#             return
-#
-#         if self._data_iter is None:
-#             self._data_iter = iter(self._data_loader)
-#
-#         def data_loader():
-#             for num_iter in itertools.count(1):
-#                 if num_iter % 100 == 0:
-#                     self._logger.info(
-#                         "Running precise-BN ... {}/{} iterations.".format(num_iter, self._num_iter)
-#                     )
-#                 # This way we can reuse the same iterator
-#                 yield next(self._data_iter)
-#
-#         with EventStorage():  # capture events in a new storage to discard them
-#             self._logger.info(
-#                 "Running precise-BN for {} iterations...  ".format(self._num_iter)
-#                 + "Note that this could produce different statistics every time."
-#             )
-#             update_bn_stats(self._model, data_loader(), self._num_iter)
+
+class PreciseBN(HookBase):
+    """
+    The standard implementation of BatchNorm uses EMA in inference, which is
+    sometimes suboptimal.
+    This class computes the true average of statistics rather than the moving average,
+    and put true averages to every BN layer in the given model.
+    It is executed after the last iteration.
+    """
+
+    def __init__(self, model, data_loader, num_iter):
+        """
+        Args:
+            model (nn.Module): a module whose all BN layers in training mode will be
+                updated by precise BN.
+                Note that user is responsible for ensuring the BN layers to be
+                updated are in training mode when this hook is triggered.
+            data_loader (iterable): it will produce data to be run by `model(data)`.
+            num_iter (int): number of iterations used to compute the precise
+                statistics.
+        """
+        self._logger = logging.getLogger(__name__)
+        if len(get_bn_modules(model)) == 0:
+            self._logger.info(
+                "PreciseBN is disabled because model does not contain BN layers in training mode."
+            )
+            self._disabled = True
+            return
+
+        self._model = model
+        self._data_loader = data_loader
+        self._num_iter = num_iter
+        self._disabled = False
+
+        self._data_iter = None
+
+    def after_step(self):
+        next_iter = self.trainer.iter + 1
+        is_final = next_iter == self.trainer.max_iter
+        if is_final:
+            self.update_stats()
+
+    def update_stats(self):
+        """
+        Update the model with precise statistics. Users can manually call this method.
+        """
+        if self._disabled:
+            return
+
+        if self._data_iter is None:
+            self._data_iter = self._data_loader
+
+        def data_loader():
+            for num_iter in itertools.count(1):
+                if num_iter % 100 == 0:
+                    self._logger.info(
+                        "Running precise-BN ... {}/{} iterations.".format(num_iter, self._num_iter)
+                    )
+                # This way we can reuse the same iterator
+                yield self._data_iter.next()
+
+        with EventStorage():  # capture events in a new storage to discard them
+            self._logger.info(
+                "Running precise-BN for {} iterations...  ".format(self._num_iter)
+                + "Note that this could produce different statistics every time."
+            )
+            update_bn_stats(self._model, data_loader(), self._num_iter)
diff --git a/fastreid/engine/train_loop.py b/fastreid/engine/train_loop.py
index 78f872b..510b80d 100644
--- a/fastreid/engine/train_loop.py
+++ b/fastreid/engine/train_loop.py
@@ -160,7 +160,7 @@ class SimpleTrainer(TrainerBase):
     or write your own training loop.
     """
 
-    def __init__(self, model, data_loader, optimizer, preprocess_inputs, criterion):
+    def __init__(self, model, data_loader, optimizer, criterion):
         """
         Args:
             model: a torch Module. Takes a data from data_loader and returns a
@@ -180,9 +180,7 @@ class SimpleTrainer(TrainerBase):
 
         self.model = model
         self.data_loader = data_loader
-        self._data_loader_iter = iter(data_loader)
         self.optimizer = optimizer
-        self.preprocess_inputs = preprocess_inputs
         self.criterion = criterion
 
     def run_step(self):
@@ -194,14 +192,13 @@ class SimpleTrainer(TrainerBase):
         """
         If your want to do something with the data, you can wrap the dataloader.
         """
-        data = next(self._data_loader_iter)
+        data = self.data_loader.next()
         data_time = time.perf_counter() - start
 
         """
         If your want to do something with the heads, you can wrap the model.
         """
-        inputs = self.preprocess_inputs(data)
-        outputs = self.model(*inputs)
+        outputs = self.model(data)
         loss_dict = self.criterion(*outputs)
         losses = sum(loss for loss in loss_dict.values())
         self._detect_anomaly(losses, loss_dict)
diff --git a/fastreid/evaluation/evaluator.py b/fastreid/evaluation/evaluator.py
index baf4723..d763877 100644
--- a/fastreid/evaluation/evaluator.py
+++ b/fastreid/evaluation/evaluator.py
@@ -97,28 +97,31 @@ def inference_on_dataset(model, data_loader, evaluator):
     """
     # num_devices = torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1
     logger = logging.getLogger(__name__)
-    logger.info("Start inference on {} images".format(len(data_loader.dataset)))
+    logger.info("Start inference on {} images".format(len(data_loader.loader.dataset)))
 
-    total = len(data_loader)  # inference data loader must have a fixed length
+    total = len(data_loader.loader)  # inference data loader must have a fixed length
     evaluator.reset()
 
     num_warmup = min(5, total - 1)
     start_time = time.perf_counter()
     total_compute_time = 0
     with inference_context(model), torch.no_grad():
-        for idx, inputs in enumerate(data_loader):
+        idx = 0
+        inputs = data_loader.next()
+        while inputs is not None:
             if idx == num_warmup:
                 start_time = time.perf_counter()
                 total_compute_time = 0
 
             start_compute_time = time.perf_counter()
-            inputs = evaluator.preprocess_inputs(inputs)
-            outputs = model(*inputs)
+            outputs = model(inputs)
             if torch.cuda.is_available():
                 torch.cuda.synchronize()
             total_compute_time += time.perf_counter() - start_compute_time
-            evaluator.process(*outputs)
+            evaluator.process(outputs)
 
+            idx += 1
+            inputs = data_loader.next()
             # iters_after_start = idx + 1 - num_warmup * int(idx >= num_warmup)
             # seconds_per_img = total_compute_time / iters_after_start
             # if idx >= num_warmup * 2 or seconds_per_img > 30:
diff --git a/fastreid/evaluation/reid_evaluation.py b/fastreid/evaluation/reid_evaluation.py
index 8f473e8..a4cac7b 100644
--- a/fastreid/evaluation/reid_evaluation.py
+++ b/fastreid/evaluation/reid_evaluation.py
@@ -4,12 +4,9 @@
 @contact: sherlockliao01@gmail.com
 """
 import copy
-import logging
 from collections import OrderedDict
 
-import numpy as np
 import torch
-import torch.nn.functional as F
 
 from .evaluator import DatasetEvaluator
 from .rank import evaluate_rank
@@ -18,13 +15,6 @@ from .rank import evaluate_rank
 class ReidEvaluator(DatasetEvaluator):
     def __init__(self, cfg, num_query):
         self._num_query = num_query
-        self._logger = logging.getLogger(__name__)
-
-        assert len(cfg.MODEL.PIXEL_MEAN) == len(cfg.MODEL.PIXEL_STD)
-        num_channels = len(cfg.MODEL.PIXEL_MEAN)
-        pixel_mean = torch.tensor(cfg.MODEL.PIXEL_MEAN).view(1, num_channels, 1, 1)
-        pixel_std = torch.tensor(cfg.MODEL.PIXEL_STD).view(1, num_channels, 1, 1)
-        self.normalizer = lambda x: (x - pixel_mean) / pixel_std
 
         self.features = []
         self.pids = []
@@ -35,31 +25,10 @@ class ReidEvaluator(DatasetEvaluator):
         self.pids = []
         self.camids = []
 
-    def preprocess_inputs(self, inputs):
-        # images
-        images = [x["images"] for x in inputs]
-        is_ndarray = isinstance(images[0], np.ndarray)
-        if not is_ndarray:
-            w = images[0].size[0]
-            h = images[0].size[1]
-        else:
-            w = images[0].shape[1]
-            h = images[0].shpae[0]
-        tensor = torch.zeros((len(images), 3, h, w), dtype=torch.float32)
-        for i, image in enumerate(images):
-            if not is_ndarray:
-                image = np.asarray(image, dtype=np.float32)
-            numpy_array = np.rollaxis(image, 2)
-            tensor[i] += torch.from_numpy(numpy_array)
-
-        # labels
-        for input in inputs:
-            self.pids.append(input['targets'])
-            self.camids.append(input['camid'])
-        return self.normalizer(tensor),
-
     def process(self, outputs):
-        self.features.append(outputs.cpu())
+        self.features.append(outputs[0].cpu())
+        self.pids.extend(outputs[1].cpu().numpy())
+        self.camids.extend(outputs[2].cpu().numpy())
 
     def evaluate(self):
         features = torch.cat(self.features, dim=0)
diff --git a/fastreid/modeling/backbones/resnet.py b/fastreid/modeling/backbones/resnet.py
index 903775b..bd9fcef 100644
--- a/fastreid/modeling/backbones/resnet.py
+++ b/fastreid/modeling/backbones/resnet.py
@@ -186,5 +186,6 @@ def build_resnet_backbone(cfg):
             state_dict = new_state_dict
         res = model.load_state_dict(state_dict, strict=False)
         logger = logging.getLogger(__name__)
-        logger.info('missing keys is {} and unexpected keys is {}'.format(res.missing_keys, res.unexpected_keys))
+        logger.info('missing keys is {}'.format(res.missing_keys))
+        logger.info('unexpected keys is {}'.format(res.unexpected_keys))
     return model
diff --git a/fastreid/modeling/heads/arcface.py b/fastreid/modeling/heads/arcface.py
index ad3460e..55d3281 100644
--- a/fastreid/modeling/heads/arcface.py
+++ b/fastreid/modeling/heads/arcface.py
@@ -50,7 +50,7 @@ class ArcFace(nn.Module):
         bn_features = self.bnneck(global_features)
 
         if not self.training:
-            return F.normalize(bn_features),
+            return F.normalize(bn_features)
 
         cosine = F.linear(F.normalize(bn_features), F.normalize(self.weight))
         sine = torch.sqrt((1.0 - torch.pow(cosine, 2)).clamp(0, 1))
diff --git a/fastreid/modeling/heads/bn_linear.py b/fastreid/modeling/heads/bn_linear.py
index af03062..414d7df 100644
--- a/fastreid/modeling/heads/bn_linear.py
+++ b/fastreid/modeling/heads/bn_linear.py
@@ -35,7 +35,7 @@ class BNneckLinear(nn.Module):
         bn_features = self.bnneck(global_features)
 
         if not self.training:
-            return F.normalize(bn_features),
+            return F.normalize(bn_features)
 
         pred_class_logits = self.classifier(bn_features)
-        return pred_class_logits, global_features, targets,
+        return pred_class_logits, global_features, targets
diff --git a/fastreid/modeling/meta_arch/baseline.py b/fastreid/modeling/meta_arch/baseline.py
index d778ff2..5a29ba8 100644
--- a/fastreid/modeling/meta_arch/baseline.py
+++ b/fastreid/modeling/meta_arch/baseline.py
@@ -4,13 +4,11 @@
 @contact: sherlockliao01@gmail.com
 """
 
-import torch
 from torch import nn
 
 from .build import META_ARCH_REGISTRY
 from ..backbones import build_backbone
 from ..heads import build_reid_heads
-from ...layers import Lambda
 
 
 @META_ARCH_REGISTRY.register()
@@ -20,26 +18,19 @@ class Baseline(nn.Module):
         self.backbone = build_backbone(cfg)
         self.heads = build_reid_heads(cfg)
 
-    def forward(self, inputs, labels=None):
-        global_feat = self.backbone(inputs)  # (bs, 2048, 16, 8)
-        outputs = self.heads(global_feat, labels)
+    def forward(self, inputs):
+        if not self.training:
+            return self.inference(inputs)
+
+        images = inputs["images"]
+        targets = inputs["targets"]
+        global_feat = self.backbone(images)  # (bs, 2048, 16, 8)
+        outputs = self.heads(global_feat, targets)
         return outputs
 
-    # def unfreeze_all_layers(self, ):
-    #     self.train()
-    #     for p in self.parameters():
-    #         p.requires_grad_()
-    #
-    # def unfreeze_specific_layer(self, names):
-    #     if isinstance(names, str):
-    #         names = [names]
-    #
-    #     for name, module in self.named_children():
-    #         if name in names:
-    #             module.train()
-    #             for p in module.parameters():
-    #                 p.requires_grad_()
-    #         else:
-    #             module.eval()
-    #             for p in module.parameters():
-    #                 p.requires_grad_(False)
+    def inference(self, inputs):
+        assert not self.training
+        images = inputs["images"]
+        global_feat = self.backbone(images)
+        pred_features = self.heads(global_feat)
+        return pred_features, inputs["targets"], inputs["camid"]
diff --git a/fastreid/utils/precision_bn.py b/fastreid/utils/precision_bn.py
index d87270b..9c3727b 100644
--- a/fastreid/utils/precision_bn.py
+++ b/fastreid/utils/precision_bn.py
@@ -5,8 +5,9 @@
 """
 
 import itertools
+
 import torch
-from data.prefetcher import data_prefetcher
+
 
 BN_MODULE_TYPES = (
     torch.nn.BatchNorm1d,
@@ -57,26 +58,19 @@ def update_bn_stats(model, data_loader, num_iters: int = 200):
     running_mean = [torch.zeros_like(bn.running_mean) for bn in bn_layers]
     running_var = [torch.zeros_like(bn.running_var) for bn in bn_layers]
 
-    ind = 0
-    num_epoch = num_iters // len(data_loader) + 1
-    for _ in range(num_epoch):
-        prefetcher = data_prefetcher(data_loader)
-        batch = prefetcher.next()
-        while batch[0] is not None:
-            model(batch[0], batch[1])
+    for ind, inputs in enumerate(itertools.islice(data_loader, num_iters)):
+        with torch.no_grad():  # No need to backward
+            model(inputs)
 
-            for i, bn in enumerate(bn_layers):
-                # Accumulates the bn stats.
-                running_mean[i] += (bn.running_mean - running_mean[i]) / (ind + 1)
-                running_var[i] += (bn.running_var - running_var[i]) / (ind + 1)
-                # We compute the "average of variance" across iterations.
-
-            if ind == (num_iters - 1):
-                print(f"update_bn_stats is running for {num_iters} iterations.")
-                break
-
-            ind += 1
-            batch = prefetcher.next()
+        for i, bn in enumerate(bn_layers):
+            # Accumulates the bn stats.
+            running_mean[i] += (bn.running_mean - running_mean[i]) / (ind + 1)
+            running_var[i] += (bn.running_var - running_var[i]) / (ind + 1)
+            # We compute the "average of variance" across iterations.
+    assert ind == num_iters - 1, (
+        "update_bn_stats is meant to run for {} iterations, "
+        "but the dataloader stops at {} iterations.".format(num_iters, ind)
+    )
 
     for i, bn in enumerate(bn_layers):
         # Sets the precise bn stats.
diff --git a/projects/AGWBaseline/configs/Base-AGW.yml b/projects/AGWBaseline/configs/Base-AGW.yml
index c25600c..2f8e906 100644
--- a/projects/AGWBaseline/configs/Base-AGW.yml
+++ b/projects/AGWBaseline/configs/Base-AGW.yml
@@ -28,10 +28,10 @@ INPUT:
   SIZE_TRAIN: [256, 128]
   SIZE_TEST: [256, 128]
   RE:
-    DO: True
+    ENABLED: True
     PROB: 0.5
   CUTOUT:
-    DO: False
+    ENABLED: False
   DO_PAD: True
 
   DO_LIGHTING: False
diff --git a/projects/StrongBaseline/configs/Base-Strongbaseline.yml b/projects/StrongBaseline/configs/Base-Strongbaseline.yml
index cbbc15a..de333c4 100644
--- a/projects/StrongBaseline/configs/Base-Strongbaseline.yml
+++ b/projects/StrongBaseline/configs/Base-Strongbaseline.yml
@@ -28,10 +28,10 @@ INPUT:
   SIZE_TRAIN: [256, 128]
   SIZE_TEST: [256, 128]
   RE:
-    DO: True
+    ENABLED: True
     PROB: 0.5
   CUTOUT:
-    DO: False
+    ENABLED: False
   DO_PAD: True
 
   DO_LIGHTING: False
diff --git a/projects/StrongBaseline/configs/baseline_market1501.yml b/projects/StrongBaseline/configs/baseline_market1501.yml
index 5a0534c..38d3c46 100644
--- a/projects/StrongBaseline/configs/baseline_market1501.yml
+++ b/projects/StrongBaseline/configs/baseline_market1501.yml
@@ -2,12 +2,24 @@ _BASE_: "Base-Strongbaseline.yml"
 
 MODEL:
   BACKBONE:
-    PRETRAIN: False
+    PRETRAIN: True
+
   HEADS:
+    NAME: "BNneckLinear"
     NUM_CLASSES: 751
 
+  LOSSES:
+    NAME: ("CrossEntropyLoss", "TripletLoss")
+    SMOOTH_ON: True
+    SCALE_CE: 1.0
+
+    MARGIN: 0.0
+    SCALE_TRI: 1.0
+
+
 DATASETS:
   NAMES: ("Market1501",)
   TESTS: ("Market1501",)
 
-OUTPUT_DIR: "logs/fastreid_market1501/softmax_softmargin_wo_pretrain"
+
+OUTPUT_DIR: "logs/market1501/test"
diff --git a/projects/StrongBaseline/non_linear_head.py b/projects/StrongBaseline/non_linear_head.py
new file mode 100644
index 0000000..4dc2ee9
--- /dev/null
+++ b/projects/StrongBaseline/non_linear_head.py
@@ -0,0 +1,78 @@
+# encoding: utf-8
+"""
+@author:  l1aoxingyu
+@contact: sherlockliao01@gmail.com
+"""
+
+import math
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+from torch.nn import Parameter
+
+from fastreid.modeling.heads import REID_HEADS_REGISTRY
+from fastreid.modeling.model_utils import weights_init_classifier, weights_init_kaiming
+
+
+@REID_HEADS_REGISTRY.register()
+class NonLinear(nn.Module):
+    def __init__(self, cfg):
+        super().__init__()
+        self._num_classes = cfg.MODEL.HEADS.NUM_CLASSES
+        self.gap = nn.AdaptiveAvgPool2d(1)
+
+        self.fc1 = nn.Linear(2048, 1024, bias=False)
+        self.bn1 = nn.BatchNorm1d(1024)
+        # self.bn1.bias.requires_grad_(False)
+        self.relu = nn.ReLU(True)
+        self.fc2 = nn.Linear(1024, 512, bias=False)
+        self.bn2 = nn.BatchNorm1d(512)
+        self.bn2.bias.requires_grad_(False)
+
+        self._m = 0.50
+        self._s = 30.0
+        self._in_features = 512
+        self.cos_m = math.cos(self._m)
+        self.sin_m = math.sin(self._m)
+
+        self.th = math.cos(math.pi - self._m)
+        self.mm = math.sin(math.pi - self._m) * self._m
+
+        self.weight = Parameter(torch.Tensor(self._num_classes, self._in_features))
+
+        self.init_parameters()
+
+    def init_parameters(self):
+        self.fc1.apply(weights_init_kaiming)
+        self.bn1.apply(weights_init_kaiming)
+        self.fc2.apply(weights_init_kaiming)
+        self.bn2.apply(weights_init_kaiming)
+        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
+
+    def forward(self, features, targets=None):
+        global_features = self.gap(features)
+        global_features = global_features.view(global_features.shape[0], -1)
+
+        if not self.training:
+            return F.normalize(global_features)
+
+        fc_features = self.fc1(global_features)
+        fc_features = self.bn1(fc_features)
+        fc_features = self.relu(fc_features)
+        fc_features = self.fc2(fc_features)
+        fc_features = self.bn2(fc_features)
+
+        cosine = F.linear(F.normalize(fc_features), F.normalize(self.weight))
+        sine = torch.sqrt((1.0 - torch.pow(cosine, 2)).clamp(0, 1))
+        phi = cosine * self.cos_m - sine * self.sin_m
+        phi = torch.where(cosine > self.th, phi, cosine - self.mm)
+        # --------------------------- convert label to one-hot ---------------------------
+        # one_hot = torch.zeros(cosine.size(), requires_grad=True, device='cuda')
+        one_hot = torch.zeros(cosine.size(), device='cuda')
+        one_hot.scatter_(1, targets.view(-1, 1).long(), 1)
+        # -------------torch.where(out_i = {x_i if condition_i else y_i) -------------
+        pred_class_logits = (one_hot * phi) + (
+                    (1.0 - one_hot) * cosine)  # you can use torch.where if your torch.__version__ is 0.4
+        pred_class_logits *= self._s
+        return pred_class_logits, global_features, targets
diff --git a/projects/StrongBaseline/train_net.py b/projects/StrongBaseline/train_net.py
index 4014751..3d83408 100644
--- a/projects/StrongBaseline/train_net.py
+++ b/projects/StrongBaseline/train_net.py
@@ -11,6 +11,8 @@ from fastreid.config import get_cfg
 from fastreid.engine import DefaultTrainer, default_argument_parser, default_setup
 from fastreid.utils.checkpoint import Checkpointer
 
+from non_linear_head import NonLinear
+
 
 def setup(args):
     """
@@ -36,6 +38,11 @@ def main(args):
         return res
 
     trainer = DefaultTrainer(cfg)
+    # moco pretrain
+    # import torch
+    # state_dict = torch.load('logs/model_0109999.pth')['model_ema']
+    # ret = trainer.model.module.load_state_dict(state_dict, strict=False)
+    #
     trainer.resume_or_load(resume=args.resume)
     return trainer.train()