You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The challenge is that we'd ideally like to maximize E[\log p(z, y; \theta)] but for the Poisson GLM likelihood that may include expectations that are not analytically calculable. Specifically, for a mean function f, we have:
If f(x) = e^x then we can compute both in closed form, but not for the general case. We just defaulted to a Monte Carlo approximation instead, but we could consider alternatives. E.g. we could take a first or second-order Taylor approximation of f and \log f to get Gaussian integrals. That would be pretty straightforward with JAX and could be more efficient and/or lead to nicer convergence.
In https://github.com/lindermanlab/ssm-jax-refactor/blob/main/ssm/lds/emissions.py#L263, you take the Gaussian expected sufficient statistics E[z_t y_t], and then sample from them, before fitting the Poisson model on this sampled data (IIUC). Is the sampling step necessary? Can you use weighted MLE?
The text was updated successfully, but these errors were encountered: