mmrazor/tools/pruning/get_channel_units.py
LKJacky f886821ba1
Add get_prune_config and a demo config_pruning (#389)
* update tools and test

* add demo

* disable test doc

* add switch for test tools and test_doc

* fix bug

* update doc

* update tools name

* mv get_channel_units

Co-authored-by: liukai <your_email@abc.example>
2022-12-13 10:56:29 +08:00

85 lines
2.4 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import json
import sys
import torch.nn as nn
from mmengine import MODELS
from mmengine.config import Config
from mmrazor.models import BaseAlgorithm
from mmrazor.models.mutators import ChannelMutator
sys.setrecursionlimit(int(pow(2, 20)))
def parse_args():
parser = argparse.ArgumentParser(
description='Get channel unit of a model.')
parser.add_argument('config', help='config of the model')
parser.add_argument(
'-c',
'--with-channel',
action='store_true',
help='output with channel config')
parser.add_argument(
'-i',
'--with-init-args',
action='store_true',
help='output with init args')
parser.add_argument(
'--choice',
action='store_true',
help=('output choices template. When this flag is activated, '
'-c and -i will be ignored'))
parser.add_argument(
'-o',
'--output-path',
default='',
help='the file path to store channel unit info')
return parser.parse_args()
def main():
args = parse_args()
config = Config.fromfile(args.config)
default_scope = config['default_scope']
model = MODELS.build(config['model'])
if isinstance(model, BaseAlgorithm):
mutator = model.mutator
elif isinstance(model, nn.Module):
mutator: ChannelMutator = ChannelMutator(
channel_unit_cfg=dict(
type='L1MutableChannelUnit',
default_args=dict(choice_mode='ratio'),
),
parse_cfg={
'type': 'ChannelAnalyzer',
'demo_input': {
'type': 'DefaultDemoInput',
'scope': default_scope
},
'tracer_type': 'FxTracer'
})
mutator.prepare_from_supernet(model)
if args.choice:
config = mutator.choice_template
else:
config = mutator.config_template(
with_channels=args.with_channel,
with_unit_init_args=args.with_init_args)
json_config = json.dumps(config, indent=4, separators=(',', ':'))
if args.output_path == '':
print('=' * 100)
print('config template')
print('=' * 100)
print(json_config)
else:
with open(args.output_path, 'w') as file:
file.write(json_config)
if __name__ == '__main__':
main()