diff --git a/pysaliency/external_datasets/cat2000.py b/pysaliency/external_datasets/cat2000.py index 9d4a550..7e405d1 100644 --- a/pysaliency/external_datasets/cat2000.py +++ b/pysaliency/external_datasets/cat2000.py @@ -178,7 +178,9 @@ def _get_cat2000_train(name, location): # Stimuli print('Creating stimuli') f = zipfile.ZipFile(os.path.join(temp_dir, 'trainSet.zip')) - f.extractall(temp_dir) + namelist = f.namelist() + namelist = filter_files(namelist, ['Output']) + f.extractall(temp_dir, namelist) stimuli_src_location = os.path.join(temp_dir, 'trainSet', 'Stimuli') stimuli_target_location = os.path.join(location, 'Stimuli') if location else None @@ -304,7 +306,9 @@ def _get_cat2000_train_v1_1(name, location): # Stimuli print('Creating stimuli') f = zipfile.ZipFile(os.path.join(temp_dir, 'trainSet.zip')) - f.extractall(temp_dir) + namelist = f.namelist() + namelist = filter_files(namelist, ['Output']) + f.extractall(temp_dir, namelist) stimuli_src_location = os.path.join(temp_dir, 'trainSet', 'Stimuli') stimuli_target_location = os.path.join(location, 'Stimuli') if location else None diff --git a/tests/test_external_datasets.py b/tests/test_external_datasets.py index fdf2726..58b1059 100644 --- a/tests/test_external_datasets.py +++ b/tests/test_external_datasets.py @@ -1,5 +1,6 @@ import numpy as np import pytest +from pathlib import Path from pytest import approx from scipy.stats import kurtosis, skew @@ -71,7 +72,7 @@ def test_toronto(location): @pytest.mark.download @pytest.mark.matlab @pytest.mark.skip_octave -def test_cat2000_train(location, matlab): +def test_cat2000_train_v1_0(location, matlab): real_location = _location(location) stimuli, fixations = pysaliency.external_datasets.get_cat2000_train(location=real_location) @@ -83,6 +84,8 @@ def test_cat2000_train(location, matlab): assert isinstance(stimuli, pysaliency.FileStimuli) assert location.join('CAT2000_train/stimuli.hdf5').check() assert location.join('CAT2000_train/fixations.hdf5').check() + assert not list ((Path(location) / 'CAT2000_train' / 'Stimuli').glob('**/Output')) + assert not list ((Path(location) / 'CAT2000_train' / 'Stimuli').glob('**/*_SaliencyMap.jpg')) assert len(stimuli.stimuli) == 2000 assert set(stimuli.sizes) == {(1080, 1920)} @@ -118,6 +121,59 @@ def test_cat2000_train(location, matlab): assert len(fixations) == len(pysaliency.datasets.remove_out_of_stimulus_fixations(stimuli, fixations)) +@pytest.mark.slow +@pytest.mark.download +@pytest.mark.matlab +@pytest.mark.skip_octave +def test_cat2000_train_v1_1(location, matlab): + real_location = _location(location) + + stimuli, fixations = pysaliency.external_datasets.get_cat2000_train(location=real_location, version='1.1') + + if location is None: + assert isinstance(stimuli, pysaliency.Stimuli) + assert not isinstance(stimuli, pysaliency.FileStimuli) + else: + assert isinstance(stimuli, pysaliency.FileStimuli) + assert location.join('CAT2000_train_v1.1/stimuli.hdf5').check() + assert location.join('CAT2000_train_v1.1/fixations.hdf5').check() + assert not list ((Path(location) / 'CAT2000_train_v1.1' / 'Stimuli').glob('**/Output')) + assert not list ((Path(location) / 'CAT2000_train_v1.1' / 'Stimuli').glob('**/*_SaliencyMap.jpg')) + + assert len(stimuli.stimuli) == 2000 + assert set(stimuli.sizes) == {(1080, 1920)} + assert set(stimuli.attributes.keys()) == {'category'} + assert np.all(np.array(stimuli.attributes['category'][0:100]) == 0) + assert np.all(np.array(stimuli.attributes['category'][100:200]) == 1) + + assert len(fixations.x) == 667804 + + assert np.mean(fixations.x) == approx(977.048229720098) + assert np.mean(fixations.y) == approx(535.7335899455527) + assert np.mean(fixations.t) == approx(10.888694886523592) + assert np.mean(fixations.lengths) == approx(9.888694886523592) + + assert np.std(fixations.x) == approx(265.7561897117776) + assert np.std(fixations.y) == approx(200.47021508760227) + assert np.std(fixations.t) == approx(6.8276447542371805) + assert np.std(fixations.lengths) == approx(6.8276447542371805) + + assert kurtosis(fixations.x) == approx(0.8314129075001575) + assert kurtosis(fixations.y) == approx(0.16001475266665466) + assert kurtosis(fixations.t) == approx(0.07131517526032427) + assert kurtosis(fixations.lengths) == approx(0.07131517526032427) + + assert skew(fixations.x) == approx(0.07615972876511597) + assert skew(fixations.y) == approx(0.2770231691322164) + assert skew(fixations.t) == approx(0.5813051491385639) + assert skew(fixations.lengths) == approx(0.5813051491385639) + + assert entropy(fixations.n) == approx(10.955097604631638) + assert (fixations.n == 0).sum() == 304 + + assert len(fixations) == len(pysaliency.datasets.remove_out_of_stimulus_fixations(stimuli, fixations)) + + @pytest.mark.slow @pytest.mark.download @pytest.mark.skip_octave @@ -132,6 +188,9 @@ def test_cat2000_test(location): else: assert isinstance(stimuli, pysaliency.FileStimuli) assert location.join('CAT2000_test/stimuli.hdf5').check() + assert not list ((Path(location) / 'CAT2000_test' / 'Stimuli').glob('**/Output')) + assert not list ((Path(location) / 'CAT2000_test' / 'Stimuli').glob('**/*_SaliencyMap.jpg')) + assert len(stimuli.stimuli) == 2000 assert set(stimuli.sizes) == {(1080, 1920)}