Open In Colab   Open in Kaggle

Tutorial 1: LIF Neuron Part II#

Week 0, Day 2: Python Workshop 2

By Neuromatch Academy

Content creators: Marco Brigham and the CCNSS team

Content reviewers: Michael Waskom, Karolina Stosio, Spiros Chavlis

Production editors: Ella Batty, Spiros Chavlis


Tutorial objectives#

We learned basic Python and NumPy concepts in the previous tutorial. These new and efficient coding techniques can be applied repeatedly in tutorials from the NMA course, and elsewhere.

In this tutorial, we’ll introduce spikes in our LIF neuron and evaluate the refractory period’s effect in spiking dynamics!


Setup#

Install and import feedback gadget#

Hide code cell source
# @title Install and import feedback gadget

!pip3 install vibecheck datatops --quiet

from vibecheck import DatatopsContentReviewContainer
def content_review(notebook_section: str):
    return DatatopsContentReviewContainer(
        "",  # No text prompt
        notebook_section,
        {
            "url": "https://pmyvdlilci.execute-api.us-east-1.amazonaws.com/klab",
            "name": "neuromatch-precourse",
            "user_key": "8zxfvwxw",
        },
    ).render()


feedback_prefix = "W0D2_T1"
# Imports
import numpy as np
import matplotlib.pyplot as plt

Figure settings#

Hide code cell source
# @title Figure settings
import logging
logging.getLogger('matplotlib.font_manager').disabled = True
import ipywidgets as widgets  # interactive display
%config InlineBackend.figure_format = 'retina'
plt.style.use("https://raw.githubusercontent.com/NeuromatchAcademy/content-creation/main/nma.mplstyle")

Helper functions#

Hide code cell source
# @title Helper functions

t_max = 150e-3   # second
dt = 1e-3        # second
tau = 20e-3      # second
el = -60e-3      # milivolt
vr = -70e-3      # milivolt
vth = -50e-3     # milivolt
r = 100e6        # ohm
i_mean = 25e-11  # ampere


def plot_all(t_range, v, raster=None, spikes=None, spikes_mean=None):
  """
  Plots Time evolution for
  (1) multiple realizations of membrane potential
  (2) spikes
  (3) mean spike rate (optional)

  Args:
    t_range (numpy array of floats)
        range of time steps for the plots of shape (time steps)

    v (numpy array of floats)
        membrane potential values of shape (neurons, time steps)

    raster (numpy array of floats)
        spike raster of shape (neurons, time steps)

    spikes (dictionary of lists)
        list with spike times indexed by neuron number

    spikes_mean (numpy array of floats)
        Mean spike rate for spikes as dictionary

  Returns:
    Nothing.
  """

  v_mean = np.mean(v, axis=0)
  fig_w, fig_h = plt.rcParams['figure.figsize']
  plt.figure(figsize=(fig_w, 1.5 * fig_h))

  ax1 = plt.subplot(3, 1, 1)
  for j in range(n):
    plt.scatter(t_range, v[j], color="k", marker=".", alpha=0.01)
  plt.plot(t_range, v_mean, 'C1', alpha=0.8, linewidth=3)
  plt.xticks([])
  plt.ylabel(r'$V_m$ (V)')

  if raster is not None:
    plt.subplot(3, 1, 2)
    spikes_mean = np.mean(raster, axis=0)
    plt.imshow(raster, cmap='Greys', origin='lower', aspect='auto')

  else:
    plt.subplot(3, 1, 2, sharex=ax1)
    for j in range(n):
      times = np.array(spikes[j])
      plt.scatter(times, j * np.ones_like(times), color="C0", marker=".", alpha=0.2)

  plt.xticks([])
  plt.ylabel('neuron')

  if spikes_mean is not None:
    plt.subplot(3, 1, 3, sharex=ax1)
    plt.plot(t_range, spikes_mean)
    plt.xlabel('time (s)')
    plt.ylabel('rate (Hz)')

  plt.tight_layout()
  plt.show()

