Skip to content

Commit

Permalink
Bugfix: new torch versions do not support sparse tensors in default_c…
Browse files Browse the repository at this point in the history
…ollate

Signed-off-by: Matthias Kümmerer <[email protected]>
  • Loading branch information
matthias-k committed Jun 20, 2024
1 parent 74f4ee3 commit 7b1b094
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 2 deletions.
3 changes: 2 additions & 1 deletion pysaliency/saliency_map_conversion_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from .optpy import minimize
from .saliency_map_models import SaliencyMapModel
from .torch_utils import GaussianFilterNd, Nonlinearity, zero_grad, log_likelihood
from .torch_datasets import ImageDataset, ImageDatasetSampler, FixationMaskTransform
from .torch_datasets import ImageDataset, ImageDatasetSampler, FixationMaskTransform, collate_fn


class CenterBias(nn.Module):
Expand Down Expand Up @@ -275,6 +275,7 @@ def _optimize_saliency_map_conversion_over_multiple_models_and_datasets(
batch_sampler=ImageDatasetSampler(dataset, batch_size=batch_size, shuffle=False),
pin_memory=False,
num_workers=0, # doesn't work for sparse tensors yet. Might work soon.
collate_fn=collate_fn,
)

saliency_map_processing = SaliencyMapProcessing(
Expand Down
12 changes: 12 additions & 0 deletions pysaliency/torch_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np
import torch
from boltons.iterutils import chunked
from torch.utils.data.dataloader import default_collate
from tqdm import tqdm

from .models import Model
Expand Down Expand Up @@ -167,3 +168,14 @@ def __iter__(self):

def __len__(self):
return int(self.ratio_used * len(self.batches))


# we need to extend the defaut collate fn to handle sparse coo tensors
def collate_fn(batch):
result = {}
for key in batch[0]:
if isinstance(batch[0][key], torch.sparse.Tensor):
result[key] = torch.stack([item[key] for item in batch], 0)
else:
result[key] = default_collate([item[key] for item in batch])
return result
3 changes: 2 additions & 1 deletion tests/test_torch_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
SaliencyMapModelFromDirectory,
UniformModel
)
from pysaliency.torch_datasets import ImageDataset, ImageDatasetSampler, FixationMaskTransform
from pysaliency.torch_datasets import ImageDataset, ImageDatasetSampler, FixationMaskTransform, collate_fn
import torch


Expand Down Expand Up @@ -71,6 +71,7 @@ def test_dataset(stimuli, fixations, png_saliency_map_model):
batch_sampler=ImageDatasetSampler(dataset, batch_size=4, shuffle=False),
pin_memory=False,
num_workers=0, # doesn't work for sparse tensors yet. Might work soon.
collate_fn=collate_fn,
)

count = 0
Expand Down

0 comments on commit 7b1b094

Please sign in to comment.