diff --git a/vortex-scalar/proto/vortex/scalar/scalar.proto b/vortex-scalar/proto/vortex/scalar/scalar.proto index 6b95928e82..7fa00ad5dc 100644 --- a/vortex-scalar/proto/vortex/scalar/scalar.proto +++ b/vortex-scalar/proto/vortex/scalar/scalar.proto @@ -4,6 +4,7 @@ package vortex.scalar; import "vortex/dtype/dtype.proto"; import "google/protobuf/struct.proto"; +import "google/protobuf/wrappers.proto"; message Scalar { vortex.dtype.DType dtype = 1; @@ -11,5 +12,21 @@ message Scalar { } message ScalarValue { - google.protobuf.Value value = 1; + oneof kind { + google.protobuf.NullValue null_value = 1; + bool bool_value = 2; + int32 int32_value = 3; + int64 int64_value = 4; + uint32 uint32_value = 5; + uint64 uint64_value = 6; + float float_value = 7; + double double_value = 8; + string string_value = 9; + bytes bytes_value = 10; + ListValue list_value = 12; + } +} + +message ListValue { + repeated ScalarValue values = 1; } diff --git a/vortex-scalar/src/serde/proto.rs b/vortex-scalar/src/serde/proto.rs index 1bff3d24dd..4de7110d84 100644 --- a/vortex-scalar/src/serde/proto.rs +++ b/vortex-scalar/src/serde/proto.rs @@ -1,14 +1,92 @@ #![cfg(feature = "proto")] -use prost_types::value::Kind; -use prost_types::{ListValue, Struct, Value}; -use vortex_buffer::BufferString; -use vortex_dtype::{DType, StructDType}; -use vortex_error::{vortex_bail, vortex_err, VortexError}; +use vortex_buffer::{Buffer, BufferString}; +use vortex_dtype::DType; +use vortex_error::{vortex_err, VortexError}; +use crate::proto::scalar::scalar_value::Kind; +use crate::proto::scalar::ListValue; use crate::pvalue::PValue; use crate::{proto::scalar as pb, Scalar, ScalarValue}; +impl From<&Scalar> for pb::Scalar { + fn from(value: &Scalar) -> Self { + pb::Scalar { + dtype: Some((&value.dtype).into()), + value: Some((&value.value).into()), + } + } +} + +impl From<&ScalarValue> for pb::ScalarValue { + fn from(value: &ScalarValue) -> Self { + match value { + ScalarValue::Null => pb::ScalarValue { + kind: Some(Kind::NullValue(0)), + }, + ScalarValue::Bool(v) => pb::ScalarValue { + kind: Some(Kind::BoolValue(*v)), + }, + ScalarValue::Primitive(v) => v.into(), + ScalarValue::Buffer(v) => pb::ScalarValue { + kind: Some(Kind::BytesValue(v.as_slice().to_vec())), + }, + ScalarValue::BufferString(v) => pb::ScalarValue { + kind: Some(Kind::StringValue(v.as_str().to_string())), + }, + ScalarValue::List(v) => { + let mut values = Vec::with_capacity(v.len()); + for elem in v.iter() { + values.push(pb::ScalarValue::from(elem)); + } + pb::ScalarValue { + kind: Some(Kind::ListValue(ListValue { values })), + } + } + } + } +} + +impl From<&PValue> for pb::ScalarValue { + fn from(value: &PValue) -> Self { + match value { + PValue::I8(v) => pb::ScalarValue { + kind: Some(Kind::Int32Value(*v as i32)), + }, + PValue::I16(v) => pb::ScalarValue { + kind: Some(Kind::Int32Value(*v as i32)), + }, + PValue::I32(v) => pb::ScalarValue { + kind: Some(Kind::Int32Value(*v)), + }, + PValue::I64(v) => pb::ScalarValue { + kind: Some(Kind::Int64Value(*v)), + }, + PValue::U8(v) => pb::ScalarValue { + kind: Some(Kind::Uint32Value(*v as u32)), + }, + PValue::U16(v) => pb::ScalarValue { + kind: Some(Kind::Uint32Value(*v as u32)), + }, + PValue::U32(v) => pb::ScalarValue { + kind: Some(Kind::Uint32Value(*v)), + }, + PValue::U64(v) => pb::ScalarValue { + kind: Some(Kind::Uint64Value(*v)), + }, + PValue::F16(v) => pb::ScalarValue { + kind: Some(Kind::FloatValue(v.to_f32())), + }, + PValue::F32(v) => pb::ScalarValue { + kind: Some(Kind::FloatValue(*v)), + }, + PValue::F64(v) => pb::ScalarValue { + kind: Some(Kind::DoubleValue(*v)), + }, + } + } +} + impl TryFrom<&pb::Scalar> for Scalar { type Error = VortexError; @@ -20,303 +98,114 @@ impl TryFrom<&pb::Scalar> for Scalar { .ok_or_else(|| vortex_err!(InvalidSerde: "Scalar missing dtype"))?, )?; - let scalar_value = value - .value - .as_ref() - .ok_or_else(|| vortex_err!(InvalidSerde: "Scalar missing value"))?; - - let pb_value = scalar_value - .value - .as_ref() - .ok_or_else(|| vortex_err!(InvalidSerde: "ScalarValue missing value"))?; - - let value = try_from_value(&dtype, pb_value)?; + let value = ScalarValue::try_from( + value + .value + .as_ref() + .ok_or_else(|| vortex_err!(InvalidSerde: "Scalar missing value"))?, + )?; Ok(Self { dtype, value }) } } -fn try_from_value(dtype: &DType, value: &Value) -> Result { - let kind = value - .kind - .as_ref() - .ok_or_else(|| vortex_err!(InvalidSerde: "Value missing kind"))?; - - Ok(match kind { - Kind::NullValue(_) => { - if !dtype.is_nullable() { - vortex_bail!(InvalidSerde: "Expected a nullable or Null dtype, found {:?}", dtype); - } - - ScalarValue::Null - } - Kind::BoolValue(v) => { - if !matches!(dtype, DType::Bool(_)) { - vortex_bail!(InvalidSerde: "Expected a bool dtype, found {:?}", dtype); - } - - ScalarValue::Bool(*v) - } - Kind::NumberValue(v) => { - if !matches!(dtype, DType::Primitive(_, _)) { - vortex_bail!(InvalidSerde: "Expected a primitive dtype, found {:?}", dtype); - } - - ScalarValue::Primitive(PValue::F64(*v)) - } - Kind::StringValue(v) => { - if !matches!(dtype, DType::Utf8(_)) { - vortex_bail!(InvalidSerde: "Expected a utf8 dtype, found {:?}", dtype); - } - - ScalarValue::BufferString(BufferString::from(v.clone())) - } - Kind::ListValue(v) => { - if let DType::List(elem_dtype, _) = dtype { - return try_from_list_value(elem_dtype, v); - } - - vortex_bail!(InvalidSerde: "Expected a list dtype, found {:?}", dtype); - } - Kind::StructValue(v) => { - if let DType::Struct(sdt, _) = dtype { - return try_from_struct_value(sdt, v); - } - - vortex_bail!(InvalidSerde: "Expected a struct dtype, found {:?}", dtype); - } - }) -} - -fn try_from_list_value(elem_dtype: &DType, value: &ListValue) -> Result { - let mut values = vec![]; - - for elem in value.values.iter() { - let nested = try_from_value(elem_dtype, elem)?; - - // Allow null values for nullable list only. - if matches!(nested, ScalarValue::Null) && !elem_dtype.is_nullable() { - vortex_bail!(InvalidSerde: "Non-nullable list element is null"); - } - - values.push(try_from_value(elem_dtype, elem)?); - } - - Ok(ScalarValue::List(values.into())) -} - -fn try_from_struct_value(dtype: &StructDType, value: &Struct) -> Result { - let mut values = vec![]; - - for (field, field_dt) in dtype.names().iter().zip(dtype.dtypes().iter()) { - if let Some((_, v)) = - // Add field values in order defined by the struct dtype. - value - .fields - .iter() - .find(|(f, _)| field.as_ref() == f.as_str()) - { - let nested = try_from_value(field_dt, v)?; +impl TryFrom<&pb::ScalarValue> for ScalarValue { + type Error = VortexError; - // Allow null values for nullable struct only. - if matches!(nested, ScalarValue::Null) && !field_dt.is_nullable() { - vortex_bail!(InvalidSerde: "Non-nullable struct field {} is null", field); + fn try_from(value: &pb::ScalarValue) -> Result { + let kind = value + .kind + .as_ref() + .ok_or_else(|| vortex_err!(InvalidSerde: "ScalarValue missing kind"))?; + + Ok(match kind { + Kind::NullValue(_) => ScalarValue::Null, + Kind::BoolValue(v) => ScalarValue::Bool(*v), + Kind::Int32Value(v) => ScalarValue::Primitive(PValue::I32(*v)), + Kind::Int64Value(v) => ScalarValue::Primitive(PValue::I64(*v)), + Kind::Uint32Value(v) => ScalarValue::Primitive(PValue::U32(*v)), + Kind::Uint64Value(v) => ScalarValue::Primitive(PValue::U64(*v)), + Kind::FloatValue(v) => ScalarValue::Primitive(PValue::F32(*v)), + Kind::DoubleValue(v) => ScalarValue::Primitive(PValue::F64(*v)), + Kind::StringValue(v) => ScalarValue::BufferString(BufferString::from(v.clone())), + Kind::BytesValue(v) => ScalarValue::Buffer(Buffer::from(v.as_slice())), + Kind::ListValue(v) => { + let mut values = Vec::with_capacity(v.values.len()); + for elem in v.values.iter() { + values.push(ScalarValue::try_from(elem)?); + } + ScalarValue::List(values.into()) } - - values.push(try_from_value(field_dt, v)?); - } else if field_dt.is_nullable() { - values.push(ScalarValue::Null); - } else { - vortex_bail!(InvalidSerde: "Non-nullable struct field {} not found", field); - } + }) } - - Ok(ScalarValue::List(values.into())) } #[cfg(test)] mod test { - use std::collections::BTreeMap; use std::sync::Arc; - use prost_types::value::Kind; - use prost_types::Value; - use vortex_dtype::{DType, FieldNames, Nullability, PType, StructDType}; + use vortex_buffer::BufferString; + use vortex_dtype::PType::I32; + use vortex_dtype::{DType, Nullability}; use crate::Scalar; - use crate::{proto as pb, PValue, ScalarValue}; + use crate::{proto as pb, ScalarValue}; - fn round_trip(dtype: DType, value: Value) -> Scalar { - let pb_scalar = pb::scalar::Scalar { - dtype: Some(pb::dtype::DType::from(&dtype)), - value: Some(pb::scalar::ScalarValue { value: Some(value) }), - }; - Scalar::try_from(&pb_scalar).unwrap() + fn round_trip(scalar: Scalar) { + Scalar::try_from(&pb::scalar::Scalar::from(&scalar)).unwrap(); } #[test] fn test_null() { - let scalar = round_trip( - DType::Null, - Value { - kind: Some(Kind::NullValue(0)), - }, - ); - assert_eq!(scalar.value, ScalarValue::Null); + round_trip(Scalar::null(DType::Null)); } #[test] - fn test_nullable() { - let scalar = round_trip( + fn test_bool() { + round_trip(Scalar::new( DType::Bool(Nullability::Nullable), - Value { - kind: Some(Kind::NullValue(0)), - }, - ); - assert_eq!(scalar.value, ScalarValue::Null); + ScalarValue::Bool(true), + )); } #[test] - fn test_bool() { - let scalar = round_trip( - DType::Bool(Nullability::NonNullable), - Value { - kind: Some(Kind::BoolValue(true)), - }, - ); - assert_eq!(scalar.value, ScalarValue::Bool(true)); + fn test_primitive() { + round_trip(Scalar::new( + DType::Primitive(I32, Nullability::Nullable), + ScalarValue::Primitive(42i32.into()), + )); } #[test] - fn test_number() { - let scalar = round_trip( - DType::Primitive(PType::F64, Nullability::NonNullable), - Value { - kind: Some(Kind::NumberValue(42.42)), - }, - ); - assert_eq!(scalar.value, ScalarValue::Primitive(PValue::F64(42.42))); + fn test_buffer() { + round_trip(Scalar::new( + DType::Binary(Nullability::Nullable), + ScalarValue::Buffer(vec![1, 2, 3].into()), + )); } #[test] - fn test_string() { - let scalar = round_trip( - DType::Utf8(Nullability::NonNullable), - Value { - kind: Some(Kind::StringValue("hello".to_string())), - }, - ); - assert_eq!( - scalar.value, - ScalarValue::BufferString("hello".to_string().into()) - ); + fn test_buffer_string() { + round_trip(Scalar::new( + DType::Utf8(Nullability::Nullable), + ScalarValue::BufferString(BufferString::from("hello".to_string())), + )); } #[test] fn test_list() { - let scalar = round_trip( + round_trip(Scalar::new( DType::List( - Arc::new(DType::Bool(Nullability::Nullable)), - Nullability::NonNullable, - ), - Value { - kind: Some(Kind::ListValue(prost_types::ListValue { - values: vec![Value { - kind: Some(Kind::BoolValue(true)), - }], - })), - }, - ); - assert_eq!( - scalar.value, - ScalarValue::List(vec![ScalarValue::Bool(true)].into()) - ); - } - - #[test] - fn test_list_nullable() { - let scalar = round_trip( - DType::List( - Arc::new(DType::Bool(Nullability::Nullable)), + Arc::new(DType::Primitive(I32, Nullability::Nullable)), Nullability::Nullable, ), - Value { - kind: Some(Kind::ListValue(prost_types::ListValue { - values: vec![Value { - kind: Some(Kind::NullValue(0)), - }], - })), - }, - ); - assert_eq!( - scalar.value, - ScalarValue::List(vec![ScalarValue::Null].into()) - ); - } - - #[test] - fn test_struct() { - let names = FieldNames::from(vec![Arc::from("a")]); - let mut nested_fields = BTreeMap::new(); - nested_fields.insert( - "a".to_string(), - Value { - kind: Some(Kind::BoolValue(true)), - }, - ); - - let scalar = round_trip( - DType::Struct( - StructDType::new(names, vec![DType::Bool(Nullability::NonNullable)]), - Nullability::NonNullable, - ), - Value { - kind: Some(Kind::StructValue(prost_types::Struct { - fields: nested_fields, - })), - }, - ); - assert_eq!( - scalar.value, - ScalarValue::List(vec![ScalarValue::Bool(true)].into()) - ); - } - - #[test] - fn test_struct_nullable() { - let names = FieldNames::from(vec![Arc::from("a")]); - let nested_fields = BTreeMap::new(); - - let scalar = round_trip( - DType::Struct( - StructDType::new(names, vec![DType::Bool(Nullability::Nullable)]), - Nullability::NonNullable, + ScalarValue::List( + vec![ + ScalarValue::Primitive(42i32.into()), + ScalarValue::Primitive(43i32.into()), + ] + .into(), ), - Value { - kind: Some(Kind::StructValue(prost_types::Struct { - fields: nested_fields, - })), - }, - ); - assert_eq!( - scalar.value, - ScalarValue::List(vec![ScalarValue::Null].into()) - ); - } - - #[test] - fn test_wrong_type() { - let pb_scalar = pb::scalar::Scalar { - dtype: Some(pb::dtype::DType::from(&DType::Primitive( - PType::F64, - Nullability::NonNullable, - ))), - value: Some(pb::scalar::ScalarValue { - value: Some(Value { - kind: Some(Kind::BoolValue(true)), - }), - }), - }; - assert!(Scalar::try_from(&pb_scalar).is_err()); + )); } }