From ef79a938cf060e6725528be69b4717d78706e9c8 Mon Sep 17 00:00:00 2001 From: "William F. Broderick" Date: Fri, 18 Jun 2021 18:32:43 -0500 Subject: [PATCH] adds computation and plot of orientation distribution orientation anisotropy seems like a possible way of separating pink noise from natural images (pink noise has equal power at all orientations). so this adds a way to compute that and then plot it we don't mean / median the vaules in each slice, because the distributions are weird. instead we keep the full set of values and, for now, we plot the median using catplot. might change --- Snakefile | 18 +++-- foveated_metamers/figures.py | 124 ++++++++++++++++++++++++++++---- foveated_metamers/plotting.py | 39 ++++++++++ foveated_metamers/statistics.py | 118 ++++++++++++++++++++++++++++-- 4 files changed, 273 insertions(+), 26 deletions(-) diff --git a/Snakefile b/Snakefile index 11a0423..c7fcbb5 100644 --- a/Snakefile +++ b/Snakefile @@ -1515,7 +1515,8 @@ rule compute_amplitude_spectra: ims = sorted(ims, key=lambda x: LINEAR_IMAGES.index([i for i in LINEAR_IMAGES if i in x][0])) assert len(ims) == len(LINEAR_IMAGES), f"Have too many images! Expected {len(LINEAR_IMAGES)}, but got {ims}" ref_image_spectra = fov.statistics.image_set_amplitude_spectra(ims, LINEAR_IMAGES, metadata) - ref_image_spectra = ref_image_spectra.rename({'sf_amplitude': 'ref_image_sf_amplitude'}) + ref_image_spectra = ref_image_spectra.rename({'sf_amplitude': 'ref_image_sf_amplitude', + 'orientation_amplitude': 'ref_image_orientation_amplitude'}) spectra = [] for scaling in scalings: tmp_ims = [i for i in input if len(re.findall(f'scaling-{scaling}', i)) == 1] @@ -1530,7 +1531,8 @@ rule compute_amplitude_spectra: tmp_spectra.append(fov.statistics.image_set_amplitude_spectra(ims, LINEAR_IMAGES, metadata)) spectra.append(xarray.concat(tmp_spectra, 'seed_n')) spectra = xarray.concat(spectra, 'scaling') - spectra = xarray.merge([spectra.rename({'sf_amplitude': 'metamer_sf_amplitude'}), + spectra = xarray.merge([spectra.rename({'sf_amplitude': 'metamer_sf_amplitude', + 'orientation_amplitude': 'metamer_orientation_amplitude'}), ref_image_spectra]) spectra.to_netcdf(output[0]) @@ -1541,13 +1543,13 @@ rule plot_amplitude_spectra: 'task-split_comp-{comp}_amplitude-spectra.nc') output: op.join(config['DATA_DIR'], 'statistics', 'amplitude_spectra', '{model_name}', 'task-split_comp-{comp}', - 'task-split_comp-{comp}_amplitude-spectra.svg') + 'task-split_comp-{comp}_{amplitude_type}-spectra.svg') log: op.join(config['DATA_DIR'], 'logs', 'statistics', 'amplitude_spectra', '{model_name}', 'task-split_comp-{comp}', - 'task-split_comp-{comp}_amplitude-spectra_plot.log') + 'task-split_comp-{comp}_{amplitude_type}-spectra_plot.log') benchmark: op.join(config['DATA_DIR'], 'logs', 'statistics', 'amplitude_spectra', '{model_name}', 'task-split_comp-{comp}', - 'task-split_comp-{comp}_amplitude-spectra_plot_benchmark.txt') + 'task-split_comp-{comp}_{amplitude_type}-spectra_plot_benchmark.txt') run: import foveated_metamers as fov import xarray @@ -1555,9 +1557,13 @@ rule plot_amplitude_spectra: with open(log[0], 'w', buffering=1) as log_file: with contextlib.redirect_stdout(log_file), contextlib.redirect_stderr(log_file): ds = xarray.load_dataset(input[0]) - g = fov.figures.amplitude_spectra(ds) + if wildcards.amplitude_type == 'sf': + g = fov.figures.amplitude_spectra(ds) + elif wildcards.amplitude_type == 'orientation': + g = fov.figures.amplitude_orientation(ds) g.savefig(output[0], bbox_inches='tight') + rule simulate_optimization: output: op.join(config['DATA_DIR'], 'simulate', 'optimization', 'a0-{a0}_s0-{s0}_seeds-{n_seeds}_iter-{max_iter}.svg'), diff --git a/foveated_metamers/figures.py b/foveated_metamers/figures.py index 3c7b018..acf9e34 100644 --- a/foveated_metamers/figures.py +++ b/foveated_metamers/figures.py @@ -4,7 +4,7 @@ import itertools import torch import re -import yaml +from fractions import Fraction import pandas as pd import numpy as np import pyrtools as pt @@ -1758,7 +1758,7 @@ def amplitude_spectra(spectra, hue='scaling', style=None, col='image_name', row=None, col_wrap=5, kind='line', estimator=None, height=2.5, aspect=1, **kwargs): """Compare amplitude spectra of natural and synthesized images. - + Parameters ---------- spectra : xarray.Dataset @@ -1789,19 +1789,7 @@ def amplitude_spectra(spectra, hue='scaling', style=None, col='image_name', Facetgrid with the plot. """ - df = spectra.ref_image_sf_amplitude.to_dataframe().reset_index() - met_df = spectra.metamer_sf_amplitude.to_dataframe().reset_index() - # give the ref image rows dummy values - df['scaling'] = 'ref_image' - df['seed_n'] = 0 - df = df.melt(['freq_n', 'image_name', 'model', 'scaling', 'seed_n', - 'trial_type'], - var_name='image_type', value_name='sf_amplitude') - met_df = met_df.melt(['freq_n', 'image_name', 'model', 'scaling', 'seed_n', - 'trial_type'], - var_name='image_type', value_name='sf_amplitude') - df = pd.concat([df, met_df]) - df.image_type = df.image_type.apply(lambda x: x.replace('_sf_amplitude', '')) + df = plotting._spectra_dataset_to_dataframe(spectra, 'sf') # seaborn raises an error if col_wrap is non-None when col is None or if # row is not None, so prevent that possibility if col is None or row is not None: @@ -1860,3 +1848,109 @@ def amplitude_spectra(spectra, hue='scaling', style=None, col='image_name', " comparisons\n") g.fig.suptitle(title_str, va='bottom') return g + + +def amplitude_orientation(spectra, hue='scaling', style=None, col='image_name', + row=None, col_wrap=5, kind='point', + estimator=np.median, height=2.5, aspect=2, **kwargs): + """Compare orientation distributions of natural and synthesized images. + + Note this is fairly memory intensive. Setting `n_boot` to a lower number + (defaults to 1000) seems to help with this. + + Parameters + ---------- + spectra : xarray.Dataset + Dataset containing the spectra for synthesized metamers and our natural + reference images. + hue, style, col, row : str or None, optional + The dimensions in spectra to facet along the columns, rows, hues, and + styles, respectively. + col_wrap : int or None, optional + If row is None, how many columns to have before wrapping to the next + row. Ignored if col=None. + kind : {'line', 'scatter'}, optional + Type of plot to make. + estimator : name of pandas method or callable + Method for aggregating across multiple observations of the y variable + at the same x level. median is recommended because of the outliers in + the dataset + height : float, optional + Height of the axes. + aspect : float, optional + Aspect of the axes. + kwargs : + Passed to sns.catplot + + Returns + ------- + g : sns.FacetGrid + Facetgrid with the plot. + + """ + df = plotting._spectra_dataset_to_dataframe(spectra, 'orientation') + # seaborn raises an error if col_wrap is non-None when col is None or if + # row is not None, so prevent that possibility + if col is None or row is not None: + col_wrap = None + # remap the image names to be better for plotting + df = plotting._remap_image_names(df) + img_order = plotting.get_order('image_name') + if col == 'image_name': + kwargs.setdefault('col_order', img_order) + if row == 'image_name': + kwargs.setdefault('row_order', img_order) + if hue is not None: + kwargs.setdefault('palette', plotting.get_palette(hue, df[hue].unique())) + else: + kwargs.setdefault('color', 'k') + marker_adjust = {} + dashes_dict = {} + if style is not None: + style_dict = plotting.get_style(style, df[style].unique()) + dashes_dict = style_dict.pop('dashes_dict', {}) + marker_adjust = style_dict.pop('marker_adjust', {}) + kwargs.setdefault('dashes', dashes_dict) + kwargs.update(style_dict) + + kwargs.setdefault('join', False) + kwargs.setdefault('sharey', True) + g = sns.catplot(x='orientation_slice', y='orientation_amplitude', hue=hue, + col=col, row=row, col_wrap=col_wrap, estimator=estimator, + height=height, aspect=aspect, data=df, kind=kind, + legend=False, **kwargs) + + if marker_adjust: + labels = {v: k for k, v in kwargs.get('markers', {}).items()} + final_markers = plotting._marker_adjust(g.axes.flatten(), + marker_adjust, labels) + else: + final_markers = {} + g.set_ylabels('Amplitude') + g.set_xlabels('Orientation') + + # make some nice xticklabels + angles = [a/np.pi for a in sorted(df.orientation_slice.dropna().unique())] + ticklabels = [] + for a in angles: + if a % .25 == 0: + f = Fraction(int(a*4), 4) + if f.denominator == 1: + ticklabels.append(str(int(a))) + else: + ticklabels.append(r"$\frac{%s\pi}{%s}$" % + (f.numerator, f.denominator)) + else: + ticklabels.append('') + g.set_xticklabels(ticklabels) + # create the legend + plotting._add_legend(df, g, None, hue, style, + kwargs.get('palette', {}), final_markers, + dashes_dict) + # we use spectra because it doesn't include np.nan from dummy rows + title_str = (f"Orientation energy for {' and '.join(spectra.model.values)}" + f" metamers, {' and '.join(spectra.trial_type.values)}" + " comparisons\n") + g.fig.suptitle(title_str, va='bottom') + g.fig.subplots_adjust(wpsace=.1) + return g diff --git a/foveated_metamers/plotting.py b/foveated_metamers/plotting.py index 9fe04ce..31cd548 100644 --- a/foveated_metamers/plotting.py +++ b/foveated_metamers/plotting.py @@ -1521,3 +1521,42 @@ def add_physiological_scaling_arrows(ax, side_length=.05, midget_rgc=True, ax.text(xy[0][0], xy[0][1] - triangle_height - triangle_height/4, label, ha='center', va='top', transform=ax.transAxes) return xy + + +def _spectra_dataset_to_dataframe(spectra, data='sf'): + """Convert spectra xarray dataset to pandas dataframe. + + Parameters + ---------- + spectra : xarray.Dataset + Dataset containing the spectra for synthesized metamers and our natural + reference images. + data : {'sf', 'orientation'}, optional + Whether to grab the spatial frequency or orientation info + + Returns + ------- + df : pd.DataFrame + + """ + if data == 'sf': + cols = ['freq_n'] + elif data == 'orientation': + cols = ['orientation_slice', 'samples'] + else: + raise Exception("data must be one of {'sf', 'orientation'} but " + f"got {data}") + df = spectra[f'ref_image_{data}_amplitude'].to_dataframe().reset_index() + met_df = spectra[f'metamer_{data}_amplitude'].to_dataframe().reset_index() + # give the ref image rows dummy values + df['scaling'] = 'ref_image' + df['seed_n'] = 0 + df = df.melt(cols + ['image_name', 'model', 'scaling', 'seed_n', + 'trial_type'], + var_name='image_type', value_name=f'{data}_amplitude') + met_df = met_df.melt(cols + ['image_name', 'model', 'scaling', 'seed_n', + 'trial_type'], + var_name='image_type', value_name=f'{data}_amplitude') + df = pd.concat([df, met_df]) + df.image_type = df.image_type.apply(lambda x: x.replace(f'_{data}_amplitude', '')) + return df diff --git a/foveated_metamers/statistics.py b/foveated_metamers/statistics.py index 630a875..3952dd5 100644 --- a/foveated_metamers/statistics.py +++ b/foveated_metamers/statistics.py @@ -78,7 +78,9 @@ def amplitude_spectra(image): We compute the 2d Fourier transform of an image, take its magnitude, and then radially average it. This averages across orientations and also - discretizes the frequency. + discretizes the frequency. We also drop a disk in frequency space to + exclude the highest frequencies (that is, those where we don't have + cardinal directions). Parameters ---------- @@ -104,11 +106,112 @@ def amplitude_spectra(image): # Note the tutorial excludes label=0, but we include it (corresponds to the # DC term). rbin = pt.synthetic_images.polar_radius(frq.shape).astype(np.int) + # we ignore all frequencies outside a disk centered at the origin that + # reaches to the first edge (in frequency space). This means we get all + # frequencies that we can measure in each orientation (you can't get any + # frequencies in the cardinal directions beyond this disk) + frq_disk = pt.synthetic_images.polar_radius(frq.shape) + frq_thresh = min(frq.shape)//2 + frq_disk = frq_disk < frq_thresh + rbin[~frq_disk] = rbin.max()+1 spectra = scipy.ndimage.mean(np.abs(frq), labels=rbin, - index=np.arange(rbin.max()+1)) + index=np.arange(frq_thresh-1)) return spectra +def amplitude_orientation(image, n_angle_slices=32, metadata=OrderedDict()): + """Compute orientation energy of an image. + + We compute the 2d Fourier transform of an image, take its magnitude, and + compile the amplitudes in angular slices. Note that, unlike + amplitude_spectra(), we do not average within these slices to get a single + number. That's because the distributions here can have much larger outliers + -- whichever slice gets the DC term, for example, will have a way higher + average energy, but that's spurious and a reflection of the pixel grid's + alignment rather than anything meaningful. Right now, it's recommended to + use median to summarize these, but all values are returned. + + We also drop a disk in frequency space to exclude the highest frequencies + (that is, those where we don't have cardinal directions). + + Note that we don't window the image before taking the Fourier transform, + and thus there may be extra vertical and horizontal energy from boundary + artifacts. Thus, this should only be considered "relative" orientation + energy and used in comparison across images, rather than to infer cardinal + bias or the like. + + Parameters + ---------- + image : np.ndarray + The 2d array containing the image + n_angle_slices : int, optional + Number of slices between 0 and 2pi to break orientation into. Note that + we only return half these slices (because orientation is symmetric, + e.g., and orientation of 0 and pi is the same thing) + metadata: OrderedDict, optional + OrderedDict of extra coordinates to add to data (e.g., the model name). + Should be an OrderedDict so we get the proper ordering of dimensions. + + Returns + ------- + amplitude : xarray.Dataset + Dataset containing the amplitudes in each orientation slice. + + """ + frq = scipy.fft.fftshift(scipy.fft.fft2(image)) + theta = pt.synthetic_images.polar_angle(frq.shape, np.pi/n_angle_slices) + # to get this all positive and between 0 and 2pi + theta += np.abs(theta.min()) + # following similar logic to amplitude_spectra() above + theta = (n_angle_slices * theta/theta.max()).astype(int) + # this will be 1 or a very small number of pixels, and we want to lump them + # into the 0th bin (2pi is equivalent to 0) + theta[theta == theta.max()] = 0 + # we ignore all frequencies outside a disk centered at the origin that + # reaches to the first edge (in frequency space). This means we get all + # frequencies that we can measure in each orientation (you can't get any + # frequencies in the cardinal directions beyond this disk). + frq_disk = pt.synthetic_images.polar_radius(frq.shape) + frq_thresh = min(frq.shape)//2 + frq_disk = frq_disk < frq_thresh + theta[~frq_disk] = theta.max()+1 + + # convert this to NaN so we can use it for masking below + frq_disk = frq_disk.astype(float) + frq_disk[frq_disk == 0] = np.nan + # mask out the high frequencies + frq = frq_disk * np.abs(frq) + + slices = [] + # only need to go halfway around, because orientation is symmetric (an + # orientation 0 is the same as an orientation of pi, i.e., up is the same + # orientation as down) + th = np.linspace(0, np.pi, n_angle_slices//2, endpoint=False) + for i in range(theta.max()//2): + # grab data from this slice... + s = frq[theta == i] + # ... and drop everything beyond the frequency disk + slices.append(s[~np.isnan(s)]) + # now we want to concatenate this into a single array, which requires + # making each slice the same length (they're slightly different because of + # how they align with the pixel lattice). + max_len = max([len(s) for s in slices]) + slices = np.stack([np.pad(s, (0, max_len-len(s)), constant_values=np.nan) + for s in slices]) + # coords need to be lists when creating a DataArray + for k, v in metadata.items(): + if isinstance(v, str) or not hasattr(v, '__iter__'): + metadata[k] = [v] + metadata.update({'orientation_slice': th, + 'samples': np.arange(slices.shape[-1])}) + # add extra dimensions to the front of slices for metadata. + slices = np.expand_dims(slices, + tuple(np.arange(len(metadata.keys())-2))) + ds = xarray.DataArray(slices, metadata, metadata.keys(), + name='orientation_amplitude') + return ds.to_dataset() + + def image_set_amplitude_spectra(images, names, metadata=OrderedDict(), name_dim='image_name'): """Compute amplitude spectra of a set of images. @@ -136,10 +239,15 @@ def image_set_amplitude_spectra(images, names, metadata=OrderedDict(), """ spectra = [] - for im in images: + ori = [] + ori_metadata = metadata.copy() + for n, im in zip(names, images): if isinstance(im, str): im = po.to_numpy(po.load_images(im)).squeeze() + ori_metadata[name_dim] = n spectra.append(amplitude_spectra(im)) + ori.append(amplitude_orientation(im, metadata=ori_metadata)) + ori = xarray.concat(ori, 'image_name') for k, v in metadata.items(): if isinstance(v, str) or not hasattr(v, '__iter__'): metadata[k] = [v] @@ -149,5 +257,5 @@ def image_set_amplitude_spectra(images, names, metadata=OrderedDict(), spectra = np.expand_dims(spectra, tuple(np.arange(len(metadata.keys())-2))) data = xarray.DataArray(spectra, metadata, metadata.keys(), - name='sf_amplitude') - return data.to_dataset() + name='sf_amplitude').to_dataset() + return xarray.merge([data, ori])