Image to Image Translation
diffusion models
2024-01-25 · 30m

Table of Contents

Over winter break, a friend and I implemented stable diffusion from scratch:

image.azliu.cc

There were a lot of interesting and unexpected aspects of the project that made it a lot of fun. For example, we did not realize how difficult it would be to deploy a model that ended up being roughly 4GB. In the end, we are happy with the way it turned out, and luckily we don’t have to pay too much money to keep this running forever!

One of the reasons why we were motivated to take on this project is because we wanted to have a better understanding of the math behind generative models. Below is a summary of some of the things that we learned. All of the math is taken directly from this page of our website.

1. DDPM

We begin with DDPMs (Ho et al. 2020). These are the foundation of all good diffusion models. In particular, InstructPix2Pix is a fine-tuned version of Stable Diffusion, which works out of the box with DDPM.

1.1. Foundation

In these diffusion models, we noise images sampled from distribution q(x)q(x) through the forward process, defined via

q(x1:Tx0)=t=1Tq(xtxt1)q(xtxt1)=N(xt;1βtxt1,βtI),q(x_{1:T}|x_0) = \prod_{t=1}^T q(x_t|x_{t-1})\qquad q(x_t|x_{t-1}) = \mathcal{N}(x_t; \sqrt{1-\beta_t}x_{t-1}, \beta_t I),

where βt\beta_t is some noise scheduler. We’ll typically see β1<β2<<βT\beta_1 < \beta_2 < \cdots < \beta_T to ensure that the final image is pure noise.

Let αt=1βt\alpha_t = 1-\beta_t and i=1tαt=αt\prod_{i=1}^t \alpha_t = \overline{\alpha}_t. It can be shown that

q(xtx0)=N(xt;αtx0,(1αt)I),q(x_t|x_0) = \mathcal{N}(x_t; \sqrt{\overline{\alpha}_t}x_0, (1-\overline{\alpha}_t)I),

which not only makes sampling the forward process very efficient, but also shows that we can think of the forward process as some linear combination of the original image and pure noise.

In the backward process, the goal is to produce a model pθp_{\theta} that approximates qq, where we define

pθ(x0:T)=p(xT)t=1Tpθ(xt1xt)pθ(xt1xt)=N(xt1;μθ(xt,t),Σθ(xt,t)).p_{\theta}(x_{0:T}) = p(x_T)\prod_{t=1}^T p_{\theta}(x_{t-1}|x_t) \qquad p_{\theta}(x_{t-1}|x_t) = \mathcal{N}(x_{t-1}; \mu_{\theta}(x_t,t), \Sigma_{\theta}(x_t,t)).

Given perfect pθp_{\theta}, we can perfectly recreate the original data distribution from pure noise, which is the magic of diffusion. It turns out that q(xt1xt,x0)=N(xt1;μ~(xt,x0),β~tI)q(x_{t-1}|x_t,x_0) = \mathcal{N}(x_{t-1};\tilde{\mu}(x_t,x_0), \tilde{\beta}_tI) is tractable, and it can be shown that

μ~(xt,x0)=1αt(xt1αt1αtεt),\tilde{\mu}(x_t,x_0) = \frac{1}{\sqrt{\alpha_t}}\left(x_t - \frac{1-\alpha_t}{\sqrt{1-\overline{\alpha}_t}}\varepsilon_t\right),

where εt\varepsilon_t is the noise added to produce xtx_t from x0x_0. This quantity is relevant because it can be shown that optimizing the variational lower bound is equivalent to minimizing

L=Ex0,ε,t[12Σθ(xt,t)22μ~t(xt,x0)μθ(xt,t)2].L = \mathbb{E}_{x_0, \varepsilon, t}\left[\frac{1}{2\lVert \Sigma_{\theta}(x_t,t)\rVert_2^2}\lVert \tilde{\mu}_t(x_t,x_0) - \mu_{\theta}(x_t,t)\rVert^2\right].

where μθ(xt,t)\mu_{\theta}(x_t,t) and Σθ(xt,t)\Sigma_{\theta}(x_t,t) is the backwards mean and variance predicted by our model. Substituting values for μ~(xt,x0)\tilde{\mu}(x_t, x_0), this is equivalent to

L=Ex0,ε,t[(1αt)22αt(1αt)Σθ(xt,t)22εtεθ(xt,t)2].L = \mathbb{E}_{x_0,\varepsilon, t}\left[\frac{(1-\alpha_t)^2}{2\alpha_t(1-\overline{\alpha}_t)\lVert\Sigma_{\theta}(x_t,t)\rVert_2^2}\lVert \varepsilon_t - \varepsilon_{\theta}(x_t,t)\rVert^2\right].

In their paper, (Ho et al. 2020) showed that ignoring the constant in front can produce better results for training, so eventually reduced the loss function to

L=Ex0,ε,t[εtεθ(xt,t)2].L = \mathbb{E}_{x_0, \varepsilon, t}[\lVert \varepsilon_t - \varepsilon_{\theta}(x_t,t)\rVert^2].

The full derivation of all the facts leading up to this point can be found here. The importance of this fact is that training DDPM models boils down to sampling random images at random timesteps, and having a model (typically a UNet) learn to predict the the noise added at timestep tt.

1.2. Clarifying xtx_t

The input of our UNet is an image xtx_t, generated from x0x_0 using the distribution q(xtx0)q(x_t|x_0). the output of our UNet is the noise used to generate xtx_t, εθ(xt,t)\varepsilon_{\theta}(x_t,t).

This noise is sampled from N(0,1)\mathcal{N}(0,1), per the reparamaterization trick. this noise represents normalized quantity of noise added from x0x_0 to xtx_t, not from xt1x_{t-1} to xtx_t. in other words, this noise corresponds to the forwards distribution q(xtx0)q(x_t|x_0), and not q(xtxt1)q(x_t|x_{t-1}).

