Skip to content

Commit

Permalink
Adapt FixationTrain attributes
Browse files Browse the repository at this point in the history
Signed-off-by: Matthias Kümmerer <[email protected]>
  • Loading branch information
matthias-k committed Mar 29, 2024
1 parent a2bb711 commit 6fb8c06
Show file tree
Hide file tree
Showing 5 changed files with 279 additions and 111 deletions.
211 changes: 160 additions & 51 deletions pysaliency/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
from collections.abc import Sequence
from functools import wraps
from hashlib import sha1
from typing import Dict, Optional, Union
from typing import Dict, List, Optional, Union
from weakref import WeakValueDictionary

import numpy as np
from boltons.cacheutils import cached

from .utils.variable_length_array import VariableLengthArray
from .utils.variable_length_array import VariableLengthArray, concatenate_variable_length_arrays

try:
from imageio.v3 import imread
Expand Down Expand Up @@ -138,33 +138,75 @@ class Fixations(object):
"""
__attributes__ = ['subjects']

def __init__(self, x, y, t, x_hist, y_hist, t_hist, n, subjects, attributes=None):
x = np.asarray(x)
y = np.asarray(y)
t = np.asarray(t)
n = np.asarray(n)
x_hist = np.asarray(x_hist)
y_hist = np.asarray(y_hist)
t_hist = np.asarray(t_hist)
subjects = np.asarray(subjects)

self.x = x
self.y = y
self.t = t
def __init__(self,
x: Union[List, np.ndarray],
y: Union[List, np.ndarray],
t: Union[List, np.ndarray],
x_hist: Union[List, VariableLengthArray],
y_hist: Union[List, VariableLengthArray],
t_hist: Union[List, VariableLengthArray],
n: Union[List, np.ndarray],
subjects: Optional[Union[List, np.ndarray]] = None,
attributes: Optional[Dict[str, Union[np.ndarray, VariableLengthArray]]] = None):

self.x = np.asarray(x)
self.y = np.asarray(y)
self.t = np.asarray(t)
self.n = np.asarray(n)

# would be nice, is not yet supported. But we can simply pass the VariableLengthArray instead
# if isinstance(x_hist, list):
# x_hist = VariableLengthArray(x_hist)
# self.lengths = x_hist.lengths
if isinstance(x_hist, (list, np.ndarray)):
x_hist = np.array(x_hist)
self.lengths = (1 - np.isnan(x_hist)).sum(axis=-1)
x_hist = VariableLengthArray(x_hist, lengths=self.lengths)
elif isinstance(x_hist, VariableLengthArray):
self.lengths = x_hist.lengths


y_hist = self._as_variable_length_array(y_hist)
t_hist = self._as_variable_length_array(t_hist)

if subjects is not None:
subjects = np.asarray(subjects)

self.x_hist = x_hist
self.y_hist = y_hist
self.t_hist = t_hist
self.n = n
self.subjects = subjects
self.lengths = (1 - np.isnan(self.x_hist)).sum(axis=-1)

if not len(self.x) == len(self.y) == len(self.t) == len(self.x_hist) == len(self.y_hist) == len(self.t_hist) == len(self.n):
raise ValueError("Lengths of fixations have to match")
if self.subjects is not None and not len(self.x) == len(self.subjects):
raise ValueError("Length of subjects has to match number of fixations")

if attributes is not None:
self.__attributes__ = list(self.__attributes__)
for name, value in attributes.items():
if name not in self.__attributes__:
self.__attributes__.append(name)
if not len(value) == len(self.x):
raise ValueError(f"Length of attribute '{name}' has to match number of fixations")
setattr(self, name, value)


def _check_lengths(self, other: VariableLengthArray):
if not len(self) == len(other):
raise ValueError("Length of scanpaths has to match")
if not np.all(self.lengths == other.lengths):
raise ValueError("Lengths of scanpaths have to match")

def _as_variable_length_array(self, data: Union[np.ndarray, VariableLengthArray]) -> VariableLengthArray:
if not isinstance(data, VariableLengthArray):
data = VariableLengthArray(data, self.lengths)

self._check_lengths(data)

return data

@classmethod
def create_without_history(cls, x, y, n, subjects=None):
""" Create new fixation object from fixation data without time and optionally
Expand Down Expand Up @@ -256,6 +298,7 @@ def filter(self, inds):

