diff --git a/cmdstanpy/utils/stancsv.py b/cmdstanpy/utils/stancsv.py index 1a095f48..b7a3b21c 100644 --- a/cmdstanpy/utils/stancsv.py +++ b/cmdstanpy/utils/stancsv.py @@ -193,7 +193,12 @@ def scan_config(fd: TextIO, config_dict: Dict[str, Any], lineno: int) -> int: try: val = float(raw_val) except ValueError: - val = raw_val + if raw_val == "true": + val = 1 + elif raw_val == "false": + val = 0 + else: + val = raw_val config_dict[key_val[0].strip()] = val cur_pos = fd.tell() line = fd.readline().strip() diff --git a/docsrc/changes.rst b/docsrc/changes.rst index 651c46f0..8427effa 100644 --- a/docsrc/changes.rst +++ b/docsrc/changes.rst @@ -7,6 +7,12 @@ What's New For full changes, see the `Releases page `_ on GitHub. +CmdStanPy 1.2.4 +--------------- + +- Fixed a bug in `from_csv` which prevented reading files created by CmdStan 2.35.0+ + +Reminder: The next non-bugfix release of CmdStanPy will be version 2.0, which will remove all existing deprecations. CmdStanPy 1.2.3 --------------- diff --git a/test/test_sample.py b/test/test_sample.py index 128d1625..369132dc 100644 --- a/test/test_sample.py +++ b/test/test_sample.py @@ -2030,6 +2030,26 @@ def test_tuple_data_in() -> None: data_model.sample(data, chains=1, iter_warmup=1, iter_sampling=1) +def test_csv_roundtrip(): + stan = os.path.join(DATAFILES_PATH, 'matrix_var.stan') + model = CmdStanModel(stan_file=stan) + fit = model.sample( + iter_sampling=10, iter_warmup=9, chains=2, save_warmup=True + ) + z = fit.stan_variable(var="z") + assert z.shape == (20, 4, 3) + z_with_warmup = fit.stan_variable(var="z", inc_warmup=True) + assert z_with_warmup.shape == (38, 4, 3) + + # mostly just asserting that from_csv always succeeds + # in parsing latest cmdstan headers + fit_from_csv = from_csv(fit.runset.csv_files) + z_from_csv = fit_from_csv.stan_variable(var="z") + assert z_from_csv.shape == (20, 4, 3) + z_with_warmup_from_csv = fit.stan_variable(var="z", inc_warmup=True) + assert z_with_warmup_from_csv.shape == (38, 4, 3) + + @pytest.mark.order(before="test_no_xarray") def test_serialization(stanfile='bernoulli.stan'): # This test must before any test that uses the `without_import` context