Fundamentals
For any distribution over ( X , Y ) (X,Y) ( X , Y ) , we have
argmin f E ∣ ∣ f ( X ) − Y ∣ ∣ 2 = E [ Y ∣ X ] \text{argmin}_{f} \space \mathbb{E} ||f(X)-Y||^2=\mathbb{E}[Y|X]
argmin f E ∣∣ f ( X ) − Y ∣ ∣ 2 = E [ Y ∣ X ]
argmin f E ∣ ∣ f ( x ) − y ∣ ∣ 2 = argmin f E ∣ ∣ f ( x ) − y ∣ ∣ 2 = argmin f E X E Y ∣ X [ f ( x ) − Y ∣ X = x ] 2 = argmin f ( x ) E X [ f ( x ) 2 − 2 f ( x ) E Y ∣ X ( Y ∣ X = x ) + E Y ∣ X ( Y ∣ X = x ) 2 ] = E X ( argmin f ( x ) f ( x ) 2 − 2 f ( x ) E Y ∣ X ( Y ∣ X = x ) + E Y ∣ X ( Y ∣ X = x ) 2 ) \begin{aligned} &\mathop{\operatorname{argmin}\,}_{f} \mathbb{E} \lvert \lvert f(x)-y \rvert \rvert ^2 \\ =&\mathop{\operatorname{argmin}\,}_{f} \mathbb{E} ||f(x)-y||^2\\ =&\mathop{\operatorname{argmin}\,}_{f} \mathbb{E}_{X} \mathbb{E}_{Y|X} [f(x)-Y|X=x]^2\\ =&\mathop{\operatorname{argmin}\,}_{f(x)} \mathbb{E}_{X} [f(x)^2 - 2f(x) E_{Y|X} (Y|X=x) + \mathbb{E}_{Y|X} (Y|X=x)^2]\\ =&\mathbb{E}_{X} (\mathop{\operatorname{argmin}\,}_{f(x)} f(x)^2 - 2f(x) \mathbb{E}_{Y|X} (Y|X=x) + \mathbb{E}_{Y|X} (Y|X=x)^2) \end{aligned}
= = = = argmin f E ∣∣ f ( x ) − y ∣ ∣ 2 argmin f E ∣∣ f ( x ) − y ∣ ∣ 2 argmin f E X E Y ∣ X [ f ( x ) − Y ∣ X = x ] 2 argmin f ( x ) E X [ f ( x ) 2 − 2 f ( x ) E Y ∣ X ( Y ∣ X = x ) + E Y ∣ X ( Y ∣ X = x ) 2 ] E X ( argmin f ( x ) f ( x ) 2 − 2 f ( x ) E Y ∣ X ( Y ∣ X = x ) + E Y ∣ X ( Y ∣ X = x ) 2 )
对 f ( x ) f(x) f ( x ) 求导,得2 f ( x ) − 2 E Y ∣ X ( Y ∣ X = x ) = 0 ⇔ f ( x ) = E Y ∣ X ( Y ∣ X = x ) 2f(x)-2\mathbb{E}_{Y|X} (Y|X=x) = 0 \Leftrightarrow f(x)=\mathbb{E}_{Y|X} (Y|X=x) 2 f ( x ) − 2 E Y ∣ X ( Y ∣ X = x ) = 0 ⇔ f ( x ) = E Y ∣ X ( Y ∣ X = x )
所以 f = E [ Y ∣ X ] ( Y ) f=\mathbb{E}_{[Y|X]}(Y) f = E [ Y ∣ X ] ( Y )
Gradient Formula of Gaussian Distribution
y ∼ N ( x , σ 2 I ) i.e. p ( y ∣ x ) = 1 ( 2 π σ 2 ) d / 2 exp ( − 1 2 σ 2 ∥ y − x ∥ 2 ) ⟹ ∇ y p ( y ∣ x ) = − 1 σ 2 ( y − x ) p ( y ∣ x ) \begin{aligned} &y\sim \mathcal{N}(x,\sigma^{2}I) \text{ i.e. } p(y|x)=\dfrac{1}{(2\pi \sigma^{2})^{d/2}}\exp \left( -\dfrac{1}{2\sigma^{2}} \left\| y-x \right\| ^{2} \right) \\ \implies &\nabla_{y}p(y|x)=-\dfrac{1}{\sigma^{2}}(y-x) p(y|x) \end{aligned}
⟹ y ∼ N ( x , σ 2 I ) i.e. p ( y ∣ x ) = ( 2 π σ 2 ) d /2 1 exp ( − 2 σ 2 1 ∥ y − x ∥ 2 ) ∇ y p ( y ∣ x ) = − σ 2 1 ( y − x ) p ( y ∣ x )
∇ y p ( y ) = ∫ ∇ y p ( y ∣ x ) p ( x ) d x \nabla_{y} p(y)=\int \nabla_{y} p(y|x) p(x) \,\mathrm{d}x
∇ y p ( y ) = ∫ ∇ y p ( y ∣ x ) p ( x ) d x
∇ y p ( y ∣ x ) = p ( y ∣ x ) ⋅ ∇ y log p ( y ∣ x ) = p ( y ∣ x ) ⋅ ∇ y ( − d 2 log ( 2 π σ 2 ) − 1 2 σ 2 ∥ y − x ∥ 2 ) = p ( y ∣ x ) ⋅ ∇ y ( − 1 2 σ 2 ∥ y − x ∥ 2 ) = p ( y ∣ x ) ⋅ ( − 1 σ 2 ( y − x ) ) \begin{aligned} \nabla_{y} p(y|x) &= p(y|x) \cdot \nabla_{y} \log p(y|x)\\ &=p(y|x) \cdot \nabla_{y} \left( -\dfrac{d}{2} \log(2\pi \sigma^{2})-\dfrac{1}{2\sigma^{2}}\left\| y-x \right\| ^{2} \right)\\ &=p(y|x) \cdot \nabla_{y}\left( -\dfrac{1}{2\sigma^{2}}\left\| y-x \right\| ^{2} \right)\\ &=p(y|x) \cdot \left( -\dfrac{1}{\sigma^{2}}(y-x) \right) \end{aligned}
∇ y p ( y ∣ x ) = p ( y ∣ x ) ⋅ ∇ y log p ( y ∣ x ) = p ( y ∣ x ) ⋅ ∇ y ( − 2 d log ( 2 π σ 2 ) − 2 σ 2 1 ∥ y − x ∥ 2 ) = p ( y ∣ x ) ⋅ ∇ y ( − 2 σ 2 1 ∥ y − x ∥ 2 ) = p ( y ∣ x ) ⋅ ( − σ 2 1 ( y − x ) )
If Y ∼ N ( x , σ 2 ) Y \sim \mathcal{N}(x, \sigma^{2}) Y ∼ N ( x , σ 2 ) , then
E [ X ∣ Y = y ] = y + σ 2 ∇ y log p ( y ) \mathbb{E}[X|Y=y]=y+\sigma^2 \nabla _y \log p(y)
E [ X ∣ Y = y ] = y + σ 2 ∇ y log p ( y )
其中
p ( y ) : p(y): p ( y ) : Y Y Y 的边缘密度(观测到的 Y Y Y 的分布)
∇ y log p ( y ) \nabla_{y} \log p(y) ∇ y log p ( y ) :Y Y Y 的对数密度关于 y y y 的梯度 a.k.a. score function
p ( x ∣ y ) = p ( y ∣ x ) p ( x ) p ( y ) p(x|y)=\dfrac{p(y|x)p(x)}{p(y)} p ( x ∣ y ) = p ( y ) p ( y ∣ x ) p ( x )
由于 Y ∣ X = x ∼ N ( x , σ 2 I ) Y|X=x \sim \mathcal{N}(x, \sigma^2I) Y ∣ X = x ∼ N ( x , σ 2 I )
p ( y ∣ x ) = 1 ( 2 π σ 2 ) d / 2 exp ( − 1 2 σ 2 ) p(y|x)=\dfrac{1}{(2\pi \sigma^2)^{d/2}} \exp \left(-\frac{1}{2\sigma^{2}}\right)
p ( y ∣ x ) = ( 2 π σ 2 ) d /2 1 exp ( − 2 σ 2 1 )
E [ X ∣ Y = y ] = ∫ x p ( x ∣ y ) d x = 1 p ( y ) ∫ x p ( y ∣ x ) p ( x ) d x \begin{aligned} \mathbb{E}[X|Y=y]&=\int x p(x|y)\,\mathrm{d}x\\ &=\dfrac{1}{p(y)}\int x p(y|x) p(x) \,\mathrm{d}x \end{aligned}
E [ X ∣ Y = y ] = ∫ x p ( x ∣ y ) d x = p ( y ) 1 ∫ x p ( y ∣ x ) p ( x ) d x
∇ y p ( y ) = ∫ ∇ y p ( y ∣ x ) p ( x ) d x \nabla_{y} p(y)=\int \nabla_{y} p(y|x) p(x) \,\mathrm{d}x
∇ y p ( y ) = ∫ ∇ y p ( y ∣ x ) p ( x ) d x
高斯分布梯度公式
∇ y p ( y ∣ x ) = − 1 σ 2 ( y − x ) p ( y ∣ x ) \nabla_{y} p(y|x)=-\dfrac{1}{\sigma^{2}} (y-x) p(y|x)
∇ y p ( y ∣ x ) = − σ 2 1 ( y − x ) p ( y ∣ x )
代入,得
∇ y p ( y ) = ∫ ( − 1 σ 2 ( y − x ) p ( y ∣ x ) ) p ( x ) d x = − 1 σ 2 ( y ∫ p ( y ∣ x ) p ( x ) d x − ∫ x p ( y ∣ x ) p ( x ) d x ) = − 1 σ 2 ( y p ( y ) − ∫ x p ( y ∣ x ) p ( x ) d x ) \begin{aligned} \nabla_{y} p(y)&=\int(-\dfrac{1}{\sigma^{2}}(y-x)p(y|x))p(x)\,\mathrm{d}x\\ &=-\dfrac{1}{\sigma^{2}}\left( y\int p(y|x)p(x)\,\mathrm{d}x-\int xp(y|x)p(x)\,\mathrm{d}x \right) \\ &=-\dfrac{1}{\sigma^{2}}\left( yp(y)-\int xp(y|x)p(x)\,\mathrm{d}x \right) \end{aligned}
∇ y p ( y ) = ∫ ( − σ 2 1 ( y − x ) p ( y ∣ x )) p ( x ) d x = − σ 2 1 ( y ∫ p ( y ∣ x ) p ( x ) d x − ∫ x p ( y ∣ x ) p ( x ) d x ) = − σ 2 1 ( y p ( y ) − ∫ x p ( y ∣ x ) p ( x ) d x )
整理,得
y p ( y ) + σ 2 ∇ y p ( y ) = ∫ x p ( y ∣ x ) p ( x ) d x yp(y)+\sigma^{2}\nabla_{y}p(y)=\int xp(y|x)p(x)\,\mathrm{d}x
y p ( y ) + σ 2 ∇ y p ( y ) = ∫ x p ( y ∣ x ) p ( x ) d x
对于 E [ X ∣ Y = y ] \mathbb{E}[X|Y=y] E [ X ∣ Y = y ] ,有
E [ X ∣ Y = y ] = ∫ x p ( y ∣ x ) p ( x ) p ( y ) d x = 1 p ( y ) ∫ x p ( y ∣ x ) p ( x ) d x = y + σ 2 ∇ y log p ( y ) \mathbb{E}[X|Y=y]=\int x\dfrac{p(y|x)p(x)}{p(y)}\,\mathrm{d}x=\dfrac{1}{p(y)}\int xp(y|x)p(x)\,\mathrm{d}x=y+\sigma^{2}\nabla_{y}\log p(y)
E [ X ∣ Y = y ] = ∫ x p ( y ) p ( y ∣ x ) p ( x ) d x = p ( y ) 1 ∫ x p ( y ∣ x ) p ( x ) d x = y + σ 2 ∇ y log p ( y )
Gaussian Diffusion
x t + 1 : = x t + η t , η t ∼ N ( 0 , σ 2 ) x_{t+1}:=x_{t}+\eta_{t}, \eta_{t} \sim \mathcal{N}(0, \sigma^2)
x t + 1 := x t + η t , η t ∼ N ( 0 , σ 2 )
learn to reverse each intermediate step.
At time t t t , given input z z z , (sample from p t p_t p t ), output a sample from conditional distribution p ( x t − 1 ∣ x t = z ) p(x_{t-1}|x_{t}=z) p ( x t − 1 ∣ x t = z )
Learn the mean of p ( x t − 1 ∣ x t ) p(x_{t-1}|x_{t}) p ( x t − 1 ∣ x t ) is much simpler.
μ t − 1 ( z ) : = E [ x t − 1 ∣ x t = z ] \mu_{t-1} (z):= \mathbb{E}[x_{t-1}|x_t=z]\\
μ t − 1 ( z ) := E [ x t − 1 ∣ x t = z ]
⟹ μ t − 1 = argmin f : R d → R d E x t , x t − 1 ∣ ∣ f ( x t ) − x t − 1 ∣ ∣ 2 = argmin f E x t − 1 , η ∣ ∣ f ( x t − 1 + η ) − x t − 1 ∣ ∣ 2 \begin{aligned}
\implies \mu_{t-1}=&\mathop{\,\operatorname{argmin}\,}_{f:\mathbb{R}^d \to \mathbb{R}^d} \mathbb{E}_{x_{t}, x_{t-1}} ||f(x_{t}) - x_{t-1}||^2\\
=&\mathop{\,\operatorname{argmin}\,}_{f} \mathbb{E}_{x_{t-1}, \eta} || f(x_{t-1} + \eta) - x_{t-1}||^2\\
\end{aligned}
⟹ μ t − 1 = = argmin f : R d → R d E x t , x t − 1 ∣∣ f ( x t ) − x t − 1 ∣ ∣ 2 argmin f E x t − 1 , η ∣∣ f ( x t − 1 + η ) − x t − 1 ∣ ∣ 2
Then the estimate of E [ x t − 1 ∣ x t ] \mathbb{E}[x_{t-1}|x_{t}] E [ x t − 1 ∣ x t ] can be done by standard regression loss.
A reverse sample for step t t t is a function F t F_{t} F t such that if x t ∼ p t x_{t}\sim p_{t} x t ∼ p t then the marginal distribution of F t ( x t ) F_{t}(x_{t}) F t ( x t ) is p t − 1 p_{t-1} p t − 1
{ F t ( z ) : z ∼ p t } ≡ p t − 1 \{F_{t}(z):z \sim p_{t}\} \equiv p_{t-1}
{ F t ( z ) : z ∼ p t } ≡ p t − 1
The { F : x ∼ D } \{ F: x \sim D \} { F : x ∼ D } notation means implying a function on a variable x x x which follows distribution D D D , thus creating a new distribution.
Variance scaling:
p ( x , k Δ t ) = p k ( x ) , where Δ t = 1 T p(x, k \Delta t)=p_{k}(x)\text{, where }\Delta t=\frac{1}{T} p ( x , k Δ t ) = p k ( x ) , where Δ t = T 1 , T T T is discretization steps.
If x k = x k − 1 + N ( 0 , σ 2 ) x_{k}=x_{k-1} + \mathcal{N}(0,\sigma^2) x k = x k − 1 + N ( 0 , σ 2 ) , then x T ∼ N ( x 0 , T σ 2 ) x_{T} \sim \mathcal{N}(x_{0}, T \sigma^2) x T ∼ N ( x 0 , T σ 2 ) . So we scale variance by σ = σ q Δ t \sigma=\sigma_{q} \sqrt{ \Delta t } σ = σ q Δ t , σ q \sigma_{q} σ q is desired terminal variance.
Notations:
In below, t t t will represent a continuous-value in the interval [ 0 , 1 ] [0,1] [ 0 , 1 ] , subscripts will indicate time rather than index .
DDPM: Stochastic Sampling
DDPM stands for Denoising Diffusion Probabilistic Models
Stochastic Reverse Sampler
For input x t x_{t} x t and timestep t t t , output x ^ t − Δ t ← μ t − Δ t ( x t ) + N ( 0 , σ q 2 Δ t ) \hat{x}_{t-\Delta t} \leftarrow \mu_{t-\Delta t}(x_{t}) + \mathcal{N}(0, \sigma_{q}^2 \Delta t) x ^ t − Δ t ← μ t − Δ t ( x t ) + N ( 0 , σ q 2 Δ t )
∃ μ z , s.t. p ( x t − Δ t ∣ x t = z ) ≈ N ( x t − Δ t ; μ z , σ q 2 Δ t ) \exists \mu_{z}\text{ , s.t. }p(x_{t-\Delta t}|x_{t}=z) \approx \mathcal N (x_{t-\Delta t}; \mu_{z}, \sigma_{q}^2 \Delta t) ∃ μ z , s.t. p ( x t − Δ t ∣ x t = z ) ≈ N ( x t − Δ t ; μ z , σ q 2 Δ t )
If constant μ z \mu_{z} μ z depends only on z z z , we can take
μ z : = E x t − Δ t , x t [ x t − Δ t ∣ x t = z ] = z + ( σ q 2 Δ t ) ∇ log p t ( z ) \begin{aligned}
\mu_{z} &:= \mathbb{E}_{x_{t-\Delta t}, x_{t}}[x_{t - \Delta t} | x_{t}=z]\\
&=z+(\sigma_{q}^2 \Delta t)\nabla \log p_{t}(z)
\end{aligned}
μ z := E x t − Δ t , x t [ x t − Δ t ∣ x t = z ] = z + ( σ q 2 Δ t ) ∇ log p t ( z )
The Bayes rule:
p ( x t − Δ t ∣ x t ) = p ( x t ∣ x t − Δ t ) p t − Δ t ( x t − Δ t ) p t ( x t ) p(x_{t-\Delta t}|x_{t})=\dfrac{p(x_{t}|x_{t-\Delta t})p_{t-\Delta t}(x_{t-\Delta t})}{p_{t}(x_{t})}
p ( x t − Δ t ∣ x t ) = p t ( x t ) p ( x t ∣ x t − Δ t ) p t − Δ t ( x t − Δ t )
Take log on both side:
log p ( x t − Δ t ∣ x t ) = log p ( x t ∣ x t − Δ t ) + log p t − Δ t ( x t − Δ t ) − log p t ( x t ) Drop constants not involve x t − Δ t = log p ( x t ∣ x t − Δ t ) + log p t ( x t − Δ t ) + O ( Δ t ) Because p t − Δ t = p t + Δ t ∂ ∂ t p t = − 1 2 σ q 2 Δ t ∥ x t − Δ t − x t ∥ 2 + log p t ( x t − Δ t ) Substitute N ( x t ; x t − Δ t , σ q 2 Δ t ) = − ⋯ + log p t ( x t ) + ⟨ ∇ x log p t ( x t ) , ( x t − Δ t − x t ) ⟩ + O ( Δ t ) Taylor expand, ⟨ ⟩ is inner product = − 1 2 σ q 2 Δ t ∥ x t − Δ t − x t − σ q 2 Δ t ∇ x log p t ( x t ) ∥ 2 + C = − 1 2 σ q 2 Δ t ∥ x t − Δ t − x t ∥ 2 \begin{aligned} &\log p(x_{t-\Delta t}|x_{t})\\ =&\log p(x_{t}|x_{t-\Delta t})+\log p_{t-\Delta t}(x_{t-\Delta t})\cancel{-\log p_{t}(x_{t})} \quad\quad&\text{Drop constants not involve }x_{t-\Delta t}\\ =&\log p(x_{t}|x_{t-\Delta t})+\log p_{t}(x_{t-\Delta t})+\mathcal{O}(\Delta t) &\text{Because } p_{t - \Delta t}=p_{t}+\Delta t \frac{ \partial }{ \partial t } p_{t}\\ =&-\dfrac{1}{2\sigma_{q}^{2}\Delta t}\lVert x_{t-\Delta t}-x_{t} \rVert ^{2}+\log p_{t}(x_{t-\Delta t})&\text{Substitute } \mathcal{N}(x_{t};\,x_{t-\Delta t},\sigma_{q}^{2}\Delta t)\\ =&-\cdots+\cancel{\log p_{t}(x_{t})}+\langle \nabla_{x}\log p_{t}(x_{t}),(x_{t-\Delta t}-x_{t}) \rangle +\mathcal{O}(\Delta t)&\text{Taylor expand, }\langle \rangle\text{ is inner product} \\ =&-\dfrac{1}{2\sigma_{q}^{2}\Delta t}\lVert x_{t-\Delta t}-x_{t}-\sigma_{q}^{2}\Delta t \nabla_{x} \log p_{t}(x_{t}) \rVert ^{2}+C\\ =&-\dfrac{1}{2\sigma_{q}^{2}\Delta t}\lVert x_{t-\Delta t}-x_{t} \rVert ^{2} \end{aligned}
= = = = = = log p ( x t − Δ t ∣ x t ) log p ( x t ∣ x t − Δ t ) + log p t − Δ t ( x t − Δ t ) − log p t ( x t ) log p ( x t ∣ x t − Δ t ) + log p t ( x t − Δ t ) + O ( Δ t ) − 2 σ q 2 Δ t 1 ∥ x t − Δ t − x t ∥ 2 + log p t ( x t − Δ t ) − ⋯ + log p t ( x t ) + ⟨ ∇ x log p t ( x t ) , ( x t − Δ t − x t )⟩ + O ( Δ t ) − 2 σ q 2 Δ t 1 ∥ x t − Δ t − x t − σ q 2 Δ t ∇ x log p t ( x t ) ∥ 2 + C − 2 σ q 2 Δ t 1 ∥ x t − Δ t − x t ∥ 2 Drop constants not involve x t − Δ t Because p t − Δ t = p t + Δ t ∂ t ∂ p t Substitute N ( x t ; x t − Δ t , σ q 2 Δ t ) Taylor expand, ⟨ ⟩ is inner product
It is the log density of N ( x t − Δ t ; μ , σ q 2 Δ t ) \mathcal{N}(x_{t-\Delta t};\mu,\sigma_q^{2}\Delta t) N ( x t − Δ t ; μ , σ q 2 Δ t )
The train loss of DDPM:
1 2 3 4 5 6 def train_loss (f_theta, p ): x0 = p.sample() t = uniform(0 , 1 ).sample() x = x0 + normal(0 , sigma_q**2 * t).sample() x_ = x + normal(0 , sigma_q**2 * dt).sample() return (f_theta(x_, t + dt) - x)**2
Sampling:
1 2 3 4 5 6 def DDPM (f_theta ): x = normal(0 , sigma_q**2 ).sample() for t in reversed (range (0 , 1 , dt)): eta = normal(0 , sigma_q**2 * dt).sample() x = f_theta(x, t) + eta return x
1 2 3 4 5 6 def DDIM (f_theta ): x = normal(0 , sigma_q**2 ).sample() for t in reversed (range (0 , 1 , dt)): weight = (t**0.5 ) / ((t-dt)**0.5 + t**0.5 ) x = x + weight * (f_theta(x, t) - x) return x
For Gaussian diffusion setting, we have
E [ ( x t − Δ t − x t ) ∣ x t ] = Δ t t E [ x 0 ∣ x t ] + ( 1 − Δ t t ) x t \mathbb{E}[(x_{t-\Delta t}-x_{t})|x_{t}]=\dfrac{\Delta t}{t}\mathbb{E}[x_{0}|x_{t}]+\left( 1- \dfrac{\Delta t}{t} \right)x_{t}
E [( x t − Δ t − x t ) ∣ x t ] = t Δ t E [ x 0 ∣ x t ] + ( 1 − t Δ t ) x t
DDIM: Deterministic Sampling
Deterministic Reverse Sampler
Given joint distribution ( x 0 , x Δ t , … , x 1 ) (x_{0},x_{\Delta t},\dots,x_{1}) ( x 0 , x Δ t , … , x 1 ) and conditional expectation μ t ( z ) : = E [ x t ∣ x t − Δ t = z ] \mu_{t}(z):=\mathbb{E}[x_{t}|x_{t-\Delta t}=z] μ t ( z ) := E [ x t ∣ x t − Δ t = z ]
The DDIM sample receives input sample x t x_{t} x t and step index t t t , outputs
x ^ t − Δ t ← x t + λ ( μ t − Δ t ( x t ) − x t ) , λ : = ( σ t σ t − Δ t + σ t ) , σ t ≡ σ q t \hat x_{t-\Delta t}\leftarrow x_{t}+\lambda(\mu_{t-\Delta t}(x_{t})-x_{t}), \quad \lambda:=\left( \dfrac{\sigma_{t}}{\sigma_{t-\Delta t}+\sigma_{t}} \right),\sigma_{t}\equiv \sigma_q\sqrt{ t }
x ^ t − Δ t ← x t + λ ( μ t − Δ t ( x t ) − x t ) , λ := ( σ t − Δ t + σ t σ t ) , σ t ≡ σ q t
Flow Matching
A flow is a collection of time-indexed vector fields v = { v t } t ∈ [ 0 , 1 ] v=\{ v_{t} \}_{t \in[0,1]} v = { v t } t ∈ [ 0 , 1 ] ,v t v_{t} v t : velocity-field of a gas at each time t t t .
For flow v v v and initial point x 1 x_{1} x 1 , there has d x t d t = − v t ( x t ) \dfrac{\,\mathrm{d}x_{t}}{\,\mathrm{d}t}=-v_{t}(x_{t}) d t d x t = − v t ( x t )
The Goal of Flow Matching
Learn a flow v ∗ v^* v ∗ transports q q q to p p p , where p p p is the target distribution, q q q is some easy-to-sample base distribution (ie. Gaussian)
The DDIM algorithm is a special case of this.
A pointwise flow v [ x 1 , x 0 ] v^{[x_{1},x_{0}]} v [ x 1 , x 0 ] is a flow { v t } t \{ v_{t} \}_{t} { v t } t that satisfies d x t d t = − v t ( x t ) \dfrac{\,\mathrm{d}x_{t}}{\,\mathrm{d}t}=-v_{t}(x_{t}) d t d x t = − v t ( x t ) , with boundary conditions x 1 x_{1} x 1 and x 0 x_{0} x 0
weighted average of all individual partical velocities v t [ x 1 , x 0 ] v_{t}^{[x_{1},x_{0}]} v t [ x 1 , x 0 ]
E x 0 , x 1 ∣ x t [ v t [ x 1 , x 0 ] ( x t ) ∣ x t ] \mathbb{E}_{x_{0},x_{1}|x_{t}} [v_{t}^{[x_{1},x_{0}]}(x_{t})|x_{t}]
E x 0 , x 1 ∣ x t [ v t [ x 1 , x 0 ] ( x t ) ∣ x t ]
The ( x 1 , x 0 , x t ) (x_{1},x_{0},x_{t}) ( x 1 , x 0 , x t ) is induced by sampling ( x 1 , x 0 ) ∼ Π q , p (x_{1},x_{0}) \sim \Pi_{q,p} ( x 1 , x 0 ) ∼ Π q , p , x t ← RunFlow ( v [ x 1 , x 0 ] , x 1 , t ) x_{t}\leftarrow\text{RunFlow}(v^{[x_{1},x_{0}]},x_{1},t) x t ← RunFlow ( v [ x 1 , x 0 ] , x 1 , t )
Flow Matching:
v t ∗ ( x t ) : = E x 0 , x 1 ∣ x t [ v t [ x 1 , x 0 ] ( x t ) ∣ x t ] ⟹ v t ∗ = argmin f : R d → R d E ( x 0 , x 1 , x t ) ∥ f ( x t ) − v t [ x 1 , x 0 ] ( x t ) ∥ 2 \begin{aligned}
&v_{t}^*(x_{t}):= \mathbb{E}_{x_{0},x_{1}|x_{t}} [v_{t}^{[x_{1},x_{0}]} (x_{t})|x_{t}]\\
\implies &v_{t}^*=\mathop{\,\operatorname{argmin}\,}_{f:\mathbb{R}^d\to \mathbb{R}^{d}} \mathbb{E}_{(x_{0},x_{1},x_{t})} \lVert f(x_{t})-v_{t}^{[x_{1},x_{0}]}(x_{t}) \rVert ^{2}
\end{aligned}
⟹ v t ∗ ( x t ) := E x 0 , x 1 ∣ x t [ v t [ x 1 , x 0 ] ( x t ) ∣ x t ] v t ∗ = argmin f : R d → R d E ( x 0 , x 1 , x t ) ∥ f ( x t ) − v t [ x 1 , x 0 ] ( x t ) ∥ 2
Train loss of Flow-matching:
1 2 3 4 5 def train_loss (f_theta, q_p_dist, pointwise_flow ): x1, x0 = q_p_dist.sample() t = uniform(0 , 1 ).sample() xt = run_flow(pointwise_flow(x1, x0), x1, t) return (f_theta(xt, t) - vt(x1,x0)(xt)) ** 2
1 2 3 4 5 6 def sample (f_theta, base_dist, step_size ): x1 = base_dist.sample() x0 = x1 for i in reversed (range (0 , 1 , step_size)): x0 = x0 + f_theta(x0, t) * step_size return x0
VAE
VAE: Variational Auto-Encoder
Latent Variables z z z are variables that we do not observe and hense are not part of training dataset.
Encoder : convert from input x x x to latent variables z z z .
Decoder : convert from z z z to generated vector x ^ \hat{x} x ^
Variational: 变分,关于在函数上的优化
VAE: search for the optimal probability distributions to describe x x x and z z z .
p ( x ) p(x) p ( x ) : The true distribution of x x x . THE ULTIMATE GOAL of diffusion is to draw a sample from p ( x ) p(x) p ( x ) .
p ( z ) p(z) p ( z ) : The distribution of latent variable. Typically it is made to be N ( 0 , I ) \mathcal{N}(0,I) N ( 0 , I ) Any distribution can be generated by mapping a Gaussian through a sufficiently complicated function.
p ( z ∣ x ) p(z|x) p ( z ∣ x ) : The conditional distribution associated with the encoder , the likelihood of z z z when given x x x .
p ( x ∣ z ) p(x|z) p ( x ∣ z ) : decoder , posterior probability of getting x x x given z z z .
q Φ ( z ∣ x ) q_{\Phi}(z|x) q Φ ( z ∣ x ) : The proxy for p ( z ∣ x ) p(z|x) p ( z ∣ x ) that can be parameterized using deep neural networks. eg.( μ , σ 2 ) = EncoderNetwork Φ ( x ) , q Φ ( z ∣ x ) = N ( z ∣ μ , diag ( σ 2 ) ) (\mu,\sigma^{2})=\text{EncoderNetwork}_{\Phi}(x), q_{\Phi}(z|x) =\mathcal{N}(z|\mu,\text{diag}(\sigma^{2}))
( μ , σ 2 ) = EncoderNetwork Φ ( x ) , q Φ ( z ∣ x ) = N ( z ∣ μ , diag ( σ 2 ))
p θ ( x ∣ z ) p_{\theta}(x|z) p θ ( x ∣ z ) : The proxy for p ( x ∣ z ) p(x|z) p ( x ∣ z )
![[Pasted image 20250503161111.png]]
ELBO
ELBO: Evidence Lower Bound
ELBO ( x ) = def E q ϕ ( z ∣ x ) [ log p ( x , z ) q ϕ ( z ∣ x ) ] \text{ELBO}(x)\stackrel{\text{def}}{=} \mathbb{E}_{q_{\phi}}(z|x)\left[ \log \dfrac{p(x,z)}{q_{\phi}(z|x)} \right]
ELBO ( x ) = def E q ϕ ( z ∣ x ) [ log q ϕ ( z ∣ x ) p ( x , z ) ]
D KL ( P ∥ Q ) = E x ∼ P [ ln p ( x ) q ( x ) ] \mathbb{D}_{\text{KL}}(P\|Q)=\mathop{\mathbb{E}}_{x \sim P}\left[ \ln \dfrac{p(x)}{q(x)} \right]
D KL ( P ∥ Q ) = E x ∼ P [ ln q ( x ) p ( x ) ]
Decomposition of Log-Likelihood
log p ( x ) = E q ϕ ( z ∣ x ) [ log p ( x , z ) q ϕ ( z ∣ x ) ] + D KL ( q ϕ ( z ∣ x ) ∥ p ( z ∣ x ) ) \log p(x)=\mathbb{E}_{q_{\phi}}(z|x)\left[ \log \dfrac{p(x,z)}{q_{\phi}(z|x)} \right]+\mathbb{D}_{\text{KL}}(q_{\phi}(z|x)\|p(z|x)) log p ( x ) = E q ϕ ( z ∣ x ) [ log q ϕ ( z ∣ x ) p ( x , z ) ] + D KL ( q ϕ ( z ∣ x ) ∥ p ( z ∣ x ))
log p ( x ) = log p ( x ) × ∫ q ϕ ( z ∣ x ) d z ⏟ 1 = ∫ log p ( x ) × q ϕ ( z ∣ x ) d z = E q ϕ ( z ∣ x ) [ log p ( x ) ] = E q ϕ ( z ∣ x ) [ log p ( x , z ) p ( z ∣ x ) ] Bayes Theorem = E q ϕ ( z ∣ x ) [ log p ( x , z ) p ( z ∣ x ) ⋅ q ϕ ( z ∣ x ) q ϕ ( z ∣ x ) ] = E q ϕ ( z ∣ x ) [ log p ( x , z ) q ϕ ( z ∣ x ) ] ⏟ ELBO + E q ϕ ( z ∣ x ) [ log q ϕ ( z ∣ x ) p ( z ∣ x ) ] ⏟ D KL ( q ϕ ( z ∣ x ) ∥ p ( z ∣ x ) ) \begin{aligned} \log p(x)&=\log p(x)\times \underbrace{ \int q_{\phi} (z|x)\,\mathrm{d}z }_{ 1 }\\ &=\int \log p(x) \times q_{\phi}(z|x) \,\mathrm{d}z\\ &=\mathbb{E}_{q_\phi(z|x)}[\log p(x)]\\ &=\mathbb{E}_{q_\phi(z|x)}\left[ \log \dfrac{p(x,z)}{p(z|x)} \right] \quad &\text{Bayes Theorem}\\ &=\mathbb{E}_{q_\phi(z|x)}\left[ \log \dfrac{p(x,z)}{p(z|x)} \cdot \dfrac{q_\phi(z|x)}{q_\phi(z|x)} \right]\\ &=\underbrace{ \mathbb{E}_{q_\phi(z|x)}\left[ \log \dfrac{p(x,z)}{q_\phi(z|x)} \right] }_{ \text{ELBO} }+\underbrace{ \mathbb{E}_{q_\phi(z|x)}\left[ \log \dfrac{q_\phi(z|x)}{p(z|x)} \right] }_{ \mathbb{D}_{\text{KL}} (q_\phi(z|x)\|p(z|x)) } \quad& \end{aligned}
log p ( x ) = log p ( x ) × 1 ∫ q ϕ ( z ∣ x ) d z = ∫ log p ( x ) × q ϕ ( z ∣ x ) d z = E q ϕ ( z ∣ x ) [ log p ( x )] = E q ϕ ( z ∣ x ) [ log p ( z ∣ x ) p ( x , z ) ] = E q ϕ ( z ∣ x ) [ log p ( z ∣ x ) p ( x , z ) ⋅ q ϕ ( z ∣ x ) q ϕ ( z ∣ x ) ] = ELBO E q ϕ ( z ∣ x ) [ log q ϕ ( z ∣ x ) p ( x , z ) ] + D KL ( q ϕ ( z ∣ x ) ∥ p ( z ∣ x )) E q ϕ ( z ∣ x ) [ log p ( z ∣ x ) q ϕ ( z ∣ x ) ] Bayes Theorem
So the ELBO is a lower bound of log p ( x ) \log p(x) log p ( x ) , maximize ELBO can achieve the goal of maximize log p ( x ) \log p(x) log p ( x ) .
When the KL-divergence is zero, q ϕ ( z ∣ x ) = p ( z ∣ x ) q_\phi(z|x)=p(z|x) q ϕ ( z ∣ x ) = p ( z ∣ x ) , since p ( z ∣ x ) p(z|x) p ( z ∣ x ) is delta function, we have
q ϕ ( z ∣ x ) = N ( z ∣ x − μ σ , 0 ) = δ ( z − x − μ σ ) q_\phi(z|x)=\mathcal{N}\left( z \left| \frac{x-\mu}{\sigma},0 \right.\right)=\delta\left( z-\dfrac{x-\mu}{\sigma} \right)
q ϕ ( z ∣ x ) = N ( z σ x − μ , 0 ) = δ ( z − σ x − μ )
ELBO is still not useful, for it involves p ( x , z ) p(x,z) p ( x , z ) that we do not have access.
ELBO ( x ) = E q ϕ ( z ∣ x ) [ log p θ ( x ∣ z ) ] ⏟ how good your decoder is − D KL ( q ϕ ( z ∣ x ) ∥ p ( z ) ) ⏟ how good your encoder is \text{ELBO}(x)=\underbrace{ \mathbb{E}_{q_\phi(z|x)}[\log p_{\theta}(x|z)] }_{ \text{how good your decoder is} }\quad\underbrace{-\quad \mathbb{D}_{\text{KL}}(q_\phi(z|x)\|p(z)) }_{ \text{how good your encoder is} }
ELBO ( x ) = how good your decoder is E q ϕ ( z ∣ x ) [ log p θ ( x ∣ z )] how good your encoder is − D KL ( q ϕ ( z ∣ x ) ∥ p ( z ))
p θ ( x ∣ z ) , q ϕ ( z ∣ x ) , p ( z ) p_{\theta}(x|z),q_\phi(z|x),p(z) p θ ( x ∣ z ) , q ϕ ( z ∣ x ) , p ( z ) are both Gaussian
ELBO ( x ) = def E q ϕ ( z ∣ x ) [ log p ( x , z ) q ϕ ( z ∣ x ) ] = E q ϕ ( z ∣ x ) [ log p ( x ∣ z ) p ( z ) q ϕ ( z ∣ x ) ] = E q ϕ ( z ∣ x ) [ log p θ ( x ∣ z ) ] + E q ϕ ( z ∣ x ) [ log p ( z ) q ϕ ( z ∣ x ) ] \begin{aligned} \text{ELBO}(x)&\stackrel{\text{def}}{=}\mathbb{E}_{q_\phi(z|x)}\left[ \log \dfrac{p(x,z)}{q_\phi(z|x)} \right]\\
&=\mathbb{E}_{q_\phi(z|x)}\left[ \log \dfrac{p(x|z)p(z)}{q_\phi(z|x)} \right]\\ &=\mathbb{E}_{q_\phi(z|x)}[\log p_{\theta}(x|z)]+\mathbb{E}_{q_\phi(z|x)}\left[ \log \dfrac{p(z)}{q_\phi(z|x)} \right] \end{aligned} ELBO ( x ) = def E q ϕ ( z ∣ x ) [ log q ϕ ( z ∣ x ) p ( x , z ) ] = E q ϕ ( z ∣ x ) [ log q ϕ ( z ∣ x ) p ( x ∣ z ) p ( z ) ] = E q ϕ ( z ∣ x ) [ log p θ ( x ∣ z )] + E q ϕ ( z ∣ x ) [ log q ϕ ( z ∣ x ) p ( z ) ]
Note that we replaced p ( x ∣ z ) p(x|z) p ( x ∣ z ) by p θ ( x ∣ z ) p_{\theta}(x|z) p θ ( x ∣ z ) since the latter is accessible.
The meaning of each term:
Reconstruction:
E q ϕ ( z ∣ x ) [ log p θ ( x ∣ z ) ] \mathbb{E}_{q_\phi(\mathbf{z}|\mathbf{x})}[\log p_\theta(\mathbf{x}|\mathbf{z})] E q ϕ ( z ∣ x ) [ log p θ ( x ∣ z )] . It is similar to maximum likelihood where we want to find the model parameter to maximize the likelihood. The expectation is taken w.r.t. samples z \mathbf{z} z that is sampled from q ϕ ( z ∣ x ) q_\phi(\mathbf{z}|\mathbf{x}) q ϕ ( z ∣ x )
Prior Matching:
KL divergence for encoder. Let encoder to turn x x x to a latent vector z z z such that the z z z vector follows the choice of latent distribution p ( z ) p(\mathbf{z}) p ( z ) .
x ∼ p ( x ) = N ( x ∣ μ , σ 2 I ) \mathbf{x} \sim p(\mathbf{x})=\mathcal{N}(\mathbf{x}|\boldsymbol{\mu},\sigma^{2}\mathbf{I}) x ∼ p ( x ) = N ( x ∣ μ , σ 2 I )
z ∼ p ( z ) = N ( z ∣ 0 , I ) \mathbf{z}\sim p(\mathbf{z})=\mathcal{N}(\mathbf{z}|0,\mathbf{I}) z ∼ p ( z ) = N ( z ∣0 , I )
So that z z z can be trivial solution x − u σ \dfrac{x-u}{\sigma} σ x − u , and x ^ = μ + σ z \hat{x}=\mu+\sigma z x ^ = μ + σ z .
p ( x ∣ z ) = δ ( x − ( σ z + μ ) ) p(\mathbf{x}|\mathbf{z})=\delta(\mathbf{x}-(\sigma \mathbf{z}+\boldsymbol{\mu})) p ( x ∣ z ) = δ ( x − ( σ z + μ ))
p ( z ∣ x ) = δ ( z − x − μ σ ) p(\mathbf{z}|\mathbf{x})=\delta\left( \mathbf{z}-\dfrac{\mathbf{x}-\boldsymbol{\mu}}{\sigma} \right) p ( z ∣ x ) = δ ( z − σ x − μ )
Suppose we don’t know p ( x ) p(\mathbf{x}) p ( x ) so we need to estimate z z z and x x x .
( μ ^ ( x ) , σ ^ ( x ) 2 ) = Encoder ϕ ( x ) q ϕ ( z ∣ x ) = N ( z ∣ a x + b , t 2 I ) \begin{aligned}(\hat{\boldsymbol{\mu}}(\mathbf{x}),\hat{\sigma}(\mathbf{x})^{2})&=\text{Encoder}_{\phi}(\mathbf{x})\\q_\phi(z|x)&=\mathcal{N}(\mathbf{z}|a\mathbf{x}+\mathbf{b},t^{2}\mathbf{I})\end{aligned}
( μ ^ ( x ) , σ ^ ( x ) 2 ) q ϕ ( z ∣ x ) = Encoder ϕ ( x ) = N ( z ∣ a x + b , t 2 I )
Assume μ ^ \hat{\boldsymbol{\mu}} μ ^ is an affine function of x x x .
q ϕ ( z ∣ x ) = N ( z ∣ a x + b , t 2 I ) q_\phi(\mathbf{z}|\mathbf{x})=\mathcal{N}(\mathbf{z}|a\mathbf{x}+\mathbf{b},t^{2}\mathbf{I})
q ϕ ( z ∣ x ) = N ( z ∣ a x + b , t 2 I )
For decoder, we have
p θ ( x ∣ z ) = N ( z ∣ c x + v , s 2 I ) p_\theta(\mathbf{x}|\mathbf{z})=\mathcal{N}(\mathbf{z}|c\mathbf{x}+\mathbf{v},s^{2}\mathbf{I})
p θ ( x ∣ z ) = N ( z ∣ c x + v , s 2 I )
For KL-divergence D KL ( q ϕ ( z ∣ x ) ∥ p ( x ∣ z ) ) \mathbb{D}_{\text{KL}}(q_\phi(\mathbf{z}|\mathbf{x})\|p(\mathbf{x}|\mathbf{z})) D KL ( q ϕ ( z ∣ x ) ∥ p ( x ∣ z )) to be zero, we have
q ϕ ( z ∣ x ) = N ( z ∣ x − μ σ , 0 ) = δ ( z − x − μ σ ) q_\phi(\mathbf{z}|\mathbf{x})=\mathcal{N}\left( \mathbf{z}| \dfrac{x-\mu}{\sigma},0 \right)=\delta\left( \mathbf{z}-\dfrac{x-\mu}{\sigma} \right)
q ϕ ( z ∣ x ) = N ( z ∣ σ x − μ , 0 ) = δ ( z − σ x − μ )
Substitue to E q ϕ ( z ∣ x ) [ log p θ ( x ∣ z ) ] \mathbb{E}_{q_\phi(\mathbf{z}|\mathbf{x})}[\log p_\theta(\mathbf{x}|\mathbf{z})] E q ϕ ( z ∣ x ) [ log p θ ( x ∣ z )] :
E q ϕ ( z ∣ x ) [ log p θ ( x ∣ z ) ] = E q ϕ ( z ∣ x ) [ log N ( x ∣ c z + v , s 2 I ) ] = − 1 2 log 2 π − log s − c 2 2 s 2 [ ∥ x − μ σ − x − v c ∥ 2 ] ≤ − 1 2 log 2 π − log s \begin{aligned} \mathbb{E}_{q_\phi(\mathbf{z}|\mathbf{x})}[\log p_\theta(\mathbf{x}|\mathbf{z})]&=\mathbb{E}_{q_\phi(\mathbf{z}|\mathbf{x})}[\log \mathcal{N}(\mathbf{x}|c\mathbf{z}+\mathbf{v},s^{2}\mathbf{I})]\\ &=-\dfrac{1}{2}\log 2\pi-\log s-\dfrac{c^{2}}{2s^{2}}\left[ \left\lVert \dfrac{x-\mu}{\sigma} -\dfrac{x-v}{c} \right\rVert ^{2} \right]\\ &\leq-\dfrac{1}{2} \log 2\pi-\log s \end{aligned}
E q ϕ ( z ∣ x ) [ log p θ ( x ∣ z )] = E q ϕ ( z ∣ x ) [ log N ( x ∣ c z + v , s 2 I )] = − 2 1 log 2 π − log s − 2 s 2 c 2 [ σ x − μ − c x − v 2 ] ≤ − 2 1 log 2 π − log s
When v = μ , c = σ \mathbf{v}=\boldsymbol{\mu},c=\sigma v = μ , c = σ ,the equal holds. when s = 0 s=0 s = 0 , this term reach its maximum. This implies that p θ ( x ∣ z ) = N ( x ∣ σ z + μ , 0 ) = δ ( x − ( σ z + μ ) ) p_\theta(\mathbf{x}|\mathbf{z})=\mathcal{N}(\mathbf{x}|\sigma \mathbf{z}+\boldsymbol{\mu},0)=\delta(\mathbf{x}-(\sigma \mathbf{z}+\boldsymbol{\mu})) p θ ( x ∣ z ) = N ( x ∣ σ z + μ , 0 ) = δ ( x − ( σ z + μ ))
The ELBO have limitations when q ϕ ( z ∣ x ) q_\phi(\mathbf{z}|\mathbf{x}) q ϕ ( z ∣ x ) may not equals to p ( z ∣ x ) p(\mathbf{z}|\mathbf{x}) p ( z ∣ x ) , thus ELBO not same to log p ( x ) \log p(\mathbf{x}) log p ( x )
If we don’t know p ( z ∣ x ) p(\mathbf{z}|\mathbf{x}) p ( z ∣ x ) , we need to train VAE by maxing ELBO.
q ϕ ( z ∣ x ) = N ( z ∣ ( x − μ ) σ , t 2 I ) p θ ( x ∣ z ) = N ( x ∣ σ z + μ , s 2 I ) \begin{aligned} &q_\phi(\mathbf{z}|\mathbf{x})=\mathcal{N}\left( \dfrac{\mathbf{z}|(x-\mu)}{\sigma},t^{2}\mathbf{I} \right)\\ &p_\theta(\mathbf{x}|\mathbf{z})=\mathcal{N}(\mathbf{x}|\sigma \mathbf{z}+\boldsymbol{\mu},s^{2}\mathbf{I}) \end{aligned}
q ϕ ( z ∣ x ) = N ( σ z ∣ ( x − μ ) , t 2 I ) p θ ( x ∣ z ) = N ( x ∣ σ z + μ , s 2 I )
After maximizing D KL ( q ϕ ( z ∣ x ) ∥ p ( z ) ) \mathbb{D}_{\text{KL}}(q_\phi(\mathbf{z}|\mathbf{x})\|p(\mathbf{z})) D KL ( q ϕ ( z ∣ x ) ∥ p ( z )) , and E q ϕ ( z ∣ x ) [ log p θ ( x ∣ z ) ] \mathbb{E}_{q_\phi(\mathbf{z}|\mathbf{x})}[\log p_\theta(\mathbf{x}|\mathbf{z})] E q ϕ ( z ∣ x ) [ log p θ ( x ∣ z )] ,we have
q ϕ ( z ∣ x ) = N ( z ∣ x − μ σ , I ) p θ ( x ∣ z ) = N ( x ∣ σ z + μ , σ 2 I ) \begin{aligned} q_\phi(\mathbf{z}|\mathbf{x})=\mathcal{N}\left( \mathbf{z}| \dfrac{\mathbf{x}-\mu}{\sigma},\mathbf{I} \right)\\ p_\theta(\mathbf{x}|\mathbf{z})=\mathcal{N}(\mathbf{x}| \sigma \mathbf{z}+\boldsymbol{\mu},\sigma^{2}\mathbf{I}) \end{aligned}
q ϕ ( z ∣ x ) = N ( z ∣ σ x − μ , I ) p θ ( x ∣ z ) = N ( x ∣ σ z + μ , σ 2 I )
Compare to former result, the result here contains variance I \mathbf{I} I and σ 2 I \sigma^{2}\mathbf{I} σ 2 I , adds randomness to samples.
Thus we know that maximizing ELBO is not just maximizing log p ( x ) \log p(x) log p ( x ) .
Optimizing VAE
Since Monte-Carlo can not sample the gradient of distribution E z ∼ P ϕ ( z ) [ f ( z ) ] E_{z\sim P_{\phi}(z)}[f(z)] E z ∼ P ϕ ( z ) [ f ( z )] itself (i.e. ∫ ∇ ϕ { f ( z ) P ϕ ( z ) } d z ≠ ∫ ∇ ϕ { f ( z ) } P ϕ ( z ) d z = 1 N ∑ i = 1 N ∇ ϕ f ( z i ) \int \nabla_{\phi}\{f(z)P_{\phi}(z)\}\,\mathrm{d}z\neq \int \nabla_{\phi} \{ f(z) \}P_{\phi}(z)\,\mathrm{d}z=\dfrac{1}{N}\sum_{i=1}^N \nabla_{\phi}f(z_{i}) ∫ ∇ ϕ { f ( z ) P ϕ ( z )} d z = ∫ ∇ ϕ { f ( z )} P ϕ ( z ) d z = N 1 ∑ i = 1 N ∇ ϕ f ( z i ) )
We need to introduce the reparameterization trick : express z z z as some differentiable transformation of another random variable ε \varepsilon ε which is independent to parameter ϕ \phi ϕ .
In the ELBO context, we define a function g g g s.t. z = g ( ε , ϕ , x ) \mathbf{z}=\mathbf{g}(\boldsymbol{\varepsilon},\boldsymbol{\phi},\mathbf{x}) z = g ( ε , ϕ , x ) for random var ε ∼ p ( ε ) \varepsilon\sim p(\boldsymbol{\varepsilon}) ε ∼ p ( ε ) and q ϕ ( z ∣ x ) ⋅ ∣ det ( ∂ z ∂ ε ) ∣ = p ( ε ) q_\phi(\mathbf{z}|\mathbf{x}) \cdot \left\lvert \det\left( \frac{ \partial \mathbf{z} }{ \partial \boldsymbol{\varepsilon} } \right) \right\rvert=p(\boldsymbol{\varepsilon}) q ϕ ( z ∣ x ) ⋅ det ( ∂ ε ∂ z ) = p ( ε )
E q ϕ ( z ∣ x ) ( f ( z ) ) = ∫ f ( z ) ⋅ q ϕ ( z ∣ x ) d z = ∫ f ( g ( ε ) ) ⋅ q ϕ ( g ( ε ) ∣ x ) d g ( x ) = ∫ f ( g ( ε ) ) ⋅ q ϕ ( g ( ε ) ∣ x ) ⋅ ∣ det ( ∂ g ( ε ) ∂ ε ) ∣ d ε (changing variable) = ∫ f ( z ) ⋅ p ( ε ) d ε = E p ( ε ) [ f ( z ) ] \begin{aligned}
\mathbb{E}_{q_\phi(\mathbf{z}|\mathbf{x})}(f(\mathbf{z}))
&=\int f(z)\cdot q_\phi(\mathbf{z}|\mathbf{x})\,\mathrm{d}\mathbf{z}\\
&=\int f(g(\boldsymbol{\varepsilon}))\cdot q_\phi(\mathbf{g}(\boldsymbol{\varepsilon})|\mathbf{x}) \, \mathrm{d}\mathbf{g}(\mathbf{x}) \\
&=\int f(g(\varepsilon))\cdot q_\phi(\mathbf{g}(\mathbf{\varepsilon})|\mathbf{x})\cdot \left\lvert \det\left( \frac{ \partial \mathbf{g}(\boldsymbol{\varepsilon}) }{ \partial \boldsymbol{\varepsilon} } \right) \right\rvert \, \mathrm{d}\boldsymbol{\varepsilon}\quad\text{(changing variable)}\\
&=\int f(\mathbf{z})\cdot p(\boldsymbol{\varepsilon}) \, \mathrm{d}\boldsymbol{\varepsilon} \\
&=\mathbb{E}_{p(\mathbf{\varepsilon})}[f(\mathbf{z})]
\end{aligned}
E q ϕ ( z ∣ x ) ( f ( z )) = ∫ f ( z ) ⋅ q ϕ ( z ∣ x ) d z = ∫ f ( g ( ε )) ⋅ q ϕ ( g ( ε ) ∣ x ) d g ( x ) = ∫ f ( g ( ε )) ⋅ q ϕ ( g ( ε ) ∣ x ) ⋅ det ( ∂ ε ∂ g ( ε ) ) d ε (changing variable) = ∫ f ( z ) ⋅ p ( ε ) d ε = E p ( ε ) [ f ( z )]
∇ ϕ E q ϕ ( z ∣ x ) [ f ( z ) ] = E p ( ε ) [ ∇ ϕ f ( z ) ] \nabla_{\phi}\mathbb{E}_{q_\phi(\mathbf{z}|\mathbf{x})}[f(\mathbf{z})]=E_{p(\boldsymbol{\varepsilon})} [\nabla_{\phi}f(\mathbf{z})]
∇ ϕ E q ϕ ( z ∣ x ) [ f ( z )] = E p ( ε ) [ ∇ ϕ f ( z )]
Recall the ELBO formula, we can substitute f ( z ) = − log q ϕ ( z ∣ x ) f(\mathbf{z})=-\log q_{\phi}(\mathbf{z}|\mathbf{x}) f ( z ) = − log q ϕ ( z ∣ x )
∇ ϕ E q ϕ ( z ∣ x ) [ − log q ϕ ( z ∣ x ) ] = 1 L ∑ l = 1 L ∇ ϕ [ log ∣ det ∂ z ( l ) ∂ ε ( l ) ∣ ] \nabla_{\phi}\mathbb{E}_{q_\phi(\mathbf{z}|\mathbf{x})}[-\log q_\phi(\mathbf{z}|\mathbf{x})]=\dfrac{1}{L}\sum_{l=1}^L \nabla_{\phi}\left[ \log \left\lvert \det \frac{ \partial \mathbf{z}^{(l)}}{ \partial \boldsymbol{\varepsilon}^{(l)} }\right\rvert \right]
∇ ϕ E q ϕ ( z ∣ x ) [ − log q ϕ ( z ∣ x )] = L 1 l = 1 ∑ L ∇ ϕ [ log det ∂ ε ( l ) ∂ z ( l ) ]
Integration by Substitution
∫ D f ( z ) d z = ∫ g − 1 ( D ) f ( g ( ε ) ) ⋅ ∣ det ( ∂ z ∂ ε ) ∣ d ε \int _{D}f(\mathbf{z}) \, \mathrm{d}\mathbf{z}=\int_{\mathbf{g}^{-1}(D)}f(\mathbf{g}(\varepsilon)) \cdot \left\lvert \det\left( \frac{ \partial \mathbf{z} }{ \partial \boldsymbol{\varepsilon} } \right) \right\rvert \mathrm{d}\boldsymbol{\varepsilon}
∫ D f ( z ) d z = ∫ g − 1 ( D ) f ( g ( ε )) ⋅ det ( ∂ ε ∂ z ) d ε
Reference
Step-by-Step Diffusion: An Elementary Tutorial
Tutorial on Diffusion Models for Imaging and Vision