There were a lot of interesting and unexpected aspects of the project that made it a lot of fun. For example, we did not realize how difficult it would be to deploy a model that ended up being roughly 4GB. In the end, we are happy with the way it turned out, and luckily we don’t have to pay too much money to keep this running forever!
One of the reasons why we were motivated to take on this project is because we wanted to have a better understanding of the math behind generative models. Below is a summary of some of the things that we learned. All of the math is taken directly from this page of our website.
1. DDPM
We begin with DDPMs (Ho et al. 2020). These are the foundation of all good diffusion models. In particular, InstructPix2Pix is a fine-tuned version of Stable Diffusion, which works out of the box with DDPM.
1.1. Foundation
In these diffusion models, we noise images sampled from distribution q(x) through the forward process, defined via
where βt is some noise scheduler. We’ll typically see β1<β2<⋯<βT to ensure that the final image is pure noise.
Let αt=1−βt and ∏i=1tαt=αt. It can be shown that
q(xt∣x0)=N(xt;αtx0,(1−αt)I),
which not only makes sampling the forward process very efficient, but also shows that we can think of the forward process as some linear combination of the original image and pure noise.
In the backward process, the goal is to produce a model pθ that approximates q, where we define
Given perfect pθ, we can perfectly recreate the original data distribution from pure noise, which is the magic of diffusion. It turns out that q(xt−1∣xt,x0)=N(xt−1;μ~(xt,x0),β~tI) is tractable, and it can be shown that
μ~(xt,x0)=αt1(xt−1−αt1−αtεt),
where εt is the noise added to produce xt from x0. This quantity is relevant because it can be shown that optimizing the variational lower bound is equivalent to minimizing
In their paper, (Ho et al. 2020) showed that ignoring the constant in front can produce better results for training, so eventually reduced the loss function to
L=Ex0,ε,t[∥εt−εθ(xt,t)∥2].
The full derivation of all the facts leading up to this point can be found here. The importance of this fact is that training DDPM models boils down to sampling random images at random timesteps, and having a model (typically a UNet) learn to predict the the noise added at timestep t.
1.2. Clarifying xt
The input of our UNet is an image xt, generated from x0 using the distribution q(xt∣x0). the output of our UNet is the noise used to generate xt, εθ(xt,t).
This noise is sampled from N(0,1), per the reparamaterization trick. this noise represents normalized quantity of noise added from x0 to xt, not from xt−1 to xt. in other words, this noise corresponds to the forwards distribution q(xt∣x0), and not q(xt∣xt−1).
Even though we are predicting noise from x0 to xt, this noise is used to generate the distribution p(xt−1∣xt). in other words, the distribution for predicting one timestep backwards from t to t−1 is a function of noise added from 0 to t. this is because the original (tractable) distribution that we are trying to learn is q(xt−1∣xt,x0), which also has access to information about x0.
A concrete demonstration of this fact is the way that we simulate the backwards process once we have trained our UNet. We must make T calls to the UNet to go from pure noise to the original data distribution (more accurately, our approximation of the distribution). In each of these T calls, although the UNet estimates noise from timestep 0 to the current timestep, we cannot jump directly to the beginning, because we do not know p(x0∣xt) as a function of this noise. Thus, although we have a shortcut for the forwards process, there is no equivalent shortcut for the backwards process, and so inference is quite expensive.
2. Multivariate Gaussians
Before continuing, we first derive some useful identities for multivariate Gaussians that will be used later. Many of these identities and derivations can be found in Pattern Recognition and Machine Learning (Bishop 2006). They are recreated here for completeness and for our own learning.
Suppose we represent Gaussian p(x) as some jointly defined distribution p(xa,xb), where xa and xb arbitrarily partition the dimensions in x. xa and xb are distributions in their own right, with means and covariances; write
x=(xaxb)μ=(μaμb)Σ=(ΣaaΣbaΣabΣbb).
Also, let
Λ≡Σ−1=(ΛaaΛbaΛabΛbb)
be the precision matrix corresponding to x.
2.1. Computing marginal gaussian from joint gaussian
The first question we will focus on is how to compute the marginal
p(xa)=∫p(xa,xb)dxb.
The purpose of decomposing everything into xa and xb components is that we may now write
where f(xa) is some function of xa independent of xb. There are two key observations here:
f(xa) is a quadratic form in xa
the integrand is just an unnormalized Gaussian, so it will integrate to the inverse normalization factor. in this case, this normalization factor is only a function of detΛbb, which is not a function of xa
Together, these two observations imply that p(xa) is itself Gaussian, and thus we can ignore the constant that we get from the integral. Instead, given that the distribution is Gaussian, we can cherry pick μa and Σa by comparing coefficients with the general Gaussian expansion
−21(x−μ)TΣ−1(x−μ)=−21xTΣ−1x+xTΣ−1μ+const.
Alternatively, we could manually expand the integral and compute everything, but comparing coefficients is much easier.
The full expression for f(xa) has terms from our original expansion, and the leftover terms from completing the square on xb. The terms coming from our original expansion are
−21xaTΛaaxa+xaT(Λaaμa+Λabμb),
while the terms leftover from completing the square are
where we combine like terms using Λab=ΛbaT and the fact that all individual terms are scalars, i.e., we can transpose terms freely. Combining all terms together now gives
Lastly, we note that Cov[p(xa)] is actually the inverse Schur complement of block Σaa, so the result is more cleanly written Cov[p(xa)]=Σaa.
Note how intuitively nice this result is. It’s exactly what we expect, even if it takes some work to prove it rigorously.
2.2. Computing marginal gaussian given other other marginal and conditional
The next question we will focus on is computing the marginal p(y) given
p(x)=N(x;μ,Λ−1)p(y∣x)=N(y;Ax+b,L−1).
This form seems a little contrived but becomes useful when we discuss DDIMs.
The first observation is that the joint p(x,y) is Gaussian, since P[x=x0,y=y0]=P[x=x0]⋅P[y=y0∣x=x0] for all (x0,y0). Note that it is not true in general that the product of two Gaussian random variables is Gaussian; it works out here because the pdf of the joint distribution at every point is equal to the product of the pdfs of the marginal Gaussians, and the product of Gaussian pdfs is always Gaussian.
Thus, if we can find mean and variance of p(x,y), we can use our results from 2.1 to obtain the distribution of the other conditional.
With some more work, we could also extract the other conditional, but we won’t need this result.
Like our result in 2.1, this final expression is nice because it aligns reasonably with what we expect. We have y∣x sampled from a distribution with mean f(x), where f is linear; therefore, the fact that E[y]=f(E(x)) makes intuitive sense. Further, since we have f(x)=Ax+b, we expect Cov(f(x))=ATCov(x)A. The only “dependence” that y has on x is through their means; Cov(y|x)=L−1 is a source of noise that is essentially independent from the noise associated with x, so through linearity of variance it makes intuitive sense that Cov(y)=ATCov(x)A+L−1.
2.3. KL Divergence
The last thing we will examine is how to compute the KL divergence, or relative entropy, between two multivariate gaussians. The general expression is given by
DKL(P∣∣Q)=Ex∼p(x)[logq(x)p(x)].
Applying this to multivariate P=N(μ1,Σ1) and Q=N(μ2,Σ2), we have
To simplify this, we can apply the trace trick; since quadratic forms xTAx are singletons, it is equal to its trace, and since traces commute, it is thus also equal to Tr((Ax)xT) and Tr(x(xTA)). So, we can simplify our expression as follows:
Next, we turn to DDIMs (Song et al. 2020), since this variation on DDPMs is an important component of the Stable Diffusion models. A key motivation for these models is the fact referenced above that inference, i.e., simulating the backwards diffusion process, is quite expensive.
This paper introduces a family of inference distributions Q paramaterized by σ fixing the variance added during denoising:
This distribution was constructed so that forward sampling still works as expected for the purpose of training; it can be proven that qσ(xt∣x0)=N(αtx0,(1−αt)I), which is the same sampling distribution for DDPMs (see here). It is not true that the normal forwards process q(xt∣xt−1) stays intact, and in fact qσ(xt∣xt−1,x0)=qσ(xt∣xt−1), hence this family of distributions is non-Markovian.
To see why this new family of distributions intuitively captures the spirit of the backwards diffusion process derived in DDPMs, note that
1−αtxt−αtx0=εt,
so the two terms 1−αt−1−σt2 (under the mean) and σt (under actual variance) can be seen as having total noise 1−αt−1, which matches the noise expression for the forward process, i.e., the distribution q(xt−1∣x0). The paper introduces this “splitting” of the noise factors to control the actual amount of noise that is induced during the backwards inference step.
3.1. Proof that Q satisfies forwards definition
To prove that forwards sampling from x0 remains the same, we can use an inductive argument, inducting downwards on the timestep. When T=t, we assume that qσ(xT∣x0) is normally distributed (i.e., pure noise), so
qσ(xT∣x0)=N(αTx0,(1−αT)I)=N(0,I),
and our base case holds. Now, by our inductive hypothesis, assume that we have
Now we have a marginal distribution qσ(xt∣x0) and a distribution conditioned on this marginal distribution qσ(xt−1∣xt,x0). We wish to find the other marginal qσ(xt−1∣x0), and luckily the setup is the same as our setup from 2.2. Thus, using the results from our derivation, we have
3.2. Proof that Q can be applied to DDPM trained models
One of the key properties that the inference distributions Q is that they can be applied even to models that were originally trained with the DDPM objective.
In order to prove this, we first introduce some more notation from the paper. Let L be a family of loss functions generalizing the DDPM training process, such that for any Lγ∈L,
For example, in DDPM, it was shown that the mathematically optimal loss was given by
γ=2αt(1−αt)∥Σθ(xt,t)∥2(1−αt)2,
(see here), while γ=1 was shown to be good for training. Now, let Jσ be the optimal objective for learning qσ. To show that Q inference can be effectively applied to DDPM-trained models, it suffices to show that Jσ∈L.
We first utilize the variational inference objective from (Ho et al. 2020):
when only taking terms L1,…,Lt−1 (in the notation of the paper, we use ≡ instead of = when we take steps that throw away constant factors).
Now, per the paper, we define the actual generative process pθ(x0:T) as a function of the derived distribution qσ(xt−1∣xt,x0). Since we don’t know x0 during inference, we replace this term in qσ with the (derived) output of the neural net.
From the definition of the forward process (see here), our predicted denoised observation of x0 given xt is given by
fθ(t)(xt):=αt(xt−1−αt⋅εθ(t)(xt)).
Thus, we can define the reverse generative process with a prior distribution pθ(t)(xt)=N(0,I) as
pθ(t)(xt−1∣xt)=qσ(xt−1∣xt,fθ(t)(xt)).
Finally, we have enough to evaluate Jσ. First note that we can rewrite the expected value:
Using 2.3, we see that this is equivalent to optimizing mse between the two means (under ≡), so the part of our objective inside of the expected value becomes
Eventually, our goal is not just to generate image from noise, but we would like to also generate images conditioned on text labels. More specifically, our final task is to eventually generate images conditioned on both text labels and an input reference image, but we’ll discuss this more in the next section. In this section, we’ll discuss methods for conditional generation in an easier subtask, which starts with generating images from discrete input classes. As a simple example of this subtask, we might have a diffusion model try to generate one of the ten digits from the MNIST dataset.
4.1 Classifier Guided
The general approach is as follows. At each step during inference, we are trying to approximate ∇xtlogq(xt,y), also known as the score function for the joint distribution for q(xt,y).
We first consider the case when we have an external classifier fϕ(y∣xt,t). We want to use the gradient of the classifier, ∇xlogfϕ(y∣xt) to alter the noise prediction based on our classifier in order to guide our diffusion process. Intuitively, taking the gradient of the classifier can essentially be described as calculating the log difference of the probabilities in two images at adjacent timesteps.
Note in general that the gradient of a multivariate normal p(x)∼N(x;μ,Σ) is given by
−21∂x∂(x−μ)TΣ−1(x−μ)=−Σ−1(x−μ).
Thus we may approximate ∇xtlogq(xt) in terms of predicted noise εθ(xt) with
Thus, during inference, we can use the new noise predictor
ϵθ(xt,t)=ϵθ(xt,t)−1−αt∇xlogfϕ(y∣xt)
Given the external classifier fϕ(y∣xt,t), it wasn’t entirely clear to us the optimal way to calculate the gradient with respect to xt. In practice, we feel that this could be approximated by taking the difference in value of the prediction between two timesteps.
4.2 Classifier-free
Assuming that we don’t have a classifier, we can instead modify our original noise predictor εθ to learn conditional and unconditional generation.
The conditional diffusion model εθ(xt,t,y) is trained in the exact same way as the normal model, but with class labels y added as an additional embedding to the input image. Some of these class labels are omitted during training so that the model still learns how to generate images unconditionally.
Absent a true classifier, the only thing we are missing in the above derivation is ∇xtlogq(y∣xt). We can approximate this value with our new noise predictors:
The ultimate objective for classifier free guidance is to improve the quality of images as well as their correspondence with the conditional class label. Intuitively, the gradient of q(y∣xt) tries to maximize the likelihood of the condition y in order to make images match with their corresponding label.
In practice, it makes sense to weight this gradient with a weight w so that we can control how much influence it has over the course of inference. Therefore, our new predictor is given by
and it then becomes intuitively clear how our weight is influencing inference; when we place more weight on the gradient, we get an output closer to the theoretical conditional gradient, and conversely when the weight on the gradient is smaller, we get an output that is closer to unconditional diffusion.
5. Modifying the conditional objective for Pix2Pix
We can apply similar logic to obtain classifier-free guidance with two conditionings, a previous image cI and the class label cT.
In Brooks et. al, the authors chose to omit only cI in 5% of examples, only cT in 5% of examples, and both cI and cT in 5% of examples. This produces a robust noise estimator for different combinations of conditionings.
For inference, we can similarly introduce scaling factors sI,sT to adjust how strongly we want to the generated image to correlate with the input image condition and the class label condition, respectively. Our final noise estimator becomes