diff --git a/Cargo.lock b/Cargo.lock index 987847e..661120f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -514,6 +514,7 @@ dependencies = [ "dotenv", "envy", "futures-util", + "http", "jsonwebtoken", "prost", "prost-types", diff --git a/Cargo.toml b/Cargo.toml index 6694744..167c719 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,6 +9,7 @@ ctrlc2 = { version = "3", features = ["termination", "tokio"] } derive_builder = "0.20" envy = "0.4" futures-util = "0.3" +http = "1" jsonwebtoken = "9" prost = "0.13" prost-types = "0.13" diff --git a/examples/fibonacci.rs b/examples/fibonacci.rs index 8910d25..49babd8 100644 --- a/examples/fibonacci.rs +++ b/examples/fibonacci.rs @@ -1,4 +1,4 @@ -use hatchet_sdk::{Client, StepBuilder, WorkflowBuilder}; +use hatchet_sdk::{Client, Context, StepBuilder, WorkflowBuilder}; fn fibonacci(n: u32) -> u32 { (1..=n) @@ -33,7 +33,7 @@ struct Output { result: u32, } -async fn execute(Input { n }: Input) -> anyhow::Result { +async fn execute(_context: Context, Input { n }: Input) -> anyhow::Result { Ok(Output { result: fibonacci(n), }) @@ -51,7 +51,7 @@ async fn main() -> anyhow::Result<()> { .init(); let client = Client::new()?; - let mut worker = client.worker("example").build(); + let mut worker = client.worker("example_fibonacci").build(); worker.register_workflow( WorkflowBuilder::default() .name("fibonacci") diff --git a/examples/spawn_workflow.rs b/examples/spawn_workflow.rs new file mode 100644 index 0000000..e60ed5c --- /dev/null +++ b/examples/spawn_workflow.rs @@ -0,0 +1,66 @@ +use hatchet_sdk::{Client, Context, StepBuilder, WorkflowBuilder}; + +async fn execute_hello( + context: Context, + _: serde_json::Value, +) -> anyhow::Result { + context + .trigger_workflow( + "world", + serde_json::json!({ + "x": 42 + }), + ) + .await?; + Ok(serde_json::json!({ + "message": "Hello" + })) +} + +async fn execute_world( + _context: Context, + _: serde_json::Value, +) -> anyhow::Result { + Ok(serde_json::json!({ + "message": "World" + })) +} + +#[tokio::main] +async fn main() -> anyhow::Result<()> { + dotenv::dotenv().ok(); + tracing_subscriber::fmt() + .with_target(false) + .with_env_filter( + tracing_subscriber::EnvFilter::from_default_env() + .add_directive("hatchet_sdk=debug".parse()?), + ) + .init(); + + let client = Client::new()?; + let mut worker = client.worker("example_spawn_workflow").build(); + worker.register_workflow( + WorkflowBuilder::default() + .name("hello") + .step( + StepBuilder::default() + .name("execute") + .function(&execute_hello) + .build()?, + ) + .build()?, + ); + worker.register_workflow( + WorkflowBuilder::default() + .name("world") + .step( + StepBuilder::default() + .name("execute") + .function(&execute_world) + .build()?, + ) + .build()?, + ); + worker.start().await?; + Ok(()) +} diff --git a/src/client.rs b/src/client.rs index 05d4d3d..6485d53 100644 --- a/src/client.rs +++ b/src/client.rs @@ -25,7 +25,7 @@ pub struct Client { } impl Client { - pub fn new() -> crate::Result { + pub fn new() -> crate::InternalResult { let environment = envy::prefixed("HATCHET_CLIENT_").from_env::()?; Ok(Self { environment }) } diff --git a/src/error.rs b/src/error.rs index 70dace0..9098383 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,11 +1,13 @@ #[derive(Debug, thiserror::Error)] -pub enum Error { +pub enum InternalError { #[error("failed to load configuration from the environment: {0}")] Environment(#[from] envy::Error), #[error("worker registration request: {0}")] CouldNotRegisterWorker(tonic::Status), #[error("workflow registration request:: {0}")] CouldNotPutWorkflow(tonic::Status), + #[error("workflow schedule request:: {0}")] + CouldNotTriggerWorkflow(tonic::Status), #[error("dispatcher listen error: {0}")] CouldNotListenToDispatcher(tonic::Status), #[error("step status send error: {0}")] @@ -28,4 +30,12 @@ pub enum Error { CouldNotDecodeActionPayload(serde_json::Error), } +pub type InternalResult = std::result::Result; + +#[derive(Debug, thiserror::Error)] +pub enum Error { + #[error("internal error: {0}")] + Internal(#[from] InternalError), +} + pub type Result = std::result::Result; diff --git a/src/lib.rs b/src/lib.rs index eaea36a..fb37495 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,9 +1,11 @@ mod client; mod error; +mod step_function; mod worker; mod workflow; pub use error::{Error, Result}; +pub(crate) use error::{InternalError, InternalResult}; #[derive(Clone, Copy, Debug, Default, serde::Deserialize)] #[serde(rename_all = "lowercase")] @@ -15,6 +17,7 @@ enum ClientTlStrategy { } pub use client::Client; +pub use step_function::Context; pub use worker::{Worker, WorkerBuilder}; pub use workflow::{Step, StepBuilder, Workflow, WorkflowBuilder}; diff --git a/src/step_function.rs b/src/step_function.rs new file mode 100644 index 0000000..8b6f1dd --- /dev/null +++ b/src/step_function.rs @@ -0,0 +1,75 @@ +use futures_util::lock::Mutex; +use tracing::info; + +use crate::worker::{grpc, ServiceWithAuthorization}; + +pub struct Context { + workflow_run_id: String, + workflow_step_run_id: String, + workflow_service_client_and_spawn_index: Mutex<( + grpc::workflow_service_client::WorkflowServiceClient< + tonic::service::interceptor::InterceptedService< + tonic::transport::Channel, + ServiceWithAuthorization, + >, + >, + u16, + )>, +} + +impl Context { + pub(crate) fn new( + workflow_run_id: String, + workflow_step_run_id: String, + workflow_service_client: grpc::workflow_service_client::WorkflowServiceClient< + tonic::service::interceptor::InterceptedService< + tonic::transport::Channel, + ServiceWithAuthorization, + >, + >, + ) -> Self { + Self { + workflow_run_id, + workflow_service_client_and_spawn_index: Mutex::new((workflow_service_client, 0)), + workflow_step_run_id, + } + } + + pub async fn trigger_workflow( + &self, + workflow_name: &str, + input: I, + ) -> anyhow::Result<()> { + info!("Scheduling another workflow {workflow_name}"); + let mut mutex_guard = self.workflow_service_client_and_spawn_index.lock().await; + let (workflow_service_client, spawn_index) = &mut *mutex_guard; + let response = workflow_service_client + .trigger_workflow(grpc::TriggerWorkflowRequest { + name: workflow_name.to_owned(), + input: serde_json::to_string(&input).expect("must succeed"), + parent_id: Some(self.workflow_run_id.clone()), + parent_step_run_id: Some(self.workflow_step_run_id.clone()), + child_index: Some(*spawn_index as i32), + child_key: None, + additional_metadata: None, // FIXME: Add support. + desired_worker_id: None, // FIXME: Add support. + priority: Some(1), // FIXME: Add support. + }) + .await + .map_err(crate::InternalError::CouldNotTriggerWorkflow) + .map_err(crate::Error::Internal)? + .into_inner(); + info!( + "Scheduled another workflow run ID: {}", + response.workflow_run_id + ); + *spawn_index += 1; + Ok(()) + } +} + +pub(crate) type StepFunction = + dyn Fn( + Context, + serde_json::Value, + ) -> futures_util::future::LocalBoxFuture<'static, anyhow::Result>; diff --git a/src/worker/heartbeat.rs b/src/worker/heartbeat.rs index 0fc5562..842866c 100644 --- a/src/worker/heartbeat.rs +++ b/src/worker/heartbeat.rs @@ -10,7 +10,7 @@ pub(crate) async fn run( >, worker_id: &str, mut interrupt_receiver: tokio::sync::mpsc::Receiver<()>, -) -> crate::Result<()> +) -> crate::InternalResult<()> where F: tonic::service::Interceptor + Send + 'static, { @@ -24,7 +24,7 @@ where worker_id: worker_id.clone(), }) .await - .map_err(crate::Error::CouldNotSendHeartbeat)?; + .map_err(crate::InternalError::CouldNotSendHeartbeat)?; tokio::select! { _ = interval.tick() => { @@ -35,7 +35,7 @@ where } } } - crate::Result::Ok(()) + crate::InternalResult::Ok(()) }) .await .expect("must succeed spawing")?; diff --git a/src/worker/listener.rs b/src/worker/listener.rs index d9b088d..b579edc 100644 --- a/src/worker/listener.rs +++ b/src/worker/listener.rs @@ -1,9 +1,10 @@ use futures_util::FutureExt; -use tokio::{task::LocalSet, task_local}; +use tokio::task::LocalSet; use tonic::IntoRequest; use tracing::{debug, error, info, warn}; use crate::{ + step_function::Context, worker::{grpc::ActionType, DEFAULT_ACTION_TIMEOUT}, Workflow, }; @@ -13,7 +14,7 @@ use super::{ dispatcher_client::DispatcherClient, AssignedAction, StepActionEvent, StepActionEventType, WorkerListenRequest, }, - ListenStrategy, + ListenStrategy, ServiceWithAuthorization, }; const DEFAULT_ACTION_LISTENER_RETRY_INTERVAL: std::time::Duration = @@ -44,19 +45,25 @@ struct ActionInput { input: T, } -async fn handle_start_step_run( +async fn handle_start_step_run( dispatcher: &mut DispatcherClient< - tonic::service::interceptor::InterceptedService, + tonic::service::interceptor::InterceptedService< + tonic::transport::Channel, + ServiceWithAuthorization, + >, + >, + workflow_service_client: super::grpc::workflow_service_client::WorkflowServiceClient< + tonic::service::interceptor::InterceptedService< + tonic::transport::Channel, + ServiceWithAuthorization, + >, >, local_set: &tokio::task::LocalSet, namespace: &str, worker_id: &str, workflows: &[Workflow], action: AssignedAction, -) -> crate::Result<()> -where - F: tonic::service::Interceptor + Send + 'static, -{ +) -> crate::InternalResult<()> { let Some(action_callable) = workflows .iter() .flat_map(|workflow| workflow.actions(namespace)) @@ -77,16 +84,27 @@ where Default::default(), )) .await - .map_err(crate::Error::CouldNotSendStepStatus)? + .map_err(crate::InternalError::CouldNotSendStepStatus)? .into_inner(); let input: ActionInput = serde_json::from_str(&action.action_payload) - .map_err(crate::Error::CouldNotDecodeActionPayload)?; + .map_err(crate::InternalError::CouldNotDecodeActionPayload)?; + + let workflow_run_id = action.workflow_run_id.clone(); + let workflow_step_run_id = action.step_run_id.clone(); // FIXME: Obviously, run this asynchronously rather than blocking the main listening loop. let action_event = match local_set .run_until(async move { - tokio::task::spawn_local(async move { action_callable(input.input).await }).await + tokio::task::spawn_local(async move { + let context = Context::new( + workflow_run_id, + workflow_step_run_id, + workflow_service_client, + ); + action_callable(context, input.input).await + }) + .await }) .await { @@ -113,15 +131,24 @@ where dispatcher .send_step_action_event(action_event) .await - .map_err(crate::Error::CouldNotSendStepStatus)? + .map_err(crate::InternalError::CouldNotSendStepStatus)? .into_inner(); Ok(()) } -pub(crate) async fn run( +pub(crate) async fn run( mut dispatcher: DispatcherClient< - tonic::service::interceptor::InterceptedService, + tonic::service::interceptor::InterceptedService< + tonic::transport::Channel, + ServiceWithAuthorization, + >, + >, + workflow_service_client: super::grpc::workflow_service_client::WorkflowServiceClient< + tonic::service::interceptor::InterceptedService< + tonic::transport::Channel, + ServiceWithAuthorization, + >, >, namespace: &str, worker_id: &str, @@ -129,10 +156,7 @@ pub(crate) async fn run( listener_v2_timeout: Option, mut interrupt_receiver: tokio::sync::mpsc::Receiver<()>, _heartbeat_interrupt_sender: tokio::sync::mpsc::Sender<()>, -) -> crate::Result<()> -where - F: tonic::service::Interceptor + Send + 'static, -{ +) -> crate::InternalResult<()> { use futures_util::StreamExt; let mut retries: usize = 0; @@ -147,7 +171,7 @@ where retries = 0; } if retries > DEFAULT_ACTION_LISTENER_RETRY_COUNT { - return Err(crate::Error::CouldNotSubscribeToActions( + return Err(crate::InternalError::CouldNotSubscribeToActions( DEFAULT_ACTION_LISTENER_RETRY_COUNT, )); } @@ -180,7 +204,7 @@ where let mut stream = tokio::select! { response = response => { response - .map_err(crate::Error::CouldNotListenToDispatcher)? + .map_err(crate::InternalError::CouldNotListenToDispatcher)? .into_inner() } result = interrupt_receiver.recv() => { @@ -229,7 +253,7 @@ where match action_type { ActionType::StartStepRun => { - handle_start_step_run(&mut dispatcher, &local_set, namespace, worker_id, &workflows, action).await?; + handle_start_step_run(&mut dispatcher, workflow_service_client.clone(), &local_set, namespace, worker_id, &workflows, action).await?; } ActionType::CancelStepRun => { todo!() diff --git a/src/worker/mod.rs b/src/worker/mod.rs index 52f195a..87fcafe 100644 --- a/src/worker/mod.rs +++ b/src/worker/mod.rs @@ -5,12 +5,45 @@ use grpc::{ CreateWorkflowJobOpts, CreateWorkflowStepOpts, CreateWorkflowVersionOpts, PutWorkflowRequest, WorkerRegisterRequest, WorkerRegisterResponse, WorkflowKind, }; -use secrecy::{ExposeSecret, SecretString}; use tonic::transport::Certificate; use tracing::info; use crate::{client::Environment, ClientTlStrategy, Workflow}; +#[derive(Clone)] +pub(crate) struct ServiceWithAuthorization { + authorization_header_value: secrecy::SecretString, +} + +impl ServiceWithAuthorization { + fn new(token: secrecy::SecretString) -> Self { + use secrecy::ExposeSecret; + + Self { + authorization_header_value: format!("Bearer {token}", token = token.expose_secret()) + .into(), + } + } +} + +impl tonic::service::Interceptor for ServiceWithAuthorization { + fn call( + &mut self, + mut request: tonic::Request<()>, + ) -> Result, tonic::Status> { + use secrecy::ExposeSecret; + let authorization_header_value: tonic::metadata::MetadataValue = + self.authorization_header_value + .expose_secret() + .parse() + .expect("must parse successfully"); + request + .metadata_mut() + .insert("authorization", authorization_header_value); + Ok(request) + } +} + #[derive(derive_builder::Builder)] #[builder(pattern = "owned", build_fn(private, name = "build_private"))] pub struct Worker<'a> { @@ -52,8 +85,8 @@ struct TokenClaims { fn construct_endpoint_url<'a>( tls_strategy: ClientTlStrategy, host_port_in_environment: Option<&'a str>, - token: &SecretString, -) -> crate::Result { + token: &secrecy::SecretString, +) -> crate::InternalResult { use secrecy::ExposeSecret; let protocol = match tls_strategy { @@ -64,7 +97,7 @@ fn construct_endpoint_url<'a>( Ok(format!( "{protocol}://{}", host_port_in_environment - .map(|value| crate::Result::Ok(std::borrow::Cow::Borrowed(value))) + .map(|value| crate::InternalResult::Ok(std::borrow::Cow::Borrowed(value))) .unwrap_or_else(|| { let key = jsonwebtoken::DecodingKey::from_secret(&[]); let mut validation = jsonwebtoken::Validation::new(jsonwebtoken::Algorithm::ES256); @@ -72,7 +105,7 @@ fn construct_endpoint_url<'a>( validation.validate_aud = false; let data: jsonwebtoken::TokenData = jsonwebtoken::decode(token.expose_secret(), &key, &validation) - .map_err(crate::Error::CouldNotDecodeToken)?; + .map_err(crate::InternalError::CouldNotDecodeToken)?; Ok(data.claims.grpc_broadcast_address.into()) })? )) @@ -84,8 +117,8 @@ async fn construct_endpoint( tls_root_ca_file: Option<&str>, tls_root_ca: Option<&str>, host_port: Option<&str>, - token: &SecretString, -) -> crate::Result { + token: &secrecy::SecretString, +) -> crate::InternalResult { let mut endpoint = tonic::transport::Endpoint::new(construct_endpoint_url(tls_strategy, host_port, token)?) .expect("endpoint must be valid"); @@ -99,7 +132,7 @@ async fn construct_endpoint( }; let extra_root_certificate = match (tls_root_ca, tls_root_ca_file) { (Some(_), Some(_)) => { - return Err(crate::Error::CantSetBothEnvironmentVariables( + return Err(crate::InternalError::CantSetBothEnvironmentVariables( "HATCHET_CLIENT_TLS_ROOT_CA", "HATCHET_CLIENT_TLS_ROOT_CA_FILE", )); @@ -109,7 +142,7 @@ async fn construct_endpoint( } (None, Some(tls_root_ca_file)) => Some(std::borrow::Cow::Owned( tokio::fs::read(tls_root_ca_file).await.map_err(|err| { - crate::Error::CouldNotReadFile(err, tls_root_ca_file.to_owned()) + crate::InternalError::CouldNotReadFile(err, tls_root_ca_file.to_owned()) })?, )), (None, None) => None, @@ -131,7 +164,7 @@ impl<'a> Worker<'a> { self.workflows.push(workflow); } - pub async fn start(self) -> crate::Result<()> { + pub async fn start(self) -> crate::InternalResult<()> { use tonic::IntoRequest; let (heartbeat_interrupt_sender1, heartbeat_interrupt_receiver) = @@ -174,24 +207,15 @@ impl<'a> Worker<'a> { ) .await?; - let authorization_header: tonic::metadata::MetadataValue = - format!("Bearer {token}", token = token.expose_secret()) - .parse() - .expect("must parse successfully"); - let authorization_header_cloned = authorization_header.clone(); + let interceptor = ServiceWithAuthorization::new(token.clone()); let mut workflow_service_client = grpc::workflow_service_client::WorkflowServiceClient::with_interceptor( endpoint .connect() .await - .map_err(crate::Error::CouldNotConnectToWorkflowService)?, - move |mut request: tonic::Request<()>| { - request - .metadata_mut() - .insert("authorization", authorization_header.clone()); - Ok(request) - }, + .map_err(crate::InternalError::CouldNotConnectToWorkflowService)?, + interceptor.clone(), ); let mut all_actions = vec![]; @@ -244,7 +268,7 @@ impl<'a> Worker<'a> { workflow_service_client .put_workflow(PutWorkflowRequest { opts: Some(opts) }) .await - .map_err(crate::Error::CouldNotPutWorkflow)?; + .map_err(crate::InternalError::CouldNotPutWorkflow)?; } // FIXME: Account for all the settings from `self.environment`. @@ -253,13 +277,8 @@ impl<'a> Worker<'a> { endpoint .connect() .await - .map_err(crate::Error::CouldNotConnectToDispatcher)?, - move |mut request: tonic::Request<()>| { - request - .metadata_mut() - .insert("authorization", authorization_header_cloned.clone()); - Ok(request) - }, + .map_err(crate::InternalError::CouldNotConnectToDispatcher)?, + interceptor.clone(), ) }; @@ -280,12 +299,12 @@ impl<'a> Worker<'a> { let WorkerRegisterResponse { worker_id, .. } = dispatcher .register(request) .await - .map_err(crate::Error::CouldNotRegisterWorker)? + .map_err(crate::InternalError::CouldNotRegisterWorker)? .into_inner(); futures_util::try_join! { heartbeat::run(dispatcher.clone(), &worker_id, heartbeat_interrupt_receiver), - listener::run(dispatcher, namespace, &worker_id, self.workflows, *listener_v2_timeout, listening_interrupt_receiver, heartbeat_interrupt_sender2), + listener::run(dispatcher, workflow_service_client, namespace, &worker_id, self.workflows, *listener_v2_timeout, listening_interrupt_receiver, heartbeat_interrupt_sender2), }?; Ok(()) diff --git a/src/workflow.rs b/src/workflow.rs index b6ad2e7..5b54266 100644 --- a/src/workflow.rs +++ b/src/workflow.rs @@ -1,9 +1,8 @@ -use std::sync::Arc; +use crate::step_function::Context; + +use super::step_function::StepFunction; -type StepFunction = - dyn Fn( - serde_json::Value, - ) -> futures_util::future::LocalBoxFuture<'static, anyhow::Result>; +use std::sync::Arc; #[derive(derive_builder::Builder)] #[builder(pattern = "owned")] @@ -26,11 +25,14 @@ impl StepBuilder { I: serde::de::DeserializeOwned, O: serde::ser::Serialize, Fut: std::future::Future> + 'static, - F: Fn(I) -> Fut, + F: Fn(Context, I) -> Fut, { use futures_util::FutureExt; - self.function = Some(Arc::new(|value| { - let result = function(serde_json::from_value(value).expect("must succeed")); + self.function = Some(Arc::new(|context, value| { + let result = function( + context, + serde_json::from_value(value).expect("must succeed"), + ); async { Ok(serde_json::to_value(result.await?).expect("must succeed")) }.boxed_local() })); self