76 lines
2.4 KiB
Python
76 lines
2.4 KiB
Python
# encoding: utf-8
|
|
"""
|
|
@author: sherlock
|
|
@contact: sherlockliao01@gmail.com
|
|
"""
|
|
import logging
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from ignite.engine import Engine
|
|
|
|
from utils.reid_metric import R1_mAP, R1_mAP_reranking
|
|
|
|
|
|
def create_supervised_evaluator(model, metrics,
|
|
device=None):
|
|
"""
|
|
Factory function for creating an evaluator for supervised models
|
|
|
|
Args:
|
|
model (`torch.nn.Module`): the model to train
|
|
metrics (dict of str - :class:`ignite.metrics.Metric`): a map of metric names to Metrics
|
|
device (str, optional): device type specification (default: None).
|
|
Applies to both model and batches.
|
|
Returns:
|
|
Engine: an evaluator engine with supervised inference function
|
|
"""
|
|
if device:
|
|
if torch.cuda.device_count() > 1:
|
|
model = nn.DataParallel(model)
|
|
model.to(device)
|
|
|
|
def _inference(engine, batch):
|
|
model.eval()
|
|
with torch.no_grad():
|
|
data, pids, camids = batch
|
|
data = data.to(device) if torch.cuda.device_count() >= 1 else data
|
|
feat = model(data)
|
|
return feat, pids, camids
|
|
|
|
engine = Engine(_inference)
|
|
|
|
for name, metric in metrics.items():
|
|
metric.attach(engine, name)
|
|
|
|
return engine
|
|
|
|
|
|
def inference(
|
|
cfg,
|
|
model,
|
|
val_loader,
|
|
num_query
|
|
):
|
|
device = cfg.MODEL.DEVICE
|
|
|
|
logger = logging.getLogger("reid_baseline.inference")
|
|
logger.info("Enter inferencing")
|
|
if cfg.TEST.RE_RANKING == 'no':
|
|
print("Create evaluator")
|
|
evaluator = create_supervised_evaluator(model, metrics={'r1_mAP': R1_mAP(num_query, max_rank=50, feat_norm=cfg.TEST.FEAT_NORM)},
|
|
device=device)
|
|
elif cfg.TEST.RE_RANKING == 'yes':
|
|
print("Create evaluator for reranking")
|
|
evaluator = create_supervised_evaluator(model, metrics={'r1_mAP': R1_mAP_reranking(num_query, max_rank=50, feat_norm=cfg.TEST.FEAT_NORM)},
|
|
device=device)
|
|
else:
|
|
print("Unsupported re_ranking config. Only support for no or yes, but got {}.".format(cfg.TEST.RE_RANKING))
|
|
|
|
evaluator.run(val_loader)
|
|
cmc, mAP = evaluator.state.metrics['r1_mAP']
|
|
logger.info('Validation Results')
|
|
logger.info("mAP: {:.1%}".format(mAP))
|
|
for r in [1, 5, 10]:
|
|
logger.info("CMC curve, Rank-{:<3}:{:.1%}".format(r, cmc[r - 1]))
|