diff --git a/test/generate_test_data.py b/test/generate_test_data.py index db4a535..59e1682 100644 --- a/test/generate_test_data.py +++ b/test/generate_test_data.py @@ -52,3 +52,16 @@ log_weights = new_logProbs - logProbs log_weights = log_weights - logsumexp(log_weights) np.save(os.path.join(model_path, "eight_schools_log_weights.npy"), log_weights) + +nsamples = samples.shape[0]*samples.shape[1] +tmp_samples = samples.reshape((nsamples, 1, samples.shape[2])) +tmp_log_weights = log_weights.reshape((nsamples)) + +rng = np.random.default_rng(seed=0) +resampled_iterations = rng.choice( + nsamples, + size=nsamples, + p=np.exp(tmp_log_weights)) + +resampled_samples = tmp_samples[resampled_iterations, :, :] +np.save(os.path.join(model_path, "eight_schools_resampled_samples.npy"), resampled_samples) diff --git a/test/test_models/eight_schools/eight_schools_resampled_samples.npy b/test/test_models/eight_schools/eight_schools_resampled_samples.npy new file mode 100644 index 0000000..3d07bd5 Binary files /dev/null and b/test/test_models/eight_schools/eight_schools_resampled_samples.npy differ diff --git a/test/test_resample.py b/test/test_resample.py new file mode 100644 index 0000000..d9a5893 --- /dev/null +++ b/test/test_resample.py @@ -0,0 +1,108 @@ +from pathlib import Path +import os +import pytest +import cmdstanpy +import numpy as np +import json + +from retrospectr.resampling import resample + +TEST_MODELS_PATH = os.path.join(Path(__file__).parent, 'test_models') + + +@pytest.fixture +def eight_schools_model_file(): + return os.path.join( + TEST_MODELS_PATH, 'eight_schools', 'eight_schools.stan' + ) + + +@pytest.fixture +def eight_schools_data_file(): + return os.path.join( + TEST_MODELS_PATH, 'eight_schools', 'eight_schools.data.json' + ) + + +@pytest.fixture +def eight_schools_data_json(eight_schools_data_file): + with open(eight_schools_data_file) as f: + json_data = f.read() + return json_data + + +@pytest.fixture +def eight_schools_data_dict(eight_schools_data_file): + with open(eight_schools_data_file) as f: + json_dict = json.load(f) + return json_dict + + +@pytest.fixture +def eight_schools_samples(): + return np.load(os.path.join( + TEST_MODELS_PATH, 'eight_schools', 'eight_schools_samples.npy' + )) + + +@pytest.fixture +def eight_schools_log_weights(): + return np.load(os.path.join( + TEST_MODELS_PATH, 'eight_schools', 'eight_schools_log_weights.npy' + )) + + +@pytest.fixture +def eight_schools_resampled_samples(): + return np.load(os.path.join( + TEST_MODELS_PATH, 'eight_schools', 'eight_schools_resampled_samples.npy' + )) + + +@pytest.fixture +def seven_schools_data_file(): + return os.path.join( + TEST_MODELS_PATH, 'eight_schools', 'seven_schools.data.json' + ) + + +@pytest.fixture +def seven_schools_samples(): + return np.load(os.path.join( + TEST_MODELS_PATH, 'eight_schools', 'seven_schools_samples.npy' + )) + + +@pytest.fixture +def eight_schools_bad_data_file(): + return os.path.join( + TEST_MODELS_PATH, 'eight_schools', 'eight_schools.bad_data.json' + ) + + +@pytest.fixture +def eight_schools_cmdstanpy_fit(eight_schools_model_file, eight_schools_data_file): + model = cmdstanpy.CmdStanModel(stan_file=eight_schools_model_file) + fit = model.sample(data=eight_schools_data_file, chains=2, iter_sampling=200, iter_warmup=200, seed=0) + return fit + + +class TestResampled: + def test_good_sample_array(self, eight_schools_samples, eight_schools_log_weights, eight_schools_resampled_samples): + resampled_samples = resample(eight_schools_samples, eight_schools_log_weights, seed=0) + np.testing.assert_equal(resampled_samples, eight_schools_resampled_samples) + + def test_good_cmdstanpy_fit(self, eight_schools_cmdstanpy_fit, eight_schools_log_weights, eight_schools_resampled_samples): + resampled_samples = resample(eight_schools_cmdstanpy_fit, eight_schools_log_weights, seed=0) + np.testing.assert_equal(resampled_samples, eight_schools_resampled_samples) + + def test_bad_weights_wrong_len(self, eight_schools_samples, eight_schools_log_weights): + tmp_log_weights = eight_schools_log_weights[0:(len(eight_schools_log_weights)-1)] + with np.testing.assert_raises(ValueError): + resample(eight_schools_samples, tmp_log_weights, seed=0) + + def test_bad_weights_not_sum_to_one(self, eight_schools_samples, eight_schools_log_weights): + tmp_log_weights = eight_schools_log_weights + tmp_log_weights[0] = 1.0 + with np.testing.assert_raises(ValueError): + resample(eight_schools_samples, tmp_log_weights, seed=0)