{ "cells": [ { "cell_type": "markdown", "id": "7090f537", "metadata": { "slideshow": { "slide_type": "slide" }, "tags": [] }, "source": [ "# Extension of DV Formula" ] }, { "cell_type": "markdown", "id": "0c3f3519", "metadata": {}, "source": [ "$\\def\\abs#1{\\left\\lvert #1 \\right\\rvert}\n", "\\def\\Set#1{\\left\\{ #1 \\right\\}}\n", "\\def\\mc#1{\\mathcal{#1}}\n", "\\def\\M#1{\\boldsymbol{#1}}\n", "\\def\\R#1{\\mathsf{#1}}\n", "\\def\\RM#1{\\boldsymbol{\\mathsf{#1}}}\n", "\\def\\op#1{\\operatorname{#1}}\n", "\\def\\E{\\op{E}}\n", "\\def\\d{\\mathrm{\\mathstrut d}}$" ] }, { "cell_type": "markdown", "id": "7a0d2ef8", "metadata": {}, "source": [ "## $f$-divergence" ] }, { "cell_type": "markdown", "id": "eaded920", "metadata": {}, "source": [ "Consider the more general problem of estimating the $f$-divergence in {eq}`f-D`:\n", "\n", "$$\n", "D_f(P_{\\R{Z}}\\|P_{\\R{Z}'}) = E\\left[f\\left(\\frac{dP_{\\R{Z}}(\\R{Z}')}{dP_{\\R{Z}'}(\\R{Z}')}\\right)\\right].\n", "$$" ] }, { "cell_type": "markdown", "id": "41091e78", "metadata": {}, "source": [ "**Exercise** \n", "\n", "How to estimate $f$-divergence using the DV formula?" ] }, { "cell_type": "markdown", "id": "d23fa1a8", "metadata": { "nbgrader": { "grade": true, "grade_id": "DV-f-D", "locked": false, "points": 1, "schema_version": 3, "solution": true, "task": false } }, "source": [ "````{toggle}\n", "**Solution**\n", "\n", "Note that the neural network approximates the density ratio {eq}`dP-ratio`, \n", "\n", "$$\n", "t \\approx \\frac{dP_{\\R{Z}}}{dP_{\\R{Z}'}},\n", "$$\n", "\n", "which can then be used to evaluate the sample average {eq}`avg-f-D`,\n", "\n", "$$\n", "D_f(P_{\\R{Z}}\\| P_{\\R{Z}'}) \\approx \\frac{1}{n} \\sum_{i\\in [n]} f\\left(t(\\R{Z})\\right).\n", "$$\n", "\n", "````" ] }, { "cell_type": "markdown", "id": "e122e4c4", "metadata": {}, "source": [ "Instead of using the DV bound, it is desirable to train a network to optimize a tight bound on the $f$-divergence because:" ] }, { "cell_type": "markdown", "id": "469ec6a6", "metadata": {}, "source": [ "- *Estimating KL divergence well does not imply the underlying neural network approximates the density ratio well*: \n", " - While KL divergence is just a non-negative real number, \n", " - the density ratio is in a high dimensional function space." ] }, { "cell_type": "markdown", "id": "caa9a900", "metadata": {}, "source": [ "- DV formula does not directly maximizes a bound on $f$-divergence, i.e. \n", " it does not directly minimize the error in estimating $f$-divergence." ] }, { "cell_type": "markdown", "id": "bbabcd82", "metadata": {}, "source": [ "- $f$-divergence may have bounds that are easier/more stable for training a neural network." ] }, { "cell_type": "markdown", "id": "c85331f4", "metadata": {}, "source": [ "**How to extend the DV formula to $f$-divergence?**" ] }, { "cell_type": "markdown", "id": "71ba6086", "metadata": {}, "source": [ "The idea is to think of the $f$-divergence as a convex *function(al)* evaluated at the density ratio:" ] }, { "cell_type": "markdown", "id": "871bdba9", "metadata": {}, "source": [ "````{prf:proposition} \n", ":label: D->F\n", "\n", "$f$-divergence {eq}`f-D` is\n", "\n", "$$\n", "\\begin{align}\n", "D_f(P_{\\R{Z}}\\|P_{\\R{Z}'}) = F\\left[ \\frac{P_{\\R{Z}}}{P_{\\R{Z}'}}\\right]\n", "\\end{align}\n", "$$ (D->F)\n", "\n", "where \n", "\n", "$$\n", "\\begin{align}\n", "F[r] := E [ \\underbrace{(f \\circ r)(\\R{Z}')}_{f(r(\\R{Z}'))}]\n", "\\end{align}\n", "$$ (F)\n", "\n", "for any function $r:\\mc{Z} \\to \\mathbb{R}$.\n", "\n", "````" ] }, { "cell_type": "markdown", "id": "9f00ecfb", "metadata": {}, "source": [ "This is more like a re-definition than a proposition as the proof is immediate: \n", "{eq}`f-D` is obtained from {eq}`D->F` by substituting $r=\\frac{dP_{\\R{Z}}}{dP_{\\R{Z}'}}$." ] }, { "cell_type": "markdown", "id": "dccb6b5c", "metadata": {}, "source": [ "As mentioned before, the KL divergence $D(P_{\\R{Z}}\\|P_{\\R{Z}'})$ is a special case of $f$-divergence:\n", "\n", "$$\n", "D(P_{\\R{Z}}\\|P_{\\R{Z}'}) = F\\left[r\\right]\n", "$$ \n", "\n", "where\n", "\n", "$$\n", "\\begin{align*}\n", "F[r] &:= E\\left[ r(\\R{Z}')\\log r(\\R{Z}')\\right].\n", "\\end{align*}\n", "$$ (KL:F)" ] }, { "cell_type": "markdown", "id": "e9f76e67", "metadata": {}, "source": [ "**Exercise** When is $F[r]=0$ equal to $0$?" ] }, { "cell_type": "markdown", "id": "2b399394", "metadata": { "nbgrader": { "grade": true, "grade_id": "zero-F", "locked": false, "points": 1, "schema_version": 3, "solution": true, "task": false } }, "source": [ "````{toggle}\n", "**Solution** \n", "\n", "$F[r]=0$ iff $f(r(\\R{Z}'))=0$ almost surely, which happens iff $r(\\R{Z}')=1$ almost surely. We may also write it more explicitly as:\n", "\n", "$$\n", "r(z) = z \\qquad \\forall z\\in \\op{supp}(p_{\\R{Z}'}):= \\Set{z\\in \\mc{Z}\\mid p_{\\R{Z}'}(z) > 0},\n", "$$\n", "\n", "namely an identity function over the support set of $\\R{Z}'$.\n", "\n", "````" ] }, { "cell_type": "markdown", "id": "b2c68cb0", "metadata": {}, "source": [ "**Exercise** \n", "\n", "Show using the properties of $f$ that $F$ is strictly convex." ] }, { "cell_type": "markdown", "id": "a5e9c68c", "metadata": {}, "source": [ "````{toggle}\n", "**Solution**\n", "\n", "For $\\lambda\\in [0,1]$ and functions $r_1, r_2\\in \\Set{r:\\mc{Z}\\to \\mathbb{R}}$,\n", "\n", "$$\n", "\\begin{align*}\n", "F[\\lambda r_1 + (1-\\lambda) r_2] \n", "&= E[\\underbrace{ f(\\lambda r_1(\\R{Z}') + (1-\\lambda) r_2(\\R{Z}'))}_{\\stackrel{\\text{(a)}}{\\geq} \\lambda f(r_1(\\R{Z}'))+(1-\\lambda) f(r_2(\\R{Z}'))}]\\\\\n", "&\\stackrel{\\text{(b)}}{\\geq} \\lambda E[f(r_1(\\R{Z}'))] + (1-\\lambda) E[f(r_2(\\R{Z}'))]\n", "\\end{align*}\n", "$$\n", "\n", "where (a) is by the convexity of $f$, and (b) is by the linearity of expectation. $F$ is strictly convex because (b) is satisfied with equality iff (a) is almost surely.\n", "\n", "````" ] }, { "cell_type": "markdown", "id": "787242a1", "metadata": {}, "source": [ "For a clearer understanding, consider a different choice of $F$ for the KL divergence:\n", "\n", "$$\n", "D(P_{\\R{Z}}\\|P_{\\R{Z}'}) = F'\\left[r\\right]\n", "$$ \n", "\n", "where\n", "\n", "$$\n", "\\begin{align*}\n", "F'[r] &:= E\\left[ \\log r(\\R{Z})\\right].\n", "\\end{align*}\n", "$$ (rev-KL:F)" ] }, { "cell_type": "markdown", "id": "a9875417", "metadata": {}, "source": [ "Note that $F'$ in {eq}`rev-KL:F` defined above is concave in $r$. In other words, {eq}`D2` in {prf:ref}`DV2`\n", "\n", "$$\n", "\\begin{align*}\n", "D(P_{\\R{Z}}\\|P_{\\R{Z}'}) & = \\sup_{\\substack{r:\\mc{Z}\\to \\mathbb{R}_+\\\\ E[r(\\R{Z}')]=1}} E \\left[ \\log r(\\R{Z}) \\right] \n", "\\end{align*}\n", "$$\n", "\n", "is maximizing a concave function and therefore has a unique solution, namely, $r=\\frac{dP_{\\R{Z}}}{dP_{\\R{Z}'}}$. Here comes the tricky question:" ] }, { "cell_type": "markdown", "id": "ad6e03b0", "metadata": {}, "source": [ "**Exercise**\n", "\n", "Is KL divergence concave or convex in the density ratio $\\frac{dP_{\\R{Z}}}{dP_{\\R{Z}'}}$? Note that $F$ defined in {eq}`KL:F` is convex in $r$." ] }, { "cell_type": "markdown", "id": "d5d01916", "metadata": { "nbgrader": { "grade": true, "grade_id": "rev-KL", "locked": false, "points": 1, "schema_version": 3, "solution": true, "task": false } }, "source": [ "````{toggle}\n", "**Solution**\n", "\n", "The statement is invalid because KL divergence is not purely a function of the density ratio, but both $P_{\\R{Z}}$ and $P_{\\R{Z}'}$. The expectation in {eq}`rev-KL:F`, in particular, depends on $P_{\\R{Z}}$.\n", "\n", "````" ] }, { "cell_type": "markdown", "id": "8d115a5b", "metadata": {}, "source": [ "## Convex conjugation" ] }, { "cell_type": "markdown", "id": "27c6a617", "metadata": {}, "source": [ "Given $P_{\\R{Z}'}\\in \\mc{P}(\\mc{Z})$, consider \n", "- a function space $\\mc{R}$, \n", "\n", "$$\n", "\\begin{align}\n", "\\mc{R} &\\supseteq \\Set{r:\\mathcal{Z}\\to \\mathbb{R}_+\\mid E\\left[r(\\R{Z}')\\right] = 1},\n", "\\end{align}\n", "$$ (R)\n", "\n", "- a dual space $\\mc{T}$, and \n", "\n", "$$\n", "\\begin{align}\n", "\\mc{T} &\\subseteq \\Set{t:\\mc{Z} \\to \\mathbb{R}}\n", "\\end{align}\n", "$$ (T)\n", "\n", "- the corresponding inner product $\\langle\\cdot,\\cdot \\rangle$:\n", "\n", "$$\n", "\\begin{align}\n", "\\langle t,r \\rangle &= \\int_{z\\in \\mc{Z}} t(z) r(z) dP_{\\R{Z}'}(z) = E\\left[ t(\\R{Z}') r(\\R{Z}') \\right].\n", "\\end{align}\n", "$$ (inner-prod)" ] }, { "cell_type": "markdown", "id": "407b3d30", "metadata": {}, "source": [ "The following is a generalization of DV formula for estimating $f$-divergence {cite}`nguyen2010estimating`{cite}`ruderman2012tighter`:" ] }, { "cell_type": "markdown", "id": "afab2159", "metadata": {}, "source": [ "````{prf:proposition} \n", ":label: convex-conjugate\n", "\n", "$$\n", "\\begin{align}\n", "D_{f}(P_{\\R{Z}} \\| P_{\\R{Z}'}) = \\sup _{t\\in \\mc{T}} E[g(\\R{Z})] - F^*[t],\n", "\\end{align} \n", "$$ (convex-conjugate2)\n", "\n", "where \n", "\n", "$$\n", "\\begin{align}\n", "F^*[t] = \\sup_{r\\in \\mc{R}} E[t(\\R{Z}') r(\\R{Z}')] - F[r].\n", "\\end{align}\n", "$$ (convex-conjugate1)\n", "\n", "````" ] }, { "cell_type": "markdown", "id": "f96c87a5", "metadata": {}, "source": [ "````{prf:proof} \n", "\n", "Note that the supremums in {eq}`convex-conjugate1` and {eq}`convex-conjugate2` are [Fenchel-Legendre transforms][FL]. Denoting the transform as $[\\cdot]^*$,\n", "\n", "$$\\underbrace{[[F]^*]^*}_{=F}\\left[\\frac{dP_{\\R{Z}}}{dP_{\\R{Z}'}}\\right]$$\n", "\n", "gives {eq}`convex-conjugate2` by expanding the outer/later transform. The equality is by the property that Fenchel-Legendre transform is its own inverse for strictly convex functional $F$. This completes the proof by {eq}`D->F`.\n", "\n", "[FL]: https://en.wikipedia.org/wiki/Convex_conjugate\n", "\n", "````" ] }, { "cell_type": "markdown", "id": "5fb95daf", "metadata": {}, "source": [ "The proof is illustrated in the following figure:" ] }, { "cell_type": "markdown", "id": "1ab202ad", "metadata": {}, "source": [ "![$f$-Divergence](images/f-D.dio.svg)" ] }, { "cell_type": "markdown", "id": "de73f070", "metadata": {}, "source": [ "Let's breakdown the details:" ] }, { "cell_type": "markdown", "id": "6e7ea66e", "metadata": {}, "source": [ "**Step 1**\n", "\n", "For the purpose of the illustration, visualize the convex functional $F$ simply as a curve in 2D.\n", "\n", "![$f$-Divergence 1](images/f-D-Copy1.dio.svg) \n", "\n", "The $f$-divergence is then the $y$-coordinate of a point on the curve indicated above, with $r$ being the density ratio $\\frac{dP_{\\R{Z}}}{dP_{\\R{Z}'}}$." ] }, { "cell_type": "markdown", "id": "d2bd307c", "metadata": {}, "source": [ "**Step 2**\n", "\n", "To obtain a lower bound on $F$, consider any tangent of the curve with an arbitrary slope $t\\cdot dP_{\\R{Z}'}$\n", "\n", "![$f$-Divergence 2](images/f-D-Copy2.dio.svg)\n", "\n", "The lower bound is given by the $y$-coordinate of a point on the tangent with $r$ being the density ratio." ] }, { "cell_type": "markdown", "id": "64b5c6ab", "metadata": {}, "source": [ "**Exercise**\n", "\n", "Why is the $y$-coordinate of the tangent a lower bound on the $f$-divergence?" ] }, { "cell_type": "markdown", "id": "1032d063", "metadata": {}, "source": [ "````{toggle}\n", "**Solution**\n", "\n", "By the convexity of $F$, the tangent must be below $F$.\n", "\n", "````" ] }, { "cell_type": "markdown", "id": "a18de647", "metadata": {}, "source": [ "**Step 3**\n", "\n", "To calculate the lower bound, denote the $y$-intercept as $-F^*[t]$:\n", "\n", "![$f$-Divergence 3](images/f-D-Copy3.dio.svg) \n", "\n", "Thinking of a function as nothing but a vector, the displacement from the $y$-intercept to the lower bound is given by the inner product of the slope and the density ratio." ] }, { "cell_type": "markdown", "id": "b582faa7", "metadata": {}, "source": [ "**Step 4**\n", "\n", "To make the bound tight, maximize the bound over the choice of the slope or $t$:\n", "\n", "![$f$-Divergence 4](images/f-D-Copy4.dio.svg) \n", "\n", "This gives the bound in {eq}`convex-conjugate2`. It remains to show {eq}`convex-conjugate1`." ] }, { "cell_type": "markdown", "id": "31349dd2", "metadata": {}, "source": [ "**Step 5**\n", "\n", "To compute the $y$-intercept or $F^*[t]$, let $r^*$ be the value of $r$ where the tangent touches the convex curve:\n", "\n", "![$f$-Divergence 5](images/f-D.dio.svg) \n", "\n", "The displacement from the point at $r^*$ to the $y$-intercept can be computed as the inner product of the slope and $r^*$." ] }, { "cell_type": "markdown", "id": "d39a3b80", "metadata": {}, "source": [ "**Exercise**\n", "\n", "Show that for the functional $F$ {eq}`KL:F` defined for KL divergence,\n", "\n", "$$F^*[t]=\\log E[e^{t(\\R{Z}')}]$$\n", "\n", "with $\\mc{R}=\\Set{r:\\mc{Z}\\to \\mathbb{R}_+}$ and so {eq}`convex-conjugate2` gives the DV formula {eq}`DV` as a special case." ] }, { "cell_type": "markdown", "id": "d400cf39", "metadata": { "nbgrader": { "grade": true, "grade_id": "f-D_KL", "locked": false, "points": 1, "schema_version": 3, "solution": true, "task": false } }, "source": [ "````{toggle}\n", "**Solution**\n", "\n", "See {cite}`ruderman2012tighter`.````" ] } ], "metadata": { "jupytext": { "text_representation": { "extension": ".md", "format_name": "myst", "format_version": 0.13, "jupytext_version": "1.10.3" } }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "source_map": [ 14, 18, 30, 34, 42, 48, 67, 71, 77, 82, 86, 90, 94, 119, 124, 140, 144, 159, 165, 184, 200, 212, 218, 227, 231, 258, 262, 283, 297, 301, 305, 309, 319, 329, 335, 344, 354, 364, 374, 384 ] }, "nbformat": 4, "nbformat_minor": 5 }