本页面解释了 Grok-1 模型如何在多个计算设备上进行分区和分布。由于模型规模庞大(3140亿参数),高效的分片策略对于训练和推理都至关重要。本文档重点介绍这些策略在代码库中的实现细节。
有关推理管道的信息,请参阅 推理管道。有关检查点加载的信息,请参阅 检查点系统。
Grok-1 使用 JAX 的设备网格(device mesh)抽象将可用的计算设备组织成一个具有两个维度的逻辑 2D 网格:
data - 用于数据并行model - 用于模型并行此网格配置支持数据并行(并行处理多个批次)和模型并行(跨设备分割模型参数)。
网格是通过 make_mesh 函数创建的,该函数配置了本地设备和跨主机设备分配。
来源: runners.py580-593 runners.py149-182
Grok-1 使用 JAX 的 PartitionSpec(简称为 P)来描述张量应如何分布在设备网格中。这些规范控制模型参数、激活和计算如何分配到设备。
代码库定义了两套主要的分区规则:
TRANSFORMER_PARTITION_RULES - Transformer 组件的规则LM_PARTITION_RULES - 语言模型组件的规则这些规则用于创建一个映射函数(apply_rules),该函数根据参数在参数树中的路径确定每个参数的适当分区规范。
来源: model.py92-109 model.py112-160 model.py162-174
不同类型的参数根据其在模型中的功能进行不同的分片。
| 组件 | 分区规范 | 描述 |
|---|---|---|
| Attention QKV 权重 | P("data", "model") | 在两个维度上都进行了分割。 |
| Attention 输出 | P("model", "data") | 分割,第一个维度跨模型,第二个维度跨数据。 |
| MoE 路由权重 | P("data") | 仅在数据维度上进行分割。 |
| MoE 专家 | P(None, "data", "model") | 第一个维度(专家)不分割,其他维度分割。 |
| 嵌入 | P(None, ("data", "model")) | 词汇维度不分割,嵌入维度在两者之间分割。 |
| 层归一化 | P(None) | 不分割(复制)。 |
要更深入地了解模型的架构如何与这些分片决策相关联,请参阅 模型架构。
来源: model.py112-160 model.py162-174 model.py477-486
模型代码使用 with_sharding_constraint 函数将分片约束应用于张量,以确保张量按照指定的方式分布。这对于优化内存使用和最小化设备间的通信至关重要。
来源: model.py71-75 model.py846-848 model.py1046 model.py1060-1061 model.py1097
Grok-1 使用 JAX 的 Parallel JIT(pjit)和 shard_map 根据分片规范将计算分布到设备上。
PJIT 用于转换函数,以实现具有特定输入和输出分片模式的分布式执行。
InferenceRunner 将 PJIT 应用于几个关键函数:
sample_step - 用于 token 采样prefill_memory - 用于 KV 缓存初始化new_memory - 用于内存分配来源: runners.py413-440 runners.py190-191
对于专家混合(Mixture of Experts)计算,shard_map 用于根据专家分配高效地将计算映射到特定设备。
shard_map 允许专家特定的计算发生在拥有该专家的设备上,从而最大限度地减少跨设备通信。
来源: model.py319-337 model.py339-357 model.py572-573 model.py812-824
KV 缓存(用于 Attention 的键值内存)被分片以优化内存使用。get_memory_sharding 方法定义了该内存应如何分布。
来源: model.py477-486 model.py178-181 model.py184-200
专家混合(Mixture of Experts,MoE)架构需要特殊的分片策略,因为并非所有专家都对每个 token 活跃。Grok-1 通过以下方式实现了高效的 MoE 计算:
shard_map 进行优化的权重应用Router 选择应处理每个 token 的专家,然后使用 shard_map 在适当的设备上执行专家计算。
来源: model.py208-248 model.py272-273 model.py319-357 model.py1073-1091
网格配置在初始化过程中指定,并决定了模型的分布方式。
local_mesh_config 确定单个主机内的设备如何排列,而 between_hosts_config 控制跨多个主机的分布。
来源:runners.py580-593 runners.py165-182 runners.py262-286
在初始化模型时,get_state_sharding 用于根据分区规则确定参数应如何分片。
来源:runners.py202-210 runners.py231-246
对张量的操作必须遵守分片约束。例如,在更新 KV 缓存时
来源:model.py805-827 model.py845-849
高效的分片对于性能至关重要。代码库进行了多项优化
来源:model.py112-160 model.py319-357 model.py813-824
分片系统与所有模型组件深度集成。每个组件都知道如何处理分片计算
| 组件 | 分片集成 |
|---|---|
| LanguageModel | 管理高级分片上下文 |
| Transformer | 处理层级分片 |
| MultiHeadAttention | 使用分片 KV 缓存实现注意力机制 |
| MoELayer | 执行专家路由和分片计算 |
| 线性层 | 支持权重分片和分布式矩阵乘法 |
来源:model.py1194 model.py1291-1398 model.py694-891 model.py272-397
通过仔细管理参数和计算分配,Grok-1 尽管规模庞大,仍实现了高效的推理性能,使其能够在分布式硬件配置上运行。