-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: simplify download management, model file should be able to indi…
…vidually introduced
- Loading branch information
Showing
13 changed files
with
224 additions
and
245 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
use std::{fs, path::PathBuf}; | ||
|
||
use anyhow::Result; | ||
use serde::{Deserialize, Serialize}; | ||
|
||
use crate::path::models_dir; | ||
|
||
#[derive(Serialize, Deserialize)] | ||
pub struct ModelInfo { | ||
pub name: String, | ||
#[serde(skip_serializing_if = "Option::is_none")] | ||
pub prompt_template: Option<String>, | ||
#[serde(skip_serializing_if = "Option::is_none")] | ||
pub chat_template: Option<String>, | ||
pub urls: Vec<String>, | ||
pub sha256: String, | ||
} | ||
|
||
fn models_json_file(registry: &str) -> PathBuf { | ||
models_dir().join(registry).join("models.json") | ||
} | ||
|
||
async fn load_remote_registry(registry: &str) -> Result<Vec<ModelInfo>> { | ||
let value = reqwest::get(format!( | ||
"https://raw.githubusercontent.com/{}/registry-tabby/main/models.json", | ||
registry | ||
)) | ||
.await? | ||
.json() | ||
.await?; | ||
fs::create_dir_all(models_dir().join(registry))?; | ||
serdeconv::to_json_file(&value, models_json_file(registry))?; | ||
Ok(value) | ||
} | ||
|
||
fn load_local_registry(registry: &str) -> Result<Vec<ModelInfo>> { | ||
Ok(serdeconv::from_json_file(models_json_file(registry))?) | ||
} | ||
|
||
#[derive(Default)] | ||
pub struct ModelRegistry { | ||
pub name: String, | ||
pub models: Vec<ModelInfo>, | ||
} | ||
|
||
impl ModelRegistry { | ||
pub async fn new(registry: &str) -> Self { | ||
Self { | ||
name: registry.to_owned(), | ||
models: load_remote_registry(registry).await.unwrap_or_else(|err| { | ||
load_local_registry(registry).unwrap_or_else(|_| { | ||
panic!( | ||
"Failed to fetch model organization <{}>: {:?}", | ||
registry, err | ||
) | ||
}) | ||
}), | ||
} | ||
} | ||
|
||
pub fn get_model_path(&self, name: &str) -> PathBuf { | ||
models_dir() | ||
.join(&self.name) | ||
.join(name) | ||
.join("ggml/q8_0.v2.gguf") | ||
} | ||
|
||
pub fn get_model_info(&self, name: &str) -> &ModelInfo { | ||
self.models | ||
.iter() | ||
.find(|x| x.name == name) | ||
.unwrap_or_else(|| panic!("Invalid model_id <{}/{}>", self.name, name)) | ||
} | ||
} | ||
|
||
pub fn parse_model_id(model_id: &str) -> (&str, &str) { | ||
let parts: Vec<_> = model_id.split('/').collect(); | ||
if parts.len() != 2 { | ||
panic!("Invalid model id {}", model_id); | ||
} | ||
|
||
(parts[0], parts[1]) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.