菜单

运行推理

相关源文件

本文档提供了运行 Grok-1 大型语言模型进行推理的详细技术说明。它涵盖了从初始化到文本生成的整个过程,重点关注模型在推理任务中的实际应用。

有关下载模型权重的更多信息,请参阅下载模型权重

推理过程概述

运行 Grok-1 推理涉及多个关键组件的协同工作

  1. 初始化:加载模型权重、设置设备网格并编译推理函数
  2. 分词:将文本提示转换为 token 序列
  3. 前向传播:通过 314B 参数的 MoE transformer 处理 token
  4. 采样:从模型输出的概率分布中选择 token
  5. 文本生成:将生成的 token 转换回文本

来源:runners.py442-577 run.py24-67

核心推理组件

推理系统围绕两个主要类构建

  1. InferenceRunner:用于运行推理的高级接口
  2. ModelRunner:处理模型初始化和检查点加载

来源:runners.py252-270 runners.py136-249

设置推理环境

要运行 Grok-1 推理,您需要

  1. 配置模型参数
  2. 创建一个 ModelRunner 实例
  3. 创建一个 InferenceRunner 实例
  4. 初始化运行器
  5. 使用您的提示运行推理

以下代码演示了设置过程

来源:run.py50-67 runners.py275-441

基本使用示例

这是一个运行 Grok-1 推理的最小示例

# Set up model configuration
grok_1_model = LanguageModelConfig(...)  # Configuration omitted for brevity

# Create InferenceRunner
inference_runner = InferenceRunner(
    pad_sizes=(1024,),
    runner=ModelRunner(
        model=grok_1_model,
        bs_per_device=0.125,
        checkpoint_path=CKPT_PATH,
    ),
    name="local",
    load=CKPT_PATH,
    tokenizer_path="./tokenizer.model",
    local_mesh_config=(1, 8),
    between_hosts_config=(1, 1),
)

# Initialize
inference_runner.initialize()
gen = inference_runner.run()

# Run inference
prompt = "The answer to life the universe and everything is of course"
output = sample_from_model(gen, prompt, max_len=100, temperature=0.01)
print(output)

来源:run.py24-67 runners.py596-605

关键推理参数

运行推理时,您可以通过多个参数来控制生成过程

参数描述默认示例
prompt输入给模型的文本“生命的答案是...”
max_len生成的序列的最大长度100
temperature控制随机性(值越低,结果越确定)0.01
nucleus_p控制 nucleus 采样的截止点1.0
rng_seed用于可复现生成的随机种子42

这些参数会传递给 sample_from_model 函数,或封装在 Request 对象中。

来源:runners.py253-259 runners.py596-605

详细的推理过程

1. 分词

输入文本使用 SentencePiece 分词器进行分词

来源:runners.py288 runners.py513-521

2. 内存初始化

模型使用键值缓存来存储先前 token 的注意力状态

来源:runners.py65-74 runners.py330-332

3. 预填充阶段

在处理提示时,模型首先会执行一个“预填充”阶段,一次性处理所有提示 token

来源:runners.py343-379

4. 自回归生成

预填充阶段之后,模型逐个生成 token

来源:runners.py324-328 runners.py549-577

5. 采样过程

Token 采样涉及几个步骤

来源:runners.py84-97 runners.py100-133

分布式推理

Grok-1 使用 JAX 的 pjit 在多个设备上进行分布式计算

网格是通过两个配置参数创建的

  • local_mesh_config:控制主机内的设备排列
  • between_hosts_config:控制跨主机的排列

来源:runners.py580-593 run.py60-61

性能考量

运行 Grok-1 推理时,请注意以下事项

  1. 内存需求:Grok-1 是一个拥有 314B 参数的模型,需要大量的 GPU 内存
  2. 设备配置:正确配置设备网格对于高效执行至关重要
  3. 批次大小:由 bs_per_device 参数控制(值越小,内存占用越少)
  4. 提示长度:更长的提示需要更多的内存用于 KV 缓存
  5. 填充pad_sizes 参数定义了可能影响内存使用的填充桶

来源:README.md17-19 run.py51-56

故障排除

运行推理的常见问题

  1. 内存不足:减小批次大小、使用更短的提示或将模型分布到更多设备上
  2. 生成速度慢:检查 temperature 设置(非常低的值可能由于数值问题导致速度变慢)
  3. 检查点加载错误:确保检查点已正确下载且路径设置正确

来源:README.md4-14 run.py21

总结

本文档涵盖了运行 Grok-1 模型进行推理的实际操作,包括初始化、核心组件、参数设置以及详细的推理过程。有关模型架构的信息,请参阅 模型架构

来源:README.md21-36 run.py24-67 runners.py252-605