diff --git a/Snakefile b/Snakefile index 4c3a5f8..4467eba 100644 --- a/Snakefile +++ b/Snakefile @@ -1426,6 +1426,47 @@ rule mcmc_plots: fig.savefig(output[0], bbox_inches='tight') +rule mcmc_arviz_compare: + # unlike the rule after this one, this uses arviz's built-in compare + # functionality + input: + [op.join(config["DATA_DIR"], 'mcmc', '{{model_name}}', 'task-split_comp-{{comp}}', + 'task-split_comp-{{comp}}_mcmc_{mcmc_model}_step-{{step_size}}_prob-{{accept_prob}}_depth-{{tree_depth}}' + '_c-{{num_chains}}_d-{{num_draws}}_w-{{num_warmup}}_s-{{seed}}.nc').format(mcmc_model=m) + for m in ['unpooled', 'partially-pooled']] + output: + op.join(config["DATA_DIR"], 'mcmc', '{model_name}', 'task-split_comp-{comp}', + 'task-split_comp-{comp}_mcmc_compare_step-{step_size}_prob-{accept_prob}_depth-{tree_depth}' + '_c-{num_chains}_d-{num_draws}_w-{num_warmup}_s-{seed}_ic-{ic}.csv'), + op.join(config["DATA_DIR"], 'mcmc', '{model_name}', 'task-split_comp-{comp}', + 'task-split_comp-{comp}_mcmc_compare_step-{step_size}_prob-{accept_prob}_depth-{tree_depth}' + '_c-{num_chains}_d-{num_draws}_w-{num_warmup}_s-{seed}_ic-{ic}_arviz.png') + log: + op.join(config["DATA_DIR"], 'logs', 'mcmc', '{model_name}', 'task-split_comp-{comp}', + 'task-split_comp-{comp}_mcmc_compare_step-{step_size}_prob-{accept_prob}_depth-{tree_depth}' + '_c-{num_chains}_d-{num_draws}_w-{num_warmup}_s-{seed}_ic-{ic}_arviz.log') + benchmark: + op.join(config["DATA_DIR"], 'logs', 'mcmc', '{model_name}', 'task-split_comp-{comp}', + 'task-split_comp-{comp}_mcmc_compare_step-{step_size}_prob-{accept_prob}_depth-{tree_depth}' + '_c-{num_chains}_d-{num_draws}_w-{num_warmup}_s-{seed}_ic-{ic}_arviz_benchmark.txt') + run: + import foveated_metamers as fov + import arviz as az + import re + import contextlib + with open(log[0], 'w', buffering=1) as log_file: + with contextlib.redirect_stdout(log_file), contextlib.redirect_stderr(log_file): + models = {} + for i in input: + inf = az.from_netcdf(i) + name = re.findall('mcmc_([a-z-]+)_step', i)[0] + models[name] = inf + comp_df = az.compare(models, ic=wildcards.ic) + comp_df.to_csv(output[0], index=False) + fig = az.plot_compare(comp_df) + fig.savefig(output[1], bbox_inches='tight') + + rule mcmc_compare_plot: input: [op.join(config["DATA_DIR"], 'mcmc', '{{model_name}}', 'task-split_comp-{{comp}}',