diff --git a/chromadb/utils/embedding_functions/open_clip_embedding_function.py b/chromadb/utils/embedding_functions/open_clip_embedding_function.py index 0d05b6c27b6..a261e4ed2e7 100644 --- a/chromadb/utils/embedding_functions/open_clip_embedding_function.py +++ b/chromadb/utils/embedding_functions/open_clip_embedding_function.py @@ -47,6 +47,7 @@ def __init__( model, _, preprocess = open_clip.create_model_and_transforms( model_name=model_name, pretrained=checkpoint ) + self._device = device self._model = model self._model.to(device) self._preprocess = preprocess @@ -56,14 +57,16 @@ def _encode_image(self, image: Image) -> Embedding: pil_image = self._PILImage.fromarray(image) with self._torch.no_grad(): image_features = self._model.encode_image( - self._preprocess(pil_image).unsqueeze(0) + self._preprocess(pil_image).unsqueeze(0).to(self._device) ) image_features /= image_features.norm(dim=-1, keepdim=True) return cast(Embedding, image_features.squeeze().cpu().numpy()) def _encode_text(self, text: Document) -> Embedding: with self._torch.no_grad(): - text_features = self._model.encode_text(self._tokenizer(text)) + text_features = self._model.encode_text( + self._tokenizer(text).to(self._device) + ) text_features /= text_features.norm(dim=-1, keepdim=True) return cast(Embedding, text_features.squeeze().cpu().numpy())