diff --git a/besskge/pipeline.py b/besskge/pipeline.py index 49e8507..4908bdf 100644 --- a/besskge/pipeline.py +++ b/besskge/pipeline.py @@ -271,7 +271,7 @@ def forward(self) -> Dict[str, Any]: if self.return_scores: scores.append(batch_scores_filt) if self.return_topk: - topk_ids.append(torch.topk(batch_scores_filt, k=self.k, dim=-1).indices) + topk_ids.append(torch.topk(batch_scores_filt.to(torch.float32), k=self.k, dim=-1).indices) out = dict() if scores: