Skip to content

Commit

Permalink
Merge pull request #1081 from dolthub/jennifer/enum-cast
Browse files Browse the repository at this point in the history
support enum type cast
  • Loading branch information
jennifersp authored Dec 18, 2024
2 parents 9d098e6 + d3730a1 commit 535cdf8
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 13 deletions.
70 changes: 60 additions & 10 deletions server/functions/enum.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ import (

"github.com/dolthub/go-mysql-server/sql"

"github.com/dolthub/doltgresql/core"
"github.com/dolthub/doltgresql/core/id"
"github.com/dolthub/doltgresql/server/functions/framework"
pgtypes "github.com/dolthub/doltgresql/server/types"
"github.com/dolthub/doltgresql/utils"
Expand All @@ -40,10 +42,20 @@ var enum_in = framework.Function2{
Parameters: [2]*pgtypes.DoltgresType{pgtypes.Cstring, pgtypes.Oid},
Strict: true,
Callable: func(ctx *sql.Context, _ [3]*pgtypes.DoltgresType, val1, val2 any) (any, error) {
// typOid := val2.(id.Internal)
// TODO: get type using given OID, which should give access to enum labels.
// should return the index of label?
return val1.(string), nil
typ, err := getDoltgresTypeFromInternal(ctx, val2.(id.Internal))
if err != nil {
return nil, err
}
if typ.TypCategory != pgtypes.TypeCategory_EnumTypes {
return nil, fmt.Errorf(`"%s" is not an enum type`, typ.Name())
}

value := val1.(string)
if _, exists := typ.EnumLabels[value]; !exists {
return nil, pgtypes.ErrInvalidInputValueForEnum.New(typ.Name(), value)
}
// TODO: should return the index instead of label?
return value, nil
},
}

Expand All @@ -54,7 +66,7 @@ var enum_out = framework.Function1{
Parameters: [1]*pgtypes.DoltgresType{pgtypes.AnyEnum},
Strict: true,
Callable: func(ctx *sql.Context, _ [2]*pgtypes.DoltgresType, val any) (any, error) {
// TODO: should return the index of label?
// TODO: should receive the index instead of label?
return val.(string), nil
},
}
Expand All @@ -66,15 +78,28 @@ var enum_recv = framework.Function2{
Parameters: [2]*pgtypes.DoltgresType{pgtypes.Internal, pgtypes.Oid},
Strict: true,
Callable: func(ctx *sql.Context, _ [3]*pgtypes.DoltgresType, val1, val2 any) (any, error) {
// typOid := val2.(id.Internal)
// TODO: get type using given OID, which should give access to enum labels.
// should return the index of label?
// TODO: should return the index instead of label?
data := val1.([]byte)
if len(data) == 0 {
return nil, nil
}
reader := utils.NewReader(data)
return reader.String(), nil
value := reader.String()
if ctx == nil {
// TODO: currently, in some places we use nil context, should fix it.
return value, nil
}
typ, err := getDoltgresTypeFromInternal(ctx, val2.(id.Internal))
if err != nil {
return nil, err
}
if typ.TypCategory != pgtypes.TypeCategory_EnumTypes {
return nil, fmt.Errorf(`"%s" is not an enum type`, typ.Name())
}
if _, exists := typ.EnumLabels[value]; !exists {
return nil, pgtypes.ErrInvalidInputValueForEnum.New(typ.Name(), value)
}
return value, nil
},
}

Expand All @@ -85,7 +110,7 @@ var enum_send = framework.Function1{
Parameters: [1]*pgtypes.DoltgresType{pgtypes.AnyEnum},
Strict: true,
Callable: func(ctx *sql.Context, _ [2]*pgtypes.DoltgresType, val any) (any, error) {
// TODO: should return the index of label?
// TODO: should return the index instead of label?
str := val.(string)
writer := utils.NewWriter(uint64(len(str) + 4))
writer.String(str)
Expand Down Expand Up @@ -123,3 +148,28 @@ var enum_cmp = framework.Function2{
}
},
}

// getDoltgresTypeFromInternal takes internal ID and returns DoltgresType associated to it.
// It allows retrieving user-defined type and requires valid sql.Context.
func getDoltgresTypeFromInternal(ctx *sql.Context, typID id.Internal) (*pgtypes.DoltgresType, error) {
typCol, err := core.GetTypesCollectionFromContext(ctx)
if err != nil {
return nil, err
}

schName := typID.Segment(0)
sch, err := core.GetCurrentSchema(ctx)
if err != nil {
return nil, err
}
if schName == "" {
schName = sch
}

typName := typID.Segment(1)
typ, found := typCol.GetType(schName, typName)
if !found {
return nil, pgtypes.ErrTypeDoesNotExist.New(typName)
}
return typ, nil
}
4 changes: 2 additions & 2 deletions testing/go/functions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2070,8 +2070,8 @@ func TestSelectFromFunctions(t *testing.T) {
Expected: []sql.Row{{"37.89_1.2"}},
},
{
Query: `SELECT format_type(874938247, 20);`,
Expected: []sql.Row{{"???"}},
Query: `SELECT * FROM format_type('text'::regtype, 4);`,
Expected: []sql.Row{{"text(4)"}},
},
{
Query: `SELECT * from format_type(874938247, 20);`,
Expand Down
5 changes: 4 additions & 1 deletion testing/go/types_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3292,6 +3292,10 @@ var enumTypeTests = []ScriptTest{
Query: `SELECT * FROM person WHERE current_mood > 'sad' ORDER BY current_mood;`,
Expected: []sql.Row{{"Curly", "ok"}, {"Moe", "happy"}},
},
{
Query: `INSERT INTO person VALUES ('Joey', 'invalid');`,
ExpectedErr: `invalid input value for enum mood: "invalid"`,
},
},
},
{
Expand Down Expand Up @@ -3324,7 +3328,6 @@ var enumTypeTests = []ScriptTest{
},
},
{
Skip: true,
Name: "enum type cast",
SetUpScript: []string{
`CREATE TYPE mood AS ENUM ('sad', 'ok', 'happy')`,
Expand Down

0 comments on commit 535cdf8

Please sign in to comment.