菜单

图注意力网络

相关源文件

目的与范围

本文档提供了对labml_nn库中图注意力网络(Graph Attention Networks,GAT)和图注意力网络v2(Graph Attention Networks v2,GATv2)实现的技​​术概述。它涵盖了这两种模型的架构细节、注意力机制和实现细节。我们重点关注这些网络如何通过利用注意力机制自适应地加权邻居节点特征来处理图结构数据。

来源: labml_nn/graphs/gat/__init__.py8-26 labml_nn/graphs/gatv2/__init__.py7-15

图注意力网络简介

图注意力网络(GAT)是为处理图结构数据而设计的神经网络。与处理网格状数据(图像)或序列数据(文本)的标准神经网络不同,GAT在节点通过边连接的图上操作。GAT的关键创新在于其注意力机制,该机制允许每个节点在更新其特征时,以不同的重要性关注不同的邻居节点。

在此实现中,每个GAT由多个图注意力层堆叠而成。第一个实现遵循Veličković等人撰写的原始论文《图注意力网络》,而GATv2实现则基于Brody等人撰写的论文《图注意力网络有多专注?》解决了原始模型中的局限性。

图注意力网络架构

来源: labml_nn/graphs/gat/__init__.py34-198 labml_nn/graphs/gat/experiment.py109-150

注意力机制细节

GAT注意力机制

原始GAT的注意力计算方式如下:

  1. 首先,节点特征通过一个线性层进行转换

    g = W·h
    
  2. 对于每对连接的节点 i 和 j,计算注意力分数

    e_ij = LeakyReLU(a^T [g_i || g_j])
    

    其中 || 表示拼接,a 是一个可学习的权重向量。

  3. 注意力分数使用邻接矩阵进行掩码(将非连接节点的得分设置为 -∞)。

  4. 分数使用 softmax 进行归一化,得到注意力系数

    α_ij = softmax_j(e_ij)
    
  5. 每个节点的输出特征被计算为邻居特征的加权和

    h'_i = ∑_j α_ij·g_j
    

这在 GraphAttentionLayer.forward 方法中实现。

来源: labml_nn/graphs/gat/__init__.py86-198

GATv2注意力机制

GATv2通过允许动态注意力来解决原始GAT中的一个局限性。主要区别在于注意力分数的计算方式:

GAT:   e_ij = LeakyReLU(a^T [W·h_i || W·h_j])
GATv2: e_ij = a^T·LeakyReLU(W_l·h_i + W_r·h_j)

这种看似微小的改变使得GATv2能够针对不同的查询节点具有不同的注意力排名,而GAT无论查询节点是什么,总是分配相同的排名。

来源: labml_nn/graphs/gatv2/__init__.py63-237

多头注意力

GAT和GATv2都使用多头注意力机制来稳定学习过程。在多头设置中:

  1. 多个独立的注意力机制(头)并行计算。
  2. 这些头的输出要么被拼接(在中间层),要么被平均(在最终层)。

该实现通过 is_concat 参数支持这两种策略:

  • is_concat=True 时,头的输出被拼接:h'i = ||{k=1}^K h'_i^k
  • is_concat=False 时,头的输出被平均:h'i = (1/K) ∑{k=1}^K h'_i^k

来源: labml_nn/graphs/gat/__init__.py62-72 labml_nn/graphs/gat/__init__.py191-198 labml_nn/graphs/gatv2/__init__.py92-103 labml_nn/graphs/gatv2/__init__.py230-237

实现细节

核心组件

该实现包含以下关键组件:

  1. GraphAttentionLayer:GAT的基本构建块
  2. GraphAttentionV2Layer:GATv2中的改进版本
  3. GAT:一个两层GAT模型
  4. GATv2:一个两层GATv2模型

组件关系

来源: labml_nn/graphs/gat/__init__.py34-198 labml_nn/graphs/gatv2/__init__.py63-237 labml_nn/graphs/gat/experiment.py109-150 labml_nn/graphs/gatv2/experiment.py21-66

GAT和GATv2类结构

GAT和GATv2类都遵循类似的结构——它们由以下部分组成:

  1. 一个使用多头拼接的输入注意力层(layer1
  2. 一个激活函数(ELU)
  3. 一个平均多个头的输出注意力层(output
  4. 在每个层之前应用Dropout正则化

主要区别在于GAT使用 GraphAttentionLayer,而GATv2使用 GraphAttentionV2Layer

以下是两者的比较表:

功能GATGATv2
基本注意力层GraphAttentionLayerGraphAttentionV2Layer
注意力机制LeakyReLU(a^T [W·h_i || W·h_j])a^T·LeakyReLU(W_l·h_i + W_r·h_j)
权重共享不适用通过share_weights参数可选
激活函数ELUELU
Dropout

来源: labml_nn/graphs/gat/experiment.py109-150 labml_nn/graphs/gatv2/experiment.py21-66

在 Cora 数据集上训练

该实现包含了在 Cora 数据集上训练GAT和GATv2的代码,Cora 数据集是图学习任务的标准基准数据集。

  1. 数据集:Cora 包含2708篇科学出版物(节点),5429条引用关系(边)。每篇论文都有一个1433维的特征向量,表示单词出现频率,并属于7个类别中的一个。

  2. 训练过程:

    • 数据集被划分为训练集(500个节点)和验证集
    • 使用全批量梯度下降(所有节点一次性处理)
    • 使用带权重衰减的Adam优化器
    • 用于节点分类的交叉熵损失
  3. 配置:实验使用以下超参数:

    • 学习率:5e-3
    • 权重衰减:5e-4
    • Dropout率:0.6 (GAT) 或 0.7 (GATv2)
    • 隐藏单元数:64
    • 注意力头数:8

来源: labml_nn/graphs/gat/experiment.py195-254 labml_nn/graphs/gatv2/experiment.py91-109

性能与静态注意力 vs. 动态注意力

GATv2论文的一个关键见解是识别出GAT中的“静态注意力”问题。在原始GAT中:

  • 对于任何查询节点,键(邻居节点)的排名仅由键节点本身决定
  • 这意味着注意力机制无法根据查询节点学习优先考虑不同的邻居
  • GATv2通过改变注意力机制中的操作顺序来解决此问题

这一局限性在字典查找等任务中变得显而易见,GAT在此类任务中失败,而GATv2则成功。在Cora等更标准的基准测试中,性能差异不那么显著,但GATv2通常表现更好。

来源: labml_nn/graphs/gatv2/__init__.py16-54

使用示例

要在您自己的图数据上使用GAT或GATv2模型,您需要:

  1. 节点特征:形状为 [n_nodes, in_features] 的张量
  2. 邻接矩阵:形状为 [n_nodes, n_nodes][n_nodes, n_nodes, 1] 的布尔张量

模型可以按如下方式使用

输出 logits 可用于节点分类或其他下游任务。

来源: labml_nn/graphs/gat/experiment.py135-150 labml_nn/graphs/gatv2/experiment.py51-66