diff --git a/brainscore_vision/model_helpers/activations/temporal/core/inferencer/base.py b/brainscore_vision/model_helpers/activations/temporal/core/inferencer/base.py index e1579c5e9..118845d8d 100644 --- a/brainscore_vision/model_helpers/activations/temporal/core/inferencer/base.py +++ b/brainscore_vision/model_helpers/activations/temporal/core/inferencer/base.py @@ -41,9 +41,10 @@ class Inferencer: For example, {"temp_conv": "TCHW", "spatial_conv": "CHW", "fc": "C"}. visual_degrees: float the visual degrees of the stimuli. - max_spatial_size: int + max_spatial_size: int/float the maximum spatial size of the activations. If the spatial size of the activations is larger than this value, the activations will be downsampled to this size. This is used to avoid the large memory consumption by the first layers of some model. + If float, resize the image based on this factor. dtype: np.dtype data type of the activations. batch_size: int @@ -77,7 +78,7 @@ def __init__( layer_activation_format : dict, stimulus_type : Stimulus, visual_degrees : float = 8., - max_spatial_size : int = None, + max_spatial_size : Union[int, float] = None, dtype : np.dtype = np.float16, batch_size : int = 64, batch_grouper : Callable[[Stimulus], Hashable] = None, @@ -89,6 +90,8 @@ def __init__( self.stimulus_type = stimulus_type self.layer_activation_format = layer_activation_format + if isinstance(max_spatial_size, float): + assert max_spatial_size < 1, "a proporational max_spatial_size should be < 1." self.max_spatial_size = max_spatial_size self.visual_degrees = visual_degrees self.dtype = dtype @@ -254,10 +257,18 @@ def _package(activation: np.array, dims): return ret def _compute_new_size(w, h, max_spatial_size): - if h > w: - new_h = max_spatial_size - new_w = int(w * new_h / h) + if isinstance(max_spatial_size, int): + if h > w: + new_h = max_spatial_size + new_w = int(w * new_h / h) + else: + new_w = max_spatial_size + new_h = int(h * new_w / w) else: - new_w = max_spatial_size - new_h = int(h * new_w / w) + new_h = int(h * max_spatial_size) + new_w = int(w * max_spatial_size) + + new_h = max(1, new_h) + new_w = max(1, new_w) + return new_h, new_w diff --git a/brainscore_vision/model_helpers/activations/temporal/core/inferencer/video/base.py b/brainscore_vision/model_helpers/activations/temporal/core/inferencer/video/base.py index 8c1e57a86..e1cf7a9fa 100644 --- a/brainscore_vision/model_helpers/activations/temporal/core/inferencer/video/base.py +++ b/brainscore_vision/model_helpers/activations/temporal/core/inferencer/video/base.py @@ -63,9 +63,9 @@ def __init__( duration : Union[float, Tuple[float, float]] = None, time_alignment : str = "evenly_spaced", convert_img_to_video : bool = True, - img_duration : float = 1000., + img_duration : float = 1000.0, batch_size : int = 32, - batch_grouper : Callable[[Video], Hashable] = lambda video: (video.duration, video.fps), # not including video.frame_size because most preprocessors will change the frame size to be the same + batch_grouper : Callable[[Video], Hashable] = lambda video: (round(video.duration, 6), video.fps), # not including video.frame_size because most preprocessors will change the frame size to be the same **kwargs, ): super().__init__(*args, stimulus_type=Video, batch_size=batch_size, @@ -83,9 +83,9 @@ def __init__( @property def identifier(self) -> str: - id = f"{super().identifier}.{self.time_aligner.__name__}.fps={self.fps}" + id = f"{super().identifier}.{self.time_aligner.__name__}.fps={float(self.fps)}" if self.convert_to_video: - id += f".img_dur={self.img_duration}" + id += f".img_dur={float(self.img_duration)}" return id def load_stimulus(self, path: Union[str, Path]) -> Video: @@ -129,6 +129,6 @@ def _make_range(self, num, type="num_frames"): def _check_video(self, video: Video): if self.num_frames is not None: estimated_num_frames = int(self.fps * video.duration / 1000) - assert self.num_frames[0] <= estimated_num_frames <= self.num_frames[1] + assert self.num_frames[0] <= estimated_num_frames <= self.num_frames[1], f"The number of frames must be within {self.num_frames}, but got {estimated_num_frames}" if self.duration is not None: - assert self.duration[0] <= video.duration <= self.duration[1] \ No newline at end of file + assert self.duration[0] <= video.duration <= self.duration[1], f"The duration must be within {self.duration}, but got {video.duration}" diff --git a/brainscore_vision/model_helpers/activations/temporal/inputs/video.py b/brainscore_vision/model_helpers/activations/temporal/inputs/video.py index ddea10f70..81a2b5b57 100644 --- a/brainscore_vision/model_helpers/activations/temporal/inputs/video.py +++ b/brainscore_vision/model_helpers/activations/temporal/inputs/video.py @@ -10,6 +10,8 @@ from brainscore_vision.model_helpers.activations.temporal.utils import batch_2d_resize +EPS = 1e-9 + def get_video_stats(video_path): cap = cv2.VideoCapture(video_path) length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) @@ -31,31 +33,53 @@ def get_image_stats(image_path): class Video(Stimulus): """Video object that represents a video clip.""" - def __init__(self, path: Union[str, Path], fps: float, start: float, end: float, size: Tuple[int, int]): + def __init__( + self, + path: Union[str, Path], + fps: float, + start: float, + end: float, + size: Tuple[int, int] + ): self._path = path self._fps = fps self._size = size - self._original_fps = self._fps self._start = start self._end = end + self._original_fps = None + self._original_duration = None + self._original_size = None + + def __getattribute__(self, key): + if key.startswith("_original_"): + if super().__getattribute__(key) is None: + self._original_fps, self._original_duration, self._original_size = get_video_stats(self._path) + return super().__getattribute__(key) def copy(self): # return view video = self.__class__(self._path, self._fps, self._start, self._end, self._size) + video._original_fps = self._original_fps + video._original_duration = self._original_duration + video._original_size = self._original_size return video @property def duration(self): # in ms return self._end - self._start - + @property def fps(self): return self._fps - + @property def num_frames(self): - return int(self.duration * self.fps/1000) + return int(self.duration * self.fps/1000 + EPS) + + @property + def original_num_frames(self): + return int(self._original_duration * self._original_fps/1000 + EPS) @property def frame_size(self): @@ -110,12 +134,13 @@ def to_numpy(self): # get the time stamps of frame samples start_frame = self._start * self._original_fps / 1000 end_frame = self._end * self._original_fps / 1000 - EPS = 1e-9 # avoid taking the last extra frame - samples = np.arange(start_frame, end_frame - EPS, self._original_fps/self._fps) + # avoid taking the last extra frame + samples = np.arange(start_frame, end_frame - EPS, self._original_fps/self.fps) sample_indices = samples.astype(int) # padding: repeat the first/last frame - sample_indices = np.clip(sample_indices, 0, self.num_frames-1) + original_num_frames = int(self._original_duration * self._original_fps/1000 + EPS) + sample_indices = np.clip(sample_indices, 0, original_num_frames-1) # actual sampling frames = self.get_frames(sample_indices) @@ -137,6 +162,21 @@ def to_path(self): path = None # make a temporal file raise NotImplementedError() return path + + def store_to_path(self, path): + # pick format based on path filename + if path.endswith(".avi"): + fourcc = cv2.VideoWriter_fourcc(*'XVID') + elif path.endswith(".mp4"): + fourcc = cv2.VideoWriter_fourcc(*'mp4v') + else: + raise ValueError("Unsupported video format.") + + out = cv2.VideoWriter(path, fourcc, self._fps, self._size) + for frame in self.to_frames(): + out.write(frame[...,::-1]) # to RGB + out.release() + return path class VideoFromImage(Video): diff --git a/brainscore_vision/models/temporal_model_VideoMAE/model.py b/brainscore_vision/models/temporal_model_VideoMAE/model.py index 81acc36b3..716fceaf1 100644 --- a/brainscore_vision/models/temporal_model_VideoMAE/model.py +++ b/brainscore_vision/models/temporal_model_VideoMAE/model.py @@ -7,6 +7,8 @@ from torchvision import transforms +LAYER_SELECTION_STEP = 2 + class VideoMAEv1Wrapper(PytorchWrapper): def forward(self, inputs): tensor = th.stack(inputs) @@ -78,7 +80,7 @@ def get_model(identifier, num_frames=16): "fps": 6.25, "layer_activation_format": { "encoder.patch_embed": "THWC", - **{f"encoder.blocks.{i}": "THWC" for i in range(num_blocks)}, + **{f"encoder.blocks.{i}": "THWC" for i in range(0, num_blocks, LAYER_SELECTION_STEP)}, }, "num_frames": num_frames, } diff --git a/tests/test_model_helpers/temporal/activations/test_inferencer.py b/tests/test_model_helpers/temporal/activations/test_inferencer.py index 822310e82..06e8a775d 100644 --- a/tests/test_model_helpers/temporal/activations/test_inferencer.py +++ b/tests/test_model_helpers/temporal/activations/test_inferencer.py @@ -129,19 +129,19 @@ def test_compute_temporal_context(): @pytest.mark.memory_intense @pytest.mark.parametrize("preprocess", ["normal", "downsample"]) -def test_causal_inferencer(preprocess): +@pytest.mark.parametrize("fps", [1, 40]) +def test_causal_inferencer(preprocess, fps): if preprocess == "normal": preprocess = dummy_preprocess else: preprocess = time_down_sample_preprocess - fps = 10 inferencer = CausalInferencer(dummy_get_features, dummy_preprocess, dummy_layer_activation_format, fps=fps, max_workers=1) model_assembly = inferencer(video_paths, layers=dummy_layers) assert model_assembly.sizes["time_bin"] == 6 * fps assert np.isclose(model_assembly['time_bin_end'].values[0] - model_assembly['time_bin_start'].values[0], 1000/fps) - assert inferencer._compute_temporal_context() == (100, np.inf) + assert inferencer._compute_temporal_context() == (1000/fps, np.inf) # manual computation check output_values = model_assembly.sel(stimulus_path=video_paths[1])\ @@ -159,12 +159,12 @@ def test_causal_inferencer(preprocess): @pytest.mark.memory_intense @pytest.mark.parametrize("preprocess", ["normal", "downsample"]) -def test_block_inferencer(preprocess): +@pytest.mark.parametrize("fps", [1, 40]) +def test_block_inferencer(preprocess, fps): if preprocess == "normal": preprocessing = dummy_preprocess else: preprocessing = time_down_sample_preprocess - fps = 10 inferencer = BlockInferencer(dummy_get_features, preprocessing, dummy_layer_activation_format, fps=fps, duration=(200, 4000), temporal_context_strategy="greedy", max_workers=1) model_assembly = inferencer(video_paths, layers=dummy_layers) diff --git a/tests/test_model_helpers/temporal/activations/test_inputs.py b/tests/test_model_helpers/temporal/activations/test_inputs.py index bedd655d6..a49504f8a 100644 --- a/tests/test_model_helpers/temporal/activations/test_inputs.py +++ b/tests/test_model_helpers/temporal/activations/test_inputs.py @@ -63,9 +63,28 @@ def test_video(): assert video8.duration == 100 assert (video8.to_numpy() == video1.set_window(300, 400).to_numpy()).all() + # test copy + video9 = video1.set_fps(5).copy().set_fps(30).copy() + assert (video9.to_numpy()[1] == video1.to_numpy()[2]).all() + assert (video9.to_numpy()[2] == video1.to_numpy()[4]).all() + + for frame in [10, 50, 100]: + time_start = 1000 / video1.fps * frame + video10 = video1.set_window(time_start, time_start+1000/video1.fps) + assert video10.to_numpy().shape[0] == 1 + assert (video10.to_numpy()[0] == video1.to_numpy()[frame]).all() + + video10 = video1.set_window(0, time_start+1000/video1.fps) + assert video10.to_numpy().shape[0] == frame+1 + assert (video10.to_numpy()[frame] == video1.to_numpy()[frame]).all() + + video10 = video1.set_window(time_start, video1.duration) + assert video10.to_numpy().shape[0] == video1.to_numpy().shape[0] - frame + assert (video10.to_numpy()[0] == video1.to_numpy()[frame]).all() + for fps in [7.5, 9, 1, 43, 1000/video1.duration, 1001/video1.duration]: - video9 = video1.set_fps(fps) - assert video9.to_numpy().shape[0] == np.ceil(video1.duration * fps / 1000) + video11 = video1.set_fps(fps) + assert video11.to_numpy().shape[0] == np.ceil(video1.duration * fps / 1000) for v in [video1, video2]: target_num_frames = 7