Skip to content

Commit

Permalink
style: fix formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
alecksphillips committed Aug 17, 2023
1 parent f8ef727 commit 9a71451
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 12 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1 @@
*/__pycache__
*/__pycache__
1 change: 1 addition & 0 deletions retrospectr/importance_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from scipy.special import logsumexp
import json


def calculate_log_weights(model, old_samples, old_data, new_data):
"""
Calculate importance weights of old_samples generated by model with old_data
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
[flake8]
max-line-length = 127
max-line-length = 127
23 changes: 13 additions & 10 deletions test/test_importance_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,10 @@ def eight_schools_data_json(eight_schools_data_file):
@pytest.fixture
def eight_schools_data_dict(eight_schools_data_file):
with open(eight_schools_data_file) as f:
json_dict= json.load(f)
json_dict = json.load(f)
return json_dict


@pytest.fixture
def eight_schools_samples():
return np.load(os.path.join(
Expand Down Expand Up @@ -71,18 +72,21 @@ def eight_schools_new_data_file():
TEST_MODELS_PATH, 'eight_schools', 'eight_schools.new_data.json'
)


@pytest.fixture
def eight_schools_new_data_json(eight_schools_new_data_file):
with open(eight_schools_new_data_file) as f:
json_data = f.read()
return json_data


@pytest.fixture
def eight_schools_new_data_dict(eight_schools_new_data_file):
with open(eight_schools_new_data_file) as f:
json_dict = json.load(f)
return json_dict


@pytest.fixture
def seven_schools_data_file():
return os.path.join(
Expand Down Expand Up @@ -155,39 +159,38 @@ def test_good(self, eight_schools_model_file, eight_schools_samples, eight_schoo
eight_schools_data_file, eight_schools_new_data_file)
np.testing.assert_almost_equal(log_weights, eight_schools_log_weights)


def test_good_json_string_data(self, eight_schools_model_file, eight_schools_samples, eight_schools_data_json, eight_schools_new_data_json,
eight_schools_log_weights):
def test_good_json_string_data(self, eight_schools_model_file, eight_schools_samples, eight_schools_data_json,
eight_schools_new_data_json, eight_schools_log_weights):
log_weights = calculate_log_weights(
eight_schools_model_file, eight_schools_samples,
eight_schools_data_json, eight_schools_new_data_json)
np.testing.assert_almost_equal(log_weights, eight_schools_log_weights)

def test_good_python_dict_data(self, eight_schools_model_file, eight_schools_samples, eight_schools_data_dict, eight_schools_new_data_dict,
eight_schools_log_weights):
def test_good_python_dict_data(self, eight_schools_model_file, eight_schools_samples, eight_schools_data_dict,
eight_schools_new_data_dict, eight_schools_log_weights):
log_weights = calculate_log_weights(
eight_schools_model_file, eight_schools_samples,
eight_schools_data_dict, eight_schools_new_data_dict)
np.testing.assert_almost_equal(log_weights, eight_schools_log_weights)


# Should get RuntimeError from bridgestan
def test_invalid_old_data(self, eight_schools_model_file, eight_schools_samples, eight_schools_bad_data_file,
eight_schools_new_data_file):
# Should get RuntimeError from bridgestan
with np.testing.assert_raises(RuntimeError):
calculate_log_weights(
eight_schools_model_file, eight_schools_samples,
eight_schools_bad_data_file, eight_schools_new_data_file)

# Should get RuntimeError from bridgestan
def test_invalid_new_data(self, eight_schools_model_file, eight_schools_samples, eight_schools_data_file,
eight_schools_bad_data_file):
# Should get RuntimeError from bridgestan
with np.testing.assert_raises(RuntimeError):
calculate_log_weights(
eight_schools_model_file, eight_schools_samples,
eight_schools_data_file, eight_schools_bad_data_file)

def test_invalid_stan_model(self, invalid_model, eight_schools_samples, eight_schools_data_file, eight_schools_new_data_file):
def test_invalid_stan_model(self, invalid_model, eight_schools_samples, eight_schools_data_file,
eight_schools_new_data_file):
with np.testing.assert_raises(ValueError):
calculate_log_weights(
invalid_model, eight_schools_samples,
Expand Down

0 comments on commit 9a71451

Please sign in to comment.