From c3ac4f504cba74ea8d300a26ec9296aca956c001 Mon Sep 17 00:00:00 2001
From: liaoxingyu <sherlockliao01@gmail.com>
Date: Mon, 31 May 2021 17:30:43 +0800
Subject: [PATCH] Support amp and resume training in fastface

AMP in partial-fc needs to be done only on backbone; In order to impl `resume training`, need to save & load different part of classifier weight in each GPU.
---
 projects/FastFace/configs/face_base.yml       |  13 +-
 projects/FastFace/configs/r50_ir.yml          |   5 +-
 projects/FastFace/fastface/__init__.py        |   1 +
 projects/FastFace/fastface/config.py          |   2 +
 projects/FastFace/fastface/datasets/ms1mv2.py |   2 +-
 .../FastFace/fastface/modeling/__init__.py    |   2 +-
 .../fastface/modeling/face_baseline.py        |  29 ++-
 .../FastFace/fastface/modeling/face_head.py   |   8 +-
 .../FastFace/fastface/modeling/iresnet.py     | 179 ++++++++++++++++++
 .../FastFace/fastface/modeling/partial_fc.py  |  17 --
 .../FastFace/fastface/modeling/resnet_ir.py   | 122 ------------
 .../FastFace/fastface/pfc_checkpointer.py     |  84 ++++++++
 projects/FastFace/fastface/trainer.py         |  59 ++++--
 projects/FastFace/fastface/utils_amp.py       |  86 +++++++++
 projects/FastFace/train_net.py                |   1 -
 15 files changed, 432 insertions(+), 178 deletions(-)
 create mode 100644 projects/FastFace/fastface/modeling/iresnet.py
 delete mode 100644 projects/FastFace/fastface/modeling/resnet_ir.py
 create mode 100644 projects/FastFace/fastface/pfc_checkpointer.py
 create mode 100644 projects/FastFace/fastface/utils_amp.py

diff --git a/projects/FastFace/configs/face_base.yml b/projects/FastFace/configs/face_base.yml
index f125cf0..4bbec67 100644
--- a/projects/FastFace/configs/face_base.yml
+++ b/projects/FastFace/configs/face_base.yml
@@ -4,6 +4,9 @@ MODEL:
   PIXEL_MEAN: [127.5, 127.5, 127.5]
   PIXEL_STD: [127.5, 127.5, 127.5]
 
+  BACKBONE:
+    NAME: build_iresnet_backbone
+
   HEADS:
     NAME: FaceHead
     WITH_BNNECK: True
@@ -30,7 +33,7 @@ MODEL:
 DATASETS:
   REC_PATH: /export/home/DATA/Glint360k/train.rec
   NAMES: ("MS1MV2",)
-  TESTS: ("CPLFW", "VGG2_FP", "CALFW", "CFP_FF", "CFP_FP", "AgeDB_30", "LFW")
+  TESTS: ("CFP_FP", "AgeDB_30", "LFW")
 
 INPUT:
   SIZE_TRAIN: [0,]  # No need of resize
@@ -47,10 +50,10 @@ DATALOADER:
 SOLVER:
   MAX_EPOCH: 20
   AMP:
-    ENABLED: False
+    ENABLED: True
 
   OPT: SGD
-  BASE_LR: 0.1
+  BASE_LR: 0.05
   MOMENTUM: 0.9
 
   SCHED: MultiStepLR
@@ -59,10 +62,10 @@ SOLVER:
   BIAS_LR_FACTOR: 1.
   WEIGHT_DECAY: 0.0005
   WEIGHT_DECAY_BIAS: 0.0005
-  IMS_PER_BATCH: 512
+  IMS_PER_BATCH: 256
 
   WARMUP_FACTOR: 0.1
-  WARMUP_ITERS: 5000
+  WARMUP_ITERS: 0
 
   CHECKPOINT_PERIOD: 1
 
diff --git a/projects/FastFace/configs/r50_ir.yml b/projects/FastFace/configs/r50_ir.yml
index d18bc59..7340754 100644
--- a/projects/FastFace/configs/r50_ir.yml
+++ b/projects/FastFace/configs/r50_ir.yml
@@ -3,13 +3,12 @@ _BASE_: face_base.yml
 MODEL:
 
   BACKBONE:
-    NAME: build_resnetIR_backbone
     DEPTH: 50x
     FEAT_DIM: 25088 # 512x7x7
-    WITH_SE: True
+    DROPOUT: 0.
 
   HEADS:
     PFC:
       ENABLED: True
 
