Skip to content

Commit

Permalink
make sure deepgaze inputs are 8bit RGB
Browse files Browse the repository at this point in the history
Signed-off-by: Matthias Kümmerer <[email protected]>
  • Loading branch information
matthias-k committed Apr 26, 2024
1 parent 131c9da commit 9584f1f
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 7 deletions.
8 changes: 3 additions & 5 deletions pysaliency/external_models/deepgaze.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import numpy as np
import torch

from ..models import Model, ScanpathModel
from ..datasets import as_stimulus

from ..models import Model
from ..utils import as_rgb


class StaticDeepGazeModel(Model):
Expand All @@ -22,9 +22,7 @@ def _log_density(self, stimulus):
stimulus = as_stimulus(stimulus)
stimulus_data = stimulus.stimulus_data

if stimulus_data.ndim == 2:
stimulus_data = np.dstack((stimulus_data, stimulus_data, stimulus_data))

stimulus_data = as_rgb(stimulus_data)
stimulus_data = stimulus_data.transpose(2, 0, 1)

centerbias_data = self.centerbias_model.log_density(stimulus)
Expand Down
18 changes: 17 additions & 1 deletion pysaliency/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import numpy as np
import requests
from boltons.cacheutils import LRU
from PIL import Image
from scipy.interpolate import griddata
from tqdm import tqdm

Expand Down Expand Up @@ -509,4 +510,19 @@ def iterator_chunks(iterable, chunk_size=10):

counter = count()
for _, g in groupby(iterable, lambda _: next(counter) // chunk_size):
yield g
yield g


def as_rgb(image: np.ndarray):
"""makes sure that image data is in 8 bit RGB format"""

pil_image = Image.fromarray(image)

if pil_image.mode == 'I;16':
# convert 16 bit images to 8 bit
array = np.uint8(np.array(pil_image) / 256)
pil_image = Image.fromarray(array)

rgb_image = pil_image.convert('RGB')

return np.array(rgb_image)
2 changes: 1 addition & 1 deletion tests/external_models/test_deepgaze.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import os

import numpy as np
import pytest

import pysaliency
from pysaliency.external_models.deepgaze import DeepGazeI, DeepGazeIIE

import pytest

@pytest.fixture(scope='module')
def color_stimulus():
Expand Down

0 comments on commit 9584f1f

Please sign in to comment.