diff --git a/.github/workflows/go-ci.yml b/.github/workflows/go-ci.yml index 2b158b5..3d0c0c0 100644 --- a/.github/workflows/go-ci.yml +++ b/.github/workflows/go-ci.yml @@ -22,8 +22,8 @@ on: branches: - 'main' tags: - - 'v**' - pull_request: + - 'v**' + pull_request: concurrency: group: ${{ github.repository }}-${{ github.head_ref || github.sha }}-${{ github.workflow }} diff --git a/schema.go b/schema.go index 7ea7757..f6d88d3 100644 --- a/schema.go +++ b/schema.go @@ -427,6 +427,25 @@ type AfterMapValueVisitor interface { AfterMapValue(value NestedField) } +type SchemaVisitorPerPrimitiveType[T any] interface { + SchemaVisitor[T] + + VisitFixed(FixedType) T + VisitDecimal(DecimalType) T + VisitBoolean() T + VisitInt32() T + VisitInt64() T + VisitFloat32() T + VisitFloat64() T + VisitDate() T + VisitTime() T + VisitTimestamp() T + VisitTimestampTz() T + VisitString() T + VisitBinary() T + VisitUUID() T +} + // Visit accepts a visitor and performs a post-order traversal of the given schema. func Visit[T any](sc *Schema, visitor SchemaVisitor[T]) (res T, err error) { if sc == nil { @@ -534,6 +553,38 @@ func visitField[T any](f NestedField, visitor SchemaVisitor[T]) T { case *MapType: return visitMap(*typ, visitor) default: // primitive + if perPrimitive, ok := visitor.(SchemaVisitorPerPrimitiveType[T]); ok { + switch t := typ.(type) { + case BooleanType: + return perPrimitive.VisitBoolean() + case Int32Type: + return perPrimitive.VisitInt32() + case Int64Type: + return perPrimitive.VisitInt64() + case Float32Type: + return perPrimitive.VisitFloat32() + case Float64Type: + return perPrimitive.VisitFloat64() + case DateType: + return perPrimitive.VisitDate() + case TimeType: + return perPrimitive.VisitTime() + case TimestampType: + return perPrimitive.VisitTimestamp() + case TimestampTzType: + return perPrimitive.VisitTimestampTz() + case StringType: + return perPrimitive.VisitString() + case BinaryType: + return perPrimitive.VisitBinary() + case UUIDType: + return perPrimitive.VisitUUID() + case DecimalType: + return perPrimitive.VisitDecimal(t) + case FixedType: + return perPrimitive.VisitFixed(t) + } + } return visitor.Primitive(typ.(PrimitiveType)) } } @@ -706,8 +757,7 @@ func (i *indexByName) AfterField(field NestedField) { // PruneColumns visits a schema pruning any columns which do not exist in the // provided selected set. Parent fields of a selected child will be retained. func PruneColumns(schema *Schema, selected map[int]Void, selectFullTypes bool) (*Schema, error) { - - result, err := Visit[Type](schema, &pruneColVisitor{selected: selected, + result, err := Visit(schema, &pruneColVisitor{selected: selected, fullTypes: selectFullTypes}) if err != nil { return nil, err diff --git a/table/arrow_utils.go b/table/arrow_utils.go index 6104fc6..8f3890a 100644 --- a/table/arrow_utils.go +++ b/table/arrow_utils.go @@ -23,6 +23,7 @@ import ( "strconv" "github.com/apache/arrow-go/v18/arrow" + "github.com/apache/arrow-go/v18/arrow/extensions" "github.com/apache/iceberg-go" ) @@ -410,3 +411,132 @@ func ArrowSchemaToIceberg(sc *arrow.Schema, downcastNsTimestamp bool, nameMappin iceberg.ErrInvalidSchema) } } + +type convertToArrow struct { + metadata map[string]string + includeFieldIDs bool +} + +func (c convertToArrow) Schema(_ *iceberg.Schema, result arrow.Field) arrow.Field { + result.Metadata = arrow.MetadataFrom(c.metadata) + return result +} + +func (c convertToArrow) Struct(_ iceberg.StructType, results []arrow.Field) arrow.Field { + return arrow.Field{Type: arrow.StructOf(results...)} +} + +func (c convertToArrow) Field(field iceberg.NestedField, result arrow.Field) arrow.Field { + meta := map[string]string{} + if len(field.Doc) > 0 { + meta[ArrowFieldDocKey] = field.Doc + } + + if c.includeFieldIDs { + meta[ArrowParquetFieldIDKey] = strconv.Itoa(field.ID) + } + + if len(meta) > 0 { + result.Metadata = arrow.MetadataFrom(meta) + } + + result.Name, result.Nullable = field.Name, !field.Required + return result +} + +func (c convertToArrow) List(list iceberg.ListType, elemResult arrow.Field) arrow.Field { + elemField := c.Field(list.ElementField(), elemResult) + return arrow.Field{Type: arrow.LargeListOfField(elemField)} +} + +func (c convertToArrow) Map(m iceberg.MapType, keyResult, valResult arrow.Field) arrow.Field { + keyField := c.Field(m.KeyField(), keyResult) + valField := c.Field(m.ValueField(), valResult) + return arrow.Field{Type: arrow.MapOfWithMetadata(keyField.Type, keyField.Metadata, + valField.Type, valField.Metadata)} +} + +func (c convertToArrow) Primitive(iceberg.PrimitiveType) arrow.Field { panic("shouldn't be called") } + +func (c convertToArrow) VisitFixed(f iceberg.FixedType) arrow.Field { + return arrow.Field{Type: &arrow.FixedSizeBinaryType{ByteWidth: f.Len()}} +} + +func (c convertToArrow) VisitDecimal(d iceberg.DecimalType) arrow.Field { + return arrow.Field{Type: &arrow.Decimal128Type{ + Precision: int32(d.Precision()), Scale: int32(d.Scale())}} +} + +func (c convertToArrow) VisitBoolean() arrow.Field { + return arrow.Field{Type: arrow.FixedWidthTypes.Boolean} +} + +func (c convertToArrow) VisitInt32() arrow.Field { + return arrow.Field{Type: arrow.PrimitiveTypes.Int32} +} + +func (c convertToArrow) VisitInt64() arrow.Field { + return arrow.Field{Type: arrow.PrimitiveTypes.Int64} +} + +func (c convertToArrow) VisitFloat32() arrow.Field { + return arrow.Field{Type: arrow.PrimitiveTypes.Float32} +} + +func (c convertToArrow) VisitFloat64() arrow.Field { + return arrow.Field{Type: arrow.PrimitiveTypes.Float64} +} + +func (c convertToArrow) VisitDate() arrow.Field { + return arrow.Field{Type: arrow.FixedWidthTypes.Date32} +} + +func (c convertToArrow) VisitTime() arrow.Field { + return arrow.Field{Type: arrow.FixedWidthTypes.Time64us} +} + +func (c convertToArrow) VisitTimestampTz() arrow.Field { + return arrow.Field{Type: arrow.FixedWidthTypes.Timestamp_us} +} + +func (c convertToArrow) VisitTimestamp() arrow.Field { + return arrow.Field{Type: &arrow.TimestampType{Unit: arrow.Microsecond}} +} + +func (c convertToArrow) VisitString() arrow.Field { + return arrow.Field{Type: arrow.BinaryTypes.LargeString} +} + +func (c convertToArrow) VisitBinary() arrow.Field { + return arrow.Field{Type: arrow.BinaryTypes.LargeBinary} +} + +func (c convertToArrow) VisitUUID() arrow.Field { + return arrow.Field{Type: extensions.NewUUIDType()} +} + +// SchemaToArrowSchema converts an Iceberg schema to an Arrow schema. If the metadata parameter +// is non-nil, it will be included as the top-level metadata in the schema. If includeFieldIDs +// is true, then each field of the schema will contain a metadata key PARQUET:field_id set to +// the field id from the iceberg schema. +func SchemaToArrowSchema(sc *iceberg.Schema, metadata map[string]string, includeFieldIDs bool) (*arrow.Schema, error) { + top, err := iceberg.Visit(sc, convertToArrow{metadata: metadata, includeFieldIDs: includeFieldIDs}) + if err != nil { + return nil, err + } + + return arrow.NewSchema(top.Type.(*arrow.StructType).Fields(), &top.Metadata), nil +} + +// TypeToArrowType converts a given iceberg type, into the equivalent Arrow data type. +// For dealing with nested fields (List, Struct, Map) if includeFieldIDs is true, then +// the child fields will contain a metadata key PARQUET:field_id set to the field id. +func TypeToArrowType(t iceberg.Type, includeFieldIDs bool) (arrow.DataType, error) { + top, err := iceberg.Visit(iceberg.NewSchema(0, iceberg.NestedField{Type: t}), + convertToArrow{includeFieldIDs: includeFieldIDs}) + if err != nil { + return nil, err + } + + return top.Type.(*arrow.StructType).Field(0).Type, nil +} diff --git a/table/arrow_utils_test.go b/table/arrow_utils_test.go index 1d8173e..76c40fd 100644 --- a/table/arrow_utils_test.go +++ b/table/arrow_utils_test.go @@ -34,49 +34,50 @@ func fieldIDMeta(id string) arrow.Metadata { func TestArrowToIceberg(t *testing.T) { tests := []struct { - dt arrow.DataType - ice iceberg.Type - err string + dt arrow.DataType + ice iceberg.Type + reciprocal bool + err string }{ - {&arrow.FixedSizeBinaryType{ByteWidth: 23}, iceberg.FixedTypeOf(23), ""}, - {&arrow.Decimal32Type{Precision: 8, Scale: 9}, iceberg.DecimalTypeOf(8, 9), ""}, - {&arrow.Decimal64Type{Precision: 15, Scale: 14}, iceberg.DecimalTypeOf(15, 14), ""}, - {&arrow.Decimal128Type{Precision: 26, Scale: 20}, iceberg.DecimalTypeOf(26, 20), ""}, - {&arrow.Decimal256Type{Precision: 8, Scale: 9}, nil, "unsupported arrow type for conversion - decimal256(8, 9)"}, - {arrow.FixedWidthTypes.Boolean, iceberg.PrimitiveTypes.Bool, ""}, - {arrow.PrimitiveTypes.Int8, iceberg.PrimitiveTypes.Int32, ""}, - {arrow.PrimitiveTypes.Uint8, iceberg.PrimitiveTypes.Int32, ""}, - {arrow.PrimitiveTypes.Int16, iceberg.PrimitiveTypes.Int32, ""}, - {arrow.PrimitiveTypes.Uint16, iceberg.PrimitiveTypes.Int32, ""}, - {arrow.PrimitiveTypes.Int32, iceberg.PrimitiveTypes.Int32, ""}, - {arrow.PrimitiveTypes.Uint32, iceberg.PrimitiveTypes.Int32, ""}, - {arrow.PrimitiveTypes.Int64, iceberg.PrimitiveTypes.Int64, ""}, - {arrow.PrimitiveTypes.Uint64, iceberg.PrimitiveTypes.Int64, ""}, - {arrow.FixedWidthTypes.Float16, iceberg.PrimitiveTypes.Float32, ""}, - {arrow.PrimitiveTypes.Float32, iceberg.PrimitiveTypes.Float32, ""}, - {arrow.PrimitiveTypes.Float64, iceberg.PrimitiveTypes.Float64, ""}, - {arrow.FixedWidthTypes.Date32, iceberg.PrimitiveTypes.Date, ""}, - {arrow.FixedWidthTypes.Date64, nil, "unsupported arrow type for conversion - date64"}, - {arrow.FixedWidthTypes.Time32s, nil, "unsupported arrow type for conversion - time32[s]"}, - {arrow.FixedWidthTypes.Time32ms, nil, "unsupported arrow type for conversion - time32[ms]"}, - {arrow.FixedWidthTypes.Time64us, iceberg.PrimitiveTypes.Time, ""}, - {arrow.FixedWidthTypes.Time64ns, nil, "unsupported arrow type for conversion - time64[ns]"}, - {arrow.FixedWidthTypes.Timestamp_s, iceberg.PrimitiveTypes.TimestampTz, ""}, - {arrow.FixedWidthTypes.Timestamp_ms, iceberg.PrimitiveTypes.TimestampTz, ""}, - {arrow.FixedWidthTypes.Timestamp_us, iceberg.PrimitiveTypes.TimestampTz, ""}, - {arrow.FixedWidthTypes.Timestamp_ns, nil, "'ns' timestamp precision not supported"}, - {&arrow.TimestampType{Unit: arrow.Second}, iceberg.PrimitiveTypes.Timestamp, ""}, - {&arrow.TimestampType{Unit: arrow.Millisecond}, iceberg.PrimitiveTypes.Timestamp, ""}, - {&arrow.TimestampType{Unit: arrow.Microsecond}, iceberg.PrimitiveTypes.Timestamp, ""}, - {&arrow.TimestampType{Unit: arrow.Nanosecond}, nil, "'ns' timestamp precision not supported"}, - {&arrow.TimestampType{Unit: arrow.Microsecond, TimeZone: "US/Pacific"}, nil, "unsupported arrow type for conversion - timestamp[us, tz=US/Pacific]"}, - {arrow.BinaryTypes.String, iceberg.PrimitiveTypes.String, ""}, - {arrow.BinaryTypes.LargeString, iceberg.PrimitiveTypes.String, ""}, - {arrow.BinaryTypes.StringView, nil, "unsupported arrow type for conversion - string_view"}, - {arrow.BinaryTypes.Binary, iceberg.PrimitiveTypes.Binary, ""}, - {arrow.BinaryTypes.LargeBinary, iceberg.PrimitiveTypes.Binary, ""}, - {arrow.BinaryTypes.BinaryView, nil, "unsupported arrow type for conversion - binary_view"}, - {extensions.NewUUIDType(), iceberg.PrimitiveTypes.UUID, ""}, + {&arrow.FixedSizeBinaryType{ByteWidth: 23}, iceberg.FixedTypeOf(23), true, ""}, + {&arrow.Decimal32Type{Precision: 8, Scale: 9}, iceberg.DecimalTypeOf(8, 9), false, ""}, + {&arrow.Decimal64Type{Precision: 15, Scale: 14}, iceberg.DecimalTypeOf(15, 14), false, ""}, + {&arrow.Decimal128Type{Precision: 26, Scale: 20}, iceberg.DecimalTypeOf(26, 20), true, ""}, + {&arrow.Decimal256Type{Precision: 8, Scale: 9}, nil, false, "unsupported arrow type for conversion - decimal256(8, 9)"}, + {arrow.FixedWidthTypes.Boolean, iceberg.PrimitiveTypes.Bool, true, ""}, + {arrow.PrimitiveTypes.Int8, iceberg.PrimitiveTypes.Int32, false, ""}, + {arrow.PrimitiveTypes.Uint8, iceberg.PrimitiveTypes.Int32, false, ""}, + {arrow.PrimitiveTypes.Int16, iceberg.PrimitiveTypes.Int32, false, ""}, + {arrow.PrimitiveTypes.Uint16, iceberg.PrimitiveTypes.Int32, false, ""}, + {arrow.PrimitiveTypes.Int32, iceberg.PrimitiveTypes.Int32, true, ""}, + {arrow.PrimitiveTypes.Uint32, iceberg.PrimitiveTypes.Int32, false, ""}, + {arrow.PrimitiveTypes.Int64, iceberg.PrimitiveTypes.Int64, true, ""}, + {arrow.PrimitiveTypes.Uint64, iceberg.PrimitiveTypes.Int64, false, ""}, + {arrow.FixedWidthTypes.Float16, iceberg.PrimitiveTypes.Float32, false, ""}, + {arrow.PrimitiveTypes.Float32, iceberg.PrimitiveTypes.Float32, true, ""}, + {arrow.PrimitiveTypes.Float64, iceberg.PrimitiveTypes.Float64, true, ""}, + {arrow.FixedWidthTypes.Date32, iceberg.PrimitiveTypes.Date, true, ""}, + {arrow.FixedWidthTypes.Date64, nil, false, "unsupported arrow type for conversion - date64"}, + {arrow.FixedWidthTypes.Time32s, nil, false, "unsupported arrow type for conversion - time32[s]"}, + {arrow.FixedWidthTypes.Time32ms, nil, false, "unsupported arrow type for conversion - time32[ms]"}, + {arrow.FixedWidthTypes.Time64us, iceberg.PrimitiveTypes.Time, true, ""}, + {arrow.FixedWidthTypes.Time64ns, nil, false, "unsupported arrow type for conversion - time64[ns]"}, + {arrow.FixedWidthTypes.Timestamp_s, iceberg.PrimitiveTypes.TimestampTz, false, ""}, + {arrow.FixedWidthTypes.Timestamp_ms, iceberg.PrimitiveTypes.TimestampTz, false, ""}, + {arrow.FixedWidthTypes.Timestamp_us, iceberg.PrimitiveTypes.TimestampTz, true, ""}, + {arrow.FixedWidthTypes.Timestamp_ns, nil, false, "'ns' timestamp precision not supported"}, + {&arrow.TimestampType{Unit: arrow.Second}, iceberg.PrimitiveTypes.Timestamp, false, ""}, + {&arrow.TimestampType{Unit: arrow.Millisecond}, iceberg.PrimitiveTypes.Timestamp, false, ""}, + {&arrow.TimestampType{Unit: arrow.Microsecond}, iceberg.PrimitiveTypes.Timestamp, true, ""}, + {&arrow.TimestampType{Unit: arrow.Nanosecond}, nil, false, "'ns' timestamp precision not supported"}, + {&arrow.TimestampType{Unit: arrow.Microsecond, TimeZone: "US/Pacific"}, nil, false, "unsupported arrow type for conversion - timestamp[us, tz=US/Pacific]"}, + {arrow.BinaryTypes.String, iceberg.PrimitiveTypes.String, false, ""}, + {arrow.BinaryTypes.LargeString, iceberg.PrimitiveTypes.String, true, ""}, + {arrow.BinaryTypes.StringView, nil, false, "unsupported arrow type for conversion - string_view"}, + {arrow.BinaryTypes.Binary, iceberg.PrimitiveTypes.Binary, false, ""}, + {arrow.BinaryTypes.LargeBinary, iceberg.PrimitiveTypes.Binary, true, ""}, + {arrow.BinaryTypes.BinaryView, nil, false, "unsupported arrow type for conversion - binary_view"}, + {extensions.NewUUIDType(), iceberg.PrimitiveTypes.UUID, true, ""}, {arrow.StructOf(arrow.Field{ Name: "foo", Type: arrow.BinaryTypes.LargeString, @@ -98,7 +99,7 @@ func TestArrowToIceberg(t *testing.T) { {ID: 1, Name: "foo", Type: iceberg.PrimitiveTypes.String, Required: false, Doc: "foo doc"}, {ID: 2, Name: "bar", Type: iceberg.PrimitiveTypes.Int32, Required: true}, {ID: 3, Name: "baz", Type: iceberg.PrimitiveTypes.Bool, Required: false}, - }}, ""}, + }}, true, ""}, {arrow.ListOfField(arrow.Field{ Name: "element", Type: arrow.PrimitiveTypes.Int32, @@ -108,7 +109,7 @@ func TestArrowToIceberg(t *testing.T) { ElementID: 1, Element: iceberg.PrimitiveTypes.Int32, ElementRequired: true, - }, ""}, + }, false, ""}, {arrow.LargeListOfField(arrow.Field{ Name: "element", Type: arrow.PrimitiveTypes.Int32, @@ -118,7 +119,7 @@ func TestArrowToIceberg(t *testing.T) { ElementID: 1, Element: iceberg.PrimitiveTypes.Int32, ElementRequired: true, - }, ""}, + }, true, ""}, {arrow.FixedSizeListOfField(1, arrow.Field{ Name: "element", Type: arrow.PrimitiveTypes.Int32, @@ -128,23 +129,23 @@ func TestArrowToIceberg(t *testing.T) { ElementID: 1, Element: iceberg.PrimitiveTypes.Int32, ElementRequired: true, - }, ""}, + }, false, ""}, {arrow.MapOfWithMetadata(arrow.PrimitiveTypes.Int32, fieldIDMeta("1"), - arrow.BinaryTypes.String, fieldIDMeta("2")), + arrow.BinaryTypes.LargeString, fieldIDMeta("2")), &iceberg.MapType{ KeyID: 1, KeyType: iceberg.PrimitiveTypes.Int32, ValueID: 2, ValueType: iceberg.PrimitiveTypes.String, ValueRequired: false, - }, ""}, + }, true, ""}, {&arrow.DictionaryType{IndexType: arrow.PrimitiveTypes.Int32, - ValueType: arrow.BinaryTypes.String}, iceberg.PrimitiveTypes.String, ""}, + ValueType: arrow.BinaryTypes.String}, iceberg.PrimitiveTypes.String, false, ""}, {&arrow.DictionaryType{IndexType: arrow.PrimitiveTypes.Int32, - ValueType: arrow.PrimitiveTypes.Int32}, iceberg.PrimitiveTypes.Int32, ""}, + ValueType: arrow.PrimitiveTypes.Int32}, iceberg.PrimitiveTypes.Int32, false, ""}, {&arrow.DictionaryType{IndexType: arrow.PrimitiveTypes.Int64, - ValueType: arrow.PrimitiveTypes.Float64}, iceberg.PrimitiveTypes.Float64, ""}, - {arrow.RunEndEncodedOf(arrow.PrimitiveTypes.Int32, arrow.BinaryTypes.String), iceberg.PrimitiveTypes.String, ""}, - {arrow.RunEndEncodedOf(arrow.PrimitiveTypes.Int32, arrow.PrimitiveTypes.Float64), iceberg.PrimitiveTypes.Float64, ""}, - {arrow.RunEndEncodedOf(arrow.PrimitiveTypes.Int32, arrow.PrimitiveTypes.Int16), iceberg.PrimitiveTypes.Int32, ""}, + ValueType: arrow.PrimitiveTypes.Float64}, iceberg.PrimitiveTypes.Float64, false, ""}, + {arrow.RunEndEncodedOf(arrow.PrimitiveTypes.Int32, arrow.BinaryTypes.String), iceberg.PrimitiveTypes.String, false, ""}, + {arrow.RunEndEncodedOf(arrow.PrimitiveTypes.Int32, arrow.PrimitiveTypes.Float64), iceberg.PrimitiveTypes.Float64, false, ""}, + {arrow.RunEndEncodedOf(arrow.PrimitiveTypes.Int32, arrow.PrimitiveTypes.Int16), iceberg.PrimitiveTypes.Int32, false, ""}, } for _, tt := range tests { @@ -156,11 +157,17 @@ func TestArrowToIceberg(t *testing.T) { } else { assert.ErrorContains(t, err, tt.err) } + + if tt.reciprocal { + result, err := table.TypeToArrowType(tt.ice, true) + require.NoError(t, err) + assert.True(t, arrow.TypeEqual(tt.dt, result), tt.dt.String(), result.String()) + } }) } } -func TestArrowSchemaToIceb(t *testing.T) { +func TestArrowSchemaToIceberg(t *testing.T) { tests := []struct { name string sc *arrow.Schema @@ -292,6 +299,23 @@ var ( ) ) +func TestArrowSchemaRoundTripConversion(t *testing.T) { + schemas := []*iceberg.Schema{ + icebergSchemaSimple, + icebergSchemaNested, + } + + for _, tt := range schemas { + sc, err := table.SchemaToArrowSchema(tt, nil, true) + require.NoError(t, err) + + ice, err := table.ArrowSchemaToIceberg(sc, false, nil) + require.NoError(t, err) + + assert.True(t, tt.Equals(ice), tt.String(), ice.String()) + } +} + func TestArrowSchemaWithNameMapping(t *testing.T) { schemaWithoutIDs := arrow.NewSchema([]arrow.Field{ {Name: "foo", Type: arrow.BinaryTypes.String, Nullable: true},