Section 1: Histograms#

Video 1: Histograms#

Submit your feedback#

Hide code cell source
# @title Submit your feedback
content_review(f"{feedback_prefix}_Histograms_Video")

Another important statistic is the sample histogram. For our LIF neuron, it provides an approximate representation of the distribution of membrane potential \(V_m(t)\) at time \(t=t_k\in[0,t_{max}]\). For \(N\) realizations \(V\left(t_k\right)\) and \(J\) bins is given by:


\begin{equation} N = \sum_{j=1}^{J} m_j \end{equation}

where \(m_j\) is a function that counts the number of samples \(V\left(t_k\right)\) that fall into bin \(j\).

The function plt.hist(data, nbins) plots an histogram of data in nbins bins. The argument label defines a label for data and plt.legend() adds all labels to the plot.

plt.hist(data, bins, label='my data')
plt.legend()
plt.show()

The parameters histtype='stepfilled' and linewidth=0 may improve histogram appearance (depending on taste). You can read more about different histtype settings.

The function plt.hist returns the pdf, bins, and patches with the histogram bins, the edges of the bins, and the individual patches used to create the histogram.

pdf, bins, patches = plt.hist(data, bins)

Video 2: Nano recap of histograms#

Submit your feedback#

Hide code cell source
# @title Submit your feedback
content_review(f"{feedback_prefix}_Nano_recap_of_histograms_Video")

Coding Exercise 1: Plotting a histogram#

Plot an histogram of \(J=50\) bins of \(N=10000\) realizations of \(V(t)\) for \(t=t_{max}/10\) and \(t=t_{max}\).

We’ll make a small correction in the definition of t_range to ensure increments of dt by using np.arange instead of np.linspace.

numpy.arange(start, stop, step)
#################################################
## TODO for students: fill out code to plot histogram ##
# Fill out code and comment or remove the next line
raise NotImplementedError("Student exercise: You need to plot histogram")
#################################################

# Set random number generator
np.random.seed(2020)

# Initialize t_range, step_end, n, v_n, i and nbins
t_range = np.arange(0, t_max, dt)
step_end = len(t_range)
n = 10000
v_n = el * np.ones([n, step_end])
i = i_mean * (1 + 0.1 * (t_max / dt)**(0.5) * (2 * np.random.random([n, step_end]) - 1))
nbins = 50

# Loop over time steps
for step, t in enumerate(t_range):

  # Skip first iteration
  if step==0:
    continue

  # Compute v_n
  v_n[:, step] =  v_n[:, step - 1] + (dt / tau) * (el - v_n[:, step - 1] + r * i[:, step])

# Initialize the figure
plt.figure()
plt.ylabel('Frequency')
plt.xlabel('$V_m$ (V)')

# Plot a histogram at t_max/10 (add labels and parameters histtype='stepfilled' and linewidth=0)
plt.hist(...)

# Plot a histogram at t_max (add labels and parameters histtype='stepfilled' and linewidth=0)
plt.hist(...)

# Add legend
plt.legend()
plt.show()

Click for solution

Example output:

Solution hint

Submit your feedback#

Hide code cell source
# @title Submit your feedback
content_review(f"{feedback_prefix}_Plotting_a_histogram_Exercise")

Section 2: Dictionaries & introducing spikes#

Video 3: Dictionaries & introducing spikes#

Submit your feedback#

Hide code cell source
# @title Submit your feedback
content_review(f"{feedback_prefix}_Dictionaries_&_introducing_spikes_Video")

A spike occures whenever \(V(t)\) crosses \(V_{th}\). In that case, a spike is recorded, and \(V(t)\) resets to \(V_{reset}\) value. This is summarized in the reset condition:

(7)#\[\begin{equation} V(t) = V_{reset}\quad \text{ if } V(t)\geq V_{th} \end{equation}\]

