Skip to content

Commit

Permalink
add more tests and improve code coverage (#273)
Browse files Browse the repository at this point in the history
* improve tests and code coverage

* notebook

* more tests
  • Loading branch information
LegrandNico authored and SylvainEstebe committed Jan 17, 2025
1 parent 6f55a36 commit de4acf6
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 15 deletions.
30 changes: 16 additions & 14 deletions docs/source/notebooks/0.3-Generalised_filtering.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -347,30 +347,32 @@
"tags": []
},
"source": [
"````{note} From sufficient statistics to distribution parameters and backwards\n",
":::{note} From sufficient statistics to distribution parameters and backwards\n",
":class: dropdown\n",
"\n",
"When using a 1-dimensional Gaussian distribution, Setting $\\xi = [0, \\frac{1}{8}]$ is equivalent to a mean $\\mu = 0.0$ and a variance $\\sigma^2 = \\frac{1}{8}$. You can convert between distribution parameters and expected sufficient statistics using the distribution classes from PyHGF (when implemented):\n",
"\n",
"```{code-cell} python\n",
"```python\n",
"from pyhgf.math import Normal\n",
"\n",
"# from an observation to sufficient statistics\n",
"Normal.sufficient_statistics_from_observations(x=1.5)\n",
"```\n",
"> Array([1.5 , 2.25], dtype=float32)\n",
"```{code-cell} python\n",
"\n",
"```python\n",
"# from distribution parameters to sufficient statistics\n",
"Normal.sufficient_statistics_from_parameters(mean=0.0, variance=4.0)\n",
"```\n",
"> Array([0., 4.], dtype=float32)\n",
"```{code-cell} python\n",
"\n",
"```python\n",
"# from sufficient statistics to distribution parameters\n",
"Normal.parameters_from_sufficient_statistics(xis=[0.0, 4.0])\n",
"```\n",
"> (0.0, 4.0)\n",
"\n",
"````"
":::"
]
},
{
Expand Down Expand Up @@ -417,7 +419,7 @@
"</svg>\n"
],
"text/plain": [
"<graphviz.sources.Source at 0x7f69845ec530>"
"<graphviz.sources.Source at 0x7fec53b15520>"
]
},
"execution_count": 7,
Expand Down Expand Up @@ -665,7 +667,7 @@
"</svg>\n"
],
"text/plain": [
"<graphviz.sources.Source at 0x7f6984591040>"
"<graphviz.sources.Source at 0x7fec4dea0950>"
]
},
"execution_count": 13,
Expand Down Expand Up @@ -1239,7 +1241,7 @@
"</svg>\n"
],
"text/plain": [
"<graphviz.sources.Source at 0x7f69784b16a0>"
"<graphviz.sources.Source at 0x7fec3c292780>"
]
},
"execution_count": 21,
Expand Down Expand Up @@ -1292,7 +1294,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Last updated: Tue Jan 14 2025\n",
"Last updated: Wed Jan 15 2025\n",
"\n",
"Python implementation: CPython\n",
"Python version : 3.12.3\n",
Expand All @@ -1302,13 +1304,13 @@
"jax : 0.4.31\n",
"jaxlib: 0.4.31\n",
"\n",
"numpy : 1.26.0\n",
"sys : 3.12.3 | packaged by conda-forge | (main, Apr 15 2024, 18:38:13) [GCC 12.3.0]\n",
"pyhgf : 0.2.1.post4.dev0+d49aafe9\n",
"matplotlib: 3.10.0\n",
"jax : 0.4.31\n",
"IPython : 8.31.0\n",
"seaborn : 0.13.2\n",
"matplotlib: 3.10.0\n",
"IPython : 8.31.0\n",
"sys : 3.12.3 | packaged by conda-forge | (main, Apr 15 2024, 18:38:13) [GCC 12.3.0]\n",
"pyhgf : 0.2.1.post4.dev0+d49aafe9\n",
"numpy : 1.26.0\n",
"\n",
"Watermark: 2.5.0\n",
"\n"
Expand Down
32 changes: 31 additions & 1 deletion tests/test_nodes/test_exponential_family.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,10 @@ def test_multivariate_gaussian():
+ np.random.randn(N, 2) * 2
)

# Python ---------------------------------------------------------------------------
# Python
# ----------------------------------------------------------------------------------

# generalised filtering
bivariate_normal = (
PyNetwork()
.add_nodes(
Expand All @@ -64,3 +67,30 @@ def test_multivariate_gaussian():
dtype="float32",
),
).all()

# hgf updates
bivariate_hgf = PyNetwork().add_nodes(
kind="ef-state",
learning="hgf-2",
distribution="multivariate-normal",
dimension=2,
)

# adapting prior parameter values to the sufficient statistics
# covariances statistics will have greater variability and amplitudes
for node_idx in [2, 5, 8, 11, 14]:
bivariate_hgf.attributes[node_idx]["tonic_volatility"] = -2.0
for node_idx in [1, 4, 7, 10, 13]:
bivariate_hgf.attributes[node_idx]["precision"] = 0.01
for node_idx in [9, 12, 15]:
bivariate_hgf.attributes[node_idx]["mean"] = 10.0

bivariate_hgf.input_data(input_data=spiral_data)

assert jnp.isclose(
bivariate_normal.node_trajectories[0]["xis"][-1],
jnp.array(
[3.4652710e01, -1.0609777e00, 1.2103647e03, -3.6398651e01, 3.3951855e00],
dtype="float32",
),
).all()

0 comments on commit de4acf6

Please sign in to comment.