diff --git a/src/bmi/samplers/_tfp/_core.py b/src/bmi/samplers/_tfp/_core.py index ed075af5..dba12490 100644 --- a/src/bmi/samplers/_tfp/_core.py +++ b/src/bmi/samplers/_tfp/_core.py @@ -17,9 +17,11 @@ class JointDistribution: $P_X$ and $P_Y$. Attributes: - dist: $P_{XY}$ - dist_x: $P_X$ - dist_y: $P_Y$ + dist_joint: $P_{XY}$. Each sample is a *tuple* `(xs, ys)` + where `xs` is of shape `(n_samples, dim_x)` and + `ys` is of shape `(n_samples, dim_y)`. + dist_x: $P_X$. Samples are of shape `(n_samples, dim_x)` + dist_y: $P_Y$. Samples are of shape `(n_samples, dim_y,)` dim_x: dimension of the support of $X$ dim_y: dimension of the support of $Y$ analytic_mi: analytical mutual information. @@ -43,8 +45,7 @@ def sample(self, n_points: int, key: jax.Array) -> tuple[jnp.ndarray, jnp.ndarra if n_points < 1: raise ValueError("n must be positive") - xy = self.dist_joint.sample(seed=key, sample_shape=(n_points,)) - return xy[..., : self.dim_x], xy[..., self.dim_x :] # noqa: E203 (formatting discrepancy) + return self.dist_joint.sample(n_points, key) def pmi(self, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: """Calculates pointwise mutual information at specified points. @@ -60,7 +61,7 @@ def pmi(self, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: Note: This function is vectorized, i.e. it can calculate PMI for multiple points at once. """ - log_pxy = self.dist_joint.log_prob(jnp.hstack([x, y])) + log_pxy = self.dist_joint.log_prob((x, y)) log_px = self.dist_x.log_prob(x) log_py = self.dist_y.log_prob(y) @@ -136,9 +137,8 @@ def transform( if y_transform is None: y_transform = tfb.Identity() - product_bijector = tfb.Blockwise( - bijectors=[x_transform, y_transform], block_sizes=[dist.dim_x, dist.dim_y] - ) + product_bijector = tfb.JointMap((x_transform, y_transform)) + return JointDistribution( dim_x=dist.dim_x, dim_y=dist.dim_y, diff --git a/src/bmi/samplers/_tfp/_normal.py b/src/bmi/samplers/_tfp/_normal.py index a41bfef8..5f408fcb 100644 --- a/src/bmi/samplers/_tfp/_normal.py +++ b/src/bmi/samplers/_tfp/_normal.py @@ -8,6 +8,7 @@ from bmi.samplers._tfp._core import JointDistribution jtf = tfp.tf2jax +tfb = tfp.bijectors tfd = tfp.distributions @@ -55,7 +56,12 @@ def __init__( # Now we need to define the TensorFlow Probability distributions # using the information provided - dist_joint = construct_multivariate_normal_distribution(mean=mean, covariance=covariance) + _dist_joint = construct_multivariate_normal_distribution(mean=mean, covariance=covariance) + dist_joint = tfd.TransformedDistribution( + distribution=_dist_joint, + bijector=tfb.Split((dim_x, dim_y)), + ) + dist_x = construct_multivariate_normal_distribution( mean=mean[:dim_x], covariance=covariance[:dim_x, :dim_x] ) diff --git a/src/bmi/samplers/_tfp/_product.py b/src/bmi/samplers/_tfp/_product.py index 0b09d575..802eb7b1 100644 --- a/src/bmi/samplers/_tfp/_product.py +++ b/src/bmi/samplers/_tfp/_product.py @@ -29,7 +29,7 @@ def __init__(self, dist_x: tfd.Distribution, dist_y: tfd.Distribution) -> None: dim_x = int(dims_x[0]) dim_y = int(dims_y[0]) - dist_joint = tfd.Blockwise([dist_x, dist_y]) + dist_joint = tfd.JointDistributionSequential((dist_x, dist_y)) super().__init__( dim_x=dim_x, diff --git a/src/bmi/samplers/_tfp/_student.py b/src/bmi/samplers/_tfp/_student.py index 31172530..a80251b1 100644 --- a/src/bmi/samplers/_tfp/_student.py +++ b/src/bmi/samplers/_tfp/_student.py @@ -8,6 +8,7 @@ from bmi.samplers._tfp._core import JointDistribution jtf = tfp.tf2jax +tfb = tfp.bijectors tfd = tfp.distributions @@ -78,9 +79,14 @@ def __init__( # Now we need to define the TensorFlow Probability distributions # using the information provided - dist_joint = construct_multivariate_student_distribution( + _dist_joint = construct_multivariate_student_distribution( mean=mean, dispersion=dispersion, df=df ) + dist_joint = tfd.TransformedDistribution( + distribution=_dist_joint, + bijector=tfb.Split((dim_x, dim_y)), + ) + dist_x = construct_multivariate_student_distribution( mean=mean[:dim_x], dispersion=dispersion[:dim_x, :dim_x], df=df )