{ "cells": [ { "cell_type": "markdown", "id": "0a2a2a32", "metadata": { "slideshow": { "slide_type": "slide" } }, "source": [ "# Mutual Information Estimation" ] }, { "cell_type": "markdown", "id": "b5e19802", "metadata": {}, "source": [ "$\\def\\abs#1{\\left\\lvert #1 \\right\\rvert}\n", "\\def\\Set#1{\\left\\{ #1 \\right\\}}\n", "\\def\\mc#1{\\mathcal{#1}}\n", "\\def\\M#1{\\boldsymbol{#1}}\n", "\\def\\R#1{\\mathsf{#1}}\n", "\\def\\RM#1{\\boldsymbol{\\mathsf{#1}}}\n", "\\def\\op#1{\\operatorname{#1}}\n", "\\def\\E{\\op{E}}\n", "\\def\\d{\\mathrm{\\mathstrut d}}$" ] }, { "cell_type": "code", "execution_count": null, "id": "2a2fd0c6", "metadata": {}, "outputs": [], "source": [ "# init\n", "import ipywidgets as widgets\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import pandas as pd\n", "import seaborn as sns\n", "import tensorboard as tb\n", "import torch\n", "import torch.optim as optim\n", "from IPython import display\n", "from torch import Tensor, nn\n", "from torch.nn import functional as F\n", "from torch.utils.tensorboard import SummaryWriter\n", "\n", "%load_ext tensorboard\n", "%matplotlib inline\n", "\n", "\n", "def plot_samples_with_kde(df, **kwargs):\n", " p = sns.PairGrid(df, **kwargs)\n", " p.map_lower(sns.scatterplot, s=2) # scatter plot of samples\n", " p.map_upper(sns.kdeplot) # kernel density estimate for pXY\n", " p.map_diag(sns.kdeplot) # kde for pX and pY\n", " return p\n", "\n", "\n", "SEED = 0\n", "\n", "# set device\n", "DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", "\n", "# create samples\n", "XY_rng = np.random.default_rng(SEED)\n", "rho = 1 - 0.19 * XY_rng.random()\n", "mean, cov, n = [0, 0], [[1, rho], [rho, 1]], 1000\n", "XY = XY_rng.multivariate_normal(mean, cov, n)\n", "\n", "XY_ref_rng = np.random.default_rng(SEED)\n", "cov_ref, n_ = [[1, 0], [0, 1]], n\n", "XY_ref = XY_ref_rng.multivariate_normal(mean, cov_ref, n_)\n", "\n", "Z = Tensor(XY).to(DEVICE)\n", "Z_ref = Tensor(XY_ref).to(DEVICE)\n", "\n", "ground_truth = -0.5 * np.log(1 - rho ** 2)" ] }, { "cell_type": "markdown", "id": "f1746d17", "metadata": {}, "source": [ "**How to estimate MI via KL divergence?**" ] }, { "cell_type": "markdown", "id": "c65b33f8", "metadata": {}, "source": [ "In this notebook, we will introduce a few methods of estimating the mutual information ({prf:ref}mi-estimation) via KL divergence." ] }, { "cell_type": "markdown", "id": "b26dc49b", "metadata": {}, "source": [ "We first introduce the Mutual Information Neural Estimation (MINE) method in {cite}belghazi2018mine." ] }, { "cell_type": "markdown", "id": "61a9fb99", "metadata": {}, "source": [ "## MINE" ] }, { "cell_type": "markdown", "id": "3ecd0c93", "metadata": {}, "source": [ "The idea is to obtain MI {eq}MI from KL divergence {eq}D as follows:" ] }, { "cell_type": "markdown", "id": "f1b1e406", "metadata": {}, "source": [ "\n", "\\begin{align*}\n", "I(\\R{X}\\wedge \\R{Y}) = D(\\underbrace{P_{\\R{X},\\R{Y}}}_{P_{\\R{Z}}}\\| \\underbrace{P_{\\R{X}}\\times P_{\\R{Y}}}_{P_{\\R{Z}'}}).\n", "\\end{align*}\n", "" ] }, { "cell_type": "markdown", "id": "09b5a1c6", "metadata": {}, "source": [ "and then apply the DV formula {eq}avg-DV to estimate the divergence:" ] }, { "cell_type": "markdown", "id": "c395ef67", "metadata": {}, "source": [ "{prf:definition} MINE \n", ":label: MINE\n", "\n", "The idea of mutual information neural estimation (MINE) is to estimate $I(\\R{X}\\wedge\\R{Y})$ by\n", "\n", "\n", "\\begin{align}\n", "\\sup_{t_{\\theta}: \\mc{Z} \\to \\mathbb{R}} \\overbrace{\\frac1n \\sum_{i\\in [n]} t_{\\theta}(\\R{X}_i,\\R{Y}_i) - \\frac1{n'}\\sum_{i\\in [n']} e^{t_{\\theta}(\\R{X}'_i,\\R{Y}'_i)}}^{\\R{I}_{\\text{MINE}}(\\theta):=}\n", "\\end{align}\n", " (MINE)\n", "\n", "where \n", "\n", "- the supremum is over $t_{\\theta}$ representable by a neural network with trainable/optimizable parameters $\\theta$,\n", "- $P_{\\R{X}',\\R{Y}'}:=P_{\\R{X}}\\times P_{\\R{Y}}$, and\n", "- $(\\R{X}'_i,\\R{Y}'_i\\mid i\\in [n'])$ is the sequence of i.i.d. samples of $P_{\\R{X}',\\R{Y}'}$.\n", "\n", "" ] }, { "cell_type": "markdown", "id": "058a79c7", "metadata": {}, "source": [ "The above is not a complete description of MINE because there are some implementation details yet to be filled in." ] }, { "cell_type": "markdown", "id": "32b6b1d5", "metadata": {}, "source": [ "### Resampling" ] }, { "cell_type": "markdown", "id": "d1b78795", "metadata": { "slideshow": { "slide_type": "subslide" }, "tags": [] }, "source": [ "**How to obtain the reference samples ${\\R{Z}'}^{n'}$, i.e., ${\\R{X}'}^{n'}$ and ${\\R{Y}'}^{n'}$?**" ] }, { "cell_type": "markdown", "id": "3f51624e", "metadata": {}, "source": [ "We can approximate the i.i.d. sampling of $P_{\\R{X}}\\times P_{\\R{Y}}$ using samples from $P_{\\R{X},\\R{Y}}$ by a re-sampling trick:" ] }, { "cell_type": "markdown", "id": "8fe9f92b", "metadata": { "slideshow": { "slide_type": "fragment" }, "tags": [] }, "source": [ "\n", "\\begin{align}\n", "P_{\\R{Z}'^{n'}} &\\approx P_{((\\R{X}_{\\R{J}_i},\\R{Y}_{\\R{K}_i})\\mid i \\in [n'])}\n", "\\end{align}\n", " (resample)\n", "\n", "where $\\R{J}_i$ and $\\R{K}_i$ for $i\\in [n']$ are independent and uniformly random indices\n", "\n", "$$\n", "P_{\\R{J},\\R{K}} = \\op{Uniform}_{[n]\\times [n]}\n", "$$\n", "\n", "and $[n]:=\\Set{1,\\dots,n}$." ] }, { "cell_type": "markdown", "id": "29939c96", "metadata": { "slideshow": { "slide_type": "subslide" }, "tags": [] }, "source": [ "MINE {cite}belghazi2018mine uses the following implementation that samples $(\\R{J},\\R{K})$ but without replacement. You can change $n'$ using the slider for n_." ] }, { "cell_type": "code", "execution_count": null, "id": "b7c857a2", "metadata": {}, "outputs": [], "source": [ "rng = np.random.default_rng(SEED)\n", "\n", "\n", "def resample(data, size, replace=False):\n", " index = rng.choice(range(data.shape[0]), size=size, replace=replace)\n", " return data[index]\n", "\n", "\n", "@widgets.interact(n_=widgets.IntSlider(n, 2, n, continuous_update=False))\n", "def plot_resampled_data_without_replacement(n_):\n", " XY_ = np.block([resample(XY[:, [0]], n_), resample(XY[:, [1]], n_)])\n", " resampled_data = pd.DataFrame(XY_, columns=[\"X'\", \"Y'\"])\n", " plot_samples_with_kde(resampled_data)\n", " plt.show()" ] }, { "cell_type": "markdown", "id": "6862a427", "metadata": {}, "source": [ "In the above, the function resample\n", "- uses choice to uniformly randomly select a choice\n", "- from a range of integers going from 0 to \n", "- the size of the first dimension of the data." ] }, { "cell_type": "markdown", "id": "85b80369", "metadata": {}, "source": [ "Note however that the sampling is without replacement." ] }, { "cell_type": "markdown", "id": "e6ffa90f", "metadata": {}, "source": [ "**Is it a good idea to sample without replacement?**" ] }, { "cell_type": "markdown", "id": "c1feeb4a", "metadata": {}, "source": [ "**Exercise**\n", "\n", "Sampling without replacement has an important restriction $n'\\leq n$. Why?" ] }, { "cell_type": "markdown", "id": "b5b9d252", "metadata": { "nbgrader": { "grade": true, "grade_id": "without-replacement", "locked": false, "points": 1, "schema_version": 3, "solution": true, "task": false } }, "source": [ "{toggle}\n", "**Solution**\n", "\n", "To allow $n>n'$, at least one sample $(\\R{X}_i,\\R{Y}_i)$ needs to be sampled more than once.\n", "\n", "" ] }, { "cell_type": "markdown", "id": "b998d28d", "metadata": {}, "source": [ "**Exercise** To allow $n>n'$, complete the following code to sample with replacement and observe what happens when $n \\gg n'$." ] }, { "cell_type": "code", "execution_count": null, "id": "71dc175a", "metadata": { "nbgrader": { "grade": false, "grade_id": "resample", "locked": false, "schema_version": 3, "solution": true, "task": false }, "tags": [ "hide-cell" ] }, "outputs": [], "source": [ "@widgets.interact(n_=widgets.IntSlider(n, 2, 10 * n, continuous_update=False))\n", "def plot_resampled_data_with_replacement(n_):\n", " ### BEGIN SOLUTION\n", " XY_ = np.block(\n", " [resample(XY[:, [0]], n_, replace=True), resample(XY[:, [1]], n_, replace=True)]\n", " )\n", " ### END SOLUTION\n", " resampled_data = pd.DataFrame(XY_, columns=[\"X'\", \"Y'\"])\n", " plot_samples_with_kde(resampled_data)\n", " plt.show()" ] }, { "cell_type": "markdown", "id": "1934bc41", "metadata": { "slideshow": { "slide_type": "fragment" }, "tags": [] }, "source": [ "**Exercise** \n", "\n", "Explain whether the resampling trick gives i.i.d. samples $(\\R{X}_{\\R{J}_i},\\R{Y}_{\\R{K}_i})$ for the cases with replacement and without replacement respectively?" ] }, { "cell_type": "markdown", "id": "fcbce31a", "metadata": { "nbgrader": { "grade": true, "grade_id": "non-iid", "locked": false, "points": 1, "schema_version": 3, "solution": true, "task": false }, "slideshow": { "slide_type": "fragment" }, "tags": [ "hide-cell" ] }, "source": [ "{toggle}\n", "**Solution** \n", "\n", "The samples are identically distributed. However, they are not independent except in the trivial case $n=1$ or $n'=1$, regardless of whether the sample is with replacement or not. Consider $n=1$ and $n'=2$ as an example.\n", "\n", "" ] }, { "cell_type": "markdown", "id": "b584ca20", "metadata": {}, "source": [ "### Smoothing" ] }, { "cell_type": "markdown", "id": "5a3a562f", "metadata": {}, "source": [ "To improve the stability of the training, MINE applies additional smoothing to the gradient calculation." ] }, { "cell_type": "markdown", "id": "43f66540", "metadata": {}, "source": [ "**How to train the neural network?**" ] }, { "cell_type": "markdown", "id": "1a7919c0", "metadata": {}, "source": [ "{eq}MINE can be solved iteratively by minibatch gradient descent using the loss function:\n", "\n", "\n", "\\begin{align}\n", "\\R{L}_{\\text{MINE}}(\\theta) &:= \\overbrace{- \\frac{1}{\\abs{\\R{B}}} \\sum_{i\\in \\R{B}} t_{\\theta} (\\R{X}_i, \\R{Y}_i) }^{\\R{L}(\\theta):=} + \\log \\overbrace{\\frac{1}{\\abs{\\R{B}'}} \\sum_{i\\in \\R{B}'} e^{t_{\\theta} (\\R{X}'_i, \\R{Y}'_i)}}^{\\R{L}'(\\theta):=}\n", "\\end{align}\n", "\n", "\n", "where $\\R{B}$ and $\\R{B}'$ are batches of uniformly randomly chosen indices from $[n]$ and $[n']$ respectively." ] }, { "cell_type": "markdown", "id": "6cdf8429", "metadata": {}, "source": [ "The gradient expressed in terms of $\\R{L}$ and $\\R{L}'$ is:\n", "\n", "\n", "\\begin{align}\n", "\\nabla \\R{L}_{\\text{MINE}}(\\theta) &= \\nabla \\R{L}(\\theta) + \\frac{\\nabla \\R{L}'(\\theta)}{\\R{L}'(\\theta)}.\n", "\\end{align}\n", "" ] }, { "cell_type": "markdown", "id": "ecf951b5", "metadata": {}, "source": [ "Variation in $\\nabla \\R{L}'(\\theta)$ leads to the variation of the overall gradient especially when $\\R{L}'(\\theta)$ is small. With minibatch gradient descent, the sample average is over a small batch and so the variance can be quite large." ] }, { "cell_type": "markdown", "id": "968bf88c", "metadata": {}, "source": [ "**How to reduce the variation in the gradient approximation?**" ] }, { "cell_type": "markdown", "id": "8c008ab1", "metadata": {}, "source": [ "To alleviate the variation, MINE replaces the denominator $\\R{L}'(\\theta)$ by its moving average:" ] }, { "cell_type": "markdown", "id": "15b81f73", "metadata": {}, "source": [ "$$\n", "\\theta_{j+1} := \\theta_j - s_j \\underbrace{\\left(\\nabla \\R{L}_j (\\theta_j) + \\frac{\\nabla \\R{L}'_j(\\theta_j)}{\\overline{\\R{L}'}_j}\\right)}_{\\text{(*)}}\n", "$$ (MINE:update)\n", "\n", "where $\\R{L}_j$ and $\\R{L}'_j$ are the \n", "\n", "$$\n", "\\overline{\\R{L}'}_j = \\beta \\overline{\\R{L}'}_{j-1} + (1-\\beta) \\R{L}'(\\theta_j)\n", "$$ (MINE:L2bar)\n", "\n", "for some smoothing factor $\\beta\\in [0,1]$." ] }, { "cell_type": "markdown", "id": "7f5de427", "metadata": {}, "source": [ "**How to implement the smoothing?**" ] }, { "cell_type": "markdown", "id": "f2f6b898", "metadata": {}, "source": [ "Instead of implementing a new optimizer, a simpler way is to redefine the loss at each step $i$ as follows {cite}belghazi2018mine:" ] }, { "cell_type": "markdown", "id": "9c3a7383", "metadata": {}, "source": [ "$$\n", "\\R{L}_{\\text{MINE},j}(\\theta) = \\R{L}_j(\\theta) + \\frac{\\R{L}'_j(\\theta)}{\\overline{\\R{L}'}_j}\n", "$$ (MINE:mv)\n", "\n", "where $\\overline{\\R{L}'}_j$ in {eq}MINE:L2bar is regarded as a constant in calculating the gradient. It is easy to verify that $\\nabla \\R{L}_{\\text{MINE},j}(\\theta_j)$ gives the desired gradient (*) in {eq}MINE:update." ] }, { "cell_type": "markdown", "id": "04d70a42", "metadata": {}, "source": [ "In summary:" ] }, { "cell_type": "markdown", "id": "7da14f9a", "metadata": {}, "source": [ "{prf:definition} \n", "\n", "MINE estimates the mutual information $I(\\R{X}\\wedge\\R{Y})$ as $\\R{I}_{\\text{MINE}}(\\theta_j)$ {eq}MINE where $\\theta_j$ is updated by descending along the gradient of $\\R{L}_{\\text{MINE},j}$ {eq}MINE:mv iteratively after $j$ steps.\n", "\n", "" ] }, { "cell_type": "markdown", "id": "8c95feba", "metadata": { "tags": [] }, "source": [ "### Implementation" ] }, { "cell_type": "markdown", "id": "28f5a212", "metadata": {}, "source": [ "Consider implementing MINE for the jointly gaussian $\\R{X}$ and $\\R{Y}$:" ] }, { "cell_type": "code", "execution_count": null, "id": "6f93f85e", "metadata": {}, "outputs": [], "source": [ "SEED = 0\n", "\n", "# set device\n", "DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", "\n", "# create samples\n", "XY_rng = np.random.default_rng(SEED)\n", "rho = 1 - 0.19 * XY_rng.random()\n", "mean, cov, n = [0, 0], [[1, rho], [rho, 1]], 1000\n", "XY = XY_rng.multivariate_normal(mean, cov, n)\n", "\n", "X = Tensor(XY[:, [0]]).to(DEVICE)\n", "Y = Tensor(XY[:, [1]]).to(DEVICE)\n", "ground_truth = -0.5 * np.log(1 - rho ** 2)\n", "X.shape, Y.shape, ground_truth" ] }, { "cell_type": "markdown", "id": "47da6f20", "metadata": {}, "source": [ "The tensors X and Y correspond to samples of $\\R{X}$ and $\\R{Y}$ respectively. The first dimension indices the different samples. The ground_truth is the actual mutual information $I(\\R{X}\\wedge\\R{Y})$." ] }, { "cell_type": "markdown", "id": "def547c0", "metadata": {}, "source": [ "**Exercise**\n", "\n", "Complete the definition of forward so that the neural network $t_{\\theta}$ is a vectorized function that takes samples x and y of $\\R{X}$ and $\\R{Y}$ as input and approximates the density ratio $\\frac{P_{\\R{X},\\R{Y}}}{P_{\\R{X}}\\times P_{\\R{Y}}}$ at (x, y):\n", "\n", "![Neural network](images/nn.dio.svg)\n", "\n", "- Use torch.cat to concatenate tensors x and y in the last dimension.\n", "- Use F.elu for the activation function $\\sigma$." ] }, { "cell_type": "code", "execution_count": null, "id": "d20c1722", "metadata": { "nbgrader": { "grade": false, "grade_id": "MINE-Net", "locked": false, "schema_version": 3, "solution": true, "task": false } }, "outputs": [], "source": [ "class Net(nn.Module):\n", " def __init__(self, input_size=2, hidden_size=100, sigma=0.02):\n", " super().__init__()\n", " # fully-connected (fc) layers\n", " self.fc1 = nn.Linear(input_size, hidden_size)\n", " self.fc2 = nn.Linear(hidden_size, hidden_size) # layer 2\n", " self.fc3 = nn.Linear(hidden_size, 1) # layer 3\n", " nn.init.normal_(self.fc1.weight, std=sigma) #\n", " nn.init.constant_(self.fc1.bias, 0)\n", " nn.init.normal_(self.fc2.weight, std=sigma)\n", " nn.init.constant_(self.fc2.bias, 0)\n", " nn.init.normal_(self.fc3.weight, std=sigma)\n", " nn.init.constant_(self.fc3.bias, 0)\n", "\n", " def forward(self, x, y):\n", " \"\"\"\n", " Vectorized function that implements the neural network t(x,y).\n", "\n", " Parameters:\n", " -----------\n", " x, y: 2D Tensors where the first dimensions index different samples.\n", " \"\"\"\n", " ### BEGIN SOLUTION\n", " a1 = F.elu(self.fc1(torch.cat([x, y], dim=-1)))\n", " a2 = F.elu(self.fc2(a1))\n", " t = self.fc3(a2)\n", " ### END SOLUTION\n", " return t\n", "\n", " def plot(self, xmin=-5, xmax=5, ymin=-5, ymax=5, xgrids=50, ygrids=50, ax=None):\n", " \"\"\"Plot a heat map of a neural network net. net can only have two inputs.\"\"\"\n", " x, y = np.mgrid[xmin : xmax : xgrids * 1j, ymin : ymax : ygrids * 1j]\n", " with torch.no_grad():\n", " z = (\n", " net(\n", " Tensor(x.reshape(-1, 1)).to(DEVICE),\n", " Tensor(y.reshape(-1, 1)).to(DEVICE),\n", " )\n", " .reshape(x.shape)\n", " .cpu()\n", " )\n", " if ax is None:\n", " ax = plt.gca()\n", " im = ax.pcolormesh(x, y, z, cmap=\"RdBu_r\", shading=\"auto\")\n", " ax.figure.colorbar(im)\n", " ax.set(xlabel=r\"$x$\", ylabel=r\"$y$\", title=r\"Heatmap of $t(x,y)$\")" ] }, { "cell_type": "markdown", "id": "13d50d9d", "metadata": {}, "source": [ "To create and plot the neural network:" ] }, { "cell_type": "code", "execution_count": null, "id": "bc9d2162", "metadata": {}, "outputs": [], "source": [ "torch.manual_seed(SEED)\n", "net = Net().to(DEVICE)\n", "net.plot()" ] }, { "cell_type": "markdown", "id": "ce8253cf", "metadata": {}, "source": [ "**Exercise**\n", "\n", "To implement a neural network trainer for MINE similar to DVTrainer, completes the definition of loss according to {eq}MINE:mv:" ] }, { "cell_type": "code", "execution_count": null, "id": "9de4757f", "metadata": { "nbgrader": { "grade": false, "grade_id": "MINETrainer", "locked": false, "schema_version": 3, "solution": true, "task": false }, "tags": [] }, "outputs": [], "source": [ "class MINETrainer:\n", " def __init__(\n", " self, X, Y, net, n_iters_per_epoch, writer_params={}, m=1, beta=0.1, **kwargs\n", " ):\n", " \"\"\"\n", " Neural estimator for Mutual Information based on MINE.\n", "\n", " Estimate I(X;Y) using samples of X and Y by training a network t to maximize\n", " avg(t(X,Y)) - avg(e^t(X',Y')) / m\n", " where samples of (X',Y') approximates P_X * P_Y using the resampling trick, and\n", " m is obtained by smoothing avg(e^t(X',Y')) with the factor beta.\n", "\n", " parameters:\n", " -----------\n", "\n", " X, Y : Tensors with first dimensions of the same size indicing the samples.\n", " net : The neural network t that takes X and Y as input and output a real number for each a real number for each sample.\n", " n_iters_per_epoch : Number of iterations per epoch.\n", " m : initial value for the moving average\n", " beta : Smoothing/forgetting factor between [0,1]\n", " writer_params : Parameters to be passed to SummaryWriter for logging.\n", " \"\"\"\n", " self.X = X\n", " self.Y = Y\n", " self.n = min(X.shape[0], Y.shape[0])\n", " self.beta = beta\n", " self.m = m\n", " self.net = net\n", " self.n_iters_per_epoch = n_iters_per_epoch\n", "\n", " # set optimizer\n", " self.optimizer = optim.Adam(self.net.parameters(), **kwargs)\n", "\n", " # logging\n", " self.writer = SummaryWriter(**writer_params)\n", " self.n_iter = self.n_epoch = 0\n", "\n", " def step(self, epochs=1):\n", " \"\"\"\n", " Carries out the gradient descend for a number of epochs and returns the MI estimate evaluated over the entire data.\n", "\n", " Loss for each epoch is recorded into the log, but only one MI estimate is computed/logged using the entire dataset.\n", " Rerun the method to continue to train the neural network and log the results.\n", "\n", " Parameters:\n", " -----------\n", " epochs : number of epochs\n", " \"\"\"\n", " for i in range(epochs):\n", " self.n_epoch += 1\n", "\n", " # random indices for selecting samples for all batches in one epoch\n", " idx = torch.randperm(self.n)\n", "\n", " # resampling to approximate the sampling of (X',Y')\n", " idx_X = torch.randperm(self.n)\n", " idx_Y = torch.randperm(self.n)\n", "\n", " for j in range(self.n_iters_per_epoch): # loop through multiple batches\n", " self.n_iter += 1\n", " self.optimizer.zero_grad()\n", "\n", " # obtain a random batch of samples\n", " batch_X = self.X[idx[j : self.n : self.n_iters_per_epoch]]\n", " batch_Y = self.Y[idx[j : self.n : self.n_iters_per_epoch]]\n", " batch_X_ref = self.X[idx_X[j : self.n : self.n_iters_per_epoch]]\n", " batch_Y_ref = self.Y[idx_Y[j : self.n : self.n_iters_per_epoch]]\n", "\n", " # define the loss\n", " # BEGIN SOLUTION\n", " L = -self.net(batch_X, batch_Y).mean()\n", " L_ = self.net(batch_X_ref, batch_Y_ref).exp().mean()\n", " self.m = (1 - self.beta) * L_.detach() + self.beta * self.m\n", " loss = L + L_ / self.m\n", " # END SOLUTION\n", " loss.backward()\n", " self.optimizer.step()\n", "\n", " self.writer.add_scalar(\"Loss/train\", loss.item(), global_step=self.n_iter)\n", "\n", " with torch.no_grad():\n", " idx_X = torch.randperm(self.n)\n", " idx_Y = torch.randperm(self.n)\n", " X_ref = self.X[idx_X]\n", " Y_ref = self.Y[idx_Y]\n", " estimate = (\n", " self.net(self.X, self.Y).mean()\n", " - self.net(X_ref, Y_ref).logsumexp(0)\n", " + np.log(self.n)\n", " ).item()\n", " self.writer.add_scalar(\"Estimate\", estimate, global_step=self.n_epoch)\n", " return estimate" ] }, { "cell_type": "markdown", "id": "fe031909", "metadata": {}, "source": [ "To create the trainer for MINE:" ] }, { "cell_type": "code", "execution_count": null, "id": "af16d6b6", "metadata": {}, "outputs": [], "source": [ "trainer = MINETrainer(X, Y, net, n_iters_per_epoch=10)" ] }, { "cell_type": "markdown", "id": "f58a983a", "metadata": {}, "source": [ "To train the neural network:" ] }, { "cell_type": "code", "execution_count": null, "id": "7b12ce11", "metadata": {}, "outputs": [], "source": [ "if input(\"Train? [Y/n]\").lower() != \"n\":\n", " for i in range(10):\n", " print('MI estimate:', trainer.step(10))\n", " net.plot()" ] }, { "cell_type": "code", "execution_count": null, "id": "1cdd8395", "metadata": {}, "outputs": [], "source": [ "%tensorboard --logdir=runs" ] }, { "cell_type": "markdown", "id": "194571ac", "metadata": {}, "source": [ "**Exercise** \n", "\n", "See if you can get an estimate close to the ground truth by training the neural network repeated as shown below.\n", "\n", "![MI estimate](images/MI_est.dio.svg)" ] }, { "cell_type": "markdown", "id": "abb068f0", "metadata": {}, "source": [ "To clean-up:" ] }, { "cell_type": "code", "execution_count": null, "id": "5f2bb527", "metadata": {}, "outputs": [], "source": [ "if input('Delete logs? [y/N]').lower() == 'y':\n", " !rm -rf ./runs" ] }, { "cell_type": "code", "execution_count": null, "id": "4da4a8d7", "metadata": {}, "outputs": [], "source": [ "tb.notebook.list() # list all the running TensorBoard notebooks.\n", "while (pid := input('pid to kill? (press enter to exit)')):\n", " !kill {pid}" ] }, { "cell_type": "markdown", "id": "b6630b16", "metadata": { "tags": [] }, "source": [ "## MI-NEE" ] }, { "cell_type": "markdown", "id": "e2a542a0", "metadata": {}, "source": [ "**Is it possible to generate i.i.d. samples for ${\\R{Z}'}^{n'}$?**" ] }, { "cell_type": "markdown", "id": "3f0865c4", "metadata": {}, "source": [ "Consider another formula for mutual information:" ] }, { "cell_type": "markdown", "id": "ea514f8c", "metadata": {}, "source": [ "{prf:proposition} \n", ":label: MI-3D\n", "\n", "\n", "\\begin{align}\n", "I(\\R{X}\\wedge \\R{Y}) &= D(P_{\\R{X},\\R{Y}}\\|P_{\\R{X}'}\\times P_{\\R{Y}'}) - D(P_{\\R{X}}\\|P_{\\R{X}'}) - D(P_{\\R{Y}}\\|P_{\\R{Y}'})\n", "\\end{align}\n", " (MI-3D)\n", "\n", "for any product reference distribution $P_{\\R{X}'}\\times P_{\\R{Y}'}$ for which the divergences are finite.\n", "\n", "" ] }, { "cell_type": "markdown", "id": "734753f8", "metadata": {}, "source": [ "{prf:corollary} \n", ":label: MI-ub\n", "\n", "\n", "\n", "\\begin{align}\n", "I(\\R{X}\\wedge \\R{Y}) &= \\inf_{\\substack{P_{\\R{X}'}\\in \\mc{P}(\\mc{X})\\\\ P_{\\R{Y}'}\\in \\mc{P}(\\mc{Y})}} D(P_{\\R{X},\\R{Y}}\\|P_{\\R{X}'}\\times P_{\\R{Y}'}).\n", "\\end{align}\n", " (MI-ub)\n", "\n", "where the optimal solution is $P_{\\R{X}'}\\times P_{\\R{Y}'}=P_{\\R{X}}\\times P_{\\R{Y}}$, the product of marginal distributions of $\\R{X}$ and $\\R{Y}$. \n", "\n", "" ] }, { "cell_type": "markdown", "id": "2dff6de1", "metadata": {}, "source": [ "{prf:proof} \n", "\n", "{eq}MI-ub follows from {eq}MI-3D directly since the divergences are non-negative. To prove the proposition:\n", "\n", "\n", "\\begin{align}\n", "I(\\R{X}\\wedge \\R{Y}) &= H(\\R{X}) + H(\\R{Y}) - H(\\R{X},\\R{Y})\\\\\n", "&= E\\left[-\\log dP_{\\R{X}'}(\\R{X})\\right] - D(P_{\\R{X}}\\|P_{\\R{X}'})\\\\\n", "&\\quad+E\\left[-\\log dP_{\\R{Y}'}(\\R{Y})\\right] - D(P_{\\R{Y}}\\|P_{\\R{Y}'})\\\\\n", "&\\quad-E\\left[-\\log d(P_{\\R{X}'}\\times P_{\\R{Y}'})(\\R{X},\\R{Y})\\right] + D(P_{\\R{X},\\R{Y}}\\|P_{\\R{X}'}\\times P_{\\R{Y}'})\\\\\n", "&= D(P_{\\R{X},\\R{Y}}\\|P_{\\R{X}'}\\times P_{\\R{Y}'}) - D(P_{\\R{X}}\\|P_{\\R{X}'}) - D(P_{\\R{Y}}\\|P_{\\R{Y}'})\n", "\\end{align}\n", "\n", "\n", "" ] }, { "cell_type": "markdown", "id": "ed179778", "metadata": {}, "source": [ "*Mutual Information Neural Entropic Estimation (MI-NEE)* {cite}chan2019neural uses {eq}MI-3D to estimate MI by estimating the three divergences." ] }, { "cell_type": "markdown", "id": "718c5bc4", "metadata": {}, "source": [ "Applying {eq}avg-DV to each divergence in {eq}MI-3D:" ] }, { "cell_type": "markdown", "id": "f54a6097", "metadata": {}, "source": [ "\n", "\\begin{align}\n", "I(\\R{X}\\wedge \\R{Y}) &\\approx \\sup_{t: \\mc{Z} \\to \\mathbb{R}} \\frac1n \\sum_{i\\in [n]} t_{\\R{X},\\R{Y}}(\\R{X}_i,\\R{Y}_i) - \\frac1{n'}\\sum_{i\\in [n']} e^{t_{\\R{X},\\R{Y}}(\\R{X}'_i,\\R{Y}'_i)}\\\\\n", "&\\quad - \\sup_{t: \\mc{Z} \\to \\mathbb{R}} \\frac1n \\sum_{i\\in [n]} t_{\\R{X}}(\\R{X}_i) - \\frac1{n'}\\sum_{i\\in [n']} e^{t_{\\R{X}}(\\R{X}'_i)} \\\\\n", "&\\quad - \\sup_{t: \\mc{Z} \\to \\mathbb{R}} \\frac1n \\sum_{i\\in [n]} t_{\\R{Y}}(\\R{Y}_i) - \\frac1{n'}\\sum_{i\\in [n']} e^{t_{\\R{Y}}(\\R{Y}'_i)}\n", "\\end{align}\n", " (MI-NEE)" ] }, { "cell_type": "markdown", "id": "29c4a582", "metadata": {}, "source": [ "$P_{\\R{X}'}$ and $P_{\\R{Y}'}$ are known distributions and so arbitrarily many i.i.d. samples can be drawn from them directly without using the resampling trick {eq}resample." ] }, { "cell_type": "markdown", "id": "2277dc05", "metadata": {}, "source": [ "Indeed, since the choice of $P_{\\R{X}'}$ and $P_{\\R{Y}'}$ are arbitrary, we can also also train neural networks to optimize them. In particular, {eq}MI-ub is a special case where we can train neural networks to approximate $P_{\\R{X}}$ and $P_{\\R{Y}}$." ] } ], "metadata": { "jupytext": { "text_representation": { "extension": ".md", "format_name": "myst", "format_version": 0.13, "jupytext_version": "1.10.3" } }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "source_map": [ 14, 18, 30, 76, 80, 84, 88, 92, 96, 104, 108, 129, 133, 137, 141, 145, 161, 165, 180, 187, 191, 195, 201, 210, 214, 237, 243, 252, 256, 260, 264, 276, 286, 290, 294, 298, 312, 316, 320, 328, 332, 340, 344, 348, 364, 368, 379, 435, 439, 443, 449, 552, 556, 558, 562, 569, 571, 579, 583, 588, 594, 598, 602, 606, 621, 637, 655, 659, 663, 673, 677 ] }, "nbformat": 4, "nbformat_minor": 5 }