mirror of
https://github.com/JDAI-CV/fast-reid.git
synced 2025-06-03 14:50:47 +08:00
update partial code
This commit is contained in:
parent
e0fc15269c
commit
1cf580b6b0
@ -4,39 +4,44 @@ MODEL:
|
||||
BACKBONE:
|
||||
NAME: "build_resnet_backbone"
|
||||
DEPTH: 50
|
||||
NORM: "BN"
|
||||
LAST_STRIDE: 1
|
||||
WITH_IBN: True
|
||||
PRETRAIN_PATH: "/export/home/lxy/.cache/torch/checkpoints/resnet50_ibn_a.pth.tar"
|
||||
|
||||
HEADS:
|
||||
NAME: "DSRHead"
|
||||
NORM: "BN"
|
||||
POOL_LAYER: "avgpool"
|
||||
NECK_FEAT: "before"
|
||||
CLS_LAYER: "linear"
|
||||
|
||||
LOSSES:
|
||||
NAME: ("CrossEntropyLoss", "TripletLoss")
|
||||
CE:
|
||||
EPSILON: 0.1
|
||||
SCALE: 1.0
|
||||
SCALE: 1.
|
||||
TRI:
|
||||
MARGIN: 0.3
|
||||
HARD_MINING: False
|
||||
SCALE: 1.
|
||||
|
||||
|
||||
DATASETS:
|
||||
NAMES: ("Market1501",)
|
||||
TESTS: ("PartialREID","PartialiLIDS","OccludedREID",)
|
||||
TESTS: ("PartialREID", "PartialiLIDS","OccludedREID",)
|
||||
|
||||
INPUT:
|
||||
SIZE_TRAIN: [384, 128]
|
||||
SIZE_TEST: [384, 128]
|
||||
REA:
|
||||
ENABLED: False
|
||||
PROB: 0.5
|
||||
MEAN: [123.675, 116.28, 103.53]
|
||||
DO_PAD: False
|
||||
|
||||
DATALOADER:
|
||||
PK_SAMPLER: True
|
||||
NAIVE_WAY: True
|
||||
NUM_INSTANCE: 4
|
||||
NUM_WORKERS: 16
|
||||
NUM_WORKERS: 8
|
||||
|
||||
SOLVER:
|
||||
OPT: "Adam"
|
||||
@ -47,18 +52,19 @@ SOLVER:
|
||||
WEIGHT_DECAY_BIAS: 0.0
|
||||
IMS_PER_BATCH: 64
|
||||
|
||||
SCHED: "WarmupMultiStepLR"
|
||||
STEPS: [15, 25]
|
||||
GAMMA: 0.1
|
||||
|
||||
WARMUP_FACTOR: 0.01
|
||||
WARMUP_ITERS: 3
|
||||
WARMUP_ITERS: 5
|
||||
|
||||
CHECKPOINT_PERIOD: 10
|
||||
|
||||
TEST:
|
||||
EVAL_PERIOD: 5
|
||||
IMS_PER_BATCH: 512
|
||||
IMS_PER_BATCH: 128
|
||||
|
||||
CUDNN_BENCHMARK: True
|
||||
|
||||
OUTPUT_DIR: "logs/test_partial"
|
||||
OUTPUT_DIR: "projects/PartialReID/logs/test_partial"
|
||||
|
@ -12,5 +12,4 @@ def add_partialreid_config(cfg):
|
||||
|
||||
_C.TEST.DSR = CN()
|
||||
_C.TEST.DSR.ENABLED = True
|
||||
_C.TEST.DSR.TOPK = 30
|
||||
|
||||
|
@ -12,33 +12,42 @@ def normalize(nparray, order=2, axis=0):
|
||||
return nparray / (norm + np.finfo(np.float32).eps)
|
||||
|
||||
|
||||
def compute_dsr_dist(array1, array2, distmat, scores, topk=30):
|
||||
def compute_dsr_dist(array1, array2, distmat, scores):
|
||||
""" Compute the sptial feature reconstruction of all pairs
|
||||
array: [M, N, C] M: the number of query, N: the number of spatial feature, C: the dimension of each spatial feature
|
||||
array2: [M, N, C] M: the number of gallery
|
||||
:return:
|
||||
numpy array with shape [m1, m2]
|
||||
"""
|
||||
|
||||
dist = 100 * torch.ones(len(array1), len(array2))
|
||||
dist = dist.cuda()
|
||||
kappa = 0.001
|
||||
index = np.argsort(distmat, axis=1)
|
||||
|
||||
T = kappa * torch.eye(110)
|
||||
T = T.cuda()
|
||||
M = []
|
||||
for i in range(0, len(array2)):
|
||||
g = array2[i]
|
||||
g = torch.FloatTensor(g)
|
||||
g = g.view(g.size(0), g.size(1))
|
||||
g = g.cuda()
|
||||
Proj_M1 = torch.matmul(torch.inverse(torch.matmul(g.t(), g) + T), g.t())
|
||||
Proj_M1 = Proj_M1.cpu().numpy()
|
||||
M.append(Proj_M1)
|
||||
for i in range(0, len(array1)):
|
||||
q = torch.FloatTensor(array1[i])
|
||||
q = q.view(q.size(0), q.size(1))
|
||||
q = q.cuda()
|
||||
score = scores[i]
|
||||
for j in range(topk):
|
||||
for j in range(0, 100):
|
||||
g = array2[index[i, j]]
|
||||
g = torch.FloatTensor(g)
|
||||
g = g.view(g.size(0), g.size(1))
|
||||
g = g.cuda()
|
||||
sim = torch.matmul(q.t(), g)
|
||||
min_value, min_index = (1 - sim).min(1)
|
||||
dist[i, index[i, j]] = (min_value * score).sum()
|
||||
Proj_M = torch.FloatTensor(M[index[i, j]])
|
||||
Proj_M = Proj_M.cuda()
|
||||
a = torch.matmul(g, torch.matmul(Proj_M, q)) - q
|
||||
dist[i, index[i, j]] = ((torch.pow(a, 2).sum(0).sqrt()) * scores[i]).sum()
|
||||
dist = dist.cpu()
|
||||
dist = dist.numpy()
|
||||
dist = 0.98 * dist + 0.02 * distmat
|
||||
|
||||
return dist
|
||||
|
@ -38,13 +38,13 @@ class DsrEvaluator(DatasetEvaluator):
|
||||
self.pids = []
|
||||
self.camids = []
|
||||
|
||||
def process(self, outputs):
|
||||
def process(self, inputs, outputs):
|
||||
self.pids.extend(inputs["targets"].numpy())
|
||||
self.camids.extend(inputs["camid"].numpy())
|
||||
self.features.append(F.normalize(outputs[0]).cpu())
|
||||
outputs1 = F.normalize(outputs[1].data).cpu().numpy()
|
||||
self.spatial_features.append(outputs1)
|
||||
self.scores.append(outputs[2])
|
||||
self.pids.extend(inputs["targets"].numpy())
|
||||
self.camids.extend(inputs["camid"].numpy())
|
||||
|
||||
def evaluate(self):
|
||||
features = torch.cat(self.features, dim=0)
|
||||
@ -62,20 +62,25 @@ class DsrEvaluator(DatasetEvaluator):
|
||||
gallery_camids = np.asarray(self.camids[self._num_query:])
|
||||
|
||||
dist = 1 - torch.mm(query_features, gallery_features.t()).numpy()
|
||||
logger.info("Testing without DSR setting")
|
||||
self._results = OrderedDict()
|
||||
|
||||
if self.cfg.TEST.DSR.ENABLED:
|
||||
topk = self.cfg.TEST.DSR.TOPK
|
||||
dist = compute_dsr_dist(spatial_features[:self._num_query], spatial_features[self._num_query:], dist,
|
||||
scores[:self._num_query], topk)
|
||||
scores[:self._num_query])
|
||||
logger.info("Testing with DSR setting")
|
||||
|
||||
cmc, all_AP, all_INP = evaluate_rank(dist, query_pids, gallery_pids, query_camids, gallery_camids)
|
||||
mAP = np.mean(all_AP)
|
||||
mINP = np.mean(all_INP)
|
||||
|
||||
self._results['R-1'] = cmc[0]
|
||||
for r in [1, 5, 10]:
|
||||
self._results['Rank-{}'.format(r)] = cmc[r - 1]
|
||||
self._results['mAP'] = mAP
|
||||
self._results['mINP'] = mINP
|
||||
|
||||
tprs = evaluate_roc(dist, query_pids, gallery_pids, query_camids, gallery_camids)
|
||||
fprs = [1e-4, 1e-3, 1e-2]
|
||||
for i in range(len(fprs)):
|
||||
self._results["TPR@FPR={}".format(fprs[i])] = tprs[i]
|
||||
|
||||
return copy.deepcopy(self._results)
|
||||
|
@ -69,19 +69,17 @@ class DSRHead(nn.Module):
|
||||
if cfg.MODEL.HEADS.CLS_LAYER == 'linear':
|
||||
self.classifier = nn.Linear(in_feat, num_classes, bias=False)
|
||||
self.classifier_occ = nn.Linear(in_feat, num_classes, bias=False)
|
||||
self.classifier.apply(weights_init_classifier)
|
||||
self.classifier_occ.apply(weights_init_classifier)
|
||||
elif cfg.MODEL.HEADS.CLS_LAYER == 'arcface':
|
||||
self.classifier = Arcface(cfg, in_feat)
|
||||
self.classifier_occ = Arcface(cfg, in_feat)
|
||||
self.classifier = Arcface(cfg, in_feat, num_classes)
|
||||
self.classifier_occ = Arcface(cfg, in_feat, num_classes)
|
||||
elif cfg.MODEL.HEADS.CLS_LAYER == 'circle':
|
||||
self.classifier = Circle(cfg, in_feat)
|
||||
self.classifier_occ = Circle(cfg, in_feat)
|
||||
self.classifier = Circle(cfg, in_feat, num_classes)
|
||||
self.classifier_occ = Circle(cfg, in_feat, num_classes)
|
||||
else:
|
||||
self.classifier = nn.Linear(in_feat, num_classes, bias=False)
|
||||
self.classifier_occ = nn.Linear(in_feat, num_classes, bias=False)
|
||||
self.classifier.apply(weights_init_classifier)
|
||||
self.classifier_occ.apply(weights_init_classifier)
|
||||
self.classifier.apply(weights_init_classifier)
|
||||
self.classifier_occ.apply(weights_init_classifier)
|
||||
|
||||
def forward(self, features, targets=None):
|
||||
"""
|
||||
@ -112,9 +110,12 @@ class DSRHead(nn.Module):
|
||||
bn_feat = bn_feat[..., 0, 0]
|
||||
|
||||
try:
|
||||
pred_class_logits = self.classifier(bn_feat)
|
||||
fore_pred_class_legits = self.classifier_occ(bn_foreground_feat)
|
||||
cls_outputs = self.classifier(bn_feat)
|
||||
fore_cls_outputs = self.classifier_occ(bn_foreground_feat)
|
||||
except TypeError:
|
||||
pred_class_logits = self.classifier(bn_feat, targets)
|
||||
fore_pred_class_legits = self.classifier_occ(bn_foreground_feat, targets)
|
||||
return pred_class_logits, global_feat[..., 0, 0], fore_pred_class_legits, foreground_feat[..., 0, 0]
|
||||
cls_outputs = self.classifier(bn_feat, targets)
|
||||
fore_cls_outputs = self.classifier_occ(bn_foreground_feat, targets)
|
||||
|
||||
pred_class_logits = F.linear(bn_foreground_feat, self.classifier.weight)
|
||||
|
||||
return cls_outputs, fore_cls_outputs, pred_class_logits, global_feat[..., 0, 0], foreground_feat[..., 0, 0]
|
||||
|
@ -33,12 +33,14 @@ def process_test(query_path, gallery_path):
|
||||
|
||||
@DATASET_REGISTRY.register()
|
||||
class PartialREID(ImageDataset):
|
||||
|
||||
dataset_name = "partialreid"
|
||||
|
||||
def __init__(self, root='datasets',):
|
||||
self.root = root
|
||||
|
||||
self.query_dir = osp.join(self.root, 'PartialREID/query')
|
||||
self.gallery_dir = osp.join(self.root, 'PartialREID/gallery')
|
||||
self.query_dir = osp.join(self.root, 'Partial_REID/partial_body_images')
|
||||
self.gallery_dir = osp.join(self.root, 'Partial_REID/whole_body_images')
|
||||
query, gallery = process_test(self.query_dir, self.gallery_dir)
|
||||
|
||||
ImageDataset.__init__(self, [], query, gallery)
|
||||
@ -47,6 +49,7 @@ class PartialREID(ImageDataset):
|
||||
@DATASET_REGISTRY.register()
|
||||
class PartialiLIDS(ImageDataset):
|
||||
dataset_name = "partialilids"
|
||||
|
||||
def __init__(self, root='datasets',):
|
||||
self.root = root
|
||||
|
||||
@ -60,6 +63,7 @@ class PartialiLIDS(ImageDataset):
|
||||
@DATASET_REGISTRY.register()
|
||||
class OccludedREID(ImageDataset):
|
||||
dataset_name = "occludereid"
|
||||
|
||||
def __init__(self, root='datasets',):
|
||||
self.root = root
|
||||
|
||||
|
@ -12,28 +12,14 @@ from fastreid.modeling.meta_arch.build import META_ARCH_REGISTRY
|
||||
@META_ARCH_REGISTRY.register()
|
||||
class PartialBaseline(Baseline):
|
||||
|
||||
def forward(self, batched_inputs):
|
||||
images = self.preprocess_image(batched_inputs)
|
||||
features = self.backbone(images)
|
||||
|
||||
if self.training:
|
||||
assert "targets" in batched_inputs, "person ID annotation are missing in training!"
|
||||
targets = batched_inputs["targets"].long().to(self.device)
|
||||
|
||||
if targets.sum() < 0: targets.zero_()
|
||||
|
||||
cls_outputs, global_feat, fore_cls_outputs, fore_feat = self.heads(features, targets)
|
||||
return cls_outputs, global_feat, fore_cls_outputs, fore_feat, targets
|
||||
else:
|
||||
pred_features = self.heads(features)
|
||||
return pred_features
|
||||
|
||||
def losses(self, outputs):
|
||||
cls_outputs, global_feat, fore_cls_outputs, fore_feat, gt_labels = outputs
|
||||
def losses(self, outputs, gt_labels):
|
||||
cls_outputs, fore_cls_outputs, pred_class_logits, global_feat, fore_feat = outputs
|
||||
loss_dict = {}
|
||||
|
||||
loss_names = self._cfg.MODEL.LOSSES.NAME
|
||||
|
||||
# Log prediction accuracy
|
||||
CrossEntropyLoss.log_accuracy(pred_class_logits, gt_labels)
|
||||
|
||||
if "CrossEntropyLoss" in loss_names:
|
||||
loss_dict['loss_avg_branch_cls'] = CrossEntropyLoss(self._cfg)(cls_outputs, gt_labels)
|
||||
loss_dict['loss_fore_branch_cls'] = CrossEntropyLoss(self._cfg)(fore_cls_outputs, gt_labels)
|
||||
|
@ -43,13 +43,13 @@ def setup(args):
|
||||
def main(args):
|
||||
cfg = setup(args)
|
||||
|
||||
logger = logging.getLogger('fastreid.' + __name__)
|
||||
if args.eval_only:
|
||||
logger = logging.getLogger("fastreid.trainer")
|
||||
cfg.defrost()
|
||||
cfg.MODEL.BACKBONE.PRETRAIN = False
|
||||
model = Trainer.build_model(cfg)
|
||||
|
||||
Checkpointer(model, save_dir=cfg.OUTPUT_DIR).load(cfg.MODEL.WEIGHTS) # load trained model
|
||||
Checkpointer(model).load(cfg.MODEL.WEIGHTS) # load trained model
|
||||
|
||||
if cfg.TEST.PRECISE_BN.ENABLED and hooks.get_bn_modules(model):
|
||||
prebn_cfg = cfg.clone()
|
||||
|
Loading…
x
Reference in New Issue
Block a user