diff --git a/Cargo.toml b/Cargo.toml index bcfa326..99ba2a0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,7 +9,7 @@ keywords = ["deep-learning", "language", "model", "rwkv"] license = "MIT OR Apache-2.0" name = "web-rwkv" repository = "https://github.com/cryscan/web-rwkv" -version = "0.6.35" +version = "0.6.36" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html diff --git a/examples/batch.rs b/examples/batch.rs index bad7d7d..86f51c3 100644 --- a/examples/batch.rs +++ b/examples/batch.rs @@ -45,10 +45,14 @@ fn sample(probs: Vec, _top_p: f32) -> u16 { .0 as u16 } -async fn create_context(info: &ModelInfo) -> Result { +async fn create_context(info: &ModelInfo, _auto: bool) -> Result { let instance = Instance::new(); #[cfg(not(debug_assertions))] - let adapter = { + let adapter = if _auto { + instance + .adapter(wgpu::PowerPreference::HighPerformance) + .await? + } else { let backends = wgpu::Backends::all(); let adapters = instance .enumerate_adapters(backends) @@ -176,7 +180,7 @@ async fn run(cli: Cli) -> Result<()> { }; let lora = lora.as_deref(); - let context = create_context(&info).await?; + let context = create_context(&info, cli.adapter).await?; match info.version { ModelVersion::V4 => { let (model, state) = load_model( @@ -375,6 +379,8 @@ struct Cli { token_chunk_size: usize, #[arg(short, long, default_value_t = 4)] batch: usize, + #[arg(short, long, action)] + adapter: bool, } #[tokio::main] diff --git a/examples/chat.rs b/examples/chat.rs index 184e07e..bd97cd5 100644 --- a/examples/chat.rs +++ b/examples/chat.rs @@ -75,10 +75,14 @@ impl Sampler { } } -async fn create_context(info: &ModelInfo) -> Result { +async fn create_context(info: &ModelInfo, _auto: bool) -> Result { let instance = Instance::new(); #[cfg(not(debug_assertions))] - let adapter = { + let adapter = if _auto { + instance + .adapter(wgpu::PowerPreference::HighPerformance) + .await? + } else { let backends = wgpu::Backends::all(); let adapters = instance .enumerate_adapters(backends) @@ -211,7 +215,7 @@ async fn run(cli: Cli) -> Result<()> { }; let lora = lora.as_deref(); - let context = create_context(&info).await?; + let context = create_context(&info, cli.adapter).await?; match info.version { ModelVersion::V4 => { let (model, state) = load_model( @@ -407,6 +411,8 @@ struct Cli { token_chunk_size: usize, #[command(flatten)] sampler: Sampler, + #[arg(short, long, action)] + adapter: bool, } #[derive(Debug, Deserialize)] diff --git a/examples/gen.rs b/examples/gen.rs index 687fab1..efda636 100644 --- a/examples/gen.rs +++ b/examples/gen.rs @@ -34,10 +34,14 @@ fn sample(probs: &[f32], _top_p: f32) -> u16 { .0 as u16 } -async fn create_context(info: &ModelInfo) -> Result { +async fn create_context(info: &ModelInfo, _auto: bool) -> Result { let instance = Instance::new(); #[cfg(not(debug_assertions))] - let adapter = { + let adapter = if _auto { + instance + .adapter(wgpu::PowerPreference::HighPerformance) + .await? + } else { let backends = wgpu::Backends::all(); let adapters = instance .enumerate_adapters(backends) @@ -145,7 +149,7 @@ async fn run(cli: Cli) -> Result<()> { }; let lora = lora.as_deref(); - let context = create_context(&info).await?; + let context = create_context(&info, cli.adapter).await?; match info.version { ModelVersion::V4 => { let (model, state) = load_model( @@ -293,6 +297,8 @@ struct Cli { turbo: bool, #[arg(long, default_value_t = 32)] token_chunk_size: usize, + #[arg(short, long, action)] + adapter: bool, } #[tokio::main] diff --git a/examples/inspector.rs b/examples/inspector.rs index 76fd056..d1c1e40 100644 --- a/examples/inspector.rs +++ b/examples/inspector.rs @@ -55,10 +55,14 @@ fn sample(probs: &[f32], _top_p: f32) -> u16 { .0 as u16 } -async fn create_context(info: &ModelInfo) -> Result { +async fn create_context(info: &ModelInfo, _auto: bool) -> Result { let instance = Instance::new(); #[cfg(not(debug_assertions))] - let adapter = { + let adapter = if _auto { + instance + .adapter(wgpu::PowerPreference::HighPerformance) + .await? + } else { let backends = wgpu::Backends::all(); let adapters = instance .enumerate_adapters(backends) @@ -169,7 +173,7 @@ async fn run(cli: Cli) -> Result<()> { }; let lora = lora.as_deref(); - let context = create_context(&info).await?; + let context = create_context(&info, cli.adapter).await?; let (model, state) = load_model::, _>( &context, &data, @@ -325,6 +329,8 @@ struct Cli { token_chunk_size: usize, #[arg(short, long)] prompt: Option, + #[arg(short, long, action)] + adapter: bool, } #[tokio::main] diff --git a/examples/serialization.rs b/examples/serialization.rs index 2506c87..0496688 100644 --- a/examples/serialization.rs +++ b/examples/serialization.rs @@ -35,10 +35,14 @@ fn sample(probs: &[f32], _top_p: f32) -> u16 { .0 as u16 } -async fn create_context(info: &ModelInfo) -> Result { +async fn create_context(info: &ModelInfo, _auto: bool) -> Result { let instance = Instance::new(); #[cfg(not(debug_assertions))] - let adapter = { + let adapter = if _auto { + instance + .adapter(wgpu::PowerPreference::HighPerformance) + .await? + } else { let backends = wgpu::Backends::all(); let adapters = instance .enumerate_adapters(backends) @@ -146,7 +150,7 @@ async fn run(cli: Cli) -> Result<()> { }; let lora = lora.as_deref(); - let context = create_context(&info).await?; + let context = create_context(&info, cli.adapter).await?; match info.version { ModelVersion::V4 => { let (model, state) = load_model( @@ -319,6 +323,8 @@ struct Cli { token_chunk_size: usize, #[arg(short, long, value_name = "FILE")] output: Option, + #[arg(short, long, action)] + adapter: bool, } #[tokio::main]