Skip to content

Commit

Permalink
adds mcmc_arviz_compare rule
Browse files Browse the repository at this point in the history
  • Loading branch information
billbrod committed Apr 26, 2022
1 parent 126c2c9 commit def3a48
Showing 1 changed file with 41 additions and 0 deletions.
41 changes: 41 additions & 0 deletions Snakefile
Original file line number Diff line number Diff line change
Expand Up @@ -1426,6 +1426,47 @@ rule mcmc_plots:
fig.savefig(output[0], bbox_inches='tight')


rule mcmc_arviz_compare:
# unlike the rule after this one, this uses arviz's built-in compare
# functionality
input:
[op.join(config["DATA_DIR"], 'mcmc', '{{model_name}}', 'task-split_comp-{{comp}}',
'task-split_comp-{{comp}}_mcmc_{mcmc_model}_step-{{step_size}}_prob-{{accept_prob}}_depth-{{tree_depth}}'
'_c-{{num_chains}}_d-{{num_draws}}_w-{{num_warmup}}_s-{{seed}}.nc').format(mcmc_model=m)
for m in ['unpooled', 'partially-pooled']]
output:
op.join(config["DATA_DIR"], 'mcmc', '{model_name}', 'task-split_comp-{comp}',
'task-split_comp-{comp}_mcmc_compare_step-{step_size}_prob-{accept_prob}_depth-{tree_depth}'
'_c-{num_chains}_d-{num_draws}_w-{num_warmup}_s-{seed}_ic-{ic}.csv'),
op.join(config["DATA_DIR"], 'mcmc', '{model_name}', 'task-split_comp-{comp}',
'task-split_comp-{comp}_mcmc_compare_step-{step_size}_prob-{accept_prob}_depth-{tree_depth}'
'_c-{num_chains}_d-{num_draws}_w-{num_warmup}_s-{seed}_ic-{ic}_arviz.png')
log:
op.join(config["DATA_DIR"], 'logs', 'mcmc', '{model_name}', 'task-split_comp-{comp}',
'task-split_comp-{comp}_mcmc_compare_step-{step_size}_prob-{accept_prob}_depth-{tree_depth}'
'_c-{num_chains}_d-{num_draws}_w-{num_warmup}_s-{seed}_ic-{ic}_arviz.log')
benchmark:
op.join(config["DATA_DIR"], 'logs', 'mcmc', '{model_name}', 'task-split_comp-{comp}',
'task-split_comp-{comp}_mcmc_compare_step-{step_size}_prob-{accept_prob}_depth-{tree_depth}'
'_c-{num_chains}_d-{num_draws}_w-{num_warmup}_s-{seed}_ic-{ic}_arviz_benchmark.txt')
run:
import foveated_metamers as fov
import arviz as az
import re
import contextlib
with open(log[0], 'w', buffering=1) as log_file:
with contextlib.redirect_stdout(log_file), contextlib.redirect_stderr(log_file):
models = {}
for i in input:
inf = az.from_netcdf(i)
name = re.findall('mcmc_([a-z-]+)_step', i)[0]
models[name] = inf
comp_df = az.compare(models, ic=wildcards.ic)
comp_df.to_csv(output[0], index=False)
fig = az.plot_compare(comp_df)
fig.savefig(output[1], bbox_inches='tight')


rule mcmc_compare_plot:
input:
[op.join(config["DATA_DIR"], 'mcmc', '{{model_name}}', 'task-split_comp-{{comp}}',
Expand Down

0 comments on commit def3a48

Please sign in to comment.