Skip to content

Commit

Permalink
computing log likelihood via pymc now, there is something going on wi…
Browse files Browse the repository at this point in the history
…th the bambi version
  • Loading branch information
AlexanderFengler committed Aug 20, 2024
1 parent ea10e10 commit 28361ce
Showing 1 changed file with 9 additions and 18 deletions.
27 changes: 9 additions & 18 deletions src/hssm/hssm.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,6 +592,12 @@ def sample(
"pymc" if sampler == "mcmc" else sampler.split("_")[1]
)

# Don't compute likelihood directly through pymc sampler
compute_likelihood = True
if "idata_kwargs" in kwargs:
if "log_likelihood" in kwargs["idata_kwargs"]:
compute_likelihood = kwargs["idata_kwargs"].pop("log_likelihood", True)

self._inference_obj = self.model.fit(
inference_method=(
"mcmc"
Expand All @@ -603,24 +609,9 @@ def sample(
**kwargs,
)

# The parent was previously not part of deterministics --> compute it via
# posterior_predictive (works because it acts as the 'mu' parameter
# in the GLM as far as bambi is concerned)
if self._inference_obj is not None:
if self._parent not in self._inference_obj.posterior.data_vars:
# self.model.predict(self._inference_obj, kind="mean", inplace=True)
# rename 'rt,response_mean' to 'v' so in the traces everything
# looks the way it should
self._inference_obj.rename_vars(
{"rt,response_mean": self._parent}, inplace=True
)
elif (
self._parent in self._inference_obj.posterior.data_vars
and "rt,response_mean" in self._inference_obj.posterior.data_vars
):
# drop redundant 'rt,response_mean' variable,
# if parent already in posterior
del self._inference_obj.posterior["rt,response_mean"]
if compute_likelihood:
with self.pymc_model:
pm.compute_log_likelihood(self._inference_obj)

# Subset data vars in posterior
if hasattr(self, "pymc_model") and self._inference_obj is not None:
Expand Down

0 comments on commit 28361ce

Please sign in to comment.