查看原文
其他

硬干货!!如何看懂ChatGPT里的RLHF公式以及相关实现

ShuYini AINLPer 2023-07-10

点击上方AINLPer,设为星标
更多干货,第一时间送达
| 转自知乎中森

引言

最近开源社区里的基于ChatGPT的问答和LLAMA模型微调的羊驼系列非常火爆。「而笔者所看到的大部分低成本复现ChatGPT项目(除了ClossalAI)都只包含了基于人类偏好回复的SFT阶段,而不包括后面的RLHF阶段」。同时网上有几个开源的使用PPO(或类PPO算法)来更新语言模型的代码库,他们的实现略有不同,有将问答建模成基于词级别的马尔科夫决策过程的,也有基于句子级别的建模成bandit problem的方式。也有添加了不同约束项的各种实现。如何看明白他们的涵义和挑选出适合当前任务/实验所需的部分加以改进,是个现实的问题。

所以这篇笔记将会「记载笔者为了入门RLHF看懂他们的公式设计意图的历程」,并整理笔者最近一段时间在学习跟ChatGPT相关的PPO知识时读过的一些直接相关的技术博客,论文等资料,做简单的点评以供未来笔者回忆时查询。

强化学习基础概念

笔者除了在硕士时了解过value-based RL之后三四年没碰过强化学习了,Lilian-Weng的技术博客非常全面,细致地梳理了RL的基本概念,value-based到policy-based的发展脉络。非常适合有类似需要的复习一遍。其中value-based 的方法里重点复习一下td-lambda的思想和推导过程,不仅非常核心且推导过程会在下面的Generalized Advantage Estimation(GAE)里再见到。

参考:https://lilianweng.github.io/posts/2018-02-19-rl-overview/#monte-carlo-methods

策略梯度基础入门

笔者阅读上一部份lilian-weng的 Long Peek into RL里的policy gradients章节时觉得不够清楚。于是在这个部分推荐读一位 UCB的博士(现CMU博士后的robotics大佬)Daniel Seita在17年整理的策略梯度推导。即详尽,又清晰,读起来神清气爽,「读完后对策略梯度推导过程,其训练时会有哪些问题(如多步采样引起的方差过大不稳定),如何解决(引入baseline),为什么这样解决(为什么这样引入baseline既无偏又能降低方差),有个比较宏观的了解」

参考:https://danieltakeshi.github.io/2017/03/28/going-deeper-into-reinforcement-learning-fundamentals-of-policy-gradients/

GAE论文

在读Daniel Seita的策略梯度推导时,其实已经包括了GAE将要涵盖的一些概念,即如果baseline是advantage function我们可以获得一个最低方差的无偏估计。可如何有效地预估和近似这个优势方程,则需要构建一个近似的有偏优势方程(gamma-just-estimator)。以下是笔者总结的GAE几点摘要速览

1:策略梯度的朴素估计方差过大,难以收敛。通过引入baseline function可以很好地减弱这个问题的影响。其中如果baseline function是优势方程advantage function时理论上可以得到最低方差。(即每次策略梯度更新的方向必然是增大那些比当前状态的“平均”动作更优的动作,降低比平均动作更劣的动作)

2:注意在策略梯度的推导里(GAE论文和上面Daniel Seita的推导),我们并没有引入衰减因子discount factor。这样得到的使用优势方程作为基线方程的策略梯度估计虽然无偏且方差小,但是优势方程不好预估。我们转而引入衰减因子,来预估一个无偏的优势方程。这将使我们得到无偏的带衰减因子的有偏策略梯度。但注意我们的目标是预估一个没有引入衰减因子的无偏策略梯度,但引入了衰减因子的策略梯度其本身是有偏的(哪怕对其的预估是无偏的),所以GAE是一个对无偏策略梯度g使用了有偏策略梯度的近似。

3:关于gamma-just-estimator的证明可以看Daniel-seita 对这篇论文的点评笔记。

参考:https://danieltakeshi.github.io/2017/04/02/notes-on-the-generalized-advantage-estimation-paper

4:因为GAE论文里的公式(10)如下图,为了得到优势方程的预估,我们需要预估值方程V(st)。且该值方程需满足定义:当且仅当他们相等时,该预估是gamma-just,且我们得到衰减优势方程的无偏估计。注意这里对定义在无限步长上的值方程的加权配比以及求和用到了TD-lambda一模一样的推理过程,只是这次是用来预估优势方程而不是值方程。5:lambda和gamma都是衰减系数,只是涵义不同。我们知道lambda=0在TD-lambda里代表的是Q-learning的一步近似,而Lambda=1是等价于MC的无偏估计。所以lambda值的调整在GAE的语境下是类似的调整bias-variance的一个权衡值。gamma的情况非常类似,gamma是我们选取的奖励值的一个衰减因子,它的引入不仅影响值方程的尺度大小同时也带来了偏差。但作者提到当gamma<1时,无论值方程的准确度如何都会带来偏差(值方程的准确度实际上也受gamma的很大影响,因为其决定了值方程所涵盖的步数),而lambda<1时仅当值方程不准确时会带来偏差。这是由他们的作用与涵义所决定的,同时作者提到他们实验里lambda的最佳值一般显著低于gamma值。至此我们所希望估计的有偏策略梯度就变成了:

