Bonus Tutorial 4: The Kalman Filter, part 2
Contents
Bonus Tutorial 4: The Kalman Filter, part 2¶
Week 3, Day 2: Hidden Dynamics
By Neuromatch Academy
Content creators: Caroline Haimerl and Byron Galbraith
Content reviewers: Jesse Livezey, Matt Krause, Michael Waskom, and Xaq Pitkow
Post-production team: Gagana B, Spiros Chavlis
Important note: This is bonus material, included from NMA 2020. It has not been substantially revised for 2021. This means that the notation and standards are slightly different. We include it here because it provides additional information about how the Kalman filter works in two dimensions.
Useful references:
Roweis, Ghahramani (1998): A unifying review of linear Gaussian Models
Bishop (2006): Pattern Recognition and Machine Learning
Acknowledgements:
This tutorial is in part based on code originally created by Caroline Haimerl for Dr. Cristina Savin’s Probabilistic Time Series class at the Center for Data Science, New York University
Video 1: Introduction¶
Video available at https://youtu.be/6f_51L3i5aQ
Tutorial Objectives¶
In the previous tutorial we gained intuition for the Kalman filter in one dimension. In this tutorial, we will examine the two-dimensional Kalman filter and more of its mathematical foundations.
In this tutorial, you will:
Review linear dynamical systems
Implement the Kalman filter
Explore how the Kalman filter can be used to smooth data from an eye-tracking experiment
import sys
!conda install -c conda-forge ipywidgets --yes
Collecting package metadata (current_repodata.json): -
\
|
/
-
\
|
/
-
\
|
/
-
\
|
/
-
\
|
/
-
\
|
/
-
\
|
/
-
\
|
/
-
\
|
/
-
\
|
/
-
\
|
/
-
\
|
/
-
\
|
/
-
\
|
/
-
\
|
/
-
\
|
/
-
\
|
/
-
\
|
/
-
\
|
/
-
\
|
/
-
\
|
/
-
\
|
/
-
\
|
/
-
\
|
/
-
\
|
/
-
\
|
/
-
\
|
/
-
\
|
/
-
\
|
/
-
\
|
/
-
\
|
/
-
\
|
/
-
\
|
/
-
\
|
/
-
\
|
/
-
\
|
/
-
\
|
/
-
\
|
/
-
\
|
/
-
\
done
Solving environment: /
-
\
|
/
-
\
|
/
-
\
|
/
-
\
|
/
-
\
|
/
-
\
|
/
-
\
|
/
-
\
|
/
-
\
|
/
-
\
|
/
-
\
|
/
done
==> WARNING: A newer version of conda exists. <==
current version: 22.11.1
latest version: 23.1.0
Please update conda by running
$ conda update -n base -c defaults conda
Or to minimize the number of packages updated during conda update use
conda install conda=23.1.0
## Package Plan ##
environment location: /usr/share/miniconda
added / updated specs:
- ipywidgets
The following packages will be downloaded:
package | build
---------------------------|-----------------
asttokens-2.2.1 | pyhd8ed1ab_0 27 KB conda-forge
attrs-22.2.0 | pyh71513ae_0 53 KB conda-forge
backcall-0.2.0 | pyh9f0ad1d_0 13 KB conda-forge
backports-1.0 | pyhd8ed1ab_3 6 KB conda-forge
backports.functools_lru_cache-1.6.4| pyhd8ed1ab_0 9 KB conda-forge
ca-certificates-2022.12.7 | ha878542_0 143 KB conda-forge
certifi-2022.12.7 | pyhd8ed1ab_0 147 KB conda-forge
comm-0.1.2 | pyhd8ed1ab_0 11 KB conda-forge
conda-22.11.1 | py310hff52083_1 913 KB conda-forge
debugpy-1.5.1 | py310h295c915_0 1.7 MB
decorator-5.1.1 | pyhd8ed1ab_0 12 KB conda-forge
entrypoints-0.4 | pyhd8ed1ab_0 9 KB conda-forge
executing-1.2.0 | pyhd8ed1ab_0 24 KB conda-forge
importlib-metadata-6.0.0 | pyha770c72_0 24 KB conda-forge
importlib_resources-5.10.2 | pyhd8ed1ab_0 30 KB conda-forge
ipykernel-6.20.2 | pyh210e3f2_0 108 KB conda-forge
ipython-8.8.0 | pyh41d4057_0 555 KB conda-forge
ipywidgets-8.0.4 | pyhd8ed1ab_0 109 KB conda-forge
jedi-0.18.2 | pyhd8ed1ab_0 786 KB conda-forge
jsonschema-4.17.3 | pyhd8ed1ab_0 69 KB conda-forge
jupyter_client-7.3.4 | pyhd8ed1ab_0 91 KB conda-forge
jupyter_core-5.1.5 | py310hff52083_0 88 KB conda-forge
jupyterlab_widgets-3.0.5 | pyhd8ed1ab_0 169 KB conda-forge
libsodium-1.0.18 | h36c2ea0_1 366 KB conda-forge
matplotlib-inline-0.1.6 | pyhd8ed1ab_0 12 KB conda-forge
nbformat-5.7.3 | pyhd8ed1ab_0 98 KB conda-forge
nest-asyncio-1.5.6 | pyhd8ed1ab_0 10 KB conda-forge
packaging-23.0 | pyhd8ed1ab_0 40 KB conda-forge
parso-0.8.3 | pyhd8ed1ab_0 69 KB conda-forge
pexpect-4.8.0 | pyh1a96a4e_2 48 KB conda-forge
pickleshare-0.7.5 | py_1003 9 KB conda-forge
pkgutil-resolve-name-1.3.10| pyhd8ed1ab_0 9 KB conda-forge
platformdirs-2.6.2 | pyhd8ed1ab_0 17 KB conda-forge
prompt-toolkit-3.0.36 | pyha770c72_0 265 KB conda-forge
psutil-5.9.0 | py310h5eee18b_0 368 KB
ptyprocess-0.7.0 | pyhd3deb0d_0 16 KB conda-forge
pure_eval-0.2.2 | pyhd8ed1ab_0 14 KB conda-forge
pygments-2.14.0 | pyhd8ed1ab_0 805 KB conda-forge
pyrsistent-0.18.0 | py310h7f8727e_0 117 KB
python-dateutil-2.8.2 | pyhd8ed1ab_0 240 KB conda-forge
python-fastjsonschema-2.16.2| pyhd8ed1ab_0 242 KB conda-forge
python_abi-3.10 | 2_cp310 4 KB conda-forge
pyzmq-23.2.0 | py310h6a678d5_0 1017 KB
stack_data-0.6.2 | pyhd8ed1ab_0 26 KB conda-forge
tornado-6.1 | py310h5764c6d_3 657 KB conda-forge
traitlets-5.8.1 | pyhd8ed1ab_0 96 KB conda-forge
typing-extensions-4.4.0 | hd8ed1ab_0 8 KB conda-forge
typing_extensions-4.4.0 | pyha770c72_0 29 KB conda-forge
wcwidth-0.2.6 | pyhd8ed1ab_0 28 KB conda-forge
widgetsnbextension-4.0.5 | pyhd8ed1ab_0 805 KB conda-forge
zeromq-4.3.4 | h9c3ff4c_1 351 KB conda-forge
zipp-3.11.0 | pyhd8ed1ab_0 15 KB conda-forge
------------------------------------------------------------
Total: 10.7 MB
The following NEW packages will be INSTALLED:
asttokens conda-forge/noarch::asttokens-2.2.1-pyhd8ed1ab_0
attrs conda-forge/noarch::attrs-22.2.0-pyh71513ae_0
backcall conda-forge/noarch::backcall-0.2.0-pyh9f0ad1d_0
backports conda-forge/noarch::backports-1.0-pyhd8ed1ab_3
backports.functoo~ conda-forge/noarch::backports.functools_lru_cache-1.6.4-pyhd8ed1ab_0
comm conda-forge/noarch::comm-0.1.2-pyhd8ed1ab_0
debugpy pkgs/main/linux-64::debugpy-1.5.1-py310h295c915_0
decorator conda-forge/noarch::decorator-5.1.1-pyhd8ed1ab_0
entrypoints conda-forge/noarch::entrypoints-0.4-pyhd8ed1ab_0
executing conda-forge/noarch::executing-1.2.0-pyhd8ed1ab_0
importlib-metadata conda-forge/noarch::importlib-metadata-6.0.0-pyha770c72_0
importlib_resourc~ conda-forge/noarch::importlib_resources-5.10.2-pyhd8ed1ab_0
ipykernel conda-forge/noarch::ipykernel-6.20.2-pyh210e3f2_0
ipython conda-forge/noarch::ipython-8.8.0-pyh41d4057_0
ipywidgets conda-forge/noarch::ipywidgets-8.0.4-pyhd8ed1ab_0
jedi conda-forge/noarch::jedi-0.18.2-pyhd8ed1ab_0
jsonschema conda-forge/noarch::jsonschema-4.17.3-pyhd8ed1ab_0
jupyter_client conda-forge/noarch::jupyter_client-7.3.4-pyhd8ed1ab_0
jupyter_core conda-forge/linux-64::jupyter_core-5.1.5-py310hff52083_0
jupyterlab_widgets conda-forge/noarch::jupyterlab_widgets-3.0.5-pyhd8ed1ab_0
libsodium conda-forge/linux-64::libsodium-1.0.18-h36c2ea0_1
matplotlib-inline conda-forge/noarch::matplotlib-inline-0.1.6-pyhd8ed1ab_0
nbformat conda-forge/noarch::nbformat-5.7.3-pyhd8ed1ab_0
nest-asyncio conda-forge/noarch::nest-asyncio-1.5.6-pyhd8ed1ab_0
packaging conda-forge/noarch::packaging-23.0-pyhd8ed1ab_0
parso conda-forge/noarch::parso-0.8.3-pyhd8ed1ab_0
pexpect conda-forge/noarch::pexpect-4.8.0-pyh1a96a4e_2
pickleshare conda-forge/noarch::pickleshare-0.7.5-py_1003
pkgutil-resolve-n~ conda-forge/noarch::pkgutil-resolve-name-1.3.10-pyhd8ed1ab_0
platformdirs conda-forge/noarch::platformdirs-2.6.2-pyhd8ed1ab_0
prompt-toolkit conda-forge/noarch::prompt-toolkit-3.0.36-pyha770c72_0
psutil pkgs/main/linux-64::psutil-5.9.0-py310h5eee18b_0
ptyprocess conda-forge/noarch::ptyprocess-0.7.0-pyhd3deb0d_0
pure_eval conda-forge/noarch::pure_eval-0.2.2-pyhd8ed1ab_0
pygments conda-forge/noarch::pygments-2.14.0-pyhd8ed1ab_0
pyrsistent pkgs/main/linux-64::pyrsistent-0.18.0-py310h7f8727e_0
python-dateutil conda-forge/noarch::python-dateutil-2.8.2-pyhd8ed1ab_0
python-fastjsonsc~ conda-forge/noarch::python-fastjsonschema-2.16.2-pyhd8ed1ab_0
python_abi conda-forge/linux-64::python_abi-3.10-2_cp310
pyzmq pkgs/main/linux-64::pyzmq-23.2.0-py310h6a678d5_0
stack_data conda-forge/noarch::stack_data-0.6.2-pyhd8ed1ab_0
tornado conda-forge/linux-64::tornado-6.1-py310h5764c6d_3
traitlets conda-forge/noarch::traitlets-5.8.1-pyhd8ed1ab_0
typing-extensions conda-forge/noarch::typing-extensions-4.4.0-hd8ed1ab_0
typing_extensions conda-forge/noarch::typing_extensions-4.4.0-pyha770c72_0
wcwidth conda-forge/noarch::wcwidth-0.2.6-pyhd8ed1ab_0
widgetsnbextension conda-forge/noarch::widgetsnbextension-4.0.5-pyhd8ed1ab_0
zeromq conda-forge/linux-64::zeromq-4.3.4-h9c3ff4c_1
zipp conda-forge/noarch::zipp-3.11.0-pyhd8ed1ab_0
The following packages will be UPDATED:
ca-certificates pkgs/main::ca-certificates-2022.10.11~ --> conda-forge::ca-certificates-2022.12.7-ha878542_0
The following packages will be SUPERSEDED by a higher-priority channel:
certifi pkgs/main/linux-64::certifi-2022.12.7~ --> conda-forge/noarch::certifi-2022.12.7-pyhd8ed1ab_0
conda pkgs/main::conda-22.11.1-py310h06a430~ --> conda-forge::conda-22.11.1-py310hff52083_1
Downloading and Extracting Packages
python-fastjsonschem | 242 KB | | 0%
jupyterlab_widgets-3 | 169 KB | | 0%
decorator-5.1.1 | 12 KB | | 0%
zeromq-4.3.4 | 351 KB | | 0%
nest-asyncio-1.5.6 | 10 KB | | 0%
pyrsistent-0.18.0 | 117 KB | | 0%
debugpy-1.5.1 | 1.7 MB | | 0%
traitlets-5.8.1 | 96 KB | | 0%
pexpect-4.8.0 | 48 KB | | 0%
zipp-3.11.0 | 15 KB | | 0%
platformdirs-2.6.2 | 17 KB | | 0%
ptyprocess-0.7.0 | 16 KB | | 0%
python_abi-3.10 | 4 KB | | 0%
packaging-23.0 | 40 KB | | 0%
backcall-0.2.0 | 13 KB | | 0%
importlib_resources- | 30 KB | | 0%
jedi-0.18.2 | 786 KB | | 0%
ipython-8.8.0 | 555 KB | | 0%
pygments-2.14.0 | 805 KB | | 0%
pkgutil-resolve-name | 9 KB | | 0%
comm-0.1.2 | 11 KB | | 0%
importlib-metadata-6 | 24 KB | | 0%
jupyter_client-7.3.4 | 91 KB | | 0%
... (more hidden) ...
nest-asyncio-1.5.6 | 10 KB | ##################################### | 100%
decorator-5.1.1 | 12 KB | ##################################### | 100%
decorator-5.1.1 | 12 KB | ##################################### | 100%
pexpect-4.8.0 | 48 KB | ############4 | 34%
zipp-3.11.0 | 15 KB | ##################################### | 100%
pyrsistent-0.18.0 | 117 KB | ##### | 14%
traitlets-5.8.1 | 96 KB | ######1 | 17%
zeromq-4.3.4 | 351 KB | ##################################### | 100%
zeromq-4.3.4 | 351 KB | ##################################### | 100%
ptyprocess-0.7.0 | 16 KB | ####################################6 | 99%
debugpy-1.5.1 | 1.7 MB | 3 | 1%
platformdirs-2.6.2 | 17 KB | ################################### | 95%
python_abi-3.10 | 4 KB | ##################################### | 100%
backcall-0.2.0 | 13 KB | ##################################### | 100%
packaging-23.0 | 40 KB | ##############9 | 40%
importlib_resources- | 30 KB | ###################9 | 54%
ipython-8.8.0 | 555 KB | # | 3%
jedi-0.18.2 | 786 KB | 7 | 2%
pkgutil-resolve-name | 9 KB | ##################################### | 100%
pygments-2.14.0 | 805 KB | 7 | 2%
comm-0.1.2 | 11 KB | ##################################### | 100%
importlib-metadata-6 | 24 KB | ########################6 | 67%
... (more hidden) ...
jupyter_client-7.3.4 | 91 KB | ######5 | 18%
jupyterlab_widgets-3 | 169 KB | ##################################### | 100%
jupyterlab_widgets-3 | 169 KB | ##################################### | 100%
python-fastjsonschem | 242 KB | ##################################### | 100%
python-fastjsonschem | 242 KB | ##################################### | 100%
nest-asyncio-1.5.6 | 10 KB | ##################################### | 100%
zipp-3.11.0 | 15 KB | ##################################### | 100%
pexpect-4.8.0 | 48 KB | ##################################### | 100%
pexpect-4.8.0 | 48 KB | ##################################### | 100%
traitlets-5.8.1 | 96 KB | ##################################### | 100%
traitlets-5.8.1 | 96 KB | ##################################### | 100%
pyrsistent-0.18.0 | 117 KB | ##################################### | 100%
pyrsistent-0.18.0 | 117 KB | ##################################### | 100%
ptyprocess-0.7.0 | 16 KB | ##################################### | 100%
platformdirs-2.6.2 | 17 KB | ##################################### | 100%
python_abi-3.10 | 4 KB | ##################################### | 100%
backcall-0.2.0 | 13 KB | ##################################### | 100%
packaging-23.0 | 40 KB | ##################################### | 100%
packaging-23.0 | 40 KB | ##################################### | 100%
importlib_resources- | 30 KB | ##################################### | 100%
importlib_resources- | 30 KB | ##################################### | 100%
ipython-8.8.0 | 555 KB | ##################################### | 100%
ipython-8.8.0 | 555 KB | ##################################### | 100%
pkgutil-resolve-name | 9 KB | ##################################### | 100%
jedi-0.18.2 | 786 KB | ##################################### | 100%
jedi-0.18.2 | 786 KB | ##################################### | 100%
debugpy-1.5.1 | 1.7 MB | ##################################### | 100%
debugpy-1.5.1 | 1.7 MB | ##################################### | 100%
comm-0.1.2 | 11 KB | ##################################### | 100%
importlib-metadata-6 | 24 KB | ##################################### | 100%
importlib-metadata-6 | 24 KB | ##################################### | 100%
pygments-2.14.0 | 805 KB | ##################################### | 100%
pygments-2.14.0 | 805 KB | ##################################### | 100%
... (more hidden) ...
... (more hidden) ...
jupyter_client-7.3.4 | 91 KB | ##################################### | 100%
jupyter_client-7.3.4 | 91 KB | ##################################### | 100%
Preparing transaction: \
|
/
done
Verifying transaction: \
|
/
-
\
|
done
Executing transaction: -
\
|
/
-
\
|
/
-
\
|
/
-
\
|
/
-
\
|
/
-
\
|
/
-
\
|
/
-
\
|
/
-
\
|
/
-
\
|
/
-
\
|
done
!conda install numpy matplotlib scipy requests --yes
Collecting package metadata (current_repodata.json): -
\
|
/
-
\
|
/
-
\
done
Solving environment: /
-
\
|
/
-
\
|
/
-
\
|
/
-
\
|
/
-
\
|
/
-
\
|
/
-
\
done
## Package Plan ##
environment location: /usr/share/miniconda
added / updated specs:
- matplotlib
- numpy
- requests
- scipy
The following packages will be downloaded:
package | build
---------------------------|-----------------
blas-1.0 | mkl 6 KB
brotli-1.0.9 | h5eee18b_7 18 KB
brotli-bin-1.0.9 | h5eee18b_7 19 KB
ca-certificates-2023.01.10 | h06a4308_0 120 KB
conda-23.1.0 | py310h06a4308_0 953 KB
contourpy-1.0.5 | py310hdb19cb5_0 204 KB
cycler-0.11.0 | pyhd3eb1b0_0 12 KB
dbus-1.13.18 | hb2f20db_0 504 KB
expat-2.4.9 | h6a678d5_0 156 KB
fftw-3.3.9 | h27cfd23_1 2.3 MB
fontconfig-2.14.1 | h52c9d5c_1 281 KB
fonttools-4.25.0 | pyhd3eb1b0_0 632 KB
freetype-2.12.1 | h4a9f257_0 626 KB
giflib-5.2.1 | h7b6447c_0 78 KB
glib-2.69.1 | he621ea3_2 1.9 MB
gst-plugins-base-1.14.0 | h8213a91_2 4.9 MB
gstreamer-1.14.0 | h28cd5cc_2 3.2 MB
icu-58.2 | he6710b0_3 10.5 MB
intel-openmp-2021.4.0 | h06a4308_3561 4.2 MB
jpeg-9e | h7f8727e_0 240 KB
kiwisolver-1.4.4 | py310h6a678d5_0 76 KB
krb5-1.19.4 | h568e23c_0 1.3 MB
lcms2-2.12 | h3be6417_0 312 KB
lerc-3.0 | h295c915_0 196 KB
libbrotlicommon-1.0.9 | h5eee18b_7 70 KB
libbrotlidec-1.0.9 | h5eee18b_7 31 KB
libbrotlienc-1.0.9 | h5eee18b_7 264 KB
libclang-10.0.1 |default_hb85057a_2 10.8 MB
libdeflate-1.8 | h7f8727e_5 51 KB
libedit-3.1.20221030 | h5eee18b_0 181 KB
libevent-2.1.12 | h8f2d780_0 425 KB
libgfortran-ng-11.2.0 | h00389a5_1 20 KB
libgfortran5-11.2.0 | h1234567_1 2.0 MB
libllvm10-10.0.1 | hbcb73fb_5 22.1 MB
libpng-1.6.37 | hbc83047_0 278 KB
libpq-12.9 | h16c4e8d_3 2.1 MB
libtiff-4.5.0 | hecacb30_0 528 KB
libwebp-1.2.4 | h11a3e52_0 79 KB
libwebp-base-1.2.4 | h5eee18b_0 347 KB
libxcb-1.15 | h7f8727e_0 505 KB
libxkbcommon-1.0.1 | hfa300c1_0 483 KB
libxml2-2.9.14 | h74e7548_0 718 KB
libxslt-1.1.35 | h4e12654_0 453 KB
lz4-c-1.9.4 | h6a678d5_0 154 KB
matplotlib-3.6.2 | py310h06a4308_0 8 KB
matplotlib-base-3.6.2 | py310h945d387_0 6.6 MB
mkl-2021.4.0 | h06a4308_640 142.6 MB
mkl-service-2.4.0 | py310h7f8727e_0 177 KB
mkl_fft-1.3.1 | py310hd6ae3a3_0 567 KB
mkl_random-1.2.2 | py310h00e6091_0 1009 KB
munkres-1.1.4 | py_0 13 KB
nspr-4.33 | h295c915_0 222 KB
nss-3.74 | h0370c37_0 1.9 MB
numpy-1.23.5 | py310hd5efca6_0 10 KB
numpy-base-1.23.5 | py310h8e6c178_0 6.7 MB
pcre-8.45 | h295c915_0 207 KB
pillow-9.3.0 | py310hace64e9_1 729 KB
ply-3.11 | py310h06a4308_0 80 KB
pyparsing-3.0.9 | py310h06a4308_0 153 KB
pyqt-5.15.7 | py310h6a678d5_1 5.1 MB
pyqt5-sip-12.11.0 | py310h6a678d5_1 277 KB
qt-main-5.15.2 | h327a75a_7 45.1 MB
qt-webengine-5.15.9 | hd2b0992_4 47.1 MB
qtwebkit-5.212 | h4eab89a_4 14.3 MB
scipy-1.9.3 | py310hd5efca6_0 23.1 MB
sip-6.6.2 | py310h6a678d5_0 692 KB
toml-0.10.2 | pyhd3eb1b0_0 20 KB
zstd-1.5.2 | ha4553b6_0 488 KB
------------------------------------------------------------
Total: 371.1 MB
The following NEW packages will be INSTALLED:
blas pkgs/main/linux-64::blas-1.0-mkl
brotli pkgs/main/linux-64::brotli-1.0.9-h5eee18b_7
brotli-bin pkgs/main/linux-64::brotli-bin-1.0.9-h5eee18b_7
contourpy pkgs/main/linux-64::contourpy-1.0.5-py310hdb19cb5_0
cycler pkgs/main/noarch::cycler-0.11.0-pyhd3eb1b0_0
dbus pkgs/main/linux-64::dbus-1.13.18-hb2f20db_0
expat pkgs/main/linux-64::expat-2.4.9-h6a678d5_0
fftw pkgs/main/linux-64::fftw-3.3.9-h27cfd23_1
fontconfig pkgs/main/linux-64::fontconfig-2.14.1-h52c9d5c_1
fonttools pkgs/main/noarch::fonttools-4.25.0-pyhd3eb1b0_0
freetype pkgs/main/linux-64::freetype-2.12.1-h4a9f257_0
giflib pkgs/main/linux-64::giflib-5.2.1-h7b6447c_0
glib pkgs/main/linux-64::glib-2.69.1-he621ea3_2
gst-plugins-base pkgs/main/linux-64::gst-plugins-base-1.14.0-h8213a91_2
gstreamer pkgs/main/linux-64::gstreamer-1.14.0-h28cd5cc_2
icu pkgs/main/linux-64::icu-58.2-he6710b0_3
intel-openmp pkgs/main/linux-64::intel-openmp-2021.4.0-h06a4308_3561
jpeg pkgs/main/linux-64::jpeg-9e-h7f8727e_0
kiwisolver pkgs/main/linux-64::kiwisolver-1.4.4-py310h6a678d5_0
krb5 pkgs/main/linux-64::krb5-1.19.4-h568e23c_0
lcms2 pkgs/main/linux-64::lcms2-2.12-h3be6417_0
lerc pkgs/main/linux-64::lerc-3.0-h295c915_0
libbrotlicommon pkgs/main/linux-64::libbrotlicommon-1.0.9-h5eee18b_7
libbrotlidec pkgs/main/linux-64::libbrotlidec-1.0.9-h5eee18b_7
libbrotlienc pkgs/main/linux-64::libbrotlienc-1.0.9-h5eee18b_7
libclang pkgs/main/linux-64::libclang-10.0.1-default_hb85057a_2
libdeflate pkgs/main/linux-64::libdeflate-1.8-h7f8727e_5
libedit pkgs/main/linux-64::libedit-3.1.20221030-h5eee18b_0
libevent pkgs/main/linux-64::libevent-2.1.12-h8f2d780_0
libgfortran-ng pkgs/main/linux-64::libgfortran-ng-11.2.0-h00389a5_1
libgfortran5 pkgs/main/linux-64::libgfortran5-11.2.0-h1234567_1
libllvm10 pkgs/main/linux-64::libllvm10-10.0.1-hbcb73fb_5
libpng pkgs/main/linux-64::libpng-1.6.37-hbc83047_0
libpq pkgs/main/linux-64::libpq-12.9-h16c4e8d_3
libtiff pkgs/main/linux-64::libtiff-4.5.0-hecacb30_0
libwebp pkgs/main/linux-64::libwebp-1.2.4-h11a3e52_0
libwebp-base pkgs/main/linux-64::libwebp-base-1.2.4-h5eee18b_0
libxcb pkgs/main/linux-64::libxcb-1.15-h7f8727e_0
libxkbcommon pkgs/main/linux-64::libxkbcommon-1.0.1-hfa300c1_0
libxml2 pkgs/main/linux-64::libxml2-2.9.14-h74e7548_0
libxslt pkgs/main/linux-64::libxslt-1.1.35-h4e12654_0
lz4-c pkgs/main/linux-64::lz4-c-1.9.4-h6a678d5_0
matplotlib pkgs/main/linux-64::matplotlib-3.6.2-py310h06a4308_0
matplotlib-base pkgs/main/linux-64::matplotlib-base-3.6.2-py310h945d387_0
mkl pkgs/main/linux-64::mkl-2021.4.0-h06a4308_640
mkl-service pkgs/main/linux-64::mkl-service-2.4.0-py310h7f8727e_0
mkl_fft pkgs/main/linux-64::mkl_fft-1.3.1-py310hd6ae3a3_0
mkl_random pkgs/main/linux-64::mkl_random-1.2.2-py310h00e6091_0
munkres pkgs/main/noarch::munkres-1.1.4-py_0
nspr pkgs/main/linux-64::nspr-4.33-h295c915_0
nss pkgs/main/linux-64::nss-3.74-h0370c37_0
numpy pkgs/main/linux-64::numpy-1.23.5-py310hd5efca6_0
numpy-base pkgs/main/linux-64::numpy-base-1.23.5-py310h8e6c178_0
pcre pkgs/main/linux-64::pcre-8.45-h295c915_0
pillow pkgs/main/linux-64::pillow-9.3.0-py310hace64e9_1
ply pkgs/main/linux-64::ply-3.11-py310h06a4308_0
pyparsing pkgs/main/linux-64::pyparsing-3.0.9-py310h06a4308_0
pyqt pkgs/main/linux-64::pyqt-5.15.7-py310h6a678d5_1
pyqt5-sip pkgs/main/linux-64::pyqt5-sip-12.11.0-py310h6a678d5_1
qt-main pkgs/main/linux-64::qt-main-5.15.2-h327a75a_7
qt-webengine pkgs/main/linux-64::qt-webengine-5.15.9-hd2b0992_4
qtwebkit pkgs/main/linux-64::qtwebkit-5.212-h4eab89a_4
scipy pkgs/main/linux-64::scipy-1.9.3-py310hd5efca6_0
sip pkgs/main/linux-64::sip-6.6.2-py310h6a678d5_0
toml pkgs/main/noarch::toml-0.10.2-pyhd3eb1b0_0
zstd pkgs/main/linux-64::zstd-1.5.2-ha4553b6_0
The following packages will be UPDATED:
ca-certificates conda-forge::ca-certificates-2022.12.~ --> pkgs/main::ca-certificates-2023.01.10-h06a4308_0
conda conda-forge::conda-22.11.1-py310hff52~ --> pkgs/main::conda-23.1.0-py310h06a4308_0
The following packages will be SUPERSEDED by a higher-priority channel:
certifi conda-forge/noarch::certifi-2022.12.7~ --> pkgs/main/linux-64::certifi-2022.12.7-py310h06a4308_0
Downloading and Extracting Packages
glib-2.69.1 | 1.9 MB | | 0%
krb5-1.19.4 | 1.3 MB | | 0%
libpng-1.6.37 | 278 KB | | 0%
ply-3.11 | 80 KB | | 0%
mkl-service-2.4.0 | 177 KB | | 0%
fontconfig-2.14.1 | 281 KB | | 0%
libgfortran5-11.2.0 | 2.0 MB | | 0%
gstreamer-1.14.0 | 3.2 MB | | 0%
mkl_fft-1.3.1 | 567 KB | | 0%
scipy-1.9.3 | 23.1 MB | | 0%
pcre-8.45 | 207 KB | | 0%
fftw-3.3.9 | 2.3 MB | | 0%
qt-webengine-5.15.9 | 47.1 MB | | 0%
pyparsing-3.0.9 | 153 KB | | 0%
pyqt5-sip-12.11.0 | 277 KB | | 0%
intel-openmp-2021.4. | 4.2 MB | | 0%
libxcb-1.15 | 505 KB | | 0%
matplotlib-3.6.2 | 8 KB | | 0%
libgfortran-ng-11.2. | 20 KB | | 0%
libdeflate-1.8 | 51 KB | | 0%
pyqt-5.15.7 | 5.1 MB | | 0%
conda-23.1.0 | 953 KB | | 0%
mkl_random-1.2.2 | 1009 KB | | 0%
... (more hidden) ...
krb5-1.19.4 | 1.3 MB | ################################4 | 88%
libgfortran5-11.2.0 | 2.0 MB | 2 | 1%
mkl_fft-1.3.1 | 567 KB | ############5 | 34%
gstreamer-1.14.0 | 3.2 MB | 1 | 0%
scipy-1.9.3 | 23.1 MB | | 0%
pcre-8.45 | 207 KB | ##8 | 8%
ply-3.11 | 80 KB | #######3 | 20%
qt-webengine-5.15.9 | 47.1 MB | | 0%
fftw-3.3.9 | 2.3 MB | 2 | 1%
libpng-1.6.37 | 278 KB | ##################################### | 100%
libpng-1.6.37 | 278 KB | ##################################### | 100%
pyparsing-3.0.9 | 153 KB | ###8 | 10%
pyqt5-sip-12.11.0 | 277 KB | ##1 | 6%
intel-openmp-2021.4. | 4.2 MB | 1 | 0%
scipy-1.9.3 | 23.1 MB | ## | 5%
libxcb-1.15 | 505 KB | #1 | 3%
qt-webengine-5.15.9 | 47.1 MB | # | 3%
fftw-3.3.9 | 2.3 MB | ############################7 | 78%
matplotlib-3.6.2 | 8 KB | ##################################### | 100%
libdeflate-1.8 | 51 KB | ###########5 | 31%
libgfortran-ng-11.2. | 20 KB | #############################9 | 81%
intel-openmp-2021.4. | 4.2 MB | ############################ | 76%
scipy-1.9.3 | 23.1 MB | ##### | 14%
mkl-service-2.4.0 | 177 KB | ##################################### | 100%
conda-23.1.0 | 953 KB | 6 | 2%
pyqt-5.15.7 | 5.1 MB | 1 | 0%
mkl-service-2.4.0 | 177 KB | ##################################### | 100%
qt-webengine-5.15.9 | 47.1 MB | ##2 | 6%
mkl_random-1.2.2 | 1009 KB | 5 | 2%
scipy-1.9.3 | 23.1 MB | ########8 | 24%
... (more hidden) ...
pyqt-5.15.7 | 5.1 MB | ##############8 | 40%
qt-webengine-5.15.9 | 47.1 MB | ###7 | 10%
... (more hidden) ...
scipy-1.9.3 | 23.1 MB | ###########5 | 31%
pyqt-5.15.7 | 5.1 MB | #######################4 | 63%
glib-2.69.1 | 1.9 MB | ##################################### | 100%
glib-2.69.1 | 1.9 MB | ##################################### | 100%
qt-webengine-5.15.9 | 47.1 MB | #####1 | 14%
... (more hidden) ...
scipy-1.9.3 | 23.1 MB | ############## | 38%
pyqt-5.15.7 | 5.1 MB | ###################################6 | 96%
qt-webengine-5.15.9 | 47.1 MB | ######4 | 17%
... (more hidden) ...
krb5-1.19.4 | 1.3 MB | ##################################### | 100%
scipy-1.9.3 | 23.1 MB | #################7 | 48%
qt-webengine-5.15.9 | 47.1 MB | ########1 | 22%
... (more hidden) ...
fontconfig-2.14.1 | 281 KB | ##################################### | 100%
scipy-1.9.3 | 23.1 MB | ####################9 | 57%
fontconfig-2.14.1 | 281 KB | ##################################### | 100%
qt-webengine-5.15.9 | 47.1 MB | #########5 | 26%
... (more hidden) ...
scipy-1.9.3 | 23.1 MB | ######################### | 68%
mkl_fft-1.3.1 | 567 KB | ##################################### | 100%
mkl_fft-1.3.1 | 567 KB | ##################################### | 100%
qt-webengine-5.15.9 | 47.1 MB | ##########9 | 30%
... (more hidden) ...
scipy-1.9.3 | 23.1 MB | ############################2 | 76%
... (more hidden) ...
qt-webengine-5.15.9 | 47.1 MB | ############2 | 33%
pcre-8.45 | 207 KB | ##################################### | 100%
pcre-8.45 | 207 KB | ##################################### | 100%
scipy-1.9.3 | 23.1 MB | ###############################7 | 86%
... (more hidden) ...
qt-webengine-5.15.9 | 47.1 MB | #############6 | 37%
ply-3.11 | 80 KB | ##################################### | 100%
ply-3.11 | 80 KB | ##################################### | 100%
scipy-1.9.3 | 23.1 MB | ################################### | 95%
... (more hidden) ...
qt-webengine-5.15.9 | 47.1 MB | ###############2 | 41%
... (more hidden) ...
libgfortran5-11.2.0 | 2.0 MB | ##################################### | 100%
libgfortran5-11.2.0 | 2.0 MB | ##################################### | 100%
qt-webengine-5.15.9 | 47.1 MB | ################6 | 45%
... (more hidden) ...
qt-webengine-5.15.9 | 47.1 MB | #################9 | 48%
gstreamer-1.14.0 | 3.2 MB | ##################################### | 100%
gstreamer-1.14.0 | 3.2 MB | ##################################### | 100%
... (more hidden) ...
qt-webengine-5.15.9 | 47.1 MB | ###################9 | 54%
... (more hidden) ...
pyparsing-3.0.9 | 153 KB | ##################################### | 100%
pyparsing-3.0.9 | 153 KB | ##################################### | 100%
qt-webengine-5.15.9 | 47.1 MB | #####################4 | 58%
... (more hidden) ...
qt-webengine-5.15.9 | 47.1 MB | #######################5 | 64%
pyqt5-sip-12.11.0 | 277 KB | ##################################### | 100%
pyqt5-sip-12.11.0 | 277 KB | ##################################### | 100%
... (more hidden) ...
qt-webengine-5.15.9 | 47.1 MB | #########################4 | 69%
libxcb-1.15 | 505 KB | ##################################### | 100%
libxcb-1.15 | 505 KB | ##################################### | 100%
... (more hidden) ...
qt-webengine-5.15.9 | 47.1 MB | ########################### | 73%
... (more hidden) ...
matplotlib-3.6.2 | 8 KB | ##################################### | 100%
qt-webengine-5.15.9 | 47.1 MB | #############################1 | 79%
... (more hidden) ...
qt-webengine-5.15.9 | 47.1 MB | ##############################7 | 83%
... (more hidden) ...
fftw-3.3.9 | 2.3 MB | ##################################### | 100%
qt-webengine-5.15.9 | 47.1 MB | ################################8 | 89%
... (more hidden) ...
libdeflate-1.8 | 51 KB | ##################################### | 100%
libdeflate-1.8 | 51 KB | ##################################### | 100%
... (more hidden) ...
qt-webengine-5.15.9 | 47.1 MB | ##################################5 | 93%
libgfortran-ng-11.2. | 20 KB | ##################################### | 100%
qt-webengine-5.15.9 | 47.1 MB | ####################################3 | 98%
... (more hidden) ...
... (more hidden) ...
intel-openmp-2021.4. | 4.2 MB | ##################################### | 100%
... (more hidden) ...
conda-23.1.0 | 953 KB | ##################################### | 100%
conda-23.1.0 | 953 KB | ##################################### | 100%
... (more hidden) ...
mkl_random-1.2.2 | 1009 KB | ##################################### | 100%
mkl_random-1.2.2 | 1009 KB | ##################################### | 100%
... (more hidden) ...
... (more hidden) ...
... (more hidden) ...
... (more hidden) ...
... (more hidden) ...
... (more hidden) ...
... (more hidden) ...
pyqt-5.15.7 | 5.1 MB | ##################################### | 100%
... (more hidden) ...
... (more hidden) ...
... (more hidden) ...
scipy-1.9.3 | 23.1 MB | ##################################### | 100%
qt-webengine-5.15.9 | 47.1 MB | ##################################### | 100%
... (more hidden) ...
Preparing transaction: /
-
\
|
/
-
\
|
/
-
\
|
done
Verifying transaction: -
\
|
/
-
\
|
/
-
\
|
/
-
\
|
/
-
\
|
/
-
\
|
/
-
\
|
/
-
\
|
/
-
done
Executing transaction: |
/
-
\
|
/
-
\
|
/
-
\
|
/
-
\
|
/
-
\
|
/
-
\
|
/
-
\
|
/
-
\
|
/
-
\
|
/
-
\
|
/
-
\
|
/
-
done
# Install PyKalman (https://pykalman.github.io/)
!pip install pykalman --quiet
# Imports
import numpy as np
import matplotlib.pyplot as plt
import pykalman
from scipy import stats
Figure settings¶
#@title Figure settings
import ipywidgets as widgets # interactive display
%config InlineBackend.figure_format = 'retina'
plt.style.use("https://raw.githubusercontent.com/NeuromatchAcademy/course-content/main/nma.mplstyle")
Data retrieval and loading¶
#@title Data retrieval and loading
import io
import os
import hashlib
import requests
fname = "W2D3_mit_eyetracking_2009.npz"
url = "https://osf.io/jfk8w/download"
expected_md5 = "20c7bc4a6f61f49450997e381cf5e0dd"
if not os.path.isfile(fname):
try:
r = requests.get(url)
except requests.ConnectionError:
print("!!! Failed to download data !!!")
else:
if r.status_code != requests.codes.ok:
print("!!! Failed to download data !!!")
elif hashlib.md5(r.content).hexdigest() != expected_md5:
print("!!! Data download appears corrupted !!!")
else:
with open(fname, "wb") as fid:
fid.write(r.content)
def load_eyetracking_data(data_fname=fname):
with np.load(data_fname, allow_pickle=True) as dobj:
data = dict(**dobj)
images = [plt.imread(io.BytesIO(stim), format='JPG')
for stim in data['stimuli']]
subjects = data['subjects']
return subjects, images
Helper functions¶
#@title Helper functions
np.set_printoptions(precision=3)
def plot_kalman(state, observation, estimate=None, label='filter', color='r-',
title='LDS', axes=None):
if axes is None:
fig, (ax1, ax2) = plt.subplots(ncols=2, figsize=(16, 6))
ax1.plot(state[:, 0], state[:, 1], 'g-', label='true latent')
ax1.plot(observation[:, 0], observation[:, 1], 'k.', label='data')
else:
ax1, ax2 = axes
if estimate is not None:
ax1.plot(estimate[:, 0], estimate[:, 1], color=color, label=label)
ax1.set(title=title, xlabel='X position', ylabel='Y position')
ax1.legend()
if estimate is None:
ax2.plot(state[:, 0], observation[:, 0], '.k', label='dim 1')
ax2.plot(state[:, 1], observation[:, 1], '.', color='grey', label='dim 2')
ax2.set(title='correlation', xlabel='latent', ylabel='measured')
else:
ax2.plot(state[:, 0], estimate[:, 0], '.', color=color,
label='latent dim 1')
ax2.plot(state[:, 1], estimate[:, 1], 'x', color=color,
label='latent dim 2')
ax2.set(title='correlation',
xlabel='real latent',
ylabel='estimated latent')
ax2.legend()
return ax1, ax2
def plot_gaze_data(data, img=None, ax=None):
# overlay gaze on stimulus
if ax is None:
fig, ax = plt.subplots(figsize=(8, 6))
xlim = None
ylim = None
if img is not None:
ax.imshow(img, aspect='auto')
ylim = (img.shape[0], 0)
xlim = (0, img.shape[1])
ax.scatter(data[:, 0], data[:, 1], c='m', s=100, alpha=0.7)
ax.set(xlim=xlim, ylim=ylim)
return ax
def plot_kf_state(kf, data, ax):
mu_0 = np.ones(kf.n_dim_state)
mu_0[:data.shape[1]] = data[0]
kf.initial_state_mean = mu_0
mu, sigma = kf.smooth(data)
ax.plot(mu[:, 0], mu[:, 1], 'limegreen', linewidth=3, zorder=1)
ax.scatter(mu[0, 0], mu[0, 1], c='orange', marker='>', s=200, zorder=2)
ax.scatter(mu[-1, 0], mu[-1, 1], c='orange', marker='s', s=200, zorder=2)
Section 1: Linear Dynamical System (LDS)¶
Video 2: Linear Dynamical Systems¶
Video available at https://youtu.be/2SWh639YgEg
Kalman filter definitions:¶
The latent state \(s_t\) evolves as a stochastic linear dynamical system in discrete time, with a dynamics matrix \(D\):
Just as in the HMM, the structure is a Markov chain where the state at time point \(t\) is conditionally independent of previous states given the state at time point \(t-1\).
Sensory measurements \(m_t\) (observations) are noisy linear projections of the latent state:
Both states and measurements have Gaussian variability, often called noise: ‘process noise’ \(w_t\) for the states, and ‘measurement’ or ‘observation noise’ \(\eta_t\) for the measurements. The initial state is also Gaussian distributed. These quantites have means and covariances:
As a consequence, \(s_t\), \(m_t\) and their joint distributions are Gaussian. This makes all of the math analytically tractable using linear algebra, so we can easily compute the marginal and conditional distributions we will use for inferring the current state given the entire history of measurements.
Please note: we are trying to create uniform notation across tutorials. In some videos created in 2020, measurements \(m_t\) were denoted \(y_t\), and the Dynamics matrix \(D\) was denoted \(F\). We apologize for any confusion!
Section 1.1: Sampling from a latent linear dynamical system¶
The first thing we will investigate is how to generate timecourse samples from a linear dynamical system given its parameters. We will start by defining the following system:
# task dimensions
n_dim_state = 2
n_dim_obs = 2
# initialize model parameters
params = {
'D': 0.9 * np.eye(n_dim_state), # state transition matrix
'Q': np.eye(n_dim_obs), # state noise covariance
'H': np.eye(n_dim_state), # observation matrix
'R': 1.0 * np.eye(n_dim_obs), # observation noise covariance
'mu_0': np.zeros(n_dim_state), # initial state mean
'sigma_0': 0.1 * np.eye(n_dim_state), # initial state noise covariance
}
Coding note: We used a parameter dictionary params
above. As the number of parameters we need to provide to our functions increases, it can be beneficial to condense them into a data structure like this to clean up the number of inputs we pass in. The trade-off is that we have to know what is in our data structure to use those values, rather than looking at the function signature directly.
Exercise 1: Sampling from a linear dynamical system¶
In this exercise you will implement the dynamics functions of a linear dynamical system to sample both a latent space trajectory (given parameters set above) and noisy measurements.
def sample_lds(n_timesteps, params, seed=0):
""" Generate samples from a Linear Dynamical System specified by the provided
parameters.
Args:
n_timesteps (int): the number of time steps to simulate
params (dict): a dictionary of model parameters: (D, Q, H, R, mu_0, sigma_0)
seed (int): a random seed to use for reproducibility checks
Returns:
ndarray, ndarray: the generated state and observation data
"""
n_dim_state = params['D'].shape[0]
n_dim_obs = params['H'].shape[0]
# set seed
np.random.seed(seed)
# precompute random samples from the provided covariance matrices
# mean defaults to 0
mi = stats.multivariate_normal(cov=params['Q']).rvs(n_timesteps)
eta = stats.multivariate_normal(cov=params['R']).rvs(n_timesteps)
# initialize state and observation arrays
state = np.zeros((n_timesteps, n_dim_state))
obs = np.zeros((n_timesteps, n_dim_obs))
###################################################################
## TODO for students: compute the next state and observation values
# Fill out function and remove
raise NotImplementedError("Student exercise: compute the next state and observation values")
###################################################################
# simulate the system
for t in range(n_timesteps):
# write the expressions for computing state values given the time step
if t == 0:
state[t] = ...
else:
state[t] = ...
# write the expression for computing the observation
obs[t] = ...
return state, obs
# Uncomment below to test your function
# state, obs = sample_lds(100, params)
# print('sample at t=3 ', state[3])
# plot_kalman(state, obs, title='sample')
Example output:

