语言模型解码采样策略

语言模型解码/采样策略

贪心

核心思想: 每一步取当前最可能的结果,作为最终结果

具体方法:获得新生成的词是vocab中各个词的概率,取argmax作为需要生成的词向量索引,继而生成后一个词

beamsearch

核心思想: beam search尝试在广度优先基础上进行进行搜索空间的优化(类似于剪枝)达到减少内存消耗的目的

具体方法:在decoding的每个步骤,我们都保留着 top K 个可能的候选单词,然后到了下一个步骤的时候,我们对这 K 个单词都做下一步 decoding,分别选出 top K,然后对这 K^2 个候选句子再挑选出 top K 个句子。以此类推一直到 decoding 结束为止。当然 Beam Search 本质上也是一个 greedy decoding 的方法,所以我们无法保证自己一定可以得到最好的 decoding 结果

缺点:会生成出空洞、重复、前后矛盾的文本。

随机sampling

我们可以在生成文本的时候引入一些随机性。例如现在语言模型告诉我们下一个单词在整个单词表上的概率分布是 p = (p_1, p_2, … p_|V|),那么我们就可以按照这个概率分布进行随机采样,然后决定下一个单词生成什么。采样相对于greedy方法的好处是,我们生成的文字开始有了一些随机性,不会总是生成很机械的回复了。

随机采样容易产生前后不一致的问题。而在开放闲聊领域,生成文本的长度都比较短,这种问题就被自然的淡化了。

Temperature sampling

采样的时候有一个可以控制的超参数,称为温度(temperature, )T。解码器的输出层后面通常会跟一个softmax函数来将输出概率归一化,通过改变T可以控制概率分布的形貌。softmax的公式如下,当T大的时候,概率分布趋向平均,随机性增大;当T小的时候,概率密度趋向于集中,即强者愈强,随机性降低,会更多地采样出“放之四海而皆准”的词汇。

存在的问题

①生成的话容易不连贯,上下文比较矛盾。 ②容易生成奇怪的话,出现罕见词。

top-k sampling

取概率最大的K个词,之后对这K个词概率归一化之后再进行sampling,但K的大小不太好选,因为不同的句子,概率分布的变化有很大的区别,有的时候比较平,有的时候比较集中,分布均衡时,K小了容易丢失优质的词,分布集中时,K大了容易引入奇怪的词,就和随机采样没什么区别了。

top-p(nucleus) sampling核采样

The Curious Case of Neural Text Degeneration

好处:不需要手动的选取K,作者选取p为0.95 对当前的所有词的概率按照从大到小开始累加,当累加的值大于阈值P的时候,后面小的概率词就不使用,对前面的词再进行sampling,如设置阈值p为0.95,则相当于对左上选用top 4,右上选用top 2

参考:https://zhuanlan.zhihu.com/p/115076102

其实上述各种采样方式在HuggingFace的库里都已经实现了(感动!),我们来看一下代码。

先看top-k和top-p采样

自动选取超参-p&k

目标是通过top k 和 top p来最大化下一个预测最大概率的token为真实token。对于k, 可以直接找到真实token对应的sorted之后的index, 对于p, 可以看真实token对应的累计之后的位置。比如"我喜欢吃热",真实token是“狗”,而模型top 1置信度对应的token是"煎饼",top 1对应的累加概率为60%,往低概率的token继续查找,如果发现”狗“对应的index是3,此时对应的累加概率是85%,这时候就找到了最优的p了。

超参搜索。

重复惩罚

为了解决重复问题,还可以通过惩罚因子将出现过词的概率变小或者强制不使用重复词来解决。惩罚因子来自于同样广为流传的《CTRL: A Conditional Transformer Language Model for Controllable Generation》[2]。

重复词去除

参考资料

解码策略(介绍了几个解码策略,包括代码)

十分钟读懂Beam Search(2/2) 看ICLR2020论文教你如何提升(和解码策略这篇文章类似)

香侬读 | 采样算法哪家强:一个针对主流采样算法的比较(比较了当前主流的几个采样算法Top-K, Nucleus, Tempered,发现他们都满足三个关键性质(1)减熵性;(2)保序性;(3)保斜率性。)

从不同解码策略看机器如何生成文本(以GPT2为例举例说明各种策略)

语言模型采样策略(介绍+代码)

文本生成中的decoding strategy整理(有Class-factored Softmax和Pointer-generator Network)

Last updated

Was this helpful?