DMD (Distribution Matching Distillation)

笔者的数理基础不是很好,还在研磨统计基础中,有解释不到位的欢迎指出

主要贡献

  1. 提出一个新的Distribution matching loss,蒸馏模型的图像质量几乎没有损失。
  2. 能快速加速图像生成。

问题描述

有一个老师diffusion模型$\mu_{base}$,我们要把它蒸馏成一个一步diffusion模型 $G_{\theta}$,同时保证其图像质量。

Loss1:DMD Loss

文章引入了两个loss作为最终loss。

首先是作为大头的DMD loss。可以类比GAN,因为我们同样也有一个生成器,我们对

  1. $G_{\theta}$生成的图片记为$x_{fake}$
  2. 原老师生成器生成的图片记为$x_{real}$

DMD loss可以被翻译为分布匹配损失,可以简单地用KL Loss来估计,那么我们定义DMD loss如下。

我们对其进行简单的求导,得

这个loss梯度是比较难算的,作者提到了两个原因:

  1. 因为对于t很小的$\mu_{base}(x_t, t)$的分布(蓝色),是无法给DMD loss有效的信号。因此对于$p_{fake}(x)$ 很高的x在p_{real}(x)就会很小或者几乎没有。
  2. Diffusion本身只能返回 $\mu_{base}(x_t, t)$就是加过噪声的分布(多元高斯分布)而非初始数据分布,所以不能直接对分布进行求导。

解决这两个问题的方法借鉴了Song et al. 的Score-SDE,详细推导可以看这篇。

  1. 首先,我们对原始图片加噪,$\mu_{\mathrm{base}}(x_t, t)$是提供的,所以可以直接用;但是$\mu_{\mathrm{fake}}^\phi(x_t, t)$ 不行,我们要对他进行估计,同时要时刻对齐$G_{\theta}$输出的,因此我们要让$\mu_{\mathrm{fake}}^\phi(x_t, t)$从$G_{\theta}$中来, 直觉上作者用了用一个简单的denoising loss($\mathcal{L}_{\mathrm{denoise}}^\phi
    = \left| \mu_{\mathrm{fake}}^\phi(x_t, t) - x_0 \right|_2^2$ 让$\mu_{\mathrm{fake}}^\phi(x_t, t)$时刻学习到$G_{\theta}$的变化。
  2. score可以被估计为:

image.png

如上图所示,未加噪的分布很有可能在$\nabla \log p_{\text{fake}}$ 很高的时候,$\nabla \log p_{\text{real}}$趋近于0。但是在t步加噪之后的分布中,$\nabla \log p_{\text{fake}}$ $\nabla \log p_{\text{real}}$都是一个较为均匀平滑的值。

附:DMD loss的具体实现并不是简单的$\nabla_\theta D_{KL}
=
\mathbb{E}_{z \sim \mathcal{N}(0,I),\; x = G_\theta(z)}
\left[

  • \big( s_{\text{real}}(x) - s_{\text{fake}}(x) \big)
    \frac{dG}{d\theta}
    \right]$项,而带了不少的正则项,为了稳定在不同噪声大小的gradient大小,作者引入了一些权重系数的公式,详细可以读一下下面这一串解释,这篇文章整体都写的还是比较清晰的。

image-1.png

Loss2: Regression Loss

  1. 单单使用DMD loss会有mode collapse的问题,也就是fake distribution会只学到某些real distribution。
  2. 而且对于t靠近0的情况,$\nabla \log p_{\text{real}}$趋近于0。

作者通过引入一个regression loss,即$\mathcal{L}{\text{reg}} = \mathbb{E}{(z,y)\sim D}\,\ell\big(G_{\theta}(z),\, y\big)$,作为一个正则项帮助$\mu_{fake}(x_t, t)$逐步靠近所有$\mu_{base}(x_t, t)$的分布,在训练前,$(z, y)$是作为样本对预先构造的,因此需要先跑一遍完整的teacher模型的ODE solver。

下面列举了Loss不同构造方式时候,学习到的fake distribution的分布示意图。

image-2.png

image-3.png

实验中采样效果对比图。一个去掉了DMD loss,显然质量下架严重;另一个去掉了Reg loss,显然多样性不足。

算法实现

image-4.png

image-5.png

算法

  1. 准备材料:一个pretrained $\mu_{real}$,和一些噪声,预先用teacher跑的干净生成的图像对,记为$D$
  2. 先用teacher model的权重load到单步生成器$G_{\theta}$和多步生成器$\mu_{fake}$中。
  3. 采样batch噪声,然后喂给$G_{\theta}$生成。
  4. 算loss:$t \sim [T_{min}, T_{max}]$,算出$\mu_{real}(x_t, t)$和$\mu_{fake}(x_t, t)$,然后算DMD loss,再在用LPIPS Loss算下x_ref和y_ref的差距;用这两个loss更新$G_{\theta}$
  5. 之后更新$\mu_{fake}$,来动态的学习一个单步生成器的假分布。

杂谈

  1. 训练generator和$\mu_{fake}$像一个boostrap的流程,我在读这算法的时候愣了一下,因为一开始因为初始化$\mu_{base}$和$\mu_{fake}$的weight不是一样的吗,DMD loss不是0吗,之后才发现fake diffusion也是要逐步和generator对齐的,会逐渐学习到一个动态的fake的分布,和$\mu_{base}$也会逐渐不同,也就逐渐能给出有效的gradient了。

有一个点不太懂:

  1. 为什么reg loss能作为一个正则项的,就消融实验效果能看的出来但是统计直觉上不太理解希望佬们提点一下。