From 0e83666c8e76b7bee87838fe0d9b4a91dda7904c Mon Sep 17 00:00:00 2001 From: Marko Bakovic Date: Sun, 9 Jun 2024 15:20:33 +0200 Subject: [PATCH] Implement StructValue proto serde without google.protobuf.Value (#343) Move away from google.protobuf.Value. It makes things unnecessary complicated - number is always f64, look at the previous logic for handling struct.... It seems like the Value models JSON value which is not what we need. I considered using google.protobuf wrappers, e.g. BytesValue, Int32Value, instead of primitives but those are also for usage with JSON or with proto2 (before optional was introduced to enable optional primitive fields), so decided to use proto primitives. Adds to proto and try from proto. --- .../proto/vortex/scalar/scalar.proto | 19 +- vortex-scalar/src/serde/proto.rs | 411 +++++++----------- 2 files changed, 168 insertions(+), 262 deletions(-) diff --git a/vortex-scalar/proto/vortex/scalar/scalar.proto b/vortex-scalar/proto/vortex/scalar/scalar.proto index 6b95928e82..7fa00ad5dc 100644 --- a/vortex-scalar/proto/vortex/scalar/scalar.proto +++ b/vortex-scalar/proto/vortex/scalar/scalar.proto @@ -4,6 +4,7 @@ package vortex.scalar; import "vortex/dtype/dtype.proto"; import "google/protobuf/struct.proto"; +import "google/protobuf/wrappers.proto"; message Scalar { vortex.dtype.DType dtype = 1; @@ -11,5 +12,21 @@ message Scalar { } message ScalarValue { - google.protobuf.Value value = 1; + oneof kind { + google.protobuf.NullValue null_value = 1; + bool bool_value = 2; + int32 int32_value = 3; + int64 int64_value = 4; + uint32 uint32_value = 5; + uint64 uint64_value = 6; + float float_value = 7; + double double_value = 8; + string string_value = 9; + bytes bytes_value = 10; + ListValue list_value = 12; + } +} + +message ListValue { + repeated ScalarValue values = 1; } diff --git a/vortex-scalar/src/serde/proto.rs b/vortex-scalar/src/serde/proto.rs index 1bff3d24dd..4de7110d84 100644 --- a/vortex-scalar/src/serde/proto.rs +++ b/vortex-scalar/src/serde/proto.rs @@ -1,14 +1,92 @@ #![cfg(feature = "proto")] -use prost_types::value::Kind; -use prost_types::{ListValue, Struct, Value}; -use vortex_buffer::BufferString; -use vortex_dtype::{DType, StructDType}; -use vortex_error::{vortex_bail, vortex_err, VortexError}; +use vortex_buffer::{Buffer, BufferString}; +use vortex_dtype::DType; +use vortex_error::{vortex_err, VortexError}; +use crate::proto::scalar::scalar_value::Kind; +use crate::proto::scalar::ListValue; use crate::pvalue::PValue; use crate::{proto::scalar as pb, Scalar, ScalarValue}; +impl From<&Scalar> for pb::Scalar { + fn from(value: &Scalar) -> Self { + pb::Scalar { + dtype: Some((&value.dtype).into()), + value: Some((&value.value).into()), + } + } +} + +impl From<&ScalarValue> for pb::ScalarValue { + fn from(value: &ScalarValue) -> Self { + match value { + ScalarValue::Null => pb::ScalarValue { + kind: Some(Kind::NullValue(0)), + }, + ScalarValue::Bool(v) => pb::ScalarValue { + kind: Some(Kind::BoolValue(*v)), + }, + ScalarValue::Primitive(v) => v.into(), + ScalarValue::Buffer(v) => pb::ScalarValue { + kind: Some(Kind::BytesValue(v.as_slice().to_vec())), + }, + ScalarValue::BufferString(v) => pb::ScalarValue { + kind: Some(Kind::StringValue(v.as_str().to_string())), + }, + ScalarValue::List(v) => { + let mut values = Vec::with_capacity(v.len()); + for elem in v.iter() { + values.push(pb::ScalarValue::from(elem)); + } + pb::ScalarValue { + kind: Some(Kind::ListValue(ListValue { values })), + } + } + } + } +} + +impl From<&PValue> for pb::ScalarValue { + fn from(value: &PValue) -> Self { + match value { + PValue::I8(v) => pb::ScalarValue { + kind: Some(Kind::Int32Value(*v as i32)), + }, + PValue::I16(v) => pb::ScalarValue { + kind: Some(Kind::Int32Value(*v as i32)), + }, + PValue::I32(v) => pb::ScalarValue { + kind: Some(Kind::Int32Value(*v)), + }, + PValue::I64(v) => pb::ScalarValue { + kind: Some(Kind::Int64Value(*v)), + }, + PValue::U8(v) => pb::ScalarValue { + kind: Some(Kind::Uint32Value(*v as u32)), + }, + PValue::U16(v) => pb::ScalarValue { + kind: Some(Kind::Uint32Value(*v as u32)), + }, + PValue::U32(v) => pb::ScalarValue { + kind: Some(Kind::Uint32Value(*v)), + }, + PValue::U64(v) => pb::ScalarValue { + kind: Some(Kind::Uint64Value(*v)), + }, + PValue::F16(v) => pb::ScalarValue { + kind: Some(Kind::FloatValue(v.to_f32())), + }, + PValue::F32(v) => pb::ScalarValue { + kind: Some(Kind::FloatValue(*v)), + }, + PValue::F64(v) => pb::ScalarValue { + kind: Some(Kind::DoubleValue(*v)), + }, + } + } +} + impl TryFrom<&pb::Scalar> for Scalar { type Error = VortexError; @@ -20,303 +98,114 @@ impl TryFrom<&pb::Scalar> for Scalar { .ok_or_else(|| vortex_err!(InvalidSerde: "Scalar missing dtype"))?, )?; - let scalar_value = value - .value - .as_ref() - .ok_or_else(|| vortex_err!(InvalidSerde: "Scalar missing value"))?; - - let pb_value = scalar_value - .value - .as_ref() - .ok_or_else(|| vortex_err!(InvalidSerde: "ScalarValue missing value"))?; - - let value = try_from_value(&dtype, pb_value)?; + let value = ScalarValue::try_from( + value + .value + .as_ref() + .ok_or_else(|| vortex_err!(InvalidSerde: "Scalar missing value"))?, + )?; Ok(Self { dtype, value }) } } -fn try_from_value(dtype: &DType, value: &Value) -> Result { - let kind = value - .kind - .as_ref() - .ok_or_else(|| vortex_err!(InvalidSerde: "Value missing kind"))?; - - Ok(match kind { - Kind::NullValue(_) => { - if !dtype.is_nullable() { - vortex_bail!(InvalidSerde: "Expected a nullable or Null dtype, found {:?}", dtype); - } - - ScalarValue::Null - } - Kind::BoolValue(v) => { - if !matches!(dtype, DType::Bool(_)) { - vortex_bail!(InvalidSerde: "Expected a bool dtype, found {:?}", dtype); - } - - ScalarValue::Bool(*v) - } - Kind::NumberValue(v) => { - if !matches!(dtype, DType::Primitive(_, _)) { - vortex_bail!(InvalidSerde: "Expected a primitive dtype, found {:?}", dtype); - } - - ScalarValue::Primitive(PValue::F64(*v)) - } - Kind::StringValue(v) => { - if !matches!(dtype, DType::Utf8(_)) { - vortex_bail!(InvalidSerde: "Expected a utf8 dtype, found {:?}", dtype); - } - - ScalarValue::BufferString(BufferString::from(v.clone())) - } - Kind::ListValue(v) => { - if let DType::List(elem_dtype, _) = dtype { - return try_from_list_value(elem_dtype, v); - } - - vortex_bail!(InvalidSerde: "Expected a list dtype, found {:?}", dtype); - } - Kind::StructValue(v) => { - if let DType::Struct(sdt, _) = dtype { - return try_from_struct_value(sdt, v); - } - - vortex_bail!(InvalidSerde: "Expected a struct dtype, found {:?}", dtype); - } - }) -} - -fn try_from_list_value(elem_dtype: &DType, value: &ListValue) -> Result { - let mut values = vec![]; - - for elem in value.values.iter() { - let nested = try_from_value(elem_dtype, elem)?; - - // Allow null values for nullable list only. - if matches!(nested, ScalarValue::Null) && !elem_dtype.is_nullable() { - vortex_bail!(InvalidSerde: "Non-nullable list element is null"); - } - - values.push(try_from_value(elem_dtype, elem)?); - } - - Ok(ScalarValue::List(values.into())) -} - -fn try_from_struct_value(dtype: &StructDType, value: &Struct) -> Result { - let mut values = vec![]; - - for (field, field_dt) in dtype.names().iter().zip(dtype.dtypes().iter()) { - if let Some((_, v)) = - // Add field values in order defined by the struct dtype. - value - .fields - .iter() - .find(|(f, _)| field.as_ref() == f.as_str()) - { - let nested = try_from_value(field_dt, v)?; +impl TryFrom<&pb::ScalarValue> for ScalarValue { + type Error = VortexError; - // Allow null values for nullable struct only. - if matches!(nested, ScalarValue::Null) && !field_dt.is_nullable() { - vortex_bail!(InvalidSerde: "Non-nullable struct field {} is null", field); + fn try_from(value: &pb::ScalarValue) -> Result { + let kind = value + .kind + .as_ref() + .ok_or_else(|| vortex_err!(InvalidSerde: "ScalarValue missing kind"))?; + + Ok(match kind { + Kind::NullValue(_) => ScalarValue::Null, + Kind::BoolValue(v) => ScalarValue::Bool(*v), + Kind::Int32Value(v) => ScalarValue::Primitive(PValue::I32(*v)), + Kind::Int64Value(v) => ScalarValue::Primitive(PValue::I64(*v)), + Kind::Uint32Value(v) => ScalarValue::Primitive(PValue::U32(*v)), + Kind::Uint64Value(v) => ScalarValue::Primitive(PValue::U64(*v)), + Kind::FloatValue(v) => ScalarValue::Primitive(PValue::F32(*v)), + Kind::DoubleValue(v) => ScalarValue::Primitive(PValue::F64(*v)), + Kind::StringValue(v) => ScalarValue::BufferString(BufferString::from(v.clone())), + Kind::BytesValue(v) => ScalarValue::Buffer(Buffer::from(v.as_slice())), + Kind::ListValue(v) => { + let mut values = Vec::with_capacity(v.values.len()); + for elem in v.values.iter() { + values.push(ScalarValue::try_from(elem)?); + } + ScalarValue::List(values.into()) } - - values.push(try_from_value(field_dt, v)?); - } else if field_dt.is_nullable() { - values.push(ScalarValue::Null); - } else { - vortex_bail!(InvalidSerde: "Non-nullable struct field {} not found", field); - } + }) } - - Ok(ScalarValue::List(values.into())) } #[cfg(test)] mod test { - use std::collections::BTreeMap; use std::sync::Arc; - use prost_types::value::Kind; - use prost_types::Value; - use vortex_dtype::{DType, FieldNames, Nullability, PType, StructDType}; + use vortex_buffer::BufferString; + use vortex_dtype::PType::I32; + use vortex_dtype::{DType, Nullability}; use crate::Scalar; - use crate::{proto as pb, PValue, ScalarValue}; + use crate::{proto as pb, ScalarValue}; - fn round_trip(dtype: DType, value: Value) -> Scalar { - let pb_scalar = pb::scalar::Scalar { - dtype: Some(pb::dtype::DType::from(&dtype)), - value: Some(pb::scalar::ScalarValue { value: Some(value) }), - }; - Scalar::try_from(&pb_scalar).unwrap() + fn round_trip(scalar: Scalar) { + Scalar::try_from(&pb::scalar::Scalar::from(&scalar)).unwrap(); } #[test] fn test_null() { - let scalar = round_trip( - DType::Null, - Value { - kind: Some(Kind::NullValue(0)), - }, - ); - assert_eq!(scalar.value, ScalarValue::Null); + round_trip(Scalar::null(DType::Null)); } #[test] - fn test_nullable() { - let scalar = round_trip( + fn test_bool() { + round_trip(Scalar::new( DType::Bool(Nullability::Nullable), - Value { - kind: Some(Kind::NullValue(0)), - }, - ); - assert_eq!(scalar.value, ScalarValue::Null); + ScalarValue::Bool(true), + )); } #[test] - fn test_bool() { - let scalar = round_trip( - DType::Bool(Nullability::NonNullable), - Value { - kind: Some(Kind::BoolValue(true)), - }, - ); - assert_eq!(scalar.value, ScalarValue::Bool(true)); + fn test_primitive() { + round_trip(Scalar::new( + DType::Primitive(I32, Nullability::Nullable), + ScalarValue::Primitive(42i32.into()), + )); } #[test] - fn test_number() { - let scalar = round_trip( - DType::Primitive(PType::F64, Nullability::NonNullable), - Value { - kind: Some(Kind::NumberValue(42.42)), - }, - ); - assert_eq!(scalar.value, ScalarValue::Primitive(PValue::F64(42.42))); + fn test_buffer() { + round_trip(Scalar::new( + DType::Binary(Nullability::Nullable), + ScalarValue::Buffer(vec![1, 2, 3].into()), + )); } #[test] - fn test_string() { - let scalar = round_trip( - DType::Utf8(Nullability::NonNullable), - Value { - kind: Some(Kind::StringValue("hello".to_string())), - }, - ); - assert_eq!( - scalar.value, - ScalarValue::BufferString("hello".to_string().into()) - ); + fn test_buffer_string() { + round_trip(Scalar::new( + DType::Utf8(Nullability::Nullable), + ScalarValue::BufferString(BufferString::from("hello".to_string())), + )); } #[test] fn test_list() { - let scalar = round_trip( + round_trip(Scalar::new( DType::List( - Arc::new(DType::Bool(Nullability::Nullable)), - Nullability::NonNullable, - ), - Value { - kind: Some(Kind::ListValue(prost_types::ListValue { - values: vec![Value { - kind: Some(Kind::BoolValue(true)), - }], - })), - }, - ); - assert_eq!( - scalar.value, - ScalarValue::List(vec![ScalarValue::Bool(true)].into()) - ); - } - - #[test] - fn test_list_nullable() { - let scalar = round_trip( - DType::List( - Arc::new(DType::Bool(Nullability::Nullable)), + Arc::new(DType::Primitive(I32, Nullability::Nullable)), Nullability::Nullable, ), - Value { - kind: Some(Kind::ListValue(prost_types::ListValue { - values: vec![Value { - kind: Some(Kind::NullValue(0)), - }], - })), - }, - ); - assert_eq!( - scalar.value, - ScalarValue::List(vec![ScalarValue::Null].into()) - ); - } - - #[test] - fn test_struct() { - let names = FieldNames::from(vec![Arc::from("a")]); - let mut nested_fields = BTreeMap::new(); - nested_fields.insert( - "a".to_string(), - Value { - kind: Some(Kind::BoolValue(true)), - }, - ); - - let scalar = round_trip( - DType::Struct( - StructDType::new(names, vec![DType::Bool(Nullability::NonNullable)]), - Nullability::NonNullable, - ), - Value { - kind: Some(Kind::StructValue(prost_types::Struct { - fields: nested_fields, - })), - }, - ); - assert_eq!( - scalar.value, - ScalarValue::List(vec![ScalarValue::Bool(true)].into()) - ); - } - - #[test] - fn test_struct_nullable() { - let names = FieldNames::from(vec![Arc::from("a")]); - let nested_fields = BTreeMap::new(); - - let scalar = round_trip( - DType::Struct( - StructDType::new(names, vec![DType::Bool(Nullability::Nullable)]), - Nullability::NonNullable, + ScalarValue::List( + vec![ + ScalarValue::Primitive(42i32.into()), + ScalarValue::Primitive(43i32.into()), + ] + .into(), ), - Value { - kind: Some(Kind::StructValue(prost_types::Struct { - fields: nested_fields, - })), - }, - ); - assert_eq!( - scalar.value, - ScalarValue::List(vec![ScalarValue::Null].into()) - ); - } - - #[test] - fn test_wrong_type() { - let pb_scalar = pb::scalar::Scalar { - dtype: Some(pb::dtype::DType::from(&DType::Primitive( - PType::F64, - Nullability::NonNullable, - ))), - value: Some(pb::scalar::ScalarValue { - value: Some(Value { - kind: Some(Kind::BoolValue(true)), - }), - }), - }; - assert!(Scalar::try_from(&pb_scalar).is_err()); + )); } }