Skip to content

Commit

Permalink
Merge pull request #1795 from fermyon/allow-other-models
Browse files Browse the repository at this point in the history
Allow other ai models besides while known ones
  • Loading branch information
rylev authored Oct 5, 2023
2 parents 2269651 + e74b2a8 commit 362c75d
Show file tree
Hide file tree
Showing 4 changed files with 135 additions and 24 deletions.
46 changes: 46 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

94 changes: 88 additions & 6 deletions crates/llm-local/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@ use candle::DType;
use candle_nn::VarBuilder;
use llm::{
InferenceFeedback, InferenceParameters, InferenceResponse, InferenceSessionConfig, Model,
ModelKVMemoryType, ModelParameters,
ModelArchitecture, ModelKVMemoryType, ModelParameters,
};
use rand::SeedableRng;
use spin_core::async_trait;
use spin_llm::{model_arch, model_name, LlmEngine, MODEL_ALL_MINILM_L6_V2};
use spin_llm::{LlmEngine, MODEL_ALL_MINILM_L6_V2};
use spin_world::llm::{self as wasi_llm};
use std::{
collections::hash_map::Entry,
Expand Down Expand Up @@ -170,14 +170,22 @@ impl LocalLlmEngine {
&mut self,
model: wasi_llm::InferencingModel,
) -> Result<Arc<dyn Model>, wasi_llm::Error> {
let model_name = model_name(&model)?;
let use_gpu = self.use_gpu;
let progress_fn = |_| {};
let model = match self.inferencing_models.entry((model_name.into(), use_gpu)) {
let model = match self.inferencing_models.entry((model.clone(), use_gpu)) {
Entry::Occupied(o) => o.get().clone(),
Entry::Vacant(v) => v
.insert({
let path = self.registry.join(model_name);
let (path, arch) = if let Some(arch) = well_known_inferencing_model_arch(&model) {
let model_binary = self.registry.join(&model);
if model_binary.exists() {
(model_binary, arch.to_owned())
} else {
walk_registry_for_model(&self.registry, model).await?
}
} else {
walk_registry_for_model(&self.registry, model).await?
};
if !self.registry.exists() {
return Err(wasi_llm::Error::RuntimeError(
format!("The directory expected to house the inferencing model '{}' does not exist.", self.registry.display())
Expand All @@ -199,7 +207,7 @@ impl LocalLlmEngine {
n_gqa: None,
};
let model = llm::load_dynamic(
Some(model_arch(&model)?),
Some(arch),
&path,
llm::TokenizerSource::Embedded,
params,
Expand All @@ -223,6 +231,80 @@ impl LocalLlmEngine {
}
}

/// Get the model binary and arch from walking the registry file structure
async fn walk_registry_for_model(
registry_path: &Path,
model: String,
) -> Result<(PathBuf, ModelArchitecture), wasi_llm::Error> {
let mut arch_dirs = tokio::fs::read_dir(registry_path).await.map_err(|e| {
wasi_llm::Error::RuntimeError(format!(
"Could not read model registry directory '{}': {e}",
registry_path.display()
))
})?;
let mut result = None;
'outer: while let Some(arch_dir) = arch_dirs.next_entry().await.map_err(|e| {
wasi_llm::Error::RuntimeError(format!(
"Failed to read arch directory in model registry: {e}"
))
})? {
if arch_dir
.file_type()
.await
.map_err(|e| {
wasi_llm::Error::RuntimeError(format!(
"Could not read file type of '{}' dir: {e}",
arch_dir.path().display()
))
})?
.is_file()
{
continue;
}
let mut model_files = tokio::fs::read_dir(arch_dir.path()).await.map_err(|e| {
wasi_llm::Error::RuntimeError(format!(
"Error reading architecture directory in model registry: {e}"
))
})?;
while let Some(model_file) = model_files.next_entry().await.map_err(|e| {
wasi_llm::Error::RuntimeError(format!(
"Error reading model file in model registry: {e}"
))
})? {
if model_file
.file_name()
.to_str()
.map(|m| m == model)
.unwrap_or_default()
{
let arch = arch_dir.file_name();
let arch = arch
.to_str()
.ok_or(wasi_llm::Error::ModelNotSupported)?
.parse()
.map_err(|_| wasi_llm::Error::ModelNotSupported)?;
result = Some((model_file.path(), arch));
break 'outer;
}
}
}

result.ok_or_else(|| {
wasi_llm::Error::InvalidInput(format!(
"no model directory found in registry for model '{model}'"
))
})
}

fn well_known_inferencing_model_arch(
model: &wasi_llm::InferencingModel,
) -> Option<ModelArchitecture> {
match model.as_str() {
"llama2-chat" | "code_llama" => Some(ModelArchitecture::Llama),
_ => None,
}
}

async fn generate_embeddings(
data: Vec<String>,
model: Arc<(tokenizers::Tokenizer, BertModel)>,
Expand Down
2 changes: 1 addition & 1 deletion crates/llm/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ anyhow = "1.0"
bytesize = "1.1"
llm = { git = "https://github.com/rustformers/llm", rev = "2f6ffd4435799ceaa1d1bcb5a8790e5b3e0c5663", features = [
"tokenizers-remote",
"llama",
"models",
], default-features = false }
spin-app = { path = "../app" }
spin-core = { path = "../core" }
Expand Down
17 changes: 0 additions & 17 deletions crates/llm/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
pub mod host_component;

use llm::ModelArchitecture;
use spin_app::MetadataKey;
use spin_core::async_trait;
use spin_world::llm::{self as wasi_llm};
Expand Down Expand Up @@ -72,22 +71,6 @@ impl wasi_llm::Host for LlmDispatch {
}
}

pub fn model_name(model: &wasi_llm::InferencingModel) -> Result<&str, wasi_llm::Error> {
match model.as_str() {
"llama2-chat" | "codellama-instruct" => Ok(model.as_str()),
_ => Err(wasi_llm::Error::ModelNotSupported),
}
}

pub fn model_arch(
model: &wasi_llm::InferencingModel,
) -> Result<ModelArchitecture, wasi_llm::Error> {
match model.as_str() {
"llama2-chat" | "codellama-instruct" => Ok(ModelArchitecture::Llama),
_ => Err(wasi_llm::Error::ModelNotSupported),
}
}

fn access_denied_error(model: &str) -> wasi_llm::Error {
wasi_llm::Error::InvalidInput(format!(
"The component does not have access to use '{model}'. To give the component access, add '{model}' to the 'ai_models' key for the component in your spin.toml manifest"
Expand Down

0 comments on commit 362c75d

Please sign in to comment.