From 670ee53ca012ad5bafef56aff9dc4485940e6eed Mon Sep 17 00:00:00 2001 From: Jorge Antonio Date: Thu, 2 Jan 2025 16:13:06 +0000 Subject: [PATCH] address PR comments --- .../src/handlers/chat_completions.rs | 7 +++++-- atoma-service/src/handlers/embeddings.rs | 7 +++++-- .../src/handlers/image_generations.rs | 7 +++++-- atoma-service/src/handlers/mod.rs | 2 +- atoma-service/src/middleware.rs | 16 +++++++-------- atoma-service/src/streamer.rs | 20 +++++++++---------- 6 files changed, 33 insertions(+), 26 deletions(-) diff --git a/atoma-service/src/handlers/chat_completions.rs b/atoma-service/src/handlers/chat_completions.rs index 435bf7bd..e26e7657 100644 --- a/atoma-service/src/handlers/chat_completions.rs +++ b/atoma-service/src/handlers/chat_completions.rs @@ -37,6 +37,9 @@ pub const CHAT_COMPLETIONS_PATH: &str = "/v1/chat/completions"; /// The keep-alive interval in seconds const STREAM_KEEP_ALIVE_INTERVAL_IN_SECONDS: u64 = 15; +/// The key for the model parameter in the request body +const MODEL_KEY: &str = "model"; + /// The key for the stream parameter in the request body const STREAM_KEY: &str = "stream"; @@ -512,7 +515,7 @@ async fn handle_non_streaming_response( ) -> Result, AtomaServiceError> { // Record token metrics and extract the response total number of tokens let model = payload - .get("model") + .get(MODEL_KEY) .and_then(|m| m.as_str()) .unwrap_or("unknown"); let timer = CHAT_COMPLETIONS_LATENCY_METRICS @@ -608,7 +611,7 @@ async fn handle_streaming_response( }); let model = payload - .get("model") + .get(MODEL_KEY) .and_then(|m| m.as_str()) .unwrap_or("unknown"); CHAT_COMPLETIONS_NUM_REQUESTS diff --git a/atoma-service/src/handlers/embeddings.rs b/atoma-service/src/handlers/embeddings.rs index 009c5fbd..183a9f0a 100644 --- a/atoma-service/src/handlers/embeddings.rs +++ b/atoma-service/src/handlers/embeddings.rs @@ -22,6 +22,9 @@ pub const CONFIDENTIAL_EMBEDDINGS_PATH: &str = "/v1/confidential/embeddings"; /// The path for embeddings requests pub const EMBEDDINGS_PATH: &str = "/v1/embeddings"; +/// The key for the model parameter in the request body +pub const MODEL_KEY: &str = "model"; + /// OpenAPI documentation structure for the embeddings endpoint. /// /// This struct defines the OpenAPI (Swagger) documentation for the embeddings API, @@ -72,7 +75,7 @@ pub async fn embeddings_handler( ) -> Result, AtomaServiceError> { info!("Received embeddings request, with payload: {payload}"); let model = payload - .get("model") + .get(MODEL_KEY) .and_then(|m| m.as_str()) .unwrap_or("unknown"); let RequestMetadata { @@ -181,7 +184,7 @@ pub async fn confidential_embeddings_handler( ) -> Result, AtomaServiceError> { info!("Received embeddings request, with payload: {payload}"); let model = payload - .get("model") + .get(MODEL_KEY) .and_then(|m| m.as_str()) .unwrap_or("unknown"); let RequestMetadata { diff --git a/atoma-service/src/handlers/image_generations.rs b/atoma-service/src/handlers/image_generations.rs index a7df513d..5e0be598 100644 --- a/atoma-service/src/handlers/image_generations.rs +++ b/atoma-service/src/handlers/image_generations.rs @@ -23,6 +23,9 @@ pub const CONFIDENTIAL_IMAGE_GENERATIONS_PATH: &str = "/v1/confidential/images/g /// The path for image generations requests pub const IMAGE_GENERATIONS_PATH: &str = "/v1/images/generations"; +/// The key for the model parameter in the request body +pub const MODEL_KEY: &str = "model"; + /// OpenAPI documentation structure for the image generations endpoint. /// /// This struct defines the OpenAPI (Swagger) documentation for the image generations API, @@ -73,7 +76,7 @@ pub async fn image_generations_handler( ) -> Result, AtomaServiceError> { info!("Received image generations request, with payload: {payload}"); let model = payload - .get("model") + .get(MODEL_KEY) .and_then(|m| m.as_str()) .unwrap_or("unknown"); @@ -182,7 +185,7 @@ pub async fn confidential_image_generations_handler( ) -> Result, AtomaServiceError> { info!("Received image generations request, with payload: {payload}"); let model = payload - .get("model") + .get(MODEL_KEY) .and_then(|m| m.as_str()) .unwrap_or("unknown"); diff --git a/atoma-service/src/handlers/mod.rs b/atoma-service/src/handlers/mod.rs index f0662616..49d798c5 100644 --- a/atoma-service/src/handlers/mod.rs +++ b/atoma-service/src/handlers/mod.rs @@ -33,7 +33,7 @@ const RESPONSE_HASH_KEY: &str = "response_hash"; const SIGNATURE_KEY: &str = "signature"; /// Key for the usage in the response body -const USAGE_KEY: &str = "usage"; +pub const USAGE_KEY: &str = "usage"; /// Updates response signature and stack hash state /// diff --git a/atoma-service/src/middleware.rs b/atoma-service/src/middleware.rs index 76b44488..c14cdc11 100644 --- a/atoma-service/src/middleware.rs +++ b/atoma-service/src/middleware.rs @@ -1135,14 +1135,14 @@ pub(crate) mod utils { endpoint: endpoint.to_string(), })?; let salt_bytes: [u8; SALT_SIZE] = salt_bytes.try_into().map_err(|e| { - AtomaServiceError::InvalidHeader { - message: format!( - "Failed to convert salt bytes to {SALT_SIZE}-byte array, incorrect length, with error: {:?}", - e - ), - endpoint: endpoint.to_string(), - } - })?; + AtomaServiceError::InvalidHeader { + message: format!( + "Failed to convert salt bytes to {SALT_SIZE}-byte array, incorrect length, with error: {:?}", + e + ), + endpoint: endpoint.to_string(), + } + })?; let nonce_bytes = STANDARD .decode(&confidential_compute_request.nonce) .map_err(|e| AtomaServiceError::InvalidHeader { diff --git a/atoma-service/src/streamer.rs b/atoma-service/src/streamer.rs index c272f5c0..e9007930 100644 --- a/atoma-service/src/streamer.rs +++ b/atoma-service/src/streamer.rs @@ -27,7 +27,7 @@ use crate::{ CHAT_COMPLETIONS_DECODING_TIME, CHAT_COMPLETIONS_INPUT_TOKENS_METRICS, CHAT_COMPLETIONS_OUTPUT_TOKENS_METRICS, }, - update_stack_num_compute_units, + update_stack_num_compute_units, USAGE_KEY, }, server::utils, }; @@ -44,9 +44,6 @@ const KEEP_ALIVE_CHUNK: &[u8] = b": keep-alive\n\n"; /// The choices key const CHOICES: &str = "choices"; -/// The usage key -const USAGE: &str = "usage"; - /// The ciphertext key const CIPHERTEXT_KEY: &str = "ciphertext"; @@ -59,9 +56,6 @@ const RESPONSE_HASH_KEY: &str = "response_hash"; /// The signature key const SIGNATURE_KEY: &str = "signature"; -/// The usage key -const USAGE_KEY: &str = "usage"; - /// Metadata required for encrypting streaming responses to clients. /// /// This structure contains the cryptographic elements needed to establish @@ -88,7 +82,7 @@ pub struct Streamer { /// The estimated total compute units for the request estimated_total_compute_units: i64, /// The request payload hash - payload_hash: [u8; 32], + payload_hash: [u8; PAYLOAD_HASH_SIZE], /// The sender for the state manager state_manager_sender: FlumeSender, /// The keystore @@ -130,7 +124,7 @@ impl Streamer { state_manager_sender: FlumeSender, stack_small_id: i64, estimated_total_compute_units: i64, - payload_hash: [u8; 32], + payload_hash: [u8; PAYLOAD_HASH_SIZE], keystore: Arc, address_index: usize, model: String, @@ -196,7 +190,11 @@ impl Streamer { payload_hash = hex::encode(self.payload_hash) ) )] - fn handle_final_chunk(&mut self, usage: &Value, response_hash: [u8; 32]) -> Result<(), Error> { + fn handle_final_chunk( + &mut self, + usage: &Value, + response_hash: [u8; PAYLOAD_HASH_SIZE], + ) -> Result<(), Error> { // Record the decoding phase timer if let Some(timer) = self.decoding_phase_timer.take() { timer.observe_duration(); @@ -453,7 +451,7 @@ impl Stream for Streamer { if choices.is_empty() { // Check if this is a final chunk with usage info - if let Some(usage) = chunk.get(USAGE) { + if let Some(usage) = chunk.get(USAGE_KEY) { self.status = StreamStatus::Completed; let mut chunk = if let Some(streaming_encryption_metadata) = self.streaming_encryption_metadata.as_ref()