菜单

kNN 语言模型

相关源文件

本文档解释了代码库中 k-最近邻语言模型 (kNN-LM) 的实现。kNN-LM 通过对先前看到的上下文-标记对进行最近邻搜索来增强自回归 Transformer 模型,从而提高预测质量。它在无需预训练的情况下进行领域适应方面特别有效。此实现遵循论文 “通过记忆进行泛化:最近邻语言模型”

有关基本 Transformer 模型的信息,请参见 基本 Transformer 模型。有关专用 Transformer 变体的信息,请参见 专用 Transformer 变体

概述

一个标准的自回归语言模型估计 $p(w_t | c_t)$,其中 $w_t$ 是步骤 $t$ 的标记,而 $c_t = (w_1, w_2, ..., w_{t-1})$ 是上下文。kNN-LM 通过在由键值对 $(f(c_i), w_i)$ 组成的数据存储中搜索相似上下文来改进此估计,其中

  • $f(c_i)$ 是上下文 $c_i$ 的嵌入
  • $w_i$ 是训练数据中跟在上下文 $c_i$ 后面的标记

此实现使用最终 Transformer 层的 FFN(前馈网络)输入作为上下文嵌入 $f(c_t)$。FAISS 用于高效的最近邻搜索。

来源:labml_nn/transformers/knn/__init__.py11-43

系统架构

kNN-LM 系统由三个按顺序运行的主要组件组成

  1. 训练 Transformer 模型:在目标数据集上训练一个标准的自回归 Transformer。
  2. 构建索引:处理训练数据以提取上下文嵌入 $f(c_i)$ 及其对应的下一个标记 $w_i$,并将它们存储在高效的索引中。
  3. 评估 kNN-LM:在推理过程中,将 Transformer 的预测与 k-最近邻搜索的结果相结合。

来源:labml_nn/transformers/knn/__init__.py30-43 labml_nn/transformers/knn/train_model.py23-42 labml_nn/transformers/knn/build_index.py8-13 labml_nn/transformers/knn/eval_knn.py9-10

自回归模型

自回归 Transformer 模型是 kNN-LM 系统的基础。它生成用于最近邻搜索的上下文嵌入。

启用 kNN-LM 的关键修改是在最后一个 Transformer 层上设置 is_save_ff_input = True,这会捕获输入序列中每个位置的上下文嵌入 $f(c_t)$。

来源:labml_nn/transformers/knn/train_model.py23-59 labml_nn/transformers/knn/__init__.py25-26

构建 FAISS 索引

为了创建 kNN 搜索的数据存储,系统

  1. 通过 Transformer 模型处理训练数据
  2. 提取 $f(c_i)$(上下文嵌入)和 $w_i$(下一个标记)
  3. 将它们存储在内存映射的 NumPy 数组中
  4. 构建 FAISS 索引以进行高效的相似性搜索

此实现使用 FAISS 中的 IndexIVFPQ,它结合了倒排文件索引(Inverted File Index)和乘积量化(Product Quantization),用于高效存储和搜索高维向量。

来源:labml_nn/transformers/knn/build_index.py53-139

kNN 推理和插值

在评估期间,通过将 Transformer 模型的输出与 k-最近邻搜索的结果相结合来生成预测

kNN 推理过程中的关键步骤是

  1. 从 Transformer 生成上下文嵌入 $f(c_t)$
  2. 在索引中查找 k 个最近邻
  3. 使用余弦相似度计算相似度分数
  4. 从检索到的标记创建分布
  5. 使用权重 λ 与 Transformer 的预测进行插值

插值权重 λ 平衡了 Transformer 模型和 kNN 组件的贡献。此实现评估不同的 λ 值以找到最佳设置。

来源:labml_nn/transformers/knn/eval_knn.py22-63 labml_nn/transformers/knn/eval_knn.py66-109

实现工作流程

完整的实现包括三个主要脚本

  1. 训练模型 (train_model.py)

    • 配置并训练一个自回归 Transformer
    • 启用在最后一层保存前馈输入
  2. 构建索引 (build_index.py)

    • 加载已训练的模型
    • 处理训练数据以提取上下文嵌入
    • 构建并保存 FAISS 索引
  3. 评估 (eval_knn.py)

    • 加载模型和索引
    • 实现 kNN 搜索和插值
    • 评估不同插值权重下的性能

来源:labml_nn/transformers/knn/train_model.py103-144 labml_nn/transformers/knn/build_index.py142-156 labml_nn/transformers/knn/eval_knn.py137-157

kNN 搜索实现

kNN-LM 方法的核心是最近邻搜索和 kNN 分布的计算

此实现使用余弦相似度来衡量检索到的上下文与当前上下文的相关性。然后,相似度分数用于加权每个检索到的标记的贡献。

一个关键方面是 Transformer 和 kNN 预测之间的插值

这使得模型能够在 Transformer 的泛化能力和 kNN 的记忆优势之间取得平衡。

来源:labml_nn/transformers/knn/eval_knn.py22-63 labml_nn/transformers/knn/eval_knn.py104-107

配置和使用

要使用 kNN-LM 实现,您需要

  1. 训练一个 Transformer 模型:

    python -m labml_nn.transformers.knn.train_model
    
  2. 构建索引:

    python -m labml_nn.transformers.knn.build_index
    
  3. 评估 kNN-LM:

    python -m labml_nn.transformers.knn.eval_knn
    

此实现默认使用字符级分词和 tiny Shakespeare 数据集进行演示。对于更大的数据集,索引将需要更多的存储空间。

关键配置参数包括

  • transformer.d_model:上下文嵌入的维度
  • 要检索的最近邻数量(在实现中设置为 10)
  • 插值权重(评估范围从 0 到 0.45 的值)

来源:labml_nn/transformers/knn/train_model.py111-129 labml_nn/transformers/knn/eval_knn.py149

性能优势

kNN-LM 方法提供了几个优点

  1. 改进的预测质量:通过将精确的上下文匹配与 Transformer 的通用知识相结合
  2. 领域适应:无需完全重新训练即可适应新领域
  3. 内存效率:FAISS 索引提供了高效的高维向量存储和检索

此实现评估不同的插值权重,以在 Transformer 和 kNN 预测之间找到最佳平衡,展示了该方法如何优于基线 Transformer 模型。

来源:labml_nn/transformers/knn/__init__.py8-10 labml_nn/transformers/knn/__init__.py152-153