update partial code

This commit is contained in:
liaoxingyu 2020-07-15 15:08:53 +08:00
parent e0fc15269c
commit 1cf580b6b0
8 changed files with 72 additions and 62 deletions

View File

@ -4,20 +4,26 @@ MODEL:
BACKBONE:
NAME: "build_resnet_backbone"
DEPTH: 50
NORM: "BN"
LAST_STRIDE: 1
WITH_IBN: True
PRETRAIN_PATH: "/export/home/lxy/.cache/torch/checkpoints/resnet50_ibn_a.pth.tar"
HEADS:
NAME: "DSRHead"
NORM: "BN"
POOL_LAYER: "avgpool"
NECK_FEAT: "before"
CLS_LAYER: "linear"
LOSSES:
NAME: ("CrossEntropyLoss", "TripletLoss")
CE:
EPSILON: 0.1
SCALE: 1.0
SCALE: 1.
TRI:
MARGIN: 0.3
HARD_MINING: False
SCALE: 1.
DATASETS:
@ -29,14 +35,13 @@ INPUT:
SIZE_TEST: [384, 128]
REA:
ENABLED: False
PROB: 0.5
MEAN: [123.675, 116.28, 103.53]
DO_PAD: False
DATALOADER:
PK_SAMPLER: True
NAIVE_WAY: True
NUM_INSTANCE: 4
NUM_WORKERS: 16
NUM_WORKERS: 8
SOLVER:
OPT: "Adam"
@ -47,18 +52,19 @@ SOLVER:
WEIGHT_DECAY_BIAS: 0.0
IMS_PER_BATCH: 64
SCHED: "WarmupMultiStepLR"
STEPS: [15, 25]
GAMMA: 0.1
WARMUP_FACTOR: 0.01
WARMUP_ITERS: 3
WARMUP_ITERS: 5
CHECKPOINT_PERIOD: 10
TEST:
EVAL_PERIOD: 5
IMS_PER_BATCH: 512
IMS_PER_BATCH: 128
CUDNN_BENCHMARK: True
OUTPUT_DIR: "logs/test_partial"
OUTPUT_DIR: "projects/PartialReID/logs/test_partial"

View File

@ -12,5 +12,4 @@ def add_partialreid_config(cfg):
_C.TEST.DSR = CN()
_C.TEST.DSR.ENABLED = True
_C.TEST.DSR.TOPK = 30

View File

@ -12,33 +12,42 @@ def normalize(nparray, order=2, axis=0):
return nparray / (norm + np.finfo(np.float32).eps)
def compute_dsr_dist(array1, array2, distmat, scores, topk=30):
def compute_dsr_dist(array1, array2, distmat, scores):
""" Compute the sptial feature reconstruction of all pairs
array: [M, N, C] M: the number of query, N: the number of spatial feature, C: the dimension of each spatial feature
array2: [M, N, C] M: the number of gallery
:return:
numpy array with shape [m1, m2]
"""
dist = 100 * torch.ones(len(array1), len(array2))
dist = dist.cuda()
kappa = 0.001
index = np.argsort(distmat, axis=1)
T = kappa * torch.eye(110)
T = T.cuda()
M = []
for i in range(0, len(array2)):
g = array2[i]
g = torch.FloatTensor(g)
g = g.view(g.size(0), g.size(1))
g = g.cuda()
Proj_M1 = torch.matmul(torch.inverse(torch.matmul(g.t(), g) + T), g.t())
Proj_M1 = Proj_M1.cpu().numpy()
M.append(Proj_M1)
for i in range(0, len(array1)):
q = torch.FloatTensor(array1[i])
q = q.view(q.size(0), q.size(1))
q = q.cuda()
score = scores[i]
for j in range(topk):
for j in range(0, 100):
g = array2[index[i, j]]
g = torch.FloatTensor(g)
g = g.view(g.size(0), g.size(1))
g = g.cuda()
sim = torch.matmul(q.t(), g)
min_value, min_index = (1 - sim).min(1)
dist[i, index[i, j]] = (min_value * score).sum()
Proj_M = torch.FloatTensor(M[index[i, j]])
Proj_M = Proj_M.cuda()
a = torch.matmul(g, torch.matmul(Proj_M, q)) - q
dist[i, index[i, j]] = ((torch.pow(a, 2).sum(0).sqrt()) * scores[i]).sum()
dist = dist.cpu()
dist = dist.numpy()
dist = 0.98 * dist + 0.02 * distmat
return dist