Interactive Demo: Adjusting System Dynamics¶
To test your understanding of the parameters of a linear dynamical system, think about what you would expect if you made the following changes:
Reduce observation noise \(R\)
Increase respective temporal dynamics \(D\)
Use the interactive widget below to vary the values of \(R\) and \(D\).
¶
Make sure you execute this cell to enable the widget!
#@title
#@markdown Make sure you execute this cell to enable the widget!
@widgets.interact(R=widgets.FloatLogSlider(1., min=-2, max=2),
D=widgets.FloatSlider(0.9, min=0.0, max=1.0, step=.01))
def explore_dynamics(R=0.1, D=0.5):
params = {
'D': D * np.eye(n_dim_state), # state transition matrix
'Q': np.eye(n_dim_obs), # state noise covariance
'H': np.eye(n_dim_state), # observation matrix
'R': R * np.eye(n_dim_obs), # observation noise covariance
'mu_0': np.zeros(n_dim_state), # initial state mean,
'sigma_0': 0.1 * np.eye(n_dim_state), # initial state noise covariance
}
state, obs = sample_lds(100, params)
plot_kalman(state, obs, title='sample')
Section 2: Kalman Filtering¶
Video 3: Kalman Filtering¶
Video available at https://youtu.be/VboZOV9QMOI
We want to infer the latent state variable \(s_t\) given the measured (observed) variable \(m_t\).
First we obtain estimates of the latent state by running the filtering from \(t=0,....T\).
Where \(\hat{\mu}_t^{\rm pred}\) and \(\hat{\Sigma}_t^{\rm pred}\) are derived as follows:
This is the prediction for \(s_t\) obtained simply by taking the expected value of \(s_{t-1}\) and projecting it forward one step using the transition matrix \(D\). We do the same for the covariance, taking into account the noise covariance \(Q\) and the fact that scaling a variable by \(D\) scales its covariance \(\Sigma\) as \(D\Sigma D^T\):
We then use a Bayesian update from the newest measurements to obtain \(\hat{\mu}_t^{\rm filter}\) and \(\hat{\Sigma}_t^{\rm filter}\)
Project our prediction to observational space:
update prediction by actual data:
Kalman gain matrix:
We use the latent-only prediction to project it to the observational space and compute a correction proportional to the error \(m_t-HDz_{t-1}\) between prediction and data. The coefficient of this correction is the Kalman gain matrix.
Interpretations
If measurement noise is small and dynamics are fast, then estimation will depend mostly on currently observed data. If the measurement noise is large, then the Kalman filter uses past observations as well, combining them as long as the underlying state is at least somewhat predictable.
In order to explore the impact of filtering, we will use the following noisy oscillatory system:
# task dimensions
n_dim_state = 2
n_dim_obs = 2
T=100
# initialize model parameters
params = {
'D': np.array([[1., 1.], [-(2*np.pi/20.)**2., .9]]), # state transition matrix
'Q': np.eye(n_dim_obs), # state noise covariance
'H': np.eye(n_dim_state), # observation matrix
'R': 100.0 * np.eye(n_dim_obs), # observation noise covariance
'mu_0': np.zeros(n_dim_state), # initial state mean
'sigma_0': 0.1 * np.eye(n_dim_state), # initial state noise covariance
}
state, obs = sample_lds(T, params)
plot_kalman(state, obs, title='sample')
Exercise 2: Implement Kalman filtering¶
In this exercise you will implement the Kalman filter (forward) process. Your focus will be on writing the expressions for the Kalman gain, filter mean, and filter covariance at each time step (refer to the equations above).
def kalman_filter(data, params):
""" Perform Kalman filtering (forward pass) on the data given the provided
system parameters.
Args:
data (ndarray): a sequence of observations of shape(n_timesteps, n_dim_obs)
params (dict): a dictionary of model parameters: (D, Q, H, R, mu_0, sigma_0)
Returns:
ndarray, ndarray: the filtered system means and noise covariance values
"""
# pulled out of the params dict for convenience
D = params['D']
Q = params['Q']
H = params['H']
R = params['R']
n_dim_state = D.shape[0]
n_dim_obs = H.shape[0]
I = np.eye(n_dim_state) # identity matrix
# state tracking arrays
mu = np.zeros((len(data), n_dim_state))
sigma = np.zeros((len(data), n_dim_state, n_dim_state))
# filter the data
for t, y in enumerate(data):
if t == 0:
mu_pred = params['mu_0']
sigma_pred = params['sigma_0']
else:
mu_pred = D @ mu[t-1]
sigma_pred = D @ sigma[t-1] @ D.T + Q
###########################################################################
## TODO for students: compute the filtered state mean and covariance values
# Fill out function and remove
raise NotImplementedError("Student exercise: compute the filtered state mean and covariance values")
###########################################################################
# write the expression for computing the Kalman gain
K = ...
# write the expression for computing the filtered state mean
mu[t] = ...
# write the expression for computing the filtered state noise covariance
sigma[t] = ...
return mu, sigma
# Uncomment below to test your function
# filtered_state_means, filtered_state_covariances = kalman_filter(obs, params)
# plot_kalman(state, obs, filtered_state_means, title="my kf-filter",
# color='r', label='my kf-filter')
Example output:

