diff --git a/Cargo.toml b/Cargo.toml index fc2b0a3a..1758ec34 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,6 +14,7 @@ members = [ version = "0.1.0" [workspace.dependencies] +reqwest = "0.12.1" anyhow = "1.0.81" async-trait = "0.1.78" candle = { git = "https://github.com/huggingface/candle", package = "candle-core", version = "0.4.2" } @@ -33,7 +34,6 @@ image = { version = "0.25.0", default-features = false, features = [ serde = "1.0.197" serde_json = "1.0.114" rand = "0.8.5" -reqwest = "0.12.1" thiserror = "1.0.58" tokenizers = "0.15.2" tokio = "1.36.0" diff --git a/atoma-inference/src/candle/stable_diffusion.rs b/atoma-inference/src/candle/stable_diffusion.rs index c8718e4c..46dadbc1 100644 --- a/atoma-inference/src/candle/stable_diffusion.rs +++ b/atoma-inference/src/candle/stable_diffusion.rs @@ -43,9 +43,6 @@ pub struct Input { sd_version: StableDiffusionVersion, - /// Generate intermediary images at each step. - intermediary_images: bool, - use_flash_attn: bool, use_f16: bool, @@ -78,7 +75,6 @@ impl Input { n_steps: Some(20), num_samples: 1, sd_version: StableDiffusionVersion::V1_5, - intermediary_images: false, use_flash_attn: false, use_f16: true, guidance_scale: None, @@ -315,6 +311,7 @@ impl CandleModel for StableDiffusion { } } +#[allow(dead_code)] #[derive(Clone, Copy)] enum StableDiffusionVersion { V1_5, diff --git a/atoma-inference/src/lib.rs b/atoma-inference/src/lib.rs index c6f12e7a..80915305 100644 --- a/atoma-inference/src/lib.rs +++ b/atoma-inference/src/lib.rs @@ -1,10 +1,6 @@ -pub mod model_thread; +pub mod apis; pub mod candle; -pub mod config; -pub mod core_thread; +pub mod model_thread; pub mod models; pub mod service; pub mod specs; - -pub mod apis; -pub mod models; diff --git a/atoma-inference/src/models/candle/mamba.rs b/atoma-inference/src/models/candle/mamba.rs index 284e25e4..65ed9835 100644 --- a/atoma-inference/src/models/candle/mamba.rs +++ b/atoma-inference/src/models/candle/mamba.rs @@ -64,8 +64,7 @@ impl ModelTrait for MambaModel { let tokenizer_filename = filenames[1].clone(); let weights_filenames = filenames[2..].to_vec(); - let tokenizer = - Tokenizer::from_file(tokenizer_filename).map_err(ModelError::TokenizerError)?; + let tokenizer = Tokenizer::from_file(tokenizer_filename)?; let config: Config = serde_json::from_slice(&std::fs::read(config_filename).map_err(ModelError::IoError)?) @@ -110,8 +109,7 @@ impl ModelTrait for MambaModel { let mut tokens = self .tokenizer .tokenizer() - .encode(prompt, true) - .map_err(ModelError::TokenizerError)? + .encode(prompt, true)? .get_ids() .to_vec(); let mut logits_processor = diff --git a/atoma-inference/src/models/mod.rs b/atoma-inference/src/models/mod.rs index dc82f4c1..724a5d42 100644 --- a/atoma-inference/src/models/mod.rs +++ b/atoma-inference/src/models/mod.rs @@ -41,22 +41,22 @@ pub trait Response: Send + 'static { #[derive(Debug, Error)] pub enum ModelError { - #[error("Tokenizer error: `{0}`")] - TokenizerError(Box), - #[error("IO error: `{0}`")] - IoError(std::io::Error), #[error("Deserialize error: `{0}`")] - DeserializeError(serde_json::Error), - #[error("Candle error: `{0}`")] - CandleError(CandleError), + DeserializeError(#[from] serde_json::Error), #[error("{0}")] Msg(String), -} - -impl From for ModelError { - fn from(error: CandleError) -> Self { - Self::CandleError(error) - } + #[error("Candle error: `{0}`")] + CandleError(#[from] CandleError), + #[error("Config error: `{0}`")] + Config(String), + #[error("Image error: `{0}`")] + ImageError(#[from] image::ImageError), + #[error("Io error: `{0}`")] + IoError(#[from] std::io::Error), + #[error("Error: `{0}`")] + BoxedError(#[from] Box), + #[error("ApiError error: `{0}`")] + ApiError(#[from] hf_hub::api::sync::ApiError), } #[macro_export]