diff --git a/go.mod b/go.mod index a868740a87..a78185cae9 100644 --- a/go.mod +++ b/go.mod @@ -6,9 +6,9 @@ require ( github.com/PuerkitoBio/goquery v1.8.1 github.com/cockroachdb/apd/v2 v2.0.3-0.20200518165714-d020e156310a github.com/cockroachdb/errors v1.7.5 - github.com/dolthub/dolt/go v0.40.5-0.20240118193700-2398a547fccc + github.com/dolthub/dolt/go v0.40.5-0.20240118214900-3cbb73cafa3c github.com/dolthub/dolt/go/gen/proto/dolt/services/eventsapi v0.0.0-20231213233028-64c353bf920f - github.com/dolthub/go-mysql-server v0.17.1-0.20240117234409-91a2a9d4b1a1 + github.com/dolthub/go-mysql-server v0.17.1-0.20240118213933-3c0fb56900df github.com/dolthub/sqllogictest/go v0.0.0-20240118211725-a52e3f5697e3 github.com/dolthub/vitess v0.0.0-20240117231546-55b8c7b39462 github.com/fatih/color v1.13.0 diff --git a/go.sum b/go.sum index 503e7512e8..285f0f16c4 100644 --- a/go.sum +++ b/go.sum @@ -214,8 +214,8 @@ github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZm github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2/go.mod h1:SqUrOPUnsFjfmXRMNPybcSiG0BgUW2AuFH8PAnS2iTw= github.com/docker/go-connections v0.4.0/go.mod h1:Gbd7IOopHjR8Iph03tsViu4nIes5XhDvyHbTtUxmeec= github.com/docker/go-units v0.4.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk= -github.com/dolthub/dolt/go v0.40.5-0.20240118193700-2398a547fccc h1:sFL+cYw5UORwM0CZ31rJ+mSSqheL6GN9ICOVfj0tzvs= -github.com/dolthub/dolt/go v0.40.5-0.20240118193700-2398a547fccc/go.mod h1:UeVcSMEmqQFKmKGJz7uH5whri2k/bg005/gUTgZU+VQ= +github.com/dolthub/dolt/go v0.40.5-0.20240118214900-3cbb73cafa3c h1:gEvvX3cUMEOW0UyIO3klXaohXzJmJ0ls0jZKAgTdWJE= +github.com/dolthub/dolt/go v0.40.5-0.20240118214900-3cbb73cafa3c/go.mod h1:n4qCXkCLlIFbR8PuXB0WG1JV5s8SZLK4sa/dEVx420o= github.com/dolthub/dolt/go/gen/proto/dolt/services/eventsapi v0.0.0-20231213233028-64c353bf920f h1:f250FTgZ/OaCql9G6WJt46l9VOIBF1mI81hW9cnmBNM= github.com/dolthub/dolt/go/gen/proto/dolt/services/eventsapi v0.0.0-20231213233028-64c353bf920f/go.mod h1:gHeHIDGU7em40EhFTliq62pExFcc1hxDTIZ9g5UqXYM= github.com/dolthub/flatbuffers/v23 v23.3.3-dh.2 h1:u3PMzfF8RkKd3lB9pZ2bfn0qEG+1Gms9599cr0REMww= @@ -224,8 +224,8 @@ github.com/dolthub/fslock v0.0.3 h1:iLMpUIvJKMKm92+N1fmHVdxJP5NdyDK5bK7z7Ba2s2U= github.com/dolthub/fslock v0.0.3/go.mod h1:QWql+P17oAAMLnL4HGB5tiovtDuAjdDTPbuqx7bYfa0= github.com/dolthub/go-icu-regex v0.0.0-20230524105445-af7e7991c97e h1:kPsT4a47cw1+y/N5SSCkma7FhAPw7KeGmD6c9PBZW9Y= github.com/dolthub/go-icu-regex v0.0.0-20230524105445-af7e7991c97e/go.mod h1:KPUcpx070QOfJK1gNe0zx4pA5sicIK1GMikIGLKC168= -github.com/dolthub/go-mysql-server v0.17.1-0.20240117234409-91a2a9d4b1a1 h1:CPdkEWpNyz6H1380wwR+pkxXpBQF7vRTjZ7fb/UCqWs= -github.com/dolthub/go-mysql-server v0.17.1-0.20240117234409-91a2a9d4b1a1/go.mod h1:hS8Snuzg+nyTDjv4NI9jiXQ2lJJOd3O0ylhVPQlHySw= +github.com/dolthub/go-mysql-server v0.17.1-0.20240118213933-3c0fb56900df h1:OmR6U3UvCMEguh1UaXCiK4qasA/tHH3+Ls2NRiEQfjU= +github.com/dolthub/go-mysql-server v0.17.1-0.20240118213933-3c0fb56900df/go.mod h1:hS8Snuzg+nyTDjv4NI9jiXQ2lJJOd3O0ylhVPQlHySw= github.com/dolthub/ishell v0.0.0-20221214210346-d7db0b066488 h1:0HHu0GWJH0N6a6keStrHhUAK5/o9LVfkh44pvsV4514= github.com/dolthub/ishell v0.0.0-20221214210346-d7db0b066488/go.mod h1:ehexgi1mPxRTk0Mok/pADALuHbvATulTh6gzr7NzZto= github.com/dolthub/jsonpath v0.0.2-0.20230525180605-8dc13778fd72 h1:NfWmngMi1CYUWU4Ix8wM+USEhjc+mhPlT9JUR/anvbQ= diff --git a/server/ast/expr.go b/server/ast/expr.go index 8742c8889d..e6f72a9d1a 100644 --- a/server/ast/expr.go +++ b/server/ast/expr.go @@ -17,7 +17,9 @@ package ast import ( "fmt" "go/constant" + "strings" + "github.com/dolthub/go-mysql-server/sql/expression" vitess "github.com/dolthub/vitess/go/vt/sqlparser" "github.com/dolthub/doltgresql/postgres/parser/sem/tree" @@ -202,6 +204,11 @@ func nodeExpr(node tree.Expr) (vitess.Expr, error) { return nil, err } + convertType, err = translateConvertType(convertType) + if err != nil { + return nil, err + } + return &vitess.ConvertExpr{ Name: "CAST", Expr: expr, @@ -571,3 +578,35 @@ func nodeExpr(node tree.Expr) (vitess.Expr, error) { return nil, fmt.Errorf("unknown expression: `%T`", node) } } + +func translateConvertType(convertType *vitess.ConvertType) (*vitess.ConvertType, error) { + switch strings.ToLower(convertType.Type) { + // passthrough types that need no conversion + case expression.ConvertToBinary, expression.ConvertToChar, expression.ConvertToNChar, expression.ConvertToDate, + expression.ConvertToDatetime, expression.ConvertToFloat, expression.ConvertToDouble, expression.ConvertToJSON, + expression.ConvertToReal, expression.ConvertToSigned, expression.ConvertToTime, expression.ConvertToUnsigned: + return convertType, nil + case "text", "character varying", "varchar": + return &vitess.ConvertType{ + Type: expression.ConvertToChar, + }, nil + case "integer", "bigint": + return &vitess.ConvertType{ + Type: expression.ConvertToSigned, + }, nil + case "decimal", "numeric": + return &vitess.ConvertType{ + Type: expression.ConvertToFloat, + }, nil + case "boolean": + return &vitess.ConvertType{ + Type: expression.ConvertToSigned, + }, nil + case "timestamp", "timestamp with time zone", "timestamp without time zone": + return &vitess.ConvertType{ + Type: expression.ConvertToDatetime, + }, nil + default: + return nil, fmt.Errorf("unknown convert type: `%T`", convertType.Type) + } +} diff --git a/server/listener.go b/server/listener.go index 2df2caf564..12e85e594c 100644 --- a/server/listener.go +++ b/server/listener.go @@ -478,19 +478,37 @@ func extractBindVarTypes(queryPlan sql.Node) ([]int32, error) { types := make([]int32, 0) var err error - transform.InspectExpressions(inspectNode, func(expr sql.Expression) bool { - if bindVar, ok := expr.(*expression.BindVar); ok { - var id int32 - id, err = messages.VitessTypeToObjectID(bindVar.Type().Type()) + extractBindVars := func(expr sql.Expression) bool { + if err != nil { + return false + } + switch e := expr.(type) { + case *expression.BindVar: + var oid int32 + oid, err = messages.VitessTypeToObjectID(e.Type().Type()) if err != nil { + err = fmt.Errorf("could not determine OID for placeholder %s: %w", e.Name, err) + return false + } + types = append(types, oid) + // $1::text and similar get converted to a Convert expression wrapping the bindvar + case *expression.Convert: + if bindVar, ok := e.Child.(*expression.BindVar); ok { + var oid int32 + oid, err = messages.VitessTypeToObjectID(e.Type().Type()) + if err != nil { + err = fmt.Errorf("could not determine OID for placeholder %s: %w", bindVar.Name, err) + return false + } + types = append(types, oid) return false - } else { - types = append(types, id) } } + return true - }) + } + transform.InspectExpressions(inspectNode, extractBindVars) return types, err } @@ -513,15 +531,19 @@ func convertBindVarValue(typ querypb.Type, value messages.BindParameterValue) [] switch typ { case querypb.Type_INT8, querypb.Type_INT16, querypb.Type_INT24, querypb.Type_INT32, querypb.Type_UINT8, querypb.Type_UINT16, querypb.Type_UINT24, querypb.Type_UINT32: // first convert the bytes in the payload to an integer, then convert that to its base 10 string representation - intVal := binary.BigEndian.Uint32(value.Data) // TODO: bound check + intVal := binary.BigEndian.Uint32(value.Data) return []byte(strconv.FormatUint(uint64(intVal), 10)) case querypb.Type_INT64, querypb.Type_UINT64: // first convert the bytes in the payload to an integer, then convert that to its base 10 string representation intVal := binary.BigEndian.Uint64(value.Data) return []byte(strconv.FormatUint(intVal, 10)) - case querypb.Type_FLOAT32, querypb.Type_FLOAT64: + case querypb.Type_FLOAT32: + // first convert the bytes in the payload to a float, then convert that to its base 10 string representation + floatVal := binary.BigEndian.Uint32(value.Data) + return []byte(strconv.FormatFloat(float64(math.Float32frombits(floatVal)), 'f', -1, 64)) + case querypb.Type_FLOAT64: // first convert the bytes in the payload to a float, then convert that to its base 10 string representation - floatVal := binary.BigEndian.Uint64(value.Data) // TODO: bound check + floatVal := binary.BigEndian.Uint64(value.Data) return []byte(strconv.FormatFloat(math.Float64frombits(floatVal), 'f', -1, 64)) case querypb.Type_VARCHAR, querypb.Type_VARBINARY, querypb.Type_TEXT, querypb.Type_BLOB: return value.Data diff --git a/testing/go/framework.go b/testing/go/framework.go index 30dfde2003..79e2a1e5d1 100644 --- a/testing/go/framework.go +++ b/testing/go/framework.go @@ -36,6 +36,10 @@ import ( dserver "github.com/dolthub/doltgresql/server" ) +// runOnPostgres is a debug setting to redirect the test framework to a local running postgres server, +// rather than starting a doltgres server. +const runOnPostgres = false + // ScriptTest defines a consistent structure for testing queries. type ScriptTest struct { // Name of the script. @@ -76,6 +80,11 @@ type ScriptTestAssertion struct { // RunScript runs the given script. func RunScript(t *testing.T, script ScriptTest) { + if runOnPostgres { + RunScriptOnPostgres(t, script) + return + } + scriptDatabase := script.Database if len(scriptDatabase) == 0 { scriptDatabase = "postgres" @@ -113,7 +122,6 @@ func runScript(t *testing.T, script ScriptTest, conn *pgx.Conn, ctx context.Cont t.Skip("Skip has been set in the assertion") } // If we're skipping the results check, then we call Execute, as it uses a simplified message model. - // The more complicated model is only partially implemented, and therefore won't work for all queries. if assertion.SkipResultsCheck || assertion.ExpectedErr { _, err := conn.Exec(ctx, assertion.Query, assertion.BindVars...) if assertion.ExpectedErr { @@ -167,6 +175,7 @@ func RunScripts(t *testing.T, scripts []ScriptTest) { if len(focusScripts) > 0 { scripts = focusScripts } + for _, script := range scripts { RunScript(t, script) } diff --git a/testing/go/functions_test.go b/testing/go/functions_test.go index 03306f9a93..82ff88d511 100644 --- a/testing/go/functions_test.go +++ b/testing/go/functions_test.go @@ -41,6 +41,7 @@ func TestFunctionsMath(t *testing.T) { }, { Query: `SELECT round(cbrt(v1)::numeric, 10), round(cbrt(v2)::numeric, 10), round(cbrt(v3)::numeric, 10) FROM test ORDER BY pk;`, + Skip: true, // Our values are slightly different Expected: []sql.Row{ {-1.0000000000, -1.2599210499, -1.4422495703}, {1.9129311828, 2.2239800906, 2.3513346877}, diff --git a/testing/go/prepared_statement_test.go b/testing/go/prepared_statement_test.go index f93e7e9bc2..c71e5d1774 100755 --- a/testing/go/prepared_statement_test.go +++ b/testing/go/prepared_statement_test.go @@ -27,20 +27,18 @@ var preparedStatementTests = []ScriptTest{ Name: "expressions without tables", Assertions: []ScriptTestAssertion{ { - Query: "SELECT CONCAT($1, $2)", + Query: "SELECT CONCAT($1::text, $2::text)", BindVars: []any{"hello", "world"}, Expected: []sql.Row{ {"helloworld"}, }, - Skip: true, // this doesn't work without explicit type hints for the params }, { - Query: "SELECT $1 + $2", + Query: "SELECT $1::integer + $2::integer", BindVars: []any{1, 2}, Expected: []sql.Row{ {3}, }, - Skip: true, // this doesn't work without explicit type hints for the params }, }, }, @@ -80,7 +78,6 @@ var preparedStatementTests = []ScriptTest{ Expected: []sql.Row{ {1, 2}, }, - Skip: true, // can't correctly extract the bindvar type with more complicated processing during plan building }, { Query: "SELECT * FROM test WHERE pk + v1 = $1;", @@ -90,12 +87,11 @@ var preparedStatementTests = []ScriptTest{ }, }, { - Query: "SELECT * FROM test WHERE v1 = $1 + $2;", + Query: "SELECT * FROM test WHERE v1 = $1::integer + $2::integer;", BindVars: []any{1, 3}, Expected: []sql.Row{ {3, 4}, }, - Skip: true, // this doesn't work without explicit type hints for the params }, }, }, @@ -172,12 +168,11 @@ var preparedStatementTests = []ScriptTest{ }, }, { - Query: "SELECT * FROM test WHERE s = concat($1, $2);", + Query: "SELECT * FROM test WHERE s = concat($1::text, $2::text);", BindVars: []any{"he", "llo"}, Expected: []sql.Row{ {1, "hello"}, }, - Skip: true, // this doesn't work without explicit type hints for the params }, { Query: "SELECT * FROM test WHERE concat(s, '!') = $1", @@ -266,15 +261,13 @@ var preparedStatementTests = []ScriptTest{ Expected: []sql.Row{ {1, 1.1}, }, - Skip: true, // can't correctly extract the bindvar type with more complicated processing during plan building }, { - Query: "SELECT * FROM test WHERE f1 = $1 + $2;", + Query: "SELECT * FROM test WHERE f1 = $1::decimal + $2::decimal;", BindVars: []any{1.0, 0.1}, Expected: []sql.Row{ {1, 1.1}, }, - Skip: true, // this doesn't work without explicit type hints for the params }, }, },