View File

@ -38,13 +38,13 @@ class DsrEvaluator(DatasetEvaluator):
self.pids = []
self.camids = []
def process(self, outputs):
def process(self, inputs, outputs):
self.pids.extend(inputs["targets"].numpy())
self.camids.extend(inputs["camid"].numpy())
self.features.append(F.normalize(outputs[0]).cpu())
outputs1 = F.normalize(outputs[1].data).cpu().numpy()
self.spatial_features.append(outputs1)
self.scores.append(outputs[2])
self.pids.extend(inputs["targets"].numpy())
self.camids.extend(inputs["camid"].numpy())
def evaluate(self):
features = torch.cat(self.features, dim=0)
@ -62,20 +62,25 @@ class DsrEvaluator(DatasetEvaluator):
gallery_camids = np.asarray(self.camids[self._num_query:])
dist = 1 - torch.mm(query_features, gallery_features.t()).numpy()
logger.info("Testing without DSR setting")
self._results = OrderedDict()
if self.cfg.TEST.DSR.ENABLED:
topk = self.cfg.TEST.DSR.TOPK
dist = compute_dsr_dist(spatial_features[:self._num_query], spatial_features[self._num_query:], dist,
scores[:self._num_query], topk)
scores[:self._num_query])
logger.info("Testing with DSR setting")
cmc, all_AP, all_INP = evaluate_rank(dist, query_pids, gallery_pids, query_camids, gallery_camids)
mAP = np.mean(all_AP)
mINP = np.mean(all_INP)
self._results['R-1'] = cmc[0]
for r in [1, 5, 10]:
self._results['Rank-{}'.format(r)] = cmc[r - 1]
self._results['mAP'] = mAP
self._results['mINP'] = mINP
tprs = evaluate_roc(dist, query_pids, gallery_pids, query_camids, gallery_camids)
fprs = [1e-4, 1e-3, 1e-2]
for i in range(len(fprs)):
self._results["TPR@FPR={}".format(fprs[i])] = tprs[i]
return copy.deepcopy(self._results)

View File

@ -69,14 +69,12 @@ class DSRHead(nn.Module):
if cfg.MODEL.HEADS.CLS_LAYER == 'linear':
self.classifier = nn.Linear(in_feat, num_classes, bias=False)
self.classifier_occ = nn.Linear(in_feat, num_classes, bias=False)
self.classifier.apply(weights_init_classifier)
self.classifier_occ.apply(weights_init_classifier)
elif cfg.MODEL.HEADS.CLS_LAYER == 'arcface':
self.classifier = Arcface(cfg, in_feat)
self.classifier_occ = Arcface(cfg, in_feat)
self.classifier = Arcface(cfg, in_feat, num_classes)
self.classifier_occ = Arcface(cfg, in_feat, num_classes)
elif cfg.MODEL.HEADS.CLS_LAYER == 'circle':
self.classifier = Circle(cfg, in_feat)
self.classifier_occ = Circle(cfg, in_feat)
self.classifier = Circle(cfg, in_feat, num_classes)
self.classifier_occ = Circle(cfg, in_feat, num_classes)
else:
self.classifier = nn.Linear(in_feat, num_classes, bias=False)
self.classifier_occ = nn.Linear(in_feat, num_classes, bias=False)
@ -112,9 +110,12 @@ class DSRHead(nn.Module):
bn_feat = bn_feat[..., 0, 0]
try:
pred_class_logits = self.classifier(bn_feat)
fore_pred_class_legits = self.classifier_occ(bn_foreground_feat)
cls_outputs = self.classifier(bn_feat)
fore_cls_outputs = self.classifier_occ(bn_foreground_feat)
except TypeError:
pred_class_logits = self.classifier(bn_feat, targets)
fore_pred_class_legits = self.classifier_occ(bn_foreground_feat, targets)
return pred_class_logits, global_feat[..., 0, 0], fore_pred_class_legits, foreground_feat[..., 0, 0]
cls_outputs = self.classifier(bn_feat, targets)
fore_cls_outputs = self.classifier_occ(bn_foreground_feat, targets)
pred_class_logits = F.linear(bn_foreground_feat, self.classifier.weight)
return cls_outputs, fore_cls_outputs, pred_class_logits, global_feat[..., 0, 0], foreground_feat[..., 0, 0]

