diff --git a/cli/src/clients/admin_client.rs b/cli/src/clients/admin_client.rs index 7b15658e7..008b55c72 100644 --- a/cli/src/clients/admin_client.rs +++ b/cli/src/clients/admin_client.rs @@ -27,8 +27,8 @@ use crate::clients::AdminClientInterface; use super::errors::ApiError; /// Min/max supported admin API versions -pub const MIN_ADMIN_API_VERSION: AdminApiVersion = AdminApiVersion::V1; -pub const MAX_ADMIN_API_VERSION: AdminApiVersion = AdminApiVersion::V1; +pub const MIN_ADMIN_API_VERSION: AdminApiVersion = AdminApiVersion::V2; +pub const MAX_ADMIN_API_VERSION: AdminApiVersion = AdminApiVersion::V2; #[derive(Error, Debug)] #[error(transparent)] @@ -186,10 +186,13 @@ impl AdminClient { .request(method, path) .timeout(self.request_timeout); - match self.bearer_token.as_deref() { + let request_builder = match self.bearer_token.as_deref() { Some(token) => request_builder.bearer_auth(token), None => request_builder, - } + }; + + let (api_version_header, api_version) = self.admin_api_version.into(); + request_builder.header(api_version_header, api_version) } /// Prepare a request builder that encodes the body as JSON. diff --git a/cli/src/clients/datafusion_helpers.rs b/cli/src/clients/datafusion_helpers.rs index d309841d8..c93dee9f8 100644 --- a/cli/src/clients/datafusion_helpers.rs +++ b/cli/src/clients/datafusion_helpers.rs @@ -16,7 +16,7 @@ use std::str::FromStr; use anyhow::Result; use arrow::array::{Array, ArrayAccessor, AsArray, StringArray}; -use arrow::datatypes::{ArrowTemporalType, Date64Type}; +use arrow::datatypes::ArrowTemporalType; use arrow::record_batch::RecordBatch; use arrow_convert::{ArrowDeserialize, ArrowField}; use bytes::Bytes; @@ -111,10 +111,17 @@ fn value_as_u64_opt(batch: &RecordBatch, col: usize, row: usize) -> Option } fn value_as_dt_opt(batch: &RecordBatch, col: usize, row: usize) -> Option> { - batch - .column(col) - .as_primitive::() - .value_as_local_datetime_opt(row) + let col = batch.column(col); + match col.data_type() { + arrow::datatypes::DataType::Timestamp(arrow::datatypes::TimeUnit::Millisecond, _) => col + .as_primitive::() + .value_as_local_datetime_opt(row), + // older versions of restate used Date64 instead of TimestampMillisecond + arrow::datatypes::DataType::Date64 => col + .as_primitive::() + .value_as_local_datetime_opt(row), + _ => panic!("Column is not a timestamp"), + } } #[derive(ValueEnum, Copy, Clone, Eq, Hash, PartialEq, Debug, Default)] @@ -515,11 +522,7 @@ pub async fn get_service_status( .column(2) .as_primitive::() .value(i); - let oldest_at = batch - .column(3) - .as_primitive::() - .value_as_local_datetime_opt(i) - .unwrap(); + let oldest_at = value_as_dt_opt(&batch, 3, i).unwrap(); let oldest_invocation = batch.column(4).as_string::().value_string(i); @@ -744,23 +747,66 @@ impl From for DateTime { } } +pub static TIMEZONE_UTC: std::sync::LazyLock> = + std::sync::LazyLock::new(|| std::sync::Arc::from("+00:00")); + impl arrow_convert::field::ArrowField for RestateDateTime { type Type = Self; #[inline] fn data_type() -> arrow::datatypes::DataType { - arrow::datatypes::DataType::Date64 + arrow::datatypes::DataType::Timestamp( + arrow::datatypes::TimeUnit::Millisecond, + Some(TIMEZONE_UTC.clone()), + ) } } impl arrow_convert::deserialize::ArrowDeserialize for RestateDateTime { - type ArrayType = arrow::array::Date64Array; + type ArrayType = TimestampMillisecondArray; #[inline] fn arrow_deserialize(v: Option) -> Option { - v.and_then(arrow::temporal_conversions::as_datetime::) - .map(|naive| Local.from_utc_datetime(&naive)) - .map(RestateDateTime) + let timestamp = arrow::temporal_conversions::as_datetime::< + arrow::datatypes::TimestampMillisecondType, + >(v?)?; + Some(RestateDateTime(Local.from_utc_datetime(×tamp))) + } +} + +// This newtype is necessary to implement ArrowArray, which is implemented for TimestampNanosecond but not TimestampMillisecond for some reason +#[repr(transparent)] +struct TimestampMillisecondArray(arrow::array::TimestampMillisecondArray); + +impl TimestampMillisecondArray { + fn from_ref(v: &arrow::array::TimestampMillisecondArray) -> &Self { + // SAFETY: transmuting a single-element newtype struct with repr(transparent) is safe + unsafe { std::mem::transmute(v) } + } +} + +impl arrow_convert::deserialize::ArrowArray for TimestampMillisecondArray { + type BaseArrayType = arrow::array::TimestampMillisecondArray; + #[inline] + fn iter_from_array_ref( + b: &dyn Array, + ) -> ::Iter<'_> { + let b = b.as_any().downcast_ref::().unwrap(); + ::iter( + TimestampMillisecondArray::from_ref(b), + ) + } +} + +impl arrow_convert::deserialize::ArrowArrayIterable for TimestampMillisecondArray { + type Item<'a> = Option< + ::Native, + >; + + type Iter<'a> = arrow::array::PrimitiveIter<'a, arrow::datatypes::TimestampMillisecondType>; + + fn iter(&self) -> Self::Iter<'_> { + IntoIterator::into_iter(&self.0) } } diff --git a/crates/admin-rest-model/src/version.rs b/crates/admin-rest-model/src/version.rs index b9f433d69..0c421bc5a 100644 --- a/crates/admin-rest-model/src/version.rs +++ b/crates/admin-rest-model/src/version.rs @@ -23,13 +23,46 @@ use std::ops::RangeInclusive; pub enum AdminApiVersion { Unknown = 0, V1 = 1, + V2 = 2, +} + +impl From for (http::HeaderName, http::HeaderValue) { + fn from(value: AdminApiVersion) -> Self { + ( + AdminApiVersion::HEADER_NAME, + http::HeaderValue::from(value.as_repr()), + ) + } } impl AdminApiVersion { + const HEADER_NAME: http::HeaderName = + http::HeaderName::from_static("x-restate-admin-api-version"); + pub fn as_repr(&self) -> u16 { *self as u16 } + pub fn from_headers(headers: &http::HeaderMap) -> Self { + let is_cli = matches!(headers.get(http::header::USER_AGENT), Some(value) if value.as_bytes().starts_with(b"restate-cli")); + + match headers.get(Self::HEADER_NAME) { + Some(value) => match value.to_str() { + Ok(value) => match value.parse::() { + Ok(value) => match Self::try_from(value) { + Ok(value) => value, + Err(_) => Self::Unknown, + }, + Err(_) => Self::Unknown, + }, + Err(_) => Self::Unknown, + }, + // CLI didn't used to send the version header, but if we know its the CLI, then we can treat that as V1 + None if is_cli => Self::V1, + None => Self::Unknown, + } + } + pub fn choose_max_supported_version( client_versions: RangeInclusive, server_versions: RangeInclusive, diff --git a/crates/admin/src/rest_api/version.rs b/crates/admin/src/rest_api/version.rs index 68278aae0..c735cd2a7 100644 --- a/crates/admin/src/rest_api/version.rs +++ b/crates/admin/src/rest_api/version.rs @@ -14,7 +14,7 @@ use restate_admin_rest_model::version::{AdminApiVersion, VersionInformation}; /// Min/max supported admin api versions by the server pub const MIN_ADMIN_API_VERSION: AdminApiVersion = AdminApiVersion::V1; -pub const MAX_ADMIN_API_VERSION: AdminApiVersion = AdminApiVersion::V1; +pub const MAX_ADMIN_API_VERSION: AdminApiVersion = AdminApiVersion::V2; /// Version information endpoint #[openapi( diff --git a/crates/admin/src/storage_query/convert.rs b/crates/admin/src/storage_query/convert.rs new file mode 100644 index 000000000..74e92187b --- /dev/null +++ b/crates/admin/src/storage_query/convert.rs @@ -0,0 +1,284 @@ +// Copyright (c) 2023 - 2025 Restate Software, Inc., Restate GmbH. +// All rights reserved. +// +// Use of this software is governed by the Business Source License +// included in the LICENSE file. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0. + +use std::{ + pin::Pin, + task::{ready, Context, Poll}, +}; + +use datafusion::{ + arrow::{ + array::{ + Array, ArrayRef, AsArray, BinaryArray, GenericByteArray, PrimitiveArray, RecordBatch, + StringArray, + }, + buffer::{OffsetBuffer, ScalarBuffer}, + datatypes::{ + ByteArrayType, DataType, Date64Type, Field, FieldRef, Schema, SchemaRef, TimeUnit, + TimestampMillisecondType, + }, + error::ArrowError, + }, + error::DataFusionError, + execution::{RecordBatchStream, SendableRecordBatchStream}, +}; +use futures::{Stream, StreamExt}; + +pub(super) const V1_CONVERTER: JoinConverter< + JoinConverter, + TimestampConverter, +> = JoinConverter::new( + JoinConverter::new(LargeConverter, FullCountConverter), + TimestampConverter, +); + +pub(super) struct ConvertRecordBatchStream { + converter: C, + inner: SendableRecordBatchStream, + converted_schema: SchemaRef, + done: bool, +} + +impl ConvertRecordBatchStream { + pub(super) fn new(converter: C, inner: SendableRecordBatchStream) -> Self { + let converted_schema = converter.convert_schema(inner.schema()); + + Self { + converter, + inner, + converted_schema, + done: false, + } + } +} + +impl RecordBatchStream for ConvertRecordBatchStream { + fn schema(&self) -> SchemaRef { + self.converted_schema.clone() + } +} + +impl Stream for ConvertRecordBatchStream { + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + if self.done { + return Poll::Ready(None); + } + + let record_batch = ready!(self.inner.poll_next_unpin(cx)); + + Poll::Ready(match record_batch { + Some(record_batch) => Some(match record_batch { + Ok(record_batch) => Ok(self + .converter + .convert_record_batch(&self.converted_schema, record_batch)?), + Err(err) => { + self.done = true; + Err(err) + } + }), + None => None, + }) + } +} + +pub(super) trait Converter: Unpin { + fn convert_schema(&self, schema: SchemaRef) -> SchemaRef { + let fields = Vec::from_iter(schema.fields().iter().cloned()); + let fields = self.convert_fields(fields); + SchemaRef::new(Schema::new_with_metadata(fields, schema.metadata().clone())) + } + + fn convert_record_batch( + &self, + converted_schema: &SchemaRef, + record_batch: RecordBatch, + ) -> Result { + let columns = Vec::from_iter(record_batch.columns().iter().cloned()); + let columns = self.convert_columns(converted_schema, columns)?; + + RecordBatch::try_new(converted_schema.clone(), columns) + } + + fn convert_columns( + &self, + converted_schema: &SchemaRef, + columns: Vec, + ) -> Result, ArrowError>; + + fn convert_fields(&self, fields: Vec) -> Vec; +} + +pub(super) struct JoinConverter { + first: First, + second: Second, +} + +impl JoinConverter { + const fn new(before: Before, after: After) -> Self { + Self { + first: before, + second: after, + } + } +} + +impl Converter for JoinConverter { + fn convert_columns( + &self, + converted_schema: &SchemaRef, + columns: Vec, + ) -> Result, ArrowError> { + self.second.convert_columns( + converted_schema, + self.first.convert_columns(converted_schema, columns)?, + ) + } + + fn convert_fields(&self, fields: Vec) -> Vec { + self.second + .convert_fields(self.first.convert_fields(fields)) + } +} + +// Prior to 1.2, we always converted LargeUtf8 to Utf8 and LargeBinary to Binary because +// Arrow JS didn't used to support the Large datatypes. +pub(super) struct LargeConverter; + +impl Converter for LargeConverter { + fn convert_columns( + &self, + _converted_schema: &SchemaRef, + columns: Vec, + ) -> Result, ArrowError> { + columns + .into_iter() + .map(|column| { + Ok(match column.data_type() { + DataType::LargeBinary => { + let col: BinaryArray = convert_array_offset(column.as_binary::())?; + ArrayRef::from(Box::new(col) as Box) + } + DataType::LargeUtf8 => { + let col: StringArray = convert_array_offset(column.as_string::())?; + ArrayRef::from(Box::new(col) as Box) + } + _ => column, + }) + }) + .collect() + } + + fn convert_fields(&self, fields: Vec) -> Vec { + fields + .into_iter() + .map(|field| { + let data_type = match field.data_type() { + DataType::LargeBinary => DataType::Binary, + DataType::LargeUtf8 => DataType::Utf8, + other => other.clone(), + }; + FieldRef::new(Field::new(field.name(), data_type, field.is_nullable())) + }) + .collect() + } +} + +fn convert_array_offset( + array: &GenericByteArray, +) -> Result, ArrowError> +where + After::Offset: TryFrom, +{ + let offsets = array + .offsets() + .iter() + .map(|&o| After::Offset::try_from(o)) + .collect::, _>>() + .map_err(|_| ArrowError::CastError("offset conversion failed".into()))?; + GenericByteArray::::try_new( + OffsetBuffer::new(offsets), + array.values().clone(), + array.nulls().cloned(), + ) +} + +// Prior to 1.2, we used a datafusion version which incorrectly considered the results of 'COUNT' statements to be nullable +// This is relevant for a single field name, `full_count` which is used in `inv ls`. +pub(super) struct FullCountConverter; + +impl Converter for FullCountConverter { + fn convert_columns( + &self, + _converted_schema: &SchemaRef, + columns: Vec, + ) -> Result, ArrowError> { + // this is a schema-only conversion + Ok(columns) + } + + fn convert_fields(&self, fields: Vec) -> Vec { + fields + .into_iter() + .map(|field| { + if field.name().as_str() == "full_count" + && field.data_type().eq(&DataType::Int64) + && !field.is_nullable() + { + FieldRef::new(Field::clone(&field).with_nullable(true)) + } else { + field + } + }) + .collect() + } +} + +// Prior to 1.2, we used Date64 fields where we should have used Timestamp fields +// This is relevant for various fields used in the CLI +pub(super) struct TimestampConverter; + +impl Converter for TimestampConverter { + fn convert_columns( + &self, + converted_schema: &SchemaRef, + mut columns: Vec, + ) -> Result, ArrowError> { + for (i, field) in converted_schema.fields().iter().enumerate() { + if let (DataType::Date64, DataType::Timestamp(TimeUnit::Millisecond, _)) = + (field.data_type(), columns[i].data_type()) + { + let col = columns[i].as_primitive::(); + // this doesn't copy; the same backing array can be used because they both use i64 epoch-based times + let col = + PrimitiveArray::::new(col.values().clone(), col.nulls().cloned()); + columns[i] = ArrayRef::from(Box::new(col) as Box); + } + } + Ok(columns) + } + + fn convert_fields(&self, fields: Vec) -> Vec { + fields + .into_iter() + .map(|field| match (field.name().as_str(), field.data_type()) { + ( + // inv ls + "last_start_at" | "next_retry_at" | "modified_at" | "created_at" | + // inv describe + "sleep_wakeup_at", + DataType::Timestamp(TimeUnit::Millisecond, _), + ) => FieldRef::new(Field::clone(&field).with_data_type(DataType::Date64)), + _ => field, + }) + .collect() + } +} diff --git a/crates/admin/src/storage_query/mod.rs b/crates/admin/src/storage_query/mod.rs index 27e91ebb1..33b0802b0 100644 --- a/crates/admin/src/storage_query/mod.rs +++ b/crates/admin/src/storage_query/mod.rs @@ -8,6 +8,7 @@ // the Business Source License, use of this software will be governed // by the Apache License, Version 2.0. +mod convert; mod error; mod query; diff --git a/crates/admin/src/storage_query/query.rs b/crates/admin/src/storage_query/query.rs index 8654c40f3..ab152a246 100644 --- a/crates/admin/src/storage_query/query.rs +++ b/crates/admin/src/storage_query/query.rs @@ -29,10 +29,12 @@ use http_body::Frame; use http_body_util::StreamBody; use okapi_operation::*; use parking_lot::Mutex; +use restate_admin_rest_model::version::AdminApiVersion; use schemars::JsonSchema; use serde::Deserialize; use serde_with::serde_as; +use super::convert::{ConvertRecordBatchStream, V1_CONVERTER}; use super::error::StorageQueryError; use crate::state::QueryServiceState; @@ -62,6 +64,17 @@ pub async fn query( ) -> Result { let record_batch_stream = state.query_context.execute(&payload.query).await?; + let version = AdminApiVersion::from_headers(&headers); + + let record_batch_stream: SendableRecordBatchStream = match version { + AdminApiVersion::V1 => Box::pin(ConvertRecordBatchStream::new( + V1_CONVERTER, + record_batch_stream, + )), + // treat 'unknown' as latest, users can specify 1 if they want to maintain old behaviour + AdminApiVersion::Unknown | AdminApiVersion::V2 => record_batch_stream, + }; + let (result_stream, content_type) = match headers.get(http::header::ACCEPT) { Some(v) if v == HeaderValue::from_static("application/json") => ( WriteRecordBatchStream::::new(record_batch_stream)? diff --git a/crates/storage-query-datafusion/src/deployment/schema.rs b/crates/storage-query-datafusion/src/deployment/schema.rs index 85b71c896..0ef2cbdb3 100644 --- a/crates/storage-query-datafusion/src/deployment/schema.rs +++ b/crates/storage-query-datafusion/src/deployment/schema.rs @@ -25,7 +25,7 @@ define_table!(sys_deployment( endpoint: DataType::LargeUtf8, /// Timestamp indicating the deployment registration time. - created_at: DataType::Date64, + created_at: TimestampMillisecond, /// Minimum supported protocol version. min_service_protocol_version: DataType::UInt32, diff --git a/crates/storage-query-datafusion/src/inbox/schema.rs b/crates/storage-query-datafusion/src/inbox/schema.rs index bd3cda0a0..27d17b5fc 100644 --- a/crates/storage-query-datafusion/src/inbox/schema.rs +++ b/crates/storage-query-datafusion/src/inbox/schema.rs @@ -32,5 +32,5 @@ define_table!(sys_inbox( /// Timestamp indicating the start of this invocation. /// DEPRECATED: you should not use this field anymore, but join with the sys_invocation table - created_at: DataType::Date64, + created_at: TimestampMillisecond, )); diff --git a/crates/storage-query-datafusion/src/invocation_state/schema.rs b/crates/storage-query-datafusion/src/invocation_state/schema.rs index 62ccf38a3..0f76f0ef5 100644 --- a/crates/storage-query-datafusion/src/invocation_state/schema.rs +++ b/crates/storage-query-datafusion/src/invocation_state/schema.rs @@ -30,7 +30,7 @@ define_table!(sys_invocation_state( retry_count: DataType::UInt64, /// Timestamp indicating the start of the most recent attempt of this invocation. - last_start_at: DataType::Date64, + last_start_at: TimestampMillisecond, // The deployment that was selected in the last invocation attempt. This is // guaranteed to be set unlike in `sys_status` table which require that the @@ -44,7 +44,7 @@ define_table!(sys_invocation_state( last_attempt_server: DataType::LargeUtf8, /// Timestamp indicating the start of the next attempt of this invocation. - next_retry_at: DataType::Date64, + next_retry_at: TimestampMillisecond, /// An error message describing the most recent failed attempt of this invocation, if any. last_failure: DataType::LargeUtf8, diff --git a/crates/storage-query-datafusion/src/invocation_status/schema.rs b/crates/storage-query-datafusion/src/invocation_status/schema.rs index 0d6dd616b..b08395beb 100644 --- a/crates/storage-query-datafusion/src/invocation_status/schema.rs +++ b/crates/storage-query-datafusion/src/invocation_status/schema.rs @@ -85,21 +85,21 @@ define_table!(sys_invocation_status( journal_size: DataType::UInt32, /// Timestamp indicating the start of this invocation. - created_at: DataType::Date64, + created_at: TimestampMillisecond, /// Timestamp indicating the last invocation status transition. For example, last time the /// status changed from `invoked` to `suspended`. - modified_at: DataType::Date64, + modified_at: TimestampMillisecond, /// Timestamp indicating when the invocation was inboxed, if ever. - inboxed_at: DataType::Date64, + inboxed_at: TimestampMillisecond, /// Timestamp indicating when the invocation was scheduled, if ever. - scheduled_at: DataType::Date64, + scheduled_at: TimestampMillisecond, /// Timestamp indicating when the invocation first transitioned to running, if ever. - running_at: DataType::Date64, + running_at: TimestampMillisecond, /// Timestamp indicating when the invocation was completed, if ever. - completed_at: DataType::Date64, + completed_at: TimestampMillisecond, )); diff --git a/crates/storage-query-datafusion/src/journal/schema.rs b/crates/storage-query-datafusion/src/journal/schema.rs index 73fbe3a1e..18b1b22ed 100644 --- a/crates/storage-query-datafusion/src/journal/schema.rs +++ b/crates/storage-query-datafusion/src/journal/schema.rs @@ -44,7 +44,7 @@ define_table!(sys_journal( invoked_target: DataType::LargeUtf8, /// If this entry represents a sleep, indicates wakeup time. - sleep_wakeup_at: DataType::Date64, + sleep_wakeup_at: TimestampMillisecond, /// If this entry is a promise related entry (GetPromise, PeekPromise, CompletePromise), indicates the promise name. promise_name: DataType::LargeUtf8, diff --git a/crates/storage-query-datafusion/src/table_macro.rs b/crates/storage-query-datafusion/src/table_macro.rs index 428d6d363..655066627 100644 --- a/crates/storage-query-datafusion/src/table_macro.rs +++ b/crates/storage-query-datafusion/src/table_macro.rs @@ -8,6 +8,10 @@ // the Business Source License, use of this software will be governed // by the Apache License, Version 2.0. +// Instead of trying to parse the Timestamp(Millisecond, Some(...)) variant in macros, just use a marker struct +#[allow(dead_code)] +pub struct TimestampMillisecond; + macro_rules! define_builder { (DataType::Utf8) => { ::datafusion::arrow::array::StringBuilder @@ -30,14 +34,67 @@ macro_rules! define_builder { (DataType::Int32) => { ::datafusion::arrow::array::Int32Builder }; - (DataType::Date64) => { - ::datafusion::arrow::array::Date64Builder + (TimestampMillisecond) => { + TimestampMillisecondUTCBuilder }; (DataType::Boolean) => { ::datafusion::arrow::array::BooleanBuilder }; } +// This newtype is necessary to generate values with a UTC timezone, as it will default to having no timezone which can confuse downstream clients +pub struct TimestampMillisecondUTCBuilder(::datafusion::arrow::array::TimestampMillisecondBuilder); + +impl Default for TimestampMillisecondUTCBuilder { + fn default() -> Self { + Self( + ::datafusion::arrow::array::TimestampMillisecondBuilder::default() + .with_timezone(TIMEZONE_UTC.clone()), + ) + } +} + +impl TimestampMillisecondUTCBuilder { + #[inline] + pub fn append_value( + &mut self, + v: <::datafusion::arrow::datatypes::TimestampMillisecondType as ::datafusion::arrow::array::ArrowPrimitiveType>::Native, + ) { + self.0.append_value(v); + } + + #[inline] + pub fn append_null(&mut self) { + self.0.append_null(); + } +} + +impl ::datafusion::arrow::array::ArrayBuilder for TimestampMillisecondUTCBuilder { + fn len(&self) -> usize { + ::datafusion::arrow::array::ArrayBuilder::len(&self.0) + } + + fn finish(&mut self) -> datafusion::arrow::array::ArrayRef { + ::datafusion::arrow::array::ArrayBuilder::finish(&mut self.0) + } + + fn finish_cloned(&self) -> datafusion::arrow::array::ArrayRef { + ::datafusion::arrow::array::ArrayBuilder::finish_cloned(&self.0) + } + + fn as_any(&self) -> &dyn std::any::Any { + ::datafusion::arrow::array::ArrayBuilder::as_any(&self.0) + } + + fn as_any_mut(&mut self) -> &mut dyn std::any::Any { + ::datafusion::arrow::array::ArrayBuilder::as_any_mut(&mut self.0) + } + + fn into_box_any(self: Box) -> Box { + ::datafusion::arrow::array::ArrayBuilder::into_box_any(Box::new(self.0)) + } +} + macro_rules! define_primitive_trait { (DataType::Utf8) => { impl AsRef @@ -57,7 +114,7 @@ macro_rules! define_primitive_trait { (DataType::Int32) => { i32 }; - (DataType::Date64) => { + (TimestampMillisecond) => { i64 }; (DataType::UInt64) => { @@ -68,6 +125,42 @@ macro_rules! define_primitive_trait { }; } +pub static TIMEZONE_UTC: std::sync::LazyLock> = + std::sync::LazyLock::new(|| std::sync::Arc::from("+00:00")); + +macro_rules! define_data_type { + (DataType::Utf8) => { + DataType::Utf8 + }; + (DataType::LargeUtf8) => { + DataType::LargeUtf8 + }; + (DataType::Binary) => { + DataType::Binary + }; + (DataType::LargeBinary) => { + DataType::LargeBinary + }; + (DataType::UInt32) => { + DataType::UInt32 + }; + (DataType::Int32) => { + DataType::Int32 + }; + (TimestampMillisecond) => { + DataType::Timestamp( + ::datafusion::arrow::datatypes::TimeUnit::Millisecond, + Some(TIMEZONE_UTC.clone()), + ) + }; + (DataType::UInt64) => { + DataType::UInt64 + }; + (DataType::Boolean) => { + DataType::Boolean + }; +} + #[cfg(feature = "table_docs")] macro_rules! document_type { (DataType::Utf8) => { @@ -91,8 +184,8 @@ macro_rules! document_type { (DataType::Int32) => { "Int32" }; - (DataType::Date64) => { - "Date64" + (TimestampMillisecond) => { + "TimestampMillisecond" }; (DataType::Boolean) => { "Boolean" @@ -108,7 +201,7 @@ macro_rules! document_type { /// name: DataType::Utf8, /// age: DataType::UInt32, /// secret: DataType::Binary, -/// birth_date: DataType::Date64, +/// birth_date: TimestampMillisecond, /// )) /// /// ``` @@ -126,7 +219,7 @@ macro_rules! document_type { /// name: Option, /// age: Option, /// secret: Option, -/// birth_date: Option, +/// birth_date: Option, /// } /// pub struct UserRowBuilder<'a> { /// flags: UserRowBuilderFlags, @@ -286,7 +379,7 @@ macro_rules! document_type { /// (Field::new(stringify!( name ), DataType::Utf8, true)), /// (Field::new(stringify!( age ), DataType::UInt32, true)), /// (Field::new(stringify!( secret ), DataType::Binary, true)), -/// (Field::new(stringify!( birth_date ), DataType::Date64, true))]) +/// (Field::new(stringify!( birth_date ), DataType::Timestamp(TimeUnit::Millisecond, Some(TIMEZONE_UTC.clone())), true))]) /// ))) /// ) /// } @@ -439,7 +532,7 @@ macro_rules! define_table { pub fn schema() -> ::datafusion::arrow::datatypes::SchemaRef { std::sync::Arc::new(::datafusion::arrow::datatypes::Schema::new( vec![ - $(::datafusion::arrow::datatypes::Field::new(stringify!($element), $ty, true),)+ + $(::datafusion::arrow::datatypes::Field::new(stringify!($element), define_data_type!($ty), true),)+ ]) ) } @@ -498,6 +591,7 @@ macro_rules! define_table { } pub(crate) use define_builder; +pub(crate) use define_data_type; pub(crate) use define_primitive_trait; pub(crate) use define_table; #[cfg(feature = "table_docs")] diff --git a/crates/storage-query-postgres/src/pgwire_server.rs b/crates/storage-query-postgres/src/pgwire_server.rs index 06dfd1a2e..6acfa9c23 100644 --- a/crates/storage-query-postgres/src/pgwire_server.rs +++ b/crates/storage-query-postgres/src/pgwire_server.rs @@ -8,13 +8,16 @@ // the Business Source License, use of this software will be governed // by the Apache License, Version 2.0. +use std::net::SocketAddr; +use std::sync::Arc; +use std::time::{Duration, SystemTime}; + use anyhow::Context; use async_trait::async_trait; use datafusion::arrow::array::{ - Array, BinaryArray, BooleanArray, Date32Array, Date64Array, LargeBinaryArray, LargeStringArray, - PrimitiveArray, StringArray, + Array, BinaryArray, BooleanArray, LargeBinaryArray, LargeStringArray, PrimitiveArray, + StringArray, TimestampMillisecondArray, }; -use datafusion::arrow::datatypes::DataType; use datafusion::arrow::datatypes::Float32Type; use datafusion::arrow::datatypes::Float64Type; use datafusion::arrow::datatypes::Int16Type; @@ -23,12 +26,10 @@ use datafusion::arrow::datatypes::Int64Type; use datafusion::arrow::datatypes::Int8Type; use datafusion::arrow::datatypes::UInt32Type; use datafusion::arrow::datatypes::UInt64Type; +use datafusion::arrow::datatypes::{DataType, TimeUnit}; use datafusion::arrow::record_batch::RecordBatch; -use datafusion::arrow::temporal_conversions::{date32_to_datetime, date64_to_datetime}; use datafusion::physical_plan::SendableRecordBatchStream; use futures::{stream, StreamExt}; -use std::net::SocketAddr; -use std::sync::Arc; use tokio::net::TcpStream; use tokio::sync::Mutex; @@ -154,7 +155,6 @@ fn into_pg_type(df_type: &DataType) -> PgWireResult { DataType::UInt64 => Type::INT8, DataType::Timestamp(_, _) => Type::TIMESTAMP, DataType::Time32(_) | DataType::Time64(_) => Type::TIME, - DataType::Date32 | DataType::Date64 => Type::VARCHAR, DataType::Binary => Type::BYTEA, DataType::LargeBinary => Type::BYTEA, DataType::Float32 => Type::FLOAT4, @@ -241,26 +241,16 @@ fn get_bool_value(arr: &Arc, idx: usize) -> bool { .value(idx) } -fn get_date64_value(arr: &Arc, idx: usize) -> String { +fn get_timestamp_value(arr: &Arc, idx: usize) -> SystemTime { let value = arr .as_any() - .downcast_ref::() + .downcast_ref::() .unwrap() .value(idx); - let dt = date64_to_datetime(value).unwrap(); - dt.format("%Y-%m-%d %H:%M:%S").to_string() -} - -fn get_date32_value(arr: &Arc, idx: usize) -> String { - let value = arr - .as_any() - .downcast_ref::() + SystemTime::UNIX_EPOCH + .checked_add(Duration::from_millis(value as u64)) .unwrap() - .value(idx); - - let dt = date32_to_datetime(value).unwrap(); - dt.format("%Y-%m-%d %H:%M:%S").to_string() } macro_rules! get_primitive_value { @@ -330,8 +320,9 @@ fn encode_value( DataType::LargeUtf8 => encoder.encode_field(&get_large_utf8_value(arr, idx))?, DataType::Binary => encoder.encode_field(&get_binary_value(arr, idx))?, DataType::LargeBinary => encoder.encode_field(&get_large_binary_value(arr, idx))?, - DataType::Date64 => encoder.encode_field(&get_date64_value(arr, idx))?, - DataType::Date32 => encoder.encode_field(&get_date32_value(arr, idx))?, + DataType::Timestamp(TimeUnit::Millisecond, None) => { + encoder.encode_field(&get_timestamp_value(arr, idx))? + } _ => { return Err(PgWireError::UserError(Box::new(ErrorInfo::new( "ERROR".to_owned(),