For more information about spikes or action potentials see here and here.

Video 4: Nano recap of dictionaries#

Submit your feedback#

Hide code cell source
# @title Submit your feedback
content_review(f"{feedback_prefix}_Nano_recap_of_dictionaries_Video")

Coding Exercise 2: Adding spiking to the LIF neuron#

Insert the reset condition, and collect the spike times of each realization in a dictionary variable spikes, with \(N=500\).

We’ve used plt.plot for plotting lines and also for plotting dots at (x,y) coordinates, which is a scatter plot. From here on, we’ll use use plt.plot for plotting lines and for scatter plots: plt.scatter.

plt.scatter(x, y, color="k", marker=".")

A raster plot represents spikes from multiple neurons by plotting dots at spike times from neuron j at plot height j, i.e.

plt.scatter(spike_times, j*np.ones_like(spike_times))

In this exercise, we use plt.subplot for multiple plots in the same figure. These plots can share the same x or y axis by specifying the parameter sharex or sharey. Add plt.tight_layout() at the end to automatically adjust subplot parameters to fit the figure area better. Please see the example below for a row of two plots sharing axis y.

# initialize the figure
plt.figure()

# collect axis of 1st figure in ax1
ax1 = plt.subplot(1, 2, 1)
plt.plot(t_range, my_data_left)
plt.ylabel('ylabel')

# share axis x with 1st figure
plt.subplot(1, 2, 2, sharey=ax1)
plt.plot(t_range, my_data_right)

# automatically adjust subplot parameters to figure
plt.tight_layout()
plt.show()
# Set random number generator
np.random.seed(2020)

# Initialize step_end, t_range, n, v_n and i
t_range = np.arange(0, t_max, dt)
step_end = len(t_range)
n = 500
v_n = el * np.ones([n, step_end])
i = i_mean * (1 + 0.1 * (t_max / dt)**(0.5) * (2 * np.random.random([n, step_end]) - 1))

# Initialize spikes and spikes_n
spikes = {j: [] for j in range(n)}
spikes_n = np.zeros([step_end])

#################################################
## TODO for students: add spikes to LIF neuron ##
# Fill out function and remove
raise NotImplementedError("Student exercise: add spikes to LIF neuron")
#################################################

# Loop over time steps
for step, t in enumerate(t_range):

  # Skip first iteration
  if step == 0:
    continue

  # Compute v_n
  v_n[:, step] = v_n[:, step - 1] + (dt / tau) * (el - v_n[:, step - 1] + r*i[:, step])

  # Loop over simulations
  for j in range(n):

    # Check if voltage above threshold
    if v_n[j, step] >= vth:

      # Reset to reset voltage
      v_n[j, step] = ...

      # Add this spike time
      spikes[j] += ...

      # Add spike count to this step
      spikes_n[step] += ...

# Collect mean Vm and mean spiking rate
v_mean = np.mean(v_n, axis=0)
spikes_mean =  spikes_n / n

# Initialize the figure
plt.figure()

# Plot simulations and sample mean
ax1 = plt.subplot(3, 1, 1)
for j in range(n):
  plt.scatter(t_range, v_n[j], color="k", marker=".", alpha=0.01)
plt.plot(t_range, v_mean, 'C1', alpha=0.8, linewidth=3)
plt.ylabel('$V_m$ (V)')

# Plot spikes
plt.subplot(3, 1, 2, sharex=ax1)
# for each neuron j: collect spike times and plot them at height j
for j in range(n):
  times = ...
  plt.scatter(...)

plt.ylabel('neuron')

# Plot firing rate
plt.subplot(3, 1, 3, sharex=ax1)
plt.plot(t_range, spikes_mean)
plt.xlabel('time (s)')
plt.ylabel('rate (Hz)')

plt.tight_layout()

Click for solution

Submit your feedback#

