fast-reid/projects/FastFace/fastface/modeling/face_baseline.py

42 lines
1.4 KiB
Python

# encoding: utf-8
"""
@author: xingyu liao
@contact: sherlockliao01@gmail.com
"""
import torch
from fastreid.modeling.meta_arch import Baseline
from fastreid.modeling.meta_arch import META_ARCH_REGISTRY
@META_ARCH_REGISTRY.register()
class FaceBaseline(Baseline):
def __init__(self, cfg):
super().__init__(cfg)
self.pfc_enabled = cfg.MODEL.HEADS.PFC.ENABLED
self.amp_enabled = cfg.SOLVER.AMP.ENABLED
def forward(self, batched_inputs):
if not self.pfc_enabled:
return super().forward(batched_inputs)
images = self.preprocess_image(batched_inputs)
with torch.cuda.amp.autocast(self.amp_enabled):
features = self.backbone(images)
features = features.float() if self.amp_enabled else features
if self.training:
assert "targets" in batched_inputs, "Person ID annotation are missing in training!"
targets = batched_inputs["targets"]
# PreciseBN flag, When do preciseBN on different dataset, the number of classes in new dataset
# may be larger than that in the original dataset, so the circle/arcface will
# throw an error. We just set all the targets to 0 to avoid this problem.
if targets.sum() < 0: targets.zero_()
outputs = self.heads(features, targets)
return outputs, targets
else:
outputs = self.heads(features)
return outputs