2019-03-25 01:22:43 +08:00
<!DOCTYPE html>
<!-- [if IE 8]><html class="no - js lt - ie9" lang="en" > <![endif] -->
<!-- [if gt IE 8]><! --> < html class = "no-js" lang = "en" > <!-- <![endif] -->
< head >
< meta charset = "utf-8" >
< meta name = "viewport" content = "width=device-width, initial-scale=1.0" >
2019-04-18 19:12:17 +08:00
< title > torchreid.utils.torchtools — torchreid 0.7.3 documentation< / title >
2019-03-25 01:22:43 +08:00
< script type = "text/javascript" src = "../../../_static/js/modernizr.min.js" > < / script >
< script type = "text/javascript" id = "documentation_options" data-url_root = "../../../" src = "../../../_static/documentation_options.js" > < / script >
< script type = "text/javascript" src = "../../../_static/jquery.js" > < / script >
< script type = "text/javascript" src = "../../../_static/underscore.js" > < / script >
< script type = "text/javascript" src = "../../../_static/doctools.js" > < / script >
< script type = "text/javascript" src = "../../../_static/language_data.js" > < / script >
< script async = "async" type = "text/javascript" src = "https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/latest.js?config=TeX-AMS-MML_HTMLorMML" > < / script >
< script type = "text/javascript" src = "../../../_static/js/theme.js" > < / script >
< link rel = "stylesheet" href = "../../../_static/css/theme.css" type = "text/css" / >
< link rel = "stylesheet" href = "../../../_static/pygments.css" type = "text/css" / >
< link rel = "index" title = "Index" href = "../../../genindex.html" / >
< link rel = "search" title = "Search" href = "../../../search.html" / >
< / head >
< body class = "wy-body-for-nav" >
< div class = "wy-grid-for-nav" >
< nav data-toggle = "wy-nav-shift" class = "wy-nav-side" >
< div class = "wy-side-scroll" >
< div class = "wy-side-nav-search" >
< a href = "../../../index.html" class = "icon icon-home" > torchreid
< / a >
< div class = "version" >
2019-04-18 19:12:17 +08:00
0.7.3
2019-03-25 01:22:43 +08:00
< / div >
< div role = "search" >
< form id = "rtd-search-form" class = "wy-form" action = "../../../search.html" method = "get" >
< input type = "text" name = "q" placeholder = "Search docs" / >
< input type = "hidden" name = "check_keywords" value = "yes" / >
< input type = "hidden" name = "area" value = "default" / >
< / form >
< / div >
< / div >
< div class = "wy-menu wy-menu-vertical" data-spy = "affix" role = "navigation" aria-label = "main navigation" >
< ul >
< li class = "toctree-l1" > < a class = "reference internal" href = "../../../user_guide.html" > How-to< / a > < / li >
< li class = "toctree-l1" > < a class = "reference internal" href = "../../../datasets.html" > Datasets< / a > < / li >
< li class = "toctree-l1" > < a class = "reference internal" href = "../../../evaluation.html" > Evaluation< / a > < / li >
< / ul >
< p class = "caption" > < span class = "caption-text" > Package Reference< / span > < / p >
< ul >
< li class = "toctree-l1" > < a class = "reference internal" href = "../../../pkg/data.html" > torchreid.data< / a > < / li >
< li class = "toctree-l1" > < a class = "reference internal" href = "../../../pkg/engine.html" > torchreid.engine< / a > < / li >
< li class = "toctree-l1" > < a class = "reference internal" href = "../../../pkg/losses.html" > torchreid.losses< / a > < / li >
< li class = "toctree-l1" > < a class = "reference internal" href = "../../../pkg/metrics.html" > torchreid.metrics< / a > < / li >
< li class = "toctree-l1" > < a class = "reference internal" href = "../../../pkg/models.html" > torchreid.models< / a > < / li >
< li class = "toctree-l1" > < a class = "reference internal" href = "../../../pkg/optim.html" > torchreid.optim< / a > < / li >
< li class = "toctree-l1" > < a class = "reference internal" href = "../../../pkg/utils.html" > torchreid.utils< / a > < / li >
< / ul >
< p class = "caption" > < span class = "caption-text" > Resources< / span > < / p >
< ul >
< li class = "toctree-l1" > < a class = "reference internal" href = "../../../AWESOME_REID.html" > Awesome-ReID< / a > < / li >
< li class = "toctree-l1" > < a class = "reference internal" href = "../../../MODEL_ZOO.html" > Model Zoo< / a > < / li >
< / ul >
< / div >
< / div >
< / nav >
< section data-toggle = "wy-nav-shift" class = "wy-nav-content-wrap" >
< nav class = "wy-nav-top" aria-label = "top navigation" >
< i data-toggle = "wy-nav-top" class = "fa fa-bars" > < / i >
< a href = "../../../index.html" > torchreid< / a >
< / nav >
< div class = "wy-nav-content" >
< div class = "rst-content" >
< div role = "navigation" aria-label = "breadcrumbs navigation" >
< ul class = "wy-breadcrumbs" >
< li > < a href = "../../../index.html" > Docs< / a > » < / li >
< li > < a href = "../../index.html" > Module code< / a > » < / li >
< li > torchreid.utils.torchtools< / li >
< li class = "wy-breadcrumbs-aside" >
< / li >
< / ul >
< hr / >
< / div >
< div role = "main" class = "document" itemscope = "itemscope" itemtype = "http://schema.org/Article" >
< div itemprop = "articleBody" >
< h1 > Source code for torchreid.utils.torchtools< / h1 > < div class = "highlight" > < pre >
< span > < / span > < span class = "kn" > from< / span > < span class = "nn" > __future__< / span > < span class = "k" > import< / span > < span class = "n" > absolute_import< / span >
< span class = "kn" > from< / span > < span class = "nn" > __future__< / span > < span class = "k" > import< / span > < span class = "n" > print_function< / span >
< span class = "kn" > from< / span > < span class = "nn" > __future__< / span > < span class = "k" > import< / span > < span class = "n" > division< / span >
< span class = "n" > __all__< / span > < span class = "o" > =< / span > < span class = "p" > [< / span > < span class = "s1" > ' save_checkpoint' < / span > < span class = "p" > ,< / span > < span class = "s1" > ' load_checkpoint' < / span > < span class = "p" > ,< / span > < span class = "s1" > ' resume_from_checkpoint' < / span > < span class = "p" > ,< / span >
< span class = "s1" > ' open_all_layers' < / span > < span class = "p" > ,< / span > < span class = "s1" > ' open_specified_layers' < / span > < span class = "p" > ,< / span > < span class = "s1" > ' count_num_param' < / span > < span class = "p" > ,< / span >
< span class = "s1" > ' load_pretrained_weights' < / span > < span class = "p" > ]< / span >
< span class = "kn" > from< / span > < span class = "nn" > collections< / span > < span class = "k" > import< / span > < span class = "n" > OrderedDict< / span >
< span class = "kn" > import< / span > < span class = "nn" > shutil< / span >
< span class = "kn" > import< / span > < span class = "nn" > warnings< / span >
< span class = "kn" > import< / span > < span class = "nn" > os< / span >
< span class = "kn" > import< / span > < span class = "nn" > os.path< / span > < span class = "k" > as< / span > < span class = "nn" > osp< / span >
< span class = "kn" > from< / span > < span class = "nn" > functools< / span > < span class = "k" > import< / span > < span class = "n" > partial< / span >
< span class = "kn" > import< / span > < span class = "nn" > pickle< / span >
< span class = "kn" > import< / span > < span class = "nn" > torch< / span >
< span class = "kn" > import< / span > < span class = "nn" > torch.nn< / span > < span class = "k" > as< / span > < span class = "nn" > nn< / span >
< span class = "kn" > from< / span > < span class = "nn" > .tools< / span > < span class = "k" > import< / span > < span class = "n" > mkdir_if_missing< / span >
< div class = "viewcode-block" id = "save_checkpoint" > < a class = "viewcode-back" href = "../../../pkg/utils.html#torchreid.utils.torchtools.save_checkpoint" > [docs]< / a > < span class = "k" > def< / span > < span class = "nf" > save_checkpoint< / span > < span class = "p" > (< / span > < span class = "n" > state< / span > < span class = "p" > ,< / span > < span class = "n" > save_dir< / span > < span class = "p" > ,< / span > < span class = "n" > is_best< / span > < span class = "o" > =< / span > < span class = "kc" > False< / span > < span class = "p" > ,< / span > < span class = "n" > remove_module_from_keys< / span > < span class = "o" > =< / span > < span class = "kc" > False< / span > < span class = "p" > ):< / span >
< span class = "sd" > " " " Saves checkpoint.< / span >
< span class = "sd" > Args:< / span >
< span class = "sd" > state (dict): dictionary.< / span >
< span class = "sd" > save_dir (str): directory to save checkpoint.< / span >
< span class = "sd" > is_best (bool, optional): if True, this checkpoint will be copied and named< / span >
< span class = "sd" > " model-best.pth.tar" . Default is False.< / span >
< span class = "sd" > remove_module_from_keys (bool, optional): whether to remove " module." < / span >
< span class = "sd" > from layer names. Default is False.< / span >
< span class = "sd" > Examples::< / span >
< span class = "sd" > > > > state = {< / span >
< span class = "sd" > > > > ' state_dict' : model.state_dict(),< / span >
< span class = "sd" > > > > ' epoch' : 10,< / span >
< span class = "sd" > > > > ' rank1' : 0.5,< / span >
< span class = "sd" > > > > ' optimizer' : optimizer.state_dict()< / span >
< span class = "sd" > > > > }< / span >
< span class = "sd" > > > > save_checkpoint(state, ' log/my_model' )< / span >
< span class = "sd" > " " " < / span >
< span class = "n" > mkdir_if_missing< / span > < span class = "p" > (< / span > < span class = "n" > save_dir< / span > < span class = "p" > )< / span >
< span class = "k" > if< / span > < span class = "n" > remove_module_from_keys< / span > < span class = "p" > :< / span >
< span class = "c1" > # remove ' module.' in state_dict' s keys< / span >
< span class = "n" > state_dict< / span > < span class = "o" > =< / span > < span class = "n" > state< / span > < span class = "p" > [< / span > < span class = "s1" > ' state_dict' < / span > < span class = "p" > ]< / span >
< span class = "n" > new_state_dict< / span > < span class = "o" > =< / span > < span class = "n" > OrderedDict< / span > < span class = "p" > ()< / span >
< span class = "k" > for< / span > < span class = "n" > k< / span > < span class = "p" > ,< / span > < span class = "n" > v< / span > < span class = "ow" > in< / span > < span class = "n" > state_dict< / span > < span class = "o" > .< / span > < span class = "n" > items< / span > < span class = "p" > ():< / span >
< span class = "k" > if< / span > < span class = "n" > k< / span > < span class = "o" > .< / span > < span class = "n" > startswith< / span > < span class = "p" > (< / span > < span class = "s1" > ' module.' < / span > < span class = "p" > ):< / span >
< span class = "n" > k< / span > < span class = "o" > =< / span > < span class = "n" > k< / span > < span class = "p" > [< / span > < span class = "mi" > 7< / span > < span class = "p" > :]< / span >
< span class = "n" > new_state_dict< / span > < span class = "p" > [< / span > < span class = "n" > k< / span > < span class = "p" > ]< / span > < span class = "o" > =< / span > < span class = "n" > v< / span >
< span class = "n" > state< / span > < span class = "p" > [< / span > < span class = "s1" > ' state_dict' < / span > < span class = "p" > ]< / span > < span class = "o" > =< / span > < span class = "n" > new_state_dict< / span >
< span class = "c1" > # save< / span >
< span class = "n" > epoch< / span > < span class = "o" > =< / span > < span class = "n" > state< / span > < span class = "p" > [< / span > < span class = "s1" > ' epoch' < / span > < span class = "p" > ]< / span >
< span class = "n" > fpath< / span > < span class = "o" > =< / span > < span class = "n" > osp< / span > < span class = "o" > .< / span > < span class = "n" > join< / span > < span class = "p" > (< / span > < span class = "n" > save_dir< / span > < span class = "p" > ,< / span > < span class = "s1" > ' model.pth.tar-' < / span > < span class = "o" > +< / span > < span class = "nb" > str< / span > < span class = "p" > (< / span > < span class = "n" > epoch< / span > < span class = "p" > ))< / span >
< span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > save< / span > < span class = "p" > (< / span > < span class = "n" > state< / span > < span class = "p" > ,< / span > < span class = "n" > fpath< / span > < span class = "p" > )< / span >
< span class = "nb" > print< / span > < span class = "p" > (< / span > < span class = "s1" > ' Checkpoint saved to " < / span > < span class = "si" > {}< / span > < span class = "s1" > " ' < / span > < span class = "o" > .< / span > < span class = "n" > format< / span > < span class = "p" > (< / span > < span class = "n" > fpath< / span > < span class = "p" > ))< / span >
< span class = "k" > if< / span > < span class = "n" > is_best< / span > < span class = "p" > :< / span >
< span class = "n" > shutil< / span > < span class = "o" > .< / span > < span class = "n" > copy< / span > < span class = "p" > (< / span > < span class = "n" > fpath< / span > < span class = "p" > ,< / span > < span class = "n" > osp< / span > < span class = "o" > .< / span > < span class = "n" > join< / span > < span class = "p" > (< / span > < span class = "n" > osp< / span > < span class = "o" > .< / span > < span class = "n" > dirname< / span > < span class = "p" > (< / span > < span class = "n" > fpath< / span > < span class = "p" > ),< / span > < span class = "s1" > ' model-best.pth.tar' < / span > < span class = "p" > ))< / span > < / div >
< div class = "viewcode-block" id = "load_checkpoint" > < a class = "viewcode-back" href = "../../../pkg/utils.html#torchreid.utils.torchtools.load_checkpoint" > [docs]< / a > < span class = "k" > def< / span > < span class = "nf" > load_checkpoint< / span > < span class = "p" > (< / span > < span class = "n" > fpath< / span > < span class = "p" > ):< / span >
< span class = "sd" > " " " Loads checkpoint.< / span >
< span class = "sd" > ``UnicodeDecodeError`` can be well handled, which means< / span >
< span class = "sd" > python2-saved files can be read from python3.< / span >
< span class = "sd" > Args:< / span >
< span class = "sd" > fpath (str): path to checkpoint.< / span >
< span class = "sd" > Returns:< / span >
< span class = "sd" > dict< / span >
< span class = "sd" > Examples::< / span >
< span class = "sd" > > > > from torchreid.utils import load_checkpoint< / span >
< span class = "sd" > > > > fpath = ' log/my_model/model.pth.tar-10' < / span >
< span class = "sd" > > > > checkpoint = load_checkpoint(fpath)< / span >
< span class = "sd" > " " " < / span >
< span class = "n" > map_location< / span > < span class = "o" > =< / span > < span class = "kc" > None< / span > < span class = "k" > if< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > cuda< / span > < span class = "o" > .< / span > < span class = "n" > is_available< / span > < span class = "p" > ()< / span > < span class = "k" > else< / span > < span class = "s1" > ' cpu' < / span >
< span class = "k" > try< / span > < span class = "p" > :< / span >
< span class = "n" > checkpoint< / span > < span class = "o" > =< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > load< / span > < span class = "p" > (< / span > < span class = "n" > fpath< / span > < span class = "p" > ,< / span > < span class = "n" > map_location< / span > < span class = "o" > =< / span > < span class = "n" > map_location< / span > < span class = "p" > )< / span >
< span class = "k" > except< / span > < span class = "ne" > UnicodeDecodeError< / span > < span class = "p" > :< / span >
< span class = "n" > pickle< / span > < span class = "o" > .< / span > < span class = "n" > load< / span > < span class = "o" > =< / span > < span class = "n" > partial< / span > < span class = "p" > (< / span > < span class = "n" > pickle< / span > < span class = "o" > .< / span > < span class = "n" > load< / span > < span class = "p" > ,< / span > < span class = "n" > encoding< / span > < span class = "o" > =< / span > < span class = "s2" > " latin1" < / span > < span class = "p" > )< / span >
< span class = "n" > pickle< / span > < span class = "o" > .< / span > < span class = "n" > Unpickler< / span > < span class = "o" > =< / span > < span class = "n" > partial< / span > < span class = "p" > (< / span > < span class = "n" > pickle< / span > < span class = "o" > .< / span > < span class = "n" > Unpickler< / span > < span class = "p" > ,< / span > < span class = "n" > encoding< / span > < span class = "o" > =< / span > < span class = "s2" > " latin1" < / span > < span class = "p" > )< / span >
< span class = "n" > checkpoint< / span > < span class = "o" > =< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > load< / span > < span class = "p" > (< / span > < span class = "n" > fpath< / span > < span class = "p" > ,< / span > < span class = "n" > pickle_module< / span > < span class = "o" > =< / span > < span class = "n" > pickle< / span > < span class = "p" > ,< / span > < span class = "n" > map_location< / span > < span class = "o" > =< / span > < span class = "n" > map_location< / span > < span class = "p" > )< / span >
< span class = "k" > except< / span > < span class = "ne" > Exception< / span > < span class = "p" > :< / span >
< span class = "nb" > print< / span > < span class = "p" > (< / span > < span class = "s1" > ' Unable to load checkpoint from " < / span > < span class = "si" > {}< / span > < span class = "s1" > " ' < / span > < span class = "o" > .< / span > < span class = "n" > format< / span > < span class = "p" > (< / span > < span class = "n" > fpath< / span > < span class = "p" > ))< / span >
< span class = "k" > raise< / span >
< span class = "k" > return< / span > < span class = "n" > checkpoint< / span > < / div >
< div class = "viewcode-block" id = "resume_from_checkpoint" > < a class = "viewcode-back" href = "../../../pkg/utils.html#torchreid.utils.torchtools.resume_from_checkpoint" > [docs]< / a > < span class = "k" > def< / span > < span class = "nf" > resume_from_checkpoint< / span > < span class = "p" > (< / span > < span class = "n" > fpath< / span > < span class = "p" > ,< / span > < span class = "n" > model< / span > < span class = "p" > ,< / span > < span class = "n" > optimizer< / span > < span class = "o" > =< / span > < span class = "kc" > None< / span > < span class = "p" > ):< / span >
< span class = "sd" > " " " Resumes training from a checkpoint.< / span >
< span class = "sd" > This will load (1) model weights and (2) ``state_dict``< / span >
< span class = "sd" > of optimizer if ``optimizer`` is not None.< / span >
< span class = "sd" > Args:< / span >
< span class = "sd" > fpath (str): path to checkpoint.< / span >
< span class = "sd" > model (nn.Module): model.< / span >
< span class = "sd" > optimizer (Optimizer, optional): an Optimizer.< / span >
< span class = "sd" > Returns:< / span >
< span class = "sd" > int: start_epoch.< / span >
< span class = "sd" > Examples::< / span >
< span class = "sd" > > > > from torchreid.utils import resume_from_checkpoint< / span >
< span class = "sd" > > > > fpath = ' log/my_model/model.pth.tar-10' < / span >
< span class = "sd" > > > > start_epoch = resume_from_checkpoint(fpath, model, optimizer)< / span >
< span class = "sd" > " " " < / span >
< span class = "nb" > print< / span > < span class = "p" > (< / span > < span class = "s1" > ' Loading checkpoint from " < / span > < span class = "si" > {}< / span > < span class = "s1" > " ' < / span > < span class = "o" > .< / span > < span class = "n" > format< / span > < span class = "p" > (< / span > < span class = "n" > fpath< / span > < span class = "p" > ))< / span >
< span class = "n" > checkpoint< / span > < span class = "o" > =< / span > < span class = "n" > load_checkpoint< / span > < span class = "p" > (< / span > < span class = "n" > fpath< / span > < span class = "p" > )< / span >
< span class = "n" > model< / span > < span class = "o" > .< / span > < span class = "n" > load_state_dict< / span > < span class = "p" > (< / span > < span class = "n" > checkpoint< / span > < span class = "p" > [< / span > < span class = "s1" > ' state_dict' < / span > < span class = "p" > ])< / span >
< span class = "nb" > print< / span > < span class = "p" > (< / span > < span class = "s1" > ' Loaded model weights' < / span > < span class = "p" > )< / span >
< span class = "k" > if< / span > < span class = "n" > optimizer< / span > < span class = "ow" > is< / span > < span class = "ow" > not< / span > < span class = "kc" > None< / span > < span class = "ow" > and< / span > < span class = "s1" > ' optimizer' < / span > < span class = "ow" > in< / span > < span class = "n" > checkpoint< / span > < span class = "o" > .< / span > < span class = "n" > keys< / span > < span class = "p" > ():< / span >
< span class = "n" > optimizer< / span > < span class = "o" > .< / span > < span class = "n" > load_state_dict< / span > < span class = "p" > (< / span > < span class = "n" > checkpoint< / span > < span class = "p" > [< / span > < span class = "s1" > ' optimizer' < / span > < span class = "p" > ])< / span >
< span class = "nb" > print< / span > < span class = "p" > (< / span > < span class = "s1" > ' Loaded optimizer' < / span > < span class = "p" > )< / span >
< span class = "n" > start_epoch< / span > < span class = "o" > =< / span > < span class = "n" > checkpoint< / span > < span class = "p" > [< / span > < span class = "s1" > ' epoch' < / span > < span class = "p" > ]< / span >
< span class = "nb" > print< / span > < span class = "p" > (< / span > < span class = "s1" > ' Last epoch = < / span > < span class = "si" > {}< / span > < span class = "s1" > ' < / span > < span class = "o" > .< / span > < span class = "n" > format< / span > < span class = "p" > (< / span > < span class = "n" > start_epoch< / span > < span class = "p" > ))< / span >
< span class = "k" > if< / span > < span class = "s1" > ' rank1' < / span > < span class = "ow" > in< / span > < span class = "n" > checkpoint< / span > < span class = "o" > .< / span > < span class = "n" > keys< / span > < span class = "p" > ():< / span >
< span class = "nb" > print< / span > < span class = "p" > (< / span > < span class = "s1" > ' Last rank1 = < / span > < span class = "si" > {:.1%}< / span > < span class = "s1" > ' < / span > < span class = "o" > .< / span > < span class = "n" > format< / span > < span class = "p" > (< / span > < span class = "n" > checkpoint< / span > < span class = "p" > [< / span > < span class = "s1" > ' rank1' < / span > < span class = "p" > ]))< / span >
< span class = "k" > return< / span > < span class = "n" > start_epoch< / span > < / div >
< span class = "k" > def< / span > < span class = "nf" > adjust_learning_rate< / span > < span class = "p" > (< / span > < span class = "n" > optimizer< / span > < span class = "p" > ,< / span > < span class = "n" > base_lr< / span > < span class = "p" > ,< / span > < span class = "n" > epoch< / span > < span class = "p" > ,< / span > < span class = "n" > stepsize< / span > < span class = "o" > =< / span > < span class = "mi" > 20< / span > < span class = "p" > ,< / span > < span class = "n" > gamma< / span > < span class = "o" > =< / span > < span class = "mf" > 0.1< / span > < span class = "p" > ,< / span >
< span class = "n" > linear_decay< / span > < span class = "o" > =< / span > < span class = "kc" > False< / span > < span class = "p" > ,< / span > < span class = "n" > final_lr< / span > < span class = "o" > =< / span > < span class = "mi" > 0< / span > < span class = "p" > ,< / span > < span class = "n" > max_epoch< / span > < span class = "o" > =< / span > < span class = "mi" > 100< / span > < span class = "p" > ):< / span >
< span class = "sd" > " " " Adjusts learning rate.< / span >
< span class = "sd" > Deprecated.< / span >
< span class = "sd" > " " " < / span >
< span class = "k" > if< / span > < span class = "n" > linear_decay< / span > < span class = "p" > :< / span >
< span class = "c1" > # linearly decay learning rate from base_lr to final_lr< / span >
< span class = "n" > frac_done< / span > < span class = "o" > =< / span > < span class = "n" > epoch< / span > < span class = "o" > /< / span > < span class = "n" > max_epoch< / span >
< span class = "n" > lr< / span > < span class = "o" > =< / span > < span class = "n" > frac_done< / span > < span class = "o" > *< / span > < span class = "n" > final_lr< / span > < span class = "o" > +< / span > < span class = "p" > (< / span > < span class = "mf" > 1.< / span > < span class = "o" > -< / span > < span class = "n" > frac_done< / span > < span class = "p" > )< / span > < span class = "o" > *< / span > < span class = "n" > base_lr< / span >
< span class = "k" > else< / span > < span class = "p" > :< / span >
< span class = "c1" > # decay learning rate by gamma for every stepsize< / span >
< span class = "n" > lr< / span > < span class = "o" > =< / span > < span class = "n" > base_lr< / span > < span class = "o" > *< / span > < span class = "p" > (< / span > < span class = "n" > gamma< / span > < span class = "o" > **< / span > < span class = "p" > (< / span > < span class = "n" > epoch< / span > < span class = "o" > //< / span > < span class = "n" > stepsize< / span > < span class = "p" > ))< / span >
< span class = "k" > for< / span > < span class = "n" > param_group< / span > < span class = "ow" > in< / span > < span class = "n" > optimizer< / span > < span class = "o" > .< / span > < span class = "n" > param_groups< / span > < span class = "p" > :< / span >
< span class = "n" > param_group< / span > < span class = "p" > [< / span > < span class = "s1" > ' lr' < / span > < span class = "p" > ]< / span > < span class = "o" > =< / span > < span class = "n" > lr< / span >
< span class = "k" > def< / span > < span class = "nf" > set_bn_to_eval< / span > < span class = "p" > (< / span > < span class = "n" > m< / span > < span class = "p" > ):< / span >
< span class = "sd" > " " " Sets BatchNorm layers to eval mode." " " < / span >
< span class = "c1" > # 1. no update for running mean and var< / span >
< span class = "c1" > # 2. scale and shift parameters are still trainable< / span >
< span class = "n" > classname< / span > < span class = "o" > =< / span > < span class = "n" > m< / span > < span class = "o" > .< / span > < span class = "vm" > __class__< / span > < span class = "o" > .< / span > < span class = "vm" > __name__< / span >
< span class = "k" > if< / span > < span class = "n" > classname< / span > < span class = "o" > .< / span > < span class = "n" > find< / span > < span class = "p" > (< / span > < span class = "s1" > ' BatchNorm' < / span > < span class = "p" > )< / span > < span class = "o" > !=< / span > < span class = "o" > -< / span > < span class = "mi" > 1< / span > < span class = "p" > :< / span >
< span class = "n" > m< / span > < span class = "o" > .< / span > < span class = "n" > eval< / span > < span class = "p" > ()< / span >
< div class = "viewcode-block" id = "open_all_layers" > < a class = "viewcode-back" href = "../../../pkg/utils.html#torchreid.utils.torchtools.open_all_layers" > [docs]< / a > < span class = "k" > def< / span > < span class = "nf" > open_all_layers< / span > < span class = "p" > (< / span > < span class = "n" > model< / span > < span class = "p" > ):< / span >
< span class = "sd" > " " " Opens all layers in model for training.< / span >
< span class = "sd" > Examples::< / span >
< span class = "sd" > > > > from torchreid.utils import open_all_layers< / span >
< span class = "sd" > > > > open_all_layers(model)< / span >
< span class = "sd" > " " " < / span >
< span class = "n" > model< / span > < span class = "o" > .< / span > < span class = "n" > train< / span > < span class = "p" > ()< / span >
< span class = "k" > for< / span > < span class = "n" > p< / span > < span class = "ow" > in< / span > < span class = "n" > model< / span > < span class = "o" > .< / span > < span class = "n" > parameters< / span > < span class = "p" > ():< / span >
< span class = "n" > p< / span > < span class = "o" > .< / span > < span class = "n" > requires_grad< / span > < span class = "o" > =< / span > < span class = "kc" > True< / span > < / div >
< div class = "viewcode-block" id = "open_specified_layers" > < a class = "viewcode-back" href = "../../../pkg/utils.html#torchreid.utils.torchtools.open_specified_layers" > [docs]< / a > < span class = "k" > def< / span > < span class = "nf" > open_specified_layers< / span > < span class = "p" > (< / span > < span class = "n" > model< / span > < span class = "p" > ,< / span > < span class = "n" > open_layers< / span > < span class = "p" > ):< / span >
< span class = "sd" > " " " Opens specified layers in model for training while keeping< / span >
< span class = "sd" > other layers frozen.< / span >
< span class = "sd" > Args:< / span >
< span class = "sd" > model (nn.Module): neural net model.< / span >
< span class = "sd" > open_layers (str or list): layers open for training.< / span >
< span class = "sd" > Examples::< / span >
< span class = "sd" > > > > from torchreid.utils import open_specified_layers< / span >
< span class = "sd" > > > > # Only model.classifier will be updated.< / span >
< span class = "sd" > > > > open_layers = ' classifier' < / span >
< span class = "sd" > > > > open_specified_layers(model, open_layers)< / span >
< span class = "sd" > > > > # Only model.fc and model.classifier will be updated.< / span >
< span class = "sd" > > > > open_layers = [' fc' , ' classifier' ]< / span >
< span class = "sd" > > > > open_specified_layers(model, open_layers)< / span >
< span class = "sd" > " " " < / span >
< span class = "k" > if< / span > < span class = "nb" > isinstance< / span > < span class = "p" > (< / span > < span class = "n" > model< / span > < span class = "p" > ,< / span > < span class = "n" > nn< / span > < span class = "o" > .< / span > < span class = "n" > DataParallel< / span > < span class = "p" > ):< / span >
< span class = "n" > model< / span > < span class = "o" > =< / span > < span class = "n" > model< / span > < span class = "o" > .< / span > < span class = "n" > module< / span >
< span class = "k" > if< / span > < span class = "nb" > isinstance< / span > < span class = "p" > (< / span > < span class = "n" > open_layers< / span > < span class = "p" > ,< / span > < span class = "nb" > str< / span > < span class = "p" > ):< / span >
< span class = "n" > open_layers< / span > < span class = "o" > =< / span > < span class = "p" > [< / span > < span class = "n" > open_layers< / span > < span class = "p" > ]< / span >
< span class = "k" > for< / span > < span class = "n" > layer< / span > < span class = "ow" > in< / span > < span class = "n" > open_layers< / span > < span class = "p" > :< / span >
< span class = "k" > assert< / span > < span class = "nb" > hasattr< / span > < span class = "p" > (< / span > < span class = "n" > model< / span > < span class = "p" > ,< / span > < span class = "n" > layer< / span > < span class = "p" > ),< / span > < span class = "s1" > ' " < / span > < span class = "si" > {}< / span > < span class = "s1" > " is not an attribute of the model, please provide the correct name' < / span > < span class = "o" > .< / span > < span class = "n" > format< / span > < span class = "p" > (< / span > < span class = "n" > layer< / span > < span class = "p" > )< / span >
< span class = "k" > for< / span > < span class = "n" > name< / span > < span class = "p" > ,< / span > < span class = "n" > module< / span > < span class = "ow" > in< / span > < span class = "n" > model< / span > < span class = "o" > .< / span > < span class = "n" > named_children< / span > < span class = "p" > ():< / span >
< span class = "k" > if< / span > < span class = "n" > name< / span > < span class = "ow" > in< / span > < span class = "n" > open_layers< / span > < span class = "p" > :< / span >
< span class = "n" > module< / span > < span class = "o" > .< / span > < span class = "n" > train< / span > < span class = "p" > ()< / span >
< span class = "k" > for< / span > < span class = "n" > p< / span > < span class = "ow" > in< / span > < span class = "n" > module< / span > < span class = "o" > .< / span > < span class = "n" > parameters< / span > < span class = "p" > ():< / span >
< span class = "n" > p< / span > < span class = "o" > .< / span > < span class = "n" > requires_grad< / span > < span class = "o" > =< / span > < span class = "kc" > True< / span >
< span class = "k" > else< / span > < span class = "p" > :< / span >
< span class = "n" > module< / span > < span class = "o" > .< / span > < span class = "n" > eval< / span > < span class = "p" > ()< / span >
< span class = "k" > for< / span > < span class = "n" > p< / span > < span class = "ow" > in< / span > < span class = "n" > module< / span > < span class = "o" > .< / span > < span class = "n" > parameters< / span > < span class = "p" > ():< / span >
< span class = "n" > p< / span > < span class = "o" > .< / span > < span class = "n" > requires_grad< / span > < span class = "o" > =< / span > < span class = "kc" > False< / span > < / div >
< div class = "viewcode-block" id = "count_num_param" > < a class = "viewcode-back" href = "../../../pkg/utils.html#torchreid.utils.torchtools.count_num_param" > [docs]< / a > < span class = "k" > def< / span > < span class = "nf" > count_num_param< / span > < span class = "p" > (< / span > < span class = "n" > model< / span > < span class = "p" > ):< / span >
< span class = "sd" > " " " Counts number of parameters in a model.< / span >
< span class = "sd" > Examples::< / span >
< span class = "sd" > > > > from torchreid.utils import count_num_param< / span >
< span class = "sd" > > > > model_size = count_num_param(model)< / span >
< span class = "sd" > " " " < / span >
< span class = "n" > num_param< / span > < span class = "o" > =< / span > < span class = "nb" > sum< / span > < span class = "p" > (< / span > < span class = "n" > p< / span > < span class = "o" > .< / span > < span class = "n" > numel< / span > < span class = "p" > ()< / span > < span class = "k" > for< / span > < span class = "n" > p< / span > < span class = "ow" > in< / span > < span class = "n" > model< / span > < span class = "o" > .< / span > < span class = "n" > parameters< / span > < span class = "p" > ())< / span > < span class = "o" > /< / span > < span class = "mf" > 1e+06< / span >
< span class = "k" > if< / span > < span class = "nb" > isinstance< / span > < span class = "p" > (< / span > < span class = "n" > model< / span > < span class = "p" > ,< / span > < span class = "n" > nn< / span > < span class = "o" > .< / span > < span class = "n" > DataParallel< / span > < span class = "p" > ):< / span >
< span class = "n" > model< / span > < span class = "o" > =< / span > < span class = "n" > model< / span > < span class = "o" > .< / span > < span class = "n" > module< / span >
< span class = "k" > if< / span > < span class = "nb" > hasattr< / span > < span class = "p" > (< / span > < span class = "n" > model< / span > < span class = "p" > ,< / span > < span class = "s1" > ' classifier' < / span > < span class = "p" > )< / span > < span class = "ow" > and< / span > < span class = "nb" > isinstance< / span > < span class = "p" > (< / span > < span class = "n" > model< / span > < span class = "o" > .< / span > < span class = "n" > classifier< / span > < span class = "p" > ,< / span > < span class = "n" > nn< / span > < span class = "o" > .< / span > < span class = "n" > Module< / span > < span class = "p" > ):< / span >
< span class = "c1" > # we ignore the classifier because it is unused at test time< / span >
< span class = "n" > num_param< / span > < span class = "o" > -=< / span > < span class = "nb" > sum< / span > < span class = "p" > (< / span > < span class = "n" > p< / span > < span class = "o" > .< / span > < span class = "n" > numel< / span > < span class = "p" > ()< / span > < span class = "k" > for< / span > < span class = "n" > p< / span > < span class = "ow" > in< / span > < span class = "n" > model< / span > < span class = "o" > .< / span > < span class = "n" > classifier< / span > < span class = "o" > .< / span > < span class = "n" > parameters< / span > < span class = "p" > ())< / span > < span class = "o" > /< / span > < span class = "mf" > 1e+06< / span >
< span class = "k" > return< / span > < span class = "n" > num_param< / span > < / div >
< div class = "viewcode-block" id = "load_pretrained_weights" > < a class = "viewcode-back" href = "../../../pkg/utils.html#torchreid.utils.torchtools.load_pretrained_weights" > [docs]< / a > < span class = "k" > def< / span > < span class = "nf" > load_pretrained_weights< / span > < span class = "p" > (< / span > < span class = "n" > model< / span > < span class = "p" > ,< / span > < span class = "n" > weight_path< / span > < span class = "p" > ):< / span >
< span class = "sd" > " " " Loads pretrianed weights to model.< / span >
< span class = "sd" > Features::< / span >
< span class = "sd" > - Incompatible layers (unmatched in name or size) will be ignored.< / span >
< span class = "sd" > - Can automatically deal with keys containing " module." .< / span >
< span class = "sd" > Args:< / span >
< span class = "sd" > model (nn.Module): model.< / span >
< span class = "sd" > weight_path (str): path to pretrained weights.< / span >
< span class = "sd" > Examples::< / span >
< span class = "sd" > > > > from torchreid.utils import load_pretrained_weights< / span >
< span class = "sd" > > > > weight_path = ' log/my_model/model-best.pth.tar' < / span >
< span class = "sd" > > > > load_pretrained_weights(model, weight_path)< / span >
< span class = "sd" > " " " < / span >
< span class = "n" > checkpoint< / span > < span class = "o" > =< / span > < span class = "n" > load_checkpoint< / span > < span class = "p" > (< / span > < span class = "n" > weight_path< / span > < span class = "p" > )< / span >
< span class = "k" > if< / span > < span class = "s1" > ' state_dict' < / span > < span class = "ow" > in< / span > < span class = "n" > checkpoint< / span > < span class = "p" > :< / span >
< span class = "n" > state_dict< / span > < span class = "o" > =< / span > < span class = "n" > checkpoint< / span > < span class = "p" > [< / span > < span class = "s1" > ' state_dict' < / span > < span class = "p" > ]< / span >
< span class = "k" > else< / span > < span class = "p" > :< / span >
< span class = "n" > state_dict< / span > < span class = "o" > =< / span > < span class = "n" > checkpoint< / span >
< span class = "n" > model_dict< / span > < span class = "o" > =< / span > < span class = "n" > model< / span > < span class = "o" > .< / span > < span class = "n" > state_dict< / span > < span class = "p" > ()< / span >
< span class = "n" > new_state_dict< / span > < span class = "o" > =< / span > < span class = "n" > OrderedDict< / span > < span class = "p" > ()< / span >
< span class = "n" > matched_layers< / span > < span class = "p" > ,< / span > < span class = "n" > discarded_layers< / span > < span class = "o" > =< / span > < span class = "p" > [],< / span > < span class = "p" > []< / span >
< span class = "k" > for< / span > < span class = "n" > k< / span > < span class = "p" > ,< / span > < span class = "n" > v< / span > < span class = "ow" > in< / span > < span class = "n" > state_dict< / span > < span class = "o" > .< / span > < span class = "n" > items< / span > < span class = "p" > ():< / span >
< span class = "k" > if< / span > < span class = "n" > k< / span > < span class = "o" > .< / span > < span class = "n" > startswith< / span > < span class = "p" > (< / span > < span class = "s1" > ' module.' < / span > < span class = "p" > ):< / span >
< span class = "n" > k< / span > < span class = "o" > =< / span > < span class = "n" > k< / span > < span class = "p" > [< / span > < span class = "mi" > 7< / span > < span class = "p" > :]< / span > < span class = "c1" > # discard module.< / span >
< span class = "k" > if< / span > < span class = "n" > k< / span > < span class = "ow" > in< / span > < span class = "n" > model_dict< / span > < span class = "ow" > and< / span > < span class = "n" > model_dict< / span > < span class = "p" > [< / span > < span class = "n" > k< / span > < span class = "p" > ]< / span > < span class = "o" > .< / span > < span class = "n" > size< / span > < span class = "p" > ()< / span > < span class = "o" > ==< / span > < span class = "n" > v< / span > < span class = "o" > .< / span > < span class = "n" > size< / span > < span class = "p" > ():< / span >
< span class = "n" > new_state_dict< / span > < span class = "p" > [< / span > < span class = "n" > k< / span > < span class = "p" > ]< / span > < span class = "o" > =< / span > < span class = "n" > v< / span >
< span class = "n" > matched_layers< / span > < span class = "o" > .< / span > < span class = "n" > append< / span > < span class = "p" > (< / span > < span class = "n" > k< / span > < span class = "p" > )< / span >
< span class = "k" > else< / span > < span class = "p" > :< / span >
< span class = "n" > discarded_layers< / span > < span class = "o" > .< / span > < span class = "n" > append< / span > < span class = "p" > (< / span > < span class = "n" > k< / span > < span class = "p" > )< / span >
< span class = "n" > model_dict< / span > < span class = "o" > .< / span > < span class = "n" > update< / span > < span class = "p" > (< / span > < span class = "n" > new_state_dict< / span > < span class = "p" > )< / span >
< span class = "n" > model< / span > < span class = "o" > .< / span > < span class = "n" > load_state_dict< / span > < span class = "p" > (< / span > < span class = "n" > model_dict< / span > < span class = "p" > )< / span >
< span class = "k" > if< / span > < span class = "nb" > len< / span > < span class = "p" > (< / span > < span class = "n" > matched_layers< / span > < span class = "p" > )< / span > < span class = "o" > ==< / span > < span class = "mi" > 0< / span > < span class = "p" > :< / span >
< span class = "n" > warnings< / span > < span class = "o" > .< / span > < span class = "n" > warn< / span > < span class = "p" > (< / span >
< span class = "s1" > ' The pretrained weights " < / span > < span class = "si" > {}< / span > < span class = "s1" > " cannot be loaded, ' < / span >
< span class = "s1" > ' please check the key names manually ' < / span >
< span class = "s1" > ' (** ignored and continue **)' < / span > < span class = "o" > .< / span > < span class = "n" > format< / span > < span class = "p" > (< / span > < span class = "n" > weight_path< / span > < span class = "p" > ))< / span >
< span class = "k" > else< / span > < span class = "p" > :< / span >
< span class = "nb" > print< / span > < span class = "p" > (< / span > < span class = "s1" > ' Successfully loaded pretrained weights from " < / span > < span class = "si" > {}< / span > < span class = "s1" > " ' < / span > < span class = "o" > .< / span > < span class = "n" > format< / span > < span class = "p" > (< / span > < span class = "n" > weight_path< / span > < span class = "p" > ))< / span >
< span class = "k" > if< / span > < span class = "nb" > len< / span > < span class = "p" > (< / span > < span class = "n" > discarded_layers< / span > < span class = "p" > )< / span > < span class = "o" > > < / span > < span class = "mi" > 0< / span > < span class = "p" > :< / span >
< span class = "nb" > print< / span > < span class = "p" > (< / span > < span class = "s1" > ' ** The following layers are discarded ' < / span >
< span class = "s1" > ' due to unmatched keys or layer size: < / span > < span class = "si" > {}< / span > < span class = "s1" > ' < / span > < span class = "o" > .< / span > < span class = "n" > format< / span > < span class = "p" > (< / span > < span class = "n" > discarded_layers< / span > < span class = "p" > ))< / span > < / div >
< / pre > < / div >
< / div >
< / div >
< footer >
< hr / >
< div role = "contentinfo" >
< p >
© Copyright 2019, Kaiyang Zhou
< / p >
< / div >
Built with < a href = "http://sphinx-doc.org/" > Sphinx< / a > using a < a href = "https://github.com/rtfd/sphinx_rtd_theme" > theme< / a > provided by < a href = "https://readthedocs.org" > Read the Docs< / a > .
< / footer >
< / div >
< / div >
< / section >
< / div >
< script type = "text/javascript" >
jQuery(function () {
SphinxRtdTheme.Navigation.enable(true);
});
< / script >
< / body >
< / html >