Skip to content

Commit

Permalink
refactor methods to use FixationTrains.scanpaths where applicable
Browse files Browse the repository at this point in the history
Signed-off-by: Matthias Kümmerer <[email protected]>
  • Loading branch information
matthias-k committed Apr 2, 2024
1 parent 3f8538c commit 7f9a76a
Show file tree
Hide file tree
Showing 2 changed files with 160 additions and 88 deletions.
246 changes: 159 additions & 87 deletions pysaliency/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,10 +297,12 @@ def filter(self, inds):
other_attributes = {}

def filter_array(name):
print("Filtering", name)
kwargs[name] = getattr(self, name)[inds].copy()

for name in ['x', 'y', 't', 'x_hist', 'y_hist', 't_hist', 'n']:
filter_array(name)

for name in self.__attributes__:
filter_array(name)
if name != 'subjects':
Expand Down Expand Up @@ -481,19 +483,38 @@ def __init__(self, train_xs, train_ys, train_ts, train_ns, train_subjects, scanp
self.__attributes__ = list(self.__attributes__)
self.__attributes__.append('scanpath_index')

if not isinstance(train_xs, VariableLengthArray):
self.train_lengths = np.array([len(remove_trailing_nans(train_x)) for train_x in train_xs])
else:
self.train_lengths = train_xs.lengths.copy()
scanpath_attributes = scanpath_attributes or {}
scanpath_attribute_mapping = scanpath_attribute_mapping or {}
scanpath_fixation_attributes = scanpath_fixation_attributes or {}

if 'subject' in scanpath_attributes and train_subjects is not None:
raise ValueError("subject should not be in scanpath_attributes if train_subjects is specified")
if 'subject' not in scanpath_attributes:
scanpath_attributes['subject'] = train_subjects

if 'ts' in scanpath_fixation_attributes and train_ts is not None:
raise ValueError("ts should not be in scanpath_fixation_attributes if train_ts is specified")
if 'ts' not in scanpath_fixation_attributes:
scanpath_fixation_attributes['ts'] = train_ts
scanpath_attribute_mapping['ts'] = 't'

lengths = [len(remove_trailing_nans(xs)) for xs in train_xs]

scanpaths = Scanpaths(
xs=train_xs,
ys=train_ys,
n=train_ns,
lengths=lengths,
scanpath_attributes=scanpath_attributes,
fixation_attributes=scanpath_fixation_attributes,
attribute_mapping=scanpath_attribute_mapping,
)

self.scanpaths = scanpaths

self.train_xs = self._as_variable_length_scanpath_array(train_xs)
self.train_ys = self._as_variable_length_scanpath_array(train_ys)
self.train_ts = self._as_variable_length_scanpath_array(train_ts)
self.train_ns = train_ns
self.train_subjects = train_subjects
N_fixations = self.train_lengths.sum()
max_length_trains = self.train_lengths.max() if len(self.train_lengths) else 0
max_history_length = max(max_length_trains - 1, 0)
N_fixations = scanpaths.lengths.sum()
max_scanpath_length = scanpaths.lengths.max() if len(self.scanpaths) else 0
max_history_length = max(max_scanpath_length - 1, 0)


# Create conditional fixations
Expand All @@ -508,51 +529,51 @@ def __init__(self, train_xs, train_ys, train_ts, train_ns, train_subjects, scanp
self.t_hist[:] = np.nan
self.n = np.empty(N_fixations, dtype=int)
self.lengths = np.empty(N_fixations, dtype=int)
self.train_lengths = np.empty(len(self.train_xs), dtype=int)
# self.train_lengths = np.empty(len(self.train_xs), dtype=int)
self.subjects = np.empty(N_fixations, dtype=int)
self.scanpath_index = np.empty(N_fixations, dtype=int)

out_index = 0
# TODO: maybe implement in numba?
# probably best: have function fill_fixation_data(scanpath_data, fixation_data, hist_data=None)
for train_index in range(len(self.train_xs)):
fix_length = len(remove_trailing_nans(self.train_xs[train_index]))
self.train_lengths[train_index] = fix_length
for fix_index in range(fix_length):
self.x[out_index] = self.train_xs[train_index][fix_index]
self.y[out_index] = self.train_ys[train_index][fix_index]
self.t[out_index] = self.train_ts[train_index][fix_index]
self.n[out_index] = self.train_ns[train_index]
self.subjects[out_index] = self.train_subjects[train_index]
for train_index in range(len(self.scanpaths)):
#fix_length = len(remove_trailing_nans(self.train_xs[train_index]))
# self.train_lengths[train_index] = fix_length
for fix_index in range(self.scanpaths.lengths[train_index]):
self.x[out_index] = self.scanpaths.xs[train_index][fix_index]
self.y[out_index] = self.scanpaths.ys[train_index][fix_index]
self.t[out_index] = self.scanpaths.ts[train_index][fix_index]
self.n[out_index] = self.scanpaths.n[train_index]
self.subjects[out_index] = self.scanpaths.scanpath_attributes['subject'][train_index]
self.lengths[out_index] = fix_index
self.scanpath_index[out_index] = train_index
self.x_hist[out_index][:fix_index] = self.train_xs[train_index][:fix_index]
self.y_hist[out_index][:fix_index] = self.train_ys[train_index][:fix_index]
self.t_hist[out_index][:fix_index] = self.train_ts[train_index][:fix_index]
self.x_hist[out_index][:fix_index] = self.scanpaths.xs[train_index][:fix_index]
self.y_hist[out_index][:fix_index] = self.scanpaths.ys[train_index][:fix_index]
self.t_hist[out_index][:fix_index] = self.scanpaths.ts[train_index][:fix_index]
out_index += 1

# TODO: this should become irrelevant once FixationTrains is also completely upgraded to VariableLengthArrays
self.x_hist = VariableLengthArray(self.x_hist, self.lengths)
self.y_hist = VariableLengthArray(self.y_hist, self.lengths)
self.t_hist = VariableLengthArray(self.t_hist, self.lengths)

if scanpath_attributes is not None:
assert isinstance(scanpath_attributes, dict)
self.scanpath_attributes = {key: np.array(value) for key, value in scanpath_attributes.items()}
for key, value in self.scanpath_attributes.items():
assert len(value) == len(self.train_xs)
else:
self.scanpath_attributes = {}
#if scanpath_attributes is not None:
# assert isinstance(scanpath_attributes, dict)
# self.scanpath_attributes = {key: np.array(value) for key, value in scanpath_attributes.items()}
# for key, value in self.scanpath_attributes.items():
# assert len(value) == len(self.scanpaths)
#else:
# self.scanpath_attributes = {}

if scanpath_fixation_attributes is not None:
assert isinstance(scanpath_fixation_attributes, dict)
self.scanpath_fixation_attributes = {}
for key, value in scanpath_fixation_attributes.items():
self.scanpath_fixation_attributes[key] = self._as_variable_length_scanpath_array(value)
else:
self.scanpath_fixation_attributes = {}
#if scanpath_fixation_attributes is not None:
# assert isinstance(scanpath_fixation_attributes, dict)
# self.scanpath_fixation_attributes = {}
# for key, value in scanpath_fixation_attributes.items():
# self.scanpath_fixation_attributes[key] = self._as_variable_length_scanpath_array(value)
#else:
# self.scanpath_fixation_attributes = {}

self.scanpath_attribute_mapping = scanpath_attribute_mapping or {}
# self.scanpath_attribute_mapping = self.scanpaths.attribute_mapp


if attributes is None:
Expand All @@ -562,24 +583,27 @@ def __init__(self, train_xs, train_ys, train_ts, train_ns, train_subjects, scanp

self.auto_attributes = []

for attribute_name, value in self.scanpath_attributes.items():
new_attribute_name = self.scanpath_attribute_mapping.get(attribute_name, attribute_name)
for attribute_name, value in self.scanpaths.scanpath_attributes.items():
if attribute_name == 'subject':
continue
new_attribute_name = self.scanpaths.attribute_mapping.get(attribute_name, attribute_name)
if new_attribute_name in attributes:
raise ValueError("attribute name clash: {new_attribute_name}".format(new_attribute_name=new_attribute_name))
attribute_shape = [] if not value.any() else np.asarray(value[0]).shape
attributes[new_attribute_name] = np.empty([N_fixations] + list(attribute_shape), dtype=value.dtype)
self.auto_attributes.append(new_attribute_name)

out_index = 0
for train_index in range(len(self.train_xs)):
fix_length = (1 - np.isnan(self.train_xs[train_index])).sum()
for fix_index in range(fix_length):
attributes[new_attribute_name][out_index] = self.scanpath_attributes[attribute_name][train_index]
for train_index in range(len(self.scanpaths)):
for fix_index in range(self.scanpaths.lengths[train_index]):
attributes[new_attribute_name][out_index] = self.scanpaths.scanpath_attributes[attribute_name][train_index]
out_index += 1


for attribute_name, value in self.scanpath_fixation_attributes.items():
new_attribute_name = self.scanpath_attribute_mapping.get(attribute_name, attribute_name)
for attribute_name, value in self.scanpaths.fixation_attributes.items():
if attribute_name == 'ts':
continue
new_attribute_name = self.scanpaths.attribute_mapping.get(attribute_name, attribute_name)
if new_attribute_name in attributes:
raise ValueError("attribute name clash: {new_attribute_name}".format(new_attribute_name=new_attribute_name))
attributes[new_attribute_name] = np.empty(N_fixations)
Expand All @@ -588,15 +612,14 @@ def __init__(self, train_xs, train_ys, train_ts, train_ns, train_subjects, scanp
hist_attribute_name = new_attribute_name + '_hist'
if hist_attribute_name in attributes:
raise ValueError("attribute name clash: {hist_attribute_name}".format(hist_attribute_name=hist_attribute_name))
attributes[hist_attribute_name] = np.full_like(self.x_hist._data, fill_value=np.nan)
attributes[hist_attribute_name] = np.full((N_fixations, max_history_length), fill_value=np.nan)
self.auto_attributes.append(hist_attribute_name)

out_index = 0
for train_index in range(len(self.train_xs)):
fix_length = (1 - np.isnan(self.train_xs[train_index])).sum()
for fix_index in range(fix_length):
attributes[new_attribute_name][out_index] = self.scanpath_fixation_attributes[attribute_name][train_index, fix_index]
attributes[hist_attribute_name][out_index][:fix_index] = self.scanpath_fixation_attributes[attribute_name][train_index, :fix_index]
for train_index in range(len(self.scanpaths)):
for fix_index in range(self.scanpaths.lengths[train_index]):
attributes[new_attribute_name][out_index] = self.scanpaths.fixation_attributes[attribute_name][train_index, fix_index]
attributes[hist_attribute_name][out_index][:fix_index] = self.scanpaths.fixation_attributes[attribute_name][train_index, :fix_index]
out_index += 1

attributes[hist_attribute_name] = VariableLengthArray(attributes[hist_attribute_name], self.lengths)
Expand All @@ -606,6 +629,7 @@ def __init__(self, train_xs, train_ys, train_ts, train_ns, train_subjects, scanp
for key, value in attributes.items():
assert key != 'subjects'
assert key != 'scanpath_index'
assert key != 't'
assert len(value) == len(self.x)
self.__attributes__.append(key)
if not isinstance(value, VariableLengthArray):
Expand All @@ -614,19 +638,61 @@ def __init__(self, train_xs, train_ys, train_ts, train_ns, train_subjects, scanp

self.full_nonfixations = None

def _check_train_lengths(self, other: VariableLengthArray):
if not len(self.train_lengths) == len(other):
raise ValueError("Length of scanpaths has to match")
if not np.all(self.train_lengths == other.lengths):
raise ValueError("Lengths of scanpaths have to match")
# def _check_train_lengths(self, other: VariableLengthArray):
# if not len(self.train_lengths) == len(other):
# raise ValueError("Length of scanpaths has to match")
# if not np.all(self.train_lengths == other.lengths):
# raise ValueError("Lengths of scanpaths have to match")

def _as_variable_length_scanpath_array(self, data: Union[np.ndarray, VariableLengthArray]) -> VariableLengthArray:
if not isinstance(data, VariableLengthArray):
data = VariableLengthArray(data, self.train_lengths)
# def _as_variable_length_scanpath_array(self, data: Union[np.ndarray, VariableLengthArray]) -> VariableLengthArray:
# if not isinstance(data, VariableLengthArray):
# data = VariableLengthArray(data, self.train_lengths)

self._check_train_lengths(data)
# self._check_train_lengths(data)

return data
# return data

@property
def train_xs(self) -> VariableLengthArray:
return self.scanpaths.xs

@property
def train_ys(self) -> VariableLengthArray:
return self.scanpaths.ys

@property
def train_ts(self) -> VariableLengthArray:
return self.scanpaths.ts

@property
def train_ns(self) -> np.ndarray:
return self.scanpaths.n

@property
def train_subjects(self) -> VariableLengthArray:
return self.scanpaths.subject

@property
def train_lengths(self) -> np.ndarray:
return self.scanpaths.lengths

@property
def scanpath_attributes(self) -> Dict[str, np.ndarray]:
return {
key: value for key, value in self.scanpaths.scanpath_attributes.items() if key != 'subject'
}

@property
def scanpath_fixation_attributes(self) -> Dict[str, VariableLengthArray]:
return {
key: value for key, value in self.scanpaths.fixation_attributes.items() if key != 'ts'
}

@property
def scanpath_attribute_mapping(self) -> Dict[str, str]:
return {
key: value for key, value in self.scanpaths.attribute_mapping.items() if key != 'ts'
}

@classmethod
def concatenate(cls, fixation_trains):
Expand Down Expand Up @@ -734,33 +800,29 @@ def filter_fixation_trains(self, indices):
"""
Create new fixations object which contains only the fixation trains indicated.
"""
train_xs = self.train_xs[indices]
train_ys = self.train_ys[indices]
train_ts = self.train_ts[indices]
train_ns = self.train_ns[indices]
train_subjects = self.train_subjects[indices]
scanpath_attributes = {key: np.array(value)[indices] for key, value in self.scanpath_attributes.items()}
scanpath_fixation_attributes = {key: value[indices] for key, value in self.scanpath_fixation_attributes.items()}

scanpath_indices = np.arange(len(self.train_xs), dtype=int)[indices]

filtered_scanpaths = self.scanpaths[indices]

scanpath_indices = np.arange(len(self.scanpaths), dtype=int)[indices]
fixation_indices = np.in1d(self.scanpath_index, scanpath_indices)

attributes = {
attribute_name: getattr(self, attribute_name)[fixation_indices] for attribute_name in self.__attributes__ if attribute_name not in ['subjects', 'scanpath_index'] + self.auto_attributes
}

return type(self)(
train_xs,
train_ys,
train_ts,
train_ns,
train_subjects,
train_xs=filtered_scanpaths.xs,
train_ys=filtered_scanpaths.ys,
train_ts=None, # filtered_scanpaths.ts,
train_ns=filtered_scanpaths.n,
train_subjects=None, # filtered_scanpaths.subject,
scanpath_attributes=filtered_scanpaths.scanpath_attributes,
scanpath_fixation_attributes=filtered_scanpaths.fixation_attributes,
scanpath_attribute_mapping=dict(filtered_scanpaths.attribute_mapping),
attributes=attributes,
scanpath_attributes=scanpath_attributes,
scanpath_fixation_attributes=scanpath_fixation_attributes,
scanpath_attribute_mapping=dict(self.scanpath_attribute_mapping)
)


def fixation_trains(self):
"""Yield for every fixation train of the dataset:
xs, ys, ts, n, subject
Expand Down Expand Up @@ -1668,16 +1730,26 @@ def create_subset(stimuli, fixations, stimuli_indices):

new_stimuli = stimuli[stimuli_indices]
if isinstance(fixations, FixationTrains):
fix_inds = np.in1d(fixations.train_ns, stimuli_indices)
new_fixations = fixations.filter_fixation_trains(fix_inds)
fix_inds = np.in1d(fixations.scanpaths.n, stimuli_indices)

index_list = list(stimuli_indices)
new_pos = {i: index_list.index(i) for i in index_list}

new_fixation_train_ns = [new_pos[i] for i in new_fixations.train_ns]
new_fixations.train_ns = np.array(new_fixation_train_ns)
new_fixation_ns = [new_pos[i] for i in new_fixations.n]
new_fixations.n = np.array(new_fixation_ns)
new_image_indices = [new_pos[i] for i in fixations.scanpaths.n[fix_inds]]

new_scanpaths = fixations.scanpaths[fix_inds]

new_fixations = FixationTrains(
train_xs=new_scanpaths.xs,
train_ys=new_scanpaths.ys,
train_ts=None, # new_scanpaths.fixation_attributes['ts'],
train_ns=np.array(new_image_indices),
train_subjects=None, # new_scanpaths.scanpath_attributes['subject'],
scanpath_attributes=new_scanpaths.scanpath_attributes,
scanpath_fixation_attributes=new_scanpaths.fixation_attributes,
scanpath_attribute_mapping=new_scanpaths.attribute_mapping
)

else:
fix_inds = np.in1d(fixations.n, stimuli_indices)
new_fixations = fixations[fix_inds]
Expand Down
2 changes: 1 addition & 1 deletion tests/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,7 +746,7 @@ def test_concatenate_fixation_trains(fixation_trains):
def test_concatenate_scanpaths(fixation_trains):
fixation_trains2 = fixation_trains.copy()

del fixation_trains2.scanpath_attributes['task']
del fixation_trains2.scanpaths.scanpath_attributes['task']
delattr(fixation_trains2, 'task')
fixation_trains2.auto_attributes.remove('task')
fixation_trains2.__attributes__.remove('task')
Expand Down

0 comments on commit 7f9a76a

Please sign in to comment.