继续浏览精彩内容
慕课网APP
程序员的梦工厂
打开
继续
感谢您的支持,我会继续努力的
赞赏金额会直接到老师账户
将二维码发送给自己后长按识别
微信支付
支付宝支付

扩散模型训练损失的推导详解

炎炎设计
关注TA
已关注
手记 335
粉丝 74
获赞 371

大家都知道,近年来扩散模型在各个领域产生了很大的影响。我想更了解这项技术,于是决定从零开始自己搭建它。

首先,我想了解一下训练损失,但是这比较复杂,而且论文中省略了一些细节,所以我试着自己研究和推导,并在这里总结了。基本上,它遵循论文去噪扩散概率模型中的方法。

首先,扩散模型有两个主要过程:正向过程逆向过程。在正向过程中,噪声会逐步添加到输入数据里。在逆向过程中,通过执行正向过程的逆操作,从含噪声的数据中恢复原始图像。下面的图示能帮助你更直观地理解这些过程。

扩散模型是潜在变量模型的形式,形式如下 ​(x 0​):=∫ ​(x 0:T ​)dx 1:T ​。这种模型的联合分布 ​(x 0:T ​) 也称为反向过程,并定义为从 p(xT ​)=N(xT ​;0,I) 出发的具有学得的高斯转移的马尔可夫链。

正向过程被设定为一个马尔可夫链,该链会根据方差计划 β 1, …, βT 逐渐向数据添加高斯噪声。

生成模型估计的数据概率如下。

在原论文中,积分无法计算,所以公式变换如下,如下所示。

(Note: There seems to be a redundancy in the sentence "如下,如下所示". It should be revised to remove the repetition. The corrected version is as follows.)

在原论文中,积分无法计算,所以公式变换如下。

尽管论文中没有详细说明,但我个人觉得公式转换可能是通过正向过程完成的。这可能是由于正向过程有已知的概率分布,而反向过程的概率分布则可能更复杂,积分计算起来也更难。

训练是通过优化负对数似然的变分界来进行的。下面方程通过詹森不等式得到一个上界,

这个方程还可以这样进一步变换:

在上述方程的变形中,q(xt ∣ xt −1) 进行了如下所示的转换。

在上述方程式的转换过程中,使用了如下所示的关系。

我也用了以下方程转换。

DKL 被称为 KL 散度,它是一种衡量概率分布 P 与另一个概率分布 Q 之间差异的统计度量。

我们现在来简化方程(6)。

前向步骤和LT(请参见上下文以了解“LT”的含义)

我们忽略了一个事实,即可以通过重新参数化来学习前向过程方差βt,而是将它们固定为常数。因此,在我们的实现中,近似后验_q_没有可学习的参数。因此,在训练过程中,_LT_是一个常数,可以忽略。

逆向和L1:T−1注:这里的"L1:T−1"是指特定的技术术语,可以理解为前一个时间点的状态。

现在我们讨论在 ​(xt −1​∣ xt ​)=N(xt −1​:μθ ​(xt ​,t),Σ θ ​(xt ​,t)) 中的选项,对于 1<tT 的情况。首先,我们将Σ θ ​(xt ​,t) 设置为未训练的时间依赖常数 σt 2​ I。实验表明,不论是 σt 2​=βt ​ 还是 σt 2​=βt ​~​,结果都差不多。因此,我们可以将 Lt −1​ 写成如下形式。

(C) 是一个不依赖于 (\theta) 的常数。因此,我们可以看出最直接的 (\mu_\theta) 参数化是将其直接设置为前向过程后验均值 (\mu_t) 的模型。我们可以使用以下方程来展开方程。这种变换被称为 重新参数化技巧

我们可以从上面的方程中得到 _x_0。

通过使用上述的 x 0,我们就可以更新 Lt -1。

我们使用下面的 ( \bar{u_t}(x_t, x_0) ) 表示上述变换。附录 中解释了 ( \bar{u_t}(x_t, x_0) ) 和 ( \bar{\beta_t} ) 的推导。

因为 xt 可以作为模型的输入,所以我们可以决定参数化。

其中 ϵθ 是一个用来从 xt 预测 ϵ 的函数估计器。

我们可以用上面的方程来简化 (_Lt^{-1} - C)。

总之,我们可以训练逆过程的均值函数近似器 μθ 来预测 μt ~,或者通过调整其参数,我们可以训练它来预测 ϵ

数据缩放(归一化),逆向过程解码器和L0

我们假设图像数据由0到255的整数线性缩放至[-1,1]之间。这确保神经网络逆过程从标准正态先验p(x_T)开始处理一致缩放的输入数据。为了获得离散对数似然,我们将逆向过程中的最后一步设置为一个独立的离散解码器,该解码器源自高斯分布N(x_0;μ_θ(x_1,1),σ_1^2 I)。

其中 D 是数据维度,而 i 上标表示提取一个坐标。上述方程计算了每个像素的同时概率。δ 表示裁剪界限,以此来将高斯概率密度限制在每个 _x_0i 的离散值对应的范围内。这样可以确保离散数据在连续框架中得到正确处理,以此确保在连续框架中正确处理离散数据。

简化后的训练目标

通过 Eq(17) 和 Eq(18),我们能更进一步地简化训练目标。

t=1 时,对应的是 L 0,其中离散解码器定义中的积分(公式(18))被近似为高斯概率密度函数乘以区间宽度的方式,忽略 _σ_12 和边缘效应的影响。当 t>1 时,这种情况对应的是公式(17)的未加权版本的形式。具体细节请参见文章,但使用这个简单的公式是因为移除公式(17)中的权重部分后,其准确性会更好。

附录:推导q(xt−1|x_t, x_0)的均值和方差

条件分布 q(xt −1​∣ xt ​,x 0​) 正比于以下两个分布的乘积。

这两个分布的定义如下。

在计算方差时,我们利用高斯分布的乘积法则。在高斯分布的乘积形式下,方差的倒数表示为各个分布倒数的总和。

如果我们把两边交换以得到方差 βt ~

这里,我们使用以下属性。

将其代入方差公式中,进行相应的转换。

接下来,我们考虑均值的推导过程。设每个分布的均值分别为 m 1​ 和 m 2​,方差分别为 σ 12​ 和 σ 22​,均值可以按如下方式得出。

因此,我们可以通过以下方法计算平均数。

在这里面,我们又使用了下面这个属性。

说实话,我真不确定这个推导是否正确。我也不知道为什么 xt −1 可以变成 xt 。这可能与我们首先用两个概率分布来近似 q(xt −1 ∣ xt , x 0) 有关。

参考文献:
打开App,阅读手记
0人推荐
发表评论
随时随地看视频慕课网APP