菜单

Transformer 架构

相关源文件

目的与范围

本文档详细介绍了 Grok-1 中使用的 Transformer 架构,解释了其组件、注意力机制以及在代码库中的整体实现结构。有关专家混合(MoE)实现的具体信息,请参阅 专家混合(MoE)。有关内存处理和键值缓存,请参阅 内存和 KV 缓存

概述

Grok-1 实现了一种仅解码器的 Transformer 架构,包含 64 层,嵌入维度为 6,144,并在每层中都使用了专家混合的前馈网络。该模型采用了独特的注意力配置,具有 48 个查询头和 8 个键/值头,以及旋转位置嵌入(RoPE)。

来源: model.py1291-1398 model.py1201-1289 README.md21-36

组件和实现

Transformer 架构主要通过一组嵌套类来实现,这些类处理模型的不同方面。

来源: model.py1291-1398 model.py1010-1102 model.py694-911 model.py272-400 model.py489-496 model.py635-691

Transformer 类

Transformer 类作为 Transformer 架构的主要入口点。它管理着 64 个解码器层的堆栈,并处理键值缓存的内存初始化。

主要参数

  • num_q_heads: 48 (查询头)
  • num_kv_heads: 8 (键/值头)
  • key_size: 每个注意力头的尺寸
  • num_layers: 64
  • num_experts: 8 (用于 MoE 前馈网络)
  • num_selected_experts: 2 (每个 token 使用的专家数)

来源: model.py1291-1398 README.md21-36

解码器层

Grok-1 中的每个解码器层都遵循预规范化模式,并带有残差连接,它包含:

  1. RMSNorm → 多头注意力 → 残差连接
  2. RMSNorm → MoE 前馈 → 残差连接

实现位于 DecoderLayer 类中,该类结合了注意力和前馈组件。

来源: model.py1010-1102 model.py694-911 model.py272-400 model.py489-496 model.py635-691

多头注意力

Grok-1 采用独特的组查询注意力机制,包含 48 个查询头和 8 个键/值头,在 MultiHeadAttention 类中实现。这种“组查询注意力”配置通过使用更少的键/值头来实现更高效的计算,同时通过更多的查询头来维持模型的容量。

注意力机制执行以下操作:

  1. 查询、键和值的线性投影
  2. 应用旋转位置嵌入进行位置编码
  3. 计算注意力权重
  4. 值的加权求和
  5. 最终线性投影

来源: model.py694-911 model.py635-691

旋转位置嵌入

Grok-1 使用旋转位置嵌入(RoPE)来编码位置信息,该功能在 RotaryEmbedding 类中实现。RoPE 根据嵌入向量在序列中的位置对其应用旋转,这具有以下几个优点:

  1. 它能够更好地外插到更长的序列。
  2. 它保留了相对位置信息。
  3. 它允许在推理过程中进行高效的键值缓存。

实现使用 10,000 的基数指数计算旋转角度,并将其应用于注意力机制中的查询和键投影。

来源: model.py635-691 README.md32-33

归一化

Grok-1 使用 RMSNorm(均方根归一化)代替传统的 LayerNorm。RMSNorm 在 RMSNorm 类中实现,通过除以特征的均方根来归一化输入。

normed_inputs = inputs * rsqrt(mean(inputs^2) + epsilon)

RMSNorm 应用于注意力组件和前馈组件之前,遵循预规范化模式。

来源: model.py587-624 model.py489-496

模型维度和技术规格

Grok-1 的 Transformer 架构具有以下规格:

参数
层数64
嵌入大小6,144
查询注意力头48
键/值注意力头8
每个头的键/值大小768
每层 MoE 专家数8
每 token 选择的专家数2
前馈扩展因子4.0
最大序列长度8,192 个词元
位置编码旋转嵌入(RoPE)
归一化RMSNorm
总参数3140 亿

来源: README.md21-36 model.py420-486

Transformer 中的推理流程

在推理过程中,Transformer 的前向传播遵循以下顺序:

来源: model.py1326-1398 model.py1211-1289

配置与初始化

Transformer 架构使用 TransformerConfig 类进行配置,该类指定了嵌入大小、头数和层数等参数。此配置用于实例化 Transformer 类,并且在创建完整模型时通常与 LanguageModelConfig 配对使用。

定义 Transformer 架构的关键配置参数

来源: model.py420-486 README.md21-36

结论

Grok-1 中的 Transformer 架构遵循仅解码器的设计模式,并进行了一些优化和修改:

  1. 组查询注意力,具有 48 个查询头和 8 个键/值头
  2. 64 个解码器层,采用预规范化和残差连接
  3. RMSNorm 代替传统的 LayerNorm
  4. 旋转位置嵌入用于位置编码
  5. 每层的前馈网络使用专家混合(MoE)
  6. 键值缓存,用于高效的自回归推理

这些设计选择使得模型能够实现其 3140 亿参数规模,同时保持高效的推理能力。