2019-03-24 17:22:43 +00:00
<!DOCTYPE html>
<!-- [if IE 8]><html class="no - js lt - ie9" lang="en" > <![endif] -->
<!-- [if gt IE 8]><! --> < html class = "no-js" lang = "en" > <!-- <![endif] -->
< head >
< meta charset = "utf-8" >
< meta name = "viewport" content = "width=device-width, initial-scale=1.0" >
2019-04-18 12:12:17 +01:00
< title > torchreid.engine.engine — torchreid 0.7.3 documentation< / title >
2019-03-24 17:22:43 +00:00
< script type = "text/javascript" src = "../../../_static/js/modernizr.min.js" > < / script >
< script type = "text/javascript" id = "documentation_options" data-url_root = "../../../" src = "../../../_static/documentation_options.js" > < / script >
< script type = "text/javascript" src = "../../../_static/jquery.js" > < / script >
< script type = "text/javascript" src = "../../../_static/underscore.js" > < / script >
< script type = "text/javascript" src = "../../../_static/doctools.js" > < / script >
< script type = "text/javascript" src = "../../../_static/language_data.js" > < / script >
< script async = "async" type = "text/javascript" src = "https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/latest.js?config=TeX-AMS-MML_HTMLorMML" > < / script >
< script type = "text/javascript" src = "../../../_static/js/theme.js" > < / script >
< link rel = "stylesheet" href = "../../../_static/css/theme.css" type = "text/css" / >
< link rel = "stylesheet" href = "../../../_static/pygments.css" type = "text/css" / >
< link rel = "index" title = "Index" href = "../../../genindex.html" / >
< link rel = "search" title = "Search" href = "../../../search.html" / >
< / head >
< body class = "wy-body-for-nav" >
< div class = "wy-grid-for-nav" >
< nav data-toggle = "wy-nav-shift" class = "wy-nav-side" >
< div class = "wy-side-scroll" >
< div class = "wy-side-nav-search" >
< a href = "../../../index.html" class = "icon icon-home" > torchreid
< / a >
< div class = "version" >
2019-04-18 12:12:17 +01:00
0.7.3
2019-03-24 17:22:43 +00:00
< / div >
< div role = "search" >
< form id = "rtd-search-form" class = "wy-form" action = "../../../search.html" method = "get" >
< input type = "text" name = "q" placeholder = "Search docs" / >
< input type = "hidden" name = "check_keywords" value = "yes" / >
< input type = "hidden" name = "area" value = "default" / >
< / form >
< / div >
< / div >
< div class = "wy-menu wy-menu-vertical" data-spy = "affix" role = "navigation" aria-label = "main navigation" >
< ul >
< li class = "toctree-l1" > < a class = "reference internal" href = "../../../user_guide.html" > How-to< / a > < / li >
< li class = "toctree-l1" > < a class = "reference internal" href = "../../../datasets.html" > Datasets< / a > < / li >
< li class = "toctree-l1" > < a class = "reference internal" href = "../../../evaluation.html" > Evaluation< / a > < / li >
< / ul >
< p class = "caption" > < span class = "caption-text" > Package Reference< / span > < / p >
< ul >
< li class = "toctree-l1" > < a class = "reference internal" href = "../../../pkg/data.html" > torchreid.data< / a > < / li >
< li class = "toctree-l1" > < a class = "reference internal" href = "../../../pkg/engine.html" > torchreid.engine< / a > < / li >
< li class = "toctree-l1" > < a class = "reference internal" href = "../../../pkg/losses.html" > torchreid.losses< / a > < / li >
< li class = "toctree-l1" > < a class = "reference internal" href = "../../../pkg/metrics.html" > torchreid.metrics< / a > < / li >
< li class = "toctree-l1" > < a class = "reference internal" href = "../../../pkg/models.html" > torchreid.models< / a > < / li >
< li class = "toctree-l1" > < a class = "reference internal" href = "../../../pkg/optim.html" > torchreid.optim< / a > < / li >
< li class = "toctree-l1" > < a class = "reference internal" href = "../../../pkg/utils.html" > torchreid.utils< / a > < / li >
< / ul >
< p class = "caption" > < span class = "caption-text" > Resources< / span > < / p >
< ul >
< li class = "toctree-l1" > < a class = "reference internal" href = "../../../AWESOME_REID.html" > Awesome-ReID< / a > < / li >
< li class = "toctree-l1" > < a class = "reference internal" href = "../../../MODEL_ZOO.html" > Model Zoo< / a > < / li >
< / ul >
< / div >
< / div >
< / nav >
< section data-toggle = "wy-nav-shift" class = "wy-nav-content-wrap" >
< nav class = "wy-nav-top" aria-label = "top navigation" >
< i data-toggle = "wy-nav-top" class = "fa fa-bars" > < / i >
< a href = "../../../index.html" > torchreid< / a >
< / nav >
< div class = "wy-nav-content" >
< div class = "rst-content" >
< div role = "navigation" aria-label = "breadcrumbs navigation" >
< ul class = "wy-breadcrumbs" >
< li > < a href = "../../../index.html" > Docs< / a > » < / li >
< li > < a href = "../../index.html" > Module code< / a > » < / li >
< li > torchreid.engine.engine< / li >
< li class = "wy-breadcrumbs-aside" >
< / li >
< / ul >
< hr / >
< / div >
< div role = "main" class = "document" itemscope = "itemscope" itemtype = "http://schema.org/Article" >
< div itemprop = "articleBody" >
< h1 > Source code for torchreid.engine.engine< / h1 > < div class = "highlight" > < pre >
< span > < / span > < span class = "kn" > from< / span > < span class = "nn" > __future__< / span > < span class = "k" > import< / span > < span class = "n" > absolute_import< / span >
< span class = "kn" > from< / span > < span class = "nn" > __future__< / span > < span class = "k" > import< / span > < span class = "n" > print_function< / span >
< span class = "kn" > from< / span > < span class = "nn" > __future__< / span > < span class = "k" > import< / span > < span class = "n" > division< / span >
< span class = "kn" > import< / span > < span class = "nn" > os< / span >
< span class = "kn" > import< / span > < span class = "nn" > os.path< / span > < span class = "k" > as< / span > < span class = "nn" > osp< / span >
< span class = "kn" > import< / span > < span class = "nn" > time< / span >
< span class = "kn" > import< / span > < span class = "nn" > datetime< / span >
< span class = "kn" > import< / span > < span class = "nn" > numpy< / span > < span class = "k" > as< / span > < span class = "nn" > np< / span >
< span class = "kn" > import< / span > < span class = "nn" > torch< / span >
< span class = "kn" > import< / span > < span class = "nn" > torch.nn< / span > < span class = "k" > as< / span > < span class = "nn" > nn< / span >
< span class = "kn" > import< / span > < span class = "nn" > torchreid< / span >
< span class = "kn" > from< / span > < span class = "nn" > torchreid.utils< / span > < span class = "k" > import< / span > < span class = "n" > AverageMeter< / span > < span class = "p" > ,< / span > < span class = "n" > visualize_ranked_results< / span > < span class = "p" > ,< / span > < span class = "n" > save_checkpoint< / span >
< span class = "kn" > from< / span > < span class = "nn" > torchreid.losses< / span > < span class = "k" > import< / span > < span class = "n" > DeepSupervision< / span >
< span class = "kn" > from< / span > < span class = "nn" > torchreid< / span > < span class = "k" > import< / span > < span class = "n" > metrics< / span >
< div class = "viewcode-block" id = "Engine" > < a class = "viewcode-back" href = "../../../pkg/engine.html#torchreid.engine.engine.Engine" > [docs]< / a > < span class = "k" > class< / span > < span class = "nc" > Engine< / span > < span class = "p" > (< / span > < span class = "nb" > object< / span > < span class = "p" > ):< / span >
< span class = "sd" > " " " A generic base Engine class for both image- and video-reid.< / span >
< span class = "sd" > Args:< / span >
< span class = "sd" > datamanager (DataManager): an instance of ``torchreid.data.ImageDataManager``< / span >
< span class = "sd" > or ``torchreid.data.VideoDataManager``.< / span >
< span class = "sd" > model (nn.Module): model instance.< / span >
< span class = "sd" > optimizer (Optimizer): an Optimizer.< / span >
< span class = "sd" > scheduler (LRScheduler, optional): if None, no learning rate decay will be performed.< / span >
< span class = "sd" > use_cpu (bool, optional): use cpu. Default is False.< / span >
< span class = "sd" > " " " < / span >
< span class = "k" > def< / span > < span class = "nf" > __init__< / span > < span class = "p" > (< / span > < span class = "bp" > self< / span > < span class = "p" > ,< / span > < span class = "n" > datamanager< / span > < span class = "p" > ,< / span > < span class = "n" > model< / span > < span class = "p" > ,< / span > < span class = "n" > optimizer< / span > < span class = "o" > =< / span > < span class = "kc" > None< / span > < span class = "p" > ,< / span > < span class = "n" > scheduler< / span > < span class = "o" > =< / span > < span class = "kc" > None< / span > < span class = "p" > ,< / span > < span class = "n" > use_cpu< / span > < span class = "o" > =< / span > < span class = "kc" > False< / span > < span class = "p" > ):< / span >
< span class = "bp" > self< / span > < span class = "o" > .< / span > < span class = "n" > datamanager< / span > < span class = "o" > =< / span > < span class = "n" > datamanager< / span >
< span class = "bp" > self< / span > < span class = "o" > .< / span > < span class = "n" > model< / span > < span class = "o" > =< / span > < span class = "n" > model< / span >
< span class = "bp" > self< / span > < span class = "o" > .< / span > < span class = "n" > optimizer< / span > < span class = "o" > =< / span > < span class = "n" > optimizer< / span >
< span class = "bp" > self< / span > < span class = "o" > .< / span > < span class = "n" > scheduler< / span > < span class = "o" > =< / span > < span class = "n" > scheduler< / span >
< span class = "bp" > self< / span > < span class = "o" > .< / span > < span class = "n" > use_gpu< / span > < span class = "o" > =< / span > < span class = "p" > (< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > cuda< / span > < span class = "o" > .< / span > < span class = "n" > is_available< / span > < span class = "p" > ()< / span > < span class = "ow" > and< / span > < span class = "ow" > not< / span > < span class = "n" > use_cpu< / span > < span class = "p" > )< / span >
< span class = "c1" > # check attributes< / span >
< span class = "k" > if< / span > < span class = "ow" > not< / span > < span class = "nb" > isinstance< / span > < span class = "p" > (< / span > < span class = "bp" > self< / span > < span class = "o" > .< / span > < span class = "n" > model< / span > < span class = "p" > ,< / span > < span class = "n" > nn< / span > < span class = "o" > .< / span > < span class = "n" > Module< / span > < span class = "p" > ):< / span >
< span class = "k" > raise< / span > < span class = "ne" > TypeError< / span > < span class = "p" > (< / span > < span class = "s1" > ' model must be an instance of nn.Module' < / span > < span class = "p" > )< / span >
< div class = "viewcode-block" id = "Engine.run" > < a class = "viewcode-back" href = "../../../pkg/engine.html#torchreid.engine.engine.Engine.run" > [docs]< / a > < span class = "k" > def< / span > < span class = "nf" > run< / span > < span class = "p" > (< / span > < span class = "bp" > self< / span > < span class = "p" > ,< / span > < span class = "n" > save_dir< / span > < span class = "o" > =< / span > < span class = "s1" > ' log' < / span > < span class = "p" > ,< / span > < span class = "n" > max_epoch< / span > < span class = "o" > =< / span > < span class = "mi" > 0< / span > < span class = "p" > ,< / span > < span class = "n" > start_epoch< / span > < span class = "o" > =< / span > < span class = "mi" > 0< / span > < span class = "p" > ,< / span > < span class = "n" > fixbase_epoch< / span > < span class = "o" > =< / span > < span class = "mi" > 0< / span > < span class = "p" > ,< / span > < span class = "n" > open_layers< / span > < span class = "o" > =< / span > < span class = "kc" > None< / span > < span class = "p" > ,< / span >
< span class = "n" > start_eval< / span > < span class = "o" > =< / span > < span class = "mi" > 0< / span > < span class = "p" > ,< / span > < span class = "n" > eval_freq< / span > < span class = "o" > =-< / span > < span class = "mi" > 1< / span > < span class = "p" > ,< / span > < span class = "n" > test_only< / span > < span class = "o" > =< / span > < span class = "kc" > False< / span > < span class = "p" > ,< / span > < span class = "n" > print_freq< / span > < span class = "o" > =< / span > < span class = "mi" > 10< / span > < span class = "p" > ,< / span >
< span class = "n" > dist_metric< / span > < span class = "o" > =< / span > < span class = "s1" > ' euclidean' < / span > < span class = "p" > ,< / span > < span class = "n" > visrank< / span > < span class = "o" > =< / span > < span class = "kc" > False< / span > < span class = "p" > ,< / span > < span class = "n" > visrank_topk< / span > < span class = "o" > =< / span > < span class = "mi" > 20< / span > < span class = "p" > ,< / span >
< span class = "n" > use_metric_cuhk03< / span > < span class = "o" > =< / span > < span class = "kc" > False< / span > < span class = "p" > ,< / span > < span class = "n" > ranks< / span > < span class = "o" > =< / span > < span class = "p" > [< / span > < span class = "mi" > 1< / span > < span class = "p" > ,< / span > < span class = "mi" > 5< / span > < span class = "p" > ,< / span > < span class = "mi" > 10< / span > < span class = "p" > ,< / span > < span class = "mi" > 20< / span > < span class = "p" > ]):< / span >
< span class = "sd" > " " " A unified pipeline for training and evaluating a model.< / span >
< span class = "sd" > Args:< / span >
< span class = "sd" > save_dir (str): directory to save model.< / span >
< span class = "sd" > max_epoch (int): maximum epoch.< / span >
< span class = "sd" > start_epoch (int, optional): starting epoch. Default is 0.< / span >
< span class = "sd" > fixbase_epoch (int, optional): number of epochs to train ``open_layers`` (new layers)< / span >
< span class = "sd" > while keeping base layers frozen. Default is 0. ``fixbase_epoch`` is not counted< / span >
< span class = "sd" > in ``max_epoch``.< / span >
< span class = "sd" > open_layers (str or list, optional): layers (attribute names) open for training.< / span >
< span class = "sd" > start_eval (int, optional): from which epoch to start evaluation. Default is 0.< / span >
< span class = "sd" > eval_freq (int, optional): evaluation frequency. Default is -1 (meaning evaluation< / span >
< span class = "sd" > is only performed at the end of training).< / span >
< span class = "sd" > test_only (bool, optional): if True, only runs evaluation on test datasets.< / span >
< span class = "sd" > Default is False.< / span >
< span class = "sd" > print_freq (int, optional): print_frequency. Default is 10.< / span >
< span class = "sd" > dist_metric (str, optional): distance metric used to compute distance matrix< / span >
< span class = "sd" > between query and gallery. Default is " euclidean" .< / span >
< span class = "sd" > visrank (bool, optional): visualizes ranked results. Default is False. Visualization< / span >
< span class = "sd" > will be performed every test time, so it is recommended to enable ``visrank`` when< / span >
< span class = "sd" > ``test_only`` is True. The ranked images will be saved to< / span >
< span class = "sd" > " save_dir/ranks-epoch/dataset_name" , e.g. " save_dir/ranks-60/market1501" .< / span >
< span class = "sd" > visrank_topk (int, optional): top-k ranked images to be visualized. Default is 20.< / span >
< span class = "sd" > use_metric_cuhk03 (bool, optional): use single-gallery-shot setting for cuhk03.< / span >
< span class = "sd" > Default is False. This should be enabled when using cuhk03 classic split.< / span >
< span class = "sd" > ranks (list, optional): cmc ranks to be computed. Default is [1, 5, 10, 20].< / span >
< span class = "sd" > " " " < / span >
< span class = "n" > trainloader< / span > < span class = "p" > ,< / span > < span class = "n" > testloader< / span > < span class = "o" > =< / span > < span class = "bp" > self< / span > < span class = "o" > .< / span > < span class = "n" > datamanager< / span > < span class = "o" > .< / span > < span class = "n" > return_dataloaders< / span > < span class = "p" > ()< / span >
< span class = "k" > if< / span > < span class = "n" > test_only< / span > < span class = "p" > :< / span >
< span class = "bp" > self< / span > < span class = "o" > .< / span > < span class = "n" > test< / span > < span class = "p" > (< / span >
< span class = "mi" > 0< / span > < span class = "p" > ,< / span >
< span class = "n" > testloader< / span > < span class = "p" > ,< / span >
< span class = "n" > dist_metric< / span > < span class = "o" > =< / span > < span class = "n" > dist_metric< / span > < span class = "p" > ,< / span >
< span class = "n" > visrank< / span > < span class = "o" > =< / span > < span class = "n" > visrank< / span > < span class = "p" > ,< / span >
< span class = "n" > visrank_topk< / span > < span class = "o" > =< / span > < span class = "n" > visrank_topk< / span > < span class = "p" > ,< / span >
< span class = "n" > save_dir< / span > < span class = "o" > =< / span > < span class = "n" > save_dir< / span > < span class = "p" > ,< / span >
< span class = "n" > use_metric_cuhk03< / span > < span class = "o" > =< / span > < span class = "n" > use_metric_cuhk03< / span > < span class = "p" > ,< / span >
< span class = "n" > ranks< / span > < span class = "o" > =< / span > < span class = "n" > ranks< / span >
< span class = "p" > )< / span >
< span class = "k" > return< / span >
< span class = "n" > time_start< / span > < span class = "o" > =< / span > < span class = "n" > time< / span > < span class = "o" > .< / span > < span class = "n" > time< / span > < span class = "p" > ()< / span >
< span class = "nb" > print< / span > < span class = "p" > (< / span > < span class = "s1" > ' => Start training' < / span > < span class = "p" > )< / span >
< span class = "k" > if< / span > < span class = "n" > fixbase_epoch< / span > < span class = "o" > > < / span > < span class = "mi" > 0< / span > < span class = "ow" > and< / span > < span class = "p" > (< / span > < span class = "n" > open_layers< / span > < span class = "ow" > is< / span > < span class = "ow" > not< / span > < span class = "kc" > None< / span > < span class = "p" > ):< / span >
< span class = "nb" > print< / span > < span class = "p" > (< / span > < span class = "s1" > ' Pretrain open layers (< / span > < span class = "si" > {}< / span > < span class = "s1" > ) for < / span > < span class = "si" > {}< / span > < span class = "s1" > epochs' < / span > < span class = "o" > .< / span > < span class = "n" > format< / span > < span class = "p" > (< / span > < span class = "n" > open_layers< / span > < span class = "p" > ,< / span > < span class = "n" > fixbase_epoch< / span > < span class = "p" > ))< / span >
< span class = "k" > for< / span > < span class = "n" > epoch< / span > < span class = "ow" > in< / span > < span class = "nb" > range< / span > < span class = "p" > (< / span > < span class = "n" > fixbase_epoch< / span > < span class = "p" > ):< / span >
< span class = "bp" > self< / span > < span class = "o" > .< / span > < span class = "n" > train< / span > < span class = "p" > (< / span > < span class = "n" > epoch< / span > < span class = "p" > ,< / span > < span class = "n" > trainloader< / span > < span class = "p" > ,< / span > < span class = "n" > fixbase< / span > < span class = "o" > =< / span > < span class = "kc" > True< / span > < span class = "p" > ,< / span > < span class = "n" > open_layers< / span > < span class = "o" > =< / span > < span class = "n" > open_layers< / span > < span class = "p" > ,< / span >
< span class = "n" > print_freq< / span > < span class = "o" > =< / span > < span class = "n" > print_freq< / span > < span class = "p" > )< / span >
< span class = "nb" > print< / span > < span class = "p" > (< / span > < span class = "s1" > ' Done. From now on all layers are open to train for < / span > < span class = "si" > {}< / span > < span class = "s1" > epochs' < / span > < span class = "o" > .< / span > < span class = "n" > format< / span > < span class = "p" > (< / span > < span class = "n" > max_epoch< / span > < span class = "p" > ))< / span >
< span class = "k" > for< / span > < span class = "n" > epoch< / span > < span class = "ow" > in< / span > < span class = "nb" > range< / span > < span class = "p" > (< / span > < span class = "n" > start_epoch< / span > < span class = "p" > ,< / span > < span class = "n" > max_epoch< / span > < span class = "p" > ):< / span >
< span class = "bp" > self< / span > < span class = "o" > .< / span > < span class = "n" > train< / span > < span class = "p" > (< / span > < span class = "n" > epoch< / span > < span class = "p" > ,< / span > < span class = "n" > trainloader< / span > < span class = "p" > ,< / span > < span class = "n" > print_freq< / span > < span class = "o" > =< / span > < span class = "n" > print_freq< / span > < span class = "p" > )< / span >
< span class = "k" > if< / span > < span class = "p" > (< / span > < span class = "n" > epoch< / span > < span class = "o" > +< / span > < span class = "mi" > 1< / span > < span class = "p" > )< / span > < span class = "o" > > < / span > < span class = "n" > start_eval< / span > < span class = "ow" > and< / span > < span class = "n" > eval_freq< / span > < span class = "o" > > < / span > < span class = "mi" > 0< / span > < span class = "ow" > and< / span > < span class = "p" > (< / span > < span class = "n" > epoch< / span > < span class = "o" > +< / span > < span class = "mi" > 1< / span > < span class = "p" > )< / span > < span class = "o" > %< / span > < span class = "n" > eval_freq< / span > < span class = "o" > ==< / span > < span class = "mi" > 0< / span > < span class = "ow" > and< / span > < span class = "p" > (< / span > < span class = "n" > epoch< / span > < span class = "o" > +< / span > < span class = "mi" > 1< / span > < span class = "p" > )< / span > < span class = "o" > !=< / span > < span class = "n" > max_epoch< / span > < span class = "p" > :< / span >
< span class = "n" > rank1< / span > < span class = "o" > =< / span > < span class = "bp" > self< / span > < span class = "o" > .< / span > < span class = "n" > test< / span > < span class = "p" > (< / span >
< span class = "n" > epoch< / span > < span class = "p" > ,< / span >
< span class = "n" > testloader< / span > < span class = "p" > ,< / span >
< span class = "n" > dist_metric< / span > < span class = "o" > =< / span > < span class = "n" > dist_metric< / span > < span class = "p" > ,< / span >
< span class = "n" > visrank< / span > < span class = "o" > =< / span > < span class = "n" > visrank< / span > < span class = "p" > ,< / span >
< span class = "n" > visrank_topk< / span > < span class = "o" > =< / span > < span class = "n" > visrank_topk< / span > < span class = "p" > ,< / span >
< span class = "n" > save_dir< / span > < span class = "o" > =< / span > < span class = "n" > save_dir< / span > < span class = "p" > ,< / span >
< span class = "n" > use_metric_cuhk03< / span > < span class = "o" > =< / span > < span class = "n" > use_metric_cuhk03< / span > < span class = "p" > ,< / span >
< span class = "n" > ranks< / span > < span class = "o" > =< / span > < span class = "n" > ranks< / span >
< span class = "p" > )< / span >
< span class = "bp" > self< / span > < span class = "o" > .< / span > < span class = "n" > _save_checkpoint< / span > < span class = "p" > (< / span > < span class = "n" > epoch< / span > < span class = "p" > ,< / span > < span class = "n" > rank1< / span > < span class = "p" > ,< / span > < span class = "n" > save_dir< / span > < span class = "p" > )< / span >
< span class = "k" > if< / span > < span class = "n" > max_epoch< / span > < span class = "o" > > < / span > < span class = "mi" > 0< / span > < span class = "p" > :< / span >
< span class = "nb" > print< / span > < span class = "p" > (< / span > < span class = "s1" > ' => Final test' < / span > < span class = "p" > )< / span >
< span class = "n" > rank1< / span > < span class = "o" > =< / span > < span class = "bp" > self< / span > < span class = "o" > .< / span > < span class = "n" > test< / span > < span class = "p" > (< / span >
< span class = "n" > epoch< / span > < span class = "p" > ,< / span >
< span class = "n" > testloader< / span > < span class = "p" > ,< / span >
< span class = "n" > dist_metric< / span > < span class = "o" > =< / span > < span class = "n" > dist_metric< / span > < span class = "p" > ,< / span >
< span class = "n" > visrank< / span > < span class = "o" > =< / span > < span class = "n" > visrank< / span > < span class = "p" > ,< / span >
< span class = "n" > visrank_topk< / span > < span class = "o" > =< / span > < span class = "n" > visrank_topk< / span > < span class = "p" > ,< / span >
< span class = "n" > save_dir< / span > < span class = "o" > =< / span > < span class = "n" > save_dir< / span > < span class = "p" > ,< / span >
< span class = "n" > use_metric_cuhk03< / span > < span class = "o" > =< / span > < span class = "n" > use_metric_cuhk03< / span > < span class = "p" > ,< / span >
< span class = "n" > ranks< / span > < span class = "o" > =< / span > < span class = "n" > ranks< / span >
< span class = "p" > )< / span >
< span class = "bp" > self< / span > < span class = "o" > .< / span > < span class = "n" > _save_checkpoint< / span > < span class = "p" > (< / span > < span class = "n" > epoch< / span > < span class = "p" > ,< / span > < span class = "n" > rank1< / span > < span class = "p" > ,< / span > < span class = "n" > save_dir< / span > < span class = "p" > )< / span >
< span class = "n" > elapsed< / span > < span class = "o" > =< / span > < span class = "nb" > round< / span > < span class = "p" > (< / span > < span class = "n" > time< / span > < span class = "o" > .< / span > < span class = "n" > time< / span > < span class = "p" > ()< / span > < span class = "o" > -< / span > < span class = "n" > time_start< / span > < span class = "p" > )< / span >
< span class = "n" > elapsed< / span > < span class = "o" > =< / span > < span class = "nb" > str< / span > < span class = "p" > (< / span > < span class = "n" > datetime< / span > < span class = "o" > .< / span > < span class = "n" > timedelta< / span > < span class = "p" > (< / span > < span class = "n" > seconds< / span > < span class = "o" > =< / span > < span class = "n" > elapsed< / span > < span class = "p" > ))< / span >
< span class = "nb" > print< / span > < span class = "p" > (< / span > < span class = "s1" > ' Elapsed < / span > < span class = "si" > {}< / span > < span class = "s1" > ' < / span > < span class = "o" > .< / span > < span class = "n" > format< / span > < span class = "p" > (< / span > < span class = "n" > elapsed< / span > < span class = "p" > ))< / span > < / div >
< div class = "viewcode-block" id = "Engine.train" > < a class = "viewcode-back" href = "../../../pkg/engine.html#torchreid.engine.engine.Engine.train" > [docs]< / a > < span class = "k" > def< / span > < span class = "nf" > train< / span > < span class = "p" > (< / span > < span class = "bp" > self< / span > < span class = "p" > ):< / span >
< span class = "sd" > " " " Performs training on source datasets for one epoch.< / span >
< span class = "sd" > This will be called every epoch in ``run()``, e.g.< / span >
< span class = "sd" > .. code-block:: python< / span >
< span class = "sd" > < / span >
< span class = "sd" > for epoch in range(start_epoch, max_epoch):< / span >
< span class = "sd" > self.train(some_arguments)< / span >
< span class = "sd" > .. note::< / span >
< span class = "sd" > < / span >
< span class = "sd" > This needs to be implemented in subclasses.< / span >
< span class = "sd" > " " " < / span >
< span class = "k" > raise< / span > < span class = "ne" > NotImplementedError< / span > < / div >
< div class = "viewcode-block" id = "Engine.test" > < a class = "viewcode-back" href = "../../../pkg/engine.html#torchreid.engine.engine.Engine.test" > [docs]< / a > < span class = "k" > def< / span > < span class = "nf" > test< / span > < span class = "p" > (< / span > < span class = "bp" > self< / span > < span class = "p" > ,< / span > < span class = "n" > epoch< / span > < span class = "p" > ,< / span > < span class = "n" > testloader< / span > < span class = "p" > ,< / span > < span class = "n" > dist_metric< / span > < span class = "o" > =< / span > < span class = "s1" > ' euclidean' < / span > < span class = "p" > ,< / span > < span class = "n" > visrank< / span > < span class = "o" > =< / span > < span class = "kc" > False< / span > < span class = "p" > ,< / span > < span class = "n" > visrank_topk< / span > < span class = "o" > =< / span > < span class = "mi" > 20< / span > < span class = "p" > ,< / span >
< span class = "n" > save_dir< / span > < span class = "o" > =< / span > < span class = "s1" > ' ' < / span > < span class = "p" > ,< / span > < span class = "n" > use_metric_cuhk03< / span > < span class = "o" > =< / span > < span class = "kc" > False< / span > < span class = "p" > ,< / span > < span class = "n" > ranks< / span > < span class = "o" > =< / span > < span class = "p" > [< / span > < span class = "mi" > 1< / span > < span class = "p" > ,< / span > < span class = "mi" > 5< / span > < span class = "p" > ,< / span > < span class = "mi" > 10< / span > < span class = "p" > ,< / span > < span class = "mi" > 20< / span > < span class = "p" > ]):< / span >
< span class = "sd" > " " " Tests model on target datasets.< / span >
< span class = "sd" > .. note::< / span >
< span class = "sd" > This function has been called in ``run()`` when necessary.< / span >
< span class = "sd" > .. note::< / span >
< span class = "sd" > The test pipeline implemented in this function suits both image- and< / span >
< span class = "sd" > video-reid. In general, a subclass of Engine only needs to re-implement< / span >
< span class = "sd" > ``_extract_features()`` and ``_parse_data_for_eval()`` when necessary,< / span >
< span class = "sd" > but not a must. Please refer to the source code for more details.< / span >
< span class = "sd" > Args:< / span >
< span class = "sd" > epoch (int): current epoch.< / span >
< span class = "sd" > testloader (dict): dictionary containing< / span >
< span class = "sd" > {dataset_name: ' query' : queryloader, ' gallery' : galleryloader}.< / span >
< span class = "sd" > dist_metric (str, optional): distance metric used to compute distance matrix< / span >
< span class = "sd" > between query and gallery. Default is " euclidean" .< / span >
< span class = "sd" > visrank (bool, optional): visualizes ranked results. Default is False. Visualization< / span >
< span class = "sd" > will be performed every test time, so it is recommended to enable ``visrank`` when< / span >
< span class = "sd" > ``test_only`` is True. The ranked images will be saved to< / span >
< span class = "sd" > " save_dir/ranks-epoch/dataset_name" , e.g. " save_dir/ranks-60/market1501" .< / span >
< span class = "sd" > visrank_topk (int, optional): top-k ranked images to be visualized. Default is 20.< / span >
< span class = "sd" > save_dir (str): directory to save visualized results if ``visrank`` is True.< / span >
< span class = "sd" > use_metric_cuhk03 (bool, optional): use single-gallery-shot setting for cuhk03.< / span >
< span class = "sd" > Default is False. This should be enabled when using cuhk03 classic split.< / span >
< span class = "sd" > ranks (list, optional): cmc ranks to be computed. Default is [1, 5, 10, 20].< / span >
< span class = "sd" > " " " < / span >
< span class = "n" > targets< / span > < span class = "o" > =< / span > < span class = "nb" > list< / span > < span class = "p" > (< / span > < span class = "n" > testloader< / span > < span class = "o" > .< / span > < span class = "n" > keys< / span > < span class = "p" > ())< / span >
< span class = "k" > for< / span > < span class = "n" > name< / span > < span class = "ow" > in< / span > < span class = "n" > targets< / span > < span class = "p" > :< / span >
2019-03-27 22:31:26 +00:00
< span class = "n" > domain< / span > < span class = "o" > =< / span > < span class = "s1" > ' source' < / span > < span class = "k" > if< / span > < span class = "n" > name< / span > < span class = "ow" > in< / span > < span class = "bp" > self< / span > < span class = "o" > .< / span > < span class = "n" > datamanager< / span > < span class = "o" > .< / span > < span class = "n" > sources< / span > < span class = "k" > else< / span > < span class = "s1" > ' target' < / span >
< span class = "nb" > print< / span > < span class = "p" > (< / span > < span class = "s1" > ' < / span > < span class = "se" > \n< / span > < span class = "s1" > ##### Evaluating < / span > < span class = "si" > {}< / span > < span class = "s1" > (< / span > < span class = "si" > {}< / span > < span class = "s1" > ) #####' < / span > < span class = "o" > .< / span > < span class = "n" > format< / span > < span class = "p" > (< / span > < span class = "n" > name< / span > < span class = "p" > ,< / span > < span class = "n" > domain< / span > < span class = "p" > ))< / span >
2019-03-24 17:22:43 +00:00
< span class = "n" > queryloader< / span > < span class = "o" > =< / span > < span class = "n" > testloader< / span > < span class = "p" > [< / span > < span class = "n" > name< / span > < span class = "p" > ][< / span > < span class = "s1" > ' query' < / span > < span class = "p" > ]< / span >
< span class = "n" > galleryloader< / span > < span class = "o" > =< / span > < span class = "n" > testloader< / span > < span class = "p" > [< / span > < span class = "n" > name< / span > < span class = "p" > ][< / span > < span class = "s1" > ' gallery' < / span > < span class = "p" > ]< / span >
< span class = "n" > rank1< / span > < span class = "o" > =< / span > < span class = "bp" > self< / span > < span class = "o" > .< / span > < span class = "n" > _evaluate< / span > < span class = "p" > (< / span >
< span class = "n" > epoch< / span > < span class = "p" > ,< / span >
< span class = "n" > dataset_name< / span > < span class = "o" > =< / span > < span class = "n" > name< / span > < span class = "p" > ,< / span >
< span class = "n" > queryloader< / span > < span class = "o" > =< / span > < span class = "n" > queryloader< / span > < span class = "p" > ,< / span >
< span class = "n" > galleryloader< / span > < span class = "o" > =< / span > < span class = "n" > galleryloader< / span > < span class = "p" > ,< / span >
< span class = "n" > dist_metric< / span > < span class = "o" > =< / span > < span class = "n" > dist_metric< / span > < span class = "p" > ,< / span >
< span class = "n" > visrank< / span > < span class = "o" > =< / span > < span class = "n" > visrank< / span > < span class = "p" > ,< / span >
< span class = "n" > visrank_topk< / span > < span class = "o" > =< / span > < span class = "n" > visrank_topk< / span > < span class = "p" > ,< / span >
< span class = "n" > save_dir< / span > < span class = "o" > =< / span > < span class = "n" > save_dir< / span > < span class = "p" > ,< / span >
< span class = "n" > use_metric_cuhk03< / span > < span class = "o" > =< / span > < span class = "n" > use_metric_cuhk03< / span > < span class = "p" > ,< / span >
< span class = "n" > ranks< / span > < span class = "o" > =< / span > < span class = "n" > ranks< / span >
< span class = "p" > )< / span >
< span class = "k" > return< / span > < span class = "n" > rank1< / span > < / div >
< span class = "nd" > @torch< / span > < span class = "o" > .< / span > < span class = "n" > no_grad< / span > < span class = "p" > ()< / span >
< span class = "k" > def< / span > < span class = "nf" > _evaluate< / span > < span class = "p" > (< / span > < span class = "bp" > self< / span > < span class = "p" > ,< / span > < span class = "n" > epoch< / span > < span class = "p" > ,< / span > < span class = "n" > dataset_name< / span > < span class = "o" > =< / span > < span class = "s1" > ' ' < / span > < span class = "p" > ,< / span > < span class = "n" > queryloader< / span > < span class = "o" > =< / span > < span class = "kc" > None< / span > < span class = "p" > ,< / span > < span class = "n" > galleryloader< / span > < span class = "o" > =< / span > < span class = "kc" > None< / span > < span class = "p" > ,< / span >
< span class = "n" > dist_metric< / span > < span class = "o" > =< / span > < span class = "s1" > ' euclidean' < / span > < span class = "p" > ,< / span > < span class = "n" > visrank< / span > < span class = "o" > =< / span > < span class = "kc" > False< / span > < span class = "p" > ,< / span > < span class = "n" > visrank_topk< / span > < span class = "o" > =< / span > < span class = "mi" > 20< / span > < span class = "p" > ,< / span > < span class = "n" > save_dir< / span > < span class = "o" > =< / span > < span class = "s1" > ' ' < / span > < span class = "p" > ,< / span >
< span class = "n" > use_metric_cuhk03< / span > < span class = "o" > =< / span > < span class = "kc" > False< / span > < span class = "p" > ,< / span > < span class = "n" > ranks< / span > < span class = "o" > =< / span > < span class = "p" > [< / span > < span class = "mi" > 1< / span > < span class = "p" > ,< / span > < span class = "mi" > 5< / span > < span class = "p" > ,< / span > < span class = "mi" > 10< / span > < span class = "p" > ,< / span > < span class = "mi" > 20< / span > < span class = "p" > ]):< / span >
< span class = "n" > batch_time< / span > < span class = "o" > =< / span > < span class = "n" > AverageMeter< / span > < span class = "p" > ()< / span >
< span class = "bp" > self< / span > < span class = "o" > .< / span > < span class = "n" > model< / span > < span class = "o" > .< / span > < span class = "n" > eval< / span > < span class = "p" > ()< / span >
< span class = "nb" > print< / span > < span class = "p" > (< / span > < span class = "s1" > ' Extracting features from query set ...' < / span > < span class = "p" > )< / span >
< span class = "n" > qf< / span > < span class = "p" > ,< / span > < span class = "n" > q_pids< / span > < span class = "p" > ,< / span > < span class = "n" > q_camids< / span > < span class = "o" > =< / span > < span class = "p" > [],< / span > < span class = "p" > [],< / span > < span class = "p" > []< / span >
< span class = "k" > for< / span > < span class = "n" > batch_idx< / span > < span class = "p" > ,< / span > < span class = "n" > data< / span > < span class = "ow" > in< / span > < span class = "nb" > enumerate< / span > < span class = "p" > (< / span > < span class = "n" > queryloader< / span > < span class = "p" > ):< / span >
< span class = "n" > imgs< / span > < span class = "p" > ,< / span > < span class = "n" > pids< / span > < span class = "p" > ,< / span > < span class = "n" > camids< / span > < span class = "o" > =< / span > < span class = "bp" > self< / span > < span class = "o" > .< / span > < span class = "n" > _parse_data_for_eval< / span > < span class = "p" > (< / span > < span class = "n" > data< / span > < span class = "p" > )< / span >
< span class = "k" > if< / span > < span class = "bp" > self< / span > < span class = "o" > .< / span > < span class = "n" > use_gpu< / span > < span class = "p" > :< / span >
< span class = "n" > imgs< / span > < span class = "o" > =< / span > < span class = "n" > imgs< / span > < span class = "o" > .< / span > < span class = "n" > cuda< / span > < span class = "p" > ()< / span >
< span class = "n" > end< / span > < span class = "o" > =< / span > < span class = "n" > time< / span > < span class = "o" > .< / span > < span class = "n" > time< / span > < span class = "p" > ()< / span >
< span class = "n" > features< / span > < span class = "o" > =< / span > < span class = "bp" > self< / span > < span class = "o" > .< / span > < span class = "n" > _extract_features< / span > < span class = "p" > (< / span > < span class = "n" > imgs< / span > < span class = "p" > )< / span >
< span class = "n" > batch_time< / span > < span class = "o" > .< / span > < span class = "n" > update< / span > < span class = "p" > (< / span > < span class = "n" > time< / span > < span class = "o" > .< / span > < span class = "n" > time< / span > < span class = "p" > ()< / span > < span class = "o" > -< / span > < span class = "n" > end< / span > < span class = "p" > )< / span >
< span class = "n" > features< / span > < span class = "o" > =< / span > < span class = "n" > features< / span > < span class = "o" > .< / span > < span class = "n" > data< / span > < span class = "o" > .< / span > < span class = "n" > cpu< / span > < span class = "p" > ()< / span >
< span class = "n" > qf< / span > < span class = "o" > .< / span > < span class = "n" > append< / span > < span class = "p" > (< / span > < span class = "n" > features< / span > < span class = "p" > )< / span >
< span class = "n" > q_pids< / span > < span class = "o" > .< / span > < span class = "n" > extend< / span > < span class = "p" > (< / span > < span class = "n" > pids< / span > < span class = "p" > )< / span >
< span class = "n" > q_camids< / span > < span class = "o" > .< / span > < span class = "n" > extend< / span > < span class = "p" > (< / span > < span class = "n" > camids< / span > < span class = "p" > )< / span >
< span class = "n" > qf< / span > < span class = "o" > =< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > cat< / span > < span class = "p" > (< / span > < span class = "n" > qf< / span > < span class = "p" > ,< / span > < span class = "mi" > 0< / span > < span class = "p" > )< / span >
< span class = "n" > q_pids< / span > < span class = "o" > =< / span > < span class = "n" > np< / span > < span class = "o" > .< / span > < span class = "n" > asarray< / span > < span class = "p" > (< / span > < span class = "n" > q_pids< / span > < span class = "p" > )< / span >
< span class = "n" > q_camids< / span > < span class = "o" > =< / span > < span class = "n" > np< / span > < span class = "o" > .< / span > < span class = "n" > asarray< / span > < span class = "p" > (< / span > < span class = "n" > q_camids< / span > < span class = "p" > )< / span >
< span class = "nb" > print< / span > < span class = "p" > (< / span > < span class = "s1" > ' Done, obtained < / span > < span class = "si" > {}< / span > < span class = "s1" > -by-< / span > < span class = "si" > {}< / span > < span class = "s1" > matrix' < / span > < span class = "o" > .< / span > < span class = "n" > format< / span > < span class = "p" > (< / span > < span class = "n" > qf< / span > < span class = "o" > .< / span > < span class = "n" > size< / span > < span class = "p" > (< / span > < span class = "mi" > 0< / span > < span class = "p" > ),< / span > < span class = "n" > qf< / span > < span class = "o" > .< / span > < span class = "n" > size< / span > < span class = "p" > (< / span > < span class = "mi" > 1< / span > < span class = "p" > )))< / span >
< span class = "nb" > print< / span > < span class = "p" > (< / span > < span class = "s1" > ' Extracting features from gallery set ...' < / span > < span class = "p" > )< / span >
< span class = "n" > gf< / span > < span class = "p" > ,< / span > < span class = "n" > g_pids< / span > < span class = "p" > ,< / span > < span class = "n" > g_camids< / span > < span class = "o" > =< / span > < span class = "p" > [],< / span > < span class = "p" > [],< / span > < span class = "p" > []< / span >
< span class = "n" > end< / span > < span class = "o" > =< / span > < span class = "n" > time< / span > < span class = "o" > .< / span > < span class = "n" > time< / span > < span class = "p" > ()< / span >
< span class = "k" > for< / span > < span class = "n" > batch_idx< / span > < span class = "p" > ,< / span > < span class = "n" > data< / span > < span class = "ow" > in< / span > < span class = "nb" > enumerate< / span > < span class = "p" > (< / span > < span class = "n" > galleryloader< / span > < span class = "p" > ):< / span >
< span class = "n" > imgs< / span > < span class = "p" > ,< / span > < span class = "n" > pids< / span > < span class = "p" > ,< / span > < span class = "n" > camids< / span > < span class = "o" > =< / span > < span class = "bp" > self< / span > < span class = "o" > .< / span > < span class = "n" > _parse_data_for_eval< / span > < span class = "p" > (< / span > < span class = "n" > data< / span > < span class = "p" > )< / span >
< span class = "k" > if< / span > < span class = "bp" > self< / span > < span class = "o" > .< / span > < span class = "n" > use_gpu< / span > < span class = "p" > :< / span >
< span class = "n" > imgs< / span > < span class = "o" > =< / span > < span class = "n" > imgs< / span > < span class = "o" > .< / span > < span class = "n" > cuda< / span > < span class = "p" > ()< / span >
< span class = "n" > end< / span > < span class = "o" > =< / span > < span class = "n" > time< / span > < span class = "o" > .< / span > < span class = "n" > time< / span > < span class = "p" > ()< / span >
< span class = "n" > features< / span > < span class = "o" > =< / span > < span class = "bp" > self< / span > < span class = "o" > .< / span > < span class = "n" > _extract_features< / span > < span class = "p" > (< / span > < span class = "n" > imgs< / span > < span class = "p" > )< / span >
< span class = "n" > batch_time< / span > < span class = "o" > .< / span > < span class = "n" > update< / span > < span class = "p" > (< / span > < span class = "n" > time< / span > < span class = "o" > .< / span > < span class = "n" > time< / span > < span class = "p" > ()< / span > < span class = "o" > -< / span > < span class = "n" > end< / span > < span class = "p" > )< / span >
< span class = "n" > features< / span > < span class = "o" > =< / span > < span class = "n" > features< / span > < span class = "o" > .< / span > < span class = "n" > data< / span > < span class = "o" > .< / span > < span class = "n" > cpu< / span > < span class = "p" > ()< / span >
< span class = "n" > gf< / span > < span class = "o" > .< / span > < span class = "n" > append< / span > < span class = "p" > (< / span > < span class = "n" > features< / span > < span class = "p" > )< / span >
< span class = "n" > g_pids< / span > < span class = "o" > .< / span > < span class = "n" > extend< / span > < span class = "p" > (< / span > < span class = "n" > pids< / span > < span class = "p" > )< / span >
< span class = "n" > g_camids< / span > < span class = "o" > .< / span > < span class = "n" > extend< / span > < span class = "p" > (< / span > < span class = "n" > camids< / span > < span class = "p" > )< / span >
< span class = "n" > gf< / span > < span class = "o" > =< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > cat< / span > < span class = "p" > (< / span > < span class = "n" > gf< / span > < span class = "p" > ,< / span > < span class = "mi" > 0< / span > < span class = "p" > )< / span >
< span class = "n" > g_pids< / span > < span class = "o" > =< / span > < span class = "n" > np< / span > < span class = "o" > .< / span > < span class = "n" > asarray< / span > < span class = "p" > (< / span > < span class = "n" > g_pids< / span > < span class = "p" > )< / span >
< span class = "n" > g_camids< / span > < span class = "o" > =< / span > < span class = "n" > np< / span > < span class = "o" > .< / span > < span class = "n" > asarray< / span > < span class = "p" > (< / span > < span class = "n" > g_camids< / span > < span class = "p" > )< / span >
< span class = "nb" > print< / span > < span class = "p" > (< / span > < span class = "s1" > ' Done, obtained < / span > < span class = "si" > {}< / span > < span class = "s1" > -by-< / span > < span class = "si" > {}< / span > < span class = "s1" > matrix' < / span > < span class = "o" > .< / span > < span class = "n" > format< / span > < span class = "p" > (< / span > < span class = "n" > gf< / span > < span class = "o" > .< / span > < span class = "n" > size< / span > < span class = "p" > (< / span > < span class = "mi" > 0< / span > < span class = "p" > ),< / span > < span class = "n" > gf< / span > < span class = "o" > .< / span > < span class = "n" > size< / span > < span class = "p" > (< / span > < span class = "mi" > 1< / span > < span class = "p" > )))< / span >
< span class = "nb" > print< / span > < span class = "p" > (< / span > < span class = "s1" > ' Speed: < / span > < span class = "si" > {:.4f}< / span > < span class = "s1" > sec/batch' < / span > < span class = "o" > .< / span > < span class = "n" > format< / span > < span class = "p" > (< / span > < span class = "n" > batch_time< / span > < span class = "o" > .< / span > < span class = "n" > avg< / span > < span class = "p" > ))< / span >
< span class = "n" > distmat< / span > < span class = "o" > =< / span > < span class = "n" > metrics< / span > < span class = "o" > .< / span > < span class = "n" > compute_distance_matrix< / span > < span class = "p" > (< / span > < span class = "n" > qf< / span > < span class = "p" > ,< / span > < span class = "n" > gf< / span > < span class = "p" > ,< / span > < span class = "n" > dist_metric< / span > < span class = "p" > )< / span >
< span class = "n" > distmat< / span > < span class = "o" > =< / span > < span class = "n" > distmat< / span > < span class = "o" > .< / span > < span class = "n" > numpy< / span > < span class = "p" > ()< / span >
< span class = "nb" > print< / span > < span class = "p" > (< / span > < span class = "s1" > ' Computing CMC and mAP ...' < / span > < span class = "p" > )< / span >
< span class = "n" > cmc< / span > < span class = "p" > ,< / span > < span class = "n" > mAP< / span > < span class = "o" > =< / span > < span class = "n" > metrics< / span > < span class = "o" > .< / span > < span class = "n" > evaluate_rank< / span > < span class = "p" > (< / span >
< span class = "n" > distmat< / span > < span class = "p" > ,< / span >
< span class = "n" > q_pids< / span > < span class = "p" > ,< / span >
< span class = "n" > g_pids< / span > < span class = "p" > ,< / span >
< span class = "n" > q_camids< / span > < span class = "p" > ,< / span >
< span class = "n" > g_camids< / span > < span class = "p" > ,< / span >
< span class = "n" > use_metric_cuhk03< / span > < span class = "o" > =< / span > < span class = "n" > use_metric_cuhk03< / span >
< span class = "p" > )< / span >
< span class = "nb" > print< / span > < span class = "p" > (< / span > < span class = "s1" > ' ** Results **' < / span > < span class = "p" > )< / span >
< span class = "nb" > print< / span > < span class = "p" > (< / span > < span class = "s1" > ' mAP: < / span > < span class = "si" > {:.1%}< / span > < span class = "s1" > ' < / span > < span class = "o" > .< / span > < span class = "n" > format< / span > < span class = "p" > (< / span > < span class = "n" > mAP< / span > < span class = "p" > ))< / span >
< span class = "nb" > print< / span > < span class = "p" > (< / span > < span class = "s1" > ' CMC curve' < / span > < span class = "p" > )< / span >
< span class = "k" > for< / span > < span class = "n" > r< / span > < span class = "ow" > in< / span > < span class = "n" > ranks< / span > < span class = "p" > :< / span >
< span class = "nb" > print< / span > < span class = "p" > (< / span > < span class = "s1" > ' Rank-< / span > < span class = "si" > {:< 3}< / span > < span class = "s1" > : < / span > < span class = "si" > {:.1%}< / span > < span class = "s1" > ' < / span > < span class = "o" > .< / span > < span class = "n" > format< / span > < span class = "p" > (< / span > < span class = "n" > r< / span > < span class = "p" > ,< / span > < span class = "n" > cmc< / span > < span class = "p" > [< / span > < span class = "n" > r< / span > < span class = "o" > -< / span > < span class = "mi" > 1< / span > < span class = "p" > ]))< / span >
< span class = "k" > if< / span > < span class = "n" > visrank< / span > < span class = "p" > :< / span >
< span class = "n" > visualize_ranked_results< / span > < span class = "p" > (< / span >
< span class = "n" > distmat< / span > < span class = "p" > ,< / span >
< span class = "bp" > self< / span > < span class = "o" > .< / span > < span class = "n" > datamanager< / span > < span class = "o" > .< / span > < span class = "n" > return_testdataset_by_name< / span > < span class = "p" > (< / span > < span class = "n" > dataset_name< / span > < span class = "p" > ),< / span >
< span class = "n" > save_dir< / span > < span class = "o" > =< / span > < span class = "n" > osp< / span > < span class = "o" > .< / span > < span class = "n" > join< / span > < span class = "p" > (< / span > < span class = "n" > save_dir< / span > < span class = "p" > ,< / span > < span class = "s1" > ' visrank-' < / span > < span class = "o" > +< / span > < span class = "nb" > str< / span > < span class = "p" > (< / span > < span class = "n" > epoch< / span > < span class = "o" > +< / span > < span class = "mi" > 1< / span > < span class = "p" > ),< / span > < span class = "n" > dataset_name< / span > < span class = "p" > ),< / span >
< span class = "n" > topk< / span > < span class = "o" > =< / span > < span class = "n" > visrank_topk< / span >
< span class = "p" > )< / span >
< span class = "k" > return< / span > < span class = "n" > cmc< / span > < span class = "p" > [< / span > < span class = "mi" > 0< / span > < span class = "p" > ]< / span >
< span class = "k" > def< / span > < span class = "nf" > _compute_loss< / span > < span class = "p" > (< / span > < span class = "bp" > self< / span > < span class = "p" > ,< / span > < span class = "n" > criterion< / span > < span class = "p" > ,< / span > < span class = "n" > outputs< / span > < span class = "p" > ,< / span > < span class = "n" > targets< / span > < span class = "p" > ):< / span >
< span class = "k" > if< / span > < span class = "nb" > isinstance< / span > < span class = "p" > (< / span > < span class = "n" > outputs< / span > < span class = "p" > ,< / span > < span class = "p" > (< / span > < span class = "nb" > tuple< / span > < span class = "p" > ,< / span > < span class = "nb" > list< / span > < span class = "p" > )):< / span >
< span class = "n" > loss< / span > < span class = "o" > =< / span > < span class = "n" > DeepSupervision< / span > < span class = "p" > (< / span > < span class = "n" > criterion< / span > < span class = "p" > ,< / span > < span class = "n" > outputs< / span > < span class = "p" > ,< / span > < span class = "n" > targets< / span > < span class = "p" > )< / span >
< span class = "k" > else< / span > < span class = "p" > :< / span >
< span class = "n" > loss< / span > < span class = "o" > =< / span > < span class = "n" > criterion< / span > < span class = "p" > (< / span > < span class = "n" > outputs< / span > < span class = "p" > ,< / span > < span class = "n" > targets< / span > < span class = "p" > )< / span >
< span class = "k" > return< / span > < span class = "n" > loss< / span >
< span class = "k" > def< / span > < span class = "nf" > _extract_features< / span > < span class = "p" > (< / span > < span class = "bp" > self< / span > < span class = "p" > ,< / span > < span class = "nb" > input< / span > < span class = "p" > ):< / span >
< span class = "bp" > self< / span > < span class = "o" > .< / span > < span class = "n" > model< / span > < span class = "o" > .< / span > < span class = "n" > eval< / span > < span class = "p" > ()< / span >
< span class = "k" > return< / span > < span class = "bp" > self< / span > < span class = "o" > .< / span > < span class = "n" > model< / span > < span class = "p" > (< / span > < span class = "nb" > input< / span > < span class = "p" > )< / span >
< span class = "k" > def< / span > < span class = "nf" > _parse_data_for_train< / span > < span class = "p" > (< / span > < span class = "bp" > self< / span > < span class = "p" > ,< / span > < span class = "n" > data< / span > < span class = "p" > ):< / span >
< span class = "n" > imgs< / span > < span class = "o" > =< / span > < span class = "n" > data< / span > < span class = "p" > [< / span > < span class = "mi" > 0< / span > < span class = "p" > ]< / span >
< span class = "n" > pids< / span > < span class = "o" > =< / span > < span class = "n" > data< / span > < span class = "p" > [< / span > < span class = "mi" > 1< / span > < span class = "p" > ]< / span >
< span class = "k" > return< / span > < span class = "n" > imgs< / span > < span class = "p" > ,< / span > < span class = "n" > pids< / span >
< span class = "k" > def< / span > < span class = "nf" > _parse_data_for_eval< / span > < span class = "p" > (< / span > < span class = "bp" > self< / span > < span class = "p" > ,< / span > < span class = "n" > data< / span > < span class = "p" > ):< / span >
< span class = "n" > imgs< / span > < span class = "o" > =< / span > < span class = "n" > data< / span > < span class = "p" > [< / span > < span class = "mi" > 0< / span > < span class = "p" > ]< / span >
< span class = "n" > pids< / span > < span class = "o" > =< / span > < span class = "n" > data< / span > < span class = "p" > [< / span > < span class = "mi" > 1< / span > < span class = "p" > ]< / span >
< span class = "n" > camids< / span > < span class = "o" > =< / span > < span class = "n" > data< / span > < span class = "p" > [< / span > < span class = "mi" > 2< / span > < span class = "p" > ]< / span >
< span class = "k" > return< / span > < span class = "n" > imgs< / span > < span class = "p" > ,< / span > < span class = "n" > pids< / span > < span class = "p" > ,< / span > < span class = "n" > camids< / span >
< span class = "k" > def< / span > < span class = "nf" > _save_checkpoint< / span > < span class = "p" > (< / span > < span class = "bp" > self< / span > < span class = "p" > ,< / span > < span class = "n" > epoch< / span > < span class = "p" > ,< / span > < span class = "n" > rank1< / span > < span class = "p" > ,< / span > < span class = "n" > save_dir< / span > < span class = "p" > ,< / span > < span class = "n" > is_best< / span > < span class = "o" > =< / span > < span class = "kc" > False< / span > < span class = "p" > ):< / span >
< span class = "n" > save_checkpoint< / span > < span class = "p" > ({< / span >
< span class = "s1" > ' state_dict' < / span > < span class = "p" > :< / span > < span class = "bp" > self< / span > < span class = "o" > .< / span > < span class = "n" > model< / span > < span class = "o" > .< / span > < span class = "n" > state_dict< / span > < span class = "p" > (),< / span >
< span class = "s1" > ' epoch' < / span > < span class = "p" > :< / span > < span class = "n" > epoch< / span > < span class = "o" > +< / span > < span class = "mi" > 1< / span > < span class = "p" > ,< / span >
< span class = "s1" > ' rank1' < / span > < span class = "p" > :< / span > < span class = "n" > rank1< / span > < span class = "p" > ,< / span >
< span class = "s1" > ' optimizer' < / span > < span class = "p" > :< / span > < span class = "bp" > self< / span > < span class = "o" > .< / span > < span class = "n" > optimizer< / span > < span class = "o" > .< / span > < span class = "n" > state_dict< / span > < span class = "p" > (),< / span >
< span class = "p" > },< / span > < span class = "n" > save_dir< / span > < span class = "p" > ,< / span > < span class = "n" > is_best< / span > < span class = "o" > =< / span > < span class = "n" > is_best< / span > < span class = "p" > )< / span > < / div >
< / pre > < / div >
< / div >
< / div >
< footer >
< hr / >
< div role = "contentinfo" >
< p >
© Copyright 2019, Kaiyang Zhou
< / p >
< / div >
Built with < a href = "http://sphinx-doc.org/" > Sphinx< / a > using a < a href = "https://github.com/rtfd/sphinx_rtd_theme" > theme< / a > provided by < a href = "https://readthedocs.org" > Read the Docs< / a > .
< / footer >
< / div >
< / div >
< / section >
< / div >
< script type = "text/javascript" >
jQuery(function () {
SphinxRtdTheme.Navigation.enable(true);
});
< / script >
< / body >
< / html >