Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add temporal metrics; add temporal versions of MajajHong2015 #1109

Merged
merged 121 commits into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from 120 commits
Commits
Show all changes
121 commits
Select commit Hold shift + click to select a range
d97d452
feature: support temporal models for neural alignment by chaning Temp…
Feb 25, 2024
2a32f46
Merge branch 'brain-score:master' into master
YingtianDt Feb 25, 2024
c372d12
add example temporal submission
YingtianDt Feb 26, 2024
40d2d2b
Merge branch 'master' of https://github.com/YingtianDt/brain-score
YingtianDt Feb 26, 2024
c0827b3
Merge branch 'brain-score:master' into master
YingtianDt Feb 26, 2024
92cede9
complete new framework
YingtianDt Feb 27, 2024
e74015d
Merge branch 'master' of https://github.com/YingtianDt/brain-score
YingtianDt Feb 27, 2024
088b9af
new module: temporal model helpers
YingtianDt Mar 3, 2024
c5b9097
change the arch of temporal; add tutorials
YingtianDt Mar 5, 2024
8f301a3
improve: better naming
YingtianDt Mar 5, 2024
3c3e475
update: wrapper tutorial on brain model
YingtianDt Mar 5, 2024
34fbc08
add feature: inferencer identifier tracked by extractor for result ca…
YingtianDt Mar 5, 2024
f6c2fe4
fix: video fps sampling; need more tests!
YingtianDt Mar 6, 2024
7e948ed
fix bugs: video sampling based on fps was wrong.
YingtianDt Mar 6, 2024
3998027
Merge branch 'brain-score:master' into master
YingtianDt Mar 6, 2024
17ecc5d
add mmaction2 models; add more features to the inferencers
YingtianDt Mar 8, 2024
5fac41c
Merge branch 'master' of https://github.com/YingtianDt/brain-score
YingtianDt Mar 8, 2024
20941fe
Merge branch 'brain-score:master' into master
YingtianDt Mar 11, 2024
89bf58c
PR: temporal model helpers
YingtianDt Mar 11, 2024
1433635
PR fix: not including gitmodules for now
YingtianDt Mar 11, 2024
b678e95
Update brainscore_vision/model_helpers/brain_transformation/temporal.py
YingtianDt Mar 11, 2024
905c964
Update brainscore_vision/model_helpers/brain_transformation/temporal.py
YingtianDt Mar 11, 2024
e1e77ca
Update brainscore_vision/model_helpers/brain_transformation/temporal.py
YingtianDt Mar 11, 2024
817b448
Update brainscore_vision/models/temporal_models/test.py
YingtianDt Mar 11, 2024
ad8cf5a
add mae_st; add ding2012
YingtianDt Mar 11, 2024
749c371
Merge branch 'master' of https://github.com/YingtianDt/brain-score
YingtianDt Mar 11, 2024
95e413e
try new arch
YingtianDt Mar 12, 2024
d020363
merge
YingtianDt Mar 12, 2024
d6960fc
init ding2012
YingtianDt Mar 12, 2024
227e801
add tests for temporal model helpers; add block inferencer
YingtianDt Mar 12, 2024
7906ddd
Delete tests/test_model_helpers/temporal/test___init__.py
YingtianDt Mar 16, 2024
9316398
add benchmark ding2012
YingtianDt Mar 18, 2024
e0ba781
merge pr version
YingtianDt Mar 18, 2024
6151364
resolve merge conflict
YingtianDt Mar 18, 2024
8fcb49f
add mutliple libs for temporal models
YingtianDt Mar 19, 2024
f6b6554
change executor output format; add more inference tests; init load_we…
YingtianDt Mar 19, 2024
bf82677
add openstl
YingtianDt Mar 20, 2024
4d2b669
merge
YingtianDt Mar 20, 2024
66dafe0
update backend for executor
YingtianDt Mar 23, 2024
c7b2a84
feat:load_weight_file and corresponding test
YingtianDt Mar 25, 2024
4b77eab
change:resize strategy changed from bilinear to pooling
YingtianDt Mar 25, 2024
d56f164
change:resize strategy changed from bilinear to pooling
YingtianDt Mar 25, 2024
70fdb9e
fix mae_st submission
YingtianDt Mar 25, 2024
e7c37fc
minor
YingtianDt Mar 25, 2024
beba210
fix:dtype in assembly time align
YingtianDt Mar 26, 2024
83585a1
minor
YingtianDt Mar 26, 2024
d2d2673
update model submissions
YingtianDt Mar 26, 2024
8e84c09
Merge branch 'brain-score:master' into master
YingtianDt Mar 26, 2024
0b0d598
fix dependency
YingtianDt Mar 27, 2024
89d5424
Merge branch 'master' of https://github.com/YingtianDt/brain-score
YingtianDt Mar 27, 2024
df51aea
refactor: simplify the inferencer methods
YingtianDt Mar 27, 2024
f6f3105
fix:block inferencer, neuroid coord while merging
YingtianDt Mar 28, 2024
323eb8a
fix:inferencer identifier
YingtianDt Mar 29, 2024
244d38d
fix:weigh download
YingtianDt Apr 2, 2024
657996f
Merge branch 'brain-score:master' into master
YingtianDt Apr 3, 2024
bfbb470
change tests to have max_workers=1
YingtianDt Apr 4, 2024
a062a07
Merge branch 'master' of https://github.com/YingtianDt/brain-score
YingtianDt Apr 4, 2024
af3cbbb
revert screen.py
YingtianDt Apr 4, 2024
01fb289
not submit region_layer_map
YingtianDt Apr 4, 2024
cb18a88
remove torch dependency
YingtianDt Apr 4, 2024
3e47774
make fake modules in tests
YingtianDt Apr 4, 2024
64f2ec4
add torch to requirements; avoid torch in tests
YingtianDt Apr 4, 2024
20aca51
minor
YingtianDt Apr 4, 2024
2f7a3d8
minor
YingtianDt Apr 4, 2024
2d411de
Merge branch 'brain-score:master' into master
YingtianDt Apr 4, 2024
0dbf122
np.object changed to object
YingtianDt Apr 4, 2024
d19a633
Merge branch 'master' of https://github.com/YingtianDt/brain-score
YingtianDt Apr 4, 2024
4ac951c
remove return in tests
YingtianDt Apr 4, 2024
f0c6e39
fix insertion position bug
YingtianDt Apr 6, 2024
e2a2ce3
Merge branch 'master' into master
mschrimpf Apr 8, 2024
64079b8
Apply suggestions from code review
YingtianDt Apr 8, 2024
465fb07
add: more type hints and comments
YingtianDt Apr 8, 2024
d430141
minor
YingtianDt Apr 8, 2024
effb0f8
pr:only commit temporal model helpers
YingtianDt Apr 8, 2024
20e1b0a
pr: add one model for example
YingtianDt Apr 8, 2024
c0e7b13
Merge branch 'master' into master
mschrimpf Apr 9, 2024
95065a5
undo whole_brain in Brainodel.RecordingTarget
YingtianDt Apr 9, 2024
61737bd
Merge branch 'master' of https://github.com/YingtianDt/brain-score
YingtianDt Apr 9, 2024
0febc8e
Merge branch 'master' into master
mschrimpf Apr 11, 2024
28a5f40
use logger and fix newlines
mschrimpf Apr 11, 2024
b560af3
fix: video fps with copy was wrong
YingtianDt Apr 26, 2024
bea3d26
feat:fractional max_spatial_size
YingtianDt Apr 26, 2024
70b2370
downsample layers in VideoMAE
YingtianDt Apr 26, 2024
8b7a733
Merge branch 'master' of https://github.com/YingtianDt/brain-score
YingtianDt Apr 26, 2024
bd6f436
fix:video sampling wrong duration
YingtianDt Apr 26, 2024
2ed5123
add more tests
YingtianDt Apr 29, 2024
0ae615c
merge upstream
YingtianDt Apr 30, 2024
05fd380
fix merge
YingtianDt Apr 30, 2024
e4b11bb
fix merge
YingtianDt Apr 30, 2024
ebd3c0a
module refactor; add more input test
YingtianDt May 14, 2024
85f5217
add more temporal models
YingtianDt May 14, 2024
faf7fdd
Merge branch 'master' of https://github.com/brain-score/vision
YingtianDt May 14, 2024
3e14419
fix videomaev2 sha
YingtianDt May 15, 2024
427ae93
fix:temporal_modelmae_st
YingtianDt May 17, 2024
f3bd1fe
change:video conservative loading; rename:image to pil image
YingtianDt May 27, 2024
e204c66
fix:video last frame sampling; fix_time_naming
YingtianDt May 30, 2024
4343a73
Merge branch 'master' of https://github.com/brain-score/vision
YingtianDt May 30, 2024
7c7a23b
ignore pytest_cache
YingtianDt May 30, 2024
ccfdf87
Merge branch 'master' into master
mschrimpf Jun 7, 2024
a6ba086
Merge branch 'master' into master
YingtianDt Jun 17, 2024
7164de4
gerge branch 'master' of https://github.com/YingtianDt/brain-score
YingtianDt Jun 17, 2024
94bb305
re-trigger tests
YingtianDt Jun 18, 2024
342779c
Merge branch 'brain-score:master' into master
YingtianDt Jun 18, 2024
6acb243
add joblib pool error management; fix video/image path recognizer
YingtianDt Jun 18, 2024
ea732d7
Merge branch 'master' of https://github.com/YingtianDt/brain-score
YingtianDt Jun 18, 2024
96bc38b
Merge branch 'master' into master
YingtianDt Jun 20, 2024
fab8741
Merge branch 'master' into master
deirdre-k Jun 21, 2024
c39233b
Merge branch 'master' into master
YingtianDt Jun 24, 2024
91505cb
update: naming of failed to pickle func in joblibmapper
YingtianDt Jun 26, 2024
516492a
Merge branch 'master' into master
YingtianDt Jun 26, 2024
f87d9ac
Merge branch 'brain-score:master' into master
YingtianDt Jun 26, 2024
0f8fee4
add temporal metric helpers
YingtianDt Jul 30, 2024
e06853e
add temporal version of mamjajhong2015
YingtianDt Jul 31, 2024
b1dcaf7
Merge branch 'master' of https://github.com/YingtianDt/brain-score
YingtianDt Jul 31, 2024
57ce832
Merge branch 'master' into master
mike-ferguson Aug 7, 2024
e7dd0ef
Merge branch 'master' into master
YingtianDt Aug 15, 2024
f0243a0
Update benchmark.py
YingtianDt Sep 6, 2024
5542996
Update benchmark.py
YingtianDt Sep 6, 2024
149a6b0
Update brainscore_vision/metric_helpers/temporal.py
YingtianDt Sep 6, 2024
9cdebde
Update brainscore_vision/metrics/internal_consistency/__init__.py
YingtianDt Sep 6, 2024
790851b
Update benchmark.py
YingtianDt Sep 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions brainscore_vision/benchmarks/majajhong2015/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,8 @@

