From febe0b462168d55dd8295762f31cb2315b00c50d Mon Sep 17 00:00:00 2001 From: "William F. Broderick" Date: Tue, 26 Apr 2022 16:45:31 -0400 Subject: [PATCH] adds extended scaling inf_data when running mcmc, create another version of the inf_data whose predictive distributions have a finer and broader sampling of scaling values --- Snakefile | 10 +++++++ foveated_metamers/mcmc.py | 63 ++++++++++++++++++++++++++++++--------- 2 files changed, 59 insertions(+), 14 deletions(-) diff --git a/Snakefile b/Snakefile index 4467eba..c91315d 100644 --- a/Snakefile +++ b/Snakefile @@ -1323,6 +1323,9 @@ rule mcmc: op.join(config["DATA_DIR"], 'mcmc', '{model_name}', 'task-split_comp-{comp}', 'task-split_comp-{comp}_mcmc_{mcmc_model}_step-{step_size}_prob-{accept_prob}_depth-{tree_depth}_' 'c-{num_chains}_d-{num_draws}_w-{num_warmup}_s-{seed}.nc'), + op.join(config["DATA_DIR"], 'mcmc', '{model_name}', 'task-split_comp-{comp}', + 'task-split_comp-{comp}_mcmc_{mcmc_model}_step-{step_size}_prob-{accept_prob}_depth-{tree_depth}_' + 'c-{num_chains}_d-{num_draws}_w-{num_warmup}_s-{seed}_scaling-extended.nc'), log: op.join(config["DATA_DIR"], 'logs', 'mcmc', '{model_name}', 'task-split_comp-{comp}', 'task-split_comp-{comp}_mcmc_{mcmc_model}_step-{step_size}_prob-{accept_prob}_depth-{tree_depth}_' @@ -1354,6 +1357,13 @@ rule mcmc: wildcards.mcmc_model, int(wildcards.seed)+1) inf_data.to_netcdf(output[0]) + # want to have a different seed for constructing the inference + # data object than we did for inference itself + inf_data_extended = fov.mcmc.assemble_inf_data(mcmc, dataset, + wildcards.mcmc_model, + int(wildcards.seed)+10, + extend_scaling=True) + inf_data_extended.to_netcdf(output[1]) rule mcmc_plots: diff --git a/foveated_metamers/mcmc.py b/foveated_metamers/mcmc.py index 2cae2d3..5990559 100644 --- a/foveated_metamers/mcmc.py +++ b/foveated_metamers/mcmc.py @@ -579,7 +579,8 @@ def simulate_dataset(critical_scaling, proportionality_factor, coords) -def _assign_inf_dims(samples_dict, dataset, dummy_dim=None): +def _assign_inf_dims(samples_dict, dataset, dummy_dim=None, + extend_scaling=False): """Figure out the mapping between vars and coords. It's annoying to line up variables and coordinates. this does the best it @@ -592,11 +593,17 @@ def _assign_inf_dims(samples_dict, dataset, dummy_dim=None): corresponded to (trials, scaling, subject_name) or (trials, scaling, image_name) -- would assume the latter. + if extend_scaling is True, we assume that there are 50 scaling values and + ignore the length of scaling in dataset + """ dims = {} if dummy_dim is not None and not hasattr(dummy_dim, '__iter__'): dummy_dim = [dummy_dim] - if sum([dataset.dims['scaling'] == v for v in dataset.dims.values()]) > 1: + scaling_dims = dataset.dims['scaling'] + if extend_scaling: + scaling_dims = 50 + if sum([scaling_dims == v for v in dataset.dims.values()]) > 1: # then we have something that's the same size as scaling and need to do # the complicated check below scaling_check = True @@ -607,15 +614,20 @@ def _assign_inf_dims(samples_dict, dataset, dummy_dim=None): dims[k] = [] i = 1 for d in dataset.observed_responses.dims: + if d == 'scaling' and extend_scaling: + # if we've extended scaling, then scaling will have length 50 + coord_len = 50 + else: + coord_len = len(dataset.coords[d]) if i >= len(var_shape): break - if len(dataset.coords[d]) == var_shape[i]: + if coord_len == var_shape[i]: # if scaling is the same shape as one of the other dims, can # sometimes mis-assign. but we know that if there's only one # thing that has the same shape as scaling, then it's not # scaling (parameters won't have scaling, responses or # probability correct will have scaling and the other coords) - if d == 'scaling' and scaling_check and sum([len(dataset.coords[d]) == s for s in var_shape]) == 1: + if d == 'scaling' and scaling_check and sum([coord_len == s for s in var_shape]) == 1: continue dims[k] += [d] i += 1 @@ -627,14 +639,20 @@ def _assign_inf_dims(samples_dict, dataset, dummy_dim=None): return dims -def _arrange_vars(dataset): +def _arrange_vars(dataset, extend_scaling=False): """Get and reshape scaling and observed responses from dataset.""" if dataset.observed_responses.dims[:2] != ('trials', 'scaling'): raise Exception("First two dimensions of observed responses must " "be trials and scaling!") observed_responses = jnp.array(dataset.observed_responses.values, dtype=jnp.float32) - scaling = jnp.array(dataset.scaling.values, dtype=jnp.float32) + scaling = dataset.scaling.values + if not extend_scaling: + scaling = jnp.array(scaling, dtype=jnp.float32) + else: + scaling = jnp.logspace(np.log10(scaling.min() - scaling.min() / 3), + np.log10(scaling.max() + scaling.max() / 3), + dtype=jnp.float32) # get scaling into the appropriate shape -- scaling on the first dimension, # and then repeated to match the shape of observed_responses after that scaling = jnp.expand_dims(scaling, tuple(-(i+1) for i in @@ -718,7 +736,7 @@ def run_inference(dataset, mcmc_model_type='partially-pooled', step_size=.1, def assemble_inf_data(mcmc, dataset, mcmc_model_type='partially-pooled', - seed=1): + seed=1, extend_scaling=False): """Convert mcmc into properly-formatted inference data object. Parameters @@ -732,6 +750,10 @@ def assemble_inf_data(mcmc, dataset, mcmc_model_type='partially-pooled', coordinates trials and scaling (must be first two). seed : int, optional RNG seed. + extend_scaling : bool, optional + Whether to use the original scaling values (False) or extend the range + in both directions and sample it more finely (True), which will lead to + prettier plots Returns ------- @@ -755,7 +777,12 @@ def assemble_inf_data(mcmc, dataset, mcmc_model_type='partially-pooled', response_model = unpooled_response_model else: raise Exception(f"Don't know how to handle mcmc_model_type {mcmc_model_type}!") - scaling, obs = _arrange_vars(dataset) + scaling, obs = _arrange_vars(dataset, extend_scaling) + coords = {k: v.values for k, v in dataset.coords.items()} + if extend_scaling: + coords['scaling'] = scaling.squeeze() + while coords['scaling'].ndim > 1: + coords['scaling'] = coords['scaling'][..., 0] model = dataset.model.values[0].split('_')[0] if model == 'simulated': # then it's simulate_{actual_model_name} @@ -777,13 +804,15 @@ def assemble_inf_data(mcmc, dataset, mcmc_model_type='partially-pooled', # the subject-level variables have a dummy dimension at the same place as # the image_name dimension, in order to allow broadcasting. we allow it # here, and then drop it later - prior_dims = _assign_inf_dims(prior, dataset, dummy_dim=dummy_dims[0]) - prior = az.from_numpyro(prior=prior, coords=dataset.coords, dims=prior_dims) + prior_dims = _assign_inf_dims(prior, dataset, dummy_dim=dummy_dims[0], + extend_scaling=extend_scaling) + prior = az.from_numpyro(prior=prior, coords=coords, dims=prior_dims) posterior_pred = posterior_pred(PRNGKey(seed+1), scaling, model) post_dims = _assign_inf_dims(posterior_pred, dataset, - dummy_dim=dummy_dims[1]) + dummy_dim=dummy_dims[1], + extend_scaling=extend_scaling) posterior_pred = az.from_numpyro(posterior_predictive=posterior_pred, - coords=dataset.coords, dims=post_dims) + coords=coords, dims=post_dims) # the observed data will have a trials dim first post_dims['responses'][0] = 'trials' # in this case, there was only one trial_type, and the shape of the @@ -794,12 +823,18 @@ def assemble_inf_data(mcmc, dataset, mcmc_model_type='partially-pooled', post_dims['responses'].pop(1) # the subject-level variables these have a dummy dimension at the same # place as the image_name dimension, in order to allow broadcasting. we - # allow it here, and then drop it later + # allow it here, and then drop it later. we don't use extend_scaling here, + # because this is for the posterior, which is based on the actual samples + # we drew (which have the limited scaling values) rather than the + # predictive ones variable_dims = _assign_inf_dims(mcmc.get_samples(), dataset, dummy_dim=dummy_dims[2]) variable_dims.update(post_dims) # if there was missing data, it will need to the imputation in order to - # compute log-likelihood, so we need the seed handler + # compute log-likelihood, so we need the seed handler. we use + # dataset.coords here for the same reason we don't pass extend_scaling + # above -- the posterior doesn't use the extended scaling values but the + # actual ones with numpyro.handlers.seed(rng_seed=seed+2): inf_data = (az.from_numpyro(mcmc, coords=dataset.coords, dims=variable_dims) + prior + posterior_pred)