Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Book 3 (advanced loading + hub) #263

Merged
merged 10 commits into from
Aug 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions .github/workflows/book-cd.yml
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
name: Deploy Rust book
on:
# TODO put this back only when merging after this PR lands.
pull_request:
push:
branches:
- main
Expand Down
26 changes: 13 additions & 13 deletions candle-book/src/SUMMARY.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,16 @@

- [Running a model](inference/README.md)
- [Using the hub](inference/hub.md)
- [Serialization](inference/serialization.md)
- [Advanced Cuda usage](inference/cuda/README.md)
- [Writing a custom kernel](inference/cuda/writing.md)
- [Porting a custom kernel](inference/cuda/porting.md)
- [Error management](error_manage.md)
- [Creating apps](apps/README.md)
- [Creating a WASM app](apps/wasm.md)
- [Creating a REST api webserver](apps/rest.md)
- [Creating a desktop Tauri app](apps/dekstop.md)
- [Training](training/README.md)
- [MNIST](training/mnist.md)
- [Fine-tuning](training/finetuning.md)
- [Using MKL](advanced/mkl.md)
- [Error management]()
- [Advanced Cuda usage]()
- [Writing a custom kernel]()
- [Porting a custom kernel]()
- [Using MKL]()
- [Creating apps]()
- [Creating a WASM app]()
- [Creating a REST api webserver]()
- [Creating a desktop Tauri app]()
- [Training]()
- [MNIST]()
- [Fine-tuning]()
- [Serialization]()
1 change: 1 addition & 0 deletions candle-book/src/cuda/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Advanced Cuda usage
1 change: 1 addition & 0 deletions candle-book/src/cuda/porting.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Porting a custom kernel
1 change: 1 addition & 0 deletions candle-book/src/cuda/writing.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Writing a custom kernel
50 changes: 50 additions & 0 deletions candle-book/src/error_manage.md
Original file line number Diff line number Diff line change
@@ -1 +1,51 @@
# Error management

You might have seen in the code base a lot of `.unwrap()` or `?`.
If you're unfamiliar with Rust check out the [Rust book](https://doc.rust-lang.org/book/ch09-02-recoverable-errors-with-result.html)
for more information.

What's important to know though, is that if you want to know *where* a particular operation failed
You can simply use `RUST_BACKTRACE=1` to get the location of where the model actually failed.

Let's see on failing code:

```rust,ignore
let x = Tensor::zeros((1, 784), DType::F32, &device)?;
let y = Tensor::zeros((1, 784), DType::F32, &device)?;
let z = x.matmul(&y)?;
```

Will print at runtime:

```bash
Error: ShapeMismatchBinaryOp { lhs: [1, 784], rhs: [1, 784], op: "matmul" }
```


After adding `RUST_BACKTRACE=1`:


```bash
Error: WithBacktrace { inner: ShapeMismatchBinaryOp { lhs: [1, 784], rhs: [1, 784], op: "matmul" }, backtrace: Backtrace [{ fn: "candle::error::Error::bt", file: "/home/nicolas/.cargo/git/checkouts/candle-5bb8ef7e0626d693/f291065/candle-core/src/error.rs", line: 200 }, { fn: "candle::tensor::Tensor::matmul", file: "/home/nicolas/.cargo/git/checkouts/candle-5bb8ef7e0626d693/f291065/candle-core/src/tensor.rs", line: 816 }, { fn: "myapp::main", file: "./src/main.rs", line: 29 }, { fn: "core::ops::function::FnOnce::call_once", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/core/src/ops/function.rs", line: 250 }, { fn: "std::sys_common::backtrace::__rust_begin_short_backtrace", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/sys_common/backtrace.rs", line: 135 }, { fn: "std::rt::lang_start::{{closure}}", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/rt.rs", line: 166 }, { fn: "core::ops::function::impls::<impl core::ops::function::FnOnce<A> for &F>::call_once", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/core/src/ops/function.rs", line: 284 }, { fn: "std::panicking::try::do_call", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panicking.rs", line: 500 }, { fn: "std::panicking::try", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panicking.rs", line: 464 }, { fn: "std::panic::catch_unwind", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panic.rs", line: 142 }, { fn: "std::rt::lang_start_internal::{{closure}}", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/rt.rs", line: 148 }, { fn: "std::panicking::try::do_call", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panicking.rs", line: 500 }, { fn: "std::panicking::try", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panicking.rs", line: 464 }, { fn: "std::panic::catch_unwind", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/panic.rs", line: 142 }, { fn: "std::rt::lang_start_internal", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/rt.rs", line: 148 }, { fn: "std::rt::lang_start", file: "/rustc/8ede3aae28fe6e4d52b38157d7bfe0d3bceef225/library/std/src/rt.rs", line: 165 }, { fn: "main" }, { fn: "__libc_start_main" }, { fn: "_start" }] }
```

Not super pretty at the moment, but we can see error occured on `{ fn: "myapp::main", file: "./src/main.rs", line: 29 }`


