mmrazor/tools/pruning/get_flops.py
LKJacky 7acc046678
Add GroupFisher pruning algorithm. (#459)
* init

* support expand dwconv

* add tools

* init

* add import

* add configs

* add ut and fix bug

* update

* update finetune config

* update impl imports

* add deploy configs and result

* add _train_step

* detla_type -> normalization_type

* change img link

* add prune to config

* add json dump when GroupFisherSubModel init

* update prune config

* update finetune config

* update deploy config

* update prune config

* update readme

* mutable_cfg -> fix_subnet

* update readme

* impl -> implementations

* update script.sh

* rm gen_fake_cfg

* add Implementation to readme

* update docstring

* add finetune_lr to config

* update readme

* fix error in config

* update links

* update configs

* refine

* fix spell error

* add test to readme

* update README

* update readme

* update readme

* update cite format

* fix for ci

* update to pass ci

* update readme

---------

Co-authored-by: liukai <your_email@abc.example>
Co-authored-by: Your Name <you@example.com>
2023-02-20 14:29:42 +08:00

56 lines
1.6 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import argparse
from mmengine import Config
from mmrazor.models.algorithms import ItePruneAlgorithm
from mmrazor.models.task_modules import ResourceEstimator
from mmrazor.models.task_modules.demo_inputs import DefaultDemoInput
from mmrazor.registry import MODELS
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('config')
parser.add_argument('-H', default=224, type=int)
parser.add_argument('-W', default=224, type=int)
args = parser.parse_args()
return args
def input_generator_wrapper(model, shape, training, scope=None):
def input_generator(input_shape):
inputs = DefaultDemoInput(scope=scope).get_data(
model, input_shape=input_shape, training=training)
if isinstance(input, dict) and 'mode' in inputs:
inputs['mode'] = 'tensor'
return inputs
return input_generator
if __name__ == '__main__':
args = parse_args()
config = Config.fromfile(args.config)
H = args.H
W = args.W
default_scope = config['default_scope']
model_config = config['model']
# model_config['_scope_'] = default_scope
model: ItePruneAlgorithm = MODELS.build(model_config)
estimator = ResourceEstimator(
flops_params_cfg=dict(
input_shape=(1, 3, H, W),
print_per_layer_stat=False,
input_constructor=input_generator_wrapper(
model,
(1, 3, H, W),
training=False,
scope=default_scope,
)))
result = estimator.estimate(model)
print(result)