diff --git a/pysaliency/external_models/deepgaze.py b/pysaliency/external_models/deepgaze.py index f12c2f1..f821923 100644 --- a/pysaliency/external_models/deepgaze.py +++ b/pysaliency/external_models/deepgaze.py @@ -1,9 +1,9 @@ import numpy as np import torch -from ..models import Model, ScanpathModel from ..datasets import as_stimulus - +from ..models import Model +from ..utils import as_rgb class StaticDeepGazeModel(Model): @@ -22,9 +22,7 @@ def _log_density(self, stimulus): stimulus = as_stimulus(stimulus) stimulus_data = stimulus.stimulus_data - if stimulus_data.ndim == 2: - stimulus_data = np.dstack((stimulus_data, stimulus_data, stimulus_data)) - + stimulus_data = as_rgb(stimulus_data) stimulus_data = stimulus_data.transpose(2, 0, 1) centerbias_data = self.centerbias_model.log_density(stimulus) diff --git a/pysaliency/utils/__init__.py b/pysaliency/utils/__init__.py index 316eb7b..2046b22 100644 --- a/pysaliency/utils/__init__.py +++ b/pysaliency/utils/__init__.py @@ -19,6 +19,7 @@ import numpy as np import requests from boltons.cacheutils import LRU +from PIL import Image from scipy.interpolate import griddata from tqdm import tqdm @@ -509,4 +510,19 @@ def iterator_chunks(iterable, chunk_size=10): counter = count() for _, g in groupby(iterable, lambda _: next(counter) // chunk_size): - yield g \ No newline at end of file + yield g + + +def as_rgb(image: np.ndarray): + """makes sure that image data is in 8 bit RGB format""" + + pil_image = Image.fromarray(image) + + if pil_image.mode == 'I;16': + # convert 16 bit images to 8 bit + array = np.uint8(np.array(pil_image) / 256) + pil_image = Image.fromarray(array) + + rgb_image = pil_image.convert('RGB') + + return np.array(rgb_image) \ No newline at end of file diff --git a/tests/external_models/test_deepgaze.py b/tests/external_models/test_deepgaze.py index 72faddf..fd97eaa 100644 --- a/tests/external_models/test_deepgaze.py +++ b/tests/external_models/test_deepgaze.py @@ -1,11 +1,11 @@ import os import numpy as np +import pytest import pysaliency from pysaliency.external_models.deepgaze import DeepGazeI, DeepGazeIIE -import pytest @pytest.fixture(scope='module') def color_stimulus():