Hugging Face Transformers 中的生成系统为语言模型的文本生成提供了全面的框架。它处理各种生成策略,包括贪婪搜索、束搜索、采样方法以及对比搜索和辅助生成等专门技术。
本文档涵盖了生成系统的核心组件、它们如何交互以及如何有效使用它们。有关特定模型实现的信息,请参阅模型实现。有关训练模型的信息,请参阅训练系统。
生成系统通过顺序预测token来使模型生成文本。它主要通过GenerationMixin类实现,该类提供了大多数模型继承的generate()方法。该系统由几个关键组件组成:
来源
GenerationMixin类是生成系统的核心组件。它提供了大多数模型继承的generate()方法以执行文本生成。
generate()方法高度可配置,支持各种参数来控制生成过程。这些参数可以直接提供给方法,也可以通过GenerationConfig对象提供。
来源
生成系统支持多种策略,用于在生成过程中选择下一个token
在每个步骤中选择概率最高的token。这是最简单、最快的策略,但可能无法产生最多样化或高质量的输出。
维护多个候选序列(束),并选择累积概率最高的序列。这通常比贪婪搜索产生更好的结果,但计算成本更高。
根据概率分布随机选择token。这引入了随机性,可以产生更多样化的输出。
通过惩罚与先前生成内容过于相似的token来平衡高概率token与多样性。
使用一个较小的“辅助”模型生成候选token,然后由主模型验证,从而加快生成速度。
来源
Logits处理器在token选择之前修改预测分数(logits)。它们可用于在生成过程中实现各种约束和偏差。
| 处理器 | 目的 |
|---|---|
MinLengthLogitsProcessor | 在达到最小长度之前阻止EOS token |
TemperatureLogitsWarper | 通过缩放logits调整token概率 |
TopKLogitsWarper | 只保留概率最高的top-k个token |
TopPLogitsWarper | 保留累积概率小于等于top_p的token |
RepetitionPenaltyLogitsProcessor | 惩罚重复的词元 |
NoRepeatNGramLogitsProcessor | 防止n-gram重复 |
ForcedBOSTokenLogitsProcessor | 在开头强制使用特定token |
ForcedEOSTokenLogitsProcessor | 在结尾强制使用特定token |
The LogitsProcessorList类管理一个logits处理器列表,并按顺序将它们应用于模型的输出logits。
来源
停止标准决定何时停止生成过程。可以使用StoppingCriteriaList组合多个标准。
| 标准 | 目的 |
|---|---|
MaxLengthCriteria | 当输出达到最大长度时停止 |
EosTokenCriteria | 当生成EOS token时停止 |
MaxTimeCriteria | 在指定时间限制后停止 |
StopStringCriteria | 当生成特定字符串时停止 |
来源
缓存机制存储先前前向传播中的键值对,以避免在自回归生成过程中进行冗余计算。Transformers库提供了多种针对不同用例优化的缓存实现。
| 缓存类型 | 描述 | 最佳用途 |
|---|---|---|
DynamicCache | 随着token生成动态增长 | 通用用途,可变长度输出 |
StaticCache | 预分配固定大小的张量 | 内存高效,固定上下文长度 |
SlidingWindowCache | 维护近期token的窗口 | 在有限内存下进行长上下文生成 |
HybridCache | 对不同层使用不同的缓存类型 | 具有异构层的模型 |
QuantizedCache | 以更低精度存储值 | 内存受限环境 |
来源
生成系统提供结构化输出类型来组织生成结果。输出类型取决于生成策略和模型类型。
输出类型提供了一种结构化方式来访问生成的序列以及分数、注意力权重和隐藏状态等可选元数据。
来源
GenerationConfig类提供了一种集中式方法来配置生成参数。它可以从模型的配置中加载,也可以单独创建。
关键配置参数包括
| 参数 | 描述 |
|---|---|
max_length | 生成序列的最大长度 |
max_new_tokens | 要生成的新token的最大数量 |
do_sample | 是否使用采样而不是贪婪解码 |
num_beams | 束搜索的束数 |
temperature | 采样的温度 |
top_k | 为top-k采样保留的最高概率token的数量 |
top_p | top-p采样的累积概率阈值 |
repetition_penalty | 重复token的惩罚 |
no_repeat_ngram_size | 防止重复的n-gram大小 |
cache_implementation | 生成过程中使用的缓存类型 |
来源
下图说明了生成过程的完整流程
来源
生成系统支持可以提高生成速度和质量的候选生成技术
来源
生成系统支持受限生成,其中输出必须包含特定短语或遵循特定模式
来源
生成系统与Transformers库的其他组件集成
生成系统被text-generation管道使用,该管道提供了文本生成的高级接口
生成系统支持量化KV缓存,以减少生成过程中的内存使用
来源
Hugging Face Transformers 中的生成系统为语言模型的文本生成提供了灵活而强大的框架。它支持各种生成策略、logits处理器、停止标准和缓存机制,以优化质量和性能。
要点
GenerationMixin类是提供generate()方法的中心组件。GenerationConfig类提供了一种集中式方法来配置生成参数。通过理解这些组件及其交互方式,您可以有效地使用和自定义生成系统以满足您的特定需求。