本文概述了 labml_nn 仓库中的强化学习(RL)实现。该代码库包含了关键 RL 算法的实现:近端策略优化(PPO)和深度Q网络(DQN),以及广义优势估计(GAE)和优先经验回放等支持组件。这些实现旨在与 Atari 游戏(特别是 Breakout)配合使用。
本仓库中的 RL 实现围绕两种主要算法构建,并共享游戏环境基础设施。
来源: labml_nn/rl/__init__.py labml_nn/rl/ppo/__init__.py labml_nn/rl/dqn/__init__.py labml_nn/rl/game.py
该仓库为 Atari 游戏提供了一个带有多进程支持的封装器,允许高效并行采样经验。
Game 类封装了 OpenAI Gym 的 Atari 游戏环境,并带有多个预处理步骤:
为了高效采样,该实现使用 Python 的多进程功能并行运行多个游戏环境。
Worker 类创建一个带有管道连接的新进程worker_process 函数在子进程中运行并管理 Game 实例来源: labml_nn/rl/game.py135-168 labml_nn/rl/ppo/experiment.py134-141 labml_nn/rl/dqn/experiment.py91-102
PPO 是一种策略梯度算法,允许每次采样进行多次梯度更新,同时防止策略发生大幅变化。
PPO 的实现包含以下主要组件:
来源: labml_nn/rl/ppo/__init__.py labml_nn/rl/ppo/gae.py labml_nn/rl/ppo/experiment.py
PPO 的核心思想是使用裁剪替代目标来限制每次更新中的策略变化。
此方法将新旧策略之间的比率限制在 [1-ε, 1+ε] 范围内,从而防止不稳定的剧烈策略更新。
来源: labml_nn/rl/ppo/__init__.py140-179
GAE 通过对不同 n 步回报进行加权平均来计算优势,平衡了偏差和方差。
计算时间差 (TD) 误差
递归计算优势
λ=1 时的 GAE 等同于蒙特卡洛回报,而 λ=0 时等同于 TD(0) 回报。
来源: labml_nn/rl/ppo/gae.py docs/rl/ppo/gae.html
实验实现展示了在 Atari Breakout 上训练 PPO 智能体
神经网络模型由处理游戏帧的卷积骨干网以及用于策略(动作概率)和值函数的独立头部组成。
来源: labml_nn/rl/ppo/experiment.py
DQN 是一种基于价值的强化学习算法,它使用神经网络来近似 Q 值。
来源: labml_nn/rl/dqn/__init__.py labml_nn/rl/dqn/model.py labml_nn/rl/dqn/replay_buffer.py labml_nn/rl/dqn/experiment.py
DQN 旨在最小化当前 Q 值估计与目标之间的时间差 (TD) 误差
DQN 使用单独的目标网络(更新频率较低)来稳定学习。
来源: labml_nn/rl/dqn/__init__.py34-165
该实现使用双头网络架构,将状态值和动作优势函数分开。
这种架构允许网络学习哪些状态有价值,而无需为每个状态学习每个动作的效果。
该实现使用优先经验回放,以更频繁地采样重要转换。
p_i = |δ_i| + εP(i) ∝ p_i^αw_i = (1/N * 1/P(i))^β缓冲区实现使用高效的二叉线段树进行优先级求和和基于优先级的采样等操作。
来源: labml_nn/rl/dqn/replay_buffer.py
该实验展示了在 Atari Breakout 上训练 DQN 智能体
来源: labml_nn/rl/dqn/experiment.py
PPO 和 DQN 的实现都共享通用的训练基础设施模式
关键的共享模式包括:
来源: labml_nn/rl/ppo/experiment.py labml_nn/rl/dqn/experiment.py labml_nn/rl/game.py
本仓库中的强化学习实现提供了 PPO 和 DQN 算法清晰、模块化且文档齐全的实现。它们包含关键的增强功能,例如:
这些实现可作为理解强化学习算法的教育参考和进一步研究与应用的构建块。