From 40b2e10e727023e34a028ab898710412cbbafd2c Mon Sep 17 00:00:00 2001 From: Francesco Guardiani Date: Tue, 1 Oct 2024 18:18:56 +0200 Subject: [PATCH] Implement GetInvocationId and CancelInvocation entries (#10) * Squashed 'service-protocol/' changes from 0d6b476..5129eca 5129eca Add idempotency_key to CallEntryMessage & OneWayCallEntryMessage. (#97) 65560bf Add CancelInvocation and GetCallInvocationId entries. (#96) git-subtree-dir: service-protocol git-subtree-split: 5129eca343a214665237cd69b5e6b5d1842d1cab * Implement `sys_get_call_invocation_id` and `sys_cancel_invocation` Refactor header macro New protocol * Add idempotency key * Add version checks --- .../dev/restate/service/protocol.proto | 41 ++- .../service-invocation-protocol.md | 38 ++- src/lib.rs | 40 ++- .../generated/dev.restate.service.protocol.rs | 57 ++++ src/service_protocol/header.rs | 321 +++++++----------- src/service_protocol/messages.rs | 47 ++- src/service_protocol/version.rs | 6 +- src/tests/async_result.rs | 1 + src/tests/calls.rs | 181 ++++++++++ src/tests/failures.rs | 1 + src/tests/mod.rs | 1 + src/tests/state.rs | 12 +- src/vm/context.rs | 1 + src/vm/errors.rs | 36 +- src/vm/mod.rs | 94 ++++- 15 files changed, 636 insertions(+), 241 deletions(-) create mode 100644 src/tests/calls.rs diff --git a/service-protocol/dev/restate/service/protocol.proto b/service-protocol/dev/restate/service/protocol.proto index df39278..405fee0 100644 --- a/service-protocol/dev/restate/service/protocol.proto +++ b/service-protocol/dev/restate/service/protocol.proto @@ -22,6 +22,11 @@ enum ServiceProtocolVersion { // Added // * Entry retry mechanism: ErrorMessage.next_retry_delay, StartMessage.retry_count_since_last_stored_entry and StartMessage.duration_since_last_stored_entry V2 = 2; + // Added + // * New entry to cancel invocations: CancelInvocationEntryMessage + // * New entry to retrieve the invocation id: GetCallInvocationIdEntryMessage + // * New field to set idempotency key for Call entries + V3 = 3; } // --- Core frames --- @@ -313,6 +318,8 @@ message CallEntryMessage { // If this invocation has a key associated (e.g. for objects and workflows), then this key is filled in. Empty otherwise. string key = 5; + string idempotency_key = 6; + oneof result { bytes value = 14; Failure failure = 15; @@ -342,6 +349,8 @@ message OneWayCallEntryMessage { // If this invocation has a key associated (e.g. for objects and workflows), then this key is filled in. Empty otherwise. string key = 6; + string idempotency_key = 7; + // Entry name string name = 12; } @@ -383,13 +392,43 @@ message CompleteAwakeableEntryMessage { message RunEntryMessage { oneof result { bytes value = 14; - dev.restate.service.protocol.Failure failure = 15; + Failure failure = 15; }; // Entry name string name = 12; } +// Completable: No +// Fallible: Yes +// Type: 0x0C00 + 6 +message CancelInvocationEntryMessage { + oneof target { + // Target invocation id to cancel + string invocation_id = 1; + // Target index of the call/one way call journal entry in this journal. + uint32 call_entry_index = 2; + } + + // Entry name + string name = 12; +} + +// Completable: Yes +// Fallible: Yes +// Type: 0x0C00 + 7 +message GetCallInvocationIdEntryMessage { + // Index of the call/one way call journal entry in this journal. + uint32 call_entry_index = 1; + + oneof result { + string value = 14; + Failure failure = 15; + }; + + string name = 12; +} + // --- Nested messages // This failure object carries user visible errors, diff --git a/service-protocol/service-invocation-protocol.md b/service-protocol/service-invocation-protocol.md index 22635bc..89e23cb 100644 --- a/service-protocol/service-invocation-protocol.md +++ b/service-protocol/service-invocation-protocol.md @@ -330,24 +330,26 @@ used for observability purposes by Restate observability tools. The following tables describe the currently available journal entries. For more details, check the protobuf message descriptions in [`protocol.proto`](dev/restate/service/protocol.proto). -| Message | Type | Completable | Fallible | Description | -| ------------------------------- | -------- | ----------- | -------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| `InputEntryMessage` | `0x0400` | No | No | Carries the invocation input message(s) of the invocation. | -| `GetStateEntryMessage` | `0x0800` | Yes | No | Get the value of a service instance state key. | -| `GetStateKeysEntryMessage` | `0x0804` | Yes | No | Get all the known state keys for this service instance. Note: the completion value for this message is a protobuf of type `GetStateKeysEntryMessage.StateKeys`. | -| `SleepEntryMessage` | `0x0C00` | Yes | No | Initiate a timer that completes after the given time. | -| `CallEntryMessage` | `0x0C01` | Yes | Yes | Invoke another Restate service. | -| `AwakeableEntryMessage` | `0x0C03` | Yes | No | Arbitrary result container which can be completed from another service, given a specific id. See [Awakeable identifier](#awakeable-identifier) for more details. | -| `OneWayCallEntryMessage` | `0x0C02` | No | Yes | Invoke another Restate service at the given time, without waiting for the response. | -| `CompleteAwakeableEntryMessage` | `0x0C04` | No | Yes | Complete an `Awakeable`, given its id. See [Awakeable identifier](#awakeable-identifier) for more details. | -| `OutputEntryMessage` | `0x0401` | No | No | Carries the invocation output message(s) or terminal failure of the invocation. | -| `SetStateEntryMessage` | `0x0800` | No | No | Set the value of a service instance state key. | -| `ClearStateEntryMessage` | `0x0801` | No | No | Clear the value of a service instance state key. | -| `ClearAllStateEntryMessage` | `0x0802` | No | No | Clear all the values of the service instance state. | -| `RunEntryMessage` | `0x0C05` | No | No | Run non-deterministic user provided code and persist the result. | -| `GetPromiseEntryMessage` | `0x0808` | Yes | No | Get or wait the value of the given promise. If the value is not present yet, this entry will block waiting for the value. | -| `PeekPromiseEntryMessage` | `0x0809` | Yes | No | Get the value of the given promise. If the value is not present, this entry completes immediately with empty completion. | -| `CompletePromiseEntryMessage` | `0x080A` | Yes | No | Complete the given promise. If the promise was completed already, this entry completes with a failure. | +| Message | Type | Completable | Fallible | Description | +|-----------------------------------|----------|-------------|----------|------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| `InputEntryMessage` | `0x0400` | No | No | Carries the invocation input message(s) of the invocation. | +| `GetStateEntryMessage` | `0x0800` | Yes | No | Get the value of a service instance state key. | +| `GetStateKeysEntryMessage` | `0x0804` | Yes | No | Get all the known state keys for this service instance. Note: the completion value for this message is a protobuf of type `GetStateKeysEntryMessage.StateKeys`. | +| `SleepEntryMessage` | `0x0C00` | Yes | No | Initiate a timer that completes after the given time. | +| `CallEntryMessage` | `0x0C01` | Yes | Yes | Invoke another Restate service. | +| `AwakeableEntryMessage` | `0x0C03` | Yes | No | Arbitrary result container which can be completed from another service, given a specific id. See [Awakeable identifier](#awakeable-identifier) for more details. | +| `OneWayCallEntryMessage` | `0x0C02` | No | Yes | Invoke another Restate service at the given time, without waiting for the response. | +| `CompleteAwakeableEntryMessage` | `0x0C04` | No | Yes | Complete an `Awakeable`, given its id. See [Awakeable identifier](#awakeable-identifier) for more details. | +| `OutputEntryMessage` | `0x0401` | No | No | Carries the invocation output message(s) or terminal failure of the invocation. | +| `SetStateEntryMessage` | `0x0800` | No | No | Set the value of a service instance state key. | +| `ClearStateEntryMessage` | `0x0801` | No | No | Clear the value of a service instance state key. | +| `ClearAllStateEntryMessage` | `0x0802` | No | No | Clear all the values of the service instance state. | +| `RunEntryMessage` | `0x0C05` | No | No | Run non-deterministic user provided code and persist the result. | +| `GetPromiseEntryMessage` | `0x0808` | Yes | No | Get or wait the value of the given promise. If the value is not present yet, this entry will block waiting for the value. | +| `PeekPromiseEntryMessage` | `0x0809` | Yes | No | Get the value of the given promise. If the value is not present, this entry completes immediately with empty completion. | +| `CompletePromiseEntryMessage` | `0x080A` | Yes | No | Complete the given promise. If the promise was completed already, this entry completes with a failure. | +| `CancelInvocationEntryMessage` | `0x0C06` | No | Yes | Cancel the target invocation id or the target journal entry. | +| `GetCallInvocationIdEntryMessage` | `0x0C07` | Yes | Yes | Get the invocation id of a previously created call/one way call. | #### Awakeable identifier diff --git a/src/lib.rs b/src/lib.rs index 7183281..d88a1dc 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -122,6 +122,7 @@ pub struct Target { pub service: String, pub handler: String, pub key: Option, + pub idempotency_key: Option, } #[derive(Debug, Hash, Clone, Copy, Eq, PartialEq)] @@ -139,6 +140,21 @@ impl From for u32 { } } +#[derive(Debug, Clone, Copy, Eq, PartialEq)] +pub struct SendHandle(u32); + +impl From for SendHandle { + fn from(value: u32) -> Self { + SendHandle(value) + } +} + +impl From for u32 { + fn from(value: SendHandle) -> Self { + value.0 + } +} + #[derive(Debug, Eq, PartialEq)] pub enum Value { /// a void/None/undefined success @@ -147,6 +163,8 @@ pub enum Value { Failure(TerminalFailure), /// Only returned for get_state_keys StateKeys(Vec), + /// Only returned for get_call_invocation_id + InvocationId(String), CombinatorResult(Vec), } @@ -196,6 +214,19 @@ impl From for Value { } } +#[derive(Debug, Eq, PartialEq)] +pub enum GetInvocationIdTarget { + CallEntry(AsyncResultHandle), + SendEntry(SendHandle), +} + +#[derive(Debug, Eq, PartialEq)] +pub enum CancelInvocationTarget { + InvocationId(String), + CallEntry(AsyncResultHandle), + SendEntry(SendHandle), +} + #[derive(Debug, Eq, PartialEq)] pub enum TakeOutputResult { Buffer(Bytes), @@ -274,7 +305,7 @@ pub trait VM: Sized { target: Target, input: Bytes, execution_time_since_unix_epoch: Option, - ) -> VMResult<()>; + ) -> VMResult; fn sys_awakeable(&mut self) -> VMResult<(String, AsyncResultHandle)>; @@ -298,6 +329,13 @@ pub trait VM: Sized { retry_policy: RetryPolicy, ) -> VMResult; + fn sys_get_call_invocation_id( + &mut self, + call: GetInvocationIdTarget, + ) -> VMResult; + + fn sys_cancel_invocation(&mut self, target: CancelInvocationTarget) -> VMResult<()>; + fn sys_write_output(&mut self, value: NonEmptyValue) -> VMResult<()>; fn sys_end(&mut self) -> VMResult<()>; diff --git a/src/service_protocol/generated/dev.restate.service.protocol.rs b/src/service_protocol/generated/dev.restate.service.protocol.rs index b16279d..43f35bd 100644 --- a/src/service_protocol/generated/dev.restate.service.protocol.rs +++ b/src/service_protocol/generated/dev.restate.service.protocol.rs @@ -361,6 +361,8 @@ pub struct CallEntryMessage { /// If this invocation has a key associated (e.g. for objects and workflows), then this key is filled in. Empty otherwise. #[prost(string, tag = "5")] pub key: ::prost::alloc::string::String, + #[prost(string, tag = "6")] + pub idempotency_key: ::prost::alloc::string::String, /// Entry name #[prost(string, tag = "12")] pub name: ::prost::alloc::string::String, @@ -399,6 +401,8 @@ pub struct OneWayCallEntryMessage { /// If this invocation has a key associated (e.g. for objects and workflows), then this key is filled in. Empty otherwise. #[prost(string, tag = "6")] pub key: ::prost::alloc::string::String, + #[prost(string, tag = "7")] + pub idempotency_key: ::prost::alloc::string::String, /// Entry name #[prost(string, tag = "12")] pub name: ::prost::alloc::string::String, @@ -471,6 +475,52 @@ pub mod run_entry_message { Failure(super::Failure), } } +/// Completable: No +/// Fallible: Yes +/// Type: 0x0C00 + 6 +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct CancelInvocationEntryMessage { + /// Entry name + #[prost(string, tag = "12")] + pub name: ::prost::alloc::string::String, + #[prost(oneof = "cancel_invocation_entry_message::Target", tags = "1, 2")] + pub target: ::core::option::Option, +} +/// Nested message and enum types in `CancelInvocationEntryMessage`. +pub mod cancel_invocation_entry_message { + #[derive(Clone, PartialEq, ::prost::Oneof)] + pub enum Target { + /// Target invocation id to cancel + #[prost(string, tag = "1")] + InvocationId(::prost::alloc::string::String), + /// Target index of the call/one way call journal entry in this journal. + #[prost(uint32, tag = "2")] + CallEntryIndex(u32), + } +} +/// Completable: Yes +/// Fallible: Yes +/// Type: 0x0C00 + 7 +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct GetCallInvocationIdEntryMessage { + /// Index of the call/one way call journal entry in this journal. + #[prost(uint32, tag = "1")] + pub call_entry_index: u32, + #[prost(string, tag = "12")] + pub name: ::prost::alloc::string::String, + #[prost(oneof = "get_call_invocation_id_entry_message::Result", tags = "14, 15")] + pub result: ::core::option::Option, +} +/// Nested message and enum types in `GetCallInvocationIdEntryMessage`. +pub mod get_call_invocation_id_entry_message { + #[derive(Clone, PartialEq, ::prost::Oneof)] + pub enum Result { + #[prost(string, tag = "14")] + Value(::prost::alloc::string::String), + #[prost(message, tag = "15")] + Failure(super::Failure), + } +} /// This failure object carries user visible errors, /// e.g. invocation failure return value or failure result of an InvokeEntryMessage. #[derive(Clone, PartialEq, ::prost::Message)] @@ -501,6 +551,11 @@ pub enum ServiceProtocolVersion { /// Added /// * Entry retry mechanism: ErrorMessage.next_retry_delay, StartMessage.retry_count_since_last_stored_entry and StartMessage.duration_since_last_stored_entry V2 = 2, + /// Added + /// * New entry to cancel invocations: CancelInvocationEntryMessage + /// * New entry to retrieve the invocation id: GetCallInvocationIdEntryMessage + /// * New field to set idempotency key for Call entries + V3 = 3, } impl ServiceProtocolVersion { /// String value of the enum field names used in the ProtoBuf definition. @@ -512,6 +567,7 @@ impl ServiceProtocolVersion { Self::Unspecified => "SERVICE_PROTOCOL_VERSION_UNSPECIFIED", Self::V1 => "V1", Self::V2 => "V2", + Self::V3 => "V3", } } /// Creates an enum from field names used in the ProtoBuf definition. @@ -520,6 +576,7 @@ impl ServiceProtocolVersion { "SERVICE_PROTOCOL_VERSION_UNSPECIFIED" => Some(Self::Unspecified), "V1" => Some(Self::V1), "V2" => Some(Self::V2), + "V3" => Some(Self::V3), _ => None, } } diff --git a/src/service_protocol/header.rs b/src/service_protocol/header.rs index 7d00ab0..0c1c09f 100644 --- a/src/service_protocol/header.rs +++ b/src/service_protocol/header.rs @@ -8,83 +8,124 @@ // the Business Source License, use of this software will be governed // by the Apache License, Version 2.0. -const CUSTOM_MESSAGE_MASK: u16 = 0xFC00; +const CUSTOM_ENTRY_MASK: u16 = 0xFC00; const COMPLETED_MASK: u64 = 0x0001_0000_0000; const REQUIRES_ACK_MASK: u64 = 0x8000_0000_0000; type MessageTypeId = u16; -#[derive(Copy, Clone, Debug, PartialEq, Eq)] -pub enum MessageKind { - Core, - IO, - State, - Syscall, - CustomEntry, -} +#[derive(Debug, thiserror::Error)] +#[error("unknown protocol.message code {0:#x}")] +pub struct UnknownMessageType(u16); -#[derive(Copy, Clone, Debug, PartialEq, Eq)] -pub enum MessageType { - Start, - Completion, - Suspension, - Error, - End, - EntryAck, - InputEntry, - OutputEntry, - GetStateEntry, - SetStateEntry, - ClearStateEntry, - GetStateKeysEntry, - ClearAllStateEntry, - SleepEntry, - CallEntry, - OneWayCallEntry, - AwakeableEntry, - CompleteAwakeableEntry, - RunEntry, - GetPromiseEntry, - PeekPromiseEntry, - CompletePromiseEntry, - CombinatorEntry, - CustomEntry(u16), -} +// This macro generates the MessageKind enum, together with the conversions back and forth to MessageTypeId +macro_rules! gen_message_type_enum { + (@gen_enum [] -> [$($body:tt)*]) => { + #[derive(Copy, Clone, Debug, PartialEq, Eq)] + pub enum MessageType { + $($body)* + CustomEntry(u16) + } + }; + (@gen_enum [$variant:ident Entry = $id:literal, $($tail:tt)*] -> [$($body:tt)*]) => { + paste::paste! { gen_message_type_enum!(@gen_enum [$($tail)*] -> [[<$variant Entry>], $($body)*]); } + }; + (@gen_enum [$variant:ident = $id:literal, $($tail:tt)*] -> [$($body:tt)*]) => { + gen_message_type_enum!(@gen_enum [$($tail)*] -> [$variant, $($body)*]); + }; -impl MessageType { - fn kind(&self) -> MessageKind { - match self { - MessageType::Start => MessageKind::Core, - MessageType::Completion => MessageKind::Core, - MessageType::Suspension => MessageKind::Core, - MessageType::Error => MessageKind::Core, - MessageType::End => MessageKind::Core, - MessageType::EntryAck => MessageKind::Core, - MessageType::InputEntry => MessageKind::IO, - MessageType::OutputEntry => MessageKind::IO, - MessageType::GetStateEntry => MessageKind::State, - MessageType::SetStateEntry => MessageKind::State, - MessageType::ClearStateEntry => MessageKind::State, - MessageType::GetStateKeysEntry => MessageKind::State, - MessageType::ClearAllStateEntry => MessageKind::State, - MessageType::SleepEntry => MessageKind::Syscall, - MessageType::CallEntry => MessageKind::Syscall, - MessageType::OneWayCallEntry => MessageKind::Syscall, - MessageType::AwakeableEntry => MessageKind::Syscall, - MessageType::CompleteAwakeableEntry => MessageKind::Syscall, - MessageType::RunEntry => MessageKind::Syscall, - MessageType::GetPromiseEntry => MessageKind::State, - MessageType::PeekPromiseEntry => MessageKind::State, - MessageType::CompletePromiseEntry => MessageKind::State, - MessageType::CombinatorEntry => MessageKind::Syscall, - MessageType::CustomEntry(_) => MessageKind::CustomEntry, + (@gen_is_entry_impl [] -> [$($variant:ident, $is_entry:literal,)*]) => { + impl MessageType { + pub fn is_entry(&self) -> bool { + match self { + $(MessageType::$variant => $is_entry,)* + MessageType::CustomEntry(_) => true + } + } } - } + }; + (@gen_is_entry_impl [$variant:ident Entry = $id:literal, $($tail:tt)*] -> [$($body:tt)*]) => { + paste::paste! { gen_message_type_enum!(@gen_is_entry_impl [$($tail)*] -> [[<$variant Entry>], true, $($body)*]); } + }; + (@gen_is_entry_impl [$variant:ident = $id:literal, $($tail:tt)*] -> [$($body:tt)*]) => { + gen_message_type_enum!(@gen_is_entry_impl [$($tail)*] -> [$variant, false, $($body)*]); + }; - pub fn is_entry(&self) -> bool { - !matches!(self.kind(), MessageKind::Core) - } + (@gen_to_id [] -> [$($variant:ident, $id:literal,)*]) => { + impl From for MessageTypeId { + fn from(mt: MessageType) -> Self { + match mt { + $(MessageType::$variant => $id,)* + MessageType::CustomEntry(id) => id + } + } + } + }; + (@gen_to_id [$variant:ident Entry = $id:literal, $($tail:tt)*] -> [$($body:tt)*]) => { + paste::paste! { gen_message_type_enum!(@gen_to_id [$($tail)*] -> [[<$variant Entry>], $id, $($body)*]); } + }; + (@gen_to_id [$variant:ident = $id:literal, $($tail:tt)*] -> [$($body:tt)*]) => { + gen_message_type_enum!(@gen_to_id [$($tail)*] -> [$variant, $id, $($body)*]); + }; + + (@gen_from_id [] -> [$($variant:ident, $id:literal,)*]) => { + impl TryFrom for MessageType { + type Error = UnknownMessageType; + + fn try_from(value: MessageTypeId) -> Result { + match value { + $($id => Ok(MessageType::$variant),)* + v if (v & CUSTOM_ENTRY_MASK) != 0 => Ok(MessageType::CustomEntry(v)), + v => Err(UnknownMessageType(v)), + } + } + } + }; + (@gen_from_id [$variant:ident Entry = $id:literal, $($tail:tt)*] -> [$($body:tt)*]) => { + paste::paste! { gen_message_type_enum!(@gen_from_id [$($tail)*] -> [[<$variant Entry>], $id, $($body)*]); } + }; + (@gen_from_id [$variant:ident = $id:literal, $($tail:tt)*] -> [$($body:tt)*]) => { + gen_message_type_enum!(@gen_from_id [$($tail)*] -> [$variant, $id, $($body)*]); + }; + + // Entrypoint of the macro + ($($tokens:tt)*) => { + gen_message_type_enum!(@gen_enum [$($tokens)*] -> []); + gen_message_type_enum!(@gen_is_entry_impl [$($tokens)*] -> []); + gen_message_type_enum!(@gen_to_id [$($tokens)*] -> []); + gen_message_type_enum!(@gen_from_id [$($tokens)*] -> []); + }; +} +gen_message_type_enum!( + Start = 0x0000, + Completion = 0x0001, + Suspension = 0x0002, + Error = 0x0003, + End = 0x0005, + EntryAck = 0x0004, + Input Entry = 0x0400, + Output Entry = 0x0401, + GetState Entry = 0x0800, + SetState Entry = 0x0801, + ClearState Entry = 0x0802, + GetStateKeys Entry = 0x0804, + ClearAllState Entry = 0x0803, + GetPromise Entry = 0x0808, + PeekPromise Entry = 0x0809, + CompletePromise Entry = 0x080A, + Sleep Entry = 0x0C00, + Call Entry = 0x0C01, + OneWayCall Entry = 0x0C02, + Awakeable Entry = 0x0C03, + CompleteAwakeable Entry = 0x0C04, + Run Entry = 0x0C05, + CancelInvocation Entry = 0x0C06, + GetCallInvocationId Entry = 0x0C07, + Combinator Entry = 0xFC02, +); + +impl MessageType { fn has_completed_flag(&self) -> bool { matches!( self, @@ -96,108 +137,9 @@ impl MessageType { | MessageType::GetPromiseEntry | MessageType::PeekPromiseEntry | MessageType::CompletePromiseEntry + | MessageType::GetCallInvocationIdEntry ) } - - fn has_requires_ack_flag(&self) -> bool { - matches!( - self.kind(), - MessageKind::State | MessageKind::IO | MessageKind::Syscall | MessageKind::CustomEntry - ) - } -} - -const START_MESSAGE_TYPE: u16 = 0x0000; -const COMPLETION_MESSAGE_TYPE: u16 = 0x0001; -const SUSPENSION_MESSAGE_TYPE: u16 = 0x0002; -const ERROR_MESSAGE_TYPE: u16 = 0x0003; -const ENTRY_ACK_MESSAGE_TYPE: u16 = 0x0004; -const END_MESSAGE_TYPE: u16 = 0x0005; -const INPUT_ENTRY_MESSAGE_TYPE: u16 = 0x0400; -const OUTPUT_ENTRY_MESSAGE_TYPE: u16 = 0x0401; -const GET_STATE_ENTRY_MESSAGE_TYPE: u16 = 0x0800; -const SET_STATE_ENTRY_MESSAGE_TYPE: u16 = 0x0801; -const CLEAR_STATE_ENTRY_MESSAGE_TYPE: u16 = 0x0802; -const CLEAR_ALL_STATE_ENTRY_MESSAGE_TYPE: u16 = 0x0803; -const GET_STATE_KEYS_ENTRY_MESSAGE_TYPE: u16 = 0x0804; -const GET_PROMISE_ENTRY_MESSAGE_TYPE: u16 = 0x0808; -const PEEK_PROMISE_ENTRY_MESSAGE_TYPE: u16 = 0x0809; -const COMPLETE_PROMISE_ENTRY_MESSAGE_TYPE: u16 = 0x080A; -const SLEEP_ENTRY_MESSAGE_TYPE: u16 = 0x0C00; -const INVOKE_ENTRY_MESSAGE_TYPE: u16 = 0x0C01; -const BACKGROUND_INVOKE_ENTRY_MESSAGE_TYPE: u16 = 0x0C02; -const AWAKEABLE_ENTRY_MESSAGE_TYPE: u16 = 0x0C03; -const COMPLETE_AWAKEABLE_ENTRY_MESSAGE_TYPE: u16 = 0x0C04; -const SIDE_EFFECT_ENTRY_MESSAGE_TYPE: u16 = 0x0C05; -const COMBINATOR_ENTRY_MESSAGE_TYPE: u16 = 0xFC02; - -impl From for MessageTypeId { - fn from(mt: MessageType) -> Self { - match mt { - MessageType::Start => START_MESSAGE_TYPE, - MessageType::Completion => COMPLETION_MESSAGE_TYPE, - MessageType::Suspension => SUSPENSION_MESSAGE_TYPE, - MessageType::Error => ERROR_MESSAGE_TYPE, - MessageType::End => END_MESSAGE_TYPE, - MessageType::EntryAck => ENTRY_ACK_MESSAGE_TYPE, - MessageType::InputEntry => INPUT_ENTRY_MESSAGE_TYPE, - MessageType::OutputEntry => OUTPUT_ENTRY_MESSAGE_TYPE, - MessageType::GetStateEntry => GET_STATE_ENTRY_MESSAGE_TYPE, - MessageType::SetStateEntry => SET_STATE_ENTRY_MESSAGE_TYPE, - MessageType::ClearStateEntry => CLEAR_STATE_ENTRY_MESSAGE_TYPE, - MessageType::ClearAllStateEntry => CLEAR_ALL_STATE_ENTRY_MESSAGE_TYPE, - MessageType::GetStateKeysEntry => GET_STATE_KEYS_ENTRY_MESSAGE_TYPE, - MessageType::SleepEntry => SLEEP_ENTRY_MESSAGE_TYPE, - MessageType::CallEntry => INVOKE_ENTRY_MESSAGE_TYPE, - MessageType::OneWayCallEntry => BACKGROUND_INVOKE_ENTRY_MESSAGE_TYPE, - MessageType::AwakeableEntry => AWAKEABLE_ENTRY_MESSAGE_TYPE, - MessageType::CompleteAwakeableEntry => COMPLETE_AWAKEABLE_ENTRY_MESSAGE_TYPE, - MessageType::RunEntry => SIDE_EFFECT_ENTRY_MESSAGE_TYPE, - MessageType::GetPromiseEntry => GET_PROMISE_ENTRY_MESSAGE_TYPE, - MessageType::PeekPromiseEntry => PEEK_PROMISE_ENTRY_MESSAGE_TYPE, - MessageType::CompletePromiseEntry => COMPLETE_PROMISE_ENTRY_MESSAGE_TYPE, - MessageType::CombinatorEntry => COMBINATOR_ENTRY_MESSAGE_TYPE, - MessageType::CustomEntry(id) => id, - } - } -} - -#[derive(Debug, thiserror::Error)] -#[error("unknown protocol.message code {0:#x}")] -pub struct UnknownMessageType(u16); - -impl TryFrom for MessageType { - type Error = UnknownMessageType; - - fn try_from(value: MessageTypeId) -> Result { - match value { - START_MESSAGE_TYPE => Ok(MessageType::Start), - COMPLETION_MESSAGE_TYPE => Ok(MessageType::Completion), - SUSPENSION_MESSAGE_TYPE => Ok(MessageType::Suspension), - ERROR_MESSAGE_TYPE => Ok(MessageType::Error), - END_MESSAGE_TYPE => Ok(MessageType::End), - ENTRY_ACK_MESSAGE_TYPE => Ok(MessageType::EntryAck), - INPUT_ENTRY_MESSAGE_TYPE => Ok(MessageType::InputEntry), - OUTPUT_ENTRY_MESSAGE_TYPE => Ok(MessageType::OutputEntry), - GET_STATE_ENTRY_MESSAGE_TYPE => Ok(MessageType::GetStateEntry), - SET_STATE_ENTRY_MESSAGE_TYPE => Ok(MessageType::SetStateEntry), - CLEAR_STATE_ENTRY_MESSAGE_TYPE => Ok(MessageType::ClearStateEntry), - GET_STATE_KEYS_ENTRY_MESSAGE_TYPE => Ok(MessageType::GetStateKeysEntry), - CLEAR_ALL_STATE_ENTRY_MESSAGE_TYPE => Ok(MessageType::ClearAllStateEntry), - SLEEP_ENTRY_MESSAGE_TYPE => Ok(MessageType::SleepEntry), - INVOKE_ENTRY_MESSAGE_TYPE => Ok(MessageType::CallEntry), - BACKGROUND_INVOKE_ENTRY_MESSAGE_TYPE => Ok(MessageType::OneWayCallEntry), - AWAKEABLE_ENTRY_MESSAGE_TYPE => Ok(MessageType::AwakeableEntry), - COMPLETE_AWAKEABLE_ENTRY_MESSAGE_TYPE => Ok(MessageType::CompleteAwakeableEntry), - GET_PROMISE_ENTRY_MESSAGE_TYPE => Ok(MessageType::GetPromiseEntry), - PEEK_PROMISE_ENTRY_MESSAGE_TYPE => Ok(MessageType::PeekPromiseEntry), - COMPLETE_PROMISE_ENTRY_MESSAGE_TYPE => Ok(MessageType::CompletePromiseEntry), - SIDE_EFFECT_ENTRY_MESSAGE_TYPE => Ok(MessageType::RunEntry), - COMBINATOR_ENTRY_MESSAGE_TYPE => Ok(MessageType::CombinatorEntry), - v if ((v & CUSTOM_MESSAGE_MASK) != 0) => Ok(MessageType::CustomEntry(v)), - v => Err(UnknownMessageType(v)), - } - } } #[derive(Copy, Clone, Debug, PartialEq, Eq)] @@ -262,11 +204,6 @@ impl MessageHeader { } } - #[inline] - pub fn message_kind(&self) -> MessageKind { - self.ty.kind() - } - #[inline] pub fn message_type(&self) -> MessageType { self.ty @@ -308,7 +245,7 @@ impl TryFrom for MessageHeader { let ty: MessageType = ty_code.try_into()?; let completed_flag = read_flag_if!(ty.has_completed_flag(), value, COMPLETED_MASK); - let requires_ack_flag = read_flag_if!(ty.has_requires_ack_flag(), value, REQUIRES_ACK_MASK); + let requires_ack_flag = read_flag_if!(ty.is_entry(), value, REQUIRES_ACK_MASK); let length = value as u32; Ok(MessageHeader::_new( @@ -349,7 +286,7 @@ impl From for u64 { #[cfg(test)] mod tests { - use super::{MessageKind::*, MessageType::*, *}; + use super::{MessageType::*, *}; impl MessageHeader { fn new_completable_entry(ty: MessageType, completed: bool, length: u32) -> Self { @@ -358,65 +295,52 @@ mod tests { } macro_rules! roundtrip_test { - ($test_name:ident, $header:expr, $ty:expr, $kind:expr, $len:expr) => { - roundtrip_test!($test_name, $header, $ty, $kind, $len, None, None, None); + ($test_name:ident, $header:expr, $ty:expr, $len:expr) => { + roundtrip_test!($test_name, $header, $ty, $len, None, None, None); }; - ($test_name:ident, $header:expr, $ty:expr, $kind:expr, $len:expr, version: $protocol_version:expr) => { + ($test_name:ident, $header:expr, $ty:expr, $len:expr, version: $protocol_version:expr) => { roundtrip_test!( $test_name, $header, $ty, - $kind, $len, None, Some($protocol_version), None ); }; - ($test_name:ident, $header:expr, $ty:expr, $kind:expr, $len:expr, completed: $completed:expr) => { - roundtrip_test!( - $test_name, - $header, - $ty, - $kind, - $len, - Some($completed), - None, - None - ); + ($test_name:ident, $header:expr, $ty:expr, $len:expr, completed: $completed:expr) => { + roundtrip_test!($test_name, $header, $ty, $len, Some($completed), None, None); }; - ($test_name:ident, $header:expr, $ty:expr, $kind:expr, $len:expr, requires_ack: $requires_ack:expr) => { + ($test_name:ident, $header:expr, $ty:expr, $len:expr, requires_ack: $requires_ack:expr) => { roundtrip_test!( $test_name, $header, $ty, - $kind, $len, None, None, Some($requires_ack) ); }; - ($test_name:ident, $header:expr, $ty:expr, $kind:expr, $len:expr, requires_ack: $requires_ack:expr, completed: $completed:expr) => { + ($test_name:ident, $header:expr, $ty:expr, $len:expr, requires_ack: $requires_ack:expr, completed: $completed:expr) => { roundtrip_test!( $test_name, $header, $ty, - $kind, $len, Some($completed), None, Some($requires_ack) ); }; - ($test_name:ident, $header:expr, $ty:expr, $kind:expr, $len:expr, $completed:expr, $protocol_version:expr, $requires_ack:expr) => { + ($test_name:ident, $header:expr, $ty:expr, $len:expr, $completed:expr, $protocol_version:expr, $requires_ack:expr) => { #[test] fn $test_name() { let serialized: u64 = $header.into(); let header: MessageHeader = serialized.try_into().unwrap(); assert_eq!(header.message_type(), $ty); - assert_eq!(header.message_kind(), $kind); assert_eq!(header.completed(), $completed); assert_eq!(header.requires_ack(), $requires_ack); assert_eq!(header.frame_length(), $len); @@ -428,7 +352,6 @@ mod tests { completion, MessageHeader::new(Completion, 22), Completion, - Core, 22 ); @@ -436,7 +359,6 @@ mod tests { completed_get_state, MessageHeader::new_completable_entry(GetStateEntry, true, 0), GetStateEntry, - State, 0, requires_ack: false, completed: true @@ -446,7 +368,6 @@ mod tests { not_completed_get_state, MessageHeader::new_completable_entry(GetStateEntry, false, 0), GetStateEntry, - State, 0, requires_ack: false, completed: false @@ -456,7 +377,6 @@ mod tests { completed_get_state_with_len, MessageHeader::new_completable_entry(GetStateEntry, true, 10341), GetStateEntry, - State, 10341, requires_ack: false, completed: true @@ -466,7 +386,6 @@ mod tests { set_state_with_requires_ack, MessageHeader::_new(SetStateEntry, None, Some(true), 10341), SetStateEntry, - State, 10341, requires_ack: true ); @@ -475,7 +394,6 @@ mod tests { custom_entry, MessageHeader::new(MessageType::CustomEntry(0xFC00), 10341), MessageType::CustomEntry(0xFC00), - MessageKind::CustomEntry, 10341, requires_ack: false ); @@ -484,7 +402,6 @@ mod tests { custom_entry_with_requires_ack, MessageHeader::_new(MessageType::CustomEntry(0xFC00), None, Some(true), 10341), MessageType::CustomEntry(0xFC00), - MessageKind::CustomEntry, 10341, requires_ack: true ); diff --git a/src/service_protocol/messages.rs b/src/service_protocol/messages.rs index d1da271..d5b5e70 100644 --- a/src/service_protocol/messages.rs +++ b/src/service_protocol/messages.rs @@ -1,6 +1,9 @@ use crate::service_protocol::messages::get_state_keys_entry_message::StateKeys; use crate::service_protocol::{MessageHeader, MessageType}; -use crate::vm::errors::{DecodeStateKeysProst, DecodeStateKeysUtf8, EmptyStateKeys}; +use crate::vm::errors::{ + DecodeGetCallInvocationIdUtf8, DecodeStateKeysProst, DecodeStateKeysUtf8, + EmptyGetCallInvocationId, EmptyStateKeys, +}; use crate::{Error, NonEmptyValue, Value}; use paste::paste; use prost::Message; @@ -234,6 +237,29 @@ impl EntryMessageHeaderEq for RunEntryMessage { } } +impl_message_traits!(CancelInvocationEntry: non_completable_entry); + +impl_message_traits!(GetCallInvocationIdEntry: message); +impl_message_traits!(GetCallInvocationIdEntry: entry); +impl CompletableEntryMessage for GetCallInvocationIdEntryMessage { + fn is_completed(&self) -> bool { + self.result.is_some() + } + + fn into_completion(self) -> Result, Error> { + self.result.map(TryInto::try_into).transpose() + } + + fn completion_parsing_hint() -> CompletionParsingHint { + CompletionParsingHint::GetCompletionId + } +} +impl EntryMessageHeaderEq for GetCallInvocationIdEntryMessage { + fn header_eq(&self, other: &Self) -> bool { + self.call_entry_index == other.call_entry_index + } +} + impl_message_traits!(CombinatorEntry: message); impl_message_traits!(CombinatorEntry: entry); impl WriteableRestateMessage for CombinatorEntryMessage { @@ -361,6 +387,17 @@ impl From for NonEmptyValue { } } +impl TryFrom for Value { + type Error = Error; + + fn try_from(value: get_call_invocation_id_entry_message::Result) -> Result { + Ok(match value { + get_call_invocation_id_entry_message::Result::Value(id) => Value::InvocationId(id), + get_call_invocation_id_entry_message::Result::Failure(f) => Value::Failure(f.into()), + }) + } +} + // --- Other conversions impl From for Failure { @@ -386,6 +423,7 @@ impl From for crate::TerminalFailure { #[derive(Debug)] pub(crate) enum CompletionParsingHint { StateKeys, + GetCompletionId, /// The normal case EmptyOrSuccessOrValue, } @@ -408,6 +446,13 @@ impl CompletionParsingHint { } completion_message::Result::Failure(f) => Ok(Value::Failure(f.into())), }, + CompletionParsingHint::GetCompletionId => match result { + completion_message::Result::Empty(_) => Err(EmptyGetCallInvocationId.into()), + completion_message::Result::Value(b) => Ok(Value::InvocationId( + String::from_utf8(b.to_vec()).map_err(DecodeGetCallInvocationIdUtf8)?, + )), + completion_message::Result::Failure(f) => Ok(Value::Failure(f.into())), + }, CompletionParsingHint::EmptyOrSuccessOrValue => Ok(match result { completion_message::Result::Empty(_) => Value::Void, completion_message::Result::Value(b) => Value::Success(b), diff --git a/src/service_protocol/version.rs b/src/service_protocol/version.rs index f35ec94..1f71e36 100644 --- a/src/service_protocol/version.rs +++ b/src/service_protocol/version.rs @@ -5,16 +5,19 @@ use std::str::FromStr; pub enum Version { V1 = 1, V2 = 2, + V3 = 3, } const CONTENT_TYPE_V1: &str = "application/vnd.restate.invocation.v1"; const CONTENT_TYPE_V2: &str = "application/vnd.restate.invocation.v2"; +const CONTENT_TYPE_V3: &str = "application/vnd.restate.invocation.v3"; impl Version { pub const fn content_type(&self) -> &'static str { match self { Version::V1 => CONTENT_TYPE_V1, Version::V2 => CONTENT_TYPE_V2, + Version::V3 => CONTENT_TYPE_V3, } } @@ -23,7 +26,7 @@ impl Version { } pub const fn maximum_supported_version() -> Self { - Version::V2 + Version::V3 } } @@ -44,6 +47,7 @@ impl FromStr for Version { match s { CONTENT_TYPE_V1 => Ok(Version::V1), CONTENT_TYPE_V2 => Ok(Version::V2), + CONTENT_TYPE_V3 => Ok(Version::V3), s => Err(UnsupportedVersionError(s.to_owned())), } } diff --git a/src/tests/async_result.rs b/src/tests/async_result.rs index 34804f5..e0fafeb 100644 --- a/src/tests/async_result.rs +++ b/src/tests/async_result.rs @@ -10,6 +10,7 @@ fn greeter_target() -> Target { service: "Greeter".to_string(), handler: "greeter".to_string(), key: None, + idempotency_key: None, } } diff --git a/src/tests/calls.rs b/src/tests/calls.rs new file mode 100644 index 0000000..d43b713 --- /dev/null +++ b/src/tests/calls.rs @@ -0,0 +1,181 @@ +use super::*; + +use crate::service_protocol::messages::*; +use assert2::let_assert; +use googletest::prelude::*; +use test_log::test; + +#[test] +fn call_then_get_invocation_id_then_cancel_invocation() { + let mut output = VMTestCase::new() + .input(start_message(1)) + .input(input_entry_message(b"my-data")) + .input(CompletionMessage { + entry_index: 2, + result: Some(completion_message::Result::Value(Bytes::from_static( + b"my-id", + ))), + }) + .run(|vm| { + vm.sys_input().unwrap(); + + let call_handle = vm + .sys_call( + Target { + service: "MySvc".to_string(), + handler: "MyHandler".to_string(), + key: None, + idempotency_key: None, + }, + Bytes::new(), + ) + .unwrap(); + + let invocation_id_handle = vm + .sys_get_call_invocation_id(GetInvocationIdTarget::CallEntry(call_handle)) + .unwrap(); + vm.notify_await_point(invocation_id_handle); + let_assert!( + Some(Value::InvocationId(invocation_id)) = + vm.take_async_result(invocation_id_handle).unwrap() + ); + assert_eq!(invocation_id, "my-id"); + + vm.sys_cancel_invocation(CancelInvocationTarget::CallEntry(call_handle)) + .unwrap(); + vm.sys_cancel_invocation(CancelInvocationTarget::InvocationId(invocation_id.clone())) + .unwrap(); + + vm.sys_end().unwrap(); + }); + + assert_that!( + output.next_decoded::().unwrap(), + pat!(CallEntryMessage { + service_name: eq("MySvc"), + handler_name: eq("MyHandler") + }) + ); + assert_eq!( + output + .next_decoded::() + .unwrap(), + GetCallInvocationIdEntryMessage { + call_entry_index: 1, + ..Default::default() + } + ); + assert_eq!( + output + .next_decoded::() + .unwrap(), + CancelInvocationEntryMessage { + target: Some(cancel_invocation_entry_message::Target::CallEntryIndex(1)), + ..Default::default() + } + ); + assert_eq!( + output + .next_decoded::() + .unwrap(), + CancelInvocationEntryMessage { + target: Some(cancel_invocation_entry_message::Target::InvocationId( + "my-id".to_string() + )), + ..Default::default() + } + ); + assert_eq!( + output.next_decoded::().unwrap(), + EndMessage::default() + ); + assert_eq!(output.next(), None); +} + +#[test] +fn send_then_get_invocation_id_then_cancel_invocation() { + let mut output = VMTestCase::new() + .input(start_message(1)) + .input(input_entry_message(b"my-data")) + .input(CompletionMessage { + entry_index: 2, + result: Some(completion_message::Result::Value(Bytes::from_static( + b"my-id", + ))), + }) + .run(|vm| { + vm.sys_input().unwrap(); + + let send_handle = vm + .sys_send( + Target { + service: "MySvc".to_string(), + handler: "MyHandler".to_string(), + key: None, + idempotency_key: None, + }, + Bytes::new(), + None, + ) + .unwrap(); + + let invocation_id_handle = vm + .sys_get_call_invocation_id(GetInvocationIdTarget::SendEntry(send_handle)) + .unwrap(); + vm.notify_await_point(invocation_id_handle); + let_assert!( + Some(Value::InvocationId(invocation_id)) = + vm.take_async_result(invocation_id_handle).unwrap() + ); + assert_eq!(invocation_id, "my-id"); + + vm.sys_cancel_invocation(CancelInvocationTarget::SendEntry(send_handle)) + .unwrap(); + vm.sys_cancel_invocation(CancelInvocationTarget::InvocationId(invocation_id.clone())) + .unwrap(); + + vm.sys_end().unwrap(); + }); + + assert_that!( + output.next_decoded::().unwrap(), + pat!(OneWayCallEntryMessage { + service_name: eq("MySvc"), + handler_name: eq("MyHandler") + }) + ); + assert_eq!( + output + .next_decoded::() + .unwrap(), + GetCallInvocationIdEntryMessage { + call_entry_index: 1, + ..Default::default() + } + ); + assert_eq!( + output + .next_decoded::() + .unwrap(), + CancelInvocationEntryMessage { + target: Some(cancel_invocation_entry_message::Target::CallEntryIndex(1)), + ..Default::default() + } + ); + assert_eq!( + output + .next_decoded::() + .unwrap(), + CancelInvocationEntryMessage { + target: Some(cancel_invocation_entry_message::Target::InvocationId( + "my-id".to_string() + )), + ..Default::default() + } + ); + assert_eq!( + output.next_decoded::().unwrap(), + EndMessage::default() + ); + assert_eq!(output.next(), None); +} diff --git a/src/tests/failures.rs b/src/tests/failures.rs index 2baf381..f0b4608 100644 --- a/src/tests/failures.rs +++ b/src/tests/failures.rs @@ -75,6 +75,7 @@ fn one_way_call_entry_mismatch() { service: "greeter".to_owned(), handler: "greet".to_owned(), key: Some("my-key".to_owned()), + idempotency_key: None, }, Bytes::from_static(b"456"), None, diff --git a/src/tests/mod.rs b/src/tests/mod.rs index 659b7d3..12a9eb5 100644 --- a/src/tests/mod.rs +++ b/src/tests/mod.rs @@ -1,4 +1,5 @@ mod async_result; +mod calls; mod failures; mod get_state; mod input_output; diff --git a/src/tests/state.rs b/src/tests/state.rs index d62911d..ee28a4b 100644 --- a/src/tests/state.rs +++ b/src/tests/state.rs @@ -25,7 +25,7 @@ fn get_state_handler(vm: &mut CoreVM) { vm.sys_end().unwrap(); return; } - _ => panic!("Unexpected variant"), + _ => panic!("Unexpected variants"), }; vm.sys_write_output(NonEmptyValue::Success(Bytes::copy_from_slice( @@ -374,7 +374,7 @@ mod eager { vm.sys_end().unwrap(); return; } - _ => panic!("Unexpected variant"), + _ => panic!("Unexpected variants"), }; vm.sys_write_output(NonEmptyValue::Success(Bytes::copy_from_slice( @@ -619,7 +619,7 @@ mod eager { vm.sys_end().unwrap(); return; } - _ => panic!("Unexpected variant"), + _ => panic!("Unexpected variants"), }; vm.sys_state_set( @@ -644,7 +644,7 @@ mod eager { vm.sys_end().unwrap(); return; } - _ => panic!("Unexpected variant"), + _ => panic!("Unexpected variants"), }; vm.sys_write_output(NonEmptyValue::Success(second_get_result)) @@ -799,7 +799,7 @@ mod eager { vm.sys_end().unwrap(); return; } - _ => panic!("Unexpected variant"), + _ => panic!("Unexpected variants"), }; vm.sys_state_clear("STATE".to_owned()).unwrap(); @@ -958,7 +958,7 @@ mod eager { vm.sys_end().unwrap(); return; } - _ => panic!("Unexpected variant"), + _ => panic!("Unexpected variants"), }; vm.sys_state_clear_all().unwrap(); diff --git a/src/vm/context.rs b/src/vm/context.rs index 5532f19..b8c9fc6 100644 --- a/src/vm/context.rs +++ b/src/vm/context.rs @@ -193,6 +193,7 @@ impl AsyncResultsState { Value::Void | Value::Success(_) | Value::StateKeys(_) + | Value::InvocationId(_) | Value::CombinatorResult(_) => AsyncResultState::Success, Value::Failure(_) => AsyncResultState::Failure, }, diff --git a/src/vm/errors.rs b/src/vm/errors.rs index a2be1dc..6cc45d1 100644 --- a/src/vm/errors.rs +++ b/src/vm/errors.rs @@ -1,5 +1,5 @@ use crate::service_protocol::{DecodingError, MessageType, UnsupportedVersionError}; -use crate::Error; +use crate::{Error, Version}; use std::borrow::Cow; use std::fmt; @@ -62,6 +62,7 @@ pub mod codes { pub const JOURNAL_MISMATCH: InvocationErrorCode = InvocationErrorCode(570); pub const PROTOCOL_VIOLATION: InvocationErrorCode = InvocationErrorCode(571); pub const AWAITING_TWO_ASYNC_RESULTS: InvocationErrorCode = InvocationErrorCode(572); + pub const UNSUPPORTED_FEATURE: InvocationErrorCode = InvocationErrorCode(573); } // Const errors @@ -196,6 +197,36 @@ pub struct DecodeStateKeysUtf8(#[from] pub(crate) std::string::FromUtf8Error); #[error("Unexpected empty value variant for state keys")] pub struct EmptyStateKeys; +#[derive(Debug, Clone, thiserror::Error)] +#[error("Unexpected empty variant for get call invocation id")] +pub struct EmptyGetCallInvocationId; + +#[derive(Debug, Clone, thiserror::Error)] +#[error("Cannot decode get call invocation id: {0}")] +pub struct DecodeGetCallInvocationIdUtf8(#[from] pub(crate) std::string::FromUtf8Error); + +#[derive(Debug, thiserror::Error)] +#[error("Feature {feature} is not supported by the negotiated protocol version '{current_version}', the minimum required version is '{minimum_required_version}'")] +pub struct UnsupportedFeatureForNegotiatedVersion { + feature: &'static str, + current_version: Version, + minimum_required_version: Version, +} + +impl UnsupportedFeatureForNegotiatedVersion { + pub fn new( + feature: &'static str, + current_version: Version, + minimum_required_version: Version, + ) -> Self { + Self { + feature, + current_version, + minimum_required_version, + } + } +} + // Conversions to VMError trait WithInvocationErrorCode { @@ -234,3 +265,6 @@ impl_error_code!(BadEagerStateKeyError, INTERNAL); impl_error_code!(DecodeStateKeysProst, PROTOCOL_VIOLATION); impl_error_code!(DecodeStateKeysUtf8, PROTOCOL_VIOLATION); impl_error_code!(EmptyStateKeys, PROTOCOL_VIOLATION); +impl_error_code!(EmptyGetCallInvocationId, PROTOCOL_VIOLATION); +impl_error_code!(DecodeGetCallInvocationIdUtf8, PROTOCOL_VIOLATION); +impl_error_code!(UnsupportedFeatureForNegotiatedVersion, UNSUPPORTED_FEATURE); diff --git a/src/vm/mod.rs b/src/vm/mod.rs index f5e91ec..51df4e0 100644 --- a/src/vm/mod.rs +++ b/src/vm/mod.rs @@ -1,21 +1,22 @@ use crate::headers::HeaderMap; use crate::service_protocol::messages::get_state_keys_entry_message::StateKeys; use crate::service_protocol::messages::{ - complete_awakeable_entry_message, complete_promise_entry_message, get_state_entry_message, - get_state_keys_entry_message, output_entry_message, AwakeableEntryMessage, CallEntryMessage, + cancel_invocation_entry_message, complete_awakeable_entry_message, + complete_promise_entry_message, get_state_entry_message, get_state_keys_entry_message, + output_entry_message, AwakeableEntryMessage, CallEntryMessage, CancelInvocationEntryMessage, ClearAllStateEntryMessage, ClearStateEntryMessage, CompleteAwakeableEntryMessage, - CompletePromiseEntryMessage, Empty, GetPromiseEntryMessage, GetStateEntryMessage, - GetStateKeysEntryMessage, OneWayCallEntryMessage, OutputEntryMessage, PeekPromiseEntryMessage, - SetStateEntryMessage, SleepEntryMessage, + CompletePromiseEntryMessage, Empty, GetCallInvocationIdEntryMessage, GetPromiseEntryMessage, + GetStateEntryMessage, GetStateKeysEntryMessage, OneWayCallEntryMessage, OutputEntryMessage, + PeekPromiseEntryMessage, SetStateEntryMessage, SleepEntryMessage, }; use crate::service_protocol::{Decoder, RawMessage, Version}; use crate::vm::context::{EagerGetState, EagerGetStateKeys}; -use crate::vm::errors::UnexpectedStateError; +use crate::vm::errors::{UnexpectedStateError, UnsupportedFeatureForNegotiatedVersion}; use crate::vm::transitions::*; use crate::{ - AsyncResultCombinator, AsyncResultHandle, Error, Header, Input, NonEmptyValue, ResponseHead, - RetryPolicy, RunEnterResult, RunExitResult, SuspendedOrVMError, TakeOutputResult, Target, - VMOptions, VMResult, Value, + AsyncResultCombinator, AsyncResultHandle, CancelInvocationTarget, Error, GetInvocationIdTarget, + Header, Input, NonEmptyValue, ResponseHead, RetryPolicy, RunEnterResult, RunExitResult, + SendHandle, SuspendedOrVMError, TakeOutputResult, Target, VMOptions, VMResult, Value, }; use base64::engine::{DecodePaddingMode, GeneralPurpose, GeneralPurposeConfig}; use base64::{alphabet, Engine}; @@ -84,6 +85,25 @@ impl CoreVM { "" } } + + fn verify_feature_support( + &mut self, + feature: &'static str, + minimum_required_protocol: Version, + ) -> VMResult<()> { + if self.version < minimum_required_protocol { + return self.do_transition(HitError { + error: UnsupportedFeatureForNegotiatedVersion::new( + feature, + self.version, + minimum_required_protocol, + ) + .into(), + next_retry_delay: None, + }); + } + Ok(()) + } } impl fmt::Debug for CoreVM { @@ -413,12 +433,16 @@ impl super::VM for CoreVM { ret )] fn sys_call(&mut self, target: Target, input: Bytes) -> VMResult { + if target.idempotency_key.is_some() { + self.verify_feature_support("attach idempotency key to one way call", Version::V3)?; + } self.do_transition(SysCompletableEntry( "SysCall", CallEntryMessage { service_name: target.service, handler_name: target.handler, key: target.key.unwrap_or_default(), + idempotency_key: target.idempotency_key.unwrap_or_default(), parameter: input, ..Default::default() }, @@ -431,13 +455,22 @@ impl super::VM for CoreVM { fields(restate.invocation.id = self.debug_invocation_id(), restate.journal.index = self.context.journal.index(), restate.protocol.version = %self.version), ret )] - fn sys_send(&mut self, target: Target, input: Bytes, delay: Option) -> VMResult<()> { + fn sys_send( + &mut self, + target: Target, + input: Bytes, + delay: Option, + ) -> VMResult { + if target.idempotency_key.is_some() { + self.verify_feature_support("attach idempotency key to one way call", Version::V3)?; + } self.do_transition(SysNonCompletableEntry( "SysOneWayCall", OneWayCallEntryMessage { service_name: target.service, handler_name: target.handler, key: target.key.unwrap_or_default(), + idempotency_key: target.idempotency_key.unwrap_or_default(), parameter: input, invoke_time: delay .map(|d| { @@ -448,6 +481,7 @@ impl super::VM for CoreVM { ..Default::default() }, )) + .map(|_| SendHandle(self.context.journal.expect_index())) } #[instrument( @@ -578,6 +612,46 @@ impl super::VM for CoreVM { self.do_transition(SysRunExit(value, retry_policy)) } + #[instrument(level = "debug", ret)] + fn sys_get_call_invocation_id( + &mut self, + call: GetInvocationIdTarget, + ) -> VMResult { + self.verify_feature_support("get call invocation id", Version::V3)?; + self.do_transition(SysCompletableEntry( + "SysGetCallInvocationId", + GetCallInvocationIdEntryMessage { + call_entry_index: match call { + GetInvocationIdTarget::CallEntry(h) => h.0, + GetInvocationIdTarget::SendEntry(h) => h.0, + }, + ..Default::default() + }, + )) + } + + #[instrument(level = "debug", ret)] + fn sys_cancel_invocation(&mut self, target: CancelInvocationTarget) -> VMResult<()> { + self.verify_feature_support("cancel invocation", Version::V3)?; + self.do_transition(SysNonCompletableEntry( + "SysCancelInvocation", + CancelInvocationEntryMessage { + target: Some(match target { + CancelInvocationTarget::InvocationId(id) => { + cancel_invocation_entry_message::Target::InvocationId(id) + } + CancelInvocationTarget::CallEntry(handle) => { + cancel_invocation_entry_message::Target::CallEntryIndex(handle.0) + } + CancelInvocationTarget::SendEntry(handle) => { + cancel_invocation_entry_message::Target::CallEntryIndex(handle.0) + } + }), + ..Default::default() + }, + )) + } + #[instrument( level = "debug", skip(self, value),