Skip to content

Commit

Permalink
add test to config construction
Browse files Browse the repository at this point in the history
  • Loading branch information
jorgeantonio21 committed Mar 27, 2024
1 parent e60b586 commit 58cfca5
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 14 deletions.
30 changes: 27 additions & 3 deletions atoma-inference/src/config.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
use std::path::PathBuf;

use config::Config;
use serde::Deserialize;
use serde::{Deserialize, Serialize};

use crate::{models::ModelType, types::PrecisionBits};

#[derive(Clone, Debug, Deserialize)]
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct ModelTokenizer {
pub(crate) model_type: ModelType,
pub(crate) tokenizer: PathBuf,
pub(crate) precision: PrecisionBits,
pub(crate) use_kv_cache: Option<bool>,
}

#[derive(Debug, Deserialize)]
#[derive(Debug, Deserialize, Serialize)]
pub struct InferenceConfig {
api_key: String,
models: Vec<ModelTokenizer>,
Expand Down Expand Up @@ -64,3 +64,27 @@ impl InferenceConfig {
.expect("Failed to generated config file")
}
}

#[cfg(test)]
pub mod tests {
use super::*;

#[test]
fn test_config() {
let config = InferenceConfig::new(
String::from("my_key"),
vec![ModelTokenizer {
model_type: ModelType::Llama2_7b,
tokenizer: "tokenizer".parse().unwrap(),
precision: PrecisionBits::BF16,
use_kv_cache: Some(true),
}],
"storage_folder".parse().unwrap(),
true,
);

let toml_str = toml::to_string(&config).unwrap();
let should_be_toml_str = "api_key = \"my_key\"\nstorage_folder = \"storage_folder\"\ntracing = true\n\n[[models]]\nmodel_type = \"Llama2_7b\"\ntokenizer = \"tokenizer\"\nprecision = \"BF16\"\nuse_kv_cache = true\n";
assert_eq!(toml_str, should_be_toml_str);
}
}
14 changes: 7 additions & 7 deletions atoma-inference/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@ use inference::service::InferenceService;
async fn main() {
tracing_subscriber::fmt::init();

// let (_, receiver) = tokio::sync::mpsc::channel(32);
let (_, receiver) = tokio::sync::mpsc::channel(32);

// let _ = InferenceService::start::<Model>(
// "../inference.toml".parse().unwrap(),
// "../private_key".parse().unwrap(),
// receiver,
// )
// .expect("Failed to start inference service");
let _ = InferenceService::start::<Model>(
"../inference.toml".parse().unwrap(),
"../private_key".parse().unwrap(),
receiver,
)
.expect("Failed to start inference service");

// inference_service
// .run_inference(InferenceRequest {
Expand Down
4 changes: 2 additions & 2 deletions atoma-inference/src/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use candle_transformers::{
stable_diffusion::StableDiffusionConfig,
},
};
use serde::Deserialize;
use serde::{Deserialize, Serialize};
use thiserror::Error;

use tokenizers::Tokenizer;
Expand All @@ -21,7 +21,7 @@ use crate::types::Temperature;

const EOS_TOKEN: &str = "</s>";

#[derive(Clone, Debug, Deserialize, Eq, Hash, PartialEq)]
#[derive(Clone, Debug, Deserialize, Eq, Hash, PartialEq, Serialize)]
pub enum ModelType {
Llama2_7b,
Mamba3b,
Expand Down
4 changes: 2 additions & 2 deletions atoma-inference/src/types.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::models::ModelType;
use candle::DType;
use ed25519_consensus::VerificationKey;
use serde::Deserialize;
use serde::{Deserialize, Serialize};

pub type NodeId = VerificationKey;
pub type Temperature = f32;
Expand Down Expand Up @@ -37,7 +37,7 @@ pub enum QuantizationMethod {
Gptq(PrecisionBits),
}

#[derive(Copy, Clone, Debug, Deserialize)]
#[derive(Copy, Clone, Debug, Deserialize, Serialize)]
pub enum PrecisionBits {
BF16,
F16,
Expand Down

0 comments on commit 58cfca5

Please sign in to comment.