Hide code cell source
# @title Submit your feedback
content_review(f"{feedback_prefix}_Adding_spiking_to_the_LIF_neuron_Exercise")

Section 3: Boolean indexes#

Video 5: Boolean indexes#

Submit your feedback#

Hide code cell source
# @title Submit your feedback
content_review(f"{feedback_prefix}_Boolean_indexes_Video")

Boolean arrays can index NumPy arrays to select a subset of elements (also works with lists of booleans).

The boolean array itself can be initiated by a condition, as shown in the example below.

a = np.array([1, 2, 3])
b = a>=2
print(b)
--> [False  True  True]

print(a[b])
--> [2 3]

print(a[a>=2])
--> [2 3]

Video 6: Nano recap of Boolean indexes#

Submit your feedback#

Hide code cell source
# @title Submit your feedback
content_review(f"{feedback_prefix}_Nano_recap_of_Boolean_indexes_Video")

Coding Exercise 3: Using Boolean indexing#

We can avoid looping all neurons in each time step by identifying with boolean arrays the indexes of neurons that spiked in the previous step.

In the example below, v_rest is a boolean numpy array with True in each index of v_n with value vr at time index step:

v_rest = (v_n[:,step] == vr)
print(v_n[v_rest,step])
  --> [vr, ..., vr]

The function np.where returns indexes of boolean arrays with True values.

You may use the helper function plot_all that implements the figure from the previous exercise.

help(plot_all)
Help on function plot_all in module __main__:

plot_all(t_range, v, raster=None, spikes=None, spikes_mean=None)
    Plots Time evolution for
    (1) multiple realizations of membrane potential
    (2) spikes
    (3) mean spike rate (optional)
    
    Args:
      t_range (numpy array of floats)
          range of time steps for the plots of shape (time steps)
    
      v (numpy array of floats)
          membrane potential values of shape (neurons, time steps)
    
      raster (numpy array of floats)
          spike raster of shape (neurons, time steps)
    
      spikes (dictionary of lists)
          list with spike times indexed by neuron number
    
      spikes_mean (numpy array of floats)
          Mean spike rate for spikes as dictionary
    
    Returns:
      Nothing.
# Set random number generator
np.random.seed(2020)

# Initialize step_end, t_range, n, v_n and i
t_range = np.arange(0, t_max, dt)
step_end = len(t_range)
n = 500
v_n = el * np.ones([n, step_end])
i = i_mean * (1 + 0.1 * (t_max / dt)**(0.5) * (2 * np.random.random([n, step_end]) - 1))

# Initialize spikes and spikes_n
spikes = {j: [] for j in range(n)}
spikes_n = np.zeros([step_end])

#################################################
## TODO for students: use Boolean indexing ##
# Fill out function and remove
raise NotImplementedError("Student exercise: using Boolean indexing")
#################################################

# Loop over time steps
for step, t in enumerate(t_range):

  # Skip first iteration
  if step == 0:
    continue

  # Compute v_n
  v_n[:, step] = v_n[:, step - 1] + (dt / tau) * (el - v_n[:, step - 1] + r*i[:, step])

  # Initialize boolean numpy array `spiked` with v_n > v_thr
  spiked = ...

  # Set relevant values of v_n to resting potential using spiked
  v_n[spiked,step] = ...

  # Collect spike times
  for j in np.where(spiked)[0]:
    spikes[j] += [t]
    spikes_n[step] += 1

# Collect mean spiking rate
spikes_mean = spikes_n / n

# Plot multiple realizations of Vm, spikes and mean spike rate
plot_all(t_range, v_n, spikes=spikes, spikes_mean=spikes_mean)

Click for solution

Example output:

Solution hint Solution hint

Submit your feedback#

Hide code cell source
# @title Submit your feedback
content_review(f"{feedback_prefix}_Using_Boolean_indexing_Exercise")

Coding Exercise 4: Making a binary raster plot#

