Skip to content

Commit

Permalink
address PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
jorgeantonio21 committed Jan 2, 2025
1 parent 21b0e82 commit 670ee53
Show file tree
Hide file tree
Showing 6 changed files with 33 additions and 26 deletions.
7 changes: 5 additions & 2 deletions atoma-service/src/handlers/chat_completions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ pub const CHAT_COMPLETIONS_PATH: &str = "/v1/chat/completions";
/// The keep-alive interval in seconds
const STREAM_KEEP_ALIVE_INTERVAL_IN_SECONDS: u64 = 15;

/// The key for the model parameter in the request body
const MODEL_KEY: &str = "model";

/// The key for the stream parameter in the request body
const STREAM_KEY: &str = "stream";

Expand Down Expand Up @@ -512,7 +515,7 @@ async fn handle_non_streaming_response(
) -> Result<Response<Body>, AtomaServiceError> {
// Record token metrics and extract the response total number of tokens
let model = payload
.get("model")
.get(MODEL_KEY)
.and_then(|m| m.as_str())
.unwrap_or("unknown");
let timer = CHAT_COMPLETIONS_LATENCY_METRICS
Expand Down Expand Up @@ -608,7 +611,7 @@ async fn handle_streaming_response(
});

let model = payload
.get("model")
.get(MODEL_KEY)
.and_then(|m| m.as_str())
.unwrap_or("unknown");
CHAT_COMPLETIONS_NUM_REQUESTS
Expand Down
7 changes: 5 additions & 2 deletions atoma-service/src/handlers/embeddings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ pub const CONFIDENTIAL_EMBEDDINGS_PATH: &str = "/v1/confidential/embeddings";
/// The path for embeddings requests
pub const EMBEDDINGS_PATH: &str = "/v1/embeddings";

/// The key for the model parameter in the request body
pub const MODEL_KEY: &str = "model";

