mirror of https://github.com/JDAI-CV/fast-reid.git
refactor code for partial reid (#277)
Summary: make partial reid adapted for new code style close #277pull/299/head
parent
f2d2467ead
commit
a25d8a6bc1
projects/PartialReID
|
@ -7,6 +7,7 @@ MODEL:
|
|||
NORM: "BN"
|
||||
LAST_STRIDE: 1
|
||||
WITH_IBN: True
|
||||
PRETRAIN_PATH: "/export/home/lxy/.cache/torch/checkpoints/resnet50_ibn_a-d9d0bb7b.pth"
|
||||
|
||||
HEADS:
|
||||
NAME: "DSRHead"
|
||||
|
@ -38,7 +39,7 @@ INPUT:
|
|||
|
||||
DATALOADER:
|
||||
PK_SAMPLER: True
|
||||
NAIVE_WAY: True
|
||||
NAIVE_WAY: False
|
||||
NUM_INSTANCE: 4
|
||||
NUM_WORKERS: 8
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
# encoding: utf-8
|
||||
"""
|
||||
@author: xingyu liao
|
||||
@contact: liaoxingyu5@jd.com
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
from .partial_dataset import *
|
||||
|
|
|
@ -10,10 +10,12 @@ from collections import OrderedDict
|
|||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from sklearn import metrics
|
||||
|
||||
from fastreid.evaluation.evaluator import DatasetEvaluator
|
||||
from fastreid.evaluation.rank import evaluate_rank
|
||||
from fastreid.evaluation.roc import evaluate_roc
|
||||
from fastreid.utils import comm
|
||||
from .dsr_distance import compute_dsr_dist
|
||||
|
||||
logger = logging.getLogger('fastreid.partialreid.dsr_evaluation')
|
||||
|
@ -39,37 +41,73 @@ class DsrEvaluator(DatasetEvaluator):
|
|||
self.camids = []
|
||||
|
||||
def process(self, inputs, outputs):
|
||||
self.pids.extend(inputs["targets"].numpy())
|
||||
self.camids.extend(inputs["camid"].numpy())
|
||||
self.pids.extend(inputs["targets"])
|
||||
self.camids.extend(inputs["camids"])
|
||||
self.features.append(F.normalize(outputs[0]).cpu())
|
||||
outputs1 = F.normalize(outputs[1].data).cpu().numpy()
|
||||
outputs1 = F.normalize(outputs[1].data).cpu()
|
||||
self.spatial_features.append(outputs1)
|
||||
self.scores.append(outputs[2])
|
||||
|
||||
def evaluate(self):
|
||||
features = torch.cat(self.features, dim=0)
|
||||
spatial_features = np.vstack(self.spatial_features)
|
||||
scores = torch.cat(self.scores, dim=0)
|
||||
if comm.get_world_size() > 1:
|
||||
comm.synchronize()
|
||||
features = comm.gather(self.features)
|
||||
features = sum(features, [])
|
||||
|
||||
spatial_features = comm.gather(self.spatial_features)
|
||||
spatial_features = sum(spatial_features, [])
|
||||
|
||||
scores = comm.gather(self.scores)
|
||||
scores = sum(scores, [])
|
||||
|
||||
pids = comm.gather(self.pids)
|
||||
pids = sum(pids, [])
|
||||
|
||||
camids = comm.gather(self.camids)
|
||||
camids = sum(camids, [])
|
||||
|
||||
# fmt: off
|
||||
if not comm.is_main_process(): return {}
|
||||
# fmt: on
|
||||
else:
|
||||
features = self.features
|
||||
spatial_features = self.spatial_features
|
||||
scores = self.scores
|
||||
pids = self.pids
|
||||
camids = self.camids
|
||||
|
||||
features = torch.cat(features, dim=0)
|
||||
spatial_features = torch.cat(spatial_features, dim=0).numpy()
|
||||
scores = torch.cat(scores, dim=0)
|
||||
|
||||
# query feature, person ids and camera ids
|
||||
query_features = features[:self._num_query]
|
||||
query_pids = np.asarray(self.pids[:self._num_query])
|
||||
query_camids = np.asarray(self.camids[:self._num_query])
|
||||
query_pids = np.asarray(pids[:self._num_query])
|
||||
query_camids = np.asarray(camids[:self._num_query])
|
||||
|
||||
# gallery features, person ids and camera ids
|
||||
gallery_features = features[self._num_query:]
|
||||
gallery_pids = np.asarray(self.pids[self._num_query:])
|
||||
gallery_camids = np.asarray(self.camids[self._num_query:])
|
||||
gallery_pids = np.asarray(pids[self._num_query:])
|
||||
gallery_camids = np.asarray(camids[self._num_query:])
|
||||
|
||||
if self.cfg.TEST.METRIC == "cosine":
|
||||
query_features = F.normalize(query_features, dim=1)
|
||||
gallery_features = F.normalize(gallery_features, dim=1)
|
||||
|
||||
dist = 1 - torch.mm(query_features, gallery_features.t()).numpy()
|
||||
self._results = OrderedDict()
|
||||
|
||||
query_features = query_features.numpy()
|
||||
gallery_features = gallery_features.numpy()
|
||||
if self.cfg.TEST.DSR.ENABLED:
|
||||
logger.info("Testing with DSR setting")
|
||||
dist = compute_dsr_dist(spatial_features[:self._num_query], spatial_features[self._num_query:], dist,
|
||||
scores[:self._num_query])
|
||||
logger.info("Testing with DSR setting")
|
||||
|
||||
cmc, all_AP, all_INP = evaluate_rank(dist, query_pids, gallery_pids, query_camids, gallery_camids)
|
||||
cmc, all_AP, all_INP = evaluate_rank(dist, query_features, gallery_features, query_pids, gallery_pids,
|
||||
query_camids, gallery_camids, use_distmat=True)
|
||||
else:
|
||||
cmc, all_AP, all_INP = evaluate_rank(dist, query_features, gallery_features, query_pids, gallery_pids,
|
||||
query_camids, gallery_camids, use_distmat=False)
|
||||
mAP = np.mean(all_AP)
|
||||
mINP = np.mean(all_INP)
|
||||
|
||||
|
@ -78,9 +116,13 @@ class DsrEvaluator(DatasetEvaluator):
|
|||
self._results['mAP'] = mAP
|
||||
self._results['mINP'] = mINP
|
||||
|
||||
tprs = evaluate_roc(dist, query_pids, gallery_pids, query_camids, gallery_camids)
|
||||
fprs = [1e-4, 1e-3, 1e-2]
|
||||
for i in range(len(fprs)):
|
||||
self._results["TPR@FPR={}".format(fprs[i])] = tprs[i]
|
||||
if self.cfg.TEST.ROC_ENABLED:
|
||||
scores, labels = evaluate_roc(dist, query_features, gallery_features,
|
||||
query_pids, gallery_pids, query_camids, gallery_camids)
|
||||
fprs, tprs, thres = metrics.roc_curve(labels, scores)
|
||||
|
||||
for fpr in [1e-4, 1e-3, 1e-2]:
|
||||
ind = np.argmin(np.abs(fprs - fpr))
|
||||
self._results["TPR@FPR={:.0e}".format(fpr)] = tprs[ind]
|
||||
|
||||
return copy.deepcopy(self._results)
|
||||
|
|
|
@ -4,6 +4,10 @@
|
|||
@contact: helingxiao3@jd.com
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from fastreid.layers import *
|
||||
from fastreid.modeling.heads.build import REID_HEADS_REGISTRY
|
||||
from fastreid.utils.weight_init import weights_init_classifier, weights_init_kaiming
|
||||
|
@ -48,34 +52,53 @@ class OcclusionUnit(nn.Module):
|
|||
|
||||
@REID_HEADS_REGISTRY.register()
|
||||
class DSRHead(nn.Module):
|
||||
def __init__(self, cfg, in_feat, num_classes, pool_layer=nn.AdaptiveAvgPool2d(1)):
|
||||
def __init__(self, cfg):
|
||||
super().__init__()
|
||||
|
||||
self.pool_layer = pool_layer
|
||||
# fmt: off
|
||||
feat_dim = cfg.MODEL.BACKBONE.FEAT_DIM
|
||||
num_classes = cfg.MODEL.HEADS.NUM_CLASSES
|
||||
neck_feat = cfg.MODEL.HEADS.NECK_FEAT
|
||||
pool_type = cfg.MODEL.HEADS.POOL_LAYER
|
||||
cls_type = cfg.MODEL.HEADS.CLS_LAYER
|
||||
norm_type = cfg.MODEL.HEADS.NORM
|
||||
|
||||
self.occ_unit = OcclusionUnit(in_planes=in_feat)
|
||||
if pool_type == 'fastavgpool': self.pool_layer = FastGlobalAvgPool2d()
|
||||
elif pool_type == 'avgpool': self.pool_layer = nn.AdaptiveAvgPool2d(1)
|
||||
elif pool_type == 'maxpool': self.pool_layer = nn.AdaptiveMaxPool2d(1)
|
||||
elif pool_type == 'gempoolP': self.pool_layer = GeneralizedMeanPoolingP()
|
||||
elif pool_type == 'gempool': self.pool_layer = GeneralizedMeanPooling()
|
||||
elif pool_type == "avgmaxpool": self.pool_layer = AdaptiveAvgMaxPool2d()
|
||||
elif pool_type == 'clipavgpool': self.pool_layer = ClipGlobalAvgPool2d()
|
||||
elif pool_type == "identity": self.pool_layer = nn.Identity()
|
||||
elif pool_type == "flatten": self.pool_layer = Flatten()
|
||||
else: raise KeyError(f"{pool_type} is not supported!")
|
||||
# fmt: on
|
||||
|
||||
self.neck_feat = neck_feat
|
||||
|
||||
self.occ_unit = OcclusionUnit(in_planes=feat_dim)
|
||||
self.MaxPool1 = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
|
||||
self.MaxPool2 = nn.MaxPool2d(kernel_size=4, stride=2, padding=0)
|
||||
self.MaxPool3 = nn.MaxPool2d(kernel_size=6, stride=2, padding=0)
|
||||
self.MaxPool4 = nn.MaxPool2d(kernel_size=8, stride=2, padding=0)
|
||||
|
||||
self.bnneck = get_norm(cfg.MODEL.HEADS.NORM, in_feat, cfg.MODEL.HEADS.NORM_SPLIT, bias_freeze=True)
|
||||
self.bnneck = get_norm(norm_type, feat_dim, bias_freeze=True)
|
||||
self.bnneck.apply(weights_init_kaiming)
|
||||
|
||||
self.bnneck_occ = get_norm(cfg.MODEL.HEADS.NORM, in_feat, cfg.MODEL.HEADS.NORM_SPLIT, bias_freeze=True)
|
||||
self.bnneck_occ = get_norm(norm_type, feat_dim, bias_freeze=True)
|
||||
self.bnneck_occ.apply(weights_init_kaiming)
|
||||
|
||||
# identity classification layer
|
||||
cls_type = cfg.MODEL.HEADS.CLS_LAYER
|
||||
if cls_type == 'linear':
|
||||
self.classifier = nn.Linear(in_feat, num_classes, bias=False)
|
||||
self.classifier_occ = nn.Linear(in_feat, num_classes, bias=False)
|
||||
self.classifier = nn.Linear(feat_dim, num_classes, bias=False)
|
||||
self.classifier_occ = nn.Linear(feat_dim, num_classes, bias=False)
|
||||
elif cls_type == 'arcSoftmax':
|
||||
self.classifier = ArcSoftmax(cfg, in_feat, num_classes)
|
||||
self.classifier_occ = ArcSoftmax(cfg, in_feat, num_classes)
|
||||
self.classifier = ArcSoftmax(cfg, feat_dim, num_classes)
|
||||
self.classifier_occ = ArcSoftmax(cfg, feat_dim, num_classes)
|
||||
elif cls_type == 'circleSoftmax':
|
||||
self.classifier = CircleSoftmax(cfg, in_feat, num_classes)
|
||||
self.classifier_occ = CircleSoftmax(cfg, in_feat, num_classes)
|
||||
self.classifier = CircleSoftmax(cfg, feat_dim, num_classes)
|
||||
self.classifier_occ = CircleSoftmax(cfg, feat_dim, num_classes)
|
||||
else:
|
||||
raise KeyError(f"{cls_type} is invalid, please choose from "
|
||||
f"'linear', 'arcSoftmax' and 'circleSoftmax'.")
|
||||
|
@ -111,13 +134,20 @@ class DSRHead(nn.Module):
|
|||
bn_feat = self.bnneck(global_feat)
|
||||
bn_feat = bn_feat[..., 0, 0]
|
||||
|
||||
try:
|
||||
if self.classifier.__class__.__name__ == 'Linear':
|
||||
cls_outputs = self.classifier(bn_feat)
|
||||
fore_cls_outputs = self.classifier_occ(bn_foreground_feat)
|
||||
except TypeError:
|
||||
pred_class_logits = F.linear(bn_feat, self.classifier.weight)
|
||||
else:
|
||||
cls_outputs = self.classifier(bn_feat, targets)
|
||||
fore_cls_outputs = self.classifier_occ(bn_foreground_feat, targets)
|
||||
pred_class_logits = self.classifier.s * F.linear(F.normalize(bn_feat),
|
||||
F.normalize(self.classifier.weight))
|
||||
|
||||
pred_class_logits = F.linear(bn_foreground_feat, self.classifier.weight)
|
||||
|
||||
return cls_outputs, fore_cls_outputs, pred_class_logits, global_feat[..., 0, 0], foreground_feat[..., 0, 0]
|
||||
return {
|
||||
"cls_outputs": cls_outputs,
|
||||
"fore_cls_outputs": fore_cls_outputs,
|
||||
"pred_class_logits": pred_class_logits,
|
||||
"global_features": global_feat[..., 0, 0],
|
||||
"foreground_features": foreground_feat[..., 0, 0],
|
||||
}
|
||||
|
|
|
@ -12,20 +12,58 @@ from fastreid.modeling.meta_arch.build import META_ARCH_REGISTRY
|
|||
@META_ARCH_REGISTRY.register()
|
||||
class PartialBaseline(Baseline):
|
||||
|
||||
def losses(self, outputs, gt_labels):
|
||||
cls_outputs, fore_cls_outputs, pred_class_logits, global_feat, fore_feat = outputs
|
||||
def losses(self, outs):
|
||||
r"""
|
||||
Compute loss from modeling's outputs, the loss function input arguments
|
||||
must be the same as the outputs of the model forwarding.
|
||||
"""
|
||||
# fmt: off
|
||||
outputs = outs["outputs"]
|
||||
gt_labels = outs["targets"]
|
||||
# model predictions
|
||||
pred_class_logits = outputs['pred_class_logits'].detach()
|
||||
cls_outputs = outputs["cls_outputs"]
|
||||
fore_cls_outputs = outputs["fore_cls_outputs"]
|
||||
global_feat = outputs["global_features"]
|
||||
fore_feat = outputs["foreground_features"]
|
||||
# fmt: on
|
||||
|
||||
# Log prediction accuracy
|
||||
log_accuracy(pred_class_logits, gt_labels)
|
||||
|
||||
loss_dict = {}
|
||||
loss_names = self._cfg.MODEL.LOSSES.NAME
|
||||
|
||||
# Log prediction accuracy
|
||||
CrossEntropyLoss.log_accuracy(pred_class_logits, gt_labels)
|
||||
|
||||
if "CrossEntropyLoss" in loss_names:
|
||||
loss_dict['loss_avg_branch_cls'] = CrossEntropyLoss(self._cfg)(cls_outputs, gt_labels)
|
||||
loss_dict['loss_fore_branch_cls'] = CrossEntropyLoss(self._cfg)(fore_cls_outputs, gt_labels)
|
||||
loss_dict['loss_avg_branch_cls'] = cross_entropy_loss(
|
||||
cls_outputs,
|
||||
gt_labels,
|
||||
self._cfg.MODEL.LOSSES.CE.EPSILON,
|
||||
self._cfg.MODEL.LOSSES.CE.ALPHA,
|
||||
) * self._cfg.MODEL.LOSSES.CE.SCALE
|
||||
|
||||
loss_dict['loss_fore_branch_cls'] = cross_entropy_loss(
|
||||
fore_cls_outputs,
|
||||
gt_labels,
|
||||
self._cfg.MODEL.LOSSES.CE.EPSILON,
|
||||
self._cfg.MODEL.LOSSES.CE.ALPHA,
|
||||
) * self._cfg.MODEL.LOSSES.CE.SCALE
|
||||
|
||||
if "TripletLoss" in loss_names:
|
||||
loss_dict['loss_avg_branch_triplet'] = TripletLoss(self._cfg)(global_feat, gt_labels)
|
||||
loss_dict['loss_fore_branch_triplet'] = TripletLoss(self._cfg)(fore_feat, gt_labels)
|
||||
loss_dict['loss_avg_branch_triplet'] = triplet_loss(
|
||||
global_feat,
|
||||
gt_labels,
|
||||
self._cfg.MODEL.LOSSES.TRI.MARGIN,
|
||||
self._cfg.MODEL.LOSSES.TRI.NORM_FEAT,
|
||||
self._cfg.MODEL.LOSSES.TRI.HARD_MINING,
|
||||
) * self._cfg.MODEL.LOSSES.TRI.SCALE
|
||||
|
||||
loss_dict['loss_fore_branch_triplet'] = triplet_loss(
|
||||
fore_feat,
|
||||
gt_labels,
|
||||
self._cfg.MODEL.LOSSES.TRI.MARGIN,
|
||||
self._cfg.MODEL.LOSSES.TRI.NORM_FEAT,
|
||||
self._cfg.MODEL.LOSSES.TRI.HARD_MINING,
|
||||
) * self._cfg.MODEL.LOSSES.TRI.SCALE
|
||||
return loss_dict
|
||||
|
||||
|
|
|
@ -21,10 +21,9 @@ from partialreid import *
|
|||
|
||||
class Trainer(DefaultTrainer):
|
||||
@classmethod
|
||||
def build_evaluator(cls, cfg, num_query, output_folder=None):
|
||||
if output_folder is None:
|
||||
output_folder = os.path.join(cfg.OUTPUT_DIR, "inference")
|
||||
return DsrEvaluator(cfg, num_query)
|
||||
def build_evaluator(cls, cfg, dataset_name, output_dir=None):
|
||||
data_loader, num_query = cls.build_test_loader(cfg, dataset_name)
|
||||
return data_loader, DsrEvaluator(cfg, num_query, output_dir)
|
||||
|
||||
|
||||
def setup(args):
|
||||
|
|
Loading…
Reference in New Issue