From 35937e2e6316052222d6b30f2dbe99f752fd2c23 Mon Sep 17 00:00:00 2001 From: Daniel Justus Date: Fri, 27 Oct 2023 14:09:56 +0000 Subject: [PATCH] fp32 topk on cpu --- besskge/pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/besskge/pipeline.py b/besskge/pipeline.py index 49e8507..4908bdf 100644 --- a/besskge/pipeline.py +++ b/besskge/pipeline.py @@ -271,7 +271,7 @@ def forward(self) -> Dict[str, Any]: if self.return_scores: scores.append(batch_scores_filt) if self.return_topk: - topk_ids.append(torch.topk(batch_scores_filt, k=self.k, dim=-1).indices) + topk_ids.append(torch.topk(batch_scores_filt.to(torch.float32), k=self.k, dim=-1).indices) out = dict() if scores: