From 58cfca5a125d2b3788d2f1c94932e679074d83b2 Mon Sep 17 00:00:00 2001 From: jorgeantonio21 Date: Wed, 27 Mar 2024 19:41:30 +0000 Subject: [PATCH] add test to config construction --- atoma-inference/src/config.rs | 30 +++++++++++++++++++++++++++--- atoma-inference/src/main.rs | 14 +++++++------- atoma-inference/src/models.rs | 4 ++-- atoma-inference/src/types.rs | 4 ++-- 4 files changed, 38 insertions(+), 14 deletions(-) diff --git a/atoma-inference/src/config.rs b/atoma-inference/src/config.rs index b8313d95..2c33090b 100644 --- a/atoma-inference/src/config.rs +++ b/atoma-inference/src/config.rs @@ -1,11 +1,11 @@ use std::path::PathBuf; use config::Config; -use serde::Deserialize; +use serde::{Deserialize, Serialize}; use crate::{models::ModelType, types::PrecisionBits}; -#[derive(Clone, Debug, Deserialize)] +#[derive(Clone, Debug, Deserialize, Serialize)] pub struct ModelTokenizer { pub(crate) model_type: ModelType, pub(crate) tokenizer: PathBuf, @@ -13,7 +13,7 @@ pub struct ModelTokenizer { pub(crate) use_kv_cache: Option, } -#[derive(Debug, Deserialize)] +#[derive(Debug, Deserialize, Serialize)] pub struct InferenceConfig { api_key: String, models: Vec, @@ -64,3 +64,27 @@ impl InferenceConfig { .expect("Failed to generated config file") } } + +#[cfg(test)] +pub mod tests { + use super::*; + + #[test] + fn test_config() { + let config = InferenceConfig::new( + String::from("my_key"), + vec![ModelTokenizer { + model_type: ModelType::Llama2_7b, + tokenizer: "tokenizer".parse().unwrap(), + precision: PrecisionBits::BF16, + use_kv_cache: Some(true), + }], + "storage_folder".parse().unwrap(), + true, + ); + + let toml_str = toml::to_string(&config).unwrap(); + let should_be_toml_str = "api_key = \"my_key\"\nstorage_folder = \"storage_folder\"\ntracing = true\n\n[[models]]\nmodel_type = \"Llama2_7b\"\ntokenizer = \"tokenizer\"\nprecision = \"BF16\"\nuse_kv_cache = true\n"; + assert_eq!(toml_str, should_be_toml_str); + } +} diff --git a/atoma-inference/src/main.rs b/atoma-inference/src/main.rs index 7e2e1e8f..8c8d58a2 100644 --- a/atoma-inference/src/main.rs +++ b/atoma-inference/src/main.rs @@ -5,14 +5,14 @@ use inference::service::InferenceService; async fn main() { tracing_subscriber::fmt::init(); - // let (_, receiver) = tokio::sync::mpsc::channel(32); + let (_, receiver) = tokio::sync::mpsc::channel(32); - // let _ = InferenceService::start::( - // "../inference.toml".parse().unwrap(), - // "../private_key".parse().unwrap(), - // receiver, - // ) - // .expect("Failed to start inference service"); + let _ = InferenceService::start::( + "../inference.toml".parse().unwrap(), + "../private_key".parse().unwrap(), + receiver, + ) + .expect("Failed to start inference service"); // inference_service // .run_inference(InferenceRequest { diff --git a/atoma-inference/src/models.rs b/atoma-inference/src/models.rs index 706b76eb..e42a95cb 100644 --- a/atoma-inference/src/models.rs +++ b/atoma-inference/src/models.rs @@ -12,7 +12,7 @@ use candle_transformers::{ stable_diffusion::StableDiffusionConfig, }, }; -use serde::Deserialize; +use serde::{Deserialize, Serialize}; use thiserror::Error; use tokenizers::Tokenizer; @@ -21,7 +21,7 @@ use crate::types::Temperature; const EOS_TOKEN: &str = ""; -#[derive(Clone, Debug, Deserialize, Eq, Hash, PartialEq)] +#[derive(Clone, Debug, Deserialize, Eq, Hash, PartialEq, Serialize)] pub enum ModelType { Llama2_7b, Mamba3b, diff --git a/atoma-inference/src/types.rs b/atoma-inference/src/types.rs index c80506f5..bd417a66 100644 --- a/atoma-inference/src/types.rs +++ b/atoma-inference/src/types.rs @@ -1,7 +1,7 @@ use crate::models::ModelType; use candle::DType; use ed25519_consensus::VerificationKey; -use serde::Deserialize; +use serde::{Deserialize, Serialize}; pub type NodeId = VerificationKey; pub type Temperature = f32; @@ -37,7 +37,7 @@ pub enum QuantizationMethod { Gptq(PrecisionBits), } -#[derive(Copy, Clone, Debug, Deserialize)] +#[derive(Copy, Clone, Debug, Deserialize, Serialize)] pub enum PrecisionBits { BF16, F16,