2019-03-24 17:22:43 +00:00
<!DOCTYPE html>
<!-- [if IE 8]><html class="no - js lt - ie9" lang="en" > <![endif] -->
<!-- [if gt IE 8]><! --> < html class = "no-js" lang = "en" > <!-- <![endif] -->
< head >
< meta charset = "utf-8" >
< meta name = "viewport" content = "width=device-width, initial-scale=1.0" >
2019-09-01 15:05:27 +01:00
< title > torchreid.engine — torchreid 1.0.2 documentation< / title >
2019-03-24 17:22:43 +00:00
< script type = "text/javascript" src = "../_static/js/modernizr.min.js" > < / script >
< script type = "text/javascript" id = "documentation_options" data-url_root = "../" src = "../_static/documentation_options.js" > < / script >
< script type = "text/javascript" src = "../_static/jquery.js" > < / script >
< script type = "text/javascript" src = "../_static/underscore.js" > < / script >
< script type = "text/javascript" src = "../_static/doctools.js" > < / script >
< script type = "text/javascript" src = "../_static/language_data.js" > < / script >
< script async = "async" type = "text/javascript" src = "https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/latest.js?config=TeX-AMS-MML_HTMLorMML" > < / script >
< script type = "text/javascript" src = "../_static/js/theme.js" > < / script >
< link rel = "stylesheet" href = "../_static/css/theme.css" type = "text/css" / >
< link rel = "stylesheet" href = "../_static/pygments.css" type = "text/css" / >
< link rel = "index" title = "Index" href = "../genindex.html" / >
< link rel = "search" title = "Search" href = "../search.html" / >
< link rel = "next" title = "torchreid.losses" href = "losses.html" / >
< link rel = "prev" title = "torchreid.data" href = "data.html" / >
< / head >
< body class = "wy-body-for-nav" >
< div class = "wy-grid-for-nav" >
< nav data-toggle = "wy-nav-shift" class = "wy-nav-side" >
< div class = "wy-side-scroll" >
< div class = "wy-side-nav-search" >
< a href = "../index.html" class = "icon icon-home" > torchreid
< / a >
< div class = "version" >
2019-09-01 15:05:27 +01:00
1.0.2
2019-03-24 17:22:43 +00:00
< / div >
< div role = "search" >
< form id = "rtd-search-form" class = "wy-form" action = "../search.html" method = "get" >
< input type = "text" name = "q" placeholder = "Search docs" / >
< input type = "hidden" name = "check_keywords" value = "yes" / >
< input type = "hidden" name = "area" value = "default" / >
< / form >
< / div >
< / div >
< div class = "wy-menu wy-menu-vertical" data-spy = "affix" role = "navigation" aria-label = "main navigation" >
< ul >
< li class = "toctree-l1" > < a class = "reference internal" href = "../user_guide.html" > How-to< / a > < / li >
< li class = "toctree-l1" > < a class = "reference internal" href = "../datasets.html" > Datasets< / a > < / li >
< li class = "toctree-l1" > < a class = "reference internal" href = "../evaluation.html" > Evaluation< / a > < / li >
< / ul >
< p class = "caption" > < span class = "caption-text" > Package Reference< / span > < / p >
< ul class = "current" >
< li class = "toctree-l1" > < a class = "reference internal" href = "data.html" > torchreid.data< / a > < / li >
< li class = "toctree-l1 current" > < a class = "current reference internal" href = "#" > torchreid.engine< / a > < ul >
< li class = "toctree-l2" > < a class = "reference internal" href = "#base-engine" > Base Engine< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "#image-engines" > Image Engines< / a > < / li >
< li class = "toctree-l2" > < a class = "reference internal" href = "#video-engines" > Video Engines< / a > < / li >
< / ul >
< / li >
< li class = "toctree-l1" > < a class = "reference internal" href = "losses.html" > torchreid.losses< / a > < / li >
< li class = "toctree-l1" > < a class = "reference internal" href = "metrics.html" > torchreid.metrics< / a > < / li >
< li class = "toctree-l1" > < a class = "reference internal" href = "models.html" > torchreid.models< / a > < / li >
< li class = "toctree-l1" > < a class = "reference internal" href = "optim.html" > torchreid.optim< / a > < / li >
< li class = "toctree-l1" > < a class = "reference internal" href = "utils.html" > torchreid.utils< / a > < / li >
< / ul >
< p class = "caption" > < span class = "caption-text" > Resources< / span > < / p >
< ul >
< li class = "toctree-l1" > < a class = "reference internal" href = "../AWESOME_REID.html" > Awesome-ReID< / a > < / li >
< li class = "toctree-l1" > < a class = "reference internal" href = "../MODEL_ZOO.html" > Model Zoo< / a > < / li >
< / ul >
< / div >
< / div >
< / nav >
< section data-toggle = "wy-nav-shift" class = "wy-nav-content-wrap" >
< nav class = "wy-nav-top" aria-label = "top navigation" >
< i data-toggle = "wy-nav-top" class = "fa fa-bars" > < / i >
< a href = "../index.html" > torchreid< / a >
< / nav >
< div class = "wy-nav-content" >
< div class = "rst-content" >
< div role = "navigation" aria-label = "breadcrumbs navigation" >
< ul class = "wy-breadcrumbs" >
< li > < a href = "../index.html" > Docs< / a > » < / li >
< li > torchreid.engine< / li >
< li class = "wy-breadcrumbs-aside" >
< a href = "../_sources/pkg/engine.rst.txt" rel = "nofollow" > View page source< / a >
< / li >
< / ul >
< hr / >
< / div >
< div role = "main" class = "document" itemscope = "itemscope" itemtype = "http://schema.org/Article" >
< div itemprop = "articleBody" >
< div class = "section" id = "torchreid-engine" >
< span id = "id1" > < / span > < h1 > torchreid.engine< a class = "headerlink" href = "#torchreid-engine" title = "Permalink to this headline" > ¶< / a > < / h1 >
< div class = "section" id = "base-engine" >
< h2 > Base Engine< a class = "headerlink" href = "#base-engine" title = "Permalink to this headline" > ¶< / a > < / h2 >
< dl class = "class" >
< dt id = "torchreid.engine.engine.Engine" >
2019-09-01 15:05:27 +01:00
< em class = "property" > class < / em > < code class = "sig-prename descclassname" > torchreid.engine.engine.< / code > < code class = "sig-name descname" > Engine< / code > < span class = "sig-paren" > (< / span > < em class = "sig-param" > datamanager< / em > , < em class = "sig-param" > model< / em > , < em class = "sig-param" > optimizer=None< / em > , < em class = "sig-param" > scheduler=None< / em > , < em class = "sig-param" > use_gpu=True< / em > < span class = "sig-paren" > )< / span > < a class = "reference internal" href = "../_modules/torchreid/engine/engine.html#Engine" > < span class = "viewcode-link" > [source]< / span > < / a > < a class = "headerlink" href = "#torchreid.engine.engine.Engine" title = "Permalink to this definition" > ¶< / a > < / dt >
2019-03-24 17:22:43 +00:00
< dd > < p > A generic base Engine class for both image- and video-reid.< / p >
2019-09-01 15:05:27 +01:00
< dl class = "field-list simple" >
< dt class = "field-odd" > Parameters< / dt >
< dd class = "field-odd" > < ul class = "simple" >
< li > < p > < strong > datamanager< / strong > (< a class = "reference internal" href = "data.html#torchreid.data.datamanager.DataManager" title = "torchreid.data.datamanager.DataManager" > < em > DataManager< / em > < / a > ) – an instance of < code class = "docutils literal notranslate" > < span class = "pre" > torchreid.data.ImageDataManager< / span > < / code >
or < code class = "docutils literal notranslate" > < span class = "pre" > torchreid.data.VideoDataManager< / span > < / code > .< / p > < / li >
< li > < p > < strong > model< / strong > (< em > nn.Module< / em > ) – model instance.< / p > < / li >
< li > < p > < strong > optimizer< / strong > (< em > Optimizer< / em > ) – an Optimizer.< / p > < / li >
< li > < p > < strong > scheduler< / strong > (< em > LRScheduler< / em > < em > , < / em > < em > optional< / em > ) – if None, no learning rate decay will be performed.< / p > < / li >
< li > < p > < strong > use_gpu< / strong > (< em > bool< / em > < em > , < / em > < em > optional< / em > ) – use gpu. Default is True.< / p > < / li >
2019-03-24 17:22:43 +00:00
< / ul >
2019-09-01 15:05:27 +01:00
< / dd >
< / dl >
2019-03-24 17:22:43 +00:00
< dl class = "method" >
< dt id = "torchreid.engine.engine.Engine.run" >
2019-09-01 15:05:27 +01:00
< code class = "sig-name descname" > run< / code > < span class = "sig-paren" > (< / span > < em class = "sig-param" > save_dir='log', max_epoch=0, start_epoch=0, fixbase_epoch=0, open_layers=None, start_eval=0, eval_freq=-1, test_only=False, print_freq=10, dist_metric='euclidean', normalize_feature=False, visrank=False, visrank_topk=10, use_metric_cuhk03=False, ranks=[1, 5, 10, 20], rerank=False, visactmap=False< / em > < span class = "sig-paren" > )< / span > < a class = "reference internal" href = "../_modules/torchreid/engine/engine.html#Engine.run" > < span class = "viewcode-link" > [source]< / span > < / a > < a class = "headerlink" href = "#torchreid.engine.engine.Engine.run" title = "Permalink to this definition" > ¶< / a > < / dt >
2019-03-24 17:22:43 +00:00
< dd > < p > A unified pipeline for training and evaluating a model.< / p >
2019-09-01 15:05:27 +01:00
< dl class = "field-list simple" >
< dt class = "field-odd" > Parameters< / dt >
< dd class = "field-odd" > < ul class = "simple" >
< li > < p > < strong > save_dir< / strong > (< em > str< / em > ) – directory to save model.< / p > < / li >
< li > < p > < strong > max_epoch< / strong > (< em > int< / em > ) – maximum epoch.< / p > < / li >
< li > < p > < strong > start_epoch< / strong > (< em > int< / em > < em > , < / em > < em > optional< / em > ) – starting epoch. Default is 0.< / p > < / li >
< li > < p > < strong > fixbase_epoch< / strong > (< em > int< / em > < em > , < / em > < em > optional< / em > ) – number of epochs to train < code class = "docutils literal notranslate" > < span class = "pre" > open_layers< / span > < / code > (new layers)
2019-05-24 16:30:24 +01:00
while keeping base layers frozen. Default is 0. < code class = "docutils literal notranslate" > < span class = "pre" > fixbase_epoch< / span > < / code > is counted
2019-09-01 15:05:27 +01:00
in < code class = "docutils literal notranslate" > < span class = "pre" > max_epoch< / span > < / code > .< / p > < / li >
< li > < p > < strong > open_layers< / strong > (< em > str< / em > < em > or < / em > < em > list< / em > < em > , < / em > < em > optional< / em > ) – layers (attribute names) open for training.< / p > < / li >
< li > < p > < strong > start_eval< / strong > (< em > int< / em > < em > , < / em > < em > optional< / em > ) – from which epoch to start evaluation. Default is 0.< / p > < / li >
< li > < p > < strong > eval_freq< / strong > (< em > int< / em > < em > , < / em > < em > optional< / em > ) – evaluation frequency. Default is -1 (meaning evaluation
is only performed at the end of training).< / p > < / li >
< li > < p > < strong > test_only< / strong > (< em > bool< / em > < em > , < / em > < em > optional< / em > ) – if True, only runs evaluation on test datasets.
Default is False.< / p > < / li >
< li > < p > < strong > print_freq< / strong > (< em > int< / em > < em > , < / em > < em > optional< / em > ) – print_frequency. Default is 10.< / p > < / li >
< li > < p > < strong > dist_metric< / strong > (< em > str< / em > < em > , < / em > < em > optional< / em > ) – distance metric used to compute distance matrix
between query and gallery. Default is “euclidean”.< / p > < / li >
< li > < p > < strong > normalize_feature< / strong > (< em > bool< / em > < em > , < / em > < em > optional< / em > ) – performs L2 normalization on feature vectors before
computing feature distance. Default is False.< / p > < / li >
< li > < p > < strong > visrank< / strong > (< em > bool< / em > < em > , < / em > < em > optional< / em > ) – visualizes ranked results. Default is False. It is recommended to
2019-08-03 23:16:36 +01:00
enable < code class = "docutils literal notranslate" > < span class = "pre" > visrank< / span > < / code > when < code class = "docutils literal notranslate" > < span class = "pre" > test_only< / span > < / code > is True. The ranked images will be saved to
2019-09-01 15:05:27 +01:00
“save_dir/visrank_dataset”, e.g. “save_dir/visrank_market1501”.< / p > < / li >
< li > < p > < strong > visrank_topk< / strong > (< em > int< / em > < em > , < / em > < em > optional< / em > ) – top-k ranked images to be visualized. Default is 10.< / p > < / li >
< li > < p > < strong > use_metric_cuhk03< / strong > (< em > bool< / em > < em > , < / em > < em > optional< / em > ) – use single-gallery-shot setting for cuhk03.
Default is False. This should be enabled when using cuhk03 classic split.< / p > < / li >
< li > < p > < strong > ranks< / strong > (< em > list< / em > < em > , < / em > < em > optional< / em > ) – cmc ranks to be computed. Default is [1, 5, 10, 20].< / p > < / li >
< li > < p > < strong > rerank< / strong > (< em > bool< / em > < em > , < / em > < em > optional< / em > ) – uses person re-ranking (by Zhong et al. CVPR’ 17).
Default is False. This is only enabled when test_only=True.< / p > < / li >
< li > < p > < strong > visactmap< / strong > (< em > bool< / em > < em > , < / em > < em > optional< / em > ) – visualizes activation maps. Default is False.< / p > < / li >
2019-03-24 17:22:43 +00:00
< / ul >
2019-09-01 15:05:27 +01:00
< / dd >
< / dl >
2019-03-24 17:22:43 +00:00
< / dd > < / dl >
< dl class = "method" >
< dt id = "torchreid.engine.engine.Engine.test" >
2019-09-01 15:05:27 +01:00
< code class = "sig-name descname" > test< / code > < span class = "sig-paren" > (< / span > < em class = "sig-param" > epoch, testloader, dist_metric='euclidean', normalize_feature=False, visrank=False, visrank_topk=10, save_dir='', use_metric_cuhk03=False, ranks=[1, 5, 10, 20], rerank=False< / em > < span class = "sig-paren" > )< / span > < a class = "reference internal" href = "../_modules/torchreid/engine/engine.html#Engine.test" > < span class = "viewcode-link" > [source]< / span > < / a > < a class = "headerlink" href = "#torchreid.engine.engine.Engine.test" title = "Permalink to this definition" > ¶< / a > < / dt >
2019-03-24 17:22:43 +00:00
< dd > < p > Tests model on target datasets.< / p >
< div class = "admonition note" >
2019-09-01 15:05:27 +01:00
< p class = "admonition-title" > Note< / p >
< p > This function has been called in < code class = "docutils literal notranslate" > < span class = "pre" > run()< / span > < / code > .< / p >
2019-03-24 17:22:43 +00:00
< / div >
< div class = "admonition note" >
2019-09-01 15:05:27 +01:00
< p class = "admonition-title" > Note< / p >
< p > The test pipeline implemented in this function suits both image- and
2019-03-24 17:22:43 +00:00
video-reid. In general, a subclass of Engine only needs to re-implement
2019-08-03 23:16:36 +01:00
< code class = "docutils literal notranslate" > < span class = "pre" > _extract_features()< / span > < / code > and < code class = "docutils literal notranslate" > < span class = "pre" > _parse_data_for_eval()< / span > < / code > (most of the time),
2019-03-24 17:22:43 +00:00
but not a must. Please refer to the source code for more details.< / p >
< / div >
< / dd > < / dl >
< dl class = "method" >
< dt id = "torchreid.engine.engine.Engine.train" >
2019-09-01 15:05:27 +01:00
< code class = "sig-name descname" > train< / code > < span class = "sig-paren" > (< / span > < span class = "sig-paren" > )< / span > < a class = "reference internal" href = "../_modules/torchreid/engine/engine.html#Engine.train" > < span class = "viewcode-link" > [source]< / span > < / a > < a class = "headerlink" href = "#torchreid.engine.engine.Engine.train" title = "Permalink to this definition" > ¶< / a > < / dt >
2019-03-24 17:22:43 +00:00
< dd > < p > Performs training on source datasets for one epoch.< / p >
< p > This will be called every epoch in < code class = "docutils literal notranslate" > < span class = "pre" > run()< / span > < / code > , e.g.< / p >
< div class = "highlight-python notranslate" > < div class = "highlight" > < pre > < span > < / span > < span class = "k" > for< / span > < span class = "n" > epoch< / span > < span class = "ow" > in< / span > < span class = "nb" > range< / span > < span class = "p" > (< / span > < span class = "n" > start_epoch< / span > < span class = "p" > ,< / span > < span class = "n" > max_epoch< / span > < span class = "p" > ):< / span >
< span class = "bp" > self< / span > < span class = "o" > .< / span > < span class = "n" > train< / span > < span class = "p" > (< / span > < span class = "n" > some_arguments< / span > < span class = "p" > )< / span >
< / pre > < / div >
< / div >
< div class = "admonition note" >
2019-09-01 15:05:27 +01:00
< p class = "admonition-title" > Note< / p >
< p > This must be implemented in subclasses.< / p >
2019-03-24 17:22:43 +00:00
< / div >
< / dd > < / dl >
2019-08-03 23:16:36 +01:00
< dl class = "method" >
< dt id = "torchreid.engine.engine.Engine.visactmap" >
2019-09-01 15:05:27 +01:00
< code class = "sig-name descname" > visactmap< / code > < span class = "sig-paren" > (< / span > < em class = "sig-param" > testloader< / em > , < em class = "sig-param" > save_dir< / em > , < em class = "sig-param" > width< / em > , < em class = "sig-param" > height< / em > , < em class = "sig-param" > print_freq< / em > < span class = "sig-paren" > )< / span > < a class = "reference internal" href = "../_modules/torchreid/engine/engine.html#Engine.visactmap" > < span class = "viewcode-link" > [source]< / span > < / a > < a class = "headerlink" href = "#torchreid.engine.engine.Engine.visactmap" title = "Permalink to this definition" > ¶< / a > < / dt >
2019-08-03 23:16:36 +01:00
< dd > < p > Visualizes CNN activation maps to see where the CNN focuses on to extract features.< / p >
< p > This function takes as input the query images of target datasets< / p >
2019-09-01 15:05:27 +01:00
< dl class = "simple" >
< dt > Reference:< / dt > < dd > < ul class = "simple" >
< li > < p > Zagoruyko and Komodakis. Paying more attention to attention: Improving the
performance of convolutional neural networks via attention transfer. ICLR, 2017< / p > < / li >
< li > < p > Zhou et al. Omni-Scale Feature Learning for Person Re-Identification. ICCV, 2019.< / p > < / li >
2019-08-03 23:16:36 +01:00
< / ul >
< / dd >
< / dl >
< / dd > < / dl >
2019-03-24 17:22:43 +00:00
< / dd > < / dl >
< / div >
< div class = "section" id = "image-engines" >
< h2 > Image Engines< a class = "headerlink" href = "#image-engines" title = "Permalink to this headline" > ¶< / a > < / h2 >
< dl class = "class" >
< dt id = "torchreid.engine.image.softmax.ImageSoftmaxEngine" >
2019-09-01 15:05:27 +01:00
< em class = "property" > class < / em > < code class = "sig-prename descclassname" > torchreid.engine.image.softmax.< / code > < code class = "sig-name descname" > ImageSoftmaxEngine< / code > < span class = "sig-paren" > (< / span > < em class = "sig-param" > datamanager< / em > , < em class = "sig-param" > model< / em > , < em class = "sig-param" > optimizer< / em > , < em class = "sig-param" > scheduler=None< / em > , < em class = "sig-param" > use_gpu=True< / em > , < em class = "sig-param" > label_smooth=True< / em > < span class = "sig-paren" > )< / span > < a class = "reference internal" href = "../_modules/torchreid/engine/image/softmax.html#ImageSoftmaxEngine" > < span class = "viewcode-link" > [source]< / span > < / a > < a class = "headerlink" href = "#torchreid.engine.image.softmax.ImageSoftmaxEngine" title = "Permalink to this definition" > ¶< / a > < / dt >
2019-03-24 17:22:43 +00:00
< dd > < p > Softmax-loss engine for image-reid.< / p >
2019-09-01 15:05:27 +01:00
< dl class = "field-list simple" >
< dt class = "field-odd" > Parameters< / dt >
< dd class = "field-odd" > < ul class = "simple" >
< li > < p > < strong > datamanager< / strong > (< a class = "reference internal" href = "data.html#torchreid.data.datamanager.DataManager" title = "torchreid.data.datamanager.DataManager" > < em > DataManager< / em > < / a > ) – an instance of < code class = "docutils literal notranslate" > < span class = "pre" > torchreid.data.ImageDataManager< / span > < / code >
or < code class = "docutils literal notranslate" > < span class = "pre" > torchreid.data.VideoDataManager< / span > < / code > .< / p > < / li >
< li > < p > < strong > model< / strong > (< em > nn.Module< / em > ) – model instance.< / p > < / li >
< li > < p > < strong > optimizer< / strong > (< em > Optimizer< / em > ) – an Optimizer.< / p > < / li >
< li > < p > < strong > scheduler< / strong > (< em > LRScheduler< / em > < em > , < / em > < em > optional< / em > ) – if None, no learning rate decay will be performed.< / p > < / li >
< li > < p > < strong > use_gpu< / strong > (< em > bool< / em > < em > , < / em > < em > optional< / em > ) – use gpu. Default is True.< / p > < / li >
< li > < p > < strong > label_smooth< / strong > (< em > bool< / em > < em > , < / em > < em > optional< / em > ) – use label smoothing regularizer. Default is True.< / p > < / li >
2019-03-24 17:22:43 +00:00
< / ul >
2019-09-01 15:05:27 +01:00
< / dd >
< / dl >
2019-03-24 17:22:43 +00:00
< p > Examples:< / p >
< div class = "highlight-default notranslate" > < div class = "highlight" > < pre > < span > < / span > < span class = "kn" > import< / span > < span class = "nn" > torch< / span >
< span class = "kn" > import< / span > < span class = "nn" > torchreid< / span >
< span class = "n" > datamanager< / span > < span class = "o" > =< / span > < span class = "n" > torchreid< / span > < span class = "o" > .< / span > < span class = "n" > data< / span > < span class = "o" > .< / span > < span class = "n" > ImageDataManager< / span > < span class = "p" > (< / span >
< span class = "n" > root< / span > < span class = "o" > =< / span > < span class = "s1" > ' path/to/reid-data' < / span > < span class = "p" > ,< / span >
< span class = "n" > sources< / span > < span class = "o" > =< / span > < span class = "s1" > ' market1501' < / span > < span class = "p" > ,< / span >
< span class = "n" > height< / span > < span class = "o" > =< / span > < span class = "mi" > 256< / span > < span class = "p" > ,< / span >
< span class = "n" > width< / span > < span class = "o" > =< / span > < span class = "mi" > 128< / span > < span class = "p" > ,< / span >
< span class = "n" > combineall< / span > < span class = "o" > =< / span > < span class = "kc" > False< / span > < span class = "p" > ,< / span >
< span class = "n" > batch_size< / span > < span class = "o" > =< / span > < span class = "mi" > 32< / span >
< span class = "p" > )< / span >
< span class = "n" > model< / span > < span class = "o" > =< / span > < span class = "n" > torchreid< / span > < span class = "o" > .< / span > < span class = "n" > models< / span > < span class = "o" > .< / span > < span class = "n" > build_model< / span > < span class = "p" > (< / span >
< span class = "n" > name< / span > < span class = "o" > =< / span > < span class = "s1" > ' resnet50' < / span > < span class = "p" > ,< / span >
< span class = "n" > num_classes< / span > < span class = "o" > =< / span > < span class = "n" > datamanager< / span > < span class = "o" > .< / span > < span class = "n" > num_train_pids< / span > < span class = "p" > ,< / span >
< span class = "n" > loss< / span > < span class = "o" > =< / span > < span class = "s1" > ' softmax' < / span >
< span class = "p" > )< / span >
< span class = "n" > model< / span > < span class = "o" > =< / span > < span class = "n" > model< / span > < span class = "o" > .< / span > < span class = "n" > cuda< / span > < span class = "p" > ()< / span >
< span class = "n" > optimizer< / span > < span class = "o" > =< / span > < span class = "n" > torchreid< / span > < span class = "o" > .< / span > < span class = "n" > optim< / span > < span class = "o" > .< / span > < span class = "n" > build_optimizer< / span > < span class = "p" > (< / span >
< span class = "n" > model< / span > < span class = "p" > ,< / span > < span class = "n" > optim< / span > < span class = "o" > =< / span > < span class = "s1" > ' adam' < / span > < span class = "p" > ,< / span > < span class = "n" > lr< / span > < span class = "o" > =< / span > < span class = "mf" > 0.0003< / span >
< span class = "p" > )< / span >
< span class = "n" > scheduler< / span > < span class = "o" > =< / span > < span class = "n" > torchreid< / span > < span class = "o" > .< / span > < span class = "n" > optim< / span > < span class = "o" > .< / span > < span class = "n" > build_lr_scheduler< / span > < span class = "p" > (< / span >
< span class = "n" > optimizer< / span > < span class = "p" > ,< / span >
< span class = "n" > lr_scheduler< / span > < span class = "o" > =< / span > < span class = "s1" > ' single_step' < / span > < span class = "p" > ,< / span >
< span class = "n" > stepsize< / span > < span class = "o" > =< / span > < span class = "mi" > 20< / span >
< span class = "p" > )< / span >
< span class = "n" > engine< / span > < span class = "o" > =< / span > < span class = "n" > torchreid< / span > < span class = "o" > .< / span > < span class = "n" > engine< / span > < span class = "o" > .< / span > < span class = "n" > ImageSoftmaxEngine< / span > < span class = "p" > (< / span >
< span class = "n" > datamanager< / span > < span class = "p" > ,< / span > < span class = "n" > model< / span > < span class = "p" > ,< / span > < span class = "n" > optimizer< / span > < span class = "p" > ,< / span > < span class = "n" > scheduler< / span > < span class = "o" > =< / span > < span class = "n" > scheduler< / span >
< span class = "p" > )< / span >
< span class = "n" > engine< / span > < span class = "o" > .< / span > < span class = "n" > run< / span > < span class = "p" > (< / span >
< span class = "n" > max_epoch< / span > < span class = "o" > =< / span > < span class = "mi" > 60< / span > < span class = "p" > ,< / span >
< span class = "n" > save_dir< / span > < span class = "o" > =< / span > < span class = "s1" > ' log/resnet50-softmax-market1501' < / span > < span class = "p" > ,< / span >
< span class = "n" > print_freq< / span > < span class = "o" > =< / span > < span class = "mi" > 10< / span >
< span class = "p" > )< / span >
< / pre > < / div >
< / div >
< dl class = "method" >
< dt id = "torchreid.engine.image.softmax.ImageSoftmaxEngine.train" >
2019-09-01 15:05:27 +01:00
< code class = "sig-name descname" > train< / code > < span class = "sig-paren" > (< / span > < em class = "sig-param" > epoch< / em > , < em class = "sig-param" > max_epoch< / em > , < em class = "sig-param" > trainloader< / em > , < em class = "sig-param" > fixbase_epoch=0< / em > , < em class = "sig-param" > open_layers=None< / em > , < em class = "sig-param" > print_freq=10< / em > < span class = "sig-paren" > )< / span > < a class = "reference internal" href = "../_modules/torchreid/engine/image/softmax.html#ImageSoftmaxEngine.train" > < span class = "viewcode-link" > [source]< / span > < / a > < a class = "headerlink" href = "#torchreid.engine.image.softmax.ImageSoftmaxEngine.train" title = "Permalink to this definition" > ¶< / a > < / dt >
2019-05-24 16:30:24 +01:00
< dd > < p > Performs training on source datasets for one epoch.< / p >
< p > This will be called every epoch in < code class = "docutils literal notranslate" > < span class = "pre" > run()< / span > < / code > , e.g.< / p >
< div class = "highlight-python notranslate" > < div class = "highlight" > < pre > < span > < / span > < span class = "k" > for< / span > < span class = "n" > epoch< / span > < span class = "ow" > in< / span > < span class = "nb" > range< / span > < span class = "p" > (< / span > < span class = "n" > start_epoch< / span > < span class = "p" > ,< / span > < span class = "n" > max_epoch< / span > < span class = "p" > ):< / span >
< span class = "bp" > self< / span > < span class = "o" > .< / span > < span class = "n" > train< / span > < span class = "p" > (< / span > < span class = "n" > some_arguments< / span > < span class = "p" > )< / span >
< / pre > < / div >
< / div >
< div class = "admonition note" >
2019-09-01 15:05:27 +01:00
< p class = "admonition-title" > Note< / p >
< p > This must be implemented in subclasses.< / p >
2019-05-24 16:30:24 +01:00
< / div >
2019-03-24 17:22:43 +00:00
< / dd > < / dl >
< / dd > < / dl >
< dl class = "class" >
< dt id = "torchreid.engine.image.triplet.ImageTripletEngine" >
2019-09-01 15:05:27 +01:00
< em class = "property" > class < / em > < code class = "sig-prename descclassname" > torchreid.engine.image.triplet.< / code > < code class = "sig-name descname" > ImageTripletEngine< / code > < span class = "sig-paren" > (< / span > < em class = "sig-param" > datamanager< / em > , < em class = "sig-param" > model< / em > , < em class = "sig-param" > optimizer< / em > , < em class = "sig-param" > margin=0.3< / em > , < em class = "sig-param" > weight_t=1< / em > , < em class = "sig-param" > weight_x=1< / em > , < em class = "sig-param" > scheduler=None< / em > , < em class = "sig-param" > use_gpu=True< / em > , < em class = "sig-param" > label_smooth=True< / em > < span class = "sig-paren" > )< / span > < a class = "reference internal" href = "../_modules/torchreid/engine/image/triplet.html#ImageTripletEngine" > < span class = "viewcode-link" > [source]< / span > < / a > < a class = "headerlink" href = "#torchreid.engine.image.triplet.ImageTripletEngine" title = "Permalink to this definition" > ¶< / a > < / dt >
2019-03-24 17:22:43 +00:00
< dd > < p > Triplet-loss engine for image-reid.< / p >
2019-09-01 15:05:27 +01:00
< dl class = "field-list simple" >
< dt class = "field-odd" > Parameters< / dt >
< dd class = "field-odd" > < ul class = "simple" >
< li > < p > < strong > datamanager< / strong > (< a class = "reference internal" href = "data.html#torchreid.data.datamanager.DataManager" title = "torchreid.data.datamanager.DataManager" > < em > DataManager< / em > < / a > ) – an instance of < code class = "docutils literal notranslate" > < span class = "pre" > torchreid.data.ImageDataManager< / span > < / code >
or < code class = "docutils literal notranslate" > < span class = "pre" > torchreid.data.VideoDataManager< / span > < / code > .< / p > < / li >
< li > < p > < strong > model< / strong > (< em > nn.Module< / em > ) – model instance.< / p > < / li >
< li > < p > < strong > optimizer< / strong > (< em > Optimizer< / em > ) – an Optimizer.< / p > < / li >
< li > < p > < strong > margin< / strong > (< em > float< / em > < em > , < / em > < em > optional< / em > ) – margin for triplet loss. Default is 0.3.< / p > < / li >
< li > < p > < strong > weight_t< / strong > (< em > float< / em > < em > , < / em > < em > optional< / em > ) – weight for triplet loss. Default is 1.< / p > < / li >
< li > < p > < strong > weight_x< / strong > (< em > float< / em > < em > , < / em > < em > optional< / em > ) – weight for softmax loss. Default is 1.< / p > < / li >
< li > < p > < strong > scheduler< / strong > (< em > LRScheduler< / em > < em > , < / em > < em > optional< / em > ) – if None, no learning rate decay will be performed.< / p > < / li >
< li > < p > < strong > use_gpu< / strong > (< em > bool< / em > < em > , < / em > < em > optional< / em > ) – use gpu. Default is True.< / p > < / li >
< li > < p > < strong > label_smooth< / strong > (< em > bool< / em > < em > , < / em > < em > optional< / em > ) – use label smoothing regularizer. Default is True.< / p > < / li >
2019-03-24 17:22:43 +00:00
< / ul >
2019-09-01 15:05:27 +01:00
< / dd >
< / dl >
2019-03-24 17:22:43 +00:00
< p > Examples:< / p >
< div class = "highlight-default notranslate" > < div class = "highlight" > < pre > < span > < / span > < span class = "kn" > import< / span > < span class = "nn" > torch< / span >
< span class = "kn" > import< / span > < span class = "nn" > torchreid< / span >
< span class = "n" > datamanager< / span > < span class = "o" > =< / span > < span class = "n" > torchreid< / span > < span class = "o" > .< / span > < span class = "n" > data< / span > < span class = "o" > .< / span > < span class = "n" > ImageDataManager< / span > < span class = "p" > (< / span >
< span class = "n" > root< / span > < span class = "o" > =< / span > < span class = "s1" > ' path/to/reid-data' < / span > < span class = "p" > ,< / span >
< span class = "n" > sources< / span > < span class = "o" > =< / span > < span class = "s1" > ' market1501' < / span > < span class = "p" > ,< / span >
< span class = "n" > height< / span > < span class = "o" > =< / span > < span class = "mi" > 256< / span > < span class = "p" > ,< / span >
< span class = "n" > width< / span > < span class = "o" > =< / span > < span class = "mi" > 128< / span > < span class = "p" > ,< / span >
< span class = "n" > combineall< / span > < span class = "o" > =< / span > < span class = "kc" > False< / span > < span class = "p" > ,< / span >
< span class = "n" > batch_size< / span > < span class = "o" > =< / span > < span class = "mi" > 32< / span > < span class = "p" > ,< / span >
< span class = "n" > num_instances< / span > < span class = "o" > =< / span > < span class = "mi" > 4< / span > < span class = "p" > ,< / span >
< span class = "n" > train_sampler< / span > < span class = "o" > =< / span > < span class = "s1" > ' RandomIdentitySampler' < / span > < span class = "c1" > # this is important< / span >
< span class = "p" > )< / span >
< span class = "n" > model< / span > < span class = "o" > =< / span > < span class = "n" > torchreid< / span > < span class = "o" > .< / span > < span class = "n" > models< / span > < span class = "o" > .< / span > < span class = "n" > build_model< / span > < span class = "p" > (< / span >
< span class = "n" > name< / span > < span class = "o" > =< / span > < span class = "s1" > ' resnet50' < / span > < span class = "p" > ,< / span >
< span class = "n" > num_classes< / span > < span class = "o" > =< / span > < span class = "n" > datamanager< / span > < span class = "o" > .< / span > < span class = "n" > num_train_pids< / span > < span class = "p" > ,< / span >
< span class = "n" > loss< / span > < span class = "o" > =< / span > < span class = "s1" > ' triplet' < / span >
< span class = "p" > )< / span >
< span class = "n" > model< / span > < span class = "o" > =< / span > < span class = "n" > model< / span > < span class = "o" > .< / span > < span class = "n" > cuda< / span > < span class = "p" > ()< / span >
< span class = "n" > optimizer< / span > < span class = "o" > =< / span > < span class = "n" > torchreid< / span > < span class = "o" > .< / span > < span class = "n" > optim< / span > < span class = "o" > .< / span > < span class = "n" > build_optimizer< / span > < span class = "p" > (< / span >
< span class = "n" > model< / span > < span class = "p" > ,< / span > < span class = "n" > optim< / span > < span class = "o" > =< / span > < span class = "s1" > ' adam' < / span > < span class = "p" > ,< / span > < span class = "n" > lr< / span > < span class = "o" > =< / span > < span class = "mf" > 0.0003< / span >
< span class = "p" > )< / span >
< span class = "n" > scheduler< / span > < span class = "o" > =< / span > < span class = "n" > torchreid< / span > < span class = "o" > .< / span > < span class = "n" > optim< / span > < span class = "o" > .< / span > < span class = "n" > build_lr_scheduler< / span > < span class = "p" > (< / span >
< span class = "n" > optimizer< / span > < span class = "p" > ,< / span >
< span class = "n" > lr_scheduler< / span > < span class = "o" > =< / span > < span class = "s1" > ' single_step' < / span > < span class = "p" > ,< / span >
< span class = "n" > stepsize< / span > < span class = "o" > =< / span > < span class = "mi" > 20< / span >
< span class = "p" > )< / span >
< span class = "n" > engine< / span > < span class = "o" > =< / span > < span class = "n" > torchreid< / span > < span class = "o" > .< / span > < span class = "n" > engine< / span > < span class = "o" > .< / span > < span class = "n" > ImageTripletEngine< / span > < span class = "p" > (< / span >
< span class = "n" > datamanager< / span > < span class = "p" > ,< / span > < span class = "n" > model< / span > < span class = "p" > ,< / span > < span class = "n" > optimizer< / span > < span class = "p" > ,< / span > < span class = "n" > margin< / span > < span class = "o" > =< / span > < span class = "mf" > 0.3< / span > < span class = "p" > ,< / span >
< span class = "n" > weight_t< / span > < span class = "o" > =< / span > < span class = "mf" > 0.7< / span > < span class = "p" > ,< / span > < span class = "n" > weight_x< / span > < span class = "o" > =< / span > < span class = "mi" > 1< / span > < span class = "p" > ,< / span > < span class = "n" > scheduler< / span > < span class = "o" > =< / span > < span class = "n" > scheduler< / span >
< span class = "p" > )< / span >
< span class = "n" > engine< / span > < span class = "o" > .< / span > < span class = "n" > run< / span > < span class = "p" > (< / span >
< span class = "n" > max_epoch< / span > < span class = "o" > =< / span > < span class = "mi" > 60< / span > < span class = "p" > ,< / span >
< span class = "n" > save_dir< / span > < span class = "o" > =< / span > < span class = "s1" > ' log/resnet50-triplet-market1501' < / span > < span class = "p" > ,< / span >
< span class = "n" > print_freq< / span > < span class = "o" > =< / span > < span class = "mi" > 10< / span >
< span class = "p" > )< / span >
< / pre > < / div >
< / div >
< dl class = "method" >
< dt id = "torchreid.engine.image.triplet.ImageTripletEngine.train" >
2019-09-01 15:05:27 +01:00
< code class = "sig-name descname" > train< / code > < span class = "sig-paren" > (< / span > < em class = "sig-param" > epoch< / em > , < em class = "sig-param" > max_epoch< / em > , < em class = "sig-param" > trainloader< / em > , < em class = "sig-param" > fixbase_epoch=0< / em > , < em class = "sig-param" > open_layers=None< / em > , < em class = "sig-param" > print_freq=10< / em > < span class = "sig-paren" > )< / span > < a class = "reference internal" href = "../_modules/torchreid/engine/image/triplet.html#ImageTripletEngine.train" > < span class = "viewcode-link" > [source]< / span > < / a > < a class = "headerlink" href = "#torchreid.engine.image.triplet.ImageTripletEngine.train" title = "Permalink to this definition" > ¶< / a > < / dt >
2019-05-24 16:30:24 +01:00
< dd > < p > Performs training on source datasets for one epoch.< / p >
< p > This will be called every epoch in < code class = "docutils literal notranslate" > < span class = "pre" > run()< / span > < / code > , e.g.< / p >
< div class = "highlight-python notranslate" > < div class = "highlight" > < pre > < span > < / span > < span class = "k" > for< / span > < span class = "n" > epoch< / span > < span class = "ow" > in< / span > < span class = "nb" > range< / span > < span class = "p" > (< / span > < span class = "n" > start_epoch< / span > < span class = "p" > ,< / span > < span class = "n" > max_epoch< / span > < span class = "p" > ):< / span >
< span class = "bp" > self< / span > < span class = "o" > .< / span > < span class = "n" > train< / span > < span class = "p" > (< / span > < span class = "n" > some_arguments< / span > < span class = "p" > )< / span >
< / pre > < / div >
< / div >
< div class = "admonition note" >
2019-09-01 15:05:27 +01:00
< p class = "admonition-title" > Note< / p >
< p > This must be implemented in subclasses.< / p >
2019-05-24 16:30:24 +01:00
< / div >
2019-03-24 17:22:43 +00:00
< / dd > < / dl >
< / dd > < / dl >
< / div >
< div class = "section" id = "video-engines" >
< h2 > Video Engines< a class = "headerlink" href = "#video-engines" title = "Permalink to this headline" > ¶< / a > < / h2 >
< dl class = "class" >
< dt id = "torchreid.engine.video.softmax.VideoSoftmaxEngine" >
2019-09-01 15:05:27 +01:00
< em class = "property" > class < / em > < code class = "sig-prename descclassname" > torchreid.engine.video.softmax.< / code > < code class = "sig-name descname" > VideoSoftmaxEngine< / code > < span class = "sig-paren" > (< / span > < em class = "sig-param" > datamanager< / em > , < em class = "sig-param" > model< / em > , < em class = "sig-param" > optimizer< / em > , < em class = "sig-param" > scheduler=None< / em > , < em class = "sig-param" > use_gpu=True< / em > , < em class = "sig-param" > label_smooth=True< / em > , < em class = "sig-param" > pooling_method='avg'< / em > < span class = "sig-paren" > )< / span > < a class = "reference internal" href = "../_modules/torchreid/engine/video/softmax.html#VideoSoftmaxEngine" > < span class = "viewcode-link" > [source]< / span > < / a > < a class = "headerlink" href = "#torchreid.engine.video.softmax.VideoSoftmaxEngine" title = "Permalink to this definition" > ¶< / a > < / dt >
2019-03-24 17:22:43 +00:00
< dd > < p > Softmax-loss engine for video-reid.< / p >
2019-09-01 15:05:27 +01:00
< dl class = "field-list simple" >
< dt class = "field-odd" > Parameters< / dt >
< dd class = "field-odd" > < ul class = "simple" >
< li > < p > < strong > datamanager< / strong > (< a class = "reference internal" href = "data.html#torchreid.data.datamanager.DataManager" title = "torchreid.data.datamanager.DataManager" > < em > DataManager< / em > < / a > ) – an instance of < code class = "docutils literal notranslate" > < span class = "pre" > torchreid.data.ImageDataManager< / span > < / code >
or < code class = "docutils literal notranslate" > < span class = "pre" > torchreid.data.VideoDataManager< / span > < / code > .< / p > < / li >
< li > < p > < strong > model< / strong > (< em > nn.Module< / em > ) – model instance.< / p > < / li >
< li > < p > < strong > optimizer< / strong > (< em > Optimizer< / em > ) – an Optimizer.< / p > < / li >
< li > < p > < strong > scheduler< / strong > (< em > LRScheduler< / em > < em > , < / em > < em > optional< / em > ) – if None, no learning rate decay will be performed.< / p > < / li >
< li > < p > < strong > use_gpu< / strong > (< em > bool< / em > < em > , < / em > < em > optional< / em > ) – use gpu. Default is True.< / p > < / li >
< li > < p > < strong > label_smooth< / strong > (< em > bool< / em > < em > , < / em > < em > optional< / em > ) – use label smoothing regularizer. Default is True.< / p > < / li >
< li > < p > < strong > pooling_method< / strong > (< em > str< / em > < em > , < / em > < em > optional< / em > ) – how to pool features for a tracklet.
Default is “avg” (average). Choices are [“avg”, “max”].< / p > < / li >
2019-03-24 17:22:43 +00:00
< / ul >
2019-09-01 15:05:27 +01:00
< / dd >
< / dl >
2019-03-24 17:22:43 +00:00
< p > Examples:< / p >
< div class = "highlight-default notranslate" > < div class = "highlight" > < pre > < span > < / span > < span class = "kn" > import< / span > < span class = "nn" > torch< / span >
< span class = "kn" > import< / span > < span class = "nn" > torchreid< / span >
< span class = "c1" > # Each batch contains batch_size*seq_len images< / span >
< span class = "n" > datamanager< / span > < span class = "o" > =< / span > < span class = "n" > torchreid< / span > < span class = "o" > .< / span > < span class = "n" > data< / span > < span class = "o" > .< / span > < span class = "n" > VideoDataManager< / span > < span class = "p" > (< / span >
< span class = "n" > root< / span > < span class = "o" > =< / span > < span class = "s1" > ' path/to/reid-data' < / span > < span class = "p" > ,< / span >
< span class = "n" > sources< / span > < span class = "o" > =< / span > < span class = "s1" > ' mars' < / span > < span class = "p" > ,< / span >
< span class = "n" > height< / span > < span class = "o" > =< / span > < span class = "mi" > 256< / span > < span class = "p" > ,< / span >
< span class = "n" > width< / span > < span class = "o" > =< / span > < span class = "mi" > 128< / span > < span class = "p" > ,< / span >
< span class = "n" > combineall< / span > < span class = "o" > =< / span > < span class = "kc" > False< / span > < span class = "p" > ,< / span >
< span class = "n" > batch_size< / span > < span class = "o" > =< / span > < span class = "mi" > 8< / span > < span class = "p" > ,< / span > < span class = "c1" > # number of tracklets< / span >
< span class = "n" > seq_len< / span > < span class = "o" > =< / span > < span class = "mi" > 15< / span > < span class = "c1" > # number of images in each tracklet< / span >
< span class = "p" > )< / span >
< span class = "n" > model< / span > < span class = "o" > =< / span > < span class = "n" > torchreid< / span > < span class = "o" > .< / span > < span class = "n" > models< / span > < span class = "o" > .< / span > < span class = "n" > build_model< / span > < span class = "p" > (< / span >
< span class = "n" > name< / span > < span class = "o" > =< / span > < span class = "s1" > ' resnet50' < / span > < span class = "p" > ,< / span >
< span class = "n" > num_classes< / span > < span class = "o" > =< / span > < span class = "n" > datamanager< / span > < span class = "o" > .< / span > < span class = "n" > num_train_pids< / span > < span class = "p" > ,< / span >
< span class = "n" > loss< / span > < span class = "o" > =< / span > < span class = "s1" > ' softmax' < / span >
< span class = "p" > )< / span >
< span class = "n" > model< / span > < span class = "o" > =< / span > < span class = "n" > model< / span > < span class = "o" > .< / span > < span class = "n" > cuda< / span > < span class = "p" > ()< / span >
< span class = "n" > optimizer< / span > < span class = "o" > =< / span > < span class = "n" > torchreid< / span > < span class = "o" > .< / span > < span class = "n" > optim< / span > < span class = "o" > .< / span > < span class = "n" > build_optimizer< / span > < span class = "p" > (< / span >
< span class = "n" > model< / span > < span class = "p" > ,< / span > < span class = "n" > optim< / span > < span class = "o" > =< / span > < span class = "s1" > ' adam' < / span > < span class = "p" > ,< / span > < span class = "n" > lr< / span > < span class = "o" > =< / span > < span class = "mf" > 0.0003< / span >
< span class = "p" > )< / span >
< span class = "n" > scheduler< / span > < span class = "o" > =< / span > < span class = "n" > torchreid< / span > < span class = "o" > .< / span > < span class = "n" > optim< / span > < span class = "o" > .< / span > < span class = "n" > build_lr_scheduler< / span > < span class = "p" > (< / span >
< span class = "n" > optimizer< / span > < span class = "p" > ,< / span >
< span class = "n" > lr_scheduler< / span > < span class = "o" > =< / span > < span class = "s1" > ' single_step' < / span > < span class = "p" > ,< / span >
< span class = "n" > stepsize< / span > < span class = "o" > =< / span > < span class = "mi" > 20< / span >
< span class = "p" > )< / span >
< span class = "n" > engine< / span > < span class = "o" > =< / span > < span class = "n" > torchreid< / span > < span class = "o" > .< / span > < span class = "n" > engine< / span > < span class = "o" > .< / span > < span class = "n" > VideoSoftmaxEngine< / span > < span class = "p" > (< / span >
< span class = "n" > datamanager< / span > < span class = "p" > ,< / span > < span class = "n" > model< / span > < span class = "p" > ,< / span > < span class = "n" > optimizer< / span > < span class = "p" > ,< / span > < span class = "n" > scheduler< / span > < span class = "o" > =< / span > < span class = "n" > scheduler< / span > < span class = "p" > ,< / span >
< span class = "n" > pooling_method< / span > < span class = "o" > =< / span > < span class = "s1" > ' avg' < / span >
< span class = "p" > )< / span >
< span class = "n" > engine< / span > < span class = "o" > .< / span > < span class = "n" > run< / span > < span class = "p" > (< / span >
< span class = "n" > max_epoch< / span > < span class = "o" > =< / span > < span class = "mi" > 60< / span > < span class = "p" > ,< / span >
< span class = "n" > save_dir< / span > < span class = "o" > =< / span > < span class = "s1" > ' log/resnet50-softmax-mars' < / span > < span class = "p" > ,< / span >
< span class = "n" > print_freq< / span > < span class = "o" > =< / span > < span class = "mi" > 10< / span >
< span class = "p" > )< / span >
< / pre > < / div >
< / div >
< / dd > < / dl >
< dl class = "class" >
< dt id = "torchreid.engine.video.triplet.VideoTripletEngine" >
2019-09-01 15:05:27 +01:00
< em class = "property" > class < / em > < code class = "sig-prename descclassname" > torchreid.engine.video.triplet.< / code > < code class = "sig-name descname" > VideoTripletEngine< / code > < span class = "sig-paren" > (< / span > < em class = "sig-param" > datamanager< / em > , < em class = "sig-param" > model< / em > , < em class = "sig-param" > optimizer< / em > , < em class = "sig-param" > margin=0.3< / em > , < em class = "sig-param" > weight_t=1< / em > , < em class = "sig-param" > weight_x=1< / em > , < em class = "sig-param" > scheduler=None< / em > , < em class = "sig-param" > use_gpu=False< / em > , < em class = "sig-param" > label_smooth=True< / em > , < em class = "sig-param" > pooling_method='avg'< / em > < span class = "sig-paren" > )< / span > < a class = "reference internal" href = "../_modules/torchreid/engine/video/triplet.html#VideoTripletEngine" > < span class = "viewcode-link" > [source]< / span > < / a > < a class = "headerlink" href = "#torchreid.engine.video.triplet.VideoTripletEngine" title = "Permalink to this definition" > ¶< / a > < / dt >
2019-03-24 17:22:43 +00:00
< dd > < p > Triplet-loss engine for video-reid.< / p >
2019-09-01 15:05:27 +01:00
< dl class = "field-list simple" >
< dt class = "field-odd" > Parameters< / dt >
< dd class = "field-odd" > < ul class = "simple" >
< li > < p > < strong > datamanager< / strong > (< a class = "reference internal" href = "data.html#torchreid.data.datamanager.DataManager" title = "torchreid.data.datamanager.DataManager" > < em > DataManager< / em > < / a > ) – an instance of < code class = "docutils literal notranslate" > < span class = "pre" > torchreid.data.ImageDataManager< / span > < / code >
or < code class = "docutils literal notranslate" > < span class = "pre" > torchreid.data.VideoDataManager< / span > < / code > .< / p > < / li >
< li > < p > < strong > model< / strong > (< em > nn.Module< / em > ) – model instance.< / p > < / li >
< li > < p > < strong > optimizer< / strong > (< em > Optimizer< / em > ) – an Optimizer.< / p > < / li >
< li > < p > < strong > margin< / strong > (< em > float< / em > < em > , < / em > < em > optional< / em > ) – margin for triplet loss. Default is 0.3.< / p > < / li >
< li > < p > < strong > weight_t< / strong > (< em > float< / em > < em > , < / em > < em > optional< / em > ) – weight for triplet loss. Default is 1.< / p > < / li >
< li > < p > < strong > weight_x< / strong > (< em > float< / em > < em > , < / em > < em > optional< / em > ) – weight for softmax loss. Default is 1.< / p > < / li >
< li > < p > < strong > scheduler< / strong > (< em > LRScheduler< / em > < em > , < / em > < em > optional< / em > ) – if None, no learning rate decay will be performed.< / p > < / li >
< li > < p > < strong > use_gpu< / strong > (< em > bool< / em > < em > , < / em > < em > optional< / em > ) – use gpu. Default is True.< / p > < / li >
< li > < p > < strong > label_smooth< / strong > (< em > bool< / em > < em > , < / em > < em > optional< / em > ) – use label smoothing regularizer. Default is True.< / p > < / li >
< li > < p > < strong > pooling_method< / strong > (< em > str< / em > < em > , < / em > < em > optional< / em > ) – how to pool features for a tracklet.
Default is “avg” (average). Choices are [“avg”, “max”].< / p > < / li >
2019-03-24 17:22:43 +00:00
< / ul >
2019-09-01 15:05:27 +01:00
< / dd >
< / dl >
2019-03-24 17:22:43 +00:00
< p > Examples:< / p >
< div class = "highlight-default notranslate" > < div class = "highlight" > < pre > < span > < / span > < span class = "kn" > import< / span > < span class = "nn" > torch< / span >
< span class = "kn" > import< / span > < span class = "nn" > torchreid< / span >
< span class = "c1" > # Each batch contains batch_size*seq_len images< / span >
< span class = "c1" > # Each identity is sampled with num_instances tracklets< / span >
< span class = "n" > datamanager< / span > < span class = "o" > =< / span > < span class = "n" > torchreid< / span > < span class = "o" > .< / span > < span class = "n" > data< / span > < span class = "o" > .< / span > < span class = "n" > VideoDataManager< / span > < span class = "p" > (< / span >
< span class = "n" > root< / span > < span class = "o" > =< / span > < span class = "s1" > ' path/to/reid-data' < / span > < span class = "p" > ,< / span >
< span class = "n" > sources< / span > < span class = "o" > =< / span > < span class = "s1" > ' mars' < / span > < span class = "p" > ,< / span >
< span class = "n" > height< / span > < span class = "o" > =< / span > < span class = "mi" > 256< / span > < span class = "p" > ,< / span >
< span class = "n" > width< / span > < span class = "o" > =< / span > < span class = "mi" > 128< / span > < span class = "p" > ,< / span >
< span class = "n" > combineall< / span > < span class = "o" > =< / span > < span class = "kc" > False< / span > < span class = "p" > ,< / span >
< span class = "n" > num_instances< / span > < span class = "o" > =< / span > < span class = "mi" > 4< / span > < span class = "p" > ,< / span >
< span class = "n" > train_sampler< / span > < span class = "o" > =< / span > < span class = "s1" > ' RandomIdentitySampler' < / span >
< span class = "n" > batch_size< / span > < span class = "o" > =< / span > < span class = "mi" > 8< / span > < span class = "p" > ,< / span > < span class = "c1" > # number of tracklets< / span >
< span class = "n" > seq_len< / span > < span class = "o" > =< / span > < span class = "mi" > 15< / span > < span class = "c1" > # number of images in each tracklet< / span >
< span class = "p" > )< / span >
< span class = "n" > model< / span > < span class = "o" > =< / span > < span class = "n" > torchreid< / span > < span class = "o" > .< / span > < span class = "n" > models< / span > < span class = "o" > .< / span > < span class = "n" > build_model< / span > < span class = "p" > (< / span >
< span class = "n" > name< / span > < span class = "o" > =< / span > < span class = "s1" > ' resnet50' < / span > < span class = "p" > ,< / span >
< span class = "n" > num_classes< / span > < span class = "o" > =< / span > < span class = "n" > datamanager< / span > < span class = "o" > .< / span > < span class = "n" > num_train_pids< / span > < span class = "p" > ,< / span >
< span class = "n" > loss< / span > < span class = "o" > =< / span > < span class = "s1" > ' triplet' < / span >
< span class = "p" > )< / span >
< span class = "n" > model< / span > < span class = "o" > =< / span > < span class = "n" > model< / span > < span class = "o" > .< / span > < span class = "n" > cuda< / span > < span class = "p" > ()< / span >
< span class = "n" > optimizer< / span > < span class = "o" > =< / span > < span class = "n" > torchreid< / span > < span class = "o" > .< / span > < span class = "n" > optim< / span > < span class = "o" > .< / span > < span class = "n" > build_optimizer< / span > < span class = "p" > (< / span >
< span class = "n" > model< / span > < span class = "p" > ,< / span > < span class = "n" > optim< / span > < span class = "o" > =< / span > < span class = "s1" > ' adam' < / span > < span class = "p" > ,< / span > < span class = "n" > lr< / span > < span class = "o" > =< / span > < span class = "mf" > 0.0003< / span >
< span class = "p" > )< / span >
< span class = "n" > scheduler< / span > < span class = "o" > =< / span > < span class = "n" > torchreid< / span > < span class = "o" > .< / span > < span class = "n" > optim< / span > < span class = "o" > .< / span > < span class = "n" > build_lr_scheduler< / span > < span class = "p" > (< / span >
< span class = "n" > optimizer< / span > < span class = "p" > ,< / span >
< span class = "n" > lr_scheduler< / span > < span class = "o" > =< / span > < span class = "s1" > ' single_step' < / span > < span class = "p" > ,< / span >
< span class = "n" > stepsize< / span > < span class = "o" > =< / span > < span class = "mi" > 20< / span >
< span class = "p" > )< / span >
< span class = "n" > engine< / span > < span class = "o" > =< / span > < span class = "n" > torchreid< / span > < span class = "o" > .< / span > < span class = "n" > engine< / span > < span class = "o" > .< / span > < span class = "n" > VideoTripletEngine< / span > < span class = "p" > (< / span >
< span class = "n" > datamanager< / span > < span class = "p" > ,< / span > < span class = "n" > model< / span > < span class = "p" > ,< / span > < span class = "n" > optimizer< / span > < span class = "p" > ,< / span > < span class = "n" > margin< / span > < span class = "o" > =< / span > < span class = "mf" > 0.3< / span > < span class = "p" > ,< / span >
< span class = "n" > weight_t< / span > < span class = "o" > =< / span > < span class = "mf" > 0.7< / span > < span class = "p" > ,< / span > < span class = "n" > weight_x< / span > < span class = "o" > =< / span > < span class = "mi" > 1< / span > < span class = "p" > ,< / span > < span class = "n" > scheduler< / span > < span class = "o" > =< / span > < span class = "n" > scheduler< / span > < span class = "p" > ,< / span >
< span class = "n" > pooling_method< / span > < span class = "o" > =< / span > < span class = "s1" > ' avg' < / span >
< span class = "p" > )< / span >
< span class = "n" > engine< / span > < span class = "o" > .< / span > < span class = "n" > run< / span > < span class = "p" > (< / span >
< span class = "n" > max_epoch< / span > < span class = "o" > =< / span > < span class = "mi" > 60< / span > < span class = "p" > ,< / span >
< span class = "n" > save_dir< / span > < span class = "o" > =< / span > < span class = "s1" > ' log/resnet50-triplet-mars' < / span > < span class = "p" > ,< / span >
< span class = "n" > print_freq< / span > < span class = "o" > =< / span > < span class = "mi" > 10< / span >
< span class = "p" > )< / span >
< / pre > < / div >
< / div >
< / dd > < / dl >
< / div >
< / div >
< / div >
< / div >
< footer >
< div class = "rst-footer-buttons" role = "navigation" aria-label = "footer navigation" >
< a href = "losses.html" class = "btn btn-neutral float-right" title = "torchreid.losses" accesskey = "n" rel = "next" > Next < span class = "fa fa-arrow-circle-right" > < / span > < / a >
< a href = "data.html" class = "btn btn-neutral float-left" title = "torchreid.data" accesskey = "p" rel = "prev" > < span class = "fa fa-arrow-circle-left" > < / span > Previous< / a >
< / div >
< hr / >
< div role = "contentinfo" >
< p >
© Copyright 2019, Kaiyang Zhou
< / p >
< / div >
Built with < a href = "http://sphinx-doc.org/" > Sphinx< / a > using a < a href = "https://github.com/rtfd/sphinx_rtd_theme" > theme< / a > provided by < a href = "https://readthedocs.org" > Read the Docs< / a > .
< / footer >
< / div >
< / div >
< / section >
< / div >
< script type = "text/javascript" >
jQuery(function () {
SphinxRtdTheme.Navigation.enable(true);
});
< / script >
< / body >
< / html >