A binary raster plot represents spike times as 1s in a binary grid initialized with 0s. We start with a numpy array raster of zeros with shape (neurons, time steps), and represent a spike from neuron 5 at time step 20 as raster(5,20)=1, for example.

The binary raster plot is much more efficient than the previous method by plotting the numpy array raster as an image:

plt.imshow(raster, cmap='Greys', origin='lower', aspect='auto')

Suggestions

  • At each time step:

    • Initialize boolean numpy array spiked with \(V_n(t)\geq V_{th}\)

    • Set to vr indexes of v_n using spiked

    • Set to 1 indexes of numpy array raster using spiked

#################################################
## TODO for students: make a raster ##
# Fill out function and remove
raise NotImplementedError("Student exercise: make a raster ")
#################################################

# Set random number generator
np.random.seed(2020)

# Initialize step_end, t_range, n, v_n and i
t_range = np.arange(0, t_max, dt)
step_end = len(t_range)
n = 500
v_n = el * np.ones([n, step_end])
i = i_mean * (1 + 0.1 * (t_max / dt)**(0.5) * (2 * np.random.random([n, step_end]) - 1))

# Initialize binary numpy array for raster plot
raster = np.zeros([n,step_end])

# Loop over time steps
for step, t in enumerate(t_range):

  # Skip first iteration
  if step==0:
    continue

  # Compute v_n
  v_n[:, step] = v_n[:, step - 1] + (dt / tau) * (el - v_n[:, step - 1] + r*i[:, step])

  # Initialize boolean numpy array `spiked` with v_n > v_thr
  spiked = (v_n[:,step] >= vth)

  # Set relevant values of v_n to v_reset using spiked
  v_n[spiked,step] = vr

  # Set relevant elements in raster to 1 using spiked
  raster[spiked,step] = ...

# Plot multiple realizations of Vm, spikes and mean spike rate
plot_all(t_range, v_n, raster)

Click for solution

Example output:

Solution hint

Submit your feedback#

Hide code cell source
# @title Submit your feedback
content_review(f"{feedback_prefix}_Making_a_binary_raster_plot_Exercise")

Section 4: Refractory period#

Video 7: Refractory period#

Submit your feedback#

Hide code cell source
# @title Submit your feedback
content_review(f"{feedback_prefix}_Refractory_period_Video")

The absolute refractory period is a time interval in the order of a few milliseconds during which synaptic input will not lead to a 2nd spike, no matter how strong. This effect is due to the biophysics of the neuron membrane channels, and you can read more about absolute and relative refractory period here and here.

Video 8: Nano recap of refractory period#

Submit your feedback#

Hide code cell source
# @title Submit your feedback
content_review(f"{feedback_prefix}_Nano_recap_of_refractory_period_Video")

Coding Exercise 5: Investigating refactory periods#

Investigate the effect of (absolute) refractory period \(t_{ref}\) on the evolution of output rate \(\lambda(t)\). Add refractory period \(t_{ref}=10\) ms after each spike, during which \(V(t)\) is clamped to \(V_{reset}\).

#################################################
## TODO for students: add refactory period ##
# Fill out function and remove
raise NotImplementedError("Student exercise: add refactory period ")
#################################################

# Set random number generator
np.random.seed(2020)

# Initialize step_end, t_range, n, v_n and i
t_range = np.arange(0, t_max, dt)
step_end = len(t_range)
n = 500
v_n = el * np.ones([n, step_end])
i = i_mean * (1 + 0.1 * (t_max / dt)**(0.5) * (2 * np.random.random([n, step_end]) - 1))

# Initialize binary numpy array for raster plot
raster = np.zeros([n,step_end])

# Initialize t_ref and last_spike
t_ref = 0.01
last_spike = -t_ref * np.ones([n])

