diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 0000000..e9ab4fc Binary files /dev/null and b/.DS_Store differ diff --git a/README.md b/README.md index 4cf0b29..7dc3eb7 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,14 @@ -# AAAI21-HGSL -Source code of AAAI21-Heterogeneous Graph Structure Learning for Graph Neural Networks -To be released soon... +# HGSL +Source code of AAAI submission "Heterogeneous Graph Structure Learning for Graph Neural Networks" +# Requirements +## Python Packages +- Python >= 3.6.8 +- Pytorch >= 1.3.0 +- DGL == 0.4.3 +## GPU Memmory Requirements +- ACM >= 8G +- DBLP >=5G +- Yelp >=3G +# Usage +Take DBLP dataset as an example: +python train.py --dataset='dblp' diff --git a/data/.DS_Store b/data/.DS_Store new file mode 100644 index 0000000..fd1406b Binary files /dev/null and b/data/.DS_Store differ diff --git a/data/acm/.DS_Store b/data/acm/.DS_Store new file mode 100644 index 0000000..5008ddf Binary files /dev/null and b/data/acm/.DS_Store differ diff --git a/data/acm/edges.pkl b/data/acm/edges.pkl new file mode 100644 index 0000000..2af3175 Binary files /dev/null and b/data/acm/edges.pkl differ diff --git a/data/acm/labels.pkl b/data/acm/labels.pkl new file mode 100644 index 0000000..ccdaa85 Binary files /dev/null and b/data/acm/labels.pkl differ diff --git a/data/acm/meta_data.pkl b/data/acm/meta_data.pkl new file mode 100644 index 0000000..1dc25fb Binary files /dev/null and b/data/acm/meta_data.pkl differ diff --git a/data/acm/node_features.pkl b/data/acm/node_features.pkl new file mode 100644 index 0000000..2ebd2d1 Binary files /dev/null and b/data/acm/node_features.pkl differ diff --git a/data/acm/pap_emb.pkl b/data/acm/pap_emb.pkl new file mode 100644 index 0000000..cb32505 Binary files /dev/null and b/data/acm/pap_emb.pkl differ diff --git a/data/acm/psp_emb.pkl b/data/acm/psp_emb.pkl new file mode 100644 index 0000000..8179be5 Binary files /dev/null and b/data/acm/psp_emb.pkl differ diff --git a/data/acm/pspap_emb.pkl b/data/acm/pspap_emb.pkl new file mode 100644 index 0000000..05554f1 Binary files /dev/null and b/data/acm/pspap_emb.pkl differ diff --git a/data/dblp/apcpa_emb.pkl b/data/dblp/apcpa_emb.pkl new file mode 100644 index 0000000..5727e77 Binary files /dev/null and b/data/dblp/apcpa_emb.pkl differ diff --git a/data/dblp/edges.pkl b/data/dblp/edges.pkl new file mode 100644 index 0000000..20d9a88 Binary files /dev/null and b/data/dblp/edges.pkl differ diff --git a/data/dblp/labels.pkl b/data/dblp/labels.pkl new file mode 100644 index 0000000..bd6023b Binary files /dev/null and b/data/dblp/labels.pkl differ diff --git a/data/dblp/meta_data.pkl b/data/dblp/meta_data.pkl new file mode 100644 index 0000000..bbf3b24 Binary files /dev/null and b/data/dblp/meta_data.pkl differ diff --git a/data/dblp/node_features.pkl b/data/dblp/node_features.pkl new file mode 100644 index 0000000..549bc6c Binary files /dev/null and b/data/dblp/node_features.pkl differ diff --git a/data/yelp/.DS_Store b/data/yelp/.DS_Store new file mode 100644 index 0000000..5008ddf Binary files /dev/null and b/data/yelp/.DS_Store differ diff --git a/data/yelp/bsb_emb.pkl b/data/yelp/bsb_emb.pkl new file mode 100644 index 0000000..b7eed31 Binary files /dev/null and b/data/yelp/bsb_emb.pkl differ diff --git a/data/yelp/bub_emb.pkl b/data/yelp/bub_emb.pkl new file mode 100644 index 0000000..60008e1 Binary files /dev/null and b/data/yelp/bub_emb.pkl differ diff --git a/data/yelp/bublb_emb.pkl b/data/yelp/bublb_emb.pkl new file mode 100644 index 0000000..b653300 Binary files /dev/null and b/data/yelp/bublb_emb.pkl differ diff --git a/data/yelp/bubsb_emb.pkl b/data/yelp/bubsb_emb.pkl new file mode 100644 index 0000000..60ae4d3 Binary files /dev/null and b/data/yelp/bubsb_emb.pkl differ diff --git a/data/yelp/edges.pkl b/data/yelp/edges.pkl new file mode 100644 index 0000000..c6741a8 Binary files /dev/null and b/data/yelp/edges.pkl differ diff --git a/data/yelp/labels.pkl b/data/yelp/labels.pkl new file mode 100644 index 0000000..12ce561 Binary files /dev/null and b/data/yelp/labels.pkl differ diff --git a/data/yelp/meta_data.pkl b/data/yelp/meta_data.pkl new file mode 100644 index 0000000..c909861 Binary files /dev/null and b/data/yelp/meta_data.pkl differ diff --git a/data/yelp/node_features.pkl b/data/yelp/node_features.pkl new file mode 100644 index 0000000..0240038 Binary files /dev/null and b/data/yelp/node_features.pkl differ diff --git a/src/.DS_Store b/src/.DS_Store new file mode 100644 index 0000000..46fe8d9 Binary files /dev/null and b/src/.DS_Store differ diff --git a/src/HGSL.py b/src/HGSL.py new file mode 100644 index 0000000..f6982b5 --- /dev/null +++ b/src/HGSL.py @@ -0,0 +1,200 @@ +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.parameter import Parameter + +from util_funcs import cos_sim + + +class HGSL(nn.Module): + """ + Decode neighbors of input graph. + """ + + def __init__(self, cf, g): + super(HGSL, self).__init__() + self.__dict__.update(cf.get_model_conf()) + # ! Init variables + self.dev = cf.dev + self.ti, self.ri, self.types, self.ud_rels = g.t_info, g.r_info, g.types, g.undirected_relations + feat_dim, mp_emb_dim = g.features.shape[1], list(g.mp_emb_dict.values())[0].shape[1] + self.non_linear = nn.ReLU() + # ! Graph Structure Learning + MD = nn.ModuleDict + self.fgg_direct, self.fgg_left, self.fgg_right, self.fg_agg, self.sgg_gen, self.sg_agg, self.overall_g_agg = \ + MD({}), MD({}), MD({}), MD({}), MD({}), MD({}), MD({}) + # Feature encoder + self.encoder = MD(dict(zip(g.types, [nn.Linear(g.features.shape[1], cf.com_feat_dim) for _ in g.types]))) + + for r in g.undirected_relations: + # ! Feature Graph Generator + self.fgg_direct[r] = GraphGenerator(cf.com_feat_dim, cf.num_head, cf.fgd_th, self.dev) + self.fgg_left[r] = GraphGenerator(feat_dim, cf.num_head, cf.fgh_th, self.dev) + self.fgg_right[r] = GraphGenerator(feat_dim, cf.num_head, cf.fgh_th, self.dev) + self.fg_agg[r] = GraphChannelAttLayer(3) # 3 = 1 (first-order/direct) + 2 (second-order) + + # ! Semantic Graph Generator + self.sgg_gen[r] = MD(dict( + zip(cf.mp_list, [GraphGenerator(mp_emb_dim, cf.num_head, cf.sem_th, self.dev) for _ in cf.mp_list]))) + self.sg_agg[r] = GraphChannelAttLayer(len(cf.mp_list)) + + # ! Overall Graph Generator + self.overall_g_agg[r] = GraphChannelAttLayer(3, [1, 1, 10]) # 3 = feat-graph + sem-graph + ori_graph + # self.overall_g_agg[r] = GraphChannelAttLayer(3) # 3 = feat-graph + sem-graph + ori_graph + + # ! Graph Convolution + if cf.conv_method == 'gcn': + self.GCN = GCN(g.n_feat, cf.emb_dim, g.n_class, cf.dropout) + self.norm_order = cf.adj_norm_order + + def forward(self, features, adj_ori, mp_emb): + def get_rel_mat(mat, r): + return mat[self.ri[r][0]:self.ri[r][1], self.ri[r][2]:self.ri[r][3]] + + def get_type_rows(mat, type): + return mat[self.ti[type]['ind'], :] + + def gen_g_via_feat(graph_gen_func, mat, r): + return graph_gen_func(get_type_rows(mat, r[0]), get_type_rows(mat, r[-1])) + + # ! Heterogeneous Feature Mapping + com_feat_mat = torch.cat([self.non_linear( + self.encoder[t](features[self.ti[t]['ind']])) for t in self.types]) + + # ! Heterogeneous Graph Generation + new_adj = torch.zeros_like(adj_ori).to(self.dev) + for r in self.ud_rels: + ori_g = get_rel_mat(adj_ori, r) + # ! Feature Graph Generation + fg_direct = gen_g_via_feat(self.fgg_direct[r], com_feat_mat, r) + + fmat_l, fmat_r = features[self.ti[r[0]]['ind']], features[self.ti[r[-1]]['ind']] + sim_l, sim_r = self.fgg_left[r](fmat_l, fmat_l), self.fgg_right[r](fmat_r, fmat_r) + fg_left, fg_right = sim_l.mm(ori_g), sim_r.mm(ori_g.t()).t() + + feat_g = self.fg_agg[r]([fg_direct, fg_left, fg_right]) + + # ! Semantic Graph Generation + sem_g_list = [gen_g_via_feat(self.sgg_gen[r][mp], mp_emb[mp], r) for mp in mp_emb] + sem_g = self.sg_agg[r](sem_g_list) + # ! Overall Graph + # Update relation sub-matixs + new_adj[self.ri[r][0]:self.ri[r][1], self.ri[r][2]:self.ri[r][3]] = \ + self.overall_g_agg[r]([feat_g, sem_g, ori_g]) # update edge e.g. AP + + new_adj += new_adj.t() # sysmetric + # ! Aggregate + new_adj = F.normalize(new_adj, dim=0, p=self.norm_order) + logits = self.GCN(features, new_adj) + return logits, new_adj + + +class MetricCalcLayer(nn.Module): + def __init__(self, nhid): + super().__init__() + self.weight = nn.Parameter(torch.FloatTensor(1, nhid)) + nn.init.xavier_uniform_(self.weight) + + def forward(self, h): + return h * self.weight + + +class GraphGenerator(nn.Module): + """ + Generate graph using similarity. + """ + + def __init__(self, dim, num_head=2, threshold=0.1, dev=None): + super(GraphGenerator, self).__init__() + self.threshold = threshold + self.metric_layer = nn.ModuleList() + for i in range(num_head): + self.metric_layer.append(MetricCalcLayer(dim)) + self.num_head = num_head + self.dev = dev + + def forward(self, left_h, right_h): + """ + + Args: + left_h: left_node_num * hidden_dim/feat_dim + right_h: right_node_num * hidden_dim/feat_dim + Returns: + + """ + if torch.sum(left_h) == 0 or torch.sum(right_h) == 0: + return torch.zeros((left_h.shape[0], right_h.shape[0])).to(self.dev) + s = torch.zeros((left_h.shape[0], right_h.shape[0])).to(self.dev) + zero_lines = torch.nonzero(torch.sum(left_h, 1) == 0) + # The ReLU function will generate zero lines, which lead to the nan (devided by zero) problem. + if len(zero_lines) > 0: + left_h[zero_lines, :] += 1e-8 + for i in range(self.num_head): + weighted_left_h = self.metric_layer[i](left_h) + weighted_right_h = self.metric_layer[i](right_h) + s += cos_sim(weighted_left_h, weighted_right_h) + s /= self.num_head + s = torch.where(s < self.threshold, torch.zeros_like(s), s) + return s + + +class GCN(nn.Module): + def __init__(self, nfeat, nhid, nclass, dropout): + super(GCN, self).__init__() + self.gc1 = GraphConvolution(nfeat, nhid) + self.gc2 = GraphConvolution(nhid, nclass) + self.dropout = dropout + + def forward(self, x, adj): + x = F.relu(self.gc1(x, adj)) + x = F.dropout(x, self.dropout, training=self.training) + x = self.gc2(x, adj) + return F.log_softmax(x, dim=1) + # return x + + +class GraphChannelAttLayer(nn.Module): + + def __init__(self, num_channel, weights=None): + super(GraphChannelAttLayer, self).__init__() + self.weight = nn.Parameter(torch.Tensor(num_channel, 1, 1)) + nn.init.constant_(self.weight, 0.1) # equal weight + # if weights != None: + # # self.weight.data = nn.Parameter(torch.Tensor(weights).reshape(self.weight.shape)) + # with torch.no_grad(): + # w = torch.Tensor(weights).reshape(self.weight.shape) + # self.weight.copy_(w) + + def forward(self, adj_list): + adj_list = torch.stack(adj_list) + # Row normalization of all graphs generated + adj_list = F.normalize(adj_list, dim=1, p=1) + # Hadamard product + summation -> Conv + return torch.sum(adj_list * F.softmax(self.weight, dim=0), dim=0) + +class GraphConvolution(nn.Module): # GCN AHW + def __init__(self, in_features, out_features, bias=True): + super(GraphConvolution, self).__init__() + self.in_features = in_features + self.out_features = out_features + self.weight = Parameter(torch.FloatTensor(in_features, out_features)) + if bias: + self.bias = Parameter(torch.FloatTensor(out_features)) + else: + self.register_parameter('bias', None) + self.reset_parameters() + + def reset_parameters(self): + stdv = 1. / math.sqrt(self.weight.size(1)) + self.weight.data.uniform_(-stdv, stdv) + if self.bias is not None: + self.bias.data.uniform_(-stdv, stdv) + + def forward(self, inputs, adj): + support = torch.spmm(inputs, self.weight) # HW in GCN + output = torch.spmm(adj, support) # AHW + if self.bias is not None: + return output + self.bias + else: + return output diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000..99dbaa4 --- /dev/null +++ b/src/__init__.py @@ -0,0 +1,2 @@ +from .config import * +from .HGSL import HGSL \ No newline at end of file diff --git a/src/config.py b/src/config.py new file mode 100644 index 0000000..6c16a61 --- /dev/null +++ b/src/config.py @@ -0,0 +1,35 @@ +from shared_configs import ModelConfig, DataConfig + +e = 2.71828 + + +class HGSLConfig(ModelConfig): + def __init__(self, dataset, seed=0): + super(HGSLConfig, self).__init__('HGSL') + default_settings = \ + {'acm': {'alpha': 1, 'dropout': 0, 'fgd_th': 0.8, 'fgh_th': 0.2, 'sem_th': 0.6, + 'mp_list': ['psp', 'pap', 'pspap']}, + 'dblp': {'alpha': 4.5, 'dropout': 0.2, 'fgd_th': 0.99, 'fgh_th': 0.99, 'sem_th': 0.4, 'mp_list': ['apcpa']}, + 'yelp': {'alpha': 0.5, 'dropout': 0.2, 'fgd_th': 0.8, 'fgh_th': 0.1, 'sem_th': 0.2, + 'mp_list': ['bub', 'bsb', 'bublb', 'bubsb']} + } + self.dataset = dataset + self.__dict__.update(default_settings[dataset]) + # ! Model settings + self.lr = 0.01 + self.seed = seed + self.save_model_conf_list() # * Save the model config list keys + self.conv_method = 'gcn' + self.num_head = 2 + self.early_stop = 80 + self.adj_norm_order = 1 + self.feat_norm = -1 + self.emb_dim = 64 + self.com_feat_dim = 16 + self.weight_decay = 5e-4 + self.model = 'HGSL' + self.epochs = 200 + self.exp_name = 'debug' + self.save_weights = False + d_conf = DataConfig(dataset) + self.__dict__.update(d_conf.__dict__) diff --git a/src/early_stopper.py b/src/early_stopper.py new file mode 100644 index 0000000..51ff844 --- /dev/null +++ b/src/early_stopper.py @@ -0,0 +1,39 @@ +""" +Early stop provided by DGL +""" +import numpy as np +import torch + + +class EarlyStopping: + def __init__(self, patience=10, path='es_checkpoint.pt'): + self.patience = patience + self.counter = 0 + self.best_score = None + self.best_epoch = None + self.early_stop = False + self.path = path + + def step(self, acc, model, epoch): + score = acc + if self.best_score is None: + self.best_score = score + self.best_epoch = epoch + self.save_checkpoint(model) + elif score < self.best_score: + self.counter += 1 + print( + f'EarlyStopping counter: {self.counter}/{self.patience}, best_val_score:{self.best_score:.4f} at E{self.best_epoch}') + if self.counter >= self.patience: + self.early_stop = True + else: + self.best_score = score + self.best_epoch = epoch + self.save_checkpoint(model) + self.counter = 0 + + return self.early_stop + + def save_checkpoint(self, model): + '''Saves model when validation loss decrease.''' + torch.save(model.state_dict(), self.path) diff --git a/src/evaluation.py b/src/evaluation.py new file mode 100644 index 0000000..8f4d327 --- /dev/null +++ b/src/evaluation.py @@ -0,0 +1,87 @@ +import torch +import util_funcs as uf + + +def torch_f1_score(pred, target, n_class): + ''' + Returns macro-f1 and micro-f1 score + Args: + pred: + target: + n_class: + + Returns: + ma_f1,mi_f1: numpy values of macro-f1 and micro-f1 scores. + ''' + + def true_positive(pred, target, n_class): + return torch.tensor([((pred == i) & (target == i)).sum() + for i in range(n_class)]) + + def false_positive(pred, target, n_class): + return torch.tensor([((pred == i) & (target != i)).sum() + for i in range(n_class)]) + + def false_negative(pred, target, n_class): + return torch.tensor([((pred != i) & (target == i)).sum() + for i in range(n_class)]) + + def precision(tp, fp): + res = tp / (tp + fp) + res[torch.isnan(res)] = 0 + return res + + def recall(tp, fn): + res = tp / (tp + fn) + res[torch.isnan(res)] = 0 + return res + + def f1_score(prec, rec): + f1_score = 2 * (prec * rec) / (prec + rec) + f1_score[torch.isnan(f1_score)] = 0 + return f1_score + + def cal_maf1(tp, fp, fn): + prec = precision(tp, fp) + rec = recall(tp, fn) + ma_f1 = f1_score(prec, rec) + return torch.mean(ma_f1).cpu().numpy() + + def cal_mif1(tp, fp, fn): + gl_tp, gl_fp, gl_fn = torch.sum(tp), torch.sum(fp), torch.sum(fn) + gl_prec = precision(gl_tp, gl_fp) + gl_rec = recall(gl_tp, gl_fn) + mi_f1 = f1_score(gl_prec, gl_rec) + return mi_f1.cpu().numpy() + + tp = true_positive(pred, target, n_class).to(torch.float) + fn = false_negative(pred, target, n_class).to(torch.float) + fp = false_positive(pred, target, n_class).to(torch.float) + + ma_f1, mi_f1 = cal_maf1(tp, fp, fn), cal_mif1(tp, fp, fn) + return ma_f1, mi_f1 + + +def eval_logits(logits, target_x, target_y): + pred_y = torch.argmax(logits[target_x], dim=1) + return torch_f1_score(pred_y, target_y, n_class=logits.shape[1]) + + +def eval_and_save(cf, logits, test_x, test_y, val_x, val_y, stopper=None, res={}): + test_f1, test_mif1 = eval_logits(logits, test_x, test_y) + val_f1, val_mif1 = eval_logits(logits, val_x, val_y) + save_results(cf, test_f1, val_f1, test_mif1, val_mif1, stopper, res) + + +def save_results(cf, test_f1, val_f1, test_mif1=0, val_mif1=0, stopper=None, res={}): + if stopper != None: + res.update({'test_f1': f'{test_f1:.4f}', 'test_mif1': f'{test_mif1:.4f}', + 'val_f1': f'{val_f1:.4f}', 'val_mif1': f'{val_mif1:.4f}', + 'best_epoch': stopper.best_epoch}) + else: + res.update({'test_f1': f'{test_f1:.4f}', 'test_mif1': f'{test_mif1:.4f}', + 'val_f1': f'{val_f1:.4f}', 'val_mif1': f'{val_mif1:.4f}'}) + print(f"Seed{cf.seed}") + res_dict = {'res': res, 'parameters': cf.get_model_conf()} + print(f'\n\n\nTrain finished, results:{res_dict}') + uf.write_nested_dict(res_dict, cf.res_file) diff --git a/src/hin_loader.py b/src/hin_loader.py new file mode 100644 index 0000000..9dacef9 --- /dev/null +++ b/src/hin_loader.py @@ -0,0 +1,79 @@ +import pickle +import numpy as np +import torch +import torch.nn.functional as F +import scipy + + +class HIN(object): + + def __init__(self, dataset): + data_path = f'data/{dataset}/' + with open(f'{data_path}node_features.pkl', 'rb') as f: + self.features = pickle.load(f) + with open(f'{data_path}edges.pkl', 'rb') as f: + self.edges = pickle.load(f) + with open(f'{data_path}labels.pkl', 'rb') as f: + self.labels = pickle.load(f) + with open(f'{data_path}meta_data.pkl', 'rb') as f: + self.__dict__.update(pickle.load(f)) + if scipy.sparse.issparse(self.features): + self.features = self.features.todense() + + def to_torch(self, cf): + ''' + Returns the torch tensor of the graph. + Args: + cf: The ModelConfig file. + Returns: + features, adj: feature and adj. matrix + mp_emb: only available for models that uses mp_list. + train_x, train_y, val_x, val_y, test_x, test_y: train/val/test index and labels + ''' + features = torch.from_numpy(self.features).type(torch.FloatTensor).to(cf.dev) + train_x, train_y, val_x, val_y, test_x, test_y = self.get_label(cf.dev) + + adj = np.sum(list(self.edges.values())).todense() + adj = torch.from_numpy(adj).type(torch.FloatTensor).to(cf.dev) + adj = F.normalize(adj, dim=1, p=2) + + mp_emb = {} + if hasattr(cf, 'mp_list'): + for mp in cf.mp_list: + mp_emb[mp] = torch.from_numpy(self.mp_emb_dict[mp]).type(torch.FloatTensor).to(cf.dev) + if hasattr(cf, 'feat_norm'): + if cf.feat_norm > 0: + features = F.normalize(features, dim=1, p=cf.feat_norm) + for mp in cf.mp_list: + mp_emb[mp] = F.normalize(mp_emb[mp], dim=1, p=cf.feat_norm) + return features, adj, mp_emb, train_x, train_y, val_x, val_y, test_x, test_y + + def load_mp_embedding(self, cf): + '''Load pretrained mp_embedding''' + self.mp_emb_dict = {} + for mp in cf.mp_list: + f_name = f'{cf.data_path}{mp}_emb.pkl' + with open(f_name, 'rb') as f: + z = pickle.load(f) + zero_lines = np.nonzero(np.sum(z, 1) == 0) + if len(zero_lines) > 0: + # raise ValueError('{} zero lines in {}s!\nZero lines:{}'.format(len(zero_lines), mode, zero_lines)) + z[zero_lines, :] += 1e-8 + self.mp_emb_dict[mp] = z + return self + + def get_label(self, dev): + ''' + Args: + dev: device (cpu or gpu) + + Returns: + train_x, train_y, val_x, val_y, test_x, test_y: train/val/test index and labels + ''' + train_x = torch.from_numpy(np.array(self.labels[0])[:, 0]).type(torch.LongTensor).to(dev) + train_y = torch.from_numpy(np.array(self.labels[0])[:, 1]).type(torch.LongTensor).to(dev) + val_x = torch.from_numpy(np.array(self.labels[1])[:, 0]).type(torch.LongTensor).to(dev) + val_y = torch.from_numpy(np.array(self.labels[1])[:, 1]).type(torch.LongTensor).to(dev) + test_x = torch.from_numpy(np.array(self.labels[2])[:, 0]).type(torch.LongTensor).to(dev) + test_y = torch.from_numpy(np.array(self.labels[2])[:, 1]).type(torch.LongTensor).to(dev) + return train_x, train_y, val_x, val_y, test_x, test_y diff --git a/src/shared_configs.py b/src/shared_configs.py new file mode 100644 index 0000000..792ef3b --- /dev/null +++ b/src/shared_configs.py @@ -0,0 +1,65 @@ +import util_funcs as uf + + +class ModelConfig(): + def __init__(self, model): + self.model = model + self.exp_name = 'default' + self.model_conf_list = None + + def __str__(self): + # Print all attributes including data and other path settings added to the config object. + return str(self.__dict__) + + def save_model_conf_list(self): + self.model_conf_list = list(self.__dict__.copy().keys()) + self.model_conf_list.remove('model_conf_list') + + def update_file_conf(self): + f_conf = FileConfig(self) + self.__dict__.update(f_conf.__dict__) + return self + + def model_conf_to_str(self): + # Print the model settings only. + return str({k: self.__dict__[k] for k in self.model_conf_list}) + + def get_model_conf(self): + # Return the model settings only. + return {k: self.__dict__[k] for k in self.model_conf_list} + + def update(self, conf_dict): + self.__dict__.update(conf_dict) + self.update_file_conf() + return self + + +class DataConfig: + def __init__(self, dataset): + data_conf = { + 'acm': {'data_type': 'pas', 'relation_list': 'p-a+a-p+p-s+s-p'}, + 'dblp': {'data_type': 'apc', 'relation_list': 'p-a+a-p+p-c+c-p'}, + 'imdb': {'data_type': 'mad', 'relation_list': 'm-a+a-m+m-d+d-m'}, + 'aminer': {'data_type': 'apr', 'relation_list': 'p-a+p-r+a-p+r-p'}, + 'yelp': {'data_type': 'busl', 'relation_list': 'b-u+u-b+b-s+s-b+b-l+l-b'} + } + self.__dict__.update(data_conf[dataset]) + self.dataset = dataset + self.data_path = f'data/{dataset}/' + + return + + +class FileConfig: + + def __init__(self, cf: ModelConfig): + ''' + 1. Set f_prefix for each model. The f_prefix stores the important hyperparamters (tuned parameters) of the model. + 2. Generate the file names using f_prefix. + 3. Create required directories. + ''' + if cf.model[:4] == 'HGSL': + f_prefix = f'do{cf.dropout}_lr{cf.lr}_a{cf.alpha}_tr{cf.fgd_th}-{cf.fgh_th}-{cf.sem_th}_mpl{uf.mp_list_str(cf.mp_list)}' + self.res_file = f'results/{cf.dataset}/{cf.model}/{cf.exp_name}/<{cf.model}>{f_prefix}.txt' + self.checkpoint_file = f'temp/{cf.model}/{cf.dataset}/{f_prefix}{uf.get_cur_time()}.pt' + uf.mkdir_list(self.__dict__.values()) diff --git a/src/train.py b/src/train.py new file mode 100644 index 0000000..464aef8 --- /dev/null +++ b/src/train.py @@ -0,0 +1,94 @@ +import os +import sys + +cur_path = os.path.abspath(os.path.dirname(__file__)) +root_path = cur_path.split('src')[0] +sys.path.append(root_path + 'src') +os.chdir(root_path) +from early_stopper import * +from hin_loader import HIN +from evaluation import * +import util_funcs as uf +from config import HGSLConfig +from HGSL import HGSL +import warnings +import time +import torch +import argparse + +warnings.filterwarnings('ignore') +root_path = os.path.abspath(os.path.dirname(__file__)).split('src')[0] + + +def train_hgsl(args, gpu_id=0, log_on=True): + uf.seed_init(args.seed) + uf.shell_init(gpu_id=gpu_id) + cf = HGSLConfig(args.dataset) + + # ! Modify config + cf.update(args.__dict__) + cf.dev = torch.device("cuda:0" if gpu_id >= 0 else "cpu") + + # ! Load Graph + g = HIN(cf.dataset).load_mp_embedding(cf) + print(f'Dataset: {cf.dataset}, {g.t_info}') + features, adj, mp_emb, train_x, train_y, val_x, val_y, test_x, test_y = g.to_torch(cf) + + # ! Train Init + if not log_on: uf.block_logs() + print(f'{cf}\nStart training..') + cla_loss = torch.nn.NLLLoss() + model = HGSL(cf, g) + model.to(cf.dev) + optimizer = torch.optim.Adam( + model.parameters(), lr=cf.lr, weight_decay=cf.weight_decay) + stopper = EarlyStopping(patience=cf.early_stop, path=cf.checkpoint_file) + + dur = [] + w_list = [] + for epoch in range(cf.epochs): + # ! Train + t0 = time.time() + model.train() + logits, adj_new = model(features, adj, mp_emb) + train_f1, train_mif1 = eval_logits(logits, train_x, train_y) + w_list.append(uf.print_weights(model)) + + l_pred = cla_loss(logits[train_x], train_y) + l_reg = cf.alpha * torch.norm(adj, 1) + loss = l_pred + l_reg + optimizer.zero_grad() + with torch.autograd.detect_anomaly(): + loss.backward() + optimizer.step() + + # ! Valid + model.eval() + with torch.no_grad(): + logits = model.GCN(features, adj_new) + val_f1, val_mif1 = eval_logits(logits, val_x, val_y) + dur.append(time.time() - t0) + uf.print_train_log(epoch, dur, loss, train_f1, val_f1) + + if cf.early_stop > 0: + if stopper.step(val_f1, model, epoch): + print(f'Early stopped, loading model from epoch-{stopper.best_epoch}') + break + + if cf.early_stop > 0: + model.load_state_dict(torch.load(cf.checkpoint_file)) + logits, _ = model(features, adj, mp_emb) + cf.update(w_list[stopper.best_epoch]) + eval_and_save(cf, logits, test_x, test_y, val_x, val_y, stopper) + if not log_on: uf.enable_logs() + return cf + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + dataset = 'yelp' + parser.add_argument('--dataset', type=str, default=dataset) + parser.add_argument('--gpu_id', type=int, default=0) + parser.add_argument('--seed', type=int, default=0) + args = parser.parse_args() + train_hgsl(args, gpu_id=args.gpu_id) diff --git a/src/util_funcs.py b/src/util_funcs.py new file mode 100644 index 0000000..b72c9d8 --- /dev/null +++ b/src/util_funcs.py @@ -0,0 +1,264 @@ +import logging +import os +import pickle +import sys +import time +from pathlib import Path + +import numpy as np +import torch +import torch.nn.functional as F + + +# * ============================= Init ============================= +def shell_init(server='S5', gpu_id=0): + ''' + + Features: + 1. Specify server specific source and python command + 2. Fix Pycharm LD_LIBRARY_ISSUE + 3. Block warnings + 4. Block TF useless messages + 5. Set paths + ''' + import warnings + np.seterr(invalid='ignore') + warnings.filterwarnings("ignore", category=FutureWarning) + warnings.filterwarnings("ignore", category=UserWarning) + warnings.filterwarnings("ignore", category=RuntimeWarning) + + if server == 'Xy': + python_command = '/home/chopin/zja/anaconda/bin/python' + elif server == 'Colab': + python_command = 'python' + else: + python_command = '~/anaconda3/bin/python' + if gpu_id > 0: + os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_id) + os.environ['LD_LIBRARY_PATH'] = '/usr/local/cuda/lib64/' # Extremely useful for Pycharm users + os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # Block TF messages + + return python_command + + +def seed_init(seed): + import torch + import random + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +# * ============================= Torch ============================= + +def exists_zero_lines(h): + zero_lines = torch.where(torch.sum(h, 1) == 0)[0] + if len(zero_lines) > 0: + # raise ValueError('{} zero lines in {}s!\nZero lines:{}'.format(len(zero_lines), 'emb', zero_lines)) + print(f'{len(zero_lines)} zero lines !\nZero lines:{zero_lines}') + return True + return False + + +def cos_sim(a, b, eps=1e-8): + """ + calculate cosine similarity between matrix a and b + """ + a_n, b_n = a.norm(dim=1)[:, None], b.norm(dim=1)[:, None] + a_norm = a / torch.max(a_n, eps * torch.ones_like(a_n)) + b_norm = b / torch.max(b_n, eps * torch.ones_like(b_n)) + sim_mt = torch.mm(a_norm, b_norm.transpose(0, 1)) + return sim_mt + + +# * ============================= Print Related ============================= + +def print_dict(d, end_string='\n\n'): + for key in d.keys(): + if isinstance(d[key], dict): + print('\n', end='') + print_dict(d[key], end_string='') + elif isinstance(d[key], int): + print('{}: {:04d}'.format(key, d[key]), end=', ') + elif isinstance(d[key], float): + print('{}: {:.4f}'.format(key, d[key]), end=', ') + else: + print('{}: {}'.format(key, d[key]), end=', ') + print(end_string, end='') + + +def block_logs(): + sys.stdout = open(os.devnull, 'w') + logger = logging.getLogger() + logger.disabled = True + + +def enable_logs(): + # Restore + sys.stdout = sys.__stdout__ + logger = logging.getLogger() + logger.disabled = False + + +def progress_bar(prefix, start_time, i, max_i, postfix): + """ + Generates progress bar AFTER the ith epoch. + Args: + prefix: the prefix of printed string + start_time: start time of the loop + i: finished epoch index + max_i: total iteration times + postfix: the postfix of printed string + + Returns: prints the generated progress bar + + """ + cur_run_time = time.time() - start_time + i += 1 + if i != 0: + total_estimated_time = cur_run_time * max_i / i + else: + total_estimated_time = 0 + print( + f'{prefix} : {i}/{max_i} [{time2str(cur_run_time)}/{time2str(total_estimated_time)}, {time2str(total_estimated_time - cur_run_time)} left] - {postfix}-{get_cur_time()}') + + +def print_train_log(epoch, dur, loss, train_f1, val_f1): + print( + f"Epoch {epoch:05d} | Time(s) {np.mean(dur):.4f} | Loss {loss.item():.4f} | TrainF1 {train_f1:.4f} | ValF1 {val_f1:.4f}") + + +def mp_list_str(mp_list): + return '_'.join(mp_list) + + +# * ============================= File Operations ============================= + +def write_nested_dict(d, f_path): + def _write_dict(d, f): + for key in d.keys(): + if isinstance(d[key], dict): + f.write(str(d[key]) + '\n') + + with open(f_path, 'a+') as f: + f.write('\n') + _write_dict(d, f) + + +def save_pickle(var, f_name): + pickle.dump(var, open(f_name, 'wb')) + + +def load_pickle(f_name): + return pickle.load(open(f_name, 'rb')) + + +def clear_results(dataset, model): + res_path = f'results/{dataset}/{model}/' + os.system(f'rm -rf {res_path}') + print(f'Results in {res_path} are cleared.') + + +# * ============================= Path Operations ============================= + +def check_path(path): + if not os.path.exists(path): + os.makedirs(path) + + +def get_dir_of_file(f_name): + return os.path.dirname(f_name) + '/' + + +def get_grand_parent_dir(f_name): + if '.' in f_name.split('/')[-1]: # File + return get_grand_parent_dir(get_dir_of_file(f_name)) + else: # Path + return f'{Path(f_name).parent}/' + + +def get_abs_path(f_name, style='command_line'): + # python 中的文件目录对空格的处理为空格,命令行对空格的处理为'\ '所以命令行相关需 replace(' ','\ ') + if style == 'python': + cur_path = os.path.abspath(os.path.dirname(__file__)) + elif style == 'command_line': + cur_path = os.path.abspath(os.path.dirname(__file__)).replace(' ', '\ ') + + root_path = cur_path.split('src')[0] + return os.path.join(root_path, f_name) + + +def mkdir_p(path, log=True): + """Create a directory for the specified path. + Parameters + ---------- + path : str + Path name + log : bool + Whether to print result for directory creation + """ + import errno + if os.path.exists(path): return + # print(path) + # path = path.replace('\ ',' ') + # print(path) + try: + + os.makedirs(path) + if log: + print('Created directory {}'.format(path)) + except OSError as exc: + if exc.errno == errno.EEXIST and os.path.isdir(path) and log: + print('Directory {} already exists.'.format(path)) + else: + raise + + +def mkdir_list(p_list, use_relative_path=True, log=True): + """Create directories for the specified path lists. + Parameters + ---------- + p_list :Path lists + + """ + # ! Note that the paths MUST END WITH '/' !!! + root_path = os.path.abspath(os.path.dirname(__file__)).split('src')[0] + for p in p_list: + p = os.path.join(root_path, p) if use_relative_path else p + p = os.path.dirname(p) + mkdir_p(p, log) + + +# * ============================= Time Related ============================= + +def time2str(t): + if t > 86400: + return '{:.2f}day'.format(t / 86400) + if t > 3600: + return '{:.2f}h'.format(t / 3600) + elif t > 60: + return '{:.2f}min'.format(t / 60) + else: + return '{:.2f}s'.format(t) + + +def get_cur_time(): + import datetime + dt = datetime.datetime.now() + return f'{dt.date()}_{dt.hour:02d}-{dt.minute:02d}-{dt.second:02d}' + + +# * ============================= Others ============================= +def print_weights(model, interested_para='_agg'): + w_dict = {} + for name, W in model.named_parameters(): + if interested_para in name: + data = F.softmax(W.data.squeeze()).cpu().numpy() + # print(f'{name}:{data}') + w_dict[name] = data + return w_dict + + +def count_avg_neighbors(adj): + return len(torch.where(adj > 0)[0]) / adj.shape[0]