def filter_array(name):
kwargs[name] = getattr(self, name)[inds].copy()

for name in ['x', 'y', 't', 'x_hist', 'y_hist', 't_hist', 'n']:
filter_array(name)
for name in self.__attributes__:
Expand Down Expand Up @@ -328,7 +371,7 @@ def subject_count(self):
def copy(self):
cfix = Fixations(self.x.copy(), self.y.copy(), self.t.copy(),
self.x_hist.copy(), self.y_hist.copy(), self.t_hist.copy(),
self.n.copy(), self.subjects.copy())
self.n.copy(), self.subjects.copy() if self.subjects is not None else None)
cfix.__attributes__ = list(self.__attributes__)
for name in self.__attributes__:
setattr(cfix, name, getattr(self, name).copy())
Expand All @@ -350,26 +393,34 @@ def to_hdf5(self, target):
"""

target.attrs['type'] = np.string_('Fixations')
target.attrs['version'] = np.string_('1.0')
target.attrs['version'] = np.string_('1.1')

for attribute in ['x', 'y', 't', 'x_hist', 'y_hist', 't_hist', 'n'] + self.__attributes__:
target.create_dataset(attribute, data=getattr(self, attribute))
variable_length_arrays = []

for attribute in ['x', 'y', 't', 'x_hist', 'y_hist', 't_hist', 'n', 'lengths'] + self.__attributes__:
data = getattr(self, attribute)
if isinstance(data, VariableLengthArray):
variable_length_arrays.append(attribute)
data = data._data
target.create_dataset(attribute, data=data)

#target.create_dataset('__attributes__', data=self.__attributes__)
target.attrs['__attributes__'] = np.string_(json.dumps(self.__attributes__))
target.attrs['__variable_length_arrays__'] = np.string_(json.dumps(sorted(variable_length_arrays)))

@classmethod
@hdf5_wrapper(mode='r')
def read_hdf5(cls, source):
""" Read fixations from hdf5 file or hdf5 group """

# TODO: rewrite to use constructor instead of manipulating the object directly

data_type = decode_string(source.attrs['type'])
data_version = decode_string(source.attrs['version'])

if data_type != 'Fixations':
raise ValueError("Invalid type! Expected 'Fixations', got", data_type)

if data_version != '1.0':
if data_version not in ['1.0', '1.1']:
raise ValueError("Invalid version! Expected '1.0', got", data_version)

data = {key: source[key][...] for key in ['x', 'y', 't', 'x_hist', 'y_hist', 't_hist', 'n', 'subjects']}
Expand All @@ -379,10 +430,27 @@ def read_hdf5(cls, source):
if not isinstance(json_attributes, str):
json_attributes = json_attributes.decode('utf8')
__attributes__ = json.loads(json_attributes)
fixations.__attributes__ == list(__attributes__)
fixations.__attributes__ = list(__attributes__)

if data_version == '1.1':
lengths = source['lengths'][...]

json_variable_length_arrays = source.attrs['__variable_length_arrays__']
if not isinstance(json_variable_length_arrays, str):
json_variable_length_arrays = json_variable_length_arrays.decode('utf8')
variable_length_arrays = json.loads(json_variable_length_arrays)

else:
lengths = fixations.lengths
variable_length_arrays = ['x_hist', 'y_hist', 't_hist'] + [key for key in __attributes__ if key.endswith('_hist')]

for key in __attributes__:
setattr(fixations, key, source[key][...])
data = source[key][...]

if key in variable_length_arrays:
data = VariableLengthArray(data, lengths)

setattr(fixations, key, data)

return fixations

Expand Down Expand Up @@ -420,25 +488,6 @@ def __init__(self, train_xs, train_ys, train_ts, train_ns, train_subjects, scanp
N_trains = self.train_xs.shape[0] * self.train_xs.shape[1] - np.isnan(self.train_xs).sum()
max_length_trains = self.train_xs.shape[1]

if scanpath_attributes is not None:
assert isinstance(scanpath_attributes, dict)
self.scanpath_attributes = {key: np.array(value) for key, value in scanpath_attributes.items()}
for key, value in self.scanpath_attributes.items():
assert len(value) == len(self.train_xs)
else:
self.scanpath_attributes = {}

if scanpath_fixation_attributes is not None:
assert isinstance(scanpath_fixation_attributes, dict)
self.scanpath_fixation_attributes = {key: np.array(value) for key, value in scanpath_fixation_attributes.items()}
for key, value in self.scanpath_fixation_attributes.items():
assert len(value) == len(self.train_xs)
else:
self.scanpath_fixation_attributes = {}

self.scanpath_attribute_mapping = scanpath_attribute_mapping or {}



# Create conditional fixations
self.x = np.empty(N_trains)
Expand Down Expand Up @@ -475,6 +524,29 @@ def __init__(self, train_xs, train_ys, train_ts, train_ns, train_subjects, scanp
self.t_hist[out_index][:fix_index] = self.train_ts[train_index][:fix_index]
out_index += 1

# TODO: this should become irrelevant once FixationTrains is also completely upgraded to VariableLengthArrays
self.x_hist = VariableLengthArray(self.x_hist, self.lengths)
self.y_hist = VariableLengthArray(self.y_hist, self.lengths)
self.t_hist = VariableLengthArray(self.t_hist, self.lengths)

if scanpath_attributes is not None:
assert isinstance(scanpath_attributes, dict)
self.scanpath_attributes = {key: np.array(value) for key, value in scanpath_attributes.items()}
for key, value in self.scanpath_attributes.items():
assert len(value) == len(self.train_xs)
else:
self.scanpath_attributes = {}

if scanpath_fixation_attributes is not None:
assert isinstance(scanpath_fixation_attributes, dict)
self.scanpath_fixation_attributes = {}
for key, value in scanpath_fixation_attributes.items():
self.scanpath_fixation_attributes[key] = self._as_variable_length_scanpath_array(value)
else:
self.scanpath_fixation_attributes = {}

self.scanpath_attribute_mapping = scanpath_attribute_mapping or {}


if attributes is None:
attributes = {}
Expand Down Expand Up @@ -509,7 +581,7 @@ def __init__(self, train_xs, train_ys, train_ts, train_ns, train_subjects, scanp
hist_attribute_name = new_attribute_name + '_hist'
if hist_attribute_name in attributes:
raise ValueError("attribute name clash: {hist_attribute_name}".format(hist_attribute_name=hist_attribute_name))
attributes[hist_attribute_name] = np.full_like(self.x_hist, fill_value=np.nan)
attributes[hist_attribute_name] = np.full_like(self.x_hist._data, fill_value=np.nan)
self.auto_attributes.append(hist_attribute_name)

out_index = 0
Expand All @@ -520,18 +592,34 @@ def __init__(self, train_xs, train_ys, train_ts, train_ns, train_subjects, scanp
attributes[hist_attribute_name][out_index][:fix_index] = self.scanpath_fixation_attributes[attribute_name][train_index, :fix_index]
out_index += 1

attributes[hist_attribute_name] = VariableLengthArray(attributes[hist_attribute_name], self.lengths)

if attributes:
self.__attributes__ = list(self.__attributes__)
for key, value in attributes.items():
assert key != 'subjects'
assert key != 'scanpath_index'
assert len(value) == len(self.x)
self.__attributes__.append(key)
value = np.array(value)
if not isinstance(value, VariableLengthArray):
value = np.array(value)
setattr(self, key, value)

self.full_nonfixations = None

def _check_train_lengths(self, other: VariableLengthArray):
if not len(self.train_xs) == len(other):
raise ValueError("Length of scanpaths has to match")
if not np.all(self.train_lengths == other.lengths):
raise ValueError("Lengths of scanpaths have to match")

def _as_variable_length_scanpath_array(self, data: Union[np.ndarray, VariableLengthArray]) -> VariableLengthArray:
if not isinstance(data, VariableLengthArray):
data = VariableLengthArray(data, self.train_lengths)

self._check_train_lengths(data)

return data

@classmethod
def concatenate(cls, fixation_trains):
Expand Down Expand Up @@ -645,7 +733,7 @@ def filter_fixation_trains(self, indices):
train_ns = self.train_ns[indices]
train_subjects = self.train_subjects[indices]
scanpath_attributes = {key: np.array(value)[indices] for key, value in self.scanpath_attributes.items()}
scanpath_fixation_attributes = {key: np.array(value)[indices] for key, value in self.scanpath_fixation_attributes.items()}
scanpath_fixation_attributes = {key: value[indices] for key, value in self.scanpath_fixation_attributes.items()}

scanpath_indices = np.arange(len(self.train_xs), dtype=int)[indices]
fixation_indices = np.in1d(self.scanpath_index, scanpath_indices)
Expand Down Expand Up @@ -907,12 +995,19 @@ def to_hdf5(self, target):
"""