View File

@ -33,12 +33,14 @@ def process_test(query_path, gallery_path):
@DATASET_REGISTRY.register()
class PartialREID(ImageDataset):
dataset_name = "partialreid"
def __init__(self, root='datasets',):
self.root = root
self.query_dir = osp.join(self.root, 'PartialREID/query')
self.gallery_dir = osp.join(self.root, 'PartialREID/gallery')
self.query_dir = osp.join(self.root, 'Partial_REID/partial_body_images')
self.gallery_dir = osp.join(self.root, 'Partial_REID/whole_body_images')
query, gallery = process_test(self.query_dir, self.gallery_dir)
ImageDataset.__init__(self, [], query, gallery)
@ -47,6 +49,7 @@ class PartialREID(ImageDataset):
@DATASET_REGISTRY.register()
class PartialiLIDS(ImageDataset):
dataset_name = "partialilids"
def __init__(self, root='datasets',):
self.root = root
@ -60,6 +63,7 @@ class PartialiLIDS(ImageDataset):
@DATASET_REGISTRY.register()
class OccludedREID(ImageDataset):
dataset_name = "occludereid"
def __init__(self, root='datasets',):
self.root = root

View File

@ -12,28 +12,14 @@ from fastreid.modeling.meta_arch.build import META_ARCH_REGISTRY
@META_ARCH_REGISTRY.register()
class PartialBaseline(Baseline):
def forward(self, batched_inputs):
images = self.preprocess_image(batched_inputs)
features = self.backbone(images)
if self.training:
assert "targets" in batched_inputs, "person ID annotation are missing in training!"
targets = batched_inputs["targets"].long().to(self.device)
if targets.sum() < 0: targets.zero_()
cls_outputs, global_feat, fore_cls_outputs, fore_feat = self.heads(features, targets)
return cls_outputs, global_feat, fore_cls_outputs, fore_feat, targets
else:
pred_features = self.heads(features)
return pred_features
def losses(self, outputs):
cls_outputs, global_feat, fore_cls_outputs, fore_feat, gt_labels = outputs
def losses(self, outputs, gt_labels):
cls_outputs, fore_cls_outputs, pred_class_logits, global_feat, fore_feat = outputs
loss_dict = {}
loss_names = self._cfg.MODEL.LOSSES.NAME
# Log prediction accuracy
CrossEntropyLoss.log_accuracy(pred_class_logits, gt_labels)
if "CrossEntropyLoss" in loss_names:
loss_dict['loss_avg_branch_cls'] = CrossEntropyLoss(self._cfg)(cls_outputs, gt_labels)
loss_dict['loss_fore_branch_cls'] = CrossEntropyLoss(self._cfg)(fore_cls_outputs, gt_labels)

View File

@ -43,13 +43,13 @@ def setup(args):
def main(args):
cfg = setup(args)
logger = logging.getLogger('fastreid.' + __name__)
if args.eval_only:
logger = logging.getLogger("fastreid.trainer")
cfg.defrost()
cfg.MODEL.BACKBONE.PRETRAIN = False
model = Trainer.build_model(cfg)
Checkpointer(model, save_dir=cfg.OUTPUT_DIR).load(cfg.MODEL.WEIGHTS) # load trained model
Checkpointer(model).load(cfg.MODEL.WEIGHTS) # load trained model
if cfg.TEST.PRECISE_BN.ENABLED and hooks.get_bn_modules(model):
prebn_cfg = cfg.clone()