Skip to content

Commit

Permalink
feat: add blockchain query for newly created stacks that are requeste…
Browse files Browse the repository at this point in the history
…d to the atoma service (#248)

* first commit

* update state manager documents

* address PR comments

* address PR comments

* address additional PR comments

* add check constraints and unit tests
  • Loading branch information
jorgeantonio21 authored Nov 27, 2024
1 parent 1960417 commit 28d875f
Show file tree
Hide file tree
Showing 12 changed files with 479 additions and 51 deletions.
3 changes: 3 additions & 0 deletions atoma-bin/atoma_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ async fn main() -> Result<()> {
shutdown_sender.clone(),
);

let (stack_retrieve_sender, stack_retrieve_receiver) = tokio::sync::mpsc::unbounded_channel();
let package_id = config.sui.atoma_package_id();
info!(
target = "atoma-node-service",
Expand All @@ -209,6 +210,7 @@ async fn main() -> Result<()> {
let subscriber = SuiEventSubscriber::new(
config.sui,
event_subscriber_sender,
stack_retrieve_receiver,
shutdown_receiver.clone(),
);

Expand Down Expand Up @@ -245,6 +247,7 @@ async fn main() -> Result<()> {

let app_state = AppState {
state_manager_sender,
stack_retrieve_sender,
tokenizers: Arc::new(tokenizers),
models: Arc::new(config.service.models),
chat_completions_service_url: config
Expand Down
4 changes: 1 addition & 3 deletions atoma-service/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,5 @@ utoipa-swagger-ui = { workspace = true, features = ["axum"] }
[dev-dependencies]
rand = { workspace = true }
serial_test = { workspace = true }
tempfile = { workspace = true }

[dev_dependencies]
sqlx = { workspace = true, features = ["runtime-tokio", "postgres"] }
tempfile = { workspace = true }
85 changes: 82 additions & 3 deletions atoma-service/src/middleware.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ use serde_json::Value;
use sui_sdk::types::{
base_types::SuiAddress,
crypto::{PublicKey, Signature, SuiSignature},
digests::TransactionDigest,
};
use tokio::sync::oneshot;
use tracing::{error, instrument};
Expand Down Expand Up @@ -334,8 +335,31 @@ pub async fn verify_stack_permissions(
StatusCode::UNAUTHORIZED
})?;
if available_stack.is_none() {
error!("No available stack with enough compute units");
return Err(StatusCode::UNAUTHORIZED);
let tx_digest_str = req_parts
.headers
.get("X-Tx-Digest")
.ok_or_else(|| {
error!("Stack not found, tx digest header expected but not found");
StatusCode::BAD_REQUEST
})?
.to_str()
.map_err(|_| {
error!("Tx digest cannot be converted to a string");
StatusCode::BAD_REQUEST
})?;
let tx_digest = TransactionDigest::from_str(tx_digest_str).unwrap();
let (tx_stack_small_id, compute_units) =
utils::request_blockchain_for_stack(&state, tx_digest, total_num_compute_units).await?;

// NOTE: We need to check that the stack small id matches the one in the request
// otherwise, the user is requesting for a different stack, which is invalid. We
// must also check that the compute units are enough for processing the request.
if stack_small_id != tx_stack_small_id as i64
|| compute_units > total_num_compute_units as u64
{
error!("No available stack with enough compute units");
return Err(StatusCode::UNAUTHORIZED);
}
}
let request_metadata = RequestMetadata::default()
.with_stack_info(stack_small_id, total_num_compute_units)
Expand All @@ -353,7 +377,7 @@ pub(crate) mod utils {
secp256r1::{Secp256r1PublicKey, Secp256r1Signature},
traits::{ToFromBytes, VerifyingKey},
};
use sui_sdk::types::crypto::SignatureScheme;
use sui_sdk::types::{crypto::SignatureScheme, digests::TransactionDigest};

/// Verifies the authenticity of a request by checking its signature against the provided hash.
///
Expand Down Expand Up @@ -381,6 +405,7 @@ pub(crate) mod utils {
/// This function is critical for ensuring request authenticity. It verifies that:
/// 1. The request was signed by the owner of the public key
/// 2. The request body hasn't been tampered with since signing
#[instrument(level = "trace", skip_all)]
pub(crate) fn verify_signature(
base64_signature: &str,
body_hash: &[u8; 32],
Expand Down Expand Up @@ -431,6 +456,57 @@ pub(crate) mod utils {
Ok(())
}

/// Queries the blockchain to retrieve compute units associated with a specific transaction.
///
/// This asynchronous function uses a combination of channels to communicate with the blockchain:
/// - An mpsc channel to send the query request to the compute units handler
/// - A oneshot channel to receive the response back in the current scope
///
/// # Arguments
/// * `state` - Reference to the application state containing the compute units mpsc sender
/// * `tx_digest` - The transaction digest to query compute units for
///
/// # Returns
/// * `Ok(u64)` - The number of compute units if found
/// * `Err(StatusCode)` - If the request fails, returns one of:
/// - `INTERNAL_SERVER_ERROR` if channel communication fails
/// - `UNAUTHORIZED` if no compute units are found for the transaction
///
/// # Channel Communication Flow
/// 1. Creates a oneshot channel for receiving the response
/// 2. Sends the transaction digest and oneshot sender through the mpsc channel
/// 3. Awaits the response on the oneshot receiver
///
/// # Example
/// ```rust,ignore
/// let compute_units = request_blockchain_for_stack(&app_state, transaction_digest).await?;
/// ```
#[instrument(level = "trace", skip_all)]
pub(crate) async fn request_blockchain_for_stack(
state: &AppState,
tx_digest: TransactionDigest,
estimated_compute_units: i64,
) -> Result<(u64, u64), StatusCode> {
let (result_sender, result_receiver) = oneshot::channel();
state
.stack_retrieve_sender
.send((tx_digest, estimated_compute_units, result_sender))
.map_err(|_| {
error!("Failed to send compute units request");
StatusCode::INTERNAL_SERVER_ERROR
})?;
let result = result_receiver.await.map_err(|_| {
error!("Failed to receive compute units");
StatusCode::INTERNAL_SERVER_ERROR
})?;
if let (Some(stack_small_id), Some(compute_units)) = result {
Ok((stack_small_id, compute_units))
} else {
error!("No compute units found for transaction");
Err(StatusCode::UNAUTHORIZED)
}
}

/// Calculates the total number of compute units required for a request based on its type and content.
///
/// # Arguments
Expand Down Expand Up @@ -509,6 +585,7 @@ pub(crate) mod utils {
/// "max_tokens": 100
/// }
/// ```
#[instrument(level = "trace", skip_all)]
pub(crate) fn calculate_chat_completion_compute_units(
body_json: &Value,
state: &AppState,
Expand Down Expand Up @@ -609,6 +686,7 @@ pub(crate) mod utils {
/// # Computation
/// The total compute units is calculated as the sum of tokens across all input texts.
/// For array inputs, each string is tokenized separately and the results are summed.
#[instrument(level = "trace", skip_all)]
fn calculate_embedding_compute_units(
body_json: &Value,
state: &AppState,
Expand Down Expand Up @@ -685,6 +763,7 @@ pub(crate) mod utils {
/// "n": 1
/// }
/// ```
#[instrument(level = "trace", skip_all)]
fn calculate_image_generation_compute_units(body_json: &Value) -> Result<i64, StatusCode> {
let size = body_json
.get(IMAGE_SIZE)
Expand Down
19 changes: 18 additions & 1 deletion atoma-service/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,12 @@ use hyper::StatusCode;
use prometheus::Encoder;
use serde_json::{json, Value};
use sui_keys::keystore::FileBasedKeystore;
use sui_sdk::types::digests::TransactionDigest;
use tokenizers::Tokenizer;
use tokio::{net::TcpListener, sync::watch::Receiver};
use tokio::{
net::TcpListener,
sync::{mpsc, oneshot, watch::Receiver},
};
use tower::ServiceBuilder;
use tracing::error;
use utoipa::OpenApi;
Expand All @@ -39,6 +43,15 @@ pub const HEALTH_PATH: &str = "/health";
/// The path for the metrics endpoint.
pub const METRICS_PATH: &str = "/metrics";

/// A small identifier for a Stack, represented as a 64-bit unsigned integer.
type StackSmallId = u64;

/// Represents the number of compute units available, stored as a 64-bit unsigned integer.
type ComputeUnits = u64;

/// Represents the result of a blockchain query for stack information.
type StackQueryResult = (Option<StackSmallId>, Option<ComputeUnits>);

/// Represents the shared state of the application.
///
/// This struct holds various components and configurations that are shared
Expand All @@ -53,6 +66,10 @@ pub struct AppState {
/// updates and notifications across different components.
pub state_manager_sender: FlumeSender<AtomaAtomaStateManagerEvent>,

/// Channel sender for requesting compute units from the blockchain.
pub stack_retrieve_sender:
mpsc::UnboundedSender<(TransactionDigest, i64, oneshot::Sender<StackQueryResult>)>,

/// Tokenizer used for processing text input.
///
/// The tokenizer is responsible for breaking down text input into
Expand Down
2 changes: 2 additions & 0 deletions atoma-service/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ mod middleware {
let tokenizer = load_tokenizer().await;
let (state_manager_handle, state_manager_sender, shutdown_sender, _event_subscriber_sender) =
setup_database(public_key.clone()).await;
let (stack_retrieve_sender, _) = tokio::sync::mpsc::unbounded_channel();
(
AppState {
models: Arc::new(models.into_iter().map(|s| s.to_string()).collect()),
Expand All @@ -201,6 +202,7 @@ mod middleware {
image_generations_service_url: "".to_string(),
keystore: Arc::new(keystore),
address_index: 0,
stack_retrieve_sender,
},
public_key,
signature,
Expand Down
22 changes: 20 additions & 2 deletions atoma-state/src/handlers.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use atoma_sui::events::{
AtomaEvent, NewStackSettlementAttestationEvent, NodeSubscribedToTaskEvent,
NodeSubscriptionUpdatedEvent, NodeUnsubscribedFromTaskEvent, StackAttestationDisputeEvent,
StackCreatedEvent, StackSettlementTicketClaimedEvent, StackSettlementTicketEvent,
StackTrySettleEvent, TaskDeprecationEvent, TaskRegisteredEvent,
StackCreateAndUpdateEvent, StackCreatedEvent, StackSettlementTicketClaimedEvent,
StackSettlementTicketEvent, StackTrySettleEvent, TaskDeprecationEvent, TaskRegisteredEvent,
};
use tracing::{info, instrument};

Expand Down Expand Up @@ -33,6 +33,9 @@ pub async fn handle_atoma_event(
AtomaEvent::StackCreatedEvent(event) => {
handle_stack_created_event(state_manager, event).await
}
AtomaEvent::StackCreateAndUpdateEvent(event) => {
handle_stack_create_and_update_event(state_manager, event).await
}
AtomaEvent::StackTrySettleEvent(event) => {
handle_stack_try_settle_event(state_manager, event).await
}
Expand Down Expand Up @@ -372,6 +375,21 @@ pub(crate) async fn handle_stack_created_event(
Ok(())
}

#[instrument(level = "info", skip_all)]
pub(crate) async fn handle_stack_create_and_update_event(
state_manager: &AtomaStateManager,
event: StackCreateAndUpdateEvent,
) -> Result<()> {
info!(
target = "atoma-state-handlers",
event = "handle-stack-create-and-update-event",
"Processing stack create and update event"
);
let stack = event.into();
state_manager.state.insert_new_stack(stack).await?;
Ok(())
}

/// Handles a stack try settle event.
///
/// This function processes a stack try settle event by parsing the event data,
Expand Down
43 changes: 21 additions & 22 deletions atoma-state/src/migrations/20241121103103_create_node_tables.sql
Original file line number Diff line number Diff line change
@@ -1,46 +1,45 @@
-- Add migration script here

-- Create tasks table
CREATE TABLE IF NOT EXISTS tasks (
task_small_id BIGINT PRIMARY KEY,
task_small_id BIGINT PRIMARY KEY,
task_id TEXT UNIQUE NOT NULL,
role BIGINT NOT NULL,
role BIGINT NOT NULL,
model_name TEXT,
is_deprecated BOOLEAN NOT NULL,
valid_until_epoch BIGINT,
deprecated_at_epoch BIGINT,
security_level BIGINT NOT NULL,
security_level BIGINT NOT NULL,
minimum_reputation_score BIGINT
);

-- Create node_subscriptions table
CREATE TABLE IF NOT EXISTS node_subscriptions (
task_small_id BIGINT NOT NULL,
node_small_id BIGINT NOT NULL,
price_per_compute_unit BIGINT NOT NULL,
max_num_compute_units BIGINT NOT NULL,
task_small_id BIGINT NOT NULL,
node_small_id BIGINT NOT NULL,
price_per_compute_unit BIGINT NOT NULL,
max_num_compute_units BIGINT NOT NULL,
valid BOOLEAN NOT NULL,
PRIMARY KEY (task_small_id, node_small_id),
FOREIGN KEY (task_small_id)
REFERENCES tasks (task_small_id)
);

CREATE INDEX IF NOT EXISTS idx_node_subscriptions_task_small_id_node_small_id
CREATE INDEX IF NOT EXISTS idx_node_subscriptions_task_small_id_node_small_id
ON node_subscriptions (task_small_id, node_small_id);

-- Create stacks table
CREATE TABLE IF NOT EXISTS stacks (
stack_small_id BIGINT PRIMARY KEY,
stack_small_id BIGINT PRIMARY KEY,
owner_address TEXT NOT NULL,
stack_id TEXT UNIQUE NOT NULL,
task_small_id BIGINT NOT NULL,
selected_node_id BIGINT NOT NULL,
num_compute_units BIGINT NOT NULL,
price BIGINT NOT NULL,
already_computed_units BIGINT NOT NULL,
task_small_id BIGINT NOT NULL,
selected_node_id BIGINT NOT NULL,
num_compute_units BIGINT NOT NULL,
price BIGINT NOT NULL,
already_computed_units BIGINT NOT NULL,
in_settle_period BOOLEAN NOT NULL,
total_hash BYTEA NOT NULL,
num_total_messages BIGINT NOT NULL,
total_hash BYTEA NOT NULL,
num_total_messages BIGINT NOT NULL,
CONSTRAINT check_compute_units CHECK (already_computed_units <= num_compute_units),
FOREIGN KEY (selected_node_id, task_small_id)
REFERENCES node_subscriptions (node_small_id, task_small_id)
);
Expand All @@ -54,10 +53,10 @@ CREATE INDEX IF NOT EXISTS idx_stacks_stack_small_id
-- Create stack_attestation_disputes table
CREATE TABLE IF NOT EXISTS stack_attestation_disputes (
stack_small_id BIGINT NOT NULL,
attestation_commitment BYTEA NOT NULL,
attestation_commitment BYTEA NOT NULL,
attestation_node_id BIGINT NOT NULL,
original_node_id BIGINT NOT NULL,
original_commitment BYTEA NOT NULL,
original_commitment BYTEA NOT NULL,
PRIMARY KEY (stack_small_id, attestation_node_id),
FOREIGN KEY (stack_small_id)
REFERENCES stacks (stack_small_id)
Expand All @@ -69,8 +68,8 @@ CREATE TABLE IF NOT EXISTS stack_settlement_tickets (
selected_node_id BIGINT NOT NULL,
num_claimed_compute_units BIGINT NOT NULL,
requested_attestation_nodes TEXT NOT NULL,
committed_stack_proofs BYTEA NOT NULL,
stack_merkle_leaves BYTEA NOT NULL,
committed_stack_proofs BYTEA NOT NULL,
stack_merkle_leaves BYTEA NOT NULL,
dispute_settled_at_epoch BIGINT,
already_attested_nodes TEXT NOT NULL,
is_in_dispute BOOLEAN NOT NULL,
Expand Down
Loading

0 comments on commit 28d875f

Please sign in to comment.