Progressive Distillation For Fast Sampling
Progressive Distillation For Fast Sampling
Progressive Distillation For Fast Sampling
A BSTRACT
arXiv:2202.00512v2 [cs.LG] 7 Jun 2022
Diffusion models have recently shown great promise for generative modeling, out-
performing GANs on perceptual quality and autoregressive models at density es-
timation. A remaining downside is their slow sampling time: generating high
quality samples takes many hundreds or thousands of model evaluations. Here
we make two contributions to help eliminate this downside: First, we present new
parameterizations of diffusion models that provide increased stability when using
few sampling steps. Second, we present a method to distill a trained deterministic
diffusion sampler, using many steps, into a new diffusion model that takes half as
many sampling steps. We then keep progressively applying this distillation proce-
dure to our model, halving the number of required sampling steps each time. On
standard image generation benchmarks like CIFAR-10, ImageNet, and LSUN, we
start out with state-of-the-art samplers taking as many as 8192 steps, and are able
to distill down to models taking as few as 4 steps without losing much perceptual
quality; achieving, for example, a FID of 3.0 on CIFAR-10 in 4 steps. Finally,
we show that the full progressive distillation procedure does not take more time
than it takes to train the original model, thus representing an efficient solution for
generative modeling using diffusion at both train and test time.
1 I NTRODUCTION
Diffusion models (Sohl-Dickstein et al., 2015; Song & Ermon, 2019; Ho et al., 2020) are an emerg-
ing class of generative models that has recently delivered impressive results on many standard gen-
erative modeling benchmarks. These models have achieved ImageNet generation results outper-
forming BigGAN-deep and VQ-VAE-2 in terms of FID score and classification accuracy score (Ho
et al., 2021; Dhariwal & Nichol, 2021), and they have achieved likelihoods outperforming autore-
gressive image models (Kingma et al., 2021; Song et al., 2021b). They have also succeeded in image
super-resolution (Saharia et al., 2021; Li et al., 2021) and image inpainting (Song et al., 2021c), and
there have been promising results in shape generation (Cai et al., 2020), graph generation (Niu et al.,
2020), and text generation (Hoogeboom et al., 2021; Austin et al., 2021).
A major barrier remains to practical adoption of diffusion models: sampling speed. While sam-
pling can be accomplished in relatively few steps in strongly conditioned settings, such as text-to-
speech (Chen et al., 2021) and image super-resolution (Saharia et al., 2021), or when guiding the
sampler using an auxiliary classifier (Dhariwal & Nichol, 2021), the situation is substantially differ-
ent in settings in which there is less conditioning information available. Examples of such settings
are unconditional and standard class-conditional image generation, which currently require hundreds
or thousands of steps using network evaluations that are not amenable to the caching optimizations
of other types of generative models (Ramachandran et al., 2017).
In this paper, we reduce the sampling time of diffusion models by orders of magnitude in uncondi-
tional and class-conditional image generation, which represent the setting in which diffusion models
have been slowest in previous work. We present a procedure to distill the behavior of a N -step DDIM
sampler (Song et al., 2021a) for a pretrained diffusion model into a new model with N/2 steps, with
little degradation in sample quality. In what we call progressive distillation, we repeat this distilla-
tion procedure to produce models that generate in as few as 4 steps, still maintaining sample quality
competitive with state-of-the-art models using thousands of steps.
1
Published as a conference paper at ICLR 2022
Distillation
Distillation
Distillation
µ̃s|t (zt , x) = eλt −λs (αs /αt )zt + (1 − eλt −λs )αs x, 2
σ̃s|t = (1 − eλt −λs )σs2 (3)
We use this reversed description of the forward process to define the ancestral sampler. Starting at
z1 ∼ N (0, I), the ancestral sampler follows the rule
q
2 )1−γ (σ 2 )γ )
zs = µ̃s|t (zt , x̂θ (zt )) + (σ̃s|t (4)
t|s
q
= eλt −λs (αs /αt )zt + (1 − eλt −λs )αs x̂θ (zt ) + (σ̃s|t
2 )1−γ (σ 2 )γ ),
t|s (5)
2
Published as a conference paper at ICLR 2022
where is standard Gaussian noise, and γ is a hyperparameter that controls how much noise is added
during sampling, following Nichol & Dhariwal (2021).
Alternatively, Song et al. (2021c) show that our denoising model x̂θ (zt ) can be used to determinis-
tically map noise z1 ∼ N (0, I) to samples x by numerically solving the probability flow ODE:
1
dzt = [f (zt , t) − g 2 (t)∇z log p̂θ (zt )]dt, (6)
2
αt x̂θ (zt )−zt
where ∇z log p̂θ (zt ) = σt2
. Following Kingma et al. (2021), we have f (zt , t) = d log αt
dt zt
dσ 2
and g 2 (t) = dtt − 2 d log αt 2
dt σt . Since x̂θ (zt ) is parameterized by a neural network, this equation
is a special case of a neural ODE (Chen et al., 2018), also called a continuous normalizing flow
(Grathwohl et al., 2018).
Solving the ODE in Equation 6 numerically can be done with standard methods like the Euler
rule or the Runge-Kutta method. The DDIM sampler proposed by Song et al. (2021a) can also
be understood as an integration rule for this ODE, as we show in Appendix B, even though it was
originally proposed with a different motivation. The update rule specified by DDIM is
zt − αt x̂θ (zt )
zs = αs x̂θ (zt ) + σs (7)
σt
(λt −λs )/2
=e (αs /αt )zt + (1 − e(λt −λs )/2 )αs x̂θ (zt ), (8)
and in practice this rule performs better than the aforementioned standard ODE integration rules in
our case, as we show in Appendix C.
If x̂θ (zt ) satisfies mild smoothness conditions, the error introduced by numerical integration of the
probability flow ODE is guaranteed to vanish as the number of integration steps grows infinitely
large, i.e. N → ∞. This leads to a trade-off in practice between the accuracy of the numerical
integration, and hence the quality of the produced samples from our model, and the time needed to
produce these samples. So far, most models in the literature have needed hundreds or thousands of
integration steps to produce their highest quality samples, which is prohibitive for many practical
applications of generative modeling. Here, we therefore propose a method to distill these accurate,
but slow, ODE integrators into much faster models that are still very accurate. This idea is visualized
in Figure 1, and described in detail in the next section.
3 P ROGRESSIVE DISTILLATION
To make diffusion models more efficient at sampling time, we propose progressive distillation: an
algorithm that iteratively halves the number of required sampling steps by distilling a slow teacher
diffusion model into a faster student model. Our implementation of progressive distillation stays
very close to the implementation for training the original diffusion model, as described by e.g.
Ho et al. (2020). Algorithm 1 and Algorithm 2 present diffusion model training and progressive
distillation side-by-side, with the relative changes in progressive distillation highlighted in green.
We start the progressive distillation procedure with a teacher diffusion model that is obtained by
training in the standard way. At every iteration of progressive distillation, we then initialize the
student model with a copy of the teacher, using both the same parameters and same model definition.
Like in standard training, we then sample data from the training set and add noise to it, before
forming the training loss by applying the student denoising model to this noisy data zt . The main
difference in progressive distillation is in how we set the target for the denoising model: instead
of the original data x, we have the student model denoise towards a target x̃ that makes a single
student DDIM step match 2 teacher DDIM steps. We calculate this target value by running 2 DDIM
sampling steps using the teacher, starting from zt and ending at zt−1/N , with N being the number of
student sampling steps. By inverting a single step of DDIM, we then calculate the value the student
model would need to predict in order to move from zt to zt−1/N in a single step, as we show in
detail in Appendix G. The resulting target value x̃(zt ) is fully determined given the teacher model
and starting point zt , which allows the student model to make a sharp prediction when evaluated at
zt . In contrast, the original data point x is not fully determined given zt , since multiple different
data points x can produce the same noisy data zt : this means that the original denoising model is
3
Published as a conference paper at ICLR 2022
predicting a weighted average of possible x values, which produces a blurry prediction. By making
sharper predictions, the student model can make faster progress during sampling.
After running distillation to learn a student model taking N sampling steps, we can repeat the pro-
cedure with N/2 steps: The student model then becomes the new teacher, and a new student model
is initialized by making a copy of this model.
Unlike our procedure for training the original model, we always run progressive distillation in dis-
crete time: we sample this discrete time such that the highest time index corresponds to a signal-to-
noise ratio of zero, i.e. α1 = 0, which exactly matches the distribution of input noise z1 ∼ N (0, I)
that is used at test time. We found this to work slightly better than starting from a non-zero signal-
to-noise ratio as used by e.g. Ho et al. (2020), both for training the original model as well as when
performing progressive distillation.
In this section, we discuss how to parameterize the denoising model x̂θ , and how to specify the
reconstruction loss weight w(λt ). We assume a standard variance-preserving diffusion process for
which σt2 = 1 − αt2 . This is without loss of generalization, as shown by (Kingma et al., 2021,
appendix G): different specifications of the diffusion process, such as the variance-exploding spec-
ification, can be considered equivalent to this specification, up to rescaling of the noisy latents zt .
We use a cosine schedule αt = cos(0.5πt), similar to that introduced by Nichol & Dhariwal (2021).
Ho et al. (2020) and much of the following work choose to parameterize the denoising model through
directly predicting with a neural network ˆθ (zt ), which implicitly sets x̂θ (zt ) = α1t (zt −σt ˆθ (zt )).
In this case, the training loss is also usually defined as mean squared error in the -space:
2
1 1 αt2
Lθ = k − ˆθ (zt )k22 = (zt − αt x) − (zt − αt x̂θ (zt )) = kx − x̂θ (zt )k22 , (9)
σt σt 2 σt2
which can thus equivalently be seen as a weighted reconstruction loss in x-space, where the weight-
ing function is given by w(λt ) = exp(λt ), for log signal-to-noise ratio λt = log[αt2 /σt2 ].
4
Published as a conference paper at ICLR 2022
Although this standard specification works well for training the original model, it is not well suited
for distillation: when training the original diffusion model, and at the start of progressive distillation,
the model is evaluated at a wide range of signal-to-noise ratios αt2 /σt2 , but as distillation progresses
we increasingly evaluate at lower and lower signal-to-noise ratios. As the signal-to-noise ratio goes
to zero, the effect of small changes in the neural network output ˆθ (zt ) on the implied prediction in
x-space is increasingly amplified, since x̂θ (zt ) = α1t (zt − σt ˆθ (zt )) divides by αt → 0. This is not
much of a problem when taking many steps, since the effect of early missteps is limited by clipping
of the zt iterates, and later updates can correct any mistakes, but it becomes increasingly important
as we decrease the number of sampling steps. Eventually, if we distill all the way down to a single
sampling step, the input to the model is only pure noise , which corresponds to a signal-to-noise
ratio of zero, i.e. αt = 0, σt = 1. At this extreme, the link between -prediction and x-prediction
breaks down completely: observed data zt = is no longer informative of x and predictions ˆθ (zt )
no longer implicitly predict x. Examining our reconstruction loss (equation 9), we see that the
weighting function w(λt ) gives zero weight to the reconstruction loss at this signal-to-noise ratio.
For distillation to work, we thus need to parameterize the diffusion model in a way for which the
implied prediction x̂θ (zt ) remains stable as λt = log[αt2 /σt2 ] varies. We tried the following options,
and found all to work well with progressive distillation:
• Predicting x directly.
• Predicting both x and , via separate output channels {x̃θ (zt ), ˜θ (zt )} of the neural net-
work, and then merging the predictions via x̂ = σt2 x̃θ (zt ) + αt (zt − σt ˜θ (zt )), thus
smoothly interpolating between predicting x directly and predicting via .
• Predicting v ≡ αt − σt x, which gives x̂ = αt zt − σt v̂θ (zt ), as we show in Appendix D.
In Section 5.1 we test all three parameterizations on training an original diffusion model (no distil-
lation), and find them to work well there also.
In addition to determining an appropriate parameterization, we also need to decide on a reconstruc-
tion loss weighting w(λt ). The setup of Ho et al. (2020) weights the reconstruction loss by the
signal-to-noise ratio, implicitly gives a weight of zero to data with zero SNR, and is therefore not a
suitable choice for distillation. We consider two alternative training loss weightings:
α2
• Lθ = max(kx − x̂t k22 , k − ˆt k22 ) = max( σ2t , 1)kx − x̂t k22 ; ‘truncated SNR’ weighting.
t
α2t
• Lθ = kvt − v̂t k22 = (1 + σt2
)kx − x̂t k22 ; ‘SNR+1’ weighting.
We examine both choices in our ablation study in Section 5.1, and find both to be good choices for
training diffusion models. In practice, the choice of loss weighting also has to take into account
how αt , σt are sampled during training, as this sampling distribution strongly determines the weight
the expected loss gives to each signal-to-noise ratio. Our results are for a cosine schedule αt =
cos(0.5πt), where time is sampled uniformly from [0, 1]. In Figure 2 we visualize the resulting loss
weightings, both including and excluding the effect of the cosine schedule.
5 E XPERIMENTS
In this section we empirically validate the progressive distillation algorithm proposed in Section 3,
as well as the parameterizations and loss weightings considered in Section 4. We consider various
image generation benchmarks, with resolution varying from 32 × 32 to 128 × 128. All experiments
use the cosine schedule αt = cos(0.5πt), and all models use a U-Net architecture similar to that
introduced by Ho et al. (2020), but with BigGAN-style up- and downsampling (Brock et al., 2019),
as used in the diffusion modeling setting by Nichol & Dhariwal (2021); Song et al. (2021c). Our
training setup closely matches the open source code by Ho et al. (2020). Exact details are given in
Appendix E.
As explained in Section 4, the standard method of having our model predict , and minimizing mean
squared error in the -space (Ho et al., 2020), is not appropriate for use with progressive distillation.
5
Published as a conference paper at ICLR 2022
0 0.2
0.1
−5
0
−6 −4 −2 0 2 4 6 −6 −4 −2 0 2 4 6
log SNR log SNR
Figure 2: Left: Log weight assigned to reconstruction loss kx − x̂λ k22 as a function of the log-SNR
λ = log[α2 /σ 2 ], for each of our considered training loss weightings, excluding the influence of
the αt , σt schedule. Right: Weights assigned to the reconstruction loss including the effect of the
cosine schedule αt = cos(0.5πt), with t ∼ U [0, 1]. The weights are only defined up to a constant,
and we have adjusted these constants to fit this graph.
Table 1: Generated sample quality as measured by FID and Inception Score (FID/IS) on uncondi-
tional CIFAR-10, training the original model (no distillation), and comparing different parameteri-
zations and loss weightings discussed in Section 4. All reported results are averages over 3 random
seeds of the best metrics obtained over 2 million training steps; nevertheless we find results are still
±0.1 due to the noise inherent in training our models. Taking the neural network output to represent
a prediction of in combination with the Truncated SNR loss weighting leads to divergence.
We therefore proposed various alternative parameterizations of the denoising diffusion model that
are stable under the progressive distillation procedure, as well as various weighting functions for
the reconstruction error in x-space. Here, we perform a complete ablation experiment of all pa-
rameterizations and loss weightings considered in Section 4. For computational efficiency, and for
comparisons to established methods in the literature, we use unconditional CIFAR-10 as the bench-
mark. We measure performance of undistilled models trained from scratch, to avoid introducing too
many factors of variation into our analysis.
Table 1 lists the results of the ablation study. Overall results are fairly close across different pa-
rameterizations and loss weights. All proposed stable model specifications achieve excellent perfor-
mance, with the exception of the combination of outputting with the neural network and weighting
the loss with the truncated SNR, which we find to be unstable. Both predicting x directly, as well
as predicting v, or the combination (, x), could thus be recommended for specification of diffusion
6
Published as a conference paper at ICLR 2022
models. Here, predicting v is the most stable option, as it has the unique property of making DDIM
step-sizes independent of the SNR (see Appendix D), but predicting x gives slightly better empirical
results in this ablation study.
(a) 256 sampling steps (b) 4 sampling steps (c) 1 sampling step
Figure 3: Random samples from our distilled 64 × 64 ImageNet models, conditioned on the ‘mala-
mute’ class, for fixed random seed and for varying number of sampling steps. The mapping from
input noise to output image is well preserved as the number of sampling steps is reduced.
7
Published as a conference paper at ICLR 2022
FID
6 6
5 5
4 4
3 3
2 2
1 2 4 8 16 32 64 128 256 512 1 2 4 8 16 32 64 128 256 512
sampling steps sampling steps
128x128 LSUN Bedrooms 128x128 LSUN Church-Outdoor
20 20
Distilled Distilled
DDIM DDIM
Stochastic Stochastic
10 10
9 9
8 8
FID
FID
7 7
6 6
5 5
4 4
3 3
Figure 4: Sample quality results as measured by FID for our distilled model on unconditional
CIFAR-10, class-conditional 64x64 ImageNet, 128x128 LSUN bedrooms, and 128x128 LSUN
church-outdoor. We compare against the DDIM sampler and against an optimized stochastic sam-
pler, each evaluated using the same models that were used to initialize the progressive distillation
procedure. For CIFAR-10 we report an average over 4 random seeds. For the other data sets we
only use a single run because of their computational demand. For the stochastic sampler we set the
variance as a log-scale interpolation between an upper and lower bound on the variance, follow-
ing Nichol & Dhariwal (2021), but we use a single interpolation coefficient rather than a learned
coefficient. We then tune this interpolation coefficient separately for each number of sampling steps
and report only the best result for that number of steps: this way we obtained better results than with
the learned interpolation.
that it requires constructing a large data set by running the original model at its full number of
sampling steps: their cost of distillation thus scales linearly with this number of steps, which can
be prohibitive. In contrast, our method never needs to run the original model at the full number
of sampling steps: at every iteration of progressive distillation, the number of model evaluations
is independent of the number of teacher sampling steps, allowing our method to scale up to large
numbers of teacher steps at a logarithmic cost in total distillation time.
DDIM (Song et al., 2021a) was originally shown to be effective for few-step sampling, as was the
probability flow sampler (Song et al., 2021c). Jolicoeur-Martineau et al. (2021) study fast SDE
integrators for reverse diffusion processes, and Tzen & Raginsky (2019b) study unbiased samplers
which may be useful for fast, high quality sampling as well.
8
Published as a conference paper at ICLR 2022
Other work on fast sampling can be viewed as manual or automated methods to adjust samplers or
diffusion processes for fast generation. Nichol & Dhariwal (2021); Kong & Ping (2021) describe
methods to adjust a discrete time diffusion model trained on many timesteps into models that can
sample in few timesteps. Watson et al. (2021) describe a dynamic programming algorithm to reduce
the number of timesteps for a diffusion model in a way that is optimal for log likelihood. Chen et al.
(2021); Saharia et al. (2021); Ho et al. (2021) train diffusion models over continuous noise levels and
tune samplers post training by adjusting the noise levels of a few-step discrete time reverse diffusion
process. Their method is effective in highly conditioned settings such as text-to-speech and image
super-resolution. San-Roman et al. (2021) train a new network to estimate the noise level of noisy
data and show how to use this estimate to speed up sampling.
Alternative specifications of the diffusion model can also lend themselves to fast sampling, such
as modified forward and reverse processes (Nachmani et al., 2021; Lam et al., 2021) and training
diffusion models in latent space (Vahdat et al., 2021).
Table 2: Comparison of fast sampling results on CIFAR-10 for diffusion models in the literature.
7 D ISCUSSION
We have presented progressive distillation, a method to drastically reduce the number of sampling
steps required for high quality generation of images, and potentially other data, using diffusion
models with deterministic samplers like DDIM (Song et al., 2020). By making these models cheaper
to run at test time, we hope to increase their usefulness for practical applications, for which running
time and computational requirements often represent important constraints.
In the current work we limited ourselves to setups where the student model has the same architecture
and number of parameters as the teacher model: in future work we hope to relax this constraint
and explore settings where the student model is smaller, potentially enabling further gains in test
time computational requirements. In addition, we hope to move past the generation of images and
also explore progressive distillation of diffusion models for different data modalities such as e.g.
audio (Chen et al., 2021).
In addition to the proposed distillation procedure, some of our progress was realized through differ-
ent parameterizations of the diffusion model and its training loss. We expect to see more progress in
this direction as the community further explores this model class.
9
Published as a conference paper at ICLR 2022
R EPRODUCIBILITY STATEMENT
We provide full details on model architectures, training procedures, and hyperparameters in Ap-
pendix E, in addition to our discussion in Section 5. In Algorithm 2 we provide fairly detailed pseu-
docode that closely matches our actual implementation, which is available in open source at https:
//github.com/google-research/google-research/tree/master/diffusion_distillation.
E THICS STATEMENT
In general, generative models can have unethical uses, such as fake content generation, and they
can suffer from bias if applied to data sets that are not carefully curated. The focus of this paper
specifically is on speeding up generative models at test time in order to reduce their computational
demands; we do not have specific concerns with regards to this contribution.
R EFERENCES
Jacob Austin, Daniel D. Johnson, Jonathan Ho, Daniel Tarlow, and Rianne van den Berg. Structured
denoising diffusion models in discrete state-spaces. CoRR, abs/2107.03006, 2021.
Andrew Brock, Jeff Donahue, and Karen Simonyan. Large scale GAN training for high fidelity
natural image synthesis. In International Conference on Learning Representations, 2019.
Ruojin Cai, Guandao Yang, Hadar Averbuch-Elor, Zekun Hao, Serge Belongie, Noah Snavely,
and Bharath Hariharan. Learning gradient fields for shape generation. arXiv preprint
arXiv:2008.06520, 2020.
Nanxin Chen, Yu Zhang, Heiga Zen, Ron J Weiss, Mohammad Norouzi, and William Chan. Wave-
Grad: Estimating gradients for waveform generation. International Conference on Learning Rep-
resentations, 2021.
Tian Qi Chen, Yulia Rubanova, Jesse Bettencourt, and David K Duvenaud. Neural ordinary differ-
ential equations. In Advances in Neural Information Processing Systems, pp. 6571–6583, 2018.
Prafulla Dhariwal and Alex Nichol. Diffusion models beat GANs on image synthesis. arXiv preprint
arXiv:2105.05233, 2021.
Will Grathwohl, Ricky TQ Chen, Jesse Bettencourt, Ilya Sutskever, and David Duvenaud. Ffjord:
Free-form continuous dynamics for scalable reversible generative models. arXiv preprint
arXiv:1810.01367, 2018.
Martin Heusel, Hubert Ramsauer, Thomas Unterthiner, Bernhard Nessler, and Sepp Hochreiter.
GANs trained by a two time-scale update rule converge to a local Nash equilibrium. In Advances
in Neural Information Processing Systems, pp. 6626–6637, 2017.
Jonathan Ho, Ajay Jain, and Pieter Abbeel. Denoising diffusion probabilistic models. In Advances
in Neural Information Processing Systems, pp. 6840–6851, 2020.
Jonathan Ho, Chitwan Saharia, William Chan, David J Fleet, Mohammad Norouzi, and Tim
Salimans. Cascaded diffusion models for high fidelity image generation. arXiv preprint
arXiv:2106.15282, 2021.
Emiel Hoogeboom, Didrik Nielsen, Priyank Jaini, Patrick Forré, and Max Welling. Argmax
flows and multinomial diffusion: Towards non-autoregressive language models. arXiv preprint
arXiv:2102.05379, 2021.
Alexia Jolicoeur-Martineau, Ke Li, Rémi Piché-Taillefer, Tal Kachman, and Ioannis Mitliagkas.
Gotta go fast when generating data with score-based models. arXiv preprint arXiv:2105.14080,
2021.
Diederik P Kingma, Tim Salimans, Ben Poole, and Jonathan Ho. Variational diffusion models.
arXiv preprint arXiv:2107.00630, 2021.
10
Published as a conference paper at ICLR 2022
Zhifeng Kong and Wei Ping. On fast sampling of diffusion probabilistic models. arXiv preprint
arXiv:2106.00132, 2021.
Max WY Lam, Jun Wang, Rongjie Huang, Dan Su, and Dong Yu. Bilateral denoising diffusion
models. arXiv preprint arXiv:2108.11514, 2021.
Haoying Li, Yifan Yang, Meng Chang, Huajun Feng, Zhihai Xu, Qi Li, and Yueting Chen.
Srdiff: Single image super-resolution with diffusion probabilistic models. arXiv preprint
arXiv:2104.14951, 2021.
Ilya Loshchilov and Frank Hutter. Decoupled weight decay regularization. arXiv preprint
arXiv:1711.05101, 2017.
Eric Luhman and Troy Luhman. Knowledge distillation in iterative generative models for improved
sampling speed. arXiv preprint arXiv:2101.02388, 2021.
Eliya Nachmani, Robin San Roman, and Lior Wolf. Non gaussian denoising diffusion models. arXiv
preprint arXiv:2106.07582, 2021.
Alexander Quinn Nichol and Prafulla Dhariwal. Improved denoising diffusion probabilistic models.
In Marina Meila and Tong Zhang (eds.), Proceedings of the 38th International Conference on
Machine Learning, ICML, 2021.
Chenhao Niu, Yang Song, Jiaming Song, Shengjia Zhao, Aditya Grover, and Stefano Ermon. Per-
mutation invariant graph generation via score-based generative modeling. In International Con-
ference on Artificial Intelligence and Statistics, pp. 4474–4484. PMLR, 2020.
Prajit Ramachandran, Tom Le Paine, Pooya Khorrami, Mohammad Babaeizadeh, Shiyu Chang,
Yang Zhang, Mark A Hasegawa-Johnson, Roy H Campbell, and Thomas S Huang. Fast genera-
tion for convolutional autoregressive models. arXiv preprint arXiv:1704.06001, 2017.
Chitwan Saharia, Jonathan Ho, William Chan, Tim Salimans, David J Fleet, and Mohammad
Norouzi. Image super-resolution via iterative refinement. arXiv preprint arXiv:2104.07636, 2021.
Robin San-Roman, Eliya Nachmani, and Lior Wolf. Noise estimation for generative diffusion mod-
els. arXiv preprint arXiv:2104.02600, 2021.
Jascha Sohl-Dickstein, Eric Weiss, Niru Maheswaranathan, and Surya Ganguli. Deep unsupervised
learning using nonequilibrium thermodynamics. In International Conference on Machine Learn-
ing, pp. 2256–2265, 2015.
Jiaming Song, Chenlin Meng, and Stefano Ermon. Denoising diffusion implicit models. Interna-
tional Conference on Learning Representations, 2021a.
Yang Song and Stefano Ermon. Generative modeling by estimating gradients of the data distribution.
In Advances in Neural Information Processing Systems, pp. 11895–11907, 2019.
Yang Song and Stefano Ermon. Improved techniques for training score-based generative. Advances
in Neural Information Processing Systems, 2020.
Yang Song, Conor Durkan, Iain Murray, and Stefano Ermon. Maximum likelihood training of score-
based diffusion models. arXiv e-prints, pp. arXiv–2101, 2021b.
Yang Song, Jascha Sohl-Dickstein, Diederik P Kingma, Abhishek Kumar, Stefano Ermon, and Ben
Poole. Score-based generative modeling through stochastic differential equations. International
Conference on Learning Representations, 2021c.
Yuxuan Song, Qiwei Ye, Minkai Xu, and Tie-Yan Liu. Discriminator contrastive divergence:
Semi-amortized generative modeling by exploring energy of the discriminator. arXiv preprint
arXiv:2004.01704, 2020.
Belinda Tzen and Maxim Raginsky. Neural stochastic differential equations: Deep latent gaussian
models in the diffusion limit. arXiv preprint arXiv:1905.09883, 2019a.
11
Published as a conference paper at ICLR 2022
Belinda Tzen and Maxim Raginsky. Theoretical guarantees for sampling and inference in generative
models with latent diffusions. In Conference on Learning Theory, pp. 3084–3114. PMLR, 2019b.
Arash Vahdat, Karsten Kreis, and Jan Kautz. Score-based generative modeling in latent space. arXiv
preprint arXiv:2106.05931, 2021.
Pascal Vincent. A connection between score matching and denoising autoencoders. Neural Compu-
tation, 23(7):1661–1674, 2011.
Daniel Watson, Jonathan Ho, Mohammad Norouzi, and William Chan. Learning to efficiently sam-
ple from diffusion probabilistic models. arXiv preprint arXiv:2106.03802, 2021.
12
Published as a conference paper at ICLR 2022
ing a variance preserving diffusion process with αt2 = 1 − σt2 = sigmoid(λt ) for λt = log[αt2 /σt2 ]
(without loss of generality, see Kingma et al. (2021)), we get
d log αt 1 d log αλ2 dλ 1 dλ 1 dλ
f (z, t) = zt = zt = (1 − αt2 ) zt = σt2 zt . (13)
dt 2 dλ dt 2 dt 2 dt
Similarly, we get
dσ 2 d log αt 2 dσ 2 dλ dλ dλ dλ dλ
g 2 (t) = t − 2 σt = λ − σt4 = (σt4 − σt2 ) − σt4 = −σt2 . (14)
dt dt dλ dt dt dt dt dt
Plugging these into the probability flow ODE then gives
1
dz = [f (z, t) − g 2 (t)∇z log pt (z)]dt (15)
2
1
= σλ2 [zλ + ∇z log pλ (z)]dλ. (16)
2
Plugging in our function approximation from Equation 12 gives
1 αλ x̂θ (zλ ) − zλ
dz = σλ2 zλ + dλ (17)
2 σλ2
1
= [αλ x̂θ (zλ ) + (σλ2 − 1)zλ ]dλ (18)
2
1
= [αλ x̂θ (zλ ) − αλ2 zλ ]dλ. (19)
2
13
Published as a conference paper at ICLR 2022
In a preliminary investigation we tried several numerical integrators for the probability flow ODE.
As our model we used a pre-trained class-conditional 128x128 ImageNet model following the de-
scription in Ho et al. (2020). We tried a simple Euler integrator, RK4 (the “classic” 4th order
Runge–Kutta integrator), and DDIM (Song et al., 2021a). In addition we compared to a Gaussian
sampler with variance equal to the lower bound given by Ho et al. (2020). We calculated FID scores
on just 5000 samples, hence our results in this experiment are not comparable to results reported in
the literature. This preliminary investigation gave the results listed in Table 3 and identified DDIM
as the best integrator in terms of resulting sample quality.
Table 3: Preliminary FID scores on 128 × 128 ImageNet for various integrators of the probability
flow ODE, and compared against a stochastic sampler. Model specification and noise schedule
follow Ho et al. (2020).
We can simplify the DDIM update rule by expressing it in terms of φt = arctan(σt /αt ), rather than
in terms of time t or log-SNR λt , as we show here.
Given our definition of φ, and assuming a variance preserving diffusion process, we have αφ =
cos(φ), σφ = sin(φ), and hence zφ = cos(φ)x + sin(φ). We can now define the velocity of zφ as
dzφ d cos(φ) d sin(φ)
vφ ≡ = x+ = cos(φ) − sin(φ)x. (26)
dφ dφ dφ
Rearranging , x, v, we then get
sin(φ)x = cos(φ) − vφ (27)
cos(φ)
= (z − cos(φ)x) − vφ (28)
sin(φ)
sin2 (φ)x = cos(φ)z − cos2 (φ)x − sin(φ)vφ (29)
2 2
(sin (φ) + cos (φ))x = x = cos(φ)z − sin(φ)vφ , (30)
and similarly we get = sin(φ)zφ + cos(φ)vφ .
Furthermore, we define the predicted velocity as
v̂θ (zφ ) ≡ cos(φ)ˆ
θ (zφ ) − sin(φ)x̂θ (zφ ), (31)
where ˆθ (zφ ) = (zφ − cos(φ)x̂θ (zφ ))/ sin(φ).
Rewriting the DDIM update rule in the introduced terms then gives
zφs = cos(φs )x̂θ (zφt ) + sin(φs )ˆ
θ (zφt ) (32)
= cos(φs )(cos(φt )zφt − sin(φt )v̂θ (zφt )) + sin(φs )(sin(φt )zφt + cos(φt )v̂θ (zφt )) (33)
= [cos(φs ) cos(φt ) − sin(φs ) sin(φt )]zφt + [sin(φs ) cos(φt ) − cos(φs ) sin(φt )]v̂θ (zφt ).
(34)
14
Published as a conference paper at ICLR 2022
15
Published as a conference paper at ICLR 2022
We use a batch size of 128 for CIFAR-10 and 2048 for the other data sets. We run our experiments
on TPUv4, using 8 TPU chips for CIFAR-10, and 64 chips for the other data sets. The total time
required to first train and then distill a model varies from about a day for CIFAR-10, to about 5 days
for ImageNet.
Our progressive distillation procedure was designed to be used with the DDIM sampler, but the
resulting distilled model could in principle also be used with a stochastic sampler. Here we evaluate a
distilled model for 64x64 ImageNet using the optimized stochastic sampler also used in Section 5.2.
The results are presented in Figure 6.
64x64 ImageNet
20
Distilled DDIM
Distilled Stochastic
Undistilled Stochastic
10
9
8
7
FID
6
5
4
2
1 2 4 8 16 32 64 128 256
sampling steps
Figure 6: FID of generated samples from distilled and undistilled models, using DDIM or stochastic
sampling. For the stochastic sampling results we present the best FID obtained by a grid-search over
11 possible noise levels, spaced log-uniformly between the upper and lower bound on the variance
as derived by Ho et al. (2020). The performance of the distilled model with stochastic sampling
is found to lie in between the undistilled original model with stochastic sampling and the distilled
DDIM sampler: For small numbers of sampling steps the DDIM sampler performs better with the
distilled model, for large numbers of steps the stochastic sampler performs better.
The key difference between our progressive distillation algorithm proposed in Section 3 and the
standard diffusion training procedure is in how we determine the target value for our denoising
model. In standard diffusion training, the target for denoising is the clean data x. In progressive
distillation it is the value x̃ the student denoising model would need to predict in order to match the
teacher model when sampling. Here we derive what this target needs to be.
Using notation t0 = t − 0.5/N and t00 = t − 1/N , when training a student with N sampling steps,
we have that the teacher model samples the next set of noisy data zt00 given the current noisy data zt
by taking two steps of DDIM. The student tries to sample the same value in only one step of DDIM.
Denoting the student denoising prediction by x̃, and its one-step sample by z̃t00 , application of the
DDIM sampler (see equation 8), gives:
σt00
z̃t00 = αt00 x̃ + (zt − αt x̃). (39)
σt
16
Published as a conference paper at ICLR 2022
In order for the student sampler to match the teacher sampler, we must set z̃t00 equal to zt00 . This
gives
σt00
z̃t00 = αt00 x̃ + (zt − αt x̃) = zt00 (40)
σt
σt00 σt00
= αt00 − αt x̃ + zt = zt00 (41)
σt σt
σt00 σt00
αt00 − αt x̃ = zt00 − zt (42)
σt σt
zt00 − σσtt00 zt
x̃ = (43)
αt00 − σσtt00 αt
In other words, if our student denoising model exactly predicts x̃ as defined in equation 43 above,
then the one-step student sample z̃t00 is identical to the two-step teacher sample zt00 . In order to have
our student model approximate this ideal outcome, we thus train it to predict x̃ from zt as well as
possible, using the standard squared error denoising loss (see Equation 9).
Note that this possibility of matching the two-step teacher model with a one-step student model is
unique to deterministic samplers like DDIM: the composition of two standard stochastic DDPM
sampling steps (Equation 5) forms a non-Gaussian distribution that falls outside the family of Gaus-
sian distributions that can be modelled by a single DDPM student step: A multi-step stochastic
DDPM sampler can thus not be distilled into a few-step sampler without some loss in fidelity. This
is in contrast with the deterministic DDIM sampler: here both the two-step DDIM teacher update
and the one-step DDIM student update represent deterministic mappings implemented by a neural
net, which is why the student is able to accurately match the teacher.
Finally, note that we do lose something during the progressive distillation process: while the original
model was trained to denoise zt for any given continuous time t, the distilled student models are
only ever evaluated on a small discrete set of times t. The student models thus lose generality as
distillation progresses. At the same time, it’s this loss of generality that allows the student models
to free up enough modeling capacity to accurately match the teacher model without increasing their
model size.
In this section we present additional random samples from our diffusion models obtained through
progressive distillation. We show samples for distilled models taking 256, 4, and 1 sampling steps.
All samples are uncurated.
As explained in Section 3, our distilled samplers implement a deterministic mapping from input
noise to output samples (also see Appendix G). To facilitate comparison of this mapping for varying
numbers of sampling steps, we generate all samples using the same random input noise, and we
present the samples side-by-side. As these samples show, the mapping is mostly preserved when
moving from many steps to a single step: The same input noise is mapped to the same output image,
with a slight loss in image quality, as the number of steps is reduced. Since the mapping is preserved
while reducing the number of steps, our distilled models also preserve the excellent sample diversity
of diffusion models (see e.g. Kingma et al. (2021)).
17
Published as a conference paper at ICLR 2022
(a) 256 sampling steps (b) 4 sampling steps (c) 1 sampling step
Figure 7: Random samples from our distilled CIFAR-10 models, for fixed random seed and for
varying number of sampling steps.
(a) 256 sampling steps (b) 4 sampling steps (c) 1 sampling step
Figure 8: Random samples from our distilled 64 × 64 ImageNet models, conditioned on the ‘coral
reef’ class, for fixed random seed and for varying number of sampling steps.
(a) 256 sampling steps (b) 4 sampling steps (c) 1 sampling step
Figure 9: Random samples from our distilled 64 × 64 ImageNet models, conditioned on the ‘sports
car’ class, for fixed random seed and for varying number of sampling steps.
18
Published as a conference paper at ICLR 2022
(a) 256 sampling steps (b) 4 sampling steps (c) 1 sampling step
Figure 10: Random samples from our distilled LSUN bedrooms models, for fixed random seed and
for varying number of sampling steps.
(a) 256 sampling steps (b) 4 sampling steps (c) 1 sampling step
Figure 11: Random samples from our distilled LSUN church-outdoor models, for fixed random seed
and for varying number of sampling steps.
19
Published as a conference paper at ICLR 2022
CIFAR-10
20
Distilled 50k
Distilled 25k
Distilled 10k
10 Distilled 5k
9 Distilled 4×
8
7
FID
6
5
4
2
1 2 4 8 16 32 64 128 256 512
sampling steps
Figure 12: Comparing our proposed schedule for progressive distillation taking 50k parameter up-
dates to train a new student every time the number of steps is halved, versus fast sampling schedules
taking fewer parameter updates (25k, 10k, 5k), and a fast schedule dividing the number of steps by
4 for every new student instead of by 2. All reported numbers are averages over 4 random seeds.
For each schedule we selected the optimal learning rate from [5e−5 , 1e−4 , 2e−4 , 3e−4 ].
20
Published as a conference paper at ICLR 2022
10 10 10
9 9 9
8 8 8
7
FID
7 7
6 6 6
5 5 5
4 4 4
3
3 3
2
1 2 4 8 16 32 64 128 256 1 2 4 8 16 32 64 128 256 1 2 4 8 16 32 64 128 256
sampling steps sampling steps sampling steps
Figure 13: Comparing our proposed schedule for progressive distillation taking 50k parameter up-
dates to train a new student every time the number of steps is halved, versus a fast sampling schedule
taking 10k parameter updates. For each reported number of steps we selected the optimal learning
rate from [5e−5 , 1e−4 , 2e−4 , 3e−4 ]. Results are for a single random seed.
21