[Fix] Updating edge-embeddings after each GNN layer (#1134)

pull/1141/head
Amit Agarwal 2022-07-06 17:07:28 +05:30 committed by GitHub
parent 7800e13fc2
commit c4a2fa5eee
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 2 additions and 2 deletions

View File

@ -79,9 +79,9 @@ class SDMGRHead(BaseModule):
embed_edges = F.normalize(embed_edges)
for gnn_layer in self.gnn_layers:
nodes, cat_nodes = gnn_layer(nodes, embed_edges, node_nums)
nodes, embed_edges = gnn_layer(nodes, embed_edges, node_nums)
node_cls, edge_cls = self.node_cls(nodes), self.edge_cls(cat_nodes)
node_cls, edge_cls = self.node_cls(nodes), self.edge_cls(embed_edges)
return node_cls, edge_cls