-OUTPUT_DIR: projects/FastFace/logs/ir_se50-glink360k-pfc0.1
+OUTPUT_DIR: projects/FastFace/logs/pfc0.1_insightface
diff --git a/projects/FastFace/fastface/__init__.py b/projects/FastFace/fastface/__init__.py
index 10c6a39..5b2ee35 100644
--- a/projects/FastFace/fastface/__init__.py
+++ b/projects/FastFace/fastface/__init__.py
@@ -7,3 +7,4 @@
 from .modeling import *
 from .config import add_face_cfg
 from .trainer import FaceTrainer
+from .datasets import *
diff --git a/projects/FastFace/fastface/config.py b/projects/FastFace/fastface/config.py
index f47fa9f..af9e65e 100644
--- a/projects/FastFace/fastface/config.py
+++ b/projects/FastFace/fastface/config.py
@@ -12,5 +12,7 @@ def add_face_cfg(cfg):
 
     _C.DATASETS.REC_PATH = ""
 
+    _C.MODEL.BACKBONE.DROPOUT = 0.
+
     _C.MODEL.HEADS.PFC = CN({"ENABLED": False})
     _C.MODEL.HEADS.PFC.SAMPLE_RATE = 0.1
diff --git a/projects/FastFace/fastface/datasets/ms1mv2.py b/projects/FastFace/fastface/datasets/ms1mv2.py
index b19d4f8..c633a47 100644
--- a/projects/FastFace/fastface/datasets/ms1mv2.py
+++ b/projects/FastFace/fastface/datasets/ms1mv2.py
@@ -23,7 +23,7 @@ class MS1MV2(ImageDataset):
         required_files = [self.dataset_dir]
         self.check_before_run(required_files)
 
-        train = self.process_dirs()
+        train = self.process_dirs()[:10000]
         super().__init__(train, [], [], **kwargs)
 
     def process_dirs(self):
diff --git a/projects/FastFace/fastface/modeling/__init__.py b/projects/FastFace/fastface/modeling/__init__.py
index 7897014..4cf69a0 100644
--- a/projects/FastFace/fastface/modeling/__init__.py
+++ b/projects/FastFace/fastface/modeling/__init__.py
@@ -7,4 +7,4 @@
 from .partial_fc import PartialFC
 from .face_baseline import FaceBaseline
 from .face_head import FaceHead