Even though we are predicting noise from x0x_0 to xtx_t, this noise is used to generate the distribution p(xt1xt)p(x_{t-1}|x_t). in other words, the distribution for predicting one timestep backwards from tt to t1t-1 is a function of noise added from 00 to tt. this is because the original (tractable) distribution that we are trying to learn is q(xt1xt,x0)q(x_{t-1}|x_t,x_0), which also has access to information about x0x_0.

A concrete demonstration of this fact is the way that we simulate the backwards process once we have trained our UNet. We must make TT calls to the UNet to go from pure noise to the original data distribution (more accurately, our approximation of the distribution). In each of these TT calls, although the UNet estimates noise from timestep 00 to the current timestep, we cannot jump directly to the beginning, because we do not know p(x0xt)p(x_0|x_t) as a function of this noise. Thus, although we have a shortcut for the forwards process, there is no equivalent shortcut for the backwards process, and so inference is quite expensive.

2. Multivariate Gaussians

Before continuing, we first derive some useful identities for multivariate Gaussians that will be used later. Many of these identities and derivations can be found in Pattern Recognition and Machine Learning (Bishop 2006). They are recreated here for completeness and for our own learning.

Suppose we represent Gaussian p(x)p(x) as some jointly defined distribution p(xa,xb)p(x_a,x_b), where xax_a and xbx_b arbitrarily partition the dimensions in xx. xax_a and xbx_b are distributions in their own right, with means and covariances; write

x=(xaxb)μ=(μaμb)Σ=(ΣaaΣabΣbaΣbb).x = \begin{pmatrix}x_a \\ x_b\end{pmatrix}\qquad \mu = \begin{pmatrix}\mu_a \\ \mu_b\end{pmatrix}\qquad \Sigma = \begin{pmatrix}\Sigma_{aa} & \Sigma_{ab} \\ \Sigma_{ba} & \Sigma_{bb}\end{pmatrix}.

Also, let

ΛΣ1=(ΛaaΛabΛbaΛbb)\Lambda \equiv \Sigma^{-1} = \begin{pmatrix}\Lambda_{aa} & \Lambda_{ab} \\ \Lambda_{ba} & \Lambda_{bb}\end{pmatrix}

be the precision matrix corresponding to xx.

2.1. Computing marginal gaussian from joint gaussian

The first question we will focus on is how to compute the marginal

p(xa)=p(xa,xb)dxb.p(x_a) = \int p(x_a,x_b) \mathrm{d}x_b.

The purpose of decomposing everything into xax_a and xbx_b components is that we may now write

12(xμ)TΣ1(xμ)=12(xaμa)TΛaa(xaμa)12(xaμa)TΛab(xbμb)12(xbμb)TΛba(xaμa)12(xbμb)TΛbb(xbμb).\begin{align*} -\frac{1}{2}(x-\mu)^T\Sigma^{-1}(x-\mu) &= -\frac{1}{2}(x_a-\mu_a)^T\Lambda_{aa}(x_a-\mu_a)-\frac{1}{2}(x_a-\mu_a)^T\Lambda_{ab}(x_b-\mu_b) \\ &\quad -\frac{1}{2}(x_b-\mu_b)^T\Lambda_{ba}(x_a-\mu_a)-\frac{1}{2}(x_b-\mu_b)^T\Lambda_{bb}(x_b-\mu_b). \end{align*}

Our strategy will be to isolate the integrating variable xbx_b. Collecting all terms in the above that have xbx_b:

12xbTΛbbxb+xbT(ΛbbμbΛbaxa+Λbaμa),-\frac{1}{2}x_b^T\Lambda_{bb}x_b + x_b^T(\Lambda_{bb}\mu_b - \Lambda_{ba}x_a + \Lambda_{ba}\mu_a),

where we use the fact that Λab=ΛbaT\Lambda_{ab} = \Lambda_{ba}^T to combine like terms. After completing the square, our integral becomes

f(xa)exp{12(xbΛbb1(ΛbbμbΛbaxa+Λbaμa))TΛbb(xbΛbb1(ΛbbμbΛbaxa+Λbaμa))}dxb,\begin{align*} f(x_a) \cdot \int & \exp\biggl\{-\frac{1}{2}(x_b-\Lambda_{bb}^{-1}(\Lambda_{bb}\mu_b - \Lambda_{ba}x_a + \Lambda_{ba}\mu_a))^T\\ &\quad \cdot \Lambda_{bb}(x_b-\Lambda_{bb}^{-1}(\Lambda_{bb}\mu_b - \Lambda_{ba}x_a + \Lambda_{ba}\mu_a))\biggr\}\mathrm{d}x_b, \end{align*}

where f(xa)f(x_a) is some function of xax_a independent of xbx_b. There are two key observations here:

  • f(xa)f(x_a) is a quadratic form in xax_a
  • the integrand is just an unnormalized Gaussian, so it will integrate to the inverse normalization factor. in this case, this normalization factor is only a function of detΛbb\det\Lambda_{bb}, which is not a function of xax_a

Together, these two observations imply that p(xa)p(x_a) is itself Gaussian, and thus we can ignore the constant that we get from the integral. Instead, given that the distribution is Gaussian, we can cherry pick μa\mu_a and Σa\Sigma_a by comparing coefficients with the general Gaussian expansion

12(xμ)TΣ1(xμ)=12xTΣ1x+xTΣ1μ+const.-\frac{1}{2}(x-\mu)^T\Sigma^{-1}(x-\mu) = -\frac{1}{2}x^T\Sigma^{-1}x + x^T\Sigma^{-1}\mu + \text{const}.

Alternatively, we could manually expand the integral and compute everything, but comparing coefficients is much easier.

The full expression for f(xa)f(x_a) has terms from our original expansion, and the leftover terms from completing the square on xbx_b. The terms coming from our original expansion are

