2019-03-25 01:22:43 +08:00
<!DOCTYPE html>
<!-- [if IE 8]><html class="no - js lt - ie9" lang="en" > <![endif] -->
<!-- [if gt IE 8]><! --> < html class = "no-js" lang = "en" > <!-- <![endif] -->
< head >
< meta charset = "utf-8" >
< meta name = "viewport" content = "width=device-width, initial-scale=1.0" >
2019-05-22 23:18:39 +08:00
< title > torchreid.metrics.rank — torchreid 0.7.6 documentation< / title >
2019-03-25 01:22:43 +08:00
< script type = "text/javascript" src = "../../../_static/js/modernizr.min.js" > < / script >
< script type = "text/javascript" id = "documentation_options" data-url_root = "../../../" src = "../../../_static/documentation_options.js" > < / script >
< script type = "text/javascript" src = "../../../_static/jquery.js" > < / script >
< script type = "text/javascript" src = "../../../_static/underscore.js" > < / script >
< script type = "text/javascript" src = "../../../_static/doctools.js" > < / script >
< script type = "text/javascript" src = "../../../_static/language_data.js" > < / script >
< script async = "async" type = "text/javascript" src = "https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/latest.js?config=TeX-AMS-MML_HTMLorMML" > < / script >
< script type = "text/javascript" src = "../../../_static/js/theme.js" > < / script >
< link rel = "stylesheet" href = "../../../_static/css/theme.css" type = "text/css" / >
< link rel = "stylesheet" href = "../../../_static/pygments.css" type = "text/css" / >
< link rel = "index" title = "Index" href = "../../../genindex.html" / >
< link rel = "search" title = "Search" href = "../../../search.html" / >
< / head >
< body class = "wy-body-for-nav" >
< div class = "wy-grid-for-nav" >
< nav data-toggle = "wy-nav-shift" class = "wy-nav-side" >
< div class = "wy-side-scroll" >
< div class = "wy-side-nav-search" >
< a href = "../../../index.html" class = "icon icon-home" > torchreid
< / a >
< div class = "version" >
2019-05-22 23:18:39 +08:00
0.7.6
2019-03-25 01:22:43 +08:00
< / div >
< div role = "search" >
< form id = "rtd-search-form" class = "wy-form" action = "../../../search.html" method = "get" >
< input type = "text" name = "q" placeholder = "Search docs" / >
< input type = "hidden" name = "check_keywords" value = "yes" / >
< input type = "hidden" name = "area" value = "default" / >
< / form >
< / div >
< / div >
< div class = "wy-menu wy-menu-vertical" data-spy = "affix" role = "navigation" aria-label = "main navigation" >
< ul >
< li class = "toctree-l1" > < a class = "reference internal" href = "../../../user_guide.html" > How-to< / a > < / li >
< li class = "toctree-l1" > < a class = "reference internal" href = "../../../datasets.html" > Datasets< / a > < / li >
< li class = "toctree-l1" > < a class = "reference internal" href = "../../../evaluation.html" > Evaluation< / a > < / li >
< / ul >
< p class = "caption" > < span class = "caption-text" > Package Reference< / span > < / p >
< ul >
< li class = "toctree-l1" > < a class = "reference internal" href = "../../../pkg/data.html" > torchreid.data< / a > < / li >
< li class = "toctree-l1" > < a class = "reference internal" href = "../../../pkg/engine.html" > torchreid.engine< / a > < / li >
< li class = "toctree-l1" > < a class = "reference internal" href = "../../../pkg/losses.html" > torchreid.losses< / a > < / li >
< li class = "toctree-l1" > < a class = "reference internal" href = "../../../pkg/metrics.html" > torchreid.metrics< / a > < / li >
< li class = "toctree-l1" > < a class = "reference internal" href = "../../../pkg/models.html" > torchreid.models< / a > < / li >
< li class = "toctree-l1" > < a class = "reference internal" href = "../../../pkg/optim.html" > torchreid.optim< / a > < / li >
< li class = "toctree-l1" > < a class = "reference internal" href = "../../../pkg/utils.html" > torchreid.utils< / a > < / li >
< / ul >
< p class = "caption" > < span class = "caption-text" > Resources< / span > < / p >
< ul >
< li class = "toctree-l1" > < a class = "reference internal" href = "../../../AWESOME_REID.html" > Awesome-ReID< / a > < / li >
< li class = "toctree-l1" > < a class = "reference internal" href = "../../../MODEL_ZOO.html" > Model Zoo< / a > < / li >
< / ul >
< / div >
< / div >
< / nav >
< section data-toggle = "wy-nav-shift" class = "wy-nav-content-wrap" >
< nav class = "wy-nav-top" aria-label = "top navigation" >
< i data-toggle = "wy-nav-top" class = "fa fa-bars" > < / i >
< a href = "../../../index.html" > torchreid< / a >
< / nav >
< div class = "wy-nav-content" >
< div class = "rst-content" >
< div role = "navigation" aria-label = "breadcrumbs navigation" >
< ul class = "wy-breadcrumbs" >
< li > < a href = "../../../index.html" > Docs< / a > » < / li >
< li > < a href = "../../index.html" > Module code< / a > » < / li >
< li > torchreid.metrics.rank< / li >
< li class = "wy-breadcrumbs-aside" >
< / li >
< / ul >
< hr / >
< / div >
< div role = "main" class = "document" itemscope = "itemscope" itemtype = "http://schema.org/Article" >
< div itemprop = "articleBody" >
< h1 > Source code for torchreid.metrics.rank< / h1 > < div class = "highlight" > < pre >
< span > < / span > < span class = "kn" > from< / span > < span class = "nn" > __future__< / span > < span class = "k" > import< / span > < span class = "n" > absolute_import< / span >
< span class = "kn" > from< / span > < span class = "nn" > __future__< / span > < span class = "k" > import< / span > < span class = "n" > print_function< / span >
< span class = "kn" > from< / span > < span class = "nn" > __future__< / span > < span class = "k" > import< / span > < span class = "n" > division< / span >
< span class = "kn" > import< / span > < span class = "nn" > numpy< / span > < span class = "k" > as< / span > < span class = "nn" > np< / span >
< span class = "kn" > import< / span > < span class = "nn" > copy< / span >
< span class = "kn" > from< / span > < span class = "nn" > collections< / span > < span class = "k" > import< / span > < span class = "n" > defaultdict< / span >
< span class = "kn" > import< / span > < span class = "nn" > sys< / span >
< span class = "kn" > import< / span > < span class = "nn" > warnings< / span >
< span class = "k" > try< / span > < span class = "p" > :< / span >
< span class = "kn" > from< / span > < span class = "nn" > torchreid.metrics.rank_cylib.rank_cy< / span > < span class = "k" > import< / span > < span class = "n" > evaluate_cy< / span >
< span class = "n" > IS_CYTHON_AVAI< / span > < span class = "o" > =< / span > < span class = "kc" > True< / span >
< span class = "k" > except< / span > < span class = "ne" > ImportError< / span > < span class = "p" > :< / span >
< span class = "n" > IS_CYTHON_AVAI< / span > < span class = "o" > =< / span > < span class = "kc" > False< / span >
< span class = "n" > warnings< / span > < span class = "o" > .< / span > < span class = "n" > warn< / span > < span class = "p" > (< / span >
< span class = "s1" > ' Cython evaluation (very fast so highly recommended) is ' < / span >
< span class = "s1" > ' unavailable, now use python evaluation.' < / span >
< span class = "p" > )< / span >
< span class = "k" > def< / span > < span class = "nf" > eval_cuhk03< / span > < span class = "p" > (< / span > < span class = "n" > distmat< / span > < span class = "p" > ,< / span > < span class = "n" > q_pids< / span > < span class = "p" > ,< / span > < span class = "n" > g_pids< / span > < span class = "p" > ,< / span > < span class = "n" > q_camids< / span > < span class = "p" > ,< / span > < span class = "n" > g_camids< / span > < span class = "p" > ,< / span > < span class = "n" > max_rank< / span > < span class = "p" > ):< / span >
< span class = "sd" > " " " Evaluation with cuhk03 metric< / span >
< span class = "sd" > Key: one image for each gallery identity is randomly sampled for each query identity.< / span >
< span class = "sd" > Random sampling is performed num_repeats times.< / span >
< span class = "sd" > " " " < / span >
< span class = "n" > num_repeats< / span > < span class = "o" > =< / span > < span class = "mi" > 10< / span >
< span class = "n" > num_q< / span > < span class = "p" > ,< / span > < span class = "n" > num_g< / span > < span class = "o" > =< / span > < span class = "n" > distmat< / span > < span class = "o" > .< / span > < span class = "n" > shape< / span >
< span class = "k" > if< / span > < span class = "n" > num_g< / span > < span class = "o" > < < / span > < span class = "n" > max_rank< / span > < span class = "p" > :< / span >
< span class = "n" > max_rank< / span > < span class = "o" > =< / span > < span class = "n" > num_g< / span >
< span class = "nb" > print< / span > < span class = "p" > (< / span > < span class = "s1" > ' Note: number of gallery samples is quite small, got < / span > < span class = "si" > {}< / span > < span class = "s1" > ' < / span > < span class = "o" > .< / span > < span class = "n" > format< / span > < span class = "p" > (< / span > < span class = "n" > num_g< / span > < span class = "p" > ))< / span >
< span class = "n" > indices< / span > < span class = "o" > =< / span > < span class = "n" > np< / span > < span class = "o" > .< / span > < span class = "n" > argsort< / span > < span class = "p" > (< / span > < span class = "n" > distmat< / span > < span class = "p" > ,< / span > < span class = "n" > axis< / span > < span class = "o" > =< / span > < span class = "mi" > 1< / span > < span class = "p" > )< / span >
< span class = "n" > matches< / span > < span class = "o" > =< / span > < span class = "p" > (< / span > < span class = "n" > g_pids< / span > < span class = "p" > [< / span > < span class = "n" > indices< / span > < span class = "p" > ]< / span > < span class = "o" > ==< / span > < span class = "n" > q_pids< / span > < span class = "p" > [:,< / span > < span class = "n" > np< / span > < span class = "o" > .< / span > < span class = "n" > newaxis< / span > < span class = "p" > ])< / span > < span class = "o" > .< / span > < span class = "n" > astype< / span > < span class = "p" > (< / span > < span class = "n" > np< / span > < span class = "o" > .< / span > < span class = "n" > int32< / span > < span class = "p" > )< / span >
< span class = "c1" > # compute cmc curve for each query< / span >
< span class = "n" > all_cmc< / span > < span class = "o" > =< / span > < span class = "p" > []< / span >
< span class = "n" > all_AP< / span > < span class = "o" > =< / span > < span class = "p" > []< / span >
< span class = "n" > num_valid_q< / span > < span class = "o" > =< / span > < span class = "mf" > 0.< / span > < span class = "c1" > # number of valid query< / span >
< span class = "k" > for< / span > < span class = "n" > q_idx< / span > < span class = "ow" > in< / span > < span class = "nb" > range< / span > < span class = "p" > (< / span > < span class = "n" > num_q< / span > < span class = "p" > ):< / span >
< span class = "c1" > # get query pid and camid< / span >
< span class = "n" > q_pid< / span > < span class = "o" > =< / span > < span class = "n" > q_pids< / span > < span class = "p" > [< / span > < span class = "n" > q_idx< / span > < span class = "p" > ]< / span >
< span class = "n" > q_camid< / span > < span class = "o" > =< / span > < span class = "n" > q_camids< / span > < span class = "p" > [< / span > < span class = "n" > q_idx< / span > < span class = "p" > ]< / span >
< span class = "c1" > # remove gallery samples that have the same pid and camid with query< / span >
< span class = "n" > order< / span > < span class = "o" > =< / span > < span class = "n" > indices< / span > < span class = "p" > [< / span > < span class = "n" > q_idx< / span > < span class = "p" > ]< / span >
< span class = "n" > remove< / span > < span class = "o" > =< / span > < span class = "p" > (< / span > < span class = "n" > g_pids< / span > < span class = "p" > [< / span > < span class = "n" > order< / span > < span class = "p" > ]< / span > < span class = "o" > ==< / span > < span class = "n" > q_pid< / span > < span class = "p" > )< / span > < span class = "o" > & < / span > < span class = "p" > (< / span > < span class = "n" > g_camids< / span > < span class = "p" > [< / span > < span class = "n" > order< / span > < span class = "p" > ]< / span > < span class = "o" > ==< / span > < span class = "n" > q_camid< / span > < span class = "p" > )< / span >
< span class = "n" > keep< / span > < span class = "o" > =< / span > < span class = "n" > np< / span > < span class = "o" > .< / span > < span class = "n" > invert< / span > < span class = "p" > (< / span > < span class = "n" > remove< / span > < span class = "p" > )< / span >
< span class = "c1" > # compute cmc curve< / span >
< span class = "n" > raw_cmc< / span > < span class = "o" > =< / span > < span class = "n" > matches< / span > < span class = "p" > [< / span > < span class = "n" > q_idx< / span > < span class = "p" > ][< / span > < span class = "n" > keep< / span > < span class = "p" > ]< / span > < span class = "c1" > # binary vector, positions with value 1 are correct matches< / span >
< span class = "k" > if< / span > < span class = "ow" > not< / span > < span class = "n" > np< / span > < span class = "o" > .< / span > < span class = "n" > any< / span > < span class = "p" > (< / span > < span class = "n" > raw_cmc< / span > < span class = "p" > ):< / span >
< span class = "c1" > # this condition is true when query identity does not appear in gallery< / span >
< span class = "k" > continue< / span >
< span class = "n" > kept_g_pids< / span > < span class = "o" > =< / span > < span class = "n" > g_pids< / span > < span class = "p" > [< / span > < span class = "n" > order< / span > < span class = "p" > ][< / span > < span class = "n" > keep< / span > < span class = "p" > ]< / span >
< span class = "n" > g_pids_dict< / span > < span class = "o" > =< / span > < span class = "n" > defaultdict< / span > < span class = "p" > (< / span > < span class = "nb" > list< / span > < span class = "p" > )< / span >
< span class = "k" > for< / span > < span class = "n" > idx< / span > < span class = "p" > ,< / span > < span class = "n" > pid< / span > < span class = "ow" > in< / span > < span class = "nb" > enumerate< / span > < span class = "p" > (< / span > < span class = "n" > kept_g_pids< / span > < span class = "p" > ):< / span >
< span class = "n" > g_pids_dict< / span > < span class = "p" > [< / span > < span class = "n" > pid< / span > < span class = "p" > ]< / span > < span class = "o" > .< / span > < span class = "n" > append< / span > < span class = "p" > (< / span > < span class = "n" > idx< / span > < span class = "p" > )< / span >
< span class = "n" > cmc< / span > < span class = "o" > =< / span > < span class = "mf" > 0.< / span >
< span class = "k" > for< / span > < span class = "n" > repeat_idx< / span > < span class = "ow" > in< / span > < span class = "nb" > range< / span > < span class = "p" > (< / span > < span class = "n" > num_repeats< / span > < span class = "p" > ):< / span >
< span class = "n" > mask< / span > < span class = "o" > =< / span > < span class = "n" > np< / span > < span class = "o" > .< / span > < span class = "n" > zeros< / span > < span class = "p" > (< / span > < span class = "nb" > len< / span > < span class = "p" > (< / span > < span class = "n" > raw_cmc< / span > < span class = "p" > ),< / span > < span class = "n" > dtype< / span > < span class = "o" > =< / span > < span class = "n" > np< / span > < span class = "o" > .< / span > < span class = "n" > bool< / span > < span class = "p" > )< / span >
< span class = "k" > for< / span > < span class = "n" > _< / span > < span class = "p" > ,< / span > < span class = "n" > idxs< / span > < span class = "ow" > in< / span > < span class = "n" > g_pids_dict< / span > < span class = "o" > .< / span > < span class = "n" > items< / span > < span class = "p" > ():< / span >
< span class = "c1" > # randomly sample one image for each gallery person< / span >
< span class = "n" > rnd_idx< / span > < span class = "o" > =< / span > < span class = "n" > np< / span > < span class = "o" > .< / span > < span class = "n" > random< / span > < span class = "o" > .< / span > < span class = "n" > choice< / span > < span class = "p" > (< / span > < span class = "n" > idxs< / span > < span class = "p" > )< / span >
< span class = "n" > mask< / span > < span class = "p" > [< / span > < span class = "n" > rnd_idx< / span > < span class = "p" > ]< / span > < span class = "o" > =< / span > < span class = "kc" > True< / span >
< span class = "n" > masked_raw_cmc< / span > < span class = "o" > =< / span > < span class = "n" > raw_cmc< / span > < span class = "p" > [< / span > < span class = "n" > mask< / span > < span class = "p" > ]< / span >
< span class = "n" > _cmc< / span > < span class = "o" > =< / span > < span class = "n" > masked_raw_cmc< / span > < span class = "o" > .< / span > < span class = "n" > cumsum< / span > < span class = "p" > ()< / span >
< span class = "n" > _cmc< / span > < span class = "p" > [< / span > < span class = "n" > _cmc< / span > < span class = "o" > > < / span > < span class = "mi" > 1< / span > < span class = "p" > ]< / span > < span class = "o" > =< / span > < span class = "mi" > 1< / span >
< span class = "n" > cmc< / span > < span class = "o" > +=< / span > < span class = "n" > _cmc< / span > < span class = "p" > [:< / span > < span class = "n" > max_rank< / span > < span class = "p" > ]< / span > < span class = "o" > .< / span > < span class = "n" > astype< / span > < span class = "p" > (< / span > < span class = "n" > np< / span > < span class = "o" > .< / span > < span class = "n" > float32< / span > < span class = "p" > )< / span >
< span class = "n" > cmc< / span > < span class = "o" > /=< / span > < span class = "n" > num_repeats< / span >
< span class = "n" > all_cmc< / span > < span class = "o" > .< / span > < span class = "n" > append< / span > < span class = "p" > (< / span > < span class = "n" > cmc< / span > < span class = "p" > )< / span >
< span class = "c1" > # compute AP< / span >
< span class = "n" > num_rel< / span > < span class = "o" > =< / span > < span class = "n" > raw_cmc< / span > < span class = "o" > .< / span > < span class = "n" > sum< / span > < span class = "p" > ()< / span >
< span class = "n" > tmp_cmc< / span > < span class = "o" > =< / span > < span class = "n" > raw_cmc< / span > < span class = "o" > .< / span > < span class = "n" > cumsum< / span > < span class = "p" > ()< / span >
< span class = "n" > tmp_cmc< / span > < span class = "o" > =< / span > < span class = "p" > [< / span > < span class = "n" > x< / span > < span class = "o" > /< / span > < span class = "p" > (< / span > < span class = "n" > i< / span > < span class = "o" > +< / span > < span class = "mf" > 1.< / span > < span class = "p" > )< / span > < span class = "k" > for< / span > < span class = "n" > i< / span > < span class = "p" > ,< / span > < span class = "n" > x< / span > < span class = "ow" > in< / span > < span class = "nb" > enumerate< / span > < span class = "p" > (< / span > < span class = "n" > tmp_cmc< / span > < span class = "p" > )]< / span >
< span class = "n" > tmp_cmc< / span > < span class = "o" > =< / span > < span class = "n" > np< / span > < span class = "o" > .< / span > < span class = "n" > asarray< / span > < span class = "p" > (< / span > < span class = "n" > tmp_cmc< / span > < span class = "p" > )< / span > < span class = "o" > *< / span > < span class = "n" > raw_cmc< / span >
< span class = "n" > AP< / span > < span class = "o" > =< / span > < span class = "n" > tmp_cmc< / span > < span class = "o" > .< / span > < span class = "n" > sum< / span > < span class = "p" > ()< / span > < span class = "o" > /< / span > < span class = "n" > num_rel< / span >
< span class = "n" > all_AP< / span > < span class = "o" > .< / span > < span class = "n" > append< / span > < span class = "p" > (< / span > < span class = "n" > AP< / span > < span class = "p" > )< / span >
< span class = "n" > num_valid_q< / span > < span class = "o" > +=< / span > < span class = "mf" > 1.< / span >
< span class = "k" > assert< / span > < span class = "n" > num_valid_q< / span > < span class = "o" > > < / span > < span class = "mi" > 0< / span > < span class = "p" > ,< / span > < span class = "s1" > ' Error: all query identities do not appear in gallery' < / span >
< span class = "n" > all_cmc< / span > < span class = "o" > =< / span > < span class = "n" > np< / span > < span class = "o" > .< / span > < span class = "n" > asarray< / span > < span class = "p" > (< / span > < span class = "n" > all_cmc< / span > < span class = "p" > )< / span > < span class = "o" > .< / span > < span class = "n" > astype< / span > < span class = "p" > (< / span > < span class = "n" > np< / span > < span class = "o" > .< / span > < span class = "n" > float32< / span > < span class = "p" > )< / span >
< span class = "n" > all_cmc< / span > < span class = "o" > =< / span > < span class = "n" > all_cmc< / span > < span class = "o" > .< / span > < span class = "n" > sum< / span > < span class = "p" > (< / span > < span class = "mi" > 0< / span > < span class = "p" > )< / span > < span class = "o" > /< / span > < span class = "n" > num_valid_q< / span >
< span class = "n" > mAP< / span > < span class = "o" > =< / span > < span class = "n" > np< / span > < span class = "o" > .< / span > < span class = "n" > mean< / span > < span class = "p" > (< / span > < span class = "n" > all_AP< / span > < span class = "p" > )< / span >
< span class = "k" > return< / span > < span class = "n" > all_cmc< / span > < span class = "p" > ,< / span > < span class = "n" > mAP< / span >
< span class = "k" > def< / span > < span class = "nf" > eval_market1501< / span > < span class = "p" > (< / span > < span class = "n" > distmat< / span > < span class = "p" > ,< / span > < span class = "n" > q_pids< / span > < span class = "p" > ,< / span > < span class = "n" > g_pids< / span > < span class = "p" > ,< / span > < span class = "n" > q_camids< / span > < span class = "p" > ,< / span > < span class = "n" > g_camids< / span > < span class = "p" > ,< / span > < span class = "n" > max_rank< / span > < span class = "p" > ):< / span >
< span class = "sd" > " " " Evaluation with market1501 metric< / span >
< span class = "sd" > Key: for each query identity, its gallery images from the same camera view are discarded.< / span >
< span class = "sd" > " " " < / span >
< span class = "n" > num_q< / span > < span class = "p" > ,< / span > < span class = "n" > num_g< / span > < span class = "o" > =< / span > < span class = "n" > distmat< / span > < span class = "o" > .< / span > < span class = "n" > shape< / span >
< span class = "k" > if< / span > < span class = "n" > num_g< / span > < span class = "o" > < < / span > < span class = "n" > max_rank< / span > < span class = "p" > :< / span >
< span class = "n" > max_rank< / span > < span class = "o" > =< / span > < span class = "n" > num_g< / span >
< span class = "nb" > print< / span > < span class = "p" > (< / span > < span class = "s1" > ' Note: number of gallery samples is quite small, got < / span > < span class = "si" > {}< / span > < span class = "s1" > ' < / span > < span class = "o" > .< / span > < span class = "n" > format< / span > < span class = "p" > (< / span > < span class = "n" > num_g< / span > < span class = "p" > ))< / span >
< span class = "n" > indices< / span > < span class = "o" > =< / span > < span class = "n" > np< / span > < span class = "o" > .< / span > < span class = "n" > argsort< / span > < span class = "p" > (< / span > < span class = "n" > distmat< / span > < span class = "p" > ,< / span > < span class = "n" > axis< / span > < span class = "o" > =< / span > < span class = "mi" > 1< / span > < span class = "p" > )< / span >
< span class = "n" > matches< / span > < span class = "o" > =< / span > < span class = "p" > (< / span > < span class = "n" > g_pids< / span > < span class = "p" > [< / span > < span class = "n" > indices< / span > < span class = "p" > ]< / span > < span class = "o" > ==< / span > < span class = "n" > q_pids< / span > < span class = "p" > [:,< / span > < span class = "n" > np< / span > < span class = "o" > .< / span > < span class = "n" > newaxis< / span > < span class = "p" > ])< / span > < span class = "o" > .< / span > < span class = "n" > astype< / span > < span class = "p" > (< / span > < span class = "n" > np< / span > < span class = "o" > .< / span > < span class = "n" > int32< / span > < span class = "p" > )< / span >
< span class = "c1" > # compute cmc curve for each query< / span >
< span class = "n" > all_cmc< / span > < span class = "o" > =< / span > < span class = "p" > []< / span >
< span class = "n" > all_AP< / span > < span class = "o" > =< / span > < span class = "p" > []< / span >
< span class = "n" > num_valid_q< / span > < span class = "o" > =< / span > < span class = "mf" > 0.< / span > < span class = "c1" > # number of valid query< / span >
< span class = "k" > for< / span > < span class = "n" > q_idx< / span > < span class = "ow" > in< / span > < span class = "nb" > range< / span > < span class = "p" > (< / span > < span class = "n" > num_q< / span > < span class = "p" > ):< / span >
< span class = "c1" > # get query pid and camid< / span >
< span class = "n" > q_pid< / span > < span class = "o" > =< / span > < span class = "n" > q_pids< / span > < span class = "p" > [< / span > < span class = "n" > q_idx< / span > < span class = "p" > ]< / span >
< span class = "n" > q_camid< / span > < span class = "o" > =< / span > < span class = "n" > q_camids< / span > < span class = "p" > [< / span > < span class = "n" > q_idx< / span > < span class = "p" > ]< / span >
< span class = "c1" > # remove gallery samples that have the same pid and camid with query< / span >
< span class = "n" > order< / span > < span class = "o" > =< / span > < span class = "n" > indices< / span > < span class = "p" > [< / span > < span class = "n" > q_idx< / span > < span class = "p" > ]< / span >
< span class = "n" > remove< / span > < span class = "o" > =< / span > < span class = "p" > (< / span > < span class = "n" > g_pids< / span > < span class = "p" > [< / span > < span class = "n" > order< / span > < span class = "p" > ]< / span > < span class = "o" > ==< / span > < span class = "n" > q_pid< / span > < span class = "p" > )< / span > < span class = "o" > & < / span > < span class = "p" > (< / span > < span class = "n" > g_camids< / span > < span class = "p" > [< / span > < span class = "n" > order< / span > < span class = "p" > ]< / span > < span class = "o" > ==< / span > < span class = "n" > q_camid< / span > < span class = "p" > )< / span >
< span class = "n" > keep< / span > < span class = "o" > =< / span > < span class = "n" > np< / span > < span class = "o" > .< / span > < span class = "n" > invert< / span > < span class = "p" > (< / span > < span class = "n" > remove< / span > < span class = "p" > )< / span >
< span class = "c1" > # compute cmc curve< / span >
< span class = "n" > raw_cmc< / span > < span class = "o" > =< / span > < span class = "n" > matches< / span > < span class = "p" > [< / span > < span class = "n" > q_idx< / span > < span class = "p" > ][< / span > < span class = "n" > keep< / span > < span class = "p" > ]< / span > < span class = "c1" > # binary vector, positions with value 1 are correct matches< / span >
< span class = "k" > if< / span > < span class = "ow" > not< / span > < span class = "n" > np< / span > < span class = "o" > .< / span > < span class = "n" > any< / span > < span class = "p" > (< / span > < span class = "n" > raw_cmc< / span > < span class = "p" > ):< / span >
< span class = "c1" > # this condition is true when query identity does not appear in gallery< / span >
< span class = "k" > continue< / span >
< span class = "n" > cmc< / span > < span class = "o" > =< / span > < span class = "n" > raw_cmc< / span > < span class = "o" > .< / span > < span class = "n" > cumsum< / span > < span class = "p" > ()< / span >
< span class = "n" > cmc< / span > < span class = "p" > [< / span > < span class = "n" > cmc< / span > < span class = "o" > > < / span > < span class = "mi" > 1< / span > < span class = "p" > ]< / span > < span class = "o" > =< / span > < span class = "mi" > 1< / span >
< span class = "n" > all_cmc< / span > < span class = "o" > .< / span > < span class = "n" > append< / span > < span class = "p" > (< / span > < span class = "n" > cmc< / span > < span class = "p" > [:< / span > < span class = "n" > max_rank< / span > < span class = "p" > ])< / span >
< span class = "n" > num_valid_q< / span > < span class = "o" > +=< / span > < span class = "mf" > 1.< / span >
< span class = "c1" > # compute average precision< / span >
< span class = "c1" > # reference: https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Average_precision< / span >
< span class = "n" > num_rel< / span > < span class = "o" > =< / span > < span class = "n" > raw_cmc< / span > < span class = "o" > .< / span > < span class = "n" > sum< / span > < span class = "p" > ()< / span >
< span class = "n" > tmp_cmc< / span > < span class = "o" > =< / span > < span class = "n" > raw_cmc< / span > < span class = "o" > .< / span > < span class = "n" > cumsum< / span > < span class = "p" > ()< / span >
< span class = "n" > tmp_cmc< / span > < span class = "o" > =< / span > < span class = "p" > [< / span > < span class = "n" > x< / span > < span class = "o" > /< / span > < span class = "p" > (< / span > < span class = "n" > i< / span > < span class = "o" > +< / span > < span class = "mf" > 1.< / span > < span class = "p" > )< / span > < span class = "k" > for< / span > < span class = "n" > i< / span > < span class = "p" > ,< / span > < span class = "n" > x< / span > < span class = "ow" > in< / span > < span class = "nb" > enumerate< / span > < span class = "p" > (< / span > < span class = "n" > tmp_cmc< / span > < span class = "p" > )]< / span >
< span class = "n" > tmp_cmc< / span > < span class = "o" > =< / span > < span class = "n" > np< / span > < span class = "o" > .< / span > < span class = "n" > asarray< / span > < span class = "p" > (< / span > < span class = "n" > tmp_cmc< / span > < span class = "p" > )< / span > < span class = "o" > *< / span > < span class = "n" > raw_cmc< / span >
< span class = "n" > AP< / span > < span class = "o" > =< / span > < span class = "n" > tmp_cmc< / span > < span class = "o" > .< / span > < span class = "n" > sum< / span > < span class = "p" > ()< / span > < span class = "o" > /< / span > < span class = "n" > num_rel< / span >
< span class = "n" > all_AP< / span > < span class = "o" > .< / span > < span class = "n" > append< / span > < span class = "p" > (< / span > < span class = "n" > AP< / span > < span class = "p" > )< / span >
< span class = "k" > assert< / span > < span class = "n" > num_valid_q< / span > < span class = "o" > > < / span > < span class = "mi" > 0< / span > < span class = "p" > ,< / span > < span class = "s1" > ' Error: all query identities do not appear in gallery' < / span >
< span class = "n" > all_cmc< / span > < span class = "o" > =< / span > < span class = "n" > np< / span > < span class = "o" > .< / span > < span class = "n" > asarray< / span > < span class = "p" > (< / span > < span class = "n" > all_cmc< / span > < span class = "p" > )< / span > < span class = "o" > .< / span > < span class = "n" > astype< / span > < span class = "p" > (< / span > < span class = "n" > np< / span > < span class = "o" > .< / span > < span class = "n" > float32< / span > < span class = "p" > )< / span >
< span class = "n" > all_cmc< / span > < span class = "o" > =< / span > < span class = "n" > all_cmc< / span > < span class = "o" > .< / span > < span class = "n" > sum< / span > < span class = "p" > (< / span > < span class = "mi" > 0< / span > < span class = "p" > )< / span > < span class = "o" > /< / span > < span class = "n" > num_valid_q< / span >
< span class = "n" > mAP< / span > < span class = "o" > =< / span > < span class = "n" > np< / span > < span class = "o" > .< / span > < span class = "n" > mean< / span > < span class = "p" > (< / span > < span class = "n" > all_AP< / span > < span class = "p" > )< / span >
< span class = "k" > return< / span > < span class = "n" > all_cmc< / span > < span class = "p" > ,< / span > < span class = "n" > mAP< / span >
< span class = "k" > def< / span > < span class = "nf" > evaluate_py< / span > < span class = "p" > (< / span > < span class = "n" > distmat< / span > < span class = "p" > ,< / span > < span class = "n" > q_pids< / span > < span class = "p" > ,< / span > < span class = "n" > g_pids< / span > < span class = "p" > ,< / span > < span class = "n" > q_camids< / span > < span class = "p" > ,< / span > < span class = "n" > g_camids< / span > < span class = "p" > ,< / span > < span class = "n" > max_rank< / span > < span class = "p" > ,< / span > < span class = "n" > use_metric_cuhk03< / span > < span class = "p" > ):< / span >
< span class = "k" > if< / span > < span class = "n" > use_metric_cuhk03< / span > < span class = "p" > :< / span >
< span class = "k" > return< / span > < span class = "n" > eval_cuhk03< / span > < span class = "p" > (< / span > < span class = "n" > distmat< / span > < span class = "p" > ,< / span > < span class = "n" > q_pids< / span > < span class = "p" > ,< / span > < span class = "n" > g_pids< / span > < span class = "p" > ,< / span > < span class = "n" > q_camids< / span > < span class = "p" > ,< / span > < span class = "n" > g_camids< / span > < span class = "p" > ,< / span > < span class = "n" > max_rank< / span > < span class = "p" > )< / span >
< span class = "k" > else< / span > < span class = "p" > :< / span >
< span class = "k" > return< / span > < span class = "n" > eval_market1501< / span > < span class = "p" > (< / span > < span class = "n" > distmat< / span > < span class = "p" > ,< / span > < span class = "n" > q_pids< / span > < span class = "p" > ,< / span > < span class = "n" > g_pids< / span > < span class = "p" > ,< / span > < span class = "n" > q_camids< / span > < span class = "p" > ,< / span > < span class = "n" > g_camids< / span > < span class = "p" > ,< / span > < span class = "n" > max_rank< / span > < span class = "p" > )< / span >
< div class = "viewcode-block" id = "evaluate_rank" > < a class = "viewcode-back" href = "../../../pkg/metrics.html#torchreid.metrics.rank.evaluate_rank" > [docs]< / a > < span class = "k" > def< / span > < span class = "nf" > evaluate_rank< / span > < span class = "p" > (< / span > < span class = "n" > distmat< / span > < span class = "p" > ,< / span > < span class = "n" > q_pids< / span > < span class = "p" > ,< / span > < span class = "n" > g_pids< / span > < span class = "p" > ,< / span > < span class = "n" > q_camids< / span > < span class = "p" > ,< / span > < span class = "n" > g_camids< / span > < span class = "p" > ,< / span > < span class = "n" > max_rank< / span > < span class = "o" > =< / span > < span class = "mi" > 50< / span > < span class = "p" > ,< / span >
< span class = "n" > use_metric_cuhk03< / span > < span class = "o" > =< / span > < span class = "kc" > False< / span > < span class = "p" > ,< / span > < span class = "n" > use_cython< / span > < span class = "o" > =< / span > < span class = "kc" > True< / span > < span class = "p" > ):< / span >
< span class = "sd" > " " " Evaluates CMC rank.< / span >
< span class = "sd" > Args:< / span >
< span class = "sd" > distmat (numpy.ndarray): distance matrix of shape (num_query, num_gallery).< / span >
< span class = "sd" > q_pids (numpy.ndarray): 1-D array containing person identities< / span >
< span class = "sd" > of each query instance.< / span >
< span class = "sd" > g_pids (numpy.ndarray): 1-D array containing person identities< / span >
< span class = "sd" > of each gallery instance.< / span >
< span class = "sd" > q_camids (numpy.ndarray): 1-D array containing camera views under< / span >
< span class = "sd" > which each query instance is captured.< / span >
< span class = "sd" > g_camids (numpy.ndarray): 1-D array containing camera views under< / span >
< span class = "sd" > which each gallery instance is captured.< / span >
< span class = "sd" > max_rank (int, optional): maximum CMC rank to be computed. Default is 50.< / span >
< span class = "sd" > use_metric_cuhk03 (bool, optional): use single-gallery-shot setting for cuhk03.< / span >
< span class = "sd" > Default is False. This should be enabled when using cuhk03 classic split.< / span >
< span class = "sd" > use_cython (bool, optional): use cython code for evaluation. Default is True.< / span >
< span class = "sd" > This is highly recommended as the cython code can speed up the cmc computation< / span >
< span class = "sd" > by more than 10x. This requires Cython to be installed.< / span >
< span class = "sd" > " " " < / span >
< span class = "k" > if< / span > < span class = "n" > use_cython< / span > < span class = "ow" > and< / span > < span class = "n" > IS_CYTHON_AVAI< / span > < span class = "p" > :< / span >
< span class = "k" > return< / span > < span class = "n" > evaluate_cy< / span > < span class = "p" > (< / span > < span class = "n" > distmat< / span > < span class = "p" > ,< / span > < span class = "n" > q_pids< / span > < span class = "p" > ,< / span > < span class = "n" > g_pids< / span > < span class = "p" > ,< / span > < span class = "n" > q_camids< / span > < span class = "p" > ,< / span > < span class = "n" > g_camids< / span > < span class = "p" > ,< / span > < span class = "n" > max_rank< / span > < span class = "p" > ,< / span > < span class = "n" > use_metric_cuhk03< / span > < span class = "p" > )< / span >
< span class = "k" > else< / span > < span class = "p" > :< / span >
< span class = "k" > return< / span > < span class = "n" > evaluate_py< / span > < span class = "p" > (< / span > < span class = "n" > distmat< / span > < span class = "p" > ,< / span > < span class = "n" > q_pids< / span > < span class = "p" > ,< / span > < span class = "n" > g_pids< / span > < span class = "p" > ,< / span > < span class = "n" > q_camids< / span > < span class = "p" > ,< / span > < span class = "n" > g_camids< / span > < span class = "p" > ,< / span > < span class = "n" > max_rank< / span > < span class = "p" > ,< / span > < span class = "n" > use_metric_cuhk03< / span > < span class = "p" > )< / span > < / div >
< / pre > < / div >
< / div >
< / div >
< footer >
< hr / >
< div role = "contentinfo" >
< p >
© Copyright 2019, Kaiyang Zhou
< / p >
< / div >
Built with < a href = "http://sphinx-doc.org/" > Sphinx< / a > using a < a href = "https://github.com/rtfd/sphinx_rtd_theme" > theme< / a > provided by < a href = "https://readthedocs.org" > Read the Docs< / a > .
< / footer >
< / div >
< / div >
< / section >
< / div >
< script type = "text/javascript" >
jQuery(function () {
SphinxRtdTheme.Navigation.enable(true);
});
< / script >
< / body >
< / html >