mirror of
https://github.com/JDAI-CV/fast-reid.git
synced 2025-06-03 14:50:47 +08:00
Fix data prefetcher minor bug
This commit is contained in:
parent
12957f66aa
commit
b020c7f0ae
@ -64,7 +64,7 @@ def build_reid_test_loader(cfg, dataset_name):
|
||||
test_set,
|
||||
batch_sampler=batch_sampler,
|
||||
num_workers=num_workers,
|
||||
collate_fn=fast_batch_collator, pin_memory=True)
|
||||
collate_fn=fast_batch_collator)
|
||||
return data_prefetcher(cfg, test_loader), len(dataset.query)
|
||||
|
||||
|
||||
|
@ -68,6 +68,10 @@ class data_prefetcher():
|
||||
|
||||
self.preload()
|
||||
|
||||
def reset(self):
|
||||
self.loader_iter = iter(self.loader)
|
||||
self.preload()
|
||||
|
||||
def preload(self):
|
||||
try:
|
||||
self.next_inputs = next(self.loader_iter)
|
||||
|
@ -6,3 +6,4 @@
|
||||
|
||||
|
||||
from .build import build_transforms
|
||||
from .transforms import *
|
@ -15,7 +15,6 @@ from collections import OrderedDict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
# from fvcore.nn.precise_bn import get_bn_modules
|
||||
from torch.nn import DataParallel
|
||||
|
||||
from ..data import build_reid_test_loader, build_reid_train_loader
|
||||
@ -382,8 +381,8 @@ class DefaultTrainer(SimpleTrainer):
|
||||
return build_reid_test_loader(cfg, dataset_name)
|
||||
|
||||
@classmethod
|
||||
def build_evaluator(cls, cfg, num_query):
|
||||
return ReidEvaluator(cfg, num_query)
|
||||
def build_evaluator(cls, cfg, num_query, output_dir=None):
|
||||
return ReidEvaluator(cfg, num_query, output_dir)
|
||||
|
||||
@classmethod
|
||||
def test(cls, cfg, model, evaluators=None):
|
||||
|
@ -100,6 +100,7 @@ def inference_on_dataset(model, data_loader, evaluator):
|
||||
logger.info("Start inference on {} images".format(len(data_loader.loader.dataset)))
|
||||
|
||||
total = len(data_loader.loader) # inference data loader must have a fixed length
|
||||
data_loader.reset()
|
||||
evaluator.reset()
|
||||
|
||||
num_warmup = min(5, total - 1)
|
||||
@ -150,7 +151,6 @@ def inference_on_dataset(model, data_loader, evaluator):
|
||||
total_compute_time_str, total_compute_time / (total - num_warmup)
|
||||
)
|
||||
)
|
||||
|
||||
results = evaluator.evaluate()
|
||||
# An evaluator may return None when not in main process.
|
||||
# Replace it by an empty dict instead to make it easier for downstream code to handle
|
||||
|
@ -112,6 +112,7 @@ def eval_market1501(distmat, q_pids, g_pids, q_camids, g_camids, max_rank):
|
||||
# compute cmc curve for each query
|
||||
all_cmc = []
|
||||
all_AP = []
|
||||
all_INP = []
|
||||
num_valid_q = 0. # number of valid query
|
||||
|
||||
for q_idx in range(num_q):
|
||||
@ -125,13 +126,18 @@ def eval_market1501(distmat, q_pids, g_pids, q_camids, g_camids, max_rank):
|
||||
keep = np.invert(remove)
|
||||
|
||||
# compute cmc curve
|
||||
raw_cmc = matches[q_idx][
|
||||
keep] # binary vector, positions with value 1 are correct matches
|
||||
raw_cmc = matches[q_idx][keep] # binary vector, positions with value 1 are correct matches
|
||||
if not np.any(raw_cmc):
|
||||
# this condition is true when query identity does not appear in gallery
|
||||
continue
|
||||
|
||||
cmc = raw_cmc.cumsum()
|
||||
|
||||
pos_idx = np.where(raw_cmc == 1)
|
||||
max_pos_idx = np.max(pos_idx)
|
||||
inp = cmc[max_pos_idx] / (max_pos_idx + 1.0)
|
||||
all_INP.append(inp)
|
||||
|
||||
cmc[cmc > 1] = 1
|
||||
|
||||
all_cmc.append(cmc[:max_rank])
|
||||
@ -151,8 +157,9 @@ def eval_market1501(distmat, q_pids, g_pids, q_camids, g_camids, max_rank):
|
||||
all_cmc = np.asarray(all_cmc).astype(np.float32)
|
||||
all_cmc = all_cmc.sum(0) / num_valid_q
|
||||
mAP = np.mean(all_AP)
|
||||
mINP = np.mean(all_INP)
|
||||
|
||||
return all_cmc, mAP
|
||||
return all_cmc, mAP, mINP
|
||||
|
||||
|
||||
def evaluate_py(
|
||||
|
@ -7,14 +7,16 @@ import copy
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from .evaluator import DatasetEvaluator
|
||||
from .rank import evaluate_rank
|
||||
|
||||
|
||||
class ReidEvaluator(DatasetEvaluator):
|
||||
def __init__(self, cfg, num_query):
|
||||
def __init__(self, cfg, num_query, output_dir=None):
|
||||
self._num_query = num_query
|
||||
self._output_dir = output_dir
|
||||
|
||||
self.features = []
|
||||
self.pids = []
|
||||
@ -35,20 +37,21 @@ class ReidEvaluator(DatasetEvaluator):
|
||||
|
||||
# query feature, person ids and camera ids
|
||||
query_features = features[:self._num_query]
|
||||
query_pids = self.pids[:self._num_query]
|
||||
query_camids = self.camids[:self._num_query]
|
||||
query_pids = np.asarray(self.pids[:self._num_query])
|
||||
query_camids = np.asarray(self.camids[:self._num_query])
|
||||
|
||||
# gallery features, person ids and camera ids
|
||||
gallery_features = features[self._num_query:]
|
||||
gallery_pids = self.pids[self._num_query:]
|
||||
gallery_camids = self.camids[self._num_query:]
|
||||
gallery_pids = np.asarray(self.pids[self._num_query:])
|
||||
gallery_camids = np.asarray(self.camids[self._num_query:])
|
||||
|
||||
self._results = OrderedDict()
|
||||
|
||||
cos_dist = torch.mm(query_features, gallery_features.t()).numpy()
|
||||
cmc, mAP = evaluate_rank(-cos_dist, query_pids, gallery_pids, query_camids, gallery_camids)
|
||||
cmc, mAP = evaluate_rank(1 - cos_dist, query_pids, gallery_pids, query_camids, gallery_camids)
|
||||
for r in [1, 5, 10]:
|
||||
self._results['Rank-{}'.format(r)] = cmc[r - 1]
|
||||
self._results['mAP'] = mAP
|
||||
self._results['mINP'] = 0
|
||||
|
||||
return copy.deepcopy(self._results)
|
||||
|
Loading…
x
Reference in New Issue
Block a user