[Bug]Autoslim different checkpoints have the same size (#193)

* fix: split autoslim different checkpoint has equal model size

* chore: pre-commit

* chore: pre-commit

Co-authored-by: Lance(Yongle) Wang <lance.wang@vastaitech.com>
pull/325/head
lance 2022-07-05 13:32:03 +08:00 committed by GitHub
parent 3cc359e364
commit 1abad087eb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 32 additions and 23 deletions

View File

@ -126,10 +126,11 @@ class StructurePruner(BaseModule, metaclass=ABCMeta):
for name, module in supernet.model.named_modules():
if isinstance(module, nn.GroupNorm):
min_required_version = '1.6.0'
assert digit_version(torch.__version__) >= digit_version(
min_required_version
), f'Requires pytorch>={min_required_version} to auto-trace' \
f'GroupNorm correctly.'
assert digit_version(
torch.__version__
) >= digit_version(min_required_version), (
f'Requires pytorch>={min_required_version} to auto-trace'
f'GroupNorm correctly.')
if hasattr(module, 'weight'):
# trace shared modules
module.cnt = 0
@ -407,13 +408,14 @@ class StructurePruner(BaseModule, metaclass=ABCMeta):
same_in_channel_groups, same_out_channel_groups = {}, {}
for node_name, parents_name in node2parents.items():
parser = self.find_make_group_parser(node_name, name2module)
idx, same_in_channel_groups, same_out_channel_groups = \
parser(self,
node_name=node_name,
parents_name=parents_name,
group_idx=idx,
same_in_channel_groups=same_in_channel_groups,
same_out_channel_groups=same_out_channel_groups)
idx, same_in_channel_groups, same_out_channel_groups = parser(
self,
node_name=node_name,
parents_name=parents_name,
group_idx=idx,
same_in_channel_groups=same_in_channel_groups,
same_out_channel_groups=same_out_channel_groups,
)
groups = dict()
idx = 0
@ -455,23 +457,29 @@ class StructurePruner(BaseModule, metaclass=ABCMeta):
if isinstance(module, nn.Conv2d):
module.register_buffer(
'in_mask',
module.weight.new_ones((1, module.in_channels, 1, 1), ))
module.weight.new_ones((1, module.in_channels, 1, 1), ),
)
module.register_buffer(
'out_mask',
module.weight.new_ones((1, module.out_channels, 1, 1), ))
module.weight.new_ones((1, module.out_channels, 1, 1), ),
)
module.forward = self.modify_conv_forward(module)
if isinstance(module, nn.Linear):
module.register_buffer(
'in_mask', module.weight.new_ones((1, module.in_features), ))
module.register_buffer(
'out_mask', module.weight.new_ones((1, module.out_features), ))
module.forward = self.modify_fc_forward(module)
if (isinstance(module, _BatchNorm)
or isinstance(module, _InstanceNorm)
or isinstance(module, GroupNorm)):
'in_mask',
module.weight.new_ones((1, module.in_features), ),
)
module.register_buffer(
'out_mask',
module.weight.new_ones((1, len(module.weight), 1, 1), ))
module.weight.new_ones((1, module.out_features), ),
)
module.forward = self.modify_fc_forward(module)
if isinstance(module, _BatchNorm) or isinstance(
module, _InstanceNorm) or isinstance(module, GroupNorm):
module.register_buffer(
'out_mask',
module.weight.new_ones((1, len(module.weight), 1, 1), ),
)
def find_node_parents(self, paths):
"""Find the parent node of a node.
@ -565,11 +573,12 @@ class StructurePruner(BaseModule, metaclass=ABCMeta):
if getattr(module, 'groups', in_channels) > 1:
module.groups = in_channels
module.weight = nn.Parameter(temp_weight.data)
module.weight = nn.Parameter(temp_weight.data.clone())
module.weight.requires_grad = requires_grad
if hasattr(module, 'bias') and module.bias is not None:
module.bias = nn.Parameter(module.bias.data[:out_channels])
module.bias = nn.Parameter(
module.bias.data[:out_channels].clone())
module.bias.requires_grad = requires_grad
if hasattr(module, 'running_mean'):