本页面提供使用提供的 JAX 实现设置和运行 Grok-1 模型的说明。有关 Grok-1 架构和功能的详细信息,请参阅 Grok-1 概述。有关具体的实现细节,请参阅 实现细节。
Grok-1 是一个拥有 314B 参数的大型语言模型,推理需要大量的计算资源。
运行 Grok-1 需要以下依赖项
dm_haiku==0.0.12
jax[cuda12-pip]==0.4.25
numpy==1.26.4
sentencepiece==0.2.0
请按照以下步骤在您的系统上设置 Grok-1
克隆仓库
安装依赖项
Grok-1 模型权重必须下载并放置在适当的目录中。有两种方法
使用支持以下磁力链接的 torrent 客户端
magnet:?xt=urn:btih:5f96d43576e3d386c9ba65b883210a393b68210e&tr=https%3A%2F%2Facademictorrents.com%2Fannounce.php&tr=udp%3A%2F%2Ftracker.coppersurfer.tk%3A6969&tr=udp%3A%2F%2Ftracker.opentrackr.org%3A1337%2Fannounce
下载后,请确保模型检查点目录(ckpt-0)已放置在存储库的 checkpoints 目录中。
安装完依赖并下载模型权重后,您可以使用提供的脚本来运行推理
该脚本
run.py 中的示例脚本使用提示“The answer to life the universe and everything is of course”,并以 0.01 的温度生成响应。
来源:run.py15-72
以下图表说明了 Grok-1 的初始化和推理过程
来源:run.py24-67
以下图表显示了推理过程中关键组件之间的关系
Grok-1 模型使用以下配置进行初始化
| 参数 | 值 | 描述 |
|---|---|---|
| vocab_size | 131,072 | 词汇表大小(128K) |
| sequence_len | 8,192 | 最大序列长度(上下文窗口) |
| emb_size | 6,144 | 嵌入维度(48 * 128) |
| num_q_heads | 48 | 查询注意力头的数量 |
| num_kv_heads | 8 | 键/值注意力头的数量 |
| num_layers | 64 | Transformer 层数 |
| num_experts | 8 | MoE 层中的专家数量 |
| num_selected_experts | 2 | 每个 token 使用的专家数量 |
要使用您自己的提示运行模型,您可以修改 run.py 中的示例
inp 变量中的输入文本max_len 和 temperature 等生成参数python run.py 运行脚本来源:run.py66-67
默认情况下,示例代码配置为在单个主机上使用 8 个 GPU。此配置可以通过在初始化 InferenceRunner 时修改 local_mesh_config 参数来调整
local_mesh_config=(1, 8) # (number of hosts, number of devices per host)
您可能需要根据可用硬件调整此配置。
来源:run.py50-62
成功设置并运行基本示例后,您可能希望探索