From 9cb8ae91899e250efd58f9566bb2f296c2a2477a Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Thu, 18 Jan 2024 13:37:46 +0100 Subject: [PATCH] Add "xpu" device tag to the device list to support Intel GPU (#428) * Support XPU device * Fmt. --------- Co-authored-by: Jiong Gong --- bindings/python/src/lib.rs | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/bindings/python/src/lib.rs b/bindings/python/src/lib.rs index 3ea62ff1..d4cd89fb 100644 --- a/bindings/python/src/lib.rs +++ b/bindings/python/src/lib.rs @@ -226,6 +226,7 @@ enum Device { Cuda(usize), Mps, Npu(usize), + Xpu(usize), } impl<'source> FromPyObject<'source> for Device { @@ -236,6 +237,7 @@ impl<'source> FromPyObject<'source> for Device { "cuda" => Ok(Device::Cuda(0)), "mps" => Ok(Device::Mps), "npu" => Ok(Device::Npu(0)), + "xpu" => Ok(Device::Xpu(0)), name if name.starts_with("cuda:") => { let tokens: Vec<_> = name.split(':').collect(); if tokens.len() == 2 { @@ -258,6 +260,17 @@ impl<'source> FromPyObject<'source> for Device { ))) } } + name if name.starts_with("xpu:") => { + let tokens: Vec<_> = name.split(':').collect(); + if tokens.len() == 2 { + let device: usize = tokens[1].parse()?; + Ok(Device::Xpu(device)) + } else { + Err(SafetensorError::new_err(format!( + "device {name} is invalid" + ))) + } + } name => Err(SafetensorError::new_err(format!( "device {name} is invalid" ))), @@ -277,6 +290,7 @@ impl IntoPy for Device { Device::Cuda(n) => format!("cuda:{n}").into_py(py), Device::Mps => "mps".into_py(py), Device::Npu(n) => format!("npu:{n}").into_py(py), + Device::Xpu(n) => format!("xpu:{n}").into_py(py), } } }