6:Reward shaping 只是以另外一种角度看待上面推出的GAE。因为将奖励重新映射后上面所推的所有公式不变,唯一改变的只是lambda和gamma被统一到一个参数上。

7:作者使用了TRPO的思想来更新策略梯度,但同时也使用了Trust Region的思想来更新值方程。作者认为这样能够更好地防止值方程过度拟合最近批次的数据。

GAE的主要贡献在于提出了一个使用gamma, lambda来权衡偏差和方差的泛化优势预测器GAE,并且使用trust region的算法来优化策略和值方程,最终使得一些更复杂(状态空间和动作空间维度更广)的问题求解成为可能

Natural Policy Gradient + TRPO + PPO

前面几节的内容都在讨论策略梯度的整体更新公式,但具体到如何策略更新时我们会面临很多问题,比如采样效率的问题,比如选择步长的问题(这两个实际是关联问题)。步长的选择之所以是问题是因为参数更新的距离并不等价于我们的策略分布更新的距离。试考虑以下两组高斯分布间的参数距离和分布距离,他们的参数差(均值和方差的均方误差MSE)相等,但明显他们的分布差极大。「而强化学习相比于常见的有监督学习更不稳定,其中一部分原因便是因为策略的更新所导致的观测状态,奖励的分布变化。无法很好地决定训练步长的一个副作用就是样本利用率不高,收敛慢或难以收敛。而natural policy gradient , TRPO 和PPO则是试图解决这个问题的几种思路」

Natural Policy Gradient

其中要看懂natural policy gradient需要几项一些优化理论的前置知识(包括常见的拉格朗日松弛法,Fisher-info-Matrix以及它和KL散度,Hessian矩阵的关联等),如果有不熟悉的读者可以先查阅这几个基础概念的定义,再看natural policy gradient就水到渠成了(btw 扩散模型里score matching相关的内容涉及了大量这方面的优化概念)。

具体来说因为在参数空间上的欧式距离不等价于我们所希望的策略分布上的距离,我们转而使用KL散度来约束我们每次梯度更新的大小。这个约束可以用拉格朗日松弛法来将约束转化为惩罚项。对这个惩罚项里的KL散度求解可以用泰勒展开来近似(因为我们不知道策略分布的具体形式,如果遍历状态和动作空间来求解过于复杂)。其中会用到KL散度的泰勒二阶展开在theta值附近等于Fisher Matrix的性质。而求二阶导我们也不用特地求解Hessian矩阵(计算复杂度在左右, 其中K为单个元素计算复杂度),而是可以利用score即的梯度来求得, 计算复杂度大大下降。最终,我们的计算会变成以下结果:梯度的形式由原先的损失梯度变为使用费舍矩阵的逆乘上原梯度(使得梯度考虑到KL散度的分布限制),再乘上我们的动态步长(约束我们的更新必须小于设定的)。其中以KL散度约束梯度下降并使用FIM求解的方法就叫做自然梯度法,natural gradients method 也是该论文名称的由来。

参考:https://zhuanlan.zhihu.com/p/563212799

参考:https://towardsdatascience.com/natural-policy-gradients-in-reinforcement-learning-explained-2265864cf43c

TRPO

自然策略梯度听起来很美好,但实际上使用有非常多的问题。其中几个问题在于对费舍矩阵(theta的outer product)求逆的计算量,以及该算法其本身的不稳定(采样来近似期望,泰勒展开的近似,局域临近性的假设等)。这些问题导致了策略梯度方法往往在很多问题上效果不佳或者不收敛。而TRPO对这几个问题提出了针对性的改进,使得其能够在大量参数下(以往的方法需要求解计算复杂度极高的Hessian 矩阵)且相对复杂的场景下取得优秀的效果。

在此笔者记录几个笔者认为关键的要点:

1:自然梯度法所求解的策略梯度更新方法有着极多的限制,为了折中计算的复杂度使得计算更加可行,TRPO进行了多次近似。

