Skip to content

Commit

Permalink
feat: add resample method
Browse files Browse the repository at this point in the history
Add resample method to resample from log_weights
  • Loading branch information
alecksphillips committed Aug 17, 2023
1 parent 327e301 commit f5398eb
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 1 deletion.
2 changes: 1 addition & 1 deletion retrospectr/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
all = ["importance_weights"]
all = ["importance_weights", "resampling"]
26 changes: 26 additions & 0 deletions retrospectr/resampling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import numpy as np
import cmdstanpy
from retrospectr.importance_weights import extract_samples


def resample(samples, log_weights, seed=0):

if isinstance(samples, cmdstanpy.CmdStanMCMC):
samples = extract_samples(samples)

rng = np.random.default_rng(seed=seed)
niters = log_weights.shape[0]
nchains = log_weights.shape[1]
nparams = samples.shape[2]

nsamples = niters*nchains
flat_log_weights = log_weights.reshape((nsamples))

resampled_iterations = rng.choice(
nsamples,
size=nsamples,
p=np.exp(flat_log_weights))

flat_samples = samples.reshape(nsamples, 1, nparams)
resampled_samples = flat_samples[resampled_iterations, :]
return resampled_samples

0 comments on commit f5398eb

Please sign in to comment.