KSD, SVGD, other computational Stein discrepancy methods
November 2, 2022 — April 2, 2024
approximation
Bayes
functional analysis
Markov processes
measure
metrics
Monte Carlo
optimization
probabilistic algorithms
probability
statistics
Stein’s method meets variational inference via kernels and probability measures. The result is method of inference which maintains an ensemble of particles which notionally collectively sample from some target distribution. I should learn about this, as one of the methods I might use for low-assumption Bayes inference.
There seems to be a standard way of introducing the tools, which I find very confusing. Here I work through that standard with laborious worked examples, so that I can internalise the necessary intuitions for all this.
For a more comprehensive introduction (albeit brusquer), see Anastasiou et al. (2023), which combines a whole bunch of recent developments with consistent notation..
1 Stein operators
This seems to have been invented in Q. Liu, Lee, and Jordan (2016) and Chwialkowski, Strathmann, and Gretton (2016), weaponized in Q. Liu (2016b).
Let us introduce the bits we need. We start with the classic Stein’s identity which turns out to be a useful trick for quantifying how well we have approximated some density.
Spoiler: later on it turns out that we can even use this as a target loss in order to improve how well we have approximated some density.
We care about a target density \(p\) and another density \(q\) (which will end up approximating it). \(p\) needs to be be differentiable for this to work. They are both densities over, by assumption \(\mathcal{X}\subseteq\mathbb{R}^d\). We also introduce a family of \(\mathbb{R}^d\) to \(\mathbb{R}^d\) test functions \(\mathcal{F}\). We require that \(\lim_{x \to \pm \infty} p(xb)f(xb) = 0\) for \(\|b\|=1\), and some other stuff which we get to in a moment.
Should say more about the generic class \(\mathcal{F}\). Q. Liu, Lee, and Jordan (2016) does.
Next, we choose a Stein operator \(\mathcal{A}_p: \mathcal{F}\to \mathcal{G}\). \(\mathcal{F}\) and \(\mathcal{G}\) are spaces of functions from \(\mathbb{R}^d\) to \(\mathbb{R}^d\). I gave them different names because it is not clear to me that they are necessarily the same space, but we can probably ignore that detail for now. We do not go into the details of the requirements of the spaces, but they should be smooth functions that are square-integrable (with respect to the target distribution \(p\)? or Lebesgue measure? something else?) whose derivatives also go to zero at infinity. For a function \(\mathcal{A}_p\) to be a Stein operator for a target distribution \(p(x)\), it must satisfy the following key property:
Stein’s Operators:\(\mathcal{A}_p\) is a Stein operator with respect to a suitable class of test functions \(\mathcal{F}\), if the expectation of \(\mathcal{A}_p f(X)\) under the target distribution \(p(x)\) is zero, i.e. for all those \(f\) in \(\mathcal{F}\),
Note that \(p\) restricts our choice of \(\mathcal{A}_p\) but AFAICT does not uniquely determine it. I do not know if that is useful. I wonder if we can choose \(\mathcal{F}\) cunningly to admit some spicy alternate \(\mathcal{A}_p\).
For \(\mathcal{F}\) which include a non-trivial linear subspace, we can see that \(\mathcal{A}_p\) must be linear, because expectation is linear, and otherwise we could make linear changes to an \(f\) and end up violating the equality.
The classic choice is to set the Stein Operator to \[
\mathcal{A}_p f(x):=f(x) \nabla_x \cdot \log p(x)+\nabla_x \cdot f(x).
\tag{2}\]
AFAICT we are essentially assuming an exponential family density here, or something?
Anyway, let us make this more concrete, by choosing a specific \(p\) which is not trivial but not baffling either. I reckon a 2d Gaussian with standard deviation 1 and mean 0 with do the trick. Let us give it a correlation \(\rho\), which we leave unspecified, to keep things spicy. This implies mean \(\mu = (0, 0)\) and covariance \(\Sigma = \left[\begin{smallmatrix} 1 & \rho \\ \rho & 1 \end{smallmatrix}\right]\), and thence inverse covariance \(\Sigma^{-1} = \tfrac{1}{1-\rho^2}\left[\begin{smallmatrix} 1 & -\rho \\ -\rho & 1 \end{smallmatrix}\right]\). The pdf for this distribution is \[
\begin{aligned}
p(\boldsymbol{x})
&= \frac{1}{2\pi \sqrt{|\Sigma|}} \exp\left(-\frac{1}{2}
(x-\mu)^{\top} \Sigma^{-1} (x-\mu)\right)\\
&= \frac{1}{2\pi \sqrt{1-\rho^2}} \exp\left(-\frac{1}{2}
\begin{bmatrix}
x_1 & x_2
\end{bmatrix} \frac{1}{1-\rho^2} \begin{bmatrix}
1 & -\rho \\ -\rho & 1
\end{bmatrix} \begin{bmatrix}
x_1 \\ x_2
\end{bmatrix}\right)\\
&= \frac{1}{2\pi \sqrt{1-\rho^2}} \exp\left(-\frac{1}{2(1-\rho^2)}\left(x_1^2 - 2\rho x_1 x_2 + x_2^2\right)\right).
\end{aligned}
\]
We can simplify the Stein operator for this choice of \(p\), since \[
\begin{aligned}
\nabla_x \log p(x)
&= \nabla_x \log \left(\frac{1}{2\pi \sqrt{1-\rho^2}} \exp\left(-\frac{1}{2(1-\rho^2)}\left(x_1^2 - 2\rho x_1 x_2 + x_2^2\right)\right)\right)\\
&= \nabla_x \left( -\frac{1}{2(1-\rho^2)} \left(
x_1^2 - 2\rho x_1 x_2 + x_2^2\right)\right)\\
&= -\frac{1}{2(1-\rho^2)}\nabla_x\left(x_1^2 - 2\rho x_1 x_2 + x_2^2\right)\\
&= \frac{1}{1-\rho^2}\begin{bmatrix}
x_1 - \rho x_2\\
x_2 - \rho x_1
\end{bmatrix}
\end{aligned}
\]
We can choose some stupid \(f\) for the purposes of that intuition building we claimed we were doing. A linear one is \(f(x) = \begin{bmatrix} a_1 x_1 + b_1 x_2 \\ a_2 x_1 + b_2 x_2 \end{bmatrix}\). The expectation of the Stein operator for our bivariate Gaussian is then
Phew! OK that worked. I itch to plot these functions. Note that even in 2 dimensions that is tricky, because our functions are 2d functions of 2d inputs, so we need 4 dimensions to plot them.
I could do a 1d example. We can find loads of those on the internet, and I did not find them helpful because it is too easy to lose sight of how this method handles higher-dimensional spaces, which is the only time it is interesting.
We can juuuuust about do it in 2 dimensions by using a quiver plot and treating it as a weird vector field; here the arrows show the output space for an arbitrary stoopid linear \(f\) and the contours show the density of the target distribution \(p\).
Code
import jax.numpy as jnpimport numpy as npfrom jax import grad, vmap, jacfwd, jacrevimport plotly.figure_factory as ffimport plotly.graph_objects as gofrom plotly.subplots import make_subplotsimport plotly.io as piofrom _plotly_styles import textbookpio.templates.default ="none"# Define shared parameters for scaleA_scale_factor =1.5f_scale_factor =0.1arrow_scale_factor =0.3n =25# Define the generic function fdef f(x, a1, b1, a2, b2):return jnp.stack([ a1 * x[..., 0] + b1 * x[..., 1], a2 * x[..., 0] + b2 * x[..., 1]], axis=-1)# Define the log-density of the generic probability density function pdef log_p(x, rho):return-0.5* ( x[..., 0]**2+ x[..., 1]**2-2* rho * x[..., 0] * x[..., 1] ) / (1- rho**2 ) - jnp.log(2* jnp.pi * jnp.sqrt(1- rho**2) )# Fix specific values for rho and the parameters of frho =0.5a1, b1, a2, b2 =1, 0.25, -0.75, -1f_specific =lambda x: f(x, a1, b1, a2, b2)log_p_specific =lambda x: log_p(x, rho)# Define the Stein operator applied to a specific f and pdef A_p_f(x, f, log_p): grad_log_p = vmap(grad(log_p))(x) jac_f = vmap(jacfwd(f))(x)return grad_log_p[:, None, :] * f(x)[:, :, None] + jac_f# Create a grid of pointsx1, x2 = np.meshgrid( np.linspace(-3, 3, n, endpoint=True), np.linspace(-3, 3, n, endpoint=True))x = np.stack([x1, x2], axis=-1).reshape(-1, 2)# Compute the function f at each point in xf_x = f_specific(x)p_x = np.exp(log_p_specific(x))# Compute the Stein operator for f at each point in xA_f_p_x = A_p_f(x, f_specific, log_p_specific)p_A_f_p_x = A_f_p_x * p_x[:, None, None]# Create a subplot with 2 rows and 1 columnfig = make_subplots( rows=2, cols=1, subplot_titles=('<i>f</i>','<i>p A<sub>p</sub> f</i>'), vertical_spacing=0.1, row_heights=[0.5, 0.5])textbook(fig)# Add the quiver plot for the function f to the first subplotfig.add_trace( ff.create_quiver( x1, x2, f_x[:, 0].reshape(x1.shape), f_x[:, 1].reshape(x2.shape), scale=f_scale_factor, arrow_scale=arrow_scale_factor, name='f', line_width=1).data[0], row=1, col=1)# Add the quiver plot for the first component of the Stein operator to the second subplotfig.add_trace( ff.create_quiver( x1, x2, p_A_f_p_x[:, 0, 0].reshape(x1.shape), p_A_f_p_x[:, 1, 0].reshape(x2.shape), scale=A_scale_factor, arrow_scale=arrow_scale_factor, name='Component 1', line_width=1, line_color='blue').data[0], row=2, col=1)# Add the quiver plot for the second component of the Stein operator to the second subplotfig.add_trace( ff.create_quiver( x1, x2, p_A_f_p_x[:, 0, 1].reshape(x1.shape), p_A_f_p_x[:, 1, 1].reshape(x2.shape), scale=A_scale_factor, arrow_scale=arrow_scale_factor, name='Component 2', line_width=1, line_color='red').data[0], row=2, col=1)# Set the layoutfig.update_layout( width=400, height=800, font=dict(family="Alegreya, serif"), paper_bgcolor='rgba(0,0,0,0)', plot_bgcolor='rgba(0,0,0,0)', showlegend=False)# Update axis ranges and aspect ratio for both subplotsaxis_range = [-3, 3]for i inrange(1, 3): fig.update_xaxes(range=axis_range, row=i, col=1) fig.update_yaxes(range=axis_range, row=i, col=1, scaleanchor=f'x{i}', scaleratio=1)# Show the plotfig.show()
Did that help us? Well, kinda. It is not really clear to me that I should trust that the second figure should actually integrate to 0. Did it?
p_A_f_p_x.sum().item()
-0.21152114868164062
Hm, not convincingly exactly 0. Numerical approximation problem, truncation error or actual bug? TBD. For now, I’m out of time, we really need to be moving along.
2 Stein discrepancy
We make Equation 1 into a quantity that depends on two potentially-different densities by taking the expectation over a different density \(q\) than the one that generated the operator \(\mathcal{A}_p,\) and seeing if that does something useful:
Spoiler: it turns out that this does do something useful.
\[
\begin{aligned}
\mathbb{E}_{x\sim q}[\mathcal{A}_p f(X)]
&=\mathbb{E}_{x\sim q}[\mathcal{A}_p f(X)] - \overbrace{\mathbb{E}_{x\sim q}[\mathcal{A}_q f(X)]}^{=0}\\
&=\mathbb{E}_{x\sim q}\big[f(x) \cdot \nabla_x \log p(x)+\nabla_x \cdot f(x)\\
&\qquad-f(x) \cdot \nabla_x \log p(x)-\nabla_x \cdot f(x)\big]\\
&=\mathbb{E}_{x\sim q}\left[f(x) \cdot \nabla_x \log p(x) -f(x) \cdot \nabla_x \log q(x)\right]\\
&=\mathbb{E}_{x\sim q}\left[f(x) \delta_{p,q}(x) \right]
\end{aligned}
\] where \(\delta_{p,q}(x):= \nabla_x \log p(x) -\nabla_x \log q(x)\) is the difference in score function between \(p\) and \(q\).
By choosing a \(f\) from some sufficiently rich \(\mathcal{F}\) we can make this non-zero unless \(p=q\) a.e., so this equation tells us something about how distinct are two densities \(p\) and \(q\), in this slightly weird but credible-seeming sense where we care about the difference in their score functions. i.e. this is some kind of score matching method.
This looks neat. How can we calculate it in practice? Obstacle: we have not specified \(f\). We could fix some \(f\) and use it to measure how different are \(p\) and \(q\) in some sense. Or we could choose some stochastic process which generates some random \(f\)s and estimate it over many \(f\)s, I guess? I assume that has been done.
The Stein Discrepancy takes a strong approach to controlling those \(f\)s: We control the supremum of that difference over all \(f\) in some function class \(\mathcal{F}\), so that we know that this difference \(p\) and \(q\) is not too bad for any \(f\), since if we have found this Stein discrepancy, we have found how bad it is over the worst \(f\):
Notice we snuck in a trace there as well to make it a scalar? This ended up being the most confusing thing for me; how many dimensions even is anything in in this equation?
I said the Stein discrepancy, but it looks like each choice of \(\mathcal{F}\) and operator \(\mathcal{A}_p\) gives us a different Stein discrepancy, no? The literature seems to occasionally consider other choices of \(\mathcal{A}_p\), but sometimes takes the default Equation 2 as the only one.
Let us consider how we might find this ‘worst’ \(f\) which gives us this most powerful guarantee of the difference between \(p\) and \(q\). There are a few steps.
First, we use the linearity of Stein operator \(\mathcal{A}_p\), mentioned earlier. Suppose that \(f\) can be represented as a finite linear combination \(f(x)=\sum_i w_i f_i(x)\) of a set of basis functions \(f_i(x)\) for some coefficients \(w_i\) s.t. \(\|w\| \leq 1\). Then we can define the violation of Stein-ness by \[
\mathbb{E}_q\left[\mathcal{A}_p f\right]=\mathbb{E}_q\left[\mathcal{A}_p \sum_i w_i f_i(x)\right]=\sum_i w_i \beta_i,
\] where \[
\beta_i=\mathbb{E}_{x \sim q}\left[\mathcal{A}_p f_i(x)\right] .
\]
This only works for univariate densities, so far. To make the discrepancy be a scalar even for multivariate problems (in the sense of densities over multidimensional spaces) we define this violation as \[
\mathbb{E}_p\left[\operatorname{trace}\left(\mathcal{A}_q \boldsymbol{f}(x)\right)\right]=\mathbb{E}_p\left[\left(\boldsymbol{s}_q(x)-\boldsymbol{s}_p(x)\right)^{\top} \boldsymbol{f}(x)\right]
\]
From this we get \(\phi(x)=\phi_{q, p}^*(x) /\left\|\phi_{q, p}^*\right\|_{\mathcal{H}^d}\), where \[
\phi_{q, p}^*(\cdot)=\mathbb{E}_{x \sim q}\left[\mathcal{A}_p k(x, \cdot)\right], \quad \text { for which we have } \quad \mathbb{S}(q, p)=\left\|\phi_{q, p}^*\right\|_{\mathcal{H}^d}^2
\]
OK, so this is notionally an optimisation problem we can solve, choosing the \(w_i\) values to be as terrible as possible, and then seeing how bad the most-terrible values are. That is a nested optimisation problem, which is a bit tedious. Can we do better?
3 Kernelized Stein Discrepancy
When we see a challenge of this kind — where we wish we had ‘more tricks’ in our function space — it typically suggests trying the kernel trick and seeing if that does anything cool. This involves, specifically, choosing the function class \(\mathcal{F}\) to be a reproducing kernel Hilbert space and seeing what that does to the problem. Frequently, as here, that makes life easier.
Kernelized Stein Discrepancy does exactly that, picking a special family \(\mathcal{F}\). Let us sidle up to KSD: \(\mathcal{H}\) is the RKHS with associated kernel \(k:\mathbb{R}^d \times \mathbb{R}^d \to \mathbb{R}\). We required that \(k(x, x^{\prime}): \mathbb{R}^d \times \mathbb{R}^d \to \mathbb{R}\) be positive definite kernel. The RKHS \(\mathcal{H}\) with kernel \(k\) includes functions of form \(f(x)=\sum_i w_i k\left(x, x_i\right)\), equipped with RKHS inner product \(\langle f, g\rangle_{\mathcal{H}}=\sum_{i j} w_i v_j k\left(x_i, x_j\right)\) for \(g=\sum_j v_j k\left(x, x_j\right)\) and RKHS norm \(\|f\|_{\mathcal{H}}^2=\sum_{i j} w_i w_j k\left(x_i, x_j\right)\).
The problem is that this space doesn’t contain vector-valued functions of the kind we need to do “Stein stuff”. So… next we define \(\mathcal {H}^d\), a product space consisting of elements \(\boldsymbol {f} = (f_1, f_2, \ldots, f_d)\), where each \(f_i \in \mathcal {H}\). Each such \(f\) then would be an \(\mathbb {R}^d\to \mathbb {R}^d\) function, which is what we need for the Stein operator. We endow this little fella with inner product \(\langle \boldsymbol {f}, \boldsymbol {g}\rangle_{\mathcal {H}^d}=\sum_{i=1}^d \langle f_i, g_i\rangle_{\mathcal {H}}.\)
It turns out that we can calculate efficiently in this RKHS, because it gives us some extra structure: \[
\begin{aligned}
f(x)&=\langle f(\cdot), k(x, \cdot)\rangle_{\mathcal{H}} && \text{reproducing property}\\
\nabla_x f(x)&=\left\langle f(\cdot), \nabla_x k(x, \cdot)\right\rangle_{\mathcal{H}} && \text{gradient property}
\end{aligned}
\] With these properties, we have \[
\mathbb{E}_{x \sim q}\left[\mathcal{A}_p f(x)\right]=\left\langle f_i (\cdot), \mathbb{E}_{x \sim q}\left[\mathcal{A}_p k(\cdot, x)\right]\right\rangle_{\mathcal{H}}
\] where we shift both the expectation and Stein operator to the kernel function. Moreover, \[
\begin{aligned}
\mathbb{E}_{x \sim q}\left[\operatorname{trace}\mathcal{A}_p \boldsymbol{f}(x)\right]
&=\sum_{i=1}^{d} \left\langle f_i(\cdot), \mathbb{E}_{x \sim q}\left[\mathcal{A}_p k_i(\cdot, x)\right]\right\rangle_{\mathcal{H}}\\
&=\left\langle \boldsymbol{f}(\cdot), \mathbb{E}_{x \sim q}\left[\mathcal{A}_p k(\cdot, x)\right]\right\rangle_{\mathcal{H}^d}
\end{aligned}
\]
We take \(\mathcal{F}\) to be the unit ball in that RKHS, i.e. \(\mathcal{F}:=\{\boldsymbol{f};\|\boldsymbol{f}\|_{\mathcal{H}^d} \leq 1 \}\).
Then we can write
\[
\sqrt{S(q, p)}=\sup _{f \in {\color{red}\mathcal{H}},\|f\|_{\color{red}\mathcal{H}^d} \leq 1}\left\{\mathbb{E}_{x \sim q}\left[\operatorname{trace}\mathcal{A}_p f(x)\right], \text { s.t. }\right\} .
\] i.e. it is just the same, but we have restricted the function class to be an RKHS.
Finding that supremum is then equivalent to solving \[
\sup _f\left\langle f, \beta_{q, p}\right\rangle_{\mathcal{H}}, \quad \text { s.t. }\|f\|_{\mathcal{H}} \leq 1 .
\]
\[
S(p, q)=\mathbb{E}_{x, x^{\prime} \sim p}\left[\boldsymbol{\delta}_{q, p}(x)^{\top} k\left(x, x^{\prime}\right) \boldsymbol{\delta}_{q, p}\left(x^{\prime}\right)\right],
\] where \(\boldsymbol{\delta}_{q, p}(x)=s_q(x)-s_p(x)\) is the score difference between \(p\) and \(q\), and \(x, x^{\prime}\) are i.i.d. draws from \(p(x)\).
We maximise this, I assert, if we set \(f=\beta_{q, p} /\left\|\beta_{q, p}\right\|_{\mathcal{H}}\), normalising it to be on the unit ball (question: why can it not be on the interior?) at the point that maximises the expectation. Thus \[\begin{align}
S(q, p)
&=\left\|\beta_{q, p}\right\|_{\mathcal{H}^d}^2\\
&=\mathbb{E}_{x, x^{\prime} \sim q}\left[\kappa_p\left(x, x^{\prime}\right)\right]
\end{align}\] where \[
\kappa_p\left(x, x^{\prime}\right):=\mathcal{A}_p^x \mathcal{A}_p^{x^{\prime}} k\left(x, x^{\prime}\right) .
\] Here we defined \(\mathcal{A}_p^x\) and \(\mathcal{A}_p^{x^{\prime}}\) represents the Stein operator w.r.t. variable \(x\) and \(x^{\prime}\), respectively. \(\kappa_p\left(x, x^{\prime}\right)\) is the “Steinalized” kernel obtained by applying Stein operator on \(k\left(x, x^{\prime}\right)\) twice.
It is a mess to write out in full though.
4 Stein Variational Gradient Descent
The next bit comes from Q. Liu and Wang (2019); it turns out that we can use this Stein trick to sample from some interest int distributions, by using the Stein discrepancy as a loss function. Interestingly, this works on posterior distributions in particular.
We manufacture an empirical \(q\) by using a set of particles \(\{x_i\}_{i=1}^n\).
The gradient descent here is not SGD where we assimilate gradient steps by looking at examples; it is rather a gradient descent in parameter space which converges in a towards a good approximation of the posterior.
Define a kernel over factors and now the Stein messages may be passed locally. Discovered simultaneously in 2018 by D. Wang, Zeng, and Liu (2018) and Zhuo et al. (2018).
Ambrogioni, Güçlü, Güçlütürk, et al. 2018. “Wasserstein Variational Inference.” In Proceedings of the 32Nd International Conference on Neural Information Processing Systems. NIPS’18.
Chwialkowski, Strathmann, and Gretton. 2016. “A Kernel Test of Goodness of Fit.” In Proceedings of the 33rd International Conference on International Conference on Machine Learning - Volume 48. ICML’16.
Detommaso, Cui, Spantini, et al. 2018. “A Stein Variational Newton Method.” In Proceedings of the 32nd International Conference on Neural Information Processing Systems. NIPS’18.
Gorham, and Mackey. 2015. “Measuring Sample Quality with Stein’s Method.” In Proceedings of the 28th International Conference on Neural Information Processing Systems - Volume 1. NIPS’15.
Liu, Xing, Zhu, Ton, et al. 2022. “Grassmann Stein Variational Gradient Descent.” In Proceedings of The 25th International Conference on Artificial Intelligence and Statistics.