2:其中原文中非常关键的一个式子如下。注意两点。其一是新旧策略的差别是定义在所有的状态和动作空间上的优势函数求和(参考Kakade的CPI Conservative Policy Iteration,只是CPI里的优势函数是定义在时间步上求和,而TRPO是定义在状态空间上求和)。即如果可以保证下面式3的右侧优势函数每轮更新都单调递增的话,我们就保证了策略更新所带来的期望奖励单调递增。其二 是对新策略的期望奖励的近似是通过旧策略来得到的。其依据是如果新旧策略差别不明显时,可以用旧策略对所有状态的的访问频率来近似估计新策略的期望折扣奖励。这样的近似避免了要计算所有可能的新策略对所有状态所对应的访问频率,所带来的优化难题。3:TRPO是一种minorization-maximization(MM)算法,作者通过拓展了conservative-policy-iteration里提出的下界的形式, 实现策略的目标在迭代中只增不减。其中CPI的优化下界如下图公式6所显示。TRPO通过将alpha替换为Total_Variance的距离测度后,推导证明了新的下界来取代CPI的下界,其中因为total varaicne的测度距离一定小于等于KL测度距离,最终得到了式9。4:以上的理论部分证明的是如果我们按照以下的式子优化,则可以保证我们每步更新的策略的单调提升(即期望奖励值 的提升)。但是如果按照上述所推导的策略更新方式收敛过慢,步幅太小。于是还是回到自然梯度法里的引入硬约束trust region 。于是得到式11。但式11要求基于每个状态的策略分布的最大KL距离小于硬约束,实际太难求解于是转而近似要求所有状态上的策略分布距离的KL距离均值小于约束值得到式12。将式12里的损失函数展开成其定义,即上述式3里提到的最大化每个状态里的优势函数值的提升,得到式13。将式13里的优势函数转化为Q,不影响其单调性。同时将新策略的动作空间上的求和转换为重要性采样的预估值后我们得到式14。也就是我们的最终优化目标。5: 「但是需要注意的是TRPO的理论更新方式和实际实现的更新方式有较大差异」(论文第六章):如果我们严格按照TRPO的公式推导来更新每步的梯度,则更新幅度会特别小收敛速度特别慢(因为惩罚项C过大)。所以TRPO的实际更新方式(采用基于硬约束项)如下:1. 「使用了共轭梯度法」(Conjugate Gradients)来近似求解硬约束项的更新步幅,取代了自然梯度法里用费舍矩阵的逆求解的方法。2.「用线性搜索(Line Search)的方法」,来检验共轭梯度法求解的更新步幅是否符合硬约束项的要求。如果不符合,则指数缩小步幅直到符合约束(自然梯度法本身并不检查这一项,且由于泰勒展开的近似性,局域临近性的要求等,硬约束本身常被违背)。3. 除了检验更新前后的策略分布的KL散度的距离外,该算法还实际检测了更新后我们的代理损失函数(surrogate loss)式14,是否确实提升。「理解TRPO的最终做法对理解PPO的实现方式为何简单有效十分重要,值得仔细阅读论文」

PPO

前面提到TRPO虽然用了详细的推导论证了基于惩罚项的理论更新值可以保证策略回报单调提升(即TRPO式9),但实际实现时依然使用了类似自然梯度法的基于硬约束的方式来决定梯度更新的步长。只是在计算硬约束的情况下,TRPO使用了共轭梯度法,线性搜索和策略检验等方法来检测更新效果。PPO的实现也分两种:

第一种是在式9的基础上动态调整KL约束项的惩罚系数C,来达到约束参数更新的幅度的目的。即参数的更新应尽可能的小以保证训练的稳定,但同时应在分布空间更新得足够的大以使得策略分布发生改变。如下图的算法所示,对于更新前后的KL距离,我们设定一个目标约束值(一个可调整的超参)。如果当前的更新大于1.5倍目标约束值(1.5是个经验设定值,并非推导所得)则策略分布波动过大,我们增大惩罚系数(乘2),若KL距离小于目标约束值的2/3,则策略分布更新过小,我们减小惩罚系数(除以2)。虽然PPO的这种更新方式并没有严格的数学推导,并且时不时地更新幅度会过大或过小,违背了策略单调提升的要求。但从实际部署的效果来看,PPO往往能够快速调整惩罚系数的权重,来快速适应训练的不同阶段的要求,并且效果良好。第二种方法同样对应TRPO里使用trust region 直接设置KL散度的最大更新约束值。但是和使用KL散度的约束值不同的是该方法是对优势函数做了限制。其中当重要性采样的系数大于或小于时,该更新会被忽略(根据裁剪后的损失不依赖于参数所以不产生任何梯度信息)。本质上是忽略了差异过大的新策略所产生的优势函数值,保证了训练的稳定性和梯度更新的单调递增所需的步幅小的要求。注意因为优势函数的值可以取正负,且重要性系数是新策略的可能性除以旧策略的可能性,为了最大化优势函数的期望值我们希望当优势函数A>0时,系数大于一(但不能过大),优势函数A<0时,系数小于一(但不能过小)。理论上因为采用min函数来选取优化目标当A>0时,系数的argmax值取小于的时候,以及A<0时,系数的argmax值取大于的时候均不受裁剪的影响。但这种情况是否会发生,多经常发生,是否影响训练效果笔者尚不清楚

