mirror of https://github.com/JDAI-CV/fast-reid.git
Support amp and resume training in fastface
AMP in partial-fc needs to be done only on backbone; In order to impl `resume training`, need to save & load different part of classifier weight in each GPU.pull/504/head
parent
91ff631184
commit
c3ac4f504c
|
@ -4,6 +4,9 @@ MODEL:
|
|||
PIXEL_MEAN: [127.5, 127.5, 127.5]
|
||||
PIXEL_STD: [127.5, 127.5, 127.5]
|
||||
|
||||
BACKBONE:
|
||||
NAME: build_iresnet_backbone
|
||||
|
||||
HEADS:
|
||||
NAME: FaceHead
|
||||
WITH_BNNECK: True
|
||||
|
@ -30,7 +33,7 @@ MODEL:
|
|||
DATASETS:
|
||||
REC_PATH: /export/home/DATA/Glint360k/train.rec
|
||||
NAMES: ("MS1MV2",)
|
||||
TESTS: ("CPLFW", "VGG2_FP", "CALFW", "CFP_FF", "CFP_FP", "AgeDB_30", "LFW")
|
||||
TESTS: ("CFP_FP", "AgeDB_30", "LFW")
|
||||
|
||||
INPUT:
|
||||
SIZE_TRAIN: [0,] # No need of resize
|
||||
|
@ -47,10 +50,10 @@ DATALOADER:
|
|||
SOLVER:
|
||||
MAX_EPOCH: 20
|
||||
AMP:
|
||||
ENABLED: False
|
||||
ENABLED: True
|
||||
|
||||
OPT: SGD
|
||||
BASE_LR: 0.1
|
||||
BASE_LR: 0.05
|
||||
MOMENTUM: 0.9
|
||||
|
||||
SCHED: MultiStepLR
|
||||
|
@ -59,10 +62,10 @@ SOLVER:
|
|||
BIAS_LR_FACTOR: 1.
|
||||
WEIGHT_DECAY: 0.0005
|
||||
WEIGHT_DECAY_BIAS: 0.0005
|
||||
IMS_PER_BATCH: 512
|
||||
IMS_PER_BATCH: 256
|
||||
|
||||
WARMUP_FACTOR: 0.1
|
||||
WARMUP_ITERS: 5000
|
||||
WARMUP_ITERS: 0
|
||||
|
||||
CHECKPOINT_PERIOD: 1
|
||||
|
||||
|
|
|
@ -3,13 +3,12 @@ _BASE_: face_base.yml
|
|||
MODEL:
|
||||
|
||||
BACKBONE:
|
||||
NAME: build_resnetIR_backbone
|
||||
DEPTH: 50x
|
||||
FEAT_DIM: 25088 # 512x7x7
|
||||
WITH_SE: True
|
||||
DROPOUT: 0.
|
||||
|
||||
HEADS:
|
||||
PFC:
|
||||
ENABLED: True
|
||||
|
||||
OUTPUT_DIR: projects/FastFace/logs/ir_se50-glink360k-pfc0.1
|
||||
OUTPUT_DIR: projects/FastFace/logs/pfc0.1_insightface
|
||||
|
|
|
@ -7,3 +7,4 @@
|
|||
from .modeling import *
|
||||
from .config import add_face_cfg
|
||||
from .trainer import FaceTrainer
|
||||
from .datasets import *
|
||||
|
|
|
@ -12,5 +12,7 @@ def add_face_cfg(cfg):
|
|||
|
||||
_C.DATASETS.REC_PATH = ""
|
||||
|
||||
_C.MODEL.BACKBONE.DROPOUT = 0.
|
||||
|
||||
_C.MODEL.HEADS.PFC = CN({"ENABLED": False})
|
||||
_C.MODEL.HEADS.PFC.SAMPLE_RATE = 0.1
|
||||
|
|
|
@ -23,7 +23,7 @@ class MS1MV2(ImageDataset):
|
|||
required_files = [self.dataset_dir]
|
||||
self.check_before_run(required_files)
|
||||
|
||||
train = self.process_dirs()
|
||||
train = self.process_dirs()[:10000]
|
||||
super().__init__(train, [], [], **kwargs)
|
||||
|
||||
def process_dirs(self):
|
||||
|
|
|
@ -7,4 +7,4 @@
|
|||
from .partial_fc import PartialFC
|
||||
from .face_baseline import FaceBaseline
|
||||
from .face_head import FaceHead
|
||||
from .resnet_ir import build_resnetIR_backbone
|
||||
from .iresnet import build_iresnet_backbone
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import torch
|
||||
from fastreid.modeling.meta_arch import Baseline
|
||||
from fastreid.modeling.meta_arch import META_ARCH_REGISTRY
|
||||
|
||||
|
@ -13,12 +14,28 @@ class FaceBaseline(Baseline):
|
|||
def __init__(self, cfg):
|
||||
super().__init__(cfg)
|
||||
self.pfc_enabled = cfg.MODEL.HEADS.PFC.ENABLED
|
||||
self.amp_enabled = cfg.SOLVER.AMP.ENABLED
|
||||
|
||||
def losses(self, outputs, gt_labels):
|
||||
def forward(self, batched_inputs):
|
||||
if not self.pfc_enabled:
|
||||
return super().losses(outputs, gt_labels)
|
||||
return super().forward(batched_inputs)
|
||||
|
||||
images = self.preprocess_image(batched_inputs)
|
||||
with torch.cuda.amp.autocast(self.amp_enabled):
|
||||
features = self.backbone(images)
|
||||
features = features.float() if self.amp_enabled else features
|
||||
|
||||
if self.training:
|
||||
assert "targets" in batched_inputs, "Person ID annotation are missing in training!"
|
||||
targets = batched_inputs["targets"]
|
||||
|
||||
# PreciseBN flag, When do preciseBN on different dataset, the number of classes in new dataset
|
||||
# may be larger than that in the original dataset, so the circle/arcface will
|
||||
# throw an error. We just set all the targets to 0 to avoid this problem.
|
||||
if targets.sum() < 0: targets.zero_()
|
||||
|
||||
outputs = self.heads(features, targets)
|
||||
return outputs, targets
|
||||
else:
|
||||
# model parallel with partial-fc
|
||||
# cls layer and loss computation in partial_fc.py
|
||||
pred_features = outputs["features"]
|
||||
return pred_features, gt_labels
|
||||
outputs = self.heads(features)
|
||||
return outputs
|
||||
|
|
|
@ -30,10 +30,4 @@ class FaceHead(EmbeddingHead):
|
|||
pool_feat = self.pool_layer(features)
|
||||
neck_feat = self.bottleneck(pool_feat)
|
||||
neck_feat = neck_feat[..., 0, 0]
|
||||
|
||||
if not self.training:
|
||||
return neck_feat
|
||||
|
||||
return {
|
||||
"features": neck_feat,
|
||||
}
|
||||
|
|
|
@ -0,0 +1,179 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: xingyu liao
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from fastreid.layers import get_norm
|
||||
from fastreid.modeling.backbones import BACKBONE_REGISTRY
|
||||
|
||||
|
||||
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
|
||||
"""3x3 convolution with padding"""
|
||||
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
||||
padding=dilation, groups=groups, bias=False, dilation=dilation)
|
||||
|
||||
|
||||
def conv1x1(in_planes, out_planes, stride=1):
|
||||
"""1x1 convolution"""
|
||||
return nn.Conv2d(in_planes,
|
||||
out_planes,
|
||||
kernel_size=1,
|
||||
stride=stride,
|
||||
bias=False)
|
||||
|
||||
|
||||
class IBasicBlock(nn.Module):
|
||||
expansion = 1
|
||||
|
||||
def __init__(self, inplanes, planes, bn_norm, stride=1, downsample=None,
|
||||
groups=1, base_width=64, dilation=1):
|
||||
super().__init__()
|
||||
if groups != 1 or base_width != 64:
|
||||
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
|
||||
if dilation > 1:
|
||||
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
|
||||
self.bn1 = get_norm(bn_norm, inplanes)
|
||||
self.conv1 = conv3x3(inplanes, planes)
|
||||
self.bn2 = get_norm(bn_norm, planes)
|
||||
self.prelu = nn.PReLU(planes)
|
||||
self.conv2 = conv3x3(planes, planes, stride)
|
||||
self.bn3 = get_norm(bn_norm, planes)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
out = self.bn1(x)
|
||||
out = self.conv1(out)
|
||||
out = self.bn2(out)
|
||||
out = self.prelu(out)
|
||||
out = self.conv2(out)
|
||||
out = self.bn3(out)
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x)
|
||||
out += identity
|
||||
return out
|
||||
|
||||
|
||||
class IResNet(nn.Module):
|
||||
fc_scale = 7 * 7
|
||||
|
||||
def __init__(self, block, layers, bn_norm, dropout=0, zero_init_residual=False,
|
||||
groups=1, width_per_group=64, replace_stride_with_dilation=None, fp16=False):
|
||||
super().__init__()
|
||||
self.inplanes = 64
|
||||
self.dilation = 1
|
||||
self.fp16 = fp16
|
||||
if replace_stride_with_dilation is None:
|
||||
replace_stride_with_dilation = [False, False, False]
|
||||
if len(replace_stride_with_dilation) != 3:
|
||||
raise ValueError("replace_stride_with_dilation should be None "
|
||||
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
|
||||
self.groups = groups
|
||||
self.base_width = width_per_group
|
||||
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
|
||||
self.bn1 = get_norm(bn_norm, self.inplanes)
|
||||
self.prelu = nn.PReLU(self.inplanes)
|
||||
self.layer1 = self._make_layer(block, 64, layers[0], bn_norm, stride=2)
|
||||
self.layer2 = self._make_layer(block,
|
||||
128,
|
||||
layers[1],
|
||||
bn_norm,
|
||||
stride=2,
|
||||
dilate=replace_stride_with_dilation[0])
|
||||
self.layer3 = self._make_layer(block,
|
||||
256,
|
||||
layers[2],
|
||||
bn_norm,
|
||||
stride=2,
|
||||
dilate=replace_stride_with_dilation[1])
|
||||
self.layer4 = self._make_layer(block,
|
||||
512,
|
||||
layers[3],
|
||||
bn_norm,
|
||||
stride=2,
|
||||
dilate=replace_stride_with_dilation[2])
|
||||
self.bn2 = get_norm(bn_norm, 512 * block.expansion)
|
||||
self.dropout = nn.Dropout(p=dropout, inplace=True)
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.normal_(m.weight, 0, 0.1)
|
||||
elif m.__class__.__name__.find('Norm') != -1:
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
if zero_init_residual:
|
||||
for m in self.modules():
|
||||
if isinstance(m, IBasicBlock):
|
||||
nn.init.constant_(m.bn2.weight, 0)
|
||||
|
||||
def _make_layer(self, block, planes, blocks, bn_norm, stride=1, dilate=False):
|
||||
downsample = None
|
||||
previous_dilation = self.dilation
|
||||
if dilate:
|
||||
self.dilation *= stride
|
||||
stride = 1
|
||||
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||
downsample = nn.Sequential(
|
||||
conv1x1(self.inplanes, planes * block.expansion, stride),
|
||||
get_norm(bn_norm, planes * block.expansion),
|
||||
)
|
||||
layers = []
|
||||
layers.append(
|
||||
block(self.inplanes, planes, bn_norm, stride, downsample, self.groups,
|
||||
self.base_width, previous_dilation))
|
||||
self.inplanes = planes * block.expansion
|
||||
for _ in range(1, blocks):
|
||||
layers.append(
|
||||
block(self.inplanes,
|
||||
planes,
|
||||
bn_norm,
|
||||
groups=self.groups,
|
||||
base_width=self.base_width,
|
||||
dilation=self.dilation))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.prelu(x)
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
x = self.layer4(x)
|
||||
x = self.bn2(x)
|
||||
x = self.dropout(x)
|
||||
return x
|
||||
|
||||
|
||||
@BACKBONE_REGISTRY.register()
|
||||
def build_iresnet_backbone(cfg):
|
||||
"""
|
||||
Create a IResNet instance from config.
|
||||
Returns:
|
||||
ResNet: a :class:`ResNet` instance.
|
||||
"""
|
||||
|
||||
# fmt: off
|
||||
bn_norm = cfg.MODEL.BACKBONE.NORM
|
||||
depth = cfg.MODEL.BACKBONE.DEPTH
|
||||
dropout = cfg.MODEL.BACKBONE.DROPOUT
|
||||
fp16 = cfg.SOLVER.AMP.ENABLED
|
||||
# fmt: on
|
||||
|
||||
num_blocks_per_stage = {
|
||||
'18x': [2, 2, 2, 2],
|
||||
'34x': [3, 4, 6, 3],
|
||||
'50x': [3, 4, 14, 3],
|
||||
'100x': [3, 13, 30, 3],
|
||||
'200x': [6, 26, 60, 6],
|
||||
}[depth]
|
||||
|
||||
model = IResNet(IBasicBlock, num_blocks_per_stage, bn_norm, dropout, fp16=fp16)
|
||||
return model
|
|
@ -52,23 +52,6 @@ class PartialFC(nn.Module):
|
|||
|
||||
self.cls_layer = getattr(any_softmax, cls_type)(num_classes, scale, margin)
|
||||
|
||||
""" TODO: consider resume training
|
||||
if resume:
|
||||
try:
|
||||
self.weight: torch.Tensor = torch.load(self.weight_name)
|
||||
logging.info("softmax weight resume successfully!")
|
||||
except (FileNotFoundError, KeyError, IndexError):
|
||||
self.weight = torch.normal(0, 0.01, (self.num_local, self.embedding_size), device=self.device)
|
||||
logging.info("softmax weight resume fail!")
|
||||
|
||||
try:
|
||||
self.weight_mom: torch.Tensor = torch.load(self.weight_mom_name)
|
||||
logging.info("softmax weight mom resume successfully!")
|
||||
except (FileNotFoundError, KeyError, IndexError):
|
||||
self.weight_mom: torch.Tensor = torch.zeros_like(self.weight)
|
||||
logging.info("softmax weight mom resume fail!")
|
||||
else:
|
||||
"""
|
||||
self.weight = torch.normal(0, 0.01, (self.num_local, self.embedding_size), device=self.device)
|
||||
self.weight_mom: torch.Tensor = torch.zeros_like(self.weight)
|
||||
logger.info("softmax weight init successfully!")
|
||||
|
|
|
@ -1,122 +0,0 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: xingyu liao
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
from collections import namedtuple
|
||||
|
||||
from torch import nn
|
||||
|
||||
from fastreid.layers import get_norm, SELayer
|
||||
from fastreid.modeling.backbones import BACKBONE_REGISTRY
|
||||
|
||||
|
||||
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
|
||||
"""3x3 convolution with padding"""
|
||||
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
||||
padding=dilation, groups=groups, bias=False, dilation=dilation)
|
||||
|
||||
|
||||
class bottleneck_IR(nn.Module):
|
||||
def __init__(self, in_channel, depth, bn_norm, stride, with_se=False):
|
||||
super(bottleneck_IR, self).__init__()
|
||||
if in_channel == depth:
|
||||
self.shortcut_layer = nn.MaxPool2d(1, stride)
|
||||
else:
|
||||
self.shortcut_layer = nn.Sequential(
|
||||
nn.Conv2d(in_channel, depth, (1, 1), stride, bias=False),
|
||||
get_norm(bn_norm, depth))
|
||||
self.res_layer = nn.Sequential(
|
||||
get_norm(bn_norm, in_channel),
|
||||
nn.Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False),
|
||||
nn.PReLU(depth),
|
||||
nn.Conv2d(depth, depth, (3, 3), stride, 1, bias=False),
|
||||
get_norm(bn_norm, depth),
|
||||
SELayer(depth, 16) if with_se else nn.Identity()
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
shortcut = self.shortcut_layer(x)
|
||||
res = self.res_layer(x)
|
||||
return res + shortcut
|
||||
|
||||
|
||||
class Bottleneck(namedtuple("Block", ["in_channel", "depth", "bn_norm", "stride", "with_se"])):
|
||||
"""A named tuple describing a ResNet block."""
|
||||
|
||||
|
||||
def get_block(in_channel, depth, bn_norm, num_units, with_se, stride=2):
|
||||
return [Bottleneck(in_channel, depth, bn_norm, stride, with_se)] + \
|
||||
[Bottleneck(depth, depth, bn_norm, 1, with_se) for _ in range(num_units - 1)]
|
||||
|
||||
|
||||
def get_blocks(bn_norm, with_se, num_layers):
|
||||
if num_layers == "50x":
|
||||
blocks = [
|
||||
get_block(in_channel=64, depth=64, bn_norm=bn_norm, num_units=3, with_se=with_se),
|
||||
get_block(in_channel=64, depth=128, bn_norm=bn_norm, num_units=4, with_se=with_se),
|
||||
get_block(in_channel=128, depth=256, bn_norm=bn_norm, num_units=14, with_se=with_se),
|
||||
get_block(in_channel=256, depth=512, bn_norm=bn_norm, num_units=3, with_se=with_se)
|
||||
]
|
||||
elif num_layers == "100x":
|
||||
blocks = [
|
||||
get_block(in_channel=64, depth=64, bn_norm=bn_norm, num_units=3, with_se=with_se),
|
||||
get_block(in_channel=64, depth=128, bn_norm=bn_norm, num_units=13, with_se=with_se),
|
||||
get_block(in_channel=128, depth=256, bn_norm=bn_norm, num_units=30, with_se=with_se),
|
||||
get_block(in_channel=256, depth=512, bn_norm=bn_norm, num_units=3, with_se=with_se)
|
||||
]
|
||||
elif num_layers == "152x":
|
||||
blocks = [
|
||||
get_block(in_channel=64, depth=64, bn_norm=bn_norm, num_units=3, with_se=with_se),
|
||||
get_block(in_channel=64, depth=128, bn_norm=bn_norm, num_units=8, with_se=with_se),
|
||||
get_block(in_channel=128, depth=256, bn_norm=bn_norm, num_units=36, with_se=with_se),
|
||||
get_block(in_channel=256, depth=512, bn_norm=bn_norm, num_units=3, with_se=with_se)
|
||||
]
|
||||
return blocks
|
||||
|
||||
|
||||
class ResNetIR(nn.Module):
|
||||
def __init__(self, num_layers, bn_norm, drop_ratio, with_se):
|
||||
super(ResNetIR, self).__init__()
|
||||
assert num_layers in ["50x", "100x", "152x"], "num_layers should be 50,100, or 152"
|
||||
blocks = get_blocks(bn_norm, with_se, num_layers)
|
||||
self.input_layer = nn.Sequential(nn.Conv2d(3, 64, (3, 3), 1, 1, bias=False),
|
||||
get_norm(bn_norm, 64),
|
||||
nn.PReLU(64))
|
||||
self.output_layer = nn.Sequential(get_norm(bn_norm, 512),
|
||||
nn.Dropout(drop_ratio))
|
||||
modules = []
|
||||
for block in blocks:
|
||||
for bottleneck in block:
|
||||
modules.append(
|
||||
bottleneck_IR(bottleneck.in_channel,
|
||||
bottleneck.depth,
|
||||
bottleneck.bn_norm,
|
||||
bottleneck.stride,
|
||||
bottleneck.with_se))
|
||||
self.body = nn.Sequential(*modules)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.input_layer(x)
|
||||
x = self.body(x)
|
||||
x = self.output_layer(x)
|
||||
return x
|
||||
|
||||
|
||||
@BACKBONE_REGISTRY.register()
|
||||
def build_resnetIR_backbone(cfg):
|
||||
"""
|
||||
Create a ResNetIR instance from config.
|
||||
Returns:
|
||||
ResNet: a :class:`ResNet` instance.
|
||||
"""
|
||||
|
||||
# fmt: off
|
||||
bn_norm = cfg.MODEL.BACKBONE.NORM
|
||||
with_se = cfg.MODEL.BACKBONE.WITH_SE
|
||||
depth = cfg.MODEL.BACKBONE.DEPTH
|
||||
# fmt: on
|
||||
|
||||
model = ResNetIR(depth, bn_norm, 0.5, with_se)
|
||||
return model
|
|
@ -0,0 +1,84 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: xingyu liao
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Any, Dict
|
||||
|
||||
import torch
|
||||
|
||||
from fastreid.engine.hooks import PeriodicCheckpointer
|
||||
from fastreid.utils import comm
|
||||
from fastreid.utils.checkpoint import Checkpointer
|
||||
from fastreid.utils.file_io import PathManager
|
||||
|
||||
|
||||
class PfcPeriodicCheckpointer(PeriodicCheckpointer):
|
||||
|
||||
def step(self, epoch: int, **kwargs: Any):
|
||||
rank = comm.get_rank()
|
||||
if (epoch + 1) % self.period == 0 and epoch < self.max_epoch - 1:
|
||||
self.checkpointer.save(
|
||||
f"softmax_weight_{epoch:04d}_rank_{rank:02d}"
|
||||
)
|
||||
if epoch >= self.max_epoch - 1:
|
||||
self.checkpointer.save(f"softmax_weight_{rank:02d}", )
|
||||
|
||||
|
||||
class PfcCheckpointer(Checkpointer):
|
||||
def __init__(self, model, save_dir, *, save_to_disk=True, **checkpointables):
|
||||
super().__init__(model, save_dir, save_to_disk=save_to_disk, **checkpointables)
|
||||
self.rank = comm.get_rank()
|
||||
|
||||
def save(self, name: str, **kwargs: Dict[str, str]):
|
||||
if not self.save_dir or not self.save_to_disk:
|
||||
return
|
||||
|
||||
data = {}
|
||||
data["model"] = {
|
||||
"weight": self.model.weight.data,
|
||||
"momentum": self.model.weight_mom,
|
||||
}
|
||||
for key, obj in self.checkpointables.items():
|
||||
data[key] = obj.state_dict()
|
||||
data.update(kwargs)
|
||||
|
||||
basename = f"{name}.pth"
|
||||
save_file = os.path.join(self.save_dir, basename)
|
||||
assert os.path.basename(save_file) == basename, basename
|
||||
self.logger.info("Saving partial fc weights")
|
||||
with PathManager.open(save_file, "wb") as f:
|
||||
torch.save(data, f)
|
||||
self.tag_last_checkpoint(basename)
|
||||
|
||||
def _load_model(self, checkpoint: Any):
|
||||
checkpoint_state_dict = checkpoint.pop("model")
|
||||
self._convert_ndarray_to_tensor(checkpoint_state_dict)
|
||||
self.model.weight.data.copy_(checkpoint_state_dict.pop("weight"))
|
||||
self.model.weight_mom.data.copy_(checkpoint_state_dict.pop("momentum"))
|
||||
|
||||
def has_checkpoint(self):
|
||||
save_file = os.path.join(self.save_dir, f"last_weight_{self.rank:02d}")
|
||||
return PathManager.exists(save_file)
|
||||
|
||||
def get_checkpoint_file(self):
|
||||
"""
|
||||
Returns:
|
||||
str: The latest checkpoint file in target directory.
|
||||
"""
|
||||
save_file = os.path.join(self.save_dir, f"last_weight_{self.rank:02d}")
|
||||
try:
|
||||
with PathManager.open(save_file, "r") as f:
|
||||
last_saved = f.read().strip()
|
||||
except IOError:
|
||||
# if file doesn't exist, maybe because it has just been
|
||||
# deleted by a separate process
|
||||
return ""
|
||||
return os.path.join(self.save_dir, last_saved)
|
||||
|
||||
def tag_last_checkpoint(self, last_filename_basename: str):
|
||||
save_file = os.path.join(self.save_dir, f"last_weight_{self.rank:02d}")
|
||||
with PathManager.open(save_file, "w") as f:
|
||||
f.write(last_filename_basename)
|
|
@ -8,20 +8,23 @@ import os
|
|||
import time
|
||||
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
|
||||
from fastreid.engine import hooks
|
||||
from .face_data import TestFaceDataset
|
||||
from fastreid.data.datasets import DATASET_REGISTRY
|
||||
from fastreid.data.build import _root, build_reid_test_loader, build_reid_train_loader
|
||||
from fastreid.data.datasets import DATASET_REGISTRY
|
||||
from fastreid.data.transforms import build_transforms
|
||||
from fastreid.engine import hooks
|
||||
from fastreid.engine.defaults import DefaultTrainer, TrainerBase
|
||||
from fastreid.engine.train_loop import SimpleTrainer
|
||||
from fastreid.engine.train_loop import SimpleTrainer, AMPTrainer
|
||||
from fastreid.utils import comm
|
||||
from fastreid.utils.checkpoint import Checkpointer
|
||||
from fastreid.utils.logger import setup_logger
|
||||
from .face_data import MXFaceDataset
|
||||
from .face_data import TestFaceDataset
|
||||
from .face_evaluator import FaceEvaluator
|
||||
from .modeling import PartialFC
|
||||
from .pfc_checkpointer import PfcPeriodicCheckpointer, PfcCheckpointer
|
||||
from .utils_amp import MaxClipGradScaler
|
||||
|
||||
|
||||
class FaceTrainer(DefaultTrainer):
|
||||
|
@ -59,11 +62,17 @@ class FaceTrainer(DefaultTrainer):
|
|||
# for part of the parameters is not updated.
|
||||
model = DistributedDataParallel(
|
||||
model, device_ids=[comm.get_local_rank()], broadcast_buffers=False,
|
||||
find_unused_parameters=True
|
||||
)
|
||||
|
||||
self._trainer = PFCTrainer(model, data_loader, optimizer, self.pfc_module, self.pfc_optimizer) \
|
||||
if cfg.MODEL.HEADS.PFC.ENABLED else SimpleTrainer(model, data_loader, optimizer)
|
||||
if cfg.MODEL.HEADS.PFC.ENABLED:
|
||||
mini_batch_size = cfg.SOLVER.IMS_PER_BATCH // comm.get_world_size()
|
||||
grad_scaler = MaxClipGradScaler(mini_batch_size, 128 * mini_batch_size, growth_interval=100)
|
||||
self._trainer = PFCTrainer(model, data_loader, optimizer,
|
||||
self.pfc_module, self.pfc_optimizer, cfg.SOLVER.AMP.ENABLED, grad_scaler)
|
||||
else:
|
||||
self._trainer = (AMPTrainer if cfg.SOLVER.AMP.ENABLED else SimpleTrainer)(
|
||||
model, data_loader, optimizer
|
||||
)
|
||||
|
||||
self.iters_per_epoch = len(data_loader.dataset) // cfg.SOLVER.IMS_PER_BATCH
|
||||
self.scheduler = self.build_lr_scheduler(cfg, optimizer, self.iters_per_epoch)
|
||||
|
@ -80,7 +89,7 @@ class FaceTrainer(DefaultTrainer):
|
|||
)
|
||||
|
||||
if cfg.MODEL.HEADS.PFC.ENABLED:
|
||||
self.pfc_checkpointer = Checkpointer(
|
||||
self.pfc_checkpointer = PfcCheckpointer(
|
||||
self.pfc_module,
|
||||
cfg.OUTPUT_DIR,
|
||||
optimizer=self.pfc_optimizer,
|
||||
|
@ -100,12 +109,23 @@ class FaceTrainer(DefaultTrainer):
|
|||
ret = super().build_hooks()
|
||||
|
||||
if self.cfg.MODEL.HEADS.PFC.ENABLED:
|
||||
# Make sure checkpointer is after writer
|
||||
ret.insert(
|
||||
len(ret) - 1,
|
||||
PfcPeriodicCheckpointer(self.pfc_checkpointer, self.cfg.SOLVER.CHECKPOINT_PERIOD)
|
||||
)
|
||||
# partial fc scheduler hook
|
||||
ret.append(
|
||||
hooks.LRScheduler(self.pfc_optimizer, self.pfc_scheduler)
|
||||
)
|
||||
return ret
|
||||
|
||||
def resume_or_load(self, resume=True):
|
||||
# Backbone loading state_dict
|
||||
super().resume_or_load(resume)
|
||||
# Partial-FC loading state_dict
|
||||
self.pfc_checkpointer.resume_or_load('', resume=resume)
|
||||
|
||||
@classmethod
|
||||
def build_train_loader(cls, cfg):
|
||||
path_imgrec = cfg.DATASETS.REC_PATH
|
||||
|
@ -141,11 +161,14 @@ class PFCTrainer(SimpleTrainer):
|
|||
https://github.com/deepinsight/insightface/blob/master/recognition/arcface_torch/partial_fc.py
|
||||
"""
|
||||
|
||||
def __init__(self, model, data_loader, optimizer, pfc_module, pfc_optimizer):
|
||||
def __init__(self, model, data_loader, optimizer, pfc_module, pfc_optimizer, amp_enabled, grad_scaler):
|
||||
super().__init__(model, data_loader, optimizer)
|
||||
|
||||
self.pfc_module = pfc_module
|
||||
self.pfc_optimizer = pfc_optimizer
|
||||
self.amp_enabled = amp_enabled
|
||||
|
||||
self.grad_scaler = grad_scaler
|
||||
|
||||
def run_step(self):
|
||||
assert self.model.training, "[PFCTrainer] model was changed to eval mode!"
|
||||
|
@ -156,18 +179,24 @@ class PFCTrainer(SimpleTrainer):
|
|||
|
||||
features, targets = self.model(data)
|
||||
|
||||
self.optimizer.zero_grad()
|
||||
self.pfc_optimizer.zero_grad()
|
||||
|
||||
# Partial-fc backward
|
||||
f_grad, loss_v = self.pfc_module.forward_backward(features, targets, self.pfc_optimizer)
|
||||
|
||||
if self.amp_enabled:
|
||||
features.backward(self.grad_scaler.scale(f_grad))
|
||||
self.grad_scaler.unscale_(self.optimizer)
|
||||
clip_grad_norm_(self.model.parameters(), max_norm=5, norm_type=2)
|
||||
self.grad_scaler.step(self.optimizer)
|
||||
self.grad_scaler.update()
|
||||
else:
|
||||
features.backward(f_grad)
|
||||
clip_grad_norm_(self.model.parameters(), max_norm=5, norm_type=2)
|
||||
self.optimizer.step()
|
||||
|
||||
loss_dict = {"loss_cls": loss_v}
|
||||
self._write_metrics(loss_dict, data_time)
|
||||
|
||||
self.optimizer.step()
|
||||
self.pfc_optimizer.step()
|
||||
|
||||
self.pfc_module.update()
|
||||
self.optimizer.zero_grad()
|
||||
self.pfc_optimizer.zero_grad()
|
||||
|
|
|
@ -0,0 +1,86 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: xingyu liao
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
from typing import Dict, List
|
||||
|
||||
import torch
|
||||
from torch._six import container_abcs
|
||||
from torch.cuda.amp import GradScaler
|
||||
|
||||
|
||||
class _MultiDeviceReplicator(object):
|
||||
"""
|
||||
Lazily serves copies of a tensor to requested devices. Copies are cached per-device.
|
||||
"""
|
||||
|
||||
def __init__(self, master_tensor: torch.Tensor) -> None:
|
||||
assert master_tensor.is_cuda
|
||||
self.master = master_tensor
|
||||
self._per_device_tensors: Dict[torch.device, torch.Tensor] = {}
|
||||
|
||||
def get(self, device) -> torch.Tensor:
|
||||
retval = self._per_device_tensors.get(device, None)
|
||||
if retval is None:
|
||||
retval = self.master.to(device=device, non_blocking=True, copy=True)
|
||||
self._per_device_tensors[device] = retval
|
||||
return retval
|
||||
|
||||
|
||||
class MaxClipGradScaler(GradScaler):
|
||||
def __init__(self, init_scale, max_scale: float, growth_interval=100):
|
||||
super().__init__(init_scale=init_scale, growth_interval=growth_interval)
|
||||
self.max_scale = max_scale
|
||||
|
||||
def scale_clip(self):
|
||||
if self.get_scale() == self.max_scale:
|
||||
self.set_growth_factor(1)
|
||||
elif self.get_scale() < self.max_scale:
|
||||
self.set_growth_factor(2)
|
||||
elif self.get_scale() > self.max_scale:
|
||||
self._scale.fill_(self.max_scale)
|
||||
self.set_growth_factor(1)
|
||||
|
||||
def scale(self, outputs):
|
||||
"""
|
||||
Multiplies ('scales') a tensor or list of tensors by the scale factor.
|
||||
Returns scaled outputs. If this instance of :class:`GradScaler` is not enabled, outputs are returned
|
||||
unmodified.
|
||||
Arguments:
|
||||
outputs (Tensor or iterable of Tensors): Outputs to scale.
|
||||
"""
|
||||
if not self._enabled:
|
||||
return outputs
|
||||
self.scale_clip()
|
||||
# Short-circuit for the common case.
|
||||
if isinstance(outputs, torch.Tensor):
|
||||
assert outputs.is_cuda
|
||||
if self._scale is None:
|
||||
self._lazy_init_scale_growth_tracker(outputs.device)
|
||||
assert self._scale is not None
|
||||
return outputs * self._scale.to(device=outputs.device, non_blocking=True)
|
||||
|
||||
# Invoke the more complex machinery only if we're treating multiple outputs.
|
||||
stash: List[_MultiDeviceReplicator] = [] # holds a reference that can be overwritten by apply_scale
|
||||
|
||||
def apply_scale(val):
|
||||
if isinstance(val, torch.Tensor):
|
||||
assert val.is_cuda
|
||||
if len(stash) == 0:
|
||||
if self._scale is None:
|
||||
self._lazy_init_scale_growth_tracker(val.device)
|
||||
assert self._scale is not None
|
||||
stash.append(_MultiDeviceReplicator(self._scale))
|
||||
return val * stash[0].get(val.device)
|
||||
elif isinstance(val, container_abcs.Iterable):
|
||||
iterable = map(apply_scale, val)
|
||||
if isinstance(val, list) or isinstance(val, tuple):
|
||||
return type(val)(iterable)
|
||||
else:
|
||||
return iterable
|
||||
else:
|
||||
raise ValueError("outputs must be a Tensor or an iterable of Tensors")
|
||||
|
||||
return apply_scale(outputs)
|
|
@ -14,7 +14,6 @@ from fastreid.engine import default_argument_parser, default_setup, launch
|
|||
from fastreid.utils.checkpoint import Checkpointer
|
||||
|
||||
from fastface import *
|
||||
from fastface.datasets import *
|
||||
|
||||
|
||||
def setup(args):
|
||||
|
|
Loading…
Reference in New Issue