菜单

超网络

相关源文件

超网络 (HyperNetworks) 是为其他网络生成权重的神经网络。本文档介绍了 labml_nn 库中动态超网络的实现,重点关注 HyperLSTM。这一概念由 David Ha、Andrew Dai 和 Quoc V. Le 在论文《HyperNetworks》中提出。

超网络概述

超网络使用一个较小的网络来生成较大网络的权重。主要有两种变体:

  1. 静态超网络 (Static HyperNetworks):为卷积网络生成权重(核)。
  2. 动态超网络 (Dynamic HyperNetworks):为循环神经网络的每一步生成参数。

在传统循环网络中,参数在每一步都保持不变。然而,动态超网络在每一步生成不同的参数,从而以增加复杂性为代价实现更灵活的计算。

来源:labml_nn/hypernetworks/hyper_lstm.py7-35

HyperLSTM 架构

HyperLSTM 是一种动态超网络的具体实现,其中一个较小的 LSTM 网络(超网络)为较大的 LSTM(主网络)的权重矩阵生成缩放因子。

HyperLSTM 具有标准 LSTM 的结构,但每一步的参数都会被较小的超网络 LSTM 修改。这使得网络能够根据当前上下文调整其权重。

来源:labml_nn/hypernetworks/hyper_lstm.py81-134 labml_nn/hypernetworks/hyper_lstm.py35-70

权重缩放优化

动态计算整个权重矩阵的开销会非常大。相反,HyperLSTM 使用权重缩放方法,其中超网络生成向量来缩放静态权重矩阵的行。

对于形状为 $N_h \times N_h$ 的权重矩阵 $W_h$,直接计算将要求超网络输出一个大小为 $N_h \times N_h$ 的张量。权重缩放方法效率更高。

这种优化显著减少了参数数量,同时保留了动态调整权重的能力。

来源:labml_nn/hypernetworks/hyper_lstm.py42-69

HyperLSTM 实现细节

该实现包含两个主要类

  1. HyperLSTMCell:实现超网络机制的核心单元。
  2. HyperLSTM:一个创建 HyperLSTMCell 层网络的封装器。

HyperLSTMCell

HyperLSTMCell 类处理以下组件:

  1. 一个较小的 LSTM(超 LSTM),它接收连接的输入 [h_t-1, x_t]
  2. 用于生成权重特征向量的线性变换。
  3. 用于从特征向量生成缩放因子的线性变换。
  4. 与缩放权重矩阵的计算。
组件目的实现
超 LSTM处理输入和先前的状态self.hyper = LSTMCell(...)
特征向量生成创建 z 向量self.z_h, self.z_x, self.z_b
缩放因子生成创建 d 向量self.d_h, self.d_x, self.d_b
权重矩阵要缩放的静态权重self.w_h, self.w_x
层归一化归一化激活self.layer_norm, self.layer_norm_c

来源:labml_nn/hypernetworks/hyper_lstm.py85-149

前向计算

HyperLSTMCell 的前向传播遵循以下步骤:

  1. 连接当前输入和先前输出:x_hat = [h, x]
  2. 通过超 LSTM 处理:h_hat, c_hat = self.hyper(x_hat, h_hat, c_hat)
  3. 生成特征向量:z_h, z_x, z_b
  4. 生成缩放因子:d_h, d_x, d_b
  5. 应用缩放变换计算门控:i, f, g, o
  6. 更新细胞状态和输出:c_next, h_next

来源:labml_nn/hypernetworks/hyper_lstm.py151-198

HyperLSTM 类

HyperLSTM 类负责创建 HyperLSTMCell 实例的堆叠网络,从而能够创建深度 HyperLSTM 网络。它管理:

  1. 多层初始化
  2. 跨层的状态管理
  3. 通过堆叠单元处理序列

来源:labml_nn/hypernetworks/hyper_lstm.py202-273

使用示例

以下是 HyperLSTM 在文本生成模型中的典型用法:

该实现包含一个实验,训练 HyperLSTM 在莎士比亚数据集上预测文本。该实验将 HyperLSTM 与标准 LSTM 进行比较,从而对两种方法的性能进行比较。

来源:labml_nn/hypernetworks/experiment.py13-61 labml_nn/hypernetworks/hyper_lstm.py15-16

与标准 LSTM 的比较

HyperLSTM 增加了标准 LSTM 的复杂性,但提供了根据输入上下文动态调整权重的能力。

功能标准 LSTMHyperLSTM
参数所有时间步都固定在每个步骤动态调整
网络大小单一 LSTM 网络主 LSTM + 较小的超 LSTM
计算成本较低因额外的 LSTM 计算而更高
适应性限于细胞状态变化可以调整实际的权重矩阵
实现复杂性更简单通过权重生成逻辑更复杂

库中的标准 LSTM 实现遵循相同的基本结构,但没有用于权重生成的超网络组件。

来源:labml_nn/lstm/__init__.py20-100 labml_nn/hypernetworks/experiment.py57-65

与其他循环网络的关系

HyperLSTM 与其他高级循环网络(如循环高速网络 (Recurrent Highway Networks, RHN))相比,采取了不同的方法:

网络类型参数适应方法深度处理
HyperLSTM通过超网络动态生成权重多层动态加权 LSTM
标准 LSTM固定权重与门控多层堆叠
循环高速网络带有门控机制的高速连接通过循环深度在每层内进行深度处理

循环高速网络侧重于在每层内创建深度路径以实现梯度流,而 HyperLSTM 则侧重于根据输入上下文动态调整参数。

来源:labml_nn/recurrent_highway_networks/__init__.py19-108 labml_nn/hypernetworks/hyper_lstm.py30-34