Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[ENH] ✨ Support MPS accelerated OpenCLIP embeddings (#3295)
## Description of changes Currently, attempting to use OpenCLIP embeddings on a metal performance shader (MPS) device results in the following error: ``` RuntimeError: slow_conv2d_forward_mps: input(device='cpu') and weight(device=mps:0') must be on the same device in add. ``` These changes fix this error by explicitly moving input tensors to the model device in `OpenCLIPEmbeddingFunction` embedding methods. This provides a significant speedup (~2x on my M1 MacBook Pro) compared to running on CPU. ## Test plan It's quite difficult to perform CI testing on real macOS machines, so these changes don't include any tests for MPS devices specifically. However, these changes have been tested successfully on an M1 device, both with `device="cpu"` (the default) and `device="mps"`. - [x] Tests pass locally with `pytest` for python, `yarn test` for js, `cargo test` for rust ## Documentation Changes No user-facing APIs are updated. Now `device="mps"` should work correctly as expected by users.
- Loading branch information