Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resolve Issue 161 #165

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions src/bmi/samplers/_tfp/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@ class JointDistribution:
$P_X$ and $P_Y$.

Attributes:
dist: $P_{XY}$
dist_x: $P_X$
dist_y: $P_Y$
dist_joint: $P_{XY}$. Each sample is a *tuple* `(xs, ys)`
where `xs` is of shape `(n_samples, dim_x)` and
`ys` is of shape `(n_samples, dim_y)`.
dist_x: $P_X$. Samples are of shape `(n_samples, dim_x)`
dist_y: $P_Y$. Samples are of shape `(n_samples, dim_y,)`
dim_x: dimension of the support of $X$
dim_y: dimension of the support of $Y$
analytic_mi: analytical mutual information.
Expand All @@ -43,8 +45,7 @@ def sample(self, n_points: int, key: jax.Array) -> tuple[jnp.ndarray, jnp.ndarra
if n_points < 1:
raise ValueError("n must be positive")

xy = self.dist_joint.sample(seed=key, sample_shape=(n_points,))
return xy[..., : self.dim_x], xy[..., self.dim_x :] # noqa: E203 (formatting discrepancy)
return self.dist_joint.sample(n_points, key)

def pmi(self, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
"""Calculates pointwise mutual information at specified points.
Expand All @@ -60,7 +61,7 @@ def pmi(self, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
Note:
This function is vectorized, i.e. it can calculate PMI for multiple points at once.
"""
log_pxy = self.dist_joint.log_prob(jnp.hstack([x, y]))
log_pxy = self.dist_joint.log_prob((x, y))
log_px = self.dist_x.log_prob(x)
log_py = self.dist_y.log_prob(y)

Expand Down Expand Up @@ -136,9 +137,8 @@ def transform(
if y_transform is None:
y_transform = tfb.Identity()

product_bijector = tfb.Blockwise(
bijectors=[x_transform, y_transform], block_sizes=[dist.dim_x, dist.dim_y]
)
product_bijector = tfb.JointMap((x_transform, y_transform))

return JointDistribution(
dim_x=dist.dim_x,
dim_y=dist.dim_y,
Expand Down
8 changes: 7 additions & 1 deletion src/bmi/samplers/_tfp/_normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from bmi.samplers._tfp._core import JointDistribution

jtf = tfp.tf2jax
tfb = tfp.bijectors
tfd = tfp.distributions


Expand Down Expand Up @@ -55,7 +56,12 @@ def __init__(
# Now we need to define the TensorFlow Probability distributions
# using the information provided

dist_joint = construct_multivariate_normal_distribution(mean=mean, covariance=covariance)
_dist_joint = construct_multivariate_normal_distribution(mean=mean, covariance=covariance)
dist_joint = tfd.TransformedDistribution(
distribution=_dist_joint,
bijector=tfb.Split((dim_x, dim_y)),
)

dist_x = construct_multivariate_normal_distribution(
mean=mean[:dim_x], covariance=covariance[:dim_x, :dim_x]
)
Expand Down
2 changes: 1 addition & 1 deletion src/bmi/samplers/_tfp/_product.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __init__(self, dist_x: tfd.Distribution, dist_y: tfd.Distribution) -> None:
dim_x = int(dims_x[0])
dim_y = int(dims_y[0])

dist_joint = tfd.Blockwise([dist_x, dist_y])
dist_joint = tfd.JointDistributionSequential((dist_x, dist_y))

super().__init__(
dim_x=dim_x,
Expand Down
8 changes: 7 additions & 1 deletion src/bmi/samplers/_tfp/_student.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from bmi.samplers._tfp._core import JointDistribution

jtf = tfp.tf2jax
tfb = tfp.bijectors
tfd = tfp.distributions


Expand Down Expand Up @@ -78,9 +79,14 @@ def __init__(
# Now we need to define the TensorFlow Probability distributions
# using the information provided

dist_joint = construct_multivariate_student_distribution(
_dist_joint = construct_multivariate_student_distribution(
mean=mean, dispersion=dispersion, df=df
)
dist_joint = tfd.TransformedDistribution(
distribution=_dist_joint,
bijector=tfb.Split((dim_x, dim_y)),
)

dist_x = construct_multivariate_student_distribution(
mean=mean[:dim_x], dispersion=dispersion[:dim_x, :dim_x], df=df
)
Expand Down
Loading