From da2395f25c895af96b0b52f0887a930f21eb4783 Mon Sep 17 00:00:00 2001 From: gileshd Date: Mon, 23 Sep 2024 12:58:32 +0100 Subject: [PATCH] Add further type annotations to poisson hmm --- .../hidden_markov_model/models/poisson_hmm.py | 51 +++++++++++++------ 1 file changed, 35 insertions(+), 16 deletions(-) diff --git a/dynamax/hidden_markov_model/models/poisson_hmm.py b/dynamax/hidden_markov_model/models/poisson_hmm.py index 6dc77090..b5c79203 100644 --- a/dynamax/hidden_markov_model/models/poisson_hmm.py +++ b/dynamax/hidden_markov_model/models/poisson_hmm.py @@ -1,4 +1,4 @@ -from typing import NamedTuple, Optional, Tuple, Union +from typing import Any, Dict, NamedTuple, Optional, Tuple, Union import jax.numpy as jnp import jax.random as jr @@ -6,13 +6,14 @@ import tensorflow_probability.substrates.jax.distributions as tfd from jaxtyping import Array, Float +from dynamax.hidden_markov_model.inference import HMMPosterior from dynamax.hidden_markov_model.models.abstractions import HMM, HMMEmissions from dynamax.hidden_markov_model.models.initial import ParamsStandardHMMInitialState from dynamax.hidden_markov_model.models.initial import StandardHMMInitialState from dynamax.hidden_markov_model.models.transitions import ParamsStandardHMMTransitions from dynamax.hidden_markov_model.models.transitions import StandardHMMTransitions from dynamax.parameters import ParameterProperties, ParameterSet, PropertySet -from dynamax.types import Scalar +from dynamax.types import IntScalar, Scalar from dynamax.utils.utils import pytree_sum @@ -23,10 +24,10 @@ class ParamsPoissonHMMEmissions(NamedTuple): class PoissonHMMEmissions(HMMEmissions): def __init__(self, - num_states, - emission_dim, - emission_prior_concentration=1.1, - emission_prior_rate=0.1): + num_states: int, + emission_dim: int, + emission_prior_concentration: Scalar = 1.1, + emission_prior_rate: Scalar = 0.1): """_summary_ Args: @@ -40,12 +41,13 @@ def __init__(self, self.emission_prior_rate = emission_prior_rate @property - def emission_shape(self): + def emission_shape(self) -> Tuple[int]: return (self.emission_dim,) - def initialize(self, key=jr.PRNGKey(0), - method="prior", - emission_rates=None): + def initialize(self, key: Array=jr.PRNGKey(0), + method: str = "prior", + emission_rates: Optional[Float[Array, "num_states emission_dim"]] = None + ) -> Tuple[ParamsPoissonHMMEmissions, ParamsPoissonHMMEmissions]: # Initialize the emission probabilities if emission_rates is None: if method.lower() == "prior": @@ -64,24 +66,41 @@ def initialize(self, key=jr.PRNGKey(0), props = ParamsPoissonHMMEmissions(rates=ParameterProperties(constrainer=tfb.Softplus())) return params, props - def distribution(self, params, state, inputs=None): + def distribution( + self, + params: ParamsPoissonHMMEmissions, + state: IntScalar, + inputs: Optional[Array] = None + ) -> tfd.Distribution: return tfd.Independent(tfd.Poisson(rate=params.rates[state]), reinterpreted_batch_ndims=1) - def log_prior(self, params): + def log_prior(self, params: ParamsPoissonHMMEmissions) -> Float[Array, ""]: prior = tfd.Gamma(self.emission_prior_concentration, self.emission_prior_rate) return prior.log_prob(params.rates).sum() - def collect_suff_stats(self, params, posterior, emissions, inputs=None): + def collect_suff_stats( + self, + params: ParamsPoissonHMMEmissions, + posterior: HMMPosterior, + emissions: Float[Array, "num_timesteps emission_dim"], + inputs: Optional[Array] = None + ) -> Dict[str, Float[Array, "..."]]: expected_states = posterior.smoothed_probs sum_w = jnp.einsum("tk->k", expected_states)[:, None] sum_x = jnp.einsum("tk, ti->ki", expected_states, emissions) return dict(sum_w=sum_w, sum_x=sum_x) - def initialize_m_step_state(self, params, props): + def initialize_m_step_state(self, params: ParamsPoissonHMMEmissions, props: ParamsPoissonHMMEmissions) -> None: return None - def m_step(self, params, props, batch_stats, m_step_state): + def m_step( + self, + params: ParamsPoissonHMMEmissions, + props: ParamsPoissonHMMEmissions, + batch_stats: Dict[str, Float[Array, "..."]], + m_step_state: Any + ) -> Tuple[ParamsPoissonHMMEmissions, Any]: if props.rates.trainable: emission_stats = pytree_sum(batch_stats, axis=0) post_concentration = self.emission_prior_concentration + emission_stats['sum_x'] @@ -132,7 +151,7 @@ def __init__(self, emission_component = PoissonHMMEmissions(num_states, emission_dim, emission_prior_concentration=emission_prior_concentration, emission_prior_rate=emission_prior_rate) super().__init__(num_states, initial_component, transition_component, emission_component) - def initialize(self, key=jr.PRNGKey(0), + def initialize(self, key: Array=jr.PRNGKey(0), method="prior", initial_probs: Optional[Float[Array, " num_states"]]=None, transition_matrix: Optional[Float[Array, "num_states num_states"]]=None,