56 lines
1.6 KiB
Python
56 lines
1.6 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import argparse
|
|
|
|
from mmengine import Config
|
|
|
|
from mmrazor.models.algorithms import ItePruneAlgorithm
|
|
from mmrazor.models.task_modules import ResourceEstimator
|
|
from mmrazor.models.task_modules.demo_inputs import DefaultDemoInput
|
|
from mmrazor.registry import MODELS
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('config')
|
|
parser.add_argument('-H', default=224, type=int)
|
|
parser.add_argument('-W', default=224, type=int)
|
|
args = parser.parse_args()
|
|
return args
|
|
|
|
|
|
def input_generator_wrapper(model, shape, training, scope=None):
|
|
|
|
def input_generator(input_shape):
|
|
inputs = DefaultDemoInput(scope=scope).get_data(
|
|
model, input_shape=input_shape, training=training)
|
|
if isinstance(input, dict) and 'mode' in inputs:
|
|
inputs['mode'] = 'tensor'
|
|
return inputs
|
|
|
|
return input_generator
|
|
|
|
|
|
if __name__ == '__main__':
|
|
args = parse_args()
|
|
config = Config.fromfile(args.config)
|
|
H = args.H
|
|
W = args.W
|
|
|
|
default_scope = config['default_scope']
|
|
model_config = config['model']
|
|
# model_config['_scope_'] = default_scope
|
|
model: ItePruneAlgorithm = MODELS.build(model_config)
|
|
|
|
estimator = ResourceEstimator(
|
|
flops_params_cfg=dict(
|
|
input_shape=(1, 3, H, W),
|
|
print_per_layer_stat=False,
|
|
input_constructor=input_generator_wrapper(
|
|
model,
|
|
(1, 3, H, W),
|
|
training=False,
|
|
scope=default_scope,
|
|
)))
|
|
result = estimator.estimate(model)
|
|
print(result)
|