Skip to content

Commit

Permalink
adds computation and plot of orientation distribution
Browse files Browse the repository at this point in the history
orientation anisotropy seems like a possible way of separating pink noise
from natural images (pink noise has equal power at all orientations). so
this adds a way to compute that and then plot it

we don't mean / median the vaules in each slice, because the
distributions are weird. instead we keep the full set of values and, for
now, we plot the median using catplot. might change
  • Loading branch information
billbrod committed Jun 18, 2021
1 parent d8fddc1 commit ef79a93
Show file tree
Hide file tree
Showing 4 changed files with 273 additions and 26 deletions.
18 changes: 12 additions & 6 deletions Snakefile
Original file line number Diff line number Diff line change
Expand Up @@ -1515,7 +1515,8 @@ rule compute_amplitude_spectra:
ims = sorted(ims, key=lambda x: LINEAR_IMAGES.index([i for i in LINEAR_IMAGES if i in x][0]))
assert len(ims) == len(LINEAR_IMAGES), f"Have too many images! Expected {len(LINEAR_IMAGES)}, but got {ims}"
ref_image_spectra = fov.statistics.image_set_amplitude_spectra(ims, LINEAR_IMAGES, metadata)
ref_image_spectra = ref_image_spectra.rename({'sf_amplitude': 'ref_image_sf_amplitude'})
ref_image_spectra = ref_image_spectra.rename({'sf_amplitude': 'ref_image_sf_amplitude',
'orientation_amplitude': 'ref_image_orientation_amplitude'})
spectra = []
for scaling in scalings:
tmp_ims = [i for i in input if len(re.findall(f'scaling-{scaling}', i)) == 1]
Expand All @@ -1530,7 +1531,8 @@ rule compute_amplitude_spectra:
tmp_spectra.append(fov.statistics.image_set_amplitude_spectra(ims, LINEAR_IMAGES, metadata))
spectra.append(xarray.concat(tmp_spectra, 'seed_n'))
spectra = xarray.concat(spectra, 'scaling')
spectra = xarray.merge([spectra.rename({'sf_amplitude': 'metamer_sf_amplitude'}),
spectra = xarray.merge([spectra.rename({'sf_amplitude': 'metamer_sf_amplitude',
'orientation_amplitude': 'metamer_orientation_amplitude'}),
ref_image_spectra])
spectra.to_netcdf(output[0])

Expand All @@ -1541,23 +1543,27 @@ rule plot_amplitude_spectra:
'task-split_comp-{comp}_amplitude-spectra.nc')
output:
op.join(config['DATA_DIR'], 'statistics', 'amplitude_spectra', '{model_name}', 'task-split_comp-{comp}',
'task-split_comp-{comp}_amplitude-spectra.svg')
'task-split_comp-{comp}_{amplitude_type}-spectra.svg')
log:
op.join(config['DATA_DIR'], 'logs', 'statistics', 'amplitude_spectra', '{model_name}', 'task-split_comp-{comp}',
'task-split_comp-{comp}_amplitude-spectra_plot.log')
'task-split_comp-{comp}_{amplitude_type}-spectra_plot.log')
benchmark:
op.join(config['DATA_DIR'], 'logs', 'statistics', 'amplitude_spectra', '{model_name}', 'task-split_comp-{comp}',
'task-split_comp-{comp}_amplitude-spectra_plot_benchmark.txt')
'task-split_comp-{comp}_{amplitude_type}-spectra_plot_benchmark.txt')
run:
import foveated_metamers as fov
import xarray
import contextlib
with open(log[0], 'w', buffering=1) as log_file:
with contextlib.redirect_stdout(log_file), contextlib.redirect_stderr(log_file):
ds = xarray.load_dataset(input[0])
g = fov.figures.amplitude_spectra(ds)
if wildcards.amplitude_type == 'sf':
g = fov.figures.amplitude_spectra(ds)
elif wildcards.amplitude_type == 'orientation':
g = fov.figures.amplitude_orientation(ds)
g.savefig(output[0], bbox_inches='tight')


