菜单

强化学习

相关源文件

本文概述了 labml_nn 仓库中的强化学习(RL)实现。该代码库包含了关键 RL 算法的实现:近端策略优化(PPO)和深度Q网络(DQN),以及广义优势估计(GAE)和优先经验回放等支持组件。这些实现旨在与 Atari 游戏(特别是 Breakout)配合使用。

架构概述

本仓库中的 RL 实现围绕两种主要算法构建,并共享游戏环境基础设施。

来源: labml_nn/rl/__init__.py labml_nn/rl/ppo/__init__.py labml_nn/rl/dqn/__init__.py labml_nn/rl/game.py

游戏环境

该仓库为 Atari 游戏提供了一个带有多进程支持的封装器,允许高效并行采样经验。

游戏封装器

Game 类封装了 OpenAI Gym 的 Atari 游戏环境,并带有多个预处理步骤:

  1. 对四帧应用相同的动作并返回最后一帧
  2. 将观测帧转换为灰度图并缩放到 84×84 像素
  3. 堆叠最后四个动作的四帧
  4. 跟踪回合信息(总奖励,回合长度)
  5. 在每次生命损失后重置环境(对于有多条生命的游戏)

来源: labml_nn/rl/game.py17-133

多进程工作器

为了高效采样,该实现使用 Python 的多进程功能并行运行多个游戏环境。

  1. Worker 类创建一个带有管道连接的新进程
  2. worker_process 函数在子进程中运行并管理 Game 实例
  3. 主进程通过管道连接与工作器通信

来源: labml_nn/rl/game.py135-168 labml_nn/rl/ppo/experiment.py134-141 labml_nn/rl/dqn/experiment.py91-102

近端策略优化 (PPO)

PPO 是一种策略梯度算法,允许每次采样进行多次梯度更新,同时防止策略发生大幅变化。

算法组件

PPO 的实现包含以下主要组件:

来源: labml_nn/rl/ppo/__init__.py labml_nn/rl/ppo/gae.py labml_nn/rl/ppo/experiment.py

裁剪 PPO 损失

PPO 的核心思想是使用裁剪替代目标来限制每次更新中的策略变化。

此方法将新旧策略之间的比率限制在 [1-ε, 1+ε] 范围内,从而防止不稳定的剧烈策略更新。

来源: labml_nn/rl/ppo/__init__.py140-179

广义优势估计 (GAE)

GAE 通过对不同 n 步回报进行加权平均来计算优势,平衡了偏差和方差。

  1. 计算时间差 (TD) 误差

  2. 递归计算优势

λ=1 时的 GAE 等同于蒙特卡洛回报,而 λ=0 时等同于 TD(0) 回报。

来源: labml_nn/rl/ppo/gae.py docs/rl/ppo/gae.html

PPO 实验

实验实现展示了在 Atari Breakout 上训练 PPO 智能体

  1. 创建多个工作进程进行并行环境采样
  2. 使用当前策略采样轨迹
  3. 使用 GAE 计算优势
  4. 使用小批量梯度下降更新策略
  5. 在训练期间跟踪指标

神经网络模型由处理游戏帧的卷积骨干网以及用于策略(动作概率)和值函数的独立头部组成。

来源: labml_nn/rl/ppo/experiment.py

深度Q网络 (DQN)

DQN 是一种基于价值的强化学习算法,它使用神经网络来近似 Q 值。

DQN 关键组件

来源: labml_nn/rl/dqn/__init__.py labml_nn/rl/dqn/model.py labml_nn/rl/dqn/replay_buffer.py labml_nn/rl/dqn/experiment.py

Q 函数损失

DQN 旨在最小化当前 Q 值估计与目标之间的时间差 (TD) 误差

DQN 使用单独的目标网络(更新频率较低)来稳定学习。

来源: labml_nn/rl/dqn/__init__.py34-165

双头网络架构

该实现使用双头网络架构,将状态值和动作优势函数分开。

这种架构允许网络学习哪些状态有价值,而无需为每个状态学习每个动作的效果。

来源: labml_nn/rl/dqn/model.py

优先经验回放

该实现使用优先经验回放,以更频繁地采样重要转换。

  1. 转换以基于其 TD 误差的优先级存储: p_i = |δ_i| + ε
  2. 采样概率与优先级的 α 次方成正比: P(i) ∝ p_i^α
  3. 重要性采样权重校正更新中的偏差: w_i = (1/N * 1/P(i))^β
  4. 二叉线段树可实现高效的求和和最小操作

缓冲区实现使用高效的二叉线段树进行优先级求和和基于优先级的采样等操作。

来源: labml_nn/rl/dqn/replay_buffer.py

DQN 实验

该实验展示了在 Atari Breakout 上训练 DQN 智能体

  1. 创建工作进程进行环境交互
  2. 使用 ε-贪婪探索收集经验
  3. 将转换存储在优先回放缓冲区中
  4. 采样批次并更新网络
  5. 定期更新目标网络
  6. 调度探索和优先回放参数

来源: labml_nn/rl/dqn/experiment.py

训练基础设施

PPO 和 DQN 的实现都共享通用的训练基础设施模式

关键的共享模式包括:

  • 用于并行采样的工作进程
  • 通用观测预处理
  • 指标跟踪和日志记录
  • 动态超参数(学习率、探索率等)
  • 结构化的训练循环

来源: labml_nn/rl/ppo/experiment.py labml_nn/rl/dqn/experiment.py labml_nn/rl/game.py

总结

本仓库中的强化学习实现提供了 PPO 和 DQN 算法清晰、模块化且文档齐全的实现。它们包含关键的增强功能,例如:

  1. PPO 的广义优势估计
  2. DQN 的双头网络架构
  3. DQN 的优先经验回放
  4. DQN 中用于稳定性的双 Q 学习
  5. 高效的并行环境采样

这些实现可作为理解强化学习算法的教育参考和进一步研究与应用的构建块。

来源: labml_nn/rl/__init__.py