Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better prepared statement support #95

Merged
merged 12 commits into from
Jan 18, 2024
10 changes: 5 additions & 5 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,19 @@ require (
github.com/PuerkitoBio/goquery v1.8.1
github.com/cockroachdb/apd/v2 v2.0.3-0.20200518165714-d020e156310a
github.com/cockroachdb/errors v1.7.5
github.com/dolthub/dolt/go v0.40.5-0.20240110011351-84b9180295cc
github.com/dolthub/dolt/go v0.40.5-0.20240118010436-3613eed18a80
github.com/dolthub/dolt/go/gen/proto/dolt/services/eventsapi v0.0.0-20231213233028-64c353bf920f
github.com/dolthub/go-mysql-server v0.17.1-0.20240110234302-66c569a3137e
github.com/dolthub/go-mysql-server v0.17.1-0.20240118005749-9120557227aa
github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81
github.com/dolthub/vitess v0.0.0-20240110233415-e46007d964c0
github.com/dolthub/vitess v0.0.0-20240117231546-55b8c7b39462
github.com/fatih/color v1.13.0
github.com/gogo/protobuf v1.3.2
github.com/golang/geo v0.0.0-20200730024412-e86565bf3f35
github.com/google/go-cmp v0.6.0
github.com/grpc-ecosystem/grpc-gateway v1.16.0
github.com/jackc/pgx/v4 v4.18.1
github.com/jackc/pgx/v5 v5.4.3
github.com/lib/pq v1.10.2
github.com/lib/pq v1.10.9
github.com/madflojo/testcerts v1.1.1
github.com/pierrre/geohash v1.0.0
github.com/sergi/go-diff v1.1.0
Expand All @@ -29,6 +29,7 @@ require (
github.com/twpayne/go-geom v1.3.6
golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1
golang.org/x/net v0.17.0
golang.org/x/sync v0.3.0
golang.org/x/sys v0.15.0
golang.org/x/text v0.14.0
)
Expand Down Expand Up @@ -138,7 +139,6 @@ require (
golang.org/x/crypto v0.17.0 // indirect
golang.org/x/mod v0.12.0 // indirect
golang.org/x/oauth2 v0.8.0 // indirect
golang.org/x/sync v0.3.0 // indirect
golang.org/x/term v0.15.0 // indirect
golang.org/x/time v0.1.0 // indirect
golang.org/x/tools v0.13.0 // indirect
Expand Down
19 changes: 8 additions & 11 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -214,8 +214,8 @@ github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZm
github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2/go.mod h1:SqUrOPUnsFjfmXRMNPybcSiG0BgUW2AuFH8PAnS2iTw=
github.com/docker/go-connections v0.4.0/go.mod h1:Gbd7IOopHjR8Iph03tsViu4nIes5XhDvyHbTtUxmeec=
github.com/docker/go-units v0.4.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk=
github.com/dolthub/dolt/go v0.40.5-0.20240110011351-84b9180295cc h1:7C97S8tm3cKL4tZIKaudt4BTBOBgwdZ3ceSExwb+bNo=
github.com/dolthub/dolt/go v0.40.5-0.20240110011351-84b9180295cc/go.mod h1:+oni3DE3qkT79htI/fVogLu00bRTfdu15fL4A3KPr24=
github.com/dolthub/dolt/go v0.40.5-0.20240118010436-3613eed18a80 h1:yCXqyA6QRHt3LWF89lFvuZOuvyH9O7DFWTecQd4dkjg=
github.com/dolthub/dolt/go v0.40.5-0.20240118010436-3613eed18a80/go.mod h1:dPcQ+foa5Tr6L3m5PQSY29PW9ZP+ZCoX1O9lO1pctZ8=
github.com/dolthub/dolt/go/gen/proto/dolt/services/eventsapi v0.0.0-20231213233028-64c353bf920f h1:f250FTgZ/OaCql9G6WJt46l9VOIBF1mI81hW9cnmBNM=
github.com/dolthub/dolt/go/gen/proto/dolt/services/eventsapi v0.0.0-20231213233028-64c353bf920f/go.mod h1:gHeHIDGU7em40EhFTliq62pExFcc1hxDTIZ9g5UqXYM=
github.com/dolthub/flatbuffers/v23 v23.3.3-dh.2 h1:u3PMzfF8RkKd3lB9pZ2bfn0qEG+1Gms9599cr0REMww=
Expand All @@ -224,10 +224,8 @@ github.com/dolthub/fslock v0.0.3 h1:iLMpUIvJKMKm92+N1fmHVdxJP5NdyDK5bK7z7Ba2s2U=
github.com/dolthub/fslock v0.0.3/go.mod h1:QWql+P17oAAMLnL4HGB5tiovtDuAjdDTPbuqx7bYfa0=
github.com/dolthub/go-icu-regex v0.0.0-20230524105445-af7e7991c97e h1:kPsT4a47cw1+y/N5SSCkma7FhAPw7KeGmD6c9PBZW9Y=
github.com/dolthub/go-icu-regex v0.0.0-20230524105445-af7e7991c97e/go.mod h1:KPUcpx070QOfJK1gNe0zx4pA5sicIK1GMikIGLKC168=
github.com/dolthub/go-mysql-server v0.17.1-0.20240110020052-1eabd6054d96 h1:FDMByaljXrMExow4qE3qwQoyRbXku6GBy6jnqPjx4zg=
github.com/dolthub/go-mysql-server v0.17.1-0.20240110020052-1eabd6054d96/go.mod h1:z98pba7qbSvXiceU3NlUbJaYwITxc1Am06YjK6hexXA=
github.com/dolthub/go-mysql-server v0.17.1-0.20240110234302-66c569a3137e h1:FwStPrVtMcFTqaVp8Pk8KH1iCVTyQ58GzlNMO6ak418=
github.com/dolthub/go-mysql-server v0.17.1-0.20240110234302-66c569a3137e/go.mod h1:vANS+BQiobOQ3sfB1Qxm5zqOrsXOaK6S3EE1yb4vJuc=
github.com/dolthub/go-mysql-server v0.17.1-0.20240118005749-9120557227aa h1:/vldUjqmAMkqz39gN5oRrq0Xnjm/pd0mkY2VtDJqE6M=
github.com/dolthub/go-mysql-server v0.17.1-0.20240118005749-9120557227aa/go.mod h1:hS8Snuzg+nyTDjv4NI9jiXQ2lJJOd3O0ylhVPQlHySw=
github.com/dolthub/ishell v0.0.0-20221214210346-d7db0b066488 h1:0HHu0GWJH0N6a6keStrHhUAK5/o9LVfkh44pvsV4514=
github.com/dolthub/ishell v0.0.0-20221214210346-d7db0b066488/go.mod h1:ehexgi1mPxRTk0Mok/pADALuHbvATulTh6gzr7NzZto=
github.com/dolthub/jsonpath v0.0.2-0.20230525180605-8dc13778fd72 h1:NfWmngMi1CYUWU4Ix8wM+USEhjc+mhPlT9JUR/anvbQ=
Expand All @@ -238,10 +236,8 @@ github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81 h1:7/v8q9X
github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81/go.mod h1:siLfyv2c92W1eN/R4QqG/+RjjX5W2+gCTRjZxBjI3TY=
github.com/dolthub/swiss v0.1.0 h1:EaGQct3AqeP/MjASHLiH6i4TAmgbG/c4rA6a1bzCOPc=
github.com/dolthub/swiss v0.1.0/go.mod h1:BeucyB08Vb1G9tumVN3Vp/pyY4AMUnr9p7Rz7wJ7kAQ=
github.com/dolthub/vitess v0.0.0-20240110003421-4030c3dac015 h1:n45HAYH+kmlvZ+lZPKtJoserQJNwgQkyVWZAL7kJpn0=
github.com/dolthub/vitess v0.0.0-20240110003421-4030c3dac015/go.mod h1:IwjNXSQPymrja5pVqmfnYdcy7Uv7eNJNBPK/MEh9OOw=
github.com/dolthub/vitess v0.0.0-20240110233415-e46007d964c0 h1:P8wb4dR5krirPa0swEJbEObc/I7GaAM/01nOnuQrl0c=
github.com/dolthub/vitess v0.0.0-20240110233415-e46007d964c0/go.mod h1:IwjNXSQPymrja5pVqmfnYdcy7Uv7eNJNBPK/MEh9OOw=
github.com/dolthub/vitess v0.0.0-20240117231546-55b8c7b39462 h1:So1KO202cb047yWg5X27xRso6tkSYmU0Yu96JIVsaEU=
github.com/dolthub/vitess v0.0.0-20240117231546-55b8c7b39462/go.mod h1:IwjNXSQPymrja5pVqmfnYdcy7Uv7eNJNBPK/MEh9OOw=
github.com/dustin/go-humanize v0.0.0-20171111073723-bb3d318650d4/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk=
github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk=
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
Expand Down Expand Up @@ -611,8 +607,9 @@ github.com/lib/pq v1.1.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
github.com/lib/pq v1.8.0/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
github.com/lib/pq v1.10.0/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
github.com/lib/pq v1.10.2 h1:AqzbZs4ZoCBp+GtejcpCpcxM3zlSMx29dXbUSeVtJb8=
github.com/lib/pq v1.10.2/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
github.com/lightstep/lightstep-tracer-common/golang/gogo v0.0.0-20190605223551-bc2310a04743/go.mod h1:qklhhLq1aX+mtWk9cPHPzaBjWImj5ULL6C7HFJtXQMM=
github.com/lightstep/lightstep-tracer-go v0.18.1/go.mod h1:jlF1pusYV4pidLvZ+XD0UBX0ZE6WURAspgAczcDHrL4=
github.com/lyft/protoc-gen-star v0.5.2/go.mod h1:9toiA3cC7z5uVbODF7kEQ91Xn7XNFkVUl+SrEe+ZORU=
Expand Down
39 changes: 39 additions & 0 deletions server/ast/expr.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ package ast
import (
"fmt"
"go/constant"
"strings"

"github.com/dolthub/go-mysql-server/sql/expression"
vitess "github.com/dolthub/vitess/go/vt/sqlparser"

"github.com/dolthub/doltgresql/postgres/parser/sem/tree"
Expand Down Expand Up @@ -201,6 +203,11 @@ func nodeExpr(node tree.Expr) (vitess.Expr, error) {
if err != nil {
return nil, err
}

convertType, err = translateConvertType(convertType)
if err != nil {
return nil, err
}

return &vitess.ConvertExpr{
Name: "CAST",
Expand Down Expand Up @@ -571,3 +578,35 @@ func nodeExpr(node tree.Expr) (vitess.Expr, error) {
return nil, fmt.Errorf("unknown expression: `%T`", node)
}
}

func translateConvertType(convertType *vitess.ConvertType) (*vitess.ConvertType, error) {
zachmu marked this conversation as resolved.
Show resolved Hide resolved
switch strings.ToLower(convertType.Type) {
// passthrough types that need no conversion
case expression.ConvertToBinary, expression.ConvertToChar, expression.ConvertToNChar, expression.ConvertToDate,
expression.ConvertToDatetime, expression.ConvertToFloat, expression.ConvertToDouble, expression.ConvertToJSON,
expression.ConvertToReal, expression.ConvertToSigned, expression.ConvertToTime, expression.ConvertToUnsigned:
return convertType, nil
case "text", "character varying", "varchar":
return &vitess.ConvertType{
Type: expression.ConvertToChar,
}, nil
case "integer", "bigint":
return &vitess.ConvertType{
Type: expression.ConvertToSigned,
}, nil
case "decimal", "numeric":
return &vitess.ConvertType{
Type: expression.ConvertToFloat,
}, nil
case "boolean":
return &vitess.ConvertType{
Type: expression.ConvertToSigned,
}, nil
case "timestamp", "timestamp with time zone", "timestamp without time zone":
return &vitess.ConvertType{
Type: expression.ConvertToDatetime,
}, nil
default:
return nil, fmt.Errorf("unknown convert type: `%T`", convertType.Type)
}
}
34 changes: 26 additions & 8 deletions server/listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -478,19 +478,37 @@ func extractBindVarTypes(queryPlan sql.Node) ([]int32, error) {

types := make([]int32, 0)
var err error
transform.InspectExpressions(inspectNode, func(expr sql.Expression) bool {
if bindVar, ok := expr.(*expression.BindVar); ok {
var id int32
id, err = messages.VitessTypeToObjectID(bindVar.Type().Type())
extractBindVars := func(expr sql.Expression) bool {
if err != nil {
return false
}
switch e := expr.(type) {
case *expression.BindVar:
var oid int32
oid, err = messages.VitessTypeToObjectID(e.Type().Type())
if err != nil {
err = fmt.Errorf("could not determine OID for placeholder %s: %w", e.Name, err)
return false
}
types = append(types, oid)
// $1::text and similar get converted to a Convert expression wrapping the bindvar
case *expression.Convert:
if bindVar, ok := e.Child.(*expression.BindVar); ok {
var oid int32
oid, err = messages.VitessTypeToObjectID(e.Type().Type())
if err != nil {
err = fmt.Errorf("could not determine OID for placeholder %s: %w", bindVar.Name, err)
return false
}
types = append(types, oid)
return false
} else {
types = append(types, id)
}
}

return true
})

}

transform.InspectExpressions(inspectNode, extractBindVars)
return types, err
}

Expand Down
12 changes: 10 additions & 2 deletions testing/go/framework.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ import (
dserver "github.com/dolthub/doltgresql/server"
)

// runOnPostgres is a debug setting to redirect the test framework to a local running postgres server,
// rather than starting a doltgres server.
const runOnPostgres = false

// ScriptTest defines a consistent structure for testing queries.
type ScriptTest struct {
// Name of the script.
Expand Down Expand Up @@ -113,7 +117,6 @@ func runScript(t *testing.T, script ScriptTest, conn *pgx.Conn, ctx context.Cont
t.Skip("Skip has been set in the assertion")
}
// If we're skipping the results check, then we call Execute, as it uses a simplified message model.
// The more complicated model is only partially implemented, and therefore won't work for all queries.
if assertion.SkipResultsCheck || assertion.ExpectedErr {
_, err := conn.Exec(ctx, assertion.Query, assertion.BindVars...)
if assertion.ExpectedErr {
Expand Down Expand Up @@ -167,8 +170,13 @@ func RunScripts(t *testing.T, scripts []ScriptTest) {
if len(focusScripts) > 0 {
scripts = focusScripts
}

for _, script := range scripts {
RunScript(t, script)
if runOnPostgres {
RunScriptOnPostgres(t, script)
zachmu marked this conversation as resolved.
Show resolved Hide resolved
} else {
RunScript(t, script)
}
}
}

Expand Down
17 changes: 5 additions & 12 deletions testing/go/prepared_statement_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,20 +27,18 @@ var preparedStatementTests = []ScriptTest{
Name: "expressions without tables",
Assertions: []ScriptTestAssertion{
{
Query: "SELECT CONCAT($1, $2)",
Query: "SELECT CONCAT($1::text, $2::text)",
BindVars: []any{"hello", "world"},
Expected: []sql.Row{
{"helloworld"},
},
Skip: true, // this doesn't work without explicit type hints for the params
},
{
Query: "SELECT $1 + $2",
Query: "SELECT $1::integer + $2::integer",
BindVars: []any{1, 2},
Expected: []sql.Row{
{3},
},
Skip: true, // this doesn't work without explicit type hints for the params
},
},
},
Expand Down Expand Up @@ -80,7 +78,6 @@ var preparedStatementTests = []ScriptTest{
Expected: []sql.Row{
{1, 2},
},
Skip: true, // can't correctly extract the bindvar type with more complicated processing during plan building
},
{
Query: "SELECT * FROM test WHERE pk + v1 = $1;",
Expand All @@ -90,12 +87,11 @@ var preparedStatementTests = []ScriptTest{
},
},
{
Query: "SELECT * FROM test WHERE v1 = $1 + $2;",
Query: "SELECT * FROM test WHERE v1 = $1::integer + $2::integer;",
BindVars: []any{1, 3},
Expected: []sql.Row{
{3, 4},
},
Skip: true, // this doesn't work without explicit type hints for the params
},
},
},
Expand Down Expand Up @@ -172,12 +168,11 @@ var preparedStatementTests = []ScriptTest{
},
},
{
Query: "SELECT * FROM test WHERE s = concat($1, $2);",
Query: "SELECT * FROM test WHERE s = concat($1::text, $2::text);",
BindVars: []any{"he", "llo"},
Expected: []sql.Row{
{1, "hello"},
},
Skip: true, // this doesn't work without explicit type hints for the params
},
{
Query: "SELECT * FROM test WHERE concat(s, '!') = $1",
Expand Down Expand Up @@ -266,15 +261,13 @@ var preparedStatementTests = []ScriptTest{
Expected: []sql.Row{
{1, 1.1},
},
Skip: true, // can't correctly extract the bindvar type with more complicated processing during plan building
},
{
Query: "SELECT * FROM test WHERE f1 = $1 + $2;",
Query: "SELECT * FROM test WHERE f1 = $1::decimal + $2::decimal;",
BindVars: []any{1.0, 0.1},
Expected: []sql.Row{
{1, 1.1},
},
Skip: true, // this doesn't work without explicit type hints for the params
},
},
},
Expand Down
Loading