-
Notifications
You must be signed in to change notification settings - Fork 80
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add temporal metrics; add temporal versions of MajajHong2015 (#1109)
* feature: support temporal models for neural alignment by chaning TemporalIgnore to Temporal Aligned * add example temporal submission * complete new framework * new module: temporal model helpers * change the arch of temporal; add tutorials * improve: better naming * update: wrapper tutorial on brain model * add feature: inferencer identifier tracked by extractor for result caching * fix: video fps sampling; need more tests! * fix bugs: video sampling based on fps was wrong. * add mmaction2 models; add more features to the inferencers * PR: temporal model helpers * PR fix: not including gitmodules for now * Update brainscore_vision/model_helpers/brain_transformation/temporal.py Co-authored-by: Martin Schrimpf <[email protected]> * Update brainscore_vision/model_helpers/brain_transformation/temporal.py Co-authored-by: Martin Schrimpf <[email protected]> * Update brainscore_vision/model_helpers/brain_transformation/temporal.py Co-authored-by: Martin Schrimpf <[email protected]> * Update brainscore_vision/models/temporal_models/test.py Co-authored-by: Martin Schrimpf <[email protected]> * add mae_st; add ding2012 * try new arch * init ding2012 * add tests for temporal model helpers; add block inferencer * Delete tests/test_model_helpers/temporal/test___init__.py delete the old test * add benchmark ding2012 * add mutliple libs for temporal models * change executor output format; add more inference tests; init load_weight in s3 * add openstl * update backend for executor * feat:load_weight_file and corresponding test * change:resize strategy changed from bilinear to pooling * change:resize strategy changed from bilinear to pooling * fix mae_st submission * minor * fix:dtype in assembly time align * minor * update model submissions * fix dependency * refactor: simplify the inferencer methods * fix:block inferencer, neuroid coord while merging * fix:inferencer identifier * fix:weigh download * change tests to have max_workers=1 * revert screen.py * not submit region_layer_map * remove torch dependency * make fake modules in tests * add torch to requirements; avoid torch in tests * minor * minor * np.object changed to object * remove return in tests * fix insertion position bug * Apply suggestions from code review add: more type hints Co-authored-by: Martin Schrimpf <[email protected]> * add: more type hints and comments * minor * pr:only commit temporal model helpers * pr: add one model for example * undo whole_brain in Brainodel.RecordingTarget * use logger and fix newlines * fix: video fps with copy was wrong * feat:fractional max_spatial_size * downsample layers in VideoMAE * fix:video sampling wrong duration * add more tests * fix merge * fix merge * module refactor; add more input test * add more temporal models * fix videomaev2 sha * fix:temporal_modelmae_st * change:video conservative loading; rename:image to pil image * fix:video last frame sampling; fix_time_naming * ignore pytest_cache * re-trigger tests * add joblib pool error management; fix video/image path recognizer * update: naming of failed to pickle func in joblibmapper * add temporal metric helpers * add temporal version of mamjajhong2015 * Update benchmark.py type hint * Update benchmark.py * Update brainscore_vision/metric_helpers/temporal.py Co-authored-by: Martin Schrimpf <[email protected]> * Update brainscore_vision/metrics/internal_consistency/__init__.py Co-authored-by: Martin Schrimpf <[email protected]> * Update benchmark.py --------- Co-authored-by: Yingtian Tang <[email protected]> Co-authored-by: Martin Schrimpf <[email protected]> Co-authored-by: Martin Schrimpf <[email protected]> Co-authored-by: deirdre-k <[email protected]> Co-authored-by: Michael Ferguson <[email protected]>
- Loading branch information
1 parent
4dd1d0f
commit b17ed0b
Showing
8 changed files
with
320 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
import xarray as xr | ||
import numpy as np | ||
|
||
from brainscore_vision.benchmark_helpers.neural_common import Score | ||
from brainscore_vision.metric_helpers.transformations import standard_error_of_the_mean | ||
|
||
from .xarray_utils import apply_over_dims, recursive_op | ||
|
||
|
||
# take the mean of scores (medians of single neuron scores) over time | ||
|
||
|
||
def average_over_presentation(score: Score) -> Score: | ||
raw = score | ||
score = raw.mean('presentation') | ||
score.attrs['raw'] = raw | ||
return score | ||
|
||
|
||
# PerOps is applied to every slice/chunk of the xarray along the specified dimensions | ||
class PerOps: | ||
def __init__(self, callable, dims, check_coords=[]): | ||
# for coordinate checking, they are supposed to be the same across assemblies | ||
self.dims = dims | ||
self.callable = callable | ||
self.check_coords = check_coords | ||
|
||
def __call__(self, *asms): | ||
for check_coord in self.check_coords: | ||
asms = [asm.sortby(check_coord) for asm in asms] | ||
for asm in asms[1:]: | ||
assert (asm[check_coord].values == asms[0][check_coord].values).all() | ||
ret = apply_over_dims(self.callable, *asms, dims=self.dims) | ||
return ret | ||
|
||
|
||
# SpanOps aggregates specified dimensions to one dimension | ||
class SpanOps: | ||
def __init__(self, callable, source_dims, aggregated_dim, resample=False): | ||
# if resample, randomly choose samples from the aggregated dimension, | ||
# whose size is the same as the assembly.sizes[aggregated_dim] | ||
self.source_dims = source_dims | ||
self.aggregated_dim = aggregated_dim | ||
self.callable = callable | ||
self.resample = resample | ||
|
||
def __call__(self, *asms): | ||
asms = [self._stack(asm) for asm in asms] | ||
return self.callable(*asms) | ||
|
||
def _stack(self, assembly): | ||
assembly_type = type(assembly) | ||
size = assembly.sizes[self.aggregated_dim] | ||
assembly = xr.DataArray(assembly) # xarray cannot deal with stacking MultiIndex (pydata/xarray#1554) | ||
assembly = assembly.reset_index(self.source_dims) | ||
assembly = assembly.rename({dim:dim+"_" for dim in self.source_dims}) # we'll call stacked timebins "presentation" | ||
assembly = assembly.stack({self.aggregated_dim : [dim+"_" for dim in self.source_dims]}) | ||
if self.resample: | ||
indices = np.random.randint(0, assembly.sizes[self.aggregated_dim], size) | ||
assembly = assembly.isel({self.aggregated_dim: indices}) | ||
return assembly_type(assembly) | ||
|
||
class PerTime(PerOps): | ||
def __init__(self, callable, time_dim="time_bin", check_coord="time_bin_start", **kwargs): | ||
self.time_bin = time_dim | ||
super().__init__(callable, dims=[time_dim], check_coords=[check_coord], **kwargs) | ||
|
||
class PerPresentation(PerOps): | ||
def __init__(self, callable, presentation_dim="presentation", check_coord="stimulus_id", **kwargs): | ||
self.presentation_dim = presentation_dim | ||
super().__init__(callable, dims=[presentation_dim], check_coords=[check_coord], **kwargs) | ||
|
||
class PerNeuroid(PerOps): | ||
def __init__(self, callable, neuroid_dim="neuroid", check_coord="neuroid_id", **kwargs): | ||
self.neuroid_dim = neuroid_dim | ||
super().__init__(callable, dims=[neuroid_dim], check_coords=[check_coord], **kwargs) | ||
|
||
class SpanTime(SpanOps): | ||
def __init__(self, callable, time_dim="time_bin", presentation_dim="presentation", resample=False): | ||
self.time_dim = time_dim | ||
self.presentation_dim = presentation_dim | ||
source_dims = [self.time_dim, self.presentation_dim] | ||
aggregated_dim = self.presentation_dim | ||
super().__init__(callable, source_dims, aggregated_dim, resample=resample) | ||
|
||
class SpanTimeRegression: | ||
""" | ||
Fits a regression with weights shared across the time bins. | ||
""" | ||
|
||
def __init__(self, regression): | ||
self._regression = regression | ||
|
||
def fit(self, source, target): | ||
assert (source['time_bin'].values == target['time_bin'].values).all() | ||
SpanTime(self._regression.fit)(source, target) | ||
|
||
def predict(self, source): | ||
return PerTime(self._regression.predict)(source) | ||
|
||
class PerTimeRegression: | ||
""" | ||
Fits a regression with different weights for each time bins. | ||
""" | ||
|
||
def __init__(self, regression): | ||
self._regression = regression | ||
|
||
def fit(self, source, target): | ||
# Lazy fit until predict | ||
assert (source['time_bin'].values == target['time_bin'].values).all() | ||
self._train_source = source | ||
self._train_target = target | ||
|
||
def predict(self, source): | ||
def fit_predict(train_source, train_target, test_source): | ||
self._regression.fit(train_source, train_target) | ||
return self._regression.predict(test_source) | ||
return PerTime(fit_predict)(self._train_source, self._train_target, source) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,8 @@ | ||
from brainscore_vision import metric_registry | ||
from .ceiling import InternalConsistency | ||
|
||
from brainscore_vision.metric_helpers.temporal import PerTime | ||
|
||
|
||
metric_registry['internal_consistency'] = InternalConsistency | ||
metric_registry['internal_consistency_temporal'] = lambda *args, **kwargs: PerTime(InternalConsistency(*args, **kwargs)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
import numpy as np | ||
import scipy.stats | ||
from pytest import approx | ||
from sklearn.linear_model import LinearRegression | ||
|
||
from brainio.assemblies import NeuroidAssembly | ||
from brainscore_vision.metric_helpers.xarray_utils import XarrayRegression, XarrayCorrelation | ||
from brainscore_vision.metric_helpers.temporal import PerTime, SpanTime, PerTimeRegression, SpanTimeRegression | ||
|
||
|
||
class TestMetricHelpers: | ||
def test_pertime_ops(self): | ||
jumbled_source = NeuroidAssembly(np.random.rand(500, 10, 20), | ||
coords={'stimulus_id': ('presentation', list(reversed(range(500)))), | ||
'image_meta': ('presentation', [0] * 500), | ||
'neuroid_id': ('neuroid', list(reversed(range(10)))), | ||
'neuroid_meta': ('neuroid', [0] * 10), | ||
'time_bin_start': ('time_bin', np.arange(0, 400, 20)), | ||
'time_bin_end': ('time_bin', np.arange(20, 420, 20))}, | ||
dims=['presentation', 'neuroid', 'time_bin']) | ||
mean_neuroid = lambda arr: arr.mean('neuroid') | ||
pertime_mean_neuroid = PerTime(mean_neuroid) | ||
output = pertime_mean_neuroid(jumbled_source) | ||
output = output.transpose('presentation', 'time_bin') | ||
target = jumbled_source.transpose('presentation', 'time_bin', 'neuroid').mean('neuroid') | ||
assert (output == approx(target)).all().item() | ||
|
||
def test_spantime_ops(self): | ||
jumbled_source = NeuroidAssembly(np.random.rand(500, 10, 20), | ||
coords={'stimulus_id': ('presentation', list(reversed(range(500)))), | ||
'image_meta': ('presentation', [0] * 500), | ||
'neuroid_id': ('neuroid', list(reversed(range(10)))), | ||
'neuroid_meta': ('neuroid', [0] * 10), | ||
'time_bin_start': ('time_bin', np.arange(0, 400, 20)), | ||
'time_bin_end': ('time_bin', np.arange(20, 420, 20))}, | ||
dims=['presentation', 'neuroid', 'time_bin']) | ||
mean_presentation = lambda arr: arr.mean("presentation") | ||
spantime_mean_presentation = SpanTime(mean_presentation) | ||
output = spantime_mean_presentation(jumbled_source) | ||
output = output.transpose('neuroid') | ||
target = jumbled_source.transpose('presentation', 'time_bin', 'neuroid').mean('presentation').mean('time_bin') | ||
assert (output == approx(target)).all().item() | ||
|
||
def test_pertime_regression(self): | ||
jumbled_source = NeuroidAssembly(np.random.rand(500, 10, 20), | ||
coords={'stimulus_id': ('presentation', list(reversed(range(500)))), | ||
'image_meta': ('presentation', [0] * 500), | ||
'neuroid_id': ('neuroid', list(reversed(range(10)))), | ||
'neuroid_meta': ('neuroid', [0] * 10), | ||
'time_bin_start': ('time_bin', np.arange(0, 400, 20)), | ||
'time_bin_end': ('time_bin', np.arange(20, 420, 20))}, | ||
dims=['presentation', 'neuroid', 'time_bin']) | ||
target = jumbled_source.sortby(['stimulus_id', 'neuroid_id']) | ||
pertime_regression = PerTimeRegression(XarrayRegression(LinearRegression())) | ||
pertime_regression.fit(jumbled_source, target) | ||
prediction = pertime_regression.predict(jumbled_source) | ||
prediction = prediction.transpose(*target.dims) | ||
# do not test for alignment of metadata - it is only important that the data is well-aligned with the metadata. | ||
np.testing.assert_array_almost_equal(prediction.sortby(['stimulus_id', 'neuroid_id', 'time_bin']).values, | ||
target.sortby(['stimulus_id', 'neuroid_id', 'time_bin']).values) | ||
|
||
|
||
def test_spantime_regression(self): | ||
jumbled_source = NeuroidAssembly(np.random.rand(500, 10, 20), | ||
coords={'stimulus_id': ('presentation', list(reversed(range(500)))), | ||
'image_meta': ('presentation', [0] * 500), | ||
'neuroid_id': ('neuroid', list(reversed(range(10)))), | ||
'neuroid_meta': ('neuroid', [0] * 10), | ||
'time_bin_start': ('time_bin', np.arange(0, 400, 20)), | ||
'time_bin_end': ('time_bin', np.arange(20, 420, 20))}, | ||
dims=['presentation', 'neuroid', 'time_bin']) | ||
target = jumbled_source.sortby(['stimulus_id', 'neuroid_id']) | ||
spantime_regression = SpanTimeRegression(XarrayRegression(LinearRegression())) | ||
spantime_regression.fit(jumbled_source, target) | ||
prediction = spantime_regression.predict(jumbled_source) | ||
prediction = prediction.transpose(*target.dims) | ||
# do not test for alignment of metadata - it is only important that the data is well-aligned with the metadata. | ||
np.testing.assert_array_almost_equal(prediction.sortby(['stimulus_id', 'neuroid_id', 'time_bin']).values, | ||
target.sortby(['stimulus_id', 'neuroid_id', 'time_bin']).values) | ||
|