12xaTΛaaxa+xaT(Λaaμa+Λabμb),-\frac{1}{2}x_a^T\Lambda_{aa}x_a + x_a^T(\Lambda_{aa}\mu_a + \Lambda_{ab}\mu_b),

while the terms leftover from completing the square are

12(ΛbbμbΛbaxa+Λbaμa)TΛbb1(ΛbbμbΛbaxa+Λbaμa)=12(μbTΛbaxaxaTΛbaTμb+xaTΛbaTΛbb1ΛbaxaxaTΛbaTΛbb1ΛbaμaμaTΛbaTΛbb1Λbaxa)+const.=12xaTΛabΛbb1Λbaxa+xaT(ΛabμbΛabΛbb1Λbaμa)+const.,\begin{align*} &\frac{1}{2}(\Lambda_{bb}\mu_b - \Lambda_{ba}x_a + \Lambda_{ba}\mu_a)^T\Lambda_{bb}^{-1}(\Lambda_{bb}\mu_b - \Lambda_{ba}x_a + \Lambda_{ba}\mu_a)\\ &= \frac{1}{2}(-\mu_b^T\Lambda_{ba}x_a-x_a^T\Lambda_{ba}^T\mu_b + x_a^T\Lambda_{ba}^T\Lambda_{bb}^{-1}\Lambda_{ba}x_a \\ &\qquad - x_a^T\Lambda_{ba}^T\Lambda_{bb}^{-1}\Lambda_{ba}\mu_a-\mu_a^T\Lambda_{ba}^T\Lambda_{bb}^{-1}\Lambda_{ba}x_a) + \text{const.} \\ &= \frac{1}{2}x_a^T\Lambda_{ab}\Lambda_{bb}^{-1}\Lambda_{ba}x_a + x_a^T(-\Lambda_{ab}\mu_b-\Lambda_{ab}\Lambda_{bb}^{-1}\Lambda_{ba}\mu_a) + \text{const.}, \end{align*}

where we combine like terms using Λab=ΛbaT\Lambda_{ab}=\Lambda_{ba}^T and the fact that all individual terms are scalars, i.e., we can transpose terms freely. Combining all terms together now gives

12xaT(ΛaaΛabΛbb1Λba)xa+xaT(ΛaaΛabΛbb1Λba)μa+const.,-\frac{1}{2}x_a^T(\Lambda_{aa}-\Lambda_{ab}\Lambda_{bb}^{-1}\Lambda_{ba})x_a + x_a^T(\Lambda_{aa} - \Lambda_{ab}\Lambda_{bb}^{-1}\Lambda_{ba})\mu_a + \text{const.},

so comparing coefficients with the general gaussian expansion gives

Cov[p(xa)]=(ΛaaΛabΛbb1Λba)1,\text{Cov}[p(x_a)] = (\Lambda_{aa} - \Lambda_{ab}\Lambda_{bb}^{-1}\Lambda_{ba})^{-1},

and thus

E[p(xa)]=Cov[p(xa)]1(ΛaaΛabΛbb1Λba)1μa=μa.\mathbb{E}[p(x_a)] = \text{Cov}[p(x_a)]^{-1}(\Lambda_{aa} - \Lambda_{ab}\Lambda_{bb}^{-1}\Lambda_{ba})^{-1}\mu_a = \mu_a.

Lastly, we note that Cov[p(xa)]\text{Cov}[p(x_a)] is actually the inverse Schur complement of block Σaa\Sigma_{aa}, so the result is more cleanly written Cov[p(xa)]=Σaa\text{Cov}[p(x_a)] = \Sigma_{aa}.

Note how intuitively nice this result is. It’s exactly what we expect, even if it takes some work to prove it rigorously.

2.2. Computing marginal gaussian given other other marginal and conditional

The next question we will focus on is computing the marginal p(y)p(y) given

p(x)=N(x;μ,Λ1)p(yx)=N(y;Ax+b,L1).p(x) = \mathcal{N}(x; \mu, \Lambda^{-1}) \qquad p(y|x) = \mathcal{N}(y; Ax+b,L^{-1}).

This form seems a little contrived but becomes useful when we discuss DDIMs.

The first observation is that the joint p(x,y)p(x,y) is Gaussian, since P[x=x0,y=y0]=P[x=x0]P[y=y0x=x0]\mathbb{P}[x=x_0,y=y_0] = \mathbb{P}[x=x_0]\cdot \mathbb{P}[y=y_0|x=x_0] for all (x0,y0)(x_0,y_0). Note that it is not true in general that the product of two Gaussian random variables is Gaussian; it works out here because the pdf of the joint distribution at every point is equal to the product of the pdfs of the marginal Gaussians, and the product of Gaussian pdfs is always Gaussian.

Thus, if we can find mean and variance of p(x,y)p(x,y), we can use our results from 2.1 to obtain the distribution of the other conditional.

We have

lnp(x,y)=12(xμ)TΛ(xμ)12(yAxb)TL(yAxb)+const=12(xTΛx+xTATLAx+yTLyyTLAxTATLy2xTΛμ+2xTATLb2yTLb)+const=12(xy)T(Λ+ATLAATLLAL)(xy)+(xy)T(ΛμATLbLb)+const.\begin{align*} \ln p(x,y) &= -\frac{1}{2}(x-\mu)^T\Lambda (x-\mu) - \frac{1}{2}(y-Ax-b)^TL(y-Ax-b) + \text{const} \\ &= -\frac{1}{2}\left(x^T\Lambda x + x^TA^TLAx + y^TLy - y^TLA - x^TA^TLy \right. \\ &\qquad\qquad \left. - 2x^T\Lambda \mu + 2x^TA^TLb - 2y^TLb\right) + \text{const} \\ &= -\frac{1}{2}\begin{pmatrix}x \\ y\end{pmatrix}^T\begin{pmatrix}\Lambda + A^TLA & -A^TL \\ -LA & L\end{pmatrix}\begin{pmatrix}x \\ y\end{pmatrix} \\ &\qquad + \begin{pmatrix}x \\ y\end{pmatrix}^T\begin{pmatrix}\Lambda\mu - A^TLb \\ Lb\end{pmatrix} + \text{const}. \end{align*}

