本文档详细介绍了多头注意力(Multi-Head Attention, MHA)的实现,它是Transformer架构中的一个关键组成部分。MHA使Transformer能够同时关注来自不同表示子空间和不同位置的信息,从而使模型能够捕获输入序列的各个方面。
该实现基于原始Transformer论文“Attention Is All You Need”,是本代码库中所有Transformer变体的基本构建块。有关完整的Transformer架构信息,请参阅Transformer架构。
来源:labml_nn/transformers/mha.py1-22
注意力机制的核心是根据查询(Q)和键(K)之间的相似度分数,计算值(V)的加权和。这可以通过以下数学公式表示:
$$\mathop{Attention}(Q, K, V) = \underset{seq}{\mathop{softmax}}\Bigg(\frac{Q K^\top}{\sqrt{d_k}}\Bigg)V$$
其中
多头注意力在多个“头”上并行运行此注意力函数,使模型能够同时关注输入序列的不同部分。
来源:labml_nn/transformers/mha.py69-86
来源:labml_nn/transformers/mha.py69-206
该实现包含两个主要类
PrepareForMultiHeadAttention:处理线性变换并分割向量以进行并行处理MultiHeadAttention:实现跨多个头的核心注意力机制此类通过以下方式为多头注意力准备输入向量:
该类的构造函数接受以下参数:
d_model:模型维度(输入特征大小)heads: 注意力头数d_k:每个头的维度(d_model / heads)bias:是否在线性变换中包含偏置来源:labml_nn/transformers/mha.py33-66
此类通过以下关键步骤实现并行注意力计算:
PrepareForMultiHeadAttention转换查询、键和值来源:labml_nn/transformers/mha.py69-206
MultiHeadAttention类初始化时使用以下参数:
| 参数 | 类型 | 描述 |
|---|---|---|
| 头数(heads) | 整数 | 注意力头的数量 |
| d_model | 整数 | 模型维度(输入特征大小) |
| dropout_prob | 浮点数 | Dropout概率(默认值:0.1) |
| 偏置(bias) | 布尔值 | 是否在线性层中使用偏置(默认值:True) |
来源:labml_nn/transformers/mha.py90-118
前向方法接受以下参数:
query:形状为[seq_len, batch_size, d_model]的查询向量key:形状为[seq_len, batch_size, d_model]的键向量value:形状为[seq_len, batch_size, d_model]的值向量mask(可选):形状为[seq_len, seq_len, batch_size]的张量,用于屏蔽某些位置计算步骤如下:
准备查询、键和值用于注意力计算
使用get_scores方法计算注意力分数
缩放分数并(如果提供)应用掩码
使用softmax计算注意力权重并应用dropout
计算值的加权和
重塑并应用最终输出变换
来源:labml_nn/transformers/mha.py147-206
分数计算在get_scores方法中实现,该方法计算查询和键之间的点积
这使用了爱因斯坦求和约定来高效地计算点积,得到形状为[seq_len_q, seq_len_k, batch_size, heads]的分数。
来源:labml_nn/transformers/mha.py121-129
多头注意力在Transformer架构中以不同的方式使用
在此实现中,MultiHeadAttention类提供了一个灵活的接口,允许将不同的张量作为查询、键和值传递,从而同时处理自注意力和交叉注意力。
来源:labml_nn/transformers/mha.py147-206
掩码是注意力中的一个关键特性,特别是对于自回归模型,其中未来信息不应可访问。该实现通过以下方式支持掩码:
mask参数prepare_mask方法,用于格式化掩码以进行注意力计算对于因果(自回归)模型,通常使用后续掩码,该掩码是使用transformers/utils.py中的subsequent_mask创建的。
来源:labml_nn/transformers/mha.py131-145 labml_nn/transformers/utils.py13-18
多头注意力在整个代码库的各种Transformer实现中都有使用
在Feedback Transformer的实现中,使用了注意力的一种专用版本(FeedbackAttention),它扩展了多头注意力的核心原理,但通过让模型关注前一时间步的输出来增加了循环性。
来源:labml_nn/transformers/feedback/__init__.py54-195
多头注意力是Transformer架构中的一个基本构建块,它使模型能够并行地关注输入序列的不同部分。本代码库中的实现提供了一个灵活高效的实现,可用于各种Transformer变体。
有关MHA如何融入更广泛的Transformer架构的更多信息,请参阅基本Transformer模型。