菜单

分片与分布式

相关源文件

本页面解释了 Grok-1 模型如何在多个计算设备上进行分区和分布。由于模型规模庞大(3140亿参数),高效的分片策略对于训练和推理都至关重要。本文档重点介绍这些策略在代码库中的实现细节。

有关推理管道的信息,请参阅 推理管道。有关检查点加载的信息,请参阅 检查点系统

设备网格架构

Grok-1 使用 JAX 的设备网格(device mesh)抽象将可用的计算设备组织成一个具有两个维度的逻辑 2D 网格:

  1. data - 用于数据并行
  2. model - 用于模型并行

此网格配置支持数据并行(并行处理多个批次)和模型并行(跨设备分割模型参数)。

网格是通过 make_mesh 函数创建的,该函数配置了本地设备和跨主机设备分配。

来源: runners.py580-593 runners.py149-182

分区规范

Grok-1 使用 JAX 的 PartitionSpec(简称为 P)来描述张量应如何分布在设备网格中。这些规范控制模型参数、激活和计算如何分配到设备。

分区规则

代码库定义了两套主要的分区规则:

  1. TRANSFORMER_PARTITION_RULES - Transformer 组件的规则
  2. LM_PARTITION_RULES - 语言模型组件的规则

这些规则用于创建一个映射函数(apply_rules),该函数根据参数在参数树中的路径确定每个参数的适当分区规范。

来源: model.py92-109 model.py112-160 model.py162-174

参数分片

不同类型的参数根据其在模型中的功能进行不同的分片。

组件分区规范描述
Attention QKV 权重P("data", "model")在两个维度上都进行了分割。
Attention 输出P("model", "data")分割,第一个维度跨模型,第二个维度跨数据。
MoE 路由权重P("data")仅在数据维度上进行分割。
MoE 专家P(None, "data", "model")第一个维度(专家)不分割,其他维度分割。
嵌入P(None, ("data", "model"))词汇维度不分割,嵌入维度在两者之间分割。
层归一化P(None)不分割(复制)。

要更深入地了解模型的架构如何与这些分片决策相关联,请参阅 模型架构

来源: model.py112-160 model.py162-174 model.py477-486

分片约束

模型代码使用 with_sharding_constraint 函数将分片约束应用于张量,以确保张量按照指定的方式分布。这对于优化内存使用和最小化设备间的通信至关重要。

来源: model.py71-75 model.py846-848 model.py1046 model.py1060-1061 model.py1097

计算分布

Grok-1 使用 JAX 的 Parallel JIT(pjit)和 shard_map 根据分片规范将计算分布到设备上。

PJIT 用法

PJIT 用于转换函数,以实现具有特定输入和输出分片模式的分布式执行。

InferenceRunner 将 PJIT 应用于几个关键函数:

  1. sample_step - 用于 token 采样
  2. prefill_memory - 用于 KV 缓存初始化
  3. new_memory - 用于内存分配

来源: runners.py413-440 runners.py190-191

MoE 的 shard_map

对于专家混合(Mixture of Experts)计算,shard_map 用于根据专家分配高效地将计算映射到特定设备。

shard_map 允许专家特定的计算发生在拥有该专家的设备上,从而最大限度地减少跨设备通信。

来源: model.py319-337 model.py339-357 model.py572-573 model.py812-824

内存分片

KV 缓存(用于 Attention 的键值内存)被分片以优化内存使用。get_memory_sharding 方法定义了该内存应如何分布。

来源: model.py477-486 model.py178-181 model.py184-200

专家混合分片

专家混合(Mixture of Experts,MoE)架构需要特殊的分片策略,因为并非所有专家都对每个 token 活跃。Grok-1 通过以下方式实现了高效的 MoE 计算:

  1. 分片的路由计算
  2. 专家特定的计算分布
  3. 使用 shard_map 进行优化的权重应用

Router 选择应处理每个 token 的专家,然后使用 shard_map 在适当的设备上执行专家计算。

来源: model.py208-248 model.py272-273 model.py319-357 model.py1073-1091

网格配置与初始化

网格配置在初始化过程中指定,并决定了模型的分布方式。

local_mesh_config 确定单个主机内的设备如何排列,而 between_hosts_config 控制跨多个主机的分布。

来源:runners.py580-593 runners.py165-182 runners.py262-286

实现细节

状态分片

在初始化模型时,get_state_sharding 用于根据分区规则确定参数应如何分片。

来源:runners.py202-210 runners.py231-246

带有分片的张量运算

对张量的操作必须遵守分片约束。例如,在更新 KV 缓存时

来源:model.py805-827 model.py845-849

性能考量

高效的分片对于性能至关重要。代码库进行了多项优化

  1. 最小化跨设备通信:仔细放置分片约束以减少通信
  2. 平衡内存使用:参数被分配以适合设备内存
  3. 优化专家局部性:MoE 计算被分配以最小化数据移动
  4. 策略性复制:一些较小的参数(如层归一化)会在设备之间复制,而不是分片

来源:model.py112-160 model.py319-357 model.py813-824

与模型组件的集成

分片系统与所有模型组件深度集成。每个组件都知道如何处理分片计算

组件分片集成
LanguageModel管理高级分片上下文
Transformer处理层级分片
MultiHeadAttention使用分片 KV 缓存实现注意力机制
MoELayer执行专家路由和分片计算
线性层支持权重分片和分布式矩阵乘法

来源:model.py1194 model.py1291-1398 model.py694-891 model.py272-397

通过仔细管理参数和计算分配,Grok-1 尽管规模庞大,仍实现了高效的推理性能,使其能够在分布式硬件配置上运行。