Comparing coefficients with the general gaussian expansion, we have

Cov(x,y)=(Λ+ATLAATLLAL)1=(Λ1Λ1ATAΛ1L1+AΛ1AT).\text{Cov}(x,y) = \begin{pmatrix}\Lambda + A^TLA & -A^TL \\ -LA & L\end{pmatrix}^{-1} = \begin{pmatrix}\Lambda^{-1} & \Lambda^{-1}A^T \\ A\Lambda^{-1} & L^{-1} + A\Lambda^{-1}A^T\end{pmatrix}.

and thus

E(x,y)=(Λ1Λ1ATAΛ1L1+AΛ1AT)1(ΛμATLbLb)=(μAμ+B).\mathbb{E}(x,y) = \begin{pmatrix}\Lambda^{-1} & \Lambda^{-1}A^T \\ A\Lambda^{-1} & L^{-1} + A\Lambda^{-1}A^T\end{pmatrix}^{-1}\begin{pmatrix}\Lambda\mu -A^TLb \\ Lb\end{pmatrix} = \begin{pmatrix}\mu \\ A\mu + B\end{pmatrix}.

Finally, using 2.1,

E(y)=Aμ+bCov(y)=L1+AΛ1AT.\mathbb{E}(y) = A\mu + b \qquad \text{Cov}(y) = L^{-1} + A\Lambda^{-1}A^T.

With some more work, we could also extract the other conditional, but we won’t need this result.

Like our result in 2.1, this final expression is nice because it aligns reasonably with what we expect. We have yxy|x sampled from a distribution with mean f(x)f(x), where ff is linear; therefore, the fact that E[y]=f(E(x))\mathbb{E}[y] = f(\mathbb{E}(x)) makes intuitive sense. Further, since we have f(x)=Ax+bf(x) = Ax+b, we expect Cov(f(x))=ATCov(x)A\text{Cov}(f(x)) = A^T\text{Cov(x)}A. The only “dependence” that yy has on xx is through their means; Cov(y|x)=L1\text{Cov(y|x)} = L^{-1} is a source of noise that is essentially independent from the noise associated with xx, so through linearity of variance it makes intuitive sense that Cov(y)=ATCov(x)A+L1\text{Cov}(y) = A^T\text{Cov(x)}A + L^{-1}.

2.3. KL Divergence

The last thing we will examine is how to compute the KL divergence, or relative entropy, between two multivariate gaussians. The general expression is given by

DKL(PQ)=Exp(x)[logp(x)q(x)].D_{KL}(P || Q) = \mathbb{E}_{x\sim p(x)}\left[\log \frac{p(x)}{q(x)}\right].

Applying this to multivariate P=N(μ1,Σ1)P=\mathcal{N}(\mu_1, \Sigma_1) and Q=N(μ2,Σ2)Q=\mathcal{N}(\mu_2, \Sigma_2), we have

Exp(x)[logp(x)q(x)]=12Exp(x)[logΣ2Σ1(xμ1)TΣ11(xμ1)+(xμ2)TΣ21(xμ2)].\begin{align*} &\mathbb{E}_{x\sim p(x)}\left[\log \frac{p(x)}{q(x)}\right] \\ &\quad =\frac{1}{2}\mathbb{E}_{x\sim p(x)}\left[\log \frac{\vert \Sigma_2\vert}{\vert \Sigma_1\vert} - (x-\mu_1)^T\Sigma_1^{-1}(x-\mu_1) + (x-\mu_2)^T\Sigma_2^{-1}(x-\mu_2)\right]. \end{align*}

To simplify this, we can apply the trace trick; since quadratic forms xTAxx^TAx are singletons, it is equal to its trace, and since traces commute, it is thus also equal to Tr((Ax)xT)\text{Tr}((Ax)x^T) and Tr(x(xTA))\text{Tr}(x(x^TA)). So, we can simplify our expression as follows:

12logΣ2Σ1+12Exp(x)[Tr(Σ11(xμ1)(xμ1)T)+Tr(Σ21(xμ2)(xμ2)T)]=12logΣ2Σ1+12Exp(x)[Tr(Σ11Σ1)+Tr(Σ21(xxT2xμ2+μ2μ2T))]=12logΣ2Σ112d+12Exp(x)[Tr(Σ21(xTx2xTμ2+μ2Tμ2))]=12logΣ2Σ112d+12Exp(x)[Tr(Σ21((Σ1μ1Tμ1+2xTμ1)2xTμ2+μ2Tμ2))]=12logΣ2Σ112d+12Tr(Σ21Σ1)+12Exp(x)[Tr(Σ21(μ1μ2)T(μ1μ2))]=12logΣ2Σ112d+12Tr(Σ21Σ1)+12(μ1μ2)TΣ21(μ1μ2).\begin{align*} &\frac{1}{2}\log\frac{\vert \Sigma_2\vert}{\vert \Sigma_1\vert} + \frac{1}{2}\mathbb{E}_{x\sim p(x)}\left[\text{Tr}(-\Sigma_1^{-1}(x-\mu_1)(x-\mu_1)^T) + \text{Tr}(\Sigma_2^{-1}(x-\mu_2)(x-\mu_2)^T)\right] \\ &= \frac{1}{2}\log\frac{\vert \Sigma_2\vert}{\vert \Sigma_1\vert} + \frac{1}{2}\mathbb{E}_{x\sim p(x)}\left[\text{Tr}(-\Sigma_1^{-1}\Sigma_1) + \text{Tr}(\Sigma_2^{-1}(xx^T-2x\mu_2+\mu_2\mu_2^T))\right] \\ &= \frac{1}{2}\log\frac{\vert \Sigma_2\vert}{\vert \Sigma_1\vert} -\frac{1}{2}d + \frac{1}{2}\mathbb{E}_{x\sim p(x)}\left[\text{Tr}(\Sigma_2^{-1}(x^Tx-2x^T\mu_2+\mu_2^T\mu_2)\right)] \\ &= \frac{1}{2}\log\frac{\vert \Sigma_2\vert}{\vert \Sigma_1\vert} -\frac{1}{2}d + \frac{1}{2}\mathbb{E}_{x\sim p(x)}\left[\text{Tr}(\Sigma_2^{-1}((\Sigma_1-\mu_1^T\mu_1+2x^T\mu_1)-2x^T\mu_2+\mu_2^T\mu_2)\right)] \\ &= \frac{1}{2}\log\frac{\vert \Sigma_2\vert}{\vert \Sigma_1\vert} -\frac{1}{2}d + \frac{1}{2}\text{Tr}(\Sigma_2^{-1}\Sigma_1) + \frac{1}{2}\mathbb{E}_{x\sim p(x)}\left[\text{Tr}(\Sigma_2^{-1}(\mu_1-\mu_2)^T(\mu_1-\mu_2)\right)] \\ &= \frac{1}{2}\log\frac{\vert \Sigma_2\vert}{\vert \Sigma_1\vert} -\frac{1}{2}d + \frac{1}{2}\text{Tr}(\Sigma_2^{-1}\Sigma_1) + \frac{1}{2}(\mu_1-\mu_2)^T\Sigma_2^{-1}(\mu_1-\mu_2). \end{align*}

