From 7f9a76a0cddb2be30d125413dff259dcfea2dc9b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Matthias=20K=C3=BCmmerer?= Date: Tue, 2 Apr 2024 13:14:11 +0200 Subject: [PATCH] refactor methods to use FixationTrains.scanpaths where applicable MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Matthias Kümmerer --- pysaliency/datasets.py | 246 ++++++++++++++++++++++++++--------------- tests/test_datasets.py | 2 +- 2 files changed, 160 insertions(+), 88 deletions(-) diff --git a/pysaliency/datasets.py b/pysaliency/datasets.py index 19a0e74..49d8c2d 100644 --- a/pysaliency/datasets.py +++ b/pysaliency/datasets.py @@ -297,10 +297,12 @@ def filter(self, inds): other_attributes = {} def filter_array(name): + print("Filtering", name) kwargs[name] = getattr(self, name)[inds].copy() for name in ['x', 'y', 't', 'x_hist', 'y_hist', 't_hist', 'n']: filter_array(name) + for name in self.__attributes__: filter_array(name) if name != 'subjects': @@ -481,19 +483,38 @@ def __init__(self, train_xs, train_ys, train_ts, train_ns, train_subjects, scanp self.__attributes__ = list(self.__attributes__) self.__attributes__.append('scanpath_index') - if not isinstance(train_xs, VariableLengthArray): - self.train_lengths = np.array([len(remove_trailing_nans(train_x)) for train_x in train_xs]) - else: - self.train_lengths = train_xs.lengths.copy() + scanpath_attributes = scanpath_attributes or {} + scanpath_attribute_mapping = scanpath_attribute_mapping or {} + scanpath_fixation_attributes = scanpath_fixation_attributes or {} + + if 'subject' in scanpath_attributes and train_subjects is not None: + raise ValueError("subject should not be in scanpath_attributes if train_subjects is specified") + if 'subject' not in scanpath_attributes: + scanpath_attributes['subject'] = train_subjects + + if 'ts' in scanpath_fixation_attributes and train_ts is not None: + raise ValueError("ts should not be in scanpath_fixation_attributes if train_ts is specified") + if 'ts' not in scanpath_fixation_attributes: + scanpath_fixation_attributes['ts'] = train_ts + scanpath_attribute_mapping['ts'] = 't' + + lengths = [len(remove_trailing_nans(xs)) for xs in train_xs] + + scanpaths = Scanpaths( + xs=train_xs, + ys=train_ys, + n=train_ns, + lengths=lengths, + scanpath_attributes=scanpath_attributes, + fixation_attributes=scanpath_fixation_attributes, + attribute_mapping=scanpath_attribute_mapping, + ) + + self.scanpaths = scanpaths - self.train_xs = self._as_variable_length_scanpath_array(train_xs) - self.train_ys = self._as_variable_length_scanpath_array(train_ys) - self.train_ts = self._as_variable_length_scanpath_array(train_ts) - self.train_ns = train_ns - self.train_subjects = train_subjects - N_fixations = self.train_lengths.sum() - max_length_trains = self.train_lengths.max() if len(self.train_lengths) else 0 - max_history_length = max(max_length_trains - 1, 0) + N_fixations = scanpaths.lengths.sum() + max_scanpath_length = scanpaths.lengths.max() if len(self.scanpaths) else 0 + max_history_length = max(max_scanpath_length - 1, 0) # Create conditional fixations @@ -508,27 +529,27 @@ def __init__(self, train_xs, train_ys, train_ts, train_ns, train_subjects, scanp self.t_hist[:] = np.nan self.n = np.empty(N_fixations, dtype=int) self.lengths = np.empty(N_fixations, dtype=int) - self.train_lengths = np.empty(len(self.train_xs), dtype=int) + # self.train_lengths = np.empty(len(self.train_xs), dtype=int) self.subjects = np.empty(N_fixations, dtype=int) self.scanpath_index = np.empty(N_fixations, dtype=int) out_index = 0 # TODO: maybe implement in numba? # probably best: have function fill_fixation_data(scanpath_data, fixation_data, hist_data=None) - for train_index in range(len(self.train_xs)): - fix_length = len(remove_trailing_nans(self.train_xs[train_index])) - self.train_lengths[train_index] = fix_length - for fix_index in range(fix_length): - self.x[out_index] = self.train_xs[train_index][fix_index] - self.y[out_index] = self.train_ys[train_index][fix_index] - self.t[out_index] = self.train_ts[train_index][fix_index] - self.n[out_index] = self.train_ns[train_index] - self.subjects[out_index] = self.train_subjects[train_index] + for train_index in range(len(self.scanpaths)): + #fix_length = len(remove_trailing_nans(self.train_xs[train_index])) + # self.train_lengths[train_index] = fix_length + for fix_index in range(self.scanpaths.lengths[train_index]): + self.x[out_index] = self.scanpaths.xs[train_index][fix_index] + self.y[out_index] = self.scanpaths.ys[train_index][fix_index] + self.t[out_index] = self.scanpaths.ts[train_index][fix_index] + self.n[out_index] = self.scanpaths.n[train_index] + self.subjects[out_index] = self.scanpaths.scanpath_attributes['subject'][train_index] self.lengths[out_index] = fix_index self.scanpath_index[out_index] = train_index - self.x_hist[out_index][:fix_index] = self.train_xs[train_index][:fix_index] - self.y_hist[out_index][:fix_index] = self.train_ys[train_index][:fix_index] - self.t_hist[out_index][:fix_index] = self.train_ts[train_index][:fix_index] + self.x_hist[out_index][:fix_index] = self.scanpaths.xs[train_index][:fix_index] + self.y_hist[out_index][:fix_index] = self.scanpaths.ys[train_index][:fix_index] + self.t_hist[out_index][:fix_index] = self.scanpaths.ts[train_index][:fix_index] out_index += 1 # TODO: this should become irrelevant once FixationTrains is also completely upgraded to VariableLengthArrays @@ -536,23 +557,23 @@ def __init__(self, train_xs, train_ys, train_ts, train_ns, train_subjects, scanp self.y_hist = VariableLengthArray(self.y_hist, self.lengths) self.t_hist = VariableLengthArray(self.t_hist, self.lengths) - if scanpath_attributes is not None: - assert isinstance(scanpath_attributes, dict) - self.scanpath_attributes = {key: np.array(value) for key, value in scanpath_attributes.items()} - for key, value in self.scanpath_attributes.items(): - assert len(value) == len(self.train_xs) - else: - self.scanpath_attributes = {} + #if scanpath_attributes is not None: + # assert isinstance(scanpath_attributes, dict) + # self.scanpath_attributes = {key: np.array(value) for key, value in scanpath_attributes.items()} + # for key, value in self.scanpath_attributes.items(): + # assert len(value) == len(self.scanpaths) + #else: + # self.scanpath_attributes = {} - if scanpath_fixation_attributes is not None: - assert isinstance(scanpath_fixation_attributes, dict) - self.scanpath_fixation_attributes = {} - for key, value in scanpath_fixation_attributes.items(): - self.scanpath_fixation_attributes[key] = self._as_variable_length_scanpath_array(value) - else: - self.scanpath_fixation_attributes = {} + #if scanpath_fixation_attributes is not None: + # assert isinstance(scanpath_fixation_attributes, dict) + # self.scanpath_fixation_attributes = {} + # for key, value in scanpath_fixation_attributes.items(): + # self.scanpath_fixation_attributes[key] = self._as_variable_length_scanpath_array(value) + #else: + # self.scanpath_fixation_attributes = {} - self.scanpath_attribute_mapping = scanpath_attribute_mapping or {} + # self.scanpath_attribute_mapping = self.scanpaths.attribute_mapp if attributes is None: @@ -562,8 +583,10 @@ def __init__(self, train_xs, train_ys, train_ts, train_ns, train_subjects, scanp self.auto_attributes = [] - for attribute_name, value in self.scanpath_attributes.items(): - new_attribute_name = self.scanpath_attribute_mapping.get(attribute_name, attribute_name) + for attribute_name, value in self.scanpaths.scanpath_attributes.items(): + if attribute_name == 'subject': + continue + new_attribute_name = self.scanpaths.attribute_mapping.get(attribute_name, attribute_name) if new_attribute_name in attributes: raise ValueError("attribute name clash: {new_attribute_name}".format(new_attribute_name=new_attribute_name)) attribute_shape = [] if not value.any() else np.asarray(value[0]).shape @@ -571,15 +594,16 @@ def __init__(self, train_xs, train_ys, train_ts, train_ns, train_subjects, scanp self.auto_attributes.append(new_attribute_name) out_index = 0 - for train_index in range(len(self.train_xs)): - fix_length = (1 - np.isnan(self.train_xs[train_index])).sum() - for fix_index in range(fix_length): - attributes[new_attribute_name][out_index] = self.scanpath_attributes[attribute_name][train_index] + for train_index in range(len(self.scanpaths)): + for fix_index in range(self.scanpaths.lengths[train_index]): + attributes[new_attribute_name][out_index] = self.scanpaths.scanpath_attributes[attribute_name][train_index] out_index += 1 - for attribute_name, value in self.scanpath_fixation_attributes.items(): - new_attribute_name = self.scanpath_attribute_mapping.get(attribute_name, attribute_name) + for attribute_name, value in self.scanpaths.fixation_attributes.items(): + if attribute_name == 'ts': + continue + new_attribute_name = self.scanpaths.attribute_mapping.get(attribute_name, attribute_name) if new_attribute_name in attributes: raise ValueError("attribute name clash: {new_attribute_name}".format(new_attribute_name=new_attribute_name)) attributes[new_attribute_name] = np.empty(N_fixations) @@ -588,15 +612,14 @@ def __init__(self, train_xs, train_ys, train_ts, train_ns, train_subjects, scanp hist_attribute_name = new_attribute_name + '_hist' if hist_attribute_name in attributes: raise ValueError("attribute name clash: {hist_attribute_name}".format(hist_attribute_name=hist_attribute_name)) - attributes[hist_attribute_name] = np.full_like(self.x_hist._data, fill_value=np.nan) + attributes[hist_attribute_name] = np.full((N_fixations, max_history_length), fill_value=np.nan) self.auto_attributes.append(hist_attribute_name) out_index = 0 - for train_index in range(len(self.train_xs)): - fix_length = (1 - np.isnan(self.train_xs[train_index])).sum() - for fix_index in range(fix_length): - attributes[new_attribute_name][out_index] = self.scanpath_fixation_attributes[attribute_name][train_index, fix_index] - attributes[hist_attribute_name][out_index][:fix_index] = self.scanpath_fixation_attributes[attribute_name][train_index, :fix_index] + for train_index in range(len(self.scanpaths)): + for fix_index in range(self.scanpaths.lengths[train_index]): + attributes[new_attribute_name][out_index] = self.scanpaths.fixation_attributes[attribute_name][train_index, fix_index] + attributes[hist_attribute_name][out_index][:fix_index] = self.scanpaths.fixation_attributes[attribute_name][train_index, :fix_index] out_index += 1 attributes[hist_attribute_name] = VariableLengthArray(attributes[hist_attribute_name], self.lengths) @@ -606,6 +629,7 @@ def __init__(self, train_xs, train_ys, train_ts, train_ns, train_subjects, scanp for key, value in attributes.items(): assert key != 'subjects' assert key != 'scanpath_index' + assert key != 't' assert len(value) == len(self.x) self.__attributes__.append(key) if not isinstance(value, VariableLengthArray): @@ -614,19 +638,61 @@ def __init__(self, train_xs, train_ys, train_ts, train_ns, train_subjects, scanp self.full_nonfixations = None - def _check_train_lengths(self, other: VariableLengthArray): - if not len(self.train_lengths) == len(other): - raise ValueError("Length of scanpaths has to match") - if not np.all(self.train_lengths == other.lengths): - raise ValueError("Lengths of scanpaths have to match") + # def _check_train_lengths(self, other: VariableLengthArray): + # if not len(self.train_lengths) == len(other): + # raise ValueError("Length of scanpaths has to match") + # if not np.all(self.train_lengths == other.lengths): + # raise ValueError("Lengths of scanpaths have to match") - def _as_variable_length_scanpath_array(self, data: Union[np.ndarray, VariableLengthArray]) -> VariableLengthArray: - if not isinstance(data, VariableLengthArray): - data = VariableLengthArray(data, self.train_lengths) + # def _as_variable_length_scanpath_array(self, data: Union[np.ndarray, VariableLengthArray]) -> VariableLengthArray: + # if not isinstance(data, VariableLengthArray): + # data = VariableLengthArray(data, self.train_lengths) - self._check_train_lengths(data) + # self._check_train_lengths(data) - return data + # return data + + @property + def train_xs(self) -> VariableLengthArray: + return self.scanpaths.xs + + @property + def train_ys(self) -> VariableLengthArray: + return self.scanpaths.ys + + @property + def train_ts(self) -> VariableLengthArray: + return self.scanpaths.ts + + @property + def train_ns(self) -> np.ndarray: + return self.scanpaths.n + + @property + def train_subjects(self) -> VariableLengthArray: + return self.scanpaths.subject + + @property + def train_lengths(self) -> np.ndarray: + return self.scanpaths.lengths + + @property + def scanpath_attributes(self) -> Dict[str, np.ndarray]: + return { + key: value for key, value in self.scanpaths.scanpath_attributes.items() if key != 'subject' + } + + @property + def scanpath_fixation_attributes(self) -> Dict[str, VariableLengthArray]: + return { + key: value for key, value in self.scanpaths.fixation_attributes.items() if key != 'ts' + } + + @property + def scanpath_attribute_mapping(self) -> Dict[str, str]: + return { + key: value for key, value in self.scanpaths.attribute_mapping.items() if key != 'ts' + } @classmethod def concatenate(cls, fixation_trains): @@ -734,15 +800,10 @@ def filter_fixation_trains(self, indices): """ Create new fixations object which contains only the fixation trains indicated. """ - train_xs = self.train_xs[indices] - train_ys = self.train_ys[indices] - train_ts = self.train_ts[indices] - train_ns = self.train_ns[indices] - train_subjects = self.train_subjects[indices] - scanpath_attributes = {key: np.array(value)[indices] for key, value in self.scanpath_attributes.items()} - scanpath_fixation_attributes = {key: value[indices] for key, value in self.scanpath_fixation_attributes.items()} - - scanpath_indices = np.arange(len(self.train_xs), dtype=int)[indices] + + filtered_scanpaths = self.scanpaths[indices] + + scanpath_indices = np.arange(len(self.scanpaths), dtype=int)[indices] fixation_indices = np.in1d(self.scanpath_index, scanpath_indices) attributes = { @@ -750,17 +811,18 @@ def filter_fixation_trains(self, indices): } return type(self)( - train_xs, - train_ys, - train_ts, - train_ns, - train_subjects, + train_xs=filtered_scanpaths.xs, + train_ys=filtered_scanpaths.ys, + train_ts=None, # filtered_scanpaths.ts, + train_ns=filtered_scanpaths.n, + train_subjects=None, # filtered_scanpaths.subject, + scanpath_attributes=filtered_scanpaths.scanpath_attributes, + scanpath_fixation_attributes=filtered_scanpaths.fixation_attributes, + scanpath_attribute_mapping=dict(filtered_scanpaths.attribute_mapping), attributes=attributes, - scanpath_attributes=scanpath_attributes, - scanpath_fixation_attributes=scanpath_fixation_attributes, - scanpath_attribute_mapping=dict(self.scanpath_attribute_mapping) ) + def fixation_trains(self): """Yield for every fixation train of the dataset: xs, ys, ts, n, subject @@ -1668,16 +1730,26 @@ def create_subset(stimuli, fixations, stimuli_indices): new_stimuli = stimuli[stimuli_indices] if isinstance(fixations, FixationTrains): - fix_inds = np.in1d(fixations.train_ns, stimuli_indices) - new_fixations = fixations.filter_fixation_trains(fix_inds) + fix_inds = np.in1d(fixations.scanpaths.n, stimuli_indices) index_list = list(stimuli_indices) new_pos = {i: index_list.index(i) for i in index_list} - new_fixation_train_ns = [new_pos[i] for i in new_fixations.train_ns] - new_fixations.train_ns = np.array(new_fixation_train_ns) - new_fixation_ns = [new_pos[i] for i in new_fixations.n] - new_fixations.n = np.array(new_fixation_ns) + new_image_indices = [new_pos[i] for i in fixations.scanpaths.n[fix_inds]] + + new_scanpaths = fixations.scanpaths[fix_inds] + + new_fixations = FixationTrains( + train_xs=new_scanpaths.xs, + train_ys=new_scanpaths.ys, + train_ts=None, # new_scanpaths.fixation_attributes['ts'], + train_ns=np.array(new_image_indices), + train_subjects=None, # new_scanpaths.scanpath_attributes['subject'], + scanpath_attributes=new_scanpaths.scanpath_attributes, + scanpath_fixation_attributes=new_scanpaths.fixation_attributes, + scanpath_attribute_mapping=new_scanpaths.attribute_mapping + ) + else: fix_inds = np.in1d(fixations.n, stimuli_indices) new_fixations = fixations[fix_inds] diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 6a92fe1..3b65df7 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -746,7 +746,7 @@ def test_concatenate_fixation_trains(fixation_trains): def test_concatenate_scanpaths(fixation_trains): fixation_trains2 = fixation_trains.copy() - del fixation_trains2.scanpath_attributes['task'] + del fixation_trains2.scanpaths.scanpath_attributes['task'] delattr(fixation_trains2, 'task') fixation_trains2.auto_attributes.remove('task') fixation_trains2.__attributes__.remove('task')