Support amp and resume training in fastface

AMP in partial-fc needs to be done only on backbone; In order to impl `resume training`, need to save & load different part of classifier weight in each GPU.
pull/504/head
liaoxingyu 2021-05-31 17:30:43 +08:00
parent 91ff631184
commit c3ac4f504c
15 changed files with 432 additions and 178 deletions

View File

@ -4,6 +4,9 @@ MODEL:
PIXEL_MEAN: [127.5, 127.5, 127.5]
PIXEL_STD: [127.5, 127.5, 127.5]
BACKBONE:
NAME: build_iresnet_backbone
HEADS:
NAME: FaceHead
WITH_BNNECK: True
@ -30,7 +33,7 @@ MODEL:
DATASETS:
REC_PATH: /export/home/DATA/Glint360k/train.rec
NAMES: ("MS1MV2",)
TESTS: ("CPLFW", "VGG2_FP", "CALFW", "CFP_FF", "CFP_FP", "AgeDB_30", "LFW")
TESTS: ("CFP_FP", "AgeDB_30", "LFW")
INPUT:
SIZE_TRAIN: [0,] # No need of resize
@ -47,10 +50,10 @@ DATALOADER:
SOLVER:
MAX_EPOCH: 20
AMP:
ENABLED: False
ENABLED: True
OPT: SGD
BASE_LR: 0.1
BASE_LR: 0.05
MOMENTUM: 0.9
SCHED: MultiStepLR
@ -59,10 +62,10 @@ SOLVER:
BIAS_LR_FACTOR: 1.
WEIGHT_DECAY: 0.0005
WEIGHT_DECAY_BIAS: 0.0005
IMS_PER_BATCH: 512
IMS_PER_BATCH: 256
WARMUP_FACTOR: 0.1
WARMUP_ITERS: 5000
WARMUP_ITERS: 0
CHECKPOINT_PERIOD: 1

View File

@ -3,13 +3,12 @@ _BASE_: face_base.yml
MODEL:
BACKBONE:
NAME: build_resnetIR_backbone
DEPTH: 50x
FEAT_DIM: 25088 # 512x7x7
WITH_SE: True
DROPOUT: 0.
HEADS:
PFC:
ENABLED: True
OUTPUT_DIR: projects/FastFace/logs/ir_se50-glink360k-pfc0.1
OUTPUT_DIR: projects/FastFace/logs/pfc0.1_insightface

View File

@ -7,3 +7,4 @@
from .modeling import *
from .config import add_face_cfg
from .trainer import FaceTrainer
from .datasets import *

View File

@ -12,5 +12,7 @@ def add_face_cfg(cfg):
_C.DATASETS.REC_PATH = ""
_C.MODEL.BACKBONE.DROPOUT = 0.
_C.MODEL.HEADS.PFC = CN({"ENABLED": False})
_C.MODEL.HEADS.PFC.SAMPLE_RATE = 0.1

View File

@ -23,7 +23,7 @@ class MS1MV2(ImageDataset):
required_files = [self.dataset_dir]
self.check_before_run(required_files)
train = self.process_dirs()
train = self.process_dirs()[:10000]
super().__init__(train, [], [], **kwargs)
def process_dirs(self):

View File

@ -7,4 +7,4 @@
from .partial_fc import PartialFC
from .face_baseline import FaceBaseline
from .face_head import FaceHead
from .resnet_ir import build_resnetIR_backbone
from .iresnet import build_iresnet_backbone

View File

