Skip to content

Commit

Permalink
change way we handle stream chunks
Browse files Browse the repository at this point in the history
  • Loading branch information
jorgeantonio21 committed Jan 2, 2025
1 parent 16ef39c commit f0b5004
Show file tree
Hide file tree
Showing 5 changed files with 118 additions and 53 deletions.
13 changes: 6 additions & 7 deletions atoma-confidential/src/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,7 @@ impl AtomaConfidentialComputeService {
name = "handle_decryption_request",
skip_all,
fields(
client_public_key = ?decryption_request.client_x25519_public_key,
client_public_key = ?decryption_request.client_dh_public_key,
node_public_key = ?self.key_manager.get_public_key().as_bytes()
)
)]
Expand All @@ -328,24 +328,23 @@ impl AtomaConfidentialComputeService {
ciphertext,
nonce,
salt,
client_x25519_public_key,
node_x25519_public_key,
client_dh_public_key,
node_dh_public_key,
} = decryption_request;
let result = if PublicKey::from(node_x25519_public_key) != self.key_manager.get_public_key()
{
let result = if PublicKey::from(node_dh_public_key) != self.key_manager.get_public_key() {
tracing::error!(
target = "atoma-confidential-compute-service",
event = "confidential_compute_service_decryption_error",
"Node X25519 public key does not match the expected key: {:?} != {:?}",
node_x25519_public_key,
node_dh_public_key,
self.key_manager.get_public_key().as_bytes()
);
Err(anyhow::anyhow!(
"Node X25519 public key does not match the expected key"
))
} else {
self.key_manager
.decrypt_ciphertext(client_x25519_public_key, &ciphertext, &salt, &nonce)
.decrypt_ciphertext(client_dh_public_key, &ciphertext, &salt, &nonce)
.map_err(|e| {
tracing::error!(
target = "atoma-confidential-compute-service",
Expand Down
4 changes: 2 additions & 2 deletions atoma-confidential/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ pub struct ConfidentialComputeDecryptionRequest {
pub salt: [u8; SALT_SIZE],

/// Public key component for Diffie-Hellman key exchange from the client
pub client_x25519_public_key: [u8; DH_PUBLIC_KEY_SIZE],
pub client_dh_public_key: [u8; DH_PUBLIC_KEY_SIZE],

/// Public key component for Diffie-Hellman key exchange from the node
pub node_x25519_public_key: [u8; DH_PUBLIC_KEY_SIZE],
pub node_dh_public_key: [u8; DH_PUBLIC_KEY_SIZE],
}

/// Response containing the decrypted data from a confidential computation request
Expand Down
14 changes: 7 additions & 7 deletions atoma-service/src/middleware.rs
Original file line number Diff line number Diff line change
Expand Up @@ -611,7 +611,7 @@ pub async fn confidential_compute_middleware(
{
Ok(DecryptionMetadata {
plaintext,
client_dh_public_key: client_x25519_public_key_bytes,
client_dh_public_key: client_dh_public_key_bytes,
salt: salt_bytes,
}) => {
utils::check_plaintext_body_hash(plaintext_body_hash_bytes, &plaintext, &endpoint)?;
Expand All @@ -623,7 +623,7 @@ pub async fn confidential_compute_middleware(
.unwrap_or_default();
req_parts.extensions.insert(
request_metadata
.with_client_encryption_metadata(client_x25519_public_key_bytes, salt_bytes)
.with_client_encryption_metadata(client_dh_public_key_bytes, salt_bytes)
.with_payload_hash(plaintext_body_hash_bytes),
);
let stack_small_id = confidential_compute_request.stack_small_id;
Expand Down Expand Up @@ -1158,7 +1158,7 @@ pub(crate) mod utils {
endpoint: endpoint.to_string(),
}
})?;
let client_x25519_public_key_bytes: [u8; DH_PUBLIC_KEY_SIZE] = STANDARD
let client_dh_public_key_bytes: [u8; DH_PUBLIC_KEY_SIZE] = STANDARD
.decode(&confidential_compute_request.client_dh_public_key)
.map_err(|e| {
AtomaServiceError::InvalidHeader {
Expand All @@ -1176,7 +1176,7 @@ pub(crate) mod utils {
endpoint: endpoint.to_string(),
}
})?;
let node_x25519_public_key_bytes: [u8; DH_PUBLIC_KEY_SIZE] = STANDARD
let node_dh_public_key_bytes: [u8; DH_PUBLIC_KEY_SIZE] = STANDARD
.decode(&confidential_compute_request.node_dh_public_key)
.map_err(|e| {
AtomaServiceError::InvalidHeader {
Expand Down Expand Up @@ -1206,8 +1206,8 @@ pub(crate) mod utils {
ciphertext: ciphertext_bytes,
nonce: nonce_bytes,
salt: salt_bytes,
client_x25519_public_key: client_x25519_public_key_bytes,
node_x25519_public_key: node_x25519_public_key_bytes,
client_dh_public_key: client_dh_public_key_bytes,
node_dh_public_key: node_dh_public_key_bytes,
};
let (result_sender, result_receiver) = oneshot::channel();
state
Expand All @@ -1233,7 +1233,7 @@ pub(crate) mod utils {
.plaintext;
Ok(DecryptionMetadata {
plaintext,
client_dh_public_key: client_x25519_public_key_bytes,
client_dh_public_key: client_dh_public_key_bytes,
salt: salt_bytes,
})
}
Expand Down
128 changes: 97 additions & 31 deletions atoma-service/src/streamer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@ use std::{

use atoma_state::types::AtomaAtomaStateManagerEvent;
use atoma_utils::{
constants::{NONCE_SIZE, SALT_SIZE},
constants::{NONCE_SIZE, PAYLOAD_HASH_SIZE, SALT_SIZE},
encryption::encrypt_plaintext,
hashing::blake2b_hash,
};
use axum::body::Bytes;
use axum::{response::sse::Event, Error};
use base64::{engine::general_purpose::STANDARD, Engine};
use flume::Sender as FlumeSender;
use futures::Stream;
use prometheus::HistogramTimer;
Expand All @@ -33,15 +34,34 @@ use crate::{

/// The chunk that indicates the end of a streaming response
const DONE_CHUNK: &str = "[DONE]";

/// The prefix for the data chunk
const DATA_PREFIX: &str = "data: ";

/// The keep-alive chunk
const KEEP_ALIVE_CHUNK: &[u8] = b": keep-alive\n\n";

/// The choices key
const CHOICES: &str = "choices";

/// The usage key
const USAGE: &str = "usage";

/// The ciphertext key
const CIPHERTEXT_KEY: &str = "ciphertext";

/// The nonce key
const NONCE_KEY: &str = "nonce";

/// The response hash key
const RESPONSE_HASH_KEY: &str = "response_hash";

/// The signature key
const SIGNATURE_KEY: &str = "signature";

/// The usage key
const USAGE_KEY: &str = "usage";

/// Metadata required for encrypting streaming responses to clients.
///
/// This structure contains the cryptographic elements needed to establish
Expand Down Expand Up @@ -176,23 +196,12 @@ impl Streamer {
payload_hash = hex::encode(self.payload_hash)
)
)]
fn handle_final_chunk(&mut self, usage: &Value) -> Result<String, Error> {
fn handle_final_chunk(&mut self, usage: &Value, response_hash: [u8; 32]) -> Result<(), Error> {
// Record the decoding phase timer
if let Some(timer) = self.decoding_phase_timer.take() {
timer.observe_duration();
}

// Sign the accumulated response
let (response_hash, signature) = utils::sign_response_body(
&json!(self.accumulated_response),
&self.keystore,
self.address_index,
)
.map_err(|e| {
error!("Error signing response: {}", e);
Error::new(format!("Error signing response: {}", e))
})?;

// Get total tokens
let mut total_compute_units = 0;
if let Some(prompt_tokens) = usage.get("prompt_tokens") {
Expand Down Expand Up @@ -262,7 +271,32 @@ impl Streamer {
error!("Error updating stack num tokens: {}", e);
}

Ok(signature)
Ok(())
}

/// Signs the accumulated response
/// This is used when the streaming is complete and we need to send the final chunk back to the client
/// with the signature and response hash
///
/// # Returns
///
/// Returns a tuple containing:
/// * A base64-encoded string of the signature
/// * A base64-encoded string of the response hash
#[instrument(level = "debug", skip_all)]
pub fn sign_final_chunk(&mut self) -> Result<(String, [u8; PAYLOAD_HASH_SIZE]), Error> {
// Sign the accumulated response
let (response_hash, signature) = utils::sign_response_body(
&json!(self.accumulated_response),
&self.keystore,
self.address_index,
)
.map_err(|e| {
error!("Error signing response: {}", e);
Error::new(format!("Error signing response: {}", e))
})?;

Ok((signature, response_hash))
}

/// Handles the encryption request for a chunk of streaming data.
Expand Down Expand Up @@ -291,19 +325,33 @@ impl Streamer {
#[instrument(level = "debug", skip_all)]
fn handle_encryption_request(
chunk: &Value,
usage: Option<&Value>,
streaming_encryption_metadata: &StreamingEncryptionMetadata,
) -> Result<Value, Error> {
let StreamingEncryptionMetadata {
shared_secret,
nonce,
salt,
} = streaming_encryption_metadata;
let (encrypted_chunk, nonce) = encrypt_plaintext(
chunk.to_string().as_bytes(),
shared_secret,
salt,
Some(*nonce),
)
// NOTE: We remove the usage key from the chunk before encryption
// because we need to send the usage key back to the client in the final chunk
let (encrypted_chunk, nonce) = if usage.is_some() {
let mut chunk = chunk.clone();
chunk.as_object_mut().map(|obj| obj.remove(USAGE_KEY));
encrypt_plaintext(
chunk.to_string().as_bytes(),
shared_secret,
salt,
Some(*nonce),
)
} else {
encrypt_plaintext(
chunk.to_string().as_bytes(),
shared_secret,
salt,
Some(*nonce),
)
}
.map_err(|e| {
error!(
target = "atoma-service",
Expand All @@ -313,10 +361,18 @@ impl Streamer {
);
Error::new(format!("Error encrypting chunk: {}", e))
})?;
Ok(json!({
"ciphertext": encrypted_chunk,
"nonce": nonce,
}))
if let Some(usage) = usage {
Ok(json!({
CIPHERTEXT_KEY: encrypted_chunk,
NONCE_KEY: nonce,
USAGE_KEY: usage.clone(),
}))
} else {
Ok(json!({
CIPHERTEXT_KEY: encrypted_chunk,
NONCE_KEY: nonce,
}))
}
}
}

Expand Down Expand Up @@ -360,7 +416,7 @@ impl Stream for Streamer {
self.status = StreamStatus::Completed;
return Poll::Ready(None);
}
let mut chunk = serde_json::from_str::<Value>(chunk_str).map_err(|e| {
let chunk = serde_json::from_str::<Value>(chunk_str).map_err(|e| {
error!(
target = "atoma-service",
level = "error",
Expand Down Expand Up @@ -399,16 +455,22 @@ impl Stream for Streamer {
// Check if this is a final chunk with usage info
if let Some(usage) = chunk.get(USAGE) {
self.status = StreamStatus::Completed;
let signature = self.handle_final_chunk(usage)?;
chunk["signature"] = json!(signature);
let chunk = if let Some(streaming_encryption_metadata) =
let mut chunk = if let Some(streaming_encryption_metadata) =
self.streaming_encryption_metadata.as_ref()
{
// NOTE: We only need to perform chunk encryption when sending the chunk back to the client
Self::handle_encryption_request(&chunk, streaming_encryption_metadata)?
Self::handle_encryption_request(
&chunk,
Some(usage),
streaming_encryption_metadata,
)?
} else {
chunk
chunk.clone()
};
let (signature, response_hash) = self.sign_final_chunk()?;
chunk[SIGNATURE_KEY] = json!(signature);
chunk[RESPONSE_HASH_KEY] = json!(STANDARD.encode(response_hash));
self.handle_final_chunk(usage, response_hash)?;
Poll::Ready(Some(Ok(Event::default().json_data(&chunk)?)))
} else {
error!(
Expand All @@ -426,7 +488,11 @@ impl Stream for Streamer {
self.streaming_encryption_metadata.as_ref()
{
// NOTE: We only need to perform chunk encryption when sending the chunk back to the client
Self::handle_encryption_request(&chunk, streaming_encryption_metadata)?
Self::handle_encryption_request(
&chunk,
None,
streaming_encryption_metadata,
)?
} else {
chunk
};
Expand Down
12 changes: 6 additions & 6 deletions atoma-service/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1237,7 +1237,7 @@ mod middleware {
let client_dh_private_key = x25519_dalek::StaticSecret::random_from_rng(rand::thread_rng());
let client_dh_public_key = x25519_dalek::PublicKey::from(&client_dh_private_key);

let blake2b_hash: [u8; 32] = blake2b_hash(TEST_MESSAGE.as_bytes()).try_into().unwrap();
let blake2b_hash: [u8; 32] = blake2b_hash(TEST_MESSAGE.as_bytes()).into();
let shared_secret = client_dh_private_key.diffie_hellman(&server_dh_public_key);
let (encrypted_data, nonce) =
encrypt_plaintext(TEST_MESSAGE.as_bytes(), &shared_secret, &salt, None)
Expand Down Expand Up @@ -1293,7 +1293,7 @@ mod middleware {
let client_dh_private_key = x25519_dalek::StaticSecret::random_from_rng(rand::thread_rng());
let client_dh_public_key = x25519_dalek::PublicKey::from(&client_dh_private_key);

let blake2b_hash: [u8; 32] = blake2b_hash(TEST_MESSAGE.as_bytes()).try_into().unwrap();
let blake2b_hash: [u8; 32] = blake2b_hash(TEST_MESSAGE.as_bytes()).into();
let shared_secret = client_dh_private_key.diffie_hellman(&server_dh_public_key);
let (encrypted_data, nonce) =
encrypt_plaintext(TEST_MESSAGE.as_bytes(), &shared_secret, &salt, None)
Expand Down Expand Up @@ -1343,7 +1343,7 @@ mod middleware {
let client_dh_private_key = x25519_dalek::StaticSecret::random_from_rng(rand::thread_rng());
let client_dh_public_key = x25519_dalek::PublicKey::from(&client_dh_private_key);

let blake2b_hash: [u8; 32] = blake2b_hash(TEST_MESSAGE.as_bytes()).try_into().unwrap();
let blake2b_hash: [u8; 32] = blake2b_hash(TEST_MESSAGE.as_bytes()).into();
let encrypted_body_json = json!({
"ciphertext": STANDARD.encode([0u8; 32]),
"salt": STANDARD.encode(salt),
Expand Down Expand Up @@ -1418,7 +1418,7 @@ mod middleware {
let client_dh_private_key = x25519_dalek::StaticSecret::random_from_rng(rand::thread_rng());
let client_dh_public_key = x25519_dalek::PublicKey::from(&client_dh_private_key);

let blake2b_hash: [u8; 32] = blake2b_hash(TEST_MESSAGE.as_bytes()).try_into().unwrap();
let blake2b_hash: [u8; 32] = blake2b_hash(TEST_MESSAGE.as_bytes()).into();

let encrypted_body_json = json!({
"ciphertext": "x".repeat(2 * 1024 * 1024),
Expand Down Expand Up @@ -1471,7 +1471,7 @@ mod middleware {

// Create incorrect hash (hash of different data)
let incorrect_plaintext = "different data".as_bytes();
let incorrect_hash: [u8; 32] = blake2b_hash(incorrect_plaintext).try_into().unwrap();
let incorrect_hash: [u8; 32] = blake2b_hash(incorrect_plaintext).into();

let shared_secret = client_dh_private_key.diffie_hellman(&server_dh_public_key);
let (encrypted_data, nonce) =
Expand Down Expand Up @@ -1527,7 +1527,7 @@ mod middleware {
let client_dh_private_key = x25519_dalek::StaticSecret::random_from_rng(rand::thread_rng());
let client_dh_public_key = x25519_dalek::PublicKey::from(&client_dh_private_key);

let blake2b_hash: [u8; 32] = blake2b_hash(TEST_MESSAGE.as_bytes()).try_into().unwrap();
let blake2b_hash: [u8; 32] = blake2b_hash(TEST_MESSAGE.as_bytes()).into();
let shared_secret = client_dh_private_key.diffie_hellman(&server_dh_public_key);
let (encrypted_data, nonce) =
encrypt_plaintext(TEST_MESSAGE.as_bytes(), &shared_secret, &salt, None)
Expand Down

0 comments on commit f0b5004

Please sign in to comment.