Section 3: Fitting Eye Gaze Data¶
Video 4: Fitting Eye Gaze Data¶
Video available at https://youtu.be/M7OuXmVWHGI
Tracking eye gaze is used in both experimental and user interface applications. Getting an accurate estimation of where someone is looking on a screen in pixel coordinates can be challenging, however, due to the various sources of noise inherent in obtaining these measurements. A main source of noise is the general accuracy of the eye tracker device itself and how well it maintains calibration over time. Changes in ambient light or subject position can further reduce accuracy of the sensor. Eye blinks introduce a different form of noise as interruptions in the data stream which also need to be addressed.
Fortunately we have a candidate solution for handling noisy eye gaze data in the Kalman filter we just learned about. Let’s look at how we can apply these methods to a small subset of data taken from the MIT Eyetracking Database [Judd et al. 2009]. This data was collected as part of an effort to model visual saliency – given an image, can we predict where a person is most likely going to look.
# load eyetracking data
subjects, images = load_eyetracking_data()
Interactive Demo: Tracking Eye Gaze¶
We have three stimulus images and five different subjects’ gaze data. Each subject fixated in the center of the screen before the image appeared, then had a few seconds to freely look around. You can use the widget below to see how different subjects visually scanned the presented image. A subject ID of -1 will show the stimulus images without any overlayed gaze trace.
Note that the images are rescaled below for display purposes, they were in their original aspect ratio during the task itself.
¶
Make sure you execute this cell to enable the widget!
#@title
#@markdown Make sure you execute this cell to enable the widget!
@widgets.interact(subject_id=widgets.IntSlider(-1, min=-1, max=4),
image_id=widgets.IntSlider(0, min=0, max=2))
def plot_subject_trace(subject_id=-1, image_id=0):
if subject_id == -1:
subject = np.zeros((3, 0, 2))
else:
subject = subjects[subject_id]
data = subject[image_id]
img = images[image_id]
fig, ax = plt.subplots()
ax.imshow(img, aspect='auto')
ax.scatter(data[:, 0], data[:, 1], c='m', s=100, alpha=0.7)
ax.set(xlim=(0, img.shape[1]), ylim=(img.shape[0], 0))
Section 3.1: Fitting data with pykalman
¶
Now that we have data, we’d like to use Kalman filtering to give us a better estimate of the true gaze. Up until this point we’ve known the parameters of our LDS, but here we need to estimate them from data directly. We will use the pykalman
package to handle this estimation using the EM algorithm, a useful and influential learning algorithm described briefly in the bonus material.
Before exploring fitting models with pykalman
it’s worth pointing out some naming conventions used by the library:
The first thing we need to do is provide a guess at the dimensionality of the latent state. Let’s start by assuming the dynamics line-up directly with the observation data (pixel x,y-coordinates), and so we have a state dimension of 2.
We also need to decide which parameters we want the EM algorithm to fit. In this case, we will let the EM algorithm discover the dynamics parameters i.e. the \(D\), \(Q\), \(H\), and \(R\) matrices.
We set up our pykalman
KalmanFilter
object with these settings using the code below.
# set up our KalmanFilter object and tell it which parameters we want to
# estimate
np.random.seed(1)
n_dim_obs = 2
n_dim_state = 2
kf = pykalman.KalmanFilter(
n_dim_state=n_dim_state,
n_dim_obs=n_dim_obs,
em_vars=['transition_matrices', 'transition_covariance',
'observation_matrices', 'observation_covariance']
)
Because we know from the reported experimental design that subjects fixated in the center of the screen right before the image appears, we can set the initial starting state estimate \(\mu_0\) as being the center pixel of the stimulus image (the first data point in this sample dataset) with a correspondingly low initial noise covariance \(\Sigma_0\). Once we have everything set, it’s time to fit some data.
# Choose a subject and stimulus image
subject_id = 1
image_id = 2
data = subjects[subject_id][image_id]
# Provide the initial states
kf.initial_state_mean = data[0]
kf.initial_state_covariance = 0.1*np.eye(n_dim_state)
# Estimate the parameters from data using the EM algorithm
kf.em(data)
print(f'D=\n{kf.transition_matrices}')
print(f'Q =\n{kf.transition_covariance}')
print(f'H =\n{kf.observation_matrices}')
print(f'R =\n{kf.observation_covariance}')
D=
[[ 1.004 -0.01 ]
[ 0.005 0.989]]
Q =
[[278.016 219.292]
[219.292 389.774]]
H =
[[ 0.999 0.003]
[-0.004 1.01 ]]
R =
[[26.026 19.596]
[19.596 26.745]]
We see that the EM algorithm has found fits for the various dynamics parameters. One thing you will note is that both the state and observation matrices are close to the identity matrix, which means the x- and y-coordinate dynamics are independent of each other and primarily impacted by the noise covariances.
We can now use this model to smooth the observed data from the subject. In addition to the source image, we can also see how this model will work with the gaze recorded by the same subject on the other images as well, or even with different subjects.
Below are the three stimulus images overlayed with recorded gaze in magenta and smoothed state from the filter in green, with gaze begin (orange triangle) and gaze end (orange square) markers.
¶
Make sure you execute this cell to enable the widget!
#@title
#@markdown Make sure you execute this cell to enable the widget!
@widgets.interact(subject_id=widgets.IntSlider(1, min=0, max=4))
def plot_smoothed_traces(subject_id=0):
subject = subjects[subject_id]
fig, axes = plt.subplots(ncols=3, figsize=(18, 4))
for data, img, ax in zip(subject, images, axes):
ax = plot_gaze_data(data, img=img, ax=ax)
plot_kf_state(kf, data, ax)
Discussion questions:¶
Why do you think one trace from one subject was sufficient to provide a decent fit across all subjects? If you were to go back and change the subject_id and/or image_id for when we fit the data using EM, do you think the fits would be different?
We don’t think the eye is exactly following a linear dynamical system. Nonetheless that is what we assumed for this exercise when we applied a Kalman filter. Despite the mismatch, these algorithms do perform well. Discuss what differences we might find between the true and assumed processes. What mistakes might be likely consequences of these differences?
Finally, recall that the original task was to use this data to help develop models of visual salience. While our Kalman filter is able to provide smooth estimates of observed gaze data, it’s not telling us anything about why the gaze is going in a certain direction. In fact, if we sample data from our parameters and plot them, we get what amounts to a random walk.
kf_state, kf_data = kf.sample(len(data))
ax = plot_gaze_data(kf_data, img=images[2])
plot_kf_state(kf, kf_data, ax)