benchmark_registry['MajajHong2015public.V4-pls'] = MajajHongV4PublicBenchmark
benchmark_registry['MajajHong2015public.IT-pls'] = MajajHongITPublicBenchmark

# temporal
from .benchmark import MajajHongV4TemporalPublicBenchmark, MajajHongITTemporalPublicBenchmark
benchmark_registry['MajajHong2015public.V4-temporal-pls'] = lambda: MajajHongV4TemporalPublicBenchmark(time_interval=10)
benchmark_registry['MajajHong2015public.IT-temporal-pls'] = lambda: MajajHongITTemporalPublicBenchmark(time_interval=10)
44 changes: 34 additions & 10 deletions brainscore_vision/benchmarks/majajhong2015/benchmark.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from brainscore_core import Metric

from brainscore_vision import load_metric, Ceiling, load_ceiling, load_dataset
from brainscore_vision.benchmark_helpers.neural_common import NeuralBenchmark, average_repetition
from brainscore_vision.benchmark_helpers.neural_common import NeuralBenchmark, average_repetition, apply_keep_attrs
from brainscore_vision.model_helpers.brain_transformation.temporal import assembly_time_align

VISUAL_DEGREES = 8
NUMBER_OF_TRIALS = 50
Expand All @@ -20,13 +21,14 @@
eprint = {https://www.jneurosci.org/content/35/39/13402.full.pdf},
journal = {Journal of Neuroscience}}"""

pls_metric = lambda: load_metric('pls', crossvalidation_kwargs=dict(stratification_coord='object_name'))

crossvalidation_kwargs = dict(stratification_coord='object_name')
pls_metric = lambda: load_metric('pls', crossvalidation_kwargs=crossvalidation_kwargs)
spantime_pls_metric = lambda: load_metric('spantime_pls', crossvalidation_kwargs=crossvalidation_kwargs)

def _DicarloMajajHong2015Region(region: str, access: str, identifier_metric_suffix: str,
similarity_metric: Metric, ceiler: Ceiling):
assembly_repetition = load_assembly(average_repetitions=False, region=region, access=access)
assembly = load_assembly(average_repetitions=True, region=region, access=access)
similarity_metric: Metric, ceiler: Ceiling, time_interval: float = None):
assembly_repetition = load_assembly(average_repetitions=False, region=region, access=access, time_interval=time_interval)
assembly = load_assembly(average_repetitions=True, region=region, access=access, time_interval=time_interval)
benchmark_identifier = f'MajajHong2015.{region}' + ('.public' if access == 'public' else '')
return NeuralBenchmark(identifier=f'{benchmark_identifier}-{identifier_metric_suffix}', version=3,
assembly=assembly, similarity_metric=similarity_metric,
Expand Down Expand Up @@ -60,13 +62,35 @@ def MajajHongITPublicBenchmark():
ceiler=load_ceiling('internal_consistency'))


def load_assembly(average_repetitions, region, access='private'):
assembly = load_dataset(f'MajajHong2015.{access}')
def MajajHongV4TemporalPublicBenchmark(time_interval: float = None):
return _DicarloMajajHong2015Region(region='V4', access='public', identifier_metric_suffix='pls',
similarity_metric=spantime_pls_metric(), time_interval=time_interval,
ceiler=load_ceiling('pertime_internal_consistency'))


def MajajHongITTemporalPublicBenchmark(time_interval: float = None):
return _DicarloMajajHong2015Region(region='IT', access='public', identifier_metric_suffix='pls',
similarity_metric=spantime_pls_metric(), time_interval=time_interval,
ceiler=load_ceiling('pertime_internal_consistency'))


def load_assembly(average_repetitions: bool, region: str, access: str = 'private', time_interval: float = None):
temporal = time_interval is not None
if not temporal:
assembly = load_dataset(f'MajajHong2015.{access}')
assembly = assembly.squeeze("time_bin")
else:
assembly = load_dataset(f'MajajHong2015.temporal.{access}')
assembly = assembly.__class__(assembly)
target_time_bins = [
(t, t+time_interval) for t in range(0, assembly.time_bin_end.max().item()-time_interval, time_interval)
]
assembly = apply_keep_attrs(assembly, lambda assembly: assembly_time_align(assembly, target_time_bins))

assembly = assembly.sel(region=region)
assembly['region'] = 'neuroid', [region] * len(assembly['neuroid'])
assembly = assembly.squeeze("time_bin")
assembly.load()
assembly = assembly.transpose('presentation', 'neuroid')
assembly = assembly.transpose('presentation', 'neuroid', ...)
if average_repetitions:
assembly = average_repetition(assembly)
return assembly
119 changes: 119 additions & 0 deletions brainscore_vision/metric_helpers/temporal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
import xarray as xr
import numpy as np

from brainscore_vision.benchmark_helpers.neural_common import Score
from brainscore_vision.metric_helpers.transformations import standard_error_of_the_mean

from .xarray_utils import apply_over_dims, recursive_op


# take the mean of scores (medians of single neuron scores) over time


def average_over_presentation(score: Score) -> Score:
raw = score
score = raw.mean('presentation')
score.attrs['raw'] = raw
return score


# PerOps is applied to every slice/chunk of the xarray along the specified dimensions
class PerOps:
def __init__(self, callable, dims, check_coords=[]):
# for coordinate checking, they are supposed to be the same across assemblies
self.dims = dims
self.callable = callable
self.check_coords = check_coords

def __call__(self, *asms):
for check_coord in self.check_coords:
asms = [asm.sortby(check_coord) for asm in asms]
for asm in asms[1:]:
assert (asm[check_coord].values == asms[0][check_coord].values).all()
ret = apply_over_dims(self.callable, *asms, dims=self.dims)
return ret


# SpanOps aggregates specified dimensions to one dimension
class SpanOps:
def __init__(self, callable, source_dims, aggregated_dim, resample=False):
# if resample, randomly choose samples from the aggregated dimension,
# whose size is the same as the assembly.sizes[aggregated_dim]
self.source_dims = source_dims
self.aggregated_dim = aggregated_dim
self.callable = callable
self.resample = resample

def __call__(self, *asms):
asms = [self._stack(asm) for asm in asms]
return self.callable(*asms)

def _stack(self, assembly):
assembly_type = type(assembly)
size = assembly.sizes[self.aggregated_dim]
assembly = xr.DataArray(assembly) # xarray cannot deal with stacking MultiIndex (pydata/xarray#1554)
assembly = assembly.reset_index(self.source_dims)
assembly = assembly.rename({dim:dim+"_" for dim in self.source_dims}) # we'll call stacked timebins "presentation"
assembly = assembly.stack({self.aggregated_dim : [dim+"_" for dim in self.source_dims]})
if self.resample:
indices = np.random.randint(0, assembly.sizes[self.aggregated_dim], size)
assembly = assembly.isel({self.aggregated_dim: indices})
return assembly_type(assembly)

class PerTime(PerOps):
def __init__(self, callable, time_dim="time_bin", check_coord="time_bin_start", **kwargs):
self.time_bin = time_dim
super().__init__(callable, dims=[time_dim], check_coords=[check_coord], **kwargs)

class PerPresentation(PerOps):
def __init__(self, callable, presentation_dim="presentation", check_coord="stimulus_id", **kwargs):
self.presentation_dim = presentation_dim
super().__init__(callable, dims=[presentation_dim], check_coords=[check_coord], **kwargs)

class PerNeuroid(PerOps):
def __init__(self, callable, neuroid_dim="neuroid", check_coord="neuroid_id", **kwargs):
self.neuroid_dim = neuroid_dim
super().__init__(callable, dims=[neuroid_dim], check_coords=[check_coord], **kwargs)

class SpanTime(SpanOps):
def __init__(self, callable, time_dim="time_bin", presentation_dim="presentation", resample=False):
self.time_dim = time_dim
self.presentation_dim = presentation_dim
source_dims = [self.time_dim, self.presentation_dim]
aggregated_dim = self.presentation_dim
super().__init__(callable, source_dims, aggregated_dim, resample=resample)

class SpanTimeRegression:
"""
Fits a regression with weights shared across the time bins.
"""

def __init__(self, regression):
self._regression = regression

def fit(self, source, target):
assert (source['time_bin'].values == target['time_bin'].values).all()
SpanTime(self._regression.fit)(source, target)

def predict(self, source):
return PerTime(self._regression.predict)(source)

class PerTimeRegression:
"""
Fits a regression with different weights for each time bins.
"""

def __init__(self, regression):
self._regression = regression

def fit(self, source, target):
# Lazy fit until predict
assert (source['time_bin'].values == target['time_bin'].values).all()
self._train_source = source
self._train_target = target

def predict(self, source):
def fit_predict(train_source, train_target, test_source):
self._regression.fit(train_source, train_target)
return self._regression.predict(test_source)
return PerTime(fit_predict)(self._train_source, self._train_target, source)
59 changes: 59 additions & 0 deletions brainscore_vision/metric_helpers/xarray_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
import xarray as xr

from brainio.assemblies import NeuroidAssembly, array_is_element, walk_coords
from brainscore_vision.metric_helpers import Defaults
Expand Down Expand Up @@ -90,3 +91,61 @@ def __call__(self, prediction, target):
for coord, dims, values in walk_coords(target) if dims == neuroid_dims},
dims=neuroid_dims)
return result


# ops that also applies to attrs (and attrs of attrs), which are xarrays
def recursive_op(*arrs, op=lambda x:x):
# the attrs structure of each arr must be the same
val = op(*arrs)
attrs = arrs[0].attrs
for attr in attrs:
attr_val = arrs[0].attrs[attr]
if isinstance(attr_val, xr.DataArray):
attr_arrs = [arr.attrs[attr] for arr in arrs]
attr_val = recursive_op(*attr_arrs, op=op)
val.attrs[attr] = attr_val
return val


# apply a callable to every slice of the xarray along the specified dimensions
def apply_over_dims(callable, *asms, dims, njobs=-1):
asms = [asm.transpose(*dims, ...) for asm in asms]
sizes = [asms[0].sizes[dim] for dim in dims]

def apply_helper(sizes, dims, *asms):
xarr = []
attrs = {}
size = sizes[0]
rsizes = sizes[1:]
dim = dims[0]
rdims = dims[1:]

if len(sizes) == 1:
# parallel execution on the last applied dimension
from joblib import Parallel, delayed
results = Parallel(n_jobs=njobs)(delayed(callable)(*[asm.isel({dim:s}) for asm in asms]) for s in range(size))
else:
results = []
for s in range(size):
arr = apply_helper(rsizes, rdims, *[asm.isel({dim:s}) for asm in asms])
results.append(arr)

for arr in results:
if arr is not None:
for k,v in arr.attrs.items():
assert isinstance(v, xr.DataArray)
attrs.setdefault(k, []).append(v.expand_dims(dim))
xarr.append(arr)

if not xarr:
return
else:
xarr = xr.concat(xarr, dim=dim)
attrs = {k: xr.concat(vs, dim=dim) for k,vs in attrs.items()}
xarr.coords[dim] = asms[0].coords[dim]
for k,v in attrs.items():
attrs[k].coords[dim] = asms[0].coords[dim]
xarr.attrs[k] = attrs[k]
return xarr

return apply_helper(sizes, dims, *asms)
4 changes: 4 additions & 0 deletions brainscore_vision/metrics/internal_consistency/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
from brainscore_vision import metric_registry
from .ceiling import InternalConsistency

from brainscore_vision.metric_helpers.temporal import PerTime


metric_registry['internal_consistency'] = InternalConsistency
metric_registry['internal_consistency_temporal'] = lambda *args, **kwargs: PerTime(InternalConsistency(*args, **kwargs))
9 changes: 9 additions & 0 deletions brainscore_vision/metrics/regression_correlation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,15 @@
metric_registry['linear_predictivity'] = lambda *args, **kwargs: CrossRegressedCorrelation(
regression=linear_regression(), correlation=pearsonr_correlation(), *args, **kwargs)

# temporal metrics
from .metric import SpanTimeCrossRegressedCorrelation

metric_registry['spantime_pls'] = lambda *args, **kwargs: SpanTimeCrossRegressedCorrelation(
regression=pls_regression(), correlation=pearsonr_correlation(), *args, **kwargs)
metric_registry['spantime_ridge'] = lambda *args, **kwargs: SpanTimeCrossRegressedCorrelation(
regression=ridge_regression(), correlation=pearsonr_correlation(), *args, **kwargs)
Comment on lines +17 to +20
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not just temporal_ridge?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because "spantime" and "pertime" are essentially different: the former shares the weights across time and the latter train separate weights for each time bin.



BIBTEX = """@article{schrimpf2018brain,
title={Brain-score: Which artificial neural network for object recognition is most brain-like?},
author={Schrimpf, Martin and Kubilius, Jonas and Hong, Ha and Majaj, Najib J and Rajalingham, Rishi and Issa, Elias B and Kar, Kohitij and Bashivan, Pouya and Prescott-Roy, Jonathan and Geiger, Franziska and others},
Expand Down
10 changes: 10 additions & 0 deletions brainscore_vision/metrics/regression_correlation/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from brainscore_core.metrics import Metric, Score
from brainscore_vision.metric_helpers.transformations import CrossValidation
from brainscore_vision.metric_helpers.xarray_utils import XarrayRegression, XarrayCorrelation
from brainscore_vision.metric_helpers.temporal import SpanTimeRegression, PerTime


class CrossRegressedCorrelation(Metric):
Expand Down Expand Up @@ -65,6 +66,15 @@ def predict(self, X):
return Ypred


# make the crc to consider time as a sample dimension
def SpanTimeCrossRegressedCorrelation(regression, correlation, *args, **kwargs):
return CrossRegressedCorrelation(
regression=SpanTimeRegression(regression),
correlation=PerTime(correlation),
*args, **kwargs
)


def pls_regression(regression_kwargs=None, xarray_kwargs=None):
regression_defaults = dict(n_components=25, scale=False)
regression_kwargs = {**regression_defaults, **(regression_kwargs or {})}
Expand Down
80 changes: 80 additions & 0 deletions tests/test_metric_helpers/test_temporal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import numpy as np
import scipy.stats
from pytest import approx
from sklearn.linear_model import LinearRegression

from brainio.assemblies import NeuroidAssembly
from brainscore_vision.metric_helpers.xarray_utils import XarrayRegression, XarrayCorrelation
from brainscore_vision.metric_helpers.temporal import PerTime, SpanTime, PerTimeRegression, SpanTimeRegression


class TestMetricHelpers:
def test_pertime_ops(self):
jumbled_source = NeuroidAssembly(np.random.rand(500, 10, 20),
coords={'stimulus_id': ('presentation', list(reversed(range(500)))),
'image_meta': ('presentation', [0] * 500),
'neuroid_id': ('neuroid', list(reversed(range(10)))),
'neuroid_meta': ('neuroid', [0] * 10),
'time_bin_start': ('time_bin', np.arange(0, 400, 20)),
'time_bin_end': ('time_bin', np.arange(20, 420, 20))},
dims=['presentation', 'neuroid', 'time_bin'])
mean_neuroid = lambda arr: arr.mean('neuroid')
pertime_mean_neuroid = PerTime(mean_neuroid)
output = pertime_mean_neuroid(jumbled_source)
output = output.transpose('presentation', 'time_bin')
target = jumbled_source.transpose('presentation', 'time_bin', 'neuroid').mean('neuroid')
assert (output == approx(target)).all().item()

def test_spantime_ops(self):
jumbled_source = NeuroidAssembly(np.random.rand(500, 10, 20),
coords={'stimulus_id': ('presentation', list(reversed(range(500)))),
'image_meta': ('presentation', [0] * 500),
'neuroid_id': ('neuroid', list(reversed(range(10)))),
'neuroid_meta': ('neuroid', [0] * 10),
'time_bin_start': ('time_bin', np.arange(0, 400, 20)),
'time_bin_end': ('time_bin', np.arange(20, 420, 20))},
dims=['presentation', 'neuroid', 'time_bin'])
mean_presentation = lambda arr: arr.mean("presentation")
spantime_mean_presentation = SpanTime(mean_presentation)
output = spantime_mean_presentation(jumbled_source)
output = output.transpose('neuroid')
target = jumbled_source.transpose('presentation', 'time_bin', 'neuroid').mean('presentation').mean('time_bin')
assert (output == approx(target)).all().item()

def test_pertime_regression(self):
jumbled_source = NeuroidAssembly(np.random.rand(500, 10, 20),
coords={'stimulus_id': ('presentation', list(reversed(range(500)))),
'image_meta': ('presentation', [0] * 500),
'neuroid_id': ('neuroid', list(reversed(range(10)))),
'neuroid_meta': ('neuroid', [0] * 10),
'time_bin_start': ('time_bin', np.arange(0, 400, 20)),
'time_bin_end': ('time_bin', np.arange(20, 420, 20))},
dims=['presentation', 'neuroid', 'time_bin'])
target = jumbled_source.sortby(['stimulus_id', 'neuroid_id'])
pertime_regression = PerTimeRegression(XarrayRegression(LinearRegression()))
pertime_regression.fit(jumbled_source, target)
prediction = pertime_regression.predict(jumbled_source)
prediction = prediction.transpose(*target.dims)
# do not test for alignment of metadata - it is only important that the data is well-aligned with the metadata.
np.testing.assert_array_almost_equal(prediction.sortby(['stimulus_id', 'neuroid_id', 'time_bin']).values,
target.sortby(['stimulus_id', 'neuroid_id', 'time_bin']).values)


def test_spantime_regression(self):
jumbled_source = NeuroidAssembly(np.random.rand(500, 10, 20),
coords={'stimulus_id': ('presentation', list(reversed(range(500)))),
'image_meta': ('presentation', [0] * 500),
'neuroid_id': ('neuroid', list(reversed(range(10)))),
'neuroid_meta': ('neuroid', [0] * 10),
'time_bin_start': ('time_bin', np.arange(0, 400, 20)),
'time_bin_end': ('time_bin', np.arange(20, 420, 20))},
dims=['presentation', 'neuroid', 'time_bin'])
target = jumbled_source.sortby(['stimulus_id', 'neuroid_id'])
spantime_regression = SpanTimeRegression(XarrayRegression(LinearRegression()))
spantime_regression.fit(jumbled_source, target)
prediction = spantime_regression.predict(jumbled_source)
prediction = prediction.transpose(*target.dims)
# do not test for alignment of metadata - it is only important that the data is well-aligned with the metadata.
np.testing.assert_array_almost_equal(prediction.sortby(['stimulus_id', 'neuroid_id', 'time_bin']).values,
target.sortby(['stimulus_id', 'neuroid_id', 'time_bin']).values)

Loading