[Doc] Complete doc strings of estimators (#275)

* complete doc strings of estimators

* add Optional for params with default value
pull/225/head^2
Yang Gao 2022-09-01 15:17:28 +08:00 committed by GitHub
parent ece95a69ae
commit 1f1bcd181c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 93 additions and 12 deletions

View File

@ -216,6 +216,7 @@ def print_model_with_flops_params(model,
"""
def accumulate_params(self):
"""Accumulate params by recursion."""
if is_supported_instance(self):
return self.__params__
else:
@ -225,6 +226,7 @@ def print_model_with_flops_params(model,
return sum
def accumulate_flops(self):
"""Accumulate flops by recursion."""
if is_supported_instance(self):
return self.__flops__ / model.__batch_counter__
else:
@ -234,6 +236,7 @@ def print_model_with_flops_params(model,
return sum
def flops_repr(self):
"""A new extra_repr method of the input module."""
accumulated_num_params = self.accumulate_params()
accumulated_flops_cost = self.accumulate_flops()
flops_string = str(
@ -252,6 +255,7 @@ def print_model_with_flops_params(model,
])
def add_extra_repr(m):
"""Reload extra_repr method."""
m.accumulate_flops = accumulate_flops.__get__(m)
m.accumulate_params = accumulate_params.__get__(m)
flops_extra_repr = flops_repr.__get__(m)
@ -261,6 +265,7 @@ def print_model_with_flops_params(model,
assert m.extra_repr != m.original_extra_repr
def del_extra_repr(m):
"""Recover origin extra_repr method."""
if hasattr(m, 'original_extra_repr'):
m.extra_repr = m.original_extra_repr
del m.original_extra_repr
@ -281,6 +286,7 @@ def accumulate_sub_module_flops_params(model):
"""
def accumulate_params(module):
"""Accumulate params by recursion."""
if is_supported_instance(module):
return module.__params__
else:
@ -290,6 +296,7 @@ def accumulate_sub_module_flops_params(model):
return sum
def accumulate_flops(module):
"""Accumulate flops by recursion."""
if is_supported_instance(module):
return module.__flops__ / model.__batch_counter__
else:
@ -310,6 +317,7 @@ def get_model_parameters_number(model):
Args:
model (nn.module): The model for parameter number calculation.
Returns:
float: Parameter number of the model.
"""
@ -318,8 +326,10 @@ def get_model_parameters_number(model):
def add_flops_params_counting_methods(net_main_module):
# adding additional methods to the existing module object,
# this is done this way so that each function has access to self object
"""Add additional methods to the existing module object.
This is done this way so that each function has access to self object.
"""
net_main_module.start_flops_params_count = start_flops_params_count.__get__( # noqa: E501
net_main_module)
net_main_module.stop_flops_params_count = stop_flops_params_count.__get__(
@ -339,6 +349,7 @@ def compute_average_flops_params_cost(self):
A method to compute average FLOPs cost, which will be available after
`add_flops_params_counting_methods()` is called on a desired net object.
Returns:
float: Current mean flops consumption per image.
"""
@ -406,16 +417,18 @@ def reset_flops_params_count(self):
# ---- Internal functions
def empty_flops_params_counter_hook(module, input, output):
"""Empty flops and params variables of the module."""
module.__flops__ += 0
module.__params__ += 0
def add_batch_counter_variables_or_reset(module):
"""Add or reset the batch counter variable."""
module.__batch_counter__ = 0
def add_batch_counter_hook_function(module):
"""Register the batch counter hook for the module."""
if hasattr(module, '__batch_counter_handle__'):
return
@ -424,6 +437,7 @@ def add_batch_counter_hook_function(module):
def batch_counter_hook(module, input, output):
"""Add batch counter variable based on the input size."""
batch_size = 1
if len(input) > 0:
# Can have multiple inputs, getting the first one
@ -437,12 +451,14 @@ def batch_counter_hook(module, input, output):
def remove_batch_counter_hook_function(module):
"""Remove batch counter handle variable."""
if hasattr(module, '__batch_counter_handle__'):
module.__batch_counter_handle__.remove()
del module.__batch_counter_handle__
def add_flops_params_counter_variable_or_reset(module):
"""Add or reset flops and params variable of the module."""
if is_supported_instance(module):
if hasattr(module, '__flops__') or hasattr(module, '__params__'):
print('Warning: variables __flops__ or __params__ are already '
@ -453,16 +469,19 @@ def add_flops_params_counter_variable_or_reset(module):
def get_counter_type(module):
"""Get counter type of the module based on the module class name."""
return module.__class__.__name__ + 'Counter'
def is_supported_instance(module):
"""Judge whether the module is in TASK_UTILS registry or not."""
if get_counter_type(module) in TASK_UTILS._module_dict.keys():
return True
return False
def remove_flops_params_counter_hook_function(module):
"""Remove counter related variables after resource estimation."""
if hasattr(module, '__flops_params_handle__'):
module.__flops_params_handle__.remove()
del module.__flops_params_handle__

View File

@ -1,17 +1,31 @@
# Copyright (c) OpenMMLab. All rights reserved.
import logging
import time
from typing import Any, Dict
import torch
from mmengine.logging import print_log
def repeat_measure_inference_speed(model,
resource_args,
def repeat_measure_inference_speed(model: torch.nn.Module,
resource_args: Dict[str, Any],
max_iter: int = 100,
num_warmup: int = 5,
log_interval: int = 100,
repeat_num: int = 1) -> float:
"""Repeat speed measure for multi-times to get more precise results."""
"""Repeat speed measure for multi-times to get more precise results.
Args:
model (torch.nn.Module): The measured model.
resource_args (Dict[str, float]): resources information.
max_iter (Optional[int]): Max iteration num for inference speed test.
num_warmup (Optional[int]): Iteration num for warm-up stage.
log_interval (Optional[int]): Interval num for logging the results.
repeat_num (Optional[int]): Num of times to repeat the measurement.
Returns:
fps (float): The measured inference speed of the model.
"""
assert repeat_num >= 1
fps_list = []
@ -19,7 +33,7 @@ def repeat_measure_inference_speed(model,
for _ in range(repeat_num):
fps_list.append(
measure_inference_speed(model, resource_args, max_iter,
measure_inference_speed(model, resource_args, max_iter, num_warmup,
log_interval))
if repeat_num > 1:
@ -39,10 +53,24 @@ def repeat_measure_inference_speed(model,
return latency
def measure_inference_speed(model, resource_args, max_iter: int,
log_interval: int) -> float:
def measure_inference_speed(model: torch.nn.Module,
resource_args: Dict[str, Any],
max_iter: int = 100,
num_warmup: int = 5,
log_interval: int = 100) -> float:
"""Measure inference speed on GPU devices.
Args:
model (torch.nn.Module): The measured model.
resource_args (Dict[str, float]): resources information.
max_iter (Optional[int]): Max iteration num for inference speed test.
num_warmup (Optional[int]): Iteration num for warm-up stage.
log_interval (Optional[int]): Interval num for logging the results.
Returns:
fps (float): The measured inference speed of the model.
"""
# the first several iterations may be very slow so skip them
num_warmup = 5
pure_inf_time = 0.0
fps = 0.0
data = dict()
@ -50,7 +78,7 @@ def measure_inference_speed(model, resource_args, max_iter: int,
device = 'cuda'
else:
raise NotImplementedError('To use cpu to test latency not supported.')
# benchmark with 100 image and take the average
# benchmark with {max_iter} image and take the average
for i in range(1, max_iter):
if device == 'cuda':
data = torch.rand(resource_args['input_shape']).cuda()

View File

@ -10,6 +10,7 @@ class ReLUCounter(BaseCounter):
@staticmethod
def add_count_hook(module, input, output):
"""Calculate FLOPs and params based on the size of input & output."""
active_elements_count = output.numel()
module.__flops__ += int(active_elements_count)
module.__params__ += get_model_parameters_number(module)
@ -17,19 +18,23 @@ class ReLUCounter(BaseCounter):
@TASK_UTILS.register_module()
class PReLUCounter(ReLUCounter):
"""FLOPs/params counter for PReLU function."""
pass
@TASK_UTILS.register_module()
class ELUCounter(ReLUCounter):
"""FLOPs/params counter for ELU function."""
pass
@TASK_UTILS.register_module()
class LeakyReLUCounter(ReLUCounter):
"""FLOPs/params counter for LeakyReLU function."""
pass
@TASK_UTILS.register_module()
class ReLU6Counter(ReLUCounter):
"""FLOPs/params counter for ReLU6 function."""
pass

View File

@ -10,6 +10,7 @@ class ConvCounter(BaseCounter):
@staticmethod
def add_count_hook(module, input, output):
"""Calculate FLOPs and params based on the size of input & output."""
# Can have multiple inputs, getting the first one
input = input[0]
@ -44,14 +45,17 @@ class ConvCounter(BaseCounter):
@TASK_UTILS.register_module()
class Conv1dCounter(ConvCounter):
"""FLOPs/params counter for Conv1d module."""
pass
@TASK_UTILS.register_module()
class Conv2dCounter(ConvCounter):
"""FLOPs/params counter for Conv2d module."""
pass
@TASK_UTILS.register_module()
class Conv3dCounter(ConvCounter):
"""FLOPs/params counter for Conv3d module."""
pass

View File

@ -6,10 +6,11 @@ from .base_counter import BaseCounter
@TASK_UTILS.register_module()
class ConvTranspose2dCounter(BaseCounter):
"""FLOPs/params counter for Decov module series."""
"""FLOPs/params counter for Deconv module series."""
@staticmethod
def add_count_hook(module, input, output):
"""Compute FLOPs and params based on the size of input & output."""
# Can have multiple inputs, getting the first one
input = input[0]

View File

@ -12,6 +12,7 @@ class LinearCounter(BaseCounter):
@staticmethod
def add_count_hook(module, input, output):
"""Calculate FLOPs and params based on the size of input & output."""
input = input[0]
output_last_dim = output.shape[
-1] # pytorch checks dimensions, so here we don't care much

View File

@ -11,6 +11,7 @@ class BNCounter(BaseCounter):
@staticmethod
def add_count_hook(module, input, output):
"""Calculate FLOPs and params based on the size of input & output."""
input = input[0]
batch_flops = np.prod(input.shape)
if getattr(module, 'affine', False):
@ -21,39 +22,47 @@ class BNCounter(BaseCounter):
@TASK_UTILS.register_module()
class BatchNorm1dCounter(BNCounter):
"""FLOPs/params counter for BatchNorm1d module."""
pass
@TASK_UTILS.register_module()
class BatchNorm2dCounter(BNCounter):
"""FLOPs/params counter for BatchNorm2d module."""
pass
@TASK_UTILS.register_module()
class BatchNorm3dCounter(BNCounter):
"""FLOPs/params counter for BatchNorm3d module."""
pass
@TASK_UTILS.register_module()
class InstanceNorm1dCounter(BNCounter):
"""FLOPs/params counter for InstanceNorm1d module."""
pass
@TASK_UTILS.register_module()
class InstanceNorm2dCounter(BNCounter):
"""FLOPs/params counter for InstanceNorm2d module."""
pass
@TASK_UTILS.register_module()
class InstanceNorm3dCounter(BNCounter):
"""FLOPs/params counter for InstanceNorm3d module."""
pass
@TASK_UTILS.register_module()
class LayerNormCounter(BNCounter):
"""FLOPs/params counter for LayerNorm module."""
pass
@TASK_UTILS.register_module()
class GroupNormCounter(BNCounter):
"""FLOPs/params counter for GroupNorm module."""
pass

View File

@ -11,6 +11,7 @@ class PoolCounter(BaseCounter):
@staticmethod
def add_count_hook(module, input, output):
"""Calculate FLOPs and params based on the size of input & output."""
input = input[0]
module.__flops__ += int(np.prod(input.shape))
module.__params__ += get_model_parameters_number(module)
@ -18,59 +19,71 @@ class PoolCounter(BaseCounter):
@TASK_UTILS.register_module()
class MaxPool1dCounter(PoolCounter):
"""FLOPs/params counter for MaxPool1d module."""
pass
@TASK_UTILS.register_module()
class MaxPool2dCounter(PoolCounter):
"""FLOPs/params counter for MaxPool2d module."""
pass
@TASK_UTILS.register_module()
class MaxPool3dCounter(PoolCounter):
"""FLOPs/params counter for MaxPool3d module."""
pass
@TASK_UTILS.register_module()
class AvgPool1dCounter(PoolCounter):
"""FLOPs/params counter for AvgPool1d module."""
pass
@TASK_UTILS.register_module()
class AvgPool2dCounter(PoolCounter):
"""FLOPs/params counter for AvgPool2d module."""
pass
@TASK_UTILS.register_module()
class AvgPool3dCounter(PoolCounter):
"""FLOPs/params counter for AvgPool3d module."""
pass
@TASK_UTILS.register_module()
class AdaptiveMaxPool1dCounter(PoolCounter):
"""FLOPs/params counter for AdaptiveMaxPool1d module."""
pass
@TASK_UTILS.register_module()
class AdaptiveMaxPool2dCounter(PoolCounter):
"""FLOPs/params counter for AdaptiveMaxPool2d module."""
pass
@TASK_UTILS.register_module()
class AdaptiveMaxPool3dCounter(PoolCounter):
"""FLOPs/params counter for AdaptiveMaxPool3d module."""
pass
@TASK_UTILS.register_module()
class AdaptiveAvgPool1dCounter(PoolCounter):
"""FLOPs/params counter for AdaptiveAvgPool1d module."""
pass
@TASK_UTILS.register_module()
class AdaptiveAvgPool2dCounter(PoolCounter):
"""FLOPs/params counter for AdaptiveAvgPool2d module."""
pass
@TASK_UTILS.register_module()
class AdaptiveAvgPool3dCounter(PoolCounter):
"""FLOPs/params counter for AdaptiveAvgPool3d module."""
pass

View File

@ -10,6 +10,7 @@ class UpsampleCounter(BaseCounter):
@staticmethod
def add_count_hook(module, input, output):
"""Calculate FLOPs and params based on the size of input & output."""
output_size = output[0]
batch_size = output_size.shape[0]
output_elements_count = batch_size