diff --git a/outlines/processors/base_logits_processor.py b/outlines/processors/base_logits_processor.py index 44b55af2e..800e69f79 100644 --- a/outlines/processors/base_logits_processor.py +++ b/outlines/processors/base_logits_processor.py @@ -110,9 +110,9 @@ def _to_torch(tensor_like: Array) -> torch.Tensor: import mlx.core as mx # https://ml-explore.github.io/mlx/build/html/usage/numpy.html#pytorch - return torch.from_dlpack( - np.array(tensor_like.astype(mx.float32), copy=False) - ) + if tensor_like.dtype == mx.bfloat16: + tensor_like = tensor_like.astype(mx.float32) + return torch.from_dlpack(np.array(tensor_like, copy=False)) elif is_jax_array_type(type(tensor_like)): import jax diff --git a/tests/processors/test_base_processor.py b/tests/processors/test_base_processor.py index cd9f48278..d2a1e1af2 100644 --- a/tests/processors/test_base_processor.py +++ b/tests/processors/test_base_processor.py @@ -18,6 +18,7 @@ import mlx.core as mx arrays["mlx"] = mx.array([[1, 2], [3, 4]], dtype=mx.float32) + arrays["mlx_bfloat16"] = mx.array([[1, 2], [3, 4]], dtype=mx.bfloat16) except ImportError: pass @@ -59,7 +60,12 @@ def test_from_torch(array_type, processor): torch_tensor = torch.tensor([[1, 2], [3, 4]], dtype=torch.float32) data = processor._from_torch(torch_tensor, type(arrays[array_type])) assert isinstance(data, type(arrays[array_type])) - assert np.allclose(data, arrays[array_type]) + if array_type == "mlx_bfloat16": + # For bfloat16, we expect the output to be float32 due to the conversion + assert data.dtype == mx.float32 + assert np.allclose(np.array(data), np.array([[1, 2], [3, 4]], dtype=np.float32)) + else: + assert np.allclose(data, arrays[array_type]) @pytest.mark.parametrize("array_type", arrays.keys())