From 8237fb99730bddc9b945ad24bdeae4c480b260f3 Mon Sep 17 00:00:00 2001 From: jorgeantonio21 Date: Sat, 23 Mar 2024 17:45:32 +0000 Subject: [PATCH 01/28] Refactor core_thread.rs and remove service.rs In core_thread.rs: - Added constant. - Introduced and structs. - Modified struct to utilize the new dispatcher. - Added methods to for sending commands and running inference. - Removed unnecessary imports and improved error handling. In service.rs: - Removed the entire file as it's no longer needed after refactoring. These changes refactor the threading logic in core_thread.rs and eliminate the now-obsolete service.rs file, improving code organization and maintainability. --- atoma-inference/src/{service.rs => core.rs} | 0 atoma-inference/src/core_thread.rs | 84 ++++++++++++++++++++- 2 files changed, 82 insertions(+), 2 deletions(-) rename atoma-inference/src/{service.rs => core.rs} (100%) diff --git a/atoma-inference/src/service.rs b/atoma-inference/src/core.rs similarity index 100% rename from atoma-inference/src/service.rs rename to atoma-inference/src/core.rs diff --git a/atoma-inference/src/core_thread.rs b/atoma-inference/src/core_thread.rs index 910a2eb1..2c28a0cf 100644 --- a/atoma-inference/src/core_thread.rs +++ b/atoma-inference/src/core_thread.rs @@ -1,12 +1,20 @@ use thiserror::Error; -use tokio::sync::{mpsc, oneshot, oneshot::error::RecvError}; -use tracing::{debug, error}; +use tokio::{ + sync::{ + mpsc, + oneshot::{self, error::RecvError}, + }, + task::JoinHandle, +}; +use tracing::{debug, error, warn}; use crate::{ service::{ApiTrait, InferenceCore, InferenceCoreError}, types::{InferenceRequest, InferenceResponse, ModelRequest, ModelResponse}, }; +const CORE_THREAD_COMMANDS_CHANNEL_SIZE: usize = 32; + pub enum CoreThreadCommand { RunInference(InferenceRequest, oneshot::Sender), FetchModel(ModelRequest, oneshot::Sender), @@ -22,6 +30,19 @@ pub enum CoreError { Shutdown(RecvError), } +pub(crate) struct CoreThreadHandle { + sender: mpsc::Sender, + join_handle: JoinHandle<()>, +} + +impl CoreThreadHandle { + pub async fn stop(self) { + // drop the sender, this will force all the other weak senders to not be able to upgrade. + drop(self.sender); + self.join_handle.await.ok(); + } +} + pub struct CoreThread { core: InferenceCore, receiver: mpsc::Receiver, @@ -76,6 +97,65 @@ impl CoreThread { } } +#[derive(Clone)] +pub struct CoreThreadDispatcher { + sender: mpsc::WeakSender, +} + +impl CoreThreadDispatcher { + pub(crate) fn start( + core: InferenceCore, + ) -> (Self, CoreThreadHandle) { + let (sender, receiver) = mpsc::channel(CORE_THREAD_COMMANDS_CHANNEL_SIZE); + let core_thread = CoreThread { core, receiver }; + + let join_handle = tokio::task::spawn(async move { + if let Err(e) = core_thread.run().await { + if !matches!(e, CoreError::Shutdown(_)) { + panic!("Fatal error occurred: {e}"); + } + } + }); + + let dispatcher = Self { + sender: sender.downgrade(), + }; + let handle = CoreThreadHandle { + join_handle, + sender, + }; + + (dispatcher, handle) + } + + async fn send(&self, command: CoreThreadCommand) { + if let Some(sender) = self.sender.upgrade() { + if let Err(e) = sender.send(command).await { + warn!("Could not send command to thread core, it might be shutting down: {e}"); + } + } + } +} + +impl CoreThreadDispatcher { + async fn fetch_model(&self, request: ModelRequest) -> Result { + let (sender, receiver) = oneshot::channel(); + self.send(CoreThreadCommand::FetchModel(request, sender)) + .await; + receiver.await.map_err(CoreError::Shutdown) + } + + async fn run_inference( + &self, + request: InferenceRequest, + ) -> Result { + let (sender, receiver) = oneshot::channel(); + self.send(CoreThreadCommand::RunInference(request, sender)) + .await; + receiver.await.map_err(CoreError::Shutdown) + } +} + impl From for CoreError { fn from(error: InferenceCoreError) -> Self { match error { From ddbdc1e5f929f8ee55035bfe6533e60561a46291 Mon Sep 17 00:00:00 2001 From: jorgeantonio21 Date: Sat, 23 Mar 2024 18:27:32 +0000 Subject: [PATCH 02/28] Refactor: Moved , , and related types to a separate module This commit refactors the project structure by moving , , and related types to a separate module named . This change improves code organization and maintainability by grouping related functionality together. Changes Made: - Moved , , , and related types to the module. - Updated import paths in affected files to reflect the module restructuring. Impact: - Improved code organization and maintainability. - Clear separation of concerns between different components of the application. --- atoma-inference/src/core_thread.rs | 11 ++-- atoma-inference/src/lib.rs | 1 + atoma-inference/src/service.rs | 83 ++++++++++++++++++++++++++++++ 3 files changed, 91 insertions(+), 4 deletions(-) create mode 100644 atoma-inference/src/service.rs diff --git a/atoma-inference/src/core_thread.rs b/atoma-inference/src/core_thread.rs index 2c28a0cf..735158f3 100644 --- a/atoma-inference/src/core_thread.rs +++ b/atoma-inference/src/core_thread.rs @@ -9,7 +9,7 @@ use tokio::{ use tracing::{debug, error, warn}; use crate::{ - service::{ApiTrait, InferenceCore, InferenceCoreError}, + core::{ApiTrait, InferenceCore, InferenceCoreError}, types::{InferenceRequest, InferenceResponse, ModelRequest, ModelResponse}, }; @@ -30,7 +30,7 @@ pub enum CoreError { Shutdown(RecvError), } -pub(crate) struct CoreThreadHandle { +pub struct CoreThreadHandle { sender: mpsc::Sender, join_handle: JoinHandle<()>, } @@ -138,14 +138,17 @@ impl CoreThreadDispatcher { } impl CoreThreadDispatcher { - async fn fetch_model(&self, request: ModelRequest) -> Result { + pub(crate) async fn fetch_model( + &self, + request: ModelRequest, + ) -> Result { let (sender, receiver) = oneshot::channel(); self.send(CoreThreadCommand::FetchModel(request, sender)) .await; receiver.await.map_err(CoreError::Shutdown) } - async fn run_inference( + pub(crate) async fn run_inference( &self, request: InferenceRequest, ) -> Result { diff --git a/atoma-inference/src/lib.rs b/atoma-inference/src/lib.rs index 1ca7c952..0de5abb8 100644 --- a/atoma-inference/src/lib.rs +++ b/atoma-inference/src/lib.rs @@ -1,4 +1,5 @@ pub mod config; +pub mod core; pub mod core_thread; pub mod models; pub mod service; diff --git a/atoma-inference/src/service.rs b/atoma-inference/src/service.rs new file mode 100644 index 00000000..d3465f93 --- /dev/null +++ b/atoma-inference/src/service.rs @@ -0,0 +1,83 @@ +use ed25519_consensus::SigningKey as PrivateKey; +use std::{io, path::PathBuf}; + +use thiserror::Error; + +use crate::{ + config::InferenceConfig, + core::{ApiError, ApiTrait, InferenceCore, InferenceCoreError}, + core_thread::{CoreError, CoreThreadDispatcher, CoreThreadHandle}, + types::{InferenceRequest, InferenceResponse, ModelRequest, ModelResponse}, +}; + +pub struct InferenceService { + dispatcher: CoreThreadDispatcher, + core_thread_handle: CoreThreadHandle, +} + +impl InferenceService { + pub fn start( + config: InferenceConfig, + private_key_path: PathBuf, + ) -> Result { + let private_key_bytes = + std::fs::read(&private_key_path).map_err(InferenceServiceError::PrivateKeyError)?; + let private_key_bytes: [u8; 32] = private_key_bytes + .try_into() + .expect("Incorrect private key bytes length"); + + let private_key = PrivateKey::from(private_key_bytes); + let inference_core = InferenceCore::::new(config, private_key)?; + + let (dispatcher, core_thread_handle) = CoreThreadDispatcher::start(inference_core); + + Ok(Self { + dispatcher, + core_thread_handle, + }) + } + + async fn run_inference( + &self, + inference_request: InferenceRequest, + ) -> Result { + self.dispatcher + .run_inference(inference_request) + .await + .map_err(InferenceServiceError::CoreError) + } + + async fn fetch_model( + &self, + model_request: ModelRequest, + ) -> Result { + self.dispatcher + .fetch_model(model_request) + .await + .map_err(InferenceServiceError::CoreError) + } +} + +#[derive(Debug, Error)] +pub enum InferenceServiceError { + #[error("Failed to connect to API: `{0}`")] + FailedApiConnection(ApiError), + #[error("Failed to run inference: `{0}`")] + FailedInference(Box), + #[error("Failed to fecth model: `{0}`")] + FailedModelFetch(String), + #[error("Failed to generate private key: `{0}`")] + PrivateKeyError(io::Error), + #[error("Core error: `{0}`")] + CoreError(CoreError), +} + +impl From for InferenceServiceError { + fn from(error: InferenceCoreError) -> Self { + match error { + InferenceCoreError::FailedApiConnection(e) => Self::FailedApiConnection(e), + InferenceCoreError::FailedInference(e) => Self::FailedInference(e), + InferenceCoreError::FailedModelFetch(e) => Self::FailedModelFetch(e), + } + } +} From df7a457f4e50646b6d14f2f71bd66d6fbb242ba1 Mon Sep 17 00:00:00 2001 From: jorgeantonio21 Date: Mon, 25 Mar 2024 00:40:09 +0000 Subject: [PATCH 03/28] feat: Update dependencies and refactor configuration handling This commit includes several changes: - Updated dependencies in the file, including adding , , and crates. - Refactored the struct and its related methods to use the crate for loading configuration from a file. - Updated to properly initialize and fetch models during service startup. - Added deserialization support for enum to read from TOML configuration files. - Updated error handling in various parts of the codebase. - Added a test case to ensure proper initialization of the inference service with sample configuration data. This refactor ensures better modularity, improved error handling, and easier maintenance of the codebase. --- Cargo.toml | 3 + atoma-inference/Cargo.toml | 6 ++ atoma-inference/src/config.rs | 35 +++++----- atoma-inference/src/core.rs | 9 ++- atoma-inference/src/models.rs | 3 +- atoma-inference/src/service.rs | 118 +++++++++++++++++++++++++++++++-- atoma-inference/src/specs.rs | 24 ++++--- atoma-inference/src/types.rs | 1 + 8 files changed, 159 insertions(+), 40 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 0d1fdee6..06a34e0a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,9 +11,12 @@ async-trait = "0.1.78" candle = { git = "https://github.com/huggingface/candle", package = "candle-core", version = "0.4.2" } candle-nn = { git = "https://github.com/huggingface/candle", package = "candle-nn", version = "0.4.2" } candle-transformers = { git = "https://github.com/huggingface/candle", package = "candle-transformers", version = "0.4.2" } +config = "0.14.0" ed25519-consensus = "2.1.0" serde = "1.0.197" +rand = "0.8.5" thiserror = "1.0.58" tokenizers = "0.15.2" tokio = "1.36.0" +toml = "0.8.12" tracing = "0.1.40" diff --git a/atoma-inference/Cargo.toml b/atoma-inference/Cargo.toml index 82e2740f..d128b211 100644 --- a/atoma-inference/Cargo.toml +++ b/atoma-inference/Cargo.toml @@ -10,9 +10,15 @@ async-trait.workspace = true candle.workspace = true candle-nn.workspace = true candle-transformers.workspace = true +config.true = true ed25519-consensus.workspace = true serde = { workspace = true, features = ["derive"] } thiserror.workspace = true tokenizers.workspace = true tokio = { workspace = true, features = ["full", "tracing"] } tracing.workspace = true + +[dev-dependencies] +rand.workspace = true +toml.workspace = true + diff --git a/atoma-inference/src/config.rs b/atoma-inference/src/config.rs index 63d35627..230fcb60 100644 --- a/atoma-inference/src/config.rs +++ b/atoma-inference/src/config.rs @@ -1,15 +1,14 @@ use std::path::PathBuf; -use crate::{ - models::ModelType, - specs::{HardwareSpec, SoftwareSpec}, -}; +use config::Config; +use serde::Deserialize; +use crate::models::ModelType; + +#[derive(Debug, Deserialize)] pub struct InferenceConfig { api_key: String, - hardware_specs: HardwareSpec, models: Vec, - software_specs: SoftwareSpec, storage_base_path: PathBuf, tokenizer_file_path: PathBuf, tracing: bool, @@ -18,18 +17,14 @@ pub struct InferenceConfig { impl InferenceConfig { pub fn new( api_key: String, - hardware_specs: HardwareSpec, models: Vec, - software_specs: SoftwareSpec, storage_base_path: PathBuf, tokenizer_file_path: PathBuf, tracing: bool, ) -> Self { Self { api_key, - hardware_specs, models, - software_specs, storage_base_path, tokenizer_file_path, tracing, @@ -40,18 +35,10 @@ impl InferenceConfig { self.api_key.clone() } - pub fn hardware(&self) -> HardwareSpec { - self.hardware_specs.clone() - } - pub fn models(&self) -> Vec { self.models.clone() } - pub fn software(&self) -> SoftwareSpec { - self.software_specs.clone() - } - pub fn storage_base_path(&self) -> PathBuf { self.storage_base_path.clone() } @@ -63,4 +50,16 @@ impl InferenceConfig { pub fn tracing(&self) -> bool { self.tracing } + + pub fn from_file_path(config_file_path: PathBuf) -> Self { + let builder = Config::builder().add_source(config::File::with_name( + config_file_path.to_str().as_ref().unwrap(), + )); + let config = builder + .build() + .expect("Failed to generate inference configuration file"); + config + .try_deserialize::() + .expect("Failed to generated config file") + } } diff --git a/atoma-inference/src/core.rs b/atoma-inference/src/core.rs index a1aec840..c3df69d4 100644 --- a/atoma-inference/src/core.rs +++ b/atoma-inference/src/core.rs @@ -23,11 +23,11 @@ pub trait ApiTrait { #[allow(dead_code)] pub struct InferenceCore { - config: InferenceConfig, + pub(crate) config: InferenceConfig, // models: Vec, pub(crate) public_key: PublicKey, private_key: PrivateKey, - web2_api: T, + pub(crate) web2_api: T, } impl InferenceCore { @@ -76,7 +76,10 @@ impl InferenceCore { _model: ModelType, _quantization_method: Option, ) -> Result { - Ok(ModelResponse { is_success: true }) + Ok(ModelResponse { + is_success: true, + error: None, + }) } } diff --git a/atoma-inference/src/models.rs b/atoma-inference/src/models.rs index c2d4612b..af3bfc3d 100644 --- a/atoma-inference/src/models.rs +++ b/atoma-inference/src/models.rs @@ -13,6 +13,7 @@ use candle_transformers::{ stable_diffusion::StableDiffusionConfig, }, }; +use serde::Deserialize; use thiserror::Error; use tokenizers::Tokenizer; @@ -21,7 +22,7 @@ use crate::types::Temperature; const EOS_TOKEN: &str = ""; -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Deserialize)] pub enum ModelType { Llama(usize), Llama2(usize), diff --git a/atoma-inference/src/service.rs b/atoma-inference/src/service.rs index d3465f93..f5d86dcd 100644 --- a/atoma-inference/src/service.rs +++ b/atoma-inference/src/service.rs @@ -1,5 +1,7 @@ use ed25519_consensus::SigningKey as PrivateKey; -use std::{io, path::PathBuf}; +use std::{io, path::PathBuf, time::Instant}; +use tokio::sync::mpsc::{error::SendError, Receiver}; +use tracing::{info, warn}; use thiserror::Error; @@ -11,14 +13,17 @@ use crate::{ }; pub struct InferenceService { - dispatcher: CoreThreadDispatcher, core_thread_handle: CoreThreadHandle, + dispatcher: CoreThreadDispatcher, + start_time: Instant, + request_receiver: Receiver, } impl InferenceService { - pub fn start( - config: InferenceConfig, + pub async fn start( + config_file_path: PathBuf, private_key_path: PathBuf, + request_receiver: Receiver, ) -> Result { let private_key_bytes = std::fs::read(&private_key_path).map_err(InferenceServiceError::PrivateKeyError)?; @@ -27,14 +32,36 @@ impl InferenceService { .expect("Incorrect private key bytes length"); let private_key = PrivateKey::from(private_key_bytes); - let inference_core = InferenceCore::::new(config, private_key)?; + let inference_config = InferenceConfig::from_file_path(config_file_path); + let models = inference_config.models(); + let inference_core = InferenceCore::::new(inference_config, private_key)?; let (dispatcher, core_thread_handle) = CoreThreadDispatcher::start(inference_core); + let start_time = Instant::now(); - Ok(Self { + let inference_service = Self { dispatcher, core_thread_handle, - }) + start_time, + request_receiver, + }; + + for model in models { + let response = inference_service + .fetch_model(ModelRequest { + model: model.clone(), + quantization_method: None, + }) + .await?; + if !response.is_success { + warn!( + "Failed to fetch model: {:?}, with error: {:?}", + model, response.error + ); + } + } + + Ok(inference_service) } async fn run_inference( @@ -58,6 +85,17 @@ impl InferenceService { } } +impl InferenceService { + pub async fn stop(self) { + info!( + "Stopping Inference Service, running time: {:?}", + self.start_time.elapsed() + ); + + self.core_thread_handle.stop().await; + } +} + #[derive(Debug, Error)] pub enum InferenceServiceError { #[error("Failed to connect to API: `{0}`")] @@ -70,6 +108,8 @@ pub enum InferenceServiceError { PrivateKeyError(io::Error), #[error("Core error: `{0}`")] CoreError(CoreError), + #[error("Send error: `{0}`")] + SendError(SendError), } impl From for InferenceServiceError { @@ -81,3 +121,67 @@ impl From for InferenceServiceError { } } } + +#[cfg(test)] +mod tests { + use rand::rngs::OsRng; + use std::io::Write; + use toml::{toml, Value}; + + use super::*; + + struct TestApiInstance {} + + impl ApiTrait for TestApiInstance { + fn call(&mut self) -> Result<(), ApiError> { + Ok(()) + } + + fn connect(_: &str) -> Result + where + Self: Sized, + { + Ok(Self {}) + } + + fn fetch(&mut self) -> Result<(), ApiError> { + Ok(()) + } + } + + #[tokio::test] + async fn test_inference_service_initialization() { + const CONFIG_FILE_PATH: &str = "./config.toml"; + const PRIVATE_KEY_FILE_PATH: &str = "./private_key"; + + let private_key = PrivateKey::new(&mut OsRng); + std::fs::write(PRIVATE_KEY_FILE_PATH, private_key.to_bytes()).unwrap(); + + let config_data = Value::Table(toml! { + api_key = "your_api_key" + models = [{ Mamba = 3 }] + storage_base_path = "./storage_base_path/" + tokenizer_file_path = "./tokenizer_file_path/" + tracing = true + }); + let toml_string = + toml::to_string_pretty(&config_data).expect("Failed to serialize to TOML"); + + let mut file = std::fs::File::create(CONFIG_FILE_PATH).expect("Failed to create file"); + file.write_all(toml_string.as_bytes()) + .expect("Failed to write to file"); + + let (_, receiver) = tokio::sync::mpsc::channel(1); + + let _ = InferenceService::start::( + PathBuf::try_from(CONFIG_FILE_PATH).unwrap(), + PathBuf::try_from(PRIVATE_KEY_FILE_PATH).unwrap(), + receiver, + ) + .await + .unwrap(); + + std::fs::remove_file(CONFIG_FILE_PATH).unwrap(); + std::fs::remove_file(PRIVATE_KEY_FILE_PATH).unwrap(); + } +} diff --git a/atoma-inference/src/specs.rs b/atoma-inference/src/specs.rs index 5e64f5c3..342f70e8 100644 --- a/atoma-inference/src/specs.rs +++ b/atoma-inference/src/specs.rs @@ -1,32 +1,34 @@ #![allow(non_camel_case_types)] -#[derive(Clone, Debug)] +use serde::Deserialize; + +#[derive(Clone, Debug, Deserialize)] pub enum HardwareSpec { Cpu(CpuModel), Gpu(GpuModel), } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Deserialize)] pub enum CpuModel { Intel(IntelModel), Arm(ArmModel), } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Deserialize)] pub enum GpuModel { Nvidia(NvidiaModel), Amd(AmdModel), } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Deserialize)] pub enum IntelModel { x86(usize), } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Deserialize)] pub enum ArmModel {} -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Deserialize)] pub enum NvidiaModel { Rtx3090(usize), Rtx4090(usize), @@ -35,23 +37,23 @@ pub enum NvidiaModel { H100(usize), } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Deserialize)] pub enum AmdModel {} -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Deserialize)] pub enum SoftwareSpec { Jax(JaxVersion), Cuda(CudaVersion), Xla(XlaVersion), } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Deserialize)] pub enum JaxVersion {} -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Deserialize)] pub enum CudaVersion { v11, } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Deserialize)] pub enum XlaVersion {} diff --git a/atoma-inference/src/types.rs b/atoma-inference/src/types.rs index 9dcda742..8ee6baeb 100644 --- a/atoma-inference/src/types.rs +++ b/atoma-inference/src/types.rs @@ -36,6 +36,7 @@ pub struct ModelRequest { #[allow(dead_code)] pub struct ModelResponse { pub(crate) is_success: bool, + pub(crate) error: Option, } #[derive(Clone, Debug)] From 919298a41de9a9da4bf31b2101e590cec8bdd0f7 Mon Sep 17 00:00:00 2001 From: jorgeantonio21 Date: Mon, 25 Mar 2024 06:21:45 +0000 Subject: [PATCH 04/28] add hugging face client logic --- Cargo.toml | 2 + atoma-inference/Cargo.toml | 2 + atoma-inference/src/apis/hugging_face.rs | 106 +++++++++++++++++++++++ atoma-inference/src/apis/mod.rs | 17 ++++ atoma-inference/src/core.rs | 15 +--- atoma-inference/src/core_thread.rs | 3 +- atoma-inference/src/lib.rs | 2 + atoma-inference/src/service.rs | 3 +- 8 files changed, 134 insertions(+), 16 deletions(-) create mode 100644 atoma-inference/src/apis/hugging_face.rs create mode 100644 atoma-inference/src/apis/mod.rs diff --git a/Cargo.toml b/Cargo.toml index 06a34e0a..7fe25ce1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,7 +14,9 @@ candle-transformers = { git = "https://github.com/huggingface/candle", package = config = "0.14.0" ed25519-consensus = "2.1.0" serde = "1.0.197" +serde_json = "1.0.114" rand = "0.8.5" +reqwest = "0.12.1" thiserror = "1.0.58" tokenizers = "0.15.2" tokio = "1.36.0" diff --git a/atoma-inference/Cargo.toml b/atoma-inference/Cargo.toml index d128b211..a93645e8 100644 --- a/atoma-inference/Cargo.toml +++ b/atoma-inference/Cargo.toml @@ -12,7 +12,9 @@ candle-nn.workspace = true candle-transformers.workspace = true config.true = true ed25519-consensus.workspace = true +reqwest = { workspace = true, features = ["json"] } serde = { workspace = true, features = ["derive"] } +serde_json.workspace = true thiserror.workspace = true tokenizers.workspace = true tokio = { workspace = true, features = ["full", "tracing"] } diff --git a/atoma-inference/src/apis/hugging_face.rs b/atoma-inference/src/apis/hugging_face.rs new file mode 100644 index 00000000..a013f32e --- /dev/null +++ b/atoma-inference/src/apis/hugging_face.rs @@ -0,0 +1,106 @@ +use reqwest::{ + header::{self, HeaderMap}, + Client, IntoUrl, Url, +}; +use serde::{de::DeserializeOwned, Serialize}; +use thiserror::Error; + +pub struct HuggingFaceClient { + client: Client, + endpoint: Url, + request_id: i64, +} + +impl HuggingFaceClient { + pub fn connect(endpoint: T) -> Result { + let client = Client::builder() + .default_headers({ + let mut headers = HeaderMap::new(); + headers.insert(header::CONTENT_TYPE, "application/json".parse().unwrap()); + headers + }) + .build()?; + + Ok(Self { + client, + endpoint: endpoint.into_url()?, + request_id: 0, + }) + } + + fn next_request_id(&mut self) -> i64 { + self.request_id += 1; + self.request_id + } + + pub async fn send_request( + &mut self, + method: &str, + params: T, + ) -> Result { + let params = serde_json::to_value(params)?; + let request_json = serde_json::json!({ + "jsonrpc": "2.0", + "id": self.next_request_id(), + "method": method, + "params`": params, + }); + + let response = self + .client + .post(self.endpoint.clone()) + .body(request_json.to_string()) + .send() + .await?; + + let value = response.json().await?; + let response = extract_json_result(value)?; + + Ok(serde_json::from_value(response)?) + } +} + +#[derive(Debug, Error)] +pub enum HuggingFaceError { + #[error("Connection error: `{0}`")] + ConnectionError(reqwest::Error), + #[error("Serialization error: `{0}`")] + SerializationError(serde_json::Error), + #[error("Failed request with code `{code}` and message `{message}`")] + RequestError { code: i64, message: String }, + #[error("Invalid response: `{message}`")] + InvalidResponse { message: String }, +} + +impl From for HuggingFaceError { + fn from(error: reqwest::Error) -> Self { + Self::ConnectionError(error) + } +} + +impl From for HuggingFaceError { + fn from(error: serde_json::Error) -> Self { + Self::SerializationError(error) + } +} + +fn extract_json_result(val: serde_json::Value) -> Result { + if let Some(err) = val.get("error") { + let code = err.get("code").and_then(|c| c.as_i64()).unwrap_or(-1); + let message = err + .get("message") + .and_then(|m| m.as_str()) + .unwrap_or("Unknown error"); + return Err(HuggingFaceError::RequestError { + code, + message: message.to_string(), + }); + } + + let result = val + .get("result") + .ok_or_else(|| HuggingFaceError::InvalidResponse { + message: "Missing result field".to_string(), + })?; + Ok(result.clone()) +} diff --git a/atoma-inference/src/apis/mod.rs b/atoma-inference/src/apis/mod.rs new file mode 100644 index 00000000..ef7c053a --- /dev/null +++ b/atoma-inference/src/apis/mod.rs @@ -0,0 +1,17 @@ +pub mod hugging_face; + +use thiserror::Error; + +#[derive(Debug, Error)] +pub enum ApiError { + #[error("Api Error: {0}")] + ApiError(String), +} + +pub trait ApiTrait { + fn call(&mut self) -> Result<(), ApiError>; + fn fetch(&mut self) -> Result<(), ApiError>; + fn connect(api_key: &str) -> Result + where + Self: Sized; +} diff --git a/atoma-inference/src/core.rs b/atoma-inference/src/core.rs index c3df69d4..185eb2a3 100644 --- a/atoma-inference/src/core.rs +++ b/atoma-inference/src/core.rs @@ -2,25 +2,12 @@ use ed25519_consensus::{SigningKey as PrivateKey, VerificationKey as PublicKey}; use thiserror::Error; use crate::{ + apis::{ApiError, ApiTrait}, config::InferenceConfig, models::ModelType, types::{InferenceResponse, ModelResponse, QuantizationMethod, Temperature}, }; -#[derive(Debug, Error)] -pub enum ApiError { - #[error("Api Error: {0}")] - ApiError(String), -} - -pub trait ApiTrait { - fn call(&mut self) -> Result<(), ApiError>; - fn fetch(&mut self) -> Result<(), ApiError>; - fn connect(api_key: &str) -> Result - where - Self: Sized; -} - #[allow(dead_code)] pub struct InferenceCore { pub(crate) config: InferenceConfig, diff --git a/atoma-inference/src/core_thread.rs b/atoma-inference/src/core_thread.rs index 735158f3..db63dccf 100644 --- a/atoma-inference/src/core_thread.rs +++ b/atoma-inference/src/core_thread.rs @@ -9,7 +9,8 @@ use tokio::{ use tracing::{debug, error, warn}; use crate::{ - core::{ApiTrait, InferenceCore, InferenceCoreError}, + apis::ApiTrait, + core::{InferenceCore, InferenceCoreError}, types::{InferenceRequest, InferenceResponse, ModelRequest, ModelResponse}, }; diff --git a/atoma-inference/src/lib.rs b/atoma-inference/src/lib.rs index 0de5abb8..433065ca 100644 --- a/atoma-inference/src/lib.rs +++ b/atoma-inference/src/lib.rs @@ -5,3 +5,5 @@ pub mod models; pub mod service; pub mod specs; pub mod types; + +pub mod apis; diff --git a/atoma-inference/src/service.rs b/atoma-inference/src/service.rs index f5d86dcd..307d160a 100644 --- a/atoma-inference/src/service.rs +++ b/atoma-inference/src/service.rs @@ -6,8 +6,9 @@ use tracing::{info, warn}; use thiserror::Error; use crate::{ + apis::{ApiError, ApiTrait}, config::InferenceConfig, - core::{ApiError, ApiTrait, InferenceCore, InferenceCoreError}, + core::{InferenceCore, InferenceCoreError}, core_thread::{CoreError, CoreThreadDispatcher, CoreThreadHandle}, types::{InferenceRequest, InferenceResponse, ModelRequest, ModelResponse}, }; From b8c8edc303a47a3790505a31f31a1a477ad8e849 Mon Sep 17 00:00:00 2001 From: jorgeantonio21 Date: Mon, 25 Mar 2024 20:53:07 +0000 Subject: [PATCH 05/28] feat: Add hf-hub crate version 0.3.2 - Added hf-hub crate version 0.3.2 to Cargo.toml for dependency management. - Updated atoma-inference/Cargo.toml to include hf-hub as a workspace with the tokio feature enabled. - Added hf-hub to the list of dependencies in atoma-inference/src/apis/hugging_face.rs. - Updated ModelType enum in atoma-inference/src/models.rs to include new model types supported by hf-hub. - Modified ApiTrait and Api implementations to incorporate hf-hub functionality for model fetching. - Implemented asynchronous model fetching in the ApiTrait trait using async_trait. - Added tests for InferenceService initialization and model fetching with hf-hub. --- Cargo.toml | 1 + atoma-inference/Cargo.toml | 1 + atoma-inference/src/apis/hugging_face.rs | 196 ++++++++++++----------- atoma-inference/src/apis/mod.rs | 21 ++- atoma-inference/src/core.rs | 11 +- atoma-inference/src/core_thread.rs | 2 +- atoma-inference/src/main.rs | 20 ++- atoma-inference/src/models.rs | 22 +-- atoma-inference/src/service.rs | 12 +- 9 files changed, 165 insertions(+), 121 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 7fe25ce1..d5f109d7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,6 +13,7 @@ candle-nn = { git = "https://github.com/huggingface/candle", package = "candle-n candle-transformers = { git = "https://github.com/huggingface/candle", package = "candle-transformers", version = "0.4.2" } config = "0.14.0" ed25519-consensus = "2.1.0" +hf-hub = "0.3.2" serde = "1.0.197" serde_json = "1.0.114" rand = "0.8.5" diff --git a/atoma-inference/Cargo.toml b/atoma-inference/Cargo.toml index a93645e8..99e1e1de 100644 --- a/atoma-inference/Cargo.toml +++ b/atoma-inference/Cargo.toml @@ -12,6 +12,7 @@ candle-nn.workspace = true candle-transformers.workspace = true config.true = true ed25519-consensus.workspace = true +hf-hub = { workspace = true, features = ["tokio"] } reqwest = { workspace = true, features = ["json"] } serde = { workspace = true, features = ["derive"] } serde_json.workspace = true diff --git a/atoma-inference/src/apis/hugging_face.rs b/atoma-inference/src/apis/hugging_face.rs index a013f32e..d49b19d7 100644 --- a/atoma-inference/src/apis/hugging_face.rs +++ b/atoma-inference/src/apis/hugging_face.rs @@ -1,106 +1,118 @@ -use reqwest::{ - header::{self, HeaderMap}, - Client, IntoUrl, Url, -}; -use serde::{de::DeserializeOwned, Serialize}; -use thiserror::Error; +use std::path::PathBuf; -pub struct HuggingFaceClient { - client: Client, - endpoint: Url, - request_id: i64, -} - -impl HuggingFaceClient { - pub fn connect(endpoint: T) -> Result { - let client = Client::builder() - .default_headers({ - let mut headers = HeaderMap::new(); - headers.insert(header::CONTENT_TYPE, "application/json".parse().unwrap()); - headers - }) - .build()?; - - Ok(Self { - client, - endpoint: endpoint.into_url()?, - request_id: 0, - }) - } - - fn next_request_id(&mut self) -> i64 { - self.request_id += 1; - self.request_id - } - - pub async fn send_request( - &mut self, - method: &str, - params: T, - ) -> Result { - let params = serde_json::to_value(params)?; - let request_json = serde_json::json!({ - "jsonrpc": "2.0", - "id": self.next_request_id(), - "method": method, - "params`": params, - }); +use async_trait::async_trait; +use hf_hub::api::tokio::{Api, ApiBuilder}; - let response = self - .client - .post(self.endpoint.clone()) - .body(request_json.to_string()) - .send() - .await?; +use crate::models::ModelType; - let value = response.json().await?; - let response = extract_json_result(value)?; +use super::ApiTrait; - Ok(serde_json::from_value(response)?) - } -} - -#[derive(Debug, Error)] -pub enum HuggingFaceError { - #[error("Connection error: `{0}`")] - ConnectionError(reqwest::Error), - #[error("Serialization error: `{0}`")] - SerializationError(serde_json::Error), - #[error("Failed request with code `{code}` and message `{message}`")] - RequestError { code: i64, message: String }, - #[error("Invalid response: `{message}`")] - InvalidResponse { message: String }, +struct FilePaths { + file_paths: Vec, } -impl From for HuggingFaceError { - fn from(error: reqwest::Error) -> Self { - Self::ConnectionError(error) +impl ModelType { + fn get_hugging_face_model_path(&self) -> (String, FilePaths) { + match self { + Self::Llama2_7b => ( + String::from("meta-llama/Llama-2-7b-hf"), + FilePaths { + file_paths: vec![ + "model-00001-of-00002.safetensors".to_string(), + "model-00002-of-00002.safetensors".to_string(), + ], + }, + ), + Self::Mamba3b => ( + String::from("state-spaces/mamba-2.8b-hf"), + FilePaths { + file_paths: vec![ + "model-00001-of-00003.safetensors".to_string(), + "model-00002-of-00003.safetensors".to_string(), + "model-00003-of-00003.safetensors".to_string(), + ], + }, + ), + Self::Mistral7b => ( + String::from("mistralai/Mistral-7B-Instruct-v0.2"), + FilePaths { + file_paths: vec![ + "model-00001-of-00003.safetensors".to_string(), + "model-00002-of-00003.safetensors".to_string(), + "model-00002-of-00003.safetensors".to_string(), + ], + }, + ), + Self::Mixtral8x7b => ( + String::from("mistralai/Mixtral-8x7B-Instruct-v0.1"), + FilePaths { + file_paths: vec![ + "model-00001-of-00019.safetensors".to_string(), + "model-00002-of-00019.safetensors".to_string(), + "model-00003-of-00019.safetensors".to_string(), + "model-00004-of-00019.safetensors".to_string(), + "model-00005-of-00019.safetensors".to_string(), + "model-00006-of-00019.safetensors".to_string(), + "model-00007-of-00019.safetensors".to_string(), + "model-00008-of-00019.safetensors".to_string(), + "model-00009-of-00019.safetensors".to_string(), + "model-000010-of-00019.safetensors".to_string(), + "model-000011-of-00019.safetensors".to_string(), + "model-000012-of-00019.safetensors".to_string(), + "model-000013-of-00019.safetensors".to_string(), + "model-000014-of-00019.safetensors".to_string(), + "model-000015-of-00019.safetensors".to_string(), + "model-000016-of-00019.safetensors".to_string(), + "model-000017-of-00019.safetensors".to_string(), + "model-000018-of-00019.safetensors".to_string(), + "model-000019-of-00019.safetensors".to_string(), + ], + }, + ), + Self::StableDiffusion2 => ( + String::from("stabilityai/stable-diffusion-2"), + FilePaths { + file_paths: vec!["768-v-ema.safetensors".to_string()], + }, + ), + Self::StableDiffusionXl => ( + String::from("stabilityai/stable-diffusion-xl-base-1.0"), + FilePaths { + file_paths: vec![ + "sd_xl_base_1.0.safetensors".to_string(), + "sd_xl_base_1.0_0.9vae.safetensors".to_string(), + "sd_xl_offset_example-lora_1.0.safetensors".to_string(), + ], + }, + ), + } } } -impl From for HuggingFaceError { - fn from(error: serde_json::Error) -> Self { - Self::SerializationError(error) +#[async_trait] +impl ApiTrait for Api { + fn call(&mut self) -> Result<(), super::ApiError> { + todo!() } -} -fn extract_json_result(val: serde_json::Value) -> Result { - if let Some(err) = val.get("error") { - let code = err.get("code").and_then(|c| c.as_i64()).unwrap_or(-1); - let message = err - .get("message") - .and_then(|m| m.as_str()) - .unwrap_or("Unknown error"); - return Err(HuggingFaceError::RequestError { - code, - message: message.to_string(), - }); + fn create(api_key: String, cache_dir: PathBuf) -> Result + where + Self: Sized, + { + Ok(ApiBuilder::new() + .with_progress(true) + .with_token(Some(api_key)) + .with_cache_dir(cache_dir) + .build()?) } - let result = val - .get("result") - .ok_or_else(|| HuggingFaceError::InvalidResponse { - message: "Missing result field".to_string(), - })?; - Ok(result.clone()) + async fn fetch(&mut self, model: ModelType) -> Result<(), super::ApiError> { + let (model_path, files) = model.get_hugging_face_model_path(); + let api_repo = self.model(model_path); + for file in files.file_paths { + api_repo.get(&file).await?; + } + + Ok(()) + } } diff --git a/atoma-inference/src/apis/mod.rs b/atoma-inference/src/apis/mod.rs index ef7c053a..2c598aa2 100644 --- a/atoma-inference/src/apis/mod.rs +++ b/atoma-inference/src/apis/mod.rs @@ -1,17 +1,32 @@ pub mod hugging_face; +use async_trait::async_trait; +use hf_hub::api::tokio::ApiError as HuggingFaceError; + +use std::path::PathBuf; use thiserror::Error; +use crate::models::ModelType; + #[derive(Debug, Error)] pub enum ApiError { - #[error("Api Error: {0}")] + #[error("Api Error: `{0}`")] ApiError(String), + #[error("HuggingFace API error: `{0}`")] + HuggingFaceError(HuggingFaceError), +} + +impl From for ApiError { + fn from(error: HuggingFaceError) -> Self { + Self::HuggingFaceError(error) + } } +#[async_trait] pub trait ApiTrait { fn call(&mut self) -> Result<(), ApiError>; - fn fetch(&mut self) -> Result<(), ApiError>; - fn connect(api_key: &str) -> Result + async fn fetch(&mut self, model: ModelType) -> Result<(), ApiError>; + fn create(api_key: String, cache_dir: PathBuf) -> Result where Self: Sized; } diff --git a/atoma-inference/src/core.rs b/atoma-inference/src/core.rs index 185eb2a3..db88acda 100644 --- a/atoma-inference/src/core.rs +++ b/atoma-inference/src/core.rs @@ -14,7 +14,7 @@ pub struct InferenceCore { // models: Vec, pub(crate) public_key: PublicKey, private_key: PrivateKey, - pub(crate) web2_api: T, + pub(crate) api: T, } impl InferenceCore { @@ -23,12 +23,12 @@ impl InferenceCore { private_key: PrivateKey, ) -> Result { let public_key = private_key.verification_key(); - let web2_api = T::connect(&config.api_key())?; + let api = T::create(config.api_key(), config.storage_base_path())?; Ok(Self { config, public_key, private_key, - web2_api, + api, }) } } @@ -58,11 +58,12 @@ impl InferenceCore { todo!() } - pub fn fetch_model( + pub async fn fetch_model( &mut self, - _model: ModelType, + model: ModelType, _quantization_method: Option, ) -> Result { + self.api.fetch(model).await?; Ok(ModelResponse { is_success: true, error: None, diff --git a/atoma-inference/src/core_thread.rs b/atoma-inference/src/core_thread.rs index db63dccf..94f3df03 100644 --- a/atoma-inference/src/core_thread.rs +++ b/atoma-inference/src/core_thread.rs @@ -88,7 +88,7 @@ impl CoreThread { model, quantization_method, } = request; - let response = self.core.fetch_model(model, quantization_method)?; + let response = self.core.fetch_model(model, quantization_method).await?; sender.send(response).ok(); } } diff --git a/atoma-inference/src/main.rs b/atoma-inference/src/main.rs index e7a11a96..f8a2aef3 100644 --- a/atoma-inference/src/main.rs +++ b/atoma-inference/src/main.rs @@ -1,3 +1,19 @@ -fn main() { - println!("Hello, world!"); +use hf_hub::api::tokio::Api; +use inference::service::InferenceService; + +#[tokio::main] +async fn main() { + let (_, request_receiver) = tokio::sync::mpsc::channel(32); + + let _ = InferenceService::start::( + "/Users/jorgeantonio/dev/atoma-node/inference.toml" + .parse() + .unwrap(), + "/Users/jorgeantonio/dev/atoma-node/atoma-inference/private_key" + .parse() + .unwrap(), + request_receiver, + ) + .await + .expect("Failed to start inference service"); } diff --git a/atoma-inference/src/models.rs b/atoma-inference/src/models.rs index af3bfc3d..44891c56 100644 --- a/atoma-inference/src/models.rs +++ b/atoma-inference/src/models.rs @@ -24,29 +24,23 @@ const EOS_TOKEN: &str = ""; #[derive(Clone, Debug, Deserialize)] pub enum ModelType { - Llama(usize), - Llama2(usize), - Mamba(usize), + Llama2_7b, + Mamba3b, Mixtral8x7b, - Mistral(usize), - StableDiffusionV1_5, - StableDiffusionV2_1, + Mistral7b, + StableDiffusion2, StableDiffusionXl, - StableDiffusionTurbo, } impl Display for ModelType { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - Self::Llama(size) => write!(f, "llama({})", size), - Self::Llama2(size) => write!(f, "llama2({})", size), - Self::Mamba(size) => write!(f, "mamba({})", size), + Self::Llama2_7b => write!(f, "llama2_7b"), + Self::Mamba3b => write!(f, "mamba_3b"), Self::Mixtral8x7b => write!(f, "mixtral_8x7b"), - Self::Mistral(size) => write!(f, "mistral({})", size), - Self::StableDiffusionV1_5 => write!(f, "stable_diffusion_v1_5"), - Self::StableDiffusionV2_1 => write!(f, "stable_diffusion_v2_1"), + Self::Mistral7b => write!(f, "mistral_7b"), + Self::StableDiffusion2 => write!(f, "stable_diffusion_2"), Self::StableDiffusionXl => write!(f, "stable_diffusion_xl"), - Self::StableDiffusionTurbo => write!(f, "stable_diffusion_turbo"), } } } diff --git a/atoma-inference/src/service.rs b/atoma-inference/src/service.rs index 307d160a..f120438e 100644 --- a/atoma-inference/src/service.rs +++ b/atoma-inference/src/service.rs @@ -125,34 +125,38 @@ impl From for InferenceServiceError { #[cfg(test)] mod tests { + use async_trait::async_trait; use rand::rngs::OsRng; use std::io::Write; use toml::{toml, Value}; + use crate::models::ModelType; + use super::*; struct TestApiInstance {} + #[async_trait] impl ApiTrait for TestApiInstance { fn call(&mut self) -> Result<(), ApiError> { Ok(()) } - fn connect(_: &str) -> Result + fn create(_: String, _: PathBuf) -> Result where Self: Sized, { Ok(Self {}) } - fn fetch(&mut self) -> Result<(), ApiError> { + async fn fetch(&mut self, _: ModelType) -> Result<(), ApiError> { Ok(()) } } #[tokio::test] async fn test_inference_service_initialization() { - const CONFIG_FILE_PATH: &str = "./config.toml"; + const CONFIG_FILE_PATH: &str = "./inference.toml"; const PRIVATE_KEY_FILE_PATH: &str = "./private_key"; let private_key = PrivateKey::new(&mut OsRng); @@ -160,7 +164,7 @@ mod tests { let config_data = Value::Table(toml! { api_key = "your_api_key" - models = [{ Mamba = 3 }] + models = ["Mamba3b"] storage_base_path = "./storage_base_path/" tokenizer_file_path = "./tokenizer_file_path/" tracing = true From 9807bc9211874c2414cde0065d56c2bc5c49f53e Mon Sep 17 00:00:00 2001 From: jorgeantonio21 Date: Tue, 26 Mar 2024 00:19:48 +0000 Subject: [PATCH 06/28] refactor: Rename storage_base_path to storage_folder - Renamed the `storage_base_path` field in the `InferenceConfig` struct to `storage_folder` for improved clarity and consistency with the actual purpose of the field. - Updated references to `storage_base_path` to use `storage_folder` throughout the codebase in `config.rs` and `core.rs`. - Adjusted tests in `service.rs` to reflect the renaming of the field in the configuration. --- atoma-inference/src/config.rs | 10 +++++----- atoma-inference/src/core.rs | 4 ++-- atoma-inference/src/service.rs | 2 +- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/atoma-inference/src/config.rs b/atoma-inference/src/config.rs index 230fcb60..bdadcc75 100644 --- a/atoma-inference/src/config.rs +++ b/atoma-inference/src/config.rs @@ -9,7 +9,7 @@ use crate::models::ModelType; pub struct InferenceConfig { api_key: String, models: Vec, - storage_base_path: PathBuf, + storage_folder: PathBuf, tokenizer_file_path: PathBuf, tracing: bool, } @@ -18,14 +18,14 @@ impl InferenceConfig { pub fn new( api_key: String, models: Vec, - storage_base_path: PathBuf, + storage_folder: PathBuf, tokenizer_file_path: PathBuf, tracing: bool, ) -> Self { Self { api_key, models, - storage_base_path, + storage_folder, tokenizer_file_path, tracing, } @@ -39,8 +39,8 @@ impl InferenceConfig { self.models.clone() } - pub fn storage_base_path(&self) -> PathBuf { - self.storage_base_path.clone() + pub fn storage_folder(&self) -> PathBuf { + self.storage_folder.clone() } pub fn tokenizer_file_path(&self) -> PathBuf { diff --git a/atoma-inference/src/core.rs b/atoma-inference/src/core.rs index db88acda..fe1c14f8 100644 --- a/atoma-inference/src/core.rs +++ b/atoma-inference/src/core.rs @@ -23,7 +23,7 @@ impl InferenceCore { private_key: PrivateKey, ) -> Result { let public_key = private_key.verification_key(); - let api = T::create(config.api_key(), config.storage_base_path())?; + let api = T::create(config.api_key(), config.storage_folder())?; Ok(Self { config, public_key, @@ -46,7 +46,7 @@ impl InferenceCore { _top_p: Option, _top_k: usize, ) -> Result { - let mut model_path = self.config.storage_base_path().clone(); + let mut model_path = self.config.storage_folder().clone(); model_path.push(model.to_string()); // let tokenizer = Tokenizer::from_file(self.config.tokenizer_file_path()) diff --git a/atoma-inference/src/service.rs b/atoma-inference/src/service.rs index f120438e..3b186ff0 100644 --- a/atoma-inference/src/service.rs +++ b/atoma-inference/src/service.rs @@ -165,7 +165,7 @@ mod tests { let config_data = Value::Table(toml! { api_key = "your_api_key" models = ["Mamba3b"] - storage_base_path = "./storage_base_path/" + storage_folder = "./storage_folder/" tokenizer_file_path = "./tokenizer_file_path/" tracing = true }); From ab64ad39de9b3404d7582c3af73864f184860c62 Mon Sep 17 00:00:00 2001 From: jorgeantonio21 Date: Tue, 26 Mar 2024 00:54:17 +0000 Subject: [PATCH 07/28] Refactor core.rs and main.rs, introducing tracing for improved debugging and monitoring In this commit, changes have been made to core.rs and main.rs files to integrate tracing for enhanced debugging and monitoring capabilities. Specifically, the following modifications were implemented: - Imported the `tracing::info` module in core.rs to enable logging of informational messages. - Added logging statements using `info!` macro in core.rs to indicate the beginning of inference and provide information about the prompt and model being used. - Updated main.rs to remove unused variables and commented-out code, ensuring cleaner and more maintainable code. - Added logging statements in main.rs to indicate the start of the Core Dispatcher and fetching of models. These changes aim to improve the visibility into the execution flow of the inference service and facilitate easier debugging and monitoring during development and deployment. --- atoma-inference/src/core.rs | 10 +++------- atoma-inference/src/main.rs | 27 +++++++++++++++++++-------- atoma-inference/src/service.rs | 13 ++++++++----- atoma-inference/src/types.rs | 18 +++++++++--------- 4 files changed, 39 insertions(+), 29 deletions(-) diff --git a/atoma-inference/src/core.rs b/atoma-inference/src/core.rs index fe1c14f8..1c86f1c5 100644 --- a/atoma-inference/src/core.rs +++ b/atoma-inference/src/core.rs @@ -1,5 +1,6 @@ use ed25519_consensus::{SigningKey as PrivateKey, VerificationKey as PublicKey}; use thiserror::Error; +use tracing::info; use crate::{ apis::{ApiError, ApiTrait}, @@ -37,7 +38,7 @@ impl InferenceCore { #[allow(clippy::too_many_arguments)] pub fn inference( &mut self, - _prompt: String, + prompt: String, model: ModelType, _temperature: Option, _max_tokens: usize, @@ -46,15 +47,10 @@ impl InferenceCore { _top_p: Option, _top_k: usize, ) -> Result { + info!("Running inference on prompt: {prompt}, for model: {model}"); let mut model_path = self.config.storage_folder().clone(); model_path.push(model.to_string()); - // let tokenizer = Tokenizer::from_file(self.config.tokenizer_file_path()) - // .map_err(InferenceCoreError::FailedInference)?; - // let mut tokens = tokenizer - // .encode(prompt.0, true) - // .map_err(InferenceCoreError::FailedInference)?; - todo!() } diff --git a/atoma-inference/src/main.rs b/atoma-inference/src/main.rs index f8a2aef3..b8f50681 100644 --- a/atoma-inference/src/main.rs +++ b/atoma-inference/src/main.rs @@ -3,17 +3,28 @@ use inference::service::InferenceService; #[tokio::main] async fn main() { - let (_, request_receiver) = tokio::sync::mpsc::channel(32); + let (_, receiver) = tokio::sync::mpsc::channel(32); let _ = InferenceService::start::( - "/Users/jorgeantonio/dev/atoma-node/inference.toml" - .parse() - .unwrap(), - "/Users/jorgeantonio/dev/atoma-node/atoma-inference/private_key" - .parse() - .unwrap(), - request_receiver, + "../inference.toml".parse().unwrap(), + "../private_key".parse().unwrap(), + receiver, ) .await .expect("Failed to start inference service"); + + // inference_service + // .run_inference(InferenceRequest { + // prompt: String::from("Which protocols are faster, zk-STARKs or zk-SNARKs ?"), + // max_tokens: 512, + // model: inference::models::ModelType::Llama2_7b, + // random_seed: 42, + // sampled_nodes: vec![], + // repeat_penalty: 1.0, + // temperature: Some(0.6), + // top_k: 10, + // top_p: None, + // }) + // .await + // .unwrap(); } diff --git a/atoma-inference/src/service.rs b/atoma-inference/src/service.rs index 3b186ff0..a7432c34 100644 --- a/atoma-inference/src/service.rs +++ b/atoma-inference/src/service.rs @@ -17,14 +17,14 @@ pub struct InferenceService { core_thread_handle: CoreThreadHandle, dispatcher: CoreThreadDispatcher, start_time: Instant, - request_receiver: Receiver, + _request_receiver: Receiver, } impl InferenceService { pub async fn start( config_file_path: PathBuf, private_key_path: PathBuf, - request_receiver: Receiver, + _request_receiver: Receiver, ) -> Result { let private_key_bytes = std::fs::read(&private_key_path).map_err(InferenceServiceError::PrivateKeyError)?; @@ -37,6 +37,8 @@ impl InferenceService { let models = inference_config.models(); let inference_core = InferenceCore::::new(inference_config, private_key)?; + info!("Starting Core Dispatcher.."); + let (dispatcher, core_thread_handle) = CoreThreadDispatcher::start(inference_core); let start_time = Instant::now(); @@ -44,10 +46,11 @@ impl InferenceService { dispatcher, core_thread_handle, start_time, - request_receiver, + _request_receiver, }; for model in models { + info!("Fetching model {:?}", model); let response = inference_service .fetch_model(ModelRequest { model: model.clone(), @@ -65,7 +68,7 @@ impl InferenceService { Ok(inference_service) } - async fn run_inference( + pub async fn run_inference( &self, inference_request: InferenceRequest, ) -> Result { @@ -75,7 +78,7 @@ impl InferenceService { .map_err(InferenceServiceError::CoreError) } - async fn fetch_model( + pub async fn fetch_model( &self, model_request: ModelRequest, ) -> Result { diff --git a/atoma-inference/src/types.rs b/atoma-inference/src/types.rs index 8ee6baeb..1a2c252d 100644 --- a/atoma-inference/src/types.rs +++ b/atoma-inference/src/types.rs @@ -6,15 +6,15 @@ pub type Temperature = f32; #[derive(Clone, Debug)] pub struct InferenceRequest { - pub(crate) prompt: String, - pub(crate) model: ModelType, - pub(crate) max_tokens: usize, - pub(crate) random_seed: usize, - pub(crate) repeat_penalty: f32, - pub(crate) sampled_nodes: Vec, - pub(crate) temperature: Option, - pub(crate) top_k: usize, - pub(crate) top_p: Option, + pub prompt: String, + pub model: ModelType, + pub max_tokens: usize, + pub random_seed: usize, + pub repeat_penalty: f32, + pub sampled_nodes: Vec, + pub temperature: Option, + pub top_k: usize, + pub top_p: Option, } #[derive(Clone, Debug)] From 7a22bd3ef1b9183d90048d38e9756f728a828970 Mon Sep 17 00:00:00 2001 From: jorgeantonio21 Date: Tue, 26 Mar 2024 18:01:38 +0000 Subject: [PATCH 08/28] Add tracing-subscriber crate to Cargo.toml Add tracing-subscriber crate to the dependencies in Cargo.toml to enable structured logging in the project. Changes: - Added tracing-subscriber = "0.3.18" to dependencies in Cargo.toml - Added tracing-subscriber.workspace = true to the workspace in Cargo.toml --- Cargo.toml | 1 + atoma-inference/Cargo.toml | 1 + atoma-inference/src/apis/hugging_face.rs | 2 +- atoma-inference/src/core.rs | 10 +++++----- atoma-inference/src/main.rs | 2 ++ 5 files changed, 10 insertions(+), 6 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index d5f109d7..924a79d5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,3 +23,4 @@ tokenizers = "0.15.2" tokio = "1.36.0" toml = "0.8.12" tracing = "0.1.40" +tracing-subscriber = "0.3.18" diff --git a/atoma-inference/Cargo.toml b/atoma-inference/Cargo.toml index 99e1e1de..e89fd514 100644 --- a/atoma-inference/Cargo.toml +++ b/atoma-inference/Cargo.toml @@ -20,6 +20,7 @@ thiserror.workspace = true tokenizers.workspace = true tokio = { workspace = true, features = ["full", "tracing"] } tracing.workspace = true +tracing-subscriber.workspace = true [dev-dependencies] rand.workspace = true diff --git a/atoma-inference/src/apis/hugging_face.rs b/atoma-inference/src/apis/hugging_face.rs index d49b19d7..5b486976 100644 --- a/atoma-inference/src/apis/hugging_face.rs +++ b/atoma-inference/src/apis/hugging_face.rs @@ -39,7 +39,7 @@ impl ModelType { file_paths: vec![ "model-00001-of-00003.safetensors".to_string(), "model-00002-of-00003.safetensors".to_string(), - "model-00002-of-00003.safetensors".to_string(), + "model-00003-of-00003.safetensors".to_string(), ], }, ), diff --git a/atoma-inference/src/core.rs b/atoma-inference/src/core.rs index 1c86f1c5..c5114b6a 100644 --- a/atoma-inference/src/core.rs +++ b/atoma-inference/src/core.rs @@ -10,21 +10,21 @@ use crate::{ }; #[allow(dead_code)] -pub struct InferenceCore { +pub struct InferenceCore { pub(crate) config: InferenceConfig, // models: Vec, pub(crate) public_key: PublicKey, private_key: PrivateKey, - pub(crate) api: T, + pub(crate) api: Api, } -impl InferenceCore { +impl InferenceCore { pub fn new( config: InferenceConfig, private_key: PrivateKey, ) -> Result { let public_key = private_key.verification_key(); - let api = T::create(config.api_key(), config.storage_folder())?; + let api = Api::create(config.api_key(), config.storage_folder())?; Ok(Self { config, public_key, @@ -34,7 +34,7 @@ impl InferenceCore { } } -impl InferenceCore { +impl InferenceCore { #[allow(clippy::too_many_arguments)] pub fn inference( &mut self, diff --git a/atoma-inference/src/main.rs b/atoma-inference/src/main.rs index b8f50681..bfee0121 100644 --- a/atoma-inference/src/main.rs +++ b/atoma-inference/src/main.rs @@ -3,6 +3,8 @@ use inference::service::InferenceService; #[tokio::main] async fn main() { + tracing_subscriber::fmt::init(); + let (_, receiver) = tokio::sync::mpsc::channel(32); let _ = InferenceService::start::( From 65f6e0136c313d0a91e3522a538a7be2c4ec3135 Mon Sep 17 00:00:00 2001 From: jorgeantonio21 Date: Tue, 26 Mar 2024 22:13:35 +0000 Subject: [PATCH 09/28] address PR comments --- atoma-inference/src/apis/hugging_face.rs | 6 +-- atoma-inference/src/apis/mod.rs | 5 +-- atoma-inference/src/core.rs | 2 +- atoma-inference/src/core_thread.rs | 11 +++++ atoma-inference/src/main.rs | 3 +- atoma-inference/src/service.rs | 52 ++++++++++++++---------- 6 files changed, 48 insertions(+), 31 deletions(-) diff --git a/atoma-inference/src/apis/hugging_face.rs b/atoma-inference/src/apis/hugging_face.rs index 5b486976..abca1c91 100644 --- a/atoma-inference/src/apis/hugging_face.rs +++ b/atoma-inference/src/apis/hugging_face.rs @@ -1,7 +1,7 @@ use std::path::PathBuf; use async_trait::async_trait; -use hf_hub::api::tokio::{Api, ApiBuilder}; +use hf_hub::api::sync::{Api, ApiBuilder}; use crate::models::ModelType; @@ -106,11 +106,11 @@ impl ApiTrait for Api { .build()?) } - async fn fetch(&mut self, model: ModelType) -> Result<(), super::ApiError> { + fn fetch(&self, model: ModelType) -> Result<(), super::ApiError> { let (model_path, files) = model.get_hugging_face_model_path(); let api_repo = self.model(model_path); for file in files.file_paths { - api_repo.get(&file).await?; + api_repo.get(&file)?; } Ok(()) diff --git a/atoma-inference/src/apis/mod.rs b/atoma-inference/src/apis/mod.rs index 2c598aa2..c7919568 100644 --- a/atoma-inference/src/apis/mod.rs +++ b/atoma-inference/src/apis/mod.rs @@ -1,6 +1,6 @@ pub mod hugging_face; use async_trait::async_trait; -use hf_hub::api::tokio::ApiError as HuggingFaceError; +use hf_hub::api::sync::ApiError as HuggingFaceError; use std::path::PathBuf; @@ -22,10 +22,9 @@ impl From for ApiError { } } -#[async_trait] pub trait ApiTrait { fn call(&mut self) -> Result<(), ApiError>; - async fn fetch(&mut self, model: ModelType) -> Result<(), ApiError>; + fn fetch(&self, model: ModelType) -> Result<(), ApiError>; fn create(api_key: String, cache_dir: PathBuf) -> Result where Self: Sized; diff --git a/atoma-inference/src/core.rs b/atoma-inference/src/core.rs index c5114b6a..1a400f5e 100644 --- a/atoma-inference/src/core.rs +++ b/atoma-inference/src/core.rs @@ -59,7 +59,7 @@ impl InferenceCore { model: ModelType, _quantization_method: Option, ) -> Result { - self.api.fetch(model).await?; + self.api.fetch(model)?; Ok(ModelResponse { is_success: true, error: None, diff --git a/atoma-inference/src/core_thread.rs b/atoma-inference/src/core_thread.rs index 94f3df03..b4a332c5 100644 --- a/atoma-inference/src/core_thread.rs +++ b/atoma-inference/src/core_thread.rs @@ -53,6 +53,17 @@ impl CoreThread { pub async fn run(mut self) -> Result<(), CoreError> { debug!("Starting Core thread"); + // let models = self.core.config.models(); + // for model_type in models { + // let (model_sender, model_receiver) = std::sync::mpsc::channel(); + // let + // std::thread::spawn(move || { + // while Ok(request) = model_receiver.recv() { + + // } + // }); + // } + while let Some(command) = self.receiver.recv().await { match command { CoreThreadCommand::RunInference(request, sender) => { diff --git a/atoma-inference/src/main.rs b/atoma-inference/src/main.rs index bfee0121..40b2806e 100644 --- a/atoma-inference/src/main.rs +++ b/atoma-inference/src/main.rs @@ -1,4 +1,4 @@ -use hf_hub::api::tokio::Api; +use hf_hub::api::sync::Api; use inference::service::InferenceService; #[tokio::main] @@ -12,7 +12,6 @@ async fn main() { "../private_key".parse().unwrap(), receiver, ) - .await .expect("Failed to start inference service"); // inference_service diff --git a/atoma-inference/src/service.rs b/atoma-inference/src/service.rs index a7432c34..7648dc05 100644 --- a/atoma-inference/src/service.rs +++ b/atoma-inference/src/service.rs @@ -1,6 +1,12 @@ use ed25519_consensus::SigningKey as PrivateKey; -use std::{io, path::PathBuf, time::Instant}; -use tokio::sync::mpsc::{error::SendError, Receiver}; +use std::{io, path::PathBuf, sync::Arc, time::Instant}; +use tokio::{ + io::join, + sync::{ + mpsc::{error::SendError, Receiver}, + RwLock, + }, +}; use tracing::{info, warn}; use thiserror::Error; @@ -21,7 +27,7 @@ pub struct InferenceService { } impl InferenceService { - pub async fn start( + pub fn start( config_file_path: PathBuf, private_key_path: PathBuf, _request_receiver: Receiver, @@ -37,6 +43,17 @@ impl InferenceService { let models = inference_config.models(); let inference_core = InferenceCore::::new(inference_config, private_key)?; + let mut handles = Vec::with_capacity(models.len()); + for model in models { + let api = inference_core.api.clone(); + let handle = std::thread::spawn(move || { + api.fetch(model).expect("Failed to fetch model"); + }); + handles.push(handle); + } + + handles.into_iter().for_each(|h| h.join().unwrap()); + info!("Starting Core Dispatcher.."); let (dispatcher, core_thread_handle) = CoreThreadDispatcher::start(inference_core); @@ -49,22 +66,6 @@ impl InferenceService { _request_receiver, }; - for model in models { - info!("Fetching model {:?}", model); - let response = inference_service - .fetch_model(ModelRequest { - model: model.clone(), - quantization_method: None, - }) - .await?; - if !response.is_success { - warn!( - "Failed to fetch model: {:?}, with error: {:?}", - model, response.error - ); - } - } - Ok(inference_service) } @@ -114,6 +115,8 @@ pub enum InferenceServiceError { CoreError(CoreError), #[error("Send error: `{0}`")] SendError(SendError), + #[error("Api error: `{0}`")] + ApiError(ApiError), } impl From for InferenceServiceError { @@ -126,6 +129,12 @@ impl From for InferenceServiceError { } } +impl From for InferenceServiceError { + fn from(error: ApiError) -> Self { + Self::ApiError(error) + } +} + #[cfg(test)] mod tests { use async_trait::async_trait; @@ -137,9 +146,9 @@ mod tests { use super::*; + #[derive(Clone)] struct TestApiInstance {} - #[async_trait] impl ApiTrait for TestApiInstance { fn call(&mut self) -> Result<(), ApiError> { Ok(()) @@ -152,7 +161,7 @@ mod tests { Ok(Self {}) } - async fn fetch(&mut self, _: ModelType) -> Result<(), ApiError> { + fn fetch(&self, _: ModelType) -> Result<(), ApiError> { Ok(()) } } @@ -186,7 +195,6 @@ mod tests { PathBuf::try_from(PRIVATE_KEY_FILE_PATH).unwrap(), receiver, ) - .await .unwrap(); std::fs::remove_file(CONFIG_FILE_PATH).unwrap(); From 1ce0c122d2dc218e4d62111fce2564e1d88d3a02 Mon Sep 17 00:00:00 2001 From: jorgeantonio21 Date: Wed, 27 Mar 2024 01:26:11 +0000 Subject: [PATCH 10/28] refactor core thread to model thread, to facilitate models running in parallel --- atoma-inference/Cargo.toml | 2 +- atoma-inference/src/apis/mod.rs | 1 - atoma-inference/src/core.rs | 21 +-- atoma-inference/src/core_thread.rs | 227 ++++++++++++++--------------- atoma-inference/src/models.rs | 2 +- atoma-inference/src/service.rs | 57 +++----- atoma-inference/src/types.rs | 7 +- 7 files changed, 143 insertions(+), 174 deletions(-) diff --git a/atoma-inference/Cargo.toml b/atoma-inference/Cargo.toml index e89fd514..04d5ff96 100644 --- a/atoma-inference/Cargo.toml +++ b/atoma-inference/Cargo.toml @@ -12,7 +12,7 @@ candle-nn.workspace = true candle-transformers.workspace = true config.true = true ed25519-consensus.workspace = true -hf-hub = { workspace = true, features = ["tokio"] } +hf-hub.workspace = true reqwest = { workspace = true, features = ["json"] } serde = { workspace = true, features = ["derive"] } serde_json.workspace = true diff --git a/atoma-inference/src/apis/mod.rs b/atoma-inference/src/apis/mod.rs index c7919568..355d5f46 100644 --- a/atoma-inference/src/apis/mod.rs +++ b/atoma-inference/src/apis/mod.rs @@ -1,5 +1,4 @@ pub mod hugging_face; -use async_trait::async_trait; use hf_hub::api::sync::ApiError as HuggingFaceError; use std::path::PathBuf; diff --git a/atoma-inference/src/core.rs b/atoma-inference/src/core.rs index 1a400f5e..dd185eab 100644 --- a/atoma-inference/src/core.rs +++ b/atoma-inference/src/core.rs @@ -10,31 +10,28 @@ use crate::{ }; #[allow(dead_code)] -pub struct InferenceCore { +pub struct InferenceCore { pub(crate) config: InferenceConfig, // models: Vec, pub(crate) public_key: PublicKey, private_key: PrivateKey, - pub(crate) api: Api, } -impl InferenceCore { +impl InferenceCore { pub fn new( config: InferenceConfig, private_key: PrivateKey, ) -> Result { let public_key = private_key.verification_key(); - let api = Api::create(config.api_key(), config.storage_folder())?; Ok(Self { config, public_key, private_key, - api, }) } } -impl InferenceCore { +impl InferenceCore { #[allow(clippy::too_many_arguments)] pub fn inference( &mut self, @@ -53,18 +50,6 @@ impl InferenceCore { todo!() } - - pub async fn fetch_model( - &mut self, - model: ModelType, - _quantization_method: Option, - ) -> Result { - self.api.fetch(model)?; - Ok(ModelResponse { - is_success: true, - error: None, - }) - } } #[derive(Debug, Error)] diff --git a/atoma-inference/src/core_thread.rs b/atoma-inference/src/core_thread.rs index b4a332c5..f42d2c45 100644 --- a/atoma-inference/src/core_thread.rs +++ b/atoma-inference/src/core_thread.rs @@ -1,3 +1,7 @@ +use std::collections::HashMap; + +use candle_nn::VarBuilder; +use ed25519_consensus::VerificationKey as PublicKey; use thiserror::Error; use tokio::{ sync::{ @@ -9,100 +13,86 @@ use tokio::{ use tracing::{debug, error, warn}; use crate::{ - apis::ApiTrait, core::{InferenceCore, InferenceCoreError}, - types::{InferenceRequest, InferenceResponse, ModelRequest, ModelResponse}, + models::{ModelApi, ModelError, ModelSpecs, ModelType}, + types::{InferenceRequest, InferenceResponse}, }; const CORE_THREAD_COMMANDS_CHANNEL_SIZE: usize = 32; pub enum CoreThreadCommand { RunInference(InferenceRequest, oneshot::Sender), - FetchModel(ModelRequest, oneshot::Sender), } +pub struct ModelThreadCommand(InferenceRequest, oneshot::Sender); + #[derive(Debug, Error)] -pub enum CoreError { +pub enum ModelThreadError { #[error("Core thread shutdown: `{0}`")] FailedInference(InferenceCoreError), - #[error("Core thread shutdown: `{0}`")] - FailedModelFetch(InferenceCoreError), + #[error("Model thread shutdown: `{0}`")] + ModelError(ModelError), #[error("Core thread shutdown: `{0}`")] Shutdown(RecvError), } -pub struct CoreThreadHandle { - sender: mpsc::Sender, - join_handle: JoinHandle<()>, +pub struct ModelThreadHandle { + sender: std::sync::mpsc::Sender, + join_handle: std::thread::JoinHandle<()>, } -impl CoreThreadHandle { - pub async fn stop(self) { - // drop the sender, this will force all the other weak senders to not be able to upgrade. +impl ModelThreadHandle { + pub fn stop(self) { drop(self.sender); - self.join_handle.await.ok(); + self.join_handle.join().ok(); } } -pub struct CoreThread { - core: InferenceCore, - receiver: mpsc::Receiver, +pub struct ModelThread { + model: T, + receiver: std::sync::mpsc::Receiver, } -impl CoreThread { - pub async fn run(mut self) -> Result<(), CoreError> { - debug!("Starting Core thread"); - - // let models = self.core.config.models(); - // for model_type in models { - // let (model_sender, model_receiver) = std::sync::mpsc::channel(); - // let - // std::thread::spawn(move || { - // while Ok(request) = model_receiver.recv() { - - // } - // }); - // } - - while let Some(command) = self.receiver.recv().await { - match command { - CoreThreadCommand::RunInference(request, sender) => { - let InferenceRequest { - prompt, - model, - max_tokens, - temperature, - random_seed, - repeat_penalty, - top_k, - top_p, - sampled_nodes, - } = request; - if !sampled_nodes.contains(&self.core.public_key) { - error!("Current node, with verification key = {:?} was not sampled from {sampled_nodes:?}", self.core.public_key); - continue; - } - let response = self.core.inference( - prompt, - model, - temperature, - max_tokens, - random_seed, - repeat_penalty, - top_p, - top_k, - )?; - sender.send(response).ok(); - } - CoreThreadCommand::FetchModel(request, sender) => { - let ModelRequest { - model, - quantization_method, - } = request; - let response = self.core.fetch_model(model, quantization_method).await?; - sender.send(response).ok(); - } +impl ModelThread +where + T: ModelApi, +{ + pub fn run(mut self, public_key: PublicKey) -> Result<(), ModelThreadError> { + debug!("Start Model thread"); + + while let Ok(command) = self.receiver.recv() { + let ModelThreadCommand(request, sender) = command; + + let InferenceRequest { + prompt, + model, + max_tokens, + temperature, + random_seed, + repeat_last_n, + repeat_penalty, + top_k, + top_p, + sampled_nodes, + } = request; + if !sampled_nodes.contains(&public_key) { + error!("Current node, with verification key = {:?} was not sampled from {sampled_nodes:?}", public_key); + continue; } + let response = self + .model + .run( + prompt, + max_tokens, + random_seed, + repeat_last_n, + repeat_penalty, + temperature.unwrap_or_default(), + top_p.unwrap_or_default(), + ) + .map_err(ModelThreadError::ModelError)?; + let response = InferenceResponse { response }; + sender.send(response).ok(); } Ok(()) @@ -110,72 +100,79 @@ impl CoreThread { } #[derive(Clone)] -pub struct CoreThreadDispatcher { - sender: mpsc::WeakSender, +pub struct ModelThreadDispatcher { + model_senders: HashMap>, } -impl CoreThreadDispatcher { - pub(crate) fn start( - core: InferenceCore, - ) -> (Self, CoreThreadHandle) { - let (sender, receiver) = mpsc::channel(CORE_THREAD_COMMANDS_CHANNEL_SIZE); - let core_thread = CoreThread { core, receiver }; - - let join_handle = tokio::task::spawn(async move { - if let Err(e) = core_thread.run().await { - if !matches!(e, CoreError::Shutdown(_)) { - panic!("Fatal error occurred: {e}"); +impl ModelThreadDispatcher { + pub(crate) fn start( + &self, + models: Vec<(ModelType, ModelSpecs, VarBuilder)>, + public_key: PublicKey, + ) -> Result<(Self, Vec), ModelThreadError> { + let (core_sender, core_receiver) = std::sync::mpsc::channel::(); + + let mut handles = Vec::with_capacity(models.len()); + let mut model_senders = HashMap::with_capacity(models.len()); + + for (model_type, model_specs, var_builder) in models { + let (model_sender, model_receiver) = std::sync::mpsc::channel::(); + let model = T::load(model_specs, var_builder); // TODO: for now this piece of code cannot be shared among threads safely + let model_thread = ModelThread { + model, + receiver: model_receiver, + }; + let join_handle = std::thread::spawn(move || { + if let Err(e) = model_thread.run(public_key) { + error!("Model thread error: {e}"); + if !matches!(e, ModelThreadError::Shutdown(_)) { + panic!("Fatal error occurred: {e}"); + } } - } - }); + }); + handles.push(ModelThreadHandle { + join_handle, + sender: model_sender.clone(), + }); + model_senders.insert(model_type, model_sender); + } - let dispatcher = Self { - sender: sender.downgrade(), - }; - let handle = CoreThreadHandle { - join_handle, - sender, - }; + let model_dispatcher = ModelThreadDispatcher { model_senders }; - (dispatcher, handle) + Ok((model_dispatcher, handles)) } - async fn send(&self, command: CoreThreadCommand) { - if let Some(sender) = self.sender.upgrade() { - if let Err(e) = sender.send(command).await { - warn!("Could not send command to thread core, it might be shutting down: {e}"); - } + fn send(&self, command: ModelThreadCommand) { + let request = command.0.clone(); + let model_type = request.model; + + let sender = self + .model_senders + .get(&model_type) + .expect("Failed to get model thread, this should not happen !"); + + if let Err(e) = sender.send(command) { + warn!("Could not send command to model core, it might be shutting down: {e}"); } } } -impl CoreThreadDispatcher { - pub(crate) async fn fetch_model( - &self, - request: ModelRequest, - ) -> Result { - let (sender, receiver) = oneshot::channel(); - self.send(CoreThreadCommand::FetchModel(request, sender)) - .await; - receiver.await.map_err(CoreError::Shutdown) - } - +impl ModelThreadDispatcher { pub(crate) async fn run_inference( &self, request: InferenceRequest, - ) -> Result { + ) -> Result { let (sender, receiver) = oneshot::channel(); - self.send(CoreThreadCommand::RunInference(request, sender)) - .await; - receiver.await.map_err(CoreError::Shutdown) + self.send(ModelThreadCommand(request, sender)); + receiver.await.map_err(ModelThreadError::Shutdown) } } -impl From for CoreError { +impl From for ModelThreadError { fn from(error: InferenceCoreError) -> Self { match error { - InferenceCoreError::FailedInference(_) => CoreError::FailedInference(error), - InferenceCoreError::FailedModelFetch(_) => CoreError::FailedModelFetch(error), + InferenceCoreError::FailedInference(_) => ModelThreadError::FailedInference(error), + InferenceCoreError::FailedModelFetch(_) => unreachable!(), InferenceCoreError::FailedApiConnection(_) => { panic!("API connection should have been already established") } diff --git a/atoma-inference/src/models.rs b/atoma-inference/src/models.rs index 44891c56..09b55ccd 100644 --- a/atoma-inference/src/models.rs +++ b/atoma-inference/src/models.rs @@ -22,7 +22,7 @@ use crate::types::Temperature; const EOS_TOKEN: &str = ""; -#[derive(Clone, Debug, Deserialize)] +#[derive(Clone, Debug, Deserialize, Eq, Hash, PartialEq)] pub enum ModelType { Llama2_7b, Mamba3b, diff --git a/atoma-inference/src/service.rs b/atoma-inference/src/service.rs index 7648dc05..3896b300 100644 --- a/atoma-inference/src/service.rs +++ b/atoma-inference/src/service.rs @@ -1,13 +1,8 @@ use ed25519_consensus::SigningKey as PrivateKey; -use std::{io, path::PathBuf, sync::Arc, time::Instant}; -use tokio::{ - io::join, - sync::{ - mpsc::{error::SendError, Receiver}, - RwLock, - }, -}; -use tracing::{info, warn}; +use hf_hub::api::sync::Api; +use std::{io, path::PathBuf, time::Instant}; +use tokio::sync::mpsc::{error::SendError, Receiver}; +use tracing::info; use thiserror::Error; @@ -15,13 +10,13 @@ use crate::{ apis::{ApiError, ApiTrait}, config::InferenceConfig, core::{InferenceCore, InferenceCoreError}, - core_thread::{CoreError, CoreThreadDispatcher, CoreThreadHandle}, - types::{InferenceRequest, InferenceResponse, ModelRequest, ModelResponse}, + core_thread::{ModelThreadDispatcher, ModelThreadError, ModelThreadHandle}, + types::{InferenceRequest, InferenceResponse}, }; pub struct InferenceService { - core_thread_handle: CoreThreadHandle, - dispatcher: CoreThreadDispatcher, + model_thread_handle: Vec, + dispatcher: ModelThreadDispatcher, start_time: Instant, _request_receiver: Receiver, } @@ -40,12 +35,16 @@ impl InferenceService { let private_key = PrivateKey::from(private_key_bytes); let inference_config = InferenceConfig::from_file_path(config_file_path); + let api_key = inference_config.api_key(); + let storage_folder = inference_config.storage_folder(); let models = inference_config.models(); - let inference_core = InferenceCore::::new(inference_config, private_key)?; + let inference_core = InferenceCore::new(inference_config, private_key)?; + + let api = Api::create(api_key, storage_folder)?; let mut handles = Vec::with_capacity(models.len()); for model in models { - let api = inference_core.api.clone(); + let api = api.clone(); let handle = std::thread::spawn(move || { api.fetch(model).expect("Failed to fetch model"); }); @@ -56,17 +55,15 @@ impl InferenceService { info!("Starting Core Dispatcher.."); - let (dispatcher, core_thread_handle) = CoreThreadDispatcher::start(inference_core); + let (dispatcher, model_thread_handle) = ModelThreadDispatcher::start(inference_core)?; let start_time = Instant::now(); - let inference_service = Self { + Ok(Self { dispatcher, - core_thread_handle, + model_thread_handle, start_time, _request_receiver, - }; - - Ok(inference_service) + }) } pub async fn run_inference( @@ -76,28 +73,18 @@ impl InferenceService { self.dispatcher .run_inference(inference_request) .await - .map_err(InferenceServiceError::CoreError) - } - - pub async fn fetch_model( - &self, - model_request: ModelRequest, - ) -> Result { - self.dispatcher - .fetch_model(model_request) - .await - .map_err(InferenceServiceError::CoreError) + .map_err(InferenceServiceError::ModelThreadError) } } impl InferenceService { - pub async fn stop(self) { + pub async fn stop(mut self) { info!( "Stopping Inference Service, running time: {:?}", self.start_time.elapsed() ); - self.core_thread_handle.stop().await; + self.model_thread_handle.drain(..).map(|h| h.stop()); } } @@ -112,7 +99,7 @@ pub enum InferenceServiceError { #[error("Failed to generate private key: `{0}`")] PrivateKeyError(io::Error), #[error("Core error: `{0}`")] - CoreError(CoreError), + ModelThreadError(ModelThreadError), #[error("Send error: `{0}`")] SendError(SendError), #[error("Api error: `{0}`")] diff --git a/atoma-inference/src/types.rs b/atoma-inference/src/types.rs index 1a2c252d..383347f0 100644 --- a/atoma-inference/src/types.rs +++ b/atoma-inference/src/types.rs @@ -10,6 +10,7 @@ pub struct InferenceRequest { pub model: ModelType, pub max_tokens: usize, pub random_seed: usize, + pub repeat_last_n: usize, pub repeat_penalty: f32, pub sampled_nodes: Vec, pub temperature: Option, @@ -21,9 +22,9 @@ pub struct InferenceRequest { #[allow(dead_code)] pub struct InferenceResponse { // TODO: possibly a Merkle root hash - pub(crate) response_hash: [u8; 32], - pub(crate) node_id: NodeId, - pub(crate) signature: Vec, + // pub(crate) response_hash: [u8; 32], + // pub(crate) node_id: NodeId, + // pub(crate) signature: Vec, pub(crate) response: String, } From 156b3981fc517e28b174f99df7a5b21506938736 Mon Sep 17 00:00:00 2001 From: jorgeantonio21 Date: Wed, 27 Mar 2024 02:12:44 +0000 Subject: [PATCH 11/28] remove core, rename core_thread to model_thread, and work on setting up the models --- atoma-inference/src/config.rs | 10 +- atoma-inference/src/core.rs | 69 ----------- atoma-inference/src/core_thread.rs | 181 ----------------------------- atoma-inference/src/lib.rs | 3 +- atoma-inference/src/models.rs | 28 ++++- atoma-inference/src/service.rs | 30 +++-- atoma-inference/src/types.rs | 3 +- 7 files changed, 48 insertions(+), 276 deletions(-) delete mode 100644 atoma-inference/src/core.rs delete mode 100644 atoma-inference/src/core_thread.rs diff --git a/atoma-inference/src/config.rs b/atoma-inference/src/config.rs index bdadcc75..546a0e85 100644 --- a/atoma-inference/src/config.rs +++ b/atoma-inference/src/config.rs @@ -1,14 +1,16 @@ use std::path::PathBuf; +use candle::DType; use config::Config; use serde::Deserialize; -use crate::models::ModelType; +use crate::{models::ModelType, types::PrecisionBits}; #[derive(Debug, Deserialize)] pub struct InferenceConfig { api_key: String, models: Vec, + precision: PrecisionBits, storage_folder: PathBuf, tokenizer_file_path: PathBuf, tracing: bool, @@ -18,6 +20,7 @@ impl InferenceConfig { pub fn new( api_key: String, models: Vec, + precision: PrecisionBits, storage_folder: PathBuf, tokenizer_file_path: PathBuf, tracing: bool, @@ -25,6 +28,7 @@ impl InferenceConfig { Self { api_key, models, + precision, storage_folder, tokenizer_file_path, tracing, @@ -51,6 +55,10 @@ impl InferenceConfig { self.tracing } + pub fn precision_bits(&self) -> PrecisionBits { + self.precision + } + pub fn from_file_path(config_file_path: PathBuf) -> Self { let builder = Config::builder().add_source(config::File::with_name( config_file_path.to_str().as_ref().unwrap(), diff --git a/atoma-inference/src/core.rs b/atoma-inference/src/core.rs deleted file mode 100644 index dd185eab..00000000 --- a/atoma-inference/src/core.rs +++ /dev/null @@ -1,69 +0,0 @@ -use ed25519_consensus::{SigningKey as PrivateKey, VerificationKey as PublicKey}; -use thiserror::Error; -use tracing::info; - -use crate::{ - apis::{ApiError, ApiTrait}, - config::InferenceConfig, - models::ModelType, - types::{InferenceResponse, ModelResponse, QuantizationMethod, Temperature}, -}; - -#[allow(dead_code)] -pub struct InferenceCore { - pub(crate) config: InferenceConfig, - // models: Vec, - pub(crate) public_key: PublicKey, - private_key: PrivateKey, -} - -impl InferenceCore { - pub fn new( - config: InferenceConfig, - private_key: PrivateKey, - ) -> Result { - let public_key = private_key.verification_key(); - Ok(Self { - config, - public_key, - private_key, - }) - } -} - -impl InferenceCore { - #[allow(clippy::too_many_arguments)] - pub fn inference( - &mut self, - prompt: String, - model: ModelType, - _temperature: Option, - _max_tokens: usize, - _random_seed: usize, - _repeat_penalty: f32, - _top_p: Option, - _top_k: usize, - ) -> Result { - info!("Running inference on prompt: {prompt}, for model: {model}"); - let mut model_path = self.config.storage_folder().clone(); - model_path.push(model.to_string()); - - todo!() - } -} - -#[derive(Debug, Error)] -pub enum InferenceCoreError { - #[error("Failed to generate inference output: `{0}`")] - FailedInference(Box), - #[error("Failed to fetch new AI model: `{0}`")] - FailedModelFetch(String), - #[error("Failed to connect to web2 API: `{0}`")] - FailedApiConnection(ApiError), -} - -impl From for InferenceCoreError { - fn from(error: ApiError) -> Self { - InferenceCoreError::FailedApiConnection(error) - } -} diff --git a/atoma-inference/src/core_thread.rs b/atoma-inference/src/core_thread.rs deleted file mode 100644 index f42d2c45..00000000 --- a/atoma-inference/src/core_thread.rs +++ /dev/null @@ -1,181 +0,0 @@ -use std::collections::HashMap; - -use candle_nn::VarBuilder; -use ed25519_consensus::VerificationKey as PublicKey; -use thiserror::Error; -use tokio::{ - sync::{ - mpsc, - oneshot::{self, error::RecvError}, - }, - task::JoinHandle, -}; -use tracing::{debug, error, warn}; - -use crate::{ - core::{InferenceCore, InferenceCoreError}, - models::{ModelApi, ModelError, ModelSpecs, ModelType}, - types::{InferenceRequest, InferenceResponse}, -}; - -const CORE_THREAD_COMMANDS_CHANNEL_SIZE: usize = 32; - -pub enum CoreThreadCommand { - RunInference(InferenceRequest, oneshot::Sender), -} - -pub struct ModelThreadCommand(InferenceRequest, oneshot::Sender); - -#[derive(Debug, Error)] -pub enum ModelThreadError { - #[error("Core thread shutdown: `{0}`")] - FailedInference(InferenceCoreError), - #[error("Model thread shutdown: `{0}`")] - ModelError(ModelError), - #[error("Core thread shutdown: `{0}`")] - Shutdown(RecvError), -} - -pub struct ModelThreadHandle { - sender: std::sync::mpsc::Sender, - join_handle: std::thread::JoinHandle<()>, -} - -impl ModelThreadHandle { - pub fn stop(self) { - drop(self.sender); - self.join_handle.join().ok(); - } -} - -pub struct ModelThread { - model: T, - receiver: std::sync::mpsc::Receiver, -} - -impl ModelThread -where - T: ModelApi, -{ - pub fn run(mut self, public_key: PublicKey) -> Result<(), ModelThreadError> { - debug!("Start Model thread"); - - while let Ok(command) = self.receiver.recv() { - let ModelThreadCommand(request, sender) = command; - - let InferenceRequest { - prompt, - model, - max_tokens, - temperature, - random_seed, - repeat_last_n, - repeat_penalty, - top_k, - top_p, - sampled_nodes, - } = request; - if !sampled_nodes.contains(&public_key) { - error!("Current node, with verification key = {:?} was not sampled from {sampled_nodes:?}", public_key); - continue; - } - let response = self - .model - .run( - prompt, - max_tokens, - random_seed, - repeat_last_n, - repeat_penalty, - temperature.unwrap_or_default(), - top_p.unwrap_or_default(), - ) - .map_err(ModelThreadError::ModelError)?; - let response = InferenceResponse { response }; - sender.send(response).ok(); - } - - Ok(()) - } -} - -#[derive(Clone)] -pub struct ModelThreadDispatcher { - model_senders: HashMap>, -} - -impl ModelThreadDispatcher { - pub(crate) fn start( - &self, - models: Vec<(ModelType, ModelSpecs, VarBuilder)>, - public_key: PublicKey, - ) -> Result<(Self, Vec), ModelThreadError> { - let (core_sender, core_receiver) = std::sync::mpsc::channel::(); - - let mut handles = Vec::with_capacity(models.len()); - let mut model_senders = HashMap::with_capacity(models.len()); - - for (model_type, model_specs, var_builder) in models { - let (model_sender, model_receiver) = std::sync::mpsc::channel::(); - let model = T::load(model_specs, var_builder); // TODO: for now this piece of code cannot be shared among threads safely - let model_thread = ModelThread { - model, - receiver: model_receiver, - }; - let join_handle = std::thread::spawn(move || { - if let Err(e) = model_thread.run(public_key) { - error!("Model thread error: {e}"); - if !matches!(e, ModelThreadError::Shutdown(_)) { - panic!("Fatal error occurred: {e}"); - } - } - }); - handles.push(ModelThreadHandle { - join_handle, - sender: model_sender.clone(), - }); - model_senders.insert(model_type, model_sender); - } - - let model_dispatcher = ModelThreadDispatcher { model_senders }; - - Ok((model_dispatcher, handles)) - } - - fn send(&self, command: ModelThreadCommand) { - let request = command.0.clone(); - let model_type = request.model; - - let sender = self - .model_senders - .get(&model_type) - .expect("Failed to get model thread, this should not happen !"); - - if let Err(e) = sender.send(command) { - warn!("Could not send command to model core, it might be shutting down: {e}"); - } - } -} - -impl ModelThreadDispatcher { - pub(crate) async fn run_inference( - &self, - request: InferenceRequest, - ) -> Result { - let (sender, receiver) = oneshot::channel(); - self.send(ModelThreadCommand(request, sender)); - receiver.await.map_err(ModelThreadError::Shutdown) - } -} - -impl From for ModelThreadError { - fn from(error: InferenceCoreError) -> Self { - match error { - InferenceCoreError::FailedInference(_) => ModelThreadError::FailedInference(error), - InferenceCoreError::FailedModelFetch(_) => unreachable!(), - InferenceCoreError::FailedApiConnection(_) => { - panic!("API connection should have been already established") - } - } - } -} diff --git a/atoma-inference/src/lib.rs b/atoma-inference/src/lib.rs index 433065ca..bcd084bd 100644 --- a/atoma-inference/src/lib.rs +++ b/atoma-inference/src/lib.rs @@ -1,6 +1,5 @@ pub mod config; -pub mod core; -pub mod core_thread; +pub mod model_thread; pub mod models; pub mod service; pub mod specs; diff --git a/atoma-inference/src/models.rs b/atoma-inference/src/models.rs index 09b55ccd..118b0eb5 100644 --- a/atoma-inference/src/models.rs +++ b/atoma-inference/src/models.rs @@ -6,7 +6,7 @@ use candle_transformers::{ generation::LogitsProcessor, models::{ llama::{Cache as LlamaCache, Config as LlamaConfig, Llama}, - llama2_c::{Cache as Llama2Cache, Config as Llama2Config, Llama as Llama2}, + llama2_c::{Cache as Llama2Cache, Llama as Llama2}, mamba::{Config as MambaConfig, Model as MambaModel}, mistral::{Config as MistralConfig, Model as MistralModel}, mixtral::{Config as MixtralConfig, Model as MixtralModel}, @@ -45,10 +45,26 @@ impl Display for ModelType { } } +impl ModelType { + pub(crate) fn model_config(&self) -> ModelConfig { + match self { + Self::Llama2_7b => ModelConfig::Llama(LlamaConfig::config_7b_v2(false)), // TODO: add the case for flash attention + Self::Mamba3b => todo!(), + Self::Mistral7b => ModelConfig::Mistral(MistralConfig::config_7b_v0_1(false)), // TODO: add the case for flash attention + Self::Mixtral8x7b => ModelConfig::Mixtral8x7b(MixtralConfig::v0_1_8x7b(false)), // TODO: add the case for flash attention + Self::StableDiffusion2 => { + ModelConfig::StableDiffusion(StableDiffusionConfig::v2_1(None, None, None)) + } + Self::StableDiffusionXl => { + ModelConfig::StableDiffusion(StableDiffusionConfig::sdxl_turbo(None, None, None)) + } + } + } +} + #[derive(Clone)] pub enum ModelConfig { Llama(LlamaConfig), - Llama2(Llama2Config), Mamba(MambaConfig), Mixtral8x7b(MixtralConfig), Mistral(MistralConfig), @@ -121,10 +137,10 @@ impl ModelApi for Model { let model = Llama::load(var_builder, &config).expect("Failed to load LlaMa model"); Self::Llama { model, model_specs } } - ModelConfig::Llama2(config) => { - let model = Llama2::load(var_builder, config).expect("Failed to load LlaMa2 model"); - Self::Llama2 { model_specs, model } - } + // ModelConfig::Llama2(config) => { + // let model = Llama2::load(var_builder, config).expect("Failed to load LlaMa2 model"); + // Self::Llama2 { model_specs, model } + // } ModelConfig::Mamba(config) => { let model = MambaModel::new(&config, var_builder).expect("Failed to load Mamba model"); diff --git a/atoma-inference/src/service.rs b/atoma-inference/src/service.rs index 3896b300..483af19e 100644 --- a/atoma-inference/src/service.rs +++ b/atoma-inference/src/service.rs @@ -1,6 +1,8 @@ +use candle::{DType, Device}; use ed25519_consensus::SigningKey as PrivateKey; use hf_hub::api::sync::Api; use std::{io, path::PathBuf, time::Instant}; +use tokenizers::Tokenizer; use tokio::sync::mpsc::{error::SendError, Receiver}; use tracing::info; @@ -9,8 +11,7 @@ use thiserror::Error; use crate::{ apis::{ApiError, ApiTrait}, config::InferenceConfig, - core::{InferenceCore, InferenceCoreError}, - core_thread::{ModelThreadDispatcher, ModelThreadError, ModelThreadHandle}, + model_thread::{ModelThreadDispatcher, ModelThreadError, ModelThreadHandle}, types::{InferenceRequest, InferenceResponse}, }; @@ -38,15 +39,14 @@ impl InferenceService { let api_key = inference_config.api_key(); let storage_folder = inference_config.storage_folder(); let models = inference_config.models(); - let inference_core = InferenceCore::new(inference_config, private_key)?; let api = Api::create(api_key, storage_folder)?; let mut handles = Vec::with_capacity(models.len()); - for model in models { + for model in models.iter() { let api = api.clone(); let handle = std::thread::spawn(move || { - api.fetch(model).expect("Failed to fetch model"); + api.fetch(model.clone()).expect("Failed to fetch model"); }); handles.push(handle); } @@ -55,6 +55,13 @@ impl InferenceService { info!("Starting Core Dispatcher.."); + let model_configs = models.iter().map(|mt| mt.model_config()).collect(); + let device = Device::new_metal(ordinal); + + let tokenizer_file = inference_config.tokenizer_file_path(); + let tokenizer = + Tokenizer::from_file(tokenizer_file).map_err(InferenceServiceError::TokenizerError)?; + let (dispatcher, model_thread_handle) = ModelThreadDispatcher::start(inference_core)?; let start_time = Instant::now(); @@ -104,16 +111,8 @@ pub enum InferenceServiceError { SendError(SendError), #[error("Api error: `{0}`")] ApiError(ApiError), -} - -impl From for InferenceServiceError { - fn from(error: InferenceCoreError) -> Self { - match error { - InferenceCoreError::FailedApiConnection(e) => Self::FailedApiConnection(e), - InferenceCoreError::FailedInference(e) => Self::FailedInference(e), - InferenceCoreError::FailedModelFetch(e) => Self::FailedModelFetch(e), - } - } + #[error("Tokenizer error: `{0}`")] + TokenizerError(Box), } impl From for InferenceServiceError { @@ -124,7 +123,6 @@ impl From for InferenceServiceError { #[cfg(test)] mod tests { - use async_trait::async_trait; use rand::rngs::OsRng; use std::io::Write; use toml::{toml, Value}; diff --git a/atoma-inference/src/types.rs b/atoma-inference/src/types.rs index 383347f0..92bb9cb3 100644 --- a/atoma-inference/src/types.rs +++ b/atoma-inference/src/types.rs @@ -1,5 +1,6 @@ use crate::models::ModelType; use ed25519_consensus::VerificationKey; +use serde::Deserialize; pub type NodeId = VerificationKey; pub type Temperature = f32; @@ -46,7 +47,7 @@ pub enum QuantizationMethod { Gptq(PrecisionBits), } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Deserialize)] pub enum PrecisionBits { Q1, Q2, From 6990a12f7acabc139d39b485a1188c117f9463f9 Mon Sep 17 00:00:00 2001 From: jorgeantonio21 Date: Wed, 27 Mar 2024 02:13:43 +0000 Subject: [PATCH 12/28] add model_thread.rs, after renaming --- atoma-inference/src/model_thread.rs | 160 ++++++++++++++++++++++++++++ 1 file changed, 160 insertions(+) create mode 100644 atoma-inference/src/model_thread.rs diff --git a/atoma-inference/src/model_thread.rs b/atoma-inference/src/model_thread.rs new file mode 100644 index 00000000..01ad9829 --- /dev/null +++ b/atoma-inference/src/model_thread.rs @@ -0,0 +1,160 @@ +use std::collections::HashMap; + +use candle_nn::VarBuilder; +use ed25519_consensus::VerificationKey as PublicKey; +use thiserror::Error; +use tokio::sync::oneshot::{self, error::RecvError}; +use tracing::{debug, error, warn}; + +use crate::{ + models::{ModelApi, ModelError, ModelSpecs, ModelType}, + types::{InferenceRequest, InferenceResponse}, +}; + +const CORE_THREAD_COMMANDS_CHANNEL_SIZE: usize = 32; + +pub enum CoreThreadCommand { + RunInference(InferenceRequest, oneshot::Sender), +} + +pub struct ModelThreadCommand(InferenceRequest, oneshot::Sender); + +#[derive(Debug, Error)] +pub enum ModelThreadError { + #[error("Model thread shutdown: `{0}`")] + ModelError(ModelError), + #[error("Core thread shutdown: `{0}`")] + Shutdown(RecvError), +} + +pub struct ModelThreadHandle { + sender: std::sync::mpsc::Sender, + join_handle: std::thread::JoinHandle<()>, +} + +impl ModelThreadHandle { + pub fn stop(self) { + drop(self.sender); + self.join_handle.join().ok(); + } +} + +pub struct ModelThread { + model: T, + receiver: std::sync::mpsc::Receiver, +} + +impl ModelThread +where + T: ModelApi, +{ + pub fn run(mut self, public_key: PublicKey) -> Result<(), ModelThreadError> { + debug!("Start Model thread"); + + while let Ok(command) = self.receiver.recv() { + let ModelThreadCommand(request, sender) = command; + + let InferenceRequest { + prompt, + model, + max_tokens, + temperature, + random_seed, + repeat_last_n, + repeat_penalty, + top_k, + top_p, + sampled_nodes, + } = request; + if !sampled_nodes.contains(&public_key) { + error!("Current node, with verification key = {:?} was not sampled from {sampled_nodes:?}", public_key); + continue; + } + let response = self + .model + .run( + prompt, + max_tokens, + random_seed, + repeat_last_n, + repeat_penalty, + temperature.unwrap_or_default(), + top_p.unwrap_or_default(), + ) + .map_err(ModelThreadError::ModelError)?; + let response = InferenceResponse { response }; + sender.send(response).ok(); + } + + Ok(()) + } +} + +#[derive(Clone)] +pub struct ModelThreadDispatcher { + model_senders: HashMap>, +} + +impl ModelThreadDispatcher { + pub(crate) fn start( + &self, + models: Vec<(ModelType, ModelSpecs, VarBuilder)>, + public_key: PublicKey, + ) -> Result<(Self, Vec), ModelThreadError> { + let (core_sender, core_receiver) = std::sync::mpsc::channel::(); + + let mut handles = Vec::with_capacity(models.len()); + let mut model_senders = HashMap::with_capacity(models.len()); + + for (model_type, model_specs, var_builder) in models { + let (model_sender, model_receiver) = std::sync::mpsc::channel::(); + let model = T::load(model_specs, var_builder); // TODO: for now this piece of code cannot be shared among threads safely + let model_thread = ModelThread { + model, + receiver: model_receiver, + }; + let join_handle = std::thread::spawn(move || { + if let Err(e) = model_thread.run(public_key) { + error!("Model thread error: {e}"); + if !matches!(e, ModelThreadError::Shutdown(_)) { + panic!("Fatal error occurred: {e}"); + } + } + }); + handles.push(ModelThreadHandle { + join_handle, + sender: model_sender.clone(), + }); + model_senders.insert(model_type, model_sender); + } + + let model_dispatcher = ModelThreadDispatcher { model_senders }; + + Ok((model_dispatcher, handles)) + } + + fn send(&self, command: ModelThreadCommand) { + let request = command.0.clone(); + let model_type = request.model; + + let sender = self + .model_senders + .get(&model_type) + .expect("Failed to get model thread, this should not happen !"); + + if let Err(e) = sender.send(command) { + warn!("Could not send command to model core, it might be shutting down: {e}"); + } + } +} + +impl ModelThreadDispatcher { + pub(crate) async fn run_inference( + &self, + request: InferenceRequest, + ) -> Result { + let (sender, receiver) = oneshot::channel(); + self.send(ModelThreadCommand(request, sender)); + receiver.await.map_err(ModelThreadError::Shutdown) + } +} From d52f113cc05f37dd7fc08fb25ba4b16705ff7f94 Mon Sep 17 00:00:00 2001 From: jorgeantonio21 Date: Wed, 27 Mar 2024 02:24:52 +0000 Subject: [PATCH 13/28] intermediate steps --- atoma-inference/src/apis/hugging_face.rs | 12 +++++------ atoma-inference/src/apis/mod.rs | 3 +-- atoma-inference/src/config.rs | 8 ++++++- atoma-inference/src/service.rs | 27 +++++++++++++++--------- 4 files changed, 30 insertions(+), 20 deletions(-) diff --git a/atoma-inference/src/apis/hugging_face.rs b/atoma-inference/src/apis/hugging_face.rs index abca1c91..db0be8f9 100644 --- a/atoma-inference/src/apis/hugging_face.rs +++ b/atoma-inference/src/apis/hugging_face.rs @@ -91,10 +91,6 @@ impl ModelType { #[async_trait] impl ApiTrait for Api { - fn call(&mut self) -> Result<(), super::ApiError> { - todo!() - } - fn create(api_key: String, cache_dir: PathBuf) -> Result where Self: Sized, @@ -106,13 +102,15 @@ impl ApiTrait for Api { .build()?) } - fn fetch(&self, model: ModelType) -> Result<(), super::ApiError> { + fn fetch(&self, model: ModelType) -> Result, super::ApiError> { let (model_path, files) = model.get_hugging_face_model_path(); let api_repo = self.model(model_path); + let mut path_bufs = Vec::with_capacity(files.file_paths.len()); + for file in files.file_paths { - api_repo.get(&file)?; + path_bufs.push(api_repo.get(&file)?); } - Ok(()) + Ok(path_bufs) } } diff --git a/atoma-inference/src/apis/mod.rs b/atoma-inference/src/apis/mod.rs index 355d5f46..29f34435 100644 --- a/atoma-inference/src/apis/mod.rs +++ b/atoma-inference/src/apis/mod.rs @@ -22,8 +22,7 @@ impl From for ApiError { } pub trait ApiTrait { - fn call(&mut self) -> Result<(), ApiError>; - fn fetch(&self, model: ModelType) -> Result<(), ApiError>; + fn fetch(&self, model: ModelType) -> Result, ApiError>; fn create(api_key: String, cache_dir: PathBuf) -> Result where Self: Sized; diff --git a/atoma-inference/src/config.rs b/atoma-inference/src/config.rs index 546a0e85..3b6a343c 100644 --- a/atoma-inference/src/config.rs +++ b/atoma-inference/src/config.rs @@ -1,6 +1,5 @@ use std::path::PathBuf; -use candle::DType; use config::Config; use serde::Deserialize; @@ -14,6 +13,7 @@ pub struct InferenceConfig { storage_folder: PathBuf, tokenizer_file_path: PathBuf, tracing: bool, + use_kv_cache: Option, } impl InferenceConfig { @@ -24,6 +24,7 @@ impl InferenceConfig { storage_folder: PathBuf, tokenizer_file_path: PathBuf, tracing: bool, + use_kv_cache: Option, ) -> Self { Self { api_key, @@ -32,6 +33,7 @@ impl InferenceConfig { storage_folder, tokenizer_file_path, tracing, + use_kv_cache, } } @@ -59,6 +61,10 @@ impl InferenceConfig { self.precision } + pub fn use_kv_cache(&self) -> Option { + self.use_kv_cache + } + pub fn from_file_path(config_file_path: PathBuf) -> Self { let builder = Config::builder().add_source(config::File::with_name( config_file_path.to_str().as_ref().unwrap(), diff --git a/atoma-inference/src/service.rs b/atoma-inference/src/service.rs index 483af19e..a6900d67 100644 --- a/atoma-inference/src/service.rs +++ b/atoma-inference/src/service.rs @@ -1,4 +1,5 @@ -use candle::{DType, Device}; +use candle::Device; +use candle_transformers::quantized_var_builder::VarBuilder; use ed25519_consensus::SigningKey as PrivateKey; use hf_hub::api::sync::Api; use std::{io, path::PathBuf, time::Instant}; @@ -35,6 +36,7 @@ impl InferenceService { .expect("Incorrect private key bytes length"); let private_key = PrivateKey::from(private_key_bytes); + let public_key = private_key.verification_key(); let inference_config = InferenceConfig::from_file_path(config_file_path); let api_key = inference_config.api_key(); let storage_folder = inference_config.storage_folder(); @@ -46,12 +48,17 @@ impl InferenceService { for model in models.iter() { let api = api.clone(); let handle = std::thread::spawn(move || { - api.fetch(model.clone()).expect("Failed to fetch model"); + api.fetch(model.clone()).expect("Failed to fetch model") }); handles.push(handle); } - handles.into_iter().for_each(|h| h.join().unwrap()); + let path_bufs = handles + .into_iter() + .map(|h| { + let path_bufs = h.join().unwrap(); + }) + .collect(); info!("Starting Core Dispatcher.."); @@ -62,7 +69,11 @@ impl InferenceService { let tokenizer = Tokenizer::from_file(tokenizer_file).map_err(InferenceServiceError::TokenizerError)?; - let (dispatcher, model_thread_handle) = ModelThreadDispatcher::start(inference_core)?; + let var_builder = VarBuilder::pp(&self, s); + + let (dispatcher, model_thread_handle) = + ModelThreadDispatcher::start(inference_core, public_key) + .map_err(InferenceServiceError::ModelThreadError)?; let start_time = Instant::now(); Ok(Self { @@ -135,10 +146,6 @@ mod tests { struct TestApiInstance {} impl ApiTrait for TestApiInstance { - fn call(&mut self) -> Result<(), ApiError> { - Ok(()) - } - fn create(_: String, _: PathBuf) -> Result where Self: Sized, @@ -146,8 +153,8 @@ mod tests { Ok(Self {}) } - fn fetch(&self, _: ModelType) -> Result<(), ApiError> { - Ok(()) + fn fetch(&self, _: ModelType) -> Result, ApiError> { + Ok(vec![]) } } From 618bea8ea78653f872bcfe07ad2b3bbf7c82767b Mon Sep 17 00:00:00 2001 From: jorgeantonio21 Date: Wed, 27 Mar 2024 17:57:12 +0000 Subject: [PATCH 14/28] intermediate steps --- atoma-inference/src/config.rs | 34 +++----- atoma-inference/src/main.rs | 4 +- atoma-inference/src/model_thread.rs | 12 +-- atoma-inference/src/models.rs | 18 +--- atoma-inference/src/service.rs | 124 ++++++++++++++++++++-------- atoma-inference/src/types.rs | 27 ++++-- 6 files changed, 132 insertions(+), 87 deletions(-) diff --git a/atoma-inference/src/config.rs b/atoma-inference/src/config.rs index 3b6a343c..e1286d9b 100644 --- a/atoma-inference/src/config.rs +++ b/atoma-inference/src/config.rs @@ -5,35 +5,35 @@ use serde::Deserialize; use crate::{models::ModelType, types::PrecisionBits}; +#[derive(Clone, Debug, Deserialize)] +pub struct ModelTokenizer { + pub(crate) model_type: ModelType, + pub(crate) tokenizer: PathBuf, + pub(crate) precision: PrecisionBits, + pub(crate) use_kv_cache: Option, +} + #[derive(Debug, Deserialize)] pub struct InferenceConfig { api_key: String, - models: Vec, - precision: PrecisionBits, + models: Vec, storage_folder: PathBuf, - tokenizer_file_path: PathBuf, tracing: bool, - use_kv_cache: Option, } impl InferenceConfig { pub fn new( api_key: String, - models: Vec, - precision: PrecisionBits, + models: Vec, storage_folder: PathBuf, - tokenizer_file_path: PathBuf, tracing: bool, use_kv_cache: Option, ) -> Self { Self { api_key, models, - precision, storage_folder, - tokenizer_file_path, tracing, - use_kv_cache, } } @@ -41,7 +41,7 @@ impl InferenceConfig { self.api_key.clone() } - pub fn models(&self) -> Vec { + pub fn models(&self) -> Vec { self.models.clone() } @@ -49,22 +49,10 @@ impl InferenceConfig { self.storage_folder.clone() } - pub fn tokenizer_file_path(&self) -> PathBuf { - self.tokenizer_file_path.clone() - } - pub fn tracing(&self) -> bool { self.tracing } - pub fn precision_bits(&self) -> PrecisionBits { - self.precision - } - - pub fn use_kv_cache(&self) -> Option { - self.use_kv_cache - } - pub fn from_file_path(config_file_path: PathBuf) -> Self { let builder = Config::builder().add_source(config::File::with_name( config_file_path.to_str().as_ref().unwrap(), diff --git a/atoma-inference/src/main.rs b/atoma-inference/src/main.rs index 40b2806e..8c8d58a2 100644 --- a/atoma-inference/src/main.rs +++ b/atoma-inference/src/main.rs @@ -1,4 +1,4 @@ -use hf_hub::api::sync::Api; +use inference::models::Model; use inference::service::InferenceService; #[tokio::main] @@ -7,7 +7,7 @@ async fn main() { let (_, receiver) = tokio::sync::mpsc::channel(32); - let _ = InferenceService::start::( + let _ = InferenceService::start::( "../inference.toml".parse().unwrap(), "../private_key".parse().unwrap(), receiver, diff --git a/atoma-inference/src/model_thread.rs b/atoma-inference/src/model_thread.rs index 01ad9829..5f729268 100644 --- a/atoma-inference/src/model_thread.rs +++ b/atoma-inference/src/model_thread.rs @@ -48,7 +48,7 @@ impl ModelThread where T: ModelApi, { - pub fn run(mut self, public_key: PublicKey) -> Result<(), ModelThreadError> { + pub fn run(self, public_key: PublicKey) -> Result<(), ModelThreadError> { debug!("Start Model thread"); while let Ok(command) = self.receiver.recv() { @@ -96,13 +96,13 @@ pub struct ModelThreadDispatcher { } impl ModelThreadDispatcher { - pub(crate) fn start( - &self, + pub(crate) fn start( models: Vec<(ModelType, ModelSpecs, VarBuilder)>, public_key: PublicKey, - ) -> Result<(Self, Vec), ModelThreadError> { - let (core_sender, core_receiver) = std::sync::mpsc::channel::(); - + ) -> Result<(Self, Vec), ModelThreadError> + where + T: ModelApi + Send + 'static, + { let mut handles = Vec::with_capacity(models.len()); let mut model_senders = HashMap::with_capacity(models.len()); diff --git a/atoma-inference/src/models.rs b/atoma-inference/src/models.rs index 118b0eb5..9771e44b 100644 --- a/atoma-inference/src/models.rs +++ b/atoma-inference/src/models.rs @@ -77,12 +77,6 @@ impl From for ModelConfig { } } -#[derive(Clone)] -pub enum ModelCache { - Llama(LlamaCache), - Llama2(Llama2Cache), -} - pub trait ModelApi { fn load(model_specs: ModelSpecs, var_builder: VarBuilder) -> Self; fn run( @@ -99,7 +93,7 @@ pub trait ModelApi { #[allow(dead_code)] pub struct ModelSpecs { - pub(crate) cache: Option, + pub(crate) cache: Option, pub(crate) config: ModelConfig, pub(crate) device: Device, pub(crate) dtype: DType, @@ -174,15 +168,7 @@ impl ModelApi for Model { ) -> Result { match self { Self::Llama { model_specs, model } => { - let mut cache = if let ModelCache::Llama(cache) = - model_specs.cache.clone().expect("Failed to get cache") - { - cache - } else { - return Err(ModelError::CacheError(String::from( - "Failed to obtain correct cache", - ))); - }; + let mut cache = model_specs.cache.clone().expect("Failed to get cache"); let mut tokens = model_specs .tokenizer .encode(input, true) diff --git a/atoma-inference/src/service.rs b/atoma-inference/src/service.rs index a6900d67..aa6a238d 100644 --- a/atoma-inference/src/service.rs +++ b/atoma-inference/src/service.rs @@ -1,5 +1,6 @@ -use candle::Device; -use candle_transformers::quantized_var_builder::VarBuilder; +use candle::{Device, Error as CandleError}; +use candle_nn::var_builder::VarBuilder; +use candle_transformers::models::llama::Cache as LlamaCache; use ed25519_consensus::SigningKey as PrivateKey; use hf_hub::api::sync::Api; use std::{io, path::PathBuf, time::Instant}; @@ -11,8 +12,9 @@ use thiserror::Error; use crate::{ apis::{ApiError, ApiTrait}, - config::InferenceConfig, + config::{InferenceConfig, ModelTokenizer}, model_thread::{ModelThreadDispatcher, ModelThreadError, ModelThreadHandle}, + models::{ModelApi, ModelConfig, ModelSpecs, ModelType}, types::{InferenceRequest, InferenceResponse}, }; @@ -24,11 +26,14 @@ pub struct InferenceService { } impl InferenceService { - pub fn start( + pub fn start( config_file_path: PathBuf, private_key_path: PathBuf, _request_receiver: Receiver, - ) -> Result { + ) -> Result + where + T: ModelApi + Send + 'static, + { let private_key_bytes = std::fs::read(&private_key_path).map_err(InferenceServiceError::PrivateKeyError)?; let private_key_bytes: [u8; 32] = private_key_bytes @@ -45,34 +50,69 @@ impl InferenceService { let api = Api::create(api_key, storage_folder)?; let mut handles = Vec::with_capacity(models.len()); - for model in models.iter() { + for model in &models { let api = api.clone(); - let handle = std::thread::spawn(move || { - api.fetch(model.clone()).expect("Failed to fetch model") - }); + let model_type = model.model_type.clone(); + let handle = + std::thread::spawn(move || api.fetch(model_type).expect("Failed to fetch model")); handles.push(handle); } let path_bufs = handles .into_iter() - .map(|h| { + .zip(models) + .map(|(h, mt)| { let path_bufs = h.join().unwrap(); + (mt, path_bufs) }) - .collect(); + .collect::>(); info!("Starting Core Dispatcher.."); - let model_configs = models.iter().map(|mt| mt.model_config()).collect(); - let device = Device::new_metal(ordinal); - - let tokenizer_file = inference_config.tokenizer_file_path(); - let tokenizer = - Tokenizer::from_file(tokenizer_file).map_err(InferenceServiceError::TokenizerError)?; - - let var_builder = VarBuilder::pp(&self, s); + let device = Device::new_metal(0)?; // TODO: check this + let models = path_bufs + .iter() + .map(|(mt, paths)| { + let ModelTokenizer { + model_type, + tokenizer, + precision, + use_kv_cache, + } = mt; + let config = model_type.model_config(); + let tokenizer = Tokenizer::from_file(tokenizer) + .map_err(InferenceServiceError::TokenizerError)?; + let dtype = precision.into_dtype(); + let var_builder = + unsafe { VarBuilder::from_mmaped_safetensors(paths, dtype, &device)? }; + let cache = if let ModelType::Llama2_7b = model_type { + let llama_config = if let ModelConfig::Llama(cfg) = config.clone() { + cfg + } else { + panic!("Configuration for Llama model unexpected") + }; + Some(LlamaCache::new( + use_kv_cache.unwrap_or_default(), + dtype, + &llama_config, + &device, + )?) + } else { + None + }; + let model_specs = ModelSpecs { + cache, + config, + device: device.clone(), + dtype, + tokenizer, + }; + Ok::<_, InferenceServiceError>((model_type.clone(), model_specs, var_builder)) + }) + .collect::, _>>()?; let (dispatcher, model_thread_handle) = - ModelThreadDispatcher::start(inference_core, public_key) + ModelThreadDispatcher::start::(models, public_key) .map_err(InferenceServiceError::ModelThreadError)?; let start_time = Instant::now(); @@ -102,7 +142,11 @@ impl InferenceService { self.start_time.elapsed() ); - self.model_thread_handle.drain(..).map(|h| h.stop()); + let _ = self + .model_thread_handle + .drain(..) + .map(|h| h.stop()) + .collect::>(); } } @@ -124,6 +168,8 @@ pub enum InferenceServiceError { ApiError(ApiError), #[error("Tokenizer error: `{0}`")] TokenizerError(Box), + #[error("Candle error: `{0}`")] + CandleError(CandleError), } impl From for InferenceServiceError { @@ -132,29 +178,39 @@ impl From for InferenceServiceError { } } +impl From for InferenceServiceError { + fn from(error: CandleError) -> Self { + Self::CandleError(error) + } +} + #[cfg(test)] mod tests { use rand::rngs::OsRng; use std::io::Write; use toml::{toml, Value}; - use crate::models::ModelType; - use super::*; #[derive(Clone)] - struct TestApiInstance {} - - impl ApiTrait for TestApiInstance { - fn create(_: String, _: PathBuf) -> Result - where - Self: Sized, - { - Ok(Self {}) + struct TestModelInstance {} + + impl ModelApi for TestModelInstance { + fn load(_model_specs: ModelSpecs, _var_builder: VarBuilder) -> Self { + Self {} } - fn fetch(&self, _: ModelType) -> Result, ApiError> { - Ok(vec![]) + fn run( + &self, + _input: String, + _max_tokens: usize, + _random_seed: usize, + _repeat_last_n: usize, + _repeat_penalty: f32, + _temperature: crate::types::Temperature, + _top_p: f32, + ) -> Result { + Ok(String::from("")) } } @@ -182,7 +238,7 @@ mod tests { let (_, receiver) = tokio::sync::mpsc::channel(1); - let _ = InferenceService::start::( + let _ = InferenceService::start::( PathBuf::try_from(CONFIG_FILE_PATH).unwrap(), PathBuf::try_from(PRIVATE_KEY_FILE_PATH).unwrap(), receiver, diff --git a/atoma-inference/src/types.rs b/atoma-inference/src/types.rs index 92bb9cb3..47ed8d48 100644 --- a/atoma-inference/src/types.rs +++ b/atoma-inference/src/types.rs @@ -1,4 +1,5 @@ use crate::models::ModelType; +use candle::DType; use ed25519_consensus::VerificationKey; use serde::Deserialize; @@ -47,13 +48,27 @@ pub enum QuantizationMethod { Gptq(PrecisionBits), } -#[derive(Clone, Debug, Deserialize)] +#[derive(Copy, Clone, Debug, Deserialize)] pub enum PrecisionBits { - Q1, - Q2, - Q4, - Q5, - Q8, + BF16, F16, F32, + F64, + I64, + U8, + U32, +} + +impl PrecisionBits { + pub(crate) fn into_dtype(self) -> DType { + match self { + Self::BF16 => DType::BF16, + Self::F16 => DType::F16, + Self::F32 => DType::F32, + Self::F64 => DType::F64, + Self::I64 => DType::I64, + Self::U8 => DType::U8, + Self::U32 => DType::U32, + } + } } From 54e0abd93ed69f403aa1a1d320103bd6ee28e8e7 Mon Sep 17 00:00:00 2001 From: jorgeantonio21 Date: Wed, 27 Mar 2024 18:13:00 +0000 Subject: [PATCH 15/28] intermediate steps --- Cargo.toml | 6 +++--- atoma-inference/src/config.rs | 1 - atoma-inference/src/model_thread.rs | 5 +---- atoma-inference/src/models.rs | 30 +++++++++++++---------------- atoma-inference/src/service.rs | 2 +- atoma-inference/src/types.rs | 12 ------------ 6 files changed, 18 insertions(+), 38 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 924a79d5..0bf17ca9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,9 +8,9 @@ version = "0.1.0" [workspace.dependencies] async-trait = "0.1.78" -candle = { git = "https://github.com/huggingface/candle", package = "candle-core", version = "0.4.2" } -candle-nn = { git = "https://github.com/huggingface/candle", package = "candle-nn", version = "0.4.2" } -candle-transformers = { git = "https://github.com/huggingface/candle", package = "candle-transformers", version = "0.4.2" } +candle = { git = "https://github.com/jorgeantonio21/candle/", package = "candle-core", branch = "ja-send-sync-sd-scheduler" } +candle-nn = { git = "https://github.com/jorgeantonio21/candle/", package = "candle-nn", branch = "ja-send-sync-sd-scheduler" } +candle-transformers = { git = "https://github.com/jorgeantonio21/candle/", package = "candle-transformers", branch = "ja-send-sync-sd-scheduler" } config = "0.14.0" ed25519-consensus = "2.1.0" hf-hub = "0.3.2" diff --git a/atoma-inference/src/config.rs b/atoma-inference/src/config.rs index e1286d9b..b8313d95 100644 --- a/atoma-inference/src/config.rs +++ b/atoma-inference/src/config.rs @@ -27,7 +27,6 @@ impl InferenceConfig { models: Vec, storage_folder: PathBuf, tracing: bool, - use_kv_cache: Option, ) -> Self { Self { api_key, diff --git a/atoma-inference/src/model_thread.rs b/atoma-inference/src/model_thread.rs index 5f729268..119284a6 100644 --- a/atoma-inference/src/model_thread.rs +++ b/atoma-inference/src/model_thread.rs @@ -11,8 +11,6 @@ use crate::{ types::{InferenceRequest, InferenceResponse}, }; -const CORE_THREAD_COMMANDS_CHANNEL_SIZE: usize = 32; - pub enum CoreThreadCommand { RunInference(InferenceRequest, oneshot::Sender), } @@ -56,15 +54,14 @@ where let InferenceRequest { prompt, - model, max_tokens, temperature, random_seed, repeat_last_n, repeat_penalty, - top_k, top_p, sampled_nodes, + .. } = request; if !sampled_nodes.contains(&public_key) { error!("Current node, with verification key = {:?} was not sampled from {sampled_nodes:?}", public_key); diff --git a/atoma-inference/src/models.rs b/atoma-inference/src/models.rs index 9771e44b..706b76eb 100644 --- a/atoma-inference/src/models.rs +++ b/atoma-inference/src/models.rs @@ -6,7 +6,6 @@ use candle_transformers::{ generation::LogitsProcessor, models::{ llama::{Cache as LlamaCache, Config as LlamaConfig, Llama}, - llama2_c::{Cache as Llama2Cache, Llama as Llama2}, mamba::{Config as MambaConfig, Model as MambaModel}, mistral::{Config as MistralConfig, Model as MistralModel}, mixtral::{Config as MixtralConfig, Model as MixtralModel}, @@ -51,13 +50,15 @@ impl ModelType { Self::Llama2_7b => ModelConfig::Llama(LlamaConfig::config_7b_v2(false)), // TODO: add the case for flash attention Self::Mamba3b => todo!(), Self::Mistral7b => ModelConfig::Mistral(MistralConfig::config_7b_v0_1(false)), // TODO: add the case for flash attention - Self::Mixtral8x7b => ModelConfig::Mixtral8x7b(MixtralConfig::v0_1_8x7b(false)), // TODO: add the case for flash attention - Self::StableDiffusion2 => { - ModelConfig::StableDiffusion(StableDiffusionConfig::v2_1(None, None, None)) - } - Self::StableDiffusionXl => { - ModelConfig::StableDiffusion(StableDiffusionConfig::sdxl_turbo(None, None, None)) - } + Self::Mixtral8x7b => { + ModelConfig::Mixtral8x7b(Box::new(MixtralConfig::v0_1_8x7b(false))) + } // TODO: add the case for flash attention + Self::StableDiffusion2 => ModelConfig::StableDiffusion(Box::new( + StableDiffusionConfig::v2_1(None, None, None), + )), + Self::StableDiffusionXl => ModelConfig::StableDiffusion(Box::new( + StableDiffusionConfig::sdxl_turbo(None, None, None), + )), } } } @@ -66,9 +67,9 @@ impl ModelType { pub enum ModelConfig { Llama(LlamaConfig), Mamba(MambaConfig), - Mixtral8x7b(MixtralConfig), + Mixtral8x7b(Box), Mistral(MistralConfig), - StableDiffusion(StableDiffusionConfig), + StableDiffusion(Box), } impl From for ModelConfig { @@ -79,6 +80,8 @@ impl From for ModelConfig { pub trait ModelApi { fn load(model_specs: ModelSpecs, var_builder: VarBuilder) -> Self; + + #[allow(clippy::too_many_arguments)] fn run( &self, input: String, @@ -105,10 +108,6 @@ pub enum Model { model_specs: ModelSpecs, model: Llama, }, - Llama2 { - model_specs: ModelSpecs, - model: Llama2, - }, Mamba { model_specs: ModelSpecs, model: MambaModel, @@ -240,9 +239,6 @@ impl ModelApi for Model { } Ok(output.join(" ")) } - Self::Llama2 { .. } => { - todo!() - } Self::Mamba { .. } => { todo!() } diff --git a/atoma-inference/src/service.rs b/atoma-inference/src/service.rs index aa6a238d..d41a4076 100644 --- a/atoma-inference/src/service.rs +++ b/atoma-inference/src/service.rs @@ -35,7 +35,7 @@ impl InferenceService { T: ModelApi + Send + 'static, { let private_key_bytes = - std::fs::read(&private_key_path).map_err(InferenceServiceError::PrivateKeyError)?; + std::fs::read(private_key_path).map_err(InferenceServiceError::PrivateKeyError)?; let private_key_bytes: [u8; 32] = private_key_bytes .try_into() .expect("Incorrect private key bytes length"); diff --git a/atoma-inference/src/types.rs b/atoma-inference/src/types.rs index 47ed8d48..7d459b63 100644 --- a/atoma-inference/src/types.rs +++ b/atoma-inference/src/types.rs @@ -30,18 +30,6 @@ pub struct InferenceResponse { pub(crate) response: String, } -#[derive(Clone, Debug)] -pub struct ModelRequest { - pub(crate) model: ModelType, - pub(crate) quantization_method: Option, -} - -#[allow(dead_code)] -pub struct ModelResponse { - pub(crate) is_success: bool, - pub(crate) error: Option, -} - #[derive(Clone, Debug)] pub enum QuantizationMethod { Ggml(PrecisionBits), From e60b586a5cfc33607fcb11693983dc80ded528c5 Mon Sep 17 00:00:00 2001 From: jorgeantonio21 Date: Wed, 27 Mar 2024 18:46:27 +0000 Subject: [PATCH 16/28] address new PR comments --- Cargo.toml | 1 + atoma-inference/Cargo.toml | 1 + atoma-inference/src/main.rs | 14 +++++------ atoma-inference/src/model_thread.rs | 15 +++++------ atoma-inference/src/service.rs | 39 ++++++++++++++++++++--------- atoma-inference/src/types.rs | 1 + 6 files changed, 45 insertions(+), 26 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 0bf17ca9..d9738cda 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,6 +13,7 @@ candle-nn = { git = "https://github.com/jorgeantonio21/candle/", package = "cand candle-transformers = { git = "https://github.com/jorgeantonio21/candle/", package = "candle-transformers", branch = "ja-send-sync-sd-scheduler" } config = "0.14.0" ed25519-consensus = "2.1.0" +futures = "0.3.30" hf-hub = "0.3.2" serde = "1.0.197" serde_json = "1.0.114" diff --git a/atoma-inference/Cargo.toml b/atoma-inference/Cargo.toml index 04d5ff96..f05bd27b 100644 --- a/atoma-inference/Cargo.toml +++ b/atoma-inference/Cargo.toml @@ -12,6 +12,7 @@ candle-nn.workspace = true candle-transformers.workspace = true config.true = true ed25519-consensus.workspace = true +futures.workspace = true hf-hub.workspace = true reqwest = { workspace = true, features = ["json"] } serde = { workspace = true, features = ["derive"] } diff --git a/atoma-inference/src/main.rs b/atoma-inference/src/main.rs index 8c8d58a2..7e2e1e8f 100644 --- a/atoma-inference/src/main.rs +++ b/atoma-inference/src/main.rs @@ -5,14 +5,14 @@ use inference::service::InferenceService; async fn main() { tracing_subscriber::fmt::init(); - let (_, receiver) = tokio::sync::mpsc::channel(32); + // let (_, receiver) = tokio::sync::mpsc::channel(32); - let _ = InferenceService::start::( - "../inference.toml".parse().unwrap(), - "../private_key".parse().unwrap(), - receiver, - ) - .expect("Failed to start inference service"); + // let _ = InferenceService::start::( + // "../inference.toml".parse().unwrap(), + // "../private_key".parse().unwrap(), + // receiver, + // ) + // .expect("Failed to start inference service"); // inference_service // .run_inference(InferenceRequest { diff --git a/atoma-inference/src/model_thread.rs b/atoma-inference/src/model_thread.rs index 119284a6..168e007c 100644 --- a/atoma-inference/src/model_thread.rs +++ b/atoma-inference/src/model_thread.rs @@ -2,6 +2,7 @@ use std::collections::HashMap; use candle_nn::VarBuilder; use ed25519_consensus::VerificationKey as PublicKey; +use futures::stream::FuturesUnordered; use thiserror::Error; use tokio::sync::oneshot::{self, error::RecvError}; use tracing::{debug, error, warn}; @@ -87,9 +88,9 @@ where } } -#[derive(Clone)] pub struct ModelThreadDispatcher { model_senders: HashMap>, + pub(crate) responses: FuturesUnordered>, } impl ModelThreadDispatcher { @@ -125,7 +126,10 @@ impl ModelThreadDispatcher { model_senders.insert(model_type, model_sender); } - let model_dispatcher = ModelThreadDispatcher { model_senders }; + let model_dispatcher = ModelThreadDispatcher { + model_senders, + responses: FuturesUnordered::new(), + }; Ok((model_dispatcher, handles)) } @@ -146,12 +150,9 @@ impl ModelThreadDispatcher { } impl ModelThreadDispatcher { - pub(crate) async fn run_inference( - &self, - request: InferenceRequest, - ) -> Result { + pub(crate) fn run_inference(&self, request: InferenceRequest) { let (sender, receiver) = oneshot::channel(); self.send(ModelThreadCommand(request, sender)); - receiver.await.map_err(ModelThreadError::Shutdown) + self.responses.push(receiver); } } diff --git a/atoma-inference/src/service.rs b/atoma-inference/src/service.rs index d41a4076..40781185 100644 --- a/atoma-inference/src/service.rs +++ b/atoma-inference/src/service.rs @@ -2,11 +2,12 @@ use candle::{Device, Error as CandleError}; use candle_nn::var_builder::VarBuilder; use candle_transformers::models::llama::Cache as LlamaCache; use ed25519_consensus::SigningKey as PrivateKey; +use futures::StreamExt; use hf_hub::api::sync::Api; use std::{io, path::PathBuf, time::Instant}; use tokenizers::Tokenizer; use tokio::sync::mpsc::{error::SendError, Receiver}; -use tracing::info; +use tracing::{error, info}; use thiserror::Error; @@ -22,14 +23,14 @@ pub struct InferenceService { model_thread_handle: Vec, dispatcher: ModelThreadDispatcher, start_time: Instant, - _request_receiver: Receiver, + request_receiver: Receiver, } impl InferenceService { pub fn start( config_file_path: PathBuf, private_key_path: PathBuf, - _request_receiver: Receiver, + request_receiver: Receiver, ) -> Result where T: ModelApi + Send + 'static, @@ -120,18 +121,32 @@ impl InferenceService { dispatcher, model_thread_handle, start_time, - _request_receiver, + request_receiver, }) } - pub async fn run_inference( - &self, - inference_request: InferenceRequest, - ) -> Result { - self.dispatcher - .run_inference(inference_request) - .await - .map_err(InferenceServiceError::ModelThreadError) + pub async fn run(&mut self) -> Result { + loop { + tokio::select! { + message = self.request_receiver.recv() => { + if let Some(request) = message { + self.dispatcher.run_inference(request); + } + } + response = self.dispatcher.responses.next() => { + if let Some(resp) = response { + match resp { + Ok(response) => { + info!("Received a new inference response: {:?}", response); + } + Err(e) => { + error!("Found error in generating inference response: {e}"); + } + } + } + } + } + } } } diff --git a/atoma-inference/src/types.rs b/atoma-inference/src/types.rs index 7d459b63..c80506f5 100644 --- a/atoma-inference/src/types.rs +++ b/atoma-inference/src/types.rs @@ -8,6 +8,7 @@ pub type Temperature = f32; #[derive(Clone, Debug)] pub struct InferenceRequest { + pub request_id: u128, pub prompt: String, pub model: ModelType, pub max_tokens: usize, From 58cfca5a125d2b3788d2f1c94932e679074d83b2 Mon Sep 17 00:00:00 2001 From: jorgeantonio21 Date: Wed, 27 Mar 2024 19:41:30 +0000 Subject: [PATCH 17/28] add test to config construction --- atoma-inference/src/config.rs | 30 +++++++++++++++++++++++++++--- atoma-inference/src/main.rs | 14 +++++++------- atoma-inference/src/models.rs | 4 ++-- atoma-inference/src/types.rs | 4 ++-- 4 files changed, 38 insertions(+), 14 deletions(-) diff --git a/atoma-inference/src/config.rs b/atoma-inference/src/config.rs index b8313d95..2c33090b 100644 --- a/atoma-inference/src/config.rs +++ b/atoma-inference/src/config.rs @@ -1,11 +1,11 @@ use std::path::PathBuf; use config::Config; -use serde::Deserialize; +use serde::{Deserialize, Serialize}; use crate::{models::ModelType, types::PrecisionBits}; -#[derive(Clone, Debug, Deserialize)] +#[derive(Clone, Debug, Deserialize, Serialize)] pub struct ModelTokenizer { pub(crate) model_type: ModelType, pub(crate) tokenizer: PathBuf, @@ -13,7 +13,7 @@ pub struct ModelTokenizer { pub(crate) use_kv_cache: Option, } -#[derive(Debug, Deserialize)] +#[derive(Debug, Deserialize, Serialize)] pub struct InferenceConfig { api_key: String, models: Vec, @@ -64,3 +64,27 @@ impl InferenceConfig { .expect("Failed to generated config file") } } + +#[cfg(test)] +pub mod tests { + use super::*; + + #[test] + fn test_config() { + let config = InferenceConfig::new( + String::from("my_key"), + vec![ModelTokenizer { + model_type: ModelType::Llama2_7b, + tokenizer: "tokenizer".parse().unwrap(), + precision: PrecisionBits::BF16, + use_kv_cache: Some(true), + }], + "storage_folder".parse().unwrap(), + true, + ); + + let toml_str = toml::to_string(&config).unwrap(); + let should_be_toml_str = "api_key = \"my_key\"\nstorage_folder = \"storage_folder\"\ntracing = true\n\n[[models]]\nmodel_type = \"Llama2_7b\"\ntokenizer = \"tokenizer\"\nprecision = \"BF16\"\nuse_kv_cache = true\n"; + assert_eq!(toml_str, should_be_toml_str); + } +} diff --git a/atoma-inference/src/main.rs b/atoma-inference/src/main.rs index 7e2e1e8f..8c8d58a2 100644 --- a/atoma-inference/src/main.rs +++ b/atoma-inference/src/main.rs @@ -5,14 +5,14 @@ use inference::service::InferenceService; async fn main() { tracing_subscriber::fmt::init(); - // let (_, receiver) = tokio::sync::mpsc::channel(32); + let (_, receiver) = tokio::sync::mpsc::channel(32); - // let _ = InferenceService::start::( - // "../inference.toml".parse().unwrap(), - // "../private_key".parse().unwrap(), - // receiver, - // ) - // .expect("Failed to start inference service"); + let _ = InferenceService::start::( + "../inference.toml".parse().unwrap(), + "../private_key".parse().unwrap(), + receiver, + ) + .expect("Failed to start inference service"); // inference_service // .run_inference(InferenceRequest { diff --git a/atoma-inference/src/models.rs b/atoma-inference/src/models.rs index 706b76eb..e42a95cb 100644 --- a/atoma-inference/src/models.rs +++ b/atoma-inference/src/models.rs @@ -12,7 +12,7 @@ use candle_transformers::{ stable_diffusion::StableDiffusionConfig, }, }; -use serde::Deserialize; +use serde::{Deserialize, Serialize}; use thiserror::Error; use tokenizers::Tokenizer; @@ -21,7 +21,7 @@ use crate::types::Temperature; const EOS_TOKEN: &str = ""; -#[derive(Clone, Debug, Deserialize, Eq, Hash, PartialEq)] +#[derive(Clone, Debug, Deserialize, Eq, Hash, PartialEq, Serialize)] pub enum ModelType { Llama2_7b, Mamba3b, diff --git a/atoma-inference/src/types.rs b/atoma-inference/src/types.rs index c80506f5..bd417a66 100644 --- a/atoma-inference/src/types.rs +++ b/atoma-inference/src/types.rs @@ -1,7 +1,7 @@ use crate::models::ModelType; use candle::DType; use ed25519_consensus::VerificationKey; -use serde::Deserialize; +use serde::{Deserialize, Serialize}; pub type NodeId = VerificationKey; pub type Temperature = f32; @@ -37,7 +37,7 @@ pub enum QuantizationMethod { Gptq(PrecisionBits), } -#[derive(Copy, Clone, Debug, Deserialize)] +#[derive(Copy, Clone, Debug, Deserialize, Serialize)] pub enum PrecisionBits { BF16, F16, From 1cdb66acc02bea0bb1895938f373654cd260e762 Mon Sep 17 00:00:00 2001 From: jorgeantonio21 Date: Wed, 27 Mar 2024 22:15:38 +0000 Subject: [PATCH 18/28] remove unused code --- atoma-inference/src/model_thread.rs | 4 ---- 1 file changed, 4 deletions(-) diff --git a/atoma-inference/src/model_thread.rs b/atoma-inference/src/model_thread.rs index 168e007c..4a391026 100644 --- a/atoma-inference/src/model_thread.rs +++ b/atoma-inference/src/model_thread.rs @@ -12,10 +12,6 @@ use crate::{ types::{InferenceRequest, InferenceResponse}, }; -pub enum CoreThreadCommand { - RunInference(InferenceRequest, oneshot::Sender), -} - pub struct ModelThreadCommand(InferenceRequest, oneshot::Sender); #[derive(Debug, Error)] From b56d0b5cfb959823001be528f620be4bc5f73d12 Mon Sep 17 00:00:00 2001 From: jorgeantonio21 Date: Wed, 27 Mar 2024 22:18:52 +0000 Subject: [PATCH 19/28] remove full dependency of std::sync --- atoma-inference/src/model_thread.rs | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/atoma-inference/src/model_thread.rs b/atoma-inference/src/model_thread.rs index 4a391026..1939ccf2 100644 --- a/atoma-inference/src/model_thread.rs +++ b/atoma-inference/src/model_thread.rs @@ -1,4 +1,4 @@ -use std::collections::HashMap; +use std::{collections::HashMap, sync::mpsc}; use candle_nn::VarBuilder; use ed25519_consensus::VerificationKey as PublicKey; @@ -23,7 +23,7 @@ pub enum ModelThreadError { } pub struct ModelThreadHandle { - sender: std::sync::mpsc::Sender, + sender: mpsc::Sender, join_handle: std::thread::JoinHandle<()>, } @@ -36,7 +36,7 @@ impl ModelThreadHandle { pub struct ModelThread { model: T, - receiver: std::sync::mpsc::Receiver, + receiver: mpsc::Receiver, } impl ModelThread @@ -85,7 +85,7 @@ where } pub struct ModelThreadDispatcher { - model_senders: HashMap>, + model_senders: HashMap>, pub(crate) responses: FuturesUnordered>, } @@ -101,7 +101,7 @@ impl ModelThreadDispatcher { let mut model_senders = HashMap::with_capacity(models.len()); for (model_type, model_specs, var_builder) in models { - let (model_sender, model_receiver) = std::sync::mpsc::channel::(); + let (model_sender, model_receiver) = mpsc::channel::(); let model = T::load(model_specs, var_builder); // TODO: for now this piece of code cannot be shared among threads safely let model_thread = ModelThread { model, From 04b6d6c802fcd1bf214d01b06fd2014f9864af31 Mon Sep 17 00:00:00 2001 From: jorgeantonio21 Date: Thu, 28 Mar 2024 23:17:34 +0000 Subject: [PATCH 20/28] change to main branch --- Cargo.toml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index d9738cda..ba5f2c71 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,9 +8,9 @@ version = "0.1.0" [workspace.dependencies] async-trait = "0.1.78" -candle = { git = "https://github.com/jorgeantonio21/candle/", package = "candle-core", branch = "ja-send-sync-sd-scheduler" } -candle-nn = { git = "https://github.com/jorgeantonio21/candle/", package = "candle-nn", branch = "ja-send-sync-sd-scheduler" } -candle-transformers = { git = "https://github.com/jorgeantonio21/candle/", package = "candle-transformers", branch = "ja-send-sync-sd-scheduler" } +candle = { git = "https://github.com/jorgeantonio21/candle/", package = "candle-core", branch = "main" } +candle-nn = { git = "https://github.com/jorgeantonio21/candle/", package = "candle-nn", branch = "main" } +candle-transformers = { git = "https://github.com/jorgeantonio21/candle/", package = "candle-transformers", branch = "main" } config = "0.14.0" ed25519-consensus = "2.1.0" futures = "0.3.30" From a19817ebbd83407a965c0024f7ad7326290e7ac2 Mon Sep 17 00:00:00 2001 From: jorgeantonio21 Date: Sun, 31 Mar 2024 09:55:33 +0100 Subject: [PATCH 21/28] add model trait interface and refactor code to be more general --- atoma-inference/src/apis/hugging_face.rs | 159 ++++++------ atoma-inference/src/apis/mod.rs | 6 +- atoma-inference/src/lib.rs | 3 +- atoma-inference/src/main.rs | 33 +-- atoma-inference/src/model_thread.rs | 152 +++++++----- atoma-inference/src/models.rs | 267 --------------------- atoma-inference/src/{ => models}/config.rs | 31 +-- atoma-inference/src/models/mod.rs | 47 ++++ atoma-inference/src/service.rs | 206 ++++++++-------- atoma-inference/src/types.rs | 6 +- 10 files changed, 346 insertions(+), 564 deletions(-) delete mode 100644 atoma-inference/src/models.rs rename atoma-inference/src/{ => models}/config.rs (59%) create mode 100644 atoma-inference/src/models/mod.rs diff --git a/atoma-inference/src/apis/hugging_face.rs b/atoma-inference/src/apis/hugging_face.rs index db0be8f9..bc80099a 100644 --- a/atoma-inference/src/apis/hugging_face.rs +++ b/atoma-inference/src/apis/hugging_face.rs @@ -3,95 +3,96 @@ use std::path::PathBuf; use async_trait::async_trait; use hf_hub::api::sync::{Api, ApiBuilder}; -use crate::models::ModelType; +use crate::models::ModelId; -use super::ApiTrait; +use super::{ApiError, ApiTrait}; struct FilePaths { file_paths: Vec, } -impl ModelType { - fn get_hugging_face_model_path(&self) -> (String, FilePaths) { - match self { - Self::Llama2_7b => ( - String::from("meta-llama/Llama-2-7b-hf"), - FilePaths { - file_paths: vec![ - "model-00001-of-00002.safetensors".to_string(), - "model-00002-of-00002.safetensors".to_string(), - ], - }, - ), - Self::Mamba3b => ( - String::from("state-spaces/mamba-2.8b-hf"), - FilePaths { - file_paths: vec![ - "model-00001-of-00003.safetensors".to_string(), - "model-00002-of-00003.safetensors".to_string(), - "model-00003-of-00003.safetensors".to_string(), - ], - }, - ), - Self::Mistral7b => ( - String::from("mistralai/Mistral-7B-Instruct-v0.2"), - FilePaths { - file_paths: vec![ - "model-00001-of-00003.safetensors".to_string(), - "model-00002-of-00003.safetensors".to_string(), - "model-00003-of-00003.safetensors".to_string(), - ], - }, - ), - Self::Mixtral8x7b => ( - String::from("mistralai/Mixtral-8x7B-Instruct-v0.1"), - FilePaths { - file_paths: vec![ - "model-00001-of-00019.safetensors".to_string(), - "model-00002-of-00019.safetensors".to_string(), - "model-00003-of-00019.safetensors".to_string(), - "model-00004-of-00019.safetensors".to_string(), - "model-00005-of-00019.safetensors".to_string(), - "model-00006-of-00019.safetensors".to_string(), - "model-00007-of-00019.safetensors".to_string(), - "model-00008-of-00019.safetensors".to_string(), - "model-00009-of-00019.safetensors".to_string(), - "model-000010-of-00019.safetensors".to_string(), - "model-000011-of-00019.safetensors".to_string(), - "model-000012-of-00019.safetensors".to_string(), - "model-000013-of-00019.safetensors".to_string(), - "model-000014-of-00019.safetensors".to_string(), - "model-000015-of-00019.safetensors".to_string(), - "model-000016-of-00019.safetensors".to_string(), - "model-000017-of-00019.safetensors".to_string(), - "model-000018-of-00019.safetensors".to_string(), - "model-000019-of-00019.safetensors".to_string(), - ], - }, - ), - Self::StableDiffusion2 => ( - String::from("stabilityai/stable-diffusion-2"), - FilePaths { - file_paths: vec!["768-v-ema.safetensors".to_string()], - }, - ), - Self::StableDiffusionXl => ( - String::from("stabilityai/stable-diffusion-xl-base-1.0"), - FilePaths { - file_paths: vec![ - "sd_xl_base_1.0.safetensors".to_string(), - "sd_xl_base_1.0_0.9vae.safetensors".to_string(), - "sd_xl_offset_example-lora_1.0.safetensors".to_string(), - ], - }, - ), +fn get_model_safe_tensors_from_hf(model_id: &ModelId) -> (String, FilePaths) { + match model_id.as_str() { + "Llama2_7b" => ( + String::from("meta-llama/Llama-2-7b-hf"), + FilePaths { + file_paths: vec![ + "model-00001-of-00002.safetensors".to_string(), + "model-00002-of-00002.safetensors".to_string(), + ], + }, + ), + "Mamba3b" => ( + String::from("state-spaces/mamba-2.8b-hf"), + FilePaths { + file_paths: vec![ + "model-00001-of-00003.safetensors".to_string(), + "model-00002-of-00003.safetensors".to_string(), + "model-00003-of-00003.safetensors".to_string(), + ], + }, + ), + "Mistral7b" => ( + String::from("mistralai/Mistral-7B-Instruct-v0.2"), + FilePaths { + file_paths: vec![ + "model-00001-of-00003.safetensors".to_string(), + "model-00002-of-00003.safetensors".to_string(), + "model-00003-of-00003.safetensors".to_string(), + ], + }, + ), + "Mixtral8x7b" => ( + String::from("mistralai/Mixtral-8x7B-Instruct-v0.1"), + FilePaths { + file_paths: vec![ + "model-00001-of-00019.safetensors".to_string(), + "model-00002-of-00019.safetensors".to_string(), + "model-00003-of-00019.safetensors".to_string(), + "model-00004-of-00019.safetensors".to_string(), + "model-00005-of-00019.safetensors".to_string(), + "model-00006-of-00019.safetensors".to_string(), + "model-00007-of-00019.safetensors".to_string(), + "model-00008-of-00019.safetensors".to_string(), + "model-00009-of-00019.safetensors".to_string(), + "model-000010-of-00019.safetensors".to_string(), + "model-000011-of-00019.safetensors".to_string(), + "model-000012-of-00019.safetensors".to_string(), + "model-000013-of-00019.safetensors".to_string(), + "model-000014-of-00019.safetensors".to_string(), + "model-000015-of-00019.safetensors".to_string(), + "model-000016-of-00019.safetensors".to_string(), + "model-000017-of-00019.safetensors".to_string(), + "model-000018-of-00019.safetensors".to_string(), + "model-000019-of-00019.safetensors".to_string(), + ], + }, + ), + "StableDiffusion2" => ( + String::from("stabilityai/stable-diffusion-2"), + FilePaths { + file_paths: vec!["768-v-ema.safetensors".to_string()], + }, + ), + "StableDiffusionXl" => ( + String::from("stabilityai/stable-diffusion-xl-base-1.0"), + FilePaths { + file_paths: vec![ + "sd_xl_base_1.0.safetensors".to_string(), + "sd_xl_base_1.0_0.9vae.safetensors".to_string(), + "sd_xl_offset_example-lora_1.0.safetensors".to_string(), + ], + }, + ), + _ => { + panic!("Invalid model id") } } } #[async_trait] impl ApiTrait for Api { - fn create(api_key: String, cache_dir: PathBuf) -> Result + fn create(api_key: String, cache_dir: PathBuf) -> Result where Self: Sized, { @@ -102,8 +103,8 @@ impl ApiTrait for Api { .build()?) } - fn fetch(&self, model: ModelType) -> Result, super::ApiError> { - let (model_path, files) = model.get_hugging_face_model_path(); + fn fetch(&self, model_id: &ModelId) -> Result, ApiError> { + let (model_path, files) = get_model_safe_tensors_from_hf(model_id); let api_repo = self.model(model_path); let mut path_bufs = Vec::with_capacity(files.file_paths.len()); diff --git a/atoma-inference/src/apis/mod.rs b/atoma-inference/src/apis/mod.rs index 29f34435..e6d27941 100644 --- a/atoma-inference/src/apis/mod.rs +++ b/atoma-inference/src/apis/mod.rs @@ -5,7 +5,7 @@ use std::path::PathBuf; use thiserror::Error; -use crate::models::ModelType; +use crate::models::ModelId; #[derive(Debug, Error)] pub enum ApiError { @@ -21,8 +21,8 @@ impl From for ApiError { } } -pub trait ApiTrait { - fn fetch(&self, model: ModelType) -> Result, ApiError>; +pub trait ApiTrait: Send { + fn fetch(&self, model_id: &ModelId) -> Result, ApiError>; fn create(api_key: String, cache_dir: PathBuf) -> Result where Self: Sized; diff --git a/atoma-inference/src/lib.rs b/atoma-inference/src/lib.rs index bcd084bd..539230f0 100644 --- a/atoma-inference/src/lib.rs +++ b/atoma-inference/src/lib.rs @@ -1,8 +1,7 @@ -pub mod config; pub mod model_thread; -pub mod models; pub mod service; pub mod specs; pub mod types; pub mod apis; +pub mod models; diff --git a/atoma-inference/src/main.rs b/atoma-inference/src/main.rs index 8c8d58a2..6feadac8 100644 --- a/atoma-inference/src/main.rs +++ b/atoma-inference/src/main.rs @@ -1,31 +1,16 @@ -use inference::models::Model; -use inference::service::InferenceService; +// use hf_hub::api::sync::Api; +// use inference::service::InferenceService; #[tokio::main] async fn main() { tracing_subscriber::fmt::init(); - let (_, receiver) = tokio::sync::mpsc::channel(32); + // let (_, receiver) = tokio::sync::mpsc::channel(32); - let _ = InferenceService::start::( - "../inference.toml".parse().unwrap(), - "../private_key".parse().unwrap(), - receiver, - ) - .expect("Failed to start inference service"); - - // inference_service - // .run_inference(InferenceRequest { - // prompt: String::from("Which protocols are faster, zk-STARKs or zk-SNARKs ?"), - // max_tokens: 512, - // model: inference::models::ModelType::Llama2_7b, - // random_seed: 42, - // sampled_nodes: vec![], - // repeat_penalty: 1.0, - // temperature: Some(0.6), - // top_k: 10, - // top_p: None, - // }) - // .await - // .unwrap(); + // let _ = InferenceService::start::( + // "../inference.toml".parse().unwrap(), + // "../private_key".parse().unwrap(), + // receiver, + // ) + // .expect("Failed to start inference service"); } diff --git a/atoma-inference/src/model_thread.rs b/atoma-inference/src/model_thread.rs index 1939ccf2..8e1ee6bb 100644 --- a/atoma-inference/src/model_thread.rs +++ b/atoma-inference/src/model_thread.rs @@ -1,6 +1,5 @@ use std::{collections::HashMap, sync::mpsc}; -use candle_nn::VarBuilder; use ed25519_consensus::VerificationKey as PublicKey; use futures::stream::FuturesUnordered; use thiserror::Error; @@ -8,40 +7,67 @@ use tokio::sync::oneshot::{self, error::RecvError}; use tracing::{debug, error, warn}; use crate::{ - models::{ModelApi, ModelError, ModelSpecs, ModelType}, - types::{InferenceRequest, InferenceResponse}, + apis::{ApiError, ApiTrait}, + models::{config::ModelConfig, ModelError, ModelId, ModelTrait, Request, Response}, }; -pub struct ModelThreadCommand(InferenceRequest, oneshot::Sender); +pub struct ModelThreadCommand(T, oneshot::Sender) +where + T: Request, + U: Response; #[derive(Debug, Error)] pub enum ModelThreadError { + #[error("Model thread shutdown: `{0}`")] + ApiError(ApiError), #[error("Model thread shutdown: `{0}`")] ModelError(ModelError), #[error("Core thread shutdown: `{0}`")] Shutdown(RecvError), } -pub struct ModelThreadHandle { - sender: mpsc::Sender, - join_handle: std::thread::JoinHandle<()>, +impl From for ModelThreadError { + fn from(error: ModelError) -> Self { + Self::ModelError(error) + } +} + +impl From for ModelThreadError { + fn from(error: ApiError) -> Self { + Self::ApiError(error) + } +} + +pub struct ModelThreadHandle +where + T: Request, + U: Response, +{ + sender: mpsc::Sender>, + join_handle: std::thread::JoinHandle>, } -impl ModelThreadHandle { +impl ModelThreadHandle +where + T: Request, + U: Response, +{ pub fn stop(self) { drop(self.sender); self.join_handle.join().ok(); } } -pub struct ModelThread { - model: T, - receiver: mpsc::Receiver, +pub struct ModelThread { + model: M, + receiver: mpsc::Receiver>, } -impl ModelThread +impl ModelThread where - T: ModelApi, + M: ModelTrait, + T: Request, + U: Response, { pub fn run(self, public_key: PublicKey) -> Result<(), ModelThreadError> { debug!("Start Model thread"); @@ -49,34 +75,17 @@ where while let Ok(command) = self.receiver.recv() { let ModelThreadCommand(request, sender) = command; - let InferenceRequest { - prompt, - max_tokens, - temperature, - random_seed, - repeat_last_n, - repeat_penalty, - top_p, - sampled_nodes, - .. - } = request; - if !sampled_nodes.contains(&public_key) { - error!("Current node, with verification key = {:?} was not sampled from {sampled_nodes:?}", public_key); + if !request.is_node_authorized(&public_key) { + error!("Current node, with verification key = {:?} is not authorized to run request with id = {}", public_key, request.request_id()); continue; } - let response = self + + let model_input = request.into_model_input(); + let model_output = self .model - .run( - prompt, - max_tokens, - random_seed, - repeat_last_n, - repeat_penalty, - temperature.unwrap_or_default(), - top_p.unwrap_or_default(), - ) + .run(model_input) .map_err(ModelThreadError::ModelError)?; - let response = InferenceResponse { response }; + let response = U::from_model_output(model_output); sender.send(response).ok(); } @@ -84,42 +93,61 @@ where } } -pub struct ModelThreadDispatcher { - model_senders: HashMap>, - pub(crate) responses: FuturesUnordered>, +pub struct ModelThreadDispatcher +where + T: Request, + U: Response, +{ + model_senders: HashMap>>, + pub(crate) responses: FuturesUnordered>, } -impl ModelThreadDispatcher { - pub(crate) fn start( - models: Vec<(ModelType, ModelSpecs, VarBuilder)>, +impl ModelThreadDispatcher +where + T: Clone + Request, + U: Response, +{ + pub(crate) fn start( + api: F, + config: ModelConfig, public_key: PublicKey, - ) -> Result<(Self, Vec), ModelThreadError> + ) -> Result<(Self, Vec>), ModelThreadError> where - T: ModelApi + Send + 'static, + F: ApiTrait, + M: ModelTrait + + Send + + 'static, { - let mut handles = Vec::with_capacity(models.len()); - let mut model_senders = HashMap::with_capacity(models.len()); - - for (model_type, model_specs, var_builder) in models { - let (model_sender, model_receiver) = mpsc::channel::(); - let model = T::load(model_specs, var_builder); // TODO: for now this piece of code cannot be shared among threads safely - let model_thread = ModelThread { - model, - receiver: model_receiver, - }; + let model_ids = config.model_ids(); + let mut handles = Vec::with_capacity(model_ids.len()); + let mut model_senders = HashMap::with_capacity(model_ids.len()); + + for model_id in model_ids { + let filenames = api.fetch(&model_id)?; + + let (model_sender, model_receiver) = mpsc::channel::>(); + let join_handle = std::thread::spawn(move || { + let model = M::load(filenames)?; // TODO: for now this piece of code cannot be shared among threads safely + let model_thread = ModelThread { + model, + receiver: model_receiver, + }; + if let Err(e) = model_thread.run(public_key) { error!("Model thread error: {e}"); if !matches!(e, ModelThreadError::Shutdown(_)) { panic!("Fatal error occurred: {e}"); } } + + Ok(()) }); handles.push(ModelThreadHandle { join_handle, sender: model_sender.clone(), }); - model_senders.insert(model_type, model_sender); + model_senders.insert(model_id, model_sender); } let model_dispatcher = ModelThreadDispatcher { @@ -130,9 +158,9 @@ impl ModelThreadDispatcher { Ok((model_dispatcher, handles)) } - fn send(&self, command: ModelThreadCommand) { + fn send(&self, command: ModelThreadCommand) { let request = command.0.clone(); - let model_type = request.model; + let model_type = request.requested_model(); let sender = self .model_senders @@ -145,8 +173,12 @@ impl ModelThreadDispatcher { } } -impl ModelThreadDispatcher { - pub(crate) fn run_inference(&self, request: InferenceRequest) { +impl ModelThreadDispatcher +where + T: Clone + Request, + U: Response, +{ + pub(crate) fn run_inference(&self, request: T) { let (sender, receiver) = oneshot::channel(); self.send(ModelThreadCommand(request, sender)); self.responses.push(receiver); diff --git a/atoma-inference/src/models.rs b/atoma-inference/src/models.rs deleted file mode 100644 index e42a95cb..00000000 --- a/atoma-inference/src/models.rs +++ /dev/null @@ -1,267 +0,0 @@ -use std::fmt::Display; - -use candle::{DType, Device, Error as CandleError, Tensor}; -use candle_nn::VarBuilder; -use candle_transformers::{ - generation::LogitsProcessor, - models::{ - llama::{Cache as LlamaCache, Config as LlamaConfig, Llama}, - mamba::{Config as MambaConfig, Model as MambaModel}, - mistral::{Config as MistralConfig, Model as MistralModel}, - mixtral::{Config as MixtralConfig, Model as MixtralModel}, - stable_diffusion::StableDiffusionConfig, - }, -}; -use serde::{Deserialize, Serialize}; -use thiserror::Error; - -use tokenizers::Tokenizer; - -use crate::types::Temperature; - -const EOS_TOKEN: &str = ""; - -#[derive(Clone, Debug, Deserialize, Eq, Hash, PartialEq, Serialize)] -pub enum ModelType { - Llama2_7b, - Mamba3b, - Mixtral8x7b, - Mistral7b, - StableDiffusion2, - StableDiffusionXl, -} - -impl Display for ModelType { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Self::Llama2_7b => write!(f, "llama2_7b"), - Self::Mamba3b => write!(f, "mamba_3b"), - Self::Mixtral8x7b => write!(f, "mixtral_8x7b"), - Self::Mistral7b => write!(f, "mistral_7b"), - Self::StableDiffusion2 => write!(f, "stable_diffusion_2"), - Self::StableDiffusionXl => write!(f, "stable_diffusion_xl"), - } - } -} - -impl ModelType { - pub(crate) fn model_config(&self) -> ModelConfig { - match self { - Self::Llama2_7b => ModelConfig::Llama(LlamaConfig::config_7b_v2(false)), // TODO: add the case for flash attention - Self::Mamba3b => todo!(), - Self::Mistral7b => ModelConfig::Mistral(MistralConfig::config_7b_v0_1(false)), // TODO: add the case for flash attention - Self::Mixtral8x7b => { - ModelConfig::Mixtral8x7b(Box::new(MixtralConfig::v0_1_8x7b(false))) - } // TODO: add the case for flash attention - Self::StableDiffusion2 => ModelConfig::StableDiffusion(Box::new( - StableDiffusionConfig::v2_1(None, None, None), - )), - Self::StableDiffusionXl => ModelConfig::StableDiffusion(Box::new( - StableDiffusionConfig::sdxl_turbo(None, None, None), - )), - } - } -} - -#[derive(Clone)] -pub enum ModelConfig { - Llama(LlamaConfig), - Mamba(MambaConfig), - Mixtral8x7b(Box), - Mistral(MistralConfig), - StableDiffusion(Box), -} - -impl From for ModelConfig { - fn from(_model_type: ModelType) -> Self { - todo!() - } -} - -pub trait ModelApi { - fn load(model_specs: ModelSpecs, var_builder: VarBuilder) -> Self; - - #[allow(clippy::too_many_arguments)] - fn run( - &self, - input: String, - max_tokens: usize, - random_seed: usize, - repeat_last_n: usize, - repeat_penalty: f32, - temperature: Temperature, - top_p: f32, - ) -> Result; -} - -#[allow(dead_code)] -pub struct ModelSpecs { - pub(crate) cache: Option, - pub(crate) config: ModelConfig, - pub(crate) device: Device, - pub(crate) dtype: DType, - pub(crate) tokenizer: Tokenizer, -} - -pub enum Model { - Llama { - model_specs: ModelSpecs, - model: Llama, - }, - Mamba { - model_specs: ModelSpecs, - model: MambaModel, - }, - Mixtral8x7b { - model_specs: ModelSpecs, - model: MixtralModel, - }, - Mistral { - model_specs: ModelSpecs, - model: MistralModel, - }, -} - -impl ModelApi for Model { - fn load(model_specs: ModelSpecs, var_builder: VarBuilder) -> Self { - let model_config = model_specs.config.clone(); - match model_config { - ModelConfig::Llama(config) => { - let model = Llama::load(var_builder, &config).expect("Failed to load LlaMa model"); - Self::Llama { model, model_specs } - } - // ModelConfig::Llama2(config) => { - // let model = Llama2::load(var_builder, config).expect("Failed to load LlaMa2 model"); - // Self::Llama2 { model_specs, model } - // } - ModelConfig::Mamba(config) => { - let model = - MambaModel::new(&config, var_builder).expect("Failed to load Mamba model"); - Self::Mamba { model_specs, model } - } - ModelConfig::Mistral(config) => { - let model = - MistralModel::new(&config, var_builder).expect("Failed to load Mistral model"); - Self::Mistral { model_specs, model } - } - ModelConfig::Mixtral8x7b(config) => { - let model = - MixtralModel::new(&config, var_builder).expect("Failed to load Mixtral model"); - Self::Mixtral8x7b { model_specs, model } - } - ModelConfig::StableDiffusion(_) => { - panic!("TODO: implement it") - } - } - } - - fn run( - &self, - input: String, - max_tokens: usize, - random_seed: usize, - repeat_last_n: usize, - repeat_penalty: f32, - temperature: Temperature, - top_p: f32, - ) -> Result { - match self { - Self::Llama { model_specs, model } => { - let mut cache = model_specs.cache.clone().expect("Failed to get cache"); - let mut tokens = model_specs - .tokenizer - .encode(input, true) - .map_err(ModelError::TokenizerError)? - .get_ids() - .to_vec(); - - let mut logits_processor = LogitsProcessor::new( - random_seed as u64, - Some(temperature as f64), - Some(top_p as f64), - ); - - let eos_token_id = model_specs.tokenizer.token_to_id(EOS_TOKEN); - - let start = std::time::Instant::now(); - - let mut index_pos = 0; - let mut tokens_generated = 0; - - let mut output = Vec::with_capacity(max_tokens); - - for index in 0..max_tokens { - let (context_size, context_index) = if cache.use_kv_cache && index > 0 { - (1, index_pos) - } else { - (tokens.len(), 0) - }; - let ctx = &tokens[tokens.len().saturating_sub(context_size)..]; - let input = Tensor::new(ctx, &model_specs.device) - .map_err(ModelError::TensorError)? - .unsqueeze(0) - .map_err(ModelError::TensorError)?; - let logits = model - .forward(&input, context_index, &mut cache) - .map_err(ModelError::TensorError)?; - let logits = logits.squeeze(0).map_err(ModelError::LogitsError)?; - let logits = if repeat_penalty == 1. { - logits - } else { - let start_at = tokens.len().saturating_sub(repeat_last_n); - candle_transformers::utils::apply_repeat_penalty( - &logits, - repeat_penalty, - &tokens[start_at..], - ) - .map_err(ModelError::TensorError)? - }; - index_pos += ctx.len(); - - let next_token = logits_processor - .sample(&logits) - .map_err(ModelError::TensorError)?; - tokens_generated += 1; - tokens.push(next_token); - - if Some(next_token) == eos_token_id { - break; - } - // TODO: possibly do this in batches will speed up the process - if let Ok(t) = model_specs.tokenizer.decode(&[next_token], true) { - output.push(t); - } - let dt = start.elapsed(); - tracing::info!( - "Generated {tokens_generated} tokens ({} tokens/s)", - tokens_generated as f64 / dt.as_secs_f64() - ); - } - Ok(output.join(" ")) - } - Self::Mamba { .. } => { - todo!() - } - Self::Mistral { .. } => { - todo!() - } - Self::Mixtral8x7b { .. } => { - todo!() - } - } - } -} - -#[derive(Debug, Error)] -pub enum ModelError { - #[error("Cache error: `{0}`")] - CacheError(String), - #[error("Failed to load error: `{0}`")] - LoadError(CandleError), - #[error("Logits error: `{0}`")] - LogitsError(CandleError), - #[error("Tensor error: `{0}`")] - TensorError(CandleError), - #[error("Failed input tokenization: `{0}`")] - TokenizerError(Box), -} diff --git a/atoma-inference/src/config.rs b/atoma-inference/src/models/config.rs similarity index 59% rename from atoma-inference/src/config.rs rename to atoma-inference/src/models/config.rs index 2c33090b..08217d91 100644 --- a/atoma-inference/src/config.rs +++ b/atoma-inference/src/models/config.rs @@ -3,28 +3,20 @@ use std::path::PathBuf; use config::Config; use serde::{Deserialize, Serialize}; -use crate::{models::ModelType, types::PrecisionBits}; - -#[derive(Clone, Debug, Deserialize, Serialize)] -pub struct ModelTokenizer { - pub(crate) model_type: ModelType, - pub(crate) tokenizer: PathBuf, - pub(crate) precision: PrecisionBits, - pub(crate) use_kv_cache: Option, -} +use crate::models::ModelId; #[derive(Debug, Deserialize, Serialize)] -pub struct InferenceConfig { +pub struct ModelConfig { api_key: String, - models: Vec, + models: Vec, storage_folder: PathBuf, tracing: bool, } -impl InferenceConfig { +impl ModelConfig { pub fn new( api_key: String, - models: Vec, + models: Vec, storage_folder: PathBuf, tracing: bool, ) -> Self { @@ -40,7 +32,7 @@ impl InferenceConfig { self.api_key.clone() } - pub fn models(&self) -> Vec { + pub fn model_ids(&self) -> Vec { self.models.clone() } @@ -71,20 +63,15 @@ pub mod tests { #[test] fn test_config() { - let config = InferenceConfig::new( + let config = ModelConfig::new( String::from("my_key"), - vec![ModelTokenizer { - model_type: ModelType::Llama2_7b, - tokenizer: "tokenizer".parse().unwrap(), - precision: PrecisionBits::BF16, - use_kv_cache: Some(true), - }], + vec!["Llama2_7b".to_string()], "storage_folder".parse().unwrap(), true, ); let toml_str = toml::to_string(&config).unwrap(); - let should_be_toml_str = "api_key = \"my_key\"\nstorage_folder = \"storage_folder\"\ntracing = true\n\n[[models]]\nmodel_type = \"Llama2_7b\"\ntokenizer = \"tokenizer\"\nprecision = \"BF16\"\nuse_kv_cache = true\n"; + let should_be_toml_str = "api_key = \"my_key\"\nmodels = [\"Llama2_7b\"]\nstorage_folder = \"storage_folder\"\ntracing = true\n"; assert_eq!(toml_str, should_be_toml_str); } } diff --git a/atoma-inference/src/models/mod.rs b/atoma-inference/src/models/mod.rs new file mode 100644 index 00000000..46f820bf --- /dev/null +++ b/atoma-inference/src/models/mod.rs @@ -0,0 +1,47 @@ +use std::path::PathBuf; + +use crate::{apis::ApiTrait, models::config::ModelConfig}; +use ed25519_consensus::VerificationKey as PublicKey; +use thiserror::Error; + +pub mod config; + +pub type ModelId = String; + +pub trait ModelBuilder { + fn try_from_file(path: PathBuf) -> Result + where + Self: Sized; +} + +pub trait ModelTrait { + type Builder: Send + Sync + 'static; + type FetchApi: ApiTrait + Send + Sync + 'static; + type Input; + type Output; + + fn fetch(api: &Self::FetchApi, config: ModelConfig) -> Result<(), ModelError>; + fn load(filenames: Vec) -> Result + where + Self: Sized; + fn model_id(&self) -> ModelId; + fn run(&self, input: Self::Input) -> Result; +} + +pub trait Request: Send + 'static { + type ModelInput; + + fn into_model_input(self) -> Self::ModelInput; + fn requested_model(&self) -> ModelId; + fn request_id(&self) -> usize; // TODO: replace with Uuid + fn is_node_authorized(&self, public_key: &PublicKey) -> bool; +} + +pub trait Response: Send + 'static { + type ModelOutput; + + fn from_model_output(model_output: Self::ModelOutput) -> Self; +} + +#[derive(Debug, Error)] +pub enum ModelError {} diff --git a/atoma-inference/src/service.rs b/atoma-inference/src/service.rs index 40781185..f572a33d 100644 --- a/atoma-inference/src/service.rs +++ b/atoma-inference/src/service.rs @@ -1,39 +1,44 @@ -use candle::{Device, Error as CandleError}; -use candle_nn::var_builder::VarBuilder; -use candle_transformers::models::llama::Cache as LlamaCache; +use candle::Error as CandleError; use ed25519_consensus::SigningKey as PrivateKey; use futures::StreamExt; -use hf_hub::api::sync::Api; use std::{io, path::PathBuf, time::Instant}; -use tokenizers::Tokenizer; -use tokio::sync::mpsc::{error::SendError, Receiver}; +use tokio::sync::mpsc::Receiver; use tracing::{error, info}; use thiserror::Error; use crate::{ apis::{ApiError, ApiTrait}, - config::{InferenceConfig, ModelTokenizer}, model_thread::{ModelThreadDispatcher, ModelThreadError, ModelThreadHandle}, - models::{ModelApi, ModelConfig, ModelSpecs, ModelType}, - types::{InferenceRequest, InferenceResponse}, + models::{config::ModelConfig, ModelTrait, Request, Response}, }; -pub struct InferenceService { - model_thread_handle: Vec, - dispatcher: ModelThreadDispatcher, +pub struct InferenceService +where + T: Request, + U: Response, +{ + model_thread_handle: Vec>, + dispatcher: ModelThreadDispatcher, start_time: Instant, - request_receiver: Receiver, + request_receiver: Receiver, } -impl InferenceService { - pub fn start( +impl InferenceService +where + T: Clone + Request, + U: std::fmt::Debug + Response, +{ + pub fn start( config_file_path: PathBuf, private_key_path: PathBuf, - request_receiver: Receiver, + request_receiver: Receiver, ) -> Result where - T: ModelApi + Send + 'static, + M: ModelTrait + + Send + + 'static, + F: ApiTrait, { let private_key_bytes = std::fs::read(private_key_path).map_err(InferenceServiceError::PrivateKeyError)?; @@ -43,77 +48,14 @@ impl InferenceService { let private_key = PrivateKey::from(private_key_bytes); let public_key = private_key.verification_key(); - let inference_config = InferenceConfig::from_file_path(config_file_path); - let api_key = inference_config.api_key(); - let storage_folder = inference_config.storage_folder(); - let models = inference_config.models(); - - let api = Api::create(api_key, storage_folder)?; - - let mut handles = Vec::with_capacity(models.len()); - for model in &models { - let api = api.clone(); - let model_type = model.model_type.clone(); - let handle = - std::thread::spawn(move || api.fetch(model_type).expect("Failed to fetch model")); - handles.push(handle); - } - - let path_bufs = handles - .into_iter() - .zip(models) - .map(|(h, mt)| { - let path_bufs = h.join().unwrap(); - (mt, path_bufs) - }) - .collect::>(); + let model_config = ModelConfig::from_file_path(config_file_path); + let api_key = model_config.api_key(); + let storage_folder = model_config.storage_folder(); - info!("Starting Core Dispatcher.."); - - let device = Device::new_metal(0)?; // TODO: check this - let models = path_bufs - .iter() - .map(|(mt, paths)| { - let ModelTokenizer { - model_type, - tokenizer, - precision, - use_kv_cache, - } = mt; - let config = model_type.model_config(); - let tokenizer = Tokenizer::from_file(tokenizer) - .map_err(InferenceServiceError::TokenizerError)?; - let dtype = precision.into_dtype(); - let var_builder = - unsafe { VarBuilder::from_mmaped_safetensors(paths, dtype, &device)? }; - let cache = if let ModelType::Llama2_7b = model_type { - let llama_config = if let ModelConfig::Llama(cfg) = config.clone() { - cfg - } else { - panic!("Configuration for Llama model unexpected") - }; - Some(LlamaCache::new( - use_kv_cache.unwrap_or_default(), - dtype, - &llama_config, - &device, - )?) - } else { - None - }; - let model_specs = ModelSpecs { - cache, - config, - device: device.clone(), - dtype, - tokenizer, - }; - Ok::<_, InferenceServiceError>((model_type.clone(), model_specs, var_builder)) - }) - .collect::, _>>()?; + let api = F::create(api_key, storage_folder)?; let (dispatcher, model_thread_handle) = - ModelThreadDispatcher::start::(models, public_key) + ModelThreadDispatcher::start::(api, model_config, public_key) .map_err(InferenceServiceError::ModelThreadError)?; let start_time = Instant::now(); @@ -125,7 +67,7 @@ impl InferenceService { }) } - pub async fn run(&mut self) -> Result { + pub async fn run(&mut self) -> Result { loop { tokio::select! { message = self.request_receiver.recv() => { @@ -150,7 +92,11 @@ impl InferenceService { } } -impl InferenceService { +impl InferenceService +where + T: Request, + U: Response, +{ pub async fn stop(mut self) { info!( "Stopping Inference Service, running time: {:?}", @@ -177,8 +123,8 @@ pub enum InferenceServiceError { PrivateKeyError(io::Error), #[error("Core error: `{0}`")] ModelThreadError(ModelThreadError), - #[error("Send error: `{0}`")] - SendError(SendError), + // #[error("Send error: `{0}`")] + // SendError(SendError<_>), #[error("Api error: `{0}`")] ApiError(ApiError), #[error("Tokenizer error: `{0}`")] @@ -201,31 +147,81 @@ impl From for InferenceServiceError { #[cfg(test)] mod tests { + use ed25519_consensus::VerificationKey as PublicKey; use rand::rngs::OsRng; use std::io::Write; use toml::{toml, Value}; + use crate::models::ModelId; + use super::*; + struct MockApi {} + + impl ApiTrait for MockApi { + fn create(_: String, _: PathBuf) -> Result + where + Self: Sized, + { + Ok(Self {}) + } + + fn fetch(&self, _: &ModelId) -> Result, ApiError> { + Ok(vec![]) + } + } + + impl Request for () { + type ModelInput = (); + + fn into_model_input(self) -> Self::ModelInput { + () + } + + fn is_node_authorized(&self, _: &PublicKey) -> bool { + true + } + + fn request_id(&self) -> usize { + 0 + } + + fn requested_model(&self) -> crate::models::ModelId { + String::from("") + } + } + + impl Response for () { + type ModelOutput = (); + + fn from_model_output(_: Self::ModelOutput) -> Self { + () + } + } + #[derive(Clone)] struct TestModelInstance {} - impl ModelApi for TestModelInstance { - fn load(_model_specs: ModelSpecs, _var_builder: VarBuilder) -> Self { - Self {} + impl ModelTrait for TestModelInstance { + type Builder = (); + type FetchApi = MockApi; + type Input = (); + type Output = (); + + fn fetch(_: &Self::FetchApi, _: ModelConfig) -> Result<(), crate::models::ModelError> { + Ok(()) + } + + fn load(_: Vec) -> Result { + Ok(Self {}) + } + + fn model_id(&self) -> crate::models::ModelId { + String::from("") } - fn run( - &self, - _input: String, - _max_tokens: usize, - _random_seed: usize, - _repeat_last_n: usize, - _repeat_penalty: f32, - _temperature: crate::types::Temperature, - _top_p: f32, - ) -> Result { - Ok(String::from("")) + fn run(&self, _: Self::Input) -> Result { + Ok(()) } } @@ -251,9 +247,9 @@ mod tests { file.write_all(toml_string.as_bytes()) .expect("Failed to write to file"); - let (_, receiver) = tokio::sync::mpsc::channel(1); + let (_, receiver) = tokio::sync::mpsc::channel::<()>(1); - let _ = InferenceService::start::( + let _ = InferenceService::<(), ()>::start::( PathBuf::try_from(CONFIG_FILE_PATH).unwrap(), PathBuf::try_from(PRIVATE_KEY_FILE_PATH).unwrap(), receiver, diff --git a/atoma-inference/src/types.rs b/atoma-inference/src/types.rs index bd417a66..82643c19 100644 --- a/atoma-inference/src/types.rs +++ b/atoma-inference/src/types.rs @@ -1,8 +1,9 @@ -use crate::models::ModelType; use candle::DType; use ed25519_consensus::VerificationKey; use serde::{Deserialize, Serialize}; +use crate::models::ModelId; + pub type NodeId = VerificationKey; pub type Temperature = f32; @@ -10,7 +11,7 @@ pub type Temperature = f32; pub struct InferenceRequest { pub request_id: u128, pub prompt: String, - pub model: ModelType, + pub model: ModelId, pub max_tokens: usize, pub random_seed: usize, pub repeat_last_n: usize, @@ -49,6 +50,7 @@ pub enum PrecisionBits { } impl PrecisionBits { + #[allow(dead_code)] pub(crate) fn into_dtype(self) -> DType { match self { Self::BF16 => DType::BF16, From 4a12b71cf947e89f12abad4deb6d3ddac96c341f Mon Sep 17 00:00:00 2001 From: jorgeantonio21 Date: Sun, 31 Mar 2024 09:56:23 +0100 Subject: [PATCH 22/28] rename InferenceService to ModelService --- atoma-inference/src/main.rs | 4 ++-- atoma-inference/src/service.rs | 22 +++++++++++----------- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/atoma-inference/src/main.rs b/atoma-inference/src/main.rs index 6feadac8..1a614417 100644 --- a/atoma-inference/src/main.rs +++ b/atoma-inference/src/main.rs @@ -1,5 +1,5 @@ // use hf_hub::api::sync::Api; -// use inference::service::InferenceService; +// use inference::service::ModelService; #[tokio::main] async fn main() { @@ -7,7 +7,7 @@ async fn main() { // let (_, receiver) = tokio::sync::mpsc::channel(32); - // let _ = InferenceService::start::( + // let _ = ModelService::start::( // "../inference.toml".parse().unwrap(), // "../private_key".parse().unwrap(), // receiver, diff --git a/atoma-inference/src/service.rs b/atoma-inference/src/service.rs index f572a33d..ef8b351d 100644 --- a/atoma-inference/src/service.rs +++ b/atoma-inference/src/service.rs @@ -13,7 +13,7 @@ use crate::{ models::{config::ModelConfig, ModelTrait, Request, Response}, }; -pub struct InferenceService +pub struct ModelService where T: Request, U: Response, @@ -24,7 +24,7 @@ where request_receiver: Receiver, } -impl InferenceService +impl ModelService where T: Clone + Request, U: std::fmt::Debug + Response, @@ -33,7 +33,7 @@ where config_file_path: PathBuf, private_key_path: PathBuf, request_receiver: Receiver, - ) -> Result + ) -> Result where M: ModelTrait + Send @@ -41,7 +41,7 @@ where F: ApiTrait, { let private_key_bytes = - std::fs::read(private_key_path).map_err(InferenceServiceError::PrivateKeyError)?; + std::fs::read(private_key_path).map_err(ModelServiceError::PrivateKeyError)?; let private_key_bytes: [u8; 32] = private_key_bytes .try_into() .expect("Incorrect private key bytes length"); @@ -56,7 +56,7 @@ where let (dispatcher, model_thread_handle) = ModelThreadDispatcher::start::(api, model_config, public_key) - .map_err(InferenceServiceError::ModelThreadError)?; + .map_err(ModelServiceError::ModelThreadError)?; let start_time = Instant::now(); Ok(Self { @@ -67,7 +67,7 @@ where }) } - pub async fn run(&mut self) -> Result { + pub async fn run(&mut self) -> Result { loop { tokio::select! { message = self.request_receiver.recv() => { @@ -92,7 +92,7 @@ where } } -impl InferenceService +impl ModelService where T: Request, U: Response, @@ -112,7 +112,7 @@ where } #[derive(Debug, Error)] -pub enum InferenceServiceError { +pub enum ModelServiceError { #[error("Failed to connect to API: `{0}`")] FailedApiConnection(ApiError), #[error("Failed to run inference: `{0}`")] @@ -133,13 +133,13 @@ pub enum InferenceServiceError { CandleError(CandleError), } -impl From for InferenceServiceError { +impl From for ModelServiceError { fn from(error: ApiError) -> Self { Self::ApiError(error) } } -impl From for InferenceServiceError { +impl From for ModelServiceError { fn from(error: CandleError) -> Self { Self::CandleError(error) } @@ -249,7 +249,7 @@ mod tests { let (_, receiver) = tokio::sync::mpsc::channel::<()>(1); - let _ = InferenceService::<(), ()>::start::( + let _ = ModelService::<(), ()>::start::( PathBuf::try_from(CONFIG_FILE_PATH).unwrap(), PathBuf::try_from(PRIVATE_KEY_FILE_PATH).unwrap(), receiver, From cddb5345e3d2e489374cc27217cb53cfd4a50efe Mon Sep 17 00:00:00 2001 From: jorgeantonio21 Date: Sun, 31 Mar 2024 09:59:16 +0100 Subject: [PATCH 23/28] simplify code --- atoma-inference/src/model_thread.rs | 5 ++++- atoma-inference/src/models/config.rs | 14 +++++++------- atoma-inference/src/service.rs | 8 ++------ 3 files changed, 13 insertions(+), 14 deletions(-) diff --git a/atoma-inference/src/model_thread.rs b/atoma-inference/src/model_thread.rs index 8e1ee6bb..5b623aaa 100644 --- a/atoma-inference/src/model_thread.rs +++ b/atoma-inference/src/model_thread.rs @@ -108,7 +108,6 @@ where U: Response, { pub(crate) fn start( - api: F, config: ModelConfig, public_key: PublicKey, ) -> Result<(Self, Vec>), ModelThreadError> @@ -119,6 +118,10 @@ where + 'static, { let model_ids = config.model_ids(); + let api_key = config.api_key(); + let storage_path = config.storage_path(); + let api = F::create(api_key, storage_path)?; + let mut handles = Vec::with_capacity(model_ids.len()); let mut model_senders = HashMap::with_capacity(model_ids.len()); diff --git a/atoma-inference/src/models/config.rs b/atoma-inference/src/models/config.rs index 08217d91..ff5bf9a6 100644 --- a/atoma-inference/src/models/config.rs +++ b/atoma-inference/src/models/config.rs @@ -9,7 +9,7 @@ use crate::models::ModelId; pub struct ModelConfig { api_key: String, models: Vec, - storage_folder: PathBuf, + storage_path: PathBuf, tracing: bool, } @@ -17,13 +17,13 @@ impl ModelConfig { pub fn new( api_key: String, models: Vec, - storage_folder: PathBuf, + storage_path: PathBuf, tracing: bool, ) -> Self { Self { api_key, models, - storage_folder, + storage_path, tracing, } } @@ -36,8 +36,8 @@ impl ModelConfig { self.models.clone() } - pub fn storage_folder(&self) -> PathBuf { - self.storage_folder.clone() + pub fn storage_path(&self) -> PathBuf { + self.storage_path.clone() } pub fn tracing(&self) -> bool { @@ -66,12 +66,12 @@ pub mod tests { let config = ModelConfig::new( String::from("my_key"), vec!["Llama2_7b".to_string()], - "storage_folder".parse().unwrap(), + "storage_path".parse().unwrap(), true, ); let toml_str = toml::to_string(&config).unwrap(); - let should_be_toml_str = "api_key = \"my_key\"\nmodels = [\"Llama2_7b\"]\nstorage_folder = \"storage_folder\"\ntracing = true\n"; + let should_be_toml_str = "api_key = \"my_key\"\nmodels = [\"Llama2_7b\"]\nstorage_path = \"storage_path\"\ntracing = true\n"; assert_eq!(toml_str, should_be_toml_str); } } diff --git a/atoma-inference/src/service.rs b/atoma-inference/src/service.rs index ef8b351d..076966aa 100644 --- a/atoma-inference/src/service.rs +++ b/atoma-inference/src/service.rs @@ -49,13 +49,9 @@ where let private_key = PrivateKey::from(private_key_bytes); let public_key = private_key.verification_key(); let model_config = ModelConfig::from_file_path(config_file_path); - let api_key = model_config.api_key(); - let storage_folder = model_config.storage_folder(); - - let api = F::create(api_key, storage_folder)?; let (dispatcher, model_thread_handle) = - ModelThreadDispatcher::start::(api, model_config, public_key) + ModelThreadDispatcher::start::(model_config, public_key) .map_err(ModelServiceError::ModelThreadError)?; let start_time = Instant::now(); @@ -236,7 +232,7 @@ mod tests { let config_data = Value::Table(toml! { api_key = "your_api_key" models = ["Mamba3b"] - storage_folder = "./storage_folder/" + storage_path = "./storage_path/" tokenizer_file_path = "./tokenizer_file_path/" tracing = true }); From f673deaab8eba49d910c80abafca9caff8915b89 Mon Sep 17 00:00:00 2001 From: jorgeantonio21 Date: Sun, 31 Mar 2024 10:00:04 +0100 Subject: [PATCH 24/28] remove fetch method from ModelTrait --- atoma-inference/src/model_thread.rs | 2 +- atoma-inference/src/models/mod.rs | 3 --- atoma-inference/src/service.rs | 7 +------ 3 files changed, 2 insertions(+), 10 deletions(-) diff --git a/atoma-inference/src/model_thread.rs b/atoma-inference/src/model_thread.rs index 5b623aaa..bfe6475c 100644 --- a/atoma-inference/src/model_thread.rs +++ b/atoma-inference/src/model_thread.rs @@ -113,7 +113,7 @@ where ) -> Result<(Self, Vec>), ModelThreadError> where F: ApiTrait, - M: ModelTrait + M: ModelTrait + Send + 'static, { diff --git a/atoma-inference/src/models/mod.rs b/atoma-inference/src/models/mod.rs index 46f820bf..efc3978e 100644 --- a/atoma-inference/src/models/mod.rs +++ b/atoma-inference/src/models/mod.rs @@ -1,6 +1,5 @@ use std::path::PathBuf; -use crate::{apis::ApiTrait, models::config::ModelConfig}; use ed25519_consensus::VerificationKey as PublicKey; use thiserror::Error; @@ -16,11 +15,9 @@ pub trait ModelBuilder { pub trait ModelTrait { type Builder: Send + Sync + 'static; - type FetchApi: ApiTrait + Send + Sync + 'static; type Input; type Output; - fn fetch(api: &Self::FetchApi, config: ModelConfig) -> Result<(), ModelError>; fn load(filenames: Vec) -> Result where Self: Sized; diff --git a/atoma-inference/src/service.rs b/atoma-inference/src/service.rs index 076966aa..cf643641 100644 --- a/atoma-inference/src/service.rs +++ b/atoma-inference/src/service.rs @@ -35,7 +35,7 @@ where request_receiver: Receiver, ) -> Result where - M: ModelTrait + M: ModelTrait + Send + 'static, F: ApiTrait, @@ -200,14 +200,9 @@ mod tests { impl ModelTrait for TestModelInstance { type Builder = (); - type FetchApi = MockApi; type Input = (); type Output = (); - fn fetch(_: &Self::FetchApi, _: ModelConfig) -> Result<(), crate::models::ModelError> { - Ok(()) - } - fn load(_: Vec) -> Result { Ok(Self {}) } From e403ba96eb53f8c8bacb045d96304e0fc752c6e6 Mon Sep 17 00:00:00 2001 From: jorgeantonio21 Date: Sun, 31 Mar 2024 10:01:35 +0100 Subject: [PATCH 25/28] cargo fmt --- atoma-inference/src/model_thread.rs | 4 +--- atoma-inference/src/service.rs | 4 +--- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/atoma-inference/src/model_thread.rs b/atoma-inference/src/model_thread.rs index bfe6475c..710b66e9 100644 --- a/atoma-inference/src/model_thread.rs +++ b/atoma-inference/src/model_thread.rs @@ -113,9 +113,7 @@ where ) -> Result<(Self, Vec>), ModelThreadError> where F: ApiTrait, - M: ModelTrait - + Send - + 'static, + M: ModelTrait + Send + 'static, { let model_ids = config.model_ids(); let api_key = config.api_key(); diff --git a/atoma-inference/src/service.rs b/atoma-inference/src/service.rs index cf643641..848944f8 100644 --- a/atoma-inference/src/service.rs +++ b/atoma-inference/src/service.rs @@ -35,9 +35,7 @@ where request_receiver: Receiver, ) -> Result where - M: ModelTrait - + Send - + 'static, + M: ModelTrait + Send + 'static, F: ApiTrait, { let private_key_bytes = From b8a51ac613cd91bc57b104b04bd0e2cbca681252 Mon Sep 17 00:00:00 2001 From: jorgeantonio21 Date: Sun, 31 Mar 2024 21:19:52 +0100 Subject: [PATCH 26/28] rename --- atoma-inference/src/model_thread.rs | 56 ++++++++++++++--------------- atoma-inference/src/service.rs | 30 ++++++++-------- 2 files changed, 43 insertions(+), 43 deletions(-) diff --git a/atoma-inference/src/model_thread.rs b/atoma-inference/src/model_thread.rs index 710b66e9..5e5db91c 100644 --- a/atoma-inference/src/model_thread.rs +++ b/atoma-inference/src/model_thread.rs @@ -11,10 +11,10 @@ use crate::{ models::{config::ModelConfig, ModelError, ModelId, ModelTrait, Request, Response}, }; -pub struct ModelThreadCommand(T, oneshot::Sender) +pub struct ModelThreadCommand(Req, oneshot::Sender) where - T: Request, - U: Response; + Req: Request, + Resp: Response; #[derive(Debug, Error)] pub enum ModelThreadError { @@ -38,19 +38,19 @@ impl From for ModelThreadError { } } -pub struct ModelThreadHandle +pub struct ModelThreadHandle where - T: Request, - U: Response, + Req: Request, + Resp: Response, { - sender: mpsc::Sender>, + sender: mpsc::Sender>, join_handle: std::thread::JoinHandle>, } -impl ModelThreadHandle +impl ModelThreadHandle where - T: Request, - U: Response, + Req: Request, + Resp: Response, { pub fn stop(self) { drop(self.sender); @@ -58,16 +58,16 @@ where } } -pub struct ModelThread { +pub struct ModelThread { model: M, - receiver: mpsc::Receiver>, + receiver: mpsc::Receiver>, } -impl ModelThread +impl ModelThread where - M: ModelTrait, - T: Request, - U: Response, + M: ModelTrait, + Req: Request, + Resp: Response, { pub fn run(self, public_key: PublicKey) -> Result<(), ModelThreadError> { debug!("Start Model thread"); @@ -85,7 +85,7 @@ where .model .run(model_input) .map_err(ModelThreadError::ModelError)?; - let response = U::from_model_output(model_output); + let response = Resp::from_model_output(model_output); sender.send(response).ok(); } @@ -93,27 +93,27 @@ where } } -pub struct ModelThreadDispatcher +pub struct ModelThreadDispatcher where - T: Request, - U: Response, + Req: Request, + Resp: Response, { - model_senders: HashMap>>, - pub(crate) responses: FuturesUnordered>, + model_senders: HashMap>>, + pub(crate) responses: FuturesUnordered>, } -impl ModelThreadDispatcher +impl ModelThreadDispatcher where - T: Clone + Request, - U: Response, + Req: Clone + Request, + Resp: Response, { pub(crate) fn start( config: ModelConfig, public_key: PublicKey, - ) -> Result<(Self, Vec>), ModelThreadError> + ) -> Result<(Self, Vec>), ModelThreadError> where F: ApiTrait, - M: ModelTrait + Send + 'static, + M: ModelTrait + Send + 'static, { let model_ids = config.model_ids(); let api_key = config.api_key(); @@ -159,7 +159,7 @@ where Ok((model_dispatcher, handles)) } - fn send(&self, command: ModelThreadCommand) { + fn send(&self, command: ModelThreadCommand) { let request = command.0.clone(); let model_type = request.requested_model(); diff --git a/atoma-inference/src/service.rs b/atoma-inference/src/service.rs index 848944f8..68866ca9 100644 --- a/atoma-inference/src/service.rs +++ b/atoma-inference/src/service.rs @@ -13,29 +13,29 @@ use crate::{ models::{config::ModelConfig, ModelTrait, Request, Response}, }; -pub struct ModelService +pub struct ModelService where - T: Request, - U: Response, + Req: Request, + Resp: Response, { - model_thread_handle: Vec>, - dispatcher: ModelThreadDispatcher, + model_thread_handle: Vec>, + dispatcher: ModelThreadDispatcher, start_time: Instant, - request_receiver: Receiver, + request_receiver: Receiver, } -impl ModelService +impl ModelService where - T: Clone + Request, - U: std::fmt::Debug + Response, + Req: Clone + Request, + Resp: std::fmt::Debug + Response, { pub fn start( config_file_path: PathBuf, private_key_path: PathBuf, - request_receiver: Receiver, + request_receiver: Receiver, ) -> Result where - M: ModelTrait + Send + 'static, + M: ModelTrait + Send + 'static, F: ApiTrait, { let private_key_bytes = @@ -61,7 +61,7 @@ where }) } - pub async fn run(&mut self) -> Result { + pub async fn run(&mut self) -> Result { loop { tokio::select! { message = self.request_receiver.recv() => { @@ -86,10 +86,10 @@ where } } -impl ModelService +impl ModelService where - T: Request, - U: Response, + Req: Request, + Resp: Response, { pub async fn stop(mut self) { info!( From 6193465ca4e48cabe53d413943dfe8814f5a4174 Mon Sep 17 00:00:00 2001 From: jorgeantonio21 Date: Sun, 31 Mar 2024 21:21:47 +0100 Subject: [PATCH 27/28] remove unused error fields --- atoma-inference/src/service.rs | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/atoma-inference/src/service.rs b/atoma-inference/src/service.rs index 68866ca9..5425bdfa 100644 --- a/atoma-inference/src/service.rs +++ b/atoma-inference/src/service.rs @@ -107,22 +107,16 @@ where #[derive(Debug, Error)] pub enum ModelServiceError { - #[error("Failed to connect to API: `{0}`")] - FailedApiConnection(ApiError), #[error("Failed to run inference: `{0}`")] - FailedInference(Box), + FailedInference(Box), #[error("Failed to fecth model: `{0}`")] FailedModelFetch(String), #[error("Failed to generate private key: `{0}`")] PrivateKeyError(io::Error), #[error("Core error: `{0}`")] ModelThreadError(ModelThreadError), - // #[error("Send error: `{0}`")] - // SendError(SendError<_>), #[error("Api error: `{0}`")] ApiError(ApiError), - #[error("Tokenizer error: `{0}`")] - TokenizerError(Box), #[error("Candle error: `{0}`")] CandleError(CandleError), } From 303aedf6959cf0de12af5293a57b05be6fd739ce Mon Sep 17 00:00:00 2001 From: jorgeantonio21 Date: Mon, 1 Apr 2024 10:30:25 +0100 Subject: [PATCH 28/28] removed unused Builder from ModelTrait associated type --- atoma-inference/src/models/mod.rs | 1 - atoma-inference/src/service.rs | 1 - 2 files changed, 2 deletions(-) diff --git a/atoma-inference/src/models/mod.rs b/atoma-inference/src/models/mod.rs index efc3978e..73e2f09c 100644 --- a/atoma-inference/src/models/mod.rs +++ b/atoma-inference/src/models/mod.rs @@ -14,7 +14,6 @@ pub trait ModelBuilder { } pub trait ModelTrait { - type Builder: Send + Sync + 'static; type Input; type Output; diff --git a/atoma-inference/src/service.rs b/atoma-inference/src/service.rs index 5425bdfa..a01bd5f4 100644 --- a/atoma-inference/src/service.rs +++ b/atoma-inference/src/service.rs @@ -191,7 +191,6 @@ mod tests { struct TestModelInstance {} impl ModelTrait for TestModelInstance { - type Builder = (); type Input = (); type Output = ();