Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[wip] inference service #12

Merged
merged 30 commits into from
Apr 2, 2024
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
8237fb9
Refactor core_thread.rs and remove service.rs
jorgeantonio21 Mar 23, 2024
ddbdc1e
Refactor: Moved , , and related types to a separate module
jorgeantonio21 Mar 23, 2024
df7a457
feat: Update dependencies and refactor configuration handling
jorgeantonio21 Mar 25, 2024
919298a
add hugging face client logic
jorgeantonio21 Mar 25, 2024
b8c8edc
feat: Add hf-hub crate version 0.3.2
jorgeantonio21 Mar 25, 2024
9807bc9
refactor: Rename storage_base_path to storage_folder
jorgeantonio21 Mar 26, 2024
ab64ad3
Refactor core.rs and main.rs, introducing tracing for improved debugg…
jorgeantonio21 Mar 26, 2024
7a22bd3
Add tracing-subscriber crate to Cargo.toml
jorgeantonio21 Mar 26, 2024
65f6e01
address PR comments
jorgeantonio21 Mar 26, 2024
1ce0c12
refactor core thread to model thread, to facilitate models running in…
jorgeantonio21 Mar 27, 2024
156b398
remove core, rename core_thread to model_thread, and work on setting …
jorgeantonio21 Mar 27, 2024
6990a12
add model_thread.rs, after renaming
jorgeantonio21 Mar 27, 2024
d52f113
intermediate steps
jorgeantonio21 Mar 27, 2024
618bea8
intermediate steps
jorgeantonio21 Mar 27, 2024
54e0abd
intermediate steps
jorgeantonio21 Mar 27, 2024
e60b586
address new PR comments
jorgeantonio21 Mar 27, 2024
58cfca5
add test to config construction
jorgeantonio21 Mar 27, 2024
1cdb66a
remove unused code
jorgeantonio21 Mar 27, 2024
b56d0b5
remove full dependency of std::sync
jorgeantonio21 Mar 27, 2024
04b6d6c
change to main branch
jorgeantonio21 Mar 28, 2024
d0f6dff
Merge pull request #13 from atoma-network/experiments
jorgeantonio21 Mar 30, 2024
a19817e
add model trait interface and refactor code to be more general
jorgeantonio21 Mar 31, 2024
4a12b71
rename InferenceService to ModelService
jorgeantonio21 Mar 31, 2024
cddb534
simplify code
jorgeantonio21 Mar 31, 2024
f673dea
remove fetch method from ModelTrait
jorgeantonio21 Mar 31, 2024
e403ba9
cargo fmt
jorgeantonio21 Mar 31, 2024
b8a51ac
rename
jorgeantonio21 Mar 31, 2024
6193465
remove unused error fields
jorgeantonio21 Mar 31, 2024
303aedf
removed unused Builder from ModelTrait associated type
jorgeantonio21 Apr 1, 2024
415974f
merge main and resolve conflicts
jorgeantonio21 Apr 2, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,16 @@ async-trait = "0.1.78"
candle = { git = "https://github.com/huggingface/candle", package = "candle-core", version = "0.4.2" }
candle-nn = { git = "https://github.com/huggingface/candle", package = "candle-nn", version = "0.4.2" }
candle-transformers = { git = "https://github.com/huggingface/candle", package = "candle-transformers", version = "0.4.2" }
config = "0.14.0"
ed25519-consensus = "2.1.0"
hf-hub = "0.3.2"
serde = "1.0.197"
serde_json = "1.0.114"
rand = "0.8.5"
reqwest = "0.12.1"
thiserror = "1.0.58"
tokenizers = "0.15.2"
tokio = "1.36.0"
toml = "0.8.12"
tracing = "0.1.40"
tracing-subscriber = "0.3.18"
10 changes: 10 additions & 0 deletions atoma-inference/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,19 @@ async-trait.workspace = true
candle.workspace = true
candle-nn.workspace = true
candle-transformers.workspace = true
config.true = true
ed25519-consensus.workspace = true
hf-hub.workspace = true
reqwest = { workspace = true, features = ["json"] }
serde = { workspace = true, features = ["derive"] }
serde_json.workspace = true
thiserror.workspace = true
tokenizers.workspace = true
tokio = { workspace = true, features = ["full", "tracing"] }
tracing.workspace = true
tracing-subscriber.workspace = true

