Skip to content

Commit

Permalink
[rust] Use fused layer_norm (deepjavalibrary#3346)
Browse files Browse the repository at this point in the history
  • Loading branch information
xyang16 authored Jul 17, 2024
1 parent e811632 commit 6b642e1
Show file tree
Hide file tree
Showing 9 changed files with 229 additions and 974 deletions.
2 changes: 1 addition & 1 deletion extensions/tokenizers/build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ RUST_MANIFEST=rust/Cargo.toml
if [[ "$FLAVOR" = "cpu"* ]]; then
cargo build --manifest-path $RUST_MANIFEST --release
elif [[ "$FLAVOR" = "cu"* && "$FLAVOR" > "cu121" ]]; then
cargo build --manifest-path $RUST_MANIFEST --release --features cuda,cublaslt,flash-attn
cargo build --manifest-path $RUST_MANIFEST --release --features cuda,flash-attn
else
cargo build --manifest-path $RUST_MANIFEST --release
fi
Expand Down
24 changes: 16 additions & 8 deletions extensions/tokenizers/rust/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@ edition = "2021"

[dependencies]
jni = "0.21.1"
candle = { version = "0.5.1", package = "candle-core" }
candle-nn = "0.5.1"
candle-transformers = "0.5.1"
candle-flash-attn = { version = "0.5.1", optional = true }
cudarc = { version = "0.11.6", default-features = false, features = [ "cublaslt", "f16" ], optional = true }
candle = { version = "*", package = "candle-core" }
candle-nn = { version = "*" }
candle-transformers = { version = "*" }
candle-flash-attn = { version = "*", optional = true }
candle-cublaslt = { git = "https://github.com/huggingface/candle-cublaslt", rev = "cf789b7dd6d4abb19b03b9556442f94f0588b4a0", optional = true }
candle-layer-norm = { git = "https://github.com/xyang16/candle-layer-norm", rev = "e574de6a7f88bafbede8edf9ee43170c6a8ce51a", optional = true }
candle-rotary = { git = "https://github.com/huggingface/candle-rotary", rev = "0a718a0856569a92f3112e64f10d07e4447822e8", optional = true }
tokenizers = { path = "../tokenizers/tokenizers", version = "*", features = ["http"] }
half = "2.4.0"
tracing = "0.1.40"
Expand All @@ -19,13 +21,19 @@ thiserror = "1.0.58"
serde = { version = "1.0.198", features = ["serde_derive"] }
serde_json = "1.0.116"

[patch.crates-io]
cudarc = { git = "https://github.com/coreylowman/cudarc", rev = "c388e724af93a3e8fbe484f5ded2d8b3c1badd8e" }
candle = { git = "https://github.com/huggingface/candle", rev = "f76bb7794aa8659c5023797979a3392cdfc01f32", package = "candle-core" }
candle-nn = { git = "https://github.com/huggingface/candle", rev = "f76bb7794aa8659c5023797979a3392cdfc01f32", package = "candle-nn" }
candle-transformers = { git = "https://github.com/huggingface/candle", rev = "f76bb7794aa8659c5023797979a3392cdfc01f32", package = "candle-transformers" }
candle-flash-attn = { git = "https://github.com/huggingface/candle", rev = "f76bb7794aa8659c5023797979a3392cdfc01f32", package = "candle-flash-attn" }

[target.'cfg(target_os = "linux")'.dependencies]
openssl = { version = "0.10", features = ["vendored"] }

[lib]
crate_type = ["cdylib"]
crate-type = ["cdylib"]

[features]
cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda"]
cuda = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda", "dep:candle-cublaslt", "dep:candle-layer-norm", "dep:candle-rotary"]
flash-attn = ["cuda", "candle-transformers/flash-attn", "dep:candle-flash-attn"]
cublaslt = ["cudarc/cublaslt"]
Loading

0 comments on commit 6b642e1

Please sign in to comment.