mirror of https://github.com/YifanXu74/MQ-Det.git
510 lines
18 KiB
Python
510 lines
18 KiB
Python
'''
|
|
Copyright (C) 2019 Sovrasov V. - All Rights Reserved
|
|
* You may use, distribute and modify this code under the
|
|
* terms of the MIT license.
|
|
* You should have received a copy of the MIT license with
|
|
* this file. If not visit https://opensource.org/licenses/MIT
|
|
'''
|
|
|
|
import sys
|
|
from functools import partial
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from maskrcnn_benchmark.layers import *
|
|
|
|
def get_model_complexity_info(model, input_res,
|
|
print_per_layer_stat=True,
|
|
as_strings=True,
|
|
input_constructor=None, ost=sys.stdout,
|
|
verbose=False, ignore_modules=[],
|
|
custom_modules_hooks={}):
|
|
assert type(input_res) is tuple
|
|
assert len(input_res) >= 1
|
|
assert isinstance(model, nn.Module)
|
|
global CUSTOM_MODULES_MAPPING
|
|
CUSTOM_MODULES_MAPPING = custom_modules_hooks
|
|
flops_model = add_flops_counting_methods(model)
|
|
flops_model.eval()
|
|
flops_model.start_flops_count(ost=ost, verbose=verbose,
|
|
ignore_list=ignore_modules)
|
|
if input_constructor:
|
|
input = input_constructor(input_res)
|
|
_ = flops_model(**input)
|
|
else:
|
|
try:
|
|
batch = torch.ones(()).new_empty((1, *input_res),
|
|
dtype=next(flops_model.parameters()).dtype,
|
|
device=next(flops_model.parameters()).device)
|
|
except StopIteration:
|
|
batch = torch.ones(()).new_empty((1, *input_res))
|
|
|
|
_ = flops_model(batch)
|
|
|
|
flops_count, params_count = flops_model.compute_average_flops_cost()
|
|
if print_per_layer_stat:
|
|
print_model_with_flops(flops_model, flops_count, params_count, ost=ost)
|
|
flops_model.stop_flops_count()
|
|
CUSTOM_MODULES_MAPPING = {}
|
|
|
|
if as_strings:
|
|
return flops_to_string(flops_count), params_to_string(params_count)
|
|
|
|
return flops_count, params_count
|
|
|
|
|
|
def flops_to_string(flops, units='GMac', precision=2):
|
|
if units is None:
|
|
if flops // 10**9 > 0:
|
|
return str(round(flops / 10.**9, precision)) + ' GMac'
|
|
elif flops // 10**6 > 0:
|
|
return str(round(flops / 10.**6, precision)) + ' MMac'
|
|
elif flops // 10**3 > 0:
|
|
return str(round(flops / 10.**3, precision)) + ' KMac'
|
|
else:
|
|
return str(flops) + ' Mac'
|
|
else:
|
|
if units == 'GMac':
|
|
return str(round(flops / 10.**9, precision)) + ' ' + units
|
|
elif units == 'MMac':
|
|
return str(round(flops / 10.**6, precision)) + ' ' + units
|
|
elif units == 'KMac':
|
|
return str(round(flops / 10.**3, precision)) + ' ' + units
|
|
else:
|
|
return str(flops) + ' Mac'
|
|
|
|
|
|
def params_to_string(params_num, units=None, precision=2):
|
|
if units is None:
|
|
if params_num // 10 ** 6 > 0:
|
|
return str(round(params_num / 10 ** 6, 2)) + ' M'
|
|
elif params_num // 10 ** 3:
|
|
return str(round(params_num / 10 ** 3, 2)) + ' k'
|
|
else:
|
|
return str(params_num)
|
|
else:
|
|
if units == 'M':
|
|
return str(round(params_num / 10.**6, precision)) + ' ' + units
|
|
elif units == 'K':
|
|
return str(round(params_num / 10.**3, precision)) + ' ' + units
|
|
else:
|
|
return str(params_num)
|
|
|
|
|
|
def accumulate_flops(self):
|
|
if is_supported_instance(self):
|
|
return self.__flops__
|
|
else:
|
|
sum = 0
|
|
for m in self.children():
|
|
sum += m.accumulate_flops()
|
|
return sum
|
|
|
|
|
|
def print_model_with_flops(model, total_flops, total_params, units='GMac',
|
|
precision=3, ost=sys.stdout):
|
|
|
|
def accumulate_params(self):
|
|
if is_supported_instance(self):
|
|
return self.__params__
|
|
else:
|
|
sum = 0
|
|
for m in self.children():
|
|
sum += m.accumulate_params()
|
|
return sum
|
|
|
|
def flops_repr(self):
|
|
accumulated_params_num = self.accumulate_params()
|
|
accumulated_flops_cost = self.accumulate_flops() / model.__batch_counter__
|
|
return ', '.join([params_to_string(accumulated_params_num,
|
|
units='M', precision=precision),
|
|
'{:.3%} Params'.format(accumulated_params_num / total_params),
|
|
flops_to_string(accumulated_flops_cost,
|
|
units=units, precision=precision),
|
|
'{:.3%} MACs'.format(accumulated_flops_cost / total_flops),
|
|
self.original_extra_repr()])
|
|
|
|
def add_extra_repr(m):
|
|
m.accumulate_flops = accumulate_flops.__get__(m)
|
|
m.accumulate_params = accumulate_params.__get__(m)
|
|
flops_extra_repr = flops_repr.__get__(m)
|
|
if m.extra_repr != flops_extra_repr:
|
|
m.original_extra_repr = m.extra_repr
|
|
m.extra_repr = flops_extra_repr
|
|
assert m.extra_repr != m.original_extra_repr
|
|
|
|
def del_extra_repr(m):
|
|
if hasattr(m, 'original_extra_repr'):
|
|
m.extra_repr = m.original_extra_repr
|
|
del m.original_extra_repr
|
|
if hasattr(m, 'accumulate_flops'):
|
|
del m.accumulate_flops
|
|
|
|
model.apply(add_extra_repr)
|
|
print(repr(model), file=ost)
|
|
model.apply(del_extra_repr)
|
|
|
|
|
|
def get_model_parameters_number(model):
|
|
params_num = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
return params_num
|
|
|
|
|
|
def add_flops_counting_methods(net_main_module):
|
|
# adding additional methods to the existing module object,
|
|
# this is done this way so that each function has access to self object
|
|
net_main_module.start_flops_count = start_flops_count.__get__(net_main_module)
|
|
net_main_module.stop_flops_count = stop_flops_count.__get__(net_main_module)
|
|
net_main_module.reset_flops_count = reset_flops_count.__get__(net_main_module)
|
|
net_main_module.compute_average_flops_cost = compute_average_flops_cost.__get__(
|
|
net_main_module)
|
|
|
|
net_main_module.reset_flops_count()
|
|
|
|
return net_main_module
|
|
|
|
|
|
def compute_average_flops_cost(self):
|
|
"""
|
|
A method that will be available after add_flops_counting_methods() is called
|
|
on a desired net object.
|
|
|
|
Returns current mean flops consumption per image.
|
|
|
|
"""
|
|
|
|
for m in self.modules():
|
|
m.accumulate_flops = accumulate_flops.__get__(m)
|
|
|
|
flops_sum = self.accumulate_flops()
|
|
|
|
for m in self.modules():
|
|
if hasattr(m, 'accumulate_flops'):
|
|
del m.accumulate_flops
|
|
|
|
params_sum = get_model_parameters_number(self)
|
|
return flops_sum / self.__batch_counter__, params_sum
|
|
|
|
|
|
def start_flops_count(self, **kwargs):
|
|
"""
|
|
A method that will be available after add_flops_counting_methods() is called
|
|
on a desired net object.
|
|
|
|
Activates the computation of mean flops consumption per image.
|
|
Call it before you run the network.
|
|
|
|
"""
|
|
add_batch_counter_hook_function(self)
|
|
|
|
seen_types = set()
|
|
|
|
def add_flops_counter_hook_function(module, ost, verbose, ignore_list):
|
|
if type(module) in ignore_list:
|
|
seen_types.add(type(module))
|
|
if is_supported_instance(module):
|
|
module.__params__ = 0
|
|
elif is_supported_instance(module):
|
|
if hasattr(module, '__flops_handle__'):
|
|
return
|
|
if type(module) in CUSTOM_MODULES_MAPPING:
|
|
handle = module.register_forward_hook(
|
|
CUSTOM_MODULES_MAPPING[type(module)])
|
|
elif getattr(module, 'compute_macs', False):
|
|
handle = module.register_forward_hook(
|
|
module.compute_macs
|
|
)
|
|
else:
|
|
handle = module.register_forward_hook(MODULES_MAPPING[type(module)])
|
|
module.__flops_handle__ = handle
|
|
seen_types.add(type(module))
|
|
else:
|
|
if verbose and not type(module) in (nn.Sequential, nn.ModuleList) and \
|
|
not type(module) in seen_types:
|
|
print('Warning: module ' + type(module).__name__ +
|
|
' is treated as a zero-op.', file=ost)
|
|
seen_types.add(type(module))
|
|
|
|
self.apply(partial(add_flops_counter_hook_function, **kwargs))
|
|
|
|
|
|
def stop_flops_count(self):
|
|
"""
|
|
A method that will be available after add_flops_counting_methods() is called
|
|
on a desired net object.
|
|
|
|
Stops computing the mean flops consumption per image.
|
|
Call whenever you want to pause the computation.
|
|
|
|
"""
|
|
remove_batch_counter_hook_function(self)
|
|
self.apply(remove_flops_counter_hook_function)
|
|
|
|
|
|
def reset_flops_count(self):
|
|
"""
|
|
A method that will be available after add_flops_counting_methods() is called
|
|
on a desired net object.
|
|
|
|
Resets statistics computed so far.
|
|
|
|
"""
|
|
add_batch_counter_variables_or_reset(self)
|
|
self.apply(add_flops_counter_variable_or_reset)
|
|
|
|
|
|
# ---- Internal functions
|
|
def empty_flops_counter_hook(module, input, output):
|
|
module.__flops__ += 0
|
|
|
|
|
|
def upsample_flops_counter_hook(module, input, output):
|
|
output_size = output[0]
|
|
batch_size = output_size.shape[0]
|
|
output_elements_count = batch_size
|
|
for val in output_size.shape[1:]:
|
|
output_elements_count *= val
|
|
module.__flops__ += int(output_elements_count)
|
|
|
|
|
|
def relu_flops_counter_hook(module, input, output):
|
|
active_elements_count = output.numel()
|
|
module.__flops__ += int(active_elements_count)
|
|
|
|
|
|
def linear_flops_counter_hook(module, input, output):
|
|
input = input[0]
|
|
# pytorch checks dimensions, so here we don't care much
|
|
output_last_dim = output.shape[-1]
|
|
bias_flops = output_last_dim if module.bias is not None else 0
|
|
module.__flops__ += int(np.prod(input.shape) * output_last_dim + bias_flops)
|
|
|
|
|
|
def pool_flops_counter_hook(module, input, output):
|
|
input = input[0]
|
|
module.__flops__ += int(np.prod(input.shape))
|
|
|
|
|
|
def bn_flops_counter_hook(module, input, output):
|
|
input = input[0]
|
|
|
|
batch_flops = np.prod(input.shape)
|
|
if module.affine:
|
|
batch_flops *= 2
|
|
module.__flops__ += int(batch_flops)
|
|
|
|
|
|
def conv_flops_counter_hook(conv_module, input, output):
|
|
# Can have multiple inputs, getting the first one
|
|
input = input[0]
|
|
|
|
batch_size = input.shape[0]
|
|
output_dims = list(output.shape[2:])
|
|
|
|
kernel_dims = list(conv_module.kernel_size)
|
|
in_channels = conv_module.in_channels
|
|
out_channels = conv_module.out_channels
|
|
groups = conv_module.groups
|
|
|
|
filters_per_channel = out_channels // groups
|
|
conv_per_position_flops = int(np.prod(kernel_dims)) * \
|
|
in_channels * filters_per_channel
|
|
|
|
active_elements_count = batch_size * int(np.prod(output_dims))
|
|
|
|
overall_conv_flops = conv_per_position_flops * active_elements_count
|
|
|
|
bias_flops = 0
|
|
|
|
if conv_module.bias is not None:
|
|
|
|
bias_flops = out_channels * active_elements_count
|
|
|
|
overall_flops = overall_conv_flops + bias_flops
|
|
|
|
conv_module.__flops__ += int(overall_flops)
|
|
|
|
|
|
def batch_counter_hook(module, input, output):
|
|
batch_size = 1
|
|
if len(input) > 0:
|
|
# Can have multiple inputs, getting the first one
|
|
input = input[0]
|
|
batch_size = len(input)
|
|
else:
|
|
pass
|
|
print('Warning! No positional inputs found for a module,'
|
|
' assuming batch size is 1.')
|
|
module.__batch_counter__ += batch_size
|
|
|
|
|
|
def rnn_flops(flops, rnn_module, w_ih, w_hh, input_size):
|
|
# matrix matrix mult ih state and internal state
|
|
flops += w_ih.shape[0]*w_ih.shape[1]
|
|
# matrix matrix mult hh state and internal state
|
|
flops += w_hh.shape[0]*w_hh.shape[1]
|
|
if isinstance(rnn_module, (nn.RNN, nn.RNNCell)):
|
|
# add both operations
|
|
flops += rnn_module.hidden_size
|
|
elif isinstance(rnn_module, (nn.GRU, nn.GRUCell)):
|
|
# hadamard of r
|
|
flops += rnn_module.hidden_size
|
|
# adding operations from both states
|
|
flops += rnn_module.hidden_size*3
|
|
# last two hadamard product and add
|
|
flops += rnn_module.hidden_size*3
|
|
elif isinstance(rnn_module, (nn.LSTM, nn.LSTMCell)):
|
|
# adding operations from both states
|
|
flops += rnn_module.hidden_size*4
|
|
# two hadamard product and add for C state
|
|
flops += rnn_module.hidden_size + rnn_module.hidden_size + rnn_module.hidden_size
|
|
# final hadamard
|
|
flops += rnn_module.hidden_size + rnn_module.hidden_size + rnn_module.hidden_size
|
|
return flops
|
|
|
|
|
|
def rnn_flops_counter_hook(rnn_module, input, output):
|
|
"""
|
|
Takes into account batch goes at first position, contrary
|
|
to pytorch common rule (but actually it doesn't matter).
|
|
IF sigmoid and tanh are made hard, only a comparison FLOPS should be accurate
|
|
"""
|
|
flops = 0
|
|
# input is a tuple containing a sequence to process and (optionally) hidden state
|
|
inp = input[0]
|
|
batch_size = inp.shape[0]
|
|
seq_length = inp.shape[1]
|
|
num_layers = rnn_module.num_layers
|
|
|
|
for i in range(num_layers):
|
|
w_ih = rnn_module.__getattr__('weight_ih_l' + str(i))
|
|
w_hh = rnn_module.__getattr__('weight_hh_l' + str(i))
|
|
if i == 0:
|
|
input_size = rnn_module.input_size
|
|
else:
|
|
input_size = rnn_module.hidden_size
|
|
flops = rnn_flops(flops, rnn_module, w_ih, w_hh, input_size)
|
|
if rnn_module.bias:
|
|
b_ih = rnn_module.__getattr__('bias_ih_l' + str(i))
|
|
b_hh = rnn_module.__getattr__('bias_hh_l' + str(i))
|
|
flops += b_ih.shape[0] + b_hh.shape[0]
|
|
|
|
flops *= batch_size
|
|
flops *= seq_length
|
|
if rnn_module.bidirectional:
|
|
flops *= 2
|
|
rnn_module.__flops__ += int(flops)
|
|
|
|
|
|
def rnn_cell_flops_counter_hook(rnn_cell_module, input, output):
|
|
flops = 0
|
|
inp = input[0]
|
|
batch_size = inp.shape[0]
|
|
w_ih = rnn_cell_module.__getattr__('weight_ih')
|
|
w_hh = rnn_cell_module.__getattr__('weight_hh')
|
|
input_size = inp.shape[1]
|
|
flops = rnn_flops(flops, rnn_cell_module, w_ih, w_hh, input_size)
|
|
if rnn_cell_module.bias:
|
|
b_ih = rnn_cell_module.__getattr__('bias_ih')
|
|
b_hh = rnn_cell_module.__getattr__('bias_hh')
|
|
flops += b_ih.shape[0] + b_hh.shape[0]
|
|
|
|
flops *= batch_size
|
|
rnn_cell_module.__flops__ += int(flops)
|
|
|
|
|
|
def add_batch_counter_variables_or_reset(module):
|
|
|
|
module.__batch_counter__ = 0
|
|
|
|
|
|
def add_batch_counter_hook_function(module):
|
|
if hasattr(module, '__batch_counter_handle__'):
|
|
return
|
|
|
|
handle = module.register_forward_hook(batch_counter_hook)
|
|
module.__batch_counter_handle__ = handle
|
|
|
|
|
|
def remove_batch_counter_hook_function(module):
|
|
if hasattr(module, '__batch_counter_handle__'):
|
|
module.__batch_counter_handle__.remove()
|
|
del module.__batch_counter_handle__
|
|
|
|
|
|
def add_flops_counter_variable_or_reset(module):
|
|
if is_supported_instance(module):
|
|
if hasattr(module, '__flops__') or hasattr(module, '__params__'):
|
|
print('Warning: variables __flops__ or __params__ are already '
|
|
'defined for the module' + type(module).__name__ +
|
|
' ptflops can affect your code!')
|
|
module.__flops__ = 0
|
|
module.__params__ = get_model_parameters_number(module)
|
|
|
|
|
|
CUSTOM_MODULES_MAPPING = {}
|
|
|
|
MODULES_MAPPING = {
|
|
# convolutions
|
|
nn.Conv1d: conv_flops_counter_hook,
|
|
nn.Conv2d: conv_flops_counter_hook,
|
|
nn.Conv3d: conv_flops_counter_hook,
|
|
Conv2d: conv_flops_counter_hook,
|
|
ModulatedDeformConv: conv_flops_counter_hook,
|
|
# activations
|
|
nn.ReLU: relu_flops_counter_hook,
|
|
nn.PReLU: relu_flops_counter_hook,
|
|
nn.ELU: relu_flops_counter_hook,
|
|
nn.LeakyReLU: relu_flops_counter_hook,
|
|
nn.ReLU6: relu_flops_counter_hook,
|
|
# poolings
|
|
nn.MaxPool1d: pool_flops_counter_hook,
|
|
nn.AvgPool1d: pool_flops_counter_hook,
|
|
nn.AvgPool2d: pool_flops_counter_hook,
|
|
nn.MaxPool2d: pool_flops_counter_hook,
|
|
nn.MaxPool3d: pool_flops_counter_hook,
|
|
nn.AvgPool3d: pool_flops_counter_hook,
|
|
nn.AdaptiveMaxPool1d: pool_flops_counter_hook,
|
|
nn.AdaptiveAvgPool1d: pool_flops_counter_hook,
|
|
nn.AdaptiveMaxPool2d: pool_flops_counter_hook,
|
|
nn.AdaptiveAvgPool2d: pool_flops_counter_hook,
|
|
nn.AdaptiveMaxPool3d: pool_flops_counter_hook,
|
|
nn.AdaptiveAvgPool3d: pool_flops_counter_hook,
|
|
# BNs
|
|
nn.BatchNorm1d: bn_flops_counter_hook,
|
|
nn.BatchNorm2d: bn_flops_counter_hook,
|
|
nn.BatchNorm3d: bn_flops_counter_hook,
|
|
nn.GroupNorm : bn_flops_counter_hook,
|
|
# FC
|
|
nn.Linear: linear_flops_counter_hook,
|
|
# Upscale
|
|
nn.Upsample: upsample_flops_counter_hook,
|
|
# Deconvolution
|
|
nn.ConvTranspose1d: conv_flops_counter_hook,
|
|
nn.ConvTranspose2d: conv_flops_counter_hook,
|
|
nn.ConvTranspose3d: conv_flops_counter_hook,
|
|
ConvTranspose2d: conv_flops_counter_hook,
|
|
# RNN
|
|
nn.RNN: rnn_flops_counter_hook,
|
|
nn.GRU: rnn_flops_counter_hook,
|
|
nn.LSTM: rnn_flops_counter_hook,
|
|
nn.RNNCell: rnn_cell_flops_counter_hook,
|
|
nn.LSTMCell: rnn_cell_flops_counter_hook,
|
|
nn.GRUCell: rnn_cell_flops_counter_hook
|
|
}
|
|
|
|
|
|
def is_supported_instance(module):
|
|
if type(module) in MODULES_MAPPING or type(module) in CUSTOM_MODULES_MAPPING \
|
|
or getattr(module, 'compute_macs', False):
|
|
return True
|
|
return False
|
|
|
|
|
|
def remove_flops_counter_hook_function(module):
|
|
if is_supported_instance(module):
|
|
if hasattr(module, '__flops_handle__'):
|
|
module.__flops_handle__.remove()
|
|
del module.__flops_handle__ |