Skip to content

Commit

Permalink
adds extended scaling inf_data
Browse files Browse the repository at this point in the history
when running mcmc, create another version of the inf_data whose
predictive distributions have a finer and broader sampling of scaling
values
  • Loading branch information
billbrod committed Apr 26, 2022
1 parent def3a48 commit febe0b4
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 14 deletions.
10 changes: 10 additions & 0 deletions Snakefile
Original file line number Diff line number Diff line change
Expand Up @@ -1323,6 +1323,9 @@ rule mcmc:
op.join(config["DATA_DIR"], 'mcmc', '{model_name}', 'task-split_comp-{comp}',
'task-split_comp-{comp}_mcmc_{mcmc_model}_step-{step_size}_prob-{accept_prob}_depth-{tree_depth}_'
'c-{num_chains}_d-{num_draws}_w-{num_warmup}_s-{seed}.nc'),
op.join(config["DATA_DIR"], 'mcmc', '{model_name}', 'task-split_comp-{comp}',
'task-split_comp-{comp}_mcmc_{mcmc_model}_step-{step_size}_prob-{accept_prob}_depth-{tree_depth}_'
'c-{num_chains}_d-{num_draws}_w-{num_warmup}_s-{seed}_scaling-extended.nc'),
log:
op.join(config["DATA_DIR"], 'logs', 'mcmc', '{model_name}', 'task-split_comp-{comp}',
'task-split_comp-{comp}_mcmc_{mcmc_model}_step-{step_size}_prob-{accept_prob}_depth-{tree_depth}_'
Expand Down Expand Up @@ -1354,6 +1357,13 @@ rule mcmc:
wildcards.mcmc_model,
int(wildcards.seed)+1)
inf_data.to_netcdf(output[0])
# want to have a different seed for constructing the inference
# data object than we did for inference itself
inf_data_extended = fov.mcmc.assemble_inf_data(mcmc, dataset,
wildcards.mcmc_model,
int(wildcards.seed)+10,
extend_scaling=True)
inf_data_extended.to_netcdf(output[1])


rule mcmc_plots:
Expand Down
63 changes: 49 additions & 14 deletions foveated_metamers/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,7 +579,8 @@ def simulate_dataset(critical_scaling, proportionality_factor,
coords)


def _assign_inf_dims(samples_dict, dataset, dummy_dim=None):
def _assign_inf_dims(samples_dict, dataset, dummy_dim=None,
extend_scaling=False):
"""Figure out the mapping between vars and coords.
It's annoying to line up variables and coordinates. this does the best it
Expand All @@ -592,11 +593,17 @@ def _assign_inf_dims(samples_dict, dataset, dummy_dim=None):
corresponded to (trials, scaling, subject_name) or (trials, scaling,
image_name) -- would assume the latter.
if extend_scaling is True, we assume that there are 50 scaling values and
ignore the length of scaling in dataset
"""
dims = {}
if dummy_dim is not None and not hasattr(dummy_dim, '__iter__'):
dummy_dim = [dummy_dim]
if sum([dataset.dims['scaling'] == v for v in dataset.dims.values()]) > 1:
scaling_dims = dataset.dims['scaling']
if extend_scaling:
scaling_dims = 50
if sum([scaling_dims == v for v in dataset.dims.values()]) > 1:
# then we have something that's the same size as scaling and need to do
# the complicated check below
scaling_check = True
Expand All @@ -607,15 +614,20 @@ def _assign_inf_dims(samples_dict, dataset, dummy_dim=None):
dims[k] = []
i = 1
for d in dataset.observed_responses.dims:
if d == 'scaling' and extend_scaling:
# if we've extended scaling, then scaling will have length 50
coord_len = 50
else:
coord_len = len(dataset.coords[d])
if i >= len(var_shape):
break
if len(dataset.coords[d]) == var_shape[i]:
if coord_len == var_shape[i]:
# if scaling is the same shape as one of the other dims, can
# sometimes mis-assign. but we know that if there's only one
# thing that has the same shape as scaling, then it's not
# scaling (parameters won't have scaling, responses or
# probability correct will have scaling and the other coords)
if d == 'scaling' and scaling_check and sum([len(dataset.coords[d]) == s for s in var_shape]) == 1:
if d == 'scaling' and scaling_check and sum([coord_len == s for s in var_shape]) == 1:
continue
dims[k] += [d]
i += 1
Expand All @@ -627,14 +639,20 @@ def _assign_inf_dims(samples_dict, dataset, dummy_dim=None):
return dims


