mirror of https://github.com/AndyJZhao/HGSL.git
Initial commit
parent
0299e34db9
commit
a56d6a3162
17
README.md
17
README.md
|
@ -1,3 +1,14 @@
|
|||
# AAAI21-HGSL
|
||||
Source code of AAAI21-Heterogeneous Graph Structure Learning for Graph Neural Networks
|
||||
To be released soon...
|
||||
# HGSL
|
||||
Source code of AAAI submission "Heterogeneous Graph Structure Learning for Graph Neural Networks"
|
||||
# Requirements
|
||||
## Python Packages
|
||||
- Python >= 3.6.8
|
||||
- Pytorch >= 1.3.0
|
||||
- DGL == 0.4.3
|
||||
## GPU Memmory Requirements
|
||||
- ACM >= 8G
|
||||
- DBLP >=5G
|
||||
- Yelp >=3G
|
||||
# Usage
|
||||
Take DBLP dataset as an example:
|
||||
python train.py --dataset='dblp'
|
||||
|
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -0,0 +1,200 @@
|
|||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from util_funcs import cos_sim
|
||||
|
||||
|
||||
class HGSL(nn.Module):
|
||||
"""
|
||||
Decode neighbors of input graph.
|
||||
"""
|
||||
|
||||
def __init__(self, cf, g):
|
||||
super(HGSL, self).__init__()
|
||||
self.__dict__.update(cf.get_model_conf())
|
||||
# ! Init variables
|
||||
self.dev = cf.dev
|
||||
self.ti, self.ri, self.types, self.ud_rels = g.t_info, g.r_info, g.types, g.undirected_relations
|
||||
feat_dim, mp_emb_dim = g.features.shape[1], list(g.mp_emb_dict.values())[0].shape[1]
|
||||
self.non_linear = nn.ReLU()
|
||||
# ! Graph Structure Learning
|
||||
MD = nn.ModuleDict
|
||||
self.fgg_direct, self.fgg_left, self.fgg_right, self.fg_agg, self.sgg_gen, self.sg_agg, self.overall_g_agg = \
|
||||
MD({}), MD({}), MD({}), MD({}), MD({}), MD({}), MD({})
|
||||
# Feature encoder
|
||||
self.encoder = MD(dict(zip(g.types, [nn.Linear(g.features.shape[1], cf.com_feat_dim) for _ in g.types])))
|
||||
|
||||
for r in g.undirected_relations:
|
||||
# ! Feature Graph Generator
|
||||
self.fgg_direct[r] = GraphGenerator(cf.com_feat_dim, cf.num_head, cf.fgd_th, self.dev)
|
||||
self.fgg_left[r] = GraphGenerator(feat_dim, cf.num_head, cf.fgh_th, self.dev)
|
||||
self.fgg_right[r] = GraphGenerator(feat_dim, cf.num_head, cf.fgh_th, self.dev)
|
||||
self.fg_agg[r] = GraphChannelAttLayer(3) # 3 = 1 (first-order/direct) + 2 (second-order)
|
||||
|
||||
# ! Semantic Graph Generator
|
||||
self.sgg_gen[r] = MD(dict(
|
||||
zip(cf.mp_list, [GraphGenerator(mp_emb_dim, cf.num_head, cf.sem_th, self.dev) for _ in cf.mp_list])))
|
||||
self.sg_agg[r] = GraphChannelAttLayer(len(cf.mp_list))
|
||||
|
||||
# ! Overall Graph Generator
|
||||
self.overall_g_agg[r] = GraphChannelAttLayer(3, [1, 1, 10]) # 3 = feat-graph + sem-graph + ori_graph
|
||||
# self.overall_g_agg[r] = GraphChannelAttLayer(3) # 3 = feat-graph + sem-graph + ori_graph
|
||||
|
||||
# ! Graph Convolution
|
||||
if cf.conv_method == 'gcn':
|
||||
self.GCN = GCN(g.n_feat, cf.emb_dim, g.n_class, cf.dropout)
|
||||
self.norm_order = cf.adj_norm_order
|
||||
|
||||
def forward(self, features, adj_ori, mp_emb):
|
||||
def get_rel_mat(mat, r):
|
||||
return mat[self.ri[r][0]:self.ri[r][1], self.ri[r][2]:self.ri[r][3]]
|
||||
|
||||
def get_type_rows(mat, type):
|
||||
return mat[self.ti[type]['ind'], :]
|
||||
|
||||
def gen_g_via_feat(graph_gen_func, mat, r):
|
||||
return graph_gen_func(get_type_rows(mat, r[0]), get_type_rows(mat, r[-1]))
|
||||
|
||||
# ! Heterogeneous Feature Mapping
|
||||
com_feat_mat = torch.cat([self.non_linear(
|
||||
self.encoder[t](features[self.ti[t]['ind']])) for t in self.types])
|
||||
|
||||
# ! Heterogeneous Graph Generation
|
||||
new_adj = torch.zeros_like(adj_ori).to(self.dev)
|
||||
for r in self.ud_rels:
|
||||
ori_g = get_rel_mat(adj_ori, r)
|
||||
# ! Feature Graph Generation
|
||||
fg_direct = gen_g_via_feat(self.fgg_direct[r], com_feat_mat, r)
|
||||
|
||||
fmat_l, fmat_r = features[self.ti[r[0]]['ind']], features[self.ti[r[-1]]['ind']]
|
||||
sim_l, sim_r = self.fgg_left[r](fmat_l, fmat_l), self.fgg_right[r](fmat_r, fmat_r)
|
||||
fg_left, fg_right = sim_l.mm(ori_g), sim_r.mm(ori_g.t()).t()
|
||||
|
||||
feat_g = self.fg_agg[r]([fg_direct, fg_left, fg_right])
|
||||
|
||||
# ! Semantic Graph Generation
|
||||
sem_g_list = [gen_g_via_feat(self.sgg_gen[r][mp], mp_emb[mp], r) for mp in mp_emb]
|
||||
sem_g = self.sg_agg[r](sem_g_list)
|
||||
# ! Overall Graph
|
||||
# Update relation sub-matixs
|
||||
new_adj[self.ri[r][0]:self.ri[r][1], self.ri[r][2]:self.ri[r][3]] = \
|
||||
self.overall_g_agg[r]([feat_g, sem_g, ori_g]) # update edge e.g. AP
|
||||
|
||||
new_adj += new_adj.t() # sysmetric
|
||||
# ! Aggregate
|
||||
new_adj = F.normalize(new_adj, dim=0, p=self.norm_order)
|
||||
logits = self.GCN(features, new_adj)
|
||||
return logits, new_adj
|
||||
|
||||
|
||||
class MetricCalcLayer(nn.Module):
|
||||
def __init__(self, nhid):
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.FloatTensor(1, nhid))
|
||||
nn.init.xavier_uniform_(self.weight)
|
||||
|
||||
def forward(self, h):
|
||||
return h * self.weight
|
||||
|
||||
|
||||
class GraphGenerator(nn.Module):
|
||||
"""
|
||||
Generate graph using similarity.
|
||||
"""
|
||||
|
||||
def __init__(self, dim, num_head=2, threshold=0.1, dev=None):
|
||||
super(GraphGenerator, self).__init__()
|
||||
self.threshold = threshold
|
||||
self.metric_layer = nn.ModuleList()
|
||||
for i in range(num_head):
|
||||
self.metric_layer.append(MetricCalcLayer(dim))
|
||||
self.num_head = num_head
|
||||
self.dev = dev
|
||||
|
||||
def forward(self, left_h, right_h):
|
||||
"""
|
||||
|
||||
Args:
|
||||
left_h: left_node_num * hidden_dim/feat_dim
|
||||
right_h: right_node_num * hidden_dim/feat_dim
|
||||
Returns:
|
||||
|
||||
"""
|
||||
if torch.sum(left_h) == 0 or torch.sum(right_h) == 0:
|
||||
return torch.zeros((left_h.shape[0], right_h.shape[0])).to(self.dev)
|
||||
s = torch.zeros((left_h.shape[0], right_h.shape[0])).to(self.dev)
|
||||
zero_lines = torch.nonzero(torch.sum(left_h, 1) == 0)
|
||||
# The ReLU function will generate zero lines, which lead to the nan (devided by zero) problem.
|
||||
if len(zero_lines) > 0:
|
||||
left_h[zero_lines, :] += 1e-8
|
||||
for i in range(self.num_head):
|
||||
weighted_left_h = self.metric_layer[i](left_h)
|
||||
weighted_right_h = self.metric_layer[i](right_h)
|
||||
s += cos_sim(weighted_left_h, weighted_right_h)
|
||||
s /= self.num_head
|
||||
s = torch.where(s < self.threshold, torch.zeros_like(s), s)
|
||||
return s
|
||||
|
||||
|
||||
class GCN(nn.Module):
|
||||
def __init__(self, nfeat, nhid, nclass, dropout):
|
||||
super(GCN, self).__init__()
|
||||
self.gc1 = GraphConvolution(nfeat, nhid)
|
||||
self.gc2 = GraphConvolution(nhid, nclass)
|
||||
self.dropout = dropout
|
||||
|
||||
def forward(self, x, adj):
|
||||
x = F.relu(self.gc1(x, adj))
|
||||
x = F.dropout(x, self.dropout, training=self.training)
|
||||
x = self.gc2(x, adj)
|
||||
return F.log_softmax(x, dim=1)
|
||||
# return x
|
||||
|
||||
|
||||
class GraphChannelAttLayer(nn.Module):
|
||||
|
||||
def __init__(self, num_channel, weights=None):
|
||||
super(GraphChannelAttLayer, self).__init__()
|
||||
self.weight = nn.Parameter(torch.Tensor(num_channel, 1, 1))
|
||||
nn.init.constant_(self.weight, 0.1) # equal weight
|
||||
# if weights != None:
|
||||
# # self.weight.data = nn.Parameter(torch.Tensor(weights).reshape(self.weight.shape))
|
||||
# with torch.no_grad():
|
||||
# w = torch.Tensor(weights).reshape(self.weight.shape)
|
||||
# self.weight.copy_(w)
|
||||
|
||||
def forward(self, adj_list):
|
||||
adj_list = torch.stack(adj_list)
|
||||
# Row normalization of all graphs generated
|
||||
adj_list = F.normalize(adj_list, dim=1, p=1)
|
||||
# Hadamard product + summation -> Conv
|
||||
return torch.sum(adj_list * F.softmax(self.weight, dim=0), dim=0)
|
||||
|
||||
class GraphConvolution(nn.Module): # GCN AHW
|
||||
def __init__(self, in_features, out_features, bias=True):
|
||||
super(GraphConvolution, self).__init__()
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
self.weight = Parameter(torch.FloatTensor(in_features, out_features))
|
||||
if bias:
|
||||
self.bias = Parameter(torch.FloatTensor(out_features))
|
||||
else:
|
||||
self.register_parameter('bias', None)
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
stdv = 1. / math.sqrt(self.weight.size(1))
|
||||
self.weight.data.uniform_(-stdv, stdv)
|
||||
if self.bias is not None:
|
||||
self.bias.data.uniform_(-stdv, stdv)
|
||||
|
||||
def forward(self, inputs, adj):
|
||||
support = torch.spmm(inputs, self.weight) # HW in GCN
|
||||
output = torch.spmm(adj, support) # AHW
|
||||
if self.bias is not None:
|
||||
return output + self.bias
|
||||
else:
|
||||
return output
|
|
@ -0,0 +1,2 @@
|
|||
from .config import *
|
||||
from .HGSL import HGSL
|
|
@ -0,0 +1,35 @@
|
|||
from shared_configs import ModelConfig, DataConfig
|
||||
|
||||
e = 2.71828
|
||||
|
||||
|
||||
class HGSLConfig(ModelConfig):
|
||||
def __init__(self, dataset, seed=0):
|
||||
super(HGSLConfig, self).__init__('HGSL')
|
||||
default_settings = \
|
||||
{'acm': {'alpha': 1, 'dropout': 0, 'fgd_th': 0.8, 'fgh_th': 0.2, 'sem_th': 0.6,
|
||||
'mp_list': ['psp', 'pap', 'pspap']},
|
||||
'dblp': {'alpha': 4.5, 'dropout': 0.2, 'fgd_th': 0.99, 'fgh_th': 0.99, 'sem_th': 0.4, 'mp_list': ['apcpa']},
|
||||
'yelp': {'alpha': 0.5, 'dropout': 0.2, 'fgd_th': 0.8, 'fgh_th': 0.1, 'sem_th': 0.2,
|
||||
'mp_list': ['bub', 'bsb', 'bublb', 'bubsb']}
|
||||
}
|
||||
self.dataset = dataset
|
||||
self.__dict__.update(default_settings[dataset])
|
||||
# ! Model settings
|
||||
self.lr = 0.01
|
||||
self.seed = seed
|
||||
self.save_model_conf_list() # * Save the model config list keys
|
||||
self.conv_method = 'gcn'
|
||||
self.num_head = 2
|
||||
self.early_stop = 80
|
||||
self.adj_norm_order = 1
|
||||
self.feat_norm = -1
|
||||
self.emb_dim = 64
|
||||
self.com_feat_dim = 16
|
||||
self.weight_decay = 5e-4
|
||||
self.model = 'HGSL'
|
||||
self.epochs = 200
|
||||
self.exp_name = 'debug'
|
||||
self.save_weights = False
|
||||
d_conf = DataConfig(dataset)
|
||||
self.__dict__.update(d_conf.__dict__)
|
|
@ -0,0 +1,39 @@
|
|||
"""
|
||||
Early stop provided by DGL
|
||||
"""
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
class EarlyStopping:
|
||||
def __init__(self, patience=10, path='es_checkpoint.pt'):
|
||||
self.patience = patience
|
||||
self.counter = 0
|
||||
self.best_score = None
|
||||
self.best_epoch = None
|
||||
self.early_stop = False
|
||||
self.path = path
|
||||
|
||||
def step(self, acc, model, epoch):
|
||||
score = acc
|
||||
if self.best_score is None:
|
||||
self.best_score = score
|
||||
self.best_epoch = epoch
|
||||
self.save_checkpoint(model)
|
||||
elif score < self.best_score:
|
||||
self.counter += 1
|
||||
print(
|
||||
f'EarlyStopping counter: {self.counter}/{self.patience}, best_val_score:{self.best_score:.4f} at E{self.best_epoch}')
|
||||
if self.counter >= self.patience:
|
||||
self.early_stop = True
|
||||
else:
|
||||
self.best_score = score
|
||||
self.best_epoch = epoch
|
||||
self.save_checkpoint(model)
|
||||
self.counter = 0
|
||||
|
||||
return self.early_stop
|
||||
|
||||
def save_checkpoint(self, model):
|
||||
'''Saves model when validation loss decrease.'''
|
||||
torch.save(model.state_dict(), self.path)
|
|
@ -0,0 +1,87 @@
|
|||
import torch
|
||||
import util_funcs as uf
|
||||
|
||||
|
||||
def torch_f1_score(pred, target, n_class):
|
||||
'''
|
||||
Returns macro-f1 and micro-f1 score
|
||||
Args:
|
||||
pred:
|
||||
target:
|
||||
n_class:
|
||||
|
||||
Returns:
|
||||
ma_f1,mi_f1: numpy values of macro-f1 and micro-f1 scores.
|
||||
'''
|
||||
|
||||
def true_positive(pred, target, n_class):
|
||||
return torch.tensor([((pred == i) & (target == i)).sum()
|
||||
for i in range(n_class)])
|
||||
|
||||
def false_positive(pred, target, n_class):
|
||||
return torch.tensor([((pred == i) & (target != i)).sum()
|
||||
for i in range(n_class)])
|
||||
|
||||
def false_negative(pred, target, n_class):
|
||||
return torch.tensor([((pred != i) & (target == i)).sum()
|
||||
for i in range(n_class)])
|
||||
|
||||
def precision(tp, fp):
|
||||
res = tp / (tp + fp)
|
||||
res[torch.isnan(res)] = 0
|
||||
return res
|
||||
|
||||
def recall(tp, fn):
|
||||
res = tp / (tp + fn)
|
||||
res[torch.isnan(res)] = 0
|
||||
return res
|
||||
|
||||
def f1_score(prec, rec):
|
||||
f1_score = 2 * (prec * rec) / (prec + rec)
|
||||
f1_score[torch.isnan(f1_score)] = 0
|
||||
return f1_score
|
||||
|
||||
def cal_maf1(tp, fp, fn):
|
||||
prec = precision(tp, fp)
|
||||
rec = recall(tp, fn)
|
||||
ma_f1 = f1_score(prec, rec)
|
||||
return torch.mean(ma_f1).cpu().numpy()
|
||||
|
||||
def cal_mif1(tp, fp, fn):
|
||||
gl_tp, gl_fp, gl_fn = torch.sum(tp), torch.sum(fp), torch.sum(fn)
|
||||
gl_prec = precision(gl_tp, gl_fp)
|
||||
gl_rec = recall(gl_tp, gl_fn)
|
||||
mi_f1 = f1_score(gl_prec, gl_rec)
|
||||
return mi_f1.cpu().numpy()
|
||||
|
||||
tp = true_positive(pred, target, n_class).to(torch.float)
|
||||
fn = false_negative(pred, target, n_class).to(torch.float)
|
||||
fp = false_positive(pred, target, n_class).to(torch.float)
|
||||
|
||||
ma_f1, mi_f1 = cal_maf1(tp, fp, fn), cal_mif1(tp, fp, fn)
|
||||
return ma_f1, mi_f1
|
||||
|
||||
|
||||
def eval_logits(logits, target_x, target_y):
|
||||
pred_y = torch.argmax(logits[target_x], dim=1)
|
||||
return torch_f1_score(pred_y, target_y, n_class=logits.shape[1])
|
||||
|
||||
|
||||
def eval_and_save(cf, logits, test_x, test_y, val_x, val_y, stopper=None, res={}):
|
||||
test_f1, test_mif1 = eval_logits(logits, test_x, test_y)
|
||||
val_f1, val_mif1 = eval_logits(logits, val_x, val_y)
|
||||
save_results(cf, test_f1, val_f1, test_mif1, val_mif1, stopper, res)
|
||||
|
||||
|
||||
def save_results(cf, test_f1, val_f1, test_mif1=0, val_mif1=0, stopper=None, res={}):
|
||||
if stopper != None:
|
||||
res.update({'test_f1': f'{test_f1:.4f}', 'test_mif1': f'{test_mif1:.4f}',
|
||||
'val_f1': f'{val_f1:.4f}', 'val_mif1': f'{val_mif1:.4f}',
|
||||
'best_epoch': stopper.best_epoch})
|
||||
else:
|
||||
res.update({'test_f1': f'{test_f1:.4f}', 'test_mif1': f'{test_mif1:.4f}',
|
||||
'val_f1': f'{val_f1:.4f}', 'val_mif1': f'{val_mif1:.4f}'})
|
||||
print(f"Seed{cf.seed}")
|
||||
res_dict = {'res': res, 'parameters': cf.get_model_conf()}
|
||||
print(f'\n\n\nTrain finished, results:{res_dict}')
|
||||
uf.write_nested_dict(res_dict, cf.res_file)
|
|
@ -0,0 +1,79 @@
|
|||
import pickle
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import scipy
|
||||
|
||||
|
||||
class HIN(object):
|
||||
|
||||
def __init__(self, dataset):
|
||||
data_path = f'data/{dataset}/'
|
||||
with open(f'{data_path}node_features.pkl', 'rb') as f:
|
||||
self.features = pickle.load(f)
|
||||
with open(f'{data_path}edges.pkl', 'rb') as f:
|
||||
self.edges = pickle.load(f)
|
||||
with open(f'{data_path}labels.pkl', 'rb') as f:
|
||||
self.labels = pickle.load(f)
|
||||
with open(f'{data_path}meta_data.pkl', 'rb') as f:
|
||||
self.__dict__.update(pickle.load(f))
|
||||
if scipy.sparse.issparse(self.features):
|
||||
self.features = self.features.todense()
|
||||
|
||||
def to_torch(self, cf):
|
||||
'''
|
||||
Returns the torch tensor of the graph.
|
||||
Args:
|
||||
cf: The ModelConfig file.
|
||||
Returns:
|
||||
features, adj: feature and adj. matrix
|
||||
mp_emb: only available for models that uses mp_list.
|
||||
train_x, train_y, val_x, val_y, test_x, test_y: train/val/test index and labels
|
||||
'''
|
||||
features = torch.from_numpy(self.features).type(torch.FloatTensor).to(cf.dev)
|
||||
train_x, train_y, val_x, val_y, test_x, test_y = self.get_label(cf.dev)
|
||||
|
||||
adj = np.sum(list(self.edges.values())).todense()
|
||||
adj = torch.from_numpy(adj).type(torch.FloatTensor).to(cf.dev)
|
||||
adj = F.normalize(adj, dim=1, p=2)
|
||||
|
||||
mp_emb = {}
|
||||
if hasattr(cf, 'mp_list'):
|
||||
for mp in cf.mp_list:
|
||||
mp_emb[mp] = torch.from_numpy(self.mp_emb_dict[mp]).type(torch.FloatTensor).to(cf.dev)
|
||||
if hasattr(cf, 'feat_norm'):
|
||||
if cf.feat_norm > 0:
|
||||
features = F.normalize(features, dim=1, p=cf.feat_norm)
|
||||
for mp in cf.mp_list:
|
||||
mp_emb[mp] = F.normalize(mp_emb[mp], dim=1, p=cf.feat_norm)
|
||||
return features, adj, mp_emb, train_x, train_y, val_x, val_y, test_x, test_y
|
||||
|
||||
def load_mp_embedding(self, cf):
|
||||
'''Load pretrained mp_embedding'''
|
||||
self.mp_emb_dict = {}
|
||||
for mp in cf.mp_list:
|
||||
f_name = f'{cf.data_path}{mp}_emb.pkl'
|
||||
with open(f_name, 'rb') as f:
|
||||
z = pickle.load(f)
|
||||
zero_lines = np.nonzero(np.sum(z, 1) == 0)
|
||||
if len(zero_lines) > 0:
|
||||
# raise ValueError('{} zero lines in {}s!\nZero lines:{}'.format(len(zero_lines), mode, zero_lines))
|
||||
z[zero_lines, :] += 1e-8
|
||||
self.mp_emb_dict[mp] = z
|
||||
return self
|
||||
|
||||
def get_label(self, dev):
|
||||
'''
|
||||
Args:
|
||||
dev: device (cpu or gpu)
|
||||
|
||||
Returns:
|
||||
train_x, train_y, val_x, val_y, test_x, test_y: train/val/test index and labels
|
||||
'''
|
||||
train_x = torch.from_numpy(np.array(self.labels[0])[:, 0]).type(torch.LongTensor).to(dev)
|
||||
train_y = torch.from_numpy(np.array(self.labels[0])[:, 1]).type(torch.LongTensor).to(dev)
|
||||
val_x = torch.from_numpy(np.array(self.labels[1])[:, 0]).type(torch.LongTensor).to(dev)
|
||||
val_y = torch.from_numpy(np.array(self.labels[1])[:, 1]).type(torch.LongTensor).to(dev)
|
||||
test_x = torch.from_numpy(np.array(self.labels[2])[:, 0]).type(torch.LongTensor).to(dev)
|
||||
test_y = torch.from_numpy(np.array(self.labels[2])[:, 1]).type(torch.LongTensor).to(dev)
|
||||
return train_x, train_y, val_x, val_y, test_x, test_y
|
|
@ -0,0 +1,65 @@
|
|||
import util_funcs as uf
|
||||
|
||||
|
||||
class ModelConfig():
|
||||
def __init__(self, model):
|
||||
self.model = model
|
||||
self.exp_name = 'default'
|
||||
self.model_conf_list = None
|
||||
|
||||
def __str__(self):
|
||||
# Print all attributes including data and other path settings added to the config object.
|
||||
return str(self.__dict__)
|
||||
|
||||
def save_model_conf_list(self):
|
||||
self.model_conf_list = list(self.__dict__.copy().keys())
|
||||
self.model_conf_list.remove('model_conf_list')
|
||||
|
||||
def update_file_conf(self):
|
||||
f_conf = FileConfig(self)
|
||||
self.__dict__.update(f_conf.__dict__)
|
||||
return self
|
||||
|
||||
def model_conf_to_str(self):
|
||||
# Print the model settings only.
|
||||
return str({k: self.__dict__[k] for k in self.model_conf_list})
|
||||
|
||||
def get_model_conf(self):
|
||||
# Return the model settings only.
|
||||
return {k: self.__dict__[k] for k in self.model_conf_list}
|
||||
|
||||
def update(self, conf_dict):
|
||||
self.__dict__.update(conf_dict)
|
||||
self.update_file_conf()
|
||||
return self
|
||||
|
||||
|
||||
class DataConfig:
|
||||
def __init__(self, dataset):
|
||||
data_conf = {
|
||||
'acm': {'data_type': 'pas', 'relation_list': 'p-a+a-p+p-s+s-p'},
|
||||
'dblp': {'data_type': 'apc', 'relation_list': 'p-a+a-p+p-c+c-p'},
|
||||
'imdb': {'data_type': 'mad', 'relation_list': 'm-a+a-m+m-d+d-m'},
|
||||
'aminer': {'data_type': 'apr', 'relation_list': 'p-a+p-r+a-p+r-p'},
|
||||
'yelp': {'data_type': 'busl', 'relation_list': 'b-u+u-b+b-s+s-b+b-l+l-b'}
|
||||
}
|
||||
self.__dict__.update(data_conf[dataset])
|
||||
self.dataset = dataset
|
||||
self.data_path = f'data/{dataset}/'
|
||||
|
||||
return
|
||||
|
||||
|
||||
class FileConfig:
|
||||
|
||||
def __init__(self, cf: ModelConfig):
|
||||
'''
|
||||
1. Set f_prefix for each model. The f_prefix stores the important hyperparamters (tuned parameters) of the model.
|
||||
2. Generate the file names using f_prefix.
|
||||
3. Create required directories.
|
||||
'''
|
||||
if cf.model[:4] == 'HGSL':
|
||||
f_prefix = f'do{cf.dropout}_lr{cf.lr}_a{cf.alpha}_tr{cf.fgd_th}-{cf.fgh_th}-{cf.sem_th}_mpl{uf.mp_list_str(cf.mp_list)}'
|
||||
self.res_file = f'results/{cf.dataset}/{cf.model}/{cf.exp_name}/<{cf.model}>{f_prefix}.txt'
|
||||
self.checkpoint_file = f'temp/{cf.model}/{cf.dataset}/{f_prefix}{uf.get_cur_time()}.pt'
|
||||
uf.mkdir_list(self.__dict__.values())
|
|
@ -0,0 +1,94 @@
|
|||
import os
|
||||
import sys
|
||||
|
||||
cur_path = os.path.abspath(os.path.dirname(__file__))
|
||||
root_path = cur_path.split('src')[0]
|
||||
sys.path.append(root_path + 'src')
|
||||
os.chdir(root_path)
|
||||
from early_stopper import *
|
||||
from hin_loader import HIN
|
||||
from evaluation import *
|
||||
import util_funcs as uf
|
||||
from config import HGSLConfig
|
||||
from HGSL import HGSL
|
||||
import warnings
|
||||
import time
|
||||
import torch
|
||||
import argparse
|
||||
|
||||
warnings.filterwarnings('ignore')
|
||||
root_path = os.path.abspath(os.path.dirname(__file__)).split('src')[0]
|
||||
|
||||
|
||||
def train_hgsl(args, gpu_id=0, log_on=True):
|
||||
uf.seed_init(args.seed)
|
||||
uf.shell_init(gpu_id=gpu_id)
|
||||
cf = HGSLConfig(args.dataset)
|
||||
|
||||
# ! Modify config
|
||||
cf.update(args.__dict__)
|
||||
cf.dev = torch.device("cuda:0" if gpu_id >= 0 else "cpu")
|
||||
|
||||
# ! Load Graph
|
||||
g = HIN(cf.dataset).load_mp_embedding(cf)
|
||||
print(f'Dataset: {cf.dataset}, {g.t_info}')
|
||||
features, adj, mp_emb, train_x, train_y, val_x, val_y, test_x, test_y = g.to_torch(cf)
|
||||
|
||||
# ! Train Init
|
||||
if not log_on: uf.block_logs()
|
||||
print(f'{cf}\nStart training..')
|
||||
cla_loss = torch.nn.NLLLoss()
|
||||
model = HGSL(cf, g)
|
||||
model.to(cf.dev)
|
||||
optimizer = torch.optim.Adam(
|
||||
model.parameters(), lr=cf.lr, weight_decay=cf.weight_decay)
|
||||
stopper = EarlyStopping(patience=cf.early_stop, path=cf.checkpoint_file)
|
||||
|
||||
dur = []
|
||||
w_list = []
|
||||
for epoch in range(cf.epochs):
|
||||
# ! Train
|
||||
t0 = time.time()
|
||||
model.train()
|
||||
logits, adj_new = model(features, adj, mp_emb)
|
||||
train_f1, train_mif1 = eval_logits(logits, train_x, train_y)
|
||||
w_list.append(uf.print_weights(model))
|
||||
|
||||
l_pred = cla_loss(logits[train_x], train_y)
|
||||
l_reg = cf.alpha * torch.norm(adj, 1)
|
||||
loss = l_pred + l_reg
|
||||
optimizer.zero_grad()
|
||||
with torch.autograd.detect_anomaly():
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# ! Valid
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
logits = model.GCN(features, adj_new)
|
||||
val_f1, val_mif1 = eval_logits(logits, val_x, val_y)
|
||||
dur.append(time.time() - t0)
|
||||
uf.print_train_log(epoch, dur, loss, train_f1, val_f1)
|
||||
|
||||
if cf.early_stop > 0:
|
||||
if stopper.step(val_f1, model, epoch):
|
||||
print(f'Early stopped, loading model from epoch-{stopper.best_epoch}')
|
||||
break
|
||||
|
||||
if cf.early_stop > 0:
|
||||
model.load_state_dict(torch.load(cf.checkpoint_file))
|
||||
logits, _ = model(features, adj, mp_emb)
|
||||
cf.update(w_list[stopper.best_epoch])
|
||||
eval_and_save(cf, logits, test_x, test_y, val_x, val_y, stopper)
|
||||
if not log_on: uf.enable_logs()
|
||||
return cf
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
dataset = 'yelp'
|
||||
parser.add_argument('--dataset', type=str, default=dataset)
|
||||
parser.add_argument('--gpu_id', type=int, default=0)
|
||||
parser.add_argument('--seed', type=int, default=0)
|
||||
args = parser.parse_args()
|
||||
train_hgsl(args, gpu_id=args.gpu_id)
|
|
@ -0,0 +1,264 @@
|
|||
import logging
|
||||
import os
|
||||
import pickle
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
# * ============================= Init =============================
|
||||
def shell_init(server='S5', gpu_id=0):
|
||||
'''
|
||||
|
||||
Features:
|
||||
1. Specify server specific source and python command
|
||||
2. Fix Pycharm LD_LIBRARY_ISSUE
|
||||
3. Block warnings
|
||||
4. Block TF useless messages
|
||||
5. Set paths
|
||||
'''
|
||||
import warnings
|
||||
np.seterr(invalid='ignore')
|
||||
warnings.filterwarnings("ignore", category=FutureWarning)
|
||||
warnings.filterwarnings("ignore", category=UserWarning)
|
||||
warnings.filterwarnings("ignore", category=RuntimeWarning)
|
||||
|
||||
if server == 'Xy':
|
||||
python_command = '/home/chopin/zja/anaconda/bin/python'
|
||||
elif server == 'Colab':
|
||||
python_command = 'python'
|
||||
else:
|
||||
python_command = '~/anaconda3/bin/python'
|
||||
if gpu_id > 0:
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_id)
|
||||
os.environ['LD_LIBRARY_PATH'] = '/usr/local/cuda/lib64/' # Extremely useful for Pycharm users
|
||||
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # Block TF messages
|
||||
|
||||
return python_command
|
||||
|
||||
|
||||
def seed_init(seed):
|
||||
import torch
|
||||
import random
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
|
||||
# * ============================= Torch =============================
|
||||
|
||||
def exists_zero_lines(h):
|
||||
zero_lines = torch.where(torch.sum(h, 1) == 0)[0]
|
||||
if len(zero_lines) > 0:
|
||||
# raise ValueError('{} zero lines in {}s!\nZero lines:{}'.format(len(zero_lines), 'emb', zero_lines))
|
||||
print(f'{len(zero_lines)} zero lines !\nZero lines:{zero_lines}')
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def cos_sim(a, b, eps=1e-8):
|
||||
"""
|
||||
calculate cosine similarity between matrix a and b
|
||||
"""
|
||||
a_n, b_n = a.norm(dim=1)[:, None], b.norm(dim=1)[:, None]
|
||||
a_norm = a / torch.max(a_n, eps * torch.ones_like(a_n))
|
||||
b_norm = b / torch.max(b_n, eps * torch.ones_like(b_n))
|
||||
sim_mt = torch.mm(a_norm, b_norm.transpose(0, 1))
|
||||
return sim_mt
|
||||
|
||||
|
||||
# * ============================= Print Related =============================
|
||||
|
||||
def print_dict(d, end_string='\n\n'):
|
||||
for key in d.keys():
|
||||
if isinstance(d[key], dict):
|
||||
print('\n', end='')
|
||||
print_dict(d[key], end_string='')
|
||||
elif isinstance(d[key], int):
|
||||
print('{}: {:04d}'.format(key, d[key]), end=', ')
|
||||
elif isinstance(d[key], float):
|
||||
print('{}: {:.4f}'.format(key, d[key]), end=', ')
|
||||
else:
|
||||
print('{}: {}'.format(key, d[key]), end=', ')
|
||||
print(end_string, end='')
|
||||
|
||||
|
||||
def block_logs():
|
||||
sys.stdout = open(os.devnull, 'w')
|
||||
logger = logging.getLogger()
|
||||
logger.disabled = True
|
||||
|
||||
|
||||
def enable_logs():
|
||||
# Restore
|
||||
sys.stdout = sys.__stdout__
|
||||
logger = logging.getLogger()
|
||||
logger.disabled = False
|
||||
|
||||
|
||||
def progress_bar(prefix, start_time, i, max_i, postfix):
|
||||
"""
|
||||
Generates progress bar AFTER the ith epoch.
|
||||
Args:
|
||||
prefix: the prefix of printed string
|
||||
start_time: start time of the loop
|
||||
i: finished epoch index
|
||||
max_i: total iteration times
|
||||
postfix: the postfix of printed string
|
||||
|
||||
Returns: prints the generated progress bar
|
||||
|
||||
"""
|
||||
cur_run_time = time.time() - start_time
|
||||
i += 1
|
||||
if i != 0:
|
||||
total_estimated_time = cur_run_time * max_i / i
|
||||
else:
|
||||
total_estimated_time = 0
|
||||
print(
|
||||
f'{prefix} : {i}/{max_i} [{time2str(cur_run_time)}/{time2str(total_estimated_time)}, {time2str(total_estimated_time - cur_run_time)} left] - {postfix}-{get_cur_time()}')
|
||||
|
||||
|
||||
def print_train_log(epoch, dur, loss, train_f1, val_f1):
|
||||
print(
|
||||
f"Epoch {epoch:05d} | Time(s) {np.mean(dur):.4f} | Loss {loss.item():.4f} | TrainF1 {train_f1:.4f} | ValF1 {val_f1:.4f}")
|
||||
|
||||
|
||||
def mp_list_str(mp_list):
|
||||
return '_'.join(mp_list)
|
||||
|
||||
|
||||
# * ============================= File Operations =============================
|
||||
|
||||
def write_nested_dict(d, f_path):
|
||||
def _write_dict(d, f):
|
||||
for key in d.keys():
|
||||
if isinstance(d[key], dict):
|
||||
f.write(str(d[key]) + '\n')
|
||||
|
||||
with open(f_path, 'a+') as f:
|
||||
f.write('\n')
|
||||
_write_dict(d, f)
|
||||
|
||||
|
||||
def save_pickle(var, f_name):
|
||||
pickle.dump(var, open(f_name, 'wb'))
|
||||
|
||||
|
||||
def load_pickle(f_name):
|
||||
return pickle.load(open(f_name, 'rb'))
|
||||
|
||||
|
||||
def clear_results(dataset, model):
|
||||
res_path = f'results/{dataset}/{model}/'
|
||||
os.system(f'rm -rf {res_path}')
|
||||
print(f'Results in {res_path} are cleared.')
|
||||
|
||||
|
||||
# * ============================= Path Operations =============================
|
||||
|
||||
def check_path(path):
|
||||
if not os.path.exists(path):
|
||||
os.makedirs(path)
|
||||
|
||||
|
||||
def get_dir_of_file(f_name):
|
||||
return os.path.dirname(f_name) + '/'
|
||||
|
||||
|
||||
def get_grand_parent_dir(f_name):
|
||||
if '.' in f_name.split('/')[-1]: # File
|
||||
return get_grand_parent_dir(get_dir_of_file(f_name))
|
||||
else: # Path
|
||||
return f'{Path(f_name).parent}/'
|
||||
|
||||
|
||||
def get_abs_path(f_name, style='command_line'):
|
||||
# python 中的文件目录对空格的处理为空格,命令行对空格的处理为'\ '所以命令行相关需 replace(' ','\ ')
|
||||
if style == 'python':
|
||||
cur_path = os.path.abspath(os.path.dirname(__file__))
|
||||
elif style == 'command_line':
|
||||
cur_path = os.path.abspath(os.path.dirname(__file__)).replace(' ', '\ ')
|
||||
|
||||
root_path = cur_path.split('src')[0]
|
||||
return os.path.join(root_path, f_name)
|
||||
|
||||
|
||||
def mkdir_p(path, log=True):
|
||||
"""Create a directory for the specified path.
|
||||
Parameters
|
||||
----------
|
||||
path : str
|
||||
Path name
|
||||
log : bool
|
||||
Whether to print result for directory creation
|
||||
"""
|
||||
import errno
|
||||
if os.path.exists(path): return
|
||||
# print(path)
|
||||
# path = path.replace('\ ',' ')
|
||||
# print(path)
|
||||
try:
|
||||
|
||||
os.makedirs(path)
|
||||
if log:
|
||||
print('Created directory {}'.format(path))
|
||||
except OSError as exc:
|
||||
if exc.errno == errno.EEXIST and os.path.isdir(path) and log:
|
||||
print('Directory {} already exists.'.format(path))
|
||||
else:
|
||||
raise
|
||||
|
||||
|
||||
def mkdir_list(p_list, use_relative_path=True, log=True):
|
||||
"""Create directories for the specified path lists.
|
||||
Parameters
|
||||
----------
|
||||
p_list :Path lists
|
||||
|
||||
"""
|
||||
# ! Note that the paths MUST END WITH '/' !!!
|
||||
root_path = os.path.abspath(os.path.dirname(__file__)).split('src')[0]
|
||||
for p in p_list:
|
||||
p = os.path.join(root_path, p) if use_relative_path else p
|
||||
p = os.path.dirname(p)
|
||||
mkdir_p(p, log)
|
||||
|
||||
|
||||
# * ============================= Time Related =============================
|
||||
|
||||
def time2str(t):
|
||||
if t > 86400:
|
||||
return '{:.2f}day'.format(t / 86400)
|
||||
if t > 3600:
|
||||
return '{:.2f}h'.format(t / 3600)
|
||||
elif t > 60:
|
||||
return '{:.2f}min'.format(t / 60)
|
||||
else:
|
||||
return '{:.2f}s'.format(t)
|
||||
|
||||
|
||||
def get_cur_time():
|
||||
import datetime
|
||||
dt = datetime.datetime.now()
|
||||
return f'{dt.date()}_{dt.hour:02d}-{dt.minute:02d}-{dt.second:02d}'
|
||||
|
||||
|
||||
# * ============================= Others =============================
|
||||
def print_weights(model, interested_para='_agg'):
|
||||
w_dict = {}
|
||||
for name, W in model.named_parameters():
|
||||
if interested_para in name:
|
||||
data = F.softmax(W.data.squeeze()).cpu().numpy()
|
||||
# print(f'{name}:{data}')
|
||||
w_dict[name] = data
|
||||
return w_dict
|
||||
|
||||
|
||||
def count_avg_neighbors(adj):
|
||||
return len(torch.where(adj > 0)[0]) / adj.shape[0]
|
Loading…
Reference in New Issue