diff --git a/projects/FastFace/configs/face_base.yml b/projects/FastFace/configs/face_base.yml index f125cf0..4bbec67 100644 --- a/projects/FastFace/configs/face_base.yml +++ b/projects/FastFace/configs/face_base.yml @@ -4,6 +4,9 @@ MODEL: PIXEL_MEAN: [127.5, 127.5, 127.5] PIXEL_STD: [127.5, 127.5, 127.5] + BACKBONE: + NAME: build_iresnet_backbone + HEADS: NAME: FaceHead WITH_BNNECK: True @@ -30,7 +33,7 @@ MODEL: DATASETS: REC_PATH: /export/home/DATA/Glint360k/train.rec NAMES: ("MS1MV2",) - TESTS: ("CPLFW", "VGG2_FP", "CALFW", "CFP_FF", "CFP_FP", "AgeDB_30", "LFW") + TESTS: ("CFP_FP", "AgeDB_30", "LFW") INPUT: SIZE_TRAIN: [0,] # No need of resize @@ -47,10 +50,10 @@ DATALOADER: SOLVER: MAX_EPOCH: 20 AMP: - ENABLED: False + ENABLED: True OPT: SGD - BASE_LR: 0.1 + BASE_LR: 0.05 MOMENTUM: 0.9 SCHED: MultiStepLR @@ -59,10 +62,10 @@ SOLVER: BIAS_LR_FACTOR: 1. WEIGHT_DECAY: 0.0005 WEIGHT_DECAY_BIAS: 0.0005 - IMS_PER_BATCH: 512 + IMS_PER_BATCH: 256 WARMUP_FACTOR: 0.1 - WARMUP_ITERS: 5000 + WARMUP_ITERS: 0 CHECKPOINT_PERIOD: 1 diff --git a/projects/FastFace/configs/r50_ir.yml b/projects/FastFace/configs/r50_ir.yml index d18bc59..7340754 100644 --- a/projects/FastFace/configs/r50_ir.yml +++ b/projects/FastFace/configs/r50_ir.yml @@ -3,13 +3,12 @@ _BASE_: face_base.yml MODEL: BACKBONE: - NAME: build_resnetIR_backbone DEPTH: 50x FEAT_DIM: 25088 # 512x7x7 - WITH_SE: True + DROPOUT: 0. HEADS: PFC: ENABLED: True -OUTPUT_DIR: projects/FastFace/logs/ir_se50-glink360k-pfc0.1 +OUTPUT_DIR: projects/FastFace/logs/pfc0.1_insightface diff --git a/projects/FastFace/fastface/__init__.py b/projects/FastFace/fastface/__init__.py index 10c6a39..5b2ee35 100644 --- a/projects/FastFace/fastface/__init__.py +++ b/projects/FastFace/fastface/__init__.py @@ -7,3 +7,4 @@ from .modeling import * from .config import add_face_cfg from .trainer import FaceTrainer +from .datasets import * diff --git a/projects/FastFace/fastface/config.py b/projects/FastFace/fastface/config.py index f47fa9f..af9e65e 100644 --- a/projects/FastFace/fastface/config.py +++ b/projects/FastFace/fastface/config.py @@ -12,5 +12,7 @@ def add_face_cfg(cfg): _C.DATASETS.REC_PATH = "" + _C.MODEL.BACKBONE.DROPOUT = 0. + _C.MODEL.HEADS.PFC = CN({"ENABLED": False}) _C.MODEL.HEADS.PFC.SAMPLE_RATE = 0.1 diff --git a/projects/FastFace/fastface/datasets/ms1mv2.py b/projects/FastFace/fastface/datasets/ms1mv2.py index b19d4f8..c633a47 100644 --- a/projects/FastFace/fastface/datasets/ms1mv2.py +++ b/projects/FastFace/fastface/datasets/ms1mv2.py @@ -23,7 +23,7 @@ class MS1MV2(ImageDataset): required_files = [self.dataset_dir] self.check_before_run(required_files) - train = self.process_dirs() + train = self.process_dirs()[:10000] super().__init__(train, [], [], **kwargs) def process_dirs(self): diff --git a/projects/FastFace/fastface/modeling/__init__.py b/projects/FastFace/fastface/modeling/__init__.py index 7897014..4cf69a0 100644 --- a/projects/FastFace/fastface/modeling/__init__.py +++ b/projects/FastFace/fastface/modeling/__init__.py @@ -7,4 +7,4 @@ from .partial_fc import PartialFC from .face_baseline import FaceBaseline from .face_head import FaceHead -from .resnet_ir import build_resnetIR_backbone +from .iresnet import build_iresnet_backbone diff --git a/projects/FastFace/fastface/modeling/face_baseline.py b/projects/FastFace/fastface/modeling/face_baseline.py index 9cf6ec7..ad0b24e 100644 --- a/projects/FastFace/fastface/modeling/face_baseline.py +++ b/projects/FastFace/fastface/modeling/face_baseline.py @@ -4,6 +4,7 @@ @contact: sherlockliao01@gmail.com """ +import torch from fastreid.modeling.meta_arch import Baseline from fastreid.modeling.meta_arch import META_ARCH_REGISTRY @@ -13,12 +14,28 @@ class FaceBaseline(Baseline): def __init__(self, cfg): super().__init__(cfg) self.pfc_enabled = cfg.MODEL.HEADS.PFC.ENABLED + self.amp_enabled = cfg.SOLVER.AMP.ENABLED - def losses(self, outputs, gt_labels): + def forward(self, batched_inputs): if not self.pfc_enabled: - return super().losses(outputs, gt_labels) + return super().forward(batched_inputs) + + images = self.preprocess_image(batched_inputs) + with torch.cuda.amp.autocast(self.amp_enabled): + features = self.backbone(images) + features = features.float() if self.amp_enabled else features + + if self.training: + assert "targets" in batched_inputs, "Person ID annotation are missing in training!" + targets = batched_inputs["targets"] + + # PreciseBN flag, When do preciseBN on different dataset, the number of classes in new dataset + # may be larger than that in the original dataset, so the circle/arcface will + # throw an error. We just set all the targets to 0 to avoid this problem. + if targets.sum() < 0: targets.zero_() + + outputs = self.heads(features, targets) + return outputs, targets else: - # model parallel with partial-fc - # cls layer and loss computation in partial_fc.py - pred_features = outputs["features"] - return pred_features, gt_labels + outputs = self.heads(features) + return outputs diff --git a/projects/FastFace/fastface/modeling/face_head.py b/projects/FastFace/fastface/modeling/face_head.py index 0168b84..7583a0b 100644 --- a/projects/FastFace/fastface/modeling/face_head.py +++ b/projects/FastFace/fastface/modeling/face_head.py @@ -30,10 +30,4 @@ class FaceHead(EmbeddingHead): pool_feat = self.pool_layer(features) neck_feat = self.bottleneck(pool_feat) neck_feat = neck_feat[..., 0, 0] - - if not self.training: - return neck_feat - - return { - "features": neck_feat, - } + return neck_feat diff --git a/projects/FastFace/fastface/modeling/iresnet.py b/projects/FastFace/fastface/modeling/iresnet.py new file mode 100644 index 0000000..da2d593 --- /dev/null +++ b/projects/FastFace/fastface/modeling/iresnet.py @@ -0,0 +1,179 @@ +# encoding: utf-8 +""" +@author: xingyu liao +@contact: sherlockliao01@gmail.com +""" + +import torch +from torch import nn + +from fastreid.layers import get_norm +from fastreid.modeling.backbones import BACKBONE_REGISTRY + + +def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=dilation, groups=groups, bias=False, dilation=dilation) + + +def conv1x1(in_planes, out_planes, stride=1): + """1x1 convolution""" + return nn.Conv2d(in_planes, + out_planes, + kernel_size=1, + stride=stride, + bias=False) + + +class IBasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, bn_norm, stride=1, downsample=None, + groups=1, base_width=64, dilation=1): + super().__init__() + if groups != 1 or base_width != 64: + raise ValueError('BasicBlock only supports groups=1 and base_width=64') + if dilation > 1: + raise NotImplementedError("Dilation > 1 not supported in BasicBlock") + self.bn1 = get_norm(bn_norm, inplanes) + self.conv1 = conv3x3(inplanes, planes) + self.bn2 = get_norm(bn_norm, planes) + self.prelu = nn.PReLU(planes) + self.conv2 = conv3x3(planes, planes, stride) + self.bn3 = get_norm(bn_norm, planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + out = self.bn1(x) + out = self.conv1(out) + out = self.bn2(out) + out = self.prelu(out) + out = self.conv2(out) + out = self.bn3(out) + if self.downsample is not None: + identity = self.downsample(x) + out += identity + return out + + +class IResNet(nn.Module): + fc_scale = 7 * 7 + + def __init__(self, block, layers, bn_norm, dropout=0, zero_init_residual=False, + groups=1, width_per_group=64, replace_stride_with_dilation=None, fp16=False): + super().__init__() + self.inplanes = 64 + self.dilation = 1 + self.fp16 = fp16 + if replace_stride_with_dilation is None: + replace_stride_with_dilation = [False, False, False] + if len(replace_stride_with_dilation) != 3: + raise ValueError("replace_stride_with_dilation should be None " + "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) + self.groups = groups + self.base_width = width_per_group + self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False) + self.bn1 = get_norm(bn_norm, self.inplanes) + self.prelu = nn.PReLU(self.inplanes) + self.layer1 = self._make_layer(block, 64, layers[0], bn_norm, stride=2) + self.layer2 = self._make_layer(block, + 128, + layers[1], + bn_norm, + stride=2, + dilate=replace_stride_with_dilation[0]) + self.layer3 = self._make_layer(block, + 256, + layers[2], + bn_norm, + stride=2, + dilate=replace_stride_with_dilation[1]) + self.layer4 = self._make_layer(block, + 512, + layers[3], + bn_norm, + stride=2, + dilate=replace_stride_with_dilation[2]) + self.bn2 = get_norm(bn_norm, 512 * block.expansion) + self.dropout = nn.Dropout(p=dropout, inplace=True) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.normal_(m.weight, 0, 0.1) + elif m.__class__.__name__.find('Norm') != -1: + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + if zero_init_residual: + for m in self.modules(): + if isinstance(m, IBasicBlock): + nn.init.constant_(m.bn2.weight, 0) + + def _make_layer(self, block, planes, blocks, bn_norm, stride=1, dilate=False): + downsample = None + previous_dilation = self.dilation + if dilate: + self.dilation *= stride + stride = 1 + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + conv1x1(self.inplanes, planes * block.expansion, stride), + get_norm(bn_norm, planes * block.expansion), + ) + layers = [] + layers.append( + block(self.inplanes, planes, bn_norm, stride, downsample, self.groups, + self.base_width, previous_dilation)) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append( + block(self.inplanes, + planes, + bn_norm, + groups=self.groups, + base_width=self.base_width, + dilation=self.dilation)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.prelu(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = self.bn2(x) + x = self.dropout(x) + return x + + +@BACKBONE_REGISTRY.register() +def build_iresnet_backbone(cfg): + """ + Create a IResNet instance from config. + Returns: + ResNet: a :class:`ResNet` instance. + """ + + # fmt: off + bn_norm = cfg.MODEL.BACKBONE.NORM + depth = cfg.MODEL.BACKBONE.DEPTH + dropout = cfg.MODEL.BACKBONE.DROPOUT + fp16 = cfg.SOLVER.AMP.ENABLED + # fmt: on + + num_blocks_per_stage = { + '18x': [2, 2, 2, 2], + '34x': [3, 4, 6, 3], + '50x': [3, 4, 14, 3], + '100x': [3, 13, 30, 3], + '200x': [6, 26, 60, 6], + }[depth] + + model = IResNet(IBasicBlock, num_blocks_per_stage, bn_norm, dropout, fp16=fp16) + return model diff --git a/projects/FastFace/fastface/modeling/partial_fc.py b/projects/FastFace/fastface/modeling/partial_fc.py index 408cace..b1c6c42 100644 --- a/projects/FastFace/fastface/modeling/partial_fc.py +++ b/projects/FastFace/fastface/modeling/partial_fc.py @@ -52,23 +52,6 @@ class PartialFC(nn.Module): self.cls_layer = getattr(any_softmax, cls_type)(num_classes, scale, margin) - """ TODO: consider resume training - if resume: - try: - self.weight: torch.Tensor = torch.load(self.weight_name) - logging.info("softmax weight resume successfully!") - except (FileNotFoundError, KeyError, IndexError): - self.weight = torch.normal(0, 0.01, (self.num_local, self.embedding_size), device=self.device) - logging.info("softmax weight resume fail!") - - try: - self.weight_mom: torch.Tensor = torch.load(self.weight_mom_name) - logging.info("softmax weight mom resume successfully!") - except (FileNotFoundError, KeyError, IndexError): - self.weight_mom: torch.Tensor = torch.zeros_like(self.weight) - logging.info("softmax weight mom resume fail!") - else: - """ self.weight = torch.normal(0, 0.01, (self.num_local, self.embedding_size), device=self.device) self.weight_mom: torch.Tensor = torch.zeros_like(self.weight) logger.info("softmax weight init successfully!") diff --git a/projects/FastFace/fastface/modeling/resnet_ir.py b/projects/FastFace/fastface/modeling/resnet_ir.py deleted file mode 100644 index 67632a2..0000000 --- a/projects/FastFace/fastface/modeling/resnet_ir.py +++ /dev/null @@ -1,122 +0,0 @@ -# encoding: utf-8 -""" -@author: xingyu liao -@contact: sherlockliao01@gmail.com -""" - -from collections import namedtuple - -from torch import nn - -from fastreid.layers import get_norm, SELayer -from fastreid.modeling.backbones import BACKBONE_REGISTRY - - -def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): - """3x3 convolution with padding""" - return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, - padding=dilation, groups=groups, bias=False, dilation=dilation) - - -class bottleneck_IR(nn.Module): - def __init__(self, in_channel, depth, bn_norm, stride, with_se=False): - super(bottleneck_IR, self).__init__() - if in_channel == depth: - self.shortcut_layer = nn.MaxPool2d(1, stride) - else: - self.shortcut_layer = nn.Sequential( - nn.Conv2d(in_channel, depth, (1, 1), stride, bias=False), - get_norm(bn_norm, depth)) - self.res_layer = nn.Sequential( - get_norm(bn_norm, in_channel), - nn.Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), - nn.PReLU(depth), - nn.Conv2d(depth, depth, (3, 3), stride, 1, bias=False), - get_norm(bn_norm, depth), - SELayer(depth, 16) if with_se else nn.Identity() - ) - - def forward(self, x): - shortcut = self.shortcut_layer(x) - res = self.res_layer(x) - return res + shortcut - - -class Bottleneck(namedtuple("Block", ["in_channel", "depth", "bn_norm", "stride", "with_se"])): - """A named tuple describing a ResNet block.""" - - -def get_block(in_channel, depth, bn_norm, num_units, with_se, stride=2): - return [Bottleneck(in_channel, depth, bn_norm, stride, with_se)] + \ - [Bottleneck(depth, depth, bn_norm, 1, with_se) for _ in range(num_units - 1)] - - -def get_blocks(bn_norm, with_se, num_layers): - if num_layers == "50x": - blocks = [ - get_block(in_channel=64, depth=64, bn_norm=bn_norm, num_units=3, with_se=with_se), - get_block(in_channel=64, depth=128, bn_norm=bn_norm, num_units=4, with_se=with_se), - get_block(in_channel=128, depth=256, bn_norm=bn_norm, num_units=14, with_se=with_se), - get_block(in_channel=256, depth=512, bn_norm=bn_norm, num_units=3, with_se=with_se) - ] - elif num_layers == "100x": - blocks = [ - get_block(in_channel=64, depth=64, bn_norm=bn_norm, num_units=3, with_se=with_se), - get_block(in_channel=64, depth=128, bn_norm=bn_norm, num_units=13, with_se=with_se), - get_block(in_channel=128, depth=256, bn_norm=bn_norm, num_units=30, with_se=with_se), - get_block(in_channel=256, depth=512, bn_norm=bn_norm, num_units=3, with_se=with_se) - ] - elif num_layers == "152x": - blocks = [ - get_block(in_channel=64, depth=64, bn_norm=bn_norm, num_units=3, with_se=with_se), - get_block(in_channel=64, depth=128, bn_norm=bn_norm, num_units=8, with_se=with_se), - get_block(in_channel=128, depth=256, bn_norm=bn_norm, num_units=36, with_se=with_se), - get_block(in_channel=256, depth=512, bn_norm=bn_norm, num_units=3, with_se=with_se) - ] - return blocks - - -class ResNetIR(nn.Module): - def __init__(self, num_layers, bn_norm, drop_ratio, with_se): - super(ResNetIR, self).__init__() - assert num_layers in ["50x", "100x", "152x"], "num_layers should be 50,100, or 152" - blocks = get_blocks(bn_norm, with_se, num_layers) - self.input_layer = nn.Sequential(nn.Conv2d(3, 64, (3, 3), 1, 1, bias=False), - get_norm(bn_norm, 64), - nn.PReLU(64)) - self.output_layer = nn.Sequential(get_norm(bn_norm, 512), - nn.Dropout(drop_ratio)) - modules = [] - for block in blocks: - for bottleneck in block: - modules.append( - bottleneck_IR(bottleneck.in_channel, - bottleneck.depth, - bottleneck.bn_norm, - bottleneck.stride, - bottleneck.with_se)) - self.body = nn.Sequential(*modules) - - def forward(self, x): - x = self.input_layer(x) - x = self.body(x) - x = self.output_layer(x) - return x - - -@BACKBONE_REGISTRY.register() -def build_resnetIR_backbone(cfg): - """ - Create a ResNetIR instance from config. - Returns: - ResNet: a :class:`ResNet` instance. - """ - - # fmt: off - bn_norm = cfg.MODEL.BACKBONE.NORM - with_se = cfg.MODEL.BACKBONE.WITH_SE - depth = cfg.MODEL.BACKBONE.DEPTH - # fmt: on - - model = ResNetIR(depth, bn_norm, 0.5, with_se) - return model diff --git a/projects/FastFace/fastface/pfc_checkpointer.py b/projects/FastFace/fastface/pfc_checkpointer.py new file mode 100644 index 0000000..0ea759b --- /dev/null +++ b/projects/FastFace/fastface/pfc_checkpointer.py @@ -0,0 +1,84 @@ +# encoding: utf-8 +""" +@author: xingyu liao +@contact: sherlockliao01@gmail.com +""" + +import os +from typing import Any, Dict + +import torch + +from fastreid.engine.hooks import PeriodicCheckpointer +from fastreid.utils import comm +from fastreid.utils.checkpoint import Checkpointer +from fastreid.utils.file_io import PathManager + + +class PfcPeriodicCheckpointer(PeriodicCheckpointer): + + def step(self, epoch: int, **kwargs: Any): + rank = comm.get_rank() + if (epoch + 1) % self.period == 0 and epoch < self.max_epoch - 1: + self.checkpointer.save( + f"softmax_weight_{epoch:04d}_rank_{rank:02d}" + ) + if epoch >= self.max_epoch - 1: + self.checkpointer.save(f"softmax_weight_{rank:02d}", ) + + +class PfcCheckpointer(Checkpointer): + def __init__(self, model, save_dir, *, save_to_disk=True, **checkpointables): + super().__init__(model, save_dir, save_to_disk=save_to_disk, **checkpointables) + self.rank = comm.get_rank() + + def save(self, name: str, **kwargs: Dict[str, str]): + if not self.save_dir or not self.save_to_disk: + return + + data = {} + data["model"] = { + "weight": self.model.weight.data, + "momentum": self.model.weight_mom, + } + for key, obj in self.checkpointables.items(): + data[key] = obj.state_dict() + data.update(kwargs) + + basename = f"{name}.pth" + save_file = os.path.join(self.save_dir, basename) + assert os.path.basename(save_file) == basename, basename + self.logger.info("Saving partial fc weights") + with PathManager.open(save_file, "wb") as f: + torch.save(data, f) + self.tag_last_checkpoint(basename) + + def _load_model(self, checkpoint: Any): + checkpoint_state_dict = checkpoint.pop("model") + self._convert_ndarray_to_tensor(checkpoint_state_dict) + self.model.weight.data.copy_(checkpoint_state_dict.pop("weight")) + self.model.weight_mom.data.copy_(checkpoint_state_dict.pop("momentum")) + + def has_checkpoint(self): + save_file = os.path.join(self.save_dir, f"last_weight_{self.rank:02d}") + return PathManager.exists(save_file) + + def get_checkpoint_file(self): + """ + Returns: + str: The latest checkpoint file in target directory. + """ + save_file = os.path.join(self.save_dir, f"last_weight_{self.rank:02d}") + try: + with PathManager.open(save_file, "r") as f: + last_saved = f.read().strip() + except IOError: + # if file doesn't exist, maybe because it has just been + # deleted by a separate process + return "" + return os.path.join(self.save_dir, last_saved) + + def tag_last_checkpoint(self, last_filename_basename: str): + save_file = os.path.join(self.save_dir, f"last_weight_{self.rank:02d}") + with PathManager.open(save_file, "w") as f: + f.write(last_filename_basename) diff --git a/projects/FastFace/fastface/trainer.py b/projects/FastFace/fastface/trainer.py index 334a141..90db184 100644 --- a/projects/FastFace/fastface/trainer.py +++ b/projects/FastFace/fastface/trainer.py @@ -8,20 +8,23 @@ import os import time from torch.nn.parallel import DistributedDataParallel +from torch.nn.utils import clip_grad_norm_ -from fastreid.engine import hooks -from .face_data import TestFaceDataset -from fastreid.data.datasets import DATASET_REGISTRY from fastreid.data.build import _root, build_reid_test_loader, build_reid_train_loader +from fastreid.data.datasets import DATASET_REGISTRY from fastreid.data.transforms import build_transforms +from fastreid.engine import hooks from fastreid.engine.defaults import DefaultTrainer, TrainerBase -from fastreid.engine.train_loop import SimpleTrainer +from fastreid.engine.train_loop import SimpleTrainer, AMPTrainer from fastreid.utils import comm from fastreid.utils.checkpoint import Checkpointer from fastreid.utils.logger import setup_logger from .face_data import MXFaceDataset +from .face_data import TestFaceDataset from .face_evaluator import FaceEvaluator from .modeling import PartialFC +from .pfc_checkpointer import PfcPeriodicCheckpointer, PfcCheckpointer +from .utils_amp import MaxClipGradScaler class FaceTrainer(DefaultTrainer): @@ -59,11 +62,17 @@ class FaceTrainer(DefaultTrainer): # for part of the parameters is not updated. model = DistributedDataParallel( model, device_ids=[comm.get_local_rank()], broadcast_buffers=False, - find_unused_parameters=True ) - self._trainer = PFCTrainer(model, data_loader, optimizer, self.pfc_module, self.pfc_optimizer) \ - if cfg.MODEL.HEADS.PFC.ENABLED else SimpleTrainer(model, data_loader, optimizer) + if cfg.MODEL.HEADS.PFC.ENABLED: + mini_batch_size = cfg.SOLVER.IMS_PER_BATCH // comm.get_world_size() + grad_scaler = MaxClipGradScaler(mini_batch_size, 128 * mini_batch_size, growth_interval=100) + self._trainer = PFCTrainer(model, data_loader, optimizer, + self.pfc_module, self.pfc_optimizer, cfg.SOLVER.AMP.ENABLED, grad_scaler) + else: + self._trainer = (AMPTrainer if cfg.SOLVER.AMP.ENABLED else SimpleTrainer)( + model, data_loader, optimizer + ) self.iters_per_epoch = len(data_loader.dataset) // cfg.SOLVER.IMS_PER_BATCH self.scheduler = self.build_lr_scheduler(cfg, optimizer, self.iters_per_epoch) @@ -80,7 +89,7 @@ class FaceTrainer(DefaultTrainer): ) if cfg.MODEL.HEADS.PFC.ENABLED: - self.pfc_checkpointer = Checkpointer( + self.pfc_checkpointer = PfcCheckpointer( self.pfc_module, cfg.OUTPUT_DIR, optimizer=self.pfc_optimizer, @@ -100,12 +109,23 @@ class FaceTrainer(DefaultTrainer): ret = super().build_hooks() if self.cfg.MODEL.HEADS.PFC.ENABLED: + # Make sure checkpointer is after writer + ret.insert( + len(ret) - 1, + PfcPeriodicCheckpointer(self.pfc_checkpointer, self.cfg.SOLVER.CHECKPOINT_PERIOD) + ) # partial fc scheduler hook ret.append( hooks.LRScheduler(self.pfc_optimizer, self.pfc_scheduler) ) return ret + def resume_or_load(self, resume=True): + # Backbone loading state_dict + super().resume_or_load(resume) + # Partial-FC loading state_dict + self.pfc_checkpointer.resume_or_load('', resume=resume) + @classmethod def build_train_loader(cls, cfg): path_imgrec = cfg.DATASETS.REC_PATH @@ -141,11 +161,14 @@ class PFCTrainer(SimpleTrainer): https://github.com/deepinsight/insightface/blob/master/recognition/arcface_torch/partial_fc.py """ - def __init__(self, model, data_loader, optimizer, pfc_module, pfc_optimizer): + def __init__(self, model, data_loader, optimizer, pfc_module, pfc_optimizer, amp_enabled, grad_scaler): super().__init__(model, data_loader, optimizer) self.pfc_module = pfc_module self.pfc_optimizer = pfc_optimizer + self.amp_enabled = amp_enabled + + self.grad_scaler = grad_scaler def run_step(self): assert self.model.training, "[PFCTrainer] model was changed to eval mode!" @@ -156,18 +179,24 @@ class PFCTrainer(SimpleTrainer): features, targets = self.model(data) - self.optimizer.zero_grad() - self.pfc_optimizer.zero_grad() - # Partial-fc backward f_grad, loss_v = self.pfc_module.forward_backward(features, targets, self.pfc_optimizer) - features.backward(f_grad) + if self.amp_enabled: + features.backward(self.grad_scaler.scale(f_grad)) + self.grad_scaler.unscale_(self.optimizer) + clip_grad_norm_(self.model.parameters(), max_norm=5, norm_type=2) + self.grad_scaler.step(self.optimizer) + self.grad_scaler.update() + else: + features.backward(f_grad) + clip_grad_norm_(self.model.parameters(), max_norm=5, norm_type=2) + self.optimizer.step() loss_dict = {"loss_cls": loss_v} self._write_metrics(loss_dict, data_time) - self.optimizer.step() self.pfc_optimizer.step() - self.pfc_module.update() + self.optimizer.zero_grad() + self.pfc_optimizer.zero_grad() diff --git a/projects/FastFace/fastface/utils_amp.py b/projects/FastFace/fastface/utils_amp.py new file mode 100644 index 0000000..de115c6 --- /dev/null +++ b/projects/FastFace/fastface/utils_amp.py @@ -0,0 +1,86 @@ +# encoding: utf-8 +""" +@author: xingyu liao +@contact: sherlockliao01@gmail.com +""" + +from typing import Dict, List + +import torch +from torch._six import container_abcs +from torch.cuda.amp import GradScaler + + +class _MultiDeviceReplicator(object): + """ + Lazily serves copies of a tensor to requested devices. Copies are cached per-device. + """ + + def __init__(self, master_tensor: torch.Tensor) -> None: + assert master_tensor.is_cuda + self.master = master_tensor + self._per_device_tensors: Dict[torch.device, torch.Tensor] = {} + + def get(self, device) -> torch.Tensor: + retval = self._per_device_tensors.get(device, None) + if retval is None: + retval = self.master.to(device=device, non_blocking=True, copy=True) + self._per_device_tensors[device] = retval + return retval + + +class MaxClipGradScaler(GradScaler): + def __init__(self, init_scale, max_scale: float, growth_interval=100): + super().__init__(init_scale=init_scale, growth_interval=growth_interval) + self.max_scale = max_scale + + def scale_clip(self): + if self.get_scale() == self.max_scale: + self.set_growth_factor(1) + elif self.get_scale() < self.max_scale: + self.set_growth_factor(2) + elif self.get_scale() > self.max_scale: + self._scale.fill_(self.max_scale) + self.set_growth_factor(1) + + def scale(self, outputs): + """ + Multiplies ('scales') a tensor or list of tensors by the scale factor. + Returns scaled outputs. If this instance of :class:`GradScaler` is not enabled, outputs are returned + unmodified. + Arguments: + outputs (Tensor or iterable of Tensors): Outputs to scale. + """ + if not self._enabled: + return outputs + self.scale_clip() + # Short-circuit for the common case. + if isinstance(outputs, torch.Tensor): + assert outputs.is_cuda + if self._scale is None: + self._lazy_init_scale_growth_tracker(outputs.device) + assert self._scale is not None + return outputs * self._scale.to(device=outputs.device, non_blocking=True) + + # Invoke the more complex machinery only if we're treating multiple outputs. + stash: List[_MultiDeviceReplicator] = [] # holds a reference that can be overwritten by apply_scale + + def apply_scale(val): + if isinstance(val, torch.Tensor): + assert val.is_cuda + if len(stash) == 0: + if self._scale is None: + self._lazy_init_scale_growth_tracker(val.device) + assert self._scale is not None + stash.append(_MultiDeviceReplicator(self._scale)) + return val * stash[0].get(val.device) + elif isinstance(val, container_abcs.Iterable): + iterable = map(apply_scale, val) + if isinstance(val, list) or isinstance(val, tuple): + return type(val)(iterable) + else: + return iterable + else: + raise ValueError("outputs must be a Tensor or an iterable of Tensors") + + return apply_scale(outputs) diff --git a/projects/FastFace/train_net.py b/projects/FastFace/train_net.py index 2b47401..8e24769 100644 --- a/projects/FastFace/train_net.py +++ b/projects/FastFace/train_net.py @@ -14,7 +14,6 @@ from fastreid.engine import default_argument_parser, default_setup, launch from fastreid.utils.checkpoint import Checkpointer from fastface import * -from fastface.datasets import * def setup(args):