From c8ab1adc4eff0969c640e21f8c1c581497388459 Mon Sep 17 00:00:00 2001 From: jorgeantonio21 Date: Tue, 2 Apr 2024 10:57:12 +0100 Subject: [PATCH] address PR comments --- Cargo.toml | 1 + atoma-inference/Cargo.toml | 1 + atoma-inference/src/models/config.rs | 31 ++++++++++++++++++++++++++++ 3 files changed, 33 insertions(+) diff --git a/Cargo.toml b/Cargo.toml index e37a9e35..c2ccab6a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,6 +13,7 @@ candle-flash-attn = { git = "https://github.com/huggingface/candle", package = " candle-nn = { git = "https://github.com/huggingface/candle", package = "candle-nn", version = "0.4.2" } candle-transformers = { git = "https://github.com/huggingface/candle", package = "candle-transformers", version = "0.4.2" } config = "0.14.0" +dotenv = "0.15.0" ed25519-consensus = "2.1.0" futures = "0.3.30" hf-hub = "0.3.2" diff --git a/atoma-inference/Cargo.toml b/atoma-inference/Cargo.toml index 87aeee0f..a9cbffc5 100644 --- a/atoma-inference/Cargo.toml +++ b/atoma-inference/Cargo.toml @@ -12,6 +12,7 @@ candle-flash-attn = { workspace = true, optional = true } candle-nn.workspace = true candle-transformers.workspace = true config.workspace = true +dotenv.workspace = true ed25519-consensus.workspace = true futures.workspace = true hf-hub.workspace = true diff --git a/atoma-inference/src/models/config.rs b/atoma-inference/src/models/config.rs index e17a3582..37ea846c 100644 --- a/atoma-inference/src/models/config.rs +++ b/atoma-inference/src/models/config.rs @@ -1,6 +1,7 @@ use std::path::PathBuf; use config::Config; +use dotenv::dotenv; use serde::{Deserialize, Serialize}; use crate::{models::types::PrecisionBits, models::ModelId}; @@ -64,6 +65,36 @@ impl ModelConfig { .try_deserialize::() .expect("Failed to generated config file") } + + pub fn from_env_file() -> Self { + dotenv().ok(); + + let api_key = std::env::var("API_KEY").expect("Failed to retrieve api key, from .env file"); + let flush_storage = std::env::var("FLUSH_STORAGE") + .expect("Failed to retrieve flush storage variable, from .env file") + .parse() + .unwrap(); + let models = serde_json::from_str( + &std::env::var("MODELS").expect("Failed to retrieve models metadata, from .env file"), + ) + .unwrap(); + let storage_path = std::env::var("STORAGE_PATH") + .expect("Failed to retrieve storage path, from .env file") + .parse() + .unwrap(); + let tracing = std::env::var("TRACING") + .expect("Failed to retrieve tracing variable, from .env file") + .parse() + .unwrap(); + + Self { + api_key, + flush_storage, + models, + storage_path, + tracing, + } + } } #[cfg(test)]