[dev-dependencies]
rand.workspace = true
toml.workspace = true

116 changes: 116 additions & 0 deletions atoma-inference/src/apis/hugging_face.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
use std::path::PathBuf;

use async_trait::async_trait;
use hf_hub::api::sync::{Api, ApiBuilder};

use crate::models::ModelType;

use super::ApiTrait;

struct FilePaths {
file_paths: Vec<String>,
}

impl ModelType {
fn get_hugging_face_model_path(&self) -> (String, FilePaths) {
jorgeantonio21 marked this conversation as resolved.
Show resolved Hide resolved
match self {
Self::Llama2_7b => (
String::from("meta-llama/Llama-2-7b-hf"),
FilePaths {
file_paths: vec![
"model-00001-of-00002.safetensors".to_string(),
"model-00002-of-00002.safetensors".to_string(),
],
},
),
Self::Mamba3b => (
String::from("state-spaces/mamba-2.8b-hf"),
FilePaths {
file_paths: vec![
"model-00001-of-00003.safetensors".to_string(),
"model-00002-of-00003.safetensors".to_string(),
"model-00003-of-00003.safetensors".to_string(),
],
},
),
Self::Mistral7b => (
String::from("mistralai/Mistral-7B-Instruct-v0.2"),
FilePaths {
file_paths: vec![
"model-00001-of-00003.safetensors".to_string(),
"model-00002-of-00003.safetensors".to_string(),
"model-00003-of-00003.safetensors".to_string(),
],
},
),
Self::Mixtral8x7b => (
String::from("mistralai/Mixtral-8x7B-Instruct-v0.1"),
FilePaths {
file_paths: vec![
"model-00001-of-00019.safetensors".to_string(),
"model-00002-of-00019.safetensors".to_string(),
"model-00003-of-00019.safetensors".to_string(),
"model-00004-of-00019.safetensors".to_string(),
"model-00005-of-00019.safetensors".to_string(),
"model-00006-of-00019.safetensors".to_string(),
"model-00007-of-00019.safetensors".to_string(),
"model-00008-of-00019.safetensors".to_string(),
"model-00009-of-00019.safetensors".to_string(),
"model-000010-of-00019.safetensors".to_string(),
"model-000011-of-00019.safetensors".to_string(),
"model-000012-of-00019.safetensors".to_string(),
"model-000013-of-00019.safetensors".to_string(),
"model-000014-of-00019.safetensors".to_string(),
"model-000015-of-00019.safetensors".to_string(),
"model-000016-of-00019.safetensors".to_string(),
"model-000017-of-00019.safetensors".to_string(),
"model-000018-of-00019.safetensors".to_string(),
"model-000019-of-00019.safetensors".to_string(),
],
},
),
Self::StableDiffusion2 => (
String::from("stabilityai/stable-diffusion-2"),
FilePaths {
file_paths: vec!["768-v-ema.safetensors".to_string()],
},
),
Self::StableDiffusionXl => (
String::from("stabilityai/stable-diffusion-xl-base-1.0"),
FilePaths {
file_paths: vec![
"sd_xl_base_1.0.safetensors".to_string(),
"sd_xl_base_1.0_0.9vae.safetensors".to_string(),
"sd_xl_offset_example-lora_1.0.safetensors".to_string(),
],
},
),
}
}
}

#[async_trait]
impl ApiTrait for Api {
fn create(api_key: String, cache_dir: PathBuf) -> Result<Self, super::ApiError>
where
Self: Sized,
{
Ok(ApiBuilder::new()
.with_progress(true)
.with_token(Some(api_key))
.with_cache_dir(cache_dir)
.build()?)
}

fn fetch(&self, model: ModelType) -> Result<Vec<PathBuf>, super::ApiError> {
let (model_path, files) = model.get_hugging_face_model_path();
let api_repo = self.model(model_path);
let mut path_bufs = Vec::with_capacity(files.file_paths.len());

for file in files.file_paths {
path_bufs.push(api_repo.get(&file)?);
}

Ok(path_bufs)
}
}
29 changes: 29 additions & 0 deletions atoma-inference/src/apis/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
pub mod hugging_face;
use hf_hub::api::sync::ApiError as HuggingFaceError;

