本文档解释了代码库中 k-最近邻语言模型 (kNN-LM) 的实现。kNN-LM 通过对先前看到的上下文-标记对进行最近邻搜索来增强自回归 Transformer 模型,从而提高预测质量。它在无需预训练的情况下进行领域适应方面特别有效。此实现遵循论文 “通过记忆进行泛化:最近邻语言模型”。
有关基本 Transformer 模型的信息,请参见 基本 Transformer 模型。有关专用 Transformer 变体的信息,请参见 专用 Transformer 变体。
一个标准的自回归语言模型估计 $p(w_t | c_t)$,其中 $w_t$ 是步骤 $t$ 的标记,而 $c_t = (w_1, w_2, ..., w_{t-1})$ 是上下文。kNN-LM 通过在由键值对 $(f(c_i), w_i)$ 组成的数据存储中搜索相似上下文来改进此估计,其中
此实现使用最终 Transformer 层的 FFN(前馈网络)输入作为上下文嵌入 $f(c_t)$。FAISS 用于高效的最近邻搜索。
来源:labml_nn/transformers/knn/__init__.py11-43
kNN-LM 系统由三个按顺序运行的主要组件组成
来源:labml_nn/transformers/knn/__init__.py30-43 labml_nn/transformers/knn/train_model.py23-42 labml_nn/transformers/knn/build_index.py8-13 labml_nn/transformers/knn/eval_knn.py9-10
自回归 Transformer 模型是 kNN-LM 系统的基础。它生成用于最近邻搜索的上下文嵌入。
启用 kNN-LM 的关键修改是在最后一个 Transformer 层上设置 is_save_ff_input = True,这会捕获输入序列中每个位置的上下文嵌入 $f(c_t)$。
来源:labml_nn/transformers/knn/train_model.py23-59 labml_nn/transformers/knn/__init__.py25-26
为了创建 kNN 搜索的数据存储,系统
此实现使用 FAISS 中的 IndexIVFPQ,它结合了倒排文件索引(Inverted File Index)和乘积量化(Product Quantization),用于高效存储和搜索高维向量。
来源:labml_nn/transformers/knn/build_index.py53-139
在评估期间,通过将 Transformer 模型的输出与 k-最近邻搜索的结果相结合来生成预测
kNN 推理过程中的关键步骤是
插值权重 λ 平衡了 Transformer 模型和 kNN 组件的贡献。此实现评估不同的 λ 值以找到最佳设置。
来源:labml_nn/transformers/knn/eval_knn.py22-63 labml_nn/transformers/knn/eval_knn.py66-109
完整的实现包括三个主要脚本
训练模型 (train_model.py)
构建索引 (build_index.py)
评估 (eval_knn.py)
来源:labml_nn/transformers/knn/train_model.py103-144 labml_nn/transformers/knn/build_index.py142-156 labml_nn/transformers/knn/eval_knn.py137-157
kNN-LM 方法的核心是最近邻搜索和 kNN 分布的计算
此实现使用余弦相似度来衡量检索到的上下文与当前上下文的相关性。然后,相似度分数用于加权每个检索到的标记的贡献。
一个关键方面是 Transformer 和 kNN 预测之间的插值
这使得模型能够在 Transformer 的泛化能力和 kNN 的记忆优势之间取得平衡。
来源:labml_nn/transformers/knn/eval_knn.py22-63 labml_nn/transformers/knn/eval_knn.py104-107
要使用 kNN-LM 实现,您需要
训练一个 Transformer 模型:
python -m labml_nn.transformers.knn.train_model
构建索引:
python -m labml_nn.transformers.knn.build_index
评估 kNN-LM:
python -m labml_nn.transformers.knn.eval_knn
此实现默认使用字符级分词和 tiny Shakespeare 数据集进行演示。对于更大的数据集,索引将需要更多的存储空间。
关键配置参数包括
transformer.d_model:上下文嵌入的维度来源:labml_nn/transformers/knn/train_model.py111-129 labml_nn/transformers/knn/eval_knn.py149
kNN-LM 方法提供了几个优点
此实现评估不同的插值权重,以在 Transformer 和 kNN 预测之间找到最佳平衡,展示了该方法如何优于基线 Transformer 模型。
来源:labml_nn/transformers/knn/__init__.py8-10 labml_nn/transformers/knn/__init__.py152-153