From 15e296db9bd905046ce365ae556f0bfbb1b1c12e Mon Sep 17 00:00:00 2001 From: evelynmitchell Date: Thu, 25 Jan 2024 12:29:06 -0700 Subject: [PATCH 1/2] example colab notebook --- DynamaxSSM.ipynb | 89 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 89 insertions(+) create mode 100644 DynamaxSSM.ipynb diff --git a/DynamaxSSM.ipynb b/DynamaxSSM.ipynb new file mode 100644 index 00000000..28cc3797 --- /dev/null +++ b/DynamaxSSM.ipynb @@ -0,0 +1,89 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "private_outputs": true, + "provenance": [] + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + } + }, + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Ru6SwSvrCsbl" + }, + "outputs": [], + "source": [ + "!pip install dynamax[notebooks]" + ] + }, + { + "cell_type": "code", + "source": [ + "import jax.numpy as jnp\n", + "import jax.random as jr\n", + "import matplotlib.pyplot as plt\n", + "from dynamax.hidden_markov_model import GaussianHMM\n", + "\n", + "key1, key2, key3 = jr.split(jr.PRNGKey(0), 3)\n", + "num_states = 3\n", + "emission_dim = 2\n", + "num_timesteps = 1000\n", + "\n", + "# Make a Gaussian HMM and sample data from it\n", + "hmm = GaussianHMM(num_states, emission_dim)\n", + "true_params, _ = hmm.initialize(key1)\n", + "true_states, emissions = hmm.sample(true_params, key2, num_timesteps)\n", + "\n", + "# Make a new Gaussian HMM and fit it with EM\n", + "params, props = hmm.initialize(key3, method=\"kmeans\", emissions=emissions)\n", + "params, lls = hmm.fit_em(params, props, emissions, num_iters=20)\n", + "\n", + "# Plot the marginal log probs across EM iterations\n", + "plt.plot(lls)\n", + "plt.xlabel(\"EM iterations\")\n", + "plt.ylabel(\"marginal log prob.\")\n", + "\n", + "# Use fitted model for posterior inference\n", + "post = hmm.smoother(params, emissions)\n", + "print(post.smoothed_probs.shape) # (1000, 3)" + ], + "metadata": { + "id": "CRtRwx4LD5oe" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "from functools import partial\n", + "from jax import vmap\n", + "\n", + "num_seq = 200\n", + "batch_true_states, batch_emissions = \\\n", + " vmap(partial(hmm.sample, true_params, num_timesteps=num_timesteps))(\n", + " jr.split(key2, num_seq))\n", + "print(batch_true_states.shape, batch_emissions.shape) # (200,1000) and (200,1000,2)\n", + "\n", + "# Make a new Gaussian HMM and fit it with EM\n", + "params, props = hmm.initialize(key3, method=\"kmeans\", emissions=batch_emissions)\n", + "params, lls = hmm.fit_em(params, props, batch_emissions, num_iters=20)" + ], + "metadata": { + "id": "IfgGkfmeFh8R" + }, + "execution_count": null, + "outputs": [] + } + ] +} \ No newline at end of file From bfe681af8f6a4be7743159b0fcd37605da0357e0 Mon Sep 17 00:00:00 2001 From: evelynmitchell Date: Thu, 25 Jan 2024 12:33:10 -0700 Subject: [PATCH 2/2] add link to colab notebook --- README.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/README.md b/README.md index b679bc40..0393ab0c 100644 --- a/README.md +++ b/README.md @@ -61,6 +61,10 @@ pytest dynamax/hmm/inference_test.py # Run a specific test pytest -k lgssm # Run tests with lgssm in the name ``` +Run example in Colab: +Open In Colab + + ## What are state space models? A state space model or SSM is a partially observed Markov model, in