{ "cells": [ { "cell_type": "markdown", "metadata": { "colab_type": "text", "execution": {}, "id": "view-in-github" }, "source": [ "  " ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "# Tutorial 1: Sequential Probability Ratio Test\n", "**Week 3, Day 2: Hidden Dynamics**\n", "\n", "**By Neuromatch Academy**\n", "\n", "__Content creators:__ Yicheng Fei and Xaq Pitkow\n", "\n", "__Content reviewers:__ John Butler, Matt Krause, Spiros Chavlis, Melvin Selim Atay, Keith van Antwerp, Michael Waskom, Jesse Livezey, and Byron Galbraith\n", "\n", "__Production Editor:__ Ella Batty\n", "\n", "__Post-production team:__ Gagana B, Spiros Chavlis" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ " " ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "# Tutorial Objectives\n", "\n", "*Estimated timing of tutorial: 45 minutes*\n", "\n", "On Bayes Day, we learned how to combine the sensory measurement $m$ about a latent variable $s$ with our prior knowledge, using Bayes' Theorem. This produced a posterior probability distribution $p(s|m)$. Today we will allow for _dynamic_ world states and measurements.\n", "\n", "In Tutorial 1 we will assume that the world state is _binary_ ($\\pm 1$) and _constant_ over time, but allow for multiple observations over time. We will use the *Sequential Probability Ratio Test* (SPRT) to infer which state is true. This leads to the *Drift Diffusion Model (DDM)* where evidence accumulates until reaching a stopping criterion.\n", "\n", "By the end of this tutorial, you should be able to:\n", "- Define and implement the Sequential Probability Ratio Test for a series of measurements\n", "- Define what drift and diffusion mean in a drift-diffusion model\n", "- Explain the speed-accuracy trade-off in a drift diffusion model\n", "\n", "**Summary of Exercises**\n", "\n", "0. Bonus (math): derive the Drift Diffusion Model mathematically from SPRT\n", "\n", "1. Simulate the DDM\n", " 1. _Code_: Accumulate evidence and make a decision (DDM)\n", " 2. _Interactive_: Manipulate parameters and interpret\n", "\n", "2. Analyze the DDM\n", " 1. _Code_: Quantify speed-accuracy tradeoff\n", " 2. _Interactive_: Manipulate parameters and interpret" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Video 1: Overview of Tutorials on Hidden Dynamics\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "remove-input" ] }, "outputs": [], "source": [ "# @title Video 1: Overview of Tutorials on Hidden Dynamics\n", "from ipywidgets import widgets\n", "from IPython.display import YouTubeVideo\n", "from IPython.display import IFrame\n", "from IPython.display import display\n", "\n", "\n", "class PlayVideo(IFrame):\n", " def __init__(self, id, source, page=1, width=400, height=300, **kwargs):\n", " self.id = id\n", " if source == 'Bilibili':\n", " src = f'https://player.bilibili.com/player.html?bvid={id}&page={page}'\n", " elif source == 'Osf':\n", " src = f'https://mfr.ca-1.osf.io/render?url=https://osf.io/download/{id}/?direct%26mode=render'\n", " super(PlayVideo, self).__init__(src, width, height, **kwargs)\n", "\n", "\n", "def display_videos(video_ids, W=400, H=300, fs=1):\n", " tab_contents = []\n", " for i, video_id in enumerate(video_ids):\n", " out = widgets.Output()\n", " with out:\n", " if video_ids[i] == 'Youtube':\n", " video = YouTubeVideo(id=video_ids[i], width=W,\n", " height=H, fs=fs, rel=0)\n", " print(f'Video available at https://youtube.com/watch?v={video.id}')\n", " else:\n", " video = PlayVideo(id=video_ids[i], source=video_ids[i], width=W,\n", " height=H, fs=fs, autoplay=False)\n", " if video_ids[i] == 'Bilibili':\n", " print(f'Video available at https://www.bilibili.com/video/{video.id}')\n", " elif video_ids[i] == 'Osf':\n", " print(f'Video available at https://osf.io/{video.id}') \n", " display(video)\n", " tab_contents.append(out)\n", " return tab_contents\n", "\n", "\n", "video_ids = [('Youtube', 'HH7HkQ1kv5M'), ('Bilibili', 'BV1Eh411r7hm')]\n", "tab_contents = display_videos(video_ids, W=730, H=410)\n", "tabs = widgets.Tab()\n", "tabs.children = tab_contents\n", "for i in range(len(tab_contents)):\n", " tabs.set_title(i, video_ids[i])\n", "display(tabs)" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "---\n", "# Setup" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "execution": {} }, "outputs": [], "source": [ "# Imports\n", "import numpy as np\n", "from scipy import stats\n", "import matplotlib.pyplot as plt\n", "from scipy.special import erf" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Figure Settings\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Figure Settings\n", "import ipywidgets as widgets # interactive display\n", "%config InlineBackend.figure_format = 'retina'\n", "plt.style.use(\"https://raw.githubusercontent.com/NeuromatchAcademy/course-content/NMA2020/nma.mplstyle\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Helper Functions\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @title Helper Functions\n", "\n", "def simulate_and_plot_SPRT_fixedtime(mu, sigma, stop_time, num_sample,\n", " verbose=True):\n", " \"\"\"Simulate and plot a SPRT for a fixed amount of time given a std.\n", "\n", " Args:\n", " mu (float): absolute mean value of the symmetric observation distributions\n", " sigma (float): Standard deviation of the observations.\n", " stop_time (int): Number of steps to run before stopping.\n", " num_sample (int): The number of samples to plot.\n", " \"\"\"\n", "\n", " evidence_history_list = []\n", " if verbose:\n", " print(\"#Trial\\tTotal_Evidence\\tDecision\")\n", " for i in range(num_sample):\n", " evidence_history, decision, Mvec = simulate_SPRT_fixedtime(mu, sigma, stop_time)\n", " if verbose:\n", " print(\"{}\\t{:f}\\t{}\".format(i, evidence_history[-1], decision))\n", " evidence_history_list.append(evidence_history)\n", "\n", " fig, ax = plt.subplots()\n", " maxlen_evidence = np.max(list(map(len,evidence_history_list)))\n", " ax.plot(np.zeros(maxlen_evidence), '--', c='red', alpha=1.0)\n", " for evidences in evidence_history_list:\n", " ax.plot(np.arange(len(evidences)), evidences)\n", " ax.set_xlabel(\"Time\")\n", " ax.set_ylabel(\"Accumulated log likelihood ratio\")\n", " ax.set_title(\"Log likelihood ratio trajectories under the fixed-time \" +\n", " \"stopping rule\")\n", "\n", " plt.show(fig)\n", "\n", "\n", "def plot_accuracy_vs_stoptime(mu, sigma, stop_time_list, accuracy_analytical_list, accuracy_list=None):\n", " \"\"\"Simulate and plot a SPRT for a fixed amount of times given a std.\n", "\n", " Args:\n", " mu (float): absolute mean value of the symmetric observation distributions\n", " sigma (float): Standard deviation of the observations.\n", " stop_time_list (int): List of number of steps to run before stopping.\n", " accuracy_analytical_list (int): List of analytical accuracies for each stop time\n", " accuracy_list (int (optional)): List of simulated accuracies for each stop time\n", " \"\"\"\n", " T = stop_time_list[-1]\n", " fig, ax = plt.subplots(figsize=(12,8))\n", " ax.set_xlabel('Stop Time')\n", " ax.set_ylabel('Average Accuracy')\n", " ax.plot(stop_time_list, accuracy_analytical_list)\n", " if accuracy_list is not None:\n", " ax.plot(stop_time_list, accuracy_list)\n", " ax.legend(['analytical','simulated'], loc='upper center')\n", "\n", " # Show two gaussian\n", " stop_time_list_plot = [max(1,T//10), T*2//3]\n", " sigma_st_max = 2*mu*np.sqrt(stop_time_list_plot[-1])/sigma\n", " domain = np.linspace(-3*sigma_st_max,3*sigma_st_max,50)\n", " for stop_time in stop_time_list_plot:\n", " ins = ax.inset_axes([stop_time/T,0.05,0.2,0.3])\n", " for pos in ['right', 'top', 'bottom', 'left']:\n", " ins.spines[pos].set_visible(False)\n", " ins.axis('off')\n", " ins.set_title(f\"stop_time={stop_time}\")\n", "\n", " left = np.zeros_like(domain)\n", " mu_st = 4*mu*mu*stop_time/2/sigma**2\n", " sigma_st = 2*mu*np.sqrt(stop_time)/sigma\n", " for i, mu1 in enumerate([-mu_st,mu_st]):\n", " rv = stats.norm(mu1, sigma_st)\n", " offset = rv.pdf(domain)\n", " # lbl = \"measurement distribution\" if i==0 else \"\"\n", " lbl = \"summed evidence\" if i==1 else \"\"\n", " color = \"crimson\"\n", " ls = \"solid\" if i==1 else \"dashed\"\n", " ins.plot(domain, left+offset, label=lbl, color=color,ls=ls)\n", "\n", " rv = stats.norm(mu_st, sigma_st)\n", " domain0 = np.linspace(-3*sigma_st_max,0,50)\n", " offset = rv.pdf(domain0)\n", " ins.fill_between(domain0, np.zeros_like(domain0), offset, color=\"crimson\", label=\"error\")\n", " ins.legend(bbox_to_anchor=(1.05, 1.0), loc='upper left')\n", "\n", "\n", " # ins.legend(loc=\"upper right\")\n", "\n", " plt.show(fig)\n", "\n", "\n", "def simulate_and_plot_SPRT_fixedthreshold(mu, sigma, num_sample, alpha,\n", " verbose=True):\n", " \"\"\"Simulate and plot a SPRT for a fixed amount of times given a std.\n", "\n", " Args:\n", " mu (float): absolute mean value of the symmetric observation distributions\n", " sigma (float): Standard deviation of the observations.\n", " num_sample (int): The number of samples to plot.\n", " alpha (float): Threshold for making a decision.\n", " \"\"\"\n", " # calculate evidence threshold from error rate\n", " threshold = threshold_from_errorrate(alpha)\n", "\n", " # run simulation\n", " evidence_history_list = []\n", " if verbose:\n", " print(\"#Trial\\tTime\\tAccumulated Evidence\\tDecision\")\n", " for i in range(num_sample):\n", " evidence_history, decision, Mvec = simulate_SPRT_threshold(mu, sigma, threshold)\n", " if verbose:\n", " print(\"{}\\t{}\\t{:f}\\t{}\".format(i, len(Mvec), evidence_history[-1],\n", " decision))\n", " evidence_history_list.append(evidence_history)\n", "\n", " fig, ax = plt.subplots()\n", " maxlen_evidence = np.max(list(map(len,evidence_history_list)))\n", " ax.plot(np.repeat(threshold,maxlen_evidence + 1), c=\"red\")\n", " ax.plot(-np.repeat(threshold,maxlen_evidence + 1), c=\"red\")\n", " ax.plot(np.zeros(maxlen_evidence + 1), '--', c='red', alpha=0.5)\n", "\n", " for evidences in evidence_history_list:\n", " ax.plot(np.arange(len(evidences) + 1), np.concatenate([, evidences]))\n", "\n", " ax.set_xlabel(\"Time\")\n", " ax.set_ylabel(\"Accumulated log likelihood ratio\")\n", " ax.set_title(\"Log likelihood ratio trajectories under the threshold rule\")\n", "\n", " plt.show(fig)\n", "\n", "\n", "def simulate_and_plot_accuracy_vs_threshold(mu, sigma, threshold_list, num_sample):\n", " \"\"\"Simulate and plot a SPRT for a set of thresholds given a std.\n", "\n", " Args:\n", " mu (float): absolute mean value of the symmetric observation distributions\n", " sigma (float): Standard deviation of the observations.\n", " alpha_list (float): List of thresholds for making a decision.\n", " num_sample (int): The number of samples to plot.\n", " \"\"\"\n", " accuracies, decision_speeds = simulate_accuracy_vs_threshold(mu, sigma,\n", " threshold_list,\n", " num_sample)\n", "\n", " # Plotting\n", " fig, ax = plt.subplots()\n", " ax.plot(decision_speeds, accuracies, linestyle=\"--\", marker=\"o\")\n", " ax.plot([np.amin(decision_speeds), np.amax(decision_speeds)],\n", " [0.5, 0.5], c='red')\n", " ax.set_xlabel(\"Average Decision speed\")\n", " ax.set_ylabel('Average Accuracy')\n", " ax.set_title(\"Speed/Accuracy Tradeoff\")\n", " ax.set_ylim(0.45, 1.05)\n", "\n", " plt.show(fig)\n", "\n", "\n", "def threshold_from_errorrate(alpha):\n", " \"\"\"Calculate log likelihood ratio threshold from desired error rate alpha\n", "\n", " Args:\n", " alpha (float): in (0,1), the desired error rate\n", "\n", " Return:\n", " threshold: corresponding evidence threshold\n", " \"\"\"\n", " threshold = np.log((1. - alpha) / alpha)\n", " return threshold" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "---\n", "\n", "# Section 1: Sequential Probability Ratio Test as a Drift Diffusion Model\n", "\n", "*Estimated timing to here from start of tutorial: 8 min*" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Video 2: Sequential Probability Ratio Test\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "remove-input" ] }, "outputs": [], "source": [ "# @title Video 2: Sequential Probability Ratio Test\n", "from ipywidgets import widgets\n", "from IPython.display import YouTubeVideo\n", "from IPython.display import IFrame\n", "from IPython.display import display\n", "\n", "\n", "class PlayVideo(IFrame):\n", " def __init__(self, id, source, page=1, width=400, height=300, **kwargs):\n", " self.id = id\n", " if source == 'Bilibili':\n", " src = f'https://player.bilibili.com/player.html?bvid={id}&page={page}'\n", " elif source == 'Osf':\n", " src = f'https://mfr.ca-1.osf.io/render?url=https://osf.io/download/{id}/?direct%26mode=render'\n", " super(PlayVideo, self).__init__(src, width, height, **kwargs)\n", "\n", "\n", "def display_videos(video_ids, W=400, H=300, fs=1):\n", " tab_contents = []\n", " for i, video_id in enumerate(video_ids):\n", " out = widgets.Output()\n", " with out:\n", " if video_ids[i] == 'Youtube':\n", " video = YouTubeVideo(id=video_ids[i], width=W,\n", " height=H, fs=fs, rel=0)\n", " print(f'Video available at https://youtube.com/watch?v={video.id}')\n", " else:\n", " video = PlayVideo(id=video_ids[i], source=video_ids[i], width=W,\n", " height=H, fs=fs, autoplay=False)\n", " if video_ids[i] == 'Bilibili':\n", " print(f'Video available at https://www.bilibili.com/video/{video.id}')\n", " elif video_ids[i] == 'Osf':\n", " print(f'Video available at https://osf.io/{video.id}') \n", " display(video)\n", " tab_contents.append(out)\n", " return tab_contents\n", "\n", "\n", "video_ids = [('Youtube', 'vv0yukRSTT0'), ('Bilibili', 'BV1Yo4y1D7Be')]\n", "tab_contents = display_videos(video_ids, W=730, H=410)\n", "tabs = widgets.Tab()\n", "tabs.children = tab_contents\n", "for i in range(len(tab_contents)):\n", " tabs.set_title(i, video_ids[i])\n", "display(tabs)" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "This video covers the definition of and math behind the sequential probability ratio test (SPRT), and introduces the idea of the SPRT as a drift diffusion model." ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "
\n", "Click here for text recap of video\n", "\n", "\n", "**Sequential Probability Ratio Test**\n", "\n", "The Sequential Probability Ratio Test is a likelihood ratio test for determining which of two hypotheses is more likely. It is appropriate for sequential independent and identially distributed (iid) data. iid means that the data comes from the same distribution.\n", "\n", "Let's return to what we learned yesterday. We had probabilities of our measurement ($m$) given a state of the world ($s$). For example, we knew the probability of seeing someone catch a fish while fishing on the left side given that the fish were on the left side $P(m = \\textrm{catch fish} | s = \\textrm{left})$.\n", "\n", "Now let's extend this slightly to assume we take a series of measurements, from time 1 up to time t ($m_{1:t}$), and that our state is either +1 or -1. We want to figure out what the state is, given our measurements. To do this, we can compare the total evidence up to time $t$ for our two hypotheses (that the state is +1 or that the state is -1). We do this by computing a likelihood ratio: the ratio of the likelihood of all these measurements given the state is +1, $p(m_{1:t}|s=+1)$, to the likelihood of the measurements given the state is -1, $p(m_{1:t}|s=-1)$. This is our likelihood ratio test. In fact, we want to take the log of this likelihood ratio to give us the log likelihood ratio $L_T$.\n", "\n", "\\begin{align*}\n", "L_T &= log\\frac{p(m_{1:t}|s=+1)}{p(m_{1:t}|s=-1)}\n", "\\end{align*}\n", "\n", "Since our data is independent and identically distribution, the probability of all measurements given the state equals the product of the separate probabilities of each measurement given the state ($p(m_{1:t}|s) = \\prod_{t=1}^T p(m_t | s)$). We can substitute this in and use log properties to convert to a sum.\n", "\n", "\\begin{align*}\n", "L_T &= log\\frac{p(m_{1:t}|s=+1)}{p(m_{1:t}|s=-1)}\\\\\n", "&= log\\frac{\\prod_{t=1}^Tp(m_{t}|s=+1)}{\\prod_{t=1}^Tp(m_{t}|s=-1)}\\\\\n", "&= \\sum_{t=1}^T log\\frac{p(m_{t}|s=+1)}{p(m_{t}|s=-1)}\\\\\n", "&= \\sum_{t=1}^T \\Delta_t\n", "\\end{align*}\n", "\n", "In the last line, we have used $\\Delta_t = log\\frac{p(m_{t}|s=+1)}{p(m_{t}|s=-1)}$. \n", "\n", "To get the full log likelihood ratio, we are summing up the log likelihood ratios at each time step. The log likelihood ratio at a time step ($L_T$) will equal the ratio at the previous time step ($L_{T-1}$) plus the ratio for the measurement at that time step, given by $\\Delta_T$:\n", "\n", "\\begin{align*}\n", "L_T = L_{T-1} + \\Delta_T\n", "\\end{align*}\n", "\n", "The SPRT states that if $L_T$ is positive, then the state $s=+1$ is more likely than $s=-1$! \n", "\n", "\n", "**Sequential Probability Ratio Test as a Drift Diffusion Model**\n", "\n", "Let's assume that the probability of seeing a measurement given the state is a Gaussian (Normal) distribution where the mean ($\\mu$) is different for the two states but the standard deviation ($\\sigma$) is the same:\n", "\n", "\\begin{align*}\n", "p(m_t | s = +1) &= \\mathcal{N}(\\mu, \\sigma^2)\\\\\n", "p(m_t | s = -1) &= \\mathcal{N}(-\\mu, \\sigma^2)\\\\\n", "\\end{align*}\n", "\n", "We can write the new evidence (the log likelihood ratio for the measurement at time $t$) as\n", "\n", "$$\\Delta_t=b+c\\epsilon_t$$\n", "\n", "The first term, $b$, is a consistant value and equals $b=2\\mu^2/\\sigma^2$. This term favors the actual hidden state. The second term, $c\\epsilon_t$ where $\\epsilon_t\\sim\\mathcal{N}(0,1)$, is a standard random variable which is scaled by the diffusion $c=2\\mu/\\sigma$. You can work through proving this in the bonus exercise 0 below if you wish!\n", "\n", "The accumulation of evidence will thus \"drift\" toward one outcome, while \"diffusing\" in random directions, hence the term \"drift-diffusion model\" (DDM). The process is most likely (but not guaranteed) to reach the correct outcome eventually.\n", "\n", "\n", "
" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "
\n", "Bonus math exercise 0: derive Drift Diffusion Model from SPRT\n", "\n", "\n", "We can do a little math to find the SPRT update $\\Delta_t$ to the log-likelihood ratio. You can derive this yourself, filling in the steps below, or skip to the end result.\n", "\n", "Assume measurements are Gaussian-distributed with different means depending on the discrete latent variable $s$:\n", "\n", "\\begin{equation}\n", "p(m|s=\\pm 1) = \\mathcal{N}\\left(\\mu_\\pm,\\sigma^2\\right)=\\frac{1}{\\sqrt{2\\pi\\sigma^2}}\\exp{\\left[-\\frac{(m-\\mu_\\pm)^2}{2\\sigma^2}\\right]}\n", "\\end{equation}\n", "\n", "In the log likelihood ratio for a single data point $m_i$, the normalizations cancel to give\n", "\n", "\\begin{equation}\n", "\\Delta_t=\\log \\frac{p(m_t|s=+1)}{p(m_t|s=-1)} = \\frac{1}{2\\sigma^2}\\left[-\\left(m_t-\\mu_+\\right)^2 + (m_t-\\mu_-)^2\\right] \\tag{5}\n", "\\end{equation}\n", "\n", "It's convenient to rewrite $m=\\mu_\\pm + \\sigma \\epsilon$, where $\\epsilon\\sim \\mathcal{N}(0,1)$ is a standard Gaussian variable with zero mean and unit variance. (Why does this give the correct probability for $m$?). The preceding formula can then be rewritten as \n", "$$\\Delta_t = \\frac{1}{2\\sigma^2}\\left( -((\\mu_\\pm+\\sigma\\epsilon)-\\mu_+)^2 + ((\\mu_\\pm+\\sigma\\epsilon)-\\mu_-)^2\\right) \\tag{5}$$\n", "Let's assume that $s=+1$ so $\\mu_\\pm=\\mu_+$ (if $s=-1$ then the result is the same with a reversed sign). In that case, the means in the first term $m_t-\\mu_+$ cancel, leaving\n", "\n", "\\begin{equation}\n", "\\Delta_t = \\frac{\\delta^2\\mu^2}{2\\sigma^2}+\\frac{\\delta\\mu}{\\sigma}\\epsilon_t \\tag{5}\n", "\\end{equation}\n", "\n", "where $\\delta\\mu=\\mu_+-\\mu_-$. If we take $\\mu_\\pm=\\pm\\mu$, then $\\delta\\mu=2\\mu$, and\n", "\n", "\\begin{equation}\n", "\\Delta_t=2\\frac{\\mu^2}{\\sigma^2}+2\\frac{\\mu}{\\sigma}\\epsilon_t\n", "\\end{equation}\n", "\n", "The first term is a constant *drift*, and the second term is a random *diffusion*.\n", "\n", "The SPRT says that we should add up these evidences, $L_T=\\sum_{t=1}^T \\Delta_t$. Note that the $\\Delta_t$ are independent. Recall that for independent random variables, the mean of a sum is the sum of the means. And the variance of a sum is the sum of the variances. \n", "\n", "
\n", "\n", "Adding these $\\Delta_t$ over time gives\n", "\n", "\\begin{equation}\n", "L_T\\sim\\mathcal{N}\\left(2\\frac{\\mu^2}{\\sigma^2}T,\\ 4\\frac{\\mu^2}{\\sigma^2}T\\right)=\\mathcal{N}(bT,c^2T)\n", "\\end{equation}\n", "\n", "as claimed. The log-likelihood ratio $L_t$ is a biased random walk --- normally distributed with a time-dependent mean and variance. This is the **Drift Diffusion Model**." ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "## Coding Exercise 1.1: Simulating an SPRT model\n", "\n", "Let's now generate simulated data with $s=+1$ and see if the SPRT can infer the state correctly.\n", "\n", "We will implement a function simulate_SPRT_fixedtime, which will generate measurements based on $\\mu$, $\\sigma$, and the true state. It will then accumulate evidence over the time steps and output a decision on the state. The decision will be the state that is more likely according to the accumulated evidence. We will use the helper function log_likelihood_ratio, implemented in the next cell, which computes the log of the likelihood of the state being 1 divided by the likelihood of the state being -1. \n", "\n", "**Your coding tasks are:**\n", "\n", "**Step 1**: accumulate evidence.\n", "\n", "**Step 2**: make a decision at the last time point.\n", "\n", "We will then visualize 10 simulations of the DDM. In the next exercise you'll see how the parameters affect performance.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ " Execute this cell to enable the helper function log_likelihood_ratio\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @markdown Execute this cell to enable the helper function log_likelihood_ratio\n", "\n", "def log_likelihood_ratio(Mvec, p0, p1):\n", " \"\"\"Given a sequence(vector) of observed data, calculate the log of\n", " likelihood ratio of p1 and p0\n", "\n", " Args:\n", " Mvec (numpy vector): A vector of scalar measurements\n", " p0 (Gaussian random variable): A normal random variable with logpdf'\n", " method\n", " p1 (Gaussian random variable): A normal random variable with logpdf\n", " method\n", "\n", " Returns:\n", " llvec: a vector of log likelihood ratios for each input data point\n", " \"\"\"\n", " return p1.logpdf(Mvec) - p0.logpdf(Mvec)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "execution": {} }, "outputs": [], "source": [ "def simulate_SPRT_fixedtime(mu, sigma, stop_time, true_dist = 1):\n", " \"\"\"Simulate a Sequential Probability Ratio Test with fixed time stopping\n", " rule. Two observation models are 1D Gaussian distributions N(1,sigma^2) and\n", " N(-1,sigma^2).\n", "\n", " Args:\n", " mu (float): absolute mean value of the symmetric observation distributions\n", " sigma (float): Standard deviation of observation models\n", " stop_time (int): Number of samples to take before stopping\n", " true_dist (1 or -1): Which state is the true state.\n", "\n", " Returns:\n", " evidence_history (numpy vector): the history of cumulated evidence given\n", " generated data\n", " decision (int): 1 for s = 1, -1 for s = -1\n", " Mvec (numpy vector): the generated sequences of measurement data in this trial\n", " \"\"\"\n", "\n", " #################################################\n", " ## TODO for students ##\n", " # Fill out function and remove\n", " raise NotImplementedError(\"Student exercise: complete simulate_SPRT_fixedtime\")\n", " #################################################\n", "\n", " # Set means of observation distributions\n", " assert mu > 0, \"Mu should be > 0\"\n", " mu_pos = mu\n", " mu_neg = -mu\n", "\n", " # Make observation distributions\n", " p_pos = stats.norm(loc = mu_pos, scale = sigma)\n", " p_neg = stats.norm(loc = mu_neg, scale = sigma)\n", "\n", " # Generate a random sequence of measurements\n", " if true_dist == 1:\n", " Mvec = p_pos.rvs(size = stop_time)\n", " else:\n", " Mvec = p_neg.rvs(size = stop_time)\n", "\n", " # Calculate log likelihood ratio for each measurement (delta_t)\n", " ll_ratio_vec = log_likelihood_ratio(Mvec, p_neg, p_pos)\n", "\n", " # STEP 1: Calculate accumulated evidence (S) given a time series of evidence (hint: np.cumsum)\n", " evidence_history = ...\n", "\n", " # STEP 2: Make decision based on the sign of the evidence at the final time.\n", " decision = ...\n", "\n", " return evidence_history, decision, Mvec\n", "\n", "\n", "# Set random seed\n", "np.random.seed(100)\n", "\n", "# Set model parameters\n", "mu = .2\n", "sigma = 3.5 # standard deviation for p+ and p-\n", "num_sample = 10 # number of simulations to run\n", "stop_time = 150 # number of steps before stopping\n", "\n", "# Simulate and visualize\n", "simulate_and_plot_SPRT_fixedtime(mu, sigma, stop_time, num_sample)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "execution": {} }, "source": [ "[*Click for solution*](https://github.com/NeuromatchAcademy/course-content/tree/main//tutorials/W3D2_HiddenDynamics/solutions/W3D2_Tutorial1_Solution_985833af.py)\n", "\n", "*Example output:*\n", "\n", " \n", "\n" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "## Interactive Demo 1.2: Trajectories under the fixed-time stopping rule\n", "\n", "\n", "In the following demo, you can change the drift level (mu), noise level (sigma) in the observation model and the number of time steps before stopping (stop_time) using the sliders. You will then observe 10 simulations with those parameters. As in the previous exercise, the true state is +1.\n", " \n", "\n", "\n", "1. Are you more likely to make the wrong decision (choose the incorrect state) with high or low noise?\n", "2. What happens when sigma is very small? Why?\n", "3. Are you more likely to make the wrong decision (choose the incorrect state) with fewer or more time steps before stopping?\n", "\n", "\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ " Make sure you execute this cell to enable the widget!\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @markdown Make sure you execute this cell to enable the widget!\n", "\n", "def simulate_SPRT_fixedtime(mu, sigma, stop_time, true_dist = 1):\n", " \"\"\"Simulate a Sequential Probability Ratio Test with fixed time stopping\n", " rule. Two observation models are 1D Gaussian distributions N(1,sigma^2) and\n", " N(-1,sigma^2).\n", "\n", " Args:\n", " mu (float): absolute mean value of the symmetric observation distributions\n", " sigma (float): Standard deviation of observation models\n", " stop_time (int): Number of samples to take before stopping\n", " true_dist (1 or -1): Which state is the true state.\n", "\n", " Returns:\n", " evidence_history (numpy vector): the history of cumulated evidence given\n", " generated data\n", " decision (int): 1 for s = 1, -1 for s = -1\n", " Mvec (numpy vector): the generated sequences of measurement data in this trial\n", " \"\"\"\n", "\n", " # Set means of observation distributions\n", " assert mu > 0, \"Mu should be >0\"\n", " mu_pos = mu\n", " mu_neg = -mu\n", "\n", " # Make observation distributions\n", " p_pos = stats.norm(loc = mu_pos, scale = sigma)\n", " p_neg = stats.norm(loc = mu_neg, scale = sigma)\n", "\n", " # Generate a random sequence of measurements\n", " if true_dist == 1:\n", " Mvec = p_pos.rvs(size = stop_time)\n", " else:\n", " Mvec = p_neg.rvs(size = stop_time)\n", "\n", " # Calculate log likelihood ratio for each measurement (delta_t)\n", " ll_ratio_vec = log_likelihood_ratio(Mvec, p_neg, p_pos)\n", "\n", " # STEP 1: Calculate accumulated evidence (S) given a time series of evidence (hint: np.cumsum)\n", " evidence_history = np.cumsum(ll_ratio_vec)\n", "\n", " # STEP 2: Make decision based on the sign of the evidence at the final time.\n", " decision = np.sign(evidence_history[-1])\n", "\n", " return evidence_history, decision, Mvec\n", "\n", "np.random.seed(100)\n", "num_sample = 10\n", "\n", "@widgets.interact(mu=widgets.FloatSlider(min=0.1, max=5.0, step=0.1, value=0.5), sigma=(0.05, 10.0, 0.05), stop_time=(5, 500, 1))\n", "def plot(mu, sigma, stop_time):\n", " simulate_and_plot_SPRT_fixedtime(mu, sigma, stop_time, num_sample, verbose=False)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "execution": {} }, "source": [ "[*Click for solution*](https://github.com/NeuromatchAcademy/course-content/tree/main//tutorials/W3D2_HiddenDynamics/solutions/W3D2_Tutorial1_Solution_923ea26b.py)\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Video 3: Section 1 Exercises Discussion\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "remove-input" ] }, "outputs": [], "source": [ "# @title Video 3: Section 1 Exercises Discussion\n", "from ipywidgets import widgets\n", "from IPython.display import YouTubeVideo\n", "from IPython.display import IFrame\n", "from IPython.display import display\n", "\n", "\n", "class PlayVideo(IFrame):\n", " def __init__(self, id, source, page=1, width=400, height=300, **kwargs):\n", " self.id = id\n", " if source == 'Bilibili':\n", " src = f'https://player.bilibili.com/player.html?bvid={id}&page={page}'\n", " elif source == 'Osf':\n", " src = f'https://mfr.ca-1.osf.io/render?url=https://osf.io/download/{id}/?direct%26mode=render'\n", " super(PlayVideo, self).__init__(src, width, height, **kwargs)\n", "\n", "\n", "def display_videos(video_ids, W=400, H=300, fs=1):\n", " tab_contents = []\n", " for i, video_id in enumerate(video_ids):\n", " out = widgets.Output()\n", " with out:\n", " if video_ids[i] == 'Youtube':\n", " video = YouTubeVideo(id=video_ids[i], width=W,\n", " height=H, fs=fs, rel=0)\n", " print(f'Video available at https://youtube.com/watch?v={video.id}')\n", " else:\n", " video = PlayVideo(id=video_ids[i], source=video_ids[i], width=W,\n", " height=H, fs=fs, autoplay=False)\n", " if video_ids[i] == 'Bilibili':\n", " print(f'Video available at https://www.bilibili.com/video/{video.id}')\n", " elif video_ids[i] == 'Osf':\n", " print(f'Video available at https://osf.io/{video.id}') \n", " display(video)\n", " tab_contents.append(out)\n", " return tab_contents\n", "\n", "\n", "video_ids = [('Youtube', 'P6xuOS5TB7Q'), ('Bilibili', 'BV1h54y1E7UC')]\n", "tab_contents = display_videos(video_ids, W=730, H=410)\n", "tabs = widgets.Tab()\n", "tabs.children = tab_contents\n", "for i in range(len(tab_contents)):\n", " tabs.set_title(i, video_ids[i])\n", "display(tabs)" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "---\n", "# Section 2: Analyzing the DDM: accuracy vs stopping time\n", "\n", "*Estimated timing to here from start of tutorial: 28 min*" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Video 4: Speed vs Accuracy Tradeoff\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "remove-input" ] }, "outputs": [], "source": [ "# @title Video 4: Speed vs Accuracy Tradeoff\n", "from ipywidgets import widgets\n", "from IPython.display import YouTubeVideo\n", "from IPython.display import IFrame\n", "from IPython.display import display\n", "\n", "\n", "class PlayVideo(IFrame):\n", " def __init__(self, id, source, page=1, width=400, height=300, **kwargs):\n", " self.id = id\n", " if source == 'Bilibili':\n", " src = f'https://player.bilibili.com/player.html?bvid={id}&page={page}'\n", " elif source == 'Osf':\n", " src = f'https://mfr.ca-1.osf.io/render?url=https://osf.io/download/{id}/?direct%26mode=render'\n", " super(PlayVideo, self).__init__(src, width, height, **kwargs)\n", "\n", "\n", "def display_videos(video_ids, W=400, H=300, fs=1):\n", " tab_contents = []\n", " for i, video_id in enumerate(video_ids):\n", " out = widgets.Output()\n", " with out:\n", " if video_ids[i] == 'Youtube':\n", " video = YouTubeVideo(id=video_ids[i], width=W,\n", " height=H, fs=fs, rel=0)\n", " print(f'Video available at https://youtube.com/watch?v={video.id}')\n", " else:\n", " video = PlayVideo(id=video_ids[i], source=video_ids[i], width=W,\n", " height=H, fs=fs, autoplay=False)\n", " if video_ids[i] == 'Bilibili':\n", " print(f'Video available at https://www.bilibili.com/video/{video.id}')\n", " elif video_ids[i] == 'Osf':\n", " print(f'Video available at https://osf.io/{video.id}') \n", " display(video)\n", " tab_contents.append(out)\n", " return tab_contents\n", "\n", "\n", "video_ids = [('Youtube', 'Hc3uXQiKvZA'), ('Bilibili', 'BV1s54y1E7yT')]\n", "tab_contents = display_videos(video_ids, W=730, H=410)\n", "tabs = widgets.Tab()\n", "tabs.children = tab_contents\n", "for i in range(len(tab_contents)):\n", " tabs.set_title(i, video_ids[i])\n", "display(tabs)" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "If you make a hasty decision (e.g., after only seeing 2 samples), or if observation noise buries the signal, you may see a negative accumulated log likelihood ratio and thus make a wrong decision. Let's plot how decision accuracy varies with the number of samples. Accuracy is the proportion of correct trials across our repeated simulations: $\\frac{\\# \\textrm{ correct decisions}}{\\# \\textrm{ total decisions}}$.\n", "\n", "\n", "\n" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "### Coding Exercise 2.1: The Speed/Accuracy Tradeoff\n", "\n", "We will fix our observation noise level. In this exercise you will implement a function to run many simulations for a certain stopping time, and calculate the _average decision accuracy_. We will then visualize the relation between average decision accuracy and stopping time. " ] }, { "cell_type": "code", "execution_count": null, "metadata": { "execution": {} }, "outputs": [], "source": [ "def simulate_accuracy_vs_stoptime(mu, sigma, stop_time_list, num_sample, no_numerical=False):\n", " \"\"\"Calculate the average decision accuracy vs. stopping time by running\n", " repeated SPRT simulations for each stop time.\n", "\n", " Args:\n", " mu (float): absolute mean value of the symmetric observation distributions\n", " sigma (float): standard deviation for observation model\n", " stop_list_list (list-like object): a list of stopping times to run over\n", " num_sample (int): number of simulations to run per stopping time\n", " no_numerical (bool): flag that indicates the function to return analytical values only\n", "\n", " Returns:\n", " accuracy_list: a list of average accuracies corresponding to input\n", " stop_time_list\n", " decisions_list: a list of decisions made in all trials\n", " \"\"\"\n", "\n", " #################################################\n", " ## TODO for students##\n", " # Fill out function and remove\n", " raise NotImplementedError(\"Student exercise: complete simulate_accuracy_vs_stoptime\")\n", " #################################################\n", "\n", " # Determine true state (1 or -1)\n", " true_dist = 1\n", "\n", " # Set up tracker of accuracy and decisions\n", " accuracies = np.zeros(len(stop_time_list),)\n", " accuracies_analytical = np.zeros(len(stop_time_list),)\n", " decisions_list = []\n", "\n", " # Loop over stop times\n", " for i_stop_time, stop_time in enumerate(stop_time_list):\n", "\n", " if not no_numerical:\n", " # Set up tracker of decisions for this stop time\n", " decisions = np.zeros((num_sample,))\n", "\n", " # Loop over samples\n", " for i in range(num_sample):\n", "\n", " # STEP 1: Simulate run for this stop time (hint: use output from last exercise)\n", " _, decision, _= ...\n", "\n", " # Log decision\n", " decisions[i] = decision\n", "\n", " # STEP 2: Calculate accuracy by averaging over trials\n", " accuracies[i_stop_time] = ...\n", "\n", " # Log decision\n", " decisions_list.append(decisions)\n", "\n", " # Calculate analytical accuracy\n", " sigma_sum_gaussian = sigma / np.sqrt(stop_time)\n", " accuracies_analytical[i_stop_time] = 0.5 + 0.5 * erf(mu / np.sqrt(2) / sigma_sum_gaussian)\n", "\n", " return accuracies, accuracies_analytical, decisions_list\n", "\n", "\n", "# Set random seed\n", "np.random.seed(100)\n", "\n", "# Set parameters of model\n", "mu = 0.5\n", "sigma = 4.65 # standard deviation for observation noise\n", "num_sample = 100 # number of simulations to run for each stopping time\n", "stop_time_list = np.arange(1, 150, 10) # Array of stopping times to use\n", "\n", "\n", "# Calculate accuracies for each stop time\n", "accuracies, accuracies_analytical, _ = simulate_accuracy_vs_stoptime(mu, sigma, stop_time_list,\n", " num_sample)\n", "\n", "# Visualize\n", "plot_accuracy_vs_stoptime(mu, sigma, stop_time_list, accuracies_analytical, accuracies)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "execution": {} }, "source": [ "[*Click for solution*](https://github.com/NeuromatchAcademy/course-content/tree/main//tutorials/W3D2_HiddenDynamics/solutions/W3D2_Tutorial1_Solution_98b93ad3.py)\n", "\n", "*Example output:*\n", "\n", " \n", "\n" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "In the figure above, we are plotting the simulated accuracies in orange. We can actually find an analytical equation for the average accuracy in this specific case, which we plot in blue. We will not dive into this analytical solution here but you can imagine that if you ran a bunch of different simulations and had the equivalent number of orange lines, the average of those would resemble the blue line. \n", "\n", "In the insets, we are showing the evidence distributions for the two states at a certain time point. Recall from Section 1 that the likelihood ratio at time $T$ for state of +1 is: \n", "\n", "\\begin{equation}\n", "L_T\\sim\\mathcal{N}\\left(2\\frac{\\mu^2}{\\sigma^2}T,\\ 4\\frac{\\mu^2}{\\sigma^2}T\\right)=\\mathcal{N}(bT,c^2T)\n", "\\end{equation}\n", "\n", "If the state is -1, the mean is the reverse sign. We are plotting this Gaussian distribution for the state equaling -1 (dashed line) and the state equaling +1 (solid line). The area in red reflects the error rate - this region corresponds to $L_T$ being below 0 even though the true state is +1 so you would decide on the wrong state. As more time goes by, these distributions separate more and the error is lower." ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "### Interactive Demo 2.2: Accuracy versus stop-time\n", "\n", "For this same visualization, now vary the mean $\\mu$ and standard deviation sigma of the evidence. What do you predict will the accuracy vs stopping time plot look like for low noise and high noise?" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ " Make sure you execute this cell to enable the widget!\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "#@markdown Make sure you execute this cell to enable the widget!\n", "def simulate_accuracy_vs_stoptime(mu, sigma, stop_time_list, num_sample, no_numerical=False):\n", " \"\"\"Calculate the average decision accuracy vs. stopping time by running\n", " repeated SPRT simulations for each stop time.\n", "\n", " Args:\n", " mu (float): absolute mean value of the symmetric observation distributions\n", " sigma (float): standard deviation for observation model\n", " stop_list_list (list-like object): a list of stopping times to run over\n", " num_sample (int): number of simulations to run per stopping time\n", " no_numerical (bool): flag that indicates the function to return analytical values only\n", "\n", " Returns:\n", " accuracy_list: a list of average accuracies corresponding to input\n", " stop_time_list\n", " decisions_list: a list of decisions made in all trials\n", " \"\"\"\n", "\n", " # Determine true state (1 or -1)\n", " true_dist = 1\n", "\n", " # Set up tracker of accuracy and decisions\n", " accuracies = np.zeros(len(stop_time_list),)\n", " accuracies_analytical = np.zeros(len(stop_time_list),)\n", " decisions_list = []\n", "\n", " # Loop over stop times\n", " for i_stop_time, stop_time in enumerate(stop_time_list):\n", "\n", " if not no_numerical:\n", " # Set up tracker of decisions for this stop time\n", " decisions = np.zeros((num_sample,))\n", "\n", " # Loop over samples\n", " for i in range(num_sample):\n", "\n", " # Simulate run for this stop time (hint: last exercise)\n", " _, decision, _= simulate_SPRT_fixedtime(mu, sigma, stop_time, true_dist)\n", "\n", " # Log decision\n", " decisions[i] = decision\n", "\n", " # Calculate accuracy\n", " accuracies[i_stop_time] = np.sum(decisions == true_dist) / decisions.shape\n", " # Log decisions\n", " decisions_list.append(decisions)\n", "\n", " # Calculate analytical accuracy\n", " sigma_sum_gaussian = sigma / np.sqrt(stop_time)\n", " accuracies_analytical[i_stop_time] = 0.5 + 0.5 * erf(mu / np.sqrt(2) / sigma_sum_gaussian)\n", "\n", "\n", " return accuracies, accuracies_analytical, decisions_list\n", "\n", "np.random.seed(100)\n", "num_sample = 100\n", "stop_time_list = np.arange(1, 100, 1)\n", "\n", "@widgets.interact\n", "def plot(mu=widgets.FloatSlider(min=0.1, max=5.0, step=0.1, value=1.0), sigma=(0.05, 10.0, 0.05)):\n", " # Calculate accuracies for each stop time\n", " _, accuracies_analytical, _ = simulate_accuracy_vs_stoptime(mu, sigma, stop_time_list, num_sample, no_numerical=True)\n", "\n", " # Visualize\n", " plot_accuracy_vs_stoptime(mu, sigma, stop_time_list, accuracies_analytical)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "execution": {} }, "source": [ "[*Click for solution*](https://github.com/NeuromatchAcademy/course-content/tree/main//tutorials/W3D2_HiddenDynamics/solutions/W3D2_Tutorial1_Solution_f243732b.py)\n", "\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Video 5: Section 2 Exercises Discussion\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "remove-input" ] }, "outputs": [], "source": [ "# @title Video 5: Section 2 Exercises Discussion\n", "from ipywidgets import widgets\n", "from IPython.display import YouTubeVideo\n", "from IPython.display import IFrame\n", "from IPython.display import display\n", "\n", "\n", "class PlayVideo(IFrame):\n", " def __init__(self, id, source, page=1, width=400, height=300, **kwargs):\n", " self.id = id\n", " if source == 'Bilibili':\n", " src = f'https://player.bilibili.com/player.html?bvid={id}&page={page}'\n", " elif source == 'Osf':\n", " src = f'https://mfr.ca-1.osf.io/render?url=https://osf.io/download/{id}/?direct%26mode=render'\n", " super(PlayVideo, self).__init__(src, width, height, **kwargs)\n", "\n", "\n", "def display_videos(video_ids, W=400, H=300, fs=1):\n", " tab_contents = []\n", " for i, video_id in enumerate(video_ids):\n", " out = widgets.Output()\n", " with out:\n", " if video_ids[i] == 'Youtube':\n", " video = YouTubeVideo(id=video_ids[i], width=W,\n", " height=H, fs=fs, rel=0)\n", " print(f'Video available at https://youtube.com/watch?v={video.id}')\n", " else:\n", " video = PlayVideo(id=video_ids[i], source=video_ids[i], width=W,\n", " height=H, fs=fs, autoplay=False)\n", " if video_ids[i] == 'Bilibili':\n", " print(f'Video available at https://www.bilibili.com/video/{video.id}')\n", " elif video_ids[i] == 'Osf':\n", " print(f'Video available at https://osf.io/{video.id}') \n", " display(video)\n", " tab_contents.append(out)\n", " return tab_contents\n", "\n", "\n", "video_ids = [('Youtube', 'OBDv6nB6a2g'), ('Bilibili', 'BV11g411M7Lm')]\n", "tab_contents = display_videos(video_ids, W=730, H=410)\n", "tabs = widgets.Tab()\n", "tabs.children = tab_contents\n", "for i in range(len(tab_contents)):\n", " tabs.set_title(i, video_ids[i])\n", "display(tabs)" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "**Application**\n", "\n", "We have looked at the drift diffusion model of decisions in the context of the fishing problem. There are lots of uses of this in neuroscience! As one example, a classic experimental task in neuroscience is the random dot kinematogram ([Newsome, Britten, Movshon 1989](https://www.nature.com/articles/341052a0.pdf)), in which a pattern of moving dots are moving in random directions but with some weak coherence that favors a net rightward or leftward motion. The observer must guess the direction. Neurons in the brain are informative about this task, and have responses that correlate with the choice, as predicted by the Drift Diffusion Model (Huk and Shadlen 2005).\n", "\n", "Below is a video by Pamela Reinagle of a rat guessing the direction of motion in such a task." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ " Rat performing random dot motion task\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "remove-input" ] }, "outputs": [], "source": [ "# @markdown Rat performing random dot motion task\n", "from IPython.display import YouTubeVideo\n", "video = YouTubeVideo(id=\"oDxcyTn-0os\", width=730, height=410, fs=1)\n", "print(\"Video available at https://youtu.be/\" + video.id)\n", "video" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "After you finish the other tutorials, come back to see Bonus material to learn about a different stopping rule for DDMs: a fixed threshold on confidence." ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "---\n", "# Summary\n", "\n", "*Estimated timing of tutorial: 45 minutes*\n", "\n", "Good job! By simulating Drift Diffusion Models, you have learnt how to:\n", "\n", "* Calculate individual sample evidence as the log likelihood ratio of two candidate models\n", "* Accumulate evidence from new data points, and compute posterior using recursive formula\n", "* Run repeated simulations to get an estimate of decision accuracies\n", "* Measure the speed-accuracy tradeoff" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "---\n", "# Bonus " ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "---\n", "## Bonus Section 1: DDM with fixed thresholds on confidence" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Video 6: Fixed threshold on confidence\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "remove-input" ] }, "outputs": [], "source": [ "# @title Video 6: Fixed threshold on confidence\n", "from ipywidgets import widgets\n", "from IPython.display import YouTubeVideo\n", "from IPython.display import IFrame\n", "from IPython.display import display\n", "\n", "\n", "class PlayVideo(IFrame):\n", " def __init__(self, id, source, page=1, width=400, height=300, **kwargs):\n", " self.id = id\n", " if source == 'Bilibili':\n", " src = f'https://player.bilibili.com/player.html?bvid={id}&page={page}'\n", " elif source == 'Osf':\n", " src = f'https://mfr.ca-1.osf.io/render?url=https://osf.io/download/{id}/?direct%26mode=render'\n", " super(PlayVideo, self).__init__(src, width, height, **kwargs)\n", "\n", "\n", "def display_videos(video_ids, W=400, H=300, fs=1):\n", " tab_contents = []\n", " for i, video_id in enumerate(video_ids):\n", " out = widgets.Output()\n", " with out:\n", " if video_ids[i] == 'Youtube':\n", " video = YouTubeVideo(id=video_ids[i], width=W,\n", " height=H, fs=fs, rel=0)\n", " print(f'Video available at https://youtube.com/watch?v={video.id}')\n", " else:\n", " video = PlayVideo(id=video_ids[i], source=video_ids[i], width=W,\n", " height=H, fs=fs, autoplay=False)\n", " if video_ids[i] == 'Bilibili':\n", " print(f'Video available at https://www.bilibili.com/video/{video.id}')\n", " elif video_ids[i] == 'Osf':\n", " print(f'Video available at https://osf.io/{video.id}') \n", " display(video)\n", " tab_contents.append(out)\n", " return tab_contents\n", "\n", "\n", "video_ids = [('Youtube', 'E8lvgFeIGQM'), ('Bilibili', 'BV1Ya4y1a7c1')]\n", "tab_contents = display_videos(video_ids, W=730, H=410)\n", "tabs = widgets.Tab()\n", "tabs.children = tab_contents\n", "for i in range(len(tab_contents)):\n", " tabs.set_title(i, video_ids[i])\n", "display(tabs)" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "The next exercises consider a variant of the DDM with fixed confidence thresholds instead of fixed decision time. This may be a better description of neural integration. Please complete this material after you have finished the main content of all tutorials, if you would like extra information about this topic." ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "### Bonus Coding Exercise 1.1, Coding: Simulating the DDM with fixed confidence thresholds\n", "\n", "*Referred to as exercise 3 in video*\n", "\n", "In this exercise, we will use thresholding as our stopping rule and observe the behavior of the DDM. \n", "\n", "With thresholding stopping rule, we define a desired error rate and will continue making measurements until that error rate is reached. Experimental evidence suggested that evidence accumulation and thresholding stopping strategy happens at neuronal level (see [this article](https://www.annualreviews.org/doi/full/10.1146/annurev.neuro.29.051605.113038) for further reading).\n", "\n", "* Complete the function threshold_from_errorrate to calculate the evidence threshold from desired error rate $\\alpha$ as described in the formulas below. The evidence thresholds $th_1$ and $th_0$ for $p_+$ and $p_-$ are opposite of each other as shown below, so you can just return the absolute value.\n", "\n", "\\begin{align}\n", "th_{L} &= \\log \\frac{\\alpha}{1-\\alpha} = -th_{R} \\\\\n", "th_{R} &= \\log \\frac{1-\\alpha}{\\alpha} = -th{_1}\n", "\\end{align}\n", "\n", "* Complete the function simulate_SPRT_threshold to simulate an SPRT with thresholding stopping rule given noise level and desired threshold \n", "\n", "* Run repeated simulations for a given noise level and a desired error rate visualize the DDM traces using our provided code" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "execution": {} }, "outputs": [], "source": [ "def simulate_SPRT_threshold(mu, sigma, threshold , true_dist=1):\n", " \"\"\"Simulate a Sequential Probability Ratio Test with thresholding stopping\n", " rule. Two observation models are 1D Gaussian distributions N(1,sigma^2) and\n", " N(-1,sigma^2).\n", "\n", " Args:\n", " mu (float): absolute mean value of the symmetric observation distributions\n", " sigma (float): Standard deviation\n", " threshold (float): Desired log likelihood ratio threshold to achieve\n", " before making decision\n", "\n", " Returns:\n", " evidence_history (numpy vector): the history of cumulated evidence given\n", " generated data\n", " decision (int): 1 for pR, 0 for pL\n", " data (numpy vector): the generated sequences of data in this trial\n", " \"\"\"\n", " assert mu > 0, \"Mu should be > 0\"\n", " muL = -mu\n", " muR = mu\n", "\n", " pL = stats.norm(muL, sigma)\n", " pR = stats.norm(muR, sigma)\n", "\n", " has_enough_data = False\n", "\n", " data_history = []\n", " evidence_history = []\n", " current_evidence = 0.0\n", "\n", " # Keep sampling data until threshold is crossed\n", " while not has_enough_data:\n", " if true_dist == 1:\n", " Mvec = pR.rvs()\n", " else:\n", " Mvec = pL.rvs()\n", "\n", " ########################################################################\n", " # Insert your code here to:\n", " # * Calculate the log-likelihood ratio for the new sample\n", " # * Update the accumulated evidence\n", " raise NotImplementedError(\"simulate_SPRT_threshold is incomplete\")\n", " ########################################################################\n", "\n", " # STEP 1: individual log likelihood ratios\n", " ll_ratio = log_likelihood_ratio(...)\n", "\n", " # STEP 2: accumulated evidence for this chunk\n", " evidence_history.append(...)\n", "\n", " # update the collection of all data\n", " data_history.append(Mvec)\n", " current_evidence = evidence_history[-1]\n", "\n", " # check if we've got enough data\n", " if abs(current_evidence) > threshold:\n", " has_enough_data = True\n", "\n", " data_history = np.array(data_history)\n", " evidence_history = np.array(evidence_history)\n", "\n", " # Make decision\n", " if evidence_history[-1] >= 0:\n", " decision = 1\n", " elif evidence_history[-1] < 0:\n", " decision = 0\n", "\n", " return evidence_history, decision, data_history\n", "\n", "\n", "# Set parameters\n", "np.random.seed(100)\n", "mu = 1.0\n", "sigma = 2.8\n", "num_sample = 10\n", "log10_alpha = -3 # log10(alpha)\n", "alpha = np.power(10.0, log10_alpha)\n", "\n", "# Simulate and visualize\n", "simulate_and_plot_SPRT_fixedthreshold(mu, sigma, num_sample, alpha)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "execution": {} }, "source": [ "[*Click for solution*](https://github.com/NeuromatchAcademy/course-content/tree/main//tutorials/W3D2_HiddenDynamics/solutions/W3D2_Tutorial1_Solution_3559a6a0.py)\n", "\n", "*Example output:*\n", "\n", " \n", "\n" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "### Bonus Interactive Demo 1.2: DDM with fixed confidence threshold\n", "\n", "\n", "\n", "Play with different values of alpha and sigma and observe how that affects the dynamics of Drift-Diffusion Model." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ " Make sure you execute this cell to enable the widget!\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @markdown Make sure you execute this cell to enable the widget!\n", "def simulate_SPRT_threshold(mu, sigma, threshold , true_dist=1):\n", " \"\"\"Simulate a Sequential Probability Ratio Test with thresholding stopping\n", " rule. Two observation models are 1D Gaussian distributions N(1,sigma^2) and\n", " N(-1,sigma^2).\n", "\n", " Args:\n", " mu (float): absolute mean value of the symmetric observation distributions\n", " sigma (float): Standard deviation\n", " threshold (float): Desired log likelihood ratio threshold to achieve\n", " before making decision\n", "\n", " Returns:\n", " evidence_history (numpy vector): the history of cumulated evidence given\n", " generated data\n", " decision (int): 1 for pR, 0 for pL\n", " data (numpy vector): the generated sequences of data in this trial\n", " \"\"\"\n", " assert mu > 0, \"Mu should be > 0\"\n", " muL = -mu\n", " muR = mu\n", "\n", " pL = stats.norm(muL, sigma)\n", " pR = stats.norm(muR, sigma)\n", "\n", " has_enough_data = False\n", "\n", " data_history = []\n", " evidence_history = []\n", " current_evidence = 0.0\n", "\n", " # Keep sampling data until threshold is crossed\n", " while not has_enough_data:\n", " if true_dist == 1:\n", " Mvec = pR.rvs()\n", " else:\n", " Mvec = pL.rvs()\n", "\n", " # STEP 1: individual log likelihood ratios\n", " ll_ratio = log_likelihood_ratio(Mvec, pL, pR)\n", "\n", " # STEP 2: accumulated evidence for this chunk\n", " evidence_history.append(ll_ratio + current_evidence)\n", "\n", " # update the collection of all data\n", " data_history.append(Mvec)\n", " current_evidence = evidence_history[-1]\n", "\n", " # check if we've got enough data\n", " if abs(current_evidence) > threshold:\n", " has_enough_data = True\n", "\n", " data_history = np.array(data_history)\n", " evidence_history = np.array(evidence_history)\n", "\n", " # Make decision\n", " if evidence_history[-1] >= 0:\n", " decision = 1\n", " elif evidence_history[-1] < 0:\n", " decision = 0\n", "\n", " return evidence_history, decision, data_history\n", "\n", "np.random.seed(100)\n", "num_sample = 10\n", "\n", "@widgets.interact\n", "def plot(mu=(0.1,5.0,0.1), sigma=(0.05, 10.0, 0.05), log10_alpha=(-8, -1, .1)):\n", " alpha = np.power(10.0, log10_alpha)\n", " simulate_and_plot_SPRT_fixedthreshold(mu, sigma, num_sample, alpha, verbose=False)" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "### Bonus Coding Exercise 1.3: Speed/Accuracy Tradeoff Revisited\n", "\n", "The faster you make a decision, the lower your accuracy often is. This phenomenon is known as the **speed/accuracy tradeoff**. Humans can make this tradeoff in a wide range of situations, and many animal species, including ants, bees, rodents, and monkeys also show similar effects. \n", "\n", "To illustrate the speed/accuracy tradeoff under thresholding stopping rule, let's run some simulations under different thresholds and look at how average decision \"speed\" (1/length) changes with average decision accuracy. We use speed rather than accuracy because in real experiments, subjects can be incentivized to respond faster or slower; it's much harder to precisely control their decision time or error threshold. \n", "\n", "* Complete the function simulate_accuracy_vs_threshold to simulate and compute average accuracies vs. average decision lengths for a list of error thresholds. You will need to supply code to calculate average decision 'speed' from the lengths of trials. You should also calculate the overall accuracy across these trials. \n", "\n", "* We've set up a list of error thresholds. Run repeated simulations and collect average accuracy with average length for each error rate in this list, and use our provided code to visualize the speed/accuracy tradeoff. You should see a positive correlation between length and accuracy.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "execution": {} }, "outputs": [], "source": [ "def simulate_accuracy_vs_threshold(mu, sigma, threshold_list, num_sample):\n", " \"\"\"Calculate the average decision accuracy vs. average decision length by\n", " running repeated SPRT simulations with thresholding stopping rule for each\n", " threshold.\n", "\n", " Args:\n", " mu (float): absolute mean value of the symmetric observation distributions\n", " sigma (float): standard deviation for observation model\n", " threshold_list (list-like object): a list of evidence thresholds to run\n", " over\n", " num_sample (int): number of simulations to run per stopping time\n", "\n", " Returns:\n", " accuracy_list: a list of average accuracies corresponding to input\n", " threshold_list\n", " decision_speed_list: a list of average decision speeds\n", " \"\"\"\n", " decision_speed_list = []\n", " accuracy_list = []\n", " for threshold in threshold_list:\n", " decision_time_list = []\n", " decision_list = []\n", " for i in range(num_sample):\n", " # run simulation and get decision of current simulation\n", " _, decision, Mvec = simulate_SPRT_threshold(mu, sigma, threshold)\n", " decision_time = len(Mvec)\n", " decision_list.append(decision)\n", " decision_time_list.append(decision_time)\n", "\n", " ########################################################################\n", " # Insert your code here to:\n", " # * Calculate mean decision speed given a list of decision times\n", " # * Hint: Think about speed as being inversely proportional\n", " # to decision_length. If it takes 10 seconds to make one decision,\n", " # our \"decision speed\" is 0.1 decisions per second.\n", " # * Calculate the decision accuracy\n", " raise NotImplementedError(\"simulate_accuracy_vs_threshold is incomplete\")\n", " ########################################################################\n", " # Calculate and store average decision speed and accuracy\n", " decision_speed = ...\n", " decision_accuracy = ...\n", " decision_speed_list.append(decision_speed)\n", " accuracy_list.append(decision_accuracy)\n", "\n", " return accuracy_list, decision_speed_list\n", "\n", "\n", "# Set parameters\n", "np.random.seed(100)\n", "mu = 1.0\n", "sigma = 3.75\n", "num_sample = 200\n", "alpha_list = np.logspace(-2, -0.1, 8)\n", "threshold_list = threshold_from_errorrate(alpha_list)\n", "\n", "# Simulate and visualize\n", "simulate_and_plot_accuracy_vs_threshold(mu, sigma, threshold_list, num_sample)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "execution": {} }, "source": [ "[*Click for solution*](https://github.com/NeuromatchAcademy/course-content/tree/main//tutorials/W3D2_HiddenDynamics/solutions/W3D2_Tutorial1_Solution_87825db1.py)\n", "\n", "*Example output:*\n", "\n", " \n", "\n" ] }, { "cell_type": "markdown", "metadata": { "execution": {} }, "source": [ "### Bonus Interactive demo 1.4: Speed/Accuracy with a threshold rule\n", "\n", "Manipulate the noise level sigma and observe how that affects the speed/accuracy tradeoff." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ " Make sure you execute this cell to enable the widget!\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "form", "execution": {}, "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# @markdown Make sure you execute this cell to enable the widget!\n", "def simulate_accuracy_vs_threshold(mu, sigma, threshold_list, num_sample):\n", " \"\"\"Calculate the average decision accuracy vs. average decision speed by\n", " running repeated SPRT simulations with thresholding stopping rule for each\n", " threshold.\n", "\n", " Args:\n", " mu (float): absolute mean value of the symmetric observation distributions\n", " sigma (float): standard deviation for observation model\n", " threshold_list (list-like object): a list of evidence thresholds to run\n", " over\n", " num_sample (int): number of simulations to run per stopping time\n", "\n", " Returns:\n", " accuracy_list: a list of average accuracies corresponding to input\n", " threshold_list`\n", " decision_speed_list: a list of average decision speeds\n", " \"\"\"\n", " decision_speed_list = []\n", " accuracy_list = []\n", " for threshold in threshold_list:\n", " decision_time_list = []\n", " decision_list = []\n", " for i in range(num_sample):\n", " # run simulation and get decision of current simulation\n", " _, decision, Mvec = simulate_SPRT_threshold(mu, sigma, threshold)\n", " decision_time = len(Mvec)\n", " decision_list.append(decision)\n", " decision_time_list.append(decision_time)\n", "\n", " # Calculate and store average decision speed and accuracy\n", " decision_speed = np.mean(1. / np.array(decision_time_list))\n", " decision_accuracy = sum(decision_list) / len(decision_list)\n", " decision_speed_list.append(decision_speed)\n", " accuracy_list.append(decision_accuracy)\n", "\n", " return accuracy_list, decision_speed_list\n", "\n", "np.random.seed(100)\n", "num_sample = 100\n", "alpha_list = np.logspace(-2, -0.1, 8)\n", "threshold_list = threshold_from_errorrate(alpha_list)\n", "\n", "@widgets.interact\n", "def plot(mu=(0.1, 5.0, 0.1), sigma=(0.05, 10.0, 0.05)):\n", " alpha = np.power(10.0, log10_alpha)\n", " simulate_and_plot_accuracy_vs_threshold(mu, sigma, threshold_list, num_sample)" ] } ], "metadata": { "@webio": { "lastCommId": null, "lastKernelId": null }, "colab": { "collapsed_sections": [], "include_colab_link": true, "name": "W3D2_Tutorial1", "provenance": [], "toc_visible": true }, "kernel": { "display_name": "Python 3", "language": "python", "name": "python3" }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.13" }, "toc": { "base_numbering": 1, "nav_menu": {}, "number_sections": true, "sideBar": true, "skip_h1_title": true, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": {}, "toc_section_display": true, "toc_window_display": true } }, "nbformat": 4, "nbformat_minor": 0 }