diff --git a/fortuna/hallucination/base.py b/fortuna/hallucination/base.py index 3703fb10..0e9e1074 100644 --- a/fortuna/hallucination/base.py +++ b/fortuna/hallucination/base.py @@ -110,10 +110,6 @@ def fit( Dict The status returned by fitting the multicalibrator. """ - self.generative_model.to( - torch.device("cuda" if torch.cuda.is_available() else "cpu") - ) - ( scores, embeddings, @@ -177,10 +173,6 @@ def predict_proba( if self.multicalibrator is None: raise ValueError("`fit` must be called before this method.") - self.generative_model.to( - torch.device("cuda" if torch.cuda.is_available() else "cpu") - ) - ( scores, embeddings,