From 08bad9dd69528f277d6236cac7619c1847cfc608 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Matthias=20K=C3=BCmmerer?= <matthias@matthias-k.org>
Date: Sat, 21 Sep 2024 21:11:05 +0200
Subject: [PATCH] ENH: Subsets of stimuli take over existing stimulus ids
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

When creating a subset of Stimuli, so far neither stimulus data
nor stimulus ids were taken over and had to be reloaded if required.
Now, at least stimulus ids are propagated which can save a lot of
memory, e.g. if splitting a large stimulus set into many small subsets.

Signed-off-by: Matthias Kümmerer <matthias@matthias-k.org>
---
 CHANGELOG.md                   |  1 +
 pysaliency/datasets/stimuli.py | 28 ++++++++++++++++++++++++----
 tests/datasets/test_stimuli.py | 21 ++++++++++++++++++++-
 3 files changed, 45 insertions(+), 5 deletions(-)

diff --git a/CHANGELOG.md b/CHANGELOG.md
index 55c50ac..2ef0a06 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -22,6 +22,7 @@
   * Bugfix!: The download location of the RARE2012 model changed. The new source code results in slightly different predictions.
   * Feature: The RARE2007 model is now available as `pysaliency.external_models.RARE2007`. It's execution requires MATLAB.
   * matlab scripts are now called with the `-batch` option instead of `-nodisplay -nosplash -r`, which should behave better.
+  * Enhancement: preloaded stimulus ids are passed on to subsets of Stimuli and FileStimuli.
 
 
 * 0.2.22:
diff --git a/pysaliency/datasets/stimuli.py b/pysaliency/datasets/stimuli.py
index 2de5ce5..0029d90 100644
--- a/pysaliency/datasets/stimuli.py
+++ b/pysaliency/datasets/stimuli.py
@@ -2,7 +2,7 @@
 import os
 from collections.abc import Sequence
 from hashlib import sha1
-from typing import Union
+from typing import List, Union
 
 import numpy as np
 
@@ -152,7 +152,12 @@ def _get_attribute_for_stimulus_subset(self, index):
     def __getitem__(self, index):
         if isinstance(index, slice):
             attributes = self._get_attribute_for_stimulus_subset(index)
-            return ObjectStimuli([self.stimulus_objects[i] for i in range(len(self))[index]], attributes=attributes)
+            sub_stimuli = ObjectStimuli([self.stimulus_objects[i] for i in range(len(self))[index]], attributes=attributes)
+
+            # populate stimulus_id cache with existing entries
+            self._propagate_stimulus_ids(sub_stimuli, range(len(self))[index])
+
+            return sub_stimuli
         elif isinstance(index, (list, np.ndarray)):
             index = np.asarray(index)
             if index.dtype == bool:
@@ -161,10 +166,20 @@ def __getitem__(self, index):
                 index = np.nonzero(index)[0]
 
             attributes = self._get_attribute_for_stimulus_subset(index)
-            return ObjectStimuli([self.stimulus_objects[i] for i in index], attributes=attributes)
+            sub_stimuli = ObjectStimuli([self.stimulus_objects[i] for i in index], attributes=attributes)
+
+            # populate stimulus_id cache with existing entries
+            self._propagate_stimulus_ids(sub_stimuli, index)
+
+            return sub_stimuli
         else:
             return self.stimulus_objects[index]
 
+    def _propagate_stimulus_ids(self, sub_stimuli: "Stimuli", index: List[int]):
+        for new_index, old_index in enumerate(index):
+            if old_index in self.stimulus_ids._cache:
+                sub_stimuli.stimulus_ids._cache[new_index] = self.stimulus_ids._cache[old_index]
+
     @hdf5_wrapper(mode='w')
     def to_hdf5(self, target, verbose=False, compression='gzip', compression_opts=9):
         """ Write stimuli to hdf5 file or hdf5 group
@@ -343,7 +358,12 @@ def __getitem__(self, index):
             filenames = [self.filenames[i] for i in index]
             shapes = [self.shapes[i] for i in index]
             attributes = self._get_attribute_for_stimulus_subset(index)
-            return type(self)(filenames=filenames, shapes=shapes, attributes=attributes, cached=self.cached)
+            sub_stimuli = type(self)(filenames=filenames, shapes=shapes, attributes=attributes, cached=self.cached)
+
+            # populate stimulus_id cache with existing entries
+            self._propagate_stimulus_ids(sub_stimuli, index)
+
+            return sub_stimuli
         else:
             return self.stimulus_objects[index]
 
diff --git a/tests/datasets/test_stimuli.py b/tests/datasets/test_stimuli.py
index 6f63c32..53d84e7 100644
--- a/tests/datasets/test_stimuli.py
+++ b/tests/datasets/test_stimuli.py
@@ -291,4 +291,23 @@ def test_check_prediction_shape():
     stimulus = Stimulus(np.random.rand(10, 11))
     with pytest.raises(ValueError) as excinfo:
         check_prediction_shape(prediction, stimulus)
-    assert str(excinfo.value) == "Prediction shape (10, 10) does not match stimulus shape (10, 11)"
\ No newline at end of file
+    assert str(excinfo.value) == "Prediction shape (10, 10) does not match stimulus shape (10, 11)"
+
+
+@pytest.mark.parametrize(
+        'stimuli',
+        ['stimuli_with_attributes', 'file_stimuli_with_attributes']
+)
+def test_substimuli_inherit_cachedstimulus_ids(stimuli, request):
+    _stimuli = request.getfixturevalue(stimuli)
+    # load some stimulus ids
+    cache_stimulus_indices = [1, 2, 5]
+    # make sure the ids are cached
+    for i in cache_stimulus_indices:
+        _stimuli.stimulus_ids[i]
+
+    assert len(_stimuli.stimulus_ids._cache) == len(cache_stimulus_indices)
+
+    sub_stimuli = _stimuli[1:5]
+    assert set(sub_stimuli.stimulus_ids._cache.keys()) == {0, 1}
+