diff --git a/.DS_Store b/.DS_Store
new file mode 100644
index 0000000..e9ab4fc
Binary files /dev/null and b/.DS_Store differ
diff --git a/README.md b/README.md
index 4cf0b29..7dc3eb7 100644
--- a/README.md
+++ b/README.md
@@ -1,3 +1,14 @@
-# AAAI21-HGSL
-Source code of AAAI21-Heterogeneous Graph Structure Learning for Graph Neural Networks
-To be released soon...
+# HGSL
+Source code of AAAI submission "Heterogeneous Graph Structure Learning for Graph Neural Networks"
+# Requirements
+## Python Packages
+- Python >= 3.6.8
+- Pytorch >= 1.3.0
+- DGL == 0.4.3
+## GPU Memmory Requirements
+- ACM >= 8G
+- DBLP >=5G
+- Yelp >=3G 
+# Usage
+Take DBLP dataset as an example:
+python train.py --dataset='dblp' 
diff --git a/data/.DS_Store b/data/.DS_Store
new file mode 100644
index 0000000..fd1406b
Binary files /dev/null and b/data/.DS_Store differ
diff --git a/data/acm/.DS_Store b/data/acm/.DS_Store
new file mode 100644
index 0000000..5008ddf
Binary files /dev/null and b/data/acm/.DS_Store differ
diff --git a/data/acm/edges.pkl b/data/acm/edges.pkl
new file mode 100644
index 0000000..2af3175
Binary files /dev/null and b/data/acm/edges.pkl differ
diff --git a/data/acm/labels.pkl b/data/acm/labels.pkl
new file mode 100644
index 0000000..ccdaa85
Binary files /dev/null and b/data/acm/labels.pkl differ
diff --git a/data/acm/meta_data.pkl b/data/acm/meta_data.pkl
new file mode 100644
index 0000000..1dc25fb
Binary files /dev/null and b/data/acm/meta_data.pkl differ
diff --git a/data/acm/node_features.pkl b/data/acm/node_features.pkl
new file mode 100644
index 0000000..2ebd2d1
Binary files /dev/null and b/data/acm/node_features.pkl differ
diff --git a/data/acm/pap_emb.pkl b/data/acm/pap_emb.pkl
new file mode 100644
index 0000000..cb32505
Binary files /dev/null and b/data/acm/pap_emb.pkl differ
diff --git a/data/acm/psp_emb.pkl b/data/acm/psp_emb.pkl
new file mode 100644
index 0000000..8179be5
Binary files /dev/null and b/data/acm/psp_emb.pkl differ
diff --git a/data/acm/pspap_emb.pkl b/data/acm/pspap_emb.pkl
new file mode 100644
index 0000000..05554f1
Binary files /dev/null and b/data/acm/pspap_emb.pkl differ
diff --git a/data/dblp/apcpa_emb.pkl b/data/dblp/apcpa_emb.pkl
new file mode 100644
index 0000000..5727e77
Binary files /dev/null and b/data/dblp/apcpa_emb.pkl differ
diff --git a/data/dblp/edges.pkl b/data/dblp/edges.pkl
new file mode 100644
index 0000000..20d9a88
Binary files /dev/null and b/data/dblp/edges.pkl differ
diff --git a/data/dblp/labels.pkl b/data/dblp/labels.pkl
new file mode 100644
index 0000000..bd6023b
Binary files /dev/null and b/data/dblp/labels.pkl differ
diff --git a/data/dblp/meta_data.pkl b/data/dblp/meta_data.pkl
new file mode 100644
index 0000000..bbf3b24
Binary files /dev/null and b/data/dblp/meta_data.pkl differ
diff --git a/data/dblp/node_features.pkl b/data/dblp/node_features.pkl
new file mode 100644
index 0000000..549bc6c
Binary files /dev/null and b/data/dblp/node_features.pkl differ
diff --git a/data/yelp/.DS_Store b/data/yelp/.DS_Store
new file mode 100644
index 0000000..5008ddf
Binary files /dev/null and b/data/yelp/.DS_Store differ
diff --git a/data/yelp/bsb_emb.pkl b/data/yelp/bsb_emb.pkl
new file mode 100644
index 0000000..b7eed31
Binary files /dev/null and b/data/yelp/bsb_emb.pkl differ
diff --git a/data/yelp/bub_emb.pkl b/data/yelp/bub_emb.pkl
new file mode 100644
index 0000000..60008e1
Binary files /dev/null and b/data/yelp/bub_emb.pkl differ
diff --git a/data/yelp/bublb_emb.pkl b/data/yelp/bublb_emb.pkl
new file mode 100644
index 0000000..b653300
Binary files /dev/null and b/data/yelp/bublb_emb.pkl differ
diff --git a/data/yelp/bubsb_emb.pkl b/data/yelp/bubsb_emb.pkl
new file mode 100644
index 0000000..60ae4d3
Binary files /dev/null and b/data/yelp/bubsb_emb.pkl differ
diff --git a/data/yelp/edges.pkl b/data/yelp/edges.pkl
new file mode 100644
index 0000000..c6741a8
Binary files /dev/null and b/data/yelp/edges.pkl differ
diff --git a/data/yelp/labels.pkl b/data/yelp/labels.pkl
new file mode 100644
index 0000000..12ce561
Binary files /dev/null and b/data/yelp/labels.pkl differ
diff --git a/data/yelp/meta_data.pkl b/data/yelp/meta_data.pkl
new file mode 100644
index 0000000..c909861
Binary files /dev/null and b/data/yelp/meta_data.pkl differ
diff --git a/data/yelp/node_features.pkl b/data/yelp/node_features.pkl
new file mode 100644
index 0000000..0240038
Binary files /dev/null and b/data/yelp/node_features.pkl differ
diff --git a/src/.DS_Store b/src/.DS_Store
new file mode 100644
index 0000000..46fe8d9
Binary files /dev/null and b/src/.DS_Store differ
diff --git a/src/HGSL.py b/src/HGSL.py
new file mode 100644
index 0000000..f6982b5
--- /dev/null
+++ b/src/HGSL.py
@@ -0,0 +1,200 @@
+import math
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.nn.parameter import Parameter
+
+from util_funcs import cos_sim
+
+
+class HGSL(nn.Module):
+    """
+    Decode neighbors of input graph.
+    """
+
+    def __init__(self, cf, g):
+        super(HGSL, self).__init__()
+        self.__dict__.update(cf.get_model_conf())
+        # ! Init variables
+        self.dev = cf.dev
+        self.ti, self.ri, self.types, self.ud_rels = g.t_info, g.r_info, g.types, g.undirected_relations
+        feat_dim, mp_emb_dim = g.features.shape[1], list(g.mp_emb_dict.values())[0].shape[1]
+        self.non_linear = nn.ReLU()
+        # ! Graph Structure Learning
+        MD = nn.ModuleDict
+        self.fgg_direct, self.fgg_left, self.fgg_right, self.fg_agg, self.sgg_gen, self.sg_agg, self.overall_g_agg = \
+            MD({}), MD({}), MD({}), MD({}), MD({}), MD({}), MD({})
+        # Feature encoder
+        self.encoder = MD(dict(zip(g.types, [nn.Linear(g.features.shape[1], cf.com_feat_dim) for _ in g.types])))
+
+        for r in g.undirected_relations:
+            # ! Feature Graph Generator
+            self.fgg_direct[r] = GraphGenerator(cf.com_feat_dim, cf.num_head, cf.fgd_th, self.dev)
+            self.fgg_left[r] = GraphGenerator(feat_dim, cf.num_head, cf.fgh_th, self.dev)
+            self.fgg_right[r] = GraphGenerator(feat_dim, cf.num_head, cf.fgh_th, self.dev)
+            self.fg_agg[r] = GraphChannelAttLayer(3)  # 3 = 1 (first-order/direct) + 2 (second-order)
+
+            # ! Semantic Graph Generator
+            self.sgg_gen[r] = MD(dict(
+                zip(cf.mp_list, [GraphGenerator(mp_emb_dim, cf.num_head, cf.sem_th, self.dev) for _ in cf.mp_list])))
+            self.sg_agg[r] = GraphChannelAttLayer(len(cf.mp_list))
+
+            # ! Overall Graph Generator
+            self.overall_g_agg[r] = GraphChannelAttLayer(3, [1, 1, 10])  # 3 = feat-graph + sem-graph + ori_graph
+            # self.overall_g_agg[r] = GraphChannelAttLayer(3)  # 3 = feat-graph + sem-graph + ori_graph
+
+        # ! Graph Convolution
+        if cf.conv_method == 'gcn':
+            self.GCN = GCN(g.n_feat, cf.emb_dim, g.n_class, cf.dropout)
+        self.norm_order = cf.adj_norm_order
+
+    def forward(self, features, adj_ori, mp_emb):
+        def get_rel_mat(mat, r):
+            return mat[self.ri[r][0]:self.ri[r][1], self.ri[r][2]:self.ri[r][3]]
+
+        def get_type_rows(mat, type):
+            return mat[self.ti[type]['ind'], :]
+
+        def gen_g_via_feat(graph_gen_func, mat, r):
+            return graph_gen_func(get_type_rows(mat, r[0]), get_type_rows(mat, r[-1]))
+
+        # ! Heterogeneous Feature Mapping
+        com_feat_mat = torch.cat([self.non_linear(
+            self.encoder[t](features[self.ti[t]['ind']])) for t in self.types])
+
+        # ! Heterogeneous Graph Generation
+        new_adj = torch.zeros_like(adj_ori).to(self.dev)
+        for r in self.ud_rels:
+            ori_g = get_rel_mat(adj_ori, r)
+            # ! Feature Graph Generation
+            fg_direct = gen_g_via_feat(self.fgg_direct[r], com_feat_mat, r)
+
+            fmat_l, fmat_r = features[self.ti[r[0]]['ind']], features[self.ti[r[-1]]['ind']]
+            sim_l, sim_r = self.fgg_left[r](fmat_l, fmat_l), self.fgg_right[r](fmat_r, fmat_r)
+            fg_left, fg_right = sim_l.mm(ori_g), sim_r.mm(ori_g.t()).t()
+
+            feat_g = self.fg_agg[r]([fg_direct, fg_left, fg_right])
+
+            # ! Semantic Graph Generation
+            sem_g_list = [gen_g_via_feat(self.sgg_gen[r][mp], mp_emb[mp], r) for mp in mp_emb]
+            sem_g = self.sg_agg[r](sem_g_list)
+            # ! Overall Graph
+            # Update relation sub-matixs
+            new_adj[self.ri[r][0]:self.ri[r][1], self.ri[r][2]:self.ri[r][3]] = \
+                self.overall_g_agg[r]([feat_g, sem_g, ori_g])  # update edge  e.g. AP
+
+        new_adj += new_adj.t()  # sysmetric
+        # ! Aggregate
+        new_adj = F.normalize(new_adj, dim=0, p=self.norm_order)
+        logits = self.GCN(features, new_adj)
+        return logits, new_adj
+
+
+class MetricCalcLayer(nn.Module):
+    def __init__(self, nhid):
+        super().__init__()
+        self.weight = nn.Parameter(torch.FloatTensor(1, nhid))
+        nn.init.xavier_uniform_(self.weight)
+
+    def forward(self, h):
+        return h * self.weight
+
+
+class GraphGenerator(nn.Module):
+    """
+    Generate graph using similarity.
+    """
+
+    def __init__(self, dim, num_head=2, threshold=0.1, dev=None):
+        super(GraphGenerator, self).__init__()
+        self.threshold = threshold
+        self.metric_layer = nn.ModuleList()
+        for i in range(num_head):
+            self.metric_layer.append(MetricCalcLayer(dim))
+        self.num_head = num_head
+        self.dev = dev
+
+    def forward(self, left_h, right_h):
+        """
+
+        Args:
+            left_h: left_node_num * hidden_dim/feat_dim
+            right_h: right_node_num * hidden_dim/feat_dim
+        Returns:
+
+        """
+        if torch.sum(left_h) == 0 or torch.sum(right_h) == 0:
+            return torch.zeros((left_h.shape[0], right_h.shape[0])).to(self.dev)
+        s = torch.zeros((left_h.shape[0], right_h.shape[0])).to(self.dev)
+        zero_lines = torch.nonzero(torch.sum(left_h, 1) == 0)
+        # The ReLU function will generate zero lines, which lead to the nan (devided by zero) problem.
+        if len(zero_lines) > 0:
+            left_h[zero_lines, :] += 1e-8
+        for i in range(self.num_head):
+            weighted_left_h = self.metric_layer[i](left_h)
+            weighted_right_h = self.metric_layer[i](right_h)
+            s += cos_sim(weighted_left_h, weighted_right_h)
+        s /= self.num_head
+        s = torch.where(s < self.threshold, torch.zeros_like(s), s)
+        return s
+
+
+class GCN(nn.Module):
+    def __init__(self, nfeat, nhid, nclass, dropout):
+        super(GCN, self).__init__()
+        self.gc1 = GraphConvolution(nfeat, nhid)
+        self.gc2 = GraphConvolution(nhid, nclass)
+        self.dropout = dropout
+
+    def forward(self, x, adj):
+        x = F.relu(self.gc1(x, adj))
+        x = F.dropout(x, self.dropout, training=self.training)
+        x = self.gc2(x, adj)
+        return F.log_softmax(x, dim=1)
+        # return x
+
+
+class GraphChannelAttLayer(nn.Module):
+
+    def __init__(self, num_channel, weights=None):
+        super(GraphChannelAttLayer, self).__init__()
+        self.weight = nn.Parameter(torch.Tensor(num_channel, 1, 1))
+        nn.init.constant_(self.weight, 0.1)  # equal weight
+        # if weights != None:
+        #     # self.weight.data = nn.Parameter(torch.Tensor(weights).reshape(self.weight.shape))
+        #     with torch.no_grad():
+        #         w = torch.Tensor(weights).reshape(self.weight.shape)
+        #         self.weight.copy_(w)
+
+    def forward(self, adj_list):
+        adj_list = torch.stack(adj_list)
+        # Row normalization of all graphs generated
+        adj_list = F.normalize(adj_list, dim=1, p=1)
+        # Hadamard product + summation -> Conv
+        return torch.sum(adj_list * F.softmax(self.weight, dim=0), dim=0)
+
+class GraphConvolution(nn.Module):  # GCN AHW
+    def __init__(self, in_features, out_features, bias=True):
+        super(GraphConvolution, self).__init__()
+        self.in_features = in_features
+        self.out_features = out_features
+        self.weight = Parameter(torch.FloatTensor(in_features, out_features))
+        if bias:
+            self.bias = Parameter(torch.FloatTensor(out_features))
+        else:
+            self.register_parameter('bias', None)
+        self.reset_parameters()
+
+    def reset_parameters(self):
+        stdv = 1. / math.sqrt(self.weight.size(1))
+        self.weight.data.uniform_(-stdv, stdv)
+        if self.bias is not None:
+            self.bias.data.uniform_(-stdv, stdv)
+
+    def forward(self, inputs, adj):
+        support = torch.spmm(inputs, self.weight)  # HW in GCN
+        output = torch.spmm(adj, support)  # AHW
+        if self.bias is not None:
+            return output + self.bias
+        else:
+            return output
diff --git a/src/__init__.py b/src/__init__.py
new file mode 100644
index 0000000..99dbaa4
--- /dev/null
+++ b/src/__init__.py
@@ -0,0 +1,2 @@
+from .config import *
+from .HGSL import HGSL
\ No newline at end of file
diff --git a/src/config.py b/src/config.py
new file mode 100644
index 0000000..6c16a61
--- /dev/null
+++ b/src/config.py
@@ -0,0 +1,35 @@
+from shared_configs import ModelConfig, DataConfig
+
+e = 2.71828
+
+
+class HGSLConfig(ModelConfig):
+    def __init__(self, dataset, seed=0):
+        super(HGSLConfig, self).__init__('HGSL')
+        default_settings = \
+            {'acm': {'alpha': 1, 'dropout': 0, 'fgd_th': 0.8, 'fgh_th': 0.2, 'sem_th': 0.6,
+                     'mp_list': ['psp', 'pap', 'pspap']},
+             'dblp': {'alpha': 4.5, 'dropout': 0.2, 'fgd_th': 0.99, 'fgh_th': 0.99, 'sem_th': 0.4, 'mp_list': ['apcpa']},
+             'yelp': {'alpha': 0.5, 'dropout': 0.2, 'fgd_th': 0.8, 'fgh_th': 0.1, 'sem_th': 0.2,
+                      'mp_list': ['bub', 'bsb', 'bublb', 'bubsb']}
+             }
+        self.dataset = dataset
+        self.__dict__.update(default_settings[dataset])
+        # ! Model settings
+        self.lr = 0.01
+        self.seed = seed
+        self.save_model_conf_list()  # * Save the model config list keys
+        self.conv_method = 'gcn'
+        self.num_head = 2
+        self.early_stop = 80
+        self.adj_norm_order = 1
+        self.feat_norm = -1
+        self.emb_dim = 64
+        self.com_feat_dim = 16
+        self.weight_decay = 5e-4
+        self.model = 'HGSL'
+        self.epochs = 200
+        self.exp_name = 'debug'
+        self.save_weights = False
+        d_conf = DataConfig(dataset)
+        self.__dict__.update(d_conf.__dict__)
diff --git a/src/early_stopper.py b/src/early_stopper.py
new file mode 100644
index 0000000..51ff844
--- /dev/null
+++ b/src/early_stopper.py
@@ -0,0 +1,39 @@
+"""
+Early stop provided by DGL
+"""
+import numpy as np
+import torch
+
+
+class EarlyStopping:
+    def __init__(self, patience=10, path='es_checkpoint.pt'):
+        self.patience = patience
+        self.counter = 0
+        self.best_score = None
+        self.best_epoch = None
+        self.early_stop = False
+        self.path = path
+
+    def step(self, acc, model, epoch):
+        score = acc
+        if self.best_score is None:
+            self.best_score = score
+            self.best_epoch = epoch
+            self.save_checkpoint(model)
+        elif score < self.best_score:
+            self.counter += 1
+            print(
+                f'EarlyStopping counter: {self.counter}/{self.patience}, best_val_score:{self.best_score:.4f} at E{self.best_epoch}')
+            if self.counter >= self.patience:
+                self.early_stop = True
+        else:
+            self.best_score = score
+            self.best_epoch = epoch
+            self.save_checkpoint(model)
+            self.counter = 0
+
+        return self.early_stop
+
+    def save_checkpoint(self, model):
+        '''Saves model when validation loss decrease.'''
+        torch.save(model.state_dict(), self.path)
diff --git a/src/evaluation.py b/src/evaluation.py
new file mode 100644
index 0000000..8f4d327
--- /dev/null
+++ b/src/evaluation.py
@@ -0,0 +1,87 @@
+import torch
+import util_funcs as uf
+
+
+def torch_f1_score(pred, target, n_class):
+    '''
+    Returns macro-f1 and micro-f1 score
+    Args:
+        pred:
+        target:
+        n_class:
+
+    Returns:
+        ma_f1,mi_f1: numpy values of macro-f1 and micro-f1 scores.
+    '''
+
+    def true_positive(pred, target, n_class):
+        return torch.tensor([((pred == i) & (target == i)).sum()
+                             for i in range(n_class)])
+
+    def false_positive(pred, target, n_class):
+        return torch.tensor([((pred == i) & (target != i)).sum()
+                             for i in range(n_class)])
+
+    def false_negative(pred, target, n_class):
+        return torch.tensor([((pred != i) & (target == i)).sum()
+                             for i in range(n_class)])
+
+    def precision(tp, fp):
+        res = tp / (tp + fp)
+        res[torch.isnan(res)] = 0
+        return res
+
+    def recall(tp, fn):
+        res = tp / (tp + fn)
+        res[torch.isnan(res)] = 0
+        return res
+
+    def f1_score(prec, rec):
+        f1_score = 2 * (prec * rec) / (prec + rec)
+        f1_score[torch.isnan(f1_score)] = 0
+        return f1_score
+
+    def cal_maf1(tp, fp, fn):
+        prec = precision(tp, fp)
+        rec = recall(tp, fn)
+        ma_f1 = f1_score(prec, rec)
+        return torch.mean(ma_f1).cpu().numpy()
+
+    def cal_mif1(tp, fp, fn):
+        gl_tp, gl_fp, gl_fn = torch.sum(tp), torch.sum(fp), torch.sum(fn)
+        gl_prec = precision(gl_tp, gl_fp)
+        gl_rec = recall(gl_tp, gl_fn)
+        mi_f1 = f1_score(gl_prec, gl_rec)
+        return mi_f1.cpu().numpy()
+
+    tp = true_positive(pred, target, n_class).to(torch.float)
+    fn = false_negative(pred, target, n_class).to(torch.float)
+    fp = false_positive(pred, target, n_class).to(torch.float)
+
+    ma_f1, mi_f1 = cal_maf1(tp, fp, fn), cal_mif1(tp, fp, fn)
+    return ma_f1, mi_f1
+
+
+def eval_logits(logits, target_x, target_y):
+    pred_y = torch.argmax(logits[target_x], dim=1)
+    return torch_f1_score(pred_y, target_y, n_class=logits.shape[1])
+
+
+def eval_and_save(cf, logits, test_x, test_y, val_x, val_y, stopper=None, res={}):
+    test_f1, test_mif1 = eval_logits(logits, test_x, test_y)
+    val_f1, val_mif1 = eval_logits(logits, val_x, val_y)
+    save_results(cf, test_f1, val_f1, test_mif1, val_mif1, stopper, res)
+
+
+def save_results(cf, test_f1, val_f1, test_mif1=0, val_mif1=0, stopper=None, res={}):
+    if stopper != None:
+        res.update({'test_f1': f'{test_f1:.4f}', 'test_mif1': f'{test_mif1:.4f}',
+                    'val_f1': f'{val_f1:.4f}', 'val_mif1': f'{val_mif1:.4f}',
+                    'best_epoch': stopper.best_epoch})
+    else:
+        res.update({'test_f1': f'{test_f1:.4f}', 'test_mif1': f'{test_mif1:.4f}',
+                    'val_f1': f'{val_f1:.4f}', 'val_mif1': f'{val_mif1:.4f}'})
+    print(f"Seed{cf.seed}")
+    res_dict = {'res': res, 'parameters': cf.get_model_conf()}
+    print(f'\n\n\nTrain finished, results:{res_dict}')
+    uf.write_nested_dict(res_dict, cf.res_file)
diff --git a/src/hin_loader.py b/src/hin_loader.py
new file mode 100644
index 0000000..9dacef9
--- /dev/null
+++ b/src/hin_loader.py
@@ -0,0 +1,79 @@
+import pickle
+import numpy as np
+import torch
+import torch.nn.functional as F
+import scipy
+
+
+class HIN(object):
+
+    def __init__(self, dataset):
+        data_path = f'data/{dataset}/'
+        with open(f'{data_path}node_features.pkl', 'rb') as f:
+            self.features = pickle.load(f)
+        with open(f'{data_path}edges.pkl', 'rb') as f:
+            self.edges = pickle.load(f)
+        with open(f'{data_path}labels.pkl', 'rb') as f:
+            self.labels = pickle.load(f)
+        with open(f'{data_path}meta_data.pkl', 'rb') as f:
+            self.__dict__.update(pickle.load(f))
+        if scipy.sparse.issparse(self.features):
+            self.features = self.features.todense()
+
+    def to_torch(self, cf):
+        '''
+        Returns the torch tensor of the graph.
+        Args:
+            cf: The ModelConfig file.
+        Returns:
+            features, adj: feature and adj. matrix
+            mp_emb: only available for models that uses mp_list.
+            train_x, train_y, val_x, val_y, test_x, test_y: train/val/test index and labels
+        '''
+        features = torch.from_numpy(self.features).type(torch.FloatTensor).to(cf.dev)
+        train_x, train_y, val_x, val_y, test_x, test_y = self.get_label(cf.dev)
+
+        adj = np.sum(list(self.edges.values())).todense()
+        adj = torch.from_numpy(adj).type(torch.FloatTensor).to(cf.dev)
+        adj = F.normalize(adj, dim=1, p=2)
+
+        mp_emb = {}
+        if hasattr(cf, 'mp_list'):
+            for mp in cf.mp_list:
+                mp_emb[mp] = torch.from_numpy(self.mp_emb_dict[mp]).type(torch.FloatTensor).to(cf.dev)
+        if hasattr(cf, 'feat_norm'):
+            if cf.feat_norm > 0:
+                features = F.normalize(features, dim=1, p=cf.feat_norm)
+                for mp in cf.mp_list:
+                    mp_emb[mp] = F.normalize(mp_emb[mp], dim=1, p=cf.feat_norm)
+        return features, adj, mp_emb, train_x, train_y, val_x, val_y, test_x, test_y
+
+    def load_mp_embedding(self, cf):
+        '''Load pretrained mp_embedding'''
+        self.mp_emb_dict = {}
+        for mp in cf.mp_list:
+            f_name = f'{cf.data_path}{mp}_emb.pkl'
+            with open(f_name, 'rb') as f:
+                z = pickle.load(f)
+                zero_lines = np.nonzero(np.sum(z, 1) == 0)
+                if len(zero_lines) > 0:
+                    # raise ValueError('{} zero lines in {}s!\nZero lines:{}'.format(len(zero_lines), mode, zero_lines))
+                    z[zero_lines, :] += 1e-8
+                self.mp_emb_dict[mp] = z
+        return self
+
+    def get_label(self, dev):
+        '''
+        Args:
+            dev: device (cpu or gpu)
+
+        Returns:
+            train_x, train_y, val_x, val_y, test_x, test_y: train/val/test index and labels
+        '''
+        train_x = torch.from_numpy(np.array(self.labels[0])[:, 0]).type(torch.LongTensor).to(dev)
+        train_y = torch.from_numpy(np.array(self.labels[0])[:, 1]).type(torch.LongTensor).to(dev)
+        val_x = torch.from_numpy(np.array(self.labels[1])[:, 0]).type(torch.LongTensor).to(dev)
+        val_y = torch.from_numpy(np.array(self.labels[1])[:, 1]).type(torch.LongTensor).to(dev)
+        test_x = torch.from_numpy(np.array(self.labels[2])[:, 0]).type(torch.LongTensor).to(dev)
+        test_y = torch.from_numpy(np.array(self.labels[2])[:, 1]).type(torch.LongTensor).to(dev)
+        return train_x, train_y, val_x, val_y, test_x, test_y
diff --git a/src/shared_configs.py b/src/shared_configs.py
new file mode 100644
index 0000000..792ef3b
--- /dev/null
+++ b/src/shared_configs.py
@@ -0,0 +1,65 @@
+import util_funcs as uf
+
+
+class ModelConfig():
+    def __init__(self, model):
+        self.model = model
+        self.exp_name = 'default'
+        self.model_conf_list = None
+
+    def __str__(self):
+        # Print all attributes including data and other path settings added to the config object.
+        return str(self.__dict__)
+
+    def save_model_conf_list(self):
+        self.model_conf_list = list(self.__dict__.copy().keys())
+        self.model_conf_list.remove('model_conf_list')
+
+    def update_file_conf(self):
+        f_conf = FileConfig(self)
+        self.__dict__.update(f_conf.__dict__)
+        return self
+
+    def model_conf_to_str(self):
+        # Print the model settings only.
+        return str({k: self.__dict__[k] for k in self.model_conf_list})
+
+    def get_model_conf(self):
+        # Return the model settings only.
+        return {k: self.__dict__[k] for k in self.model_conf_list}
+
+    def update(self, conf_dict):
+        self.__dict__.update(conf_dict)
+        self.update_file_conf()
+        return self
+
+
+class DataConfig:
+    def __init__(self, dataset):
+        data_conf = {
+            'acm': {'data_type': 'pas', 'relation_list': 'p-a+a-p+p-s+s-p'},
+            'dblp': {'data_type': 'apc', 'relation_list': 'p-a+a-p+p-c+c-p'},
+            'imdb': {'data_type': 'mad', 'relation_list': 'm-a+a-m+m-d+d-m'},
+            'aminer': {'data_type': 'apr', 'relation_list': 'p-a+p-r+a-p+r-p'},
+            'yelp': {'data_type': 'busl', 'relation_list': 'b-u+u-b+b-s+s-b+b-l+l-b'}
+        }
+        self.__dict__.update(data_conf[dataset])
+        self.dataset = dataset
+        self.data_path = f'data/{dataset}/'
+
+        return
+
+
+class FileConfig:
+
+    def __init__(self, cf: ModelConfig):
+        '''
+            1. Set f_prefix for each model. The f_prefix stores the important hyperparamters (tuned parameters) of the model.
+            2. Generate the file names using f_prefix.
+            3. Create required directories.
+        '''
+        if cf.model[:4] == 'HGSL':
+            f_prefix = f'do{cf.dropout}_lr{cf.lr}_a{cf.alpha}_tr{cf.fgd_th}-{cf.fgh_th}-{cf.sem_th}_mpl{uf.mp_list_str(cf.mp_list)}'
+        self.res_file = f'results/{cf.dataset}/{cf.model}/{cf.exp_name}/<{cf.model}>{f_prefix}.txt'
+        self.checkpoint_file = f'temp/{cf.model}/{cf.dataset}/{f_prefix}{uf.get_cur_time()}.pt'
+        uf.mkdir_list(self.__dict__.values())
diff --git a/src/train.py b/src/train.py
new file mode 100644
index 0000000..464aef8
--- /dev/null
+++ b/src/train.py
@@ -0,0 +1,94 @@
+import os
+import sys
+
+cur_path = os.path.abspath(os.path.dirname(__file__))
+root_path = cur_path.split('src')[0]
+sys.path.append(root_path + 'src')
+os.chdir(root_path)
+from early_stopper import *
+from hin_loader import HIN
+from evaluation import *
+import util_funcs as uf
+from config import HGSLConfig
+from HGSL import HGSL
+import warnings
+import time
+import torch
+import argparse
+
+warnings.filterwarnings('ignore')
+root_path = os.path.abspath(os.path.dirname(__file__)).split('src')[0]
+
+
+def train_hgsl(args, gpu_id=0, log_on=True):
+    uf.seed_init(args.seed)
+    uf.shell_init(gpu_id=gpu_id)
+    cf = HGSLConfig(args.dataset)
+
+    # ! Modify config
+    cf.update(args.__dict__)
+    cf.dev = torch.device("cuda:0" if gpu_id >= 0 else "cpu")
+
+    # ! Load Graph
+    g = HIN(cf.dataset).load_mp_embedding(cf)
+    print(f'Dataset: {cf.dataset}, {g.t_info}')
+    features, adj, mp_emb, train_x, train_y, val_x, val_y, test_x, test_y = g.to_torch(cf)
+
+    # ! Train Init
+    if not log_on: uf.block_logs()
+    print(f'{cf}\nStart training..')
+    cla_loss = torch.nn.NLLLoss()
+    model = HGSL(cf, g)
+    model.to(cf.dev)
+    optimizer = torch.optim.Adam(
+        model.parameters(), lr=cf.lr, weight_decay=cf.weight_decay)
+    stopper = EarlyStopping(patience=cf.early_stop, path=cf.checkpoint_file)
+
+    dur = []
+    w_list = []
+    for epoch in range(cf.epochs):
+        # ! Train
+        t0 = time.time()
+        model.train()
+        logits, adj_new = model(features, adj, mp_emb)
+        train_f1, train_mif1 = eval_logits(logits, train_x, train_y)
+        w_list.append(uf.print_weights(model))
+
+        l_pred = cla_loss(logits[train_x], train_y)
+        l_reg = cf.alpha * torch.norm(adj, 1)
+        loss = l_pred + l_reg
+        optimizer.zero_grad()
+        with torch.autograd.detect_anomaly():
+            loss.backward()
+        optimizer.step()
+
+        # ! Valid
+        model.eval()
+        with torch.no_grad():
+            logits = model.GCN(features, adj_new)
+            val_f1, val_mif1 = eval_logits(logits, val_x, val_y)
+        dur.append(time.time() - t0)
+        uf.print_train_log(epoch, dur, loss, train_f1, val_f1)
+
+        if cf.early_stop > 0:
+            if stopper.step(val_f1, model, epoch):
+                print(f'Early stopped, loading model from epoch-{stopper.best_epoch}')
+                break
+
+    if cf.early_stop > 0:
+        model.load_state_dict(torch.load(cf.checkpoint_file))
+    logits, _ = model(features, adj, mp_emb)
+    cf.update(w_list[stopper.best_epoch])
+    eval_and_save(cf, logits, test_x, test_y, val_x, val_y, stopper)
+    if not log_on: uf.enable_logs()
+    return cf
+
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser()
+    dataset = 'yelp'
+    parser.add_argument('--dataset', type=str, default=dataset)
+    parser.add_argument('--gpu_id', type=int, default=0)
+    parser.add_argument('--seed', type=int, default=0)
+    args = parser.parse_args()
+    train_hgsl(args, gpu_id=args.gpu_id)
diff --git a/src/util_funcs.py b/src/util_funcs.py
new file mode 100644
index 0000000..b72c9d8
--- /dev/null
+++ b/src/util_funcs.py
@@ -0,0 +1,264 @@
+import logging
+import os
+import pickle
+import sys
+import time
+from pathlib import Path
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+
+
+# * ============================= Init =============================
+def shell_init(server='S5', gpu_id=0):
+    '''
+
+    Features:
+    1. Specify server specific source and python command
+    2. Fix Pycharm LD_LIBRARY_ISSUE
+    3. Block warnings
+    4. Block TF useless messages
+    5. Set paths
+    '''
+    import warnings
+    np.seterr(invalid='ignore')
+    warnings.filterwarnings("ignore", category=FutureWarning)
+    warnings.filterwarnings("ignore", category=UserWarning)
+    warnings.filterwarnings("ignore", category=RuntimeWarning)
+
+    if server == 'Xy':
+        python_command = '/home/chopin/zja/anaconda/bin/python'
+    elif server == 'Colab':
+        python_command = 'python'
+    else:
+        python_command = '~/anaconda3/bin/python'
+        if gpu_id > 0:
+            os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_id)
+        os.environ['LD_LIBRARY_PATH'] = '/usr/local/cuda/lib64/'  # Extremely useful for Pycharm users
+    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'  # Block TF messages
+
+    return python_command
+
+
+def seed_init(seed):
+    import torch
+    import random
+    random.seed(seed)
+    np.random.seed(seed)
+    torch.manual_seed(seed)
+    torch.cuda.manual_seed_all(seed)
+
+
+# * ============================= Torch =============================
+
+def exists_zero_lines(h):
+    zero_lines = torch.where(torch.sum(h, 1) == 0)[0]
+    if len(zero_lines) > 0:
+        # raise ValueError('{} zero lines in {}s!\nZero lines:{}'.format(len(zero_lines), 'emb', zero_lines))
+        print(f'{len(zero_lines)} zero lines !\nZero lines:{zero_lines}')
+        return True
+    return False
+
+
+def cos_sim(a, b, eps=1e-8):
+    """
+    calculate cosine similarity between matrix a and b
+    """
+    a_n, b_n = a.norm(dim=1)[:, None], b.norm(dim=1)[:, None]
+    a_norm = a / torch.max(a_n, eps * torch.ones_like(a_n))
+    b_norm = b / torch.max(b_n, eps * torch.ones_like(b_n))
+    sim_mt = torch.mm(a_norm, b_norm.transpose(0, 1))
+    return sim_mt
+
+
+# * ============================= Print Related =============================
+
+def print_dict(d, end_string='\n\n'):
+    for key in d.keys():
+        if isinstance(d[key], dict):
+            print('\n', end='')
+            print_dict(d[key], end_string='')
+        elif isinstance(d[key], int):
+            print('{}: {:04d}'.format(key, d[key]), end=', ')
+        elif isinstance(d[key], float):
+            print('{}: {:.4f}'.format(key, d[key]), end=', ')
+        else:
+            print('{}: {}'.format(key, d[key]), end=', ')
+    print(end_string, end='')
+
+
+def block_logs():
+    sys.stdout = open(os.devnull, 'w')
+    logger = logging.getLogger()
+    logger.disabled = True
+
+
+def enable_logs():
+    # Restore
+    sys.stdout = sys.__stdout__
+    logger = logging.getLogger()
+    logger.disabled = False
+
+
+def progress_bar(prefix, start_time, i, max_i, postfix):
+    """
+    Generates progress bar AFTER the ith epoch.
+    Args:
+        prefix: the prefix of printed string
+        start_time: start time of the loop
+        i: finished epoch index
+        max_i: total iteration times
+        postfix: the postfix of printed string
+
+    Returns: prints the generated progress bar
+
+    """
+    cur_run_time = time.time() - start_time
+    i += 1
+    if i != 0:
+        total_estimated_time = cur_run_time * max_i / i
+    else:
+        total_estimated_time = 0
+    print(
+        f'{prefix} :  {i}/{max_i} [{time2str(cur_run_time)}/{time2str(total_estimated_time)}, {time2str(total_estimated_time - cur_run_time)} left] - {postfix}-{get_cur_time()}')
+
+
+def print_train_log(epoch, dur, loss, train_f1, val_f1):
+    print(
+        f"Epoch {epoch:05d} | Time(s) {np.mean(dur):.4f} | Loss {loss.item():.4f} | TrainF1 {train_f1:.4f} | ValF1 {val_f1:.4f}")
+
+
+def mp_list_str(mp_list):
+    return '_'.join(mp_list)
+
+
+# * ============================= File Operations =============================
+
+def write_nested_dict(d, f_path):
+    def _write_dict(d, f):
+        for key in d.keys():
+            if isinstance(d[key], dict):
+                f.write(str(d[key]) + '\n')
+
+    with open(f_path, 'a+') as f:
+        f.write('\n')
+        _write_dict(d, f)
+
+
+def save_pickle(var, f_name):
+    pickle.dump(var, open(f_name, 'wb'))
+
+
+def load_pickle(f_name):
+    return pickle.load(open(f_name, 'rb'))
+
+
+def clear_results(dataset, model):
+    res_path = f'results/{dataset}/{model}/'
+    os.system(f'rm -rf {res_path}')
+    print(f'Results in {res_path} are cleared.')
+
+
+# * ============================= Path Operations =============================
+
+def check_path(path):
+    if not os.path.exists(path):
+        os.makedirs(path)
+
+
+def get_dir_of_file(f_name):
+    return os.path.dirname(f_name) + '/'
+
+
+def get_grand_parent_dir(f_name):
+    if '.' in f_name.split('/')[-1]:  # File
+        return get_grand_parent_dir(get_dir_of_file(f_name))
+    else:  # Path
+        return f'{Path(f_name).parent}/'
+
+
+def get_abs_path(f_name, style='command_line'):
+    # python 中的文件目录对空格的处理为空格,命令行对空格的处理为'\ '所以命令行相关需 replace(' ','\ ')
+    if style == 'python':
+        cur_path = os.path.abspath(os.path.dirname(__file__))
+    elif style == 'command_line':
+        cur_path = os.path.abspath(os.path.dirname(__file__)).replace(' ', '\ ')
+
+    root_path = cur_path.split('src')[0]
+    return os.path.join(root_path, f_name)
+
+
+def mkdir_p(path, log=True):
+    """Create a directory for the specified path.
+    Parameters
+    ----------
+    path : str
+        Path name
+    log : bool
+        Whether to print result for directory creation
+    """
+    import errno
+    if os.path.exists(path): return
+    # print(path)
+    # path = path.replace('\ ',' ')
+    # print(path)
+    try:
+
+        os.makedirs(path)
+        if log:
+            print('Created directory {}'.format(path))
+    except OSError as exc:
+        if exc.errno == errno.EEXIST and os.path.isdir(path) and log:
+            print('Directory {} already exists.'.format(path))
+        else:
+            raise
+
+
+def mkdir_list(p_list, use_relative_path=True, log=True):
+    """Create directories for the specified path lists.
+        Parameters
+        ----------
+        p_list :Path lists
+
+    """
+    # ! Note that the paths MUST END WITH '/' !!!
+    root_path = os.path.abspath(os.path.dirname(__file__)).split('src')[0]
+    for p in p_list:
+        p = os.path.join(root_path, p) if use_relative_path else p
+        p = os.path.dirname(p)
+        mkdir_p(p, log)
+
+
+# * ============================= Time Related =============================
+
+def time2str(t):
+    if t > 86400:
+        return '{:.2f}day'.format(t / 86400)
+    if t > 3600:
+        return '{:.2f}h'.format(t / 3600)
+    elif t > 60:
+        return '{:.2f}min'.format(t / 60)
+    else:
+        return '{:.2f}s'.format(t)
+
+
+def get_cur_time():
+    import datetime
+    dt = datetime.datetime.now()
+    return f'{dt.date()}_{dt.hour:02d}-{dt.minute:02d}-{dt.second:02d}'
+
+
+# * ============================= Others =============================
+def print_weights(model, interested_para='_agg'):
+    w_dict = {}
+    for name, W in model.named_parameters():
+        if interested_para in name:
+            data = F.softmax(W.data.squeeze()).cpu().numpy()
+            # print(f'{name}:{data}')
+            w_dict[name] = data
+    return w_dict
+
+
+def count_avg_neighbors(adj):
+    return len(torch.where(adj > 0)[0]) / adj.shape[0]