# Loop over time steps
for step, t in enumerate(t_range):

  # Skip first iteration
  if step == 0:
    continue

  # Compute v_n
  v_n[:, step] = v_n[:, step - 1] + (dt / tau) * (el - v_n[:, step - 1] + r*i[:, step])

  # Initialize boolean numpy array `spiked` with v_n > v_thr
  spiked = (v_n[:,step] >= vth)

  # Set relevant values of v_n to v_reset using spiked
  v_n[spiked,step] = vr

  # Set relevant elements in raster to 1 using spiked
  raster[spiked,step] = 1.

  # Initialize boolean numpy array clamped using last_spike, t and t_ref
  clamped = ...

  # Reset clamped neurons to vr using clamped
  v_n[clamped,step] = ...

  # Update numpy array last_spike with time t for spiking neurons
  last_spike[spiked] = ...

# Plot multiple realizations of Vm, spikes and mean spike rate
plot_all(t_range, v_n, raster)

Click for solution

Example output:

Solution hint

Submit your feedback#

Hide code cell source
# @title Submit your feedback
content_review(f"{feedback_prefix}_Investigating_refractory_periods_Exercise")

Interactive Demo 1: Random refractory period#

In the following interactive demo, we will investigate the effect of random refractory periods. We will use random refactory periods \(t_{ref}\) with \(t_{ref} = \mu + \sigma\,\mathcal{N}\), where \(\mathcal{N}\) is the normal distribution, \(\mu=0.01\) and \(\sigma=0.007\).

Refractory period samples t_ref of size n is initialized with np.random.normal. We clip negative values to 0 with boolean indexes. (Why?) You can double-click the cell to see the hidden code.

You can play with the parameters mu and sigma and visualize the resulting simulation. What is the effect of different \(\sigma\) values?

Execute this cell to enable the demo

Hide code cell source
# @markdown Execute this cell to enable the demo

def random_ref_period(mu, sigma):
  # set random number generator
  np.random.seed(2020)

  # initialize step_end, t_range, n, v_n, syn and raster
  t_range = np.arange(0, t_max, dt)
  step_end = len(t_range)
  n = 500
  v_n = el * np.ones([n,step_end])
  syn = i_mean * (1 + 0.1*(t_max/dt)**(0.5)*(2*np.random.random([n,step_end])-1))
  raster = np.zeros([n,step_end])

  # initialize t_ref and last_spike
  t_ref = mu + sigma*np.random.normal(size=n)
  t_ref[t_ref<0] = 0
  last_spike = -t_ref * np.ones([n])

  # loop time steps
  for step, t in enumerate(t_range):
    if step==0:
      continue

    v_n[:,step] = v_n[:,step-1] + dt/tau * (el - v_n[:,step-1] + r*syn[:,step])

    # boolean array spiked indexes neurons with v>=vth
    spiked = (v_n[:,step] >= vth)
    v_n[spiked,step] = vr
    raster[spiked,step] = 1.

    # boolean array clamped indexes refractory neurons
    clamped = (last_spike + t_ref > t)
    v_n[clamped,step] = vr
    last_spike[spiked] = t

  # plot multiple realizations of Vm, spikes and mean spike rate
  plot_all(t_range, v_n, raster)

  # plot histogram of t_ref
  plt.figure(figsize=(8,4))
  plt.hist(t_ref, bins=32, histtype='stepfilled', linewidth=0, color='C1')
  plt.xlabel(r'$t_{ref}$ (s)')
  plt.ylabel('count')
  plt.tight_layout()

_ = widgets.interact(random_ref_period, mu = (0.01, 0.05, 0.01), \
                              sigma = (0.001, 0.02, 0.001))

Submit your feedback#

Hide code cell source
# @title Submit your feedback
content_review(f"{feedback_prefix}_Random_refractory_period_Interactive_Demo")

Section 5: Using functions#

Running key parts of your code inside functions improves your coding narrative by making it clearer and more flexible to future changes.

Video 9: Functions#

