[Bug]Autoslim different checkpoints have the same size (#193)
* fix: split autoslim different checkpoint has equal model size * chore: pre-commit * chore: pre-commit Co-authored-by: Lance(Yongle) Wang <lance.wang@vastaitech.com>pull/325/head
parent
3cc359e364
commit
1abad087eb
|
@ -126,10 +126,11 @@ class StructurePruner(BaseModule, metaclass=ABCMeta):
|
|||
for name, module in supernet.model.named_modules():
|
||||
if isinstance(module, nn.GroupNorm):
|
||||
min_required_version = '1.6.0'
|
||||
assert digit_version(torch.__version__) >= digit_version(
|
||||
min_required_version
|
||||
), f'Requires pytorch>={min_required_version} to auto-trace' \
|
||||
f'GroupNorm correctly.'
|
||||
assert digit_version(
|
||||
torch.__version__
|
||||
) >= digit_version(min_required_version), (
|
||||
f'Requires pytorch>={min_required_version} to auto-trace'
|
||||
f'GroupNorm correctly.')
|
||||
if hasattr(module, 'weight'):
|
||||
# trace shared modules
|
||||
module.cnt = 0
|
||||
|
@ -407,13 +408,14 @@ class StructurePruner(BaseModule, metaclass=ABCMeta):
|
|||
same_in_channel_groups, same_out_channel_groups = {}, {}
|
||||
for node_name, parents_name in node2parents.items():
|
||||
parser = self.find_make_group_parser(node_name, name2module)
|
||||
idx, same_in_channel_groups, same_out_channel_groups = \
|
||||
parser(self,
|
||||
node_name=node_name,
|
||||
parents_name=parents_name,
|
||||
group_idx=idx,
|
||||
same_in_channel_groups=same_in_channel_groups,
|
||||
same_out_channel_groups=same_out_channel_groups)
|
||||
idx, same_in_channel_groups, same_out_channel_groups = parser(
|
||||
self,
|
||||
node_name=node_name,
|
||||
parents_name=parents_name,
|
||||
group_idx=idx,
|
||||
same_in_channel_groups=same_in_channel_groups,
|
||||
same_out_channel_groups=same_out_channel_groups,
|
||||
)
|
||||
|
||||
groups = dict()
|
||||
idx = 0
|
||||
|
@ -455,23 +457,29 @@ class StructurePruner(BaseModule, metaclass=ABCMeta):
|
|||
if isinstance(module, nn.Conv2d):
|
||||
module.register_buffer(
|
||||
'in_mask',
|
||||
module.weight.new_ones((1, module.in_channels, 1, 1), ))
|
||||
module.weight.new_ones((1, module.in_channels, 1, 1), ),
|
||||
)
|
||||
module.register_buffer(
|
||||
'out_mask',
|
||||
module.weight.new_ones((1, module.out_channels, 1, 1), ))
|
||||
module.weight.new_ones((1, module.out_channels, 1, 1), ),
|
||||
)
|
||||
module.forward = self.modify_conv_forward(module)
|
||||
if isinstance(module, nn.Linear):
|
||||
module.register_buffer(
|
||||
'in_mask', module.weight.new_ones((1, module.in_features), ))
|
||||
module.register_buffer(
|
||||
'out_mask', module.weight.new_ones((1, module.out_features), ))
|
||||
module.forward = self.modify_fc_forward(module)
|
||||
if (isinstance(module, _BatchNorm)
|
||||
or isinstance(module, _InstanceNorm)
|
||||
or isinstance(module, GroupNorm)):
|
||||
'in_mask',
|
||||
module.weight.new_ones((1, module.in_features), ),
|
||||
)
|
||||
module.register_buffer(
|
||||
'out_mask',
|
||||
module.weight.new_ones((1, len(module.weight), 1, 1), ))
|
||||
module.weight.new_ones((1, module.out_features), ),
|
||||
)
|
||||
module.forward = self.modify_fc_forward(module)
|
||||
if isinstance(module, _BatchNorm) or isinstance(
|
||||
module, _InstanceNorm) or isinstance(module, GroupNorm):
|
||||
module.register_buffer(
|
||||
'out_mask',
|
||||
module.weight.new_ones((1, len(module.weight), 1, 1), ),
|
||||
)
|
||||
|
||||
def find_node_parents(self, paths):
|
||||
"""Find the parent node of a node.
|
||||
|
@ -565,11 +573,12 @@ class StructurePruner(BaseModule, metaclass=ABCMeta):
|
|||
if getattr(module, 'groups', in_channels) > 1:
|
||||
module.groups = in_channels
|
||||
|
||||
module.weight = nn.Parameter(temp_weight.data)
|
||||
module.weight = nn.Parameter(temp_weight.data.clone())
|
||||
module.weight.requires_grad = requires_grad
|
||||
|
||||
if hasattr(module, 'bias') and module.bias is not None:
|
||||
module.bias = nn.Parameter(module.bias.data[:out_channels])
|
||||
module.bias = nn.Parameter(
|
||||
module.bias.data[:out_channels].clone())
|
||||
module.bias.requires_grad = requires_grad
|
||||
|
||||
if hasattr(module, 'running_mean'):
|
||||
|
|
Loading…
Reference in New Issue