diff --git a/dynamax/hidden_markov_model/inference.py b/dynamax/hidden_markov_model/inference.py index f39b6e04..7b1f914b 100644 --- a/dynamax/hidden_markov_model/inference.py +++ b/dynamax/hidden_markov_model/inference.py @@ -1,23 +1,29 @@ +from functools import partial +from typing import Callable, NamedTuple, Optional, Tuple, Union import jax.numpy as jnp import jax.random as jr from jax import jit, lax, vmap -from functools import partial - -from typing import Callable, Optional, Tuple, Union, NamedTuple from jaxtyping import Int, Float, Array -from dynamax.types import Scalar +from dynamax.types import IntScalar, Scalar _get_params = lambda x, dim, t: x[t] if x.ndim == dim + 1 else x -def get_trans_mat(transition_matrix, transition_fn, t): +def get_trans_mat( + transition_matrix: Optional[Union[Float[Array, "num_states num_states"], + Float[Array, "num_timesteps_minus_1 num_states num_states"]]], + transition_fn: Optional[Callable[[IntScalar], Float[Array, "num_states num_states"]]], + t: IntScalar +) -> Float[Array, "num_states num_states"]: if transition_fn is not None: return transition_fn(t) - else: - if transition_matrix.ndim == 3: # (T,K,K) + elif transition_matrix is not None: + if transition_matrix.ndim == 3: # (T-1,K,K) return transition_matrix[t] else: return transition_matrix + else: + raise ValueError("Either `transition_matrix` or `transition_fn` must be specified.") class HMMPosteriorFiltered(NamedTuple): r"""Simple wrapper for properties of an HMM filtering posterior. @@ -49,8 +55,8 @@ class HMMPosterior(NamedTuple): predicted_probs: Float[Array, "num_timesteps num_states"] smoothed_probs: Float[Array, "num_timesteps num_states"] initial_probs: Float[Array, " num_states"] - trans_probs: Optional[Union[Float[Array, "num_timesteps num_states num_states"], - Float[Array, "num_states num_states"]]] = None + trans_probs: Optional[Union[Float[Array, "num_states num_states"], + Float[Array, "num_timesteps_minus_1 num_states num_states"]]] = None def _normalize(u: Array, axis=0, eps=1e-15): @@ -96,10 +102,10 @@ def _predict(probs, A): @partial(jit, static_argnames=["transition_fn"]) def hmm_filter( initial_distribution: Float[Array, " num_states"], - transition_matrix: Union[Float[Array, "num_timesteps num_states num_states"], - Float[Array, "num_states num_states"]], + transition_matrix: Optional[Union[Float[Array, "num_states num_states"], + Float[Array, "num_timesteps_minus_1 num_states num_states"]]], log_likelihoods: Float[Array, "num_timesteps num_states"], - transition_fn: Optional[Callable[[Int], Float[Array, "num_states num_states"]]] = None + transition_fn: Optional[Callable[[IntScalar], Float[Array, "num_states num_states"]]] = None ) -> HMMPosteriorFiltered: r"""Forwards filtering @@ -143,8 +149,8 @@ def _step(carry, t): @partial(jit, static_argnames=["transition_fn"]) def hmm_backward_filter( - transition_matrix: Union[Float[Array, "num_timesteps num_states num_states"], - Float[Array, "num_states num_states"]], + transition_matrix: Optional[Union[Float[Array, "num_states num_states"], + Float[Array, "num_timesteps_minus_1 num_states num_states"]]], log_likelihoods: Float[Array, "num_timesteps num_states"], transition_fn: Optional[Callable[[int], Float[Array, "num_states num_states"]]]= None ) -> Tuple[Scalar, Float[Array, "num_timesteps num_states"]]: @@ -190,10 +196,10 @@ def _step(carry, t): @partial(jit, static_argnames=["transition_fn"]) def hmm_two_filter_smoother( initial_distribution: Float[Array, " num_states"], - transition_matrix: Union[Float[Array, "num_timesteps num_states num_states"], - Float[Array, "num_states num_states"]], + transition_matrix: Optional[Union[Float[Array, "num_states num_states"], + Float[Array, "num_timesteps_minus_1 num_states num_states"]]], log_likelihoods: Float[Array, "num_timesteps num_states"], - transition_fn: Optional[Callable[[Int], Float[Array, "num_states num_states"]]]= None, + transition_fn: Optional[Callable[[IntScalar], Float[Array, "num_states num_states"]]]= None, compute_trans_probs: bool = True ) -> HMMPosterior: r"""Computed the smoothed state probabilities using the two-filter @@ -244,10 +250,10 @@ def hmm_two_filter_smoother( @partial(jit, static_argnames=["transition_fn"]) def hmm_smoother( initial_distribution: Float[Array, " num_states"], - transition_matrix: Union[Float[Array, "num_timesteps num_states num_states"], - Float[Array, "num_states num_states"]], + transition_matrix: Optional[Union[Float[Array, "num_states num_states"], + Float[Array, "num_timesteps_minus_1 num_states num_states"]]], log_likelihoods: Float[Array, "num_timesteps num_states"], - transition_fn: Optional[Callable[[Int], Float[Array, "num_states num_states"]]]= None, + transition_fn: Optional[Callable[[IntScalar], Float[Array, "num_states num_states"]]]= None, compute_trans_probs: bool = True ) -> HMMPosterior: r"""Computed the smoothed state probabilities using a general @@ -324,11 +330,11 @@ def _step(carry, args): @partial(jit, static_argnames=["transition_fn", "window_size"]) def hmm_fixed_lag_smoother( initial_distribution: Float[Array, " num_states"], - transition_matrix: Union[Float[Array, "num_timesteps num_states num_states"], - Float[Array, "num_states num_states"]], + transition_matrix: Optional[Union[Float[Array, "num_states num_states"], + Float[Array, "num_timesteps_minus_1 num_states num_states"]]], log_likelihoods: Float[Array, "num_timesteps num_states"], - window_size: Int, - transition_fn: Optional[Callable[[Int], Float[Array, "num_states num_states"]]]= None + window_size: int, + transition_fn: Optional[Callable[[IntScalar], Float[Array, "num_states num_states"]]]= None ) -> HMMPosterior: r"""Compute the smoothed state probabilities using the fixed-lag smoother. @@ -438,10 +444,10 @@ def compute_posterior(filtered_probs, beta): @partial(jit, static_argnames=["transition_fn"]) def hmm_posterior_mode( initial_distribution: Float[Array, " num_states"], - transition_matrix: Union[Float[Array, "num_timesteps num_states num_states"], - Float[Array, "num_states num_states"]], + transition_matrix: Optional[Union[Float[Array, "num_states num_states"], + Float[Array, "num_timesteps_minus_1 num_states num_states"]]], log_likelihoods: Float[Array, "num_timesteps num_states"], - transition_fn: Optional[Callable[[Int], Float[Array, "num_states num_states"]]]= None + transition_fn: Optional[Callable[[IntScalar], Float[Array, "num_states num_states"]]]= None ) -> Int[Array, " num_timesteps"]: r"""Compute the most likely state sequence. This is called the Viterbi algorithm. @@ -486,10 +492,10 @@ def _forward_pass(state, best_next_state): def hmm_posterior_sample( key: Array, initial_distribution: Float[Array, " num_states"], - transition_matrix: Union[Float[Array, "num_timesteps num_states num_states"], - Float[Array, "num_states num_states"]], + transition_matrix: Optional[Union[Float[Array, "num_states num_states"], + Float[Array, "num_timesteps_minus_1 num_states num_states"]]], log_likelihoods: Float[Array, "num_timesteps num_states"], - transition_fn: Optional[Callable[[Int], Float[Array, "num_states num_states"]]] = None + transition_fn: Optional[Callable[[IntScalar], Float[Array, "num_states num_states"]]] = None ) -> Tuple[Scalar, Int[Array, " num_timesteps"]]: r"""Sample a latent sequence from the posterior. @@ -542,6 +548,7 @@ def _compute_sum_transition_probs( transition_matrix: Float[Array, "num_states num_states"], hmm_posterior: HMMPosterior) -> Float[Array, "num_states num_states"]: """Compute the transition probabilities from the HMM posterior messages. + Args: transition_matrix (_type_): _description_ hmm_posterior (_type_): _description_ @@ -578,11 +585,13 @@ def _step(carry, args: Tuple[Array, Array, Array, Int[Array, ""]]): def _compute_all_transition_probs( - transition_matrix: Float[Array, "num_timesteps num_states num_states"], + transition_matrix: Optional[Union[Float[Array, "num_states num_states"], + Float[Array, "num_timesteps_minus_1 num_states num_states"]]], hmm_posterior: HMMPosterior, - transition_fn: Optional[Callable[[Int], Float[Array, "num_states num_states"]]] = None + transition_fn: Optional[Callable[[IntScalar], Float[Array, "num_states num_states"]]] = None ) -> Float[Array, "num_timesteps num_states num_states"]: """Compute the transition probabilities from the HMM posterior messages. + Args: transition_matrix (_type_): _description_ hmm_posterior (_type_): _description_ @@ -596,20 +605,21 @@ def _compute_probs(t): A = get_trans_mat(transition_matrix, transition_fn, t) return jnp.einsum('i,ij,j->ij', filtered_probs[t], A, relative_probs_next[t]) - transition_probs = vmap(_compute_probs)(jnp.arange(len(filtered_probs)-1)) + transition_probs = vmap(_compute_probs)(jnp.arange(len(filtered_probs))) return transition_probs -# TODO: Consider alternative annotation for return type: -# Float[Array, "*num_timesteps num_states num_states"] I think this would allow multiple prepended dims. -# Float[Array, "#num_timesteps num_states num_states"] this might accept (1, sd, sd) but not (sd, sd). +# TODO: This is a candidate for @overload however at present I think we would need to use +# `@beartype.typing.overload` and beartype is currently not a core dependency. +# Support for `typing.overload` might change in the future: +# https://github.com/beartype/beartype/issues/54 def compute_transition_probs( - transition_matrix: Union[Float[Array, "num_timesteps num_states num_states"], - Float[Array, "num_states num_states"]], + transition_matrix: Optional[Union[Float[Array, "num_states num_states"], + Float[Array, "num_timesteps_minus_1 num_states num_states"]]], hmm_posterior: HMMPosterior, - transition_fn: Optional[Callable[[Int], Float[Array, "num_states num_states"]]] = None -) -> Union[Float[Array, "num_timesteps num_states num_states"], - Float[Array, "num_states num_states"]]: + transition_fn: Optional[Callable[[IntScalar], Float[Array, "num_states num_states"]]] = None +) -> Union[Float[Array, "num_states num_states"], + Float[Array, "num_timesteps_minus_1 num_states num_states"]]: r"""Compute the posterior marginal distributions $p(z_{t+1}, z_t \mid y_{1:T}, u_{1:T}, \theta)$. Args: @@ -620,8 +630,10 @@ def compute_transition_probs( Returns: array of smoothed transition probabilities. """ - reduce_sum = transition_matrix is not None and transition_matrix.ndim == 2 - if reduce_sum: + if transition_matrix is None and transition_fn is None: + raise ValueError("Either `transition_matrix` or `transition_fn` must be specified.") + + if transition_matrix is not None and transition_matrix.ndim == 2: return _compute_sum_transition_probs(transition_matrix, hmm_posterior) else: return _compute_all_transition_probs(transition_matrix, hmm_posterior, transition_fn=transition_fn)