Fix data prefetcher minor bug

This commit is contained in:
L1aoXingyu 2020-02-27 12:16:57 +08:00
parent 12957f66aa
commit b020c7f0ae
7 changed files with 28 additions and 14 deletions

View File

@ -64,7 +64,7 @@ def build_reid_test_loader(cfg, dataset_name):
test_set,
batch_sampler=batch_sampler,
num_workers=num_workers,
collate_fn=fast_batch_collator, pin_memory=True)
collate_fn=fast_batch_collator)
return data_prefetcher(cfg, test_loader), len(dataset.query)

View File

@ -68,6 +68,10 @@ class data_prefetcher():
self.preload()
def reset(self):
self.loader_iter = iter(self.loader)
self.preload()
def preload(self):
try:
self.next_inputs = next(self.loader_iter)

View File

@ -6,3 +6,4 @@
from .build import build_transforms
from .transforms import *

View File

@ -15,7 +15,6 @@ from collections import OrderedDict
import numpy as np
import torch
# from fvcore.nn.precise_bn import get_bn_modules
from torch.nn import DataParallel
from ..data import build_reid_test_loader, build_reid_train_loader
@ -382,8 +381,8 @@ class DefaultTrainer(SimpleTrainer):
return build_reid_test_loader(cfg, dataset_name)
@classmethod
def build_evaluator(cls, cfg, num_query):
return ReidEvaluator(cfg, num_query)
def build_evaluator(cls, cfg, num_query, output_dir=None):
return ReidEvaluator(cfg, num_query, output_dir)
@classmethod
def test(cls, cfg, model, evaluators=None):

View File

@ -100,6 +100,7 @@ def inference_on_dataset(model, data_loader, evaluator):
logger.info("Start inference on {} images".format(len(data_loader.loader.dataset)))
total = len(data_loader.loader) # inference data loader must have a fixed length
data_loader.reset()
evaluator.reset()
num_warmup = min(5, total - 1)
@ -150,7 +151,6 @@ def inference_on_dataset(model, data_loader, evaluator):
total_compute_time_str, total_compute_time / (total - num_warmup)
)
)
results = evaluator.evaluate()
# An evaluator may return None when not in main process.
# Replace it by an empty dict instead to make it easier for downstream code to handle

View File

@ -112,6 +112,7 @@ def eval_market1501(distmat, q_pids, g_pids, q_camids, g_camids, max_rank):
# compute cmc curve for each query
all_cmc = []
all_AP = []
all_INP = []
num_valid_q = 0. # number of valid query
for q_idx in range(num_q):
@ -125,13 +126,18 @@ def eval_market1501(distmat, q_pids, g_pids, q_camids, g_camids, max_rank):
keep = np.invert(remove)
# compute cmc curve
raw_cmc = matches[q_idx][
keep] # binary vector, positions with value 1 are correct matches
raw_cmc = matches[q_idx][keep] # binary vector, positions with value 1 are correct matches
if not np.any(raw_cmc):
# this condition is true when query identity does not appear in gallery
continue
cmc = raw_cmc.cumsum()
pos_idx = np.where(raw_cmc == 1)
max_pos_idx = np.max(pos_idx)
inp = cmc[max_pos_idx] / (max_pos_idx + 1.0)
all_INP.append(inp)
cmc[cmc > 1] = 1
all_cmc.append(cmc[:max_rank])
@ -151,8 +157,9 @@ def eval_market1501(distmat, q_pids, g_pids, q_camids, g_camids, max_rank):
all_cmc = np.asarray(all_cmc).astype(np.float32)
all_cmc = all_cmc.sum(0) / num_valid_q
mAP = np.mean(all_AP)
mINP = np.mean(all_INP)
return all_cmc, mAP
return all_cmc, mAP, mINP
def evaluate_py(

View File

@ -7,14 +7,16 @@ import copy
from collections import OrderedDict
import torch
import numpy as np
from .evaluator import DatasetEvaluator
from .rank import evaluate_rank
class ReidEvaluator(DatasetEvaluator):
def __init__(self, cfg, num_query):
def __init__(self, cfg, num_query, output_dir=None):
self._num_query = num_query
self._output_dir = output_dir
self.features = []
self.pids = []
@ -35,20 +37,21 @@ class ReidEvaluator(DatasetEvaluator):
# query feature, person ids and camera ids
query_features = features[:self._num_query]
query_pids = self.pids[:self._num_query]
query_camids = self.camids[:self._num_query]
query_pids = np.asarray(self.pids[:self._num_query])
query_camids = np.asarray(self.camids[:self._num_query])
# gallery features, person ids and camera ids
gallery_features = features[self._num_query:]
gallery_pids = self.pids[self._num_query:]
gallery_camids = self.camids[self._num_query:]
gallery_pids = np.asarray(self.pids[self._num_query:])
gallery_camids = np.asarray(self.camids[self._num_query:])
self._results = OrderedDict()
cos_dist = torch.mm(query_features, gallery_features.t()).numpy()
cmc, mAP = evaluate_rank(-cos_dist, query_pids, gallery_pids, query_camids, gallery_camids)
cmc, mAP = evaluate_rank(1 - cos_dist, query_pids, gallery_pids, query_camids, gallery_camids)
for r in [1, 5, 10]:
self._results['Rank-{}'.format(r)] = cmc[r - 1]
self._results['mAP'] = mAP
self._results['mINP'] = 0
return copy.deepcopy(self._results)