From ebd3c0ac70e75674d15979277ba65580fed6c121 Mon Sep 17 00:00:00 2001 From: YingtianDt Date: Tue, 14 May 2024 10:09:40 +0200 Subject: [PATCH] module refactor; add more input test --- .../model_helpers/activations/temporal/core/executor.py | 9 ++++++++- .../model_helpers/activations/temporal/inputs/image.py | 5 ++++- .../model_helpers/activations/temporal/inputs/video.py | 1 - .../model_helpers/activations/temporal/utils.py | 4 ++-- .../temporal/activations/test_inputs.py | 1 + 5 files changed, 15 insertions(+), 5 deletions(-) diff --git a/brainscore_vision/model_helpers/activations/temporal/core/executor.py b/brainscore_vision/model_helpers/activations/temporal/core/executor.py index fbf9adcdb..1695ccb0c 100644 --- a/brainscore_vision/model_helpers/activations/temporal/core/executor.py +++ b/brainscore_vision/model_helpers/activations/temporal/core/executor.py @@ -25,9 +25,16 @@ class JoblibMapper: def __init__(self, num_threads: int): self._num_threads = num_threads self._pool = Parallel(n_jobs=num_threads, verbose=False, backend="loky") + self._not_supported = False def map(self, func, *data): - return self._pool(delayed(func)(*x) for x in zip(*data)) + from joblib.externals.loky.process_executor import TerminatedWorkerError + if not self._not_supported: + try: + return self._pool(delayed(func)(*x) for x in zip(*data)) + except TerminatedWorkerError: + self._not_supported = True + return [func(*x) for x in zip(*data)] class BatchExecutor: diff --git a/brainscore_vision/model_helpers/activations/temporal/inputs/image.py b/brainscore_vision/model_helpers/activations/temporal/inputs/image.py index adb85d559..8e2d8df0c 100644 --- a/brainscore_vision/model_helpers/activations/temporal/inputs/image.py +++ b/brainscore_vision/model_helpers/activations/temporal/inputs/image.py @@ -27,10 +27,13 @@ def from_path(path): def to_img(self): return PILImage.fromarray(self.to_numpy()) + + def get_frame(self): + return np.array(PILImage.open(self._path).convert('RGB')) # return (H, W, C[RGB]) def to_numpy(self): - arr = np.array(PILImage.open(self._path).convert('RGB')) + arr = self.get_frame() if arr.shape[:2][::-1] != self._size: arr = batch_2d_resize(arr[None,:], self._size, "bilinear")[0] diff --git a/brainscore_vision/model_helpers/activations/temporal/inputs/video.py b/brainscore_vision/model_helpers/activations/temporal/inputs/video.py index 81a2b5b57..ca9d4f075 100644 --- a/brainscore_vision/model_helpers/activations/temporal/inputs/video.py +++ b/brainscore_vision/model_helpers/activations/temporal/inputs/video.py @@ -120,7 +120,6 @@ def get_frames(self, indices): ### I/O def from_path(path): - path = path fps, end, size = get_video_stats(path) start = 0 return Video(path, fps, start, end, size) diff --git a/brainscore_vision/model_helpers/activations/temporal/utils.py b/brainscore_vision/model_helpers/activations/temporal/utils.py index 821c63cb2..d1f7a264c 100644 --- a/brainscore_vision/model_helpers/activations/temporal/utils.py +++ b/brainscore_vision/model_helpers/activations/temporal/utils.py @@ -43,7 +43,7 @@ def stack_with_nan_padding_(arr_list, axis=0, dtype=np.float16): return result -def stack_with_nan_padding(arr_list, axis=0, dtype=np.float16): +def stack_with_nan_padding(arr_list, axis=0, dtype=None): # Get shapes of all arrays shapes = [np.array(arr.shape) for arr in arr_list] max_shape = np.max(shapes, axis=0) @@ -58,7 +58,7 @@ def stack_with_nan_padding(arr_list, axis=0, dtype=np.float16): result = np.stack(results, axis=axis) result = np.swapaxes(result, 0, axis) - if result.dtype != dtype: + if dtype is not None and result.dtype != dtype: result = result.astype(dtype) return result diff --git a/tests/test_model_helpers/temporal/activations/test_inputs.py b/tests/test_model_helpers/temporal/activations/test_inputs.py index a49504f8a..fb2cb7843 100644 --- a/tests/test_model_helpers/temporal/activations/test_inputs.py +++ b/tests/test_model_helpers/temporal/activations/test_inputs.py @@ -40,6 +40,7 @@ def test_video(): video3 = video1.set_window(-10, 0, padding="repeat") video4 = video1.set_window(-20, -10, padding="repeat") assert (video3.to_numpy() == video4.to_numpy()).all() + assert (video3.to_numpy()[0] == video1.to_numpy()[0]).all() assert video2.fps == 30 assert video2.set_fps(1).to_numpy().shape[0] == 1