diff --git a/docs/tutorials/pymc.ipynb b/docs/tutorials/pymc.ipynb index a61b3133..bec12009 100644 --- a/docs/tutorials/pymc.ipynb +++ b/docs/tutorials/pymc.ipynb @@ -61,15 +61,10 @@ "import pytensor\n", "from ssms.basic_simulators.simulator import simulator\n", "\n", - "from hssm.distribution_utils import (\n", - " make_distribution, # A general function for making Distribution classes\n", - " make_distribution_from_onnx, # Makes Distribution classes from onnx files\n", - " make_distribution_from_blackbox, # Makes Distribution classes from callables\n", - ")\n", + "# A general function for making Distribution classes\n", + "from hssm.distribution_utils import make_distribution\n", "\n", - "# pm.Distributions that represents the top-level distribution for\n", - "# DDM models (the Wiener First-Passage Time distribution)\n", - "from hssm.likelihoods import logp_ddm_sdv, DDM\n", + "# A utility function for downloading the pre-trained models\n", "from hssm.utils import download_hf\n", "\n", "pytensor.config.floatX = \"float32\"" @@ -200,7 +195,7 @@ ], "source": [ "# Simulate some data\n", - "v_true, a_true, z_true, t_true = [0.5, 1.5, 0.5, 0.5]\n", + "v_true, a_true, z_true, t_true, sv_true = [0.5, 1.5, 0.5, 0.5, 0.1]\n", "obs_ddm = simulator([v_true, a_true, z_true, t_true], model=\"ddm\", n_samples=1000)\n", "obs_ddm = np.column_stack([obs_ddm[\"rts\"][:, 0], obs_ddm[\"choices\"][:, 0]])\n", "dataset = pd.DataFrame(obs_ddm, columns=[\"rt\", \"response\"])\n", @@ -235,6 +230,9 @@ } ], "source": [ + "# This is a pm.Distribution available in HSSM\n", + "from hssm.likelihoods import DDM\n", + "\n", "with pm.Model() as ddm_pymc:\n", " v = pm.Uniform(\"v\", lower=-10.0, upper=10.0)\n", " a = pm.HalfNormal(\"a\", sigma=2.0)\n", @@ -253,16 +251,18 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Building top-level distributions with `distribution_utils`\n", + "## Building custom top-level distributions with `make_distribution`\n", "\n", - "### `make_distribution` and `make_distribution_from_blackbox`\n", + "**Note**: This tutorial has undergone major updates in HSSM 0.3.0 following breaking changes in the `distribution_utils` api. Please follow this tutorial closely if your previous code no longer works.\n", "\n", - "The above example shows that, as long as the top-level distribution is known, modeling can be done in `PyMC` as well without using `Bambi`. However, as [this official `PyMC` tutorial](https://www.pymc.io/projects/docs/en/v4.0.1/contributing/developer_guide_implementing_distribution.html) shows, creating a distribution in PyMC can be a time consuming-task. You will need to create a `RandomVariable` first and then define your custom `Distribution` by extending `pm.Distribution` class. From `PyMC 5.0.0` on, `pm.CustomDist` simplifies this process, but the use case is not applicable to complex likelihoods in HSSM. Fortunately, HSSM provides many convenience functions in its `distribution_utils` submodule that make this process easy. Next, we use another example to show how we can use these functions to create custom `pm.Distribution`s to be used with `PyMC`.\n", + "### `make_distribution`\n", + "\n", + "The above example shows that, as long as the top-level distribution is known, modeling can be done in `PyMC` as well without using `Bambi`. However, as [this official `PyMC` tutorial](https://www.pymc.io/projects/docs/en/latest/contributing/implementing_distribution.html) shows, creating a distribution in PyMC can be a time consuming-task. You will need to create a `RandomVariable` first and then define your custom `Distribution` by extending `pm.Distribution` class. From `PyMC 5.0.0` on, `pm.CustomDist` simplifies this process, but the use case is not applicable to complex likelihoods in HSSM. Therefore, HSSM provides convenience functions in its `distribution_utils` submodule that make this process easy. Next, we use another example to show how we can use these functions to create custom `pm.Distribution`s to be used with `PyMC`.\n", "\n", "Suppose we have a likelihood function for DDM models with standard deviations for `v` written. This model has 5 parameters: `v, a, z, t, sv`, and we want to use this function as the likelihood to create a `pm.Distribution` for modeling with `PyMC`. We can use `make_distribution` for this purpose.\n", "\n", "**Note**\n", - "This distribution is already available in HSSM at `hssm.likelihoods.DDM_SDV`. For illustration purposes, we go through the process in which this distribution is created. We can use the same procedure for other distributions not currently available in HSSM." + "This distribution is already available in HSSM at `hssm.likelihoods.DDM_SDV`. For illustration purposes, we go through the same process in which this distribution is created. We can use the same procedure for other distributions not currently available in HSSM." ] }, { @@ -271,12 +271,30 @@ "metadata": {}, "outputs": [], "source": [ + "# This is a likelihood function for the DDM with SDV\n", + "# Different from DDM which we imported in the previous example, which is a pm.Distribution\n", + "from hssm.likelihoods import logp_ddm_sdv\n", + "\n", + "# We use `make_distribution` to wrap the likelihood function into a pm.Distribution\n", "DDM_SDV = make_distribution(\n", " rv=\"ddm_sdv\",\n", " loglik=logp_ddm_sdv,\n", " list_params=[\"v\", \"a\", \"z\", \"t\", \"sv\"],\n", " bounds={\"t\": (0, 1)},\n", - ")" + ")\n", + "\n", + "with pm.Model() as ddm_sdv_model:\n", + " v = pm.Uniform(\"v\", lower=-10.0, upper=10.0)\n", + " a = pm.HalfNormal(\"a\", sigma=2.0)\n", + " z = pm.Uniform(\"z\", lower=0.01, upper=0.99)\n", + " t = pm.Uniform(\"t\", lower=0.0, upper=0.6, initval=0.1)\n", + " sv = pm.HalfNormal(\"sv\", sigma=2.0)\n", + "\n", + " ddm = DDM_SDV(\"ddm\", v=v, a=a, z=z, t=t, sv=sv, observed=dataset.values)\n", + "\n", + " ddm_sdv_trace = pm.sample()\n", + "\n", + "az.plot_trace(ddm_sdv_trace)" ] }, { @@ -284,11 +302,11 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Now we have a `pm.Distribution` that can be used for modeling in `PyMC`. The procedure is very similar to the one described in the previous section, so we will omit it here. Let's instead focus on the parameters for `make_distribution`.\n", + "`make_distribution` will create a `pm.Distribution` class that can be used for modeling in `PyMC` in the code above, which \n", "\n", "- `rv`: a `str` or a `RandomVariable`. If a `str` is provided, a `RandomVariable` class will be created automatically. This `RandomVariable` will use the `str` to identify a simulator provided in the `ssm_simulators` package as its `rng_fn` (sampling) function. If this `str` is not one of the entries to the `model_config` `dict` specified [here](https://github.com/AlexanderFengler/ssm-simulators/blob/main/ssms/config/config.py), then the `Distribution` will still be created but with a warning that any attempt to sample from the `RandomVariable` will result in an Error. That includes sampling from the posterior distribution. The user could create his/her own `RandomVariable` class and define its `rng_fn` class method for sampling.\n", "\n", - "- `loglik`: an `Op` or a `Callable`. HSSM assumes that the `Callable` to be written with `pytensor` functions and can be compiled into a `pytensor` computation graph. If that is not the case, please use `make_distribution`'s close sibling `make_distribution_from_blackbox`, which will first wrap the `Callable` in an `Op` and then create the `pm.Distribution`. The signature of `make_distribution_from_blackbox` is almost identical to that of `make_distribution`.\n", + "- `loglik`: an `Op` or a `Callable`. HSSM assumes that the `Callable` is a function written in `pytensor` and can be compiled into a `pytensor` computation graph. If that is not the case, please use `make_distribution`'s close sibling `make_distribution_from_blackbox`, which will first wrap the `Callable` in an `Op` and then create the `pm.Distribution`. The signature of `make_distribution_from_blackbox` is almost identical to that of `make_distribution`.\n", "\n", " The function signature for the `Op` or `Callable` has to follow a specific pattern. Please refer to [this section](likelihoods.ipynb#using-custom-likelihoods) for more details.\n", "\n",