菜单

内存与KV缓存

相关源文件

本文档提供了 Grok-1 模型中内存管理和键值缓存机制的技术概述。这些系统对于高效的自回归文本生成至关重要,它允许模型在不为每个新 token 重新计算注意力模式的情况下,保持对先前生成 token 的上下文感知。

有关 Transformer 架构整体的信息,请参阅 Transformer 架构。有关专家混合(Mixture of Experts)实现细节,请参阅 专家混合 (MoE)

核心数据结构

Grok-1 模型使用两种主要数据结构来管理键值缓存。

KVMemory:一个 NamedTuple,用于存储:

  • k:到目前为止处理过的所有位置的键(Key)张量(形状:[batch_size, sequence_len, num_kv_heads, key_size]
  • v:到目前为止处理过的所有位置的值(Value)张量(形状:[batch_size, sequence_len, num_kv_heads, key_size]
  • step:一个整数张量,指示序列中的当前位置(形状:[batch_size]

Memory:一个 NamedTuple,包含:

  • layers:一个 KVMemory 对象列表,每个 Transformer 层对应一个对象。

来源:model.py178-182 model.py203-205

内存初始化

模型在处理输入 token 之前,使用零初始化内存。`init_layer_memories` 函数创建初始内存结构。

内存初始化由以下函数处理:

  • init_layer_memories:创建具有零张量的 KVMemory 对象数组。
  • Transformer.init_memory:包装器,调用 `init_layer_memories`。
  • LanguageModel.init_memory:公共 API,委托给 Transformer 的初始化。

来源:model.py184-200 model.py1281-1282 model.py1313-1324

推理期间的 KV 缓存

在推理过程中,KV 缓存对于高效的自回归生成起着关键作用。

此过程中的关键操作包括:

  1. 计算新 token 的查询(Query)、键(Key)和值(Value)投影。
  2. 根据序列中的位置应用旋转位置嵌入。
  3. 从内存中检索现有的键和值。
  4. 计算新查询与所有键(缓存+当前)之间的注意力分数。
  5. 使用新的键和值更新 KV 缓存。
  6. 在内存中递增位置计数器(step)。

来源:model.py720-891 model.py1325-1398

内存更新机制

当处理新 token 时,KV 缓存通过动态切片操作进行更新。这确保了内存随着每个 token 的添加而增长,同时保持正确的定位信息。

更新是通过 JAX 的 `dynamic_update_slice_in_dim` 函数执行的,该函数被包装在 `vmap` 函数中以处理批处理操作。

此函数将新的键值向量插入到缓存中的正确位置,该位置由 `step` 计数器确定。

来源:model.py805-808 model.py826-831

内存分片用于分布式计算

为了高效的分布式计算,内存使用 JAX 的分片系统在设备之间进行分片。

分片配置确保:

  • 键值缓存跨数据和模型维度进行分区。
  • 步计数器仅跨数据维度进行分区。
  • 分片与模型参数分片保持一致。

当在多个设备上运行时,使用 `shard_map` 的更新函数的专用版本来高效处理分布式内存更新。

来源:model.py476-486 model.py812-824

生成期间的内存管理

`InferenceRunner` 类在整个文本生成过程中管理内存。

生成期间内存管理的关键方面:

  1. 内存初始化:为最大序列长度创建零填充的内存结构。
  2. 提示处理:高效处理所有提示 token 并更新内存。
  3. Token 生成:逐个生成新 token,使用并更新内存。
  4. 活跃请求跟踪:内存跟踪哪些批次元素正在积极生成文本。
  5. 内存重用:生成完成后,内存槽会被重用于新请求。

来源:runners.py442-578 runners.py323-329 runners.py333-394

预填充 KV 缓存

系统使用“预填充”过程来高效处理提示 token,在开始逐 token 生成之前填充 KV 缓存。

`prefill_memory` 函数一次性处理提示 token,并使用它们的键值投影更新内存。这比逐个处理提示 token 更有效。预填充后,内存中的 `step` 计数器被设置为提示长度,以确保新 token 在正确的位置添加。

来源:runners.py333-394 model.py1284-1288

内存与注意力掩码的交互

KV 缓存系统的关键部分在于它如何与注意力掩码进行交互,以确保 token 只关注正确的 pos。

在使用 KV 缓存时,系统会根据缓存中有效的位置创建一个“内存掩码”。此掩码确保注意力只为有效的、先前处理过的位置计算。它与因果掩码(用于自回归生成)结合,形成最终的注意力掩码。

来源:model.py832-838 model.py867-875

内存和旋转位置嵌入

Grok-1 使用旋转位置嵌入(RoPE),该嵌入与 KV 缓存系统集成。

旋转位置嵌入使用内存中的 `step` 计数器来计算每个 token 的绝对位置。这确保了 token 在 Transformer 中保持其正确的定位信息,即使在使用 KV 缓存进行高效生成时也是如此。

来源:model.py635-691 model.py802-803

总结

Grok-1 中的内存和 KV 缓存系统通过以下方式实现高效的自回归 token 生成:

  1. 存储先前计算的键值投影,以避免重复计算。
  2. 使用内存中的步计数器跟踪序列位置。
  3. 与注意力掩码集成,以确保正确的注意力模式。
  4. 通过内存分片支持分布式计算。
  5. 通过矢量化操作高效处理批处理。

该系统对于模型在推理期间的性能至关重要,允许它在序列变长时以最小的计算开销生成文本。

来源: model.py178-205 model.py720-891 runners.py442-578