From 1ae60a90bb681fbc25bf7c4529b9098758ee54b3 Mon Sep 17 00:00:00 2001 From: Till Hoffmann Date: Fri, 9 Feb 2024 16:33:54 -0500 Subject: [PATCH 1/3] Add `diagnose` method to `CmdStanModel`. --- cmdstanpy/model.py | 104 +++++++++++++++++++++++++++++++++++++++++++++ test/test_model.py | 28 ++++++++++++ 2 files changed, 132 insertions(+) diff --git a/cmdstanpy/model.py b/cmdstanpy/model.py index 8a3d067b..e59cc623 100644 --- a/cmdstanpy/model.py +++ b/cmdstanpy/model.py @@ -1,5 +1,6 @@ """CmdStanModel""" +import io import os import platform import re @@ -2171,3 +2172,106 @@ def progress_hook(line: str, idx: int) -> None: pbars[idx].postfix[0]["value"] = mline return progress_hook + + def diagnose( + self, + inits: Union[Dict[str, Any], str, os.PathLike, None] = None, + data: Union[Mapping[str, Any], str, os.PathLike, None] = None, + *, + epsilon: Optional[float] = None, + error: Optional[float] = None, + require_gradients_ok: bool = True, + sig_figs: Optional[int] = None, + ) -> pd.DataFrame: + """ + Run diagnostics to calculate the gradients at the specified parameter + values and compare them with gradients calculated by finite differences. + + :param inits: Specifies how the sampler initializes parameter values. + Initialization is either uniform random on a range centered on 0, + exactly 0, or a dictionary or file of initial values for some or + all parameters in the model. The default initialization behavior + will initialize all parameter values on range [-2, 2] on the + *unconstrained* support. The following value types are allowed: + * Single number, n > 0 - initialization range is [-n, n]. + * 0 - all parameters are initialized to 0. + * dictionary - pairs parameter name : initial value. + * string - pathname to a JSON or Rdump data file. + + :param data: Values for all data variables in the model, specified + either as a dictionary with entries matching the data variables, + or as the path of a data file in JSON or Rdump format. + + :param sig_figs: Numerical precision used for output CSV and text files. + Must be an integer between 1 and 18. If unspecified, the default + precision for the system file I/O is used; the usual value is 6. + + :param epsilon: Step size for finite difference gradients. + + :param error: Absolute error threshold for comparing autodiff and finite + difference gradients. + + :param require_gradients_ok: Whether or not to raise an error if Stan + reports that the difference between autodiff gradients and finite + difference gradients exceed the error threshold. + + :return: A pandas.DataFrame containing columns + * "param_idx": increasing parameter index. + * "value": Parameter value. + * "model": Gradients evaluated using autodiff. + * "finite_diff": Gradients evaluated using finite differences. + * "error": Delta between autodiff and finite difference gradients. + """ + + with temp_single_json(data) as _data, \ + temp_single_json(inits) as _inits: + cmd = [ + str(self.exe_file), + "diagnose", + "test=gradient", + ] + if epsilon is not None: + cmd.append(f"epsilon={epsilon}") + if error is not None: + cmd.append(f"epsilon={error}") + if _data is not None: + cmd += ["data", f"file={_data}"] + if _inits is not None: + cmd.append(f"inits={_inits}") + + output_dir = tempfile.mkdtemp(prefix=self.name, dir=_TMPDIR) + + output = os.path.join(output_dir, "output.csv") + cmd += ["output", f"file={output}"] + if sig_figs is not None: + cmd.append(f"sig_figs={sig_figs}") + + get_logger().debug("Cmd: %s", str(cmd)) + + proc = subprocess.run( + cmd, capture_output=True, check=False, text=True + ) + if proc.returncode: + if require_gradients_ok: + raise RuntimeError( + "The difference between autodiff and finite difference " + "gradients may exceed the error threshold. If you " + "would like to inspect the output, re-call with " + "`require_gradients_ok=False`." + ) + get_logger().warning( + "The difference between autodiff and finite difference " + "gradients may exceed the error threshold. Proceeding " + "because `require_gradients_ok` is set to `False`." + ) + + # Read the text and get the last chunk separated by a single # char. + with open(output) as handle: + text = handle.read() + *_, table = re.split(r"#\s*\n", text) + table = ( + re.sub(r"^#\s*", "", table, flags=re.M) + .replace("param idx", "param_idx") + .replace("finite diff", "finite_diff") + ) + return pd.read_csv(io.StringIO(table), sep=r"\s+") diff --git a/test/test_model.py b/test/test_model.py index 55c2bfe7..87c48e01 100644 --- a/test/test_model.py +++ b/test/test_model.py @@ -12,6 +12,7 @@ from typing import List from unittest.mock import MagicMock, patch +import numpy as np import pytest from cmdstanpy.model import CmdStanModel @@ -593,3 +594,30 @@ def test_format_old_version() -> None: model.format(max_line_length=88) model.format(canonicalize=True) + + +def test_diagnose(): + # Check the gradients. + model = CmdStanModel(stan_file=BERN_STAN) + gradients = model.diagnose(data=BERN_DATA) + + # Check we have the right columns. + assert set(gradients) == { + "param_idx", + "value", + "model", + "finite_diff", + "error", + } + + # Simulate bad gradients by using large finite difference. + with pytest.raises(RuntimeError, match="may exceed the error threshold"): + model.diagnose(data=BERN_DATA, epsilon=3) + + # Check we get the results if we set require_gradients_ok=False. + gradients = model.diagnose( + data=BERN_DATA, + epsilon=3, + require_gradients_ok=False, + ) + assert np.abs(gradients["error"]).max() > 1e-3 From a79f0c2dd7c3dbf390f4290b7255fb751dd4c0cd Mon Sep 17 00:00:00 2001 From: Till Hoffmann Date: Fri, 9 Feb 2024 18:34:14 -0500 Subject: [PATCH 2/3] Add `inits` and explicit gradient test. --- cmdstanpy/model.py | 4 +++- test/test_model.py | 5 +++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/cmdstanpy/model.py b/cmdstanpy/model.py index e59cc623..5add5704 100644 --- a/cmdstanpy/model.py +++ b/cmdstanpy/model.py @@ -2221,6 +2221,8 @@ def diagnose( * "model": Gradients evaluated using autodiff. * "finite_diff": Gradients evaluated using finite differences. * "error": Delta between autodiff and finite difference gradients. + + Gradients are evaluated in the unconstrained space. """ with temp_single_json(data) as _data, \ @@ -2237,7 +2239,7 @@ def diagnose( if _data is not None: cmd += ["data", f"file={_data}"] if _inits is not None: - cmd.append(f"inits={_inits}") + cmd.append(f"init={_inits}") output_dir = tempfile.mkdtemp(prefix=self.name, dir=_TMPDIR) diff --git a/test/test_model.py b/test/test_model.py index 87c48e01..61e07342 100644 --- a/test/test_model.py +++ b/test/test_model.py @@ -610,6 +610,11 @@ def test_diagnose(): "error", } + # Check gradients against the same value as in `log_prob`. + inits = {"theta": 0.34903938392023830482} + gradients = model.diagnose(data=BERN_DATA, inits=inits) + np.testing.assert_allclose(gradients.model.iloc[0], -1.18847) + # Simulate bad gradients by using large finite difference. with pytest.raises(RuntimeError, match="may exceed the error threshold"): model.diagnose(data=BERN_DATA, epsilon=3) From 2b19fc82352aebc8b976229921252dd3a2c5169b Mon Sep 17 00:00:00 2001 From: Till Hoffmann Date: Fri, 9 Feb 2024 18:36:12 -0500 Subject: [PATCH 3/3] Show `stdout` and `stderr` for non-zero return codes. --- cmdstanpy/model.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/cmdstanpy/model.py b/cmdstanpy/model.py index 5add5704..cd6e284f 100644 --- a/cmdstanpy/model.py +++ b/cmdstanpy/model.py @@ -2254,6 +2254,11 @@ def diagnose( cmd, capture_output=True, check=False, text=True ) if proc.returncode: + get_logger().error( + "'diagnose' command failed!\nstdout:%s\nstderr:%s", + proc.stdout, + proc.stderr, + ) if require_gradients_ok: raise RuntimeError( "The difference between autodiff and finite difference " @@ -2268,8 +2273,13 @@ def diagnose( ) # Read the text and get the last chunk separated by a single # char. - with open(output) as handle: - text = handle.read() + try: + with open(output) as handle: + text = handle.read() + except FileNotFoundError as exc: + raise RuntimeError( + "Output of 'diagnose' command does not exist." + ) from exc *_, table = re.split(r"#\s*\n", text) table = ( re.sub(r"^#\s*", "", table, flags=re.M)