参考:https://zhuanlan.zhihu.com/p/563166533

参考:https://towardsdatascience.com/proximal-policy-optimization-ppo-explained-abed1952457b

策略梯度发展梳理

在阅读完以上几篇策略梯度的技术博客后,再回过头来看lilian-weng在intro to RL之后的policy gradients 脉络梳理就比较容易抓住重点了。其中注意目前开源的大部分RLHF的实现都基于On Policy的Actor Critic算法

InstructGPT & RL4LM

关于这篇经典的RLHF论文,笔者在这里只讨论里面的RLHF公式以及相对应的一些代码库是如何实现的。首先在instructGPT里,我们将期望优化的语言模型用策略来表示(该语言模型的参数将在RLHF阶段更新),而经过SFT的语言模型用来表示(该语言模型不参与RLHF阶段的更新),同时反馈模型RM用表示。这个公式的涵义是我们希望最大化反馈模型的奖励值,但同时希望我们RLHF阶段的语言模型的输出分布,不要距离SFT指令微调后的分布太远(即虽然InstructGPT里将语言模型的回答建模成了bandit problem,但第二项KL散度的约束是基于词粒度的,防止语言模型过度针对奖励值优化)。除此之外,模型还需要做一些预训练任务(下式第三项)以保证其预训练的能力不会过分退化。

注意式2长得很像TRPO里的优化公式9,笔者曾疑惑为何约束更新的策略不是上一步的策略而是一个静态的SFT策略模型。这样是否依然能够满足策略梯度模型性能单调增强的KL临近性需求。但仔细回想上文里TRPO到PPO的演化进程就可以想明白,公式2里的KL散度约束,的确如原文所说,只是为了防止所优化的语言模型过度倾向于奖励模型。而真正策略梯度更新所要求的策略分布的变化值小于约束值是体现在了公式之外的PPO-clipping算法里。值得一提的是,在下面的RL4LM的消融实验里提出了一个很有意思也很符合直觉的观点,即如果SFT模型本身仍不具备RLHF所希望其习得的能力或能力较差,那么SFT的约束反而会有反效果。注意看以下clossal AI里的RLHF-PPO实现 ,里面Policy的损失很明显就是按照PPO的方式一字不差地实现。同时如果对RLHF的全流程不是特别清楚可以参照ClossalAI 的以下流程图。非常清晰地介绍了ActorCritic这个算法是如何被运用在在InstructGPT中的RLHF流程。值得注意的是笔者与研究院的同事在训练RLHF的时候发现要同时实现多个模型(比如actor,critic, reward model, ref model)的加载和切换比较麻烦(如果内存有限的话)。需要一些额外的工程设计(比如triton server 或者模型并行等技术)。同时不同代码库的实现方式和架构不同(笔者原以为部分代码库并没有实现SFT模型的约束,经评论指正后笔者在三个代码库分别找到了对应的位置实现,都放在了Advantage计算时的代码块里)。

参考:

ClossalAI:https://github.com/hpcaitech/ColossalAI

RL4LM:https://github.com/allenai/RL4LMs

TRL: https://github.com/lvwerra/trl/

其实近期有不少文章在探讨RLHF的效率和实现方式(比如Off policy的算法做RLHF等),其中包括如Pieter Abeel或者John Schulman的文章都非常值得一看。笔者最近在基于其中的一些想法做些实验,如果有空也会断断续续总结一下,并结合自己在最近和研究院里的小伙伴训练RLHF的一些心得谈谈看法。

推荐阅读

[1]麻省理工(MIT)的最新研究:重塑你对LLMs的理解!

[2]十分钟部署清华ChatGLM-6B,实测效果还可以!

[3]白泽:一个以中国神兽命名的大型自然语言模型(LLM)

[4]NLP突破界限,2023 十篇必读的顶级NLP论文!

[5]你必须要知道的 “ 十二个国际顶级会议 ”!

[6]2023年!自然语言处理 10 大预训练模型

点击下方链接🔗关注我们

「资料整理不易,点个再看吧」

您可能也对以下帖子感兴趣

文章有问题?点此查看未经处理的缓存