diff --git a/crates/errors/src/error_codes/META0002.md b/crates/errors/src/error_codes/META0002.md index ea46a5dcf..9aa9a02cc 100644 --- a/crates/errors/src/error_codes/META0002.md +++ b/crates/errors/src/error_codes/META0002.md @@ -2,10 +2,7 @@ Bad key definition encountered while registering/updating a service. When a service is keyed, for each method the input message must have a field annotated with `dev.restate.ext.field`. -When defining the key field, make sure: - -* The field type is either a primitive or a custom message, and not a repeated field nor a map. -* The field type is the same for every method input message of the same service. +The key field type must be `string`. Example: @@ -17,6 +14,6 @@ service HelloWorld { } message GreetingRequest { - Person person = 1 [(dev.restate.ext.field) = KEY]; + string person_id = 1 [(dev.restate.ext.field) = KEY]; } ``` \ No newline at end of file diff --git a/crates/ingress-dispatcher/src/event_remapping.rs b/crates/ingress-dispatcher/src/event_remapping.rs index 93b091a39..df2eda4a9 100644 --- a/crates/ingress-dispatcher/src/event_remapping.rs +++ b/crates/ingress-dispatcher/src/event_remapping.rs @@ -16,10 +16,10 @@ use std::fmt; #[derive(Debug, thiserror::Error)] #[error("Field {field_name} cannot be mapped to field tag {tag} because it's not a valid UTF-8 string: {reason}")] pub struct Error { - field_name: &'static str, - tag: u32, + pub(crate) field_name: &'static str, + pub(crate) tag: u32, #[source] - reason: core::str::Utf8Error, + pub(crate) reason: core::str::Utf8Error, } /// Structure that implements the remapping of the event fields. diff --git a/crates/ingress-dispatcher/src/lib.rs b/crates/ingress-dispatcher/src/lib.rs index ee733b84c..ef743b54c 100644 --- a/crates/ingress-dispatcher/src/lib.rs +++ b/crates/ingress-dispatcher/src/lib.rs @@ -8,7 +8,7 @@ // the Business Source License, use of this software will be governed // by the Apache License, Version 2.0. -use bytes::{Buf, BufMut, Bytes, BytesMut}; +use bytes::Bytes; use bytestring::ByteString; use prost::Message; use restate_pb::restate::Event; @@ -188,13 +188,19 @@ impl IngressRequest { match instance_type { EventReceiverServiceInstanceType::Keyed { ordering_key_is_key, - } => generate_restate_key(if *ordering_key_is_key { + } => Bytes::from(if *ordering_key_is_key { event.ordering_key.clone() } else { - event.key.clone() + std::str::from_utf8(&event.key) + .map_err(|e| EventError { + field_name: "key", + tag: 2, + reason: e, + })? + .to_owned() }), EventReceiverServiceInstanceType::Unkeyed => { - Bytes::copy_from_slice(InvocationUuid::now_v7().as_bytes()) + Bytes::from(InvocationUuid::now_v7().to_string()) } EventReceiverServiceInstanceType::Singleton => Bytes::new(), }, @@ -254,18 +260,6 @@ impl IngressRequest { } } -fn generate_restate_key(key: impl Buf) -> Bytes { - // Because this needs to be a valid Restate key, we need to prepend it with its length to make it - // look like it was extracted using the RestateKeyExtractor - // This is done to ensure all the other operations on the key will work correctly (e.g. key to json) - let key_len = key.remaining(); - let mut buf = - BytesMut::with_capacity(prost::encoding::encoded_len_varint(key_len as u64) + key_len); - prost::encoding::encode_varint(key_len as u64, &mut buf); - buf.put(key); - buf.freeze() -} - // -- Types used by the network to interact with the ingress dispatcher service pub type IngressDispatcherInputReceiver = mpsc::Receiver; diff --git a/crates/ingress-kafka/Cargo.toml b/crates/ingress-kafka/Cargo.toml index 43b582502..19b0b585a 100644 --- a/crates/ingress-kafka/Cargo.toml +++ b/crates/ingress-kafka/Cargo.toml @@ -20,6 +20,7 @@ restate-schema-api = { workspace = true, features = ["subscription"] } restate-timer-queue = { workspace = true } restate-types = { workspace = true } +base64 = { workspace = true } bytes = { workspace = true } derive_builder = { workspace = true } drain = { workspace = true } diff --git a/crates/ingress-kafka/src/consumer_task.rs b/crates/ingress-kafka/src/consumer_task.rs index 41d5749fd..0b9e0323c 100644 --- a/crates/ingress-kafka/src/consumer_task.rs +++ b/crates/ingress-kafka/src/consumer_task.rs @@ -8,7 +8,8 @@ // the Business Source License, use of this software will be governed // by the Apache License, Version 2.0. -use bytes::{BufMut, Bytes, BytesMut}; +use base64::Engine; +use bytes::Bytes; use opentelemetry_api::trace::TraceContextExt; use rdkafka::consumer::{Consumer, DefaultConsumerContext, StreamConsumer}; use rdkafka::error::KafkaError; @@ -161,20 +162,22 @@ impl MessageSender { ordering_key_format: &KafkaOrderingKeyFormat, ordering_key_prefix: &str, msg: &impl Message, - ) -> Bytes { - let mut buf = BytesMut::new(); - buf.put(ordering_key_prefix.as_bytes()); - buf.put(msg.topic().as_bytes()); - buf.put_i32(msg.partition()); + ) -> String { + let partition = msg.partition().to_string(); - match ordering_key_format { - KafkaOrderingKeyFormat::ConsumerGroupTopicPartitionKey if msg.key().is_some() => { - buf.put(msg.key().unwrap()) - } - _ => {} - }; + let mut buf = + String::with_capacity(ordering_key_prefix.len() + msg.topic().len() + partition.len()); + buf.push_str(ordering_key_prefix); + buf.push_str(msg.topic()); + buf.push_str(&partition); + + if let (KafkaOrderingKeyFormat::ConsumerGroupTopicPartitionKey, Some(key)) = + (ordering_key_format, msg.key()) + { + buf.push_str(&base64::prelude::BASE64_STANDARD.encode(key)); + } - buf.freeze() + buf } fn generate_events_attributes( diff --git a/crates/pb/proto/dev/restate/events.proto b/crates/pb/proto/dev/restate/events.proto index 98ffc2dbf..97518517c 100644 --- a/crates/pb/proto/dev/restate/events.proto +++ b/crates/pb/proto/dev/restate/events.proto @@ -12,36 +12,16 @@ syntax = "proto3"; package dev.restate; import "dev/restate/ext.proto"; -import "google/protobuf/empty.proto"; -import "google/protobuf/struct.proto"; option java_multiple_files = true; option java_package = "dev.restate.generated"; option go_package = "restate.dev/sdk-go/pb"; message Event { - bytes ordering_key = 1 [(dev.restate.ext.field) = KEY]; + string ordering_key = 1 [(dev.restate.ext.field) = KEY]; bytes key = 2; bytes payload = 3; map attributes = 15; } - -message KeyedEvent { - option deprecated = true; - - // Payload - bytes key = 1 [(dev.restate.ext.field) = KEY]; - bytes payload = 2 [(dev.restate.ext.field) = EVENT_PAYLOAD]; - map attributes = 15 [(dev.restate.ext.field) = EVENT_METADATA]; -} - -message StringKeyedEvent { - option deprecated = true; - - // Payload - string key = 1 [(dev.restate.ext.field) = KEY]; - bytes payload = 2 [(dev.restate.ext.field) = EVENT_PAYLOAD]; - map attributes = 15 [(dev.restate.ext.field) = EVENT_METADATA]; -} diff --git a/crates/pb/proto/dev/restate/ext.proto b/crates/pb/proto/dev/restate/ext.proto index 86f8ca1c3..6bc840661 100644 --- a/crates/pb/proto/dev/restate/ext.proto +++ b/crates/pb/proto/dev/restate/ext.proto @@ -29,8 +29,14 @@ enum ServiceType { enum FieldType { // protolint:disable:next ENUM_FIELD_NAMES_ZERO_VALUE_END_WITH + // Note: only string fields can be used for service key fields KEY = 0; + + // Flag a field as event payload. When receiving events, this field will be filled with the event payload. + // Note: only string fields can be used for event payload fields EVENT_PAYLOAD = 1; + // Flag a field as event metadata. When receiving events, this field will be filled with the event metadata. + // Note: only type map can be used for event payload fields EVENT_METADATA = 2; } diff --git a/crates/pb/tests/proto/event_handler.proto b/crates/pb/tests/proto/event_handler.proto index 9a5a565fa..e27b4c4d8 100644 --- a/crates/pb/tests/proto/event_handler.proto +++ b/crates/pb/tests/proto/event_handler.proto @@ -8,11 +8,3 @@ package eventhandler; service EventHandler { rpc Handle(dev.restate.Event) returns (google.protobuf.Empty); } - -service KeyedEventHandler { - rpc Handle(dev.restate.KeyedEvent) returns (google.protobuf.Empty); -} - -service StringKeyedEventHandler { - rpc Handle(dev.restate.StringKeyedEvent) returns (google.protobuf.Empty); -} diff --git a/crates/schema-api/src/lib.rs b/crates/schema-api/src/lib.rs index fc75f6fcf..aa3a9edae 100644 --- a/crates/schema-api/src/lib.rs +++ b/crates/schema-api/src/lib.rs @@ -559,6 +559,8 @@ pub mod key { UnexpectedServiceInstanceType, #[error("unexpected value for a singleton service. Singleton service have no service key associated")] UnexpectedNonNullSingletonKey, + #[error("bad unkeyed service key. Expected a string")] + BadUnkeyedKey, #[error("error when decoding the json key: {0}")] DecodeJson(#[from] serde_json::Error), } diff --git a/crates/schema-impl/src/json_key_conversion.rs b/crates/schema-impl/src/json_key_conversion.rs index 82c38f15a..380301685 100644 --- a/crates/schema-impl/src/json_key_conversion.rs +++ b/crates/schema-impl/src/json_key_conversion.rs @@ -16,11 +16,9 @@ use bytes::Bytes; use prost::Message; use prost_reflect::{DynamicMessage, MethodDescriptor}; use restate_schema_api::key::json_conversion::{Error, RestateKeyConverter}; -use restate_serde_util::SerdeableUuid; use serde::de::IntoDeserializer; -use serde::{Deserialize, Serialize}; +use serde::Serialize; use serde_json::{Map, Value}; -use uuid::Uuid; impl RestateKeyConverter for Schemas { fn key_to_json( @@ -89,10 +87,7 @@ fn key_to_json( }) } InstanceTypeMetadata::Unkeyed => Ok(Value::String( - uuid::Builder::from_slice(key.as_ref()) - .unwrap() - .into_uuid() - .to_string(), + String::from_utf8(key.as_ref().to_vec()).expect("Must be a valid UTF-8 string"), )), InstanceTypeMetadata::Singleton => Ok(Value::Object(Map::new())), InstanceTypeMetadata::Unsupported => Err(Error::NotFound), @@ -135,9 +130,11 @@ fn json_to_key( )?) } InstanceTypeMetadata::Unkeyed => { - let parse_result: Uuid = SerdeableUuid::deserialize(key.into_deserializer())?.into(); - - Ok(parse_result.as_bytes().to_vec().into()) + return if let Some(key_str) = key.as_str() { + Ok(Bytes::copy_from_slice(key_str.as_bytes())) + } else { + Err(Error::BadUnkeyedKey) + } } InstanceTypeMetadata::Singleton if key.is_null() => Ok(Bytes::default()), InstanceTypeMetadata::Singleton => Err(Error::UnexpectedNonNullSingletonKey), @@ -156,7 +153,7 @@ mod tests { use restate_pb::mocks::test::*; use restate_schema_api::discovery::KeyStructure; use serde::Serialize; - use std::collections::{BTreeMap, HashMap}; + use std::collections::HashMap; use uuid::Uuid; static METHOD_NAME: &str = "Test"; @@ -189,18 +186,6 @@ mod tests { } } - fn nested_key_structure() -> KeyStructure { - KeyStructure::Nested(BTreeMap::from([ - (1, KeyStructure::Scalar), - (2, KeyStructure::Scalar), - (3, KeyStructure::Scalar), - ( - 4, - KeyStructure::Nested(BTreeMap::from([(1, KeyStructure::Scalar)])), - ), - ])) - } - fn mock_keyed_service_instance_type( key_structure: KeyStructure, field_number: u32, @@ -328,36 +313,6 @@ mod tests { } json_tests!(string); - json_tests!(bytes); - json_tests!(number); - json_tests!(nested_message, nested_key_structure()); - json_tests!( - test: nested_message_with_default, - field_name: nested_message, - key_structure: nested_key_structure(), - test_message: TestMessage { - nested_message: Some(NestedKey { - b: "b".to_string(), - ..Default::default() - }), - ..Default::default() - } - ); - json_tests!( - test: double_nested_message, - field_name: nested_message, - key_structure: nested_key_structure(), - test_message: TestMessage { - nested_message: Some(NestedKey { - b: "b".to_string(), - other: Some(OtherMessage { - d: "d".to_string() - }), - ..Default::default() - }), - ..Default::default() - } - ); #[test] fn unkeyed_convert_key_to_json() { @@ -411,14 +366,11 @@ mod tests { let expected_restate_key = extract(&service_instance_type, METHOD_NAME, Bytes::new()) .expect("successful key extraction"); - // Parse this as uuid - let uuid = Uuid::from_slice(&expected_restate_key).unwrap(); - // Now convert the key to json let actual_restate_key = json_to_key( &service_instance_type, test_method_descriptor(), - Value::String(uuid.as_simple().to_string()), + Value::String(String::from_utf8(expected_restate_key.to_vec()).unwrap()), ) .unwrap(); diff --git a/crates/schema-impl/src/key_expansion.rs b/crates/schema-impl/src/key_expansion.rs index cd744b03a..ab07f08c7 100644 --- a/crates/schema-impl/src/key_expansion.rs +++ b/crates/schema-impl/src/key_expansion.rs @@ -45,7 +45,7 @@ pub(crate) mod expand_impls { use bytes::{BufMut, BytesMut}; use crate::schemas_impl::InstanceTypeMetadata; - use prost::encoding::{encode_key, key_len}; + use prost::encoding::{encode_key, encode_varint, key_len}; use prost_reflect::{DynamicMessage, MessageDescriptor}; pub(crate) fn expand( @@ -76,6 +76,8 @@ pub(crate) mod expand_impls { // which converts groups to nested messages. encode_key(root_number, field_descriptor_kind.wire_type(), &mut b); + encode_varint(restate_key.as_ref().len() as u64, &mut b); + // Append the restate key buffer b.put(restate_key.as_ref()); @@ -95,7 +97,7 @@ pub(crate) mod expand_impls { use restate_pb::mocks::test::*; use restate_pb::mocks::DESCRIPTOR_POOL; use restate_schema_api::discovery::KeyStructure; - use std::collections::{BTreeMap, HashMap}; + use std::collections::HashMap; static METHOD_NAME: &str = "test"; @@ -119,18 +121,6 @@ pub(crate) mod expand_impls { } } - fn nested_key_structure() -> KeyStructure { - KeyStructure::Nested(BTreeMap::from([ - (1, KeyStructure::Scalar), - (2, KeyStructure::Scalar), - (3, KeyStructure::Scalar), - ( - 4, - KeyStructure::Nested(BTreeMap::from([(1, KeyStructure::Scalar)])), - ), - ])) - } - fn mock_keyed_service_instance_type( key_structure: KeyStructure, field_number: u32, @@ -207,35 +197,5 @@ pub(crate) mod expand_impls { } expand_tests!(string); - expand_tests!(bytes); - expand_tests!(number); - expand_tests!(nested_message, nested_key_structure()); - expand_tests!( - test: nested_message_with_default, - field_name: nested_message, - key_structure: nested_key_structure(), - test_message: TestMessage { - nested_message: Some(NestedKey { - b: "b".to_string(), - ..Default::default() - }), - ..Default::default() - } - ); - expand_tests!( - test: double_nested_message, - field_name: nested_message, - key_structure: nested_key_structure(), - test_message: TestMessage { - nested_message: Some(NestedKey { - b: "b".to_string(), - other: Some(OtherMessage { - d: "d".to_string() - }), - ..Default::default() - }), - ..Default::default() - } - ); } } diff --git a/crates/schema-impl/src/key_extraction.rs b/crates/schema-impl/src/key_extraction.rs index 2c4c433c1..38e0aff04 100644 --- a/crates/schema-impl/src/key_extraction.rs +++ b/crates/schema-impl/src/key_extraction.rs @@ -35,13 +35,13 @@ pub(crate) mod extract_impls { use bytes::{Buf, BufMut, Bytes, BytesMut}; use prost::encoding::WireType::*; use prost::encoding::{ - decode_key, decode_varint, encode_key, encode_varint, skip_field, DecodeContext, WireType, + decode_key, decode_varint, encode_key, skip_field, DecodeContext, WireType, }; use restate_schema_api::discovery::KeyStructure; use uuid::Uuid; fn generate_random_key() -> Bytes { - Bytes::copy_from_slice(Uuid::now_v7().as_bytes()) + Bytes::copy_from_slice(Uuid::now_v7().to_string().as_ref()) } pub(crate) fn extract( @@ -167,142 +167,20 @@ pub(crate) mod extract_impls { match (current_wire_type, current_parser_directive) { // Primitive cases - (Varint, _) => result_buf.put(slice_varint_bytes(buf)?), - (ThirtyTwoBit, _) => result_buf.put(slice_const_bytes(buf, 4)?), - (SixtyFourBit, _) => result_buf.put(slice_const_bytes(buf, 8)?), (LengthDelimited, KeyStructure::Scalar) => { - let (length, field_slice) = slice_length_delimited_bytes(buf)?; - encode_varint(length, &mut result_buf); + let (_length, field_slice) = slice_length_delimited_bytes(buf)?; result_buf.put(field_slice) } - - // Composite cases - (StartGroup, KeyStructure::Nested(expected_message_fields)) => { - let mut message_fields = Vec::with_capacity(expected_message_fields.len()); - loop { - let (next_field_number, next_wire_type) = decode_key(buf)?; - if next_wire_type == EndGroup { - break; - } - match expected_message_fields.get(&next_field_number) { - None => { - // Unknown field, just skip it - skip_field( - next_wire_type, - next_field_number, - buf, - DecodeContext::default(), - )?; - continue; - } - Some(next_parser_directive) => { - message_fields.push(( - next_field_number, - deep_extract( - buf, - next_field_number, - next_wire_type, - next_parser_directive, - true, - )?, - )); - } - }; - } - // Reorder fields - message_fields.sort_by(|(index_a, _), (index_b, _)| index_a.cmp(index_b)); - - // Compute length delimited message length - let inner_message_length: usize = - message_fields.iter().map(|(_, buf)| buf.len()).sum(); - encode_varint(inner_message_length as u64, &mut result_buf); - - // Write the fields - for (_, b) in message_fields { - result_buf.put(b) - } - } - - (LengthDelimited, KeyStructure::Nested(expected_message_fields)) => { - let mut message_fields = Vec::with_capacity(expected_message_fields.len()); - let inner_message_len = decode_varint(buf)? as usize; - let mut current_buf = buf.split_to(inner_message_len); - while current_buf.has_remaining() { - let (next_field_number, next_wire_type) = decode_key(&mut current_buf)?; - match expected_message_fields.get(&next_field_number) { - None => { - // Unknown field, just skip it - skip_field( - next_wire_type, - next_field_number, - &mut current_buf, - DecodeContext::default(), - )?; - } - Some(next_parser_directive) => { - message_fields.push(( - next_field_number, - deep_extract( - &mut current_buf, - next_field_number, - next_wire_type, - next_parser_directive, - true, - )?, - )); - } - }; - } - // Reorder fields - message_fields.sort_by(|(index_a, _), (index_b, _)| index_a.cmp(index_b)); - - // Compute length delimited message length - // We recompute it as the size could be different if we converted a nested message that was a group - let inner_message_length: usize = - message_fields.iter().map(|(_, buf)| buf.len()).sum(); - encode_varint(inner_message_length as u64, &mut result_buf); - - // Write the fields - for (_, b) in message_fields { - result_buf.put(b) - } - } - - // Expecting a primitive message, but got composite -> schema mismatch - (StartGroup | EndGroup, KeyStructure::Scalar) => return Err(Error::UnexpectedValue), - // EndGroup is handled by the loop below, so we're not supposed to have a match here - (EndGroup, _) => return Err(Error::UnexpectedValue), + // Expecting a string key, but got something else -> schema mismatch + (_, KeyStructure::Scalar) => return Err(Error::UnexpectedValue), + (_, _) => panic!( + "Unsupported key extraction. See https://github.com/restatedev/restate/issues/955" + ), }; Ok(result_buf.freeze()) } - /// This behaves similarly to [decode_varint], but without parsing the number, but simply returning the bytes composing it. - fn slice_varint_bytes(buf: &mut Bytes) -> Result { - let len = buf.len(); - if len == 0 { - return Err(Error::UnexpectedEndOfBuffer); - } - - let mut scanned_bytes = 0; - let mut end_byte_reached = false; - while scanned_bytes < len { - if buf[scanned_bytes] < 0x80 { - // MSB == 1 means more bytes, == 0 means last byte - end_byte_reached = true; - break; - } - scanned_bytes += 1; - } - if end_byte_reached { - let res = buf.split_to(scanned_bytes + 1); - - return Ok(res); - } - - Err(Error::UnexpectedEndOfBuffer) - } - fn slice_const_bytes(buf: &mut Bytes, len: usize) -> Result { check_remaining(buf, len)?; let res = buf.split_to(len); @@ -322,7 +200,6 @@ pub(crate) mod extract_impls { use prost::encoding::{encode_key, encode_varint, key_len, DecodeContext}; use prost::{length_delimiter_len, DecodeError, Message}; - use std::collections::BTreeMap; #[derive(Debug)] struct MockMessage { @@ -350,43 +227,6 @@ pub(crate) mod extract_impls { } impl MockMessage { - fn fill_expected_buf(&self, out_buf: &mut BytesMut) { - let mut msg_buf = BytesMut::new(); - - // Write fields - self.write_a(&mut msg_buf); - self.write_b(false, &mut msg_buf); - self.write_c(&mut msg_buf); - - let msg_buf = msg_buf.freeze(); - - // Write the msg_buf in the output buf - encode_varint(msg_buf.len() as u64, out_buf); - out_buf.put(msg_buf); - } - - fn parser_directive(nested: bool) -> KeyStructure { - if nested { - KeyStructure::Nested(BTreeMap::from([ - ( - 1, - KeyStructure::Nested(BTreeMap::from([(1, KeyStructure::Scalar)])), - ), - ( - 2, - KeyStructure::Nested(BTreeMap::from([(1, KeyStructure::Scalar)])), - ), - (3, KeyStructure::Scalar), - ])) - } else { - KeyStructure::Nested(BTreeMap::from([ - (1, KeyStructure::Scalar), - (2, KeyStructure::Scalar), - (3, KeyStructure::Scalar), - ])) - } - } - fn write_a(&self, buf: &mut B) { if self.nested { Self::write_in_length_delimited(1, &self.a, buf); @@ -594,262 +434,15 @@ pub(crate) mod extract_impls { }; } - // Note: The encoding from rust types to protobuf has been taken directly from - // https://github.com/tokio-rs/prost/blob/master/src/encoding.rs - - // Test single varint size type - extract_tests!( - bool, - item: true, - fill_expected_buf: |buf: &mut BytesMut, val| encode_varint(if val { 1u64 } else { 0u64 }, buf) - ); - extract_tests!( - int32, - item: -21314_i32, - fill_expected_buf: |buf: &mut BytesMut, val| encode_varint(val as u64, buf) - ); - extract_tests!( - int64, - item: -245361314_i64, - fill_expected_buf: |buf: &mut BytesMut, val| encode_varint(val as u64, buf) - ); - extract_tests!( - uint32, - item: 21314_u32, - fill_expected_buf: |buf: &mut BytesMut, val| encode_varint(val as u64, buf) - ); - extract_tests!( - uint64, - item: 245361314_u64, - fill_expected_buf: |buf: &mut BytesMut, val: u64| encode_varint(val, buf) - ); - extract_tests!( - sint32, - item: -21314_i32, - fill_expected_buf: |buf: &mut BytesMut, val| encode_varint(((val << 1) ^ (val >> 31)) as u32 as u64, buf) - ); - extract_tests!( - sint64, - item: -245361314_i64, - fill_expected_buf: |buf: &mut BytesMut, val| encode_varint(((val << 1) ^ (val >> 63)) as u64, buf) - ); - - // Test single 32/64 const size type - extract_tests!( - float, - item: 4543.342_f32, - fill_expected_buf: |buf: &mut BytesMut, val| buf.put_f32_le(val) - ); - extract_tests!( - double, - item: 4543986.342542_f64, - fill_expected_buf: |buf: &mut BytesMut, val| buf.put_f64_le(val) - ); - extract_tests!( - fixed32, - item: 4543_u32, - fill_expected_buf: |buf: &mut BytesMut, val| buf.put_u32_le(val) - ); - extract_tests!( - fixed64, - item: 349320_u64, - fill_expected_buf: |buf: &mut BytesMut, val| buf.put_u64_le(val) - ); - extract_tests!( - sfixed32, - item: -4543_i32, - fill_expected_buf: |buf: &mut BytesMut, val| buf.put_i32_le(val) - ); - extract_tests!( - sfixed64, - item: -349320_i64, - fill_expected_buf: |buf: &mut BytesMut, val| buf.put_i64_le(val) - ); - // Test single length delimited type extract_tests!( string, item: "my awesome string".to_string(), fill_expected_buf: |buf: &mut BytesMut, val: String| { - encode_varint(val.len().try_into().unwrap(), buf); + //encode_varint(val.len().try_into().unwrap(), buf); buf.put_slice(val.as_bytes()); } ); - extract_tests!( - bytes, - item: Bytes::from_static(&[1_u8, 2, 3]), - fill_expected_buf: |buf: &mut BytesMut, val: Bytes| { - encode_varint(val.len().try_into().unwrap(), buf); - buf.put_slice(&val); - } - ); - - // Test message - // Note: the difference between message and group is that - // the former encodes using length delimited message encoding, - // while the latter encodes using the [Start/End]Group markers - extract_tests!( - message, - item: MockMessage::default(), - fill_expected_buf: |buf: &mut BytesMut, val: MockMessage| val.fill_expected_buf(buf), - parser: MockMessage::parser_directive(false) - ); - extract_tests!( - message, - mod: message_reverse, - item: MockMessage { - ordered_encoding: false, - ..MockMessage::default() - }, - fill_expected_buf: |buf: &mut BytesMut, val: MockMessage| val.fill_expected_buf(buf), - parser: MockMessage::parser_directive(false) - ); - extract_tests!( - group, - item: MockMessage::default(), - fill_expected_buf: |buf: &mut BytesMut, val: MockMessage| val.fill_expected_buf(buf), - parser: MockMessage::parser_directive(false) - ); - extract_tests!( - group, - mod: group_reverse, - item: MockMessage { - ordered_encoding: false, - ..MockMessage::default() - }, - fill_expected_buf: |buf: &mut BytesMut, val: MockMessage| val.fill_expected_buf(buf), - parser: MockMessage::parser_directive(false) - ); - extract_tests!( - message, - mod: message_nested, - item: MockMessage { - nested: true, - ..MockMessage::default() - }, - fill_expected_buf: |buf: &mut BytesMut, val: MockMessage| val.fill_expected_buf(buf), - parser: MockMessage::parser_directive(true) - ); - extract_tests!( - message, - mod: message_nested_reverse, - item: MockMessage { - nested: true, - ordered_encoding: false, - ..MockMessage::default() - }, - fill_expected_buf: |buf: &mut BytesMut, val: MockMessage| val.fill_expected_buf(buf), - parser: MockMessage::parser_directive(true) - ); - extract_tests!( - group, - mod: group_nested, - item: MockMessage { - nested: true, - ..MockMessage::default() - }, - fill_expected_buf: |buf: &mut BytesMut, val: MockMessage| val.fill_expected_buf(buf), - parser: MockMessage::parser_directive(true) - ); - extract_tests!( - group, - mod: group_nested_reverse, - item: MockMessage { - nested: true, - ordered_encoding: false, - ..MockMessage::default() - }, - fill_expected_buf: |buf: &mut BytesMut, val: MockMessage| val.fill_expected_buf(buf), - parser: MockMessage::parser_directive(true) - ); - - // Tests with unknown field - extract_tests!( - message, - mod: message_unknown, - item: MockMessage { - unknown_field: true, - ..MockMessage::default() - }, - fill_expected_buf: |buf: &mut BytesMut, val: MockMessage| val.fill_expected_buf(buf), - parser: MockMessage::parser_directive(false) - ); - extract_tests!( - group, - mod: group_unknown, - item: MockMessage { - unknown_field: true, - ..MockMessage::default() - }, - fill_expected_buf: |buf: &mut BytesMut, val: MockMessage| val.fill_expected_buf(buf), - parser: MockMessage::parser_directive(false) - ); - extract_tests!( - message, - mod: message_reverse_unknown, - item: MockMessage { - ordered_encoding: false, - unknown_field: true, - ..MockMessage::default() - }, - fill_expected_buf: |buf: &mut BytesMut, val: MockMessage| val.fill_expected_buf(buf), - parser: MockMessage::parser_directive(false) - ); - extract_tests!( - group, - mod: group_reverse_unknown, - item: MockMessage { - ordered_encoding: false, - unknown_field: true, - ..MockMessage::default() - }, - fill_expected_buf: |buf: &mut BytesMut, val: MockMessage| val.fill_expected_buf(buf), - parser: MockMessage::parser_directive(false) - ); - - // Test skipping B - extract_tests!( - message, - mod: message_skip_b, - item: MockMessage { - skip_b: true, - ..MockMessage::default() - }, - fill_expected_buf: |buf: &mut BytesMut, val: MockMessage| val.fill_expected_buf(buf), - parser: MockMessage::parser_directive(false) - ); - extract_tests!( - group, - mod: group_skip_b, - item: MockMessage { - skip_b: true, - ..MockMessage::default() - }, - fill_expected_buf: |buf: &mut BytesMut, val: MockMessage| val.fill_expected_buf(buf), - parser: MockMessage::parser_directive(false) - ); - extract_tests!( - message, - mod: message_reverse_skip_b, - item: MockMessage { - ordered_encoding: false, - skip_b: true, - ..MockMessage::default() - }, - fill_expected_buf: |buf: &mut BytesMut, val: MockMessage| val.fill_expected_buf(buf), - parser: MockMessage::parser_directive(false) - ); - extract_tests!( - group, - mod: group_reverse_skip_b, - item: MockMessage { - ordered_encoding: false, - skip_b: true, - ..MockMessage::default() - }, - fill_expected_buf: |buf: &mut BytesMut, val: MockMessage| val.fill_expected_buf(buf), - parser: MockMessage::parser_directive(false) - ); // Additional tests #[test] @@ -874,36 +467,5 @@ pub(crate) mod extract_impls { Bytes::new() ); } - - // {a: "AA", b: "B"} and {a: "A", b: "AB"} are different keys! - #[test] - fn fields_are_correctly_separated() { - fn build_input_buf(str_1: &'static str, str_2: &'static str) -> Bytes { - // Prepare the key message - let mut key_msg = BytesMut::new(); - prost::encoding::string::encode(1, &str_1.to_string(), &mut key_msg); - prost::encoding::string::encode(2, &str_2.to_string(), &mut key_msg); - - // Prepare the root message (key is a nested message) - let mut out_msg = BytesMut::new(); - encode_key(1, LengthDelimited, &mut out_msg); - encode_varint(key_msg.len() as u64, &mut out_msg); - out_msg.put(key_msg); - - out_msg.freeze() - } - - let root_key_field_number = 1; - let key_structure = - KeyStructure::Nested([(1, KeyStructure::Scalar), (2, KeyStructure::Scalar)].into()); - - let input_buf_a = build_input_buf("AA", "B"); - let input_buf_b = build_input_buf("A", "AB"); - - assert_ne!( - root_extract(input_buf_a, root_key_field_number, &key_structure).unwrap(), - root_extract(input_buf_b, root_key_field_number, &key_structure).unwrap() - ); - } } } diff --git a/crates/schema-impl/src/schemas_impl.rs b/crates/schema-impl/src/schemas_impl.rs index 6726edc6e..336a12b6f 100644 --- a/crates/schema-impl/src/schemas_impl.rs +++ b/crates/schema-impl/src/schemas_impl.rs @@ -540,12 +540,12 @@ impl SchemasInner { let key = if let Some(index) = method_schemas.input_field_annotated(FieldAnnotation::Key) { - let kind = input_type.get_field(index).unwrap().kind(); - if kind == Kind::String { - Some((index, FieldRemapType::String)) - } else { - Some((index, FieldRemapType::Bytes)) - } + debug_assert_eq!( + input_type.get_field(index).unwrap().kind(), + Kind::String, + "discovery should check whether this field is string or not." + ); + Some((index, FieldRemapType::String)) } else { None }; diff --git a/crates/service-protocol/src/discovery.rs b/crates/service-protocol/src/discovery.rs index e883c0b1a..788e3d125 100644 --- a/crates/service-protocol/src/discovery.rs +++ b/crates/service-protocol/src/discovery.rs @@ -499,13 +499,7 @@ fn resolve_key_field( } }; - // Validate type - if field_descriptor.is_map() { - return Err(ServiceDiscoveryError::BadKeyFieldType( - method_descriptor.clone(), - )); - } - if field_descriptor.is_list() { + if field_descriptor.kind() != Kind::String { return Err(ServiceDiscoveryError::BadKeyFieldType( method_descriptor.clone(), )); diff --git a/crates/storage-query-datafusion/src/inbox/row.rs b/crates/storage-query-datafusion/src/inbox/row.rs index b42629937..344670b7a 100644 --- a/crates/storage-query-datafusion/src/inbox/row.rs +++ b/crates/storage-query-datafusion/src/inbox/row.rs @@ -10,8 +10,6 @@ use super::schema::InboxBuilder; use crate::table_util::format_using; -use crate::udfs::restate_keys; -use restate_schema_api::key::RestateKeyConverter; use restate_storage_api::inbox_table::InboxEntry; use restate_types::identifiers::{InvocationId, WithPartitionKey}; use restate_types::invocation::{ServiceInvocation, ServiceInvocationResponseSink}; @@ -23,7 +21,6 @@ pub(crate) fn append_inbox_row( builder: &mut InboxBuilder, output: &mut String, inbox_entry: InboxEntry, - resolver: impl RestateKeyConverter, ) { let InboxEntry { inbox_sequence_number, @@ -43,35 +40,7 @@ pub(crate) fn append_inbox_row( row.service_name(&fid.service_id.service_name); row.method(&method_name); - row.service_key(&fid.service_id.key); - if row.is_service_key_utf8_defined() { - if let Some(utf8) = restate_keys::try_decode_restate_key_as_utf8(&fid.service_id.key) { - row.service_key_utf8(utf8); - } - } - if row.is_service_key_int32_defined() { - if let Some(key) = restate_keys::try_decode_restate_key_as_int32(&fid.service_id.key) { - row.service_key_int32(key); - } - } - if row.is_service_key_uuid_defined() { - let mut buffer = Uuid::encode_buffer(); - if let Some(key) = - restate_keys::try_decode_restate_key_as_uuid(&fid.service_id.key, &mut buffer) - { - row.service_key_uuid(key); - } - } - if row.is_service_key_json_defined() { - if let Some(key) = restate_keys::try_decode_restate_key_as_json( - &fid.service_id.service_name, - &fid.service_id.key, - output, - resolver, - ) { - row.service_key_json(key); - } - } + row.service_key(std::str::from_utf8(&fid.service_id.key).expect("The key must be a string!")); if row.is_id_defined() { row.id(format_using(output, &InvocationId::from(&fid))); diff --git a/crates/storage-query-datafusion/src/inbox/schema.rs b/crates/storage-query-datafusion/src/inbox/schema.rs index 1cd0fe87d..4d6eb2bb6 100644 --- a/crates/storage-query-datafusion/src/inbox/schema.rs +++ b/crates/storage-query-datafusion/src/inbox/schema.rs @@ -20,11 +20,7 @@ define_table!(inbox( service_name: DataType::LargeUtf8, method: DataType::LargeUtf8, - service_key: DataType::LargeBinary, - service_key_utf8: DataType::LargeUtf8, - service_key_int32: DataType::Int32, - service_key_uuid: DataType::LargeUtf8, - service_key_json: DataType::LargeUtf8, + service_key: DataType::LargeUtf8, id: DataType::LargeUtf8, diff --git a/crates/storage-query-datafusion/src/inbox/table.rs b/crates/storage-query-datafusion/src/inbox/table.rs index 021ea7322..76a030d9d 100644 --- a/crates/storage-query-datafusion/src/inbox/table.rs +++ b/crates/storage-query-datafusion/src/inbox/table.rs @@ -23,7 +23,6 @@ use datafusion::physical_plan::stream::RecordBatchReceiverStream; use datafusion::physical_plan::SendableRecordBatchStream; pub use datafusion_expr::UserDefinedLogicalNode; use futures::StreamExt; -use restate_schema_api::key::RestateKeyConverter; use restate_storage_api::inbox_table::{InboxEntry, InboxTable}; use restate_storage_api::GetStream; use restate_storage_rocksdb::RocksDBStorage; @@ -33,12 +32,8 @@ use tokio::sync::mpsc::Sender; pub(crate) fn register_self( ctx: &QueryContext, storage: RocksDBStorage, - resolver: impl RestateKeyConverter + Send + Sync + Debug + Clone + 'static, ) -> datafusion::common::Result<()> { - let table = GenericTableProvider::new( - InboxBuilder::schema(), - Arc::new(InboxScanner(storage, resolver)), - ); + let table = GenericTableProvider::new(InboxBuilder::schema(), Arc::new(InboxScanner(storage))); ctx.as_ref() .register_table("sys_inbox", Arc::new(table)) @@ -46,11 +41,9 @@ pub(crate) fn register_self( } #[derive(Debug, Clone)] -struct InboxScanner(RocksDBStorage, R); +struct InboxScanner(RocksDBStorage); -impl RangeScanner - for InboxScanner -{ +impl RangeScanner for InboxScanner { fn scan( &self, range: RangeInclusive, @@ -58,13 +51,12 @@ impl RangeScanne ) -> SendableRecordBatchStream { let db = self.0.clone(); let schema = projection.clone(); - let resolver = self.1.clone(); let mut stream_builder = RecordBatchReceiverStream::builder(projection, 16); let tx = stream_builder.tx(); let background_task = async move { let mut transaction = db.transaction(); let rows = transaction.all_inboxes(range); - for_each_state(schema, tx, rows, resolver).await; + for_each_state(schema, tx, rows).await; }; stream_builder.spawn(background_task); stream_builder.build() @@ -75,12 +67,11 @@ async fn for_each_state( schema: SchemaRef, tx: Sender>, mut rows: GetStream<'_, InboxEntry>, - resolver: impl RestateKeyConverter + Clone, ) { let mut builder = InboxBuilder::new(schema.clone()); let mut temp = String::new(); while let Some(Ok(row)) = rows.next().await { - append_inbox_row(&mut builder, &mut temp, row, resolver.clone()); + append_inbox_row(&mut builder, &mut temp, row); if builder.full() { let batch = builder.finish(); if tx.send(Ok(batch)).await.is_err() { diff --git a/crates/storage-query-datafusion/src/invocation_state/row.rs b/crates/storage-query-datafusion/src/invocation_state/row.rs index 68ed8dfb2..56cb95064 100644 --- a/crates/storage-query-datafusion/src/invocation_state/row.rs +++ b/crates/storage-query-datafusion/src/invocation_state/row.rs @@ -10,20 +10,15 @@ use crate::invocation_state::schema::StateBuilder; use crate::table_util::format_using; -use crate::udfs::restate_keys; use restate_invoker_api::InvocationStatusReport; -use restate_schema_api::key::RestateKeyConverter; use restate_types::identifiers::{InvocationId, WithPartitionKey}; use restate_types::time::MillisSinceEpoch; -use uuid::Uuid; - #[inline] pub(crate) fn append_state_row( builder: &mut StateBuilder, output: &mut String, status_row: InvocationStatusReport, - resolver: impl RestateKeyConverter, ) { let mut row = builder.row(); @@ -31,39 +26,9 @@ pub(crate) fn append_state_row( row.partition_key(invocation_id.service_id.partition_key()); row.service(&invocation_id.service_id.service_name); - row.service_key(&invocation_id.service_id.key); - if row.is_service_key_utf8_defined() { - if let Some(utf8) = - restate_keys::try_decode_restate_key_as_utf8(&invocation_id.service_id.key) - { - row.service_key_utf8(utf8); - } - } - if row.is_service_key_int32_defined() { - if let Some(key) = - restate_keys::try_decode_restate_key_as_int32(&invocation_id.service_id.key) - { - row.service_key_int32(key); - } - } - if row.is_service_key_uuid_defined() { - let mut buffer = Uuid::encode_buffer(); - if let Some(key) = - restate_keys::try_decode_restate_key_as_uuid(&invocation_id.service_id.key, &mut buffer) - { - row.service_key_uuid(key); - } - } - if row.is_service_key_json_defined() { - if let Some(key) = restate_keys::try_decode_restate_key_as_json( - &invocation_id.service_id.service_name, - &invocation_id.service_id.key, - output, - resolver, - ) { - row.service_key_json(key); - } - } + row.service_key( + std::str::from_utf8(&invocation_id.service_id.key).expect("The key must be a string!"), + ); if row.is_id_defined() { row.id(format_using(output, &InvocationId::from(invocation_id))); } diff --git a/crates/storage-query-datafusion/src/invocation_state/schema.rs b/crates/storage-query-datafusion/src/invocation_state/schema.rs index 8762195d5..acb617a45 100644 --- a/crates/storage-query-datafusion/src/invocation_state/schema.rs +++ b/crates/storage-query-datafusion/src/invocation_state/schema.rs @@ -17,11 +17,7 @@ use datafusion::arrow::datatypes::DataType; define_table!(state( partition_key: DataType::UInt64, service: DataType::LargeUtf8, - service_key: DataType::LargeBinary, - service_key_utf8: DataType::LargeUtf8, - service_key_int32: DataType::Int32, - service_key_uuid: DataType::LargeUtf8, - service_key_json: DataType::LargeUtf8, + service_key: DataType::LargeUtf8, id: DataType::LargeUtf8, in_flight: DataType::Boolean, retry_count: DataType::UInt64, diff --git a/crates/storage-query-datafusion/src/invocation_state/table.rs b/crates/storage-query-datafusion/src/invocation_state/table.rs index ee01c07cf..7f85fca9b 100644 --- a/crates/storage-query-datafusion/src/invocation_state/table.rs +++ b/crates/storage-query-datafusion/src/invocation_state/table.rs @@ -23,19 +23,15 @@ use datafusion::physical_plan::stream::RecordBatchReceiverStream; use datafusion::physical_plan::SendableRecordBatchStream; pub use datafusion_expr::UserDefinedLogicalNode; use restate_invoker_api::{InvocationStatusReport, StatusHandle}; -use restate_schema_api::key::RestateKeyConverter; use restate_types::identifiers::{PartitionKey, WithPartitionKey}; use tokio::sync::mpsc::Sender; pub(crate) fn register_self( ctx: &QueryContext, status: impl StatusHandle + Send + Sync + Debug + Clone + 'static, - resolver: impl RestateKeyConverter + Send + Sync + Debug + Clone + 'static, ) -> datafusion::common::Result<()> { - let status_table = GenericTableProvider::new( - StateBuilder::schema(), - Arc::new(StatusScanner(status, resolver)), - ); + let status_table = + GenericTableProvider::new(StateBuilder::schema(), Arc::new(StatusScanner(status))); ctx.as_ref() .register_table("sys_invocation_state", Arc::new(status_table)) @@ -43,13 +39,9 @@ pub(crate) fn register_self( } #[derive(Debug, Clone)] -struct StatusScanner(S, R); +struct StatusScanner(S); -impl< - S: StatusHandle + Send + Sync + Debug + Clone + 'static, - R: RestateKeyConverter + Send + Sync + Debug + Clone + 'static, - > RangeScanner for StatusScanner -{ +impl RangeScanner for StatusScanner { fn scan( &self, range: RangeInclusive, @@ -57,12 +49,11 @@ impl< ) -> SendableRecordBatchStream { let status = self.0.clone(); let schema = projection.clone(); - let resolver = self.1.clone(); let mut stream_builder = RecordBatchReceiverStream::builder(projection, 16); let tx = stream_builder.tx(); let background_task = async move { let rows = status.read_status(range).await; - for_each_state(schema, tx, rows, resolver).await; + for_each_state(schema, tx, rows).await; }; stream_builder.spawn(background_task); stream_builder.build() @@ -73,7 +64,6 @@ async fn for_each_state<'a, I>( schema: SchemaRef, tx: Sender>, rows: I, - resolver: impl RestateKeyConverter + Clone, ) where I: Iterator + 'a, { @@ -83,7 +73,7 @@ async fn for_each_state<'a, I>( // need to be ordered by partition key for symmetric joins rows.sort_unstable_by_key(|row| row.full_invocation_id().service_id.partition_key()); for row in rows { - append_state_row(&mut builder, &mut temp, row, resolver.clone()); + append_state_row(&mut builder, &mut temp, row); if builder.full() { let batch = builder.finish(); if tx.send(Ok(batch)).await.is_err() { diff --git a/crates/storage-query-datafusion/src/journal/row.rs b/crates/storage-query-datafusion/src/journal/row.rs index 2d384fac5..570d6f56a 100644 --- a/crates/storage-query-datafusion/src/journal/row.rs +++ b/crates/storage-query-datafusion/src/journal/row.rs @@ -9,9 +9,7 @@ // by the Apache License, Version 2.0. use crate::journal::schema::JournalBuilder; -use crate::udfs::restate_keys; -use restate_schema_api::key::RestateKeyConverter; use restate_service_protocol::codec::ProtobufRawEntryCodec; use restate_storage_api::journal_table::JournalEntry; @@ -22,48 +20,20 @@ use restate_types::journal::raw::{EntryHeader, RawEntryCodec}; use crate::table_util::format_using; use restate_types::journal::{BackgroundInvokeEntry, Entry, InvokeEntry, InvokeRequest}; -use uuid::Uuid; #[inline] pub(crate) fn append_journal_row( builder: &mut JournalBuilder, output: &mut String, journal_row: OwnedJournalRow, - resolver: impl RestateKeyConverter + Clone, ) { let mut row = builder.row(); row.partition_key(journal_row.partition_key); row.service(&journal_row.service); - row.service_key(&journal_row.service_key); - if row.is_service_key_utf8_defined() { - if let Some(utf8) = restate_keys::try_decode_restate_key_as_utf8(&journal_row.service_key) { - row.service_key_utf8(utf8); - } - } - if row.is_service_key_int32_defined() { - if let Some(key) = restate_keys::try_decode_restate_key_as_int32(&journal_row.service_key) { - row.service_key_int32(key); - } - } - if row.is_service_key_uuid_defined() { - let mut buffer = Uuid::encode_buffer(); - if let Some(key) = - restate_keys::try_decode_restate_key_as_uuid(&journal_row.service_key, &mut buffer) - { - row.service_key_uuid(key); - } - } - if row.is_service_key_json_defined() { - if let Some(key) = restate_keys::try_decode_restate_key_as_json( - &journal_row.service, - &journal_row.service_key, - output, - resolver.clone(), - ) { - row.service_key_json(key); - } - } + row.service_key( + std::str::from_utf8(&journal_row.service_key).expect("The key must be a string!"), + ); row.index(journal_row.journal_index); @@ -81,21 +51,13 @@ pub(crate) fn append_journal_row( .. } | EnrichedEntryHeader::BackgroundInvoke { resolution_result } => { - row.invoked_service_key(&resolution_result.service_key); + row.invoked_service_key( + std::str::from_utf8(&resolution_result.service_key) + .expect("The key must be a string!"), + ); row.invoked_service(&resolution_result.service_name); - if row.is_invoked_service_key_json_defined() { - if let Some(key) = restate_keys::try_decode_restate_key_as_json( - &resolution_result.service_name, - &resolution_result.service_key, - output, - resolver, - ) { - row.invoked_service_key_json(key) - } - } - if row.is_invoked_id_defined() { let partition_key = ServiceId::new( resolution_result.service_name.clone(), diff --git a/crates/storage-query-datafusion/src/journal/schema.rs b/crates/storage-query-datafusion/src/journal/schema.rs index c0f6442b7..5fc665689 100644 --- a/crates/storage-query-datafusion/src/journal/schema.rs +++ b/crates/storage-query-datafusion/src/journal/schema.rs @@ -17,17 +17,12 @@ use datafusion::arrow::datatypes::DataType; define_table!(journal( partition_key: DataType::UInt64, service: DataType::LargeUtf8, - service_key: DataType::LargeBinary, - service_key_utf8: DataType::LargeUtf8, - service_key_int32: DataType::Int32, - service_key_uuid: DataType::LargeUtf8, - service_key_json: DataType::LargeUtf8, + service_key: DataType::LargeUtf8, index: DataType::UInt32, entry_type: DataType::LargeUtf8, completed: DataType::Boolean, invoked_id: DataType::LargeUtf8, invoked_service: DataType::LargeUtf8, invoked_method: DataType::LargeUtf8, - invoked_service_key: DataType::LargeBinary, - invoked_service_key_json: DataType::LargeUtf8, + invoked_service_key: DataType::LargeUtf8, )); diff --git a/crates/storage-query-datafusion/src/journal/table.rs b/crates/storage-query-datafusion/src/journal/table.rs index c4141a6d6..bf9eaee81 100644 --- a/crates/storage-query-datafusion/src/journal/table.rs +++ b/crates/storage-query-datafusion/src/journal/table.rs @@ -22,7 +22,6 @@ use crate::journal::schema::JournalBuilder; use datafusion::physical_plan::stream::RecordBatchReceiverStream; use datafusion::physical_plan::SendableRecordBatchStream; pub use datafusion_expr::UserDefinedLogicalNode; -use restate_schema_api::key::RestateKeyConverter; use restate_storage_rocksdb::journal_table::OwnedJournalRow; use restate_storage_rocksdb::RocksDBStorage; use restate_types::identifiers::PartitionKey; @@ -31,12 +30,9 @@ use tokio::sync::mpsc::Sender; pub(crate) fn register_self( ctx: &QueryContext, storage: RocksDBStorage, - resolver: impl RestateKeyConverter + Send + Sync + Debug + Clone + 'static, ) -> datafusion::common::Result<()> { - let journal_table = GenericTableProvider::new( - JournalBuilder::schema(), - Arc::new(JournalScanner(storage, resolver)), - ); + let journal_table = + GenericTableProvider::new(JournalBuilder::schema(), Arc::new(JournalScanner(storage))); ctx.as_ref() .register_table("sys_journal", Arc::new(journal_table)) @@ -44,11 +40,9 @@ pub(crate) fn register_self( } #[derive(Debug, Clone)] -struct JournalScanner(RocksDBStorage, R); +struct JournalScanner(RocksDBStorage); -impl RangeScanner - for JournalScanner -{ +impl RangeScanner for JournalScanner { fn scan( &self, range: RangeInclusive, @@ -56,12 +50,11 @@ impl RangeScanne ) -> SendableRecordBatchStream { let db = self.0.clone(); let schema = projection.clone(); - let resolver = self.1.clone(); let mut stream_builder = RecordBatchReceiverStream::builder(projection, 16); let tx = stream_builder.tx(); let background_task = move || { let rows = db.all_journal(range); - for_each_journal(schema, tx, rows, resolver); + for_each_journal(schema, tx, rows); }; stream_builder.spawn_blocking(background_task); stream_builder.build() @@ -72,14 +65,13 @@ fn for_each_journal<'a, I>( schema: SchemaRef, tx: Sender>, rows: I, - resolver: impl RestateKeyConverter + Clone, ) where I: Iterator + 'a, { let mut builder = JournalBuilder::new(schema.clone()); let mut temp = String::new(); for row in rows { - append_journal_row(&mut builder, &mut temp, row, resolver.clone()); + append_journal_row(&mut builder, &mut temp, row); if builder.full() { let batch = builder.finish(); if tx.blocking_send(Ok(batch)).is_err() { diff --git a/crates/storage-query-datafusion/src/lib.rs b/crates/storage-query-datafusion/src/lib.rs index 7493aea2d..ce2771a38 100644 --- a/crates/storage-query-datafusion/src/lib.rs +++ b/crates/storage-query-datafusion/src/lib.rs @@ -20,6 +20,5 @@ mod state; mod status; mod table_macro; mod table_util; -mod udfs; pub use crate::options::{BuildError, Options, OptionsBuilder, OptionsBuilderError}; diff --git a/crates/storage-query-datafusion/src/options.rs b/crates/storage-query-datafusion/src/options.rs index c2b05d36c..dc815ef56 100644 --- a/crates/storage-query-datafusion/src/options.rs +++ b/crates/storage-query-datafusion/src/options.rs @@ -12,7 +12,6 @@ use crate::context::QueryContext; use codederror::CodedError; use datafusion::error::DataFusionError; use restate_invoker_api::StatusHandle; -use restate_schema_api::key::RestateKeyConverter; use restate_storage_rocksdb::RocksDBStorage; use std::fmt::Debug; @@ -45,7 +44,6 @@ impl Options { pub fn build( self, rocksdb: RocksDBStorage, - schema: impl RestateKeyConverter + Sync + Send + Clone + Debug + 'static, status: impl StatusHandle + Send + Sync + Debug + Clone + 'static, ) -> Result { let Options { @@ -55,11 +53,11 @@ impl Options { } = self; let ctx = QueryContext::new(memory_limit, temp_folder, query_parallelism); - crate::status::register_self(&ctx, rocksdb.clone(), schema.clone())?; - crate::state::register_self(&ctx, rocksdb.clone(), schema.clone())?; - crate::journal::register_self(&ctx, rocksdb.clone(), schema.clone())?; - crate::invocation_state::register_self(&ctx, status, schema.clone())?; - crate::inbox::register_self(&ctx, rocksdb, schema)?; + crate::status::register_self(&ctx, rocksdb.clone())?; + crate::state::register_self(&ctx, rocksdb.clone())?; + crate::journal::register_self(&ctx, rocksdb.clone())?; + crate::invocation_state::register_self(&ctx, status)?; + crate::inbox::register_self(&ctx, rocksdb)?; Ok(ctx) } diff --git a/crates/storage-query-datafusion/src/state/row.rs b/crates/storage-query-datafusion/src/state/row.rs index 6f385faea..eff20ba38 100644 --- a/crates/storage-query-datafusion/src/state/row.rs +++ b/crates/storage-query-datafusion/src/state/row.rs @@ -9,18 +9,10 @@ // by the Apache License, Version 2.0. use crate::state::schema::StateBuilder; -use crate::udfs::restate_keys; -use restate_schema_api::key::RestateKeyConverter; use restate_storage_rocksdb::state_table::OwnedStateRow; -use uuid::Uuid; #[inline] -pub(crate) fn append_state_row( - builder: &mut StateBuilder, - output: &mut String, - state_row: OwnedStateRow, - resolver: impl RestateKeyConverter, -) { +pub(crate) fn append_state_row(builder: &mut StateBuilder, state_row: OwnedStateRow) { let OwnedStateRow { partition_key, service, @@ -32,30 +24,7 @@ pub(crate) fn append_state_row( let mut row = builder.row(); row.partition_key(partition_key); row.service(&service); - row.service_key(&service_key); - if row.is_service_key_utf8_defined() { - if let Some(utf8) = restate_keys::try_decode_restate_key_as_utf8(&service_key) { - row.service_key_utf8(utf8); - } - } - if row.is_service_key_int32_defined() { - if let Some(key) = restate_keys::try_decode_restate_key_as_int32(&service_key) { - row.service_key_int32(key); - } - } - if row.is_service_key_uuid_defined() { - let mut buffer = Uuid::encode_buffer(); - if let Some(key) = restate_keys::try_decode_restate_key_as_uuid(&service_key, &mut buffer) { - row.service_key_uuid(key); - } - } - if row.is_service_key_json_defined() { - if let Some(key) = - restate_keys::try_decode_restate_key_as_json(&service, &service_key, output, resolver) - { - row.service_key_json(key); - } - } + row.service_key(std::str::from_utf8(&service_key).expect("The key must be a string!")); if row.is_key_defined() { if let Ok(str) = std::str::from_utf8(&state_key) { row.key(str); diff --git a/crates/storage-query-datafusion/src/state/schema.rs b/crates/storage-query-datafusion/src/state/schema.rs index a78ecdff6..4c502fc66 100644 --- a/crates/storage-query-datafusion/src/state/schema.rs +++ b/crates/storage-query-datafusion/src/state/schema.rs @@ -17,11 +17,7 @@ use datafusion::arrow::datatypes::DataType; define_table!(state( partition_key: DataType::UInt64, service: DataType::LargeUtf8, - service_key: DataType::LargeBinary, - service_key_utf8: DataType::LargeUtf8, - service_key_int32: DataType::Int32, - service_key_uuid: DataType::LargeUtf8, - service_key_json: DataType::LargeUtf8, + service_key: DataType::LargeUtf8, key: DataType::LargeUtf8, value_utf8: DataType::LargeUtf8, value: DataType::LargeBinary, diff --git a/crates/storage-query-datafusion/src/state/table.rs b/crates/storage-query-datafusion/src/state/table.rs index 3349df823..fd1f2ae3e 100644 --- a/crates/storage-query-datafusion/src/state/table.rs +++ b/crates/storage-query-datafusion/src/state/table.rs @@ -22,7 +22,6 @@ use crate::state::schema::StateBuilder; use datafusion::physical_plan::stream::RecordBatchReceiverStream; use datafusion::physical_plan::SendableRecordBatchStream; pub use datafusion_expr::UserDefinedLogicalNode; -use restate_schema_api::key::RestateKeyConverter; use restate_storage_rocksdb::state_table::OwnedStateRow; use restate_storage_rocksdb::RocksDBStorage; use restate_types::identifiers::PartitionKey; @@ -31,12 +30,8 @@ use tokio::sync::mpsc::Sender; pub(crate) fn register_self( ctx: &QueryContext, storage: RocksDBStorage, - resolver: impl RestateKeyConverter + Send + Sync + Debug + Clone + 'static, ) -> datafusion::common::Result<()> { - let table = GenericTableProvider::new( - StateBuilder::schema(), - Arc::new(StateScanner(storage, resolver)), - ); + let table = GenericTableProvider::new(StateBuilder::schema(), Arc::new(StateScanner(storage))); ctx.as_ref() .register_table("state", Arc::new(table)) @@ -44,11 +39,9 @@ pub(crate) fn register_self( } #[derive(Debug, Clone)] -struct StateScanner(RocksDBStorage, R); +struct StateScanner(RocksDBStorage); -impl RangeScanner - for StateScanner -{ +impl RangeScanner for StateScanner { fn scan( &self, range: RangeInclusive, @@ -56,12 +49,11 @@ impl RangeScanne ) -> SendableRecordBatchStream { let db = self.0.clone(); let schema = projection.clone(); - let resolver = self.1.clone(); let mut stream_builder = RecordBatchReceiverStream::builder(projection, 16); let tx = stream_builder.tx(); let background_task = move || { let rows = db.all_states(range); - for_each_state(schema, tx, rows, resolver); + for_each_state(schema, tx, rows); }; stream_builder.spawn_blocking(background_task); stream_builder.build() @@ -72,14 +64,12 @@ fn for_each_state<'a, I>( schema: SchemaRef, tx: Sender>, rows: I, - resolver: impl RestateKeyConverter + Clone, ) where I: Iterator + 'a, { let mut builder = StateBuilder::new(schema.clone()); - let mut temp = String::new(); for row in rows { - append_state_row(&mut builder, &mut temp, row, resolver.clone()); + append_state_row(&mut builder, row); if builder.full() { let batch = builder.finish(); if tx.blocking_send(Ok(batch)).is_err() { diff --git a/crates/storage-query-datafusion/src/status/row.rs b/crates/storage-query-datafusion/src/status/row.rs index 1a19957e6..900a2b4ef 100644 --- a/crates/storage-query-datafusion/src/status/row.rs +++ b/crates/storage-query-datafusion/src/status/row.rs @@ -10,8 +10,6 @@ use crate::status::schema::{StatusBuilder, StatusRowBuilder}; use crate::table_util::format_using; -use crate::udfs::restate_keys; -use restate_schema_api::key::RestateKeyConverter; use restate_storage_api::status_table::{ InvocationMetadata, InvocationStatus, JournalMetadata, StatusTimestamps, }; @@ -19,48 +17,19 @@ use restate_storage_rocksdb::status_table::OwnedStatusRow; use restate_types::identifiers::InvocationId; use restate_types::invocation::ServiceInvocationResponseSink; -use uuid::Uuid; - #[inline] pub(crate) fn append_status_row( builder: &mut StatusBuilder, output: &mut String, status_row: OwnedStatusRow, - resolver: impl RestateKeyConverter, ) { let mut row = builder.row(); row.partition_key(status_row.partition_key); row.service(&status_row.service); - row.service_key(&status_row.service_key); - if row.is_service_key_utf8_defined() { - if let Some(utf8) = restate_keys::try_decode_restate_key_as_utf8(&status_row.service_key) { - row.service_key_utf8(utf8); - } - } - if row.is_service_key_int32_defined() { - if let Some(key) = restate_keys::try_decode_restate_key_as_int32(&status_row.service_key) { - row.service_key_int32(key); - } - } - if row.is_service_key_uuid_defined() { - let mut buffer = Uuid::encode_buffer(); - if let Some(key) = - restate_keys::try_decode_restate_key_as_uuid(&status_row.service_key, &mut buffer) - { - row.service_key_uuid(key); - } - } - if row.is_service_key_json_defined() { - if let Some(key) = restate_keys::try_decode_restate_key_as_json( - &status_row.service, - &status_row.service_key, - output, - resolver, - ) { - row.service_key_json(key); - } - } + row.service_key( + std::str::from_utf8(&status_row.service_key).expect("The key must be a string!"), + ); // Invocation id if row.is_id_defined() { diff --git a/crates/storage-query-datafusion/src/status/schema.rs b/crates/storage-query-datafusion/src/status/schema.rs index 11188aa25..00a4d6625 100644 --- a/crates/storage-query-datafusion/src/status/schema.rs +++ b/crates/storage-query-datafusion/src/status/schema.rs @@ -18,11 +18,7 @@ define_table!(status( partition_key: DataType::UInt64, service: DataType::LargeUtf8, method: DataType::LargeUtf8, - service_key: DataType::LargeBinary, - service_key_utf8: DataType::LargeUtf8, - service_key_int32: DataType::Int32, - service_key_uuid: DataType::LargeUtf8, - service_key_json: DataType::LargeUtf8, + service_key: DataType::LargeUtf8, status: DataType::LargeUtf8, id: DataType::LargeUtf8, invoked_by: DataType::LargeUtf8, diff --git a/crates/storage-query-datafusion/src/status/table.rs b/crates/storage-query-datafusion/src/status/table.rs index f8bb18fd2..26c4ac11c 100644 --- a/crates/storage-query-datafusion/src/status/table.rs +++ b/crates/storage-query-datafusion/src/status/table.rs @@ -22,7 +22,6 @@ use crate::status::schema::StatusBuilder; use datafusion::physical_plan::stream::RecordBatchReceiverStream; use datafusion::physical_plan::SendableRecordBatchStream; pub use datafusion_expr::UserDefinedLogicalNode; -use restate_schema_api::key::RestateKeyConverter; use restate_storage_rocksdb::status_table::OwnedStatusRow; use restate_storage_rocksdb::RocksDBStorage; use restate_types::identifiers::PartitionKey; @@ -31,12 +30,9 @@ use tokio::sync::mpsc::Sender; pub(crate) fn register_self( ctx: &QueryContext, storage: RocksDBStorage, - resolver: impl RestateKeyConverter + Send + Sync + Debug + Clone + 'static, ) -> datafusion::common::Result<()> { - let status_table = GenericTableProvider::new( - StatusBuilder::schema(), - Arc::new(StatusScanner(storage, resolver)), - ); + let status_table = + GenericTableProvider::new(StatusBuilder::schema(), Arc::new(StatusScanner(storage))); ctx.as_ref() .register_table("sys_status", Arc::new(status_table)) @@ -44,11 +40,9 @@ pub(crate) fn register_self( } #[derive(Debug, Clone)] -struct StatusScanner(RocksDBStorage, R); +struct StatusScanner(RocksDBStorage); -impl RangeScanner - for StatusScanner -{ +impl RangeScanner for StatusScanner { fn scan( &self, range: RangeInclusive, @@ -56,12 +50,11 @@ impl RangeScanne ) -> SendableRecordBatchStream { let db = self.0.clone(); let schema = projection.clone(); - let resolver = self.1.clone(); let mut stream_builder = RecordBatchReceiverStream::builder(projection, 16); let tx = stream_builder.tx(); let background_task = move || { let rows = db.all_status(range); - for_each_status(schema, tx, rows, resolver); + for_each_status(schema, tx, rows); }; stream_builder.spawn_blocking(background_task); stream_builder.build() @@ -72,14 +65,13 @@ fn for_each_status<'a, I>( schema: SchemaRef, tx: Sender>, rows: I, - resolver: impl RestateKeyConverter + Clone, ) where I: Iterator + 'a, { let mut builder = StatusBuilder::new(schema.clone()); let mut temp = String::new(); for row in rows { - append_status_row(&mut builder, &mut temp, row, resolver.clone()); + append_status_row(&mut builder, &mut temp, row); if builder.full() { let batch = builder.finish(); if tx.blocking_send(Ok(batch)).is_err() { diff --git a/crates/storage-query-datafusion/src/udfs/mod.rs b/crates/storage-query-datafusion/src/udfs/mod.rs deleted file mode 100644 index e55a7942c..000000000 --- a/crates/storage-query-datafusion/src/udfs/mod.rs +++ /dev/null @@ -1,11 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH. -// All rights reserved. -// -// Use of this software is governed by the Business Source License -// included in the LICENSE file. -// -// As of the Change Date specified in that file, in accordance with -// the Business Source License, use of this software will be governed -// by the Apache License, Version 2.0. - -pub(crate) mod restate_keys; diff --git a/crates/storage-query-datafusion/src/udfs/restate_keys.rs b/crates/storage-query-datafusion/src/udfs/restate_keys.rs deleted file mode 100644 index 8d4158e83..000000000 --- a/crates/storage-query-datafusion/src/udfs/restate_keys.rs +++ /dev/null @@ -1,57 +0,0 @@ -// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH. -// All rights reserved. -// -// Use of this software is governed by the Business Source License -// included in the LICENSE file. -// -// As of the Change Date specified in that file, in accordance with -// the Business Source License, use of this software will be governed -// by the Apache License, Version 2.0. - -use restate_schema_api::key::RestateKeyConverter; -use std::fmt::Write; -use uuid::Uuid; - -#[inline] -pub(crate) fn try_decode_restate_key_as_utf8(mut key_slice: &[u8]) -> Option<&str> { - let len = prost::encoding::decode_varint(&mut key_slice).ok()?; - if len != key_slice.len() as u64 { - return None; - } - std::str::from_utf8(key_slice).ok() -} - -#[inline] -pub(crate) fn try_decode_restate_key_as_int32(mut key_slice: &[u8]) -> Option { - let value = prost::encoding::decode_varint(&mut key_slice).ok()?; - i32::try_from(value).ok() -} - -#[inline] -pub(crate) fn try_decode_restate_key_as_uuid<'a>( - key_slice: &[u8], - temp_buffer: &'a mut [u8], -) -> Option<&'a str> { - if key_slice.len() != 16 { - return None; - } - let uuid = Uuid::from_slice(key_slice).ok()?; - Some(uuid.simple().encode_lower(temp_buffer)) -} - -#[inline] -pub(crate) fn try_decode_restate_key_as_json<'a>( - service_name: &str, - key_slice: &[u8], - output: &'a mut String, - resolver: impl RestateKeyConverter, -) -> Option<&'a str> { - resolver - .key_to_json(service_name, key_slice) - .map(|value| { - output.clear(); - write!(output, "{}", value).expect("Error occurred while trying to write in String"); - output.as_str() - }) - .ok() -} diff --git a/crates/worker/src/lib.rs b/crates/worker/src/lib.rs index 6101fa60d..7fc763a42 100644 --- a/crates/worker/src/lib.rs +++ b/crates/worker/src/lib.rs @@ -297,11 +297,8 @@ impl Worker { schemas.clone(), ); - let query_context = storage_query_datafusion.build( - rocksdb.clone(), - schemas.clone(), - invoker.status_reader(), - )?; + let query_context = + storage_query_datafusion.build(rocksdb.clone(), invoker.status_reader())?; let storage_query_http = storage_query_http.build(query_context.clone()); let storage_query_postgres = storage_query_postgres.build(query_context);