-
Notifications
You must be signed in to change notification settings - Fork 80
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
ebd3c0a
commit 85f5217
Showing
37 changed files
with
1,424 additions
and
0 deletions.
There are no files selected for viewing
17 changes: 17 additions & 0 deletions
17
brainscore_vision/models/temporal_model_AVID-CMA/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
from brainscore_vision import model_registry | ||
from brainscore_vision.model_helpers.brain_transformation import ModelCommitment | ||
from brainscore_vision.model_helpers.activations.temporal.utils import get_specified_layers | ||
from brainscore_vision.model_interface import BrainModel | ||
from . import model | ||
|
||
|
||
def commit_model(identifier): | ||
activations_model=model.get_model(identifier) | ||
layers=get_specified_layers(activations_model) | ||
return ModelCommitment(identifier=identifier, activations_model=activations_model, layers=layers) | ||
|
||
|
||
model_registry["AVID-CMA-Kinetics400"] = lambda: commit_model("AVID-CMA-Kinetics400") | ||
model_registry["AVID-CMA-Audioset"] = lambda: commit_model("AVID-CMA-Audioset") | ||
model_registry["AVID-Kinetics400"] = lambda: commit_model("AVID-Kinetics400") | ||
model_registry["AVID-Audioset"] = lambda: commit_model("AVID-Audioset") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
import yaml | ||
import os | ||
|
||
import torch | ||
|
||
import avid_cma | ||
from avid_cma.utils.logger import Logger | ||
from avid_cma.utils import main_utils | ||
from avid_cma.datasets import preprocessing | ||
|
||
from brainscore_vision.model_helpers.activations.temporal.model import PytorchWrapper | ||
from brainscore_vision.model_helpers.s3 import load_weight_file | ||
|
||
|
||
HOME = os.path.dirname(os.path.abspath(avid_cma.__file__)) | ||
|
||
def get_model(identifier): | ||
|
||
if identifier == 'AVID-CMA-Kinetics400': | ||
cfg_path = os.path.join(HOME, "configs/main/avid-cma/kinetics/InstX-N1024-PosW-N64-Top32.yaml") | ||
weight_path = load_weight_file( | ||
bucket="brainscore-vision", | ||
relative_path="temporal_model_AVID-CMA/AVID-CMA_Kinetics_InstX-N1024-PosW-N64-Top32_checkpoint.pth.tar", | ||
version_id="yx9Pbq3SuNOOd4sX7csTolaHD1iTCx8y", | ||
sha1="6efe4464ca654a56affff766acf24e89e6f3ffbf" | ||
) | ||
|
||
elif identifier == 'AVID-CMA-Audioset': | ||
cfg_path = os.path.join(HOME, "configs/main/avid-cma/audioset/InstX-N1024-PosW-N64-Top32.yaml") | ||
weight_path = load_weight_file( | ||
bucket="brainscore-vision", | ||
relative_path="temporal_model_AVID-CMA/AVID-CMA_Audioset_InstX-N1024-PosW-N64-Top32_checkpoint.pth.tar", | ||
version_id="jSaZgbUohM0ZeoEUUKZiLBo6iz_v8VvQ", | ||
sha1="9db5eba9aab6bdbb74025be57ab532df808fe3f6" | ||
) | ||
|
||
elif identifier == 'AVID-Kinetics400': | ||
cfg_path = os.path.join(HOME, "configs/main/avid/kinetics/Cross-N1024.yaml") | ||
weight_path = load_weight_file( | ||
bucket="brainscore-vision", | ||
relative_path="temporal_model_AVID-CMA/AVID_Kinetics_Cross-N1024_checkpoint.pth.tar", | ||
version_id="XyKt0UOUFsuuyrl6ZREivK8FadRPx34u", | ||
sha1="d3a04f856d29421ba8de37808593a3fad4d4794f" | ||
) | ||
|
||
elif identifier == 'AVID-Audioset': | ||
cfg_path = os.path.join(HOME, "configs/main/avid/audioset/Cross-N1024.yaml") | ||
weight_path = load_weight_file( | ||
bucket="brainscore-vision", | ||
relative_path="temporal_model_AVID-CMA/AVID_Audioset_Cross-N1024_checkpoint.pth.tar", | ||
version_id="0Sxuhn8LsYXQC4FnPfJ7rw7uU6kDlKgc", | ||
sha1="b48d8428a1a2526ccca070f810333df18bfce5fd" | ||
) | ||
|
||
else: | ||
raise ValueError(f"Unknown model identifier: {identifier}") | ||
|
||
|
||
cfg = yaml.safe_load(open(cfg_path)) | ||
cfg['model']['args']['checkpoint'] = weight_path | ||
logger = Logger() | ||
|
||
# Define model | ||
model = main_utils.build_model(cfg['model'], logger) | ||
|
||
# take only video model | ||
model = model.video_model | ||
|
||
# Define dataloaders | ||
db_cfg = cfg['dataset'] | ||
print(db_cfg) | ||
|
||
num_frames = int(db_cfg['video_clip_duration'] * db_cfg['video_fps']) | ||
|
||
_video_transform = preprocessing.VideoPrep_Crop_CJ( | ||
resize=(256, 256), | ||
crop=(db_cfg['crop_size'], db_cfg['crop_size']), | ||
augment=False, | ||
num_frames=num_frames, | ||
pad_missing=True, | ||
) | ||
|
||
def video_transform(video): | ||
frames = video.to_pil_imgs() | ||
return _video_transform(frames) | ||
|
||
layer_activation_format = { | ||
'conv1': 'CTHW', | ||
**{f"conv{i}x": 'CTHW' for i in range(2, 6)}, | ||
} | ||
|
||
return PytorchWrapper(identifier, model, video_transform, fps=db_cfg['video_fps'], layer_activation_format=layer_activation_format) |
3 changes: 3 additions & 0 deletions
3
brainscore_vision/models/temporal_model_AVID-CMA/requirements.txt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
avid_cma @ git+https://github.com/YingtianDt/AVID-CMA.git | ||
torch | ||
torchvision |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
import pytest | ||
|
||
from brainscore_vision import load_model | ||
|
||
|
||
model_list = [ | ||
"AVID-CMA-Kinetics400", | ||
"AVID-CMA-Audioset", | ||
"AVID-Kinetics400", | ||
"AVID-Audioset" | ||
] | ||
|
||
@pytest.mark.private_access | ||
@pytest.mark.memory_intense | ||
@pytest.mark.parametrize("model_identifier", model_list) | ||
def test_load(model_identifier): | ||
model = load_model(model_identifier) | ||
assert model is not None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
from brainscore_vision import model_registry | ||
from brainscore_vision.model_helpers.brain_transformation import ModelCommitment | ||
from brainscore_vision.model_helpers.activations.temporal.utils import get_specified_layers | ||
from brainscore_vision.model_interface import BrainModel | ||
from . import model | ||
|
||
|
||
def commit_model(identifier): | ||
activations_model=model.get_model(identifier) | ||
layers=get_specified_layers(activations_model) | ||
return ModelCommitment(identifier=identifier, activations_model=activations_model, layers=layers) | ||
|
||
|
||
model_registry["GDT-Kinetics400"] = lambda: commit_model("GDT-Kinetics400") | ||
model_registry["GDT-HowTo100M"] = lambda: commit_model("GDT-HowTo100M") | ||
model_registry["GDT-IG65M"] = lambda: commit_model("GDT-IG65M") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
import torch | ||
|
||
from gdt_model.model import GDT | ||
from gdt_model.video_transforms import clip_augmentation | ||
|
||
from brainscore_vision.model_helpers.activations.temporal.model import PytorchWrapper | ||
from brainscore_vision.model_helpers.s3 import load_weight_file | ||
|
||
|
||
def transform_video(video): | ||
arr = video.to_numpy() | ||
arr = torch.as_tensor(arr) | ||
return clip_augmentation(arr) | ||
|
||
|
||
def get_model(identifier): | ||
|
||
assert identifier.startswith("GDT-") | ||
dataset = "-".join(identifier.split("-")[1:]) | ||
|
||
if dataset == "Kinetics400": | ||
pth = load_weight_file( | ||
bucket="brainscore-vision", | ||
relative_path="temporal_model_GDT/gdt_K400.pth", | ||
version_id="JpU_tnCzrbTejn6sOrQMk8eRsJ97yFgt", | ||
sha1="7f12c60670346b1aab15194eb44c341906e1bca6" | ||
) | ||
elif dataset == "IG65M": | ||
pth = load_weight_file( | ||
bucket="brainscore-vision", | ||
relative_path="temporal_model_GDT/gdt_IG65M.pth", | ||
version_id="R.NoD6VAbFbJdf8tg5jnXIWB3hQ8GlSD", | ||
sha1="3dcee3af61691e1e7e47e4b115be6808f4ea8172" | ||
) | ||
elif dataset == "HowTo100M": | ||
pth = load_weight_file( | ||
bucket="brainscore-vision", | ||
relative_path="temporal_model_GDT/gdt_HT100M.pth", | ||
version_id="BVRl9t_134PoKZCn9W54cyfkImCW2ioq", | ||
sha1="a9a979c82e83b955794814923af736eb34e6f080" | ||
) | ||
else: | ||
raise ValueError(f"Unknown dataset: {dataset}") | ||
|
||
# Load model | ||
model = GDT( | ||
vid_base_arch="r2plus1d_18", | ||
aud_base_arch="resnet9", | ||
pretrained=False, | ||
norm_feat=False, | ||
use_mlp=False, | ||
num_classes=256, | ||
) | ||
|
||
model = model.video_network # Remove audio network | ||
|
||
# Load weights | ||
state_dict_ = torch.load(pth, map_location="cpu")['model'] | ||
state_dict = {} | ||
for k, v in list(state_dict_.items()): | ||
if k.startswith("video_network."): | ||
k = k[len("video_network."):] | ||
state_dict[k] = v | ||
model.load_state_dict(state_dict) | ||
|
||
layer_activation_format = { | ||
"base.stem": "CTHW", | ||
**{f"base.layer{i}": "CTHW" for i in range(1, 5)}, | ||
# "base.fc": "C", # no fc | ||
} | ||
|
||
return PytorchWrapper(identifier, model, transform_video, fps=30, layer_activation_format=layer_activation_format) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
gdt_model @ git+https://github.com/YingtianDt/GDT.git | ||
torch | ||
torchvision |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
import pytest | ||
|
||
from brainscore_vision import load_model | ||
|
||
|
||
model_list = [ | ||
"GDT-Kinetics400", | ||
"GDT-HowTo100M", | ||
"GDT-IG65M", | ||
] | ||
|
||
@pytest.mark.private_access | ||
@pytest.mark.memory_intense | ||
@pytest.mark.parametrize("model_identifier", model_list) | ||
def test_load(model_identifier): | ||
model = load_model(model_identifier) | ||
assert model is not None |
14 changes: 14 additions & 0 deletions
14
brainscore_vision/models/temporal_model_S3D_text_video/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
from brainscore_vision import model_registry | ||
from brainscore_vision.model_helpers.brain_transformation import ModelCommitment | ||
from brainscore_vision.model_helpers.activations.temporal.utils import get_specified_layers | ||
from brainscore_vision.model_interface import BrainModel | ||
from . import model | ||
|
||
|
||
def commit_model(identifier): | ||
activations_model=model.get_model(identifier) | ||
layers=get_specified_layers(activations_model) | ||
return ModelCommitment(identifier=identifier, activations_model=activations_model, layers=layers) | ||
|
||
|
||
model_registry["s3d-HowTo100M"] = lambda: commit_model("s3d-HowTo100M") |
65 changes: 65 additions & 0 deletions
65
brainscore_vision/models/temporal_model_S3D_text_video/model.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
import torch | ||
import numpy as np | ||
from torchvision import transforms | ||
from s3dg_howto100m import S3D | ||
|
||
from brainscore_vision.model_helpers.activations.temporal.model.pytorch import PytorchWrapper | ||
from brainscore_vision.model_helpers.s3 import load_weight_file | ||
|
||
|
||
img_transform = transforms.Compose([ | ||
transforms.Resize((256, 256)), | ||
]) | ||
|
||
def transform_video(video): | ||
frames = video.to_numpy() / 255. | ||
frames = torch.Tensor(frames) | ||
frames = frames.permute(0, 3, 1, 2) | ||
frames = img_transform(frames) | ||
return frames.permute(1, 0, 2, 3) | ||
|
||
|
||
def get_model(identifier="s3d-HowTo100M"): | ||
inferencer_kwargs = { | ||
"fps": 24, # common YouTube frame rate | ||
"layer_activation_format": | ||
{ | ||
"conv1": "CTHW", | ||
"conv_2c": "CTHW", | ||
"mixed_3c": "CTHW", | ||
"mixed_4b": "CTHW", | ||
"mixed_4d": "CTHW", | ||
"mixed_4f": "CTHW", | ||
"mixed_5c": "CTHW", | ||
"fc": "C" | ||
}, | ||
} | ||
process_output = None | ||
|
||
model_name = identifier | ||
|
||
model_pth = load_weight_file( | ||
bucket="brainscore-vision", | ||
relative_path="temporal_model_S3D_text_video/s3d_howto100m.pth", | ||
version_id="hRp6I8bpwreIMUVL0H.zCdK0hqRggL7n", | ||
sha1="31e99d2a1cd48f2259ca75e719ac82c8b751ea75" | ||
) | ||
|
||
dict_pth = load_weight_file( | ||
bucket="brainscore-vision", | ||
relative_path="temporal_model_S3D_text_video/s3d_dict.npy", | ||
version_id="4NxVLe8DSL6Uue0F7e2rz8HZuOk.tkBI", | ||
sha1="d368ff7d397ec8240f1f963b5efe8ff245bac35f" | ||
) | ||
|
||
# Instantiate the model | ||
model = S3D(dict_pth, 512) | ||
|
||
# Load the model weights | ||
model.load_state_dict(torch.load(model_pth)) | ||
|
||
wrapper = PytorchWrapper(identifier, model, transform_video, | ||
process_output=process_output, | ||
**inferencer_kwargs) | ||
|
||
return wrapper |
1 change: 1 addition & 0 deletions
1
brainscore_vision/models/temporal_model_S3D_text_video/requirements.txt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
S3D_HowTo100M @ git+https://github.com/YingtianDt/S3D_HowTo100M |
15 changes: 15 additions & 0 deletions
15
brainscore_vision/models/temporal_model_S3D_text_video/test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
import pytest | ||
|
||
from brainscore_vision import load_model | ||
|
||
|
||
model_list = [ | ||
"s3d-HowTo100M", | ||
] | ||
|
||
@pytest.mark.private_access | ||
@pytest.mark.memory_intense | ||
@pytest.mark.parametrize("model_identifier", model_list) | ||
def test_load(model_identifier): | ||
model = load_model(model_identifier) | ||
assert model is not None |
17 changes: 17 additions & 0 deletions
17
brainscore_vision/models/temporal_model_SeLaVi/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
from brainscore_vision import model_registry | ||
from brainscore_vision.model_helpers.brain_transformation import ModelCommitment | ||
from brainscore_vision.model_helpers.activations.temporal.utils import get_specified_layers | ||
from brainscore_vision.model_interface import BrainModel | ||
from . import model | ||
|
||
|
||
def commit_model(identifier): | ||
activations_model=model.get_model(identifier) | ||
layers=get_specified_layers(activations_model) | ||
return ModelCommitment(identifier=identifier, activations_model=activations_model, layers=layers) | ||
|
||
|
||
model_registry["SeLaVi-Kinetics400"] = lambda: commit_model("SeLaVi-Kinetics400") | ||
model_registry["SeLaVi-Kinetics-Sound"] = lambda: commit_model("SeLaVi-Kinetics-Sound") | ||
model_registry["SeLaVi-VGG-Sound"] = lambda: commit_model("SeLaVi-VGG-Sound") | ||
model_registry["SeLaVi-AVE"] = lambda: commit_model("SeLaVi-AVE") |
Oops, something went wrong.