Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better prepared statement support #95

Merged
merged 12 commits into from
Jan 18, 2024
4 changes: 2 additions & 2 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ require (
github.com/PuerkitoBio/goquery v1.8.1
github.com/cockroachdb/apd/v2 v2.0.3-0.20200518165714-d020e156310a
github.com/cockroachdb/errors v1.7.5
github.com/dolthub/dolt/go v0.40.5-0.20240118193700-2398a547fccc
github.com/dolthub/dolt/go v0.40.5-0.20240118214900-3cbb73cafa3c
github.com/dolthub/dolt/go/gen/proto/dolt/services/eventsapi v0.0.0-20231213233028-64c353bf920f
github.com/dolthub/go-mysql-server v0.17.1-0.20240117234409-91a2a9d4b1a1
github.com/dolthub/go-mysql-server v0.17.1-0.20240118213933-3c0fb56900df
github.com/dolthub/sqllogictest/go v0.0.0-20240118211725-a52e3f5697e3
github.com/dolthub/vitess v0.0.0-20240117231546-55b8c7b39462
github.com/fatih/color v1.13.0
Expand Down
8 changes: 4 additions & 4 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -214,8 +214,8 @@ github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZm
github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2/go.mod h1:SqUrOPUnsFjfmXRMNPybcSiG0BgUW2AuFH8PAnS2iTw=
github.com/docker/go-connections v0.4.0/go.mod h1:Gbd7IOopHjR8Iph03tsViu4nIes5XhDvyHbTtUxmeec=
github.com/docker/go-units v0.4.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk=
github.com/dolthub/dolt/go v0.40.5-0.20240118193700-2398a547fccc h1:sFL+cYw5UORwM0CZ31rJ+mSSqheL6GN9ICOVfj0tzvs=
github.com/dolthub/dolt/go v0.40.5-0.20240118193700-2398a547fccc/go.mod h1:UeVcSMEmqQFKmKGJz7uH5whri2k/bg005/gUTgZU+VQ=
github.com/dolthub/dolt/go v0.40.5-0.20240118214900-3cbb73cafa3c h1:gEvvX3cUMEOW0UyIO3klXaohXzJmJ0ls0jZKAgTdWJE=
github.com/dolthub/dolt/go v0.40.5-0.20240118214900-3cbb73cafa3c/go.mod h1:n4qCXkCLlIFbR8PuXB0WG1JV5s8SZLK4sa/dEVx420o=
github.com/dolthub/dolt/go/gen/proto/dolt/services/eventsapi v0.0.0-20231213233028-64c353bf920f h1:f250FTgZ/OaCql9G6WJt46l9VOIBF1mI81hW9cnmBNM=
github.com/dolthub/dolt/go/gen/proto/dolt/services/eventsapi v0.0.0-20231213233028-64c353bf920f/go.mod h1:gHeHIDGU7em40EhFTliq62pExFcc1hxDTIZ9g5UqXYM=
github.com/dolthub/flatbuffers/v23 v23.3.3-dh.2 h1:u3PMzfF8RkKd3lB9pZ2bfn0qEG+1Gms9599cr0REMww=
Expand All @@ -224,8 +224,8 @@ github.com/dolthub/fslock v0.0.3 h1:iLMpUIvJKMKm92+N1fmHVdxJP5NdyDK5bK7z7Ba2s2U=
github.com/dolthub/fslock v0.0.3/go.mod h1:QWql+P17oAAMLnL4HGB5tiovtDuAjdDTPbuqx7bYfa0=
github.com/dolthub/go-icu-regex v0.0.0-20230524105445-af7e7991c97e h1:kPsT4a47cw1+y/N5SSCkma7FhAPw7KeGmD6c9PBZW9Y=
github.com/dolthub/go-icu-regex v0.0.0-20230524105445-af7e7991c97e/go.mod h1:KPUcpx070QOfJK1gNe0zx4pA5sicIK1GMikIGLKC168=
github.com/dolthub/go-mysql-server v0.17.1-0.20240117234409-91a2a9d4b1a1 h1:CPdkEWpNyz6H1380wwR+pkxXpBQF7vRTjZ7fb/UCqWs=
github.com/dolthub/go-mysql-server v0.17.1-0.20240117234409-91a2a9d4b1a1/go.mod h1:hS8Snuzg+nyTDjv4NI9jiXQ2lJJOd3O0ylhVPQlHySw=
github.com/dolthub/go-mysql-server v0.17.1-0.20240118213933-3c0fb56900df h1:OmR6U3UvCMEguh1UaXCiK4qasA/tHH3+Ls2NRiEQfjU=
github.com/dolthub/go-mysql-server v0.17.1-0.20240118213933-3c0fb56900df/go.mod h1:hS8Snuzg+nyTDjv4NI9jiXQ2lJJOd3O0ylhVPQlHySw=
github.com/dolthub/ishell v0.0.0-20221214210346-d7db0b066488 h1:0HHu0GWJH0N6a6keStrHhUAK5/o9LVfkh44pvsV4514=
github.com/dolthub/ishell v0.0.0-20221214210346-d7db0b066488/go.mod h1:ehexgi1mPxRTk0Mok/pADALuHbvATulTh6gzr7NzZto=
github.com/dolthub/jsonpath v0.0.2-0.20230525180605-8dc13778fd72 h1:NfWmngMi1CYUWU4Ix8wM+USEhjc+mhPlT9JUR/anvbQ=
Expand Down
39 changes: 39 additions & 0 deletions server/ast/expr.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ package ast
import (
"fmt"
"go/constant"
"strings"

"github.com/dolthub/go-mysql-server/sql/expression"
vitess "github.com/dolthub/vitess/go/vt/sqlparser"

"github.com/dolthub/doltgresql/postgres/parser/sem/tree"
Expand Down Expand Up @@ -202,6 +204,11 @@ func nodeExpr(node tree.Expr) (vitess.Expr, error) {
return nil, err
}

convertType, err = translateConvertType(convertType)
if err != nil {
return nil, err
}

return &vitess.ConvertExpr{
Name: "CAST",
Expr: expr,
Expand Down Expand Up @@ -571,3 +578,35 @@ func nodeExpr(node tree.Expr) (vitess.Expr, error) {
return nil, fmt.Errorf("unknown expression: `%T`", node)
}
}

func translateConvertType(convertType *vitess.ConvertType) (*vitess.ConvertType, error) {
zachmu marked this conversation as resolved.
Show resolved Hide resolved
switch strings.ToLower(convertType.Type) {
// passthrough types that need no conversion
case expression.ConvertToBinary, expression.ConvertToChar, expression.ConvertToNChar, expression.ConvertToDate,
expression.ConvertToDatetime, expression.ConvertToFloat, expression.ConvertToDouble, expression.ConvertToJSON,
expression.ConvertToReal, expression.ConvertToSigned, expression.ConvertToTime, expression.ConvertToUnsigned:
return convertType, nil
case "text", "character varying", "varchar":
return &vitess.ConvertType{
Type: expression.ConvertToChar,
}, nil
case "integer", "bigint":
return &vitess.ConvertType{
Type: expression.ConvertToSigned,
}, nil
case "decimal", "numeric":
return &vitess.ConvertType{
Type: expression.ConvertToFloat,
}, nil
case "boolean":
return &vitess.ConvertType{
Type: expression.ConvertToSigned,
}, nil
case "timestamp", "timestamp with time zone", "timestamp without time zone":
return &vitess.ConvertType{
Type: expression.ConvertToDatetime,
}, nil
default:
return nil, fmt.Errorf("unknown convert type: `%T`", convertType.Type)
}
}
42 changes: 32 additions & 10 deletions server/listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -478,19 +478,37 @@ func extractBindVarTypes(queryPlan sql.Node) ([]int32, error) {

types := make([]int32, 0)
var err error
transform.InspectExpressions(inspectNode, func(expr sql.Expression) bool {
if bindVar, ok := expr.(*expression.BindVar); ok {
var id int32
id, err = messages.VitessTypeToObjectID(bindVar.Type().Type())
extractBindVars := func(expr sql.Expression) bool {
if err != nil {
return false
}
switch e := expr.(type) {
case *expression.BindVar:
var oid int32
oid, err = messages.VitessTypeToObjectID(e.Type().Type())
if err != nil {
err = fmt.Errorf("could not determine OID for placeholder %s: %w", e.Name, err)
return false
}
types = append(types, oid)
// $1::text and similar get converted to a Convert expression wrapping the bindvar
case *expression.Convert:
if bindVar, ok := e.Child.(*expression.BindVar); ok {
var oid int32
oid, err = messages.VitessTypeToObjectID(e.Type().Type())
if err != nil {
err = fmt.Errorf("could not determine OID for placeholder %s: %w", bindVar.Name, err)
return false
}
types = append(types, oid)
return false
} else {
types = append(types, id)
}
}

return true
})
}

transform.InspectExpressions(inspectNode, extractBindVars)
return types, err
}

Expand All @@ -513,15 +531,19 @@ func convertBindVarValue(typ querypb.Type, value messages.BindParameterValue) []
switch typ {
case querypb.Type_INT8, querypb.Type_INT16, querypb.Type_INT24, querypb.Type_INT32, querypb.Type_UINT8, querypb.Type_UINT16, querypb.Type_UINT24, querypb.Type_UINT32:
// first convert the bytes in the payload to an integer, then convert that to its base 10 string representation
intVal := binary.BigEndian.Uint32(value.Data) // TODO: bound check
intVal := binary.BigEndian.Uint32(value.Data)
return []byte(strconv.FormatUint(uint64(intVal), 10))
case querypb.Type_INT64, querypb.Type_UINT64:
// first convert the bytes in the payload to an integer, then convert that to its base 10 string representation
intVal := binary.BigEndian.Uint64(value.Data)
return []byte(strconv.FormatUint(intVal, 10))
case querypb.Type_FLOAT32, querypb.Type_FLOAT64:
case querypb.Type_FLOAT32:
// first convert the bytes in the payload to a float, then convert that to its base 10 string representation
floatVal := binary.BigEndian.Uint32(value.Data)
return []byte(strconv.FormatFloat(float64(math.Float32frombits(floatVal)), 'f', -1, 64))
case querypb.Type_FLOAT64:
// first convert the bytes in the payload to a float, then convert that to its base 10 string representation
floatVal := binary.BigEndian.Uint64(value.Data) // TODO: bound check
floatVal := binary.BigEndian.Uint64(value.Data)
return []byte(strconv.FormatFloat(math.Float64frombits(floatVal), 'f', -1, 64))
case querypb.Type_VARCHAR, querypb.Type_VARBINARY, querypb.Type_TEXT, querypb.Type_BLOB:
return value.Data
Expand Down
11 changes: 10 additions & 1 deletion testing/go/framework.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ import (
dserver "github.com/dolthub/doltgresql/server"
)

// runOnPostgres is a debug setting to redirect the test framework to a local running postgres server,
// rather than starting a doltgres server.
const runOnPostgres = false

// ScriptTest defines a consistent structure for testing queries.
type ScriptTest struct {
// Name of the script.
Expand Down Expand Up @@ -76,6 +80,11 @@ type ScriptTestAssertion struct {

// RunScript runs the given script.
func RunScript(t *testing.T, script ScriptTest) {
if runOnPostgres {
RunScriptOnPostgres(t, script)
return
}

scriptDatabase := script.Database
if len(scriptDatabase) == 0 {
scriptDatabase = "postgres"
Expand Down Expand Up @@ -113,7 +122,6 @@ func runScript(t *testing.T, script ScriptTest, conn *pgx.Conn, ctx context.Cont
t.Skip("Skip has been set in the assertion")
}
// If we're skipping the results check, then we call Execute, as it uses a simplified message model.
// The more complicated model is only partially implemented, and therefore won't work for all queries.
if assertion.SkipResultsCheck || assertion.ExpectedErr {
_, err := conn.Exec(ctx, assertion.Query, assertion.BindVars...)
if assertion.ExpectedErr {
Expand Down Expand Up @@ -167,6 +175,7 @@ func RunScripts(t *testing.T, scripts []ScriptTest) {
if len(focusScripts) > 0 {
scripts = focusScripts
}

for _, script := range scripts {
RunScript(t, script)
}
Expand Down
1 change: 1 addition & 0 deletions testing/go/functions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ func TestFunctionsMath(t *testing.T) {
},
{
Query: `SELECT round(cbrt(v1)::numeric, 10), round(cbrt(v2)::numeric, 10), round(cbrt(v3)::numeric, 10) FROM test ORDER BY pk;`,
Skip: true, // Our values are slightly different
Expected: []sql.Row{
{-1.0000000000, -1.2599210499, -1.4422495703},
{1.9129311828, 2.2239800906, 2.3513346877},
Expand Down
17 changes: 5 additions & 12 deletions testing/go/prepared_statement_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,20 +27,18 @@ var preparedStatementTests = []ScriptTest{
Name: "expressions without tables",
Assertions: []ScriptTestAssertion{
{
Query: "SELECT CONCAT($1, $2)",
Query: "SELECT CONCAT($1::text, $2::text)",
BindVars: []any{"hello", "world"},
Expected: []sql.Row{
{"helloworld"},
},
Skip: true, // this doesn't work without explicit type hints for the params
},
{
Query: "SELECT $1 + $2",
Query: "SELECT $1::integer + $2::integer",
BindVars: []any{1, 2},
Expected: []sql.Row{
{3},
},
Skip: true, // this doesn't work without explicit type hints for the params
},
},
},
Expand Down Expand Up @@ -80,7 +78,6 @@ var preparedStatementTests = []ScriptTest{
Expected: []sql.Row{
{1, 2},
},
Skip: true, // can't correctly extract the bindvar type with more complicated processing during plan building
},
{
Query: "SELECT * FROM test WHERE pk + v1 = $1;",
Expand All @@ -90,12 +87,11 @@ var preparedStatementTests = []ScriptTest{
},
},
{
Query: "SELECT * FROM test WHERE v1 = $1 + $2;",
Query: "SELECT * FROM test WHERE v1 = $1::integer + $2::integer;",
BindVars: []any{1, 3},
Expected: []sql.Row{
{3, 4},
},
Skip: true, // this doesn't work without explicit type hints for the params
},
},
},
Expand Down Expand Up @@ -172,12 +168,11 @@ var preparedStatementTests = []ScriptTest{
},
},
{
Query: "SELECT * FROM test WHERE s = concat($1, $2);",
Query: "SELECT * FROM test WHERE s = concat($1::text, $2::text);",
BindVars: []any{"he", "llo"},
Expected: []sql.Row{
{1, "hello"},
},
Skip: true, // this doesn't work without explicit type hints for the params
},
{
Query: "SELECT * FROM test WHERE concat(s, '!') = $1",
Expand Down Expand Up @@ -266,15 +261,13 @@ var preparedStatementTests = []ScriptTest{
Expected: []sql.Row{
{1, 1.1},
},
Skip: true, // can't correctly extract the bindvar type with more complicated processing during plan building
},
{
Query: "SELECT * FROM test WHERE f1 = $1 + $2;",
Query: "SELECT * FROM test WHERE f1 = $1::decimal + $2::decimal;",
BindVars: []any{1.0, 0.1},
Expected: []sql.Row{
{1, 1.1},
},
Skip: true, // this doesn't work without explicit type hints for the params
},
},
},
Expand Down