Skip to content

Commit

Permalink
fix: Handling nil values when converting from proto to arrow (#61)
Browse files Browse the repository at this point in the history
* fix: Handling nil valaues when converting from proto to arrow

---------

Co-authored-by: Bhargav Dodla <[email protected]>
  • Loading branch information
EXPEbdodla and Bhargav Dodla authored Oct 11, 2023
1 parent 6bb38a6 commit 8597890
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 62 deletions.
139 changes: 78 additions & 61 deletions go/types/typeconversion.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,81 +91,71 @@ func ValueTypeEnumToArrowType(t types.ValueType_Enum) (arrow.DataType, error) {
}

func CopyProtoValuesToArrowArray(builder array.Builder, values []*types.Value) error {
switch fieldBuilder := builder.(type) {
case *array.BooleanBuilder:
for _, v := range values {
fieldBuilder.Append(v.GetBoolVal())
for _, value := range values {
if value == nil {
builder.AppendNull()
continue
}
case *array.BinaryBuilder:
for _, v := range values {
fieldBuilder.Append(v.GetBytesVal())
}
case *array.StringBuilder:
for _, v := range values {
fieldBuilder.Append(v.GetStringVal())
}
case *array.Int32Builder:
for _, v := range values {
fieldBuilder.Append(v.GetInt32Val())
}
case *array.Int64Builder:
for _, v := range values {
fieldBuilder.Append(v.GetInt64Val())
}
case *array.Float32Builder:
for _, v := range values {
fieldBuilder.Append(v.GetFloatVal())
}
case *array.Float64Builder:
for _, v := range values {
fieldBuilder.Append(v.GetDoubleVal())
}
case *array.TimestampBuilder:
for _, v := range values {
fieldBuilder.Append(arrow.Timestamp(v.GetUnixTimestampVal()))
}
case *array.ListBuilder:
for _, list := range values {

switch fieldBuilder := builder.(type) {

case *array.BooleanBuilder:
fieldBuilder.Append(value.GetBoolVal())
case *array.BinaryBuilder:
fieldBuilder.Append(value.GetBytesVal())
case *array.StringBuilder:
fieldBuilder.Append(value.GetStringVal())
case *array.Int32Builder:
fieldBuilder.Append(value.GetInt32Val())
case *array.Int64Builder:
fieldBuilder.Append(value.GetInt64Val())
case *array.Float32Builder:
fieldBuilder.Append(value.GetFloatVal())
case *array.Float64Builder:
fieldBuilder.Append(value.GetDoubleVal())
case *array.TimestampBuilder:
fieldBuilder.Append(arrow.Timestamp(value.GetUnixTimestampVal()))
case *array.ListBuilder:
fieldBuilder.Append(true)

switch valueBuilder := fieldBuilder.ValueBuilder().(type) {

case *array.BooleanBuilder:
for _, v := range list.GetBoolListVal().GetVal() {
for _, v := range value.GetBoolListVal().GetVal() {
valueBuilder.Append(v)
}
case *array.BinaryBuilder:
for _, v := range list.GetBytesListVal().GetVal() {
for _, v := range value.GetBytesListVal().GetVal() {
valueBuilder.Append(v)
}
case *array.StringBuilder:
for _, v := range list.GetStringListVal().GetVal() {
for _, v := range value.GetStringListVal().GetVal() {
valueBuilder.Append(v)
}
case *array.Int32Builder:
for _, v := range list.GetInt32ListVal().GetVal() {
for _, v := range value.GetInt32ListVal().GetVal() {
valueBuilder.Append(v)
}
case *array.Int64Builder:
for _, v := range list.GetInt64ListVal().GetVal() {
for _, v := range value.GetInt64ListVal().GetVal() {
valueBuilder.Append(v)
}
case *array.Float32Builder:
for _, v := range list.GetFloatListVal().GetVal() {
for _, v := range value.GetFloatListVal().GetVal() {
valueBuilder.Append(v)
}
case *array.Float64Builder:
for _, v := range list.GetDoubleListVal().GetVal() {
for _, v := range value.GetDoubleListVal().GetVal() {
valueBuilder.Append(v)
}
case *array.TimestampBuilder:
for _, v := range list.GetUnixTimestampListVal().GetVal() {
for _, v := range value.GetUnixTimestampListVal().GetVal() {
valueBuilder.Append(arrow.Timestamp(v))
}
}
default:
return fmt.Errorf("unsupported array builder: %s", builder)
}
default:
return fmt.Errorf("unsupported array builder: %s", builder)
}
return nil
}
Expand Down Expand Up @@ -249,41 +239,68 @@ func ArrowValuesToProtoValues(arr arrow.Array) ([]*types.Value, error) {

switch arr.DataType() {
case arrow.PrimitiveTypes.Int32:
for _, v := range arr.(*array.Int32).Int32Values() {
values = append(values, &types.Value{Val: &types.Value_Int32Val{Int32Val: v}})
for idx := 0; idx < arr.Len(); idx++ {
if arr.IsNull(idx) {
values = append(values, nil)
} else {
values = append(values, &types.Value{Val: &types.Value_Int32Val{Int32Val: arr.(*array.Int32).Value(idx)}})
}
}
case arrow.PrimitiveTypes.Int64:
for _, v := range arr.(*array.Int64).Int64Values() {
values = append(values, &types.Value{Val: &types.Value_Int64Val{Int64Val: v}})
for idx := 0; idx < arr.Len(); idx++ {
if arr.IsNull(idx) {
values = append(values, nil)
} else {
values = append(values, &types.Value{Val: &types.Value_Int64Val{Int64Val: arr.(*array.Int64).Value(idx)}})
}
}
case arrow.PrimitiveTypes.Float32:
for _, v := range arr.(*array.Float32).Float32Values() {
values = append(values, &types.Value{Val: &types.Value_FloatVal{FloatVal: v}})
for idx := 0; idx < arr.Len(); idx++ {
if arr.IsNull(idx) {
values = append(values, nil)
} else {
values = append(values, &types.Value{Val: &types.Value_FloatVal{FloatVal: arr.(*array.Float32).Value(idx)}})
}
}
case arrow.PrimitiveTypes.Float64:
for _, v := range arr.(*array.Float64).Float64Values() {
values = append(values, &types.Value{Val: &types.Value_DoubleVal{DoubleVal: v}})
for idx := 0; idx < arr.Len(); idx++ {
if arr.IsNull(idx) {
values = append(values, nil)
} else {
values = append(values, &types.Value{Val: &types.Value_DoubleVal{DoubleVal: arr.(*array.Float64).Value(idx)}})
}
}
case arrow.FixedWidthTypes.Boolean:
for idx := 0; idx < arr.Len(); idx++ {
values = append(values,
&types.Value{Val: &types.Value_BoolVal{BoolVal: arr.(*array.Boolean).Value(idx)}})
if arr.IsNull(idx) {
values = append(values, nil)
} else {
values = append(values, &types.Value{Val: &types.Value_BoolVal{BoolVal: arr.(*array.Boolean).Value(idx)}})
}
}
case arrow.BinaryTypes.Binary:
for idx := 0; idx < arr.Len(); idx++ {
values = append(values,
&types.Value{Val: &types.Value_BytesVal{BytesVal: arr.(*array.Binary).Value(idx)}})
if arr.IsNull(idx) {
values = append(values, nil)
} else {
values = append(values, &types.Value{Val: &types.Value_BytesVal{BytesVal: arr.(*array.Binary).Value(idx)}})
}
}
case arrow.BinaryTypes.String:
for idx := 0; idx < arr.Len(); idx++ {
values = append(values,
&types.Value{Val: &types.Value_StringVal{StringVal: arr.(*array.String).Value(idx)}})
if arr.IsNull(idx) {
values = append(values, nil)
} else {
values = append(values, &types.Value{Val: &types.Value_StringVal{StringVal: arr.(*array.String).Value(idx)}})
}
}
case arrow.FixedWidthTypes.Timestamp_s:
for idx := 0; idx < arr.Len(); idx++ {
values = append(values,
&types.Value{Val: &types.Value_UnixTimestampVal{
UnixTimestampVal: int64(arr.(*array.Timestamp).Value(idx))}})
if arr.IsNull(idx) {
values = append(values, nil)
} else {
values = append(values, &types.Value{Val: &types.Value_UnixTimestampVal{UnixTimestampVal: int64(arr.(*array.Timestamp).Value(idx))}})
}
}
case arrow.Null:
for idx := 0; idx < arr.Len(); idx++ {
Expand Down
9 changes: 9 additions & 0 deletions go/types/typeconversion_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,22 @@ import (

var (
PROTO_VALUES = [][]*types.Value{
{{Val: &types.Value_Int32Val{10}}, nil},
{nil, {Val: &types.Value_Int32Val{20}}},
{{Val: &types.Value_Int32Val{10}}, {Val: &types.Value_Int32Val{20}}},
{{Val: &types.Value_Int64Val{10}}, nil},
{{Val: &types.Value_Int64Val{10}}, {Val: &types.Value_Int64Val{20}}},
{nil, {Val: &types.Value_FloatVal{2.0}}},
{{Val: &types.Value_FloatVal{1.0}}, {Val: &types.Value_FloatVal{2.0}}},
{{Val: &types.Value_DoubleVal{1.0}}, {Val: &types.Value_DoubleVal{2.0}}},
{{Val: &types.Value_DoubleVal{1.0}}, nil},
{nil, {Val: &types.Value_StringVal{"bbb"}}},
{{Val: &types.Value_StringVal{"aaa"}}, {Val: &types.Value_StringVal{"bbb"}}},
{{Val: &types.Value_BytesVal{[]byte{1, 2, 3}}}, nil},
{{Val: &types.Value_BytesVal{[]byte{1, 2, 3}}}, {Val: &types.Value_BytesVal{[]byte{4, 5, 6}}}},
{nil, {Val: &types.Value_BoolVal{false}}},
{{Val: &types.Value_BoolVal{true}}, {Val: &types.Value_BoolVal{false}}},
{{Val: &types.Value_UnixTimestampVal{time.Now().Unix()}}, nil},
{{Val: &types.Value_UnixTimestampVal{time.Now().Unix()}},
{Val: &types.Value_UnixTimestampVal{time.Now().Unix()}}},

Expand Down
5 changes: 4 additions & 1 deletion sdk/python/feast/infra/registry/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -999,7 +999,10 @@ def _apply_object(
}
update_stmt = (
update(table)
.where(getattr(table.c, id_field_name) == name)
.where(
getattr(table.c, id_field_name) == name,
table.c.project_id == project,
)
.values(
values,
)
Expand Down

0 comments on commit 8597890

Please sign in to comment.