{ "cells": [ { "cell_type": "markdown", "metadata": { "colab_type": "text", "execution": {}, "id": "view-in-github" }, "source": [ "  " ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "# Bonus Tutorial 5: Expectation Maximization for spiking neurons\n", "\n", "**Week 3, Day 2: Hidden Dynamics**\n", "\n", "**By Neuromatch Academy**\n", "\n", "**Content creators:** Yicheng Fei with help from Jesse Livezey\n", "\n", "**Content reviewers:** John Butler, Matt Krause, Meenakshi Khosla, Spiros Chavlis, Michael Waskom\n", "\n", "**Production editors:** Gagana B, Spiros Chavlis\n", "\n", "
\n", "\n", "**Important Note:** this material was developed in NMA 2020 and has not been revised according to the standards of the Hidden Dynamics material.\n", "\n", "
\n", "\n", "**Acknowledgements:** This tutorial is based on code originally created by Sean Escola." ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "---\n", "# Tutorial objectives\n", "\n", "The Expectation-Maximization (EM) algorithm is a powerful and widely used optimization tool that is much more general than HMMs. Since it is typically taught in the context of Hidden Markov Models, we include it here.\n", "\n", "You will implement an HMM of a network of Poisson spiking neurons mentioned in today's intro and:\n", "\n", "* Implement the forward-backward algorithm\n", "* Complete the E-step and M-step\n", "* Learn parameters for the example problem using the EM algorithm\n", "* Get an intuition of how the EM algorithm monotonically increases data likelihood" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "---\n", "# Setup" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Install and import feedback gadget\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Install and import feedback gadget\n", "\n", "!pip3 install vibecheck datatops --quiet\n", "\n", "from vibecheck import DatatopsContentReviewContainer\n", "def content_review(notebook_section: str):\n", " return DatatopsContentReviewContainer(\n", " \"\", # No text prompt\n", " notebook_section,\n", " {\n", " \"url\": \"https://pmyvdlilci.execute-api.us-east-1.amazonaws.com/klab\",\n", " \"name\": \"neuromatch_cn\",\n", " \"user_key\": \"y1x3mpx5\",\n", " },\n", " ).render()\n", "\n", "\n", "feedback_prefix = \"W3D2_T5_Bonus\"" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "both", "execution": {} }, "outputs": [], "source": [ "import numpy as np\n", "from scipy import stats\n", "from scipy.optimize import linear_sum_assignment\n", "from collections import namedtuple\n", "\n", "import matplotlib.pyplot as plt\n", "from matplotlib import patches\n", "\n", "GaussianHMM1D = namedtuple('GaussianHMM1D', ['startprob', 'transmat','means','vars','n_components'])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Figure Settings\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Figure Settings\n", "import logging\n", "logging.getLogger('matplotlib.font_manager').disabled = True\n", "\n", "from ipywidgets import widgets, interactive, interact, HBox, Layout,VBox\n", "from IPython.display import HTML\n", "%config InlineBackend.figure_format = 'retina'\n", "plt.style.use(\"https://raw.githubusercontent.com/NeuromatchAcademy/course-content/NMA2020/nma.mplstyle\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Plotting functions\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Plotting functions\n", "def plot_spike_train(X, Y, dt):\n", " \"\"\"Plots the spike train for cells across trials and overlay the state.\n", "\n", " Args:\n", " X: (2d numpy array of binary values): The state sequence in a one-hot\n", " representation. (T, states)\n", " Y: (3d numpy array of floats): The spike sequence.\n", " (trials, T, C)\n", " dt (float): Interval for a bin.\n", " \"\"\"\n", " n_trials, T, C = Y.shape\n", " trial_T = T * dt\n", " fig = plt.figure(figsize=(.7 * (12.8 + 6.4), .7 * 9.6))\n", "\n", " # plot state sequence\n", " starts =  + list(np.diff(X.nonzero()).nonzero())\n", " stops = list(np.diff(X.nonzero()).nonzero()) + [T]\n", " states = [X[i + 1].nonzero() for i in starts]\n", " for a, b, i in zip(starts, stops, states):\n", " rect = patches.Rectangle((a * dt, 0), (b - a) * dt, n_trials * C,\n", " facecolor=plt.get_cmap('tab10').colors[i],\n", " alpha=0.15)\n", " plt.gca().add_patch(rect)\n", "\n", " # plot rasters\n", " for c in range(C):\n", " if c > 0:\n", " plt.plot([0, trial_T], [c * n_trials, c * n_trials],\n", " color=plt.get_cmap('tab10').colors)\n", " for r in range(n_trials):\n", " tmp = Y[r, :, c].nonzero()\n", " if len(tmp) > 0:\n", " plt.plot(np.stack((tmp, tmp)) * dt, (c * n_trials + r + 0.1, c * n_trials + r + .9), color='k')\n", "\n", " ax = plt.gca()\n", " plt.yticks(np.arange(0, n_trials * C, n_trials),\n", " labels=np.arange(C, dtype=int))\n", " plt.xlabel('time (s)', fontsize=16)\n", " plt.ylabel('Cell number', fontsize=16)\n", " plt.show(fig)\n", "\n", "\n", "def plot_lls(lls):\n", " \"\"\"Plots log likelihoods at each epoch.\n", " Args:\n", " lls (list of floats) log likelihoods at each epoch.\n", " \"\"\"\n", " epochs = len(lls)\n", " fig, ax = plt.subplots()\n", " ax.plot(range(epochs) , lls, linewidth=3)\n", " span = max(lls) - min(lls)\n", " ax.set_ylim(min(lls) - span * 0.05, max(lls) + span * 0.05)\n", " plt.xlabel('iteration')\n", " plt.ylabel('log likelihood')\n", " plt.show(fig)\n", "\n", "\n", "def plot_lls_eclls(plot_epochs, save_vals):\n", " \"\"\"Plots log likelihoods at each epoch.\n", " Args:\n", " plot_epochs (list of ints): Which epochs were saved to plot.\n", " save_vals (lists of floats): Different likelihoods from EM for plotting.\n", " \"\"\"\n", " rows = int(np.ceil(min(len(plot_epochs), len(save_vals)) / 3))\n", " fig, axes = plt.subplots(rows, 3, figsize=(.7 * 6.4 * 3, .7 * 4.8 * rows))\n", " axes = axes.flatten()\n", "\n", " minll, maxll = np.inf, -np.inf\n", " for i, (ax, (bs, lls_for_plot, eclls_for_plot)) in enumerate(zip(axes, save_vals)):\n", " ax.set_xlim([-1.15, 2.15])\n", " min_val = np.stack((lls_for_plot, eclls_for_plot)).min()\n", " max_val = np.stack((lls_for_plot, eclls_for_plot)).max()\n", "\n", " ax.plot([0, 0], [min_val, lls_for_plot[bs == 0].item()], '--b')\n", " ax.plot([1, 1], [min_val, lls_for_plot[bs == 1].item()], '--b')\n", " ax.set_xticks([0, 1])\n", " ax.set_xticklabels([f'$\\\\theta^{plot_epochs[i]}$',\n", " f'$\\\\theta^{plot_epochs[i] + 1}$'])\n", " ax.tick_params(axis='y')\n", " ax.tick_params(axis='x')\n", "\n", " ax.plot(bs, lls_for_plot)\n", " ax.plot(bs, eclls_for_plot)\n", "\n", " if min_val < minll: minll = min_val\n", " if max_val > maxll: maxll = max_val\n", "\n", " if i % 3 == 0: ax.set_ylabel('log likelihood')\n", " if i == 4:\n", " l = ax.legend(ax.lines[-2:], ['LL', 'ECLL'], framealpha=1)\n", " plt.show(fig)\n", "\n", "\n", "def plot_learnt_vs_true(L_true, L, A_true, A, dt):\n", " \"\"\"Plot and compare the true and learnt parameters.\n", "\n", " Args:\n", " L_true (numpy array): True L.\n", " L (numpy array): Estimated L.\n", " A_true (numpy array): True A.\n", " A (numpy array): Estimated A.\n", " dt (float): Bin length.\n", " \"\"\"\n", " C, K = L.shape\n", " fig = plt.figure(figsize=(8, 4))\n", " plt.subplot(121)\n", " plt.plot([0, L_true.max() * 1.05], [0, L_true.max() * 1.05], '--b')\n", " for i in range(K):\n", " for c in range(C):\n", " plt.plot(L_true[c, i], L[c, i], color='C{}'.format(c),\n", " marker=['o', '*', 'd'][i]) # this line will fail for K > 3\n", " ax = plt.gca()\n", " ax.axis('equal')\n", " plt.xlabel('True firing rate (Hz)')\n", " plt.ylabel('Inferred firing rate (Hz)')\n", " xlim, ylim = ax.get_xlim(), ax.get_ylim()\n", " for c in range(C):\n", " plt.plot([-10^6], [-10^6], 'o', color='C{}'.format(c))\n", " for i in range(K):\n", " plt.plot([-10^6], [-10^6], marker=['o', '*', 'd'][i], c=\"black\")\n", " l = plt.legend(ax.lines[-C - K:], [f'cell {c + 1}' for c in range(C)] + [f'state {i + 1}' for i in range(K)])\n", " ax.set_xlim(xlim), ax.set_ylim(ylim)\n", "\n", " plt.subplot(122)\n", " ymax = np.max(A_true - np.diag(np.diag(A_true))) / dt * 1.05\n", " plt.plot([0, ymax], [0, ymax], '--b')\n", " for j in range(K):\n", " for i in range(K):\n", " if i == j:\n", " continue\n", " plt.plot(A_true[i, j] / dt, A[i, j] / dt, 'o')\n", " ax = plt.gca()\n", " ax.axis('equal')\n", " plt.xlabel('True transition rate (Hz)')\n", " plt.ylabel('Inferred transition rate (Hz)')\n", " l = plt.legend(ax.lines[1:], ['state 1 -> 2',\n", " 'state 1 -> 3',\n", " 'state 2 -> 1',\n", " 'state 2 -> 3',\n", " 'state 3 -> 1',\n", " 'state 3 -> 2'\n", " ])\n", " plt.show(fig)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Helper functions\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Helper functions\n", "def run_em(epochs, Y, psi, A, L, dt):\n", " \"\"\"Run EM for the HMM spiking model.\n", "\n", " Args:\n", " epochs (int): Number of epochs of EM to run\n", " Y (numpy 3d array): Tensor of recordings, has shape (n_trials, T, C)\n", " psi (numpy vector): Initial probabilities for each state\n", " A (numpy matrix): Transition matrix, A[i,j] represents the prob to switch\n", " from j to i. Has shape (K,K)\n", " L (numpy matrix): Poisson rate parameter for different cells.\n", " Has shape (C,K)\n", " dt (float): Duration of a time bin\n", "\n", " Returns:\n", " save_vals (lists of floats): Data for later plotting\n", " lls (list of flots): ll Before each EM step\n", " psi (numpy vector): Estimated initial probabilities for each state\n", " A (numpy matrix): Estimated transition matrix, A[i,j] represents\n", " the prob to switch from j to i. Has shape (K,K)\n", " L (numpy matrix): Estimated Poisson rate parameter for different\n", " cells. Has shape (C,K)\n", " \"\"\"\n", " save_vals = []\n", " lls = []\n", " for e in range(epochs):\n", "\n", " # Run E-step\n", " ll, gamma, xi = e_step(Y, psi, A, L, dt)\n", " lls.append(ll) # log the data log likelihood for current cycle\n", "\n", " if e % print_every == 0: print(f'epoch: {e:3d}, ll = {ll}') # log progress\n", " # Run M-step\n", " psi_new, A_new, L_new = m_step(gamma, xi, dt)\n", "\n", " \"\"\"Booking keeping for later plotting\n", " Calculate the difference of parameters for later\n", " interpolation/extrapolation\n", " \"\"\"\n", " dp, dA, dL = psi_new - psi, A_new - A, L_new - L\n", " # Calculate LLs and ECLLs for later plotting\n", " if e in plot_epochs:\n", " b_min = -min([np.min(psi[dp > 0] / dp[dp > 0]),\n", " np.min(A[dA > 0] / dA[dA > 0]),\n", " np.min(L[dL > 0] / dL[dL > 0])])\n", " b_max = -max([np.max(psi[dp < 0] / dp[dp < 0]),\n", " np.max(A[dA < 0] / dA[dA < 0]),\n", " np.max(L[dL < 0] / dL[dL < 0])])\n", " b_min = np.max([.99 * b_min, b_lims])\n", " b_max = np.min([.99 * b_max, b_lims])\n", " bs = np.linspace(b_min, b_max, num_plot_vals)\n", " bs = sorted(list(set(np.hstack((bs, [0, 1])))))\n", " bs = np.array(bs)\n", " lls_for_plot = []\n", " eclls_for_plot = []\n", " for i, b in enumerate(bs):\n", " ll = e_step(Y, psi + b * dp, A + b * dA, L + b * dL, dt)\n", " lls_for_plot.append(ll)\n", " rate = (L + b * dL) * dt\n", " ecll = ((gamma[:, 0] @ np.log(psi + b * dp) +\n", " (xi * np.log(A + b * dA)).sum(axis=(-1, -2, -3)) +\n", " (gamma * stats.poisson(rate).logpmf(Y[..., np.newaxis]).sum(-2)\n", " ).sum(axis=(-1, -2))).mean() / T / dt)\n", " eclls_for_plot.append(ecll)\n", " if b == 0:\n", " diff_ll = ll - ecll\n", " lls_for_plot = np.array(lls_for_plot)\n", " eclls_for_plot = np.array(eclls_for_plot) + diff_ll\n", " save_vals.append((bs, lls_for_plot, eclls_for_plot))\n", " # return new parameter\n", " psi, A, L = psi_new, A_new, L_new\n", "\n", " ll = e_step(Y, psi, A, L, dt)\n", " lls.append(ll)\n", " print(f'epoch: {epochs:3d}, ll = {ll}')\n", " return save_vals, lls, psi, A, L" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "---\n", "# Section 0: Introduction" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Video 1: Introduction\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "remove-input" ] }, "outputs": [], "source": [ "#@title Video 1: Introduction\n", "# Insert the ID of the corresponding youtube video\n", "from IPython.display import YouTubeVideo\n", "video = YouTubeVideo(id=\"ceQXN0OUaFo\", width=730, height=410, fs=1)\n", "print(\"Video available at https://youtu.be/\" + video.id)\n", "video" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Submit your feedback\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Submit your feedback\n", "content_review(f\"{feedback_prefix}_Introduction_Video\")" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "---\n", "\n", "# Section 1: HMM for Poisson spiking neuronal network" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Video 2: HMM for Poisson spiking neurons case study\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "remove-input" ] }, "outputs": [], "source": [ "#@title Video 2: HMM for Poisson spiking neurons case study\n", "# Insert the ID of the corresponding youtube video\n", "from IPython.display import YouTubeVideo\n", "video = YouTubeVideo(id=\"Wb8mf5chmyI\", width=730, height=410, fs=1)\n", "print(\"Video available at https://youtu.be/\" + video.id)\n", "video" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Submit your feedback\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Submit your feedback\n", "content_review(f\"{feedback_prefix}_HMM_for_Poisson_spiking_neurons_Video\")" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "Given noisy neural or behavioral measurements, we as neuroscientists often want to infer the unobserved latent variables as they change over time. Thalamic relay neurons fire in two distinct modes: a tonic mode where spikes are produced one at a time, and a 'burst mode' where several action potentials are produced in rapid succession. These modes are thought to differentially encode how the neurons relay information from sensory receptors to cortex. A distinct molecular mechanism, T-type calcium channels, switches neurons between modes, but it is very challenging to measure in the brain of a living monkey. However, statistical approaches let us recover the hidden state of those calcium channels purely from their spiking activity, which can be measured in a behaving monkey.\n", "\n", "Here, we're going to tackle a simplified version of that problem.\n", "\n", "\n", "Let's consider the formulation mentioned in the intro lecture.\n", "We have a network of $C$ neurons switching between $K$ states. Neuron $c$ has firing rate $\\lambda_i^c$ in state $i$. The transition between states are represented by the $K\\times K$ transition matrix $A_{ij}$ and initial probability vector $\\psi$ with length $K$ at time $t=1$.\n", "\n", "Let $y_t^c$ be the number of spikes for cell $c$ in time bin $t$.\n", "\n", "
\n", "\n", "In the following exercises (1 and 2) and tutorials, you will\n", "\n", "* Define an instance of such model with $C=5$ and $K=3$\n", "* Generate a dataset from this model\n", "* (**Exercise 1**) Implement the M-step for this HMM\n", "* Run EM to estimate all parameters $A,\\psi,\\lambda_i^c$\n", "* Plot the learning likelihood curve\n", "* Plot expected complete log likelihood versus data log likelihood\n", "* Compare learnt parameters versus true parameters" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "## Define model and generate data\n", "\n", "Let's first generate a random state sequence from the hidden Markov Chain, and generate n_frozen_trials different trials of spike trains for each cell assuming they all use the same underlying sequence we just generated.\n", "\n", "**Suggestions**\n", "\n", "1. Run the following two sections **Model and simulation parameters** and **Initialize true model** to define a true model and parameters that will be used in our following exercises. Please take a look at the parameters and come back to these two cells if you encounter a variable you don't know in the future.\n", "\n", "2. Run the provided code to convert a given state sequence to corresponding spike rates for all cells at all times, and use provided code to visualize all spike trains.\n", "\n" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "#### Model and simulation parameters" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "execution": {} }, "outputs": [], "source": [ "# model and data parameters\n", "C = 5 # number of cells\n", "K = 3 # number of states\n", "dt = 0.002 # seconds\n", "trial_T = 2.0 # seconds\n", "n_frozen_trials = 20 # used to plot multiple trials with the same state sequence\n", "n_trials = 300 # number of trials (each has it's own state sequence)\n", "\n", "# for random data\n", "max_firing_rate = 50 # Hz\n", "max_transition_rate = 3 # Hz\n", "\n", "# needed to plot LL and ECLL for every M-step\n", "# **This substantially slows things down!!**\n", "num_plot_vals = 10 # resolution of the plot (this is the expensive part)\n", "b_lims = (-1, 2) # lower limit on graph (b = 0 is start-of-M-step LL; b = 1 is end-of-M-step LL)\n", "plot_epochs = list(range(9)) # list of epochs to plot" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "#### Initialize true model" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "execution": {} }, "outputs": [], "source": [ "np.random.seed(101)\n", "T = round(trial_T / dt)\n", "ts = np.arange(T)\n", "\n", "# initial state distribution\n", "psi = np.arange(1, K + 1)\n", "psi = psi / psi.sum()\n", "\n", "# off-diagonal transition rates sampled uniformly\n", "A = np.random.rand(K, K) * max_transition_rate * dt\n", "A = (1. - np.eye(K)) * A\n", "A = A + np.diag(1 - A.sum(1))\n", "\n", "# hand-crafted firing rates make good plots\n", "L = np.array([\n", " [.02, .8, .37],\n", " [1., .7, .1],\n", " [.92, .07, .5],\n", " [.25, .42, .75],\n", " [.15, .2, .85]\n", "]) * max_firing_rate # (C,K)\n", "\n", "# Save true parameters for comparison later\n", "psi_true = psi\n", "A_true = A\n", "L_true = L" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "#### Generate data with frozen sequence and plot\n", "Given a state sequence [0,1,1,3,2,...], we'll first convert each state into sequence - the so-called \"one-hot\" coding. For example, with 5 total states, the one-hot coding of state 0 is [1,0,0,0,0] and the coding for state 3 is [0,0,0,1,0]. Suppose we now have a sequence of length T, the one-hot coding of this sequence Xf will have shape (T,K)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "execution": {} }, "outputs": [], "source": [ "np.random.seed(101)\n", "# sample n_frozen_trials state sequences\n", "Xf = np.zeros(T, dtype=int)\n", "Xf = (psi.cumsum() > np.random.rand()).argmax()\n", "for t in range(1, T):\n", " Xf[t] = (A[Xf[t - 1],:].cumsum() > np.random.rand()).argmax()\n", "\n", "# switch to one-hot encoding of the state\n", "Xf = np.eye(K, dtype=int)[Xf] # (T,K)\n", "\n", "# get the Y values\n", "Rates = np.squeeze(L @ Xf[..., None]) * dt # (T,C)\n", "\n", "Rates = np.tile(Rates, [n_frozen_trials, 1, 1]) # (n_trials, T, C)\n", "Yf = stats.poisson(Rates).rvs()\n", "\n", "plot_spike_train(Xf, Yf, dt)" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "#### Generate data for EM learning\n", "\n", "The previous dataset is generated with the same state sequence for visualization. Now let's generate n_trials trials of observations, each one with its own randomly generated sequence" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "execution": {} }, "outputs": [], "source": [ "np.random.seed(101)\n", "# sample n_trials state sequences\n", "X = np.zeros((n_trials, T), dtype=int)\n", "X[:, 0] = (psi_true.cumsum(0)[:, None] > np.random.rand(n_trials)).argmax(0)\n", "for t in range(1, T):\n", " X[:, t] = (A_true[X[:, t - 1], :].T.cumsum(0) > np.random.rand(n_trials)).argmax(0)\n", "\n", "# switch to one-hot encoding of the state\n", "one_hot = np.eye(K)[np.array(X).reshape(-1)]\n", "X = one_hot.reshape(list(X.shape) + [K])\n", "\n", "# get the Y values\n", "Y = stats.poisson(np.squeeze(L_true @ X[..., None]) * dt).rvs() # (n_trials, T, C)\n", "print(\"Y has shape: (n_trial={},T={},C={})\".format(*Y.shape))" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "---\n", "\n", "# Section 2: EM algorithm for HMM" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "##### Video 3: EM Tutorial\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "remove-input" ] }, "outputs": [], "source": [ "#@title Video 3: EM Tutorial\n", "# Insert the ID of the corresponding youtube video\n", "from IPython.display import YouTubeVideo\n", "video = YouTubeVideo(id=\"umU4wUWlKvg\", width=730, height=410, fs=1)\n", "print(\"Video available at https://youtu.be/\" + video.id)\n", "video" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "##### Submit your feedback\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Submit your feedback\n", "content_review(f\"{feedback_prefix}_EM_tutorial_Video\")" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "Finding the optimal values of parameters that maximizes the data likelihood is practically infeasible since we need to integrating out all latent variables $x_{1:T}$. The time needed is exponential to $T$. Thus as an alternative approach, we use the Expectation-Maximization algorithm, which iteratively performs an E-step followed by a M-step and is guaranteed to not decrease(usually increase) the data likelihood after each EM cycle.\n", "\n", "\n", "In this section we will briefly review the EM algorithm for HMM and list\n", "\n", "* Recursive equations for forward and backward probabilities $a_i(t)$ and $b_i(t)$\n", "* Expressions for singleton and pairwise marginal distributions after seeing data: $\\gamma_{i}(t):=p_{\\theta}\\left(x_{t}=i | Y_{1: T}\\right)$ and $\\xi_{i j}(t) = p_{\\theta}(x_t=i,x_{t+1}=j|Y_{1:T})$\n", "* Closed-form solutions for updated values of $A,\\psi,\\lambda$ which increases data likelihood\n", "\n", "\n", "### E-step: Forward-backward algorithm\n", "In the forward pass, we calculate the **forward probabilities**, or the joint probability of $x_t$ and current and past data $Y_{1:t}$: $a_i(t):=p(x_t=i,Y_{1:t})$ recursively by\n", "\n", "\\begin{equation}\n", "a_i(t) = p_(y_t|x_i=t)\\sum_j A_{ji} a_j(t-1)\n", "\\end{equation}\n", "\n", "In contrast to the intro, now $A_{ji}$ means **the transition probability from state $j$ to state $i$.**\n", "\n", "The backward pass calculate the **backward probabilities** $b_i(t):=p_{\\theta}(Y_{t+1:T}|x_t=i)$, which is the likelihood of observing all future data points given current state $x_t$. The recursion of $b_i(t)$ is given by\n", "\n", "\\begin{equation}\n", "b_i(t) = \\sum_j p_{\\theta}(y_{t+1}|x_{t+1}=j)b_j(t+1)A_{ij}\n", "\\end{equation}\n", "\n", "Combining all past and future information, the **singleton and pairwise marginal distributions** are given by\n", "\n", "\\begin{equation}\n", "\\gamma_{i}(t):=p_{\\theta}\\left(x_{t}=i | Y_{1: T}\\right)=\\frac{a_{i}(t) b_{i}(t)}{p_{\\theta}\\left(Y_{1: T}\\right)}\n", "\\end{equation}\n", "\n", "\\begin{equation}\n", "\\xi_{i j}(t) = p_{\\theta}(x_t=i,x_{t+1}=j|Y_{1:T}) =\\frac{b_{j}(t+1)p_{\\theta}\\left(y_{t+1} | x_{t+1}=j\\right) A_{i j} a_{i}(t)}{p_{\\theta}\\left(Y_{1: T}\\right)}\n", "\\end{equation}\n", "\n", "where $p_{\\theta}(Y_{1:T})=\\sum_i a_i(T)$.\n", "\n", "### M-step\n", "\n", "The M-step for HMM has a closed-form solution. First the new transition matrix is given by\n", "\n", "\\begin{equation}\n", "A_{ij} =\\frac{\\sum_{t=1}^{T-1} \\xi_{i j}(t)}{\\sum_{t=1}^{T-1} \\gamma_{i}(t)}\n", "\\end{equation}\n", "\n", "which is the expected empirical transition probabilities.\n", "New initial probabilities and parameters of the emission models are also given by their empirical values given single and pairwise marginal distributions:\n", "\n", "\\begin{align}\n", "\\psi_i &= \\frac{1}{N}\\sum_{trials}\\gamma_i(1) \\\\\n", "\\lambda_{i}^{c} &= \\frac{\\sum_{t} \\gamma_{i}(t) y_{t}^{c}}{\\sum_{t} \\gamma_{i}(t) d t}\n", "\\end{align}" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "### E-step: forward and backward algorithm\n", "\n", "**(Optional)**\n", "\n", "In this section you will read through the code for the forward-backward algorithm and understand how to implement the computation efficiently in numpy by calculating the recursion for all trials at once.\n", "\n", "Let's re-write the forward and backward recursions in a more compact form:\n", "\n", "\\begin{eqnarray}\n", "a_i^t &=& \\sum_j A_{ji}o_j^t a_j^{t-1}\\\\\n", "b^t_i &=& \\sum_j A_{ij} o_j^{t+1}b_j^{t+1} \\text{, where } o_j^{t}=p(y_{t}|x_{t}=j)\n", "\\end{eqnarray}\n", "\n", "\n", "Let's take the backward recursion for example. In practice we will handle all trials together since they are independent of each other. After adding a trial index $l$ to the recursion equations, the backward recursion becomes:\n", "\n", "\\begin{equation}\n", "b^t_{li} = \\sum_j A_{ij} o_{lj}^{t+1}b_{lj}^{t+1}\n", "\\end{equation}\n", "\n", "What we have in hand are:\n", "* A: matrix of size (K,K)\n", "* o^{t+1}: array of size (N,K) is the log data likelihood for all trials at a given time\n", "* b^{t+1}: array of size (N,K) is the backward probability for all trials at a given time\n", "\n", "where N stands for the number of trials.\n", "\n", "The index size and meaning doesn't match for these three arrays: the index is $i$ for $A$ in the first dimension and is $l$ for $o$ and $b$, so we can't just multiply them together. However, we can do this by viewing vectors $o^{t+1}_{l\\cdot}$ and $b^{t+1}_{l\\cdot}$ as a matrix with 1 row and re-write the backward equation as:\n", "\n", "\\begin{equation}\n", "b^t_{li} = \\sum_j A_{ij} o_{l1j}^{t+1}b_{l1j}^{t+1}\n", "\\end{equation}\n", "\n", "Now we can just multiply these three arrays element-wise and sum over the last dimension.\n", "\n", "In numpy, we can achieve this by indexing the array with None at the location we want to insert a dimension. Take b with size (N,T,K) for example,b[:,t,:] will have shape (N,K), b[:,t,None,:] will have shape (N,1,K) and b[:,t,:,None] will have shape (N,K,1).\n", "\n", "So the backward recursion computation can be implemented as\n", "\n", "python\n", "b[:,t,:] = (A * o[:,t+1,None,:] * b[:,t+1,None,:]).sum(-1)\n", "\n", "\n", "
\n", "\n", "**For reference:**\n", "\n", "New transition matrix is calculated as empirical counts of transition events from marginals\n", "\n", "\\begin{equation}\n", "A_{ij} =\\frac{\\sum_{t=1}^{T-1} \\xi_{i j}(t)}{\\sum_{t=1}^{T-1} \\gamma_{i}(t)}\n", "\\end{equation}\n", "\n", "New spiking rates for each cell and each state are given by\n", "\n", "\\begin{equation}\n", "\\lambda_{i}^{c}=\\frac{\\sum_{t} \\gamma_{i}(t) y_{t}^{c}}{\\sum_{t} \\gamma_{i}(t) d t}\n", "\\end{equation}" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "execution": {} }, "outputs": [], "source": [ "def m_step(gamma, xi, dt):\n", " \"\"\"Calculate the M-step updates for the HMM spiking model.\n", "\n", " Args:\n", " gamma (numpy 3d array): singleton marginal distribution.\n", " Has shape (n_trials, T, K)\n", " xi (numpy 3d array): Tensor of recordings, has shape (n_trials, T, C)\n", " dt (float): Duration of a time bin\n", "\n", " Returns:\n", " psi_new (numpy vector): Updated initial probabilities for each state\n", " A_new (numpy matrix): Updated transition matrix, A[i,j] represents the\n", " prob. to switch from j to i. Has shape (K,K)\n", " L_new (numpy matrix): Updated Poisson rate parameter for different\n", " cells. Has shape (C,K)\n", " \"\"\"\n", " raise NotImplementedError(\"m_step need to be implemented\")\n", " ############################################################################\n", " # Insert your code here to:\n", " # Calculate the new prior probabilities in each state at time 0\n", " # Hint: Take the first time step and average over all trials\n", " ###########################################################################\n", " psi_new = ...\n", " # Make sure the probabilities are normalized\n", " psi_new /= psi_new.sum()\n", "\n", " # Calculate new transition matrix\n", " A_new = xi.sum(axis=(0, 1)) / gamma[:, :-1].sum(axis=(0, 1))[:, np.newaxis]\n", " # Calculate new firing rates\n", " L_new = (np.swapaxes(Y, -1, -2) @ gamma).sum(axis=0) / gamma.sum(axis=(0, 1)) / dt\n", " return psi_new, A_new, L_new" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "execution": {} }, "source": [ "[*Click for solution*](https://github.com/NeuromatchAcademy/course-content/tree/main/tutorials/W3D2_HiddenDynamics/solutions/W3D2_Tutorial5_Solution_a471c4a4.py)\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "##### Submit your feedback\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Submit your feedback\n", "content_review(f\"{feedback_prefix}_Implement_M_step_Exercise\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "##### Video 5: Running and plotting EM\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "remove-input" ] }, "outputs": [], "source": [ "#@title Video 5: Running and plotting EM\n", "# Insert the ID of the corresponding youtube video\n", "from IPython.display import YouTubeVideo\n", "video = YouTubeVideo(id=\"6UTsXxE3hG0\", width=730, height=410, fs=1)\n", "print(\"Video available at https://youtu.be/\" + video.id)\n", "video" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "##### Submit your feedback\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Submit your feedback\n", "content_review(f\"{feedback_prefix}_Running_and_plotting_EM_Video\")" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "---\n", "\n", "### Run EM\n", "\n", "####Initialization for parameters\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "execution": {} }, "outputs": [], "source": [ "np.random.seed(101)\n", "# number of EM steps\n", "epochs = 9\n", "print_every = 1\n", "\n", "# initial state distribution\n", "psi = np.arange(1, K + 1)\n", "psi = psi / psi.sum()\n", "\n", "# off-diagonal transition rates sampled uniformly\n", "A = np.ones((K, K)) * max_transition_rate * dt / 2\n", "A = (1 - np.eye(K)) * A\n", "A = A + np.diag(1 - A.sum(1))\n", "\n", "# firing rates sampled uniformly\n", "L = np.random.rand(C, K) * max_firing_rate" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "execution": {} }, "outputs": [], "source": [ "# LL for true vs. initial parameters\n", "print(f'LL for true 𝜃: {e_step(Y, psi_true, A_true, L_true, dt)}')\n", "print(f'LL for initial 𝜃: {e_step(Y, psi, A, L, dt)}\\n')\n", "\n", "# Run EM\n", "save_vals, lls, psi, A, L = run_em(epochs, Y, psi, A, L, dt)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "execution": {} }, "outputs": [], "source": [ "# EM doesn't guarantee the order of learnt latent states are the same as that of true model\n", "# so we need to sort learnt parameters\n", "\n", "# Compare all true and estimated latents across cells\n", "cost_mat = np.sum((L_true[..., np.newaxis] - L[:, np.newaxis])**2, axis=0)\n", "true_ind, est_ind = linear_sum_assignment(cost_mat)\n", "\n", "psi = psi[est_ind]\n", "A = A[est_ind]\n", "A = A[:, est_ind]\n", "L = L[:, est_ind]" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "## Plotting the training process and learnt model" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "### Plotting progress during EM!\n", "\n", "Now you can\n", "\n", "* Plot the likelihood during training\n", "* Plot the M-step log likelihood versus expected complete log likelihood (ECLL) to get an intuition of how EM works and the convexity of ECLL\n", "* Plot learnt parameters versus true parameters" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "execution": {} }, "outputs": [], "source": [ "# Plot the log likelihood after each epoch of EM\n", "with plt.xkcd():\n", " plot_lls(lls)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "execution": {} }, "outputs": [], "source": [ "# For each saved epoch, plot the log likelihood and expected complete log likelihood\n", "# for the initial and final parameter values\n", "with plt.xkcd():\n", " plot_lls_eclls(plot_epochs, save_vals)" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "### Plot learnt parameters vs. true parameters\n", "\n", "Now we will plot the (sorted) learnt parameters with true parameters to see if we successfully recovered all the parameters." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "execution": {} }, "outputs": [], "source": [ "# Compare true and learnt parameters\n", "with plt.xkcd():\n", " plot_learnt_vs_true(L_true, L, A_true, A, dt)" ] } ], "metadata": { "@webio": { "lastCommId": null, "lastKernelId": null }, "colab": { "collapsed_sections": [], "include_colab_link": true, "name": "W3D2_Tutorial5", "provenance": [], "toc_visible": true }, "kernel": { "display_name": "Python 3", "language": "python", "name": "python3" }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.17" }, "toc": { "base_numbering": 1, "nav_menu": {}, "number_sections": false, "sideBar": true, "skip_h1_title": true, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": {}, "toc_section_display": true, "toc_window_display": true } }, "nbformat": 4, "nbformat_minor": 0 }