576 lines
46 KiB
HTML
576 lines
46 KiB
HTML
|
|
|
|
<!DOCTYPE html>
|
|
<html class="writer-html5" lang="en" >
|
|
<head>
|
|
<meta charset="utf-8">
|
|
|
|
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
|
|
|
<title>How-to — torchreid 1.4.0 documentation</title>
|
|
|
|
|
|
|
|
<link rel="stylesheet" href="_static/css/theme.css" type="text/css" />
|
|
<link rel="stylesheet" href="_static/pygments.css" type="text/css" />
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
<!--[if lt IE 9]>
|
|
<script src="_static/js/html5shiv.min.js"></script>
|
|
<![endif]-->
|
|
|
|
|
|
<script type="text/javascript" id="documentation_options" data-url_root="./" src="_static/documentation_options.js"></script>
|
|
<script type="text/javascript" src="_static/jquery.js"></script>
|
|
<script type="text/javascript" src="_static/underscore.js"></script>
|
|
<script type="text/javascript" src="_static/doctools.js"></script>
|
|
<script type="text/javascript" src="_static/language_data.js"></script>
|
|
<script async="async" type="text/javascript" src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/latest.js?config=TeX-AMS-MML_HTMLorMML"></script>
|
|
|
|
<script type="text/javascript" src="_static/js/theme.js"></script>
|
|
|
|
|
|
<link rel="index" title="Index" href="genindex.html" />
|
|
<link rel="search" title="Search" href="search.html" />
|
|
<link rel="next" title="Datasets" href="datasets.html" />
|
|
<link rel="prev" title="Torchreid" href="index.html" />
|
|
</head>
|
|
|
|
<body class="wy-body-for-nav">
|
|
|
|
|
|
<div class="wy-grid-for-nav">
|
|
|
|
<nav data-toggle="wy-nav-shift" class="wy-nav-side">
|
|
<div class="wy-side-scroll">
|
|
<div class="wy-side-nav-search" >
|
|
|
|
|
|
|
|
<a href="index.html" class="icon icon-home" alt="Documentation Home"> torchreid
|
|
|
|
|
|
|
|
</a>
|
|
|
|
|
|
|
|
|
|
<div class="version">
|
|
1.4.0
|
|
</div>
|
|
|
|
|
|
|
|
|
|
<div role="search">
|
|
<form id="rtd-search-form" class="wy-form" action="search.html" method="get">
|
|
<input type="text" name="q" placeholder="Search docs" />
|
|
<input type="hidden" name="check_keywords" value="yes" />
|
|
<input type="hidden" name="area" value="default" />
|
|
</form>
|
|
</div>
|
|
|
|
|
|
</div>
|
|
|
|
|
|
<div class="wy-menu wy-menu-vertical" data-spy="affix" role="navigation" aria-label="main navigation">
|
|
|
|
|
|
|
|
|
|
|
|
|
|
<ul class="current">
|
|
<li class="toctree-l1 current"><a class="current reference internal" href="#">How-to</a><ul>
|
|
<li class="toctree-l2"><a class="reference internal" href="#prepare-datasets">Prepare datasets</a></li>
|
|
<li class="toctree-l2"><a class="reference internal" href="#find-model-keys">Find model keys</a></li>
|
|
<li class="toctree-l2"><a class="reference internal" href="#show-available-models">Show available models</a></li>
|
|
<li class="toctree-l2"><a class="reference internal" href="#change-the-training-sampler">Change the training sampler</a></li>
|
|
<li class="toctree-l2"><a class="reference internal" href="#choose-an-optimizer-lr-scheduler">Choose an optimizer/lr_scheduler</a></li>
|
|
<li class="toctree-l2"><a class="reference internal" href="#resume-training">Resume training</a></li>
|
|
<li class="toctree-l2"><a class="reference internal" href="#compute-model-complexity">Compute model complexity</a></li>
|
|
<li class="toctree-l2"><a class="reference internal" href="#combine-multiple-datasets">Combine multiple datasets</a></li>
|
|
<li class="toctree-l2"><a class="reference internal" href="#do-cross-dataset-evaluation">Do cross-dataset evaluation</a></li>
|
|
<li class="toctree-l2"><a class="reference internal" href="#combine-train-query-and-gallery">Combine train, query and gallery</a></li>
|
|
<li class="toctree-l2"><a class="reference internal" href="#optimize-layers-with-different-learning-rates">Optimize layers with different learning rates</a></li>
|
|
<li class="toctree-l2"><a class="reference internal" href="#do-two-stepped-transfer-learning">Do two-stepped transfer learning</a></li>
|
|
<li class="toctree-l2"><a class="reference internal" href="#test-a-trained-model">Test a trained model</a></li>
|
|
<li class="toctree-l2"><a class="reference internal" href="#fine-tune-a-model-pre-trained-on-reid-datasets">Fine-tune a model pre-trained on reid datasets</a></li>
|
|
<li class="toctree-l2"><a class="reference internal" href="#visualize-learning-curves-with-tensorboard">Visualize learning curves with tensorboard</a></li>
|
|
<li class="toctree-l2"><a class="reference internal" href="#visualize-ranking-results">Visualize ranking results</a></li>
|
|
<li class="toctree-l2"><a class="reference internal" href="#visualize-activation-maps">Visualize activation maps</a></li>
|
|
<li class="toctree-l2"><a class="reference internal" href="#use-your-own-dataset">Use your own dataset</a></li>
|
|
<li class="toctree-l2"><a class="reference internal" href="#design-your-own-engine">Design your own Engine</a></li>
|
|
<li class="toctree-l2"><a class="reference internal" href="#use-torchreid-as-a-feature-extractor-in-your-projects">Use Torchreid as a feature extractor in your projects</a></li>
|
|
</ul>
|
|
</li>
|
|
<li class="toctree-l1"><a class="reference internal" href="datasets.html">Datasets</a></li>
|
|
<li class="toctree-l1"><a class="reference internal" href="evaluation.html">Evaluation</a></li>
|
|
</ul>
|
|
<p class="caption"><span class="caption-text">Package Reference</span></p>
|
|
<ul>
|
|
<li class="toctree-l1"><a class="reference internal" href="pkg/data.html">torchreid.data</a></li>
|
|
<li class="toctree-l1"><a class="reference internal" href="pkg/engine.html">torchreid.engine</a></li>
|
|
<li class="toctree-l1"><a class="reference internal" href="pkg/losses.html">torchreid.losses</a></li>
|
|
<li class="toctree-l1"><a class="reference internal" href="pkg/metrics.html">torchreid.metrics</a></li>
|
|
<li class="toctree-l1"><a class="reference internal" href="pkg/models.html">torchreid.models</a></li>
|
|
<li class="toctree-l1"><a class="reference internal" href="pkg/optim.html">torchreid.optim</a></li>
|
|
<li class="toctree-l1"><a class="reference internal" href="pkg/utils.html">torchreid.utils</a></li>
|
|
</ul>
|
|
<p class="caption"><span class="caption-text">Resources</span></p>
|
|
<ul>
|
|
<li class="toctree-l1"><a class="reference internal" href="AWESOME_REID.html">Awesome-ReID</a></li>
|
|
<li class="toctree-l1"><a class="reference internal" href="MODEL_ZOO.html">Model Zoo</a></li>
|
|
</ul>
|
|
|
|
|
|
|
|
</div>
|
|
|
|
</div>
|
|
</nav>
|
|
|
|
<section data-toggle="wy-nav-shift" class="wy-nav-content-wrap">
|
|
|
|
|
|
<nav class="wy-nav-top" aria-label="top navigation">
|
|
|
|
<i data-toggle="wy-nav-top" class="fa fa-bars"></i>
|
|
<a href="index.html">torchreid</a>
|
|
|
|
</nav>
|
|
|
|
|
|
<div class="wy-nav-content">
|
|
|
|
<div class="rst-content">
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
<div role="navigation" aria-label="breadcrumbs navigation">
|
|
|
|
<ul class="wy-breadcrumbs">
|
|
|
|
<li><a href="index.html" class="icon icon-home"></a> »</li>
|
|
|
|
<li>How-to</li>
|
|
|
|
|
|
<li class="wy-breadcrumbs-aside">
|
|
|
|
|
|
<a href="_sources/user_guide.rst.txt" rel="nofollow"> View page source</a>
|
|
|
|
|
|
</li>
|
|
|
|
</ul>
|
|
|
|
|
|
<hr/>
|
|
</div>
|
|
<div role="main" class="document" itemscope="itemscope" itemtype="http://schema.org/Article">
|
|
<div itemprop="articleBody">
|
|
|
|
<div class="section" id="how-to">
|
|
<h1>How-to<a class="headerlink" href="#how-to" title="Permalink to this headline">¶</a></h1>
|
|
<div class="contents local topic" id="contents">
|
|
<ul class="simple">
|
|
<li><p><a class="reference internal" href="#prepare-datasets" id="id1">Prepare datasets</a></p></li>
|
|
<li><p><a class="reference internal" href="#find-model-keys" id="id2">Find model keys</a></p></li>
|
|
<li><p><a class="reference internal" href="#show-available-models" id="id3">Show available models</a></p></li>
|
|
<li><p><a class="reference internal" href="#change-the-training-sampler" id="id4">Change the training sampler</a></p></li>
|
|
<li><p><a class="reference internal" href="#choose-an-optimizer-lr-scheduler" id="id5">Choose an optimizer/lr_scheduler</a></p></li>
|
|
<li><p><a class="reference internal" href="#resume-training" id="id6">Resume training</a></p></li>
|
|
<li><p><a class="reference internal" href="#compute-model-complexity" id="id7">Compute model complexity</a></p></li>
|
|
<li><p><a class="reference internal" href="#combine-multiple-datasets" id="id8">Combine multiple datasets</a></p></li>
|
|
<li><p><a class="reference internal" href="#do-cross-dataset-evaluation" id="id9">Do cross-dataset evaluation</a></p></li>
|
|
<li><p><a class="reference internal" href="#combine-train-query-and-gallery" id="id10">Combine train, query and gallery</a></p></li>
|
|
<li><p><a class="reference internal" href="#optimize-layers-with-different-learning-rates" id="id11">Optimize layers with different learning rates</a></p></li>
|
|
<li><p><a class="reference internal" href="#do-two-stepped-transfer-learning" id="id12">Do two-stepped transfer learning</a></p></li>
|
|
<li><p><a class="reference internal" href="#test-a-trained-model" id="id13">Test a trained model</a></p></li>
|
|
<li><p><a class="reference internal" href="#fine-tune-a-model-pre-trained-on-reid-datasets" id="id14">Fine-tune a model pre-trained on reid datasets</a></p></li>
|
|
<li><p><a class="reference internal" href="#visualize-learning-curves-with-tensorboard" id="id15">Visualize learning curves with tensorboard</a></p></li>
|
|
<li><p><a class="reference internal" href="#visualize-ranking-results" id="id16">Visualize ranking results</a></p></li>
|
|
<li><p><a class="reference internal" href="#visualize-activation-maps" id="id17">Visualize activation maps</a></p></li>
|
|
<li><p><a class="reference internal" href="#use-your-own-dataset" id="id18">Use your own dataset</a></p></li>
|
|
<li><p><a class="reference internal" href="#design-your-own-engine" id="id19">Design your own Engine</a></p></li>
|
|
<li><p><a class="reference internal" href="#use-torchreid-as-a-feature-extractor-in-your-projects" id="id20">Use Torchreid as a feature extractor in your projects</a></p></li>
|
|
</ul>
|
|
</div>
|
|
<div class="section" id="prepare-datasets">
|
|
<h2><a class="toc-backref" href="#id1">Prepare datasets</a><a class="headerlink" href="#prepare-datasets" title="Permalink to this headline">¶</a></h2>
|
|
<p>See <a class="reference internal" href="datasets.html#datasets"><span class="std std-ref">Datasets</span></a>.</p>
|
|
</div>
|
|
<div class="section" id="find-model-keys">
|
|
<h2><a class="toc-backref" href="#id2">Find model keys</a><a class="headerlink" href="#find-model-keys" title="Permalink to this headline">¶</a></h2>
|
|
<p>Keys are listed under the <em>Public keys</em> section within each model class in <a class="reference internal" href="pkg/models.html#torchreid-models"><span class="std std-ref">torchreid.models</span></a>.</p>
|
|
</div>
|
|
<div class="section" id="show-available-models">
|
|
<h2><a class="toc-backref" href="#id3">Show available models</a><a class="headerlink" href="#show-available-models" title="Permalink to this headline">¶</a></h2>
|
|
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">torchreid</span>
|
|
<span class="n">torchreid</span><span class="o">.</span><span class="n">models</span><span class="o">.</span><span class="n">show_avai_models</span><span class="p">()</span>
|
|
</pre></div>
|
|
</div>
|
|
</div>
|
|
<div class="section" id="change-the-training-sampler">
|
|
<h2><a class="toc-backref" href="#id4">Change the training sampler</a><a class="headerlink" href="#change-the-training-sampler" title="Permalink to this headline">¶</a></h2>
|
|
<p>The default <code class="docutils literal notranslate"><span class="pre">train_sampler</span></code> is “RandomSampler”. You can give the specific sampler name as input to <code class="docutils literal notranslate"><span class="pre">train_sampler</span></code>, e.g. <code class="docutils literal notranslate"><span class="pre">train_sampler='RandomIdentitySampler'</span></code> for triplet loss.</p>
|
|
</div>
|
|
<div class="section" id="choose-an-optimizer-lr-scheduler">
|
|
<h2><a class="toc-backref" href="#id5">Choose an optimizer/lr_scheduler</a><a class="headerlink" href="#choose-an-optimizer-lr-scheduler" title="Permalink to this headline">¶</a></h2>
|
|
<p>Please refer to the source code of <code class="docutils literal notranslate"><span class="pre">build_optimizer</span></code>/<code class="docutils literal notranslate"><span class="pre">build_lr_scheduler</span></code> in <a class="reference internal" href="pkg/optim.html#torchreid-optim"><span class="std std-ref">torchreid.optim</span></a> for details.</p>
|
|
</div>
|
|
<div class="section" id="resume-training">
|
|
<h2><a class="toc-backref" href="#id6">Resume training</a><a class="headerlink" href="#resume-training" title="Permalink to this headline">¶</a></h2>
|
|
<p>Suppose the checkpoint is saved in “log/resnet50/model.pth.tar-30”, you can do</p>
|
|
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">start_epoch</span> <span class="o">=</span> <span class="n">torchreid</span><span class="o">.</span><span class="n">utils</span><span class="o">.</span><span class="n">resume_from_checkpoint</span><span class="p">(</span>
|
|
<span class="s1">'log/resnet50/model.pth.tar-30'</span><span class="p">,</span>
|
|
<span class="n">model</span><span class="p">,</span>
|
|
<span class="n">optimizer</span>
|
|
<span class="p">)</span>
|
|
|
|
<span class="n">engine</span><span class="o">.</span><span class="n">run</span><span class="p">(</span>
|
|
<span class="n">save_dir</span><span class="o">=</span><span class="s1">'log/resnet50'</span><span class="p">,</span>
|
|
<span class="n">max_epoch</span><span class="o">=</span><span class="mi">60</span><span class="p">,</span>
|
|
<span class="n">start_epoch</span><span class="o">=</span><span class="n">start_epoch</span>
|
|
<span class="p">)</span>
|
|
</pre></div>
|
|
</div>
|
|
</div>
|
|
<div class="section" id="compute-model-complexity">
|
|
<h2><a class="toc-backref" href="#id7">Compute model complexity</a><a class="headerlink" href="#compute-model-complexity" title="Permalink to this headline">¶</a></h2>
|
|
<p>We provide a tool in <code class="docutils literal notranslate"><span class="pre">torchreid.utils.model_complexity.py</span></code> to automatically compute the model complexity, i.e. number of parameters and FLOPs.</p>
|
|
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">from</span> <span class="nn">torchreid</span> <span class="kn">import</span> <span class="n">models</span><span class="p">,</span> <span class="n">utils</span>
|
|
|
|
<span class="n">model</span> <span class="o">=</span> <span class="n">models</span><span class="o">.</span><span class="n">build_model</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s1">'resnet50'</span><span class="p">,</span> <span class="n">num_classes</span><span class="o">=</span><span class="mi">1000</span><span class="p">)</span>
|
|
<span class="n">num_params</span><span class="p">,</span> <span class="n">flops</span> <span class="o">=</span> <span class="n">utils</span><span class="o">.</span><span class="n">compute_model_complexity</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">256</span><span class="p">,</span> <span class="mi">128</span><span class="p">))</span>
|
|
|
|
<span class="c1"># show detailed complexity for each module</span>
|
|
<span class="n">utils</span><span class="o">.</span><span class="n">compute_model_complexity</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">256</span><span class="p">,</span> <span class="mi">128</span><span class="p">),</span> <span class="n">verbose</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
|
|
|
|
<span class="c1"># count flops for all layers including ReLU and BatchNorm</span>
|
|
<span class="n">utils</span><span class="o">.</span><span class="n">compute_model_complexity</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">3</span><span class="p">,</span> <span class="mi">256</span><span class="p">,</span> <span class="mi">128</span><span class="p">),</span> <span class="n">verbose</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">only_conv_linear</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
|
|
</pre></div>
|
|
</div>
|
|
<p>Note that (1) this function only provides an estimate of the theoretical time complexity rather than the actual running time which depends on implementations and hardware; (2) the FLOPs is only counted for layers that are used at test time. This means that redundant layers such as person ID classification layer will be ignored. The inference graph depends on how you define the computations in <code class="docutils literal notranslate"><span class="pre">forward()</span></code>.</p>
|
|
</div>
|
|
<div class="section" id="combine-multiple-datasets">
|
|
<h2><a class="toc-backref" href="#id8">Combine multiple datasets</a><a class="headerlink" href="#combine-multiple-datasets" title="Permalink to this headline">¶</a></h2>
|
|
<p>Easy. Just give whatever datasets (keys) you want to the <code class="docutils literal notranslate"><span class="pre">sources</span></code> argument when instantiating a data manager. For example,</p>
|
|
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">datamanager</span> <span class="o">=</span> <span class="n">torchreid</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">ImageDataManager</span><span class="p">(</span>
|
|
<span class="n">root</span><span class="o">=</span><span class="s1">'reid-data'</span><span class="p">,</span>
|
|
<span class="n">sources</span><span class="o">=</span><span class="p">[</span><span class="s1">'market1501'</span><span class="p">,</span> <span class="s1">'dukemtmcreid'</span><span class="p">,</span> <span class="s1">'cuhk03'</span><span class="p">,</span> <span class="s1">'msmt17'</span><span class="p">],</span>
|
|
<span class="n">height</span><span class="o">=</span><span class="mi">256</span><span class="p">,</span>
|
|
<span class="n">width</span><span class="o">=</span><span class="mi">128</span><span class="p">,</span>
|
|
<span class="n">batch_size</span><span class="o">=</span><span class="mi">32</span>
|
|
<span class="p">)</span>
|
|
</pre></div>
|
|
</div>
|
|
<p>In this example, the target datasets are Market1501, DukeMTMC-reID, CUHK03 and MSMT17 as the <code class="docutils literal notranslate"><span class="pre">targets</span></code> argument is not specified. Please refer to <code class="docutils literal notranslate"><span class="pre">Engine.test()</span></code> in <a class="reference internal" href="pkg/engine.html#torchreid-engine"><span class="std std-ref">torchreid.engine</span></a> for details regarding how evaluation is performed.</p>
|
|
</div>
|
|
<div class="section" id="do-cross-dataset-evaluation">
|
|
<h2><a class="toc-backref" href="#id9">Do cross-dataset evaluation</a><a class="headerlink" href="#do-cross-dataset-evaluation" title="Permalink to this headline">¶</a></h2>
|
|
<p>Easy. Just give whatever datasets (keys) you want to the argument <code class="docutils literal notranslate"><span class="pre">targets</span></code>, like</p>
|
|
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">datamanager</span> <span class="o">=</span> <span class="n">torchreid</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">ImageDataManager</span><span class="p">(</span>
|
|
<span class="n">root</span><span class="o">=</span><span class="s1">'reid-data'</span><span class="p">,</span>
|
|
<span class="n">sources</span><span class="o">=</span><span class="s1">'market1501'</span><span class="p">,</span>
|
|
<span class="n">targets</span><span class="o">=</span><span class="s1">'dukemtmcreid'</span><span class="p">,</span> <span class="c1"># or targets='cuhk03' or targets=['dukemtmcreid', 'cuhk03']</span>
|
|
<span class="n">height</span><span class="o">=</span><span class="mi">256</span><span class="p">,</span>
|
|
<span class="n">width</span><span class="o">=</span><span class="mi">128</span><span class="p">,</span>
|
|
<span class="n">batch_size</span><span class="o">=</span><span class="mi">32</span>
|
|
<span class="p">)</span>
|
|
</pre></div>
|
|
</div>
|
|
</div>
|
|
<div class="section" id="combine-train-query-and-gallery">
|
|
<h2><a class="toc-backref" href="#id10">Combine train, query and gallery</a><a class="headerlink" href="#combine-train-query-and-gallery" title="Permalink to this headline">¶</a></h2>
|
|
<p>This can be easily done by setting <code class="docutils literal notranslate"><span class="pre">combineall=True</span></code> when instantiating a data manager. Below is an example of using Market1501,</p>
|
|
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">datamanager</span> <span class="o">=</span> <span class="n">torchreid</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">ImageDataManager</span><span class="p">(</span>
|
|
<span class="n">root</span><span class="o">=</span><span class="s1">'reid-data'</span><span class="p">,</span>
|
|
<span class="n">sources</span><span class="o">=</span><span class="s1">'market1501'</span><span class="p">,</span>
|
|
<span class="n">height</span><span class="o">=</span><span class="mi">256</span><span class="p">,</span>
|
|
<span class="n">width</span><span class="o">=</span><span class="mi">128</span><span class="p">,</span>
|
|
<span class="n">batch_size</span><span class="o">=</span><span class="mi">32</span><span class="p">,</span>
|
|
<span class="n">market1501_500k</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
|
<span class="n">combineall</span><span class="o">=</span><span class="kc">True</span> <span class="c1"># it's me, here</span>
|
|
<span class="p">)</span>
|
|
</pre></div>
|
|
</div>
|
|
<p>More specifically, with <code class="docutils literal notranslate"><span class="pre">combineall=False</span></code>, you will get</p>
|
|
<div class="highlight-none notranslate"><div class="highlight"><pre><span></span>=> Loaded Market1501
|
|
----------------------------------------
|
|
subset | # ids | # images | # cameras
|
|
----------------------------------------
|
|
train | 751 | 12936 | 6
|
|
query | 750 | 3368 | 6
|
|
gallery | 751 | 15913 | 6
|
|
---------------------------------------
|
|
</pre></div>
|
|
</div>
|
|
<p>with <code class="docutils literal notranslate"><span class="pre">combineall=True</span></code>, you will get</p>
|
|
<div class="highlight-none notranslate"><div class="highlight"><pre><span></span>=> Loaded Market1501
|
|
----------------------------------------
|
|
subset | # ids | # images | # cameras
|
|
----------------------------------------
|
|
train | 1501 | 29419 | 6
|
|
query | 750 | 3368 | 6
|
|
gallery | 751 | 15913 | 6
|
|
---------------------------------------
|
|
</pre></div>
|
|
</div>
|
|
</div>
|
|
<div class="section" id="optimize-layers-with-different-learning-rates">
|
|
<h2><a class="toc-backref" href="#id11">Optimize layers with different learning rates</a><a class="headerlink" href="#optimize-layers-with-different-learning-rates" title="Permalink to this headline">¶</a></h2>
|
|
<p>A common practice for fine-tuning pretrained models is to use a smaller learning rate for base layers and a large learning rate for randomly initialized layers (referred to as <code class="docutils literal notranslate"><span class="pre">new_layers</span></code>). <code class="docutils literal notranslate"><span class="pre">torchreid.optim.optimizer</span></code> has implemented such feature. What you need to do is to set <code class="docutils literal notranslate"><span class="pre">staged_lr=True</span></code> and give the names of <code class="docutils literal notranslate"><span class="pre">new_layers</span></code> such as “classifier”.</p>
|
|
<p>Below is an example of setting different learning rates for base layers and new layers in ResNet50,</p>
|
|
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="c1"># New layer "classifier" has a learning rate of 0.01</span>
|
|
<span class="c1"># The base layers have a learning rate of 0.001</span>
|
|
<span class="n">optimizer</span> <span class="o">=</span> <span class="n">torchreid</span><span class="o">.</span><span class="n">optim</span><span class="o">.</span><span class="n">build_optimizer</span><span class="p">(</span>
|
|
<span class="n">model</span><span class="p">,</span>
|
|
<span class="n">optim</span><span class="o">=</span><span class="s1">'sgd'</span><span class="p">,</span>
|
|
<span class="n">lr</span><span class="o">=</span><span class="mf">0.01</span><span class="p">,</span>
|
|
<span class="n">staged_lr</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
|
|
<span class="n">new_layers</span><span class="o">=</span><span class="s1">'classifier'</span><span class="p">,</span>
|
|
<span class="n">base_lr_mult</span><span class="o">=</span><span class="mf">0.1</span>
|
|
<span class="p">)</span>
|
|
</pre></div>
|
|
</div>
|
|
<p>Please refer to <a class="reference internal" href="pkg/optim.html#torchreid-optim"><span class="std std-ref">torchreid.optim</span></a> for more details.</p>
|
|
</div>
|
|
<div class="section" id="do-two-stepped-transfer-learning">
|
|
<h2><a class="toc-backref" href="#id12">Do two-stepped transfer learning</a><a class="headerlink" href="#do-two-stepped-transfer-learning" title="Permalink to this headline">¶</a></h2>
|
|
<p>To prevent the pretrained layers from being damaged by harmful gradients back-propagated from randomly initialized layers, one can adopt the <em>two-stepped transfer learning strategy</em> presented in <a class="reference external" href="https://arxiv.org/abs/1611.05244">Deep Transfer Learning for Person Re-identification</a>. The basic idea is to pretrain the randomly initialized layers for few epochs while keeping the base layers frozen before training all layers end-to-end.</p>
|
|
<p>This has been implemented in <code class="docutils literal notranslate"><span class="pre">Engine.train()</span></code> (see <a class="reference internal" href="pkg/engine.html#torchreid-engine"><span class="std std-ref">torchreid.engine</span></a>). The arguments related to this feature are <code class="docutils literal notranslate"><span class="pre">fixbase_epoch</span></code> and <code class="docutils literal notranslate"><span class="pre">open_layers</span></code>. Intuitively, <code class="docutils literal notranslate"><span class="pre">fixbase_epoch</span></code> denotes the number of epochs to keep the base layers frozen; <code class="docutils literal notranslate"><span class="pre">open_layers</span></code> means which layer is open for training.</p>
|
|
<p>For example, say you want to pretrain the classification layer named “classifier” in ResNet50 for 5 epochs before training all layers, you can do</p>
|
|
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">engine</span><span class="o">.</span><span class="n">run</span><span class="p">(</span>
|
|
<span class="n">save_dir</span><span class="o">=</span><span class="s1">'log/resnet50'</span><span class="p">,</span>
|
|
<span class="n">max_epoch</span><span class="o">=</span><span class="mi">60</span><span class="p">,</span>
|
|
<span class="n">eval_freq</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span>
|
|
<span class="n">print_freq</span><span class="o">=</span><span class="mi">10</span><span class="p">,</span>
|
|
<span class="n">test_only</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span>
|
|
<span class="n">fixbase_epoch</span><span class="o">=</span><span class="mi">5</span><span class="p">,</span>
|
|
<span class="n">open_layers</span><span class="o">=</span><span class="s1">'classifier'</span>
|
|
<span class="p">)</span>
|
|
<span class="c1"># or open_layers=['fc', 'classifier'] if there is another fc layer that</span>
|
|
<span class="c1"># is randomly initialized, like resnet50_fc512</span>
|
|
</pre></div>
|
|
</div>
|
|
<p>Note that <code class="docutils literal notranslate"><span class="pre">fixbase_epoch</span></code> is counted into <code class="docutils literal notranslate"><span class="pre">max_epoch</span></code>. In the above example, the base network will be fixed for 5 epochs and then open for training for 55 epochs. Thus, if you want to freeze some layers throughout the training, what you can do is to set <code class="docutils literal notranslate"><span class="pre">fixbase_epoch</span></code> equal to <code class="docutils literal notranslate"><span class="pre">max_epoch</span></code> and put the layer names in <code class="docutils literal notranslate"><span class="pre">open_layers</span></code> which you want to train.</p>
|
|
</div>
|
|
<div class="section" id="test-a-trained-model">
|
|
<h2><a class="toc-backref" href="#id13">Test a trained model</a><a class="headerlink" href="#test-a-trained-model" title="Permalink to this headline">¶</a></h2>
|
|
<p>You can load a trained model using <code class="code docutils literal notranslate"><span class="pre">torchreid.utils.load_pretrained_weights(model,</span> <span class="pre">weight_path)</span></code> and set <code class="docutils literal notranslate"><span class="pre">test_only=True</span></code> in <code class="docutils literal notranslate"><span class="pre">engine.run()</span></code>.</p>
|
|
</div>
|
|
<div class="section" id="fine-tune-a-model-pre-trained-on-reid-datasets">
|
|
<h2><a class="toc-backref" href="#id14">Fine-tune a model pre-trained on reid datasets</a><a class="headerlink" href="#fine-tune-a-model-pre-trained-on-reid-datasets" title="Permalink to this headline">¶</a></h2>
|
|
<p>Use <code class="code docutils literal notranslate"><span class="pre">torchreid.utils.load_pretrained_weights(model,</span> <span class="pre">weight_path)</span></code> to load the pre-trained weights and then fine-tune on the dataset you want.</p>
|
|
</div>
|
|
<div class="section" id="visualize-learning-curves-with-tensorboard">
|
|
<h2><a class="toc-backref" href="#id15">Visualize learning curves with tensorboard</a><a class="headerlink" href="#visualize-learning-curves-with-tensorboard" title="Permalink to this headline">¶</a></h2>
|
|
<p>The <code class="docutils literal notranslate"><span class="pre">SummaryWriter()</span></code> for tensorboard will be automatically initialized in <code class="docutils literal notranslate"><span class="pre">engine.run()</span></code> when you are training your model. Therefore, you do not need to do extra jobs. After the training is done, the <code class="docutils literal notranslate"><span class="pre">*tf.events*</span></code> file will be saved in <code class="docutils literal notranslate"><span class="pre">save_dir</span></code>. Then, you just call <code class="docutils literal notranslate"><span class="pre">tensorboard</span> <span class="pre">--logdir=your_save_dir</span></code> in your terminal and visit <code class="docutils literal notranslate"><span class="pre">http://localhost:6006/</span></code> in a web browser. See <a class="reference external" href="https://pytorch.org/docs/stable/tensorboard.html">pytorch tensorboard</a> for further information.</p>
|
|
</div>
|
|
<div class="section" id="visualize-ranking-results">
|
|
<h2><a class="toc-backref" href="#id16">Visualize ranking results</a><a class="headerlink" href="#visualize-ranking-results" title="Permalink to this headline">¶</a></h2>
|
|
<p>This can be achieved by setting <code class="docutils literal notranslate"><span class="pre">visrank</span></code> to true in <code class="docutils literal notranslate"><span class="pre">engine.run()</span></code>. <code class="docutils literal notranslate"><span class="pre">visrank_topk</span></code> determines the top-k images to be visualized (Default is <code class="docutils literal notranslate"><span class="pre">visrank_topk=10</span></code>). Note that <code class="docutils literal notranslate"><span class="pre">visrank</span></code> can only be used in test mode, i.e. <code class="docutils literal notranslate"><span class="pre">test_only=True</span></code> in <code class="docutils literal notranslate"><span class="pre">engine.run()</span></code>. The output will be saved under <code class="docutils literal notranslate"><span class="pre">save_dir/visrank_DATASETNAME</span></code> where each plot contains the top-k similar gallery images given a query. An example is shown below where red and green denote incorrect and correct matches respectively.</p>
|
|
<a class="reference internal image-reference" href="_images/ranking_results.jpg"><img alt="_images/ranking_results.jpg" class="align-center" src="_images/ranking_results.jpg" style="width: 800px;" /></a>
|
|
</div>
|
|
<div class="section" id="visualize-activation-maps">
|
|
<h2><a class="toc-backref" href="#id17">Visualize activation maps</a><a class="headerlink" href="#visualize-activation-maps" title="Permalink to this headline">¶</a></h2>
|
|
<p>To understand where the CNN focuses on to extract features for ReID, you can visualize the activation maps as in <a class="reference external" href="https://arxiv.org/abs/1905.00953">OSNet</a>. This is implemented in <code class="docutils literal notranslate"><span class="pre">tools/visualize_actmap.py</span></code> (check the code for more details). An example running command is</p>
|
|
<div class="highlight-shell notranslate"><div class="highlight"><pre><span></span>python tools/visualize_actmap.py <span class="se">\</span>
|
|
--root <span class="nv">$DATA</span>/reid <span class="se">\</span>
|
|
-d market1501 <span class="se">\</span>
|
|
-m osnet_x1_0 <span class="se">\</span>
|
|
--weights PATH_TO_PRETRAINED_WEIGHTS <span class="se">\</span>
|
|
--save-dir log/visactmap_osnet_x1_0_market1501
|
|
</pre></div>
|
|
</div>
|
|
<p>The output will look like (from left to right: image, activation map, overlapped image)</p>
|
|
<a class="reference internal image-reference" href="_images/actmap.jpg"><img alt="_images/actmap.jpg" class="align-center" src="_images/actmap.jpg" style="width: 300px;" /></a>
|
|
<div class="admonition note">
|
|
<p class="admonition-title">Note</p>
|
|
<p>In order to visualize activation maps, the CNN needs to output the last convolutional feature maps at eval mode. See <code class="docutils literal notranslate"><span class="pre">torchreid/models/osnet.py</span></code> for example.</p>
|
|
</div>
|
|
</div>
|
|
<div class="section" id="use-your-own-dataset">
|
|
<h2><a class="toc-backref" href="#id18">Use your own dataset</a><a class="headerlink" href="#use-your-own-dataset" title="Permalink to this headline">¶</a></h2>
|
|
<ol class="arabic simple">
|
|
<li><p>Write your own dataset class. Below is a template for image dataset. However, it can also be applied to a video dataset class, for which you simply change <code class="docutils literal notranslate"><span class="pre">ImageDataset</span></code> to <code class="docutils literal notranslate"><span class="pre">VideoDataset</span></code>.</p></li>
|
|
</ol>
|
|
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">from</span> <span class="nn">__future__</span> <span class="kn">import</span> <span class="n">absolute_import</span>
|
|
<span class="kn">from</span> <span class="nn">__future__</span> <span class="kn">import</span> <span class="n">print_function</span>
|
|
<span class="kn">from</span> <span class="nn">__future__</span> <span class="kn">import</span> <span class="n">division</span>
|
|
|
|
<span class="kn">import</span> <span class="nn">sys</span>
|
|
<span class="kn">import</span> <span class="nn">os</span>
|
|
<span class="kn">import</span> <span class="nn">os.path</span> <span class="k">as</span> <span class="nn">osp</span>
|
|
|
|
<span class="kn">from</span> <span class="nn">torchreid.data</span> <span class="kn">import</span> <span class="n">ImageDataset</span>
|
|
|
|
|
|
<span class="k">class</span> <span class="nc">NewDataset</span><span class="p">(</span><span class="n">ImageDataset</span><span class="p">):</span>
|
|
<span class="n">dataset_dir</span> <span class="o">=</span> <span class="s1">'new_dataset'</span>
|
|
|
|
<span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">root</span><span class="o">=</span><span class="s1">''</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">root</span> <span class="o">=</span> <span class="n">osp</span><span class="o">.</span><span class="n">abspath</span><span class="p">(</span><span class="n">osp</span><span class="o">.</span><span class="n">expanduser</span><span class="p">(</span><span class="n">root</span><span class="p">))</span>
|
|
<span class="bp">self</span><span class="o">.</span><span class="n">dataset_dir</span> <span class="o">=</span> <span class="n">osp</span><span class="o">.</span><span class="n">join</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">root</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">dataset_dir</span><span class="p">)</span>
|
|
|
|
<span class="c1"># All you need to do here is to generate three lists,</span>
|
|
<span class="c1"># which are train, query and gallery.</span>
|
|
<span class="c1"># Each list contains tuples of (img_path, pid, camid),</span>
|
|
<span class="c1"># where</span>
|
|
<span class="c1"># - img_path (str): absolute path to an image.</span>
|
|
<span class="c1"># - pid (int): person ID, e.g. 0, 1.</span>
|
|
<span class="c1"># - camid (int): camera ID, e.g. 0, 1.</span>
|
|
<span class="c1"># Note that</span>
|
|
<span class="c1"># - pid and camid should be 0-based.</span>
|
|
<span class="c1"># - query and gallery should share the same pid scope (e.g.</span>
|
|
<span class="c1"># pid=0 in query refers to the same person as pid=0 in gallery).</span>
|
|
<span class="c1"># - train, query and gallery share the same camid scope (e.g.</span>
|
|
<span class="c1"># camid=0 in train refers to the same camera as camid=0</span>
|
|
<span class="c1"># in query/gallery).</span>
|
|
<span class="n">train</span> <span class="o">=</span> <span class="o">...</span>
|
|
<span class="n">query</span> <span class="o">=</span> <span class="o">...</span>
|
|
<span class="n">gallery</span> <span class="o">=</span> <span class="o">...</span>
|
|
|
|
<span class="nb">super</span><span class="p">(</span><span class="n">NewDataset</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">train</span><span class="p">,</span> <span class="n">query</span><span class="p">,</span> <span class="n">gallery</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
|
|
</pre></div>
|
|
</div>
|
|
<ol class="arabic simple" start="2">
|
|
<li><p>Register your dataset.</p></li>
|
|
</ol>
|
|
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">import</span> <span class="nn">torchreid</span>
|
|
<span class="n">torchreid</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">register_image_dataset</span><span class="p">(</span><span class="s1">'new_dataset'</span><span class="p">,</span> <span class="n">NewDataset</span><span class="p">)</span>
|
|
</pre></div>
|
|
</div>
|
|
<ol class="arabic simple" start="3">
|
|
<li><p>Initialize a data manager with your dataset.</p></li>
|
|
</ol>
|
|
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="c1"># use your own dataset only</span>
|
|
<span class="n">datamanager</span> <span class="o">=</span> <span class="n">torchreid</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">ImageDataManager</span><span class="p">(</span>
|
|
<span class="n">root</span><span class="o">=</span><span class="s1">'reid-data'</span><span class="p">,</span>
|
|
<span class="n">sources</span><span class="o">=</span><span class="s1">'new_dataset'</span>
|
|
<span class="p">)</span>
|
|
<span class="c1"># combine with other datasets</span>
|
|
<span class="n">datamanager</span> <span class="o">=</span> <span class="n">torchreid</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">ImageDataManager</span><span class="p">(</span>
|
|
<span class="n">root</span><span class="o">=</span><span class="s1">'reid-data'</span><span class="p">,</span>
|
|
<span class="n">sources</span><span class="o">=</span><span class="p">[</span><span class="s1">'new_dataset'</span><span class="p">,</span> <span class="s1">'dukemtmcreid'</span><span class="p">]</span>
|
|
<span class="p">)</span>
|
|
<span class="c1"># cross-dataset evaluation</span>
|
|
<span class="n">datamanager</span> <span class="o">=</span> <span class="n">torchreid</span><span class="o">.</span><span class="n">data</span><span class="o">.</span><span class="n">ImageDataManager</span><span class="p">(</span>
|
|
<span class="n">root</span><span class="o">=</span><span class="s1">'reid-data'</span><span class="p">,</span>
|
|
<span class="n">sources</span><span class="o">=</span><span class="p">[</span><span class="s1">'new_dataset'</span><span class="p">,</span> <span class="s1">'dukemtmcreid'</span><span class="p">],</span>
|
|
<span class="n">targets</span><span class="o">=</span><span class="s1">'market1501'</span> <span class="c1"># or targets=['market1501', 'cuhk03']</span>
|
|
<span class="p">)</span>
|
|
</pre></div>
|
|
</div>
|
|
</div>
|
|
<div class="section" id="design-your-own-engine">
|
|
<h2><a class="toc-backref" href="#id19">Design your own Engine</a><a class="headerlink" href="#design-your-own-engine" title="Permalink to this headline">¶</a></h2>
|
|
<p>A new Engine should be designed if you have your own loss function. The base Engine class <code class="docutils literal notranslate"><span class="pre">torchreid.engine.Engine</span></code> has implemented some generic methods which you can inherit to avoid re-writing. Please refer to the source code for more details. You are suggested to see how <code class="docutils literal notranslate"><span class="pre">ImageSoftmaxEngine</span></code> and <code class="docutils literal notranslate"><span class="pre">ImageTripletEngine</span></code> are constructed (also <code class="docutils literal notranslate"><span class="pre">VideoSoftmaxEngine</span></code> and <code class="docutils literal notranslate"><span class="pre">VideoTripletEngine</span></code>). All you need to implement might be just a <code class="docutils literal notranslate"><span class="pre">forward_backward()</span></code> function.</p>
|
|
</div>
|
|
<div class="section" id="use-torchreid-as-a-feature-extractor-in-your-projects">
|
|
<h2><a class="toc-backref" href="#id20">Use Torchreid as a feature extractor in your projects</a><a class="headerlink" href="#use-torchreid-as-a-feature-extractor-in-your-projects" title="Permalink to this headline">¶</a></h2>
|
|
<p>We have provided a simple API for feature extraction, which accepts input of various types such as a list of image paths or numpy arrays. More details can be found in the code at <code class="docutils literal notranslate"><span class="pre">torchreid/utils/feature_extractor.py</span></code>. Here we show a simple example of how to extract features given a list of image paths.</p>
|
|
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">from</span> <span class="nn">torchreid.utils</span> <span class="kn">import</span> <span class="n">FeatureExtractor</span>
|
|
|
|
<span class="n">extractor</span> <span class="o">=</span> <span class="n">FeatureExtractor</span><span class="p">(</span>
|
|
<span class="n">model_name</span><span class="o">=</span><span class="s1">'osnet_x1_0'</span><span class="p">,</span>
|
|
<span class="n">model_path</span><span class="o">=</span><span class="s1">'a/b/c/model.pth.tar'</span><span class="p">,</span>
|
|
<span class="n">device</span><span class="o">=</span><span class="s1">'cuda'</span>
|
|
<span class="p">)</span>
|
|
|
|
<span class="n">image_list</span> <span class="o">=</span> <span class="p">[</span>
|
|
<span class="s1">'a/b/c/image001.jpg'</span><span class="p">,</span>
|
|
<span class="s1">'a/b/c/image002.jpg'</span><span class="p">,</span>
|
|
<span class="s1">'a/b/c/image003.jpg'</span><span class="p">,</span>
|
|
<span class="s1">'a/b/c/image004.jpg'</span><span class="p">,</span>
|
|
<span class="s1">'a/b/c/image005.jpg'</span>
|
|
<span class="p">]</span>
|
|
|
|
<span class="n">features</span> <span class="o">=</span> <span class="n">extractor</span><span class="p">(</span><span class="n">image_list</span><span class="p">)</span>
|
|
<span class="nb">print</span><span class="p">(</span><span class="n">features</span><span class="o">.</span><span class="n">shape</span><span class="p">)</span> <span class="c1"># output (5, 512)</span>
|
|
</pre></div>
|
|
</div>
|
|
</div>
|
|
</div>
|
|
|
|
|
|
</div>
|
|
|
|
</div>
|
|
<footer>
|
|
|
|
<div class="rst-footer-buttons" role="navigation" aria-label="footer navigation">
|
|
|
|
<a href="datasets.html" class="btn btn-neutral float-right" title="Datasets" accesskey="n" rel="next">Next <span class="fa fa-arrow-circle-right"></span></a>
|
|
|
|
|
|
<a href="index.html" class="btn btn-neutral float-left" title="Torchreid" accesskey="p" rel="prev"><span class="fa fa-arrow-circle-left"></span> Previous</a>
|
|
|
|
</div>
|
|
|
|
|
|
<hr/>
|
|
|
|
<div role="contentinfo">
|
|
<p>
|
|
|
|
© Copyright 2019, Kaiyang Zhou
|
|
|
|
</p>
|
|
</div>
|
|
|
|
|
|
|
|
Built with <a href="http://sphinx-doc.org/">Sphinx</a> using a
|
|
|
|
<a href="https://github.com/rtfd/sphinx_rtd_theme">theme</a>
|
|
|
|
provided by <a href="https://readthedocs.org">Read the Docs</a>.
|
|
|
|
</footer>
|
|
|
|
</div>
|
|
</div>
|
|
|
|
</section>
|
|
|
|
</div>
|
|
|
|
|
|
<script type="text/javascript">
|
|
jQuery(function () {
|
|
SphinxRtdTheme.Navigation.enable(true);
|
|
});
|
|
</script>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
</body>
|
|
</html> |