2019-05-22 23:18:39 +08:00
<!DOCTYPE html>
<!-- [if IE 8]><html class="no - js lt - ie9" lang="en" > <![endif] -->
<!-- [if gt IE 8]><! --> < html class = "no-js" lang = "en" > <!-- <![endif] -->
< head >
< meta charset = "utf-8" >
< meta name = "viewport" content = "width=device-width, initial-scale=1.0" >
2019-05-24 23:30:24 +08:00
< title > torchreid.utils.model_complexity — torchreid 0.7.7 documentation< / title >
2019-05-22 23:18:39 +08:00
< script type = "text/javascript" src = "../../../_static/js/modernizr.min.js" > < / script >
< script type = "text/javascript" id = "documentation_options" data-url_root = "../../../" src = "../../../_static/documentation_options.js" > < / script >
< script type = "text/javascript" src = "../../../_static/jquery.js" > < / script >
< script type = "text/javascript" src = "../../../_static/underscore.js" > < / script >
< script type = "text/javascript" src = "../../../_static/doctools.js" > < / script >
< script type = "text/javascript" src = "../../../_static/language_data.js" > < / script >
< script async = "async" type = "text/javascript" src = "https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.5/latest.js?config=TeX-AMS-MML_HTMLorMML" > < / script >
< script type = "text/javascript" src = "../../../_static/js/theme.js" > < / script >
< link rel = "stylesheet" href = "../../../_static/css/theme.css" type = "text/css" / >
< link rel = "stylesheet" href = "../../../_static/pygments.css" type = "text/css" / >
< link rel = "index" title = "Index" href = "../../../genindex.html" / >
< link rel = "search" title = "Search" href = "../../../search.html" / >
< / head >
< body class = "wy-body-for-nav" >
< div class = "wy-grid-for-nav" >
< nav data-toggle = "wy-nav-shift" class = "wy-nav-side" >
< div class = "wy-side-scroll" >
< div class = "wy-side-nav-search" >
< a href = "../../../index.html" class = "icon icon-home" > torchreid
< / a >
< div class = "version" >
2019-05-24 23:30:24 +08:00
0.7.7
2019-05-22 23:18:39 +08:00
< / div >
< div role = "search" >
< form id = "rtd-search-form" class = "wy-form" action = "../../../search.html" method = "get" >
< input type = "text" name = "q" placeholder = "Search docs" / >
< input type = "hidden" name = "check_keywords" value = "yes" / >
< input type = "hidden" name = "area" value = "default" / >
< / form >
< / div >
< / div >
< div class = "wy-menu wy-menu-vertical" data-spy = "affix" role = "navigation" aria-label = "main navigation" >
< ul >
< li class = "toctree-l1" > < a class = "reference internal" href = "../../../user_guide.html" > How-to< / a > < / li >
< li class = "toctree-l1" > < a class = "reference internal" href = "../../../datasets.html" > Datasets< / a > < / li >
< li class = "toctree-l1" > < a class = "reference internal" href = "../../../evaluation.html" > Evaluation< / a > < / li >
< / ul >
< p class = "caption" > < span class = "caption-text" > Package Reference< / span > < / p >
< ul >
< li class = "toctree-l1" > < a class = "reference internal" href = "../../../pkg/data.html" > torchreid.data< / a > < / li >
< li class = "toctree-l1" > < a class = "reference internal" href = "../../../pkg/engine.html" > torchreid.engine< / a > < / li >
< li class = "toctree-l1" > < a class = "reference internal" href = "../../../pkg/losses.html" > torchreid.losses< / a > < / li >
< li class = "toctree-l1" > < a class = "reference internal" href = "../../../pkg/metrics.html" > torchreid.metrics< / a > < / li >
< li class = "toctree-l1" > < a class = "reference internal" href = "../../../pkg/models.html" > torchreid.models< / a > < / li >
< li class = "toctree-l1" > < a class = "reference internal" href = "../../../pkg/optim.html" > torchreid.optim< / a > < / li >
< li class = "toctree-l1" > < a class = "reference internal" href = "../../../pkg/utils.html" > torchreid.utils< / a > < / li >
< / ul >
< p class = "caption" > < span class = "caption-text" > Resources< / span > < / p >
< ul >
< li class = "toctree-l1" > < a class = "reference internal" href = "../../../AWESOME_REID.html" > Awesome-ReID< / a > < / li >
< li class = "toctree-l1" > < a class = "reference internal" href = "../../../MODEL_ZOO.html" > Model Zoo< / a > < / li >
< / ul >
< / div >
< / div >
< / nav >
< section data-toggle = "wy-nav-shift" class = "wy-nav-content-wrap" >
< nav class = "wy-nav-top" aria-label = "top navigation" >
< i data-toggle = "wy-nav-top" class = "fa fa-bars" > < / i >
< a href = "../../../index.html" > torchreid< / a >
< / nav >
< div class = "wy-nav-content" >
< div class = "rst-content" >
< div role = "navigation" aria-label = "breadcrumbs navigation" >
< ul class = "wy-breadcrumbs" >
< li > < a href = "../../../index.html" > Docs< / a > » < / li >
< li > < a href = "../../index.html" > Module code< / a > » < / li >
< li > torchreid.utils.model_complexity< / li >
< li class = "wy-breadcrumbs-aside" >
< / li >
< / ul >
< hr / >
< / div >
< div role = "main" class = "document" itemscope = "itemscope" itemtype = "http://schema.org/Article" >
< div itemprop = "articleBody" >
< h1 > Source code for torchreid.utils.model_complexity< / h1 > < div class = "highlight" > < pre >
< span > < / span > < span class = "kn" > from< / span > < span class = "nn" > __future__< / span > < span class = "k" > import< / span > < span class = "n" > absolute_import< / span > < span class = "p" > ,< / span > < span class = "n" > division< / span > < span class = "p" > ,< / span > < span class = "n" > print_function< / span >
< span class = "n" > __all__< / span > < span class = "o" > =< / span > < span class = "p" > [< / span > < span class = "s1" > ' compute_model_complexity' < / span > < span class = "p" > ]< / span >
< span class = "kn" > from< / span > < span class = "nn" > collections< / span > < span class = "k" > import< / span > < span class = "n" > namedtuple< / span > < span class = "p" > ,< / span > < span class = "n" > defaultdict< / span >
< span class = "kn" > import< / span > < span class = "nn" > numpy< / span > < span class = "k" > as< / span > < span class = "nn" > np< / span >
< span class = "kn" > import< / span > < span class = "nn" > math< / span >
< span class = "kn" > from< / span > < span class = "nn" > itertools< / span > < span class = "k" > import< / span > < span class = "n" > repeat< / span >
< span class = "kn" > import< / span > < span class = "nn" > torch< / span >
< span class = "kn" > import< / span > < span class = "nn" > torch.nn< / span > < span class = "k" > as< / span > < span class = "nn" > nn< / span >
< span class = "sd" > " " " < / span >
< span class = "sd" > Utility< / span >
< span class = "sd" > " " " < / span >
< span class = "k" > def< / span > < span class = "nf" > _ntuple< / span > < span class = "p" > (< / span > < span class = "n" > n< / span > < span class = "p" > ):< / span >
< span class = "k" > def< / span > < span class = "nf" > parse< / span > < span class = "p" > (< / span > < span class = "n" > x< / span > < span class = "p" > ):< / span >
< span class = "k" > if< / span > < span class = "nb" > isinstance< / span > < span class = "p" > (< / span > < span class = "n" > x< / span > < span class = "p" > ,< / span > < span class = "nb" > int< / span > < span class = "p" > ):< / span >
< span class = "k" > return< / span > < span class = "nb" > tuple< / span > < span class = "p" > (< / span > < span class = "n" > repeat< / span > < span class = "p" > (< / span > < span class = "n" > x< / span > < span class = "p" > ,< / span > < span class = "n" > n< / span > < span class = "p" > ))< / span >
< span class = "k" > return< / span > < span class = "n" > x< / span >
< span class = "k" > return< / span > < span class = "n" > parse< / span >
< span class = "n" > _single< / span > < span class = "o" > =< / span > < span class = "n" > _ntuple< / span > < span class = "p" > (< / span > < span class = "mi" > 1< / span > < span class = "p" > )< / span >
< span class = "n" > _pair< / span > < span class = "o" > =< / span > < span class = "n" > _ntuple< / span > < span class = "p" > (< / span > < span class = "mi" > 2< / span > < span class = "p" > )< / span >
< span class = "n" > _triple< / span > < span class = "o" > =< / span > < span class = "n" > _ntuple< / span > < span class = "p" > (< / span > < span class = "mi" > 3< / span > < span class = "p" > )< / span >
< span class = "sd" > " " " < / span >
< span class = "sd" > Convolution< / span >
< span class = "sd" > " " " < / span >
< span class = "k" > def< / span > < span class = "nf" > hook_convNd< / span > < span class = "p" > (< / span > < span class = "n" > m< / span > < span class = "p" > ,< / span > < span class = "n" > x< / span > < span class = "p" > ,< / span > < span class = "n" > y< / span > < span class = "p" > ):< / span >
< span class = "n" > k< / span > < span class = "o" > =< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > prod< / span > < span class = "p" > (< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > Tensor< / span > < span class = "p" > (< / span > < span class = "n" > m< / span > < span class = "o" > .< / span > < span class = "n" > kernel_size< / span > < span class = "p" > ))< / span > < span class = "o" > .< / span > < span class = "n" > item< / span > < span class = "p" > ()< / span >
< span class = "n" > cin< / span > < span class = "o" > =< / span > < span class = "n" > m< / span > < span class = "o" > .< / span > < span class = "n" > in_channels< / span >
< span class = "n" > flops_per_ele< / span > < span class = "o" > =< / span > < span class = "n" > k< / span > < span class = "o" > *< / span > < span class = "n" > cin< / span > < span class = "c1" > #+ (k*cin-1)< / span >
< span class = "k" > if< / span > < span class = "n" > m< / span > < span class = "o" > .< / span > < span class = "n" > bias< / span > < span class = "ow" > is< / span > < span class = "ow" > not< / span > < span class = "kc" > None< / span > < span class = "p" > :< / span >
< span class = "n" > flops_per_ele< / span > < span class = "o" > +=< / span > < span class = "mi" > 1< / span >
< span class = "n" > flops< / span > < span class = "o" > =< / span > < span class = "n" > flops_per_ele< / span > < span class = "o" > *< / span > < span class = "n" > y< / span > < span class = "o" > .< / span > < span class = "n" > numel< / span > < span class = "p" > ()< / span > < span class = "o" > /< / span > < span class = "n" > m< / span > < span class = "o" > .< / span > < span class = "n" > groups< / span >
< span class = "k" > return< / span > < span class = "nb" > int< / span > < span class = "p" > (< / span > < span class = "n" > flops< / span > < span class = "p" > )< / span >
< span class = "sd" > " " " < / span >
< span class = "sd" > Pooling< / span >
< span class = "sd" > " " " < / span >
< span class = "k" > def< / span > < span class = "nf" > hook_maxpool1d< / span > < span class = "p" > (< / span > < span class = "n" > m< / span > < span class = "p" > ,< / span > < span class = "n" > x< / span > < span class = "p" > ,< / span > < span class = "n" > y< / span > < span class = "p" > ):< / span >
< span class = "n" > flops_per_ele< / span > < span class = "o" > =< / span > < span class = "n" > m< / span > < span class = "o" > .< / span > < span class = "n" > kernel_size< / span > < span class = "o" > -< / span > < span class = "mi" > 1< / span >
< span class = "n" > flops< / span > < span class = "o" > =< / span > < span class = "n" > flops_per_ele< / span > < span class = "o" > *< / span > < span class = "n" > y< / span > < span class = "o" > .< / span > < span class = "n" > numel< / span > < span class = "p" > ()< / span >
< span class = "k" > return< / span > < span class = "nb" > int< / span > < span class = "p" > (< / span > < span class = "n" > flops< / span > < span class = "p" > )< / span >
< span class = "k" > def< / span > < span class = "nf" > hook_maxpool2d< / span > < span class = "p" > (< / span > < span class = "n" > m< / span > < span class = "p" > ,< / span > < span class = "n" > x< / span > < span class = "p" > ,< / span > < span class = "n" > y< / span > < span class = "p" > ):< / span >
< span class = "n" > k< / span > < span class = "o" > =< / span > < span class = "n" > _pair< / span > < span class = "p" > (< / span > < span class = "n" > m< / span > < span class = "o" > .< / span > < span class = "n" > kernel_size< / span > < span class = "p" > )< / span >
< span class = "n" > k< / span > < span class = "o" > =< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > prod< / span > < span class = "p" > (< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > Tensor< / span > < span class = "p" > (< / span > < span class = "n" > k< / span > < span class = "p" > ))< / span > < span class = "o" > .< / span > < span class = "n" > item< / span > < span class = "p" > ()< / span >
< span class = "c1" > # ops: compare< / span >
< span class = "n" > flops_per_ele< / span > < span class = "o" > =< / span > < span class = "n" > k< / span > < span class = "o" > -< / span > < span class = "mi" > 1< / span >
< span class = "n" > flops< / span > < span class = "o" > =< / span > < span class = "n" > flops_per_ele< / span > < span class = "o" > *< / span > < span class = "n" > y< / span > < span class = "o" > .< / span > < span class = "n" > numel< / span > < span class = "p" > ()< / span >
< span class = "k" > return< / span > < span class = "nb" > int< / span > < span class = "p" > (< / span > < span class = "n" > flops< / span > < span class = "p" > )< / span >
< span class = "k" > def< / span > < span class = "nf" > hook_maxpool3d< / span > < span class = "p" > (< / span > < span class = "n" > m< / span > < span class = "p" > ,< / span > < span class = "n" > x< / span > < span class = "p" > ,< / span > < span class = "n" > y< / span > < span class = "p" > ):< / span >
< span class = "n" > k< / span > < span class = "o" > =< / span > < span class = "n" > _triple< / span > < span class = "p" > (< / span > < span class = "n" > m< / span > < span class = "o" > .< / span > < span class = "n" > kernel_size< / span > < span class = "p" > )< / span >
< span class = "n" > k< / span > < span class = "o" > =< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > prod< / span > < span class = "p" > (< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > Tensor< / span > < span class = "p" > (< / span > < span class = "n" > k< / span > < span class = "p" > ))< / span > < span class = "o" > .< / span > < span class = "n" > item< / span > < span class = "p" > ()< / span >
< span class = "n" > flops_per_ele< / span > < span class = "o" > =< / span > < span class = "n" > k< / span > < span class = "o" > -< / span > < span class = "mi" > 1< / span >
< span class = "n" > flops< / span > < span class = "o" > =< / span > < span class = "n" > flops_per_ele< / span > < span class = "o" > *< / span > < span class = "n" > y< / span > < span class = "o" > .< / span > < span class = "n" > numel< / span > < span class = "p" > ()< / span >
< span class = "k" > return< / span > < span class = "nb" > int< / span > < span class = "p" > (< / span > < span class = "n" > flops< / span > < span class = "p" > )< / span >
< span class = "k" > def< / span > < span class = "nf" > hook_avgpool1d< / span > < span class = "p" > (< / span > < span class = "n" > m< / span > < span class = "p" > ,< / span > < span class = "n" > x< / span > < span class = "p" > ,< / span > < span class = "n" > y< / span > < span class = "p" > ):< / span >
< span class = "n" > flops_per_ele< / span > < span class = "o" > =< / span > < span class = "n" > m< / span > < span class = "o" > .< / span > < span class = "n" > kernel_size< / span >
< span class = "n" > flops< / span > < span class = "o" > =< / span > < span class = "n" > flops_per_ele< / span > < span class = "o" > *< / span > < span class = "n" > y< / span > < span class = "o" > .< / span > < span class = "n" > numel< / span > < span class = "p" > ()< / span >
< span class = "k" > return< / span > < span class = "nb" > int< / span > < span class = "p" > (< / span > < span class = "n" > flops< / span > < span class = "p" > )< / span >
< span class = "k" > def< / span > < span class = "nf" > hook_avgpool2d< / span > < span class = "p" > (< / span > < span class = "n" > m< / span > < span class = "p" > ,< / span > < span class = "n" > x< / span > < span class = "p" > ,< / span > < span class = "n" > y< / span > < span class = "p" > ):< / span >
< span class = "n" > k< / span > < span class = "o" > =< / span > < span class = "n" > _pair< / span > < span class = "p" > (< / span > < span class = "n" > m< / span > < span class = "o" > .< / span > < span class = "n" > kernel_size< / span > < span class = "p" > )< / span >
< span class = "n" > k< / span > < span class = "o" > =< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > prod< / span > < span class = "p" > (< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > Tensor< / span > < span class = "p" > (< / span > < span class = "n" > k< / span > < span class = "p" > ))< / span > < span class = "o" > .< / span > < span class = "n" > item< / span > < span class = "p" > ()< / span >
< span class = "n" > flops_per_ele< / span > < span class = "o" > =< / span > < span class = "n" > k< / span >
< span class = "n" > flops< / span > < span class = "o" > =< / span > < span class = "n" > flops_per_ele< / span > < span class = "o" > *< / span > < span class = "n" > y< / span > < span class = "o" > .< / span > < span class = "n" > numel< / span > < span class = "p" > ()< / span >
< span class = "k" > return< / span > < span class = "nb" > int< / span > < span class = "p" > (< / span > < span class = "n" > flops< / span > < span class = "p" > )< / span >
< span class = "k" > def< / span > < span class = "nf" > hook_avgpool3d< / span > < span class = "p" > (< / span > < span class = "n" > m< / span > < span class = "p" > ,< / span > < span class = "n" > x< / span > < span class = "p" > ,< / span > < span class = "n" > y< / span > < span class = "p" > ):< / span >
< span class = "n" > k< / span > < span class = "o" > =< / span > < span class = "n" > _triple< / span > < span class = "p" > (< / span > < span class = "n" > m< / span > < span class = "o" > .< / span > < span class = "n" > kernel_size< / span > < span class = "p" > )< / span >
< span class = "n" > k< / span > < span class = "o" > =< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > prod< / span > < span class = "p" > (< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > Tensor< / span > < span class = "p" > (< / span > < span class = "n" > k< / span > < span class = "p" > ))< / span > < span class = "o" > .< / span > < span class = "n" > item< / span > < span class = "p" > ()< / span >
< span class = "n" > flops_per_ele< / span > < span class = "o" > =< / span > < span class = "n" > k< / span >
< span class = "n" > flops< / span > < span class = "o" > =< / span > < span class = "n" > flops_per_ele< / span > < span class = "o" > *< / span > < span class = "n" > y< / span > < span class = "o" > .< / span > < span class = "n" > numel< / span > < span class = "p" > ()< / span >
< span class = "k" > return< / span > < span class = "nb" > int< / span > < span class = "p" > (< / span > < span class = "n" > flops< / span > < span class = "p" > )< / span >
< span class = "k" > def< / span > < span class = "nf" > hook_adapmaxpool1d< / span > < span class = "p" > (< / span > < span class = "n" > m< / span > < span class = "p" > ,< / span > < span class = "n" > x< / span > < span class = "p" > ,< / span > < span class = "n" > y< / span > < span class = "p" > ):< / span >
< span class = "n" > x< / span > < span class = "o" > =< / span > < span class = "n" > x< / span > < span class = "p" > [< / span > < span class = "mi" > 0< / span > < span class = "p" > ]< / span >
< span class = "n" > out_size< / span > < span class = "o" > =< / span > < span class = "n" > m< / span > < span class = "o" > .< / span > < span class = "n" > output_size< / span >
< span class = "n" > k< / span > < span class = "o" > =< / span > < span class = "n" > math< / span > < span class = "o" > .< / span > < span class = "n" > ceil< / span > < span class = "p" > (< / span > < span class = "n" > x< / span > < span class = "o" > .< / span > < span class = "n" > size< / span > < span class = "p" > (< / span > < span class = "mi" > 2< / span > < span class = "p" > )< / span > < span class = "o" > /< / span > < span class = "n" > out_size< / span > < span class = "p" > )< / span >
< span class = "n" > flops_per_ele< / span > < span class = "o" > =< / span > < span class = "n" > k< / span > < span class = "o" > -< / span > < span class = "mi" > 1< / span >
< span class = "n" > flops< / span > < span class = "o" > =< / span > < span class = "n" > flops_per_ele< / span > < span class = "o" > *< / span > < span class = "n" > y< / span > < span class = "o" > .< / span > < span class = "n" > numel< / span > < span class = "p" > ()< / span >
< span class = "k" > return< / span > < span class = "nb" > int< / span > < span class = "p" > (< / span > < span class = "n" > flops< / span > < span class = "p" > )< / span >
< span class = "k" > def< / span > < span class = "nf" > hook_adapmaxpool2d< / span > < span class = "p" > (< / span > < span class = "n" > m< / span > < span class = "p" > ,< / span > < span class = "n" > x< / span > < span class = "p" > ,< / span > < span class = "n" > y< / span > < span class = "p" > ):< / span >
< span class = "n" > x< / span > < span class = "o" > =< / span > < span class = "n" > x< / span > < span class = "p" > [< / span > < span class = "mi" > 0< / span > < span class = "p" > ]< / span >
< span class = "n" > out_size< / span > < span class = "o" > =< / span > < span class = "n" > _pair< / span > < span class = "p" > (< / span > < span class = "n" > m< / span > < span class = "o" > .< / span > < span class = "n" > output_size< / span > < span class = "p" > )< / span >
< span class = "n" > k< / span > < span class = "o" > =< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > Tensor< / span > < span class = "p" > (< / span > < span class = "nb" > list< / span > < span class = "p" > (< / span > < span class = "n" > x< / span > < span class = "o" > .< / span > < span class = "n" > size< / span > < span class = "p" > ()[< / span > < span class = "mi" > 2< / span > < span class = "p" > :]))< / span > < span class = "o" > /< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > Tensor< / span > < span class = "p" > (< / span > < span class = "n" > out_size< / span > < span class = "p" > )< / span >
< span class = "n" > k< / span > < span class = "o" > =< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > prod< / span > < span class = "p" > (< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > ceil< / span > < span class = "p" > (< / span > < span class = "n" > k< / span > < span class = "p" > ))< / span > < span class = "o" > .< / span > < span class = "n" > item< / span > < span class = "p" > ()< / span >
< span class = "n" > flops_per_ele< / span > < span class = "o" > =< / span > < span class = "n" > k< / span > < span class = "o" > -< / span > < span class = "mi" > 1< / span >
< span class = "n" > flops< / span > < span class = "o" > =< / span > < span class = "n" > flops_per_ele< / span > < span class = "o" > *< / span > < span class = "n" > y< / span > < span class = "o" > .< / span > < span class = "n" > numel< / span > < span class = "p" > ()< / span >
< span class = "k" > return< / span > < span class = "nb" > int< / span > < span class = "p" > (< / span > < span class = "n" > flops< / span > < span class = "p" > )< / span >
< span class = "k" > def< / span > < span class = "nf" > hook_adapmaxpool3d< / span > < span class = "p" > (< / span > < span class = "n" > m< / span > < span class = "p" > ,< / span > < span class = "n" > x< / span > < span class = "p" > ,< / span > < span class = "n" > y< / span > < span class = "p" > ):< / span >
< span class = "n" > x< / span > < span class = "o" > =< / span > < span class = "n" > x< / span > < span class = "p" > [< / span > < span class = "mi" > 0< / span > < span class = "p" > ]< / span >
< span class = "n" > out_size< / span > < span class = "o" > =< / span > < span class = "n" > _triple< / span > < span class = "p" > (< / span > < span class = "n" > m< / span > < span class = "o" > .< / span > < span class = "n" > output_size< / span > < span class = "p" > )< / span >
< span class = "n" > k< / span > < span class = "o" > =< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > Tensor< / span > < span class = "p" > (< / span > < span class = "nb" > list< / span > < span class = "p" > (< / span > < span class = "n" > x< / span > < span class = "o" > .< / span > < span class = "n" > size< / span > < span class = "p" > ()[< / span > < span class = "mi" > 2< / span > < span class = "p" > :]))< / span > < span class = "o" > /< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > Tensor< / span > < span class = "p" > (< / span > < span class = "n" > out_size< / span > < span class = "p" > )< / span >
< span class = "n" > k< / span > < span class = "o" > =< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > prod< / span > < span class = "p" > (< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > ceil< / span > < span class = "p" > (< / span > < span class = "n" > k< / span > < span class = "p" > ))< / span > < span class = "o" > .< / span > < span class = "n" > item< / span > < span class = "p" > ()< / span >
< span class = "n" > flops_per_ele< / span > < span class = "o" > =< / span > < span class = "n" > k< / span > < span class = "o" > -< / span > < span class = "mi" > 1< / span >
< span class = "n" > flops< / span > < span class = "o" > =< / span > < span class = "n" > flops_per_ele< / span > < span class = "o" > *< / span > < span class = "n" > y< / span > < span class = "o" > .< / span > < span class = "n" > numel< / span > < span class = "p" > ()< / span >
< span class = "k" > return< / span > < span class = "nb" > int< / span > < span class = "p" > (< / span > < span class = "n" > flops< / span > < span class = "p" > )< / span >
< span class = "k" > def< / span > < span class = "nf" > hook_adapavgpool1d< / span > < span class = "p" > (< / span > < span class = "n" > m< / span > < span class = "p" > ,< / span > < span class = "n" > x< / span > < span class = "p" > ,< / span > < span class = "n" > y< / span > < span class = "p" > ):< / span >
< span class = "n" > x< / span > < span class = "o" > =< / span > < span class = "n" > x< / span > < span class = "p" > [< / span > < span class = "mi" > 0< / span > < span class = "p" > ]< / span >
< span class = "n" > out_size< / span > < span class = "o" > =< / span > < span class = "n" > m< / span > < span class = "o" > .< / span > < span class = "n" > output_size< / span >
< span class = "n" > k< / span > < span class = "o" > =< / span > < span class = "n" > math< / span > < span class = "o" > .< / span > < span class = "n" > ceil< / span > < span class = "p" > (< / span > < span class = "n" > x< / span > < span class = "o" > .< / span > < span class = "n" > size< / span > < span class = "p" > (< / span > < span class = "mi" > 2< / span > < span class = "p" > )< / span > < span class = "o" > /< / span > < span class = "n" > out_size< / span > < span class = "p" > )< / span >
< span class = "n" > flops_per_ele< / span > < span class = "o" > =< / span > < span class = "n" > k< / span >
< span class = "n" > flops< / span > < span class = "o" > =< / span > < span class = "n" > flops_per_ele< / span > < span class = "o" > *< / span > < span class = "n" > y< / span > < span class = "o" > .< / span > < span class = "n" > numel< / span > < span class = "p" > ()< / span >
< span class = "k" > return< / span > < span class = "nb" > int< / span > < span class = "p" > (< / span > < span class = "n" > flops< / span > < span class = "p" > )< / span >
< span class = "k" > def< / span > < span class = "nf" > hook_adapavgpool2d< / span > < span class = "p" > (< / span > < span class = "n" > m< / span > < span class = "p" > ,< / span > < span class = "n" > x< / span > < span class = "p" > ,< / span > < span class = "n" > y< / span > < span class = "p" > ):< / span >
< span class = "n" > x< / span > < span class = "o" > =< / span > < span class = "n" > x< / span > < span class = "p" > [< / span > < span class = "mi" > 0< / span > < span class = "p" > ]< / span >
< span class = "n" > out_size< / span > < span class = "o" > =< / span > < span class = "n" > _pair< / span > < span class = "p" > (< / span > < span class = "n" > m< / span > < span class = "o" > .< / span > < span class = "n" > output_size< / span > < span class = "p" > )< / span >
< span class = "n" > k< / span > < span class = "o" > =< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > Tensor< / span > < span class = "p" > (< / span > < span class = "nb" > list< / span > < span class = "p" > (< / span > < span class = "n" > x< / span > < span class = "o" > .< / span > < span class = "n" > size< / span > < span class = "p" > ()[< / span > < span class = "mi" > 2< / span > < span class = "p" > :]))< / span > < span class = "o" > /< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > Tensor< / span > < span class = "p" > (< / span > < span class = "n" > out_size< / span > < span class = "p" > )< / span >
< span class = "n" > k< / span > < span class = "o" > =< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > prod< / span > < span class = "p" > (< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > ceil< / span > < span class = "p" > (< / span > < span class = "n" > k< / span > < span class = "p" > ))< / span > < span class = "o" > .< / span > < span class = "n" > item< / span > < span class = "p" > ()< / span >
< span class = "n" > flops_per_ele< / span > < span class = "o" > =< / span > < span class = "n" > k< / span >
< span class = "n" > flops< / span > < span class = "o" > =< / span > < span class = "n" > flops_per_ele< / span > < span class = "o" > *< / span > < span class = "n" > y< / span > < span class = "o" > .< / span > < span class = "n" > numel< / span > < span class = "p" > ()< / span >
< span class = "k" > return< / span > < span class = "nb" > int< / span > < span class = "p" > (< / span > < span class = "n" > flops< / span > < span class = "p" > )< / span >
< span class = "k" > def< / span > < span class = "nf" > hook_adapavgpool3d< / span > < span class = "p" > (< / span > < span class = "n" > m< / span > < span class = "p" > ,< / span > < span class = "n" > x< / span > < span class = "p" > ,< / span > < span class = "n" > y< / span > < span class = "p" > ):< / span >
< span class = "n" > x< / span > < span class = "o" > =< / span > < span class = "n" > x< / span > < span class = "p" > [< / span > < span class = "mi" > 0< / span > < span class = "p" > ]< / span >
< span class = "n" > out_size< / span > < span class = "o" > =< / span > < span class = "n" > _triple< / span > < span class = "p" > (< / span > < span class = "n" > m< / span > < span class = "o" > .< / span > < span class = "n" > output_size< / span > < span class = "p" > )< / span >
< span class = "n" > k< / span > < span class = "o" > =< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > Tensor< / span > < span class = "p" > (< / span > < span class = "nb" > list< / span > < span class = "p" > (< / span > < span class = "n" > x< / span > < span class = "o" > .< / span > < span class = "n" > size< / span > < span class = "p" > ()[< / span > < span class = "mi" > 2< / span > < span class = "p" > :]))< / span > < span class = "o" > /< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > Tensor< / span > < span class = "p" > (< / span > < span class = "n" > out_size< / span > < span class = "p" > )< / span >
< span class = "n" > k< / span > < span class = "o" > =< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > prod< / span > < span class = "p" > (< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > ceil< / span > < span class = "p" > (< / span > < span class = "n" > k< / span > < span class = "p" > ))< / span > < span class = "o" > .< / span > < span class = "n" > item< / span > < span class = "p" > ()< / span >
< span class = "n" > flops_per_ele< / span > < span class = "o" > =< / span > < span class = "n" > k< / span >
< span class = "n" > flops< / span > < span class = "o" > =< / span > < span class = "n" > flops_per_ele< / span > < span class = "o" > *< / span > < span class = "n" > y< / span > < span class = "o" > .< / span > < span class = "n" > numel< / span > < span class = "p" > ()< / span >
< span class = "k" > return< / span > < span class = "nb" > int< / span > < span class = "p" > (< / span > < span class = "n" > flops< / span > < span class = "p" > )< / span >
< span class = "sd" > " " " < / span >
< span class = "sd" > Non-linear activations< / span >
< span class = "sd" > " " " < / span >
< span class = "k" > def< / span > < span class = "nf" > hook_relu< / span > < span class = "p" > (< / span > < span class = "n" > m< / span > < span class = "p" > ,< / span > < span class = "n" > x< / span > < span class = "p" > ,< / span > < span class = "n" > y< / span > < span class = "p" > ):< / span >
< span class = "c1" > # eq: max(0, x)< / span >
< span class = "n" > num_ele< / span > < span class = "o" > =< / span > < span class = "n" > y< / span > < span class = "o" > .< / span > < span class = "n" > numel< / span > < span class = "p" > ()< / span >
< span class = "k" > return< / span > < span class = "nb" > int< / span > < span class = "p" > (< / span > < span class = "n" > num_ele< / span > < span class = "p" > )< / span >
< span class = "k" > def< / span > < span class = "nf" > hook_leakyrelu< / span > < span class = "p" > (< / span > < span class = "n" > m< / span > < span class = "p" > ,< / span > < span class = "n" > x< / span > < span class = "p" > ,< / span > < span class = "n" > y< / span > < span class = "p" > ):< / span >
< span class = "c1" > # eq: max(0, x) + negative_slope*min(0, x)< / span >
< span class = "n" > num_ele< / span > < span class = "o" > =< / span > < span class = "n" > y< / span > < span class = "o" > .< / span > < span class = "n" > numel< / span > < span class = "p" > ()< / span >
< span class = "n" > flops< / span > < span class = "o" > =< / span > < span class = "mi" > 3< / span > < span class = "o" > *< / span > < span class = "n" > num_ele< / span >
< span class = "k" > return< / span > < span class = "nb" > int< / span > < span class = "p" > (< / span > < span class = "n" > flops< / span > < span class = "p" > )< / span >
< span class = "sd" > " " " < / span >
< span class = "sd" > Normalization< / span >
< span class = "sd" > " " " < / span >
< span class = "k" > def< / span > < span class = "nf" > hook_batchnormNd< / span > < span class = "p" > (< / span > < span class = "n" > m< / span > < span class = "p" > ,< / span > < span class = "n" > x< / span > < span class = "p" > ,< / span > < span class = "n" > y< / span > < span class = "p" > ):< / span >
< span class = "n" > num_ele< / span > < span class = "o" > =< / span > < span class = "n" > y< / span > < span class = "o" > .< / span > < span class = "n" > numel< / span > < span class = "p" > ()< / span >
< span class = "n" > flops< / span > < span class = "o" > =< / span > < span class = "mi" > 2< / span > < span class = "o" > *< / span > < span class = "n" > num_ele< / span > < span class = "c1" > # mean and std< / span >
< span class = "k" > if< / span > < span class = "n" > m< / span > < span class = "o" > .< / span > < span class = "n" > affine< / span > < span class = "p" > :< / span >
< span class = "n" > flops< / span > < span class = "o" > +=< / span > < span class = "mi" > 2< / span > < span class = "o" > *< / span > < span class = "n" > num_ele< / span > < span class = "c1" > # gamma and beta< / span >
< span class = "k" > return< / span > < span class = "nb" > int< / span > < span class = "p" > (< / span > < span class = "n" > flops< / span > < span class = "p" > )< / span >
< span class = "k" > def< / span > < span class = "nf" > hook_instancenormNd< / span > < span class = "p" > (< / span > < span class = "n" > m< / span > < span class = "p" > ,< / span > < span class = "n" > x< / span > < span class = "p" > ,< / span > < span class = "n" > y< / span > < span class = "p" > ):< / span >
< span class = "k" > return< / span > < span class = "n" > hook_batchnormNd< / span > < span class = "p" > (< / span > < span class = "n" > m< / span > < span class = "p" > ,< / span > < span class = "n" > x< / span > < span class = "p" > ,< / span > < span class = "n" > y< / span > < span class = "p" > )< / span >
< span class = "k" > def< / span > < span class = "nf" > hook_groupnorm< / span > < span class = "p" > (< / span > < span class = "n" > m< / span > < span class = "p" > ,< / span > < span class = "n" > x< / span > < span class = "p" > ,< / span > < span class = "n" > y< / span > < span class = "p" > ):< / span >
< span class = "k" > return< / span > < span class = "n" > hook_batchnormNd< / span > < span class = "p" > (< / span > < span class = "n" > m< / span > < span class = "p" > ,< / span > < span class = "n" > x< / span > < span class = "p" > ,< / span > < span class = "n" > y< / span > < span class = "p" > )< / span >
< span class = "k" > def< / span > < span class = "nf" > hook_layernorm< / span > < span class = "p" > (< / span > < span class = "n" > m< / span > < span class = "p" > ,< / span > < span class = "n" > x< / span > < span class = "p" > ,< / span > < span class = "n" > y< / span > < span class = "p" > ):< / span >
< span class = "n" > num_ele< / span > < span class = "o" > =< / span > < span class = "n" > y< / span > < span class = "o" > .< / span > < span class = "n" > numel< / span > < span class = "p" > ()< / span >
< span class = "n" > flops< / span > < span class = "o" > =< / span > < span class = "mi" > 2< / span > < span class = "o" > *< / span > < span class = "n" > num_ele< / span > < span class = "c1" > # mean and std< / span >
< span class = "k" > if< / span > < span class = "n" > m< / span > < span class = "o" > .< / span > < span class = "n" > elementwise_affine< / span > < span class = "p" > :< / span >
< span class = "n" > flops< / span > < span class = "o" > +=< / span > < span class = "mi" > 2< / span > < span class = "o" > *< / span > < span class = "n" > num_ele< / span > < span class = "c1" > # gamma and beta< / span >
< span class = "k" > return< / span > < span class = "nb" > int< / span > < span class = "p" > (< / span > < span class = "n" > flops< / span > < span class = "p" > )< / span >
< span class = "sd" > " " " < / span >
< span class = "sd" > Linear< / span >
< span class = "sd" > " " " < / span >
< span class = "k" > def< / span > < span class = "nf" > hook_linear< / span > < span class = "p" > (< / span > < span class = "n" > m< / span > < span class = "p" > ,< / span > < span class = "n" > x< / span > < span class = "p" > ,< / span > < span class = "n" > y< / span > < span class = "p" > ):< / span >
< span class = "n" > flops_per_ele< / span > < span class = "o" > =< / span > < span class = "n" > m< / span > < span class = "o" > .< / span > < span class = "n" > in_features< / span > < span class = "c1" > #+ (m.in_features-1)< / span >
< span class = "k" > if< / span > < span class = "n" > m< / span > < span class = "o" > .< / span > < span class = "n" > bias< / span > < span class = "ow" > is< / span > < span class = "ow" > not< / span > < span class = "kc" > None< / span > < span class = "p" > :< / span >
< span class = "n" > flops_per_ele< / span > < span class = "o" > +=< / span > < span class = "mi" > 1< / span >
< span class = "n" > flops< / span > < span class = "o" > =< / span > < span class = "n" > flops_per_ele< / span > < span class = "o" > *< / span > < span class = "n" > y< / span > < span class = "o" > .< / span > < span class = "n" > numel< / span > < span class = "p" > ()< / span >
< span class = "k" > return< / span > < span class = "nb" > int< / span > < span class = "p" > (< / span > < span class = "n" > flops< / span > < span class = "p" > )< / span >
< span class = "n" > __generic_flops_counter< / span > < span class = "o" > =< / span > < span class = "p" > {< / span >
< span class = "c1" > # Convolution< / span >
< span class = "s1" > ' Conv1d' < / span > < span class = "p" > :< / span > < span class = "n" > hook_convNd< / span > < span class = "p" > ,< / span >
< span class = "s1" > ' Conv2d' < / span > < span class = "p" > :< / span > < span class = "n" > hook_convNd< / span > < span class = "p" > ,< / span >
< span class = "s1" > ' Conv3d' < / span > < span class = "p" > :< / span > < span class = "n" > hook_convNd< / span > < span class = "p" > ,< / span >
< span class = "c1" > # Pooling< / span >
< span class = "s1" > ' MaxPool1d' < / span > < span class = "p" > :< / span > < span class = "n" > hook_maxpool1d< / span > < span class = "p" > ,< / span >
< span class = "s1" > ' MaxPool2d' < / span > < span class = "p" > :< / span > < span class = "n" > hook_maxpool2d< / span > < span class = "p" > ,< / span >
< span class = "s1" > ' MaxPool3d' < / span > < span class = "p" > :< / span > < span class = "n" > hook_maxpool3d< / span > < span class = "p" > ,< / span >
< span class = "s1" > ' AvgPool1d' < / span > < span class = "p" > :< / span > < span class = "n" > hook_avgpool1d< / span > < span class = "p" > ,< / span >
< span class = "s1" > ' AvgPool2d' < / span > < span class = "p" > :< / span > < span class = "n" > hook_avgpool2d< / span > < span class = "p" > ,< / span >
< span class = "s1" > ' AvgPool3d' < / span > < span class = "p" > :< / span > < span class = "n" > hook_avgpool3d< / span > < span class = "p" > ,< / span >
< span class = "s1" > ' AdaptiveMaxPool1d' < / span > < span class = "p" > :< / span > < span class = "n" > hook_adapmaxpool1d< / span > < span class = "p" > ,< / span >
< span class = "s1" > ' AdaptiveMaxPool2d' < / span > < span class = "p" > :< / span > < span class = "n" > hook_adapmaxpool2d< / span > < span class = "p" > ,< / span >
< span class = "s1" > ' AdaptiveMaxPool3d' < / span > < span class = "p" > :< / span > < span class = "n" > hook_adapmaxpool3d< / span > < span class = "p" > ,< / span >
< span class = "s1" > ' AdaptiveAvgPool1d' < / span > < span class = "p" > :< / span > < span class = "n" > hook_adapavgpool1d< / span > < span class = "p" > ,< / span >
< span class = "s1" > ' AdaptiveAvgPool2d' < / span > < span class = "p" > :< / span > < span class = "n" > hook_adapavgpool2d< / span > < span class = "p" > ,< / span >
< span class = "s1" > ' AdaptiveAvgPool3d' < / span > < span class = "p" > :< / span > < span class = "n" > hook_adapavgpool3d< / span > < span class = "p" > ,< / span >
< span class = "c1" > # Non-linear activations< / span >
< span class = "s1" > ' ReLU' < / span > < span class = "p" > :< / span > < span class = "n" > hook_relu< / span > < span class = "p" > ,< / span >
< span class = "s1" > ' ReLU6' < / span > < span class = "p" > :< / span > < span class = "n" > hook_relu< / span > < span class = "p" > ,< / span >
< span class = "s1" > ' LeakyReLU' < / span > < span class = "p" > :< / span > < span class = "n" > hook_leakyrelu< / span > < span class = "p" > ,< / span >
< span class = "c1" > # Normalization< / span >
< span class = "s1" > ' BatchNorm1d' < / span > < span class = "p" > :< / span > < span class = "n" > hook_batchnormNd< / span > < span class = "p" > ,< / span >
< span class = "s1" > ' BatchNorm2d' < / span > < span class = "p" > :< / span > < span class = "n" > hook_batchnormNd< / span > < span class = "p" > ,< / span >
< span class = "s1" > ' BatchNorm3d' < / span > < span class = "p" > :< / span > < span class = "n" > hook_batchnormNd< / span > < span class = "p" > ,< / span >
< span class = "s1" > ' InstanceNorm1d' < / span > < span class = "p" > :< / span > < span class = "n" > hook_instancenormNd< / span > < span class = "p" > ,< / span >
< span class = "s1" > ' InstanceNorm2d' < / span > < span class = "p" > :< / span > < span class = "n" > hook_instancenormNd< / span > < span class = "p" > ,< / span >
< span class = "s1" > ' InstanceNorm3d' < / span > < span class = "p" > :< / span > < span class = "n" > hook_instancenormNd< / span > < span class = "p" > ,< / span >
< span class = "s1" > ' GroupNorm' < / span > < span class = "p" > :< / span > < span class = "n" > hook_groupnorm< / span > < span class = "p" > ,< / span >
< span class = "s1" > ' LayerNorm' < / span > < span class = "p" > :< / span > < span class = "n" > hook_layernorm< / span > < span class = "p" > ,< / span >
< span class = "c1" > # Linear< / span >
< span class = "s1" > ' Linear' < / span > < span class = "p" > :< / span > < span class = "n" > hook_linear< / span > < span class = "p" > ,< / span >
< span class = "p" > }< / span >
< span class = "n" > __conv_linear_flops_counter< / span > < span class = "o" > =< / span > < span class = "p" > {< / span >
< span class = "c1" > # Convolution< / span >
< span class = "s1" > ' Conv1d' < / span > < span class = "p" > :< / span > < span class = "n" > hook_convNd< / span > < span class = "p" > ,< / span >
< span class = "s1" > ' Conv2d' < / span > < span class = "p" > :< / span > < span class = "n" > hook_convNd< / span > < span class = "p" > ,< / span >
< span class = "s1" > ' Conv3d' < / span > < span class = "p" > :< / span > < span class = "n" > hook_convNd< / span > < span class = "p" > ,< / span >
< span class = "c1" > # Linear< / span >
< span class = "s1" > ' Linear' < / span > < span class = "p" > :< / span > < span class = "n" > hook_linear< / span > < span class = "p" > ,< / span >
< span class = "p" > }< / span >
< span class = "k" > def< / span > < span class = "nf" > _get_flops_counter< / span > < span class = "p" > (< / span > < span class = "n" > only_conv_linear< / span > < span class = "p" > ):< / span >
< span class = "k" > if< / span > < span class = "n" > only_conv_linear< / span > < span class = "p" > :< / span >
< span class = "k" > return< / span > < span class = "n" > __conv_linear_flops_counter< / span >
< span class = "k" > return< / span > < span class = "n" > __generic_flops_counter< / span >
< div class = "viewcode-block" id = "compute_model_complexity" > < a class = "viewcode-back" href = "../../../pkg/utils.html#torchreid.utils.model_complexity.compute_model_complexity" > [docs]< / a > < span class = "k" > def< / span > < span class = "nf" > compute_model_complexity< / span > < span class = "p" > (< / span > < span class = "n" > model< / span > < span class = "p" > ,< / span > < span class = "n" > input_size< / span > < span class = "p" > ,< / span > < span class = "n" > verbose< / span > < span class = "o" > =< / span > < span class = "kc" > False< / span > < span class = "p" > ,< / span > < span class = "n" > only_conv_linear< / span > < span class = "o" > =< / span > < span class = "kc" > True< / span > < span class = "p" > ):< / span >
< span class = "sd" > " " " Returns number of parameters and FLOPs.< / span >
< span class = "sd" > .. note::< / span >
2019-05-23 05:23:33 +08:00
< span class = "sd" > (1) this function only provides an estimate of the theoretical time complexity< / span >
< span class = "sd" > rather than the actual running time which depends on implementations and hardware,< / span >
< span class = "sd" > and (2) the FLOPs is only counted for layers that are used at test time. This means< / span >
< span class = "sd" > that redundant layers such as person ID classification layer will be ignored as it< / span >
< span class = "sd" > is discarded when doing feature extraction. Note that the inference graph depends on< / span >
< span class = "sd" > how you construct the computations in ``forward()``.< / span >
2019-05-22 23:18:39 +08:00
< span class = "sd" > Args:< / span >
< span class = "sd" > model (nn.Module): network model.< / span >
< span class = "sd" > input_size (tuple): input size, e.g. (1, 3, 256, 128).< / span >
< span class = "sd" > verbose (bool, optional): shows detailed complexity of< / span >
< span class = "sd" > each module. Default is False.< / span >
< span class = "sd" > only_conv_linear (bool, optional): only considers convolution< / span >
< span class = "sd" > and linear layers when counting flops. Default is True.< / span >
< span class = "sd" > If set to False, flops of all layers will be counted.< / span >
< span class = "sd" > Examples::< / span >
< span class = "sd" > > > > from torchreid import models, utils< / span >
< span class = "sd" > > > > model = models.build_model(name=' resnet50' , num_classes=1000)< / span >
< span class = "sd" > > > > num_params, flops = utils.compute_model_complexity(model, (1, 3, 256, 128), verbose=True)< / span >
< span class = "sd" > " " " < / span >
< span class = "n" > registered_handles< / span > < span class = "o" > =< / span > < span class = "p" > []< / span >
< span class = "n" > layer_list< / span > < span class = "o" > =< / span > < span class = "p" > []< / span >
< span class = "n" > layer< / span > < span class = "o" > =< / span > < span class = "n" > namedtuple< / span > < span class = "p" > (< / span > < span class = "s1" > ' layer' < / span > < span class = "p" > ,< / span > < span class = "p" > [< / span > < span class = "s1" > ' class_name' < / span > < span class = "p" > ,< / span > < span class = "s1" > ' params' < / span > < span class = "p" > ,< / span > < span class = "s1" > ' flops' < / span > < span class = "p" > ])< / span >
< span class = "k" > def< / span > < span class = "nf" > _add_hooks< / span > < span class = "p" > (< / span > < span class = "n" > m< / span > < span class = "p" > ):< / span >
< span class = "k" > def< / span > < span class = "nf" > _has_submodule< / span > < span class = "p" > (< / span > < span class = "n" > m< / span > < span class = "p" > ):< / span >
< span class = "k" > return< / span > < span class = "nb" > len< / span > < span class = "p" > (< / span > < span class = "nb" > list< / span > < span class = "p" > (< / span > < span class = "n" > m< / span > < span class = "o" > .< / span > < span class = "n" > children< / span > < span class = "p" > ()))< / span > < span class = "o" > > < / span > < span class = "mi" > 0< / span >
< span class = "k" > def< / span > < span class = "nf" > _hook< / span > < span class = "p" > (< / span > < span class = "n" > m< / span > < span class = "p" > ,< / span > < span class = "n" > x< / span > < span class = "p" > ,< / span > < span class = "n" > y< / span > < span class = "p" > ):< / span >
< span class = "n" > params< / span > < span class = "o" > =< / span > < span class = "nb" > sum< / span > < span class = "p" > (< / span > < span class = "n" > p< / span > < span class = "o" > .< / span > < span class = "n" > numel< / span > < span class = "p" > ()< / span > < span class = "k" > for< / span > < span class = "n" > p< / span > < span class = "ow" > in< / span > < span class = "n" > m< / span > < span class = "o" > .< / span > < span class = "n" > parameters< / span > < span class = "p" > ())< / span >
< span class = "n" > class_name< / span > < span class = "o" > =< / span > < span class = "nb" > str< / span > < span class = "p" > (< / span > < span class = "n" > m< / span > < span class = "o" > .< / span > < span class = "vm" > __class__< / span > < span class = "o" > .< / span > < span class = "vm" > __name__< / span > < span class = "p" > )< / span >
< span class = "n" > flops_counter< / span > < span class = "o" > =< / span > < span class = "n" > _get_flops_counter< / span > < span class = "p" > (< / span > < span class = "n" > only_conv_linear< / span > < span class = "p" > )< / span >
< span class = "k" > if< / span > < span class = "n" > class_name< / span > < span class = "ow" > in< / span > < span class = "n" > flops_counter< / span > < span class = "p" > :< / span >
< span class = "n" > flops< / span > < span class = "o" > =< / span > < span class = "n" > flops_counter< / span > < span class = "p" > [< / span > < span class = "n" > class_name< / span > < span class = "p" > ](< / span > < span class = "n" > m< / span > < span class = "p" > ,< / span > < span class = "n" > x< / span > < span class = "p" > ,< / span > < span class = "n" > y< / span > < span class = "p" > )< / span >
< span class = "k" > else< / span > < span class = "p" > :< / span >
< span class = "n" > flops< / span > < span class = "o" > =< / span > < span class = "mi" > 0< / span >
< span class = "n" > layer_list< / span > < span class = "o" > .< / span > < span class = "n" > append< / span > < span class = "p" > (< / span >
< span class = "n" > layer< / span > < span class = "p" > (< / span >
< span class = "n" > class_name< / span > < span class = "o" > =< / span > < span class = "n" > class_name< / span > < span class = "p" > ,< / span >
< span class = "n" > params< / span > < span class = "o" > =< / span > < span class = "n" > params< / span > < span class = "p" > ,< / span >
< span class = "n" > flops< / span > < span class = "o" > =< / span > < span class = "n" > flops< / span >
< span class = "p" > )< / span >
< span class = "p" > )< / span >
< span class = "c1" > # only consider the very basic nn layer< / span >
< span class = "k" > if< / span > < span class = "n" > _has_submodule< / span > < span class = "p" > (< / span > < span class = "n" > m< / span > < span class = "p" > ):< / span >
< span class = "k" > return< / span >
< span class = "n" > handle< / span > < span class = "o" > =< / span > < span class = "n" > m< / span > < span class = "o" > .< / span > < span class = "n" > register_forward_hook< / span > < span class = "p" > (< / span > < span class = "n" > _hook< / span > < span class = "p" > )< / span >
< span class = "n" > registered_handles< / span > < span class = "o" > .< / span > < span class = "n" > append< / span > < span class = "p" > (< / span > < span class = "n" > handle< / span > < span class = "p" > )< / span >
< span class = "n" > default_train_mode< / span > < span class = "o" > =< / span > < span class = "n" > model< / span > < span class = "o" > .< / span > < span class = "n" > training< / span >
< span class = "n" > model< / span > < span class = "o" > .< / span > < span class = "n" > eval< / span > < span class = "p" > ()< / span > < span class = "o" > .< / span > < span class = "n" > apply< / span > < span class = "p" > (< / span > < span class = "n" > _add_hooks< / span > < span class = "p" > )< / span >
< span class = "nb" > input< / span > < span class = "o" > =< / span > < span class = "n" > torch< / span > < span class = "o" > .< / span > < span class = "n" > rand< / span > < span class = "p" > (< / span > < span class = "n" > input_size< / span > < span class = "p" > )< / span >
< span class = "k" > if< / span > < span class = "nb" > next< / span > < span class = "p" > (< / span > < span class = "n" > model< / span > < span class = "o" > .< / span > < span class = "n" > parameters< / span > < span class = "p" > ())< / span > < span class = "o" > .< / span > < span class = "n" > is_cuda< / span > < span class = "p" > :< / span >
< span class = "nb" > input< / span > < span class = "o" > =< / span > < span class = "nb" > input< / span > < span class = "o" > .< / span > < span class = "n" > cuda< / span > < span class = "p" > ()< / span >
< span class = "n" > model< / span > < span class = "p" > (< / span > < span class = "nb" > input< / span > < span class = "p" > )< / span > < span class = "c1" > # forward< / span >
< span class = "k" > for< / span > < span class = "n" > handle< / span > < span class = "ow" > in< / span > < span class = "n" > registered_handles< / span > < span class = "p" > :< / span >
< span class = "n" > handle< / span > < span class = "o" > .< / span > < span class = "n" > remove< / span > < span class = "p" > ()< / span >
< span class = "n" > model< / span > < span class = "o" > .< / span > < span class = "n" > train< / span > < span class = "p" > (< / span > < span class = "n" > default_train_mode< / span > < span class = "p" > )< / span >
< span class = "k" > if< / span > < span class = "n" > verbose< / span > < span class = "p" > :< / span >
< span class = "n" > per_module_params< / span > < span class = "o" > =< / span > < span class = "n" > defaultdict< / span > < span class = "p" > (< / span > < span class = "nb" > list< / span > < span class = "p" > )< / span >
< span class = "n" > per_module_flops< / span > < span class = "o" > =< / span > < span class = "n" > defaultdict< / span > < span class = "p" > (< / span > < span class = "nb" > list< / span > < span class = "p" > )< / span >
< span class = "n" > total_params< / span > < span class = "p" > ,< / span > < span class = "n" > total_flops< / span > < span class = "o" > =< / span > < span class = "mi" > 0< / span > < span class = "p" > ,< / span > < span class = "mi" > 0< / span >
< span class = "k" > for< / span > < span class = "n" > layer< / span > < span class = "ow" > in< / span > < span class = "n" > layer_list< / span > < span class = "p" > :< / span >
< span class = "n" > total_params< / span > < span class = "o" > +=< / span > < span class = "n" > layer< / span > < span class = "o" > .< / span > < span class = "n" > params< / span >
< span class = "n" > total_flops< / span > < span class = "o" > +=< / span > < span class = "n" > layer< / span > < span class = "o" > .< / span > < span class = "n" > flops< / span >
< span class = "k" > if< / span > < span class = "n" > verbose< / span > < span class = "p" > :< / span >
< span class = "n" > per_module_params< / span > < span class = "p" > [< / span > < span class = "n" > layer< / span > < span class = "o" > .< / span > < span class = "n" > class_name< / span > < span class = "p" > ]< / span > < span class = "o" > .< / span > < span class = "n" > append< / span > < span class = "p" > (< / span > < span class = "n" > layer< / span > < span class = "o" > .< / span > < span class = "n" > params< / span > < span class = "p" > )< / span >
< span class = "n" > per_module_flops< / span > < span class = "p" > [< / span > < span class = "n" > layer< / span > < span class = "o" > .< / span > < span class = "n" > class_name< / span > < span class = "p" > ]< / span > < span class = "o" > .< / span > < span class = "n" > append< / span > < span class = "p" > (< / span > < span class = "n" > layer< / span > < span class = "o" > .< / span > < span class = "n" > flops< / span > < span class = "p" > )< / span >
< span class = "k" > if< / span > < span class = "n" > verbose< / span > < span class = "p" > :< / span >
2019-05-23 05:10:02 +08:00
< span class = "n" > num_udscore< / span > < span class = "o" > =< / span > < span class = "mi" > 55< / span >
< span class = "nb" > print< / span > < span class = "p" > (< / span > < span class = "s1" > ' < / span > < span class = "si" > {}< / span > < span class = "s1" > ' < / span > < span class = "o" > .< / span > < span class = "n" > format< / span > < span class = "p" > (< / span > < span class = "s1" > ' -' < / span > < span class = "o" > *< / span > < span class = "n" > num_udscore< / span > < span class = "p" > ))< / span >
< span class = "nb" > print< / span > < span class = "p" > (< / span > < span class = "s1" > ' Model complexity with input size < / span > < span class = "si" > {}< / span > < span class = "s1" > ' < / span > < span class = "o" > .< / span > < span class = "n" > format< / span > < span class = "p" > (< / span > < span class = "n" > input_size< / span > < span class = "p" > ))< / span >
< span class = "nb" > print< / span > < span class = "p" > (< / span > < span class = "s1" > ' < / span > < span class = "si" > {}< / span > < span class = "s1" > ' < / span > < span class = "o" > .< / span > < span class = "n" > format< / span > < span class = "p" > (< / span > < span class = "s1" > ' -' < / span > < span class = "o" > *< / span > < span class = "n" > num_udscore< / span > < span class = "p" > ))< / span >
2019-05-22 23:18:39 +08:00
< span class = "k" > for< / span > < span class = "n" > class_name< / span > < span class = "ow" > in< / span > < span class = "n" > per_module_params< / span > < span class = "p" > :< / span >
< span class = "n" > params< / span > < span class = "o" > =< / span > < span class = "nb" > int< / span > < span class = "p" > (< / span > < span class = "n" > np< / span > < span class = "o" > .< / span > < span class = "n" > sum< / span > < span class = "p" > (< / span > < span class = "n" > per_module_params< / span > < span class = "p" > [< / span > < span class = "n" > class_name< / span > < span class = "p" > ]))< / span >
< span class = "n" > flops< / span > < span class = "o" > =< / span > < span class = "nb" > int< / span > < span class = "p" > (< / span > < span class = "n" > np< / span > < span class = "o" > .< / span > < span class = "n" > sum< / span > < span class = "p" > (< / span > < span class = "n" > per_module_flops< / span > < span class = "p" > [< / span > < span class = "n" > class_name< / span > < span class = "p" > ]))< / span >
< span class = "nb" > print< / span > < span class = "p" > (< / span > < span class = "s1" > ' < / span > < span class = "si" > {}< / span > < span class = "s1" > (params=< / span > < span class = "si" > {:,}< / span > < span class = "s1" > , flops=< / span > < span class = "si" > {:,}< / span > < span class = "s1" > )' < / span > < span class = "o" > .< / span > < span class = "n" > format< / span > < span class = "p" > (< / span > < span class = "n" > class_name< / span > < span class = "p" > ,< / span > < span class = "n" > params< / span > < span class = "p" > ,< / span > < span class = "n" > flops< / span > < span class = "p" > ))< / span >
2019-05-23 05:10:02 +08:00
< span class = "nb" > print< / span > < span class = "p" > (< / span > < span class = "s1" > ' < / span > < span class = "si" > {}< / span > < span class = "s1" > ' < / span > < span class = "o" > .< / span > < span class = "n" > format< / span > < span class = "p" > (< / span > < span class = "s1" > ' -' < / span > < span class = "o" > *< / span > < span class = "n" > num_udscore< / span > < span class = "p" > ))< / span >
2019-05-22 23:18:39 +08:00
< span class = "nb" > print< / span > < span class = "p" > (< / span > < span class = "s1" > ' Total (params=< / span > < span class = "si" > {:,}< / span > < span class = "s1" > , flops=< / span > < span class = "si" > {:,}< / span > < span class = "s1" > )' < / span > < span class = "o" > .< / span > < span class = "n" > format< / span > < span class = "p" > (< / span > < span class = "n" > total_params< / span > < span class = "p" > ,< / span > < span class = "n" > total_flops< / span > < span class = "p" > ))< / span >
2019-05-23 05:10:02 +08:00
< span class = "nb" > print< / span > < span class = "p" > (< / span > < span class = "s1" > ' < / span > < span class = "si" > {}< / span > < span class = "s1" > ' < / span > < span class = "o" > .< / span > < span class = "n" > format< / span > < span class = "p" > (< / span > < span class = "s1" > ' -' < / span > < span class = "o" > *< / span > < span class = "n" > num_udscore< / span > < span class = "p" > ))< / span >
2019-05-22 23:18:39 +08:00
< span class = "k" > return< / span > < span class = "n" > total_params< / span > < span class = "p" > ,< / span > < span class = "n" > total_flops< / span > < / div >
< / pre > < / div >
< / div >
< / div >
< footer >
< hr / >
< div role = "contentinfo" >
< p >
© Copyright 2019, Kaiyang Zhou
< / p >
< / div >
Built with < a href = "http://sphinx-doc.org/" > Sphinx< / a > using a < a href = "https://github.com/rtfd/sphinx_rtd_theme" > theme< / a > provided by < a href = "https://readthedocs.org" > Read the Docs< / a > .
< / footer >
< / div >
< / div >
< / section >
< / div >
< script type = "text/javascript" >
jQuery(function () {
SphinxRtdTheme.Navigation.enable(true);
});
< / script >
< / body >
< / html >