Submit your feedback#

Hide code cell source
# @title Submit your feedback
content_review(f"{feedback_prefix}_Functions_Video")

Video 10: Nano recap of functions#

Submit your feedback#

Hide code cell source
# @title Submit your feedback
content_review(f"{feedback_prefix}_Nano_recap_of_functions_Video")

Coding Exercise 6: Rewriting code with functions#

We will now re-organize parts of the code from the previous exercise with functions. You need to complete the function spike_clamp() to update \(V(t)\) and deal with spiking and refractoriness

def ode_step(v, i, dt):
  """
  Evolves membrane potential by one step of discrete time integration

  Args:
    v (numpy array of floats)
      membrane potential at previous time step of shape (neurons)

    i (numpy array of floats)
      synaptic input at current time step of shape (neurons)

    dt (float)
      time step increment

  Returns:
    v (numpy array of floats)
      membrane potential at current time step of shape (neurons)
  """
  v = v + dt/tau * (el - v + r*i)

  return v


def spike_clamp(v, delta_spike):
  """
  Resets membrane potential of neurons if v>= vth
  and clamps to vr if interval of time since last spike < t_ref

  Args:
    v (numpy array of floats)
      membrane potential of shape (neurons)

    delta_spike (numpy array of floats)
      interval of time since last spike of shape (neurons)

  Returns:
    v (numpy array of floats)
      membrane potential of shape (neurons)
    spiked (numpy array of floats)
      boolean array of neurons that spiked  of shape (neurons)
  """

  ####################################################
  ## TODO for students: complete spike_clamp
  # Fill out function and remove
  raise NotImplementedError("Student exercise: complete spike_clamp")
  ####################################################
  # Boolean array spiked indexes neurons with v>=vth
  spiked = ...
  v[spiked] = ...

  # Boolean array clamped indexes refractory neurons
  clamped = ...
  v[clamped] = ...

  return v, spiked


# Set random number generator
np.random.seed(2020)

# Initialize step_end, t_range, n, v_n and i
t_range = np.arange(0, t_max, dt)
step_end = len(t_range)
n = 500
v_n = el * np.ones([n, step_end])
i = i_mean * (1 + 0.1 * (t_max / dt)**(0.5) * (2 * np.random.random([n, step_end]) - 1))

# Initialize binary numpy array for raster plot
raster = np.zeros([n,step_end])

# Initialize t_ref and last_spike
mu = 0.01
sigma = 0.007
t_ref = mu + sigma*np.random.normal(size=n)
t_ref[t_ref<0] = 0
last_spike = -t_ref * np.ones([n])

# Loop over time steps
for step, t in enumerate(t_range):

  # Skip first iteration
  if step==0:
    continue

  # Compute v_n
  v_n[:,step] = ode_step(v_n[:,step-1], i[:,step], dt)

  # Reset membrane potential and clamp
  v_n[:,step], spiked = spike_clamp(v_n[:,step], t - last_spike)

  # Update raster and last_spike
  raster[spiked,step] = 1.
  last_spike[spiked] = t

# Plot multiple realizations of Vm, spikes and mean spike rate
plot_all(t_range, v_n, raster)

Click for solution

Example output:

Solution hint Solution hint

Submit your feedback#

Hide code cell source
# @title Submit your feedback
content_review(f"{feedback_prefix}_Rewriting_code_with_functions_Exercise")

Section 6: Using classes#

Video 11: Classes#

Submit your feedback#

Hide code cell source
# @title Submit your feedback
content_review(f"{feedback_prefix}_Classes_Video")

Using classes helps with code reuse and reliability. Well-designed classes are like black boxes, receiving inputs and providing expected outputs. The details of how the class processes inputs and produces outputs are unimportant.

See additional details here: A Beginner’s Python Tutorial/Classes

Attributes are variables internal to the class, and methods are functions internal to the class.

Video 12: Nano recap of classes#

