mirror of https://github.com/alibaba/EasyCV.git
90 lines
3.2 KiB
Python
90 lines
3.2 KiB
Python
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
try:
|
|
from nni.algorithms.compression.pytorch.pruning import AGPPrunerV2
|
|
except ImportError:
|
|
raise ImportError(
|
|
'Please read docs and run "pip install https://pai-nni.oss-cn-zhangjiakou.aliyuncs.com/release/2.5/pai_nni-2.5-py3-none-manylinux1_x86_64.whl" '
|
|
'to install pai_nni')
|
|
|
|
|
|
def get_prune_layer(model):
|
|
if model == 'YOLOX_EDGE':
|
|
backbone_name = 'backbone'
|
|
|
|
prune_layer_names = []
|
|
|
|
dark_block_num = {
|
|
2: 3,
|
|
3: 9,
|
|
4: 9,
|
|
5: 3,
|
|
'C3_p4': 3,
|
|
'C3_p3': 3,
|
|
'C3_n3': 3,
|
|
'C3_n4': 3,
|
|
}
|
|
|
|
for dark_id in [2, 3, 4]:
|
|
prune_layer_names.append(
|
|
f'{backbone_name}.backbone.dark{dark_id}.0.pconv.conv')
|
|
for m_id in range(dark_block_num[dark_id]):
|
|
prune_layer_names.append(
|
|
f'{backbone_name}.backbone.dark{dark_id}.1.m.{m_id}.conv1.conv'
|
|
)
|
|
|
|
prune_layer_names.append(
|
|
f'{backbone_name}.backbone.dark{dark_id}.1.conv3.conv')
|
|
|
|
for dark_id in [5]:
|
|
prune_layer_names.append(
|
|
f'{backbone_name}.backbone.dark{dark_id}.0.pconv.conv')
|
|
prune_layer_names.append(
|
|
f'{backbone_name}.backbone.dark{dark_id}.1.conv2.conv')
|
|
prune_layer_names.append(
|
|
f'{backbone_name}.backbone.dark{dark_id}.2.conv1.conv')
|
|
for m_id in range(dark_block_num[dark_id]):
|
|
prune_layer_names.append(
|
|
f'{backbone_name}.backbone.dark{dark_id}.2.m.{m_id}.conv1.conv'
|
|
)
|
|
prune_layer_names.append(
|
|
f'{backbone_name}.backbone.dark{dark_id}.2.m.{m_id}.conv2.pconv.conv'
|
|
)
|
|
|
|
prune_layer_names.append(
|
|
f'{backbone_name}.backbone.dark{dark_id}.2.conv3.conv')
|
|
|
|
for dark_id in ['C3_p4', 'C3_p3', 'C3_n3', 'C3_n4']:
|
|
prune_layer_names.append(f'{backbone_name}.{dark_id}.conv1.conv')
|
|
for m_id in range(dark_block_num[dark_id]):
|
|
prune_layer_names.append(
|
|
f'{backbone_name}.{dark_id}.m.{m_id}.conv1.conv')
|
|
prune_layer_names.append(
|
|
f'{backbone_name}.{dark_id}.m.{m_id}.conv2.pconv.conv')
|
|
|
|
prune_layer_names.append(f'{backbone_name}.{dark_id}.conv2.conv')
|
|
prune_layer_names.append(f'{backbone_name}.{dark_id}.conv3.conv')
|
|
return prune_layer_names
|
|
else:
|
|
# default layer
|
|
return ['Conv2d']
|
|
|
|
|
|
def load_pruner(model,
|
|
pruner_config,
|
|
optimizer,
|
|
pruning_class='AGP',
|
|
pruning_algorithm='taylorfo'):
|
|
# default pruning class is AGPPrunerV2
|
|
if pruning_class == 'AGP':
|
|
# default pruning algorithm is taylorfo
|
|
pruner = AGPPrunerV2(
|
|
model=model,
|
|
config_list=pruner_config,
|
|
optimizer=optimizer,
|
|
pruning_algorithm=pruning_algorithm)
|
|
else:
|
|
raise Exception(
|
|
'pruning class {} is not supported'.format(pruning_class))
|
|
|
|
return pruner
|