Skip to content

Commit

Permalink
Merge pull request #8 from alecksphillips/feat/resample
Browse files Browse the repository at this point in the history
Resampling
  • Loading branch information
alecksphillips authored Aug 17, 2023
2 parents cf591f2 + f5398eb commit c69b7cd
Show file tree
Hide file tree
Showing 5 changed files with 148 additions and 1 deletion.
2 changes: 1 addition & 1 deletion retrospectr/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
all = ["importance_weights"]
all = ["importance_weights", "resampling"]
26 changes: 26 additions & 0 deletions retrospectr/resampling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import numpy as np
import cmdstanpy
from retrospectr.importance_weights import extract_samples


def resample(samples, log_weights, seed=0):

if isinstance(samples, cmdstanpy.CmdStanMCMC):
samples = extract_samples(samples)

rng = np.random.default_rng(seed=seed)
niters = log_weights.shape[0]
nchains = log_weights.shape[1]
nparams = samples.shape[2]

nsamples = niters*nchains
flat_log_weights = log_weights.reshape((nsamples))

resampled_iterations = rng.choice(
nsamples,
size=nsamples,
p=np.exp(flat_log_weights))

flat_samples = samples.reshape(nsamples, 1, nparams)
resampled_samples = flat_samples[resampled_iterations, :]
return resampled_samples
13 changes: 13 additions & 0 deletions test/generate_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,16 @@
log_weights = new_logProbs - logProbs
log_weights = log_weights - logsumexp(log_weights)
np.save(os.path.join(model_path, "eight_schools_log_weights.npy"), log_weights)

nsamples = samples.shape[0]*samples.shape[1]
tmp_samples = samples.reshape((nsamples, 1, samples.shape[2]))
tmp_log_weights = log_weights.reshape((nsamples))

rng = np.random.default_rng(seed=0)
resampled_iterations = rng.choice(
nsamples,
size=nsamples,
p=np.exp(tmp_log_weights))

resampled_samples = tmp_samples[resampled_iterations, :, :]
np.save(os.path.join(model_path, "eight_schools_resampled_samples.npy"), resampled_samples)
Binary file not shown.
108 changes: 108 additions & 0 deletions test/test_resample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
from pathlib import Path
import os
import pytest
import cmdstanpy
import numpy as np
import json

from retrospectr.resampling import resample

TEST_MODELS_PATH = os.path.join(Path(__file__).parent, 'test_models')


@pytest.fixture
def eight_schools_model_file():
return os.path.join(
TEST_MODELS_PATH, 'eight_schools', 'eight_schools.stan'
)


@pytest.fixture
def eight_schools_data_file():
return os.path.join(
TEST_MODELS_PATH, 'eight_schools', 'eight_schools.data.json'
)


@pytest.fixture
def eight_schools_data_json(eight_schools_data_file):
with open(eight_schools_data_file) as f:
json_data = f.read()
return json_data


@pytest.fixture
def eight_schools_data_dict(eight_schools_data_file):
with open(eight_schools_data_file) as f:
json_dict = json.load(f)
return json_dict


@pytest.fixture
def eight_schools_samples():
return np.load(os.path.join(
TEST_MODELS_PATH, 'eight_schools', 'eight_schools_samples.npy'
))


@pytest.fixture
def eight_schools_log_weights():
return np.load(os.path.join(
TEST_MODELS_PATH, 'eight_schools', 'eight_schools_log_weights.npy'
))


@pytest.fixture
def eight_schools_resampled_samples():
return np.load(os.path.join(
TEST_MODELS_PATH, 'eight_schools', 'eight_schools_resampled_samples.npy'
))


@pytest.fixture
def seven_schools_data_file():
return os.path.join(
TEST_MODELS_PATH, 'eight_schools', 'seven_schools.data.json'
)


@pytest.fixture
def seven_schools_samples():
return np.load(os.path.join(
TEST_MODELS_PATH, 'eight_schools', 'seven_schools_samples.npy'
))


@pytest.fixture
def eight_schools_bad_data_file():
return os.path.join(
TEST_MODELS_PATH, 'eight_schools', 'eight_schools.bad_data.json'
)


@pytest.fixture
def eight_schools_cmdstanpy_fit(eight_schools_model_file, eight_schools_data_file):
model = cmdstanpy.CmdStanModel(stan_file=eight_schools_model_file)
fit = model.sample(data=eight_schools_data_file, chains=2, iter_sampling=200, iter_warmup=200, seed=0)
return fit


class TestResampled:
def test_good_sample_array(self, eight_schools_samples, eight_schools_log_weights, eight_schools_resampled_samples):
resampled_samples = resample(eight_schools_samples, eight_schools_log_weights, seed=0)
np.testing.assert_equal(resampled_samples, eight_schools_resampled_samples)

def test_good_cmdstanpy_fit(self, eight_schools_cmdstanpy_fit, eight_schools_log_weights, eight_schools_resampled_samples):
resampled_samples = resample(eight_schools_cmdstanpy_fit, eight_schools_log_weights, seed=0)
np.testing.assert_equal(resampled_samples, eight_schools_resampled_samples)

def test_bad_weights_wrong_len(self, eight_schools_samples, eight_schools_log_weights):
tmp_log_weights = eight_schools_log_weights[0:(len(eight_schools_log_weights)-1)]
with np.testing.assert_raises(ValueError):
resample(eight_schools_samples, tmp_log_weights, seed=0)

def test_bad_weights_not_sum_to_one(self, eight_schools_samples, eight_schools_log_weights):
tmp_log_weights = eight_schools_log_weights
tmp_log_weights[0] = 1.0
with np.testing.assert_raises(ValueError):
resample(eight_schools_samples, tmp_log_weights, seed=0)

0 comments on commit c69b7cd

Please sign in to comment.