From c06d554bd4a757dc5fcf03436d1d9a9ca7832908 Mon Sep 17 00:00:00 2001 From: Nicholas Gates Date: Wed, 8 May 2024 10:28:36 +0100 Subject: [PATCH] Proto --- Cargo.lock | 2 ++ Cargo.toml | 1 + build-vortex/src/lib.rs | 2 +- vortex-array/src/array/varbin/builder.rs | 2 +- vortex-array/src/array/varbin/mod.rs | 2 +- vortex-array/src/array/varbin/stats.rs | 8 ++--- vortex-buffer/src/string.rs | 14 +++++--- vortex-dtype/Cargo.toml | 2 ++ vortex-dtype/src/lib.rs | 6 ++-- vortex-dtype/src/serde/proto.rs | 10 +++--- vortex-scalar/Cargo.toml | 10 +++++- vortex-scalar/proto/scalar.proto | 21 ++++++++++++ vortex-scalar/src/binary.rs | 2 +- vortex-scalar/src/lib.rs | 9 ++++++ vortex-scalar/src/serde/mod.rs | 3 +- vortex-scalar/src/serde/proto.rs | 41 ++++++++++++++++++++++++ vortex-scalar/src/serde/serde.rs | 4 ++- vortex-scalar/src/utf8.rs | 10 ++---- vortex-scalar/src/value.rs | 14 ++++++-- 19 files changed, 132 insertions(+), 31 deletions(-) create mode 100644 vortex-scalar/proto/scalar.proto create mode 100644 vortex-scalar/src/serde/proto.rs diff --git a/Cargo.lock b/Cargo.lock index 328d5fe9d7..41c5e3a314 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5115,6 +5115,8 @@ dependencies = [ "itertools 0.12.1", "num-traits", "paste", + "prost", + "prost-types", "serde", "vortex-buffer", "vortex-dtype", diff --git a/Cargo.toml b/Cargo.toml index adb1e60e59..bf92f3ed1b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -77,6 +77,7 @@ parquet = "51.0.0" paste = "1.0.14" prost = "0.12.4" prost-build = "0.12.4" +prost-types = "0.12.4" pyo3 = { version = "0.20.2", features = ["extension-module", "abi3-py311"] } pyo3-log = "0.9.0" rand = "0.8.5" diff --git a/build-vortex/src/lib.rs b/build-vortex/src/lib.rs index 650292159e..5e70edc4c0 100644 --- a/build-vortex/src/lib.rs +++ b/build-vortex/src/lib.rs @@ -26,7 +26,7 @@ pub fn build() { } // Proto (prost) - if env::var("CARGO_FEATURE_PROST").ok().is_some() { + if env::var("CARGO_FEATURE_PROTO").ok().is_some() { build_proto(); } } diff --git a/vortex-array/src/array/varbin/builder.rs b/vortex-array/src/array/varbin/builder.rs index 1bbbda906c..da9104420a 100644 --- a/vortex-array/src/array/varbin/builder.rs +++ b/vortex-array/src/array/varbin/builder.rs @@ -87,7 +87,7 @@ mod test { assert_eq!(array.dtype().nullability(), Nullable); assert_eq!( scalar_at(&array, 0).unwrap(), - Scalar::utf8("hello", Nullable) + Scalar::utf8("hello".to_string(), Nullable) ); assert!(scalar_at(&array, 1).unwrap().is_null()); } diff --git a/vortex-array/src/array/varbin/mod.rs b/vortex-array/src/array/varbin/mod.rs index 34ecb5b85a..cbd6ac3254 100644 --- a/vortex-array/src/array/varbin/mod.rs +++ b/vortex-array/src/array/varbin/mod.rs @@ -210,7 +210,7 @@ impl<'a> FromIterator> for VarBinArray<'_> { pub fn varbin_scalar(value: Vec, dtype: &DType) -> Scalar { if matches!(dtype, DType::Utf8(_)) { let str = unsafe { String::from_utf8_unchecked(value) }; - Scalar::utf8(str.as_ref(), dtype.nullability()) + Scalar::utf8(str, dtype.nullability()) } else { Scalar::binary(value.into(), dtype.nullability()) } diff --git a/vortex-array/src/array/varbin/stats.rs b/vortex-array/src/array/varbin/stats.rs index c77fb4c0e4..7060ed1351 100644 --- a/vortex-array/src/array/varbin/stats.rs +++ b/vortex-array/src/array/varbin/stats.rs @@ -144,11 +144,11 @@ mod test { let arr = array(DType::Utf8(Nullability::NonNullable)); assert_eq!( arr.statistics().compute_min::().unwrap(), - BufferString::from("hello world") + BufferString::from("hello world".to_string()) ); assert_eq!( arr.statistics().compute_max::().unwrap(), - BufferString::from("hello world this is a long string") + BufferString::from("hello world this is a long string".to_string()) ); assert_eq!(arr.statistics().compute_run_count().unwrap(), 2); assert!(!arr.statistics().compute_is_constant().unwrap()); @@ -184,11 +184,11 @@ mod test { ); assert_eq!( array.statistics().compute_min::().unwrap(), - BufferString::from("hello world") + BufferString::from("hello world".to_string()) ); assert_eq!( array.statistics().compute_max::().unwrap(), - BufferString::from("hello world this is a long string") + BufferString::from("hello world this is a long string".to_string()) ); } diff --git a/vortex-buffer/src/string.rs b/vortex-buffer/src/string.rs index bb79c40aa0..b0023490eb 100644 --- a/vortex-buffer/src/string.rs +++ b/vortex-buffer/src/string.rs @@ -15,6 +15,11 @@ impl BufferString { pub unsafe fn new_unchecked(buffer: Buffer) -> Self { Self(buffer) } + + pub fn as_str(&self) -> &str { + // SAFETY: We have already validated that the buffer is valid UTF-8 + unsafe { std::str::from_utf8_unchecked(self.0.as_ref()) } + } } impl From for Buffer { @@ -23,9 +28,9 @@ impl From for Buffer { } } -impl From<&str> for BufferString { - fn from(value: &str) -> Self { - BufferString(Buffer::from(value.as_bytes())) +impl From for BufferString { + fn from(value: String) -> Self { + BufferString(Buffer::from(value.into_bytes())) } } @@ -42,7 +47,6 @@ impl Deref for BufferString { type Target = str; fn deref(&self) -> &Self::Target { - // SAFETY: We have already validated that the buffer is valid UTF-8 - unsafe { std::str::from_utf8_unchecked(self.0.as_ref()) } + self.as_str() } } diff --git a/vortex-dtype/Cargo.toml b/vortex-dtype/Cargo.toml index 2797da7644..2b6e324853 100644 --- a/vortex-dtype/Cargo.toml +++ b/vortex-dtype/Cargo.toml @@ -36,4 +36,6 @@ build-vortex = { path = "../build-vortex" } workspace = true [features] +default = ["flatbuffers", "proto", "serde"] +proto = ["dep:prost"] serde = ["dep:serde", "half/serde"] \ No newline at end of file diff --git a/vortex-dtype/src/lib.rs b/vortex-dtype/src/lib.rs index 7bdd85a8d3..179d9d3a61 100644 --- a/vortex-dtype/src/lib.rs +++ b/vortex-dtype/src/lib.rs @@ -11,9 +11,11 @@ mod nullability; mod ptype; mod serde; -#[cfg(feature = "prost")] +#[cfg(feature = "proto")] pub mod proto { - include!(concat!(env!("OUT_DIR"), "/proto/vortex.dtype.rs")); + pub mod dtype { + include!(concat!(env!("OUT_DIR"), "/proto/vortex.dtype.rs")); + } } #[cfg(feature = "flatbuffers")] diff --git a/vortex-dtype/src/serde/proto.rs b/vortex-dtype/src/serde/proto.rs index 0e49ded1cf..247b806c62 100644 --- a/vortex-dtype/src/serde/proto.rs +++ b/vortex-dtype/src/serde/proto.rs @@ -1,9 +1,11 @@ -#![cfg(feature = "prost")] +#![cfg(feature = "proto")] + +use std::sync::Arc; use vortex_error::{vortex_err, VortexError, VortexResult}; -use crate::proto::d_type::Type; -use crate::{proto as pb, DType, ExtDType, ExtID, ExtMetadata, PType, StructDType}; +use crate::proto::dtype::d_type::Type; +use crate::{proto::dtype as pb, DType, ExtDType, ExtID, ExtMetadata, PType, StructDType}; impl TryFrom<&pb::DType> for DType { type Error = VortexError; @@ -38,7 +40,7 @@ impl TryFrom<&pb::DType> for DType { .ok_or_else(|| vortex_err!(InvalidSerde: "Invalid list element type"))? .as_ref() .try_into() - .map(Box::new)?, + .map(Arc::new)?, nullable, )) } diff --git a/vortex-scalar/Cargo.toml b/vortex-scalar/Cargo.toml index de34776c62..5bd3267d23 100644 --- a/vortex-scalar/Cargo.toml +++ b/vortex-scalar/Cargo.toml @@ -16,6 +16,8 @@ flatbuffers = { workspace = true, optional = true } flexbuffers = { workspace = true, optional = true } itertools = { workspace = true } paste = { workspace = true } +prost = { workspace = true, optional = true } +prost-types = { workspace = true, optional = true } num-traits = { workspace = true } serde = { workspace = true, optional = true, features = ["rc"] } vortex-buffer = { path = "../vortex-buffer" } @@ -30,6 +32,7 @@ build-vortex = { path = "../build-vortex" } workspace = true [features] +default = ["flatbuffers", "proto", "serde"] flatbuffers = [ "dep:flatbuffers", "dep:flexbuffers", @@ -37,4 +40,9 @@ flatbuffers = [ "vortex-buffer/flexbuffers", "vortex-error/flexbuffers", ] -serde = ["dep:serde", "serde/derive"] \ No newline at end of file +proto = [ + "dep:prost", + "dep:prost-types", + "vortex-dtype/proto", +] +serde = ["dep:serde", "serde/derive"] diff --git a/vortex-scalar/proto/scalar.proto b/vortex-scalar/proto/scalar.proto new file mode 100644 index 0000000000..308ec0670a --- /dev/null +++ b/vortex-scalar/proto/scalar.proto @@ -0,0 +1,21 @@ +syntax = "proto3"; + +import "google/protobuf/any.proto"; +import "vortex-dtype/proto/dtype.proto"; + +package vortex.scalar; + +message Scalar { + vortex.dtype.DType dtype = 1; + oneof value { + bool bool = 2; + uint32 uint32 = 3; + uint64 uint64 = 4; + sint32 sint32 = 5; + sint64 sint64 = 6; + float float = 7; + double double = 8; + bytes bytes = 9; + string string = 10; + } +} \ No newline at end of file diff --git a/vortex-scalar/src/binary.rs b/vortex-scalar/src/binary.rs index de62772b2a..8d9f811aca 100644 --- a/vortex-scalar/src/binary.rs +++ b/vortex-scalar/src/binary.rs @@ -43,7 +43,7 @@ impl<'a> TryFrom<&'a Scalar> for BinaryScalar<'a> { } Ok(Self { dtype: value.dtype(), - value: value.value.as_bytes()?, + value: value.value.as_buffer()?, }) } } diff --git a/vortex-scalar/src/lib.rs b/vortex-scalar/src/lib.rs index 9331b165b6..10cdac600f 100644 --- a/vortex-scalar/src/lib.rs +++ b/vortex-scalar/src/lib.rs @@ -24,6 +24,15 @@ pub use utf8::*; pub use value::*; use vortex_error::{vortex_bail, VortexResult}; +#[cfg(feature = "proto")] +pub mod proto { + pub mod scalar { + include!(concat!(env!("OUT_DIR"), "/proto/vortex.scalar.rs")); + } + + pub use vortex_dtype::proto::dtype; +} + #[cfg(feature = "flatbuffers")] pub mod flatbuffers { pub use gen_scalar::vortex::*; diff --git a/vortex-scalar/src/serde/mod.rs b/vortex-scalar/src/serde/mod.rs index 2171e0141f..a5e9e2a19a 100644 --- a/vortex-scalar/src/serde/mod.rs +++ b/vortex-scalar/src/serde/mod.rs @@ -1,3 +1,4 @@ +mod flatbuffers; +mod proto; #[allow(clippy::module_inception)] mod serde; -mod flatbuffers; diff --git a/vortex-scalar/src/serde/proto.rs b/vortex-scalar/src/serde/proto.rs new file mode 100644 index 0000000000..72dcee1e4e --- /dev/null +++ b/vortex-scalar/src/serde/proto.rs @@ -0,0 +1,41 @@ +#![cfg(feature = "proto")] + +use vortex_buffer::{Buffer, BufferString}; +use vortex_dtype::DType; +use vortex_error::{vortex_err, VortexError}; + +use crate::proto::scalar::scalar::Value; +use crate::pvalue::PValue; +use crate::{proto::scalar as pb, Scalar, ScalarValue}; + +impl TryFrom<&pb::Scalar> for Scalar { + type Error = VortexError; + + fn try_from(value: &pb::Scalar) -> Result { + let dtype = DType::try_from( + value + .dtype + .as_ref() + .ok_or_else(|| vortex_err!(InvalidSerde: "Scalar missing dtype"))?, + )?; + + let value = value + .value + .as_ref() + .ok_or_else(|| vortex_err!(InvalidSerde: "Scalar missing value"))?; + + let value = match value { + Value::Bool(b) => ScalarValue::Bool(*b), + Value::Uint32(v) => ScalarValue::Primitive(PValue::U32(*v)), + Value::Uint64(v) => ScalarValue::Primitive(PValue::U64(*v)), + Value::Sint32(v) => ScalarValue::Primitive(PValue::I32(*v)), + Value::Sint64(v) => ScalarValue::Primitive(PValue::I64(*v)), + Value::Float(v) => ScalarValue::Primitive(PValue::F32(*v)), + Value::Double(v) => ScalarValue::Primitive(PValue::F64(*v)), + Value::Bytes(v) => ScalarValue::Buffer(Buffer::from(v.clone())), + Value::String(v) => ScalarValue::BufferString(BufferString::from(v.clone())), + }; + + Ok(Scalar { dtype, value }) + } +} diff --git a/vortex-scalar/src/serde/serde.rs b/vortex-scalar/src/serde/serde.rs index d9666e6412..5f2f59f4cb 100644 --- a/vortex-scalar/src/serde/serde.rs +++ b/vortex-scalar/src/serde/serde.rs @@ -4,6 +4,7 @@ use std::fmt::Formatter; use serde::de::{Error, SeqAccess, Visitor}; use serde::{Deserialize, Deserializer, Serialize, Serializer}; +use vortex_buffer::BufferString; use crate::pvalue::PValue; use crate::value::ScalarValue; @@ -18,6 +19,7 @@ impl Serialize for ScalarValue { ScalarValue::Bool(b) => b.serialize(serializer), ScalarValue::Primitive(p) => p.serialize(serializer), ScalarValue::Buffer(buffer) => buffer.as_ref().serialize(serializer), + ScalarValue::BufferString(buffer) => buffer.as_str().serialize(serializer), ScalarValue::List(l) => l.serialize(serializer), } } @@ -117,7 +119,7 @@ impl<'de> Deserialize<'de> for ScalarValue { where E: Error, { - Ok(ScalarValue::Buffer(v.as_bytes().to_vec().into())) + Ok(ScalarValue::BufferString(BufferString::from(v.to_string()))) } fn visit_bytes(self, v: &[u8]) -> Result diff --git a/vortex-scalar/src/utf8.rs b/vortex-scalar/src/utf8.rs index 01a2093364..e5eb25df49 100644 --- a/vortex-scalar/src/utf8.rs +++ b/vortex-scalar/src/utf8.rs @@ -34,7 +34,7 @@ impl Scalar { { Scalar { dtype: DType::Utf8(nullability), - value: ScalarValue::Buffer(BufferString::from(str).into()), + value: ScalarValue::BufferString(BufferString::from(str)), } } } @@ -48,11 +48,7 @@ impl<'a> TryFrom<&'a Scalar> for Utf8Scalar<'a> { } Ok(Self { dtype: value.dtype(), - value: value - .value - .as_bytes()? - .map(BufferString::try_from) - .transpose()?, + value: value.value.as_buffer_string()?, }) } } @@ -71,7 +67,7 @@ impl From<&str> for Scalar { fn from(value: &str) -> Self { Scalar { dtype: DType::Utf8(NonNullable), - value: ScalarValue::Buffer(value.as_bytes().into()), + value: ScalarValue::BufferString(value.to_string().into()), } } } diff --git a/vortex-scalar/src/value.rs b/vortex-scalar/src/value.rs index ee9c661aa3..cebf1a3033 100644 --- a/vortex-scalar/src/value.rs +++ b/vortex-scalar/src/value.rs @@ -1,6 +1,6 @@ use std::sync::Arc; -use vortex_buffer::Buffer; +use vortex_buffer::{Buffer, BufferString}; use vortex_error::{vortex_err, VortexResult}; use crate::pvalue::PValue; @@ -17,6 +17,7 @@ pub enum ScalarValue { Bool(bool), Primitive(PValue), Buffer(Buffer), + BufferString(BufferString), List(Arc<[ScalarValue]>), } @@ -41,7 +42,7 @@ impl ScalarValue { } } - pub fn as_bytes(&self) -> VortexResult> { + pub fn as_buffer(&self) -> VortexResult> { match self { ScalarValue::Null => Ok(None), ScalarValue::Buffer(b) => Ok(Some(b.clone())), @@ -49,6 +50,15 @@ impl ScalarValue { } } + pub fn as_buffer_string(&self) -> VortexResult> { + match self { + ScalarValue::Null => Ok(None), + ScalarValue::Buffer(b) => Ok(Some(BufferString::try_from(b.clone())?)), + ScalarValue::BufferString(b) => Ok(Some(b.clone())), + _ => Err(vortex_err!("Expected a string scalar, found {:?}", self)), + } + } + pub fn as_list(&self) -> VortexResult>> { match self { ScalarValue::List(l) => Ok(Some(l)),