Skip to content

Commit

Permalink
[ENH] ✨ Support MPS accelerated OpenCLIP embeddings (#3295)
Browse files Browse the repository at this point in the history
## Description of changes

Currently, attempting to use OpenCLIP embeddings on a metal performance
shader (MPS) device results in the following error:

```
RuntimeError: slow_conv2d_forward_mps: input(device='cpu') and weight(device=mps:0')  must be on the same device in add.
```

These changes fix this error by explicitly moving input tensors to the
model device in `OpenCLIPEmbeddingFunction` embedding methods. This
provides a significant speedup (~2x on my M1 MacBook Pro) compared to
running on CPU.

## Test plan

It's quite difficult to perform CI testing on real macOS machines, so
these changes don't include any tests for MPS devices specifically.
However, these changes have been tested successfully on an M1 device,
both with `device="cpu"` (the default) and `device="mps"`.

- [x] Tests pass locally with `pytest` for python, `yarn test` for js,
`cargo test` for rust

## Documentation Changes

No user-facing APIs are updated. Now `device="mps"` should work
correctly as expected by users.
  • Loading branch information
connorbrinton authored Jan 2, 2025
1 parent 468f910 commit e30429b
Showing 1 changed file with 5 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def __init__(
model, _, preprocess = open_clip.create_model_and_transforms(
model_name=model_name, pretrained=checkpoint
)
self._device = device
self._model = model
self._model.to(device)
self._preprocess = preprocess
Expand All @@ -56,14 +57,16 @@ def _encode_image(self, image: Image) -> Embedding:
pil_image = self._PILImage.fromarray(image)
with self._torch.no_grad():
image_features = self._model.encode_image(
self._preprocess(pil_image).unsqueeze(0)
self._preprocess(pil_image).unsqueeze(0).to(self._device)
)
image_features /= image_features.norm(dim=-1, keepdim=True)
return cast(Embedding, image_features.squeeze().cpu().numpy())

def _encode_text(self, text: Document) -> Embedding:
with self._torch.no_grad():
text_features = self._model.encode_text(self._tokenizer(text))
text_features = self._model.encode_text(
self._tokenizer(text).to(self._device)
)
text_features /= text_features.norm(dim=-1, keepdim=True)
return cast(Embedding, text_features.squeeze().cpu().numpy())

Expand Down

0 comments on commit e30429b

Please sign in to comment.