Submit your feedback#

Hide code cell source
# @title Submit your feedback
content_review(f"{feedback_prefix}_Nano_recap_of_classes_Video")

Coding Exercise 7: Making a LIF class#

In this exercise we’ll practice with Python classes by implementing LIFNeurons, a class that evolves and keeps state of multiple realizations of LIF neurons.

Several attributes are used to keep state of our neurons:

self.v             current membrane potential
self.spiked        neurons that spiked
self.last_spike    last spike time of each neuron
self.t             running time of the simulation
self.steps         simulation step

There is a single method:

self.ode_step()    performs single step discrete time integration
                   for provided synaptic current and dt

Complete the spike and clamp part of method self.ode_step (should be similar to function spike_and_clamp seen before).

# Simulation class
class LIFNeurons:
  """
  Keeps track of membrane potential for multiple realizations of LIF neuron,
  and performs single step discrete time integration.
  """
  def __init__(self, n, t_ref_mu=0.01, t_ref_sigma=0.002,
               tau=20e-3, el=-60e-3, vr=-70e-3, vth=-50e-3, r=100e6):

    # Neuron count
    self.n = n

    # Neuron parameters
    self.tau = tau        # second
    self.el = el          # milivolt
    self.vr = vr          # milivolt
    self.vth = vth        # milivolt
    self.r = r            # ohm

    # Initializes refractory period distribution
    self.t_ref_mu = t_ref_mu
    self.t_ref_sigma = t_ref_sigma
    self.t_ref = self.t_ref_mu + self.t_ref_sigma * np.random.normal(size=self.n)
    self.t_ref[self.t_ref<0] = 0

    # State variables
    self.v = self.el * np.ones(self.n)
    self.spiked = self.v >= self.vth
    self.last_spike = -self.t_ref * np.ones([self.n])
    self.t = 0.
    self.steps = 0


  def ode_step(self, dt, i):

    # Update running time and steps
    self.t += dt
    self.steps += 1

    # One step of discrete time integration of dt
    self.v = self.v + dt / self.tau * (self.el - self.v + self.r * i)

    ####################################################
    ## TODO for students: complete the `ode_step` method
    # Fill out function and remove
    raise NotImplementedError("Student exercise: complete the ode_step method")
    ####################################################

    # Spike and clamp
    self.spiked = ...
    self.v[self.spiked] = ...
    self.last_spike[self.spiked] = ...
    clamped = ...
    self.v[clamped] = ...

    self.last_spike[self.spiked] = ...

# Set random number generator
np.random.seed(2020)

# Initialize step_end, t_range, n, v_n and i
t_range = np.arange(0, t_max, dt)
step_end = len(t_range)
n = 500
v_n = el * np.ones([n, step_end])
i = i_mean * (1 + 0.1 * (t_max / dt)**(0.5) * (2 * np.random.random([n, step_end]) - 1))

# Initialize binary numpy array for raster plot
raster = np.zeros([n,step_end])

# Initialize neurons
neurons = LIFNeurons(n)

# Loop over time steps
for step, t in enumerate(t_range):

  # Call ode_step method
  neurons.ode_step(dt, i[:,step])

  # Log v_n and spike history
  v_n[:,step] = neurons.v
  raster[neurons.spiked, step] = 1.

# Report running time and steps
print(f'Ran for {neurons.t:.3}s in {neurons.steps} steps.')

# Plot multiple realizations of Vm, spikes and mean spike rate
plot_all(t_range, v_n, raster)

Click for solution

Example output:

Solution hint

Submit your feedback#

Hide code cell source
# @title Submit your feedback
content_review(f"{feedback_prefix}_Making_a_LIF_class_Exercise")

Summary#

Video 12: Last concepts & recap#

Submit your feedback#

Hide code cell source
# @title Submit your feedback
content_review(f"{feedback_prefix}_Last_concepts_&_recap_Video")