菜单

推理管道

相关源文件

DeepSeek-V3 中的推理管道提供了高效的模型推理所需的组件和流程。本文档详细介绍了用于 FP8/BF16 计算的优化内核、令牌生成过程以及整体推理工作流。有关底层模型架构的信息,请参阅模型架构

流水线概览

DeepSeek-V3 推理管道旨在为现代硬件上的高效文本生成而设计。它包含多个协同工作的专用组件,以优化推理过程。

图示:DeepSeek-V3 推理管道概览

来源: inference/kernel.py1-191 inference/generate.py1-185

优化内核

推理管道利用专用的 Triton JIT 内核进行高效的 GPU 计算,实现在 kernel.py 文件中。

激活量化

激活量化将激活张量的精度从更高的精度(如 BF16)降低到 FP8 格式,从而在保持准确性的同时提高了计算效率。

图示:激活量化过程

act_quant 函数接收一个输入张量并执行分块量化,同时返回 FP8 格式的量化张量和缩放因子张量。

  • 输入:最后维度可被 block_size 整除的连续张量
  • 处理:计算缩放因子并使用 Triton 内核执行量化
  • 输出:FP8 张量和相应的缩放因子

来源: inference/kernel.py9-52

权重反量化

权重反量化通过应用存储的缩放因子将量化后的权重恢复到其原始精度。

图示:权重反量化过程

weight_dequant 函数接收量化权重张量及其对应的缩放因子,并生成一个反量化张量。

  • 输入:量化权重张量和缩放张量
  • 处理:将缩放因子应用于量化张量的块
  • 输出:默认精度(BF16)的反量化张量

来源: inference/kernel.py55-105

FP8 矩阵乘法

FP8 矩阵乘法是一种优化操作,它以 FP8 精度输入执行矩阵乘法,并应用缩放因子以获得准确的结果。

图示:FP8 矩阵乘法

fp8_gemm 函数使用 FP8 精度执行矩阵乘法。

  • 输入:两个 FP8 矩阵及其相应的缩放因子
  • 处理:执行自动调优的 Triton 内核以实现高效的矩阵乘法
  • 输出:默认精度(BF16)的结果矩阵

内核通过各种配置进行自动调优,以找到给定矩阵维度的最佳参数。

参数可能的值
BLOCK_SIZE_M16, 32, 64
BLOCK_SIZE_N32, 64, 128
NUM_STAGES3, 4, 5, 6
NUM_WARPS8(固定)

来源: inference/kernel.py107-191

文本生成过程

文本生成组件负责根据输入提示生成输出令牌,实现在 generate.py 文件中。

图示:文本生成过程流程

来源: inference/generate.py30-78

令牌生成函数

文本生成过程的核心是 generate 函数,它

  1. 接收提示令牌序列列表和模型
  2. 以自回归方式处理输入
  3. 处理同时生成多个提示的批次
  4. 遵守模型的最大序列长度
  5. 当遇到 EOS 令牌或达到最大令牌数时终止生成

生成过程的关键参数

参数描述
modelTransformer 模型实例
prompt_tokens每个提示的令牌序列列表
max_new_tokens要生成的最大 token 数
eos_id结束序列令牌 ID
temperature控制令牌采样的随机性

来源: inference/generate.py30-78

采样机制

sample 函数实现基于温度的采样机制来选择令牌。

  • 输入:来自模型的 Logits 和温度参数
  • 处理:应用温度缩放并计算概率
  • 输出:基于修改后的 softmax 分布选择的令牌

对于确定性生成,将温度设置为 0 会导致函数选择概率最高的令牌。

来源: inference/generate.py14-27

交互式和批量模式

推理管道支持两种操作模式:

图示:推理操作模式

  1. 交互模式:

    • 以对话方式处理用户输入
    • 维护与用户和助手的聊天记录
    • 支持跨多个 GPU 的分布式操作
  2. 批处理模式:

    • 处理来自输入文件的多个提示
    • 在单个批次中为所有提示生成补全
    • 根据模型配置强制执行批次大小限制

这两种模式都使用相同的底层 generate 函数,但在输入收集和结果呈现方式上有所不同。

来源: inference/generate.py81-158

分布式推理支持

推理管道支持跨多个 GPU 进行分布式操作。

  • 使用 PyTorch 的分布式通信包(torch.distributed)。
  • 协调节点之间的提示分发和结果收集。
  • 初始化 NCCL 后端以实现高效的 GPU 到 GPU 通信。
  • 处理 world size、rank 和 local rank 以进行正确的设备分配。

来源: inference/generate.py100-107 inference/generate.py156-158