diff --git a/fastreid/modeling/backbones/__init__.py b/fastreid/modeling/backbones/__init__.py
index 59b6872..b7c8c62 100644
--- a/fastreid/modeling/backbones/__init__.py
+++ b/fastreid/modeling/backbones/__init__.py
@@ -4,7 +4,6 @@
 @contact: sherlockliao01@gmail.com
 """
 
-from se_pcb_net import build_senet_pcb_backbone
 from .build import build_backbone, BACKBONE_REGISTRY
 from .mobilenet import build_mobilenetv2_backbone
 from .osnet import build_osnet_backbone
@@ -15,3 +14,4 @@ from .resnet import build_resnet_backbone
 from .resnext import build_resnext_backbone
 from .shufflenet import build_shufflenetv2_backbone
 from .vision_transformer import build_vit_backbone
+from .se_pcb_net import build_senet_pcb_backbone
diff --git a/fastreid/modeling/backbones/se_pcb_net.py b/fastreid/modeling/backbones/se_pcb_net.py
index 67efe23..8427a6d 100644
--- a/fastreid/modeling/backbones/se_pcb_net.py
+++ b/fastreid/modeling/backbones/se_pcb_net.py
@@ -22,13 +22,13 @@ class SePcbNet(nn.Module):
 	             part_num: int,
 	             embedding_dim: int,
 	             part_dim: int,
-	             last_stride: Tuple[int, int]
+	             last_stride: int,
 	             ):
 		super(SePcbNet, self).__init__()
 		self.part_num = part_num
 		self.embedding_dim = embedding_dim
 		self.part_dim = part_dim
-		self.last_stride = last_stride
+		self.last_stride = (last_stride, last_stride)
 
 		self.cnn = pretrainedmodels.__dict__["se_resnext101_32x4d"](pretrained='imagenet')
 		self.cnn.layer4[0].downsample[0].stride = self.last_stride
@@ -40,7 +40,7 @@ class SePcbNet(nn.Module):
 			setattr(self, 'reduction_' + str(i),
 			        nn.Sequential(
 				        nn.Conv2d(self.embedding_dim, self.part_dim, (1, 1), bias=False),
-				        nn.BatchNorm2d(self.part_num),
+				        nn.BatchNorm2d(self.part_dim),
 				        nn.ReLU()
 			        ))
 
@@ -70,7 +70,6 @@ class SePcbNet(nn.Module):
 		}
 
 	def random_init(self):
-		self.conv1.weight.data.normal_(0, math.sqrt(2. / (7 * 7 * 64)))
 		for m in self.modules():
 			if isinstance(m, nn.Conv2d):
 				n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
@@ -82,6 +81,8 @@ class SePcbNet(nn.Module):
 				m.weight.data.fill_(1)
 				m.bias.data.zero_()
 
+		self.cnn.layer0.conv1.weight.data.normal_(0, math.sqrt(2. / (7 * 7 * 64)))
+
 
 @BACKBONE_REGISTRY.register()
 def build_senet_pcb_backbone(cfg: CfgNode):
@@ -99,10 +100,10 @@ def build_senet_pcb_backbone(cfg: CfgNode):
 	if pretrain:
 		if pretrain_path:
 			try:
-				state_dict = torch.load(pretrain_path, map_location=torch.device('cpu'))['model']
+				state_dict = torch.load(pretrain_path, map_location=torch.device('cpu'))
 				new_state_dict = {}
 				for k in state_dict:
-					new_k = '.'.join(k.split('.')[2:])
+					new_k = 'cnn.' + k
 					if new_k in model.state_dict() and (model.state_dict()[new_k].shape == state_dict[k].shape):
 						new_state_dict[new_k] = state_dict[k]
 				state_dict = new_state_dict
diff --git a/fastreid/modeling/heads/pcb_head.py b/fastreid/modeling/heads/pcb_head.py
index 0b59f85..5c426cc 100644
--- a/fastreid/modeling/heads/pcb_head.py
+++ b/fastreid/modeling/heads/pcb_head.py
@@ -79,37 +79,31 @@ class PcbHead(nn.Module):
         self.reset_parameters()
 
     def forward(self, features, targets=None):
-        full = features['full']
-        parts = features['parts']
-        bsz = full.size(0)
-        
-        # normalize
-        full = self._normalize(full)
-        parts = self._normalize(parts)
+        query_feature = features['query']
+        gallery_feature = features['gallery']
 
-        # split features into pair
-        query_full, gallery_full = self._split_features(full, bsz)
-        query_part_0, gallery_part_0 = self._split_features(parts[0], bsz)
-        query_part_1, gallery_part_1 = self._split_features(parts[1], bsz)
-        query_part_2, gallery_part_2 = self._split_features(parts[2], bsz)
+        query_full, query_part_0, query_part_1, query_part_2 = torch.split(query_feature,
+            [self.full_dim, self.part_dim, self.part_dim, self.part_dim], dim=-1)
+        gallery_full, gallery_part_0, gallery_part_1, gallery_part_2 = torch.split(gallery_feature,
+            [self.full_dim, self.part_dim, self.part_dim, self.part_dim], dim=-1)
 
         m_full = self.match_full(
-                    torch.cat([query_full, gallery_full, query_full - gallery_full, 
+                    torch.cat([query_full, gallery_full, (query_full - gallery_full).abs(),
                         query_full * gallery_full], dim=-1)
                 )
         
         m_part_0 = self.match_part_0(
-                    torch.cat([query_part_0, gallery_part_0, query_part_0 - gallery_part_0, 
+                    torch.cat([query_part_0, gallery_part_0, (query_part_0 - gallery_part_0).abs(),
                         query_part_0 * gallery_part_0], dim=-1)
                 )
 
         m_part_1 = self.match_part_1(
-                    torch.cat([query_part_1, gallery_part_1, query_part_1 - gallery_part_1, 
+                    torch.cat([query_part_1, gallery_part_1, (query_part_1 - gallery_part_1).abs(),
                         query_part_1 * gallery_part_1], dim=-1)
                 )
 
         m_part_2 = self.match_part_2(
-                    torch.cat([query_part_2, gallery_part_2, query_part_2 - gallery_part_2, 
+                    torch.cat([query_part_2, gallery_part_2, (query_part_2 - gallery_part_2).abs(),
                         query_part_2 * gallery_part_2], dim=-1)
                 )
 
diff --git a/fastreid/modeling/meta_arch/__init__.py b/fastreid/modeling/meta_arch/__init__.py
index e825462..ed718e3 100644
--- a/fastreid/modeling/meta_arch/__init__.py
+++ b/fastreid/modeling/meta_arch/__init__.py
@@ -14,3 +14,4 @@ from .moco import MoCo
 from .distiller import Distiller
 from .metric import Metric
 from .pcb import PCB
+from .pcb_online import PcbOnline
diff --git a/fastreid/modeling/meta_arch/pcb_online.py b/fastreid/modeling/meta_arch/pcb_online.py
new file mode 100644
index 0000000..becdce6
--- /dev/null
+++ b/fastreid/modeling/meta_arch/pcb_online.py
@@ -0,0 +1,60 @@
+# coding: utf-8
+"""
+Sun, Y. ,  Zheng, L. ,  Yang, Y. ,  Tian, Q. , &  Wang, S. . (2017). Beyond part models: person retrieval with refined part pooling (and a strong convolutional baseline). Springer, Cham.
+实现和线上一模一样的PCB
+"""
+import torch
+import torch.nn.functional as F
+
+from fastreid.modeling.losses import cross_entropy_loss, log_accuracy
+from fastreid.modeling.meta_arch import Baseline
+from fastreid.modeling.meta_arch import META_ARCH_REGISTRY
+
+
+@META_ARCH_REGISTRY.register()
+class PcbOnline(Baseline):
+
+    def forward(self, batched_inputs):
+        images = self.preprocess_image(batched_inputs)
+        bsz = int(images.size(0) / 2)
+        feats = self.backbone(images)
+        feats = torch.cat((feats['full'], feats['parts'][0], feats['parts'][1], feats['parts'][2]), 1)
+        feats = F.normalize(feats, p=2.0, dim=-1)
+
+        qf = feats[0: bsz * 2: 2, ...]
+        xf = feats[1: bsz * 2: 2, ...]
+        outputs = self.heads({'query': qf, 'gallery': xf})  
+        
+        if self.training:
+            targets = batched_inputs['targets']
+            losses = self.losses(outputs, targets)
+            return losses
+        else:
+            return outputs
+
+    def losses(self, outputs, gt_labels):
+        """
+        Compute loss from modeling's outputs, the loss function input arguments
+        must be the same as the outputs of the model forwarding.
+        """
+        # model predictions
+        pred_class_logits = outputs['pred_class_logits'].detach()
+        cls_outputs = outputs['cls_outputs']
+        
+        # Log prediction accuracy
+        log_accuracy(pred_class_logits, gt_labels)
+
+        loss_dict = {}
+        loss_names = self.loss_kwargs['loss_names']
+
+        if 'CrossEntropyLoss' in loss_names:
+            ce_kwargs = self.loss_kwargs.get('ce')
+            loss_dict['loss_cls'] = cross_entropy_loss(
+                cls_outputs,
+                gt_labels,
+                ce_kwargs.get('eps'),
+                ce_kwargs.get('alpha')
+            ) * ce_kwargs.get('scale')
+
+
+        return loss_dict
diff --git a/projects/FastShoe/configs/online-pcb.yaml b/projects/FastShoe/configs/online-pcb.yaml
index 555572d..266300f 100644
--- a/projects/FastShoe/configs/online-pcb.yaml
+++ b/projects/FastShoe/configs/online-pcb.yaml
@@ -1,8 +1,8 @@
 _BASE_: base.yaml
 
 MODEL:
-  META_ARCHITECTURE: PCB
-  
+  META_ARCHITECTURE: PcbOnline
+
   PCB:
     PART_NUM: 3
     PART_DIM: 512
@@ -14,10 +14,12 @@ MODEL:
       EMBEDDING_DIM: 512
   
   BACKBONE:
-    NAME: build_resnet_backbone
+    PRETRAIN: True
+    PRETRAIN_PATH: /home/apps/.cache/torch/hub/checkpoints/se_resnext101_32x4d-3b2fe3d8.pth
+    NAME: build_senet_pcb_backbone
     DEPTH: 101x
     NORM: BN
-    LAST_STRIDE: 2
+    LAST_STRIDE: 1
     FEAT_DIM: 512
     PRETRAIN: True
     WITH_IBN: True
@@ -46,11 +48,34 @@ INPUT:
     ENABLED: True
     SIZE: [270, 260]
     SCALE: [0.8, 1.2]
-    RATIO: [3./4, 4./3]
+    RATIO: [0.75, 1.33333333]
+
+DATALOADER:
+  NUM_WORKERS: 8
+
+SOLVER:
+  OPT: SGD
+  SCHED: CosineAnnealingLR
+
+  BASE_LR: 0.001
+  MOMENTUM: 0.9
+  NESTEROV: False
+
+  BIAS_LR_FACTOR: 1.
+  WEIGHT_DECAY: 0.0005
+  WEIGHT_DECAY_BIAS: 0.
+  ETA_MIN_LR: 0.00003
+
+  WARMUP_FACTOR: 0.1
+  WARMUP_ITERS: 1000
+
+  IMS_PER_BATCH: 40
+
+TEST:
+  IMS_PER_BATCH: 64
 
 DATASETS:
   NAMES: ("ShoeDataset",)
-  TESTS: ("ShoeDataset", "OnlineDataset")
+  TESTS: ("ShoeDataset",)
 
 OUTPUT_DIR: projects/FastShoe/logs/online-pcb
-
diff --git a/projects/FastShoe/fastshoe/data/pair_dataset.py b/projects/FastShoe/fastshoe/data/pair_dataset.py
index 5a53440..21d31f7 100644
--- a/projects/FastShoe/fastshoe/data/pair_dataset.py
+++ b/projects/FastShoe/fastshoe/data/pair_dataset.py
@@ -28,14 +28,9 @@ class PairDataset(Dataset):
             self._logger.info('set {} with {} random seed: 12345'.format(self.mode, self.__class__.__name__))
             seed_all_rng(12345)
         
-        # if self.mode == 'train':
-        #     # make negative sample come from all negative folders when train
-        #     self.neg_folders = sum(self.neg_folders, list())
-
     def __len__(self):
         if self.mode == 'test':
             return len(self.pos_folders) * 10
-
         return len(self.pos_folders)
 
     def __getitem__(self, idx):
@@ -43,9 +38,6 @@ class PairDataset(Dataset):
             idx = int(idx / 10)
 		
         pf = self.pos_folders[idx]
-        # if self.mode == 'train':
-        #     nf = self.neg_folders
-        # else:
         nf = self.neg_folders[idx]
 
         label = 1