This should not be surprising, as we have given the model no other observed data beyond the pixels at which gaze was detected. We expect there is some other aspect driving the latent state of where to look next other than just the previous fixation location.
In summary, while the Kalman filter is a good option for smoothing the gaze trajectory itself, especially if using a lower-quality eye tracker or in noisy environmental conditions, a linear dynamical system may not be the right way to approach the much more challenging task of modeling visual saliency.
Handling Eye Blinks¶
In the MIT Eyetracking Database, raw tracking data includes times when the subject blinked. The way this is represented in the data stream is via negative pixel coordinate values.
We could try to mitigate these samples by simply deleting them from the stream, though this introduces other issues. For instance, if each sample corresponds to a fixed time step, and you arbitrarily remove some samples, the integrity of that consistent timestep between samples is lost. It’s sometimes better to flag data as missing rather than to pretend it was never there at all, especially with time series data.
Another solution is to use masked arrays. In numpy
, a masked array is an ndarray
with an additional embedded boolean masking array that indicates which elements should be masked. When computation is performed on the array, the masked elements are ignored. Both matplotlib
and pykalman
work with masked arrays, and, in fact, this is the approach taken with the data we explore in this notebook.
In preparing the dataset for this notebook, the original dataset was preprocessed to set all gaze data as masked arrays, with the mask enabled for any pixel with a negative x or y coordinate.
Bonus¶
Review on Gaussian joint, marginal and conditional distributions¶
Assume
then the marginal distributions are
and the conditional distributions are
important take away: given the joint Gaussian distribution we can derive the conditionals
Kalman Smoothing¶
Video 5: Kalman Smoothing and the EM Algorithm¶
Video available at https://youtu.be/4Ar2mYz1Nms
Obtain estimates by propagating from \(y_T\) back to \(y_0\) using results of forward pass (\(\hat{\mu}_t^{\rm filter}, \hat{\Sigma}_t^{\rm filter}, P_t=\hat{\Sigma}_{t+1}^{\rm pred}\))
This gives us the final estimate for \(z_t\).
Exercise 3: Implement Kalman smoothing¶
In this exercise you will implement the Kalman smoothing (backward) process. Again you will focus on writing the expressions for computing the smoothed mean, smoothed covariance, and \(J_t\) values.
def kalman_smooth(data, params):
""" Perform Kalman smoothing (backward pass) on the data given the provided
system parameters.
Args:
data (ndarray): a sequence of observations of shape(n_timesteps, n_dim_obs)
params (dict): a dictionary of model parameters: (D, Q, H, R, mu_0, sigma_0)
Returns:
ndarray, ndarray: the smoothed system means and noise covariance values
"""
# pulled out of the params dict for convenience
D= params['D']
Q = params['Q']
H = params['H']
R = params['R']
n_dim_state = D.shape[0]
n_dim_obs = H.shape[0]
# first run the forward pass to get the filtered means and covariances
mu, sigma = kalman_filter(data, params)
# initialize state mean and covariance estimates
mu_hat = np.zeros_like(mu)
sigma_hat = np.zeros_like(sigma)
mu_hat[-1] = mu[-1]
sigma_hat[-1] = sigma[-1]
# smooth the data
for t in reversed(range(len(data)-1)):
sigma_pred = D@ sigma[t] @ D.T + Q # sigma_pred at t+1
###########################################################################
## TODO for students: compute the smoothed state mean and covariance values
# Fill out function and remove
raise NotImplementedError("Student exercise: compute the smoothed state mean and covariance values")
###########################################################################
# write the expression to compute the Kalman gain for the backward process
J = ...
# write the expression to compute the smoothed state mean estimate
mu_hat[t] = ...
# write the expression to compute the smoothed state noise covariance estimate
sigma_hat[t] = ...
return mu_hat, sigma_hat
# Uncomment once the kalman_smooth function is complete
# smoothed_state_means, smoothed_state_covariances = kalman_smooth(obs, params)
# axes = plot_kalman(state, obs, filtered_state_means, color="r",
# label="my kf-filter")
# plot_kalman(state, obs, smoothed_state_means, color="b",
# label="my kf-smoothed", axes=axes)
Example output:

