Merge branch 'develop_reg' of https://github.com/weisy11/PaddleClas into develop_reg
commit
2921934016
|
@ -148,9 +148,9 @@ Infer:
|
|||
|
||||
Metric:
|
||||
Train:
|
||||
- Topk:
|
||||
k: [1, 5]
|
||||
- TopkAcc:
|
||||
topk: [1, 5]
|
||||
Eval:
|
||||
- Topk:
|
||||
k: [1, 5]
|
||||
- TopkAcc:
|
||||
topk: [1, 5]
|
||||
|
||||
|
|
|
@ -1,6 +1,4 @@
|
|||
# global configs
|
||||
Trainer:
|
||||
name: TrainerReID
|
||||
Global:
|
||||
checkpoints: null
|
||||
pretrained_model: null
|
||||
|
@ -16,8 +14,7 @@ Global:
|
|||
# used for static mode and model export
|
||||
image_shape: [3, 224, 224]
|
||||
save_inference_dir: "./inference"
|
||||
num_split: 1
|
||||
feature_normalize: True
|
||||
eval_mode: "retrieval"
|
||||
|
||||
# model architecture
|
||||
Arch:
|
||||
|
@ -99,10 +96,10 @@ DataLoader:
|
|||
loader:
|
||||
num_workers: 6
|
||||
use_shared_memory: False
|
||||
|
||||
Query:
|
||||
Eval:
|
||||
Query:
|
||||
# TOTO: modify to the latest trainer
|
||||
dataset:
|
||||
dataset:
|
||||
name: "VeriWild"
|
||||
image_root: "/work/dataset/VeRI-Wild/images"
|
||||
cls_label_path: "/work/dataset/VeRI-Wild/train_test_split/test_3000_id_query.txt"
|
||||
|
@ -114,18 +111,18 @@ DataLoader:
|
|||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: ''
|
||||
sampler:
|
||||
sampler:
|
||||
name: DistributedBatchSampler
|
||||
batch_size: 64
|
||||
drop_last: False
|
||||
shuffle: False
|
||||
loader:
|
||||
loader:
|
||||
num_workers: 6
|
||||
use_shared_memory: False
|
||||
|
||||
Gallery:
|
||||
Gallery:
|
||||
# TOTO: modify to the latest trainer
|
||||
dataset:
|
||||
dataset:
|
||||
name: "VeriWild"
|
||||
image_root: "/work/dataset/VeRI-Wild/images"
|
||||
cls_label_path: "/work/dataset/VeRI-Wild/train_test_split/test_3000_id.txt"
|
||||
|
@ -137,15 +134,21 @@ DataLoader:
|
|||
mean: [0.485, 0.456, 0.406]
|
||||
std: [0.229, 0.224, 0.225]
|
||||
order: ''
|
||||
sampler:
|
||||
sampler:
|
||||
name: DistributedBatchSampler
|
||||
batch_size: 64
|
||||
drop_last: False
|
||||
shuffle: False
|
||||
loader:
|
||||
loader:
|
||||
num_workers: 6
|
||||
use_shared_memory: False
|
||||
|
||||
Metric:
|
||||
Eval:
|
||||
- Recallk:
|
||||
topk: [1, 5]
|
||||
- mAP: {}
|
||||
|
||||
Infer:
|
||||
infer_imgs: "docs/images/whl/demo.jpg"
|
||||
batch_size: 10
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
import numpy as np
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
from functools import lru_cache
|
||||
|
||||
|
||||
# TODO: fix the format
|
||||
|
@ -38,23 +39,13 @@ class TopkAcc(nn.Layer):
|
|||
|
||||
|
||||
class mAP(nn.Layer):
|
||||
def __init__(self, max_rank=50):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.max_rank = max_rank
|
||||
|
||||
def forward(self, similarities_matrix, query_img_id, gallery_img_id):
|
||||
metric_dict = dict()
|
||||
num_q, num_g = similarities_matrix.shape
|
||||
q_pids = query_img_id.numpy().reshape((query_img_id.shape[0]))
|
||||
g_pids = gallery_img_id.numpy().reshape((gallery_img_id.shape[0]))
|
||||
if num_g < self.max_rank:
|
||||
self.max_rank = num_g
|
||||
print('Note: number of gallery samples is quite small, got {}'.
|
||||
format(num_g))
|
||||
indices = paddle.argsort(
|
||||
similarities_matrix, axis=1, descending=True).numpy()
|
||||
_, all_AP, _ = get_metrics(indices, num_q, num_g, q_pids, g_pids,
|
||||
self.max_rank)
|
||||
_, all_AP, _ = get_metrics(similarities_matrix, query_img_id,
|
||||
gallery_img_id)
|
||||
|
||||
mAP = np.mean(all_AP)
|
||||
metric_dict["mAP"] = mAP
|
||||
|
@ -62,23 +53,13 @@ class mAP(nn.Layer):
|
|||
|
||||
|
||||
class mINP(nn.Layer):
|
||||
def __init__(self, max_rank=50):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.max_rank = max_rank
|
||||
|
||||
def forward(self, similarities_matrix, query_img_id, gallery_img_id):
|
||||
metric_dict = dict()
|
||||
num_q, num_g = similarities_matrix.shape
|
||||
q_pids = query_img_id.numpy().reshape((query_img_id.shape[0]))
|
||||
g_pids = gallery_img_id.numpy().reshape((gallery_img_id.shape[0]))
|
||||
if num_g < self.max_rank:
|
||||
max_rank = num_g
|
||||
print('Note: number of gallery samples is quite small, got {}'.
|
||||
format(num_g))
|
||||
indices = paddle.argsort(
|
||||
similarities_matrix, axis=1, descending=True).numpy()
|
||||
_, _, all_INP = get_metrics(indices, num_q, num_g, q_pids, g_pids,
|
||||
self.max_rank)
|
||||
_, _, all_INP = get_metrics(similarities_matrix, query_img_id,
|
||||
gallery_img_id)
|
||||
|
||||
mINP = np.mean(all_INP)
|
||||
metric_dict["mINP"] = mINP
|
||||
|
@ -86,34 +67,37 @@ class mINP(nn.Layer):
|
|||
|
||||
|
||||
class Recallk(nn.Layer):
|
||||
def __init__(self, max_rank=50, topk=(1, 5)):
|
||||
def __init__(self, topk=(1, 5)):
|
||||
super().__init__()
|
||||
self.max_rank = max_rank
|
||||
assert isinstance(topk, (int, list))
|
||||
if isinstance(topk, int):
|
||||
topk = [topk]
|
||||
self.topk = topk
|
||||
self.max_rank = max(self.topk) if max(self.topk) > 50 else 50
|
||||
|
||||
def forward(self, similarities_matrix, query_img_id, gallery_img_id):
|
||||
metric_dict = dict()
|
||||
num_q, num_g = similarities_matrix.shape
|
||||
q_pids = query_img_id.numpy().reshape((query_img_id.shape[0]))
|
||||
g_pids = gallery_img_id.numpy().reshape((gallery_img_id.shape[0]))
|
||||
if num_g < self.max_rank:
|
||||
max_rank = num_g
|
||||
print('Note: number of gallery samples is quite small, got {}'.
|
||||
format(num_g))
|
||||
indices = paddle.argsort(
|
||||
similarities_matrix, axis=1, descending=True).numpy()
|
||||
all_cmc, _, _ = get_metrics(indices, num_q, num_g, q_pids, g_pids,
|
||||
self.max_rank)
|
||||
all_cmc, _, _ = get_metrics(similarities_matrix, query_img_id,
|
||||
gallery_img_id, self.max_rank)
|
||||
|
||||
for k in self.topk:
|
||||
metric_dict["recall{}".format(k)] = all_cmc[k - 1]
|
||||
return metric_dict
|
||||
|
||||
|
||||
def get_metrics(indices, num_q, num_g, q_pids, g_pids, max_rank=50):
|
||||
@lru_cache()
|
||||
def get_metrics(similarities_matrix, query_img_id, gallery_img_id,
|
||||
max_rank=50):
|
||||
num_q, num_g = similarities_matrix.shape
|
||||
q_pids = query_img_id.numpy().reshape((query_img_id.shape[0]))
|
||||
g_pids = gallery_img_id.numpy().reshape((gallery_img_id.shape[0]))
|
||||
if num_g < max_rank:
|
||||
max_rank = num_g
|
||||
print('Note: number of gallery samples is quite small, got {}'.format(
|
||||
num_g))
|
||||
indices = paddle.argsort(
|
||||
similarities_matrix, axis=1, descending=True).numpy()
|
||||
|
||||
all_cmc = []
|
||||
all_AP = []
|
||||
all_INP = []
|
||||
|
|
Loading…
Reference in New Issue