Another thing to note, is that since Rust is compiled it is not necessarily as easy to recover proper stacktraces
especially in release builds. We're using [`anyhow`](https://docs.rs/anyhow/latest/anyhow/) for that.
The library is still young, please [report](https://github.com/LaurentMazare/candle/issues) any issues detecting where an error is coming from.

## Cuda error management

When running a model on Cuda, you might get a stacktrace not really representing the error.
The reason is that CUDA is async by nature, and therefore the error might be caught while you were sending totally different kernels.

One way to avoid this is to use `CUDA_LAUNCH_BLOCKING=1` as an environment variable. This will force every kernel to be launched sequentially.
You might still however see the error happening on other kernels as the faulty kernel might exit without an error but spoiling some pointer for which the error will happen when dropping the `CudaSlice` only.


If this occurs, you can use [`compute-sanitizer`](https://docs.nvidia.com/compute-sanitizer/ComputeSanitizer/index.html)
This tool is like `valgrind` but for cuda. It will help locate the errors in the kernels.


6 changes: 6 additions & 0 deletions candle-book/src/inference/README.md
Original file line number Diff line number Diff line change
@@ -1 +1,7 @@
# Running a model


In order to run an existing model, you will need to download and use existing weights.
Most models are already available on https://huggingface.co/ in [`safetensors`](https://github.com/huggingface/safetensors) format.

Let's get started by running an old model : `bert-base-uncased`.
103 changes: 103 additions & 0 deletions candle-book/src/inference/hub.md
Original file line number Diff line number Diff line change
@@ -1 +1,104 @@
# Using the hub

Install the [`hf-hub`](https://github.com/huggingface/hf-hub) crate:

```bash
cargo add hf-hub
```

Then let's start by downloading the [model file](https://huggingface.co/bert-base-uncased/tree/main).


```rust
# extern crate candle_core;
# extern crate hf_hub;
use hf_hub::api::sync::Api;
use candle_core::Device;

let api = Api::new().unwrap();
let repo = api.model("bert-base-uncased".to_string());

let weights = repo.get("model.safetensors").unwrap();

let weights = candle_core::safetensors::load(weights, &Device::Cpu);
```

We now have access to all the [tensors](https://huggingface.co/bert-base-uncased?show_tensors=true) within the file.

You can check all the names of the tensors [here](https://huggingface.co/bert-base-uncased?show_tensors=true)


## Using async

`hf-hub` comes with an async API.

```bash
cargo add hf-hub --features tokio
```

```rust,ignore
# This is tested directly in examples crate because it needs external dependencies unfortunately:
# See [this](https://github.com/rust-lang/mdBook/issues/706)
{{#include ../../../candle-examples/src/lib.rs:book_hub_1}}
```


## Using in a real model.

Now that we have our weights, we can use them in our bert architecture:

```rust
# extern crate candle_core;
# extern crate candle_nn;
# extern crate hf_hub;
# use hf_hub::api::sync::Api;
#
# let api = Api::new().unwrap();
# let repo = api.model("bert-base-uncased".to_string());
#
# let weights = repo.get("model.safetensors").unwrap();
use candle_core::{Device, Tensor, DType};
use candle_nn::Linear;

let weights = candle_core::safetensors::load(weights, &Device::Cpu).unwrap();

let weight = weights.get("bert.encoder.layer.0.attention.self.query.weight").unwrap();
let bias = weights.get("bert.encoder.layer.0.attention.self.query.bias").unwrap();

let linear = Linear::new(weight.clone(), Some(bias.clone()));

let input_ids = Tensor::zeros((3, 768), DType::F32, &Device::Cpu).unwrap();
let output = linear.forward(&input_ids).unwrap();
```

For a full reference, you can check out the full [bert](https://github.com/LaurentMazare/candle/tree/main/candle-examples/examples/bert) example.

## Memory mapping

For more efficient loading, instead of reading the file, you could use [`memmap2`](https://docs.rs/memmap2/latest/memmap2/)

**Note**: Be careful about memory mapping it seems to cause issues on [Windows, WSL](https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/5893)
and will definitely be slower on network mounted disk, because it will issue more read calls.

```rust,ignore
{{#include ../../../candle-examples/src/lib.rs:book_hub_2}}
```

**Note**: This operation is **unsafe**. [See the safety notice](https://docs.rs/memmap2/latest/memmap2/struct.Mmap.html#safety).
In practice model files should never be modified, and the mmaps should be mostly READONLY anyway, so the caveat most likely does not apply, but always keep it in mind.


## Tensor Parallel Sharding

When using multiple GPUs to use in Tensor Parallel in order to get good latency, you can load only the part of the Tensor you need.

For that you need to use [`safetensors`](https://crates.io/crates/safetensors) directly.

```bash
cargo add safetensors
```


```rust,ignore
{{#include ../../../candle-examples/src/lib.rs:book_hub_3}}
```
6 changes: 5 additions & 1 deletion candle-core/src/safetensors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,11 @@ fn convert_back(tensor: &Tensor) -> Result<Vec<u8>> {

pub fn load<P: AsRef<Path>>(filename: P, device: &Device) -> Result<HashMap<String, Tensor>> {
let data = std::fs::read(filename.as_ref())?;
let st = safetensors::SafeTensors::deserialize(&data)?;
load_buffer(&data[..], device)
}

pub fn load_buffer(data: &[u8], device: &Device) -> Result<HashMap<String, Tensor>> {
let st = safetensors::SafeTensors::deserialize(data)?;
st.tensors()
.into_iter()
.map(|(name, view)| Ok((name, view.load(device)?)))
Expand Down
4 changes: 3 additions & 1 deletion candle-examples/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,17 @@ half = { workspace = true, optional = true }
[dev-dependencies]
anyhow = { workspace = true }
byteorder = { workspace = true }
hf-hub = { workspace = true, features=["tokio"]}
clap = { workspace = true }
hf-hub = { workspace = true }
memmap2 = { workspace = true }
rand = { workspace = true }
tokenizers = { workspace = true, features = ["onig"] }
tracing = { workspace = true }
tracing-chrome = { workspace = true }
tracing-subscriber = { workspace = true }
wav = { workspace = true }
# Necessary to disambiguate with tokio in wasm examples which are 1.28.1
tokio = "1.29.1"

[build-dependencies]
anyhow = { workspace = true }
Expand Down
99 changes: 99 additions & 0 deletions candle-examples/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,102 @@ pub fn device(cpu: bool) -> Result<Device> {
Ok(device)
}
}

#[cfg(test)]
mod tests {
// NOTE: Waiting on https://github.com/rust-lang/mdBook/pull/1856
#[rustfmt::skip]
#[tokio::test]
async fn book_hub_1() {
// ANCHOR: book_hub_1
use candle::Device;
use hf_hub::api::tokio::Api;

let api = Api::new().unwrap();
let repo = api.model("bert-base-uncased".to_string());

let weights_filename = repo.get("model.safetensors").await.unwrap();

let weights = candle::safetensors::load(weights_filename, &Device::Cpu).unwrap();
// ANCHOR_END: book_hub_1
assert_eq!(weights.len(), 206);
}

#[rustfmt::skip]
#[test]
fn book_hub_2() {
// ANCHOR: book_hub_2
use candle::Device;
use hf_hub::api::sync::Api;
use memmap2::Mmap;
use std::fs;

let api = Api::new().unwrap();
let repo = api.model("bert-base-uncased".to_string());
let weights_filename = repo.get("model.safetensors").unwrap();

let file = fs::File::open(weights_filename).unwrap();
let mmap = unsafe { Mmap::map(&file).unwrap() };
let weights = candle::safetensors::load_buffer(&mmap[..], &Device::Cpu).unwrap();
// ANCHOR_END: book_hub_2
assert_eq!(weights.len(), 206);
}

#[rustfmt::skip]
#[test]
fn book_hub_3() {
// ANCHOR: book_hub_3
use candle::{DType, Device, Tensor};
use hf_hub::api::sync::Api;
use memmap2::Mmap;
use safetensors::slice::IndexOp;
use safetensors::SafeTensors;
use std::fs;

let api = Api::new().unwrap();
let repo = api.model("bert-base-uncased".to_string());
let weights_filename = repo.get("model.safetensors").unwrap();

let file = fs::File::open(weights_filename).unwrap();
let mmap = unsafe { Mmap::map(&file).unwrap() };

// Use safetensors directly
let tensors = SafeTensors::deserialize(&mmap[..]).unwrap();
let view = tensors
.tensor("bert.encoder.layer.0.attention.self.query.weight")
.unwrap();

// We're going to load shard with rank 1, within a world_size of 4
// We're going to split along dimension 0 doing VIEW[start..stop, :]
let rank = 1;
let world_size = 4;
let dim = 0;
let dtype = view.dtype();
let mut tp_shape = view.shape().to_vec();
let size = tp_shape[0];

if size % world_size != 0 {
panic!("The dimension is not divisble by `world_size`");
}
let block_size = size / world_size;
let start = rank * block_size;
let stop = (rank + 1) * block_size;

// Everything is expressed in tensor dimension
// bytes offsets is handled automatically for safetensors.

let iterator = view.slice(start..stop).unwrap();

tp_shape[dim] = block_size;

// Convert safetensors Dtype to candle DType
let dtype: DType = dtype.try_into().unwrap();

// TODO: Implement from_buffer_iterator so we can skip the extra CPU alloc.
let raw: Vec<u8> = iterator.into_iter().flatten().cloned().collect();
let tp_tensor = Tensor::from_raw_buffer(&raw, dtype, &tp_shape, &Device::Cpu).unwrap();
// ANCHOR_END: book_hub_3
assert_eq!(view.shape(), &[768, 768]);
assert_eq!(tp_tensor.dims(), &[192, 768]);
}
}