mirror of https://github.com/alibaba/EasyCV.git
477 lines
15 KiB
Python
477 lines
15 KiB
Python
# Modified from flops-counter.pytorch by Vladislav Sovrasov
|
|
# original repo: https://github.com/sovrasov/flops-counter.pytorch
|
|
|
|
# MIT License
|
|
|
|
# Copyright (c) 2018 Vladislav Sovrasov
|
|
|
|
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
# of this software and associated documentation files (the "Software"), to deal
|
|
# in the Software without restriction, including without limitation the rights
|
|
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
# copies of the Software, and to permit persons to whom the Software is
|
|
# furnished to do so, subject to the following conditions:
|
|
|
|
# The above copyright notice and this permission notice shall be included in
|
|
# all copies or substantial portions of the Software.
|
|
|
|
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
# SOFTWARE.
|
|
|
|
import sys
|
|
from copy import deepcopy
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
from thop import profile
|
|
from torch.nn.modules.batchnorm import _BatchNorm
|
|
from torch.nn.modules.conv import _ConvNd, _ConvTransposeMixin
|
|
from torch.nn.modules.pooling import (_AdaptiveAvgPoolNd, _AdaptiveMaxPoolNd,
|
|
_AvgPoolNd, _MaxPoolNd)
|
|
|
|
|
|
def get_model_info(model, input_size, model_config, logger):
|
|
"""
|
|
get_model_info, check model parameters and Gflops
|
|
"""
|
|
stride = 64
|
|
img = torch.zeros((1, 3, stride, stride),
|
|
device=next(model.parameters()).device)
|
|
mode = 'test'
|
|
flops, params = profile(
|
|
deepcopy(model), inputs=(
|
|
img,
|
|
mode,
|
|
), verbose=False)
|
|
params = params / 1e6
|
|
flops = flops / 1e9
|
|
# Gflops
|
|
flops *= (input_size[0] * input_size[1]) / stride / stride * 2
|
|
|
|
info = 'Params: {:.2f}M, Gflops: {:.2f}'.format(params, flops)
|
|
logger.info('Model Summary: {}'.format(info))
|
|
|
|
if model_config['model_type'] == 'customized':
|
|
if model_config['max_model_params'] > 0:
|
|
assert model_config[
|
|
'max_model_params'] > params, 'model params is larger than set parameters, please reset model!'
|
|
if model_config['max_model_flops'] > 0:
|
|
assert model_config[
|
|
'max_model_flops'] > flops, 'model flops is larger than set parameters, please reset model!'
|
|
|
|
|
|
def get_model_complexity_info(model,
|
|
input_res,
|
|
print_per_layer_stat=True,
|
|
as_strings=True,
|
|
input_constructor=None,
|
|
ost=sys.stdout):
|
|
assert type(input_res) is tuple
|
|
assert len(input_res) >= 2
|
|
flops_model = add_flops_counting_methods(model)
|
|
flops_model.eval().start_flops_count()
|
|
if input_constructor:
|
|
input = input_constructor(input_res)
|
|
_ = flops_model(**input)
|
|
else:
|
|
batch = torch.ones(()).new_empty(
|
|
(1, *input_res),
|
|
dtype=next(flops_model.parameters()).dtype,
|
|
device=next(flops_model.parameters()).device)
|
|
flops_model(batch)
|
|
|
|
if print_per_layer_stat:
|
|
print_model_with_flops(flops_model, ost=ost)
|
|
flops_count = flops_model.compute_average_flops_cost()
|
|
params_count = get_model_parameters_number(flops_model)
|
|
flops_model.stop_flops_count()
|
|
|
|
if as_strings:
|
|
return flops_to_string(flops_count), params_to_string(params_count)
|
|
|
|
return flops_count, params_count
|
|
|
|
|
|
def flops_to_string(flops, units='GMac', precision=2):
|
|
if units is None:
|
|
if flops // 10**9 > 0:
|
|
return str(round(flops / 10.**9, precision)) + ' GMac'
|
|
elif flops // 10**6 > 0:
|
|
return str(round(flops / 10.**6, precision)) + ' MMac'
|
|
elif flops // 10**3 > 0:
|
|
return str(round(flops / 10.**3, precision)) + ' KMac'
|
|
else:
|
|
return str(flops) + ' Mac'
|
|
else:
|
|
if units == 'GMac':
|
|
return str(round(flops / 10.**9, precision)) + ' ' + units
|
|
elif units == 'MMac':
|
|
return str(round(flops / 10.**6, precision)) + ' ' + units
|
|
elif units == 'KMac':
|
|
return str(round(flops / 10.**3, precision)) + ' ' + units
|
|
else:
|
|
return str(flops) + ' Mac'
|
|
|
|
|
|
def params_to_string(params_num):
|
|
"""converting number to string
|
|
|
|
:param float params_num: number
|
|
:returns str: number
|
|
|
|
>>> params_to_string(1e9)
|
|
'1000.0 M'
|
|
>>> params_to_string(2e5)
|
|
'200.0 k'
|
|
>>> params_to_string(3e-9)
|
|
'3e-09'
|
|
"""
|
|
if params_num // 10**6 > 0:
|
|
return str(round(params_num / 10**6, 2)) + ' M'
|
|
elif params_num // 10**3:
|
|
return str(round(params_num / 10**3, 2)) + ' k'
|
|
else:
|
|
return str(params_num)
|
|
|
|
|
|
def print_model_with_flops(model, units='GMac', precision=3, ost=sys.stdout):
|
|
total_flops = model.compute_average_flops_cost()
|
|
|
|
def accumulate_flops(self):
|
|
if is_supported_instance(self):
|
|
return self.__flops__ / model.__batch_counter__
|
|
else:
|
|
sum = 0
|
|
for m in self.children():
|
|
sum += m.accumulate_flops()
|
|
return sum
|
|
|
|
def flops_repr(self):
|
|
accumulated_flops_cost = self.accumulate_flops()
|
|
return ', '.join([
|
|
flops_to_string(
|
|
accumulated_flops_cost, units=units, precision=precision),
|
|
'{:.3%} MACs'.format(accumulated_flops_cost / total_flops),
|
|
self.original_extra_repr()
|
|
])
|
|
|
|
def add_extra_repr(m):
|
|
m.accumulate_flops = accumulate_flops.__get__(m)
|
|
flops_extra_repr = flops_repr.__get__(m)
|
|
if m.extra_repr != flops_extra_repr:
|
|
m.original_extra_repr = m.extra_repr
|
|
m.extra_repr = flops_extra_repr
|
|
assert m.extra_repr != m.original_extra_repr
|
|
|
|
def del_extra_repr(m):
|
|
if hasattr(m, 'original_extra_repr'):
|
|
m.extra_repr = m.original_extra_repr
|
|
del m.original_extra_repr
|
|
if hasattr(m, 'accumulate_flops'):
|
|
del m.accumulate_flops
|
|
|
|
model.apply(add_extra_repr)
|
|
print(model, file=ost)
|
|
model.apply(del_extra_repr)
|
|
|
|
|
|
def get_model_parameters_number(model):
|
|
params_num = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
return params_num
|
|
|
|
|
|
def add_flops_counting_methods(net_main_module):
|
|
# adding additional methods to the existing module object,
|
|
# this is done this way so that each function has access to self object
|
|
net_main_module.start_flops_count = start_flops_count.__get__(
|
|
net_main_module)
|
|
net_main_module.stop_flops_count = stop_flops_count.__get__(
|
|
net_main_module)
|
|
net_main_module.reset_flops_count = reset_flops_count.__get__(
|
|
net_main_module)
|
|
net_main_module.compute_average_flops_cost = \
|
|
compute_average_flops_cost.__get__(net_main_module)
|
|
|
|
net_main_module.reset_flops_count()
|
|
|
|
# Adding variables necessary for masked flops computation
|
|
net_main_module.apply(add_flops_mask_variable_or_reset)
|
|
|
|
return net_main_module
|
|
|
|
|
|
def compute_average_flops_cost(self):
|
|
"""
|
|
A method that will be available after add_flops_counting_methods() is
|
|
called on a desired net object.
|
|
Returns current mean flops consumption per image.
|
|
"""
|
|
|
|
batches_count = self.__batch_counter__
|
|
flops_sum = 0
|
|
for module in self.modules():
|
|
if is_supported_instance(module):
|
|
flops_sum += module.__flops__
|
|
|
|
return flops_sum / batches_count
|
|
|
|
|
|
def start_flops_count(self):
|
|
"""
|
|
A method that will be available after add_flops_counting_methods() is
|
|
called on a desired net object.
|
|
Activates the computation of mean flops consumption per image.
|
|
Call it before you run the network.
|
|
"""
|
|
add_batch_counter_hook_function(self)
|
|
self.apply(add_flops_counter_hook_function)
|
|
|
|
|
|
def stop_flops_count(self):
|
|
"""
|
|
A method that will be available after add_flops_counting_methods() is
|
|
called on a desired net object.
|
|
Stops computing the mean flops consumption per image.
|
|
Call whenever you want to pause the computation.
|
|
"""
|
|
remove_batch_counter_hook_function(self)
|
|
self.apply(remove_flops_counter_hook_function)
|
|
|
|
|
|
def reset_flops_count(self):
|
|
"""
|
|
A method that will be available after add_flops_counting_methods() is
|
|
called on a desired net object.
|
|
Resets statistics computed so far.
|
|
"""
|
|
add_batch_counter_variables_or_reset(self)
|
|
self.apply(add_flops_counter_variable_or_reset)
|
|
|
|
|
|
def add_flops_mask(module, mask):
|
|
|
|
def add_flops_mask_func(module):
|
|
if isinstance(module, torch.nn.Conv2d):
|
|
module.__mask__ = mask
|
|
|
|
module.apply(add_flops_mask_func)
|
|
|
|
|
|
def remove_flops_mask(module):
|
|
module.apply(add_flops_mask_variable_or_reset)
|
|
|
|
|
|
def is_supported_instance(module):
|
|
for mod in hook_mapping:
|
|
if issubclass(type(module), mod):
|
|
return True
|
|
return False
|
|
|
|
|
|
def empty_flops_counter_hook(module, input, output):
|
|
module.__flops__ += 0
|
|
|
|
|
|
def upsample_flops_counter_hook(module, input, output):
|
|
output_size = output[0]
|
|
batch_size = output_size.shape[0]
|
|
output_elements_count = batch_size
|
|
for val in output_size.shape[1:]:
|
|
output_elements_count *= val
|
|
module.__flops__ += int(output_elements_count)
|
|
|
|
|
|
def relu_flops_counter_hook(module, input, output):
|
|
active_elements_count = output.numel()
|
|
module.__flops__ += int(active_elements_count)
|
|
|
|
|
|
def linear_flops_counter_hook(module, input, output):
|
|
input = input[0]
|
|
batch_size = input.shape[0]
|
|
module.__flops__ += int(batch_size * input.shape[1] * output.shape[1])
|
|
|
|
|
|
def pool_flops_counter_hook(module, input, output):
|
|
input = input[0]
|
|
module.__flops__ += int(np.prod(input.shape))
|
|
|
|
|
|
def bn_flops_counter_hook(module, input, output):
|
|
input = input[0]
|
|
|
|
batch_flops = np.prod(input.shape)
|
|
if module.affine:
|
|
batch_flops *= 2
|
|
module.__flops__ += int(batch_flops)
|
|
|
|
|
|
def gn_flops_counter_hook(module, input, output):
|
|
elems = np.prod(input[0].shape)
|
|
# there is no precise FLOPs estimation of computing mean and variance,
|
|
# and we just set it 2 * elems: half muladds for computing
|
|
# means and half for computing vars
|
|
batch_flops = 3 * elems
|
|
if module.affine:
|
|
batch_flops += elems
|
|
module.__flops__ += int(batch_flops)
|
|
|
|
|
|
def deconv_flops_counter_hook(conv_module, input, output):
|
|
# Can have multiple inputs, getting the first one
|
|
input = input[0]
|
|
|
|
batch_size = input.shape[0]
|
|
input_height, input_width = input.shape[2:]
|
|
|
|
kernel_height, kernel_width = conv_module.kernel_size
|
|
in_channels = conv_module.in_channels
|
|
out_channels = conv_module.out_channels
|
|
groups = conv_module.groups
|
|
|
|
filters_per_channel = out_channels // groups
|
|
conv_per_position_flops = (
|
|
kernel_height * kernel_width * in_channels * filters_per_channel)
|
|
|
|
active_elements_count = batch_size * input_height * input_width
|
|
overall_conv_flops = conv_per_position_flops * active_elements_count
|
|
bias_flops = 0
|
|
if conv_module.bias is not None:
|
|
output_height, output_width = output.shape[2:]
|
|
bias_flops = out_channels * batch_size * output_height * output_height
|
|
overall_flops = overall_conv_flops + bias_flops
|
|
|
|
conv_module.__flops__ += int(overall_flops)
|
|
|
|
|
|
def conv_flops_counter_hook(conv_module, input, output):
|
|
# Can have multiple inputs, getting the first one
|
|
input = input[0]
|
|
|
|
batch_size = input.shape[0]
|
|
output_dims = list(output.shape[2:])
|
|
|
|
kernel_dims = list(conv_module.kernel_size)
|
|
in_channels = conv_module.in_channels
|
|
out_channels = conv_module.out_channels
|
|
groups = conv_module.groups
|
|
|
|
filters_per_channel = out_channels // groups
|
|
conv_per_position_flops = np.prod(
|
|
kernel_dims) * in_channels * filters_per_channel
|
|
|
|
active_elements_count = batch_size * np.prod(output_dims)
|
|
|
|
if conv_module.__mask__ is not None:
|
|
# (b, 1, h, w)
|
|
output_height, output_width = output.shape[2:]
|
|
flops_mask = conv_module.__mask__.expand(batch_size, 1, output_height,
|
|
output_width)
|
|
active_elements_count = flops_mask.sum()
|
|
|
|
overall_conv_flops = conv_per_position_flops * active_elements_count
|
|
|
|
bias_flops = 0
|
|
|
|
if conv_module.bias is not None:
|
|
|
|
bias_flops = out_channels * active_elements_count
|
|
|
|
overall_flops = overall_conv_flops + bias_flops
|
|
|
|
conv_module.__flops__ += int(overall_flops)
|
|
|
|
|
|
hook_mapping = {
|
|
# conv
|
|
_ConvNd: conv_flops_counter_hook,
|
|
# deconv
|
|
_ConvTransposeMixin: deconv_flops_counter_hook,
|
|
# fc
|
|
nn.Linear: linear_flops_counter_hook,
|
|
# pooling
|
|
_AvgPoolNd: pool_flops_counter_hook,
|
|
_MaxPoolNd: pool_flops_counter_hook,
|
|
_AdaptiveAvgPoolNd: pool_flops_counter_hook,
|
|
_AdaptiveMaxPoolNd: pool_flops_counter_hook,
|
|
# activation
|
|
nn.ReLU: relu_flops_counter_hook,
|
|
nn.PReLU: relu_flops_counter_hook,
|
|
nn.ELU: relu_flops_counter_hook,
|
|
nn.LeakyReLU: relu_flops_counter_hook,
|
|
nn.ReLU6: relu_flops_counter_hook,
|
|
# normalization
|
|
_BatchNorm: bn_flops_counter_hook,
|
|
nn.GroupNorm: gn_flops_counter_hook,
|
|
# upsample
|
|
nn.Upsample: upsample_flops_counter_hook,
|
|
}
|
|
|
|
|
|
def batch_counter_hook(module, input, output):
|
|
batch_size = 1
|
|
if len(input) > 0:
|
|
# Can have multiple inputs, getting the first one
|
|
input = input[0]
|
|
batch_size = len(input)
|
|
else:
|
|
print('Warning! No positional inputs found for a module, '
|
|
'assuming batch size is 1.')
|
|
module.__batch_counter__ += batch_size
|
|
|
|
|
|
def add_batch_counter_variables_or_reset(module):
|
|
module.__batch_counter__ = 0
|
|
|
|
|
|
def add_batch_counter_hook_function(module):
|
|
if hasattr(module, '__batch_counter_handle__'):
|
|
return
|
|
|
|
handle = module.register_forward_hook(batch_counter_hook)
|
|
module.__batch_counter_handle__ = handle
|
|
|
|
|
|
def remove_batch_counter_hook_function(module):
|
|
if hasattr(module, '__batch_counter_handle__'):
|
|
module.__batch_counter_handle__.remove()
|
|
del module.__batch_counter_handle__
|
|
|
|
|
|
def add_flops_counter_variable_or_reset(module):
|
|
if is_supported_instance(module):
|
|
module.__flops__ = 0
|
|
|
|
|
|
def add_flops_counter_hook_function(module):
|
|
if is_supported_instance(module):
|
|
if hasattr(module, '__flops_handle__'):
|
|
return
|
|
|
|
for mod_type, counter_hook in hook_mapping.items():
|
|
if issubclass(type(module), mod_type):
|
|
handle = module.register_forward_hook(counter_hook)
|
|
break
|
|
|
|
module.__flops_handle__ = handle
|
|
|
|
|
|
def remove_flops_counter_hook_function(module):
|
|
if is_supported_instance(module):
|
|
if hasattr(module, '__flops_handle__'):
|
|
module.__flops_handle__.remove()
|
|
del module.__flops_handle__
|
|
|
|
|
|
# --- Masked flops counting
|
|
# Also being run in the initialization
|
|
def add_flops_mask_variable_or_reset(module):
|
|
if is_supported_instance(module):
|
|
module.__mask__ = None
|