rule simulate_optimization:
output:
op.join(config['DATA_DIR'], 'simulate', 'optimization', 'a0-{a0}_s0-{s0}_seeds-{n_seeds}_iter-{max_iter}.svg'),
Expand Down
124 changes: 109 additions & 15 deletions foveated_metamers/figures.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import itertools
import torch
import re
import yaml
from fractions import Fraction
import pandas as pd
import numpy as np
import pyrtools as pt
Expand Down Expand Up @@ -1758,7 +1758,7 @@ def amplitude_spectra(spectra, hue='scaling', style=None, col='image_name',
row=None, col_wrap=5, kind='line', estimator=None,
height=2.5, aspect=1, **kwargs):
"""Compare amplitude spectra of natural and synthesized images.
Parameters
----------
spectra : xarray.Dataset
Expand Down Expand Up @@ -1789,19 +1789,7 @@ def amplitude_spectra(spectra, hue='scaling', style=None, col='image_name',
Facetgrid with the plot.
"""
df = spectra.ref_image_sf_amplitude.to_dataframe().reset_index()
met_df = spectra.metamer_sf_amplitude.to_dataframe().reset_index()
# give the ref image rows dummy values
df['scaling'] = 'ref_image'
df['seed_n'] = 0
df = df.melt(['freq_n', 'image_name', 'model', 'scaling', 'seed_n',
'trial_type'],
var_name='image_type', value_name='sf_amplitude')
met_df = met_df.melt(['freq_n', 'image_name', 'model', 'scaling', 'seed_n',
'trial_type'],
var_name='image_type', value_name='sf_amplitude')
df = pd.concat([df, met_df])
df.image_type = df.image_type.apply(lambda x: x.replace('_sf_amplitude', ''))
df = plotting._spectra_dataset_to_dataframe(spectra, 'sf')
# seaborn raises an error if col_wrap is non-None when col is None or if
# row is not None, so prevent that possibility
if col is None or row is not None:
Expand Down Expand Up @@ -1860,3 +1848,109 @@ def amplitude_spectra(spectra, hue='scaling', style=None, col='image_name',
" comparisons\n")
g.fig.suptitle(title_str, va='bottom')
return g


def amplitude_orientation(spectra, hue='scaling', style=None, col='image_name',
row=None, col_wrap=5, kind='point',
estimator=np.median, height=2.5, aspect=2, **kwargs):
"""Compare orientation distributions of natural and synthesized images.
Note this is fairly memory intensive. Setting `n_boot` to a lower number
(defaults to 1000) seems to help with this.
Parameters
----------
spectra : xarray.Dataset
Dataset containing the spectra for synthesized metamers and our natural
reference images.
hue, style, col, row : str or None, optional
The dimensions in spectra to facet along the columns, rows, hues, and
styles, respectively.
col_wrap : int or None, optional
If row is None, how many columns to have before wrapping to the next
row. Ignored if col=None.
kind : {'line', 'scatter'}, optional
Type of plot to make.
estimator : name of pandas method or callable
Method for aggregating across multiple observations of the y variable
at the same x level. median is recommended because of the outliers in
the dataset
height : float, optional
Height of the axes.
aspect : float, optional
Aspect of the axes.
kwargs :
Passed to sns.catplot
Returns
-------
g : sns.FacetGrid
Facetgrid with the plot.
"""
df = plotting._spectra_dataset_to_dataframe(spectra, 'orientation')
# seaborn raises an error if col_wrap is non-None when col is None or if
# row is not None, so prevent that possibility
if col is None or row is not None:
col_wrap = None
# remap the image names to be better for plotting
df = plotting._remap_image_names(df)
img_order = plotting.get_order('image_name')
if col == 'image_name':
kwargs.setdefault('col_order', img_order)
if row == 'image_name':
kwargs.setdefault('row_order', img_order)
if hue is not None:
kwargs.setdefault('palette', plotting.get_palette(hue, df[hue].unique()))
else:
kwargs.setdefault('color', 'k')
marker_adjust = {}
dashes_dict = {}
if style is not None:
style_dict = plotting.get_style(style, df[style].unique())
dashes_dict = style_dict.pop('dashes_dict', {})
marker_adjust = style_dict.pop('marker_adjust', {})
kwargs.setdefault('dashes', dashes_dict)
kwargs.update(style_dict)

kwargs.setdefault('join', False)
kwargs.setdefault('sharey', True)
g = sns.catplot(x='orientation_slice', y='orientation_amplitude', hue=hue,
col=col, row=row, col_wrap=col_wrap, estimator=estimator,
height=height, aspect=aspect, data=df, kind=kind,
legend=False, **kwargs)

if marker_adjust:
labels = {v: k for k, v in kwargs.get('markers', {}).items()}
final_markers = plotting._marker_adjust(g.axes.flatten(),
marker_adjust, labels)
else:
final_markers = {}
g.set_ylabels('Amplitude')
g.set_xlabels('Orientation')