Forward vs Backward
Now that we have implementations for both, let’s compare their performance by computing the MSE between the filtered (forward) and smoothed (backward) estimated states and the true latent state.
print(f"Filtered MSE: {np.mean((state - filtered_state_means)**2):.3f}")
print(f"Smoothed MSE: {np.mean((state - smoothed_state_means)**2):.3f}")
In this example, the smoothed estimate is clearly superior to the filtered one. This makes sense as the forward pass uses only the past measurements, whereas the backward pass can use future measurement too, correcting the forward pass estimates given all the data we’ve collected.
So why would you ever use Kalman filtering alone, without smoothing? As Kalman filtering only depends on already observed data (i.e. the past) it can be run in a streaming, or on-line, setting. Kalman smoothing relies on future data as it were, and as such can only be applied in a batch, or off-line, setting. So use Kalman filtering if you need real-time corrections and Kalman smoothing if you are considering already-collected data.
This online case is typically what the brain faces.
The Expectation-Maximization (EM) Algorithm¶
want to maximize \(\log p(m|\theta)\)
need to marginalize out latent state (which is not tractable)
add a probability distribution \(q(s)\) which will approximate the latent state distribution
can be rewritten as
\(\mathcal{L}(q,\theta)\) contains the joint distribution of \(m\) and \(s\)
\(KL(q||p)\) contains the conditional distribution of \(s|m\)
Expectation step¶
parameters are kept fixed
find a good approximation \(q(s)\): maximize lower bound \(\mathcal{L}(q,\theta)\) with respect to \(q(s)\)
(already implemented Kalman filter+smoother)
Maximization step¶
keep distribution \(q(s)\) fixed
change parameters to maximize the lower bound \(\mathcal{L}(q,\theta)\)
As mentioned, we have already effectively solved for the E-Step with our Kalman filter and smoother. The M-step requires further derivation, which is covered in the Appendix. Rather than having you implement the M-Step yourselves, let’s instead turn to using a library that has already implemented EM for exploring some experimental data from cognitive neuroscience.
The M-step for a LDS¶
(see Bishop, chapter 13.3.2 Learning in LDS) Update parameters of the probability distribution
For the updates in the M-step we will need the following posterior marginals obtained from the Kalman smoothing results* \(\hat{\mu}_t^{\rm smooth}, \hat{\Sigma}_t^{\rm smooth}\)
Update parameters
Initial parameters
Hidden (latent) state parameters
Observable (measured) space parameters