From f5398eb143b33c52d799e3f502b2de18cf83d53d Mon Sep 17 00:00:00 2001 From: Alex Phillips Date: Thu, 17 Aug 2023 12:01:34 +0100 Subject: [PATCH] feat: add resample method Add resample method to resample from log_weights --- retrospectr/__init__.py | 2 +- retrospectr/resampling.py | 26 ++++++++++++++++++++++++++ 2 files changed, 27 insertions(+), 1 deletion(-) create mode 100644 retrospectr/resampling.py diff --git a/retrospectr/__init__.py b/retrospectr/__init__.py index 863d295..b53fce8 100644 --- a/retrospectr/__init__.py +++ b/retrospectr/__init__.py @@ -1 +1 @@ -all = ["importance_weights"] +all = ["importance_weights", "resampling"] diff --git a/retrospectr/resampling.py b/retrospectr/resampling.py new file mode 100644 index 0000000..917948d --- /dev/null +++ b/retrospectr/resampling.py @@ -0,0 +1,26 @@ +import numpy as np +import cmdstanpy +from retrospectr.importance_weights import extract_samples + + +def resample(samples, log_weights, seed=0): + + if isinstance(samples, cmdstanpy.CmdStanMCMC): + samples = extract_samples(samples) + + rng = np.random.default_rng(seed=seed) + niters = log_weights.shape[0] + nchains = log_weights.shape[1] + nparams = samples.shape[2] + + nsamples = niters*nchains + flat_log_weights = log_weights.reshape((nsamples)) + + resampled_iterations = rng.choice( + nsamples, + size=nsamples, + p=np.exp(flat_log_weights)) + + flat_samples = samples.reshape(nsamples, 1, nparams) + resampled_samples = flat_samples[resampled_iterations, :] + return resampled_samples