[Doc] Complete doc strings of estimators (#275)
* complete doc strings of estimators * add Optional for params with default valuepull/225/head^2
parent
ece95a69ae
commit
1f1bcd181c
|
@ -216,6 +216,7 @@ def print_model_with_flops_params(model,
|
|||
"""
|
||||
|
||||
def accumulate_params(self):
|
||||
"""Accumulate params by recursion."""
|
||||
if is_supported_instance(self):
|
||||
return self.__params__
|
||||
else:
|
||||
|
@ -225,6 +226,7 @@ def print_model_with_flops_params(model,
|
|||
return sum
|
||||
|
||||
def accumulate_flops(self):
|
||||
"""Accumulate flops by recursion."""
|
||||
if is_supported_instance(self):
|
||||
return self.__flops__ / model.__batch_counter__
|
||||
else:
|
||||
|
@ -234,6 +236,7 @@ def print_model_with_flops_params(model,
|
|||
return sum
|
||||
|
||||
def flops_repr(self):
|
||||
"""A new extra_repr method of the input module."""
|
||||
accumulated_num_params = self.accumulate_params()
|
||||
accumulated_flops_cost = self.accumulate_flops()
|
||||
flops_string = str(
|
||||
|
@ -252,6 +255,7 @@ def print_model_with_flops_params(model,
|
|||
])
|
||||
|
||||
def add_extra_repr(m):
|
||||
"""Reload extra_repr method."""
|
||||
m.accumulate_flops = accumulate_flops.__get__(m)
|
||||
m.accumulate_params = accumulate_params.__get__(m)
|
||||
flops_extra_repr = flops_repr.__get__(m)
|
||||
|
@ -261,6 +265,7 @@ def print_model_with_flops_params(model,
|
|||
assert m.extra_repr != m.original_extra_repr
|
||||
|
||||
def del_extra_repr(m):
|
||||
"""Recover origin extra_repr method."""
|
||||
if hasattr(m, 'original_extra_repr'):
|
||||
m.extra_repr = m.original_extra_repr
|
||||
del m.original_extra_repr
|
||||
|
@ -281,6 +286,7 @@ def accumulate_sub_module_flops_params(model):
|
|||
"""
|
||||
|
||||
def accumulate_params(module):
|
||||
"""Accumulate params by recursion."""
|
||||
if is_supported_instance(module):
|
||||
return module.__params__
|
||||
else:
|
||||
|
@ -290,6 +296,7 @@ def accumulate_sub_module_flops_params(model):
|
|||
return sum
|
||||
|
||||
def accumulate_flops(module):
|
||||
"""Accumulate flops by recursion."""
|
||||
if is_supported_instance(module):
|
||||
return module.__flops__ / model.__batch_counter__
|
||||
else:
|
||||
|
@ -310,6 +317,7 @@ def get_model_parameters_number(model):
|
|||
|
||||
Args:
|
||||
model (nn.module): The model for parameter number calculation.
|
||||
|
||||
Returns:
|
||||
float: Parameter number of the model.
|
||||
"""
|
||||
|
@ -318,8 +326,10 @@ def get_model_parameters_number(model):
|
|||
|
||||
|
||||
def add_flops_params_counting_methods(net_main_module):
|
||||
# adding additional methods to the existing module object,
|
||||
# this is done this way so that each function has access to self object
|
||||
"""Add additional methods to the existing module object.
|
||||
|
||||
This is done this way so that each function has access to self object.
|
||||
"""
|
||||
net_main_module.start_flops_params_count = start_flops_params_count.__get__( # noqa: E501
|
||||
net_main_module)
|
||||
net_main_module.stop_flops_params_count = stop_flops_params_count.__get__(
|
||||
|
@ -339,6 +349,7 @@ def compute_average_flops_params_cost(self):
|
|||
|
||||
A method to compute average FLOPs cost, which will be available after
|
||||
`add_flops_params_counting_methods()` is called on a desired net object.
|
||||
|
||||
Returns:
|
||||
float: Current mean flops consumption per image.
|
||||
"""
|
||||
|
@ -406,16 +417,18 @@ def reset_flops_params_count(self):
|
|||
|
||||
# ---- Internal functions
|
||||
def empty_flops_params_counter_hook(module, input, output):
|
||||
"""Empty flops and params variables of the module."""
|
||||
module.__flops__ += 0
|
||||
module.__params__ += 0
|
||||
|
||||
|
||||
def add_batch_counter_variables_or_reset(module):
|
||||
|
||||
"""Add or reset the batch counter variable."""
|
||||
module.__batch_counter__ = 0
|
||||
|
||||
|
||||
def add_batch_counter_hook_function(module):
|
||||
"""Register the batch counter hook for the module."""
|
||||
if hasattr(module, '__batch_counter_handle__'):
|
||||
return
|
||||
|
||||
|
@ -424,6 +437,7 @@ def add_batch_counter_hook_function(module):
|
|||
|
||||
|
||||
def batch_counter_hook(module, input, output):
|
||||
"""Add batch counter variable based on the input size."""
|
||||
batch_size = 1
|
||||
if len(input) > 0:
|
||||
# Can have multiple inputs, getting the first one
|
||||
|
@ -437,12 +451,14 @@ def batch_counter_hook(module, input, output):
|
|||
|
||||
|
||||
def remove_batch_counter_hook_function(module):
|
||||
"""Remove batch counter handle variable."""
|
||||
if hasattr(module, '__batch_counter_handle__'):
|
||||
module.__batch_counter_handle__.remove()
|
||||
del module.__batch_counter_handle__
|
||||
|
||||
|
||||
def add_flops_params_counter_variable_or_reset(module):
|
||||
"""Add or reset flops and params variable of the module."""
|
||||
if is_supported_instance(module):
|
||||
if hasattr(module, '__flops__') or hasattr(module, '__params__'):
|
||||
print('Warning: variables __flops__ or __params__ are already '
|
||||
|
@ -453,16 +469,19 @@ def add_flops_params_counter_variable_or_reset(module):
|
|||
|
||||
|
||||
def get_counter_type(module):
|
||||
"""Get counter type of the module based on the module class name."""
|
||||
return module.__class__.__name__ + 'Counter'
|
||||
|
||||
|
||||
def is_supported_instance(module):
|
||||
"""Judge whether the module is in TASK_UTILS registry or not."""
|
||||
if get_counter_type(module) in TASK_UTILS._module_dict.keys():
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def remove_flops_params_counter_hook_function(module):
|
||||
"""Remove counter related variables after resource estimation."""
|
||||
if hasattr(module, '__flops_params_handle__'):
|
||||
module.__flops_params_handle__.remove()
|
||||
del module.__flops_params_handle__
|
||||
|
|
|
@ -1,17 +1,31 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Dict
|
||||
|
||||
import torch
|
||||
from mmengine.logging import print_log
|
||||
|
||||
|
||||
def repeat_measure_inference_speed(model,
|
||||
resource_args,
|
||||
def repeat_measure_inference_speed(model: torch.nn.Module,
|
||||
resource_args: Dict[str, Any],
|
||||
max_iter: int = 100,
|
||||
num_warmup: int = 5,
|
||||
log_interval: int = 100,
|
||||
repeat_num: int = 1) -> float:
|
||||
"""Repeat speed measure for multi-times to get more precise results."""
|
||||
"""Repeat speed measure for multi-times to get more precise results.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): The measured model.
|
||||
resource_args (Dict[str, float]): resources information.
|
||||
max_iter (Optional[int]): Max iteration num for inference speed test.
|
||||
num_warmup (Optional[int]): Iteration num for warm-up stage.
|
||||
log_interval (Optional[int]): Interval num for logging the results.
|
||||
repeat_num (Optional[int]): Num of times to repeat the measurement.
|
||||
|
||||
Returns:
|
||||
fps (float): The measured inference speed of the model.
|
||||
"""
|
||||
assert repeat_num >= 1
|
||||
|
||||
fps_list = []
|
||||
|
@ -19,7 +33,7 @@ def repeat_measure_inference_speed(model,
|
|||
for _ in range(repeat_num):
|
||||
|
||||
fps_list.append(
|
||||
measure_inference_speed(model, resource_args, max_iter,
|
||||
measure_inference_speed(model, resource_args, max_iter, num_warmup,
|
||||
log_interval))
|
||||
|
||||
if repeat_num > 1:
|
||||
|
@ -39,10 +53,24 @@ def repeat_measure_inference_speed(model,
|
|||
return latency
|
||||
|
||||
|
||||
def measure_inference_speed(model, resource_args, max_iter: int,
|
||||
log_interval: int) -> float:
|
||||
def measure_inference_speed(model: torch.nn.Module,
|
||||
resource_args: Dict[str, Any],
|
||||
max_iter: int = 100,
|
||||
num_warmup: int = 5,
|
||||
log_interval: int = 100) -> float:
|
||||
"""Measure inference speed on GPU devices.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): The measured model.
|
||||
resource_args (Dict[str, float]): resources information.
|
||||
max_iter (Optional[int]): Max iteration num for inference speed test.
|
||||
num_warmup (Optional[int]): Iteration num for warm-up stage.
|
||||
log_interval (Optional[int]): Interval num for logging the results.
|
||||
|
||||
Returns:
|
||||
fps (float): The measured inference speed of the model.
|
||||
"""
|
||||
# the first several iterations may be very slow so skip them
|
||||
num_warmup = 5
|
||||
pure_inf_time = 0.0
|
||||
fps = 0.0
|
||||
data = dict()
|
||||
|
@ -50,7 +78,7 @@ def measure_inference_speed(model, resource_args, max_iter: int,
|
|||
device = 'cuda'
|
||||
else:
|
||||
raise NotImplementedError('To use cpu to test latency not supported.')
|
||||
# benchmark with 100 image and take the average
|
||||
# benchmark with {max_iter} image and take the average
|
||||
for i in range(1, max_iter):
|
||||
if device == 'cuda':
|
||||
data = torch.rand(resource_args['input_shape']).cuda()
|
||||
|
|
|
@ -10,6 +10,7 @@ class ReLUCounter(BaseCounter):
|
|||
|
||||
@staticmethod
|
||||
def add_count_hook(module, input, output):
|
||||
"""Calculate FLOPs and params based on the size of input & output."""
|
||||
active_elements_count = output.numel()
|
||||
module.__flops__ += int(active_elements_count)
|
||||
module.__params__ += get_model_parameters_number(module)
|
||||
|
@ -17,19 +18,23 @@ class ReLUCounter(BaseCounter):
|
|||
|
||||
@TASK_UTILS.register_module()
|
||||
class PReLUCounter(ReLUCounter):
|
||||
"""FLOPs/params counter for PReLU function."""
|
||||
pass
|
||||
|
||||
|
||||
@TASK_UTILS.register_module()
|
||||
class ELUCounter(ReLUCounter):
|
||||
"""FLOPs/params counter for ELU function."""
|
||||
pass
|
||||
|
||||
|
||||
@TASK_UTILS.register_module()
|
||||
class LeakyReLUCounter(ReLUCounter):
|
||||
"""FLOPs/params counter for LeakyReLU function."""
|
||||
pass
|
||||
|
||||
|
||||
@TASK_UTILS.register_module()
|
||||
class ReLU6Counter(ReLUCounter):
|
||||
"""FLOPs/params counter for ReLU6 function."""
|
||||
pass
|
||||
|
|
|
@ -10,6 +10,7 @@ class ConvCounter(BaseCounter):
|
|||
|
||||
@staticmethod
|
||||
def add_count_hook(module, input, output):
|
||||
"""Calculate FLOPs and params based on the size of input & output."""
|
||||
# Can have multiple inputs, getting the first one
|
||||
input = input[0]
|
||||
|
||||
|
@ -44,14 +45,17 @@ class ConvCounter(BaseCounter):
|
|||
|
||||
@TASK_UTILS.register_module()
|
||||
class Conv1dCounter(ConvCounter):
|
||||
"""FLOPs/params counter for Conv1d module."""
|
||||
pass
|
||||
|
||||
|
||||
@TASK_UTILS.register_module()
|
||||
class Conv2dCounter(ConvCounter):
|
||||
"""FLOPs/params counter for Conv2d module."""
|
||||
pass
|
||||
|
||||
|
||||
@TASK_UTILS.register_module()
|
||||
class Conv3dCounter(ConvCounter):
|
||||
"""FLOPs/params counter for Conv3d module."""
|
||||
pass
|
||||
|
|
|
@ -6,10 +6,11 @@ from .base_counter import BaseCounter
|
|||
|
||||
@TASK_UTILS.register_module()
|
||||
class ConvTranspose2dCounter(BaseCounter):
|
||||
"""FLOPs/params counter for Decov module series."""
|
||||
"""FLOPs/params counter for Deconv module series."""
|
||||
|
||||
@staticmethod
|
||||
def add_count_hook(module, input, output):
|
||||
"""Compute FLOPs and params based on the size of input & output."""
|
||||
# Can have multiple inputs, getting the first one
|
||||
input = input[0]
|
||||
|
||||
|
|
|
@ -12,6 +12,7 @@ class LinearCounter(BaseCounter):
|
|||
|
||||
@staticmethod
|
||||
def add_count_hook(module, input, output):
|
||||
"""Calculate FLOPs and params based on the size of input & output."""
|
||||
input = input[0]
|
||||
output_last_dim = output.shape[
|
||||
-1] # pytorch checks dimensions, so here we don't care much
|
||||
|
|
|
@ -11,6 +11,7 @@ class BNCounter(BaseCounter):
|
|||
|
||||
@staticmethod
|
||||
def add_count_hook(module, input, output):
|
||||
"""Calculate FLOPs and params based on the size of input & output."""
|
||||
input = input[0]
|
||||
batch_flops = np.prod(input.shape)
|
||||
if getattr(module, 'affine', False):
|
||||
|
@ -21,39 +22,47 @@ class BNCounter(BaseCounter):
|
|||
|
||||
@TASK_UTILS.register_module()
|
||||
class BatchNorm1dCounter(BNCounter):
|
||||
"""FLOPs/params counter for BatchNorm1d module."""
|
||||
pass
|
||||
|
||||
|
||||
@TASK_UTILS.register_module()
|
||||
class BatchNorm2dCounter(BNCounter):
|
||||
"""FLOPs/params counter for BatchNorm2d module."""
|
||||
pass
|
||||
|
||||
|
||||
@TASK_UTILS.register_module()
|
||||
class BatchNorm3dCounter(BNCounter):
|
||||
"""FLOPs/params counter for BatchNorm3d module."""
|
||||
pass
|
||||
|
||||
|
||||
@TASK_UTILS.register_module()
|
||||
class InstanceNorm1dCounter(BNCounter):
|
||||
"""FLOPs/params counter for InstanceNorm1d module."""
|
||||
pass
|
||||
|
||||
|
||||
@TASK_UTILS.register_module()
|
||||
class InstanceNorm2dCounter(BNCounter):
|
||||
"""FLOPs/params counter for InstanceNorm2d module."""
|
||||
pass
|
||||
|
||||
|
||||
@TASK_UTILS.register_module()
|
||||
class InstanceNorm3dCounter(BNCounter):
|
||||
"""FLOPs/params counter for InstanceNorm3d module."""
|
||||
pass
|
||||
|
||||
|
||||
@TASK_UTILS.register_module()
|
||||
class LayerNormCounter(BNCounter):
|
||||
"""FLOPs/params counter for LayerNorm module."""
|
||||
pass
|
||||
|
||||
|
||||
@TASK_UTILS.register_module()
|
||||
class GroupNormCounter(BNCounter):
|
||||
"""FLOPs/params counter for GroupNorm module."""
|
||||
pass
|
||||
|
|
|
@ -11,6 +11,7 @@ class PoolCounter(BaseCounter):
|
|||
|
||||
@staticmethod
|
||||
def add_count_hook(module, input, output):
|
||||
"""Calculate FLOPs and params based on the size of input & output."""
|
||||
input = input[0]
|
||||
module.__flops__ += int(np.prod(input.shape))
|
||||
module.__params__ += get_model_parameters_number(module)
|
||||
|
@ -18,59 +19,71 @@ class PoolCounter(BaseCounter):
|
|||
|
||||
@TASK_UTILS.register_module()
|
||||
class MaxPool1dCounter(PoolCounter):
|
||||
"""FLOPs/params counter for MaxPool1d module."""
|
||||
pass
|
||||
|
||||
|
||||
@TASK_UTILS.register_module()
|
||||
class MaxPool2dCounter(PoolCounter):
|
||||
"""FLOPs/params counter for MaxPool2d module."""
|
||||
pass
|
||||
|
||||
|
||||
@TASK_UTILS.register_module()
|
||||
class MaxPool3dCounter(PoolCounter):
|
||||
"""FLOPs/params counter for MaxPool3d module."""
|
||||
pass
|
||||
|
||||
|
||||
@TASK_UTILS.register_module()
|
||||
class AvgPool1dCounter(PoolCounter):
|
||||
"""FLOPs/params counter for AvgPool1d module."""
|
||||
pass
|
||||
|
||||
|
||||
@TASK_UTILS.register_module()
|
||||
class AvgPool2dCounter(PoolCounter):
|
||||
"""FLOPs/params counter for AvgPool2d module."""
|
||||
pass
|
||||
|
||||
|
||||
@TASK_UTILS.register_module()
|
||||
class AvgPool3dCounter(PoolCounter):
|
||||
"""FLOPs/params counter for AvgPool3d module."""
|
||||
pass
|
||||
|
||||
|
||||
@TASK_UTILS.register_module()
|
||||
class AdaptiveMaxPool1dCounter(PoolCounter):
|
||||
"""FLOPs/params counter for AdaptiveMaxPool1d module."""
|
||||
pass
|
||||
|
||||
|
||||
@TASK_UTILS.register_module()
|
||||
class AdaptiveMaxPool2dCounter(PoolCounter):
|
||||
"""FLOPs/params counter for AdaptiveMaxPool2d module."""
|
||||
pass
|
||||
|
||||
|
||||
@TASK_UTILS.register_module()
|
||||
class AdaptiveMaxPool3dCounter(PoolCounter):
|
||||
"""FLOPs/params counter for AdaptiveMaxPool3d module."""
|
||||
pass
|
||||
|
||||
|
||||
@TASK_UTILS.register_module()
|
||||
class AdaptiveAvgPool1dCounter(PoolCounter):
|
||||
"""FLOPs/params counter for AdaptiveAvgPool1d module."""
|
||||
pass
|
||||
|
||||
|
||||
@TASK_UTILS.register_module()
|
||||
class AdaptiveAvgPool2dCounter(PoolCounter):
|
||||
"""FLOPs/params counter for AdaptiveAvgPool2d module."""
|
||||
pass
|
||||
|
||||
|
||||
@TASK_UTILS.register_module()
|
||||
class AdaptiveAvgPool3dCounter(PoolCounter):
|
||||
"""FLOPs/params counter for AdaptiveAvgPool3d module."""
|
||||
pass
|
||||
|
|
|
@ -10,6 +10,7 @@ class UpsampleCounter(BaseCounter):
|
|||
|
||||
@staticmethod
|
||||
def add_count_hook(module, input, output):
|
||||
"""Calculate FLOPs and params based on the size of input & output."""
|
||||
output_size = output[0]
|
||||
batch_size = output_size.shape[0]
|
||||
output_elements_count = batch_size
|
||||
|
|
Loading…
Reference in New Issue