diff --git a/mmrazor/models/task_modules/estimators/counters/flops_params_counter.py b/mmrazor/models/task_modules/estimators/counters/flops_params_counter.py index c788d3c8..31e998a2 100644 --- a/mmrazor/models/task_modules/estimators/counters/flops_params_counter.py +++ b/mmrazor/models/task_modules/estimators/counters/flops_params_counter.py @@ -216,6 +216,7 @@ def print_model_with_flops_params(model, """ def accumulate_params(self): + """Accumulate params by recursion.""" if is_supported_instance(self): return self.__params__ else: @@ -225,6 +226,7 @@ def print_model_with_flops_params(model, return sum def accumulate_flops(self): + """Accumulate flops by recursion.""" if is_supported_instance(self): return self.__flops__ / model.__batch_counter__ else: @@ -234,6 +236,7 @@ def print_model_with_flops_params(model, return sum def flops_repr(self): + """A new extra_repr method of the input module.""" accumulated_num_params = self.accumulate_params() accumulated_flops_cost = self.accumulate_flops() flops_string = str( @@ -252,6 +255,7 @@ def print_model_with_flops_params(model, ]) def add_extra_repr(m): + """Reload extra_repr method.""" m.accumulate_flops = accumulate_flops.__get__(m) m.accumulate_params = accumulate_params.__get__(m) flops_extra_repr = flops_repr.__get__(m) @@ -261,6 +265,7 @@ def print_model_with_flops_params(model, assert m.extra_repr != m.original_extra_repr def del_extra_repr(m): + """Recover origin extra_repr method.""" if hasattr(m, 'original_extra_repr'): m.extra_repr = m.original_extra_repr del m.original_extra_repr @@ -281,6 +286,7 @@ def accumulate_sub_module_flops_params(model): """ def accumulate_params(module): + """Accumulate params by recursion.""" if is_supported_instance(module): return module.__params__ else: @@ -290,6 +296,7 @@ def accumulate_sub_module_flops_params(model): return sum def accumulate_flops(module): + """Accumulate flops by recursion.""" if is_supported_instance(module): return module.__flops__ / model.__batch_counter__ else: @@ -310,6 +317,7 @@ def get_model_parameters_number(model): Args: model (nn.module): The model for parameter number calculation. + Returns: float: Parameter number of the model. """ @@ -318,8 +326,10 @@ def get_model_parameters_number(model): def add_flops_params_counting_methods(net_main_module): - # adding additional methods to the existing module object, - # this is done this way so that each function has access to self object + """Add additional methods to the existing module object. + + This is done this way so that each function has access to self object. + """ net_main_module.start_flops_params_count = start_flops_params_count.__get__( # noqa: E501 net_main_module) net_main_module.stop_flops_params_count = stop_flops_params_count.__get__( @@ -339,6 +349,7 @@ def compute_average_flops_params_cost(self): A method to compute average FLOPs cost, which will be available after `add_flops_params_counting_methods()` is called on a desired net object. + Returns: float: Current mean flops consumption per image. """ @@ -406,16 +417,18 @@ def reset_flops_params_count(self): # ---- Internal functions def empty_flops_params_counter_hook(module, input, output): + """Empty flops and params variables of the module.""" module.__flops__ += 0 module.__params__ += 0 def add_batch_counter_variables_or_reset(module): - + """Add or reset the batch counter variable.""" module.__batch_counter__ = 0 def add_batch_counter_hook_function(module): + """Register the batch counter hook for the module.""" if hasattr(module, '__batch_counter_handle__'): return @@ -424,6 +437,7 @@ def add_batch_counter_hook_function(module): def batch_counter_hook(module, input, output): + """Add batch counter variable based on the input size.""" batch_size = 1 if len(input) > 0: # Can have multiple inputs, getting the first one @@ -437,12 +451,14 @@ def batch_counter_hook(module, input, output): def remove_batch_counter_hook_function(module): + """Remove batch counter handle variable.""" if hasattr(module, '__batch_counter_handle__'): module.__batch_counter_handle__.remove() del module.__batch_counter_handle__ def add_flops_params_counter_variable_or_reset(module): + """Add or reset flops and params variable of the module.""" if is_supported_instance(module): if hasattr(module, '__flops__') or hasattr(module, '__params__'): print('Warning: variables __flops__ or __params__ are already ' @@ -453,16 +469,19 @@ def add_flops_params_counter_variable_or_reset(module): def get_counter_type(module): + """Get counter type of the module based on the module class name.""" return module.__class__.__name__ + 'Counter' def is_supported_instance(module): + """Judge whether the module is in TASK_UTILS registry or not.""" if get_counter_type(module) in TASK_UTILS._module_dict.keys(): return True return False def remove_flops_params_counter_hook_function(module): + """Remove counter related variables after resource estimation.""" if hasattr(module, '__flops_params_handle__'): module.__flops_params_handle__.remove() del module.__flops_params_handle__ diff --git a/mmrazor/models/task_modules/estimators/counters/latency_counter.py b/mmrazor/models/task_modules/estimators/counters/latency_counter.py index 55a145d0..e3e91c54 100644 --- a/mmrazor/models/task_modules/estimators/counters/latency_counter.py +++ b/mmrazor/models/task_modules/estimators/counters/latency_counter.py @@ -1,17 +1,31 @@ # Copyright (c) OpenMMLab. All rights reserved. import logging import time +from typing import Any, Dict import torch from mmengine.logging import print_log -def repeat_measure_inference_speed(model, - resource_args, +def repeat_measure_inference_speed(model: torch.nn.Module, + resource_args: Dict[str, Any], max_iter: int = 100, + num_warmup: int = 5, log_interval: int = 100, repeat_num: int = 1) -> float: - """Repeat speed measure for multi-times to get more precise results.""" + """Repeat speed measure for multi-times to get more precise results. + + Args: + model (torch.nn.Module): The measured model. + resource_args (Dict[str, float]): resources information. + max_iter (Optional[int]): Max iteration num for inference speed test. + num_warmup (Optional[int]): Iteration num for warm-up stage. + log_interval (Optional[int]): Interval num for logging the results. + repeat_num (Optional[int]): Num of times to repeat the measurement. + + Returns: + fps (float): The measured inference speed of the model. + """ assert repeat_num >= 1 fps_list = [] @@ -19,7 +33,7 @@ def repeat_measure_inference_speed(model, for _ in range(repeat_num): fps_list.append( - measure_inference_speed(model, resource_args, max_iter, + measure_inference_speed(model, resource_args, max_iter, num_warmup, log_interval)) if repeat_num > 1: @@ -39,10 +53,24 @@ def repeat_measure_inference_speed(model, return latency -def measure_inference_speed(model, resource_args, max_iter: int, - log_interval: int) -> float: +def measure_inference_speed(model: torch.nn.Module, + resource_args: Dict[str, Any], + max_iter: int = 100, + num_warmup: int = 5, + log_interval: int = 100) -> float: + """Measure inference speed on GPU devices. + + Args: + model (torch.nn.Module): The measured model. + resource_args (Dict[str, float]): resources information. + max_iter (Optional[int]): Max iteration num for inference speed test. + num_warmup (Optional[int]): Iteration num for warm-up stage. + log_interval (Optional[int]): Interval num for logging the results. + + Returns: + fps (float): The measured inference speed of the model. + """ # the first several iterations may be very slow so skip them - num_warmup = 5 pure_inf_time = 0.0 fps = 0.0 data = dict() @@ -50,7 +78,7 @@ def measure_inference_speed(model, resource_args, max_iter: int, device = 'cuda' else: raise NotImplementedError('To use cpu to test latency not supported.') - # benchmark with 100 image and take the average + # benchmark with {max_iter} image and take the average for i in range(1, max_iter): if device == 'cuda': data = torch.rand(resource_args['input_shape']).cuda() diff --git a/mmrazor/models/task_modules/estimators/counters/op_counters/activation_layer_counter.py b/mmrazor/models/task_modules/estimators/counters/op_counters/activation_layer_counter.py index f124c0db..e32aa552 100644 --- a/mmrazor/models/task_modules/estimators/counters/op_counters/activation_layer_counter.py +++ b/mmrazor/models/task_modules/estimators/counters/op_counters/activation_layer_counter.py @@ -10,6 +10,7 @@ class ReLUCounter(BaseCounter): @staticmethod def add_count_hook(module, input, output): + """Calculate FLOPs and params based on the size of input & output.""" active_elements_count = output.numel() module.__flops__ += int(active_elements_count) module.__params__ += get_model_parameters_number(module) @@ -17,19 +18,23 @@ class ReLUCounter(BaseCounter): @TASK_UTILS.register_module() class PReLUCounter(ReLUCounter): + """FLOPs/params counter for PReLU function.""" pass @TASK_UTILS.register_module() class ELUCounter(ReLUCounter): + """FLOPs/params counter for ELU function.""" pass @TASK_UTILS.register_module() class LeakyReLUCounter(ReLUCounter): + """FLOPs/params counter for LeakyReLU function.""" pass @TASK_UTILS.register_module() class ReLU6Counter(ReLUCounter): + """FLOPs/params counter for ReLU6 function.""" pass diff --git a/mmrazor/models/task_modules/estimators/counters/op_counters/conv_layer_counter.py b/mmrazor/models/task_modules/estimators/counters/op_counters/conv_layer_counter.py index 0e9c6c77..959d88fa 100644 --- a/mmrazor/models/task_modules/estimators/counters/op_counters/conv_layer_counter.py +++ b/mmrazor/models/task_modules/estimators/counters/op_counters/conv_layer_counter.py @@ -10,6 +10,7 @@ class ConvCounter(BaseCounter): @staticmethod def add_count_hook(module, input, output): + """Calculate FLOPs and params based on the size of input & output.""" # Can have multiple inputs, getting the first one input = input[0] @@ -44,14 +45,17 @@ class ConvCounter(BaseCounter): @TASK_UTILS.register_module() class Conv1dCounter(ConvCounter): + """FLOPs/params counter for Conv1d module.""" pass @TASK_UTILS.register_module() class Conv2dCounter(ConvCounter): + """FLOPs/params counter for Conv2d module.""" pass @TASK_UTILS.register_module() class Conv3dCounter(ConvCounter): + """FLOPs/params counter for Conv3d module.""" pass diff --git a/mmrazor/models/task_modules/estimators/counters/op_counters/deconv_layer_counter.py b/mmrazor/models/task_modules/estimators/counters/op_counters/deconv_layer_counter.py index 73604243..0426fbb4 100644 --- a/mmrazor/models/task_modules/estimators/counters/op_counters/deconv_layer_counter.py +++ b/mmrazor/models/task_modules/estimators/counters/op_counters/deconv_layer_counter.py @@ -6,10 +6,11 @@ from .base_counter import BaseCounter @TASK_UTILS.register_module() class ConvTranspose2dCounter(BaseCounter): - """FLOPs/params counter for Decov module series.""" + """FLOPs/params counter for Deconv module series.""" @staticmethod def add_count_hook(module, input, output): + """Compute FLOPs and params based on the size of input & output.""" # Can have multiple inputs, getting the first one input = input[0] diff --git a/mmrazor/models/task_modules/estimators/counters/op_counters/linear_layer_counter.py b/mmrazor/models/task_modules/estimators/counters/op_counters/linear_layer_counter.py index c4f6ac6e..f8e9ea8f 100644 --- a/mmrazor/models/task_modules/estimators/counters/op_counters/linear_layer_counter.py +++ b/mmrazor/models/task_modules/estimators/counters/op_counters/linear_layer_counter.py @@ -12,6 +12,7 @@ class LinearCounter(BaseCounter): @staticmethod def add_count_hook(module, input, output): + """Calculate FLOPs and params based on the size of input & output.""" input = input[0] output_last_dim = output.shape[ -1] # pytorch checks dimensions, so here we don't care much diff --git a/mmrazor/models/task_modules/estimators/counters/op_counters/norm_layer_counter.py b/mmrazor/models/task_modules/estimators/counters/op_counters/norm_layer_counter.py index 5941f7c0..9b9a14ca 100644 --- a/mmrazor/models/task_modules/estimators/counters/op_counters/norm_layer_counter.py +++ b/mmrazor/models/task_modules/estimators/counters/op_counters/norm_layer_counter.py @@ -11,6 +11,7 @@ class BNCounter(BaseCounter): @staticmethod def add_count_hook(module, input, output): + """Calculate FLOPs and params based on the size of input & output.""" input = input[0] batch_flops = np.prod(input.shape) if getattr(module, 'affine', False): @@ -21,39 +22,47 @@ class BNCounter(BaseCounter): @TASK_UTILS.register_module() class BatchNorm1dCounter(BNCounter): + """FLOPs/params counter for BatchNorm1d module.""" pass @TASK_UTILS.register_module() class BatchNorm2dCounter(BNCounter): + """FLOPs/params counter for BatchNorm2d module.""" pass @TASK_UTILS.register_module() class BatchNorm3dCounter(BNCounter): + """FLOPs/params counter for BatchNorm3d module.""" pass @TASK_UTILS.register_module() class InstanceNorm1dCounter(BNCounter): + """FLOPs/params counter for InstanceNorm1d module.""" pass @TASK_UTILS.register_module() class InstanceNorm2dCounter(BNCounter): + """FLOPs/params counter for InstanceNorm2d module.""" pass @TASK_UTILS.register_module() class InstanceNorm3dCounter(BNCounter): + """FLOPs/params counter for InstanceNorm3d module.""" pass @TASK_UTILS.register_module() class LayerNormCounter(BNCounter): + """FLOPs/params counter for LayerNorm module.""" pass @TASK_UTILS.register_module() class GroupNormCounter(BNCounter): + """FLOPs/params counter for GroupNorm module.""" pass diff --git a/mmrazor/models/task_modules/estimators/counters/op_counters/pooling_layer_counter.py b/mmrazor/models/task_modules/estimators/counters/op_counters/pooling_layer_counter.py index c4e94cdc..27d9b605 100644 --- a/mmrazor/models/task_modules/estimators/counters/op_counters/pooling_layer_counter.py +++ b/mmrazor/models/task_modules/estimators/counters/op_counters/pooling_layer_counter.py @@ -11,6 +11,7 @@ class PoolCounter(BaseCounter): @staticmethod def add_count_hook(module, input, output): + """Calculate FLOPs and params based on the size of input & output.""" input = input[0] module.__flops__ += int(np.prod(input.shape)) module.__params__ += get_model_parameters_number(module) @@ -18,59 +19,71 @@ class PoolCounter(BaseCounter): @TASK_UTILS.register_module() class MaxPool1dCounter(PoolCounter): + """FLOPs/params counter for MaxPool1d module.""" pass @TASK_UTILS.register_module() class MaxPool2dCounter(PoolCounter): + """FLOPs/params counter for MaxPool2d module.""" pass @TASK_UTILS.register_module() class MaxPool3dCounter(PoolCounter): + """FLOPs/params counter for MaxPool3d module.""" pass @TASK_UTILS.register_module() class AvgPool1dCounter(PoolCounter): + """FLOPs/params counter for AvgPool1d module.""" pass @TASK_UTILS.register_module() class AvgPool2dCounter(PoolCounter): + """FLOPs/params counter for AvgPool2d module.""" pass @TASK_UTILS.register_module() class AvgPool3dCounter(PoolCounter): + """FLOPs/params counter for AvgPool3d module.""" pass @TASK_UTILS.register_module() class AdaptiveMaxPool1dCounter(PoolCounter): + """FLOPs/params counter for AdaptiveMaxPool1d module.""" pass @TASK_UTILS.register_module() class AdaptiveMaxPool2dCounter(PoolCounter): + """FLOPs/params counter for AdaptiveMaxPool2d module.""" pass @TASK_UTILS.register_module() class AdaptiveMaxPool3dCounter(PoolCounter): + """FLOPs/params counter for AdaptiveMaxPool3d module.""" pass @TASK_UTILS.register_module() class AdaptiveAvgPool1dCounter(PoolCounter): + """FLOPs/params counter for AdaptiveAvgPool1d module.""" pass @TASK_UTILS.register_module() class AdaptiveAvgPool2dCounter(PoolCounter): + """FLOPs/params counter for AdaptiveAvgPool2d module.""" pass @TASK_UTILS.register_module() class AdaptiveAvgPool3dCounter(PoolCounter): + """FLOPs/params counter for AdaptiveAvgPool3d module.""" pass diff --git a/mmrazor/models/task_modules/estimators/counters/op_counters/upsample_layer_counter.py b/mmrazor/models/task_modules/estimators/counters/op_counters/upsample_layer_counter.py index 9442ac56..12958b6d 100644 --- a/mmrazor/models/task_modules/estimators/counters/op_counters/upsample_layer_counter.py +++ b/mmrazor/models/task_modules/estimators/counters/op_counters/upsample_layer_counter.py @@ -10,6 +10,7 @@ class UpsampleCounter(BaseCounter): @staticmethod def add_count_hook(module, input, output): + """Calculate FLOPs and params based on the size of input & output.""" output_size = output[0] batch_size = output_size.shape[0] output_elements_count = batch_size