基于世界模型的深度强化学习

266 阅读4分钟
原文链接: zhuanlan.zhihu.com

起源

"世界模型"被lecun认为是弥补RL不足和通向下一代AI的要点。虽然Model-Based RL不是新概念,但是世界模型的构建方法,以及提升其泛化能力、注意力能力和记忆容量的设计还是颇具挑战性。

DaH和LSTM的发明人在今年五月提出了基于非监督学习训练大规模RNN,用于表征世界模型,并针对强化学习框架设计了智能体架构和简单实现。

本文目的是学习DaH本文的核心要点,记录复现过程。

论文摘要和核心思想

本文探索构建流行的强化学习环境之下的生成神经网络。本文的「世界模型」可以无监督方式进行快速训练,以学习环境的稀疏时空表征。通过使用提取自世界模型的特征作为智能体的输入,训练面向任务的小规模控制器,用简单的策略。甚至可以完全通过由世界模型本身生成的虚幻梦境训练本文的智能体,并把从中学会的策略迁移进真实环境之中。

We explore building generative neural network models of popular reinforcement learning environments. Our world modelcan be trained quickly in an unsupervised manner to learn a compressed spatial and temporal representation of the environment. By using features extracted from the world model as inputs to an agent, we can train a very compact and simple policy that can solve the required task. We can even train our agent entirely inside of its own hallucinated dream generated by its world model, and transfer this policy back into the actual environment.

智能体的VMC架构

本研究中通过将智能体分为大型世界模型和小型控制器模型,来训练能够解决 RL 任务的大型神经网络。首先用无监督的方式训练一个大型神经网络V+M,来学习智能体世界的模型,然后训练小型控制器模型C来使用该世界模型执行任务。小型控制器使得算法聚焦于小搜索空间的信用分配问题,同时无需牺牲大型世界模型的容量和表达能力。通过世界模型来训练智能体,我们发现智能体学会一个高度紧凑的策略来执行任务。

  • Vision Model:采用Variational AutoEncoder,生成抽象、压缩的环境表征
  • Memory Model:采用RNN,可以结合历史信息,生成可预测未来状态的表征。
  • Controller:基于当下V的输出和M的预测,选择好的行动策略。



  • 下面的流程图展示了V、M和C如何与环境进行交互:首先每个时间步t原始的观察输入由V进行处理生成压缩后的z(t)。随后C的输入是z(t)和M的隐状态h(t)。随后C输出动作矢量a(t)影响环境。M以当前时刻的z(t)和a(t)作为输入,预测下一时刻的隐状态h(t+1)。



V模型采用VAE

环境在每一时间步上为智能体提供一个高维输入观测,这一输入通常是视频序列中的一个 2D 图像帧。VAE 模型的任务是学习每个已观测输入帧的抽象压缩表征z。



M模型采用RNN-MDN

让M模型预测未来,预测下一个时刻V产生的z 向量。由于自然中的很多复杂环境是随机的,我们训练 RNN 以输出一个概率密度函数 p(z) 而不是一个确定性预测z。

  • MDN是RNN的Mixed-Density-Network,输出的是预测的z的高斯混合模型。
  • h是hidden-state,用来表征智能体对自身行动所引发环境变化的预测。
  • T是temperature parameter,用来控制模型的不确定性,本文发现调节T对控制器C的训练有用。



C模型

在环境的展开过程中,控制器 (C) 负责决定动作进程以最大化智能体期望的累加奖励。在实验中,尽可能使 C 模型简单而小,并把 V 和 M 分开训练,从而智能体的绝大多数复杂度位于世界模型(V 和 M)之中。

复现过程


  • 依赖要求:目前版本的commit-e686342只能支持:python-3.5.2, gym-0.9.2(虽然说0.9.x都行,但是实测0.9.6就不行),tensorflow-1.8.0,numpy-1.13.3,box2d-2.3.2. 建议采用pip/pip3 install xxx==version_number 安装

参考资料

https://worldmodels.github.io/ 可交互论文