菜单

优化器

相关源文件

本页面介绍了 labml_nn/optimizers 模块中实现的基于梯度的优化算法。它涵盖了诸如 Adam、AMSGrad、RAdam、AdaBelief 等自适应优化器,以及 Noam 和 AdamWarmup 等学习率调度优化器。这些实现既是神经网络优化的实用工具,也是具有底层数学和算法注释解释的教育资源。

有关使用这些优化器进行训练的信息,请参阅训练基础设施实验配置

优化器架构

此模块中的优化器通过分层类结构实现,将通用功能分解到基类中。

继承层次结构

来源:labml_nn/optimizers/__init__.py70-164 labml_nn/optimizers/adam.py50-214 labml_nn/optimizers/amsgrad.py27-109 labml_nn/optimizers/radam.py148-274 labml_nn/optimizers/ada_belief.py45-159 labml_nn/optimizers/noam.py20-64 labml_nn/optimizers/adam_warmup.py18-61

优化流程

所有优化器的优化步骤都遵循以下通用流程

来源:labml_nn/optimizers/__init__.py122-164 labml_nn/optimizers/adam.py194-214 labml_nn/optimizers/radam.py178-198

基类和实用工具

通用自适应优化器(GenericAdaptiveOptimizer)

这是模块中所有自适应优化器的基类。它定义了一个通用的结构,包含由子类实现的模板方法。

  • init_state:为参数初始化优化器状态
  • step_param:实现参数特定的更新逻辑
  • step:协调优化过程的模板方法

来源:labml_nn/optimizers/__init__.py70-164

权重衰减(WeightDecay)

WeightDecay 类实现了具有灵活选项的 L2 权重衰减:

模式描述
weight_decouple=True, absolute=True通过因子 (1-weight_decay) 进行直接参数衰减
weight_decouple=True, absolute=False通过因子 (1-lr*weight_decay) 进行衰减
weight_decouple=Falseweight_decay * parameter 添加到梯度中

这种分离使得优化器能够以一致的方式处理不同实现中的权重衰减。

来源:labml_nn/optimizers/__init__.py167-218

优化器实现

Adam

Adam(自适应矩估计)维护梯度及其平方的指数移动平均值,以独立地调整每个参数的学习率。

Adam 的更新规则是

m_t = β₁ * m_{t-1} + (1 - β₁) * g_t
v_t = β₂ * v_{t-1} + (1 - β₂) * g_t²
m̂_t = m_t / (1 - β₁ᵗ)
v̂_t = v_t / (1 - β₂ᵗ)
θ_t = θ_{t-1} - α * m̂_t / (√v̂_t + ε)

其中

  • m_t 是第一阶矩(动量)
  • v_t 是第二阶矩(非中心方差)
  • β₁, β₂ 是矩的衰减率
  • α 是学习率
  • ε 是一个小的常数,以防止除以零

该实现包括一个 optimized_update 标志,通过重新排列项来优化计算:

θ_t = θ_{t-1} - α * √(1-β₂ᵗ)/(1-β₁ᵗ) * m_t / (√v_t + ε̂)

其中 ε̂ = ε * (1-β₂ᵗ)

来源:labml_nn/optimizers/adam.py1-214

AMSGrad

AMSGrad 是 Adam 的一个变体,通过维护所有过去第二阶矩的最大值来解决收敛问题。

v̂_t = max(v̂_{t-1}, v_t)

这可以防止自适应学习率下降过快,这种情况在某些情况下可能会发生在 Adam 中。

AMSGrad 通过添加一个 max_exp_avg_sq 状态变量并使用 torch.maximum() 来维护最大值来实现这一点。

来源:labml_nn/optimizers/amsgrad.py1-109

RAdam (Rectified Adam)

RAdam 通过添加一个修正项来解决训练早期自适应学习率方差过大的问题。

r_t = sqrt(((ρ_t-2)*(ρ_t-4)*ρ_∞)/((ρ_∞-2)*(ρ_∞-4)*ρ_t))

其中

  • ρ_∞ = 2/(1-β₂) - 1
  • ρ_t = ρ_∞ - 2t*β₂ᵗ/(1-β₂ᵗ)

ρ_t ≥ 5 时,修正项应用于 Adam 更新。当修正项无法可靠计算时,RAdam 可以退回到带动量的 SGD。

来源:labml_nn/optimizers/radam.py1-294

AdaBelief

AdaBelief 通过跟踪梯度的方差而不是其平方来修改自适应学习率的计算方式。

s_t = β₂ * s_{t-1} + (1 - β₂) * (g_t - m_t)²

这有助于提高稳定性和泛化能力。AdaBelief 还可以通过 rectify 标志整合 RAdam 的修正机制。

来源:labml_nn/optimizers/ada_belief.py1-160

Noam 优化器

Noam 实现了“Attention Is All You Need”论文中的学习率调度,即先预热后衰减。

lr = α * (1/√d_model) * min(1/√t, t/warmup^(3/2))

其中

  • α 是基础学习率
  • d_model 是模型维度
  • warmup 是预热步数
  • t 是当前步数

此调度在 warmup 步数内线性增加学习率,然后按步数平方根的倒数比例降低学习率。

来源:labml_nn/optimizers/noam.py1-89

AdamWarmup

一种更简单的预热实现,在预热期内线性增加学习率。

lr = step*α/warmup  (during warmup)
lr = α              (after warmup)

这有助于稳定训练的早期阶段。

来源:labml_nn/optimizers/adam_warmup.py1-62

使用示例

优化器可以像标准的 PyTorch 优化器一样使用,这在 MNIST 示例中有所展示。

该模块在 mnist_experiment.py 中提供了一个可配置的实验设置,允许以最小的代码更改测试不同的优化器。

来源:labml_nn/optimizers/mnist_experiment.py22-137

参数组和状态

参数组

PyTorch 优化器支持参数组,允许对不同参数集使用不同的超参数。

参数组在迁移学习中特别有用,您可能希望对预训练层和新添加层使用不同的学习率。

来源:labml_nn/optimizers/__init__.py34-56

优化器状态

优化器为每个参数维护状态,包括:

  • 步数(更新次数)
  • 第一阶矩(动量)
  • 第二阶矩(梯度平方或方差)
  • 额外的算法特定状态(例如,AMSGrad 的过去第二阶矩的最大值)

这些状态通过 optimizer.state[parameter] 访问,并在首次使用参数时自动初始化。

来源:labml_nn/optimizers/__init__.py58-61

性能考量

自定义 Adam 实现与 PyTorch 内置 Adam 之间的性能比较显示出相似的性能,在某些情况下自定义实现稍快。

TorchAdam warmup: 222.59ms
TorchAdam: 1,356.01ms
MyAdam warmup: 119.15ms
MyAdam: 1,192.89ms

这表明该模块中的实现不仅具有教育意义,而且对于实际使用也是高效的。

来源:labml_nn/optimizers/performance_test.py1-56