commit
1691aa5a7a
|
@ -15,7 +15,6 @@ class TheseusLayer(nn.Layer):
|
||||||
def __init__(self, *args, return_patterns=None, **kwargs):
|
def __init__(self, *args, return_patterns=None, **kwargs):
|
||||||
super(TheseusLayer, self).__init__()
|
super(TheseusLayer, self).__init__()
|
||||||
self.res_dict = None
|
self.res_dict = None
|
||||||
# self.register_forward_post_hook(self._disconnect_res_dict_hook)
|
|
||||||
if return_patterns is not None:
|
if return_patterns is not None:
|
||||||
self._update_res(return_patterns)
|
self._update_res(return_patterns)
|
||||||
|
|
||||||
|
@ -48,12 +47,6 @@ class TheseusLayer(nn.Layer):
|
||||||
self._sub_layers[layer_i].register_forward_post_hook(
|
self._sub_layers[layer_i].register_forward_post_hook(
|
||||||
self._save_sub_res_hook)
|
self._save_sub_res_hook)
|
||||||
|
|
||||||
# def _save_sub_res_hook(self, layer, input, output):
|
|
||||||
# self.res_dict[layer.full_name()] = output
|
|
||||||
#
|
|
||||||
# def _disconnect_res_dict_hook(self, input, output):
|
|
||||||
# self.res_dict = None
|
|
||||||
|
|
||||||
def replace_sub(self, layer_name_pattern, replace_function,
|
def replace_sub(self, layer_name_pattern, replace_function,
|
||||||
recursive=True):
|
recursive=True):
|
||||||
for k in self._sub_layers.keys():
|
for k in self._sub_layers.keys():
|
||||||
|
|
Loading…
Reference in New Issue