Open In Colab   Open in Kaggle

Tutorial 1: Optimal Control for Discrete States#

Week 3, Day 3: Optimal Control

By Neuromatch Academy

Content creators: Zhengwei Wu, Itzel Olivos Castillo, Shreya Saxena, Xaq Pitkow

Content reviewers: Karolina Stosio, Roozbeh Farhoodi, Saeed Salehi, Ella Batty, Spiros Chavlis, Matt Krause, Michael Waskom, Melisa Maidana Capitan

Production editors: Gagana B, Spiros Chavlis


Tutorial Objectives#

Estimated timing of tutorial: 60 min

In this tutorial, we will implement a binary control task: a Partially Observable Markov Decision Process (POMDP) that describes fishing. The agent (you) seeks reward from two fishing sites without directly observing where the school of fish is (yes, a group of fish is called a school!). This makes the world a Hidden Markov Model (HMM), just like in the Hidden Dynamics day. Based on when and where you catch fish, you keep updating your belief about the fish location, i.e., the posterior of the fish given past observations. You should control your position to get the most fish while minimizing the cost of switching sides.

You’ve already learned about stochastic dynamics, latent states, and measurements. These first exercises largely repeat your previous work. Now we introduce actions, based on the new concepts of control, utility, and policy. This general structure provides a foundational model for the brain’s computations because it includes a perception-action loop where the animal can gather information, draw inferences about its environment, and select actions with the greatest benefit. How, mechanistically, the neurons could actually implement these calculations is a separate question we don’t address in this lesson.

In this tutorial, you will:

  • Use the Hidden Markov Models you learned about previously to model the world state.

  • Use the observations (fish caught) to build beliefs (posterior distributions) about the fish location.

  • Evaluate the quality of different control policies for choosing actions.

  • Discover the policy that maximizes utility.


Setup#

Install and import feedback gadget#

Hide code cell source
# @title Install and import feedback gadget

!pip3 install vibecheck datatops --quiet

from vibecheck import DatatopsContentReviewContainer
def content_review(notebook_section: str):
    return DatatopsContentReviewContainer(
        "",  # No text prompt
        notebook_section,
        {
            "url": "https://pmyvdlilci.execute-api.us-east-1.amazonaws.com/klab",
            "name": "neuromatch_cn",
            "user_key": "y1x3mpx5",
        },
    ).render()


feedback_prefix = "W3D3_T1"
# Imports
import numpy as np
from math import isclose
import matplotlib.pyplot as plt

Figure Settings#

Hide code cell source
# @title Figure Settings
import logging
logging.getLogger('matplotlib.font_manager').disabled = True

import ipywidgets as widgets
from IPython.display import HTML
%config InlineBackend.figure_format = 'retina'
plt.style.use("https://raw.githubusercontent.com/NeuromatchAcademy/course-content/main/nma.mplstyle")

Plotting Functions#

Hide code cell source
# @title Plotting Functions

def plot_fish(fish_state, ax=None, show=True):
  """
  Plot the fish dynamics (states across time)
  """
  T = len(fish_state)

  offset = 3

  if not ax:
    fig, ax = plt.subplots(1, 1, figsize=(12, 3.5))

  x = np.arange(0, T, 1)
  y = offset * (fish_state*2 - 1)

  ax.plot(y, color='cornflowerblue', markersize=10, linewidth=3.0, zorder=0)
  ax.fill_between(x, y, color='cornflowerblue', alpha=.3)

  ax.set_xlabel('time')
  ax.set_ylabel('fish location')

  ax.set_xlim([0, T])
  ax.set_xticks([])
  ax.xaxis.set_label_coords(1.05, .54)

  ax.set_ylim([-(offset+.5), offset+.5])
  ax.set_yticks([-offset, offset])
  ax.set_yticklabels(['left', 'right'])

  ax.spines['bottom'].set_position('center')
  if show:
    plt.show()


def plot_measurement(measurement, ax=None, show=True):
  """
  Plot the measurements
  """
  T = len(measurement)

  rel_pos = 3
  red_y = []
  blue_y = []
  for idx, value in enumerate(measurement):
    if value == 0:
      blue_y.append([idx, -rel_pos])
    else:
      red_y.append([idx, rel_pos])

  red_y = np.asa