Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Created demo colab notebook, and link in readme #352

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 89 additions & 0 deletions DynamaxSSM.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"private_outputs": true,
"provenance": []
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Ru6SwSvrCsbl"
},
"outputs": [],
"source": [
"!pip install dynamax[notebooks]"
]
},
{
"cell_type": "code",
"source": [
"import jax.numpy as jnp\n",
"import jax.random as jr\n",
"import matplotlib.pyplot as plt\n",
"from dynamax.hidden_markov_model import GaussianHMM\n",
"\n",
"key1, key2, key3 = jr.split(jr.PRNGKey(0), 3)\n",
"num_states = 3\n",
"emission_dim = 2\n",
"num_timesteps = 1000\n",
"\n",
"# Make a Gaussian HMM and sample data from it\n",
"hmm = GaussianHMM(num_states, emission_dim)\n",
"true_params, _ = hmm.initialize(key1)\n",
"true_states, emissions = hmm.sample(true_params, key2, num_timesteps)\n",
"\n",
"# Make a new Gaussian HMM and fit it with EM\n",
"params, props = hmm.initialize(key3, method=\"kmeans\", emissions=emissions)\n",
"params, lls = hmm.fit_em(params, props, emissions, num_iters=20)\n",
"\n",
"# Plot the marginal log probs across EM iterations\n",
"plt.plot(lls)\n",
"plt.xlabel(\"EM iterations\")\n",
"plt.ylabel(\"marginal log prob.\")\n",
"\n",
"# Use fitted model for posterior inference\n",
"post = hmm.smoother(params, emissions)\n",
"print(post.smoothed_probs.shape) # (1000, 3)"
],
"metadata": {
"id": "CRtRwx4LD5oe"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"source": [
"from functools import partial\n",
"from jax import vmap\n",
"\n",
"num_seq = 200\n",
"batch_true_states, batch_emissions = \\\n",
" vmap(partial(hmm.sample, true_params, num_timesteps=num_timesteps))(\n",
" jr.split(key2, num_seq))\n",
"print(batch_true_states.shape, batch_emissions.shape) # (200,1000) and (200,1000,2)\n",
"\n",
"# Make a new Gaussian HMM and fit it with EM\n",
"params, props = hmm.initialize(key3, method=\"kmeans\", emissions=batch_emissions)\n",
"params, lls = hmm.fit_em(params, props, batch_emissions, num_iters=20)"
],
"metadata": {
"id": "IfgGkfmeFh8R"
},
"execution_count": null,
"outputs": []
}
]
}
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ pytest dynamax/hmm/inference_test.py # Run a specific test
pytest -k lgssm # Run tests with lgssm in the name
```

Run example in Colab: <a target="_blank" href="https://colab.research.google.com/github/evelynmitchell/dynamax/blob/master/DynamaxSSM.ipynb">
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

## What are state space models?

A state space model or SSM is a partially observed Markov model, in
Expand Down