mmrazor/tools/pruning/get_flops.py

56 lines
1.6 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import argparse
from mmengine import Config
from mmrazor.models.algorithms import ItePruneAlgorithm
from mmrazor.models.task_modules import ResourceEstimator
from mmrazor.models.task_modules.demo_inputs import DefaultDemoInput
from mmrazor.registry import MODELS
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('config')
parser.add_argument('-H', default=224, type=int)
parser.add_argument('-W', default=224, type=int)
args = parser.parse_args()
return args
def input_generator_wrapper(model, shape, training, scope=None):
def input_generator(input_shape):
inputs = DefaultDemoInput(scope=scope).get_data(
model, input_shape=input_shape, training=training)
if isinstance(input, dict) and 'mode' in inputs:
inputs['mode'] = 'tensor'
return inputs
return input_generator
if __name__ == '__main__':
args = parse_args()
config = Config.fromfile(args.config)
H = args.H
W = args.W
default_scope = config['default_scope']
model_config = config['model']
# model_config['_scope_'] = default_scope
model: ItePruneAlgorithm = MODELS.build(model_config)
estimator = ResourceEstimator(
flops_params_cfg=dict(
input_shape=(1, 3, H, W),
print_per_layer_stat=False,
input_constructor=input_generator_wrapper(
model,
(1, 3, H, W),
training=False,
scope=default_scope,
)))
result = estimator.estimate(model)
print(result)