3. DDIM

Next, we turn to DDIMs (Song et al. 2020), since this variation on DDPMs is an important component of the Stable Diffusion models. A key motivation for these models is the fact referenced above that inference, i.e., simulating the backwards diffusion process, is quite expensive.

This paper introduces a family of inference distributions Q\mathcal{Q} paramaterized by σ\sigma fixing the variance added during denoising:

qσ(xt1xt,x0)=N(αt1x0+1αt1σt2xtαtx01αt,σt2I).q_{\sigma}(x_{t-1}|x_t,x_0) = \mathcal{N}\left(\sqrt{\overline{\alpha}_{t-1}}x_0 + \sqrt{1-\overline{\alpha}_{t-1}-\sigma_t^2}\cdot \frac{x_t - \sqrt{\overline{\alpha}_t}x_0}{\sqrt{1-\overline{\alpha}_t}}, \sigma_t^2 I\right).

This distribution was constructed so that forward sampling still works as expected for the purpose of training; it can be proven that qσ(xtx0)=N(αtx0,(1αt)I)q_{\sigma}(x_t|x_0) = \mathcal{N}(\sqrt{\overline{\alpha}_t}x_0, (1-\overline{\alpha}_t)I), which is the same sampling distribution for DDPMs (see here). It is not true that the normal forwards process q(xtxt1)q(x_t|x_{t-1}) stays intact, and in fact qσ(xtxt1,x0)qσ(xtxt1)q_{\sigma}(x_t|x_{t-1},x_0)\neq q_{\sigma}(x_t|x_{t-1}), hence this family of distributions is non-Markovian.

To see why this new family of distributions intuitively captures the spirit of the backwards diffusion process derived in DDPMs, note that

xtαtx01αt=εt,\frac{x_t - \sqrt{\overline{\alpha}_t}x_0}{\sqrt{1-\overline{\alpha}_t}} = \varepsilon_{t},

so the two terms 1αt1σt2\sqrt{1-\overline{\alpha}_{t-1}-\sigma_t^2} (under the mean) and σt\sigma_t (under actual variance) can be seen as having total noise 1αt1\sqrt{1-\overline{\alpha}_{t-1}}, which matches the noise expression for the forward process, i.e., the distribution q(xt1x0)q(x_{t-1}|x_0). The paper introduces this “splitting” of the noise factors to control the actual amount of noise that is induced during the backwards inference step.

3.1. Proof that Q\mathcal{Q} satisfies forwards definition

To prove that forwards sampling from x0x_0 remains the same, we can use an inductive argument, inducting downwards on the timestep. When T=tT=t, we assume that qσ(xTx0)q_{\sigma}(x_T|x_0) is normally distributed (i.e., pure noise), so

qσ(xTx0)=N(αTx0,(1αT)I)=N(0,I),q_{\sigma}(x_T|x_0) = \mathcal{N}(\sqrt{\overline{\alpha}_T}x_0, (1-\overline{\alpha}_T)I) = \mathcal{N}(0,I),

and our base case holds. Now, by our inductive hypothesis, assume that we have

qσ(xtx0)=N(αtx0,(1αt)I),q_{\sigma}(x_t|x_0) = \mathcal{N}(\sqrt{\overline{\alpha}_t}x_0, (1-\overline{\alpha}_t)I),

and we also have

qσ(xt1xt,x0)=N(αt1x0+1αt1σt2xtαtx01αt,σt2I).q_{\sigma}(x_{t-1}|x_t,x_0) = \mathcal{N}(\sqrt{\overline{\alpha}_{t-1}}x_0 + \sqrt{1 - \overline{\alpha}_{t-1} - \sigma_t^2}\cdot \frac{x_t - \sqrt{\overline{\alpha}_t}x_0}{\sqrt{1-\overline{\alpha}_t}}, \sigma_t^2 I).

Now we have a marginal distribution qσ(xtx0)q_{\sigma}(x_{t}|x_0) and a distribution conditioned on this marginal distribution qσ(xt1xt,x0)q_{\sigma}(x_{t-1} | x_t,x_0). We wish to find the other marginal qσ(xt1x0)q_{\sigma}(x_{t-1}|x_0), and luckily the setup is the same as our setup from 2.2. Thus, using the results from our derivation, we have

qσ(xt1x0)=N(Aμ+b,L1+AΛ1AT),q_{\sigma}(x_{t-1}|x_0) = \mathcal{N}(A\mu + b, L^{-1} + A\Lambda^{-1}A^T),