def _arrange_vars(dataset):
def _arrange_vars(dataset, extend_scaling=False):
"""Get and reshape scaling and observed responses from dataset."""
if dataset.observed_responses.dims[:2] != ('trials', 'scaling'):
raise Exception("First two dimensions of observed responses must "
"be trials and scaling!")
observed_responses = jnp.array(dataset.observed_responses.values,
dtype=jnp.float32)
scaling = jnp.array(dataset.scaling.values, dtype=jnp.float32)
scaling = dataset.scaling.values
if not extend_scaling:
scaling = jnp.array(scaling, dtype=jnp.float32)
else:
scaling = jnp.logspace(np.log10(scaling.min() - scaling.min() / 3),
np.log10(scaling.max() + scaling.max() / 3),
dtype=jnp.float32)
# get scaling into the appropriate shape -- scaling on the first dimension,
# and then repeated to match the shape of observed_responses after that
scaling = jnp.expand_dims(scaling, tuple(-(i+1) for i in
Expand Down Expand Up @@ -718,7 +736,7 @@ def run_inference(dataset, mcmc_model_type='partially-pooled', step_size=.1,


def assemble_inf_data(mcmc, dataset, mcmc_model_type='partially-pooled',
seed=1):
seed=1, extend_scaling=False):
"""Convert mcmc into properly-formatted inference data object.
Parameters
Expand All @@ -732,6 +750,10 @@ def assemble_inf_data(mcmc, dataset, mcmc_model_type='partially-pooled',
coordinates trials and scaling (must be first two).
seed : int, optional
RNG seed.
extend_scaling : bool, optional
Whether to use the original scaling values (False) or extend the range
in both directions and sample it more finely (True), which will lead to
prettier plots
Returns
-------
Expand All @@ -755,7 +777,12 @@ def assemble_inf_data(mcmc, dataset, mcmc_model_type='partially-pooled',
response_model = unpooled_response_model
else:
raise Exception(f"Don't know how to handle mcmc_model_type {mcmc_model_type}!")
scaling, obs = _arrange_vars(dataset)
scaling, obs = _arrange_vars(dataset, extend_scaling)
coords = {k: v.values for k, v in dataset.coords.items()}
if extend_scaling:
coords['scaling'] = scaling.squeeze()
while coords['scaling'].ndim > 1:
coords['scaling'] = coords['scaling'][..., 0]
model = dataset.model.values[0].split('_')[0]
if model == 'simulated':
# then it's simulate_{actual_model_name}
Expand All @@ -777,13 +804,15 @@ def assemble_inf_data(mcmc, dataset, mcmc_model_type='partially-pooled',
# the subject-level variables have a dummy dimension at the same place as
# the image_name dimension, in order to allow broadcasting. we allow it
# here, and then drop it later
prior_dims = _assign_inf_dims(prior, dataset, dummy_dim=dummy_dims[0])
prior = az.from_numpyro(prior=prior, coords=dataset.coords, dims=prior_dims)
prior_dims = _assign_inf_dims(prior, dataset, dummy_dim=dummy_dims[0],
extend_scaling=extend_scaling)
prior = az.from_numpyro(prior=prior, coords=coords, dims=prior_dims)
posterior_pred = posterior_pred(PRNGKey(seed+1), scaling, model)
post_dims = _assign_inf_dims(posterior_pred, dataset,
dummy_dim=dummy_dims[1])
dummy_dim=dummy_dims[1],
extend_scaling=extend_scaling)
posterior_pred = az.from_numpyro(posterior_predictive=posterior_pred,
coords=dataset.coords, dims=post_dims)
coords=coords, dims=post_dims)
# the observed data will have a trials dim first
post_dims['responses'][0] = 'trials'
# in this case, there was only one trial_type, and the shape of the
Expand All @@ -794,12 +823,18 @@ def assemble_inf_data(mcmc, dataset, mcmc_model_type='partially-pooled',
post_dims['responses'].pop(1)
# the subject-level variables these have a dummy dimension at the same
# place as the image_name dimension, in order to allow broadcasting. we
# allow it here, and then drop it later
# allow it here, and then drop it later. we don't use extend_scaling here,
# because this is for the posterior, which is based on the actual samples
# we drew (which have the limited scaling values) rather than the
# predictive ones
variable_dims = _assign_inf_dims(mcmc.get_samples(), dataset,
dummy_dim=dummy_dims[2])
variable_dims.update(post_dims)
# if there was missing data, it will need to the imputation in order to
# compute log-likelihood, so we need the seed handler
# compute log-likelihood, so we need the seed handler. we use
# dataset.coords here for the same reason we don't pass extend_scaling
# above -- the posterior doesn't use the extended scaling values but the
# actual ones
with numpyro.handlers.seed(rng_seed=seed+2):
inf_data = (az.from_numpyro(mcmc, coords=dataset.coords, dims=variable_dims) +
prior + posterior_pred)
Expand Down

0 comments on commit febe0b4

Please sign in to comment.