# make some nice xticklabels
angles = [a/np.pi for a in sorted(df.orientation_slice.dropna().unique())]
ticklabels = []
for a in angles:
if a % .25 == 0:
f = Fraction(int(a*4), 4)
if f.denominator == 1:
ticklabels.append(str(int(a)))
else:
ticklabels.append(r"$\frac{%s\pi}{%s}$" %
(f.numerator, f.denominator))
else:
ticklabels.append('')
g.set_xticklabels(ticklabels)
# create the legend
plotting._add_legend(df, g, None, hue, style,
kwargs.get('palette', {}), final_markers,
dashes_dict)
# we use spectra because it doesn't include np.nan from dummy rows
title_str = (f"Orientation energy for {' and '.join(spectra.model.values)}"
f" metamers, {' and '.join(spectra.trial_type.values)}"
" comparisons\n")
g.fig.suptitle(title_str, va='bottom')
g.fig.subplots_adjust(wpsace=.1)
return g
39 changes: 39 additions & 0 deletions foveated_metamers/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -1521,3 +1521,42 @@ def add_physiological_scaling_arrows(ax, side_length=.05, midget_rgc=True,
ax.text(xy[0][0], xy[0][1] - triangle_height - triangle_height/4,
label, ha='center', va='top', transform=ax.transAxes)
return xy


def _spectra_dataset_to_dataframe(spectra, data='sf'):
"""Convert spectra xarray dataset to pandas dataframe.
Parameters
----------
spectra : xarray.Dataset
Dataset containing the spectra for synthesized metamers and our natural
reference images.
data : {'sf', 'orientation'}, optional
Whether to grab the spatial frequency or orientation info
Returns
-------
df : pd.DataFrame
"""
if data == 'sf':
cols = ['freq_n']
elif data == 'orientation':
cols = ['orientation_slice', 'samples']
else:
raise Exception("data must be one of {'sf', 'orientation'} but "
f"got {data}")
df = spectra[f'ref_image_{data}_amplitude'].to_dataframe().reset_index()
met_df = spectra[f'metamer_{data}_amplitude'].to_dataframe().reset_index()
# give the ref image rows dummy values
df['scaling'] = 'ref_image'
df['seed_n'] = 0
df = df.melt(cols + ['image_name', 'model', 'scaling', 'seed_n',
'trial_type'],
var_name='image_type', value_name=f'{data}_amplitude')
met_df = met_df.melt(cols + ['image_name', 'model', 'scaling', 'seed_n',
'trial_type'],
var_name='image_type', value_name=f'{data}_amplitude')
df = pd.concat([df, met_df])
df.image_type = df.image_type.apply(lambda x: x.replace(f'_{data}_amplitude', ''))
return df
118 changes: 113 additions & 5 deletions foveated_metamers/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,9 @@ def amplitude_spectra(image):
We compute the 2d Fourier transform of an image, take its magnitude, and
then radially average it. This averages across orientations and also
discretizes the frequency.
discretizes the frequency. We also drop a disk in frequency space to
exclude the highest frequencies (that is, those where we don't have
cardinal directions).
Parameters
----------
Expand All @@ -104,11 +106,112 @@ def amplitude_spectra(image):
# Note the tutorial excludes label=0, but we include it (corresponds to the
# DC term).
rbin = pt.synthetic_images.polar_radius(frq.shape).astype(np.int)
# we ignore all frequencies outside a disk centered at the origin that
# reaches to the first edge (in frequency space). This means we get all
# frequencies that we can measure in each orientation (you can't get any
# frequencies in the cardinal directions beyond this disk)
frq_disk = pt.synthetic_images.polar_radius(frq.shape)
frq_thresh = min(frq.shape)//2
frq_disk = frq_disk < frq_thresh
rbin[~frq_disk] = rbin.max()+1
spectra = scipy.ndimage.mean(np.abs(frq), labels=rbin,
index=np.arange(rbin.max()+1))
index=np.arange(frq_thresh-1))
return spectra


