v1.3.6: added University-1652
parent
6e498f8b17
commit
93b8c9f3db
torchreid
data/datasets
utils/GPU-Re-Ranking
|
@ -33,7 +33,7 @@ You can find some research projects that are built on top of Torchreid `here <ht
|
|||
|
||||
What's new
|
||||
------------
|
||||
- [Feb 2021] We support the new multi-view multi-source geo-localization dataset `University-1652 <https://dl.acm.org/doi/abs/10.1145/3394171.3413896>`_.
|
||||
- [Feb 2021] ``v1.3.6`` Added `University-1652 <https://dl.acm.org/doi/abs/10.1145/3394171.3413896>`_, a new dataset for multi-view multi-source geo-localization (credit to `Zhedong Zheng <https://github.com/layumi>`_).
|
||||
- [Feb 2021] ``v1.3.5``: Now the `cython code <https://github.com/KaiyangZhou/deep-person-reid/pull/412>`_ works on Windows (credit to `lablabla <https://github.com/lablabla>`_).
|
||||
- [Jan 2021] Our recent work, `MixStyle <https://openreview.net/forum?id=6xHJ37MVxxp>`_ (mixing instance-level feature statistics of samples of different domains for improving domain generalization), has been accepted to ICLR'21. The code has been released at https://github.com/KaiyangZhou/mixstyle-release where the person re-ID part is based on Torchreid.
|
||||
- [Jan 2021] A new evaluation metric called `mean Inverse Negative Penalty (mINP)` for person re-ID has been introduced in `Deep Learning for Person Re-identification: A Survey and Outlook (TPAMI 2021) <https://arxiv.org/abs/2001.04193>`_. Their code can be accessed at `<https://github.com/mangye16/ReID-Survey>`_.
|
||||
|
@ -232,7 +232,7 @@ Image-reid datasets
|
|||
- `PRID <https://pdfs.semanticscholar.org/4c1b/f0592be3e535faf256c95e27982db9b3d3d3.pdf>`_
|
||||
|
||||
Geo-localization datasets
|
||||
^^^^^^^^^^^^^^^^^^^^^^^
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
- `University-1652 <https://dl.acm.org/doi/abs/10.1145/3394171.3413896>`_
|
||||
|
||||
Video-reid datasets
|
||||
|
|
|
@ -2,7 +2,7 @@ from __future__ import print_function, absolute_import
|
|||
|
||||
from torchreid import data, optim, utils, engine, losses, models, metrics
|
||||
|
||||
__version__ = '1.3.5'
|
||||
__version__ = '1.3.6'
|
||||
__author__ = 'Kaiyang Zhou'
|
||||
__homepage__ = 'https://kaiyangzhou.github.io/'
|
||||
__description__ = 'Deep learning person re-identification in PyTorch'
|
||||
|
|
|
@ -2,7 +2,7 @@ from __future__ import print_function, absolute_import
|
|||
|
||||
from .image import (
|
||||
GRID, PRID, CUHK01, CUHK02, CUHK03, MSMT17, VIPeR, SenseReID, Market1501,
|
||||
DukeMTMCreID, iLIDS, University1652
|
||||
DukeMTMCreID, University1652, iLIDS
|
||||
)
|
||||
from .video import PRID2011, Mars, DukeMTMCVidReID, iLIDSVID
|
||||
from .dataset import Dataset, ImageDataset, VideoDataset
|
||||
|
@ -19,7 +19,7 @@ __image_datasets = {
|
|||
'sensereid': SenseReID,
|
||||
'prid': PRID,
|
||||
'cuhk02': CUHK02,
|
||||
'university1652':University1652
|
||||
'university1652': University1652
|
||||
}
|
||||
|
||||
__video_datasets = {
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
from __future__ import division, print_function, absolute_import
|
||||
import os
|
||||
import re
|
||||
import glob
|
||||
import os.path as osp
|
||||
import os
|
||||
import gdown
|
||||
|
||||
from ..dataset import ImageDataset
|
||||
|
@ -15,51 +15,56 @@ class University1652(ImageDataset):
|
|||
- Zheng et al. University-1652: A Multi-view Multi-source Benchmark for Drone-based Geo-localization. ACM MM 2020.
|
||||
|
||||
URL: `<https://github.com/layumi/University1652-Baseline>`_
|
||||
OneDrive:
|
||||
https://studentutsedu-my.sharepoint.com/:u:/g/personal/12639605_student_uts_edu_au/Ecrz6xK-PcdCjFdpNb0T0s8B_9J5ynaUy3q63_XumjJyrA?e=z4hpcz
|
||||
[Backup] GoogleDrive:
|
||||
https://drive.google.com/file/d/1iVnP4gjw-iHXa0KerZQ1IfIO0i1jADsR/view?usp=sharing
|
||||
[Backup] Baidu Yun:
|
||||
https://pan.baidu.com/s/1H_wBnWwikKbaBY1pMPjoqQ password: hrqp
|
||||
|
||||
Dataset statistics:
|
||||
- buildings: 1652 (train + query).
|
||||
- The dataset split is as follows:
|
||||
| Split | #imgs | #buildings | #universities|
|
||||
| -------- | ----- | ----| ----|
|
||||
| Training | 50,218 | 701 | 33 |
|
||||
| Query_drone | 37,855 | 701 | 39 |
|
||||
| Query_satellite | 701 | 701 | 39|
|
||||
| Query_ground | 2,579 | 701 | 39|
|
||||
| Gallery_drone | 51,355 | 951 | 39|
|
||||
| Gallery_satellite | 951 | 951 | 39|
|
||||
| Gallery_ground | 2,921 | 793 | 39|
|
||||
- cameras: None.
|
||||
datamanager = torchreid.data.ImageDataManager(
|
||||
root='reid-data',
|
||||
sources='university1652',
|
||||
targets='university1652',
|
||||
height=256,
|
||||
width=256,
|
||||
batch_size_train=32,
|
||||
batch_size_test=100,
|
||||
transforms=['random_flip', 'random_crop']
|
||||
)
|
||||
OneDrive:
|
||||
https://studentutsedu-my.sharepoint.com/:u:/g/personal/12639605_student_uts_edu_au/Ecrz6xK-PcdCjFdpNb0T0s8B_9J5ynaUy3q63_XumjJyrA?e=z4hpcz
|
||||
[Backup] GoogleDrive:
|
||||
https://drive.google.com/file/d/1iVnP4gjw-iHXa0KerZQ1IfIO0i1jADsR/view?usp=sharing
|
||||
[Backup] Baidu Yun:
|
||||
https://pan.baidu.com/s/1H_wBnWwikKbaBY1pMPjoqQ password: hrqp
|
||||
|
||||
Dataset statistics:
|
||||
- buildings: 1652 (train + query).
|
||||
- The dataset split is as follows:
|
||||
| Split | #imgs | #buildings | #universities|
|
||||
| -------- | ----- | ----| ----|
|
||||
| Training | 50,218 | 701 | 33 |
|
||||
| Query_drone | 37,855 | 701 | 39 |
|
||||
| Query_satellite | 701 | 701 | 39|
|
||||
| Query_ground | 2,579 | 701 | 39|
|
||||
| Gallery_drone | 51,355 | 951 | 39|
|
||||
| Gallery_satellite | 951 | 951 | 39|
|
||||
| Gallery_ground | 2,921 | 793 | 39|
|
||||
- cameras: None.
|
||||
datamanager = torchreid.data.ImageDataManager(
|
||||
root='reid-data',
|
||||
sources='university1652',
|
||||
targets='university1652',
|
||||
height=256,
|
||||
width=256,
|
||||
batch_size_train=32,
|
||||
batch_size_test=100,
|
||||
transforms=['random_flip', 'random_crop']
|
||||
)
|
||||
"""
|
||||
dataset_dir = 'university1652'
|
||||
dataset_url = 'https://drive.google.com/uc?id=1iVnP4gjw-iHXa0KerZQ1IfIO0i1jADsR'
|
||||
|
||||
def __init__(self, root='', **kwargs):
|
||||
self.root = osp.abspath(osp.expanduser(root))
|
||||
self.dataset_dir = osp.join(self.root, self.dataset_dir)
|
||||
print(self.dataset_dir)
|
||||
if not os.path.isdir(self.dataset_dir):
|
||||
os.mkdir(self.dataset_dir)
|
||||
gdown.download(self.dataset_url, self.dataset_dir+'data.zip', quiet=False)
|
||||
os.system('unzip %s'%(self.dataset_dir+'data.zip'))
|
||||
gdown.download(
|
||||
self.dataset_url, self.dataset_dir + 'data.zip', quiet=False
|
||||
)
|
||||
os.system('unzip %s' % (self.dataset_dir + 'data.zip'))
|
||||
self.train_dir = osp.join(
|
||||
self.dataset_dir,'University-Release/train/'
|
||||
self.dataset_dir, 'University-Release/train/'
|
||||
)
|
||||
self.query_dir = osp.join(
|
||||
self.dataset_dir, 'University-Release/test/query_drone'
|
||||
)
|
||||
self.query_dir = osp.join(self.dataset_dir, 'University-Release/test/query_drone')
|
||||
self.gallery_dir = osp.join(
|
||||
self.dataset_dir, 'University-Release/test/gallery_satellite'
|
||||
)
|
||||
|
@ -77,7 +82,10 @@ datamanager = torchreid.data.ImageDataManager(
|
|||
super(University1652, self).__init__(train, query, gallery, **kwargs)
|
||||
|
||||
def process_dir(self, dir_path, relabel=False, train=False):
|
||||
IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp')
|
||||
IMG_EXTENSIONS = (
|
||||
'.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff',
|
||||
'.webp'
|
||||
)
|
||||
if train:
|
||||
img_paths = glob.glob(osp.join(dir_path, '*/*/*'))
|
||||
else:
|
||||
|
@ -86,7 +94,7 @@ datamanager = torchreid.data.ImageDataManager(
|
|||
for img_path in img_paths:
|
||||
if not img_path.lower().endswith(IMG_EXTENSIONS):
|
||||
continue
|
||||
pid = int(os.path.basename(os.path.dirname(img_path)))
|
||||
pid = int(os.path.basename(os.path.dirname(img_path)))
|
||||
pid_container.add(pid)
|
||||
pid2label = {pid: label for label, pid in enumerate(pid_container)}
|
||||
data = []
|
||||
|
@ -94,9 +102,9 @@ datamanager = torchreid.data.ImageDataManager(
|
|||
for img_path in img_paths:
|
||||
if not img_path.lower().endswith(IMG_EXTENSIONS):
|
||||
continue
|
||||
pid = int(os.path.basename(os.path.dirname(img_path)))
|
||||
pid = int(os.path.basename(os.path.dirname(img_path)))
|
||||
if relabel:
|
||||
pid = pid2label[pid]
|
||||
data.append((img_path, pid, self.fake_camid))
|
||||
self.fake_camid +=1
|
||||
self.fake_camid += 1
|
||||
return data
|
||||
|
|
|
@ -16,22 +16,21 @@
|
|||
with limited time cost.
|
||||
"""
|
||||
|
||||
from setuptools import setup, Extension
|
||||
|
||||
from setuptools import Extension, setup
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.autograd import Function
|
||||
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
||||
|
||||
from torch.utils.cpp_extension import CUDAExtension, BuildExtension
|
||||
|
||||
setup(
|
||||
name='build_adjacency_matrix',
|
||||
ext_modules=[
|
||||
CUDAExtension('build_adjacency_matrix', [
|
||||
'build_adjacency_matrix.cpp',
|
||||
'build_adjacency_matrix_kernel.cu',
|
||||
]),
|
||||
CUDAExtension(
|
||||
'build_adjacency_matrix', [
|
||||
'build_adjacency_matrix.cpp',
|
||||
'build_adjacency_matrix_kernel.cu',
|
||||
]
|
||||
),
|
||||
],
|
||||
cmdclass={
|
||||
'build_ext':BuildExtension
|
||||
})
|
||||
cmdclass={'build_ext': BuildExtension}
|
||||
)
|
||||
|
|
|
@ -16,22 +16,21 @@
|
|||
with limited time cost.
|
||||
"""
|
||||
|
||||
from setuptools import setup, Extension
|
||||
|
||||
from setuptools import Extension, setup
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.autograd import Function
|
||||
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
||||
|
||||
from torch.utils.cpp_extension import CUDAExtension, BuildExtension
|
||||
|
||||
setup(
|
||||
name='gnn_propagate',
|
||||
ext_modules=[
|
||||
CUDAExtension('gnn_propagate', [
|
||||
'gnn_propagate.cpp',
|
||||
'gnn_propagate_kernel.cu',
|
||||
]),
|
||||
CUDAExtension(
|
||||
'gnn_propagate', [
|
||||
'gnn_propagate.cpp',
|
||||
'gnn_propagate_kernel.cu',
|
||||
]
|
||||
),
|
||||
],
|
||||
cmdclass={
|
||||
'build_ext':BuildExtension
|
||||
})
|
||||
cmdclass={'build_ext': BuildExtension}
|
||||
)
|
||||
|
|
|
@ -16,42 +16,44 @@
|
|||
with limited time cost.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
import build_adjacency_matrix
|
||||
import gnn_propagate
|
||||
|
||||
import build_adjacency_matrix
|
||||
from utils import *
|
||||
|
||||
|
||||
|
||||
def gnn_reranking(X_q, X_g, k1, k2):
|
||||
query_num, gallery_num = X_q.shape[0], X_g.shape[0]
|
||||
|
||||
X_u = torch.cat((X_q, X_g), axis = 0)
|
||||
X_u = torch.cat((X_q, X_g), axis=0)
|
||||
original_score = torch.mm(X_u, X_u.t())
|
||||
del X_u, X_q, X_g
|
||||
|
||||
# initial ranking list
|
||||
S, initial_rank = original_score.topk(k=k1, dim=-1, largest=True, sorted=True)
|
||||
|
||||
S, initial_rank = original_score.topk(
|
||||
k=k1, dim=-1, largest=True, sorted=True
|
||||
)
|
||||
|
||||
# stage 1
|
||||
A = build_adjacency_matrix.forward(initial_rank.float())
|
||||
A = build_adjacency_matrix.forward(initial_rank.float())
|
||||
S = S * S
|
||||
|
||||
# stage 2
|
||||
if k2 != 1:
|
||||
if k2 != 1:
|
||||
for i in range(2):
|
||||
A = A + A.T
|
||||
A = gnn_propagate.forward(A, initial_rank[:, :k2].contiguous().float(), S[:, :k2].contiguous().float())
|
||||
A = gnn_propagate.forward(
|
||||
A, initial_rank[:, :k2].contiguous().float(),
|
||||
S[:, :k2].contiguous().float()
|
||||
)
|
||||
A_norm = torch.norm(A, p=2, dim=1, keepdim=True)
|
||||
A = A.div(A_norm.expand_as(A))
|
||||
|
||||
|
||||
cosine_similarity = torch.mm(A[:query_num,], A[query_num:, ].t())
|
||||
A = A.div(A_norm.expand_as(A))
|
||||
|
||||
cosine_similarity = torch.mm(A[:query_num, ], A[query_num:, ].t())
|
||||
del A, S
|
||||
|
||||
L = torch.sort(-cosine_similarity, dim = 1)[1]
|
||||
|
||||
L = torch.sort(-cosine_similarity, dim=1)[1]
|
||||
L = L.data.cpu().numpy()
|
||||
return L
|
||||
|
|
|
@ -17,46 +17,56 @@
|
|||
"""
|
||||
|
||||
import os
|
||||
import torch
|
||||
import argparse
|
||||
import numpy as np
|
||||
import argparse
|
||||
import torch
|
||||
|
||||
from utils import *
|
||||
from gnn_reranking import *
|
||||
|
||||
parser = argparse.ArgumentParser(description='Reranking_is_GNN')
|
||||
parser.add_argument('--data_path',
|
||||
type=str,
|
||||
default='../xm_rerank_gpu_2/features/market_88_test.pkl',
|
||||
help='path to dataset')
|
||||
parser.add_argument('--k1',
|
||||
type=int,
|
||||
default=26, # Market-1501
|
||||
# default=60, # Veri-776
|
||||
help='parameter k1')
|
||||
parser.add_argument('--k2',
|
||||
type=int,
|
||||
default=7, # Market-1501
|
||||
# default=10, # Veri-776
|
||||
help='parameter k2')
|
||||
parser.add_argument(
|
||||
'--data_path',
|
||||
type=str,
|
||||
default='../xm_rerank_gpu_2/features/market_88_test.pkl',
|
||||
help='path to dataset'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--k1',
|
||||
type=int,
|
||||
default=26, # Market-1501
|
||||
# default=60, # Veri-776
|
||||
help='parameter k1'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--k2',
|
||||
type=int,
|
||||
default=7, # Market-1501
|
||||
# default=10, # Veri-776
|
||||
help='parameter k2'
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
def main():
|
||||
|
||||
def main():
|
||||
data = load_pickle(args.data_path)
|
||||
|
||||
|
||||
query_cam = data['query_cam']
|
||||
query_label = data['query_label']
|
||||
gallery_cam = data['gallery_cam']
|
||||
gallery_label = data['gallery_label']
|
||||
|
||||
|
||||
gallery_feature = torch.FloatTensor(data['gallery_f'])
|
||||
query_feature = torch.FloatTensor(data['query_f'])
|
||||
query_feature = query_feature.cuda()
|
||||
gallery_feature = gallery_feature.cuda()
|
||||
|
||||
indices = gnn_reranking(query_feature, gallery_feature, args.k1, args.k2)
|
||||
evaluate_ranking_list(indices, query_label, query_cam, gallery_label, gallery_cam)
|
||||
|
||||
evaluate_ranking_list(
|
||||
indices, query_label, query_cam, gallery_label, gallery_cam
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
main()
|
||||
|
|
|
@ -16,21 +16,23 @@
|
|||
with limited time cost.
|
||||
"""
|
||||
|
||||
import pickle
|
||||
import numpy as np
|
||||
import os
|
||||
import numpy as np
|
||||
import pickle
|
||||
import torch
|
||||
|
||||
|
||||
def load_pickle(pickle_path):
|
||||
with open(pickle_path, 'rb') as f:
|
||||
data = pickle.load(f)
|
||||
return data
|
||||
with open(pickle_path, 'rb') as f:
|
||||
data = pickle.load(f)
|
||||
return data
|
||||
|
||||
|
||||
def save_pickle(pickle_path, data):
|
||||
with open(pickle_path, 'wb') as f:
|
||||
pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
|
||||
|
||||
def pairwise_squared_distance(x):
|
||||
'''
|
||||
x : (n_samples, n_points, dims)
|
||||
|
@ -38,18 +40,24 @@ def pairwise_squared_distance(x):
|
|||
'''
|
||||
x2s = (x * x).sum(-1, keepdim=True)
|
||||
return x2s + x2s.transpose(-1, -2) - 2 * x @ x.transpose(-1, -2)
|
||||
|
||||
|
||||
|
||||
def pairwise_distance(x, y):
|
||||
m, n = x.size(0), y.size(0)
|
||||
|
||||
|
||||
x = x.view(m, -1)
|
||||
y = y.view(n, -1)
|
||||
|
||||
dist = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(m, n) + torch.pow(y, 2).sum(dim=1, keepdim=True).expand(n,m).t()
|
||||
dist = torch.pow(x, 2).sum(
|
||||
dim=1, keepdim=True
|
||||
).expand(m, n) + torch.pow(y, 2).sum(
|
||||
dim=1, keepdim=True
|
||||
).expand(n, m).t()
|
||||
dist.addmm_(1, -2, x, y.t())
|
||||
|
||||
return dist
|
||||
|
||||
|
||||
def cosine_similarity(x, y):
|
||||
m, n = x.size(0), y.size(0)
|
||||
|
||||
|
@ -61,30 +69,40 @@ def cosine_similarity(x, y):
|
|||
|
||||
return score
|
||||
|
||||
def evaluate_ranking_list(indices, query_label, query_cam, gallery_label, gallery_cam):
|
||||
|
||||
def evaluate_ranking_list(
|
||||
indices, query_label, query_cam, gallery_label, gallery_cam
|
||||
):
|
||||
CMC = np.zeros((len(gallery_label)), dtype=np.int)
|
||||
ap = 0.0
|
||||
|
||||
for i in range(len(query_label)):
|
||||
ap_tmp, CMC_tmp = evaluate(indices[i],query_label[i], query_cam[i], gallery_label, gallery_cam)
|
||||
if CMC_tmp[0]==-1:
|
||||
ap_tmp, CMC_tmp = evaluate(
|
||||
indices[i], query_label[i], query_cam[i], gallery_label,
|
||||
gallery_cam
|
||||
)
|
||||
if CMC_tmp[0] == -1:
|
||||
continue
|
||||
CMC = CMC + CMC_tmp
|
||||
ap += ap_tmp
|
||||
ap += ap_tmp
|
||||
|
||||
CMC = CMC.astype(np.float32)
|
||||
CMC = CMC/len(query_label) #average CMC
|
||||
print('Rank@1:%f Rank@5:%f Rank@10:%f mAP:%f'%(CMC[0],CMC[4],CMC[9],ap/len(query_label)))
|
||||
CMC = CMC / len(query_label) #average CMC
|
||||
print(
|
||||
'Rank@1:%f Rank@5:%f Rank@10:%f mAP:%f' %
|
||||
(CMC[0], CMC[4], CMC[9], ap / len(query_label))
|
||||
)
|
||||
|
||||
def evaluate(index, ql,qc,gl,gc):
|
||||
query_index = np.argwhere(gl==ql)
|
||||
camera_index = np.argwhere(gc==qc)
|
||||
|
||||
def evaluate(index, ql, qc, gl, gc):
|
||||
query_index = np.argwhere(gl == ql)
|
||||
camera_index = np.argwhere(gc == qc)
|
||||
|
||||
good_index = np.setdiff1d(query_index, camera_index, assume_unique=True)
|
||||
junk_index1 = np.argwhere(gl==-1)
|
||||
junk_index1 = np.argwhere(gl == -1)
|
||||
junk_index2 = np.intersect1d(query_index, camera_index)
|
||||
junk_index = np.append(junk_index2, junk_index1) #.flatten())
|
||||
|
||||
|
||||
CMC_tmp = compute_mAP(index, good_index, junk_index)
|
||||
return CMC_tmp
|
||||
|
||||
|
@ -92,9 +110,9 @@ def evaluate(index, ql,qc,gl,gc):
|
|||
def compute_mAP(index, good_index, junk_index):
|
||||
ap = 0
|
||||
cmc = np.zeros((len(index)), dtype=np.int)
|
||||
if good_index.size==0: # if empty
|
||||
if good_index.size == 0: # if empty
|
||||
cmc[0] = -1
|
||||
return ap,cmc
|
||||
return ap, cmc
|
||||
|
||||
# remove junk_index
|
||||
mask = np.in1d(index, junk_index, invert=True)
|
||||
|
@ -103,17 +121,17 @@ def compute_mAP(index, good_index, junk_index):
|
|||
# find good_index index
|
||||
ngood = len(good_index)
|
||||
mask = np.in1d(index, good_index)
|
||||
rows_good = np.argwhere(mask==True)
|
||||
rows_good = np.argwhere(mask == True)
|
||||
rows_good = rows_good.flatten()
|
||||
|
||||
|
||||
cmc[rows_good[0]:] = 1
|
||||
for i in range(ngood):
|
||||
d_recall = 1.0/ngood
|
||||
precision = (i+1)*1.0/(rows_good[i]+1)
|
||||
if rows_good[i]!=0:
|
||||
old_precision = i*1.0/rows_good[i]
|
||||
d_recall = 1.0 / ngood
|
||||
precision = (i+1) * 1.0 / (rows_good[i] + 1)
|
||||
if rows_good[i] != 0:
|
||||
old_precision = i * 1.0 / rows_good[i]
|
||||
else:
|
||||
old_precision=1.0
|
||||
ap = ap + d_recall*(old_precision + precision)/2
|
||||
old_precision = 1.0
|
||||
ap = ap + d_recall * (old_precision+precision) / 2
|
||||
|
||||
return ap, cmc
|
||||
return ap, cmc
|
||||
|
|
Loading…
Reference in New Issue