本文档解释了 Llama 2 模型架构中旋转嵌入(RoPE - Rotary Position Embeddings)的实现和作用。旋转嵌入作为 Transformer 注意力系统中的位置编码机制,使模型能够理解词元之间的相对位置。
有关更广泛的 Transformer 架构信息,请参阅Transformer 实现。
旋转嵌入通过对查询(query)和键(key)向量应用与位置相关的旋转,直接在注意力机制中编码位置信息。与将位置信息添加到词元嵌入中的传统位置嵌入不同,旋转嵌入以与位置相关的方式旋转向量。
来源:llama/model.py80-104 llama/model.py132-161
Llama 2 的实现在模型初始化期间使用 precompute_freqs_cis 函数预计算频率张量。这些频率控制在每个位置应用多少旋转。
核心实现
来源:llama/model.py80-104 llama/model.py450-454
在注意力计算过程中,旋转嵌入通过 apply_rotary_emb 函数应用于查询(query)和键(key)向量。
旋转嵌入通过在复平面中应用旋转来实现。对于每个头部维度对 (d, d+1),都会应用一个与位置相关的旋转。
Llama 2 中使用的核心公式:
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))freqs = torch.outer(t, freqs)freqs_cis = torch.polar(torch.ones_like(freqs), freqs)旋转嵌入集成到 Attention 类的 forward 方法中的注意力机制中。
Attention 类中的流程:
来源:llama/model.py176-304 llama/model.py280
下图展示了旋转嵌入如何融入更广泛的 Transformer 架构。
来源:llama/model.py351-410 llama/model.py176-304
reshape_for_broadcast 函数重塑频率张量,使其在复数乘法期间与查询(query)和键(key)张量正确对齐。
该实现使用了 PyTorch 的复数支持:
torch.view_as_complex 将实数张量转换为复数torch.view_as_real 转换回实数旋转嵌入在 Llama 2 架构中提供了几个优势:
刷新此 Wiki
最后索引时间2025 年 4 月 18 日(689c7f)