@ -4,6 +4,7 @@
@contact: sherlockliao01@gmail.com
"""
import torch
from fastreid.modeling.meta_arch import Baseline
from fastreid.modeling.meta_arch import META_ARCH_REGISTRY
@ -13,12 +14,28 @@ class FaceBaseline(Baseline):
def __init__(self, cfg):
super().__init__(cfg)
self.pfc_enabled = cfg.MODEL.HEADS.PFC.ENABLED
self.amp_enabled = cfg.SOLVER.AMP.ENABLED
def losses(self, outputs, gt_labels):
def forward(self, batched_inputs):
if not self.pfc_enabled:
return super().losses(outputs, gt_labels)
return super().forward(batched_inputs)
images = self.preprocess_image(batched_inputs)
with torch.cuda.amp.autocast(self.amp_enabled):
features = self.backbone(images)
features = features.float() if self.amp_enabled else features
if self.training:
assert "targets" in batched_inputs, "Person ID annotation are missing in training!"
targets = batched_inputs["targets"]
# PreciseBN flag, When do preciseBN on different dataset, the number of classes in new dataset
# may be larger than that in the original dataset, so the circle/arcface will
# throw an error. We just set all the targets to 0 to avoid this problem.
if targets.sum() < 0: targets.zero_()
outputs = self.heads(features, targets)
return outputs, targets
else:
# model parallel with partial-fc
# cls layer and loss computation in partial_fc.py
pred_features = outputs["features"]
return pred_features, gt_labels
outputs = self.heads(features)
return outputs

View File

@ -30,10 +30,4 @@ class FaceHead(EmbeddingHead):
pool_feat = self.pool_layer(features)
neck_feat = self.bottleneck(pool_feat)
neck_feat = neck_feat[..., 0, 0]
if not self.training:
return neck_feat
return {
"features": neck_feat,
}

View File

@ -0,0 +1,179 @@
# encoding: utf-8
"""
@author: xingyu liao
@contact: sherlockliao01@gmail.com
"""
import torch
from torch import nn
from fastreid.layers import get_norm
from fastreid.modeling.backbones import BACKBONE_REGISTRY
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=dilation, groups=groups, bias=False, dilation=dilation)
def conv1x1(in_planes, out_planes, stride=1):
"""1x1 convolution"""
return nn.Conv2d(in_planes,
out_planes,
kernel_size=1,
stride=stride,
bias=False)
class IBasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, bn_norm, stride=1, downsample=None,
groups=1, base_width=64, dilation=1):
super().__init__()
if groups != 1 or base_width != 64:
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
if dilation > 1:
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
self.bn1 = get_norm(bn_norm, inplanes)
self.conv1 = conv3x3(inplanes, planes)
self.bn2 = get_norm(bn_norm, planes)
self.prelu = nn.PReLU(planes)
self.conv2 = conv3x3(planes, planes, stride)
self.bn3 = get_norm(bn_norm, planes)
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.bn1(x)
out = self.conv1(out)
out = self.bn2(out)
out = self.prelu(out)
out = self.conv2(out)
out = self.bn3(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
return out
class IResNet(nn.Module):
fc_scale = 7 * 7
def __init__(self, block, layers, bn_norm, dropout=0, zero_init_residual=False,
groups=1, width_per_group=64, replace_stride_with_dilation=None, fp16=False):
super().__init__()
self.inplanes = 64
self.dilation = 1
self.fp16 = fp16
if replace_stride_with_dilation is None:
replace_stride_with_dilation = [False, False, False]
if len(replace_stride_with_dilation) != 3:
raise ValueError("replace_stride_with_dilation should be None "
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
self.groups = groups
self.base_width = width_per_group
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
self.bn1 = get_norm(bn_norm, self.inplanes)
self.prelu = nn.PReLU(self.inplanes)
self.layer1 = self._make_layer(block, 64, layers[0], bn_norm, stride=2)
self.layer2 = self._make_layer(block,
128,
layers[1],
bn_norm,
stride=2,
dilate=replace_stride_with_dilation[0])
self.layer3 = self._make_layer(block,
256,
layers[2],
bn_norm,
stride=2,
dilate=replace_stride_with_dilation[1])
self.layer4 = self._make_layer(block,
512,
layers[3],
bn_norm,
stride=2,
dilate=replace_stride_with_dilation[2])
self.bn2 = get_norm(bn_norm, 512 * block.expansion)
self.dropout = nn.Dropout(p=dropout, inplace=True)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.normal_(m.weight, 0, 0.1)
elif m.__class__.__name__.find('Norm') != -1:
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
if zero_init_residual:
for m in self.modules():
if isinstance(m, IBasicBlock):
nn.init.constant_(m.bn2.weight, 0)
def _make_layer(self, block, planes, blocks, bn_norm, stride=1, dilate=False):
downsample = None
previous_dilation = self.dilation
if dilate:
self.dilation *= stride
stride = 1
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
conv1x1(self.inplanes, planes * block.expansion, stride),
get_norm(bn_norm, planes * block.expansion),
)
layers = []
layers.append(
block(self.inplanes, planes, bn_norm, stride, downsample, self.groups,
self.base_width, previous_dilation))
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(
block(self.inplanes,
planes,
bn_norm,
groups=self.groups,
base_width=self.base_width,
dilation=self.dilation))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.prelu(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.bn2(x)
x = self.dropout(x)
return x
@BACKBONE_REGISTRY.register()
def build_iresnet_backbone(cfg):
"""
Create a IResNet instance from config.
Returns:
ResNet: a :class:`ResNet` instance.
"""
# fmt: off
bn_norm = cfg.MODEL.BACKBONE.NORM
depth = cfg.MODEL.BACKBONE.DEPTH
dropout = cfg.MODEL.BACKBONE.DROPOUT
fp16 = cfg.SOLVER.AMP.ENABLED
# fmt: on
num_blocks_per_stage = {
'18x': [2, 2, 2, 2],
'34x': [3, 4, 6, 3],
'50x': [3, 4, 14, 3],
'100x': [3, 13, 30, 3],
'200x': [6, 26, 60, 6],
}[depth]
model = IResNet(IBasicBlock, num_blocks_per_stage, bn_norm, dropout, fp16=fp16)
return model

View File

@ -52,23 +52,6 @@ class PartialFC(nn.Module):
self.cls_layer = getattr(any_softmax, cls_type)(num_classes, scale, margin)
""" TODO: consider resume training
if resume:
try:
self.weight: torch.Tensor = torch.load(self.weight_name)
logging.info("softmax weight resume successfully!")
except (FileNotFoundError, KeyError, IndexError):
self.weight = torch.normal(0, 0.01, (self.num_local, self.embedding_size), device=self.device)
logging.info("softmax weight resume fail!")
try:
self.weight_mom: torch.Tensor = torch.load(self.weight_mom_name)
logging.info("softmax weight mom resume successfully!")
except (FileNotFoundError, KeyError, IndexError):
self.weight_mom: torch.Tensor = torch.zeros_like(self.weight)
logging.info("softmax weight mom resume fail!")
else:
"""
self.weight = torch.normal(0, 0.01, (self.num_local, self.embedding_size), device=self.device)
self.weight_mom: torch.Tensor = torch.zeros_like(self.weight)
logger.info("softmax weight init successfully!")

View File

@ -1,122 +0,0 @@
# encoding: utf-8
"""
@author: xingyu liao
@contact: sherlockliao01@gmail.com
"""
from collections import namedtuple
from torch import nn
from fastreid.layers import get_norm, SELayer
from fastreid.modeling.backbones import BACKBONE_REGISTRY
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=dilation, groups=groups, bias=False, dilation=dilation)
class bottleneck_IR(nn.Module):
def __init__(self, in_channel, depth, bn_norm, stride, with_se=False):
super(bottleneck_IR, self).__init__()
if in_channel == depth:
self.shortcut_layer = nn.MaxPool2d(1, stride)
else:
self.shortcut_layer = nn.Sequential(
nn.Conv2d(in_channel, depth, (1, 1), stride, bias=False),
get_norm(bn_norm, depth))
self.res_layer = nn.Sequential(
get_norm(bn_norm, in_channel),
nn.Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False),
nn.PReLU(depth),
nn.Conv2d(depth, depth, (3, 3), stride, 1, bias=False),
get_norm(bn_norm, depth),
SELayer(depth, 16) if with_se else nn.Identity()
)
def forward(self, x):
shortcut = self.shortcut_layer(x)
res = self.res_layer(x)
return res + shortcut
class Bottleneck(namedtuple("Block", ["in_channel", "depth", "bn_norm", "stride", "with_se"])):
"""A named tuple describing a ResNet block."""
def get_block(in_channel, depth, bn_norm, num_units, with_se, stride=2):
return [Bottleneck(in_channel, depth, bn_norm, stride, with_se)] + \
[Bottleneck(depth, depth, bn_norm, 1, with_se) for _ in range(num_units - 1)]
def get_blocks(bn_norm, with_se, num_layers):
if num_layers == "50x":
blocks = [
get_block(in_channel=64, depth=64, bn_norm=bn_norm, num_units=3, with_se=with_se),
get_block(in_channel=64, depth=128, bn_norm=bn_norm, num_units=4, with_se=with_se),
get_block(in_channel=128, depth=256, bn_norm=bn_norm, num_units=14, with_se=with_se),
get_block(in_channel=256, depth=512, bn_norm=bn_norm, num_units=3, with_se=with_se)
]
elif num_layers == "100x":
blocks = [
get_block(in_channel=64, depth=64, bn_norm=bn_norm, num_units=3, with_se=with_se),
get_block(in_channel=64, depth=128, bn_norm=bn_norm, num_units=13, with_se=with_se),
get_block(in_channel=128, depth=256, bn_norm=bn_norm, num_units=30, with_se=with_se),
get_block(in_channel=256, depth=512, bn_norm=bn_norm, num_units=3, with_se=with_se)
]
elif num_layers == "152x":
blocks = [
get_block(in_channel=64, depth=64, bn_norm=bn_norm, num_units=3, with_se=with_se),
get_block(in_channel=64, depth=128, bn_norm=bn_norm, num_units=8, with_se=with_se),
get_block(in_channel=128, depth=256, bn_norm=bn_norm, num_units=36, with_se=with_se),
get_block(in_channel=256, depth=512, bn_norm=bn_norm, num_units=3, with_se=with_se)
]
return blocks
class ResNetIR(nn.Module):
def __init__(self, num_layers, bn_norm, drop_ratio, with_se):
super(ResNetIR, self).__init__()
assert num_layers in ["50x", "100x", "152x"], "num_layers should be 50,100, or 152"
blocks = get_blocks(bn_norm, with_se, num_layers)
self.input_layer = nn.Sequential(nn.Conv2d(3, 64, (3, 3), 1, 1, bias=False),
get_norm(bn_norm, 64),
nn.PReLU(64))
self.output_layer = nn.Sequential(get_norm(bn_norm, 512),
nn.Dropout(drop_ratio))
modules = []
for block in blocks:
for bottleneck in block:
modules.append(
bottleneck_IR(bottleneck.in_channel,
bottleneck.depth,
bottleneck.bn_norm,
bottleneck.stride,
bottleneck.with_se))
self.body = nn.Sequential(*modules)
def forward(self, x):
x = self.input_layer(x)
x = self.body(x)
x = self.output_layer(x)
return x
@BACKBONE_REGISTRY.register()
def build_resnetIR_backbone(cfg):
"""
Create a ResNetIR instance from config.
Returns:
ResNet: a :class:`ResNet` instance.
"""
# fmt: off
bn_norm = cfg.MODEL.BACKBONE.NORM
with_se = cfg.MODEL.BACKBONE.WITH_SE
depth = cfg.MODEL.BACKBONE.DEPTH
# fmt: on
model = ResNetIR(depth, bn_norm, 0.5, with_se)
return model

View File

@ -0,0 +1,84 @@
# encoding: utf-8
"""
@author: xingyu liao
@contact: sherlockliao01@gmail.com
"""
import os
from typing import Any, Dict
import torch
from fastreid.engine.hooks import PeriodicCheckpointer
from fastreid.utils import comm
from fastreid.utils.checkpoint import Checkpointer
from fastreid.utils.file_io import PathManager
class PfcPeriodicCheckpointer(PeriodicCheckpointer):
def step(self, epoch: int, **kwargs: Any):
rank = comm.get_rank()
if (epoch + 1) % self.period == 0 and epoch < self.max_epoch - 1:
self.checkpointer.save(
f"softmax_weight_{epoch:04d}_rank_{rank:02d}"
)
if epoch >= self.max_epoch - 1:
self.checkpointer.save(f"softmax_weight_{rank:02d}", )
class PfcCheckpointer(Checkpointer):
def __init__(self, model, save_dir, *, save_to_disk=True, **checkpointables):
super().__init__(model, save_dir, save_to_disk=save_to_disk, **checkpointables)
self.rank = comm.get_rank()
def save(self, name: str, **kwargs: Dict[str, str]):
if not self.save_dir or not self.save_to_disk:
return
data = {}
data["model"] = {
"weight": self.model.weight.data,
"momentum": self.model.weight_mom,
}
for key, obj in self.checkpointables.items():
data[key] = obj.state_dict()
data.update(kwargs)
basename = f"{name}.pth"
save_file = os.path.join(self.save_dir, basename)
assert os.path.basename(save_file) == basename, basename
self.logger.info("Saving partial fc weights")
with PathManager.open(save_file, "wb") as f:
torch.save(data, f)
self.tag_last_checkpoint(basename)
def _load_model(self, checkpoint: Any):
checkpoint_state_dict = checkpoint.pop("model")
self._convert_ndarray_to_tensor(checkpoint_state_dict)
self.model.weight.data.copy_(checkpoint_state_dict.pop("weight"))
self.model.weight_mom.data.copy_(checkpoint_state_dict.pop("momentum"))
def has_checkpoint(self):
save_file = os.path.join(self.save_dir, f"last_weight_{self.rank:02d}")
return PathManager.exists(save_file)
def get_checkpoint_file(self):
"""
Returns:
str: The latest checkpoint file in target directory.
"""
save_file = os.path.join(self.save_dir, f"last_weight_{self.rank:02d}")
try:
with PathManager.open(save_file, "r") as f:
last_saved = f.read().strip()
except IOError:
# if file doesn't exist, maybe because it has just been
# deleted by a separate process
return ""
return os.path.join(self.save_dir, last_saved)
def tag_last_checkpoint(self, last_filename_basename: str):
save_file = os.path.join(self.save_dir, f"last_weight_{self.rank:02d}")
with PathManager.open(save_file, "w") as f:
f.write(last_filename_basename)

View File

@ -8,20 +8,23 @@ import os
import time
from torch.nn.parallel import DistributedDataParallel
from torch.nn.utils import clip_grad_norm_
from fastreid.engine import hooks
from .face_data import TestFaceDataset
from fastreid.data.datasets import DATASET_REGISTRY
from fastreid.data.build import _root, build_reid_test_loader, build_reid_train_loader
from fastreid.data.datasets import DATASET_REGISTRY
from fastreid.data.transforms import build_transforms
from fastreid.engine import hooks
from fastreid.engine.defaults import DefaultTrainer, TrainerBase
from fastreid.engine.train_loop import SimpleTrainer
from fastreid.engine.train_loop import SimpleTrainer, AMPTrainer
from fastreid.utils import comm
from fastreid.utils.checkpoint import Checkpointer
from fastreid.utils.logger import setup_logger
from .face_data import MXFaceDataset
from .face_data import TestFaceDataset
from .face_evaluator import FaceEvaluator
from .modeling import PartialFC
from .pfc_checkpointer import PfcPeriodicCheckpointer, PfcCheckpointer
from .utils_amp import MaxClipGradScaler
class FaceTrainer(DefaultTrainer):
@ -59,11 +62,17 @@ class FaceTrainer(DefaultTrainer):
# for part of the parameters is not updated.
model = DistributedDataParallel(
model, device_ids=[comm.get_local_rank()], broadcast_buffers=False,
find_unused_parameters=True
)
self._trainer = PFCTrainer(model, data_loader, optimizer, self.pfc_module, self.pfc_optimizer) \
if cfg.MODEL.HEADS.PFC.ENABLED else SimpleTrainer(model, data_loader, optimizer)
if cfg.MODEL.HEADS.PFC.ENABLED:
mini_batch_size = cfg.SOLVER.IMS_PER_BATCH // comm.get_world_size()
grad_scaler = MaxClipGradScaler(mini_batch_size, 128 * mini_batch_size, growth_interval=100)
self._trainer = PFCTrainer(model, data_loader, optimizer,
self.pfc_module, self.pfc_optimizer, cfg.SOLVER.AMP.ENABLED, grad_scaler)
else:
self._trainer = (AMPTrainer if cfg.SOLVER.AMP.ENABLED else SimpleTrainer)(
model, data_loader, optimizer
)
self.iters_per_epoch = len(data_loader.dataset) // cfg.SOLVER.IMS_PER_BATCH
self.scheduler = self.build_lr_scheduler(cfg, optimizer, self.iters_per_epoch)
@ -80,7 +89,7 @@ class FaceTrainer(DefaultTrainer):
)
if cfg.MODEL.HEADS.PFC.ENABLED:
self.pfc_checkpointer = Checkpointer(
self.pfc_checkpointer = PfcCheckpointer(
self.pfc_module,
cfg.OUTPUT_DIR,
optimizer=self.pfc_optimizer,
@ -100,12 +109,23 @@ class FaceTrainer(DefaultTrainer):
ret = super().build_hooks()
if self.cfg.MODEL.HEADS.PFC.ENABLED:
# Make sure checkpointer is after writer
ret.insert(
len(ret) - 1,
PfcPeriodicCheckpointer(self.pfc_checkpointer, self.cfg.SOLVER.CHECKPOINT_PERIOD)
)
# partial fc scheduler hook
ret.append(
hooks.LRScheduler(self.pfc_optimizer, self.pfc_scheduler)
)
return ret
def resume_or_load(self, resume=True):
# Backbone loading state_dict
super().resume_or_load(resume)
# Partial-FC loading state_dict
self.pfc_checkpointer.resume_or_load('', resume=resume)
@classmethod
def build_train_loader(cls, cfg):
path_imgrec = cfg.DATASETS.REC_PATH
@ -141,11 +161,14 @@ class PFCTrainer(SimpleTrainer):
https://github.com/deepinsight/insightface/blob/master/recognition/arcface_torch/partial_fc.py
"""
def __init__(self, model, data_loader, optimizer, pfc_module, pfc_optimizer):
def __init__(self, model, data_loader, optimizer, pfc_module, pfc_optimizer, amp_enabled, grad_scaler):
super().__init__(model, data_loader, optimizer)
self.pfc_module = pfc_module
self.pfc_optimizer = pfc_optimizer
self.amp_enabled = amp_enabled
self.grad_scaler = grad_scaler
def run_step(self):
assert self.model.training, "[PFCTrainer] model was changed to eval mode!"
@ -156,18 +179,24 @@ class PFCTrainer(SimpleTrainer):
features, targets = self.model(data)
self.optimizer.zero_grad()
self.pfc_optimizer.zero_grad()
# Partial-fc backward
f_grad, loss_v = self.pfc_module.forward_backward(features, targets, self.pfc_optimizer)
if self.amp_enabled:
features.backward(self.grad_scaler.scale(f_grad))
self.grad_scaler.unscale_(self.optimizer)
clip_grad_norm_(self.model.parameters(), max_norm=5, norm_type=2)
self.grad_scaler.step(self.optimizer)
self.grad_scaler.update()
else:
features.backward(f_grad)
clip_grad_norm_(self.model.parameters(), max_norm=5, norm_type=2)
self.optimizer.step()
loss_dict = {"loss_cls": loss_v}
self._write_metrics(loss_dict, data_time)
self.optimizer.step()
self.pfc_optimizer.step()
self.pfc_module.update()
self.optimizer.zero_grad()
self.pfc_optimizer.zero_grad()

