Neural Estimation via DV bound¶
\(\def\abs#1{\left\lvert #1 \right\rvert} \def\Set#1{\left\{ #1 \right\}} \def\mc#1{\mathcal{#1}} \def\M#1{\boldsymbol{#1}} \def\R#1{\mathsf{#1}} \def\RM#1{\boldsymbol{\mathsf{#1}}} \def\op#1{\operatorname{#1}} \def\E{\op{E}} \def\d{\mathrm{\mathstrut d}}\)
Estimating MI well neither requires nor implies the divergence/density to be estimated well. However,
MI estimation is often not the end goal, but an objective to train a neural network to return the divergence/density.
The features/representations learned by the neural network may be applicable to different downstream inference tasks.
Neural estimation of KL divergence¶
To explain the idea of neural estimation, consider the following characterization of divergence:
Proposition 3
where the unique optimal solution is \(Q=P_{\R{Z}}\).
(7) is (2) but with \(P_{\R{Z}}\) replaced by a parameter \(Q\).
The proposition essentially gives a tight lower bound on KL divergence.
The unknown distribution is recovered as the optimal solution.
Proof. To prove (7),
The idea of neural estimation is to
estimate the expectation in (7) by the sample average
use a neural network to compute the density ratio (*), and train the network to maximizes the expectation, e.g., by gradient ascent on the above sample average.
Since \(Q\) is arbitrary, the sample average above is a valid estimate.
But how to compute the density ratio?
We will first consider estimating the KL divergence \(D(P_{\R{Z}}\|P_{\R{Z}'})\) when both \(P_{\R{Z}}\) and \(P_{\R{Z}'}\) are unknown.
Donsker-Varadhan formula¶
If \(P_{\R{Z}'}\) is unknown, we can apply a change of variable
which absorbs the unknown reference into the parameter.
Proposition 4
where the optimal \(r\) satisfies \( r(\R{Z}) = \frac{dP_{\R{Z}}(\R{Z})}{dP_{\R{Z}'}(\R{Z})}. \)
Exercise
Show using (8) that the optimal solution satisfies the constraint stated in the supremum (9).
Solution
The constraint on \(r\) is obtained from the constraint on \(Q\in \mc{P}(\mc{Z})\), i.e., with \(dQ(z)=r(z)dP_{\R{Z}'}(z)\),
The next step is to train a neural network that computes \(r\). What about?
How to impose the constraint on \(r\) when training a neural network?
We can apply a change of variable:
Exercise
Show that \(r\) defined in (11) satisfies the constraint in (7) for all real-valued function \(t:\mc{Z}\to \mathbb{R}\).
Solution
Substituting (11) into (7) gives the well-known Donsker-Varadhan (DV) formula [1]:
Corollary 5 (Donsker-Varadhan)
where the optimal \(t\) satisfies
almost surely for some constant \(c\).
The divergence can be estimated as follows instead of (10):
In summary, the neural estimation of KL divergence is a sample average of (2) but
but with the unknown density ratio replaced by (11) trained as a neural network.