Skip to content

Commit

Permalink
Add more temporal models (#924)
Browse files Browse the repository at this point in the history
* feature: support temporal models for neural alignment by chaning TemporalIgnore to Temporal Aligned

* add example temporal submission

* complete new framework

* new module: temporal model helpers

* change the arch of temporal; add tutorials

* improve: better naming

* update: wrapper tutorial on brain model

* add feature: inferencer identifier tracked by extractor for result caching

* fix: video fps sampling; need more tests!

* fix bugs: video sampling based on fps was wrong.

* add mmaction2 models; add more features to the inferencers

* PR: temporal model helpers

* PR fix: not including gitmodules for now

* Update brainscore_vision/model_helpers/brain_transformation/temporal.py

Co-authored-by: Martin Schrimpf <[email protected]>

* Update brainscore_vision/model_helpers/brain_transformation/temporal.py

Co-authored-by: Martin Schrimpf <[email protected]>

* Update brainscore_vision/model_helpers/brain_transformation/temporal.py

Co-authored-by: Martin Schrimpf <[email protected]>

* Update brainscore_vision/models/temporal_models/test.py

Co-authored-by: Martin Schrimpf <[email protected]>

* add mae_st; add ding2012

* try new arch

* init ding2012

* add tests for temporal model helpers; add block inferencer

* Delete tests/test_model_helpers/temporal/test___init__.py

delete the old test

* add benchmark ding2012

* add mutliple libs for temporal models

* change executor output format; add more inference tests; init load_weight in s3

* add openstl

* update backend for executor

* feat:load_weight_file and corresponding test

* change:resize strategy changed from bilinear to pooling

* change:resize strategy changed from bilinear to pooling

* fix mae_st submission

* minor

* fix:dtype in assembly time align

* minor

* update model submissions

* fix dependency

* refactor: simplify the inferencer methods

* fix:block inferencer, neuroid coord while merging

* fix:inferencer identifier

* fix:weigh download

* change tests to have max_workers=1

* revert screen.py

* not submit region_layer_map

* remove torch dependency

* make fake modules in tests

* add torch to requirements; avoid torch in tests

* minor

* minor

* np.object changed to object

* remove return in tests

* fix insertion position bug

* Apply suggestions from code review

add: more type hints

Co-authored-by: Martin Schrimpf <[email protected]>

* add: more type hints and comments

* minor

* pr:only commit temporal model helpers

* pr: add one model for example

* undo whole_brain in Brainodel.RecordingTarget

* use logger and fix newlines

* fix: video fps with copy was wrong

* feat:fractional max_spatial_size

* downsample layers in VideoMAE

* fix:video sampling wrong duration

* add more tests

* fix merge

* fix merge

* module refactor; add more input test

* add more temporal models

* fix videomaev2 sha

* fix:temporal_modelmae_st

* change:video conservative loading; rename:image to pil image

* fix:video last frame sampling; fix_time_naming

* ignore pytest_cache

* re-trigger tests

* add joblib pool error management; fix video/image path recognizer

* update: naming of failed to pickle func in joblibmapper

---------

Co-authored-by: Yingtian Tang <[email protected]>
Co-authored-by: Martin Schrimpf <[email protected]>
Co-authored-by: Martin Schrimpf <[email protected]>
Co-authored-by: deirdre-k <[email protected]>
  • Loading branch information
5 people authored Jun 26, 2024
1 parent 573265c commit 7f99883
Show file tree
Hide file tree
Showing 45 changed files with 1,463 additions and 12 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ venv
.venv
build
.DS_Store
.pytest_cache

# Model Weights
*.pt
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,16 @@ class JoblibMapper:
def __init__(self, num_threads: int):
self._num_threads = num_threads
self._pool = Parallel(n_jobs=num_threads, verbose=False, backend="loky")
self._failed_to_pickle_func = False

def map(self, func, *data):
return self._pool(delayed(func)(*x) for x in zip(*data))
from joblib.externals.loky.process_executor import TerminatedWorkerError, BrokenProcessPool
if not self._failed_to_pickle_func:
try:
return self._pool(delayed(func)(*x) for x in zip(*data))
except (TerminatedWorkerError, BrokenProcessPool):
self._failed_to_pickle_func = True
return [func(*x) for x in zip(*data)]


class BatchExecutor:
Expand Down Expand Up @@ -209,4 +216,4 @@ def execute(self, layers):
layer_activations[layer] = [activations[i] for i in indices]

self.clear_stimuli()
return layer_activations
return layer_activations
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@ def from_path(self, path):

@staticmethod
def is_video_path(path: Union[str, Path]) -> bool:
extension = path.split('.')[-1]
extension = path.split('.')[-1].lower()
return extension in ['mp4', 'avi', 'mov', 'flv', 'wmv', 'webm', 'mkv', 'gif']

@staticmethod
def is_image_path(path: Union[str, Path]) -> bool:
extension = path.split('.')[-1]
extension = path.split('.')[-1].lower()
return extension in ['jpg', 'jpeg', 'png', 'bmp', 'tiff']

Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,25 @@ def set_size(self, size):
def from_path(path):
return Image(path, get_image_size(path))

def to_img(self):
def to_pil_img(self):
return PILImage.fromarray(self.to_numpy())

def get_frame(self):
return np.array(PILImage.open(self._path).convert('RGB'))

# return (H, W, C[RGB])
def to_numpy(self):
arr = np.array(PILImage.open(self._path).convert('RGB'))
arr = self.get_frame()

if arr.shape[:2][::-1] != self._size:
arr = batch_2d_resize(arr[None,:], self._size, "bilinear")[0]

return arr

def store_to_path(self, path):
self.to_img().save(path)
return path

def get_image_size(path):
with PILImage.open(path) as img:
size = img.size
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,6 @@ def get_frames(self, indices):

### I/O
def from_path(path):
path = path
fps, end, size = get_video_stats(path)
start = 0
return Video(path, fps, start, end, size)
Expand All @@ -139,7 +138,7 @@ def to_numpy(self):
sample_indices = samples.astype(int)

# padding: repeat the first/last frame
original_num_frames = int(self._original_duration * self._original_fps/1000 + EPS)
original_num_frames = int(self._original_duration * self._original_fps/1000 - EPS) # EPS to avoid last frame OOB error
sample_indices = np.clip(sample_indices, 0, original_num_frames-1)

# actual sampling
Expand All @@ -156,7 +155,7 @@ def to_frames(self):

def to_pil_imgs(self):
return [PILImage.fromarray(frame) for frame in self.to_numpy()]

def to_path(self):
# use context manager ?
path = None # make a temporal file
Expand Down
4 changes: 2 additions & 2 deletions brainscore_vision/model_helpers/activations/temporal/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def stack_with_nan_padding_(arr_list, axis=0, dtype=np.float16):
return result


def stack_with_nan_padding(arr_list, axis=0, dtype=np.float16):
def stack_with_nan_padding(arr_list, axis=0, dtype=None):
# Get shapes of all arrays
shapes = [np.array(arr.shape) for arr in arr_list]
max_shape = np.max(shapes, axis=0)
Expand All @@ -58,7 +58,7 @@ def stack_with_nan_padding(arr_list, axis=0, dtype=np.float16):

result = np.stack(results, axis=axis)
result = np.swapaxes(result, 0, axis)
if result.dtype != dtype:
if dtype is not None and result.dtype != dtype:
result = result.astype(dtype)

return result
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,13 +170,13 @@ def look_at(self, stimuli, number_of_trials=1):
bin_responses = bin_responses.stack(time_bin=['time_bin_start', 'time_bin_end'])
time_responses.append(bin_responses)
responses = merge_data_arrays(time_responses)
responses = fix_timebin_naming(responses)
else:
# for temporal models, align the time bins
responses = assembly_time_align(responses, self._time_bins)

if len(self._time_bins) == 1:
responses = responses.squeeze('time_bin')
responses = fix_timebin_naming(responses)
return responses

@property
Expand Down
17 changes: 17 additions & 0 deletions brainscore_vision/models/temporal_model_AVID-CMA/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from brainscore_vision import model_registry
from brainscore_vision.model_helpers.brain_transformation import ModelCommitment
from brainscore_vision.model_helpers.activations.temporal.utils import get_specified_layers
from brainscore_vision.model_interface import BrainModel
from . import model


def commit_model(identifier):
activations_model=model.get_model(identifier)
layers=get_specified_layers(activations_model)
return ModelCommitment(identifier=identifier, activations_model=activations_model, layers=layers)


model_registry["AVID-CMA-Kinetics400"] = lambda: commit_model("AVID-CMA-Kinetics400")
model_registry["AVID-CMA-Audioset"] = lambda: commit_model("AVID-CMA-Audioset")
model_registry["AVID-Kinetics400"] = lambda: commit_model("AVID-Kinetics400")
model_registry["AVID-Audioset"] = lambda: commit_model("AVID-Audioset")
92 changes: 92 additions & 0 deletions brainscore_vision/models/temporal_model_AVID-CMA/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import yaml
import os

import torch

import avid_cma
from avid_cma.utils.logger import Logger
from avid_cma.utils import main_utils
from avid_cma.datasets import preprocessing

from brainscore_vision.model_helpers.activations.temporal.model import PytorchWrapper
from brainscore_vision.model_helpers.s3 import load_weight_file


HOME = os.path.dirname(os.path.abspath(avid_cma.__file__))

def get_model(identifier):

if identifier == 'AVID-CMA-Kinetics400':
cfg_path = os.path.join(HOME, "configs/main/avid-cma/kinetics/InstX-N1024-PosW-N64-Top32.yaml")
weight_path = load_weight_file(
bucket="brainscore-vision",
relative_path="temporal_model_AVID-CMA/AVID-CMA_Kinetics_InstX-N1024-PosW-N64-Top32_checkpoint.pth.tar",
version_id="yx9Pbq3SuNOOd4sX7csTolaHD1iTCx8y",
sha1="6efe4464ca654a56affff766acf24e89e6f3ffbf"
)

elif identifier == 'AVID-CMA-Audioset':
cfg_path = os.path.join(HOME, "configs/main/avid-cma/audioset/InstX-N1024-PosW-N64-Top32.yaml")
weight_path = load_weight_file(
bucket="brainscore-vision",
relative_path="temporal_model_AVID-CMA/AVID-CMA_Audioset_InstX-N1024-PosW-N64-Top32_checkpoint.pth.tar",
version_id="jSaZgbUohM0ZeoEUUKZiLBo6iz_v8VvQ",
sha1="9db5eba9aab6bdbb74025be57ab532df808fe3f6"
)

elif identifier == 'AVID-Kinetics400':
cfg_path = os.path.join(HOME, "configs/main/avid/kinetics/Cross-N1024.yaml")
weight_path = load_weight_file(
bucket="brainscore-vision",
relative_path="temporal_model_AVID-CMA/AVID_Kinetics_Cross-N1024_checkpoint.pth.tar",
version_id="XyKt0UOUFsuuyrl6ZREivK8FadRPx34u",
sha1="d3a04f856d29421ba8de37808593a3fad4d4794f"
)

elif identifier == 'AVID-Audioset':
cfg_path = os.path.join(HOME, "configs/main/avid/audioset/Cross-N1024.yaml")
weight_path = load_weight_file(
bucket="brainscore-vision",
relative_path="temporal_model_AVID-CMA/AVID_Audioset_Cross-N1024_checkpoint.pth.tar",
version_id="0Sxuhn8LsYXQC4FnPfJ7rw7uU6kDlKgc",
sha1="b48d8428a1a2526ccca070f810333df18bfce5fd"
)

else:
raise ValueError(f"Unknown model identifier: {identifier}")


cfg = yaml.safe_load(open(cfg_path))
cfg['model']['args']['checkpoint'] = weight_path
logger = Logger()

# Define model
model = main_utils.build_model(cfg['model'], logger)

# take only video model
model = model.video_model

# Define dataloaders
db_cfg = cfg['dataset']
print(db_cfg)

num_frames = int(db_cfg['video_clip_duration'] * db_cfg['video_fps'])

_video_transform = preprocessing.VideoPrep_Crop_CJ(
resize=(256, 256),
crop=(db_cfg['crop_size'], db_cfg['crop_size']),
augment=False,
num_frames=num_frames,
pad_missing=True,
)

def video_transform(video):
frames = video.to_pil_imgs()
return _video_transform(frames)

layer_activation_format = {
'conv1': 'CTHW',
**{f"conv{i}x": 'CTHW' for i in range(2, 6)},
}

return PytorchWrapper(identifier, model, video_transform, fps=db_cfg['video_fps'], layer_activation_format=layer_activation_format)
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
avid_cma @ git+https://github.com/YingtianDt/AVID-CMA.git
torch
torchvision
18 changes: 18 additions & 0 deletions brainscore_vision/models/temporal_model_AVID-CMA/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import pytest

from brainscore_vision import load_model


model_list = [
"AVID-CMA-Kinetics400",
"AVID-CMA-Audioset",
"AVID-Kinetics400",
"AVID-Audioset"
]

@pytest.mark.private_access
@pytest.mark.memory_intense
@pytest.mark.parametrize("model_identifier", model_list)
def test_load(model_identifier):
model = load_model(model_identifier)
assert model is not None
16 changes: 16 additions & 0 deletions brainscore_vision/models/temporal_model_GDT/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from brainscore_vision import model_registry
from brainscore_vision.model_helpers.brain_transformation import ModelCommitment
from brainscore_vision.model_helpers.activations.temporal.utils import get_specified_layers
from brainscore_vision.model_interface import BrainModel
from . import model


def commit_model(identifier):
activations_model=model.get_model(identifier)
layers=get_specified_layers(activations_model)
return ModelCommitment(identifier=identifier, activations_model=activations_model, layers=layers)


model_registry["GDT-Kinetics400"] = lambda: commit_model("GDT-Kinetics400")
model_registry["GDT-HowTo100M"] = lambda: commit_model("GDT-HowTo100M")
model_registry["GDT-IG65M"] = lambda: commit_model("GDT-IG65M")
72 changes: 72 additions & 0 deletions brainscore_vision/models/temporal_model_GDT/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import torch

from gdt_model.model import GDT
from gdt_model.video_transforms import clip_augmentation

from brainscore_vision.model_helpers.activations.temporal.model import PytorchWrapper
from brainscore_vision.model_helpers.s3 import load_weight_file


def transform_video(video):
arr = video.to_numpy()
arr = torch.as_tensor(arr)
return clip_augmentation(arr)


def get_model(identifier):

assert identifier.startswith("GDT-")
dataset = "-".join(identifier.split("-")[1:])

if dataset == "Kinetics400":
pth = load_weight_file(
bucket="brainscore-vision",
relative_path="temporal_model_GDT/gdt_K400.pth",
version_id="JpU_tnCzrbTejn6sOrQMk8eRsJ97yFgt",
sha1="7f12c60670346b1aab15194eb44c341906e1bca6"
)
elif dataset == "IG65M":
pth = load_weight_file(
bucket="brainscore-vision",
relative_path="temporal_model_GDT/gdt_IG65M.pth",
version_id="R.NoD6VAbFbJdf8tg5jnXIWB3hQ8GlSD",
sha1="3dcee3af61691e1e7e47e4b115be6808f4ea8172"
)
elif dataset == "HowTo100M":
pth = load_weight_file(
bucket="brainscore-vision",
relative_path="temporal_model_GDT/gdt_HT100M.pth",
version_id="BVRl9t_134PoKZCn9W54cyfkImCW2ioq",
sha1="a9a979c82e83b955794814923af736eb34e6f080"
)
else:
raise ValueError(f"Unknown dataset: {dataset}")

# Load model
model = GDT(
vid_base_arch="r2plus1d_18",
aud_base_arch="resnet9",
pretrained=False,
norm_feat=False,
use_mlp=False,
num_classes=256,
)

model = model.video_network # Remove audio network

# Load weights
state_dict_ = torch.load(pth, map_location="cpu")['model']
state_dict = {}
for k, v in list(state_dict_.items()):
if k.startswith("video_network."):
k = k[len("video_network."):]
state_dict[k] = v
model.load_state_dict(state_dict)

layer_activation_format = {
"base.stem": "CTHW",
**{f"base.layer{i}": "CTHW" for i in range(1, 5)},
# "base.fc": "C", # no fc
}

return PytorchWrapper(identifier, model, transform_video, fps=30, layer_activation_format=layer_activation_format)
3 changes: 3 additions & 0 deletions brainscore_vision/models/temporal_model_GDT/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
gdt_model @ git+https://github.com/YingtianDt/GDT.git
torch
torchvision
17 changes: 17 additions & 0 deletions brainscore_vision/models/temporal_model_GDT/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import pytest

from brainscore_vision import load_model


model_list = [
"GDT-Kinetics400",
"GDT-HowTo100M",
"GDT-IG65M",
]

@pytest.mark.private_access
@pytest.mark.memory_intense
@pytest.mark.parametrize("model_identifier", model_list)
def test_load(model_identifier):
model = load_model(model_identifier)
assert model is not None
Loading

0 comments on commit 7f99883

Please sign in to comment.