Extension of DV Formula

\(\def\abs#1{\left\lvert #1 \right\rvert} \def\Set#1{\left\{ #1 \right\}} \def\mc#1{\mathcal{#1}} \def\M#1{\boldsymbol{#1}} \def\R#1{\mathsf{#1}} \def\RM#1{\boldsymbol{\mathsf{#1}}} \def\op#1{\operatorname{#1}} \def\E{\op{E}} \def\d{\mathrm{\mathstrut d}}\)

\(f\)-divergence

Consider the more general problem of estimating the \(f\)-divergence in (3):

\[ D_f(P_{\R{Z}}\|P_{\R{Z}'}) = E\left[f\left(\frac{dP_{\R{Z}}(\R{Z}')}{dP_{\R{Z}'}(\R{Z}')}\right)\right]. \]

Exercise

How to estimate \(f\)-divergence using the DV formula?

Solution

Note that the neural network approximates the density ratio (5),

\[ t \approx \frac{dP_{\R{Z}}}{dP_{\R{Z}'}}, \]

which can then be used to evaluate the sample average (4),

\[ D_f(P_{\R{Z}}\| P_{\R{Z}'}) \approx \frac{1}{n} \sum_{i\in [n]} f\left(t(\R{Z})\right). \]

Instead of using the DV bound, it is desirable to train a network to optimize a tight bound on the \(f\)-divergence because:

  • Estimating KL divergence well does not imply the underlying neural network approximates the density ratio well:

    • While KL divergence is just a non-negative real number,

    • the density ratio is in a high dimensional function space.

  • DV formula does not directly maximizes a bound on \(f\)-divergence, i.e.
    it does not directly minimize the error in estimating \(f\)-divergence.

  • \(f\)-divergence may have bounds that are easier/more stable for training a neural network.

How to extend the DV formula to \(f\)-divergence?

The idea is to think of the \(f\)-divergence as a convex function(al) evaluated at the density ratio:

Proposition 6

\(f\)-divergence (3) is

(15)\[ \begin{align} D_f(P_{\R{Z}}\|P_{\R{Z}'}) = F\left[ \frac{P_{\R{Z}}}{P_{\R{Z}'}}\right] \end{align} \]

where

(16)\[ \begin{align} F[r] := E [ \underbrace{(f \circ r)(\R{Z}')}_{f(r(\R{Z}'))}] \end{align} \]

for any function \(r:\mc{Z} \to \mathbb{R}\).

This is more like a re-definition than a proposition as the proof is immediate:
(3) is obtained from (15) by substituting \(r=\frac{dP_{\R{Z}}}{dP_{\R{Z}'}}\).

As mentioned before, the KL divergence \(D(P_{\R{Z}}\|P_{\R{Z}'})\) is a special case of \(f\)-divergence:

\[ D(P_{\R{Z}}\|P_{\R{Z}'}) = F\left[r\right] \]

where

(17)\[ \begin{align*} F[r] &:= E\left[ r(\R{Z}')\log r(\R{Z}')\right]. \end{align*} \]

Exercise When is \(F[r]=0\) equal to \(0\)?

Solution

\(F[r]=0\) iff \(f(r(\R{Z}'))=0\) almost surely, which happens iff \(r(\R{Z}')=1\) almost surely. We may also write it more explicitly as:

\[ r(z) = z \qquad \forall z\in \op{supp}(p_{\R{Z}'}):= \Set{z\in \mc{Z}\mid p_{\R{Z}'}(z) > 0}, \]

namely an identity function over the support set of \(\R{Z}'\).

Exercise

Show using the properties of \(f\) that \(F\) is strictly convex.

Solution

For \(\lambda\in [0,1]\) and functions \(r_1, r_2\in \Set{r:\mc{Z}\to \mathbb{R}}\),

\[\begin{split} \begin{align*} F[\lambda r_1 + (1-\lambda) r_2] &= E[\underbrace{ f(\lambda r_1(\R{Z}') + (1-\lambda) r_2(\R{Z}'))}_{\stackrel{\text{(a)}}{\geq} \lambda f(r_1(\R{Z}'))+(1-\lambda) f(r_2(\R{Z}'))}]\\ &\stackrel{\text{(b)}}{\geq} \lambda E[f(r_1(\R{Z}'))] + (1-\lambda) E[f(r_2(\R{Z}'))] \end{align*} \end{split}\]

where (a) is by the convexity of \(f\), and (b) is by the linearity of expectation. \(F\) is strictly convex because (b) is satisfied with equality iff (a) is almost surely.

For a clearer understanding, consider a different choice of \(F\) for the KL divergence:

\[ D(P_{\R{Z}}\|P_{\R{Z}'}) = F'\left[r\right] \]

where

(18)\[ \begin{align*} F'[r] &:= E\left[ \log r(\R{Z})\right]. \end{align*} \]

Note that \(F'\) in (18) defined above is concave in \(r\). In other words, (9) in Proposition 4

\[\begin{split} \begin{align*} D(P_{\R{Z}}\|P_{\R{Z}'}) & = \sup_{\substack{r:\mc{Z}\to \mathbb{R}_+\\ E[r(\R{Z}')]=1}} E \left[ \log r(\R{Z}) \right] \end{align*} \end{split}\]

is maximizing a concave function and therefore has a unique solution, namely, \(r=\frac{dP_{\R{Z}}}{dP_{\R{Z}'}}\). Here comes the tricky question:

Exercise

Is KL divergence concave or convex in the density ratio \(\frac{dP_{\R{Z}}}{dP_{\R{Z}'}}\)? Note that \(F\) defined in (17) is convex in \(r\).

Solution

The statement is invalid because KL divergence is not purely a function of the density ratio, but both \(P_{\R{Z}}\) and \(P_{\R{Z}'}\). The expectation in (18), in particular, depends on \(P_{\R{Z}}\).

Convex conjugation

Given \(P_{\R{Z}'}\in \mc{P}(\mc{Z})\), consider

  • a function space \(\mc{R}\),

(19)\[ \begin{align} \mc{R} &\supseteq \Set{r:\mathcal{Z}\to \mathbb{R}_+\mid E\left[r(\R{Z}')\right] = 1}, \end{align} \]
  • a dual space \(\mc{T}\), and

(20)\[ \begin{align} \mc{T} &\subseteq \Set{t:\mc{Z} \to \mathbb{R}} \end{align} \]
  • the corresponding inner product \(\langle\cdot,\cdot \rangle\):

(21)\[ \begin{align} \langle t,r \rangle &= \int_{z\in \mc{Z}} t(z) r(z) dP_{\R{Z}'}(z) = E\left[ t(\R{Z}') r(\R{Z}') \right]. \end{align} \]

The following is a generalization of DV formula for estimating \(f\)-divergence [2][3]:

Proposition 7

(22)\[ \begin{align} D_{f}(P_{\R{Z}} \| P_{\R{Z}'}) = \sup _{t\in \mc{T}} E[g(\R{Z})] - F^*[t], \end{align} \]

where

(23)\[ \begin{align} F^*[t] = \sup_{r\in \mc{R}} E[t(\R{Z}') r(\R{Z}')] - F[r]. \end{align} \]

Proof. Note that the supremums in (23) and (22) are Fenchel-Legendre transforms. Denoting the transform as \([\cdot]^*\),

\[\underbrace{[[F]^*]^*}_{=F}\left[\frac{dP_{\R{Z}}}{dP_{\R{Z}'}}\right]\]

gives (22) by expanding the outer/later transform. The equality is by the property that Fenchel-Legendre transform is its own inverse for strictly convex functional \(F\). This completes the proof by (15).

The proof is illustrated in the following figure:

-Divergence

Let’s breakdown the details:

Step 1

For the purpose of the illustration, visualize the convex functional \(F\) simply as a curve in 2D.

-Divergence 1

The \(f\)-divergence is then the \(y\)-coordinate of a point on the curve indicated above, with \(r\) being the density ratio \(\frac{dP_{\R{Z}}}{dP_{\R{Z}'}}\).

Step 2

To obtain a lower bound on \(F\), consider any tangent of the curve with an arbitrary slope \(t\cdot dP_{\R{Z}'}\)

-Divergence 2

The lower bound is given by the \(y\)-coordinate of a point on the tangent with \(r\) being the density ratio.

Exercise

Why is the \(y\)-coordinate of the tangent a lower bound on the \(f\)-divergence?

Solution

By the convexity of \(F\), the tangent must be below \(F\).

Step 3

To calculate the lower bound, denote the \(y\)-intercept as \(-F^*[t]\):

-Divergence 3

Thinking of a function as nothing but a vector, the displacement from the \(y\)-intercept to the lower bound is given by the inner product of the slope and the density ratio.

Step 4

To make the bound tight, maximize the bound over the choice of the slope or \(t\):

-Divergence 4

This gives the bound in (22). It remains to show (23).

Step 5

To compute the \(y\)-intercept or \(F^*[t]\), let \(r^*\) be the value of \(r\) where the tangent touches the convex curve:

-Divergence 5

The displacement from the point at \(r^*\) to the \(y\)-intercept can be computed as the inner product of the slope and \(r^*\).

Exercise

Show that for the functional \(F\) (17) defined for KL divergence,

\[F^*[t]=\log E[e^{t(\R{Z}')}]\]

with \(\mc{R}=\Set{r:\mc{Z}\to \mathbb{R}_+}\) and so (22) gives the DV formula (12) as a special case.

Solution

See [3].````