{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"execution": {},
"id": "view-in-github"
},
"source": [
" "
]
},
{
"cell_type": "markdown",
"metadata": {
"execution": {}
},
"source": [
"# Tutorial 2: Statistical Inference\n",
"\n",
"**Week 0, Day 5: Probability & Statistics**\n",
"\n",
"**By Neuromatch Academy**\n",
"\n",
"**Content creators:** Ulrik Beierholm\n",
"\n",
"**Content reviewers:** Natalie Schaworonkow, Keith van Antwerp, Anoop Kulkarni, Pooya Pakarian, Hyosub Kim\n",
"\n",
"**Production editors:** Ethan Cheng, Ella Batty"
]
},
{
"cell_type": "markdown",
"metadata": {
"execution": {}
},
"source": [
"---\n",
"#Tutorial Objectives\n",
"\n",
"This tutorial builds on Tutorial 1 by explaining how to do inference through inverting the generative process.\n",
"\n",
"By completing the exercises in this tutorial, you should:\n",
"* understand what the likelihood function is, and have some intuition of why it is important\n",
"* know how to summarise the Gaussian distribution using mean and variance\n",
"* know how to maximise a likelihood function\n",
"* be able to do simple inference in both classical and Bayesian ways\n",
"* (Optional) understand how Bayes Net can be used to model causal relationships"
]
},
{
"cell_type": "markdown",
"metadata": {
"execution": {}
},
"source": [
"---\n",
"# Setup"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Install and import feedback gadget\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"execution": {},
"tags": [
"hide-input"
]
},
"outputs": [],
"source": [
"# @title Install and import feedback gadget\n",
"\n",
"!pip3 install vibecheck datatops --quiet\n",
"\n",
"from vibecheck import DatatopsContentReviewContainer\n",
"def content_review(notebook_section: str):\n",
" return DatatopsContentReviewContainer(\n",
" \"\", # No text prompt\n",
" notebook_section,\n",
" {\n",
" \"url\": \"https://pmyvdlilci.execute-api.us-east-1.amazonaws.com/klab\",\n",
" \"name\": \"neuromatch-precourse\",\n",
" \"user_key\": \"8zxfvwxw\",\n",
" },\n",
" ).render()\n",
"\n",
"\n",
"feedback_prefix = \"W0D5_T2\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "code",
"execution": {}
},
"outputs": [],
"source": [
"# Imports\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import scipy as sp\n",
"from scipy.stats import norm\n",
"from numpy.random import default_rng # a default random number generator"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Figure settings\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"execution": {},
"tags": [
"hide-input"
]
},
"outputs": [],
"source": [
"# @title Figure settings\n",
"import logging\n",
"logging.getLogger('matplotlib.font_manager').disabled = True\n",
"import ipywidgets as widgets # interactive display\n",
"from ipywidgets import interact, fixed, HBox, Layout, VBox, interactive, Label, interact_manual\n",
"%config InlineBackend.figure_format = 'retina'\n",
"plt.style.use(\"https://raw.githubusercontent.com/NeuromatchAcademy/content-creation/main/nma.mplstyle\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Plotting functions\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"execution": {},
"tags": [
"hide-input"
]
},
"outputs": [],
"source": [
"# @title Plotting functions\n",
"\n",
"def plot_hist(data, xlabel, figtitle = None, num_bins = None):\n",
" \"\"\" Plot the given data as a histogram.\n",
"\n",
" Args:\n",
" data (ndarray): array with data to plot as histogram\n",
" xlabel (str): label of x-axis\n",
" figtitle (str): title of histogram plot (default is no title)\n",
" num_bins (int): number of bins for histogram (default is 10)\n",
"\n",
" Returns:\n",
" count (ndarray): number of samples in each histogram bin\n",
" bins (ndarray): center of each histogram bin\n",
" \"\"\"\n",
" fig, ax = plt.subplots()\n",
" ax.set_xlabel(xlabel)\n",
" ax.set_ylabel('Count')\n",
" if num_bins is not None:\n",
" count, bins, _ = plt.hist(data, max(data), bins=num_bins)\n",
" else:\n",
" count, bins, _ = plt.hist(data, max(data)) # 10 bins default\n",
" if figtitle is not None:\n",
" fig.suptitle(figtitle, size=16)\n",
" plt.show()\n",
" return count, bins\n",
"\n",
"\n",
"def plot_gaussian_samples_true(samples, xspace, mu, sigma, xlabel, ylabel):\n",
" \"\"\" Plot a histogram of the data samples on the same plot as the gaussian\n",
" distribution specified by the give mu and sigma values.\n",
"\n",
" Args:\n",
" samples (ndarray): data samples for gaussian distribution\n",
" xspace (ndarray): x values to sample from normal distribution\n",
" mu (scalar): mean parameter of normal distribution\n",
" sigma (scalar): variance parameter of normal distribution\n",
" xlabel (str): the label of the x-axis of the histogram\n",
" ylabel (str): the label of the y-axis of the histogram\n",
"\n",
" Returns:\n",
" Nothing.\n",
" \"\"\"\n",
" fig, ax = plt.subplots()\n",
" ax.set_xlabel(xlabel)\n",
" ax.set_ylabel(ylabel)\n",
" # num_samples = samples.shape[0]\n",
"\n",
" count, bins, _ = plt.hist(samples, density=True) # probability density function\n",
"\n",
" plt.plot(xspace, norm.pdf(xspace, mu, sigma), 'r-')\n",
" plt.show()\n",
"\n",
"\n",
"def plot_likelihoods(likelihoods, mean_vals, variance_vals):\n",
" \"\"\" Plot the likelihood values on a heatmap plot where the x and y axes match\n",
" the mean and variance parameter values the likelihoods were computed for.\n",
"\n",
" Args:\n",
" likelihoods (ndarray): array of computed likelihood values\n",
" mean_vals (ndarray): array of mean parameter values for which the\n",
" likelihood was computed\n",
" variance_vals (ndarray): array of variance parameter values for which the\n",
" likelihood was computed\n",
"\n",
" Returns:\n",
" Nothing.\n",
" \"\"\"\n",
" fig, ax = plt.subplots()\n",
" im = ax.imshow(likelihoods)\n",
"\n",
" cbar = ax.figure.colorbar(im, ax=ax)\n",
" cbar.ax.set_ylabel('log likelihood', rotation=-90, va=\"bottom\")\n",
"\n",
" ax.set_xticks(np.arange(len(mean_vals)))\n",
" ax.set_yticks(np.arange(len(variance_vals)))\n",
" ax.set_xticklabels(mean_vals)\n",
" ax.set_yticklabels(variance_vals)\n",
" ax.set_xlabel('Mean')\n",
" ax.set_ylabel('Variance')\n",
" plt.show()\n",
"\n",
"\n",
"def posterior_plot(x, likelihood=None, prior=None,\n",
" posterior_pointwise=None, ax=None):\n",
" \"\"\"\n",
" Plots normalized Gaussian distributions and posterior.\n",
"\n",
" Args:\n",
" x (numpy array of floats): points at which the likelihood has been evaluated\n",
" auditory (numpy array of floats): normalized probabilities for auditory likelihood evaluated at each `x`\n",
" visual (numpy array of floats): normalized probabilities for visual likelihood evaluated at each `x`\n",
" posterior (numpy array of floats): normalized probabilities for the posterior evaluated at each `x`\n",
" ax: Axis in which to plot. If None, create new axis.\n",
"\n",
" Returns:\n",
" Nothing.\n",
" \"\"\"\n",
" if likelihood is None:\n",
" likelihood = np.zeros_like(x)\n",
"\n",
" if prior is None:\n",
" prior = np.zeros_like(x)\n",
"\n",
" if posterior_pointwise is None:\n",
" posterior_pointwise = np.zeros_like(x)\n",
"\n",
" if ax is None:\n",
" fig, ax = plt.subplots()\n",
"\n",
" ax.plot(x, likelihood, '-C1', linewidth=2, label='Auditory')\n",
" ax.plot(x, prior, '-C0', linewidth=2, label='Visual')\n",
" ax.plot(x, posterior_pointwise, '-C2', linewidth=2, label='Posterior')\n",
" ax.legend()\n",
" ax.set_ylabel('Probability')\n",
" ax.set_xlabel('Orientation (Degrees)')\n",
" plt.show()\n",
"\n",
" return ax\n",
"\n",
"\n",
"def plot_classical_vs_bayesian_normal(num_points, mu_classic, var_classic,\n",
" mu_bayes, var_bayes):\n",
" \"\"\" Helper function to plot optimal normal distribution parameters for varying\n",
" observed sample sizes using both classic and Bayesian inference methods.\n",
"\n",
" Args:\n",
" num_points (int): max observed sample size to perform inference with\n",
" mu_classic (ndarray): estimated mean parameter for each observed sample size\n",
" using classic inference method\n",
" var_classic (ndarray): estimated variance parameter for each observed sample size\n",
" using classic inference method\n",
" mu_bayes (ndarray): estimated mean parameter for each observed sample size\n",
" using Bayesian inference method\n",
" var_bayes (ndarray): estimated variance parameter for each observed sample size\n",
" using Bayesian inference method\n",
"\n",
" Returns:\n",
" Nothing.\n",
" \"\"\"\n",
" xspace = np.linspace(0, num_points, num_points)\n",
" fig, ax = plt.subplots()\n",
" ax.set_xlabel('n data points')\n",
" ax.set_ylabel('mu')\n",
" plt.plot(xspace, mu_classic,'r-', label=\"Classical\")\n",
" plt.plot(xspace, mu_bayes,'b-', label=\"Bayes\")\n",
" plt.legend()\n",
" plt.show()\n",
"\n",
" fig, ax = plt.subplots()\n",
" ax.set_xlabel('n data points')\n",
" ax.set_ylabel('sigma^2')\n",
" plt.plot(xspace, var_classic,'r-', label=\"Classical\")\n",
" plt.plot(xspace, var_bayes,'b-', label=\"Bayes\")\n",
" plt.legend()\n",
" plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {
"execution": {}
},
"source": [
" ---\n",
"# Section 1: Basic probability"
]
},
{
"cell_type": "markdown",
"metadata": {
"execution": {}
},
"source": [
"## Section 1.1: Basic probability theory"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Video 1: Basic Probability\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"execution": {},
"tags": [
"remove-input"
]
},
"outputs": [],
"source": [
"# @title Video 1: Basic Probability\n",
"from ipywidgets import widgets\n",
"from IPython.display import YouTubeVideo\n",
"from IPython.display import IFrame\n",
"from IPython.display import display\n",
"\n",
"\n",
"class PlayVideo(IFrame):\n",
" def __init__(self, id, source, page=1, width=400, height=300, **kwargs):\n",
" self.id = id\n",
" if source == 'Bilibili':\n",
" src = f'https://player.bilibili.com/player.html?bvid={id}&page={page}'\n",
" elif source == 'Osf':\n",
" src = f'https://mfr.ca-1.osf.io/render?url=https://osf.io/download/{id}/?direct%26mode=render'\n",
" super(PlayVideo, self).__init__(src, width, height, **kwargs)\n",
"\n",
"\n",
"def display_videos(video_ids, W=400, H=300, fs=1):\n",
" tab_contents = []\n",
" for i, video_id in enumerate(video_ids):\n",
" out = widgets.Output()\n",
" with out:\n",
" if video_ids[i][0] == 'Youtube':\n",
" video = YouTubeVideo(id=video_ids[i][1], width=W,\n",
" height=H, fs=fs, rel=0)\n",
" print(f'Video available at https://youtube.com/watch?v={video.id}')\n",
" else:\n",
" video = PlayVideo(id=video_ids[i][1], source=video_ids[i][0], width=W,\n",
" height=H, fs=fs, autoplay=False)\n",
" if video_ids[i][0] == 'Bilibili':\n",
" print(f'Video available at https://www.bilibili.com/video/{video.id}')\n",
" elif video_ids[i][0] == 'Osf':\n",
" print(f'Video available at https://osf.io/{video.id}')\n",
" display(video)\n",
" tab_contents.append(out)\n",
" return tab_contents\n",
"\n",
"\n",
"video_ids = [('Youtube', 'SL0_6rw8zrM'), ('Bilibili', 'BV1bw411o7HR')]\n",
"tab_contents = display_videos(video_ids, W=730, H=410)\n",
"tabs = widgets.Tab()\n",
"tabs.children = tab_contents\n",
"for i in range(len(tab_contents)):\n",
" tabs.set_title(i, video_ids[i][0])\n",
"display(tabs)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Submit your feedback\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"execution": {},
"tags": [
"hide-input"
]
},
"outputs": [],
"source": [
"# @title Submit your feedback\n",
"content_review(f\"{feedback_prefix}_Basic_Probability_Video\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"execution": {}
},
"source": [
"This video covers basic probability theory, including complementary probability, conditional probability, joint probability, and marginalisation.\n",
"\n",
" Click here for text recap of video
\n",
"\n",
"Previously we were only looking at sampling or properties of a single variables, but as we will now move on to statistical inference, it is useful to go over basic probability theory.\n",
"\n",
"\n",
"As a reminder, probability has to be in the range 0 to 1\n",
"$P(A) \\in [0,1] $\n",
"\n",
"and the complementary can always be defined as\n",
"\n",
"$P(\\neg A) = 1-P(A)$\n",
"\n",
"\n",
"When we have two variables, the *conditional probability* of $A$ given $B$ is\n",
"\n",
"$P (A|B) = P (A \\cap B)/P (B)=P (A, B)/P (B)$\n",
"\n",
"while the *joint probability* of $A$ and $B$ is\n",
"\n",
"$P(A \\cap B)=P(A,B) = P(B|A)P(A) = P(A|B)P(B) $\n",
"\n",
"We can then also define the process of *marginalisation* (for discrete variables) as\n",
"\n",
"$P(A)=\\sum P(A,B)=\\sum P(A|B)P(B)$\n",
"\n",
"where the summation is over the possible values of $B$.\n",
"\n",
"As an example if $B$ is a binary variable that can take values $B+$ or $B0$ then\n",
"$P(A)=\\sum P(A,B)=P(A|B+)P(B+)+ P(A|B0)P(B0) $.\n",
"\n",
"For continuous variables marginalization is given as\n",
"$P(A)=\\int P(A,B) dB=\\int P(A|B)P(B) dB$\n",
" Click here for text recap of video
\n",
"\n",
"A generative model (such as the Gaussian distribution from the previous tutorial) allows us to make predictions about outcomes.\n",
"\n",
"However, after we observe $n$ data points, we can also evaluate our model (and any of its associated parameters) by calculating the **likelihood** of our model having generated each of those data points $x_i$.\n",
"\n",
"\\begin{equation}\n",
"P(x_i|\\mu,\\sigma)=\\mathcal{N}(x_i,\\mu,\\sigma)\n",
"\\end{equation}\n",
"\n",
"For all data points $\\mathbf{x}=(x_1, x_2, x_3, ...x_n) $ we can then calculate the likelihood for the whole dataset by computing the product of the likelihood for each single data point.\n",
"\n",
"\\begin{equation}\n",
"P(\\mathbf{x}|\\mu,\\sigma)=\\prod_{i=1}^n \\mathcal{N}(x_i,\\mu,\\sigma)\n",
"\\end{equation}\n",
"\n",
"