mirror of https://github.com/alibaba/EasyCV.git
114 lines
3.8 KiB
Python
114 lines
3.8 KiB
Python
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
import argparse
|
|
import inspect
|
|
from collections import Counter
|
|
|
|
import numpy as np
|
|
import tqdm
|
|
from fvcore.nn import FlopCountAnalysis, flop_count_table
|
|
from mmcv.parallel import scatter_kwargs
|
|
from prettytable import PrettyTable
|
|
|
|
from easycv.datasets.builder import build_dataset
|
|
from easycv.datasets.loader import build_dataloader
|
|
from easycv.datasets.utils import is_dali_dataset_type
|
|
from easycv.models.builder import build_model
|
|
from easycv.utils.config_tools import mmcv_config_fromfile
|
|
from easycv.utils.mmlab_utils import dynamic_adapt_for_mmlab
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(description='count model flops')
|
|
parser.add_argument('config', help='config file path')
|
|
parser.add_argument(
|
|
'--repeat_num', default=10, type=int, help='repeat number')
|
|
args = parser.parse_args()
|
|
return args
|
|
|
|
|
|
def flatten_inputs(model, inputs):
|
|
full_args_spec = inspect.getfullargspec(model.forward)
|
|
args = [] if not full_args_spec.args else full_args_spec.args
|
|
args.pop(0) if (args and args[0] in ['self', 'cls']) else args
|
|
|
|
default_values = [] if not full_args_spec.defaults else full_args_spec.defaults
|
|
args_has_default = args[len(args) - len(default_values):]
|
|
args_with_default = dict(zip(
|
|
args_has_default, default_values)) if len(args_has_default) else {}
|
|
|
|
flat_inputs = []
|
|
for arg_i in args:
|
|
if inputs.get(arg_i, None) is not None:
|
|
flat_inputs.append(inputs.get(arg_i))
|
|
else:
|
|
flat_inputs.append(args_with_default.get(arg_i))
|
|
|
|
return tuple(flat_inputs)
|
|
|
|
|
|
def count_flop():
|
|
args = parse_args()
|
|
|
|
device = 'cuda'
|
|
cfg = mmcv_config_fromfile(args.config)
|
|
|
|
# dynamic adapt mmdet models
|
|
dynamic_adapt_for_mmlab(cfg)
|
|
|
|
model = build_model(cfg.model)
|
|
model.to(device)
|
|
model.eval()
|
|
|
|
if cfg.data.get('val', None) is not None:
|
|
cfg.data.val.pop('imgs_per_gpu', None) # pop useless params
|
|
data_cfg = cfg.data.val
|
|
else:
|
|
data_cfg = cfg.data.train
|
|
|
|
if is_dali_dataset_type(data_cfg['type']):
|
|
data_cfg.distributed = False
|
|
data_cfg.batch_size = 1
|
|
data_cfg.workers_per_gpu = 1
|
|
dataset = build_dataset(data_cfg)
|
|
data_loader = dataset.get_dataloader()
|
|
else:
|
|
dataset = build_dataset(data_cfg)
|
|
data_loader = build_dataloader(
|
|
dataset, imgs_per_gpu=1, workers_per_gpu=0)
|
|
|
|
handlers = {} # mapping from operator names to handles.
|
|
counts = Counter()
|
|
gflop_unit = 1e9
|
|
total_flops = []
|
|
for idx, data in zip(tqdm.trange(args.repeat_num), data_loader):
|
|
# use scatter_kwargs to unpack DataContainer data for raw torch.nn.module
|
|
_, kwargs = scatter_kwargs(None, data, [0])
|
|
kwargs[0].update({'mode': 'test'})
|
|
inputs = flatten_inputs(model, kwargs[0])
|
|
|
|
# Provides access to per-submodule model flop count obtained by
|
|
# tracing a model with pytorch's jit tracing functionality.
|
|
# So models that donot support jit tracing may fail.
|
|
flops = FlopCountAnalysis(model, inputs)
|
|
flops.set_op_handle(**handlers)
|
|
if idx > 0:
|
|
flops.unsupported_ops_warnings(False).uncalled_modules_warnings(
|
|
False)
|
|
counts += flops.by_operator()
|
|
total_flops.append(flops.total())
|
|
|
|
print('Flops from only one sample is:\n' + flop_count_table(flops))
|
|
ops_show = PrettyTable()
|
|
ops_show.field_names = ['operator type', 'Gflops']
|
|
for k, v in counts.items():
|
|
ops_show.add_row([k, round(v / (idx + 1) / gflop_unit, 3)])
|
|
print('Average Gflops of each operator type is:')
|
|
print(ops_show)
|
|
print('Total flops: {:.1f}G ± {:.1f}G'.format(
|
|
np.mean(total_flops) / gflop_unit,
|
|
np.std(total_flops) / gflop_unit))
|
|
|
|
|
|
if __name__ == '__main__':
|
|
count_flop()
|