菜单

高级 RNN

相关源文件

此页面介绍了D2L.ai代码库中实现的高级循环神经网络(RNN)架构。这些架构解决了基本RNN的局限性,例如梯度消失问题和捕获长期依赖关系的困难。有关基本RNN的信息,请参阅循环神经网络

架构概述

D2L.ai的实现包括几种高级RNN变体及其上构建的技术。

来源

门控循环单元 (GRU)

门控循环单元引入了门控机制来控制信息在网络中的流动,并缓解了梯度消失问题。

关键组件

  1. 重置门 (R_t):控制在计算候选隐状态时使用多少前一时刻的隐状态。

    • 公式:R_t = σ(X_t·W_xr + H_{t-1}·W_hr + b_r)
  2. 更新门 (Z_t):控制使用新的候选状态与前一状态的比例。

    • 公式:Z_t = σ(X_t·W_xz + H_{t-1}·W_hz + b_z)
  3. 候选隐状态 (H̃_t):提出的新隐状态。

    • 公式:H̃_t = tanh(X_t·W_xh + (R_t ⊙ H_{t-1})·W_hh + b_h)
  4. 隐状态更新:

    • 公式:H_t = Z_t ⊙ H_{t-1} + (1-Z_t) ⊙ H̃_t

在实现中,GRU的计算方式如chapter_recurrent-modern/gru.md341-369所示。

来源

长短期记忆 (LSTM)

LSTM网络引入了具有三种不同类型门的记忆单元,以在长时间内维护信息。

关键组件

  1. 输入门 (I_t):控制何时写入记忆单元。

    • 公式:I_t = σ(X_t·W_xi + H_{t-1}·W_hi + b_i)
  2. 遗忘门 (F_t):控制何时擦除记忆单元。

    • 公式:F_t = σ(X_t·W_xf + H_{t-1}·W_hf + b_f)
  3. 输出门 (O_t):控制何时从记忆单元读取。

    • 公式:O_t = σ(X_t·W_xo + H_{t-1}·W_ho + b_o)
  4. 候选记忆单元 (C̃_t):提出的新记忆单元内容。

    • 公式:C̃_t = tanh(X_t·W_xc + H_{t-1}·W_hc + b_c)
  5. 记忆单元更新:

    • 公式:C_t = F_t ⊙ C_{t-1} + I_t ⊙ C̃_t
  6. 隐状态更新:

    • 公式:H_t = O_t ⊙ tanh(C_t)

chapter_recurrent-modern/lstm.md322-356中的LSTM实现直接遵循了这些方程。

来源

深度循环神经网络

深度RNN堆叠多个RNN层以增加模型容量,并捕获序列中的分层模式。

实现

通过在初始化RNN时设置num_layers参数来实现深度RNN。

第 l 个隐层的数学公式为:

  • H_t^(l) = φ_l(H_t^(l-1)·W_xh^(l) + H_{t-1}^(l)·W_hh^(l) + b_h^(l))

其中 H_t^(0) = X_t(时间 t 的输入)

来源

双向RNN

双向RNN同时处理序列的正向和反向,以捕获来自过去和未来的上下文。

实现

通过设置bidirectional=True参数来创建双向RNN。

数学公式为:

  • 正向:→H_t = φ(X_t·W_xh^(f) + →H_{t-1}·W_hh^(f) + b_h^(f))
  • 反向:←H_t = φ(X_t·W_xh^(b) + ←H_{t+1}·W_hh^(b) + b_h^(b))
  • 合并:H_t = [→H_t, ←H_t](拼接)

来源

编码器-解码器架构

编码器-解码器架构通过将输入序列编码为固定表示,然后解码为输出序列,从而实现序列到序列的转换。

实现

编码器-解码器架构通过三个主要类实现:

  1. Encoder Interface:

  2. Decoder Interface:

  3. EncoderDecoder Class:

来源

序列到序列学习

序列到序列(Seq2Seq)学习使用RNN实现编码器-解码器架构,用于机器翻译等任务。

关键组件

  1. Seq2SeqEncoder:使用RNN对输入序列进行编码。

  2. Seq2SeqDecoder:使用另一个RNN生成输出序列。

来源

Beam search 通过在解码过程中维护多个候选序列来改进序列生成。

算法步骤

  1. 在时间步1,选择条件概率最高的k个词。
  2. 在随后的每个步骤中,对于k个序列中的每个序列,考虑所有可能的下一个词。
  3. 从这 k×|Y| 种可能的续写中,选择总概率最高的k个序列。
  4. 继续此过程,直到每个候选序列都产生一个结束符或达到最大长度。
  5. 选择具有最高得分的序列:(1/L^α) × log P(y_1, ..., y_L)

Beam search 的实现可以在chapter_recurrent-modern/beam-search.md97-152中可以找到。

来源

总结

本页面介绍了D2L.ai框架中实现的高级RNN架构。

架构主要功能主要用途
GRU重置门和更新门高效捕获长期依赖
LSTM记忆单元和三个门(输入、遗忘、输出)序列中的长期记忆
深度RNN多个堆叠的RNN层分层表示学习
双向RNN正反两个方向处理序列需要过去和未来上下文的任务
编码器-解码器转换可变长度序列序列到序列任务
Seq2Seq编码器-解码器的RNN实现机器翻译、文本摘要
束搜索维护多个候选序列改进的序列生成质量

这些组件构成了序列处理中众多应用的基础,包括机器翻译、语音识别和文本生成。