View File

@ -0,0 +1,86 @@
# encoding: utf-8
"""
@author: xingyu liao
@contact: sherlockliao01@gmail.com
"""
from typing import Dict, List
import torch
from torch._six import container_abcs
from torch.cuda.amp import GradScaler
class _MultiDeviceReplicator(object):
"""
Lazily serves copies of a tensor to requested devices. Copies are cached per-device.
"""
def __init__(self, master_tensor: torch.Tensor) -> None:
assert master_tensor.is_cuda
self.master = master_tensor
self._per_device_tensors: Dict[torch.device, torch.Tensor] = {}
def get(self, device) -> torch.Tensor:
retval = self._per_device_tensors.get(device, None)
if retval is None:
retval = self.master.to(device=device, non_blocking=True, copy=True)
self._per_device_tensors[device] = retval
return retval
class MaxClipGradScaler(GradScaler):
def __init__(self, init_scale, max_scale: float, growth_interval=100):
super().__init__(init_scale=init_scale, growth_interval=growth_interval)
self.max_scale = max_scale
def scale_clip(self):
if self.get_scale() == self.max_scale:
self.set_growth_factor(1)
elif self.get_scale() < self.max_scale:
self.set_growth_factor(2)
elif self.get_scale() > self.max_scale:
self._scale.fill_(self.max_scale)
self.set_growth_factor(1)
def scale(self, outputs):
"""
Multiplies ('scales') a tensor or list of tensors by the scale factor.
Returns scaled outputs. If this instance of :class:`GradScaler` is not enabled, outputs are returned
unmodified.
Arguments:
outputs (Tensor or iterable of Tensors): Outputs to scale.
"""
if not self._enabled:
return outputs
self.scale_clip()
# Short-circuit for the common case.
if isinstance(outputs, torch.Tensor):
assert outputs.is_cuda
if self._scale is None:
self._lazy_init_scale_growth_tracker(outputs.device)
assert self._scale is not None
return outputs * self._scale.to(device=outputs.device, non_blocking=True)
# Invoke the more complex machinery only if we're treating multiple outputs.
stash: List[_MultiDeviceReplicator] = [] # holds a reference that can be overwritten by apply_scale
def apply_scale(val):
if isinstance(val, torch.Tensor):
assert val.is_cuda
if len(stash) == 0:
if self._scale is None:
self._lazy_init_scale_growth_tracker(val.device)
assert self._scale is not None
stash.append(_MultiDeviceReplicator(self._scale))
return val * stash[0].get(val.device)
elif isinstance(val, container_abcs.Iterable):
iterable = map(apply_scale, val)
if isinstance(val, list) or isinstance(val, tuple):
return type(val)(iterable)
else:
return iterable
else:
raise ValueError("outputs must be a Tensor or an iterable of Tensors")
return apply_scale(outputs)

View File

@ -14,7 +14,6 @@ from fastreid.engine import default_argument_parser, default_setup, launch
from fastreid.utils.checkpoint import Checkpointer
from fastface import *
from fastface.datasets import *
def setup(args):