diff --git a/candle-book/src/inference/hub.md b/candle-book/src/inference/hub.md index de514322eb..01492df198 100644 --- a/candle-book/src/inference/hub.md +++ b/candle-book/src/inference/hub.md @@ -25,6 +25,8 @@ let weights = candle::safetensors::load(weights, &Device::Cpu); We now have access to all the [tensors](https://huggingface.co/bert-base-uncased?show_tensors=true) within the file. +You can check all the names of the tensors [here](https://huggingface.co/bert-base-uncased?show_tensors=true) + ## Using async @@ -35,17 +37,9 @@ cargo add hf-hub --features tokio ``` ```rust,ignore -# extern crate candle; -# extern crate hf_hub; -use hf_hub::api::tokio::Api; -use candle::Device; - -let api = Api::new().unwrap(); -let repo = api.model("bert-base-uncased".to_string()); - -let weights = repo.get("model.safetensors").await.unwrap(); - -let weights = candle::safetensors::load(weights, &Device::Cpu); +# This is tested directly in examples crate because it needs external dependencies unfortunately: +# See [this](https://github.com/rust-lang/mdBook/issues/706) +{{#include ../../../candle-examples/src/lib.rs:book_hub_1}} ``` @@ -78,3 +72,33 @@ let output = linear.forward(&input_ids); ``` For a full reference, you can check out the full [bert](https://github.com/LaurentMazare/candle/tree/main/candle-examples/examples/bert) example. + +## Memory mapping + +For more efficient loading, instead of reading the file, you could use [`memmap2`](https://docs.rs/memmap2/latest/memmap2/) + +**Note**: Be careful about memory mapping it seems to cause issues on [Windows, WSL](https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/5893) +and will definitely be slower on network mounted disk, because it will issue more read calls. + +```rust,ignore +{{#include ../../../candle-examples/src/lib.rs:book_hub_2}} +``` + +**Note**: This operation is **unsafe**. [See the safety notice](https://docs.rs/memmap2/latest/memmap2/struct.Mmap.html#safety). +In practice model files should never be modified, and the mmaps should be mostly READONLY anyway, so the caveat most likely does not apply, but always keep it in mind. + + +## Tensor Parallel Sharding + +When using multiple GPUs to use in Tensor Parallel in order to get good latency, you can load only the part of the Tensor you need. + +For that you need to use [`safetensors`](https://crates.io/crates/safetensors) directly. + +```bash +cargo add safetensors +``` + + +```rust,ignore +{{#include ../../../candle-examples/src/lib.rs:book_hub_3}} +``` diff --git a/candle-core/src/safetensors.rs b/candle-core/src/safetensors.rs index 1880a0411d..132fb914e5 100644 --- a/candle-core/src/safetensors.rs +++ b/candle-core/src/safetensors.rs @@ -242,7 +242,11 @@ fn convert_back(tensor: &Tensor) -> Result> { pub fn load>(filename: P, device: &Device) -> Result> { let data = std::fs::read(filename.as_ref())?; - let st = safetensors::SafeTensors::deserialize(&data)?; + load_buffer(&data[..], device) +} + +pub fn load_buffer(data: &[u8], device: &Device) -> Result> { + let st = safetensors::SafeTensors::deserialize(data)?; st.tensors() .into_iter() .map(|(name, view)| Ok((name, view.load(device)?))) diff --git a/candle-examples/Cargo.toml b/candle-examples/Cargo.toml index ff28c646be..6f7dee9985 100644 --- a/candle-examples/Cargo.toml +++ b/candle-examples/Cargo.toml @@ -26,7 +26,7 @@ half = { workspace = true, optional = true } [dev-dependencies] anyhow = { workspace = true } byteorder = { workspace = true } -hf-hub = { workspace = true} +hf-hub = { workspace = true, features=["tokio"]} clap = { workspace = true } rand = { workspace = true } tokenizers = { workspace = true, features = ["onig"] } @@ -34,6 +34,9 @@ tracing = { workspace = true } tracing-chrome = { workspace = true } tracing-subscriber = { workspace = true } wav = { workspace = true } +# Necessary to disambiguate with tokio in wasm examples which are 1.28.1 +tokio = "1.29.1" +memmap2.workspace = true [build-dependencies] anyhow = { workspace = true } diff --git a/candle-examples/src/lib.rs b/candle-examples/src/lib.rs index 285aee049d..3410026ee8 100644 --- a/candle-examples/src/lib.rs +++ b/candle-examples/src/lib.rs @@ -11,3 +11,102 @@ pub fn device(cpu: bool) -> Result { Ok(device) } } + +#[cfg(test)] +mod tests { + // NOTE: Waiting on https://github.com/rust-lang/mdBook/pull/1856 + #[rustfmt::skip] + #[tokio::test] + async fn book_hub_1() { +// ANCHOR: book_hub_1 +use candle::Device; +use hf_hub::api::tokio::Api; + +let api = Api::new().unwrap(); +let repo = api.model("bert-base-uncased".to_string()); + +let weights_filename = repo.get("model.safetensors").await.unwrap(); + +let weights = candle::safetensors::load(weights_filename, &Device::Cpu).unwrap(); +// ANCHOR_END: book_hub_1 + assert_eq!(weights.len(), 206); + } + + #[rustfmt::skip] + #[test] + fn book_hub_2() { +// ANCHOR: book_hub_2 +use candle::Device; +use hf_hub::api::sync::Api; +use memmap2::Mmap; +use std::fs; + +let api = Api::new().unwrap(); +let repo = api.model("bert-base-uncased".to_string()); +let weights_filename = repo.get("model.safetensors").unwrap(); + +let file = fs::File::open(weights_filename).unwrap(); +let mmap = unsafe { Mmap::map(&file).unwrap() }; +let weights = candle::safetensors::load_buffer(&mmap[..], &Device::Cpu).unwrap(); +// ANCHOR_END: book_hub_2 + assert_eq!(weights.len(), 206); + } + + #[rustfmt::skip] + #[test] + fn book_hub_3() { +// ANCHOR: book_hub_3 +use candle::{DType, Device, Tensor}; +use hf_hub::api::sync::Api; +use memmap2::Mmap; +use safetensors::slice::IndexOp; +use safetensors::SafeTensors; +use std::fs; + +let api = Api::new().unwrap(); +let repo = api.model("bert-base-uncased".to_string()); +let weights_filename = repo.get("model.safetensors").unwrap(); + +let file = fs::File::open(weights_filename).unwrap(); +let mmap = unsafe { Mmap::map(&file).unwrap() }; + +// Use safetensors directly +let tensors = SafeTensors::deserialize(&mmap[..]).unwrap(); +let view = tensors +.tensor("bert.encoder.layer.0.attention.self.query.weight") +.unwrap(); + +// We're going to load shard with rank 1, within a world_size of 4 +// We're going to split along dimension 0 doing VIEW[start..stop, :] +let rank = 1; +let world_size = 4; +let dim = 0; +let dtype = view.dtype(); +let mut tp_shape = view.shape().to_vec(); +let size = tp_shape[0]; + +if size % world_size != 0 { +panic!("The dimension is not divisble by `world_size`"); +} +let block_size = size / world_size; +let start = rank * block_size; +let stop = (rank + 1) * block_size; + +// Everything is expressed in tensor dimension +// bytes offsets is handled automatically for safetensors. + +let iterator = view.slice(start..stop).unwrap(); + +tp_shape[dim] = block_size; + +// Convert safetensors Dtype to candle DType +let dtype: DType = dtype.try_into().unwrap(); + +// TODO: Implement from_buffer_iterator to we can skip the extra CPU alloc. +let raw: Vec = iterator.into_iter().flatten().cloned().collect(); +let tp_tensor = Tensor::from_raw_buffer(&raw, dtype, &tp_shape, &Device::Cpu).unwrap(); +// ANCHOR_END: book_hub_3 + assert_eq!(view.shape(), &[768, 768]); + assert_eq!(tp_tensor.dims(), &[192, 768]); + } +}