Skip to content

Commit

Permalink
Add temporal metrics; add temporal versions of MajajHong2015 (#1109)
Browse files Browse the repository at this point in the history
* feature: support temporal models for neural alignment by chaning TemporalIgnore to Temporal Aligned

* add example temporal submission

* complete new framework

* new module: temporal model helpers

* change the arch of temporal; add tutorials

* improve: better naming

* update: wrapper tutorial on brain model

* add feature: inferencer identifier tracked by extractor for result caching

* fix: video fps sampling; need more tests!

* fix bugs: video sampling based on fps was wrong.

* add mmaction2 models; add more features to the inferencers

* PR: temporal model helpers

* PR fix: not including gitmodules for now

* Update brainscore_vision/model_helpers/brain_transformation/temporal.py

Co-authored-by: Martin Schrimpf <[email protected]>

* Update brainscore_vision/model_helpers/brain_transformation/temporal.py

Co-authored-by: Martin Schrimpf <[email protected]>

* Update brainscore_vision/model_helpers/brain_transformation/temporal.py

Co-authored-by: Martin Schrimpf <[email protected]>

* Update brainscore_vision/models/temporal_models/test.py

Co-authored-by: Martin Schrimpf <[email protected]>

* add mae_st; add ding2012

* try new arch

* init ding2012

* add tests for temporal model helpers; add block inferencer

* Delete tests/test_model_helpers/temporal/test___init__.py

delete the old test

* add benchmark ding2012

* add mutliple libs for temporal models

* change executor output format; add more inference tests; init load_weight in s3

* add openstl

* update backend for executor

* feat:load_weight_file and corresponding test

* change:resize strategy changed from bilinear to pooling

* change:resize strategy changed from bilinear to pooling

* fix mae_st submission

* minor

* fix:dtype in assembly time align

* minor

* update model submissions

* fix dependency

* refactor: simplify the inferencer methods

* fix:block inferencer, neuroid coord while merging

* fix:inferencer identifier

* fix:weigh download

* change tests to have max_workers=1

* revert screen.py

* not submit region_layer_map

* remove torch dependency

* make fake modules in tests

* add torch to requirements; avoid torch in tests

* minor

* minor

* np.object changed to object

* remove return in tests

* fix insertion position bug

* Apply suggestions from code review

add: more type hints

Co-authored-by: Martin Schrimpf <[email protected]>

* add: more type hints and comments

* minor

* pr:only commit temporal model helpers

* pr: add one model for example

* undo whole_brain in Brainodel.RecordingTarget

* use logger and fix newlines

* fix: video fps with copy was wrong

* feat:fractional max_spatial_size

* downsample layers in VideoMAE

* fix:video sampling wrong duration

* add more tests

* fix merge

* fix merge

* module refactor; add more input test

* add more temporal models

* fix videomaev2 sha

* fix:temporal_modelmae_st

* change:video conservative loading; rename:image to pil image

* fix:video last frame sampling; fix_time_naming

* ignore pytest_cache

* re-trigger tests

* add joblib pool error management; fix video/image path recognizer

* update: naming of failed to pickle func in joblibmapper

* add temporal metric helpers

* add temporal version of mamjajhong2015

* Update benchmark.py

type hint

* Update benchmark.py

* Update brainscore_vision/metric_helpers/temporal.py

Co-authored-by: Martin Schrimpf <[email protected]>

* Update brainscore_vision/metrics/internal_consistency/__init__.py

Co-authored-by: Martin Schrimpf <[email protected]>

* Update benchmark.py

---------

Co-authored-by: Yingtian Tang <[email protected]>
Co-authored-by: Martin Schrimpf <[email protected]>
Co-authored-by: Martin Schrimpf <[email protected]>
Co-authored-by: deirdre-k <[email protected]>
Co-authored-by: Michael Ferguson <[email protected]>
  • Loading branch information
6 people authored Sep 10, 2024
1 parent 4dd1d0f commit b17ed0b
Show file tree
Hide file tree
Showing 8 changed files with 320 additions and 10 deletions.
5 changes: 5 additions & 0 deletions brainscore_vision/benchmarks/majajhong2015/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,8 @@

benchmark_registry['MajajHong2015public.V4-pls'] = MajajHongV4PublicBenchmark
benchmark_registry['MajajHong2015public.IT-pls'] = MajajHongITPublicBenchmark

# temporal
from .benchmark import MajajHongV4TemporalPublicBenchmark, MajajHongITTemporalPublicBenchmark
benchmark_registry['MajajHong2015public.V4-temporal-pls'] = lambda: MajajHongV4TemporalPublicBenchmark(time_interval=10)
benchmark_registry['MajajHong2015public.IT-temporal-pls'] = lambda: MajajHongITTemporalPublicBenchmark(time_interval=10)
44 changes: 34 additions & 10 deletions brainscore_vision/benchmarks/majajhong2015/benchmark.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from brainscore_core import Metric

from brainscore_vision import load_metric, Ceiling, load_ceiling, load_dataset
from brainscore_vision.benchmark_helpers.neural_common import NeuralBenchmark, average_repetition
from brainscore_vision.benchmark_helpers.neural_common import NeuralBenchmark, average_repetition, apply_keep_attrs
from brainscore_vision.model_helpers.brain_transformation.temporal import assembly_time_align

VISUAL_DEGREES = 8
NUMBER_OF_TRIALS = 50
Expand All @@ -20,13 +21,14 @@
eprint = {https://www.jneurosci.org/content/35/39/13402.full.pdf},
journal = {Journal of Neuroscience}}"""

pls_metric = lambda: load_metric('pls', crossvalidation_kwargs=dict(stratification_coord='object_name'))

crossvalidation_kwargs = dict(stratification_coord='object_name')
pls_metric = lambda: load_metric('pls', crossvalidation_kwargs=crossvalidation_kwargs)
spantime_pls_metric = lambda: load_metric('spantime_pls', crossvalidation_kwargs=crossvalidation_kwargs)

def _DicarloMajajHong2015Region(region: str, access: str, identifier_metric_suffix: str,
similarity_metric: Metric, ceiler: Ceiling):
assembly_repetition = load_assembly(average_repetitions=False, region=region, access=access)
assembly = load_assembly(average_repetitions=True, region=region, access=access)
similarity_metric: Metric, ceiler: Ceiling, time_interval: float = None):
assembly_repetition = load_assembly(average_repetitions=False, region=region, access=access, time_interval=time_interval)
assembly = load_assembly(average_repetitions=True, region=region, access=access, time_interval=time_interval)
benchmark_identifier = f'MajajHong2015.{region}' + ('.public' if access == 'public' else '')
return NeuralBenchmark(identifier=f'{benchmark_identifier}-{identifier_metric_suffix}', version=3,
assembly=assembly, similarity_metric=similarity_metric,
Expand Down Expand Up @@ -60,13 +62,35 @@ def MajajHongITPublicBenchmark():
ceiler=load_ceiling('internal_consistency'))


def load_assembly(average_repetitions, region, access='private'):
assembly = load_dataset(f'MajajHong2015.{access}')
def MajajHongV4TemporalPublicBenchmark(time_interval: float = None):
return _DicarloMajajHong2015Region(region='V4', access='public', identifier_metric_suffix='pls',
similarity_metric=spantime_pls_metric(), time_interval=time_interval,
ceiler=load_ceiling('internal_consistency_temporal'))


def MajajHongITTemporalPublicBenchmark(time_interval: float = None):
return _DicarloMajajHong2015Region(region='IT', access='public', identifier_metric_suffix='pls',
similarity_metric=spantime_pls_metric(), time_interval=time_interval,
ceiler=load_ceiling('internal_consistency_temporal'))


def load_assembly(average_repetitions: bool, region: str, access: str = 'private', time_interval: float = None):
temporal = time_interval is not None
if not temporal:
assembly = load_dataset(f'MajajHong2015.{access}')
assembly = assembly.squeeze("time_bin")
else:
assembly = load_dataset(f'MajajHong2015.temporal.{access}')
assembly = assembly.__class__(assembly)
target_time_bins = [
(t, t+time_interval) for t in range(0, assembly.time_bin_end.max().item()-time_interval, time_interval)
]
assembly = apply_keep_attrs(assembly, lambda assembly: assembly_time_align(assembly, target_time_bins))

assembly = assembly.sel(region=region)
assembly['region'] = 'neuroid', [region] * len(assembly['neuroid'])
assembly = assembly.squeeze("time_bin")
assembly.load()
assembly = assembly.transpose('presentation', 'neuroid')
assembly = assembly.transpose('presentation', 'neuroid', ...)
if average_repetitions:
assembly = average_repetition(assembly)
return assembly
119 changes: 119 additions & 0 deletions brainscore_vision/metric_helpers/temporal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
import xarray as xr
import numpy as np

from brainscore_vision.benchmark_helpers.neural_common import Score
from brainscore_vision.metric_helpers.transformations import standard_error_of_the_mean

from .xarray_utils import apply_over_dims, recursive_op


# take the mean of scores (medians of single neuron scores) over time


def average_over_presentation(score: Score) -> Score:
raw = score
score = raw.mean('presentation')
score.attrs['raw'] = raw
return score


# PerOps is applied to every slice/chunk of the xarray along the specified dimensions
class PerOps:
def __init__(self, callable, dims, check_coords=[]):
# for coordinate checking, they are supposed to be the same across assemblies
self.dims = dims
self.callable = callable
self.check_coords = check_coords

def __call__(self, *asms):
for check_coord in self.check_coords:
asms = [asm.sortby(check_coord) for asm in asms]
for asm in asms[1:]:
assert (asm[check_coord].values == asms[0][check_coord].values).all()
ret = apply_over_dims(self.callable, *asms, dims=self.dims)
return ret


# SpanOps aggregates specified dimensions to one dimension
class SpanOps:
def __init__(self, callable, source_dims, aggregated_dim, resample=False):
# if resample, randomly choose samples from the aggregated dimension,
# whose size is the same as the assembly.sizes[aggregated_dim]
self.source_dims = source_dims
self.aggregated_dim = aggregated_dim
self.callable = callable
self.resample = resample

def __call__(self, *asms):
asms = [self._stack(asm) for asm in asms]
return self.callable(*asms)

def _stack(self, assembly):
assembly_type = type(assembly)
size = assembly.sizes[self.aggregated_dim]
assembly = xr.DataArray(assembly) # xarray cannot deal with stacking MultiIndex (pydata/xarray#1554)
assembly = assembly.reset_index(self.source_dims)
assembly = assembly.rename({dim:dim+"_" for dim in self.source_dims}) # we'll call stacked timebins "presentation"
assembly = assembly.stack({self.aggregated_dim : [dim+"_" for dim in self.source_dims]})
if self.resample:
indices = np.random.randint(0, assembly.sizes[self.aggregated_dim], size)
assembly = assembly.isel({self.aggregated_dim: indices})
return assembly_type(assembly)

class PerTime(PerOps):
def __init__(self, callable, time_dim="time_bin", check_coord="time_bin_start", **kwargs):
self.time_bin = time_dim
super().__init__(callable, dims=[time_dim], check_coords=[check_coord], **kwargs)

class PerPresentation(PerOps):
def __init__(self, callable, presentation_dim="presentation", check_coord="stimulus_id", **kwargs):
self.presentation_dim = presentation_dim
super().__init__(callable, dims=[presentation_dim], check_coords=[check_coord], **kwargs)

class PerNeuroid(PerOps):
def __init__(self, callable, neuroid_dim="neuroid", check_coord="neuroid_id", **kwargs):
self.neuroid_dim = neuroid_dim
super().__init__(callable, dims=[neuroid_dim], check_coords=[check_coord], **kwargs)

class SpanTime(SpanOps):
def __init__(self, callable, time_dim="time_bin", presentation_dim="presentation", resample=False):
self.time_dim = time_dim
self.presentation_dim = presentation_dim
source_dims = [self.time_dim, self.presentation_dim]
aggregated_dim = self.presentation_dim
super().__init__(callable, source_dims, aggregated_dim, resample=resample)

class SpanTimeRegression:
"""
Fits a regression with weights shared across the time bins.
"""

def __init__(self, regression):
self._regression = regression

def fit(self, source, target):
assert (source['time_bin'].values == target['time_bin'].values).all()
SpanTime(self._regression.fit)(source, target)

def predict(self, source):
return PerTime(self._regression.predict)(source)

class PerTimeRegression:
"""
Fits a regression with different weights for each time bins.
"""

def __init__(self, regression):
self._regression = regression

def fit(self, source, target):
# Lazy fit until predict
assert (source['time_bin'].values == target['time_bin'].values).all()
self._train_source = source
self._train_target = target

def predict(self, source):
def fit_predict(train_source, train_target, test_source):
self._regression.fit(train_source, train_target)
return self._regression.predict(test_source)
return PerTime(fit_predict)(self._train_source, self._train_target, source)
59 changes: 59 additions & 0 deletions brainscore_vision/metric_helpers/xarray_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
import xarray as xr

from brainio.assemblies import NeuroidAssembly, array_is_element, walk_coords
from brainscore_vision.metric_helpers import Defaults
Expand Down Expand Up @@ -90,3 +91,61 @@ def __call__(self, prediction, target):
for coord, dims, values in walk_coords(target) if dims == neuroid_dims},
dims=neuroid_dims)
return result


# ops that also applies to attrs (and attrs of attrs), which are xarrays
def recursive_op(*arrs, op=lambda x:x):
# the attrs structure of each arr must be the same
val = op(*arrs)
attrs = arrs[0].attrs
for attr in attrs:
attr_val = arrs[0].attrs[attr]
if isinstance(attr_val, xr.DataArray):
attr_arrs = [arr.attrs[attr] for arr in arrs]
attr_val = recursive_op(*attr_arrs, op=op)
val.attrs[attr] = attr_val
return val


# apply a callable to every slice of the xarray along the specified dimensions
def apply_over_dims(callable, *asms, dims, njobs=-1):
asms = [asm.transpose(*dims, ...) for asm in asms]
sizes = [asms[0].sizes[dim] for dim in dims]

def apply_helper(sizes, dims, *asms):
xarr = []
attrs = {}
size = sizes[0]
rsizes = sizes[1:]
dim = dims[0]
rdims = dims[1:]

if len(sizes) == 1:
# parallel execution on the last applied dimension
from joblib import Parallel, delayed
results = Parallel(n_jobs=njobs)(delayed(callable)(*[asm.isel({dim:s}) for asm in asms]) for s in range(size))
else:
results = []
for s in range(size):
arr = apply_helper(rsizes, rdims, *[asm.isel({dim:s}) for asm in asms])
results.append(arr)

for arr in results:
if arr is not None:
for k,v in arr.attrs.items():
assert isinstance(v, xr.DataArray)
attrs.setdefault(k, []).append(v.expand_dims(dim))
xarr.append(arr)

if not xarr:
return
else:
xarr = xr.concat(xarr, dim=dim)
attrs = {k: xr.concat(vs, dim=dim) for k,vs in attrs.items()}
xarr.coords[dim] = asms[0].coords[dim]
for k,v in attrs.items():
attrs[k].coords[dim] = asms[0].coords[dim]
xarr.attrs[k] = attrs[k]
return xarr

return apply_helper(sizes, dims, *asms)
4 changes: 4 additions & 0 deletions brainscore_vision/metrics/internal_consistency/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
from brainscore_vision import metric_registry
from .ceiling import InternalConsistency

from brainscore_vision.metric_helpers.temporal import PerTime


metric_registry['internal_consistency'] = InternalConsistency
metric_registry['internal_consistency_temporal'] = lambda *args, **kwargs: PerTime(InternalConsistency(*args, **kwargs))
9 changes: 9 additions & 0 deletions brainscore_vision/metrics/regression_correlation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,15 @@
metric_registry['linear_predictivity'] = lambda *args, **kwargs: CrossRegressedCorrelation(
regression=linear_regression(), correlation=pearsonr_correlation(), *args, **kwargs)

# temporal metrics
from .metric import SpanTimeCrossRegressedCorrelation

metric_registry['spantime_pls'] = lambda *args, **kwargs: SpanTimeCrossRegressedCorrelation(
regression=pls_regression(), correlation=pearsonr_correlation(), *args, **kwargs)
metric_registry['spantime_ridge'] = lambda *args, **kwargs: SpanTimeCrossRegressedCorrelation(
regression=ridge_regression(), correlation=pearsonr_correlation(), *args, **kwargs)


BIBTEX = """@article{schrimpf2018brain,
title={Brain-score: Which artificial neural network for object recognition is most brain-like?},
author={Schrimpf, Martin and Kubilius, Jonas and Hong, Ha and Majaj, Najib J and Rajalingham, Rishi and Issa, Elias B and Kar, Kohitij and Bashivan, Pouya and Prescott-Roy, Jonathan and Geiger, Franziska and others},
Expand Down
10 changes: 10 additions & 0 deletions brainscore_vision/metrics/regression_correlation/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from brainscore_core.metrics import Metric, Score
from brainscore_vision.metric_helpers.transformations import CrossValidation
from brainscore_vision.metric_helpers.xarray_utils import XarrayRegression, XarrayCorrelation
from brainscore_vision.metric_helpers.temporal import SpanTimeRegression, PerTime


class CrossRegressedCorrelation(Metric):
Expand Down Expand Up @@ -65,6 +66,15 @@ def predict(self, X):
return Ypred


# make the crc to consider time as a sample dimension
def SpanTimeCrossRegressedCorrelation(regression, correlation, *args, **kwargs):
return CrossRegressedCorrelation(
regression=SpanTimeRegression(regression),
correlation=PerTime(correlation),
*args, **kwargs
)


def pls_regression(regression_kwargs=None, xarray_kwargs=None):
regression_defaults = dict(n_components=25, scale=False)
regression_kwargs = {**regression_defaults, **(regression_kwargs or {})}
Expand Down
80 changes: 80 additions & 0 deletions tests/test_metric_helpers/test_temporal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import numpy as np
import scipy.stats
from pytest import approx
from sklearn.linear_model import LinearRegression

from brainio.assemblies import NeuroidAssembly
from brainscore_vision.metric_helpers.xarray_utils import XarrayRegression, XarrayCorrelation
from brainscore_vision.metric_helpers.temporal import PerTime, SpanTime, PerTimeRegression, SpanTimeRegression


class TestMetricHelpers:
def test_pertime_ops(self):
jumbled_source = NeuroidAssembly(np.random.rand(500, 10, 20),
coords={'stimulus_id': ('presentation', list(reversed(range(500)))),
'image_meta': ('presentation', [0] * 500),
'neuroid_id': ('neuroid', list(reversed(range(10)))),
'neuroid_meta': ('neuroid', [0] * 10),
'time_bin_start': ('time_bin', np.arange(0, 400, 20)),
'time_bin_end': ('time_bin', np.arange(20, 420, 20))},
dims=['presentation', 'neuroid', 'time_bin'])
mean_neuroid = lambda arr: arr.mean('neuroid')
pertime_mean_neuroid = PerTime(mean_neuroid)
output = pertime_mean_neuroid(jumbled_source)
output = output.transpose('presentation', 'time_bin')
target = jumbled_source.transpose('presentation', 'time_bin', 'neuroid').mean('neuroid')
assert (output == approx(target)).all().item()

def test_spantime_ops(self):
jumbled_source = NeuroidAssembly(np.random.rand(500, 10, 20),
coords={'stimulus_id': ('presentation', list(reversed(range(500)))),
'image_meta': ('presentation', [0] * 500),
'neuroid_id': ('neuroid', list(reversed(range(10)))),
'neuroid_meta': ('neuroid', [0] * 10),
'time_bin_start': ('time_bin', np.arange(0, 400, 20)),
'time_bin_end': ('time_bin', np.arange(20, 420, 20))},
dims=['presentation', 'neuroid', 'time_bin'])
mean_presentation = lambda arr: arr.mean("presentation")
spantime_mean_presentation = SpanTime(mean_presentation)
output = spantime_mean_presentation(jumbled_source)
output = output.transpose('neuroid')
target = jumbled_source.transpose('presentation', 'time_bin', 'neuroid').mean('presentation').mean('time_bin')
assert (output == approx(target)).all().item()

def test_pertime_regression(self):
jumbled_source = NeuroidAssembly(np.random.rand(500, 10, 20),
coords={'stimulus_id': ('presentation', list(reversed(range(500)))),
'image_meta': ('presentation', [0] * 500),
'neuroid_id': ('neuroid', list(reversed(range(10)))),
'neuroid_meta': ('neuroid', [0] * 10),
'time_bin_start': ('time_bin', np.arange(0, 400, 20)),
'time_bin_end': ('time_bin', np.arange(20, 420, 20))},
dims=['presentation', 'neuroid', 'time_bin'])
target = jumbled_source.sortby(['stimulus_id', 'neuroid_id'])
pertime_regression = PerTimeRegression(XarrayRegression(LinearRegression()))
pertime_regression.fit(jumbled_source, target)
prediction = pertime_regression.predict(jumbled_source)
prediction = prediction.transpose(*target.dims)
# do not test for alignment of metadata - it is only important that the data is well-aligned with the metadata.
np.testing.assert_array_almost_equal(prediction.sortby(['stimulus_id', 'neuroid_id', 'time_bin']).values,
target.sortby(['stimulus_id', 'neuroid_id', 'time_bin']).values)


def test_spantime_regression(self):
jumbled_source = NeuroidAssembly(np.random.rand(500, 10, 20),
coords={'stimulus_id': ('presentation', list(reversed(range(500)))),
'image_meta': ('presentation', [0] * 500),
'neuroid_id': ('neuroid', list(reversed(range(10)))),
'neuroid_meta': ('neuroid', [0] * 10),
'time_bin_start': ('time_bin', np.arange(0, 400, 20)),
'time_bin_end': ('time_bin', np.arange(20, 420, 20))},
dims=['presentation', 'neuroid', 'time_bin'])
target = jumbled_source.sortby(['stimulus_id', 'neuroid_id'])
spantime_regression = SpanTimeRegression(XarrayRegression(LinearRegression()))
spantime_regression.fit(jumbled_source, target)
prediction = spantime_regression.predict(jumbled_source)
prediction = prediction.transpose(*target.dims)
# do not test for alignment of metadata - it is only important that the data is well-aligned with the metadata.
np.testing.assert_array_almost_equal(prediction.sortby(['stimulus_id', 'neuroid_id', 'time_bin']).values,
target.sortby(['stimulus_id', 'neuroid_id', 'time_bin']).values)

0 comments on commit b17ed0b

Please sign in to comment.