where

Aμ+b=αt1x0+1αt1σt2αtx0αtx01αt=αt1x0,A\mu + b = \sqrt{\overline{\alpha}_{t-1}}x_0 + \sqrt{1 - \overline{\alpha}_{t-1}-\sigma_t^2}\cdot \frac{\sqrt{\overline{\alpha}_t}x_0 - \sqrt{\overline{\alpha}_t}x_0}{\sqrt{1-\overline{\alpha}_t}} = \sqrt{\overline{\alpha}_{t-1}}x_0,

and

L1+AΛ1AT=σt2I+1αt1σt21αt(1αt)I=(1αt1)I.L^{-1} + A\Lambda^{-1}A^T = \sigma_t^2I + \frac{1 - \overline{\alpha}_{t-1} - \sigma_t^2}{1 - \overline{\alpha}_t}(1-\overline{\alpha}_t)I = (1-\overline{\alpha}_{t-1})I.

Thus,

qσ(xt1x0)=N(αt1x0,(1αt1)I),q_{\sigma}(x_{t-1}|x_0) = \mathcal{N}(\sqrt{\overline{\alpha}_{t-1}}x_0, (1-\overline{\alpha}_{t-1})I),

which completes the proof.

3.2. Proof that Q\mathcal{Q} can be applied to DDPM trained models

One of the key properties that the inference distributions Q\mathcal{Q} is that they can be applied even to models that were originally trained with the DDPM objective.

In order to prove this, we first introduce some more notation from the paper. Let L\mathcal{L} be a family of loss functions generalizing the DDPM training process, such that for any LγLL_{\gamma}\in \mathcal{L},

Lγ(εt)=Ex0,εt,xt=αtx0+1αtεt[γεtεθ(xt)22].L_{\gamma}(\varepsilon_t) = \mathbb{E}_{x_0, \varepsilon_t, x_t = \sqrt{\overline{\alpha}_t}x_0 + \sqrt{1-\overline{\alpha}_t}\varepsilon_t}\left[\gamma \lVert\varepsilon_t - \varepsilon_{\theta}(x_t)\rVert_2^2\right].

For example, in DDPM, it was shown that the mathematically optimal loss was given by

γ=(1αt)22αt(1αt)Σθ(xt,t)2,\gamma = \frac{(1-\alpha_t)^2}{2\alpha_t(1-\overline{\alpha}_t)\lVert \Sigma_{\theta}(x_t,t)\rVert^2},

(see here), while γ=1\gamma = 1 was shown to be good for training. Now, let JσJ_{\sigma} be the optimal objective for learning qσq_{\sigma}. To show that Q\mathcal{Q} inference can be effectively applied to DDPM-trained models, it suffices to show that JσLJ_{\sigma}\in \mathcal{L}.

We first utilize the variational inference objective from (Ho et al. 2020):

LVLB=Ex0:Tqσ(x0:T)[logqσ(x1:Tx0)logpθ(x0:T)]L_{VLB} = \mathbb{E}_{x_{0:T}\sim q_{\sigma}(x_{0:T})}\left[\frac{\log q_{\sigma}(x_{1:T\vert x_0})}{\log p_{\theta}(x_{0:T})}\right]

Using results from (Sohl-Dickstein et al. 2015) and the same derivations from DDPM, we have

Jσ(εθ)Ex0:Tqσ(x0:T)[t=2TDKL((qσ(xt1xt,x0))pθ(t)(xt1xt))logpθ(1)(x0x1)]J_{\sigma}(\varepsilon_{\theta}) \equiv \mathbb{E}_{x_{0:T}\sim q_{\sigma}(x_{0:T})}\left[\sum_{t=2}^{T}D_{KL}((q_{\sigma}(x_{t-1}\vert x_t, x_0))\Vert p_{\theta}^{(t)}(x_{t-1}\vert x_t)) - \log p_{\theta}^{(1)}(x_0 \vert x_1)\right]

when only taking terms L1,,Lt1L_1,\dots,L_{t-1} (in the notation of the paper, we use \equiv instead of == when we take steps that throw away constant factors).

Now, per the paper, we define the actual generative process pθ(x0:T)p_{\theta}(x_{0:T}) as a function of the derived distribution qσ(xt1xt,x0)q_{\sigma}(x_{t-1}\vert x_t, x_0). Since we don’t know x0x_0 during inference, we replace this term in qσq_{\sigma} with the (derived) output of the neural net.

From the definition of the forward process (see here), our predicted denoised observation of x0x_0 given xtx_t is given by

fθ(t)(xt):=(xt1αtεθ(t)(xt))αt.f_{\theta}^{(t)}(x_t) := \frac{(x_t - \sqrt{1-\alpha_t} \cdot \varepsilon_{\theta}^{(t)}(x_t))}{\sqrt{\alpha_t}}.

Thus, we can define the reverse generative process with a prior distribution pθ(t)(xt)=N(0,I)p_{\theta}^{(t)}(x_{t}) = \mathcal{N}(0, I) as

pθ(t)(xt1xt)=qσ(xt1xt,fθ(t)(xt)).p_{\theta}^{(t)}(x_{t-1} \vert x_{t}) = q_{\sigma}(x_{t-1} \vert x_t, f_{\theta}^{(t)}(x_t)).

Finally, we have enough to evaluate JσJ_{\sigma}. First note that we can rewrite the expected value:

Jσ(εθ)Ex0,ε,xt=αtx0+1αtε[DKL(qσ(xt1xt,x0))qσ(xt1xt,fθ(t)(xt))].J_{\sigma}(\varepsilon_{\theta}) \equiv \mathbb{E}_{x_0,\varepsilon,x_t=\sqrt{\overline{\alpha}_t}x_0+\sqrt{1-\overline{\alpha}_t}\varepsilon}[D_{KL}(q_{\sigma}(x_{t-1}|x_t,x_0))||q_{\sigma}(x_{t-1} \vert x_t, f_{\theta}^{(t)}(x_t))].