target.attrs['type'] = np.string_('FixationTrains')
target.attrs['version'] = np.string_('1.2')
target.attrs['version'] = np.string_('1.3')

variable_length_arrays = []

for attribute in ['train_xs', 'train_ys', 'train_ts', 'train_ns', 'train_subjects'] + self.__attributes__:
for attribute in ['train_xs', 'train_ys', 'train_ts', 'train_ns', 'train_subjects', 'train_lengths'] + self.__attributes__:
if attribute in ['subjects', 'scanpath_index'] + self.auto_attributes:
continue
target.create_dataset(attribute, data=getattr(self, attribute))

data = getattr(self, attribute)
if isinstance(data, VariableLengthArray):
variable_length_arrays.append(attribute)
data = data._data
target.create_dataset(attribute, data=data)

saved_attributes = [attribute_name for attribute_name in self.__attributes__ if attribute_name not in self.auto_attributes]
target.attrs['__attributes__'] = np.string_(json.dumps(saved_attributes))
Expand All @@ -926,7 +1021,7 @@ def to_hdf5(self, target):

scanpath_fixation_attributes_group = target.create_group('scanpath_fixation_attributes')
for attribute_name, attribute_value in self.scanpath_fixation_attributes.items():
scanpath_fixation_attributes_group.create_dataset(attribute_name, data=attribute_value)
scanpath_fixation_attributes_group.create_dataset(attribute_name, data=attribute_value._data)
scanpath_fixation_attributes_group.attrs['__attributes__'] = np.string_(json.dumps(sorted(self.scanpath_fixation_attributes.keys())))


