Skip to content

Commit

Permalink
Adapter selection.
Browse files Browse the repository at this point in the history
  • Loading branch information
cryscan committed Aug 1, 2023
1 parent 4c3f430 commit 8781b9e
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 29 deletions.
3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "web-rwkv"
version = "0.1.16"
version = "0.1.17"
edition = "2021"
authors = ["Zhenyuan Zhang <[email protected]>"]
license = "MIT OR Apache-2.0"
Expand Down Expand Up @@ -33,6 +33,7 @@ memmap2 = "0.7"
itertools = "0.11"
fastrand = "2.0"
clap = { version = "4.3", features = ["derive"] }
dialoguer = "0.10"

[target.'cfg(target_arch = "wasm32")'.dev-dependencies]
wasm-bindgen = "0.2"
Expand Down
14 changes: 12 additions & 2 deletions examples/chat.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
use ahash::{HashMap, HashMapExt};
use anyhow::Result;
use clap::{Args, Parser};
use dialoguer::{theme::ColorfulTheme, Select};
use itertools::Itertools;
use memmap2::Mmap;
use std::{
fs::File,
io::{BufReader, Read, Write},
path::PathBuf,
};
use web_rwkv::{Environment, LayerFlags, Model, ModelBuilder, Quantization, Tokenizer};
use web_rwkv::{Environment, Instance, LayerFlags, Model, ModelBuilder, Quantization, Tokenizer};

#[derive(Debug, Clone, Args)]
struct Sampler {
Expand Down Expand Up @@ -60,7 +61,16 @@ impl Sampler {
}

async fn create_environment() -> Result<Environment> {
let env = Environment::create().await?;
let instance = Instance::new();
let adapters = instance.adapters();
let selection = Select::with_theme(&ColorfulTheme::default())
.with_prompt("Please select an adapter")
.default(0)
.items(&adapters)
.interact()?;

let adapter = instance.select_adapter(selection)?;
let env = Environment::new(adapter).await?;
println!("{:#?}", env.adapter.get_info());
Ok(env)
}
Expand Down
14 changes: 12 additions & 2 deletions examples/gen.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use anyhow::Result;
use clap::Parser;
use dialoguer::{theme::ColorfulTheme, Select};
use itertools::Itertools;
use memmap2::Mmap;
use std::{
Expand All @@ -8,7 +9,7 @@ use std::{
path::PathBuf,
time::Instant,
};
use web_rwkv::{Environment, LayerFlags, Model, ModelBuilder, Quantization, Tokenizer};
use web_rwkv::{Environment, Instance, LayerFlags, Model, ModelBuilder, Quantization, Tokenizer};

fn sample(probs: Vec<f32>, top_p: f32) -> u16 {
let sorted = probs
Expand Down Expand Up @@ -37,7 +38,16 @@ fn sample(probs: Vec<f32>, top_p: f32) -> u16 {
}

async fn create_environment() -> Result<Environment> {
let env = Environment::create().await?;
let instance = Instance::new();
let adapters = instance.adapters();
let selection = Select::with_theme(&ColorfulTheme::default())
.with_prompt("Please select an adapter")
.default(0)
.items(&adapters)
.interact()?;

let adapter = instance.select_adapter(selection)?;
let env = Environment::new(adapter).await?;
println!("{:#?}", env.adapter.get_info());
Ok(env)
}
Expand Down
72 changes: 49 additions & 23 deletions src/environment.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,55 @@
use anyhow::Result;
use std::sync::Arc;
use wgpu::{
Adapter, Backends, Device, DeviceDescriptor, Dx12Compiler, Instance, InstanceDescriptor,
PowerPreference, Queue, RequestAdapterOptions,
};
use wgpu::{Adapter, Backends, Device, DeviceDescriptor, Dx12Compiler, InstanceDescriptor, Queue};

#[derive(Clone)]
pub struct Instance(pub Arc<wgpu::Instance>);

impl std::ops::Deref for Instance {
type Target = wgpu::Instance;

fn deref(&self) -> &Self::Target {
&self.0
}
}

impl Default for Instance {
fn default() -> Self {
Self::new()
}
}

impl Instance {
pub const BACKENDS: Backends = Backends::PRIMARY;

pub fn new() -> Self {
let instance = wgpu::Instance::new(InstanceDescriptor {
backends: Self::BACKENDS,
dx12_shader_compiler: Dx12Compiler::Dxc {
dxil_path: None,
dxc_path: None,
},
});
Self(Arc::new(instance))
}

pub fn adapters(&self) -> Vec<String> {
self.enumerate_adapters(Self::BACKENDS)
.map(|adapter| {
let info = adapter.get_info();
format!("{} ({:?})", info.name, info.backend)
})
.collect()
}

pub fn select_adapter(&self, selection: usize) -> Result<Adapter, CreateEnvironmentError> {
self.enumerate_adapters(Self::BACKENDS)
.nth(selection)
.ok_or(CreateEnvironmentError::RequestAdapterFailed)
}
}

#[derive(Clone)]
pub struct Environment {
pub instance: Arc<Instance>,
pub adapter: Arc<Adapter>,
pub device: Arc<Device>,
pub queue: Arc<Queue>,
Expand All @@ -31,22 +73,7 @@ impl std::fmt::Display for CreateEnvironmentError {
impl std::error::Error for CreateEnvironmentError {}

impl Environment {
pub async fn create() -> Result<Self> {
let instance = Instance::new(InstanceDescriptor {
backends: Backends::PRIMARY,
dx12_shader_compiler: Dx12Compiler::Dxc {
dxil_path: None,
dxc_path: None,
},
});
let adapter = instance
.request_adapter(&RequestAdapterOptions {
power_preference: PowerPreference::HighPerformance,
force_fallback_adapter: false,
compatible_surface: None,
})
.await
.ok_or(CreateEnvironmentError::RequestAdapterFailed)?;
pub async fn new(adapter: Adapter) -> Result<Self, CreateEnvironmentError> {
let (device, queue) = adapter
.request_device(
&DeviceDescriptor {
Expand All @@ -60,7 +87,6 @@ impl Environment {
.map_err(|_| CreateEnvironmentError::RequestDeviceFailed)?;

Ok(Self {
instance: Arc::new(instance),
adapter: Arc::new(adapter),
device: Arc::new(device),
queue: Arc::new(queue),
Expand Down
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ mod environment;
mod model;
mod tokenizer;

pub use environment::{CreateEnvironmentError, Environment};
pub use environment::{CreateEnvironmentError, Environment, Instance};
pub use model::{
BackedModelState, LayerFlags, Model, ModelBuffer, ModelBuilder, ModelState, Quantization,
};
Expand Down

0 comments on commit 8781b9e

Please sign in to comment.