菜单

模型加载与生成

相关源文件

本文档涵盖 Llama 模型的模型加载过程、量化转换和文本生成管线。它重点关注 Llama4 类的初始化、权重量化以及包括采样策略在内的核心生成循环。

有关模型架构详情,请参阅模型架构。有关分词细节,请参阅分词与量化。有关脚本使用示例,请参阅文本生成脚本

模型加载过程

模型加载过程由 Llama4.build() 静态方法处理,该方法初始化分布式训练,加载检查点,并选择性地应用量化。

构建流程

构建过程遵循以下关键步骤:

  1. 分布式设置:使用 fairscale 初始化 NCCL 进程组和模型并行
  2. 检查点发现:在检查点目录中查找 .pth 文件
  3. 配置加载:读取 params.json 以创建 ModelArgs
  4. 状态字典加载:使用 maybe_reshard_state_dict() 处理模型并行分片
  5. 可选量化:如果指定,转换为量化权重
  6. 模型实例化:创建最终的 Llama4 包装器实例

来源: models/llama4/generation.py36-110

关键组件

组件目的位置
ModelArgs来自 params.json 的配置参数models/llama4/generation.py70-74
Tokenizer处理文本编码/解码models/llama4/generation.py75
ChatFormat格式化消息以完成聊天models/llama4/generation.py116
Transformer核心模型架构models/llama4/generation.py93-104

来源: models/llama4/generation.py36-117

量化转换

量化系统将全精度权重转换为降低精度格式(FP8 或 INT4),以减少内存使用并提高推理速度。

量化管线

量化模式

系统支持在 QuantizationMode 中定义的两种量化模式

  • fp8_mixed:使用 FP8 量化,跳过第一层和最后一层
  • int4_mixed:使用 INT4 量化,对路由专家和共享专家都进行量化

should_quantize_block() 函数确定要量化的层

来源: models/llama4/quantization/loader.py46-170

权重处理

对于每个量化块,系统处理 MoE 专家的三个权重矩阵(w1w3w2

  1. 路由专家:对 moe.experts.{w1,w3,w2} 参数应用量化
  2. 共享专家:对于 INT4 模式,也量化 moe.shared_expert.{w1,w3,w2} 参数
  3. 函数替换:用优化的量化实现替换前向方法

来源: models/llama4/quantization/loader.py129-153

文本生成管线

生成管线通过模型转换输入 token 以生成输出文本,处理补全和聊天补全两种场景。

生成架构

核心生成组件

组件类型目的
LLMInput输入包含 token 和可选图像
TransformerInput模型输入Token、位置、图像嵌入
MaskedEmbedding视觉处理图像块嵌入
GenerationResult输出每步的 token、文本、元数据

来源: models/llama4/generation.py118-246

生成方法

Llama4 类提供了三种生成接口

来源: models/llama4/generation.py247-290

采样策略

生成系统支持多种采样策略来控制输出的随机性和质量。

采样决策流程

Top-p (核心) 采样

sample_top_p() 函数实现了核心采样

  1. 排序概率: torch.sort(probs, descending=True)
  2. 累积和: torch.cumsum(probs_sort)
  3. 应用阈值: 遮蔽 probs_sum - probs_sort > p 的 token
  4. 重新归一化: probs_sort.div_(probs_sort.sum())
  5. 采样: torch.multinomial(probs_sort, num_samples=1)

这种方法从累积概率超过阈值 p 的最小 token 集合中进行选择,提供比传统 top-k 采样更好的质量。

来源: models/llama4/generation.py292-314

采样参数

参数范围效果
temperature0.0-2.0+控制随机性(0=贪婪,数值越高=越随机)
top_p0.0-1.0核心采样阈值
max_gen_len1-max_seq_len最大生成 token 数量

来源: models/llama4/generation.py119-128 models/llama4/generation.py207-211