From e30429b31455d75b733c702170e93ea0bcccc34b Mon Sep 17 00:00:00 2001 From: Connor Brinton Date: Thu, 2 Jan 2025 12:35:44 -0500 Subject: [PATCH] =?UTF-8?q?[ENH]=20=E2=9C=A8=20Support=20MPS=20accelerated?= =?UTF-8?q?=20OpenCLIP=20embeddings=20(#3295)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Description of changes Currently, attempting to use OpenCLIP embeddings on a metal performance shader (MPS) device results in the following error: ``` RuntimeError: slow_conv2d_forward_mps: input(device='cpu') and weight(device=mps:0') must be on the same device in add. ``` These changes fix this error by explicitly moving input tensors to the model device in `OpenCLIPEmbeddingFunction` embedding methods. This provides a significant speedup (~2x on my M1 MacBook Pro) compared to running on CPU. ## Test plan It's quite difficult to perform CI testing on real macOS machines, so these changes don't include any tests for MPS devices specifically. However, these changes have been tested successfully on an M1 device, both with `device="cpu"` (the default) and `device="mps"`. - [x] Tests pass locally with `pytest` for python, `yarn test` for js, `cargo test` for rust ## Documentation Changes No user-facing APIs are updated. Now `device="mps"` should work correctly as expected by users. --- .../embedding_functions/open_clip_embedding_function.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/chromadb/utils/embedding_functions/open_clip_embedding_function.py b/chromadb/utils/embedding_functions/open_clip_embedding_function.py index 0d05b6c27b6..a261e4ed2e7 100644 --- a/chromadb/utils/embedding_functions/open_clip_embedding_function.py +++ b/chromadb/utils/embedding_functions/open_clip_embedding_function.py @@ -47,6 +47,7 @@ def __init__( model, _, preprocess = open_clip.create_model_and_transforms( model_name=model_name, pretrained=checkpoint ) + self._device = device self._model = model self._model.to(device) self._preprocess = preprocess @@ -56,14 +57,16 @@ def _encode_image(self, image: Image) -> Embedding: pil_image = self._PILImage.fromarray(image) with self._torch.no_grad(): image_features = self._model.encode_image( - self._preprocess(pil_image).unsqueeze(0) + self._preprocess(pil_image).unsqueeze(0).to(self._device) ) image_features /= image_features.norm(dim=-1, keepdim=True) return cast(Embedding, image_features.squeeze().cpu().numpy()) def _encode_text(self, text: Document) -> Embedding: with self._torch.no_grad(): - text_features = self._model.encode_text(self._tokenizer(text)) + text_features = self._model.encode_text( + self._tokenizer(text).to(self._device) + ) text_features /= text_features.norm(dim=-1, keepdim=True) return cast(Embedding, text_features.squeeze().cpu().numpy())