本文档提供了对labml_nn库中图注意力网络(Graph Attention Networks,GAT)和图注意力网络v2(Graph Attention Networks v2,GATv2)实现的技术概述。它涵盖了这两种模型的架构细节、注意力机制和实现细节。我们重点关注这些网络如何通过利用注意力机制自适应地加权邻居节点特征来处理图结构数据。
来源: labml_nn/graphs/gat/__init__.py8-26 labml_nn/graphs/gatv2/__init__.py7-15
图注意力网络(GAT)是为处理图结构数据而设计的神经网络。与处理网格状数据(图像)或序列数据(文本)的标准神经网络不同,GAT在节点通过边连接的图上操作。GAT的关键创新在于其注意力机制,该机制允许每个节点在更新其特征时,以不同的重要性关注不同的邻居节点。
在此实现中,每个GAT由多个图注意力层堆叠而成。第一个实现遵循Veličković等人撰写的原始论文《图注意力网络》,而GATv2实现则基于Brody等人撰写的论文《图注意力网络有多专注?》解决了原始模型中的局限性。
图注意力网络架构
来源: labml_nn/graphs/gat/__init__.py34-198 labml_nn/graphs/gat/experiment.py109-150
原始GAT的注意力计算方式如下:
首先,节点特征通过一个线性层进行转换
g = W·h
对于每对连接的节点 i 和 j,计算注意力分数
e_ij = LeakyReLU(a^T [g_i || g_j])
其中 || 表示拼接,a 是一个可学习的权重向量。
注意力分数使用邻接矩阵进行掩码(将非连接节点的得分设置为 -∞)。
分数使用 softmax 进行归一化,得到注意力系数
α_ij = softmax_j(e_ij)
每个节点的输出特征被计算为邻居特征的加权和
h'_i = ∑_j α_ij·g_j
这在 GraphAttentionLayer.forward 方法中实现。
来源: labml_nn/graphs/gat/__init__.py86-198
GATv2通过允许动态注意力来解决原始GAT中的一个局限性。主要区别在于注意力分数的计算方式:
GAT: e_ij = LeakyReLU(a^T [W·h_i || W·h_j])
GATv2: e_ij = a^T·LeakyReLU(W_l·h_i + W_r·h_j)
这种看似微小的改变使得GATv2能够针对不同的查询节点具有不同的注意力排名,而GAT无论查询节点是什么,总是分配相同的排名。
来源: labml_nn/graphs/gatv2/__init__.py63-237
GAT和GATv2都使用多头注意力机制来稳定学习过程。在多头设置中:
该实现通过 is_concat 参数支持这两种策略:
is_concat=True 时,头的输出被拼接:h'i = ||{k=1}^K h'_i^kis_concat=False 时,头的输出被平均:h'i = (1/K) ∑{k=1}^K h'_i^k来源: labml_nn/graphs/gat/__init__.py62-72 labml_nn/graphs/gat/__init__.py191-198 labml_nn/graphs/gatv2/__init__.py92-103 labml_nn/graphs/gatv2/__init__.py230-237
该实现包含以下关键组件:
来源: labml_nn/graphs/gat/__init__.py34-198 labml_nn/graphs/gatv2/__init__.py63-237 labml_nn/graphs/gat/experiment.py109-150 labml_nn/graphs/gatv2/experiment.py21-66
GAT和GATv2类都遵循类似的结构——它们由以下部分组成:
layer1)output)主要区别在于GAT使用 GraphAttentionLayer,而GATv2使用 GraphAttentionV2Layer。
以下是两者的比较表:
| 功能 | GAT | GATv2 |
|---|---|---|
| 基本注意力层 | GraphAttentionLayer | GraphAttentionV2Layer |
| 注意力机制 | LeakyReLU(a^T [W·h_i || W·h_j]) | a^T·LeakyReLU(W_l·h_i + W_r·h_j) |
| 权重共享 | 不适用 | 通过share_weights参数可选 |
| 激活函数 | ELU | ELU |
| Dropout | 是 | 是 |
来源: labml_nn/graphs/gat/experiment.py109-150 labml_nn/graphs/gatv2/experiment.py21-66
该实现包含了在 Cora 数据集上训练GAT和GATv2的代码,Cora 数据集是图学习任务的标准基准数据集。
数据集:Cora 包含2708篇科学出版物(节点),5429条引用关系(边)。每篇论文都有一个1433维的特征向量,表示单词出现频率,并属于7个类别中的一个。
训练过程:
配置:实验使用以下超参数:
来源: labml_nn/graphs/gat/experiment.py195-254 labml_nn/graphs/gatv2/experiment.py91-109
GATv2论文的一个关键见解是识别出GAT中的“静态注意力”问题。在原始GAT中:
这一局限性在字典查找等任务中变得显而易见,GAT在此类任务中失败,而GATv2则成功。在Cora等更标准的基准测试中,性能差异不那么显著,但GATv2通常表现更好。
来源: labml_nn/graphs/gatv2/__init__.py16-54
要在您自己的图数据上使用GAT或GATv2模型,您需要:
[n_nodes, in_features] 的张量[n_nodes, n_nodes] 或 [n_nodes, n_nodes, 1] 的布尔张量模型可以按如下方式使用
输出 logits 可用于节点分类或其他下游任务。
来源: labml_nn/graphs/gat/experiment.py135-150 labml_nn/graphs/gatv2/experiment.py51-66
刷新此 Wiki
最后索引时间2025 年 4 月 18 日(90e21b)