diff --git a/mmrazor/models/pruners/structure_pruning.py b/mmrazor/models/pruners/structure_pruning.py index b755c2e5..9bb69097 100644 --- a/mmrazor/models/pruners/structure_pruning.py +++ b/mmrazor/models/pruners/structure_pruning.py @@ -126,10 +126,11 @@ class StructurePruner(BaseModule, metaclass=ABCMeta): for name, module in supernet.model.named_modules(): if isinstance(module, nn.GroupNorm): min_required_version = '1.6.0' - assert digit_version(torch.__version__) >= digit_version( - min_required_version - ), f'Requires pytorch>={min_required_version} to auto-trace' \ - f'GroupNorm correctly.' + assert digit_version( + torch.__version__ + ) >= digit_version(min_required_version), ( + f'Requires pytorch>={min_required_version} to auto-trace' + f'GroupNorm correctly.') if hasattr(module, 'weight'): # trace shared modules module.cnt = 0 @@ -407,13 +408,14 @@ class StructurePruner(BaseModule, metaclass=ABCMeta): same_in_channel_groups, same_out_channel_groups = {}, {} for node_name, parents_name in node2parents.items(): parser = self.find_make_group_parser(node_name, name2module) - idx, same_in_channel_groups, same_out_channel_groups = \ - parser(self, - node_name=node_name, - parents_name=parents_name, - group_idx=idx, - same_in_channel_groups=same_in_channel_groups, - same_out_channel_groups=same_out_channel_groups) + idx, same_in_channel_groups, same_out_channel_groups = parser( + self, + node_name=node_name, + parents_name=parents_name, + group_idx=idx, + same_in_channel_groups=same_in_channel_groups, + same_out_channel_groups=same_out_channel_groups, + ) groups = dict() idx = 0 @@ -455,23 +457,29 @@ class StructurePruner(BaseModule, metaclass=ABCMeta): if isinstance(module, nn.Conv2d): module.register_buffer( 'in_mask', - module.weight.new_ones((1, module.in_channels, 1, 1), )) + module.weight.new_ones((1, module.in_channels, 1, 1), ), + ) module.register_buffer( 'out_mask', - module.weight.new_ones((1, module.out_channels, 1, 1), )) + module.weight.new_ones((1, module.out_channels, 1, 1), ), + ) module.forward = self.modify_conv_forward(module) if isinstance(module, nn.Linear): module.register_buffer( - 'in_mask', module.weight.new_ones((1, module.in_features), )) - module.register_buffer( - 'out_mask', module.weight.new_ones((1, module.out_features), )) - module.forward = self.modify_fc_forward(module) - if (isinstance(module, _BatchNorm) - or isinstance(module, _InstanceNorm) - or isinstance(module, GroupNorm)): + 'in_mask', + module.weight.new_ones((1, module.in_features), ), + ) module.register_buffer( 'out_mask', - module.weight.new_ones((1, len(module.weight), 1, 1), )) + module.weight.new_ones((1, module.out_features), ), + ) + module.forward = self.modify_fc_forward(module) + if isinstance(module, _BatchNorm) or isinstance( + module, _InstanceNorm) or isinstance(module, GroupNorm): + module.register_buffer( + 'out_mask', + module.weight.new_ones((1, len(module.weight), 1, 1), ), + ) def find_node_parents(self, paths): """Find the parent node of a node. @@ -565,11 +573,12 @@ class StructurePruner(BaseModule, metaclass=ABCMeta): if getattr(module, 'groups', in_channels) > 1: module.groups = in_channels - module.weight = nn.Parameter(temp_weight.data) + module.weight = nn.Parameter(temp_weight.data.clone()) module.weight.requires_grad = requires_grad if hasattr(module, 'bias') and module.bias is not None: - module.bias = nn.Parameter(module.bias.data[:out_channels]) + module.bias = nn.Parameter( + module.bias.data[:out_channels].clone()) module.bias.requires_grad = requires_grad if hasattr(module, 'running_mean'):