Skip to content

Commit

Permalink
chore: simplify README example + example.py
Browse files Browse the repository at this point in the history
Simplify README example and example.py to use `resample` and data dict input
  • Loading branch information
alecksphillips committed Aug 17, 2023
1 parent c69b7cd commit 4899268
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 34 deletions.
20 changes: 3 additions & 17 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,11 @@ pip install -U git+https://github.com/alecksphillips/retrospectr.git
## Example
```python
import cmdstanpy
import numpy as np
import matplotlib.pyplot as plt
import seaborn
import pandas as pd
import json
from retrospectr.importance_weights import calculate_log_weights, extract_samples
from retrospectr.resampling import resample


model_file = "test/test_models/bernoulli/bernoulli.stan"
Expand All @@ -44,35 +43,22 @@ original_data = {
"N": 10,
"y": [0, 1, 0, 0, 0, 0, 0, 0, 0, 1]
}

original_data_file_path = "original_data.json"
with open(original_data_file_path, "w") as f:
json.dump(original_data, f)

original_fit = stan_model.sample(data=original_data, chains=1)

new_data = {
"N": 20,
"y": [0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 0, 1]
}

new_data_file_path = "new_data.json"
with open(new_data_file_path, "w") as f:
json.dump(new_data, f)


new_fit = stan_model.sample(data=new_data, chains=1)

original_samples = extract_samples(original_fit)

new_samples = extract_samples(new_fit)

log_weights = calculate_log_weights(model_file, original_samples, original_data_file_path, new_data_file_path)

resampled_iterations = np.random.choice(
len(log_weights), size=len(log_weights), p=np.exp(log_weights.reshape(len(log_weights))))
log_weights = calculate_log_weights(model_file, original_samples, original_data, new_data)

resampled_original_samples = original_samples[resampled_iterations, :]
resampled_original_samples = resample(original_samples, log_weights)

df_original = pd.DataFrame({
"theta": original_samples.reshape(len(original_samples)),
Expand Down
20 changes: 3 additions & 17 deletions example.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import cmdstanpy
import numpy as np
import matplotlib.pyplot as plt
import seaborn
import pandas as pd
import json
from retrospectr.importance_weights import calculate_log_weights, extract_samples
from retrospectr.resampling import resample


model_file = "test/test_models/bernoulli/bernoulli.stan"
Expand All @@ -14,35 +13,22 @@
"N": 10,
"y": [0, 1, 0, 0, 0, 0, 0, 0, 0, 1]
}

original_data_file_path = "original_data.json"
with open(original_data_file_path, "w") as f:
json.dump(original_data, f)

original_fit = stan_model.sample(data=original_data, chains=1)

new_data = {
"N": 20,
"y": [0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 0, 1]
}

new_data_file_path = "new_data.json"
with open(new_data_file_path, "w") as f:
json.dump(new_data, f)


new_fit = stan_model.sample(data=new_data, chains=1)

original_samples = extract_samples(original_fit)

new_samples = extract_samples(new_fit)

log_weights = calculate_log_weights(model_file, original_samples, original_data_file_path, new_data_file_path)

resampled_iterations = np.random.choice(
len(log_weights), size=len(log_weights), p=np.exp(log_weights.reshape(len(log_weights))))
log_weights = calculate_log_weights(model_file, original_samples, original_data, new_data)

resampled_original_samples = original_samples[resampled_iterations, :]
resampled_original_samples = resample(original_samples, log_weights)

df_original = pd.DataFrame({
"theta": original_samples.reshape(len(original_samples)),
Expand Down

0 comments on commit 4899268

Please sign in to comment.