Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into taylor/schemas-5
Browse files Browse the repository at this point in the history
  • Loading branch information
tbantle22 committed Jun 27, 2024
2 parents 0536414 + d3d72b5 commit 15f3e40
Show file tree
Hide file tree
Showing 9 changed files with 188 additions and 65 deletions.
6 changes: 3 additions & 3 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@ require (
github.com/PuerkitoBio/goquery v1.8.1
github.com/cockroachdb/apd/v2 v2.0.3-0.20200518165714-d020e156310a
github.com/cockroachdb/errors v1.7.5
github.com/dolthub/dolt/go v0.40.5-0.20240625223514-4c704c3daeca
github.com/dolthub/dolt/go v0.40.5-0.20240626185946-7aef8fcde146
github.com/dolthub/dolt/go/gen/proto/dolt/services/eventsapi v0.0.0-20240529071237-4a099b896ce8
github.com/dolthub/flatbuffers/v23 v23.3.3-dh.2
github.com/dolthub/go-mysql-server v0.18.2-0.20240625212035-80f4e402d726
github.com/dolthub/go-mysql-server v0.18.2-0.20240626180128-807a2e35937f
github.com/dolthub/sqllogictest/go v0.0.0-20240618184124-ca47f9354216
github.com/dolthub/vitess v0.0.0-20240617225939-55a46c5dcfc8
github.com/dolthub/vitess v0.0.0-20240626174323-4083c07f5e9c
github.com/fatih/color v1.13.0
github.com/goccy/go-json v0.10.2
github.com/gogo/protobuf v1.3.2
Expand Down
12 changes: 6 additions & 6 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -214,8 +214,8 @@ github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZm
github.com/dgryski/go-farm v0.0.0-20190423205320-6a90982ecee2/go.mod h1:SqUrOPUnsFjfmXRMNPybcSiG0BgUW2AuFH8PAnS2iTw=
github.com/docker/go-connections v0.4.0/go.mod h1:Gbd7IOopHjR8Iph03tsViu4nIes5XhDvyHbTtUxmeec=
github.com/docker/go-units v0.4.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk=
github.com/dolthub/dolt/go v0.40.5-0.20240625223514-4c704c3daeca h1:iaw0jiEb9QmF0wtrYyN0+xrYfZzjrtnNmKtNlumEVS0=
github.com/dolthub/dolt/go v0.40.5-0.20240625223514-4c704c3daeca/go.mod h1:3AsoVPqO/ELL1eBO9wOZYdUy0Iy7kpJ4Y9t87UN8mlI=
github.com/dolthub/dolt/go v0.40.5-0.20240626185946-7aef8fcde146 h1:PoVHlUEWWnkS7VlyVnryfumvFyHJnmwojCNOQ17QbnM=
github.com/dolthub/dolt/go v0.40.5-0.20240626185946-7aef8fcde146/go.mod h1:t9SrujEmNlUnQ/fPjpXfGUoQ4CQMkjIuASq5GT6bUnw=
github.com/dolthub/dolt/go/gen/proto/dolt/services/eventsapi v0.0.0-20240529071237-4a099b896ce8 h1:izuogF6KRc6Pr5g5KevRtn8JK/KwyEGjbpqWJIORbQo=
github.com/dolthub/dolt/go/gen/proto/dolt/services/eventsapi v0.0.0-20240529071237-4a099b896ce8/go.mod h1:L5RDYZbC9BBWmoU2+TjTekeqqhFXX5EqH9ln00O0stY=
github.com/dolthub/flatbuffers/v23 v23.3.3-dh.2 h1:u3PMzfF8RkKd3lB9pZ2bfn0qEG+1Gms9599cr0REMww=
Expand All @@ -224,8 +224,8 @@ github.com/dolthub/fslock v0.0.3 h1:iLMpUIvJKMKm92+N1fmHVdxJP5NdyDK5bK7z7Ba2s2U=
github.com/dolthub/fslock v0.0.3/go.mod h1:QWql+P17oAAMLnL4HGB5tiovtDuAjdDTPbuqx7bYfa0=
github.com/dolthub/go-icu-regex v0.0.0-20230524105445-af7e7991c97e h1:kPsT4a47cw1+y/N5SSCkma7FhAPw7KeGmD6c9PBZW9Y=
github.com/dolthub/go-icu-regex v0.0.0-20230524105445-af7e7991c97e/go.mod h1:KPUcpx070QOfJK1gNe0zx4pA5sicIK1GMikIGLKC168=
github.com/dolthub/go-mysql-server v0.18.2-0.20240625212035-80f4e402d726 h1:zVMMW0gpT/Cq+xEUPulbt5y7dWz0K1A5BtNRDarW+i0=
github.com/dolthub/go-mysql-server v0.18.2-0.20240625212035-80f4e402d726/go.mod h1:XdiHsd2TX3OOhjwY6tPcw1ztT2BdBiP6Wp0m/7OYHn4=
github.com/dolthub/go-mysql-server v0.18.2-0.20240626180128-807a2e35937f h1:ouZxORcShC3qeQdsTkKDpROh+OI1OZuoOAJx8HThMrs=
github.com/dolthub/go-mysql-server v0.18.2-0.20240626180128-807a2e35937f/go.mod h1:JahRYjx/Py6T/bWrnTu25CaGn94Df+McAuWGEG0shwU=
github.com/dolthub/gozstd v0.0.0-20240423170813-23a2903bca63 h1:OAsXLAPL4du6tfbBgK0xXHZkOlos63RdKYS3Sgw/dfI=
github.com/dolthub/gozstd v0.0.0-20240423170813-23a2903bca63/go.mod h1:lV7lUeuDhH5thVGDCKXbatwKy2KW80L4rMT46n+Y2/Q=
github.com/dolthub/ishell v0.0.0-20221214210346-d7db0b066488 h1:0HHu0GWJH0N6a6keStrHhUAK5/o9LVfkh44pvsV4514=
Expand All @@ -238,8 +238,8 @@ github.com/dolthub/sqllogictest/go v0.0.0-20240618184124-ca47f9354216 h1:JWkKRE4
github.com/dolthub/sqllogictest/go v0.0.0-20240618184124-ca47f9354216/go.mod h1:e/FIZVvT2IR53HBCAo41NjqgtEnjMJGKca3Y/dAmZaA=
github.com/dolthub/swiss v0.1.0 h1:EaGQct3AqeP/MjASHLiH6i4TAmgbG/c4rA6a1bzCOPc=
github.com/dolthub/swiss v0.1.0/go.mod h1:BeucyB08Vb1G9tumVN3Vp/pyY4AMUnr9p7Rz7wJ7kAQ=
github.com/dolthub/vitess v0.0.0-20240617225939-55a46c5dcfc8 h1:d+dOTwI8dkwNYmcweXNjei2ot3GHJB3HqLWUeNvAkC0=
github.com/dolthub/vitess v0.0.0-20240617225939-55a46c5dcfc8/go.mod h1:uBvlRluuL+SbEWTCZ68o0xvsdYZER3CEG/35INdzfJM=
github.com/dolthub/vitess v0.0.0-20240626174323-4083c07f5e9c h1:Y3M0hPCUvT+5RTNbJLKywGc9aHIRCIlg+0NOhC91GYE=
github.com/dolthub/vitess v0.0.0-20240626174323-4083c07f5e9c/go.mod h1:uBvlRluuL+SbEWTCZ68o0xvsdYZER3CEG/35INdzfJM=
github.com/dustin/go-humanize v0.0.0-20171111073723-bb3d318650d4/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk=
github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk=
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
Expand Down
11 changes: 6 additions & 5 deletions postgres/parser/parser/sql/sql_parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package sql

import (
"context"
"fmt"

"github.com/dolthub/doltgresql/postgres/parser/parser"
Expand All @@ -35,17 +36,17 @@ func NewPostgresParser() *PostgresParser { return &PostgresParser{} }

// ParseSimple implements sql.Parser interface.
func (p *PostgresParser) ParseSimple(query string) (vitess.Statement, error) {
stmt, _, _, err := p.ParseWithOptions(query, ';', false, vitess.ParserOptions{})
stmt, _, _, err := p.ParseWithOptions(context.Background(), query, ';', false, vitess.ParserOptions{})
return stmt, err
}

// Parse implements sql.Parser interface.
func (p *PostgresParser) Parse(_ *sql.Context, query string, multi bool) (vitess.Statement, string, string, error) {
return p.ParseWithOptions(query, ';', multi, vitess.ParserOptions{})
func (p *PostgresParser) Parse(ctx *sql.Context, query string, multi bool) (vitess.Statement, string, string, error) {
return p.ParseWithOptions(ctx, query, ';', multi, vitess.ParserOptions{})
}

// ParseWithOptions implements sql.Parser interface.
func (p *PostgresParser) ParseWithOptions(query string, delimiter rune, _ bool, _ vitess.ParserOptions) (vitess.Statement, string, string, error) {
func (p *PostgresParser) ParseWithOptions(ctx context.Context, query string, delimiter rune, _ bool, _ vitess.ParserOptions) (vitess.Statement, string, string, error) {
q := sql.RemoveSpaceAndDelimiter(query, delimiter)
stmts, err := parser.Parse(q)
if err != nil {
Expand All @@ -70,7 +71,7 @@ func (p *PostgresParser) ParseWithOptions(query string, delimiter rune, _ bool,
}

// ParseOneWithOptions implements sql.Parser interface.
func (p *PostgresParser) ParseOneWithOptions(query string, _ vitess.ParserOptions) (vitess.Statement, int, error) {
func (p *PostgresParser) ParseOneWithOptions(_ context.Context, query string, _ vitess.ParserOptions) (vitess.Statement, int, error) {
stmt, err := parser.ParseOne(query)
if err != nil {
return nil, 0, err
Expand Down
9 changes: 6 additions & 3 deletions server/analyzer/type_sanitizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,14 @@ func TypeSanitizer(ctx *sql.Context, a *analyzer.Analyzer, node sql.Node, scope
case sql.FunctionExpression:
// Compiled functions are Doltgres functions. We're only concerned with GMS functions.
if _, ok := expr.(*framework.CompiledFunction); !ok {
// The COUNT functions cannot be wrapped due to expectations in the analyzer, so we exclude them here.
// Some aggregation functions cannot be wrapped due to expectations in the analyzer, so we exclude them here.
switch expr.FunctionName() {
case "Count", "CountDistinct", "GroupConcat", "JSONObjectAgg":
case "Count", "CountDistinct", "GroupConcat", "JSONObjectAgg", "Sum":
default:
return pgexprs.NewGMSCast(expr), transform.NewTree, nil
// Some GMS functions wrap Doltgres parameters, so we'll only handle those that return GMS types
if _, ok := expr.Type().(pgtypes.DoltgresType); !ok {
return pgexprs.NewGMSCast(expr), transform.NewTree, nil
}
}
}
case *sql.ColumnDefaultValue:
Expand Down
4 changes: 4 additions & 0 deletions server/ast/resolvable_type_reference.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ func nodeResolvableTypeReference(typ tree.ResolvableTypeReference) (*vitess.Conv
return nil, nil, err
}
resolvedType = baseResolvedType.ToArrayType()
} else if columnType.Family() == types.GeometryFamily {
return nil, nil, fmt.Errorf("geometry types are not yet supported")
} else if columnType.Family() == types.GeographyFamily {
return nil, nil, fmt.Errorf("geography types are not yet supported")
} else {
switch columnType.Oid() {
case oid.T_bool:
Expand Down
21 changes: 14 additions & 7 deletions server/connection_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package server

import (
"context"
"crypto/tls"
"fmt"
"io"
Expand Down Expand Up @@ -93,6 +94,12 @@ func (h *ConnectionHandler) HandleConnection() {
eomErr = fmt.Errorf("panic: %v", r)
}

// Sending eom can panic, which means we must recover again
defer func() {
if r := recover(); r != nil {
fmt.Printf("Listener recovered panic: %v", r)
}
}()
h.endOfMessages(eomErr)
}

Expand Down Expand Up @@ -263,7 +270,7 @@ InitialMessageLoop:
// startup message provided
func (h *ConnectionHandler) chooseInitialDatabase(startupMessage messages.StartupMessage) error {
if db, ok := startupMessage.Parameters["database"]; ok && len(db) > 0 {
err := h.handler.ComQuery(h.mysqlConn, fmt.Sprintf("USE `%s`;", db), func(res *sqltypes.Result, more bool) error {
err := h.handler.ComQuery(context.Background(), h.mysqlConn, fmt.Sprintf("USE `%s`;", db), func(res *sqltypes.Result, more bool) error {
return nil
})
if err != nil {
Expand All @@ -280,7 +287,7 @@ func (h *ConnectionHandler) chooseInitialDatabase(startupMessage messages.Startu
} else {
// If a database isn't specified, then we attempt to connect to a database with the same name as the user,
// ignoring any error
_ = h.handler.ComQuery(h.mysqlConn, fmt.Sprintf("USE `%s`;", h.mysqlConn.User), func(*sqltypes.Result, bool) error {
_ = h.handler.ComQuery(context.Background(), h.mysqlConn, fmt.Sprintf("USE `%s`;", h.mysqlConn.User), func(*sqltypes.Result, bool) error {
return nil
})
}
Expand Down Expand Up @@ -483,7 +490,7 @@ func (h *ConnectionHandler) handleExecute(message messages.Execute) error {
return connection.Send(h.Conn(), messages.EmptyQueryResponse{})
}

err := h.handler.(mysql.ExtendedHandler).ComExecuteBound(h.mysqlConn, query.String, portalData.BoundPlan, spoolRowsCallback(h.Conn(), &complete, true))
err := h.handler.(mysql.ExtendedHandler).ComExecuteBound(context.Background(), h.mysqlConn, query.String, portalData.BoundPlan, spoolRowsCallback(h.Conn(), &complete, true))
if err != nil {
return err
}
Expand Down Expand Up @@ -895,7 +902,7 @@ func (h *ConnectionHandler) getPlanAndFields(query ConvertedQuery) (sql.Node, []
return nil, nil, fmt.Errorf("cannot prepare a query that has not been parsed")
}

parsedQuery, fields, err := h.handler.(mysql.ExtendedHandler).ComPrepareParsed(h.mysqlConn, query.String, query.AST, &mysql.PrepareData{
parsedQuery, fields, err := h.handler.(mysql.ExtendedHandler).ComPrepareParsed(context.Background(), h.mysqlConn, query.String, query.AST, &mysql.PrepareData{
PrepareStmt: query.String,
})

Expand All @@ -914,9 +921,9 @@ func (h *ConnectionHandler) getPlanAndFields(query ConvertedQuery) (sql.Node, []
// comQuery is a shortcut that determines which version of ComQuery to call based on whether the query has been parsed.
func (h *ConnectionHandler) comQuery(query ConvertedQuery, callback func(res *sqltypes.Result, more bool) error) error {
if query.AST == nil {
return h.handler.ComQuery(h.mysqlConn, query.String, callback)
return h.handler.ComQuery(context.Background(), h.mysqlConn, query.String, callback)
} else {
return h.handler.(mysql.ExtendedHandler).ComParsedQuery(h.mysqlConn, query.String, query.AST, callback)
return h.handler.(mysql.ExtendedHandler).ComParsedQuery(context.Background(), h.mysqlConn, query.String, query.AST, callback)
}
}

Expand All @@ -926,7 +933,7 @@ func (h *ConnectionHandler) bindParams(
parsedQuery sqlparser.Statement,
bindVars map[string]*querypb.BindVariable,
) (sql.Node, []*querypb.Field, error) {
bound, fields, err := h.handler.(mysql.ExtendedHandler).ComBind(h.mysqlConn, query, parsedQuery, &mysql.PrepareData{
bound, fields, err := h.handler.(mysql.ExtendedHandler).ComBind(context.Background(), h.mysqlConn, query, parsedQuery, &mysql.PrepareData{
PrepareStmt: query,
ParamsCount: uint16(len(bindVars)),
BindVars: bindVars,
Expand Down
125 changes: 88 additions & 37 deletions server/expression/in_tuple.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,14 @@ import (

// InTuple represents a VALUE IN (<VALUES>) expression.
type InTuple struct {
left sql.Expression
right expression.Tuple
leftExpr sql.Expression
rightExpr expression.Tuple

// These variables are used so that we can resolve the comparison functions once and reuse them as we iterate over rows.
// These are assigned in WithChildren, so refer there for more information.
staticLiteral *Literal
arrayLiterals []*Literal
compFuncs []*framework.CompiledFunction
}

var _ vitess.Injectable = (*BinaryOperator)(nil)
Expand All @@ -38,59 +44,51 @@ var _ expression.BinaryExpression = (*BinaryOperator)(nil)
// NewInTuple returns a new *InTuple.
func NewInTuple() *InTuple {
return &InTuple{
left: nil,
right: nil,
leftExpr: nil,
rightExpr: nil,
}
}

// Children implements the sql.Expression interface.
func (it *InTuple) Children() []sql.Expression {
return []sql.Expression{it.left, it.right}
return []sql.Expression{it.leftExpr, it.rightExpr}
}

// Eval implements the sql.Expression interface.
func (it *InTuple) Eval(ctx *sql.Context, row sql.Row) (any, error) {
left, err := it.left.Eval(ctx, row)
if len(it.compFuncs) == 0 {
return nil, fmt.Errorf("%T: cannot Eval as it has not been fully resolved", it)
}
// First we'll evaluate everything before we do the comparisons
left, err := it.leftExpr.Eval(ctx, row)
if err != nil {
return nil, err
}
rightInterface, err := it.right.Eval(ctx, row)
rightInterface, err := it.rightExpr.Eval(ctx, row)
if err != nil {
return nil, err
}
rightValues, ok := rightInterface.([]any)
if !ok {
// Tuples will return the value directly if it has a length of one, so we'll check for that first
if len(it.right) == 1 {
if len(it.rightExpr) == 1 {
rightValues = []any{rightInterface}
} else {
return nil, fmt.Errorf("%T: expected right child to return `%T` but returned `%T`", it, []any{}, rightInterface)
}
}
leftType, ok := it.left.Type().(pgtypes.DoltgresType)
if !ok {
return nil, fmt.Errorf("%T: GMS type `%s` on left child", it, it.left.Type().String())
}
// Next we'll assign our evaluated values to the expressions that the comparison functions reference
it.staticLiteral.value = left
for i, rightValue := range rightValues {
rightType, ok := it.right[i].Type().(pgtypes.DoltgresType)
if !ok {
return nil, fmt.Errorf("%T: GMS type `%s` within right child", it, it.right[i].Type().String())
}
// TODO: this should use the BinaryOperator expression, but since equality is not yet implemented, we implicitly cast
if !leftType.Equals(rightType) {
castFunc := framework.GetImplicitCast(rightType.BaseID(), leftType.BaseID())
if castFunc == nil {
return nil, fmt.Errorf("operator does not exist: %s = %s",
leftType.String(), rightType.String())
}
rightValue, err = castFunc(ctx, rightValue, leftType)
if err != nil {
return nil, err
}
}
if res, err := leftType.Compare(left, rightValue); err != nil {
it.arrayLiterals[i].value = rightValue
}
// Now we can loop over all of the comparison functions, as they'll reference their respective values
for _, compFunc := range it.compFuncs {
result, err := compFunc.Eval(ctx, row)
if err != nil {
return nil, err
} else if res == 0 {
}
if result.(bool) {
return true, nil
}
}
Expand All @@ -104,15 +102,23 @@ func (it *InTuple) IsNullable() bool {

// Resolved implements the sql.Expression interface.
func (it *InTuple) Resolved() bool {
return it.left != nil && it.left.Resolved() && it.right != nil && it.right.Resolved()
if it.leftExpr == nil || !it.leftExpr.Resolved() || it.rightExpr == nil || !it.rightExpr.Resolved() || len(it.compFuncs) == 0 {
return false
}
for _, compFunc := range it.compFuncs {
if !compFunc.Resolved() {
return false
}
}
return true
}

// String implements the sql.Expression interface.
func (it *InTuple) String() string {
if it.left == nil || it.right == nil {
if it.leftExpr == nil || it.rightExpr == nil {
return "? IN ?"
}
return fmt.Sprintf("%s IN %s", it.left.String(), it.right.String())
return fmt.Sprintf("%s IN %s", it.leftExpr.String(), it.rightExpr.String())
}

// Type implements the sql.Expression interface.
Expand All @@ -129,9 +135,54 @@ func (it *InTuple) WithChildren(children ...sql.Expression) (sql.Expression, err
if !ok {
return nil, fmt.Errorf("%T: expected right child to be `%T` but has type `%T`", it, expression.Tuple{}, children[1])
}
if len(rightTuple) == 0 {
return nil, fmt.Errorf("IN must contain at least 1 expression")
}
// We'll only resolve the comparison functions once we have all Doltgres types.
// We may see GMS types during some analyzer steps, so we should wait until those are done.
if leftType, ok := children[0].Type().(pgtypes.DoltgresType); ok {
// Rather than finding and resolving a comparison function every time we call Eval, we resolve them once and
// reuse the functions. We also want to avoid re-assigning the parameters of the comparison functions since that
// will also cause the functions to resolve again. To do this, we store expressions within our struct that the
// functions reference, so we can freely switch the values within the literals without changing anything
// regarding the comparison functions. This is usually unsafe, but since we're verifying the types returned by
// the parameters, and assigning the values to our own literals, we do not have to worry. This offers a
// significant speedup as function resolution is very expensive, so we want to do it as few times as possible
// (preferably once).
staticLiteral := &Literal{typ: leftType}
arrayLiterals := make([]*Literal, len(rightTuple))
// Each expression may be a different type (which is valid), so we need a comparison function for each expression.
compFuncs := make([]*framework.CompiledFunction, len(rightTuple))
allValidChildren := true
for i, rightExpr := range rightTuple {
rightType, ok := rightExpr.Type().(pgtypes.DoltgresType)
if !ok {
allValidChildren = false
break
}
arrayLiterals[i] = &Literal{typ: rightType}
compFuncs[i] = framework.GetBinaryFunction(framework.Operator_BinaryEqual).Compile("internal_in_comparison", staticLiteral, arrayLiterals[i])
if compFuncs[i] == nil {
return nil, fmt.Errorf("operator does not exist: %s = %s", leftType.String(), rightType.String())
}
if compFuncs[i].Type().(pgtypes.DoltgresType).BaseID() != pgtypes.DoltgresTypeBaseID_Bool {
// This should never happen, but this is just to be safe
return nil, fmt.Errorf("%T: found equality comparison that does not return a bool", it)
}
}
if allValidChildren {
return &InTuple{
leftExpr: children[0],
rightExpr: rightTuple,
staticLiteral: staticLiteral,
arrayLiterals: arrayLiterals,
compFuncs: compFuncs,
}, nil
}
}
return &InTuple{
left: children[0],
right: rightTuple,
leftExpr: children[0],
rightExpr: rightTuple,
}, nil
}

Expand All @@ -153,10 +204,10 @@ func (it *InTuple) WithResolvedChildren(children []any) (any, error) {

// Left implements the expression.BinaryExpression interface.
func (it *InTuple) Left() sql.Expression {
return it.left
return it.leftExpr
}

// Right implements the expression.BinaryExpression interface.
func (it *InTuple) Right() sql.Expression {
return it.right
return it.rightExpr
}
Loading

0 comments on commit 15f3e40

Please sign in to comment.