From c6623293bcc1606a10bf8c8f361e6316220c2c12 Mon Sep 17 00:00:00 2001 From: jennifersp Date: Tue, 17 Dec 2024 16:39:21 -0800 Subject: [PATCH 1/2] support enum type cast --- server/functions/enum.go | 70 ++++++++++++++++++++++++++++++------ testing/go/functions_test.go | 4 --- testing/go/types_test.go | 5 ++- 3 files changed, 64 insertions(+), 15 deletions(-) diff --git a/server/functions/enum.go b/server/functions/enum.go index 9a1731b120..8b2573417c 100644 --- a/server/functions/enum.go +++ b/server/functions/enum.go @@ -19,6 +19,8 @@ import ( "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/doltgresql/core" + "github.com/dolthub/doltgresql/core/id" "github.com/dolthub/doltgresql/server/functions/framework" pgtypes "github.com/dolthub/doltgresql/server/types" "github.com/dolthub/doltgresql/utils" @@ -40,10 +42,20 @@ var enum_in = framework.Function2{ Parameters: [2]*pgtypes.DoltgresType{pgtypes.Cstring, pgtypes.Oid}, Strict: true, Callable: func(ctx *sql.Context, _ [3]*pgtypes.DoltgresType, val1, val2 any) (any, error) { - // typOid := val2.(id.Internal) - // TODO: get type using given OID, which should give access to enum labels. - // should return the index of label? - return val1.(string), nil + typ, err := getDoltgresTypeFromInternal(ctx, val2.(id.Internal)) + if err != nil { + return nil, err + } + if typ.TypCategory != pgtypes.TypeCategory_EnumTypes { + return nil, fmt.Errorf(`"%s" is not an enum type`, typ.Name()) + } + + value := val1.(string) + if _, exists := typ.EnumLabels[value]; !exists { + return nil, pgtypes.ErrInvalidInputValueForEnum.New(typ.Name(), value) + } + // TODO: should return the index instead of label? + return value, nil }, } @@ -54,7 +66,7 @@ var enum_out = framework.Function1{ Parameters: [1]*pgtypes.DoltgresType{pgtypes.AnyEnum}, Strict: true, Callable: func(ctx *sql.Context, _ [2]*pgtypes.DoltgresType, val any) (any, error) { - // TODO: should return the index of label? + // TODO: should receive the index instead of label? return val.(string), nil }, } @@ -66,15 +78,28 @@ var enum_recv = framework.Function2{ Parameters: [2]*pgtypes.DoltgresType{pgtypes.Internal, pgtypes.Oid}, Strict: true, Callable: func(ctx *sql.Context, _ [3]*pgtypes.DoltgresType, val1, val2 any) (any, error) { - // typOid := val2.(id.Internal) - // TODO: get type using given OID, which should give access to enum labels. - // should return the index of label? + // TODO: should return the index instead of label? data := val1.([]byte) if len(data) == 0 { return nil, nil } reader := utils.NewReader(data) - return reader.String(), nil + value := reader.String() + if ctx == nil { + // TODO: currently, in some places we use nil context, should fix it. + return value, nil + } + typ, err := getDoltgresTypeFromInternal(ctx, val2.(id.Internal)) + if err != nil { + return nil, err + } + if typ.TypCategory != pgtypes.TypeCategory_EnumTypes { + return nil, fmt.Errorf(`"%s" is not an enum type`, typ.Name()) + } + if _, exists := typ.EnumLabels[value]; !exists { + return nil, pgtypes.ErrInvalidInputValueForEnum.New(typ.Name(), value) + } + return value, nil }, } @@ -85,7 +110,7 @@ var enum_send = framework.Function1{ Parameters: [1]*pgtypes.DoltgresType{pgtypes.AnyEnum}, Strict: true, Callable: func(ctx *sql.Context, _ [2]*pgtypes.DoltgresType, val any) (any, error) { - // TODO: should return the index of label? + // TODO: should return the index instead of label? str := val.(string) writer := utils.NewWriter(uint64(len(str) + 4)) writer.String(str) @@ -123,3 +148,28 @@ var enum_cmp = framework.Function2{ } }, } + +// getDoltgresTypeFromInternal takes internal ID and returns DoltgresType associated to it. +// It allows retrieving user-defined type and requires valid sql.Context. +func getDoltgresTypeFromInternal(ctx *sql.Context, typID id.Internal) (*pgtypes.DoltgresType, error) { + typCol, err := core.GetTypesCollectionFromContext(ctx) + if err != nil { + return nil, err + } + + schName := typID.Segment(0) + sch, err := core.GetCurrentSchema(ctx) + if err != nil { + return nil, err + } + if schName == "" { + schName = sch + } + + typName := typID.Segment(1) + typ, found := typCol.GetType(schName, typName) + if !found { + return nil, pgtypes.ErrTypeDoesNotExist.New(typName) + } + return typ, nil +} diff --git a/testing/go/functions_test.go b/testing/go/functions_test.go index 23900f4d1b..0f5d3e3223 100644 --- a/testing/go/functions_test.go +++ b/testing/go/functions_test.go @@ -2069,10 +2069,6 @@ func TestSelectFromFunctions(t *testing.T) { Query: `SELECT * FROM array_to_string(ARRAY[37.89, 1.2], '_');`, Expected: []sql.Row{{"37.89_1.2"}}, }, - { - Query: `SELECT format_type(874938247, 20);`, - Expected: []sql.Row{{"???"}}, - }, { Query: `SELECT * from format_type(874938247, 20);`, Expected: []sql.Row{{"???"}}, diff --git a/testing/go/types_test.go b/testing/go/types_test.go index 2882e7a8fb..7470cf2b4c 100644 --- a/testing/go/types_test.go +++ b/testing/go/types_test.go @@ -3292,6 +3292,10 @@ var enumTypeTests = []ScriptTest{ Query: `SELECT * FROM person WHERE current_mood > 'sad' ORDER BY current_mood;`, Expected: []sql.Row{{"Curly", "ok"}, {"Moe", "happy"}}, }, + { + Query: `INSERT INTO person VALUES ('Joey', 'invalid');`, + ExpectedErr: `invalid input value for enum mood: "invalid"`, + }, }, }, { @@ -3324,7 +3328,6 @@ var enumTypeTests = []ScriptTest{ }, }, { - Skip: true, Name: "enum type cast", SetUpScript: []string{ `CREATE TYPE mood AS ENUM ('sad', 'ok', 'happy')`, From d3730a109a36390ad22d35759d011d6c3cdaa9bb Mon Sep 17 00:00:00 2001 From: jennifersp Date: Wed, 18 Dec 2024 09:58:36 -0800 Subject: [PATCH 2/2] update test --- testing/go/functions_test.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/testing/go/functions_test.go b/testing/go/functions_test.go index 0f5d3e3223..4086d57ae8 100644 --- a/testing/go/functions_test.go +++ b/testing/go/functions_test.go @@ -2069,6 +2069,10 @@ func TestSelectFromFunctions(t *testing.T) { Query: `SELECT * FROM array_to_string(ARRAY[37.89, 1.2], '_');`, Expected: []sql.Row{{"37.89_1.2"}}, }, + { + Query: `SELECT * FROM format_type('text'::regtype, 4);`, + Expected: []sql.Row{{"text(4)"}}, + }, { Query: `SELECT * from format_type(874938247, 20);`, Expected: []sql.Row{{"???"}},