Expand All @@ -941,7 +1036,7 @@ def read_hdf5(cls, source):
if data_type != 'FixationTrains':
raise ValueError("Invalid type! Expected 'FixationTrains', got", data_type)

valid_versions = ['1.0', '1.1', '1.2']
valid_versions = ['1.0', '1.1', '1.2', '1.3']
if data_version not in valid_versions:
raise ValueError("Invalid version! Expected one of {}, got {}".format(', '.join(valid_versions), data_version))

Expand Down Expand Up @@ -972,6 +1067,15 @@ def read_hdf5(cls, source):
data['scanpath_fixation_attributes'] = _load_attribute_dict_from_hdf5(source['scanpath_fixation_attributes'])
data['scanpath_attribute_mapping'] = json.loads(decode_string(source.attrs['scanpath_attribute_mapping']))

if data_version < '1.3':
train_lengths = np.array([len(remove_trailing_nans(data['train_xs'][i])) for i in range(len(data['train_xs']))])
else:
train_lengths = source['train_lengths'][...]

data['scanpath_fixation_attributes'] = {
key: VariableLengthArray(value, train_lengths) for key, value in data['scanpath_fixation_attributes'].items()
}

fixations = cls(**data)

return fixations
Expand Down Expand Up @@ -1597,6 +1701,11 @@ def concatenate_stimuli(stimuli):


def concatenate_attributes(attributes):
attributes = list(attributes)

if isinstance(attributes[0], VariableLengthArray):
return concatenate_variable_length_arrays(attributes)

attributes = [np.array(a) for a in attributes]
for a in attributes:
assert len(a.shape) == len(attributes[0].shape)
Expand Down
5 changes: 3 additions & 2 deletions pysaliency/torch_datasets.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import random

from boltons.iterutils import chunked
import numpy as np
import torch
from boltons.iterutils import chunked
from tqdm import tqdm

from .models import Model
Expand Down Expand Up @@ -126,7 +126,8 @@ def __call__(self, item):
inds = np.array([y, x])
values = np.ones(len(y), dtype=int)

mask = torch.sparse.IntTensor(torch.tensor(inds), torch.tensor(values), shape)
# mask = torch.sparse.IntTensor(torch.tensor(inds), torch.tensor(values), shape)
mask = torch.sparse_coo_tensor(torch.tensor(inds), torch.tensor(values), shape, dtype=torch.int)
mask = mask.coalesce()

item['fixation_mask'] = mask
Expand Down
Loading

0 comments on commit 6fb8c06

Please sign in to comment.