diff --git a/enginetest/evaluation.go b/enginetest/evaluation.go index 32d77fde64..a99c9decd0 100644 --- a/enginetest/evaluation.go +++ b/enginetest/evaluation.go @@ -948,8 +948,12 @@ func AssertErrWithCtx(t *testing.T, e QueryEngine, harness Harness, ctx *sql.Con require.Error(t, err) if expectedErrKind != nil { err = sql.UnwrapError(err) - if !IsServerEngine(e) { + if reh, ok := harness.(ResultEvaluationHarness); ok { + reh.EvaluateExpectedErrorKind(t, expectedErrKind, err) + } else if !IsServerEngine(e) { require.True(t, expectedErrKind.Is(err), "Expected error of type %s but got %s", expectedErrKind, err) + } else { + t.Skipf("Unimplemented error kind check for harness %T", harness) } } diff --git a/enginetest/harness.go b/enginetest/harness.go index 94c733a2ce..2a28d882fd 100644 --- a/enginetest/harness.go +++ b/enginetest/harness.go @@ -17,6 +17,8 @@ package enginetest import ( "testing" + "gopkg.in/src-d/go-errors.v1" + sqle "github.com/dolthub/go-mysql-server" "github.com/dolthub/go-mysql-server/enginetest/scriptgen/setup" "github.com/dolthub/go-mysql-server/server" @@ -167,4 +169,7 @@ type ResultEvaluationHarness interface { // EvaluateExpectedError compares expected error strings to actual errors and emits failed test assertions in the // event there are any EvaluateExpectedError(t *testing.T, expected string, err error) + + // EvaluateExpectedErrorKind compares expected error kinds to actual errors and emits failed test assertions in the + EvaluateExpectedErrorKind(t *testing.T, expected *errors.Kind, err error) } diff --git a/sql/analyzer/apply_indexes_from_outer_scope.go b/sql/analyzer/apply_indexes_from_outer_scope.go index ab8a933ff9..ef5f55274a 100644 --- a/sql/analyzer/apply_indexes_from_outer_scope.go +++ b/sql/analyzer/apply_indexes_from_outer_scope.go @@ -92,24 +92,24 @@ func applyIndexesFromOuterScope(ctx *sql.Context, a *Analyzer, n sql.Node, scope // sql.IndexAddressableTable func pushdownIndexToTable(ctx *sql.Context, a *Analyzer, tableNode sql.NameableNode, index sql.Index, keyExpr []sql.Expression, nullmask []bool) (sql.Node, transform.TreeIdentity, error) { return transform.Node(tableNode, func(n sql.Node) (sql.Node, transform.TreeIdentity, error) { - switch n := n.(type) { + switch nn := n.(type) { case *plan.IndexedTableAccess: case sql.TableNode: table := getTable(tableNode) if table == nil { return n, transform.SameTree, nil } - if _, ok := table.(sql.IndexAddressableTable); ok { - a.Log("table %q transformed with pushdown of index", tableNode.Name()) - lb := plan.NewLookupBuilder(index, keyExpr, nullmask) - - ret, err := plan.NewIndexedAccessForTableNode(n, lb) - if err != nil { - return nil, transform.SameTree, err - } - - return ret, transform.NewTree, nil + _, isIdxAddrTbl := table.(sql.IndexAddressableTable) + if !isIdxAddrTbl { + return n, transform.SameTree, nil + } + a.Log("table %q transformed with pushdown of index", tableNode.Name()) + lb := plan.NewLookupBuilder(index, keyExpr, nullmask) + ret, err := plan.NewIndexedAccessForTableNode(nn, lb) + if err != nil { + return nil, transform.SameTree, err } + return ret, transform.NewTree, nil } return n, transform.SameTree, nil }) diff --git a/sql/analyzer/indexed_joins.go b/sql/analyzer/indexed_joins.go index b9896e7848..37a0b23e38 100644 --- a/sql/analyzer/indexed_joins.go +++ b/sql/analyzer/indexed_joins.go @@ -1396,7 +1396,7 @@ func attrsRefSingleTableCol(e sql.Expression) (tableCol, bool) { transform.InspectExpr(e, func(e sql.Expression) bool { switch e := e.(type) { case *expression.GetField: - newTc := tableCol{col: strings.ToLower(e.Name()), table: strings.ToLower(e.Table())} + newTc := newTableCol(e.Table(), e.Name()) if tc.table == "" && !invalid { tc = newTc } else if tc != newTc { diff --git a/sql/analyzer/symbol_resolution.go b/sql/analyzer/symbol_resolution.go index 3166f18fc2..c829a20bae 100644 --- a/sql/analyzer/symbol_resolution.go +++ b/sql/analyzer/symbol_resolution.go @@ -47,6 +47,11 @@ import ( // - stars: a tablescan with a qualified star or cannot be pruned. An // unqualified star prevents pruning every child tablescan. func pruneTables(ctx *sql.Context, a *Analyzer, n sql.Node, s *plan.Scope, sel RuleSelector, qFlags *sql.QueryFlags) (sql.Node, transform.TreeIdentity, error) { + // MATCH ... AGAINST ... prevents pruning due to its internal reliance on an expected and consistent schema in all situations + if hasMatchAgainstExpr(n) { + return n, transform.SameTree, nil + } + // the same table can appear in multiple table scans, // so we use a counter to pin references parentCols := make(map[tableCol]int) @@ -73,11 +78,6 @@ func pruneTables(ctx *sql.Context, a *Analyzer, n sql.Node, s *plan.Scope, sel R unqualifiedStar = beforeUnq } - // MATCH ... AGAINST ... prevents pruning due to its internal reliance on an expected and consistent schema in all situations - if ma := findMatchAgainstExpr(n); ma != nil { - return n, transform.SameTree, nil - } - var pruneWalk func(n sql.Node) (sql.Node, transform.TreeIdentity, error) pruneWalk = func(n sql.Node) (sql.Node, transform.TreeIdentity, error) { switch n := n.(type) { @@ -170,17 +170,17 @@ func findSubqueryExpr(n sql.Node) *plan.Subquery { return nil } -// findMatchAgainstExpr searches for an *expression.MatchAgainst within the node, returning the node or nil. -func findMatchAgainstExpr(n sql.Node) *expression.MatchAgainst { - var maExpr *expression.MatchAgainst - transform.InspectExpressionsWithNode(n, func(n sql.Node, expr sql.Expression) bool { - if matchAgainstExpr, ok := expr.(*expression.MatchAgainst); ok { - maExpr = matchAgainstExpr - return false +// hasMatchAgainstExpr searches for an *expression.MatchAgainst within the node's expressions +func hasMatchAgainstExpr(node sql.Node) bool { + var foundMatchAgainstExpr bool + transform.InspectExpressions(node, func(expr sql.Expression) bool { + _, isMatchAgainstExpr := expr.(*expression.MatchAgainst) + if isMatchAgainstExpr { + foundMatchAgainstExpr = true } - return true + return !foundMatchAgainstExpr }) - return maExpr + return foundMatchAgainstExpr } // pruneTableCols uses a list of parent dependencies columns and stars @@ -194,8 +194,11 @@ func pruneTableCols( unqualifiedStar bool, ) (sql.Node, transform.TreeIdentity, error) { table := getTable(n) - ptab, ok := table.(sql.ProjectedTable) - if !ok || table.Name() == plan.DualTableName { + ptab, isProjTbl := table.(sql.ProjectedTable) + if !isProjTbl || plan.IsDualTable(table) { + return n, transform.SameTree, nil + } + if len(ptab.Projections()) > 0 { return n, transform.SameTree, nil } @@ -204,33 +207,31 @@ func pruneTableCols( selectStar = true } - if len(ptab.Projections()) > 0 { - return n, transform.SameTree, nil - } - // Don't prune columns if they're needed by a virtual column virtualColDeps := make(map[tableCol]int) - if vct, ok := n.WrappedTable().(*plan.VirtualColumnTable); ok { - for _, projection := range vct.Projections { - transform.Expr(projection, func(e sql.Expression) (sql.Expression, transform.TreeIdentity, error) { - if cd, ok := e.(*sql.ColumnDefaultValue); ok { - transform.Expr(cd.Expr, func(e sql.Expression) (sql.Expression, transform.TreeIdentity, error) { - if gf, ok := e.(*expression.GetField); ok { - c := tableCol{table: strings.ToLower(gf.Table()), col: strings.ToLower(gf.Name())} - virtualColDeps[c] = virtualColDeps[c] + 1 - } - return e, transform.SameTree, nil - }) - } - return e, transform.SameTree, nil - }) + if !selectStar { // if selectStar, we're adding all columns anyway + if vct, isVCT := n.WrappedTable().(*plan.VirtualColumnTable); isVCT { + for _, projection := range vct.Projections { + transform.InspectExpr(projection, func(e sql.Expression) bool { + if cd, isCD := e.(*sql.ColumnDefaultValue); isCD { + transform.InspectExpr(cd.Expr, func(e sql.Expression) bool { + if gf, ok := e.(*expression.GetField); ok { + c := newTableCol(gf.Table(), gf.Name()) + virtualColDeps[c]++ + } + return false + }) + } + return false + }) + } } } cols := make([]string, 0) source := strings.ToLower(table.Name()) for _, col := range table.Schema() { - c := tableCol{table: strings.ToLower(source), col: strings.ToLower(col.Name)} + c := newTableCol(source, col.Name) if selectStar || parentCols[c] > 0 || virtualColDeps[c] > 0 { cols = append(cols, c.col) } @@ -251,6 +252,7 @@ func gatherOuterCols(n sql.Node) ([]tableCol, []string, bool) { if !ok { return nil, nil, false } + var cols []tableCol var nodeStars []string var nodeUnqualifiedStar bool @@ -261,15 +263,15 @@ func gatherOuterCols(n sql.Node) ([]tableCol, []string, bool) { case *expression.Alias: switch e := e.Child.(type) { case *expression.GetField: - col = tableCol{table: strings.ToLower(e.Table()), col: strings.ToLower(e.Name())} + col = newTableCol(e.Table(), e.Name()) case *expression.UnresolvedColumn: - col = tableCol{table: strings.ToLower(e.Table()), col: strings.ToLower(e.Name())} + col = newTableCol(e.Table(), e.Name()) default: } case *expression.GetField: - col = tableCol{table: strings.ToLower(e.Table()), col: strings.ToLower(e.Name())} + col = newTableCol(e.Table(), e.Name()) case *expression.UnresolvedColumn: - col = tableCol{table: strings.ToLower(e.Table()), col: strings.ToLower(e.Name())} + col = newTableCol(e.Table(), e.Name()) case *expression.Star: if len(e.Table) > 0 { nodeStars = append(nodeStars, strings.ToLower(e.Table)) @@ -280,7 +282,6 @@ func gatherOuterCols(n sql.Node) ([]tableCol, []string, bool) { } if col.col != "" { cols = append(cols, col) - } return false }) @@ -315,8 +316,8 @@ func gatherTableAlias( starred = true } for _, col := range n.Schema() { - baseCol := tableCol{table: strings.ToLower(base), col: strings.ToLower(col.Name)} - aliasCol := tableCol{table: strings.ToLower(alias), col: strings.ToLower(col.Name)} + baseCol := newTableCol(base, col.Name) + aliasCol := newTableCol(alias, col.Name) if starred || parentCols[aliasCol] > 0 { // if the outer scope requests an aliased column // a table lower in the tree must provide the source diff --git a/sql/analyzer/tables.go b/sql/analyzer/tables.go index 31582d9a80..4d2336901e 100644 --- a/sql/analyzer/tables.go +++ b/sql/analyzer/tables.go @@ -71,21 +71,23 @@ func getUnaliasedTableName(node sql.Node) string { // Finds first table node that is a descendant of the node given func getTable(node sql.Node) sql.Table { var table sql.Table - transform.Inspect(node, func(node sql.Node) bool { + transform.Inspect(node, func(n sql.Node) bool { + // Inspect is called on all children of a node even if an earlier child's call returns false. + // We only want the first TableNode match. if table != nil { return false } - - switch n := node.(type) { + switch nn := n.(type) { case sql.TableNode: - table = n.UnderlyingTable() - // TODO unwinding a table wrapper here causes infinite analyzer recursion + // TODO: unwinding a table wrapper here causes infinite analyzer recursion + table = nn.UnderlyingTable() return false case *plan.IndexedTableAccess: - table = n.TableNode.UnderlyingTable() + table = nn.TableNode.UnderlyingTable() return false + default: + return true } - return true }) return table } @@ -94,25 +96,23 @@ func getTable(node sql.Node) sql.Table { // This function will not look inside SubqueryAliases func getResolvedTable(node sql.Node) *plan.ResolvedTable { var table *plan.ResolvedTable - transform.Inspect(node, func(node sql.Node) bool { - // plan.Inspect will get called on all children of a node even if one of the children's calls returns false. We - // only want the first TableNode match. + transform.Inspect(node, func(n sql.Node) bool { + // Inspect is called on all children of a node even if an earlier child's call returns false. + // We only want the first TableNode match. if table != nil { return false } - - switch n := node.(type) { + switch nn := n.(type) { case *plan.SubqueryAlias: // We should not be matching with ResolvedTables inside SubqueryAliases return false case *plan.ResolvedTable: - if !plan.IsDualTable(n) { - table = n + if !plan.IsDualTable(nn) { + table = nn return false } case *plan.IndexedTableAccess: - rt, ok := n.TableNode.(*plan.ResolvedTable) - if ok { + if rt, ok := nn.TableNode.(*plan.ResolvedTable); ok { table = rt return false } diff --git a/sql/expression/function/str_to_date_test.go b/sql/expression/function/str_to_date_test.go index 19e96dac59..fffea82a30 100644 --- a/sql/expression/function/str_to_date_test.go +++ b/sql/expression/function/str_to_date_test.go @@ -18,9 +18,17 @@ func TestStrToDate(t *testing.T) { name string dateStr string fmtStr string - expected string + expected interface{} }{ {"standard", "Dec 26, 2000 2:13:15", "%b %e, %Y %T", "2000-12-26 02:13:15"}, + {"ymd", "20240101", "%Y%m%d", "2024-01-01"}, + {"ymd", "2024121", "%Y%m%d", "2024-12-01"}, + {"ymd", "20241301", "%Y%m%d", nil}, + {"ymd", "20240001", "%Y%m%d", nil}, + {"ymd-with-time", "2024010203:04:05", "%Y%m%d%T", "2024-01-02 03:04:05"}, + {"ymd-with-time", "202408122:03:04", "%Y%m%d%T", "2024-08-12 02:03:04"}, + // TODO: It shoud be nil, but returns "2024-02-31" + // {"ymd", "20240231", "%Y%m%d", nil}, } for _, tt := range testCases { diff --git a/sql/plan/resolved_table.go b/sql/plan/resolved_table.go index ffcd364119..78529e34ac 100644 --- a/sql/plan/resolved_table.go +++ b/sql/plan/resolved_table.go @@ -252,28 +252,29 @@ func (*ResolvedTable) CollationCoercibility(ctx *sql.Context) (collation sql.Col // WithTable returns this Node with the given table, re-wrapping it with any MutableTableWrapper that was // wrapping it prior to this call. -func (t ResolvedTable) WithTable(table sql.Table) (sql.MutableTableNode, error) { +func (t *ResolvedTable) WithTable(table sql.Table) (sql.MutableTableNode, error) { if t.Name() != table.Name() { return nil, fmt.Errorf("attempted to update TableNode `%s` with table `%s`", t.Name(), table.Name()) } - if mtw, ok := t.Table.(sql.MutableTableWrapper); ok { - t.Table = mtw.WithUnderlying(table) + nt := *t + if mtw, ok := nt.Table.(sql.MutableTableWrapper); ok { + nt.Table = mtw.WithUnderlying(table) } else { - t.Table = table + nt.Table = table } - return &t, nil + return &nt, nil } // ReplaceTable returns this Node with the given table without performing any re-wrapping of any MutableTableWrapper -func (t ResolvedTable) ReplaceTable(table sql.Table) (sql.MutableTableNode, error) { +func (t *ResolvedTable) ReplaceTable(table sql.Table) (sql.MutableTableNode, error) { if t.Name() != table.Name() { return nil, fmt.Errorf("attempted to update TableNode `%s` with table `%s`", t.Name(), table.Name()) } - - t.Table = table - return &t, nil + nt := *t + nt.Table = table + return &nt, nil } // TableIdNode is a distinct source of rows associated with a table diff --git a/sql/planbuilder/dateparse/date.go b/sql/planbuilder/dateparse/date.go index 9a4d15bea3..12e05cbf33 100644 --- a/sql/planbuilder/dateparse/date.go +++ b/sql/planbuilder/dateparse/date.go @@ -220,7 +220,7 @@ var formatSpecifiers = map[byte]parser{ // %D Day of the month with English suffix (0th, 1st, 2nd, 3rd, …) 'D': parseDayNumericWithEnglishSuffix, // %d Day of the month, numeric (00..31) - 'd': parseDayOfMonthNumeric, + 'd': parseDayOfMonth2DigitNumeric, // %e Day of the month, numeric (0..31) 'e': parseDayOfMonthNumeric, // %f Microseconds (000000..999999) @@ -242,7 +242,7 @@ var formatSpecifiers = map[byte]parser{ // %M Month name (January..December) 'M': parseMonthName, // %m Month, numeric (00..12) - 'm': parseMonthNumeric, + 'm': parseMonth2DigitNumeric, // %p AM or PM 'p': parseAmPm, // %r Time, 12-hour (hh:mm:ss followed by AM or PM) diff --git a/sql/planbuilder/dateparse/parsers.go b/sql/planbuilder/dateparse/parsers.go index 18a44a0150..756c686ac7 100644 --- a/sql/planbuilder/dateparse/parsers.go +++ b/sql/planbuilder/dateparse/parsers.go @@ -84,6 +84,19 @@ func parseMonthNumeric(result *datetime, chars string) (rest string, _ error) { return rest, nil } +func parseMonth2DigitNumeric(result *datetime, chars string) (rest string, _ error) { + num, rest, err := takeNumberAtMostNChars(2, chars) + if err != nil { + return "", err + } + if num < 1 || num > 12 { + return "", fmt.Errorf("expected 01-12, got %s", string(chars)) + } + month := time.Month(num) + result.month = &month + return rest, nil +} + func parseDayOfMonthNumeric(result *datetime, chars string) (rest string, _ error) { num, rest, err := takeNumber(chars) if err != nil { @@ -93,6 +106,18 @@ func parseDayOfMonthNumeric(result *datetime, chars string) (rest string, _ erro return rest, nil } +func parseDayOfMonth2DigitNumeric(result *datetime, chars string) (rest string, _ error) { + num, rest, err := takeNumberAtMostNChars(2, chars) + if err != nil { + return "", err + } + if num < 1 || num > 31 { + return "", fmt.Errorf("expected 01-31, got %s", string(chars)) + } + result.day = &num + return rest, nil +} + func parseMicrosecondsNumeric(result *datetime, chars string) (rest string, _ error) { num, rest, err := takeNumber(chars) if err != nil { @@ -224,7 +249,7 @@ func parseYear4DigitNumeric(result *datetime, chars string) (rest string, _ erro if len(chars) < 4 { return "", fmt.Errorf("expected at least 4 chars, got %d", len(chars)) } - year, rest, err := takeNumber(chars) + year, rest, err := takeNumberAtMostNChars(4, chars) if err != nil { return "", err }