这篇文章给大家介绍怎样使用Keras和Tensorflow学习图形数据,内容非常详细,感兴趣的小伙伴们可以参考借鉴,希望对大家能有所帮助。
动机:
有很多数据可以在实际应用中以图表的形式表示,如引文网络,社交网络(追随者图,朋友网络......),生物网络或电信。
使用Graph提取的特征可以通过依赖相邻节点之间的信息流来提高预测模型的性能。但是,表示图形数据并不简单,特别是如果不打算实现手工制作的特征,因为大多数ML模型都期望固定大小或线性输入,而图形数据不是这种情况。
在这篇文章中,将探讨一些处理通用图的方法,以便根据直接从数据中学习的图表表示进行节点分类。
数据集:
在Keras引用网络数据集将作为基地,在整个这个职位的实现和实验。每个节点代表科学论文,节点之间的边缘代表两篇论文之间的引用关系。
每个节点由一组二进制特征(字袋)以及将其链接到其他节点的一组边表示。
该数据集有2708个节点,分为七个类别之一。该网络有5429个链接。每个节点也由二进制字特征表示,指示相应字的存在。总体而言,每个节点有1433个二进制(稀疏)功能。以下我们只使用140 样品用于培训,其余用于验证/测试。
问题设定:
问题:在没有训练样本的情况下为图中的节点分配类标签。
直觉 / 假设:图中接近的节点更可能具有相似的标签。
解决方案:找到一种从图中提取特征的方法,以帮助对新节点进行分类。
拟议方法:
基线模型:
简单的基线模型
首先尝试使用最简单的模型,该模型学习仅使用二进制特征预测节点类并丢弃所有图形信息。该模型是一个完全连接的神经网络,它将二进制特征作为输入,并输出每个节点的类概率。
def get_features_only_model(n_features, n_classes):
in_ = Input((n_features,))
x = Dense(10, activation="relu", kernel_regularizer=l1(0.001))(in_)
x = Dropout(0.5)(x)
x = Dense(n_classes, activation="softmax")(x)
model = Model(in_, x)
model.compile(loss="sparse_categorical_crossentropy", metrics=['acc'], optimizer="adam")
model.summary()
return model
基线模型准确度:53.28%
这是将通过添加基于图形的功能来尝试改进的初始准确度。
添加图表功能:
通过在预测两个输入节点之间的最短路径长度的倒数的辅助任务上训练网络,通过将每个节点嵌入到矢量中来自动学习图形特征的一种方法,如下图和下面的代码片段所示:
学习每个节点的嵌入向量
def get_graph_embedding_model(n_nodes):
in_1 = Input((1,))
in_2 = Input((1,))
emb = Embedding(n_nodes, 100, name="node1")
x1 = emb(in_1)
x2 = emb(in_2)
x1 = Flatten()(x1)
x1 = Dropout(0.1)(x1)
x2 = Flatten()(x2)
x2 = Dropout(0.1)(x2)
x = Multiply()([x1, x2])
x = Dropout(0.1)(x)
x = Dense(1, activation="linear", name="spl")(x)
model = Model([in_1, in_2], x)
model.compile(loss="mae", optimizer="adam")
model.summary()
return model
下一步是使用预先训练的节点嵌入作为分类模型的输入。还使用学习的嵌入向量的距离添加附加输入,该输入是相邻节点的平均二进制特征。
生成的分类网络如下图所示:
使用预训练嵌入来进行节点分类
图嵌入分类模型准确度:73.06%
可以看到,添加学习图形特征作为分类模型的输入有助于显着提高分类准确性,与基线模型相比,从53.28%到73.06%。
改进图形功能学习:
可以通过进一步推进预训练并使用节点嵌入网络中的二进制特征,然后除了节点嵌入向量之外重新使用来自二进制特征的预训练权重,来进一步改进先前的模型。这导致模型依赖于从图结构中学习的二进制特征的更有用的表示。
改进的图嵌入分类模型准确度:76.35%
与以前的方法相比,这种额外的改进增加了几个百分点。
关于怎样使用Keras和Tensorflow学习图形数据就分享到这里了,希望以上内容可以对大家有一定的帮助,可以学到更多知识。如果觉得文章不错,可以把它分享出去让更多的人看到。
免责声明:本站发布的内容(图片、视频和文字)以原创、转载和分享为主,文章观点不代表本网站立场,如果涉及侵权请联系站长邮箱:is@yisu.com进行举报,并提供相关证据,一经查实,将立刻删除涉嫌侵权内容。