Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[wip] inference service #12

Merged
merged 30 commits into from
Apr 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
8237fb9
Refactor core_thread.rs and remove service.rs
jorgeantonio21 Mar 23, 2024
ddbdc1e
Refactor: Moved , , and related types to a separate module
jorgeantonio21 Mar 23, 2024
df7a457
feat: Update dependencies and refactor configuration handling
jorgeantonio21 Mar 25, 2024
919298a
add hugging face client logic
jorgeantonio21 Mar 25, 2024
b8c8edc
feat: Add hf-hub crate version 0.3.2
jorgeantonio21 Mar 25, 2024
9807bc9
refactor: Rename storage_base_path to storage_folder
jorgeantonio21 Mar 26, 2024
ab64ad3
Refactor core.rs and main.rs, introducing tracing for improved debugg…
jorgeantonio21 Mar 26, 2024
7a22bd3
Add tracing-subscriber crate to Cargo.toml
jorgeantonio21 Mar 26, 2024
65f6e01
address PR comments
jorgeantonio21 Mar 26, 2024
1ce0c12
refactor core thread to model thread, to facilitate models running in…
jorgeantonio21 Mar 27, 2024
156b398
remove core, rename core_thread to model_thread, and work on setting …
jorgeantonio21 Mar 27, 2024
6990a12
add model_thread.rs, after renaming
jorgeantonio21 Mar 27, 2024
d52f113
intermediate steps
jorgeantonio21 Mar 27, 2024
618bea8
intermediate steps
jorgeantonio21 Mar 27, 2024
54e0abd
intermediate steps
jorgeantonio21 Mar 27, 2024
e60b586
address new PR comments
jorgeantonio21 Mar 27, 2024
58cfca5
add test to config construction
jorgeantonio21 Mar 27, 2024
1cdb66a
remove unused code
jorgeantonio21 Mar 27, 2024
b56d0b5
remove full dependency of std::sync
jorgeantonio21 Mar 27, 2024
04b6d6c
change to main branch
jorgeantonio21 Mar 28, 2024
d0f6dff
Merge pull request #13 from atoma-network/experiments
jorgeantonio21 Mar 30, 2024
a19817e
add model trait interface and refactor code to be more general
jorgeantonio21 Mar 31, 2024
4a12b71
rename InferenceService to ModelService
jorgeantonio21 Mar 31, 2024
cddb534
simplify code
jorgeantonio21 Mar 31, 2024
f673dea
remove fetch method from ModelTrait
jorgeantonio21 Mar 31, 2024
e403ba9
cargo fmt
jorgeantonio21 Mar 31, 2024
b8a51ac
rename
jorgeantonio21 Mar 31, 2024
6193465
remove unused error fields
jorgeantonio21 Mar 31, 2024
303aedf
removed unused Builder from ModelTrait associated type
jorgeantonio21 Apr 1, 2024
415974f
merge main and resolve conflicts
jorgeantonio21 Apr 2, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,17 @@ candle = { git = "https://github.com/huggingface/candle", package = "candle-core
candle-flash-attn = { git = "https://github.com/huggingface/candle", package = "candle-flash-attn", version = "0.4.2" }
candle-nn = { git = "https://github.com/huggingface/candle", package = "candle-nn", version = "0.4.2" }
candle-transformers = { git = "https://github.com/huggingface/candle", package = "candle-transformers", version = "0.4.2" }
config = "0.14.0"
ed25519-consensus = "2.1.0"
futures = "0.3.30"
hf-hub = "0.3.2"
serde = "1.0.197"
serde_json = "1.0.114"
rand = "0.8.5"
reqwest = "0.12.1"
thiserror = "1.0.58"
tokenizers = "0.15.2"
tokio = "1.36.0"
toml = "0.8.12"
tracing = "0.1.40"
tracing-subscriber = "0.3.18"
11 changes: 11 additions & 0 deletions atoma-inference/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,23 @@ candle.workspace = true
candle-flash-attn = { workspace = true, optional = true }
candle-nn.workspace = true
candle-transformers.workspace = true
config.true = true
ed25519-consensus.workspace = true
futures.workspace = true
hf-hub.workspace = true
reqwest = { workspace = true, features = ["json"] }
serde = { workspace = true, features = ["derive"] }
serde_json.workspace = true
thiserror.workspace = true
tokenizers.workspace = true
tokio = { workspace = true, features = ["full", "tracing"] }
tracing.workspace = true
tracing-subscriber.workspace = true

[dev-dependencies]
rand.workspace = true
toml.workspace = true


[features]
accelerate = ["candle/accelerate", "candle-nn/accelerate", "candle-transformers/accelerate"]
Expand Down
117 changes: 117 additions & 0 deletions atoma-inference/src/apis/hugging_face.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
use std::path::PathBuf;

use async_trait::async_trait;
use hf_hub::api::sync::{Api, ApiBuilder};

use crate::models::ModelId;

use super::{ApiError, ApiTrait};

struct FilePaths {
file_paths: Vec<String>,
}

fn get_model_safe_tensors_from_hf(model_id: &ModelId) -> (String, FilePaths) {
match model_id.as_str() {
"Llama2_7b" => (
String::from("meta-llama/Llama-2-7b-hf"),
FilePaths {
file_paths: vec![
"model-00001-of-00002.safetensors".to_string(),
"model-00002-of-00002.safetensors".to_string(),
],
},
),
"Mamba3b" => (
String::from("state-spaces/mamba-2.8b-hf"),
FilePaths {
file_paths: vec![
"model-00001-of-00003.safetensors".to_string(),
"model-00002-of-00003.safetensors".to_string(),
"model-00003-of-00003.safetensors".to_string(),
],
},
),
"Mistral7b" => (
String::from("mistralai/Mistral-7B-Instruct-v0.2"),
FilePaths {
file_paths: vec![
"model-00001-of-00003.safetensors".to_string(),
"model-00002-of-00003.safetensors".to_string(),
"model-00003-of-00003.safetensors".to_string(),
],
},
),
"Mixtral8x7b" => (
String::from("mistralai/Mixtral-8x7B-Instruct-v0.1"),
FilePaths {
file_paths: vec![
"model-00001-of-00019.safetensors".to_string(),
"model-00002-of-00019.safetensors".to_string(),
"model-00003-of-00019.safetensors".to_string(),
"model-00004-of-00019.safetensors".to_string(),
"model-00005-of-00019.safetensors".to_string(),
"model-00006-of-00019.safetensors".to_string(),
"model-00007-of-00019.safetensors".to_string(),
"model-00008-of-00019.safetensors".to_string(),
"model-00009-of-00019.safetensors".to_string(),
"model-000010-of-00019.safetensors".to_string(),
"model-000011-of-00019.safetensors".to_string(),
"model-000012-of-00019.safetensors".to_string(),
"model-000013-of-00019.safetensors".to_string(),
"model-000014-of-00019.safetensors".to_string(),
"model-000015-of-00019.safetensors".to_string(),
"model-000016-of-00019.safetensors".to_string(),
"model-000017-of-00019.safetensors".to_string(),
"model-000018-of-00019.safetensors".to_string(),
"model-000019-of-00019.safetensors".to_string(),
],
},
),
"StableDiffusion2" => (
String::from("stabilityai/stable-diffusion-2"),
FilePaths {
file_paths: vec!["768-v-ema.safetensors".to_string()],
},
),
"StableDiffusionXl" => (
String::from("stabilityai/stable-diffusion-xl-base-1.0"),
FilePaths {
file_paths: vec![
"sd_xl_base_1.0.safetensors".to_string(),
"sd_xl_base_1.0_0.9vae.safetensors".to_string(),
"sd_xl_offset_example-lora_1.0.safetensors".to_string(),
],
},
),
_ => {
panic!("Invalid model id")
}
}
}

#[async_trait]
impl ApiTrait for Api {
fn create(api_key: String, cache_dir: PathBuf) -> Result<Self, ApiError>
where
Self: Sized,
{
Ok(ApiBuilder::new()
.with_progress(true)
.with_token(Some(api_key))
.with_cache_dir(cache_dir)
.build()?)
}

fn fetch(&self, model_id: &ModelId) -> Result<Vec<PathBuf>, ApiError> {
let (model_path, files) = get_model_safe_tensors_from_hf(model_id);
let api_repo = self.model(model_path);
let mut path_bufs = Vec::with_capacity(files.file_paths.len());

for file in files.file_paths {
path_bufs.push(api_repo.get(&file)?);
}

Ok(path_bufs)
}
}
29 changes: 29 additions & 0 deletions atoma-inference/src/apis/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
pub mod hugging_face;
use hf_hub::api::sync::ApiError as HuggingFaceError;

use std::path::PathBuf;

use thiserror::Error;

use crate::models::ModelId;

#[derive(Debug, Error)]
pub enum ApiError {
#[error("Api Error: `{0}`")]
ApiError(String),
#[error("HuggingFace API error: `{0}`")]
HuggingFaceError(HuggingFaceError),
}

impl From<HuggingFaceError> for ApiError {
fn from(error: HuggingFaceError) -> Self {
Self::HuggingFaceError(error)
}
}

pub trait ApiTrait: Send {
fn fetch(&self, model_id: &ModelId) -> Result<Vec<PathBuf>, ApiError>;
fn create(api_key: String, cache_dir: PathBuf) -> Result<Self, ApiError>
where
Self: Sized;
}
66 changes: 0 additions & 66 deletions atoma-inference/src/config.rs

This file was deleted.

89 changes: 0 additions & 89 deletions atoma-inference/src/core_thread.rs

This file was deleted.

7 changes: 4 additions & 3 deletions atoma-inference/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
pub mod config;
pub mod core_thread;
pub mod models;
pub mod model_thread;
pub mod service;
pub mod specs;
pub mod types;

pub mod apis;
pub mod models;
17 changes: 15 additions & 2 deletions atoma-inference/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,16 @@
fn main() {
println!("Hello, world!");
// use hf_hub::api::sync::Api;
// use inference::service::ModelService;

#[tokio::main]
async fn main() {
tracing_subscriber::fmt::init();

// let (_, receiver) = tokio::sync::mpsc::channel(32);

// let _ = ModelService::start::<Model, Api>(
// "../inference.toml".parse().unwrap(),
// "../private_key".parse().unwrap(),
// receiver,
// )
// .expect("Failed to start inference service");
}
Loading
Loading