-from .resnet_ir import build_resnetIR_backbone
+from .iresnet import build_iresnet_backbone
diff --git a/projects/FastFace/fastface/modeling/face_baseline.py b/projects/FastFace/fastface/modeling/face_baseline.py
index 9cf6ec7..ad0b24e 100644
--- a/projects/FastFace/fastface/modeling/face_baseline.py
+++ b/projects/FastFace/fastface/modeling/face_baseline.py
@@ -4,6 +4,7 @@
 @contact: sherlockliao01@gmail.com
 """
 
+import torch
 from fastreid.modeling.meta_arch import Baseline
 from fastreid.modeling.meta_arch import META_ARCH_REGISTRY
 
@@ -13,12 +14,28 @@ class FaceBaseline(Baseline):
     def __init__(self, cfg):
         super().__init__(cfg)
         self.pfc_enabled = cfg.MODEL.HEADS.PFC.ENABLED
+        self.amp_enabled = cfg.SOLVER.AMP.ENABLED
 
-    def losses(self, outputs, gt_labels):
+    def forward(self, batched_inputs):
         if not self.pfc_enabled:
-            return super().losses(outputs, gt_labels)
+            return super().forward(batched_inputs)
+
+        images = self.preprocess_image(batched_inputs)
+        with torch.cuda.amp.autocast(self.amp_enabled):
+            features = self.backbone(images)
+        features = features.float() if self.amp_enabled else features
+
+        if self.training:
+            assert "targets" in batched_inputs, "Person ID annotation are missing in training!"
+            targets = batched_inputs["targets"]
+
+            # PreciseBN flag, When do preciseBN on different dataset, the number of classes in new dataset
+            # may be larger than that in the original dataset, so the circle/arcface will
+            # throw an error. We just set all the targets to 0 to avoid this problem.
+            if targets.sum() < 0: targets.zero_()
+
+            outputs = self.heads(features, targets)
+            return outputs, targets
         else:
-            # model parallel with partial-fc
-            # cls layer and loss computation in partial_fc.py
-            pred_features = outputs["features"]
-            return pred_features, gt_labels
+            outputs = self.heads(features)
+            return outputs
diff --git a/projects/FastFace/fastface/modeling/face_head.py b/projects/FastFace/fastface/modeling/face_head.py
index 0168b84..7583a0b 100644
--- a/projects/FastFace/fastface/modeling/face_head.py
+++ b/projects/FastFace/fastface/modeling/face_head.py
@@ -30,10 +30,4 @@ class FaceHead(EmbeddingHead):
             pool_feat = self.pool_layer(features)
             neck_feat = self.bottleneck(pool_feat)
             neck_feat = neck_feat[..., 0, 0]
-
-            if not self.training:
-                return neck_feat
-
-            return {
-                "features": neck_feat,
-            }
+            return neck_feat
diff --git a/projects/FastFace/fastface/modeling/iresnet.py b/projects/FastFace/fastface/modeling/iresnet.py
new file mode 100644
index 0000000..da2d593
--- /dev/null
+++ b/projects/FastFace/fastface/modeling/iresnet.py
@@ -0,0 +1,179 @@
+# encoding: utf-8
+"""
+@author:  xingyu liao
+@contact: sherlockliao01@gmail.com
+"""
+
+import torch
+from torch import nn
+
+from fastreid.layers import get_norm
+from fastreid.modeling.backbones import BACKBONE_REGISTRY
+
+
+def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
+    """3x3 convolution with padding"""
+    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
+                     padding=dilation, groups=groups, bias=False, dilation=dilation)
+
+
+def conv1x1(in_planes, out_planes, stride=1):
+    """1x1 convolution"""
+    return nn.Conv2d(in_planes,
+                     out_planes,
+                     kernel_size=1,
+                     stride=stride,
+                     bias=False)
+
+
+class IBasicBlock(nn.Module):
+    expansion = 1
+
+    def __init__(self, inplanes, planes, bn_norm, stride=1, downsample=None,
+                 groups=1, base_width=64, dilation=1):
+        super().__init__()
+        if groups != 1 or base_width != 64:
+            raise ValueError('BasicBlock only supports groups=1 and base_width=64')
+        if dilation > 1:
+            raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
+        self.bn1 = get_norm(bn_norm, inplanes)
+        self.conv1 = conv3x3(inplanes, planes)
+        self.bn2 = get_norm(bn_norm, planes)
+        self.prelu = nn.PReLU(planes)
+        self.conv2 = conv3x3(planes, planes, stride)
+        self.bn3 = get_norm(bn_norm, planes)
+        self.downsample = downsample
+        self.stride = stride
+
+    def forward(self, x):
+        identity = x
+        out = self.bn1(x)
+        out = self.conv1(out)
+        out = self.bn2(out)
+        out = self.prelu(out)
+        out = self.conv2(out)
+        out = self.bn3(out)
+        if self.downsample is not None:
+            identity = self.downsample(x)
+        out += identity
+        return out
+
+
+class IResNet(nn.Module):
+    fc_scale = 7 * 7
+
+    def __init__(self, block, layers, bn_norm, dropout=0, zero_init_residual=False,
+                 groups=1, width_per_group=64, replace_stride_with_dilation=None, fp16=False):
+        super().__init__()
+        self.inplanes = 64
+        self.dilation = 1
+        self.fp16 = fp16
+        if replace_stride_with_dilation is None:
+            replace_stride_with_dilation = [False, False, False]
+        if len(replace_stride_with_dilation) != 3:
+            raise ValueError("replace_stride_with_dilation should be None "
+                             "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
+        self.groups = groups
+        self.base_width = width_per_group
+        self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
+        self.bn1 = get_norm(bn_norm, self.inplanes)
+        self.prelu = nn.PReLU(self.inplanes)
+        self.layer1 = self._make_layer(block, 64, layers[0], bn_norm, stride=2)
+        self.layer2 = self._make_layer(block,
+                                       128,
+                                       layers[1],
+                                       bn_norm,
+                                       stride=2,
+                                       dilate=replace_stride_with_dilation[0])
+        self.layer3 = self._make_layer(block,
+                                       256,
+                                       layers[2],
+                                       bn_norm,
+                                       stride=2,
+                                       dilate=replace_stride_with_dilation[1])
+        self.layer4 = self._make_layer(block,
+                                       512,
+                                       layers[3],
+                                       bn_norm,
+                                       stride=2,
+                                       dilate=replace_stride_with_dilation[2])
+        self.bn2 = get_norm(bn_norm, 512 * block.expansion)
+        self.dropout = nn.Dropout(p=dropout, inplace=True)
+
+        for m in self.modules():
+            if isinstance(m, nn.Conv2d):
+                nn.init.normal_(m.weight, 0, 0.1)
+            elif m.__class__.__name__.find('Norm') != -1:
+                nn.init.constant_(m.weight, 1)
+                nn.init.constant_(m.bias, 0)
+
+        if zero_init_residual:
+            for m in self.modules():
+                if isinstance(m, IBasicBlock):
+                    nn.init.constant_(m.bn2.weight, 0)
+
+    def _make_layer(self, block, planes, blocks, bn_norm, stride=1, dilate=False):
+        downsample = None
+        previous_dilation = self.dilation
+        if dilate:
+            self.dilation *= stride
+            stride = 1
+        if stride != 1 or self.inplanes != planes * block.expansion:
+            downsample = nn.Sequential(
+                conv1x1(self.inplanes, planes * block.expansion, stride),
+                get_norm(bn_norm, planes * block.expansion),
+            )
+        layers = []
+        layers.append(
+            block(self.inplanes, planes, bn_norm, stride, downsample, self.groups,
+                  self.base_width, previous_dilation))
+        self.inplanes = planes * block.expansion
+        for _ in range(1, blocks):
+            layers.append(
+                block(self.inplanes,
+                      planes,
+                      bn_norm,
+                      groups=self.groups,
+                      base_width=self.base_width,
+                      dilation=self.dilation))
+
+        return nn.Sequential(*layers)
+
+    def forward(self, x):
+        x = self.conv1(x)
+        x = self.bn1(x)
+        x = self.prelu(x)
+        x = self.layer1(x)
+        x = self.layer2(x)
+        x = self.layer3(x)
+        x = self.layer4(x)
+        x = self.bn2(x)
+        x = self.dropout(x)
+        return x
+
+
+@BACKBONE_REGISTRY.register()
+def build_iresnet_backbone(cfg):
+    """
+    Create a IResNet instance from config.
+    Returns:
+        ResNet: a :class:`ResNet` instance.
+    """
+
+    # fmt: off
+    bn_norm = cfg.MODEL.BACKBONE.NORM
+    depth   = cfg.MODEL.BACKBONE.DEPTH
+    dropout = cfg.MODEL.BACKBONE.DROPOUT
+    fp16    = cfg.SOLVER.AMP.ENABLED
+    # fmt: on
+
+    num_blocks_per_stage = {
+        '18x': [2, 2, 2, 2],
+        '34x': [3, 4, 6, 3],
+        '50x': [3, 4, 14, 3],
+        '100x': [3, 13, 30, 3],
+        '200x': [6, 26, 60, 6],
+    }[depth]
+
+    model = IResNet(IBasicBlock, num_blocks_per_stage, bn_norm, dropout, fp16=fp16)
+    return model
diff --git a/projects/FastFace/fastface/modeling/partial_fc.py b/projects/FastFace/fastface/modeling/partial_fc.py
index 408cace..b1c6c42 100644
--- a/projects/FastFace/fastface/modeling/partial_fc.py
+++ b/projects/FastFace/fastface/modeling/partial_fc.py
@@ -52,23 +52,6 @@ class PartialFC(nn.Module):
 
         self.cls_layer = getattr(any_softmax, cls_type)(num_classes, scale, margin)
 
-        """ TODO: consider resume training
-        if resume:
-            try:
-                self.weight: torch.Tensor = torch.load(self.weight_name)
-                logging.info("softmax weight resume successfully!")
-            except (FileNotFoundError, KeyError, IndexError):
-                self.weight = torch.normal(0, 0.01, (self.num_local, self.embedding_size), device=self.device)
-                logging.info("softmax weight resume fail!")
-
-            try:
-                self.weight_mom: torch.Tensor = torch.load(self.weight_mom_name)
-                logging.info("softmax weight mom resume successfully!")
-            except (FileNotFoundError, KeyError, IndexError):
-                self.weight_mom: torch.Tensor = torch.zeros_like(self.weight)
-                logging.info("softmax weight mom resume fail!")
-        else:
-        """
         self.weight = torch.normal(0, 0.01, (self.num_local, self.embedding_size), device=self.device)
         self.weight_mom: torch.Tensor = torch.zeros_like(self.weight)
         logger.info("softmax weight init successfully!")
diff --git a/projects/FastFace/fastface/modeling/resnet_ir.py b/projects/FastFace/fastface/modeling/resnet_ir.py
deleted file mode 100644
index 67632a2..0000000
--- a/projects/FastFace/fastface/modeling/resnet_ir.py
+++ /dev/null
@@ -1,122 +0,0 @@
-# encoding: utf-8
-"""
-@author:  xingyu liao
-@contact: sherlockliao01@gmail.com
-"""
-
-from collections import namedtuple
-
-from torch import nn
-
-from fastreid.layers import get_norm, SELayer
-from fastreid.modeling.backbones import BACKBONE_REGISTRY
-
-
-def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
-    """3x3 convolution with padding"""
-    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
-                     padding=dilation, groups=groups, bias=False, dilation=dilation)
-
-
-class bottleneck_IR(nn.Module):
-    def __init__(self, in_channel, depth, bn_norm, stride, with_se=False):
-        super(bottleneck_IR, self).__init__()
-        if in_channel == depth:
-            self.shortcut_layer = nn.MaxPool2d(1, stride)
-        else:
-            self.shortcut_layer = nn.Sequential(
-                nn.Conv2d(in_channel, depth, (1, 1), stride, bias=False),
-                get_norm(bn_norm, depth))
-        self.res_layer = nn.Sequential(
-            get_norm(bn_norm, in_channel),
-            nn.Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False),
-            nn.PReLU(depth),
-            nn.Conv2d(depth, depth, (3, 3), stride, 1, bias=False),
-            get_norm(bn_norm, depth),
-            SELayer(depth, 16) if with_se else nn.Identity()
-        )
-
-    def forward(self, x):
-        shortcut = self.shortcut_layer(x)
-        res = self.res_layer(x)
-        return res + shortcut
-
-
-class Bottleneck(namedtuple("Block", ["in_channel", "depth", "bn_norm", "stride", "with_se"])):
-    """A named tuple describing a ResNet block."""
-
-
-def get_block(in_channel, depth, bn_norm, num_units, with_se, stride=2):
-    return [Bottleneck(in_channel, depth, bn_norm, stride, with_se)] + \
-           [Bottleneck(depth, depth, bn_norm, 1, with_se) for _ in range(num_units - 1)]
-
-
-def get_blocks(bn_norm, with_se, num_layers):
-    if num_layers == "50x":
-        blocks = [
-            get_block(in_channel=64, depth=64, bn_norm=bn_norm, num_units=3, with_se=with_se),
-            get_block(in_channel=64, depth=128, bn_norm=bn_norm, num_units=4, with_se=with_se),
-            get_block(in_channel=128, depth=256, bn_norm=bn_norm, num_units=14, with_se=with_se),
-            get_block(in_channel=256, depth=512, bn_norm=bn_norm, num_units=3, with_se=with_se)
-        ]
-    elif num_layers == "100x":
-        blocks = [
-            get_block(in_channel=64, depth=64, bn_norm=bn_norm, num_units=3, with_se=with_se),
-            get_block(in_channel=64, depth=128, bn_norm=bn_norm, num_units=13, with_se=with_se),
-            get_block(in_channel=128, depth=256, bn_norm=bn_norm, num_units=30, with_se=with_se),
-            get_block(in_channel=256, depth=512, bn_norm=bn_norm, num_units=3, with_se=with_se)
-        ]
-    elif num_layers == "152x":
-        blocks = [
-            get_block(in_channel=64, depth=64, bn_norm=bn_norm, num_units=3, with_se=with_se),
-            get_block(in_channel=64, depth=128, bn_norm=bn_norm, num_units=8, with_se=with_se),
-            get_block(in_channel=128, depth=256, bn_norm=bn_norm, num_units=36, with_se=with_se),
-            get_block(in_channel=256, depth=512, bn_norm=bn_norm, num_units=3, with_se=with_se)
-        ]
-    return blocks
-
-
-class ResNetIR(nn.Module):
-    def __init__(self, num_layers, bn_norm, drop_ratio, with_se):
-        super(ResNetIR, self).__init__()
-        assert num_layers in ["50x", "100x", "152x"], "num_layers should be 50,100, or 152"
-        blocks = get_blocks(bn_norm, with_se, num_layers)
-        self.input_layer = nn.Sequential(nn.Conv2d(3, 64, (3, 3), 1, 1, bias=False),
-                                         get_norm(bn_norm, 64),
-                                         nn.PReLU(64))
-        self.output_layer = nn.Sequential(get_norm(bn_norm, 512),
-                                          nn.Dropout(drop_ratio))
-        modules = []
-        for block in blocks:
-            for bottleneck in block:
-                modules.append(
-                    bottleneck_IR(bottleneck.in_channel,
-                                  bottleneck.depth,
-                                  bottleneck.bn_norm,
-                                  bottleneck.stride,
-                                  bottleneck.with_se))
-        self.body = nn.Sequential(*modules)
-
-    def forward(self, x):
-        x = self.input_layer(x)
-        x = self.body(x)
-        x = self.output_layer(x)
-        return x
-
-
-@BACKBONE_REGISTRY.register()
-def build_resnetIR_backbone(cfg):
-    """
-    Create a ResNetIR instance from config.
-    Returns:
-        ResNet: a :class:`ResNet` instance.
-    """
-
-    # fmt: off
-    bn_norm = cfg.MODEL.BACKBONE.NORM
-    with_se = cfg.MODEL.BACKBONE.WITH_SE
-    depth   = cfg.MODEL.BACKBONE.DEPTH
-    # fmt: on
-
-    model = ResNetIR(depth, bn_norm, 0.5, with_se)
-    return model
diff --git a/projects/FastFace/fastface/pfc_checkpointer.py b/projects/FastFace/fastface/pfc_checkpointer.py
new file mode 100644
index 0000000..0ea759b
--- /dev/null
+++ b/projects/FastFace/fastface/pfc_checkpointer.py
@@ -0,0 +1,84 @@
+# encoding: utf-8
+"""
+@author:  xingyu liao
+@contact: sherlockliao01@gmail.com
+"""
+
+import os
+from typing import Any, Dict
+
+import torch
+
+from fastreid.engine.hooks import PeriodicCheckpointer
+from fastreid.utils import comm
+from fastreid.utils.checkpoint import Checkpointer
+from fastreid.utils.file_io import PathManager
+
+
+class PfcPeriodicCheckpointer(PeriodicCheckpointer):
+
+    def step(self, epoch: int, **kwargs: Any):
+        rank = comm.get_rank()
+        if (epoch + 1) % self.period == 0 and epoch < self.max_epoch - 1:
+            self.checkpointer.save(
+                f"softmax_weight_{epoch:04d}_rank_{rank:02d}"
+            )
+        if epoch >= self.max_epoch - 1:
+            self.checkpointer.save(f"softmax_weight_{rank:02d}", )
+
+
+class PfcCheckpointer(Checkpointer):
+    def __init__(self, model, save_dir, *, save_to_disk=True, **checkpointables):
+        super().__init__(model, save_dir, save_to_disk=save_to_disk, **checkpointables)
+        self.rank = comm.get_rank()
+
+    def save(self, name: str, **kwargs: Dict[str, str]):
+        if not self.save_dir or not self.save_to_disk:
+            return
+
+        data = {}
+        data["model"] = {
+            "weight": self.model.weight.data,
+            "momentum": self.model.weight_mom,
+        }
+        for key, obj in self.checkpointables.items():
+            data[key] = obj.state_dict()
+        data.update(kwargs)
+
+        basename = f"{name}.pth"
+        save_file = os.path.join(self.save_dir, basename)
+        assert os.path.basename(save_file) == basename, basename
+        self.logger.info("Saving partial fc weights")
+        with PathManager.open(save_file, "wb") as f:
+            torch.save(data, f)
+        self.tag_last_checkpoint(basename)
+
+    def _load_model(self, checkpoint: Any):
+        checkpoint_state_dict = checkpoint.pop("model")
+        self._convert_ndarray_to_tensor(checkpoint_state_dict)
+        self.model.weight.data.copy_(checkpoint_state_dict.pop("weight"))
+        self.model.weight_mom.data.copy_(checkpoint_state_dict.pop("momentum"))
+
+    def has_checkpoint(self):
+        save_file = os.path.join(self.save_dir, f"last_weight_{self.rank:02d}")
+        return PathManager.exists(save_file)
+
+    def get_checkpoint_file(self):
+        """
+        Returns:
+            str: The latest checkpoint file in target directory.
+        """
+        save_file = os.path.join(self.save_dir, f"last_weight_{self.rank:02d}")
+        try:
+            with PathManager.open(save_file, "r") as f:
+                last_saved = f.read().strip()
+        except IOError:
+            # if file doesn't exist, maybe because it has just been
+            # deleted by a separate process
+            return ""
+        return os.path.join(self.save_dir, last_saved)
+
+    def tag_last_checkpoint(self, last_filename_basename: str):
+        save_file = os.path.join(self.save_dir, f"last_weight_{self.rank:02d}")
+        with PathManager.open(save_file, "w") as f:
+            f.write(last_filename_basename)
diff --git a/projects/FastFace/fastface/trainer.py b/projects/FastFace/fastface/trainer.py
index 334a141..90db184 100644
--- a/projects/FastFace/fastface/trainer.py
+++ b/projects/FastFace/fastface/trainer.py
@@ -8,20 +8,23 @@ import os
 import time
 
 from torch.nn.parallel import DistributedDataParallel
+from torch.nn.utils import clip_grad_norm_
 
-from fastreid.engine import hooks
-from .face_data import TestFaceDataset
-from fastreid.data.datasets import DATASET_REGISTRY
 from fastreid.data.build import _root, build_reid_test_loader, build_reid_train_loader
+from fastreid.data.datasets import DATASET_REGISTRY
 from fastreid.data.transforms import build_transforms
+from fastreid.engine import hooks
 from fastreid.engine.defaults import DefaultTrainer, TrainerBase
-from fastreid.engine.train_loop import SimpleTrainer
+from fastreid.engine.train_loop import SimpleTrainer, AMPTrainer
 from fastreid.utils import comm
 from fastreid.utils.checkpoint import Checkpointer
 from fastreid.utils.logger import setup_logger
 from .face_data import MXFaceDataset
+from .face_data import TestFaceDataset
 from .face_evaluator import FaceEvaluator
 from .modeling import PartialFC
+from .pfc_checkpointer import PfcPeriodicCheckpointer, PfcCheckpointer
+from .utils_amp import MaxClipGradScaler
 
 
 class FaceTrainer(DefaultTrainer):
@@ -59,11 +62,17 @@ class FaceTrainer(DefaultTrainer):
             # for part of the parameters is not updated.
             model = DistributedDataParallel(
                 model, device_ids=[comm.get_local_rank()], broadcast_buffers=False,
-                find_unused_parameters=True
             )
 
-        self._trainer = PFCTrainer(model, data_loader, optimizer, self.pfc_module, self.pfc_optimizer) \
-            if cfg.MODEL.HEADS.PFC.ENABLED else SimpleTrainer(model, data_loader, optimizer)
+        if cfg.MODEL.HEADS.PFC.ENABLED:
+            mini_batch_size = cfg.SOLVER.IMS_PER_BATCH // comm.get_world_size()
+            grad_scaler = MaxClipGradScaler(mini_batch_size, 128 * mini_batch_size, growth_interval=100)
+            self._trainer = PFCTrainer(model, data_loader, optimizer,
+                                       self.pfc_module, self.pfc_optimizer, cfg.SOLVER.AMP.ENABLED, grad_scaler)
+        else:
+            self._trainer = (AMPTrainer if cfg.SOLVER.AMP.ENABLED else SimpleTrainer)(
+                model, data_loader, optimizer
+            )
 
         self.iters_per_epoch = len(data_loader.dataset) // cfg.SOLVER.IMS_PER_BATCH
         self.scheduler = self.build_lr_scheduler(cfg, optimizer, self.iters_per_epoch)
@@ -80,7 +89,7 @@ class FaceTrainer(DefaultTrainer):
         )
 
         if cfg.MODEL.HEADS.PFC.ENABLED:
-            self.pfc_checkpointer = Checkpointer(
+            self.pfc_checkpointer = PfcCheckpointer(
                 self.pfc_module,
                 cfg.OUTPUT_DIR,
                 optimizer=self.pfc_optimizer,
@@ -100,12 +109,23 @@ class FaceTrainer(DefaultTrainer):
         ret = super().build_hooks()
 
         if self.cfg.MODEL.HEADS.PFC.ENABLED:
+            # Make sure checkpointer is after writer
+            ret.insert(
+                len(ret) - 1,
+                PfcPeriodicCheckpointer(self.pfc_checkpointer, self.cfg.SOLVER.CHECKPOINT_PERIOD)
+            )
             # partial fc scheduler hook
             ret.append(
                 hooks.LRScheduler(self.pfc_optimizer, self.pfc_scheduler)
             )
         return ret
 
+    def resume_or_load(self, resume=True):
+        # Backbone loading state_dict
+        super().resume_or_load(resume)
+        # Partial-FC loading state_dict
+        self.pfc_checkpointer.resume_or_load('', resume=resume)
+
     @classmethod
     def build_train_loader(cls, cfg):
         path_imgrec = cfg.DATASETS.REC_PATH
@@ -141,11 +161,14 @@ class PFCTrainer(SimpleTrainer):
     https://github.com/deepinsight/insightface/blob/master/recognition/arcface_torch/partial_fc.py
     """
 