Using 2.3, we see that this is equivalent to optimizing mse between the two means (under \equiv), so the part of our objective inside of the expected value becomes

Ex0,ε,xt=αtx0+1αtε[(αt1x0+1αt1σt2xtαtx01αt)(αt1fθ(t)(xt)+1αt1σt2xtαtfθ(t)(xt)1αt)2].\begin{align*} & \mathbb{E}_{x_0,\varepsilon,x_t=\sqrt{\overline{\alpha}_t}x_0 + \sqrt{1-\overline{\alpha}_t}\varepsilon}\left[ \left\lVert\left(\sqrt{\overline{\alpha}_{t-1}}x_0 + \sqrt{1-\overline{\alpha}_{t-1}-\sigma_t^2}\cdot \frac{x_t - \sqrt{\overline{\alpha}_t}x_0}{\sqrt{1-\overline{\alpha}_t}}\right)\right.\right. \\ &\qquad\qquad \left.\left. -\left(\sqrt{\overline{\alpha}_{t-1}}f_{\theta}^{(t)}(x_t) + \sqrt{1-\overline{\alpha}_{t-1}-\sigma_t^2}\cdot \frac{x_t - \sqrt{\overline{\alpha}_t}f_{\theta}^{(t)}(x_t)}{\sqrt{1-\overline{\alpha}_t}}\right)\right\rVert^2\right]. \end{align*}

Further simplifying,

Jσ(εθ)Ex0,ε,xt=αtx0+1αtε[x0fθ(t)(xt)2]Ex0,ε,xt=αtx0+1αtε[xt1αtεαtxt1αtεθ(t)(xt)αt2]Ex0,ε,xt=αtx0+1αtε[εεθ(t)(xt)2]L,\begin{align*} J_{\sigma}(\varepsilon_{\theta}) &\equiv \mathbb{E}_{x_0,\varepsilon,x_t=\sqrt{\overline{\alpha}_t}x_0 + \sqrt{1-\overline{\alpha}_t}\varepsilon}\left[\lVert x_0 - f_{\theta}^{(t)}(x_t)\rVert^2\right] \\ &\equiv \mathbb{E}_{x_0,\varepsilon,x_t=\sqrt{\overline{\alpha}_t}x_0 + \sqrt{1-\overline{\alpha}_t}\varepsilon}\left[\left\lVert \frac{x_t - \sqrt{1-\overline{\alpha}_t}\varepsilon}{\sqrt{\overline{\alpha}_t}} - \frac{x_t - \sqrt{1-\overline{\alpha}_t}\varepsilon_{\theta}^{(t)}(x_t)}{\sqrt{\overline{\alpha}_t}} \right\rVert^2\right]\\ &\equiv \mathbb{E}_{x_0,\varepsilon,x_t=\sqrt{\overline{\alpha}_t}x_0 + \sqrt{1-\overline{\alpha}_t}\varepsilon}[\lVert \varepsilon - \varepsilon_{\theta}^{(t)}(x_t)\rVert^2] \in \mathcal{L}, \end{align*}

as desired.

4. Conditional generation

Eventually, our goal is not just to generate image from noise, but we would like to also generate images conditioned on text labels. More specifically, our final task is to eventually generate images conditioned on both text labels and an input reference image, but we’ll discuss this more in the next section. In this section, we’ll discuss methods for conditional generation in an easier subtask, which starts with generating images from discrete input classes. As a simple example of this subtask, we might have a diffusion model try to generate one of the ten digits from the MNIST dataset.

4.1 Classifier Guided

The general approach is as follows. At each step during inference, we are trying to approximate xtlogq(xt,y)\nabla_{x_t}\log q(x_t, y), also known as the score function for the joint distribution for q(xt,y)q(x_t, y).

We first consider the case when we have an external classifier fϕ(yxt,t)f_{\phi}(y|x_t,t). We want to use the gradient of the classifier, xlogfϕ(yxt)\nabla_x \log f_{\phi}(y|x_t) to alter the noise prediction based on our classifier in order to guide our diffusion process. Intuitively, taking the gradient of the classifier can essentially be described as calculating the log difference of the probabilities in two images at adjacent timesteps.

Note in general that the gradient of a multivariate normal p(x)N(x;μ,Σ)p(x)\sim \mathcal{N}(x; \mu, \Sigma) is given by

12x(xμ)TΣ1(xμ)=Σ1(xμ).-\frac{1}{2}\frac{\partial}{\partial x}(x-\mu)^T\Sigma^{-1}(x-\mu) = -\Sigma^{-1}(x-\mu).

Thus we may approximate xtlogq(xt)\nabla_{x_t}\log q(x_t) in terms of predicted noise εθ(xt)\varepsilon_{\theta}(x_t) with

xtlogq(xt)=11αtεθ(xt,t).\nabla_{x_t}\log q(x_t) = -\frac{1}{\sqrt{1-\overline{\alpha}_t}}\varepsilon_\theta(x_t, t).

Now, to approximate xtlogq(xt)\nabla_{x_t}\log q(x_t), we have

xtlogq(xt,y)=xtlog(q(xt)q(yxt))=xtlogq(xt)+xtlogq(yxt)11αtεθ(xt,t)+xtlogfϕ(yxt)=11αt(εθ(xt,t)1αtxtlogfϕ(yxt))\begin{align*} \nabla_{x_t}\log q(x_t, y) &= \nabla_{x_t}\log (q(x_t)q(y \vert x_t)) \\ &=\nabla_{x_t}\log q(x_t) + \nabla_{x_t}\log q(y \vert x_t) \\ &\approx -\frac{1}{\sqrt{1-\overline{\alpha}_t}}\varepsilon_\theta(x_t, t) + \nabla_{x_t} \log f_{\phi}(y|x_t) \\ &= -\frac{1}{\sqrt{1-\overline{\alpha}_t}}(\varepsilon_\theta(x_t, t) - \sqrt{1-\overline{\alpha}_t}\nabla_{x_t} \log f_{\phi}(y|x_t)) \end{align*}

