diff --git a/keybert/_model.py b/keybert/_model.py index 571118de..c442e416 100644 --- a/keybert/_model.py +++ b/keybert/_model.py @@ -256,7 +256,9 @@ def extract_keywords( # Fine-tune keywords using an LLM if self.llm is not None: import torch - doc_embeddings = torch.from_numpy(doc_embeddings).float().to("cuda") + doc_embeddings = torch.from_numpy(doc_embeddings).float() + if torch.cuda.is_available(): + doc_embeddings = doc_embeddings.to("cuda") if isinstance(all_keywords[0], tuple): candidate_keywords = [[keyword[0] for keyword in all_keywords]] else: