From 68eb8deeded459ea352f80701012eb93d8df260f Mon Sep 17 00:00:00 2001 From: Matthew Creamer Date: Tue, 7 Mar 2023 17:27:35 -0500 Subject: [PATCH] Modified lgssm_filter to allow it to processes data with NaNs This will also allow it to process data of different lengths by padding with NaNs --- dynamax/linear_gaussian_ssm/inference.py | 21 ++++++++++++++++++- dynamax/linear_gaussian_ssm/inference_test.py | 20 ++++++++++++++---- 2 files changed, 36 insertions(+), 5 deletions(-) diff --git a/dynamax/linear_gaussian_ssm/inference.py b/dynamax/linear_gaussian_ssm/inference.py index 8529eb45..4e5203d8 100644 --- a/dynamax/linear_gaussian_ssm/inference.py +++ b/dynamax/linear_gaussian_ssm/inference.py @@ -370,7 +370,8 @@ def _step(prev_state, args): def lgssm_filter( params: ParamsLGSSM, emissions: Float[Array, "ntime emission_dim"], - inputs: Optional[Float[Array, "ntime input_dim"]]=None + inputs: Optional[Float[Array, "ntime input_dim"]]=None, + nan_fill_multiplier: float=1e8 ) -> PosteriorGSSMFiltered: r"""Run a Kalman filter to produce the marginal likelihood and filtered state estimates. @@ -386,6 +387,18 @@ def lgssm_filter( num_timesteps = len(emissions) inputs = jnp.zeros((num_timesteps, 0)) if inputs is None else inputs + # Create a vector to replace nans in the emissions + # if an entire time trace is nan, replace with the mean across all values + # if the entire emission is nan, replace with 0s + nan_fill_mean = jnp.nanmean(emissions, axis=0) + nan_fill_mean = jnp.where(jnp.isnan(nan_fill_mean), jnp.nanmean(nan_fill_mean), nan_fill_mean) + nan_fill_mean = jnp.where(jnp.isnan(nan_fill_mean), 0, nan_fill_mean) + + # Create a vector to set the diagonal of the covariance of the emissions to a large number + # wherever there are NaNs in the emissions + nan_fill_cov = jnp.nanmax(jnp.nanvar(emissions)) * nan_fill_multiplier + nan_fill_cov = jnp.where(jnp.isnan(nan_fill_cov), nan_fill_multiplier, nan_fill_cov) + def _step(carry, t): ll, pred_mean, pred_cov = carry @@ -401,6 +414,12 @@ def _step(carry, t): u = inputs[t] y = emissions[t] + # Find NaNs in the emissions and replace them with nan_fill_mean + # Then, set the emission covariance to nan_fill_cov to push the filter to ignore these emissions + nan_loc = jnp.isnan(y) + y = jnp.where(nan_loc, nan_fill_mean, y) + R = jnp.where(jnp.diag(nan_loc), nan_fill_cov, R) + # Update the log likelihood ll += MVN(H @ pred_mean + D @ u + d, H @ pred_cov @ H.T + R).log_prob(y) diff --git a/dynamax/linear_gaussian_ssm/inference_test.py b/dynamax/linear_gaussian_ssm/inference_test.py index 92a86690..a4f7c022 100644 --- a/dynamax/linear_gaussian_ssm/inference_test.py +++ b/dynamax/linear_gaussian_ssm/inference_test.py @@ -10,11 +10,11 @@ from dynamax.utils.utils import has_tpu if has_tpu(): - def allclose(x, y): - return jnp.allclose(x, y, atol=1e-1) + def allclose(x, y, atol=1e-1): + return jnp.allclose(x, y, atol=atol) else: - def allclose(x,y): - return jnp.allclose(x, y, atol=1e-1) + def allclose(x, y, atol=1e-1): + return jnp.allclose(x, y, atol=atol) def joint_posterior_mvn(params, emissions): """Construct the joint posterior MVN of a LGSSM, by inverting the joint precision matrix which @@ -165,6 +165,12 @@ class TestFilteringAndSmoothing(): print(ssm_posterior.filtered_means.shape) print(ssm_posterior.smoothed_means.shape) + # repeat sampling with NaNs in the emissions + nan_x = (0, emissions.shape[0], 0, emissions.shape[0]) + nan_y = (0, emissions.shape[0], emissions.shape[1], emissions.shape[1]) + emissions_nan = emissions.at[nan_x, nan_y].set(jnp.nan) + ssm_posterior_nan = lgssm.smoother(params, emissions_nan) + # TensorFlow Probability posteriors tfp_lgssm = lgssm_dynamax_to_tfp(num_timesteps, params) tfp_lls, tfp_filtered_means, tfp_filtered_covs, *_ = tfp_lgssm.forward_filter(emissions) @@ -200,6 +206,12 @@ def test_kalman_tfp(self): assert allclose(self.ssm_posterior.smoothed_covariances, self.tfp_smoothed_covs) assert allclose(self.ssm_posterior.marginal_loglik, self.tfp_lls.sum()) + def test_kalman_tfp_nan(self): + assert allclose(self.ssm_posterior_nan.filtered_means, self.tfp_filtered_means, atol=1e0) + assert allclose(self.ssm_posterior_nan.filtered_covariances, self.tfp_filtered_covs, atol=1e0) + assert allclose(self.ssm_posterior_nan.smoothed_means, self.tfp_smoothed_means, atol=1e0) + assert allclose(self.ssm_posterior_nan.smoothed_covariances, self.tfp_smoothed_covs, atol=1e0) + def test_kalman_vs_joint(self): assert allclose(self.ssm_posterior.smoothed_means, self.joint_means) assert allclose(self.ssm_posterior.smoothed_covariances, self.joint_covs)