-    def __init__(self, model, data_loader, optimizer, pfc_module, pfc_optimizer):
+    def __init__(self, model, data_loader, optimizer, pfc_module, pfc_optimizer, amp_enabled, grad_scaler):
         super().__init__(model, data_loader, optimizer)
 
         self.pfc_module = pfc_module
         self.pfc_optimizer = pfc_optimizer
+        self.amp_enabled = amp_enabled
+
+        self.grad_scaler = grad_scaler
 
     def run_step(self):
         assert self.model.training, "[PFCTrainer] model was changed to eval mode!"
@@ -156,18 +179,24 @@ class PFCTrainer(SimpleTrainer):
 
         features, targets = self.model(data)
 
-        self.optimizer.zero_grad()
-        self.pfc_optimizer.zero_grad()
-
         # Partial-fc backward
         f_grad, loss_v = self.pfc_module.forward_backward(features, targets, self.pfc_optimizer)
 
-        features.backward(f_grad)
+        if self.amp_enabled:
+            features.backward(self.grad_scaler.scale(f_grad))
+            self.grad_scaler.unscale_(self.optimizer)
+            clip_grad_norm_(self.model.parameters(), max_norm=5, norm_type=2)
+            self.grad_scaler.step(self.optimizer)
+            self.grad_scaler.update()
+        else:
+            features.backward(f_grad)
+            clip_grad_norm_(self.model.parameters(), max_norm=5, norm_type=2)
+            self.optimizer.step()
 
         loss_dict = {"loss_cls": loss_v}
         self._write_metrics(loss_dict, data_time)
 