Thus, during inference, we can use the new noise predictor

ϵθ(xt,t)=ϵθ(xt,t)1αtxlogfϕ(yxt)\begin{align*} \overline{\epsilon}_\theta(x_t, t) = \epsilon_\theta(x_t, t) - \sqrt{1-\overline{\alpha}_t}\nabla_x \log f_{\phi}(y|x_t) \end{align*}

Given the external classifier fϕ(yxt,t)f_{\phi}(y|x_t,t), it wasn’t entirely clear to us the optimal way to calculate the gradient with respect to xtx_t. In practice, we feel that this could be approximated by taking the difference in value of the prediction between two timesteps.

4.2 Classifier-free

Assuming that we don’t have a classifier, we can instead modify our original noise predictor εθ\varepsilon_{\theta} to learn conditional and unconditional generation.

The conditional diffusion model εθ(xt,t,y)\varepsilon_{\theta}(x_t,t,y) is trained in the exact same way as the normal model, but with class labels yy added as an additional embedding to the input image. Some of these class labels are omitted during training so that the model still learns how to generate images unconditionally.

Absent a true classifier, the only thing we are missing in the above derivation is xtlogq(yxt)\nabla_{x_t}\log q(y \vert x_t). We can approximate this value with our new noise predictors:

xtlogq(yxt)=xtlogq(xty)q(xt)=xtlogq(xty)xtlogq(xt)=11αt(εθ(xt,t,y)εθ(xt,t))\begin{align*} \nabla_{x_t}\log q(y \vert x_t) &= \nabla_{x_t}\log \frac{q(x_t \vert y)}{q(x_t)} \\ &=\nabla_{x_t}\log q(x_t \vert y) - \nabla_{x_t} \log q(x_t) \\ &= -\frac{1}{\sqrt{1-\overline{\alpha}_t}}\left(\varepsilon_\theta(x_t, t, y) -\varepsilon_\theta(x_t, t)\right) \\ \end{align*}

The ultimate objective for classifier free guidance is to improve the quality of images as well as their correspondence with the conditional class label. Intuitively, the gradient of q(yxt)q(y \vert x_t) tries to maximize the likelihood of the condition yy in order to make images match with their corresponding label.

In practice, it makes sense to weight this gradient with a weight ww so that we can control how much influence it has over the course of inference. Therefore, our new predictor is given by

ϵθ(xt,t,y)=ϵθ(xt,t,y)1αtwxtlogq(yxt)=ϵθ(xt,t,y)+w(ϵθ(xt,t,y)ϵθ(xt,t)).\begin{align*} \overline{\epsilon}_\theta(x_t, t, y) &= \epsilon_\theta(x_t, t, y) - \sqrt{1-\overline{\alpha}_t}w\nabla_{x_t}\log q(y | x_t)\\ &= \epsilon_\theta(x_t, t, y) + w(\epsilon_\theta(x_t, t, y) - \epsilon_\theta(x_t, t)). \end{align*}

If we reframe the weight with a scaling factor s:=w+1s := w+1, we have

ϵθ(xt,t,y)=ϵθ(xt,t)+s(ϵθ(xt,t,y)ϵθ(xt,t)),\overline{\epsilon}_\theta(x_t, t, y) = \epsilon_\theta(x_t, t) + s(\epsilon_\theta(x_t, t, y) - \epsilon_\theta(x_t, t)),

and it then becomes intuitively clear how our weight is influencing inference; when we place more weight on the gradient, we get an output closer to the theoretical conditional gradient, and conversely when the weight on the gradient is smaller, we get an output that is closer to unconditional diffusion.

5. Modifying the conditional objective for Pix2Pix

We can apply similar logic to obtain classifier-free guidance with two conditionings, a previous image cIc_I and the class label cTc_T.

In Brooks et. al, the authors chose to omit only cIc_I in 5%5\% of examples, only cTc_T in 5%5\% of examples, and both cIc_I and cTc_T in 5%5\% of examples. This produces a robust noise estimator for different combinations of conditionings.

For inference, we can similarly introduce scaling factors sI,sTs_I, s_T to adjust how strongly we want to the generated image to correlate with the input image condition and the class label condition, respectively. Our final noise estimator becomes

ϵθ(xt,cI,cT)=ϵθ(xt,,)+sI(ϵθ(xt,cI,)ϵθ(xt,,))+sT(ϵθ(xt,cI,cT)ϵθ(xt,cI,)).\begin{align*} \overline{\epsilon}_\theta(x_t, c_I, c_T) &= \epsilon_\theta(x_t, \varnothing, \varnothing) + s_I(\epsilon_\theta(x_t, c_I, \varnothing) - \epsilon_\theta(x_t, \varnothing, \varnothing)) \\ &\qquad + s_T(\epsilon_\theta(x_t, c_I, c_T) - \epsilon_\theta(x_t, c_I, \varnothing)). \end{align*}

References

[1] Brooks et al. “InstructPix2Pix: Learning to Follow Image Editing Instructions” (2022).

[2] Rombach et al. “High-Resolution Image Synthesis with Latent Diffusion Models” (2021).

[3] Brown et al. “Language Models are Few-Shot Learners” (2020).

[4] Ho et al. “Denoising Diffusion Probabilistic Models” (2020).

[5] Song et al. “Denoising Diffusion Implicit Models” (2020).

[6] Weng, Lilian. “What are diffusion models?” Lil’Log (2021).

[7] Bishop, Christopher “Pattern Recognition and Machine Learning” (2006).

[8] Sohl-Dickstein et al. “Deep unsupervised learning using nonequilibrium thermodynamics.” (2015).