Skip to content

Commit

Permalink
Allow automatically choose adapter in examples.
Browse files Browse the repository at this point in the history
  • Loading branch information
cryscan committed Mar 29, 2024
1 parent 0c118c6 commit b016a76
Show file tree
Hide file tree
Showing 6 changed files with 46 additions and 16 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ keywords = ["deep-learning", "language", "model", "rwkv"]
license = "MIT OR Apache-2.0"
name = "web-rwkv"
repository = "https://github.com/cryscan/web-rwkv"
version = "0.6.35"
version = "0.6.36"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

Expand Down
12 changes: 9 additions & 3 deletions examples/batch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,14 @@ fn sample(probs: Vec<f32>, _top_p: f32) -> u16 {
.0 as u16
}

async fn create_context(info: &ModelInfo) -> Result<Context> {
async fn create_context(info: &ModelInfo, _auto: bool) -> Result<Context> {
let instance = Instance::new();
#[cfg(not(debug_assertions))]
let adapter = {
let adapter = if _auto {
instance
.adapter(wgpu::PowerPreference::HighPerformance)
.await?
} else {
let backends = wgpu::Backends::all();
let adapters = instance
.enumerate_adapters(backends)
Expand Down Expand Up @@ -176,7 +180,7 @@ async fn run(cli: Cli) -> Result<()> {
};
let lora = lora.as_deref();

let context = create_context(&info).await?;
let context = create_context(&info, cli.adapter).await?;
match info.version {
ModelVersion::V4 => {
let (model, state) = load_model(
Expand Down Expand Up @@ -375,6 +379,8 @@ struct Cli {
token_chunk_size: usize,
#[arg(short, long, default_value_t = 4)]
batch: usize,
#[arg(short, long, action)]
adapter: bool,
}

#[tokio::main]
Expand Down
12 changes: 9 additions & 3 deletions examples/chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,14 @@ impl Sampler {
}
}

async fn create_context(info: &ModelInfo) -> Result<Context> {
async fn create_context(info: &ModelInfo, _auto: bool) -> Result<Context> {
let instance = Instance::new();
#[cfg(not(debug_assertions))]
let adapter = {
let adapter = if _auto {
instance
.adapter(wgpu::PowerPreference::HighPerformance)
.await?
} else {
let backends = wgpu::Backends::all();
let adapters = instance
.enumerate_adapters(backends)
Expand Down Expand Up @@ -211,7 +215,7 @@ async fn run(cli: Cli) -> Result<()> {
};
let lora = lora.as_deref();

let context = create_context(&info).await?;
let context = create_context(&info, cli.adapter).await?;
match info.version {
ModelVersion::V4 => {
let (model, state) = load_model(
Expand Down Expand Up @@ -407,6 +411,8 @@ struct Cli {
token_chunk_size: usize,
#[command(flatten)]
sampler: Sampler,
#[arg(short, long, action)]
adapter: bool,
}

#[derive(Debug, Deserialize)]
Expand Down
12 changes: 9 additions & 3 deletions examples/gen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,14 @@ fn sample(probs: &[f32], _top_p: f32) -> u16 {
.0 as u16
}

async fn create_context(info: &ModelInfo) -> Result<Context> {
async fn create_context(info: &ModelInfo, _auto: bool) -> Result<Context> {
let instance = Instance::new();
#[cfg(not(debug_assertions))]
let adapter = {
let adapter = if _auto {
instance
.adapter(wgpu::PowerPreference::HighPerformance)
.await?
} else {
let backends = wgpu::Backends::all();
let adapters = instance
.enumerate_adapters(backends)
Expand Down Expand Up @@ -145,7 +149,7 @@ async fn run(cli: Cli) -> Result<()> {
};
let lora = lora.as_deref();

let context = create_context(&info).await?;
let context = create_context(&info, cli.adapter).await?;
match info.version {
ModelVersion::V4 => {
let (model, state) = load_model(
Expand Down Expand Up @@ -293,6 +297,8 @@ struct Cli {
turbo: bool,
#[arg(long, default_value_t = 32)]
token_chunk_size: usize,
#[arg(short, long, action)]
adapter: bool,
}

#[tokio::main]
Expand Down
12 changes: 9 additions & 3 deletions examples/inspector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,14 @@ fn sample(probs: &[f32], _top_p: f32) -> u16 {
.0 as u16
}

async fn create_context(info: &ModelInfo) -> Result<Context> {
async fn create_context(info: &ModelInfo, _auto: bool) -> Result<Context> {
let instance = Instance::new();
#[cfg(not(debug_assertions))]
let adapter = {
let adapter = if _auto {
instance
.adapter(wgpu::PowerPreference::HighPerformance)
.await?
} else {
let backends = wgpu::Backends::all();
let adapters = instance
.enumerate_adapters(backends)
Expand Down Expand Up @@ -169,7 +173,7 @@ async fn run(cli: Cli) -> Result<()> {
};
let lora = lora.as_deref();

let context = create_context(&info).await?;
let context = create_context(&info, cli.adapter).await?;
let (model, state) = load_model::<v5::Model<f16>, _>(
&context,
&data,
Expand Down Expand Up @@ -325,6 +329,8 @@ struct Cli {
token_chunk_size: usize,
#[arg(short, long)]
prompt: Option<String>,
#[arg(short, long, action)]
adapter: bool,
}

#[tokio::main]
Expand Down
12 changes: 9 additions & 3 deletions examples/serialization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,14 @@ fn sample(probs: &[f32], _top_p: f32) -> u16 {
.0 as u16
}

async fn create_context(info: &ModelInfo) -> Result<Context> {
async fn create_context(info: &ModelInfo, _auto: bool) -> Result<Context> {
let instance = Instance::new();
#[cfg(not(debug_assertions))]
let adapter = {
let adapter = if _auto {
instance
.adapter(wgpu::PowerPreference::HighPerformance)
.await?
} else {
let backends = wgpu::Backends::all();
let adapters = instance
.enumerate_adapters(backends)
Expand Down Expand Up @@ -146,7 +150,7 @@ async fn run(cli: Cli) -> Result<()> {
};
let lora = lora.as_deref();

let context = create_context(&info).await?;
let context = create_context(&info, cli.adapter).await?;
match info.version {
ModelVersion::V4 => {
let (model, state) = load_model(
Expand Down Expand Up @@ -319,6 +323,8 @@ struct Cli {
token_chunk_size: usize,
#[arg(short, long, value_name = "FILE")]
output: Option<PathBuf>,
#[arg(short, long, action)]
adapter: bool,
}

#[tokio::main]
Expand Down

0 comments on commit b016a76

Please sign in to comment.