{ "cells": [ { "cell_type": "markdown", "id": "e2c2bea1", "metadata": { "slideshow": { "slide_type": "slide" }, "tags": [] }, "source": [ "# Neural Estimation via DV bound" ] }, { "cell_type": "markdown", "id": "6019983c", "metadata": {}, "source": [ "$\\def\\abs#1{\\left\\lvert #1 \\right\\rvert}\n", "\\def\\Set#1{\\left\\{ #1 \\right\\}}\n", "\\def\\mc#1{\\mathcal{#1}}\n", "\\def\\M#1{\\boldsymbol{#1}}\n", "\\def\\R#1{\\mathsf{#1}}\n", "\\def\\RM#1{\\boldsymbol{\\mathsf{#1}}}\n", "\\def\\op#1{\\operatorname{#1}}\n", "\\def\\E{\\op{E}}\n", "\\def\\d{\\mathrm{\\mathstrut d}}$" ] }, { "cell_type": "markdown", "id": "8799d4f7", "metadata": {}, "source": [ "Estimating MI well neither requires nor implies the divergence/density to be estimated well. However, \n", "- MI estimation is often not the end goal, but an objective to train a neural network to return the divergence/density. \n", "- The features/representations learned by the neural network may be applicable to different downstream inference tasks." ] }, { "cell_type": "markdown", "id": "db21d4f5", "metadata": { "tags": [] }, "source": [ "## Neural estimation of KL divergence" ] }, { "cell_type": "markdown", "id": "03c73a8d", "metadata": {}, "source": [ "To explain the idea of neural estimation, consider the following characterization of divergence:" ] }, { "cell_type": "markdown", "id": "005bbbf6", "metadata": {}, "source": [ "````{prf:proposition} \n", ":label: DV1\n", "\n", "$$\n", "\\begin{align}\n", "D(P_{\\R{Z}}\\|P_{\\R{Z}'}) & = \\sup_{Q\\in \\mc{P}(\\mc{Z})} E \\left[ \\log \\frac{dQ(\\R{Z})}{dP_{\\R{Z}'}(\\R{Z})} \\right] \n", "\\end{align}\n", "$$ (D1)\n", "\n", "where the unique optimal solution is $Q=P_{\\R{Z}}$.\n", "\n", "````" ] }, { "cell_type": "markdown", "id": "6874cef8", "metadata": {}, "source": [ "{eq}`D1` is {eq}`D` but with $P_{\\R{Z}}$ replaced by a parameter $Q$.\n", "\n", "- The proposition essentially gives a tight lower bound on KL divergence.\n", "- The unknown distribution is recovered as the optimal solution." ] }, { "cell_type": "markdown", "id": "c9813f88", "metadata": {}, "source": [ "````{prf:proof} \n", "\n", "To prove {eq}`D1`,\n", "\n", "$$\n", "\\begin{align*}\n", "D(P_{\\R{Z}}\\|P_{\\R{Z}'}) &= D(P_{\\R{Z}}\\|P_{\\R{Z}'}) - \\inf_{Q\\in \\mc{P}(\\mc{Z})} \\underbrace{D(P_{\\R{Z}}\\|Q)}_{\\geq 0 \\text{ with equality iff } Q=P_{\\R{Z}}\\kern-3em} \\\\\n", "&= \\sup_{Q\\in \\mc{P}(\\mc{Z})} \\underbrace{D(P_{\\R{Z}}\\|P_{\\R{Z}'})}_{=E \\left[\\log \\frac{dP_{\\R{Z}}(\\R{Z})}{dP_{\\R{Z}'}(\\R{Z})}\\right]} - \\underbrace{D(P_{\\R{Z}}\\|Q)}_{=E \\left[\\log \\frac{dP_{\\R{Z}}(\\R{Z})}{dQ(\\R{Z})}\\right]}\\\\\n", "&= \\sup_{Q\\in \\mc{P}(\\mc{Z})} E \\left[\\log \\frac{dQ(\\R{Z})}{dP_{\\R{Z}'}(\\R{Z})}\\right]\n", "\\end{align*}\n", "$$\n", "\n", "````" ] }, { "cell_type": "markdown", "id": "1500bb85", "metadata": {}, "source": [ "The idea of neural estimation is to \n", "\n", "- estimate the expectation in {eq}`D1` by the sample average \n", "\n", "$$\n", "\\frac1n \\sum_{i\\in [n]} \\log \\underbrace{\\frac{dQ(\\R{Z}_i)}{dP_{\\R{Z}'}(\\R{Z}_i)}}_{\\text{(*)}},\n", "$$" ] }, { "cell_type": "markdown", "id": "4d1843db", "metadata": {}, "source": [ "- use a neural network to compute the density ratio (*), and train the network to maximizes the expectation, e.g., by gradient ascent on the above sample average." ] }, { "cell_type": "markdown", "id": "c4c39d5b", "metadata": {}, "source": [ "Since $Q$ is arbitrary, the sample average above is a valid estimate." ] }, { "cell_type": "markdown", "id": "042a97d0", "metadata": {}, "source": [ "**But how to compute the density ratio?**" ] }, { "cell_type": "markdown", "id": "dbfac31b", "metadata": {}, "source": [ "We will first consider estimating the KL divergence $D(P_{\\R{Z}}\\|P_{\\R{Z}'})$ when both $P_{\\R{Z}}$ and $P_{\\R{Z}'}$ are unknown." ] }, { "cell_type": "markdown", "id": "142b04d6", "metadata": {}, "source": [ "## Donsker-Varadhan formula" ] }, { "cell_type": "markdown", "id": "5ba9d24c", "metadata": {}, "source": [ "If $P_{\\R{Z}'}$ is unknown, we can apply a change of variable" ] }, { "cell_type": "markdown", "id": "466a9b1c", "metadata": {}, "source": [ "$$\n", "r(z) = \\frac{dQ(z)}{dP_{\\R{Z}'}(z)},\n", "$$ (Q->r)" ] }, { "cell_type": "markdown", "id": "b95c7074", "metadata": {}, "source": [ "which absorbs the unknown reference into the parameter." ] }, { "cell_type": "markdown", "id": "99a1118d", "metadata": {}, "source": [ "````{prf:proposition} \n", ":label: DV2\n", "\n", "$$\n", "\\begin{align}\n", "D(P_{\\R{Z}}\\|P_{\\R{Z}'}) & = \\sup_{\\substack{r:\\mc{Z}\\to \\mathbb{R}_+\\\\ E[r(\\R{Z}')]=1}} E \\left[ \\log r(\\R{Z}) \\right] \n", "\\end{align}\n", "$$ (D2)\n", "\n", "where the optimal $r$ satisfies \n", "$\n", "r(\\R{Z}) = \\frac{dP_{\\R{Z}}(\\R{Z})}{dP_{\\R{Z}'}(\\R{Z})}.\n", "$ \n", "\n", "````" ] }, { "cell_type": "markdown", "id": "b67fe3d3", "metadata": {}, "source": [ "**Exercise** \n", "\n", "Show using {eq}`Q->r` that the optimal solution satisfies the constraint stated in the supremum {eq}`D2`." ] }, { "cell_type": "markdown", "id": "3595ebeb", "metadata": { "nbgrader": { "grade": true, "grade_id": "optional-r", "locked": false, "points": 1, "schema_version": 3, "solution": true, "task": false } }, "source": [ "````{toggle}\n", "**Solution**\n", "\n", "The constraint on $r$ is obtained from the constraint on $Q\\in \\mc{P}(\\mc{Z})$, i.e., with $dQ(z)=r(z)dP_{\\R{Z}'}(z)$, \n", "\n", "$$\n", "\\begin{align*}\n", "dQ(z) \\geq 0 &\\iff r(z)\\geq 0\\\\\n", "\\int_{\\mc{Z}}dQ(z)=1 &\\iff E[r(\\R{Z}')]=1.\n", "\\end{align*}\n", "$$\n", "\n", "````" ] }, { "cell_type": "markdown", "id": "57aeb941", "metadata": {}, "source": [ "The next step is to train a neural network that computes $r$. What about?" ] }, { "cell_type": "markdown", "id": "35b8e82f", "metadata": {}, "source": [ "$$\n", "\\begin{align}\n", "D(P_{\\R{Z}}\\|P_{\\R{Z}'}) \\approx \\sup_{\\substack{r:\\mc{Z}\\to \\mathbb{R}_+\\\\ \\frac1{n'}\\sum_{i\\in [n']} r(\\R{Z}'_i)=1}} \\frac1n \\sum_{i\\in [n]} \\log r(\\R{Z}_i)\n", "\\end{align}\n", "$$ (avg-D1)" ] }, { "cell_type": "markdown", "id": "9c8e2db9", "metadata": {}, "source": [ "**How to impose the constraint on $r$ when training a neural network?**" ] }, { "cell_type": "markdown", "id": "ab30cf3e", "metadata": {}, "source": [ "We can apply a change of variable:\n", "\n", "$$\n", "\\begin{align}\n", "r(z)&=\\frac{e^{t(z)}}{E[e^{t(\\R{Z}')}]}.\n", "\\end{align}\n", "$$ (r->t)" ] }, { "cell_type": "markdown", "id": "ee8593b0", "metadata": {}, "source": [ "**Exercise** \n", "\n", "Show that $r$ defined in {eq}`r->t` satisfies the constraint in {eq}`D1` for all real-valued function $t:\\mc{Z}\\to \\mathbb{R}$." ] }, { "cell_type": "markdown", "id": "bf5ed652", "metadata": { "nbgrader": { "grade": true, "grade_id": "r-t", "locked": false, "points": 1, "schema_version": 3, "solution": true, "task": false } }, "source": [ "````{toggle}\n", "**Solution** \n", "\n", "$$\n", "\\begin{align}\n", "E\\left[ \\frac{e^{t(\\R{Z}')}}{E[e^{t(\\R{Z}')}]} \\right] = \\frac{E\\left[ e^{t(\\R{Z}')} \\right]}{E[e^{t(\\R{Z}')}]} = 1.\n", "\\end{align}\n", "$$\n", "\n", "````" ] }, { "cell_type": "markdown", "id": "dd20c32a", "metadata": {}, "source": [ "Substituting {eq}`r->t` into {eq}`D1` gives the well-known *Donsker-Varadhan (DV)* formula {cite}`donsker1983asymptotic`:" ] }, { "cell_type": "markdown", "id": "b3d825c2", "metadata": {}, "source": [ "````{prf:corollary} Donsker-Varadhan \n", ":label: DV3\n", "\n", "$$\n", "\\begin{align}\n", "D(P_{\\R{Z}}\\|P_{\\R{Z}'}) = \\sup_{t: \\mc{Z} \\to \\mathbb{R}} E[t(\\R{Z})] - \\log E[e^{t(\\R{Z}')}]\n", "\\end{align}\n", "$$ (DV)\n", "\n", "where the optimal $t$ satisfies\n", "\n", "$$\n", "\\begin{align}\n", "t(\\R{Z}) = \\log \\frac{dP_{\\R{Z}}(\\R{Z})}{dP_{\\R{Z}'}(\\R{Z})} + c\n", "\\end{align}\n", "$$ (DV:sol)\n", "\n", "almost surely for some constant $c$.\n", "\n", "````" ] }, { "cell_type": "markdown", "id": "cdc65bde", "metadata": {}, "source": [ "The divergence can be estimated as follows instead of {eq}`avg-D1`:" ] }, { "cell_type": "markdown", "id": "6c81b5bd", "metadata": {}, "source": [ "$$\n", "\\begin{align}\n", "D(P_{\\R{Z}}\\|P_{\\R{Z}'}) \\approx \\sup_{t: \\mc{Z} \\to \\mathbb{R}} \\frac1n \\sum_{i\\in [n]} t(\\R{Z}_i) - \\frac1{n'}\\sum_{i\\in [n']} e^{t(\\R{Z}'_i)}\n", "\\end{align}\n", "$$ (avg-DV)" ] }, { "cell_type": "markdown", "id": "cdaa6ef8", "metadata": {}, "source": [ "In summary, the neural estimation of KL divergence is a sample average of {eq}`D` but \n", "\n", "$$\n", "D(P_{\\R{Z}}\\| P_{\\R{Z}'}) = \\underset{\\stackrel{\\uparrow}\\sup_t}{} \\overbrace{E}^{\\op{avg}} \\bigg[ \\log \\underbrace{\\frac{P_{\\R{Z}}(\\R{Z})}{P_{\\R{Z}'}(\\R{Z})}}_{\\frac{e^{t(\\R{Z})}}{\\underbrace{E}_{\\op{avg}}[e^{t(\\R{Z}')}]}} \\bigg].\n", "$$\n", "\n", "but with the unknown density ratio replaced by {eq}`r->t` trained as a neural network." ] } ], "metadata": { "jupytext": { "text_representation": { "extension": ".md", "format_name": "myst", "format_version": 0.13, "jupytext_version": "1.10.3" } }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "source_map": [ 14, 18, 30, 36, 40, 44, 59, 66, 82, 92, 96, 100, 104, 108, 112, 116, 122, 126, 144, 150, 166, 170, 178, 182, 192, 198, 211, 215, 238, 242, 250 ] }, "nbformat": 4, "nbformat_minor": 5 }