def amplitude_orientation(image, n_angle_slices=32, metadata=OrderedDict()):
"""Compute orientation energy of an image.
We compute the 2d Fourier transform of an image, take its magnitude, and
compile the amplitudes in angular slices. Note that, unlike
amplitude_spectra(), we do not average within these slices to get a single
number. That's because the distributions here can have much larger outliers
-- whichever slice gets the DC term, for example, will have a way higher
average energy, but that's spurious and a reflection of the pixel grid's
alignment rather than anything meaningful. Right now, it's recommended to
use median to summarize these, but all values are returned.
We also drop a disk in frequency space to exclude the highest frequencies
(that is, those where we don't have cardinal directions).
Note that we don't window the image before taking the Fourier transform,
and thus there may be extra vertical and horizontal energy from boundary
artifacts. Thus, this should only be considered "relative" orientation
energy and used in comparison across images, rather than to infer cardinal
bias or the like.
Parameters
----------
image : np.ndarray
The 2d array containing the image
n_angle_slices : int, optional
Number of slices between 0 and 2pi to break orientation into. Note that
we only return half these slices (because orientation is symmetric,
e.g., and orientation of 0 and pi is the same thing)
metadata: OrderedDict, optional
OrderedDict of extra coordinates to add to data (e.g., the model name).
Should be an OrderedDict so we get the proper ordering of dimensions.
Returns
-------
amplitude : xarray.Dataset
Dataset containing the amplitudes in each orientation slice.
"""
frq = scipy.fft.fftshift(scipy.fft.fft2(image))
theta = pt.synthetic_images.polar_angle(frq.shape, np.pi/n_angle_slices)
# to get this all positive and between 0 and 2pi
theta += np.abs(theta.min())
# following similar logic to amplitude_spectra() above
theta = (n_angle_slices * theta/theta.max()).astype(int)
# this will be 1 or a very small number of pixels, and we want to lump them
# into the 0th bin (2pi is equivalent to 0)
theta[theta == theta.max()] = 0
# we ignore all frequencies outside a disk centered at the origin that
# reaches to the first edge (in frequency space). This means we get all
# frequencies that we can measure in each orientation (you can't get any
# frequencies in the cardinal directions beyond this disk).
frq_disk = pt.synthetic_images.polar_radius(frq.shape)
frq_thresh = min(frq.shape)//2
frq_disk = frq_disk < frq_thresh
theta[~frq_disk] = theta.max()+1

# convert this to NaN so we can use it for masking below
frq_disk = frq_disk.astype(float)
frq_disk[frq_disk == 0] = np.nan
# mask out the high frequencies
frq = frq_disk * np.abs(frq)

slices = []
# only need to go halfway around, because orientation is symmetric (an
# orientation 0 is the same as an orientation of pi, i.e., up is the same
# orientation as down)
th = np.linspace(0, np.pi, n_angle_slices//2, endpoint=False)
for i in range(theta.max()//2):
# grab data from this slice...
s = frq[theta == i]
# ... and drop everything beyond the frequency disk
slices.append(s[~np.isnan(s)])
# now we want to concatenate this into a single array, which requires
# making each slice the same length (they're slightly different because of
# how they align with the pixel lattice).
max_len = max([len(s) for s in slices])
slices = np.stack([np.pad(s, (0, max_len-len(s)), constant_values=np.nan)
for s in slices])
# coords need to be lists when creating a DataArray
for k, v in metadata.items():
if isinstance(v, str) or not hasattr(v, '__iter__'):
metadata[k] = [v]
metadata.update({'orientation_slice': th,
'samples': np.arange(slices.shape[-1])})
# add extra dimensions to the front of slices for metadata.
slices = np.expand_dims(slices,
tuple(np.arange(len(metadata.keys())-2)))
ds = xarray.DataArray(slices, metadata, metadata.keys(),
name='orientation_amplitude')
return ds.to_dataset()


def image_set_amplitude_spectra(images, names, metadata=OrderedDict(),
name_dim='image_name'):
"""Compute amplitude spectra of a set of images.
Expand Down Expand Up @@ -136,10 +239,15 @@ def image_set_amplitude_spectra(images, names, metadata=OrderedDict(),
"""
spectra = []
for im in images:
ori = []
ori_metadata = metadata.copy()
for n, im in zip(names, images):
if isinstance(im, str):
im = po.to_numpy(po.load_images(im)).squeeze()
ori_metadata[name_dim] = n
spectra.append(amplitude_spectra(im))
ori.append(amplitude_orientation(im, metadata=ori_metadata))
ori = xarray.concat(ori, 'image_name')
for k, v in metadata.items():
if isinstance(v, str) or not hasattr(v, '__iter__'):
metadata[k] = [v]
Expand All @@ -149,5 +257,5 @@ def image_set_amplitude_spectra(images, names, metadata=OrderedDict(),
spectra = np.expand_dims(spectra,
tuple(np.arange(len(metadata.keys())-2)))
data = xarray.DataArray(spectra, metadata, metadata.keys(),
name='sf_amplitude')
return data.to_dataset()
name='sf_amplitude').to_dataset()
return xarray.merge([data, ori])

0 comments on commit ef79a93

Please sign in to comment.