Skip to content

Commit

Permalink
fix: no grad
Browse files Browse the repository at this point in the history
  • Loading branch information
stephantul committed Dec 24, 2024
1 parent c4ba272 commit 4713bfa
Showing 1 changed file with 1 addition and 0 deletions.
1 change: 1 addition & 0 deletions model2vec/train/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def predict(self, texts: list[str]) -> list[str]:

return [self.classes[idx] for idx in logits.argmax(1)]

@torch.no_grad()
def _predict(self, texts: list[str]) -> torch.Tensor:
input_ids = self.tokenize(texts)
vectors, _ = self.forward(input_ids)
Expand Down

0 comments on commit 4713bfa

Please sign in to comment.