Skip to content

Commit

Permalink
Merge pull request #6 from alecksphillips/feat/data-inputs
Browse files Browse the repository at this point in the history
Handle data dict inputs
  • Loading branch information
alecksphillips authored Aug 17, 2023
2 parents a2304d6 + 9a71451 commit 30fd841
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 30 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1 @@
*/__pycache__
*/__pycache__
12 changes: 10 additions & 2 deletions retrospectr/importance_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import numpy as np
import cmdstanpy
from scipy.special import logsumexp
import json


def calculate_log_weights(model, old_samples, old_data, new_data):
Expand All @@ -11,11 +12,18 @@ def calculate_log_weights(model, old_samples, old_data, new_data):
param: model: Path to .stan file containing the model
param: old_samples: numpy array of samples from model with original/old data
param: old_data: Path to json file containing old data
param: new_data: Path to json file containing new data
param: old_data: Old data, one of: python dict; json string literal; path to .json file
param: new_data: New data, one of: python dict; json string literal; path to .json file
returns: log importance weights of old_samples conditioned on new data
"""

if isinstance(old_data, dict):
old_data = json.dumps(old_data)

if isinstance(new_data, dict):
new_data = json.dumps(new_data)

# check old samples match the model+old_data
old_model = bs.StanModel.from_stan_file(model, model_data=old_data)
check_sample_dim(old_model, old_samples)
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
[flake8]
max-line-length = 127
max-line-length = 127
96 changes: 70 additions & 26 deletions test/test_importance_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import bridgestan as bs
import cmdstanpy
import numpy as np
import json

from retrospectr.importance_weights import (
calculate_log_weights,
Expand All @@ -24,12 +25,26 @@ def eight_schools_model_file():


@pytest.fixture
def eight_schools_data():
def eight_schools_data_file():
return os.path.join(
TEST_MODELS_PATH, 'eight_schools', 'eight_schools.data.json'
)


@pytest.fixture
def eight_schools_data_json(eight_schools_data_file):
with open(eight_schools_data_file) as f:
json_data = f.read()
return json_data


@pytest.fixture
def eight_schools_data_dict(eight_schools_data_file):
with open(eight_schools_data_file) as f:
json_dict = json.load(f)
return json_dict


@pytest.fixture
def eight_schools_samples():
return np.load(os.path.join(
Expand All @@ -52,14 +67,28 @@ def eight_schools_log_weights():


@pytest.fixture
def eight_schools_new_data():
def eight_schools_new_data_file():
return os.path.join(
TEST_MODELS_PATH, 'eight_schools', 'eight_schools.new_data.json'
)


@pytest.fixture
def seven_schools_data():
def eight_schools_new_data_json(eight_schools_new_data_file):
with open(eight_schools_new_data_file) as f:
json_data = f.read()
return json_data


@pytest.fixture
def eight_schools_new_data_dict(eight_schools_new_data_file):
with open(eight_schools_new_data_file) as f:
json_dict = json.load(f)
return json_dict


@pytest.fixture
def seven_schools_data_file():
return os.path.join(
TEST_MODELS_PATH, 'eight_schools', 'seven_schools.data.json'
)
Expand All @@ -73,32 +102,32 @@ def seven_schools_samples():


@pytest.fixture
def eight_schools_bad_data():
def eight_schools_bad_data_file():
return os.path.join(
TEST_MODELS_PATH, 'eight_schools', 'eight_schools.bad_data.json'
)


@pytest.fixture
def eight_schools_cmdstanpy_fit(eight_schools_model_file, eight_schools_data):
def eight_schools_cmdstanpy_fit(eight_schools_model_file, eight_schools_data_file):
model = cmdstanpy.CmdStanModel(stan_file=eight_schools_model_file)
fit = model.sample(data=eight_schools_data, chains=2, iter_sampling=200, iter_warmup=200, seed=0)
fit = model.sample(data=eight_schools_data_file, chains=2, iter_sampling=200, iter_warmup=200, seed=0)
return fit


@pytest.fixture
def eight_schools_bs_model(eight_schools_model_file, eight_schools_data):
return bs.StanModel.from_stan_file(eight_schools_model_file, model_data=eight_schools_data)
def eight_schools_bs_model(eight_schools_model_file, eight_schools_data_file):
return bs.StanModel.from_stan_file(eight_schools_model_file, model_data=eight_schools_data_file)


@pytest.fixture
def eight_schools_new_bs_model(eight_schools_model_file, eight_schools_new_data):
return bs.StanModel.from_stan_file(eight_schools_model_file, model_data=eight_schools_new_data)
def eight_schools_new_bs_model(eight_schools_model_file, eight_schools_new_data_file):
return bs.StanModel.from_stan_file(eight_schools_model_file, model_data=eight_schools_new_data_file)


@pytest.fixture
def seven_schools_bs_model(eight_schools_model_file, seven_schools_data):
return bs.StanModel.from_stan_file(eight_schools_model_file, model_data=seven_schools_data)
def seven_schools_bs_model(eight_schools_model_file, seven_schools_data_file):
return bs.StanModel.from_stan_file(eight_schools_model_file, model_data=seven_schools_data_file)


@pytest.fixture
Expand All @@ -123,40 +152,55 @@ def invalid_model():


class TestCalculateLogWeights:
def test_good(self, eight_schools_model_file, eight_schools_samples, eight_schools_data, eight_schools_new_data,
def test_good(self, eight_schools_model_file, eight_schools_samples, eight_schools_data_file, eight_schools_new_data_file,
eight_schools_log_weights):
log_weights = calculate_log_weights(
eight_schools_model_file, eight_schools_samples,
eight_schools_data, eight_schools_new_data)
eight_schools_data_file, eight_schools_new_data_file)
np.testing.assert_almost_equal(log_weights, eight_schools_log_weights)

def test_good_json_string_data(self, eight_schools_model_file, eight_schools_samples, eight_schools_data_json,
eight_schools_new_data_json, eight_schools_log_weights):
log_weights = calculate_log_weights(
eight_schools_model_file, eight_schools_samples,
eight_schools_data_json, eight_schools_new_data_json)
np.testing.assert_almost_equal(log_weights, eight_schools_log_weights)

def test_good_python_dict_data(self, eight_schools_model_file, eight_schools_samples, eight_schools_data_dict,
eight_schools_new_data_dict, eight_schools_log_weights):
log_weights = calculate_log_weights(
eight_schools_model_file, eight_schools_samples,
eight_schools_data_dict, eight_schools_new_data_dict)
np.testing.assert_almost_equal(log_weights, eight_schools_log_weights)

# Should get RuntimeError from bridgestan
def test_invalid_old_data(self, eight_schools_model_file, eight_schools_samples, eight_schools_bad_data,
eight_schools_new_data):
def test_invalid_old_data(self, eight_schools_model_file, eight_schools_samples, eight_schools_bad_data_file,
eight_schools_new_data_file):
# Should get RuntimeError from bridgestan
with np.testing.assert_raises(RuntimeError):
calculate_log_weights(
eight_schools_model_file, eight_schools_samples,
eight_schools_bad_data, eight_schools_new_data)
eight_schools_bad_data_file, eight_schools_new_data_file)

# Should get RuntimeError from bridgestan
def test_invalid_new_data(self, eight_schools_model_file, eight_schools_samples, eight_schools_data,
eight_schools_bad_data):
def test_invalid_new_data(self, eight_schools_model_file, eight_schools_samples, eight_schools_data_file,
eight_schools_bad_data_file):
# Should get RuntimeError from bridgestan
with np.testing.assert_raises(RuntimeError):
calculate_log_weights(
eight_schools_model_file, eight_schools_samples,
eight_schools_data, eight_schools_bad_data)
eight_schools_data_file, eight_schools_bad_data_file)

def test_invalid_stan_model(self, invalid_model, eight_schools_samples, eight_schools_data, eight_schools_new_data):
def test_invalid_stan_model(self, invalid_model, eight_schools_samples, eight_schools_data_file,
eight_schools_new_data_file):
with np.testing.assert_raises(ValueError):
calculate_log_weights(
invalid_model, eight_schools_samples,
eight_schools_data, eight_schools_new_data)
eight_schools_data_file, eight_schools_data_file)

def test_invalid_samples(self, invalid_model, seven_schools_samples, eight_schools_data, eight_schools_new_data):
def test_invalid_samples(self, invalid_model, seven_schools_samples, eight_schools_data_file, eight_schools_new_data_file):
with np.testing.assert_raises(ValueError):
calculate_log_weights(
invalid_model, seven_schools_samples,
eight_schools_data, eight_schools_new_data)
eight_schools_data_file, eight_schools_new_data_file)


class TestEvaluateLogProb():
Expand Down

0 comments on commit 30fd841

Please sign in to comment.