菜单

序列建模

相关源文件

本文档提供了d2l-ai/d2l-zh仓库中序列建模的技术概述,重点关注循环神经网络(RNN)及其处理文本、时间序列和其他有序信息等序列数据的实现。它涵盖了序列数据的基本概念、统计建模方法、RNN架构、实现细节以及评估指标。有关LSTM、GRU或Transformer等特定高级序列模型的信息,请参阅其他专门的维基页面。

什么是序列数据?

序列数据是指元素顺序很重要的信息。与传统机器学习中常用的独立同分布(i.i.d.)数据不同,序列数据表现出时间依赖性,在建模过程中必须保留这种依赖性。

序列数据的例子包括

  • 文本(文档中的单词或字符)
  • 时间序列(股票价格、天气数据)
  • 音频信号
  • 视频帧

序列数据的关键特征是元素不是独立分布的——特定元素出现的概率取决于序列中的前一个元素。

来源: chapter_recurrent-neural-networks/sequence.md4-42

序列建模的统计方法

仓库中实现了几种序列数据的统计建模方法

自回归模型

自回归模型根据固定窗口的先前观测值预测序列中的下一个元素

$$P(x_t \mid x_{t-1}, \ldots, x_{t-\tau})$$

其中 $\tau$ 是上下文窗口大小(嵌入维度)。由于随着序列的增长,考虑整个历史变得在计算上不可行,因此这种近似是必要的。

马尔可夫模型

马尔可夫模型通过假设未来仅取决于当前状态,而不取决于整个历史来简化

$$P(x_1, \ldots, x_T) = \prod_{t=1}^T P(x_t \mid x_{t-1})$$

这被称为一阶马尔可夫模型。虽然在计算上效率很高,但它限制了捕捉数据中长期依赖性的能力。

隐变量模型

这些模型维护一个隐藏状态,该状态捕获有关先前观测值的信息

$$h_t = g(h_{t-1}, x_{t-1})$$ $$\hat{x}t = P(x_t \mid h{t})$$

这种方法使得在保持模型计算可行性的同时,能够建模更长期的依赖关系。

来源: chapter_recurrent-neural-networks/sequence.md58-115 chapter_recurrent-neural-networks/sequence.md80-92

循环神经网络(RNN)

RNN通过维护一个捕获整个历史信息的隐藏状态,解决了传统统计模型的局限性。

RNN中的隐藏状态更新定义为

$$h_t = \tanh(x_t W_{xh} + h_{t-1} W_{hh} + b_h)$$

以及输出计算

$$o_t = h_t W_{hq} + b_q$$

其中 $W_{xh}$、$W_{hh}$ 和 $W_{hq}$ 是权重矩阵, $b_h$ 和 $b_q}$ 是偏置向量。

当沿着时间展开时,RNN看起来像

该图说明了RNN如何通过在不同时间步重用相同的权重来处理序列,这是RNN的关键特征。

来源: chapter_recurrent-neural-networks/rnn.md36-135

仓库中的RNN实现

仓库提供了从头开始的RNN实现和高级API的RNN实现。

从零开始实现

从零开始实现的核心组件包括

  1. 参数初始化 (get_params): 初始化RNN的权重矩阵和偏置向量。

    • W_xh: 输入到隐藏的权重
    • W_hh: 隐藏到隐藏的权重
    • b_h: 隐藏层偏置
    • W_hq: 隐藏到输出的权重
    • b_q: 输出层偏置
  2. 隐藏状态初始化 (init_rnn_state): 创建一个初始全零的隐藏状态。

  3. RNN前向传播 (rnn): 一次一个时间步地处理输入序列,更新隐藏状态并生成输出。

  4. 完整的RNN模型 (RNNModelScratch): 一个封装RNN功能的类。

这些组件被集成到 RNNModelScratch 类中

来源: chapter_recurrent-neural-networks/rnn-scratch.md130-175 chapter_recurrent-neural-networks/rnn-scratch.md222-285 chapter_recurrent-neural-networks/rnn-scratch.md317-354

高级API实现

高级API实现 (RNNModel) 利用了深度学习框架内置的RNN模块

高级API通过内部处理参数初始化和前向计算的细节,简化了实现。

来源: chapter_recurrent-neural-networks/rnn-concise.md51-156

字符级语言模型

该仓库中演示的一个关键应用是字符级语言模型,它根据先前的字符预测序列中的下一个字符。

实现字符级语言模型的完整流程包括

数据处理

  1. 加载文本:使用 read_time_machine() 加载H.G.威尔斯的《时间机器》数据集。
  2. 分词:使用 tokenize() 函数将文本转换为字符标记。
  3. 词汇表构建:使用 Vocab 类创建字符与数字索引之间的映射。
  4. 批次创建:使用随机抽样或顺序划分,以小批量的方式准备训练数据。

训练过程

RNN被训练为在给定先前字符的情况下预测序列中的下一个字符。训练循环包括

  1. 通过RNN进行前向传播
  2. 计算交叉熵损失
  3. 时间反向传播(BPTT)
  4. 梯度裁剪以防止梯度爆炸
  5. 参数更新

来源: chapter_recurrent-neural-networks/text-preprocessing.md45-99 chapter_recurrent-neural-networks/language-models-and-dataset.md274-312 chapter_recurrent-neural-networks/rnn-scratch.md652-809

时间反向传播(BPTT)

训练RNN需要时间反向传播,它将RNN沿着时间步展开,并将标准反向传播应用于由此产生的计算图。

BPTT可能导致梯度问题

  1. 梯度消失:对于长序列,梯度可能变得非常小,导致学习困难。
  2. 梯度爆炸:梯度也可能变得非常大,导致训练不稳定。

这些问题通过以下方式解决

  1. 梯度裁剪:当梯度的范数超过阈值时,对其进行缩放。

  2. 截断BPTT:将反向传播限制在固定数量的时间步。

来源: chapter_recurrent-neural-networks/bptt.md22-64 chapter_recurrent-neural-networks/rnn-scratch.md537-613

预测策略

仓库实现了两种预测策略

一步预测

预测下一个元素,给定真实的先前元素

多步预测

使用模型自身的预测作为未来预测的输入

多步预测具有挑战性,因为错误会随着时间的推移而累积,正如仓库中的示例所示。

来源: chapter_recurrent-neural-networks/sequence.md386-447 chapter_recurrent-neural-networks/rnn-scratch.md444-520

使用困惑度进行评估

仓库中语言模型的主要评估指标是困惑度

$$\exp\left(-\frac{1}{n} \sum_{t=1}^n \log P(x_t \mid x_{t-1}, \ldots, x_1)\right)$$

困惑度可以解释为模型在预测下一个标记时平均不确定的选择数量。较低的困惑度表示更好的模型性能。

实现通过以下方式计算困惑度

来源: chapter_recurrent-neural-networks/rnn.md234-297 chapter_recurrent-neural-networks/rnn-scratch.md652-682

总结

使用RNN进行序列建模为处理序列数据提供了一种强大的方法。d2l-ai仓库从零开始以及使用高级API实现了RNN,展示了它们在字符级语言建模中的应用。虽然RNN能有效捕获序列依赖性,但由于梯度问题,它们在长期依赖性方面面临挑战。梯度裁剪和截断BPTT等技术有助于缓解这些挑战,从而有效地训练序列模型。

来源: chapter_recurrent-neural-networks/index.md1-43 chapter_recurrent-neural-networks/rnn.md302-308