Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A prototype of the Markov model #171

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 155 additions & 0 deletions Markov.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
# ---

Check failure on line 1 in Markov.py

View workflow job for this annotation

GitHub Actions / build (3.11, 1.3.2)

Imports are incorrectly sorted and/or formatted.
# jupyter:
# jupytext:
# text_representation:
# extension: .py
# format_name: light
# format_version: '1.5'
# jupytext_version: 1.16.2
# kernelspec:
# display_name: Python 3 (ipykernel)
# language: python
# name: python3
# ---

# +
import jax
import jax.numpy as jnp
from tensorflow_probability.substrates import jax as tfp

tfd = tfp.distributions
tfb = tfp.bijectors

# +
key = jax.random.PRNGKey(42)

initial_distribution = tfd.Categorical(probs=jnp.array([0.5, 0.3, 0.2]))

observation_distribution = tfd.Normal(jnp.asarray([0., 0.5, 1.0]), 1.0)

num_steps = 5

transition_probs =jnp.asarray([[0.0, 0.5, 0.5],
[1.0, 0.0, 0.0],
[1.0, 0.0, 0.0],
])
transition_distribution = tfd.Categorical(probs=transition_probs)



dist_x = tfd.MarkovChain(
initial_state_prior=initial_distribution,
transition_fn=lambda _, x: tfd.Categorical(probs=transition_probs[x, :]),
num_steps=num_steps,
)

dist_y = tfd.HiddenMarkovModel(
initial_distribution=initial_distribution,
transition_distribution=transition_distribution,
observation_distribution=observation_distribution,
num_steps=num_steps)

def dist_xy_log_prob(xs, ys):
# xs, ys = xys
log_prob_xs = dist_x.log_prob(xs)

ys_dists = observation_distribution[xs]
log_prob_ys_given_xs = jnp.sum(
observation_distribution[xs].log_prob(ys),
axis=-1,
)

return log_prob_xs + log_prob_ys_given_xs

def dist_xy_sample(n, key):
key1, key2 = jax.random.split(key)
xs = dist_x.sample(n, key)

ys_dists = observation_distribution[xs]
ys = ys_dists.sample((), key2)

return xs, ys

# dist_xy = tfd.Distribution(
# sample_fn=,
# log_prob_fn=,
# (


# +
dist_xy = tfd.JointDistributionSequential([
dist_x,
lambda xs: tfd.Independent(observation_distribution[xs]),
])

xys = dist_xy.sample(5_000, key)
xs, ys = xys



# +
from bmi.samplers._tfp._core import JointDistribution, monte_carlo_mi_estimate


our_dist = JointDistribution(
dist_x=dist_x,
dist_y=dist_y,
dist_joint=dist_xy,
dim_x=num_steps,
dim_y=num_steps,
unwrap=False,
)
# -

mi, std_err = monte_carlo_mi_estimate(key + 3, our_dist, 10_000)
mi

std_err

import bmi

estimator = bmi.estimators.CCAMutualInformationEstimator()
estimator.estimate(xs, ys)

estimator = bmi.estimators.NWJEstimator()
estimator.estimate(xs, ys)



# +
estimator = bmi.estimators.KSGEnsembleFirstEstimator(neighborhoods=(5,))

es = 1e-4 * jax.random.normal(key+3, shape=xs.shape)
estimator.estimate(xs + es, ys)
# -

estimator = bmi.estimators.MINEEstimator()
estimator.estimate(xs, ys)

estimator = bmi.estimators.InfoNCEEstimator()
estimator.estimate(xs, ys)

xs, ys = dist_xy
_maybe.sample(3, key)
xys = dist_xy_maybe.sample(3, key)

# +
# xys = jnp.stack([xs, ys], axis=0)
# -

dist_xy_maybe.log_prob(xys)

xs, ys = dist_xy_sample(3, key)

ys.shape

# +
num_samples = 3

sample_x = dist_x.sample(num_samples, key)
sample_y = dist_y.sample(num_samples, key)

dist_xy_log_prob(sample_x, sample_y)
# -


13 changes: 12 additions & 1 deletion src/bmi/samplers/_tfp/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class JointDistribution:
dim_x: int
dim_y: int
analytic_mi: Optional[float] = None
unwrap: bool = True

def sample(self, n_points: int, key: jax.Array) -> tuple[jnp.ndarray, jnp.ndarray]:
"""Sample from the joint distribution $P_{XY}$.
Expand All @@ -43,7 +44,12 @@ def sample(self, n_points: int, key: jax.Array) -> tuple[jnp.ndarray, jnp.ndarra
if n_points < 1:
raise ValueError("n must be positive")

# TODO(Pawel): Ensure that it works with JointDistribution
xy = self.dist_joint.sample(seed=key, sample_shape=(n_points,))

if not self.unwrap:
return xy

return xy[..., : self.dim_x], xy[..., self.dim_x :] # noqa: E203 (formatting discrepancy)

def pmi(self, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
Expand All @@ -60,7 +66,12 @@ def pmi(self, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
Note:
This function is vectorized, i.e. it can calculate PMI for multiple points at once.
"""
log_pxy = self.dist_joint.log_prob(jnp.hstack([x, y]))
# TODO(Pawel): Ensure it works with tfd.JointDistribution
if self.unwrap:
log_pxy = self.dist_joint.log_prob(jnp.hstack([x, y]))
else:
log_pxy = self.dist_joint.log_prob((x, y))

log_px = self.dist_x.log_prob(x)
log_py = self.dist_y.log_prob(y)

Expand Down
Loading