-        self.optimizer.step()
         self.pfc_optimizer.step()
-
         self.pfc_module.update()
+        self.optimizer.zero_grad()
+        self.pfc_optimizer.zero_grad()
diff --git a/projects/FastFace/fastface/utils_amp.py b/projects/FastFace/fastface/utils_amp.py
new file mode 100644
index 0000000..de115c6
--- /dev/null
+++ b/projects/FastFace/fastface/utils_amp.py
@@ -0,0 +1,86 @@
+# encoding: utf-8
+"""
+@author:  xingyu liao
+@contact: sherlockliao01@gmail.com
+"""
+
+from typing import Dict, List
+
+import torch
+from torch._six import container_abcs
+from torch.cuda.amp import GradScaler
+
+
+class _MultiDeviceReplicator(object):
+    """
+    Lazily serves copies of a tensor to requested devices.  Copies are cached per-device.
+    """
+
+    def __init__(self, master_tensor: torch.Tensor) -> None:
+        assert master_tensor.is_cuda
+        self.master = master_tensor
+        self._per_device_tensors: Dict[torch.device, torch.Tensor] = {}
+
+    def get(self, device) -> torch.Tensor:
+        retval = self._per_device_tensors.get(device, None)
+        if retval is None:
+            retval = self.master.to(device=device, non_blocking=True, copy=True)
+            self._per_device_tensors[device] = retval
+        return retval
+
+
+class MaxClipGradScaler(GradScaler):
+    def __init__(self, init_scale, max_scale: float, growth_interval=100):
+        super().__init__(init_scale=init_scale, growth_interval=growth_interval)
+        self.max_scale = max_scale
+
+    def scale_clip(self):
+        if self.get_scale() == self.max_scale:
+            self.set_growth_factor(1)
+        elif self.get_scale() < self.max_scale:
+            self.set_growth_factor(2)
+        elif self.get_scale() > self.max_scale:
+            self._scale.fill_(self.max_scale)
+            self.set_growth_factor(1)
+
+    def scale(self, outputs):
+        """
+        Multiplies ('scales') a tensor or list of tensors by the scale factor.
+        Returns scaled outputs.  If this instance of :class:`GradScaler` is not enabled, outputs are returned
+        unmodified.
+        Arguments:
+            outputs (Tensor or iterable of Tensors):  Outputs to scale.
+        """
+        if not self._enabled:
+            return outputs
+        self.scale_clip()
+        # Short-circuit for the common case.
+        if isinstance(outputs, torch.Tensor):
+            assert outputs.is_cuda
+            if self._scale is None:
+                self._lazy_init_scale_growth_tracker(outputs.device)
+            assert self._scale is not None
+            return outputs * self._scale.to(device=outputs.device, non_blocking=True)
+
+        # Invoke the more complex machinery only if we're treating multiple outputs.
+        stash: List[_MultiDeviceReplicator] = []  # holds a reference that can be overwritten by apply_scale
+
+        def apply_scale(val):
+            if isinstance(val, torch.Tensor):
+                assert val.is_cuda
+                if len(stash) == 0:
+                    if self._scale is None:
+                        self._lazy_init_scale_growth_tracker(val.device)
+                    assert self._scale is not None
+                    stash.append(_MultiDeviceReplicator(self._scale))
+                return val * stash[0].get(val.device)
+            elif isinstance(val, container_abcs.Iterable):
+                iterable = map(apply_scale, val)
+                if isinstance(val, list) or isinstance(val, tuple):
+                    return type(val)(iterable)
+                else:
+                    return iterable
+            else:
+                raise ValueError("outputs must be a Tensor or an iterable of Tensors")
+
+        return apply_scale(outputs)
diff --git a/projects/FastFace/train_net.py b/projects/FastFace/train_net.py
index 2b47401..8e24769 100644
--- a/projects/FastFace/train_net.py
+++ b/projects/FastFace/train_net.py
@@ -14,7 +14,6 @@ from fastreid.engine import default_argument_parser, default_setup, launch
 from fastreid.utils.checkpoint import Checkpointer
 
 from fastface import *
-from fastface.datasets import *
 
 
 def setup(args):