/// OpenAPI documentation structure for the embeddings endpoint.
///
/// This struct defines the OpenAPI (Swagger) documentation for the embeddings API,
Expand Down Expand Up @@ -72,7 +75,7 @@ pub async fn embeddings_handler(
) -> Result<Json<Value>, AtomaServiceError> {
info!("Received embeddings request, with payload: {payload}");
let model = payload
.get("model")
.get(MODEL_KEY)
.and_then(|m| m.as_str())
.unwrap_or("unknown");
let RequestMetadata {
Expand Down Expand Up @@ -181,7 +184,7 @@ pub async fn confidential_embeddings_handler(
) -> Result<Json<Value>, AtomaServiceError> {
info!("Received embeddings request, with payload: {payload}");
let model = payload
.get("model")
.get(MODEL_KEY)
.and_then(|m| m.as_str())
.unwrap_or("unknown");
let RequestMetadata {
Expand Down
7 changes: 5 additions & 2 deletions atoma-service/src/handlers/image_generations.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ pub const CONFIDENTIAL_IMAGE_GENERATIONS_PATH: &str = "/v1/confidential/images/g
/// The path for image generations requests
pub const IMAGE_GENERATIONS_PATH: &str = "/v1/images/generations";

/// The key for the model parameter in the request body
pub const MODEL_KEY: &str = "model";

/// OpenAPI documentation structure for the image generations endpoint.
///
/// This struct defines the OpenAPI (Swagger) documentation for the image generations API,
Expand Down Expand Up @@ -73,7 +76,7 @@ pub async fn image_generations_handler(
) -> Result<Json<Value>, AtomaServiceError> {
info!("Received image generations request, with payload: {payload}");
let model = payload
.get("model")
.get(MODEL_KEY)
.and_then(|m| m.as_str())
.unwrap_or("unknown");

Expand Down Expand Up @@ -182,7 +185,7 @@ pub async fn confidential_image_generations_handler(
) -> Result<Json<Value>, AtomaServiceError> {
info!("Received image generations request, with payload: {payload}");
let model = payload
.get("model")
.get(MODEL_KEY)
.and_then(|m| m.as_str())
.unwrap_or("unknown");

Expand Down
2 changes: 1 addition & 1 deletion atoma-service/src/handlers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ const RESPONSE_HASH_KEY: &str = "response_hash";
const SIGNATURE_KEY: &str = "signature";

/// Key for the usage in the response body
const USAGE_KEY: &str = "usage";
pub const USAGE_KEY: &str = "usage";

/// Updates response signature and stack hash state
///
Expand Down
16 changes: 8 additions & 8 deletions atoma-service/src/middleware.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1135,14 +1135,14 @@ pub(crate) mod utils {
endpoint: endpoint.to_string(),
})?;
let salt_bytes: [u8; SALT_SIZE] = salt_bytes.try_into().map_err(|e| {
AtomaServiceError::InvalidHeader {
message: format!(
"Failed to convert salt bytes to {SALT_SIZE}-byte array, incorrect length, with error: {:?}",
e
),
endpoint: endpoint.to_string(),
}
})?;
AtomaServiceError::InvalidHeader {
message: format!(
"Failed to convert salt bytes to {SALT_SIZE}-byte array, incorrect length, with error: {:?}",
e
),
endpoint: endpoint.to_string(),
}
})?;
let nonce_bytes = STANDARD
.decode(&confidential_compute_request.nonce)
.map_err(|e| AtomaServiceError::InvalidHeader {
Expand Down
20 changes: 9 additions & 11 deletions atoma-service/src/streamer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ use crate::{
CHAT_COMPLETIONS_DECODING_TIME, CHAT_COMPLETIONS_INPUT_TOKENS_METRICS,
CHAT_COMPLETIONS_OUTPUT_TOKENS_METRICS,
},
update_stack_num_compute_units,
update_stack_num_compute_units, USAGE_KEY,
},
server::utils,
};
Expand All @@ -44,9 +44,6 @@ const KEEP_ALIVE_CHUNK: &[u8] = b": keep-alive\n\n";
/// The choices key
const CHOICES: &str = "choices";

/// The usage key
const USAGE: &str = "usage";

/// The ciphertext key
const CIPHERTEXT_KEY: &str = "ciphertext";

Expand All @@ -59,9 +56,6 @@ const RESPONSE_HASH_KEY: &str = "response_hash";
/// The signature key
const SIGNATURE_KEY: &str = "signature";

/// The usage key
const USAGE_KEY: &str = "usage";

/// Metadata required for encrypting streaming responses to clients.
///
/// This structure contains the cryptographic elements needed to establish
Expand All @@ -88,7 +82,7 @@ pub struct Streamer {
/// The estimated total compute units for the request
estimated_total_compute_units: i64,
/// The request payload hash
payload_hash: [u8; 32],
payload_hash: [u8; PAYLOAD_HASH_SIZE],
/// The sender for the state manager
state_manager_sender: FlumeSender<AtomaAtomaStateManagerEvent>,
/// The keystore
Expand Down Expand Up @@ -130,7 +124,7 @@ impl Streamer {
state_manager_sender: FlumeSender<AtomaAtomaStateManagerEvent>,
stack_small_id: i64,
estimated_total_compute_units: i64,
payload_hash: [u8; 32],
payload_hash: [u8; PAYLOAD_HASH_SIZE],
keystore: Arc<FileBasedKeystore>,
address_index: usize,
model: String,
Expand Down Expand Up @@ -196,7 +190,11 @@ impl Streamer {
payload_hash = hex::encode(self.payload_hash)
)
)]
fn handle_final_chunk(&mut self, usage: &Value, response_hash: [u8; 32]) -> Result<(), Error> {
fn handle_final_chunk(
&mut self,
usage: &Value,
response_hash: [u8; PAYLOAD_HASH_SIZE],
) -> Result<(), Error> {
// Record the decoding phase timer
if let Some(timer) = self.decoding_phase_timer.take() {
timer.observe_duration();
Expand Down Expand Up @@ -453,7 +451,7 @@ impl Stream for Streamer {

if choices.is_empty() {
// Check if this is a final chunk with usage info
if let Some(usage) = chunk.get(USAGE) {
if let Some(usage) = chunk.get(USAGE_KEY) {
self.status = StreamStatus::Completed;
let mut chunk = if let Some(streaming_encryption_metadata) =
self.streaming_encryption_metadata.as_ref()
Expand Down

0 comments on commit 670ee53

Please sign in to comment.