菜单

旋转嵌入

相关源文件

本文档解释了 Llama 2 模型架构中旋转嵌入(RoPE - Rotary Position Embeddings)的实现和作用。旋转嵌入作为 Transformer 注意力系统中的位置编码机制,使模型能够理解词元之间的相对位置。

有关更广泛的 Transformer 架构信息,请参阅Transformer 实现

1. 概述

旋转嵌入通过对查询(query)和键(key)向量应用与位置相关的旋转,直接在注意力机制中编码位置信息。与将位置信息添加到词元嵌入中的传统位置嵌入不同,旋转嵌入以与位置相关的方式旋转向量。

来源:llama/model.py80-104 llama/model.py132-161

2. 实现细节

2.1 频率预计算

Llama 2 的实现在模型初始化期间使用 precompute_freqs_cis 函数预计算频率张量。这些频率控制在每个位置应用多少旋转。

核心实现

来源:llama/model.py80-104 llama/model.py450-454

2.2 注意力机制中的应用

在注意力计算过程中,旋转嵌入通过 apply_rotary_emb 函数应用于查询(query)和键(key)向量。

来源:llama/model.py132-161

3. 数学基础

旋转嵌入通过在复平面中应用旋转来实现。对于每个头部维度对 (d, d+1),都会应用一个与位置相关的旋转。

Llama 2 中使用的核心公式:

  1. 计算基频: freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
  2. 计算与位置相关的旋转: freqs = torch.outer(t, freqs)
  3. 表示为复数: freqs_cis = torch.polar(torch.ones_like(freqs), freqs)

来源:llama/model.py80-104

4. 与注意力机制的集成

旋转嵌入集成到 Attention 类的 forward 方法中的注意力机制中。

Attention 类中的流程:

  1. 将输入投影到查询(query)、键(key)和值(value)向量
  2. 重塑张量以分离注意力头
  3. 对查询(query)和键(key)向量应用旋转嵌入
  4. 使用旋转后的向量计算注意力分数
  5. 应用 softmax 并计算最终输出

来源:llama/model.py176-304 llama/model.py280

5. 数据流图

下图展示了旋转嵌入如何融入更广泛的 Transformer 架构。

来源:llama/model.py351-410 llama/model.py176-304

6. 技术实现细节

6.1 用于广播的重塑

reshape_for_broadcast 函数重塑频率张量,使其在复数乘法期间与查询(query)和键(key)张量正确对齐。

来源:llama/model.py107-129

6.2 复数处理

该实现使用了 PyTorch 的复数支持:

  1. 使用 torch.view_as_complex 将实数张量转换为复数
  2. 直接执行复数乘法
  3. 使用 torch.view_as_real 转换回实数

来源:llama/model.py156-160

7. 旋转嵌入的优势

旋转嵌入在 Llama 2 架构中提供了几个优势:

  1. 相对位置敏感性:注意力机制自然地理解词元之间的相对位置。
  2. 外推能力:模型可以推广到比训练期间更长的序列长度。
  3. 无额外参数:位置信息通过旋转编码,而非通过学习的嵌入。
  4. 线性复杂度:计算量随序列长度线性扩展。

来源:llama/model.py132-161 llama/model.py450-454