use std::path::PathBuf;

use thiserror::Error;

use crate::models::ModelType;

#[derive(Debug, Error)]
pub enum ApiError {
#[error("Api Error: `{0}`")]
ApiError(String),
#[error("HuggingFace API error: `{0}`")]
HuggingFaceError(HuggingFaceError),
}

impl From<HuggingFaceError> for ApiError {
fn from(error: HuggingFaceError) -> Self {
Self::HuggingFaceError(error)
}
}

pub trait ApiTrait {
fn fetch(&self, model: ModelType) -> Result<Vec<PathBuf>, ApiError>;
fn create(api_key: String, cache_dir: PathBuf) -> Result<Self, ApiError>
where
Self: Sized;
}
59 changes: 36 additions & 23 deletions atoma-inference/src/config.rs
Original file line number Diff line number Diff line change
@@ -1,59 +1,52 @@
use std::path::PathBuf;

use crate::{
models::ModelType,
specs::{HardwareSpec, SoftwareSpec},
};
use config::Config;
use serde::Deserialize;

use crate::{models::ModelType, types::PrecisionBits};

#[derive(Debug, Deserialize)]
pub struct InferenceConfig {
api_key: String,
hardware_specs: HardwareSpec,
models: Vec<ModelType>,
software_specs: SoftwareSpec,
storage_base_path: PathBuf,
precision: PrecisionBits,
storage_folder: PathBuf,
tokenizer_file_path: PathBuf,
tracing: bool,
use_kv_cache: Option<bool>,
}

impl InferenceConfig {
pub fn new(
api_key: String,
hardware_specs: HardwareSpec,
models: Vec<ModelType>,
software_specs: SoftwareSpec,
storage_base_path: PathBuf,
precision: PrecisionBits,
storage_folder: PathBuf,
tokenizer_file_path: PathBuf,
tracing: bool,
use_kv_cache: Option<bool>,
) -> Self {
Self {
api_key,
hardware_specs,
models,
software_specs,
storage_base_path,
precision,
storage_folder,
tokenizer_file_path,
tracing,
use_kv_cache,
}
}

pub fn api_key(&self) -> String {
self.api_key.clone()
}

pub fn hardware(&self) -> HardwareSpec {
self.hardware_specs.clone()
}

pub fn models(&self) -> Vec<ModelType> {
self.models.clone()
}

pub fn software(&self) -> SoftwareSpec {
self.software_specs.clone()
}

pub fn storage_base_path(&self) -> PathBuf {
self.storage_base_path.clone()
pub fn storage_folder(&self) -> PathBuf {
self.storage_folder.clone()
}

pub fn tokenizer_file_path(&self) -> PathBuf {
Expand All @@ -63,4 +56,24 @@ impl InferenceConfig {
pub fn tracing(&self) -> bool {
self.tracing
}

pub fn precision_bits(&self) -> PrecisionBits {
self.precision
}

pub fn use_kv_cache(&self) -> Option<bool> {
self.use_kv_cache
}

pub fn from_file_path(config_file_path: PathBuf) -> Self {
let builder = Config::builder().add_source(config::File::with_name(
config_file_path.to_str().as_ref().unwrap(),
));
let config = builder
.build()
.expect("Failed to generate inference configuration file");
config
.try_deserialize::<Self>()
.expect("Failed to generated config file")
}
}
89 changes: 0 additions & 89 deletions atoma-inference/src/core_thread.rs

This file was deleted.

4 changes: 3 additions & 1 deletion atoma-inference/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
pub mod config;
pub mod core_thread;
pub mod model_thread;
pub mod models;
pub mod service;
pub mod specs;
pub mod types;

pub mod apis;
Loading
Loading