mirror of
https://github.com/open-mmlab/mmrazor.git
synced 2025-06-03 15:02:54 +08:00
128 lines
3.7 KiB
Python
128 lines
3.7 KiB
Python
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||
|
import argparse
|
||
|
import copy
|
||
|
from typing import Dict
|
||
|
|
||
|
from mmengine import Config, fileio
|
||
|
|
||
|
from mmrazor.models.mutators import ChannelMutator
|
||
|
from mmrazor.registry import MODELS
|
||
|
|
||
|
|
||
|
def parse_args():
|
||
|
parser = argparse.ArgumentParser(
|
||
|
description='Get the config to prune a model.')
|
||
|
parser.add_argument('config', help='config of the model')
|
||
|
parser.add_argument(
|
||
|
'--checkpoint',
|
||
|
default=None,
|
||
|
type=str,
|
||
|
help='checkpoint path of the model')
|
||
|
parser.add_argument(
|
||
|
'--subnet',
|
||
|
default=None,
|
||
|
type=str,
|
||
|
help='pruning structure for the model')
|
||
|
parser.add_argument(
|
||
|
'-o',
|
||
|
type=str,
|
||
|
default='./prune.py',
|
||
|
help='output path to store the pruning config.')
|
||
|
args = parser.parse_args()
|
||
|
return args
|
||
|
|
||
|
|
||
|
def wrap_prune_config(config: Config, prune_target: Dict,
|
||
|
checkpoint_path: str):
|
||
|
config = copy.deepcopy(config)
|
||
|
default_scope = config['default_scope']
|
||
|
arch_config: Dict = config['model']
|
||
|
|
||
|
# update checkpoint_path
|
||
|
if checkpoint_path is not None:
|
||
|
arch_config.update({
|
||
|
'init_cfg': {
|
||
|
'type': 'Pretrained',
|
||
|
'checkpoint': checkpoint_path # noqa
|
||
|
},
|
||
|
})
|
||
|
|
||
|
# deal with data_preprocessor
|
||
|
if 'data_preprocessor' in config:
|
||
|
data_preprocessor = config['data_preprocessor']
|
||
|
arch_config.update({'data_preprocessor': data_preprocessor})
|
||
|
config['data_preprocessor'] = None
|
||
|
else:
|
||
|
data_preprocessor = None
|
||
|
|
||
|
# prepare algorithm
|
||
|
algorithm_config = dict(
|
||
|
_scope_='mmrazor',
|
||
|
type='ItePruneAlgorithm',
|
||
|
architecture=arch_config,
|
||
|
target_pruning_ratio=prune_target,
|
||
|
mutator_cfg=dict(
|
||
|
type='ChannelMutator',
|
||
|
channel_unit_cfg=dict(
|
||
|
type='L1MutableChannelUnit',
|
||
|
default_args=dict(choice_mode='ratio')),
|
||
|
parse_cfg=dict(
|
||
|
type='ChannelAnalyzer',
|
||
|
tracer_type='FxTracer',
|
||
|
demo_input=dict(type='DefaultDemoInput',
|
||
|
scope=default_scope))))
|
||
|
config['model'] = algorithm_config
|
||
|
|
||
|
return config
|
||
|
|
||
|
|
||
|
def change_config(config):
|
||
|
|
||
|
scope = config['default_scope']
|
||
|
config['model']['_scope_'] = scope
|
||
|
return config
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
args = parse_args()
|
||
|
config_path = args.config
|
||
|
checkpoint_path = args.checkpoint
|
||
|
target_path = args.o
|
||
|
|
||
|
origin_config = Config.fromfile(config_path)
|
||
|
origin_config = change_config(origin_config)
|
||
|
default_scope = origin_config['default_scope']
|
||
|
|
||
|
# get subnet config
|
||
|
model = MODELS.build(copy.deepcopy(origin_config['model']))
|
||
|
mutator: ChannelMutator = ChannelMutator(
|
||
|
channel_unit_cfg=dict(
|
||
|
type='L1MutableChannelUnit',
|
||
|
default_args=dict(choice_mode='ratio'),
|
||
|
),
|
||
|
parse_cfg={
|
||
|
'type': 'ChannelAnalyzer',
|
||
|
'demo_input': {
|
||
|
'type': 'DefaultDemoInput',
|
||
|
'scope': default_scope
|
||
|
},
|
||
|
'tracer_type': 'FxTracer'
|
||
|
})
|
||
|
mutator.prepare_from_supernet(model)
|
||
|
if args.subnet is None:
|
||
|
choice_template = mutator.choice_template
|
||
|
else:
|
||
|
input_choices = fileio.load(args.subnet)
|
||
|
try:
|
||
|
mutator.set_choices(input_choices)
|
||
|
choice_template = input_choices
|
||
|
except Exception as e:
|
||
|
print(f'error when apply input subnet: {e}')
|
||
|
choice_template = mutator.choice_template
|
||
|
|
||
|
# prune and finetune
|
||
|
|
||
|
prune_config: Config = wrap_prune_config(origin_config, choice_template,
|
||
|
checkpoint_path)
|
||
|
prune_config.dump(target_path)
|