菜单

多头注意力

相关源文件

目的和概述

本文档详细介绍了多头注意力(Multi-Head Attention, MHA)的实现,它是Transformer架构中的一个关键组成部分。MHA使Transformer能够同时关注来自不同表示子空间和不同位置的信息,从而使模型能够捕获输入序列的各个方面。

该实现基于原始Transformer论文“Attention Is All You Need”,是本代码库中所有Transformer变体的基本构建块。有关完整的Transformer架构信息,请参阅Transformer架构

来源:labml_nn/transformers/mha.py1-22

注意力机制

注意力机制的核心是根据查询(Q)和键(K)之间的相似度分数,计算值(V)的加权和。这可以通过以下数学公式表示:

$$\mathop{Attention}(Q, K, V) = \underset{seq}{\mathop{softmax}}\Bigg(\frac{Q K^\top}{\sqrt{d_k}}\Bigg)V$$

其中

  • Q(查询)、K(键)和V(值)是经过转换的输入向量
  • $d_k$ 是键向量的维度
  • 缩放因子 $\frac{1}{\sqrt{d_k}}$ 防止点积结果过大

多头注意力在多个“头”上并行运行此注意力函数,使模型能够同时关注输入序列的不同部分。

来源:labml_nn/transformers/mha.py69-86

架构概述

来源:labml_nn/transformers/mha.py69-206

实现组件

该实现包含两个主要类

  1. PrepareForMultiHeadAttention:处理线性变换并分割向量以进行并行处理
  2. MultiHeadAttention:实现跨多个头的核心注意力机制

PrepareForMultiHeadAttention

此类通过以下方式为多头注意力准备输入向量:

  1. 对输入应用线性变换
  2. 重塑结果以分离出不同的头

该类的构造函数接受以下参数:

  • d_model:模型维度(输入特征大小)
  • heads: 注意力头数
  • d_k:每个头的维度(d_model / heads)
  • bias:是否在线性变换中包含偏置

来源:labml_nn/transformers/mha.py33-66

MultiHeadAttention

此类通过以下关键步骤实现并行注意力计算:

  1. 使用PrepareForMultiHeadAttention转换查询、键和值
  2. 计算注意力分数(查询和键的点积)
  3. 通过 $\frac{1}{\sqrt{d_k}}$ 缩放分数
  4. 应用可选的掩码(用于自回归模型)
  5. 应用softmax以获得注意力权重
  6. 计算值的加权和
  7. 连接所有头的结果
  8. 应用最终线性变换

来源:labml_nn/transformers/mha.py69-206

实现细节

类结构和参数

MultiHeadAttention类初始化时使用以下参数:

参数类型描述
头数(heads)整数注意力头的数量
d_model整数模型维度(输入特征大小)
dropout_prob浮点数Dropout概率(默认值:0.1)
偏置(bias)布尔值是否在线性层中使用偏置(默认值:True)

来源:labml_nn/transformers/mha.py90-118

前向传播计算

前向方法接受以下参数:

  • query:形状为[seq_len, batch_size, d_model]的查询向量
  • key:形状为[seq_len, batch_size, d_model]的键向量
  • value:形状为[seq_len, batch_size, d_model]的值向量
  • mask(可选):形状为[seq_len, seq_len, batch_size]的张量,用于屏蔽某些位置

计算步骤如下:

  1. 准备查询、键和值用于注意力计算

  2. 使用get_scores方法计算注意力分数

  3. 缩放分数并(如果提供)应用掩码

  4. 使用softmax计算注意力权重并应用dropout

  5. 计算值的加权和

  6. 重塑并应用最终输出变换

来源:labml_nn/transformers/mha.py147-206

分数计算

分数计算在get_scores方法中实现,该方法计算查询和键之间的点积

这使用了爱因斯坦求和约定来高效地计算点积,得到形状为[seq_len_q, seq_len_k, batch_size, heads]的分数。

来源:labml_nn/transformers/mha.py121-129

与 Transformer 架构集成

多头注意力在Transformer架构中以不同的方式使用

  1. 自注意力 - 查询、键和值都来自同一来源
  2. 交叉注意力 - 查询来自一个来源,而键和值来自另一个来源

在此实现中,MultiHeadAttention类提供了一个灵活的接口,允许将不同的张量作为查询、键和值传递,从而同时处理自注意力和交叉注意力。

来源:labml_nn/transformers/mha.py147-206

注意力中的掩码

掩码是注意力中的一个关键特性,特别是对于自回归模型,其中未来信息不应可访问。该实现通过以下方式支持掩码:

  1. 前向方法中的可选mask参数
  2. 一个prepare_mask方法,用于格式化掩码以进行注意力计算

对于因果(自回归)模型,通常使用后续掩码,该掩码是使用transformers/utils.py中的subsequent_mask创建的。

来源:labml_nn/transformers/mha.py131-145 labml_nn/transformers/utils.py13-18

使用示例

多头注意力在整个代码库的各种Transformer实现中都有使用

  1. 在基本的Transformer模型(编码器-解码器架构)中
  2. 在像Feedback Transformer这样的专用变体中
  3. 在像GPT变体这样的自回归模型中

示例:Feedback Transformer

在Feedback Transformer的实现中,使用了注意力的一种专用版本(FeedbackAttention),它扩展了多头注意力的核心原理,但通过让模型关注前一时间步的输出来增加了循环性。

来源:labml_nn/transformers/feedback/__init__.py54-195

结论

多头注意力是Transformer架构中的一个基本构建块,它使模型能够并行地关注输入序列的不同部分。本代码库中的实现提供了一个灵活高效的实现,可用于各种Transformer变体。

有关MHA如何融入更广泛的Transformer架构的更多信息,请参阅基本Transformer模型