From 5b7216c2aeeada143ad351c982f2b39e7e7fbd53 Mon Sep 17 00:00:00 2001 From: James Cor Date: Fri, 9 Feb 2024 15:27:41 -0800 Subject: [PATCH 01/27] fix decimal return type --- enginetest/memory_engine_test.go | 124 +++++++++++++++++++++++++------ enginetest/queries/tpch_plans.go | 4 +- sql/expression/arithmetic.go | 78 +++++++++++++++++-- sql/expression/comparison.go | 5 ++ sql/expression/div.go | 80 ++++++++++++++++---- sql/expression/div_test.go | 25 +++---- sql/expression/in.go | 13 +++- 7 files changed, 267 insertions(+), 62 deletions(-) diff --git a/enginetest/memory_engine_test.go b/enginetest/memory_engine_test.go index b2cb356a99..5a9e1edae7 100644 --- a/enginetest/memory_engine_test.go +++ b/enginetest/memory_engine_test.go @@ -203,46 +203,128 @@ func newUpdateResult(matched, updated int) types.OkResult { // Convenience test for debugging a single query. Unskip and set to the desired query. func TestSingleScript(t *testing.T) { - t.Skip() var scripts = []queries.ScriptTest{ { - Name: "physical columns added after virtual one", + Name: "delete me", SetUpScript: []string{ - "create table t (pk int primary key, col1 int as (pk + 1));", - "insert into t (pk) values (1), (3)", - "alter table t add index idx1 (col1, pk);", - "alter table t add index idx2 (col1);", - "alter table t add column col2 int;", - "alter table t add column col3 int;", - "insert into t (pk, col2, col3) values (2, 4, 5);", + "CREATE TABLE tab0(col0 INTEGER, col1 INTEGER, col2 INTEGER);", + "INSERT INTO tab0 VALUES(97,1,99);", }, Assertions: []queries.ScriptTestAssertion{ { - Query: "select * from t order by pk", + Query: "select 1 / 1", + Expected: []sql.Row{ + {"1.0000"}, + }, + }, + + { + Query: "select 1 / 3 * 3;", + Expected: []sql.Row{ + {"1.0000"}, + }, + }, + { + Query: "select 1 / 3 * 3 = 0.999999999;", + Expected: []sql.Row{ + {true}, + }, + }, + { + Query: "SELECT col2 IN ( 98 + col0 / 99 ) from tab0;", Expected: []sql.Row{ - {1, 2, nil, nil}, - {2, 3, 4, 5}, - {3, 4, nil, nil}, + {false}, }, }, { - Query: "select * from t where col1 = 2", + Query: "SELECT col2 IN ( 98 + 97 / 99 ) from tab0;", Expected: []sql.Row{ - {1, 2, nil, nil}, + {false}, }, }, { - Query: "select * from t where col1 = 3 and pk = 2", + Query: "SELECT 99 IN ( 98 + 97 / 99 );", Expected: []sql.Row{ - {2, 3, 4, 5}, + {false}, }, }, { - Query: "select * from t where pk = 2", + Query: "SELECT 1 IN ( 97 / 99 );", Expected: []sql.Row{ - {2, 3, 4, 5}, + {false}, }, }, + + { + Query: "SELECT * FROM tab0 WHERE col2 IN ( 98 + 97 / 99 );", + Expected: []sql.Row{ + }, + }, + { + Query: "SELECT ALL * FROM tab0 AS cor0 WHERE col2 IN ( 39 + + 89, col0 + + col1 + + ( - ( - col0 ) ) / col2, + ( col0 ) + - 99, + col1, + col2 * - + col2 * - 12 + col1 + - 66 );", + Expected: []sql.Row{ + }, + }, + + + { + Query: `SELECT 1 IN (1 / 9 * 5);`, + Expected: []sql.Row{{false}}, + }, + { + Query: `select 1 / 3 * 3 = 1;`, + Expected: []sql.Row{{false}}, + }, + { + Query: `select 1 / 3 * 3 = 0.999999999;`, + Expected: []sql.Row{{true}}, + }, + + { + Query: `select 1 / 3 * 3 in (1);`, + Expected: []sql.Row{{false}}, + }, + { + Query: `select 1 in (1 / 3 * 3);`, + Expected: []sql.Row{{false}}, + }, + { + Query: `SELECT 1 IN (1 / 9 * 5);`, + Expected: []sql.Row{{false}}, + }, + { + Query: `SELECT 1 / 9 * 5 IN (1);`, + Expected: []sql.Row{{false}}, + }, + { + Query: `SELECT 1 / 9 * 5 IN (1 / 9 * 5);`, + Expected: []sql.Row{{true}}, + }, + + { + Query: `SELECT 1 IN (1 / 99 * 50);`, + Expected: []sql.Row{{false}}, + }, + { + Query: `select 1 / 3 * 3 in (0.999999999);`, + Expected: []sql.Row{{true}}, + }, + { + Query: `SELECT 96 / 51 * 51 > 96;`, + Expected: []sql.Row{{false}}, + }, + { + Query: `SELECT 96 / 51 * 51 = 95.999999991;`, + Expected: []sql.Row{{true}}, + }, + { + Query: `select 64 / 77 * 77;`, + Expected: []sql.Row{{"64.0000"}}, + }, + { + Query: `select (1 / 3) * (1 / 3);`, + Expected: []sql.Row{{"0.11111111"}}, + }, }, }, } @@ -254,8 +336,8 @@ func TestSingleScript(t *testing.T) { if err != nil { panic(err) } - engine.EngineAnalyzer().Debug = true - engine.EngineAnalyzer().Verbose = true + //engine.EngineAnalyzer().Debug = true + //engine.EngineAnalyzer().Verbose = true enginetest.TestScriptWithEngine(t, engine, harness, test) } diff --git a/enginetest/queries/tpch_plans.go b/enginetest/queries/tpch_plans.go index 60f5284d8c..e0b82c01ec 100644 --- a/enginetest/queries/tpch_plans.go +++ b/enginetest/queries/tpch_plans.go @@ -867,10 +867,10 @@ where " │ │ └─ AND\n" + " │ │ ├─ GreaterThanOrEqual\n" + " │ │ │ ├─ lineitem.l_discount:2!null\n" + - " │ │ │ └─ 0.05 (decimal(3,2))\n" + + " │ │ │ └─ 0.05 (decimal(5,2))\n" + " │ │ └─ LessThanOrEqual\n" + " │ │ ├─ lineitem.l_discount:2!null\n" + - " │ │ └─ 0.07 (decimal(3,2))\n" + + " │ │ └─ 0.07 (decimal(5,2))\n" + " │ └─ LessThan\n" + " │ ├─ lineitem.l_quantity:0!null\n" + " │ └─ 24 (tinyint)\n" + diff --git a/sql/expression/arithmetic.go b/sql/expression/arithmetic.go index 9fccd808c8..1f41836233 100644 --- a/sql/expression/arithmetic.go +++ b/sql/expression/arithmetic.go @@ -157,11 +157,61 @@ func (a *Arithmetic) Type() sql.Type { return types.Int64 } - if a.Op == sqlparser.MultStr { - return floatOrDecimalTypeForMult(a.LeftChild, a.RightChild) - } else { - return getFloatOrMaxDecimalType(a, false) + // TODO: special cases for div, intdiv, and mod? + + if types.IsDecimal(lTyp) && !types.IsDecimal(rTyp) { + return lTyp + } + + if types.IsDecimal(rTyp) && !types.IsDecimal(lTyp) { + return rTyp + } + + if types.IsDecimal(lTyp) && types.IsDecimal(rTyp) { + lPrec := lTyp.(types.DecimalType_).Precision() + lScale := lTyp.(types.DecimalType_).Scale() + rPrec := rTyp.(types.DecimalType_).Precision() + rScale := rTyp.(types.DecimalType_).Scale() + + var prec, scale uint8 + if lPrec > rPrec { + prec = lPrec + } else { + prec = rPrec + } + + switch a.Op { + case sqlparser.PlusStr, sqlparser.MinusStr: + if lScale > rScale { + scale = lScale + } else { + scale = rScale + } + prec = prec + scale + case sqlparser.MultStr: + scale = lScale+rScale + prec = prec + scale + case sqlparser.DivStr: + if lScale > rScale { + scale = lScale + } else { + scale = rScale + } + scale = scale + divPrecisionIncrement + prec = prec + divPrecisionIncrement + } + + if prec > types.DecimalTypeMaxPrecision { + prec = types.DecimalTypeMaxPrecision + } + if scale > types.DecimalTypeMaxScale { + scale = types.DecimalTypeMaxScale + } + + return types.MustCreateDecimalType(prec, scale) } + + return getFloatOrMaxDecimalType(a, false) } // CollationCoercibility implements the interface sql.CollationCoercible. @@ -215,10 +265,16 @@ func (a *Arithmetic) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { } // Decimals must be rounded - if res, ok := result.(decimal.Decimal); ok && isOutermostArithmeticOp(a, a.ops) { - finalScale, hasDiv := getFinalScale(a) - if hasDiv { - return res.Round(finalScale), nil + if res, ok := result.(decimal.Decimal); ok { + if isOutermostArithmeticOp(a, a.ops) { + finalScale, hasDiv := getFinalScale(a) + if hasDiv { + return res.Round(finalScale), nil + } + } + // In comparisons, we need to truncate decimals to have scale of 9 + if a.ops == -1 { + result = res.Truncate(9) } } @@ -319,6 +375,12 @@ func setArithmeticOps(e sql.Expression, opScale int32) { setArithmeticOps(a.Right(), opScale) } + if tup, ok := e.(Tuple); ok { + for _, expr := range tup { + setArithmeticOps(expr, opScale) + } + } + return } diff --git a/sql/expression/comparison.go b/sql/expression/comparison.go index 1023fc43f1..176e30e444 100644 --- a/sql/expression/comparison.go +++ b/sql/expression/comparison.go @@ -68,6 +68,11 @@ type comparison struct { } func newComparison(left, right sql.Expression) comparison { + // TODO: somewhat hacky way to disable rounding for comparisons + setArithmeticOps(left, -1) + setArithmeticOps(right, -1) + setDivs(left, -1) + setDivs(right, -1) return comparison{BinaryExpressionStub{left, right}} } diff --git a/sql/expression/div.go b/sql/expression/div.go index 1994f60f66..3adb1e7191 100644 --- a/sql/expression/div.go +++ b/sql/expression/div.go @@ -304,13 +304,52 @@ func (d *Div) determineResultType(outermostResult bool) sql.Type { return types.Float64 } - // For division operations, the result type is always either a float or decimal.Decimal. When working with - // integers, we prefer float types internally, since the performance is orders of magnitude faster to divide - // floats than to divide Decimals, but if this is the outermost division operation, we need to - // return a decimal in order to match MySQL's results exactly. - return floatOrDecimalTypeForDiv(d, !outermostResult) + if !outermostResult && types.IsNumber(lTyp) && types.IsNumber(rTyp) { + // TODO: does this mean we're never using float optimization? + //return types.Float64 + } + + // numerical outermost results from here on + + if types.IsFloat(lTyp) || types.IsFloat(rTyp) { + return types.Float64 + } + + if types.IsInteger(lTyp) { + // TODO: determine precision? + if d.ops == -1 { + return types.MustCreateDecimalType(types.DecimalTypeMaxPrecision, 9) + } + return types.MustCreateDecimalType(types.DecimalTypeMaxPrecision, divPrecisionIncrement) + } + + if types.IsDecimal(lTyp) { + // TODO: better verify this + lPrec, lScale := lTyp.(types.DecimalType_).Precision(), lTyp.(types.DecimalType_).Scale() + lPrec = lPrec + divPrecisionIncrement + lScale = lScale + divPrecisionIncrement + if d.ops == -1 { + lScale = (lScale / 9 + 1) * 9 + if lScale > types.DecimalTypeMaxScale { + lScale = types.DecimalTypeMaxScale + } + return types.MustCreateDecimalType(lPrec, lScale) + } + + if lPrec > types.DecimalTypeMaxPrecision { + lPrec = types.DecimalTypeMaxPrecision + } + if lScale > types.DecimalTypeMaxScale { + lScale = types.DecimalTypeMaxScale + } + return types.MustCreateDecimalType(lPrec, lScale) + } + + // TODO: missing cases + return lTyp } +// TODO: this is unused now, consider deleting // floatOrDecimalTypeForDiv returns either Float64 or Decimal type depending on column reference, // left and right expression types and left and right evaluated types. // If |treatIntsAsFloats| is true, then integers are treated as floats instead of Decimals. This @@ -328,24 +367,31 @@ func floatOrDecimalTypeForDiv(e sql.Expression, treatIntsAsFloats bool) sql.Type // if not float, it must be decimal type if treatIntsAsFloats { - return t + //return t } // for Div expression, if it's the outermostResult, then add the additional scales for the final result - p, s := t.(types.DecimalType_).Precision(), t.(types.DecimalType_).Scale() - maxWhole := p - s - maxFrac := s + prec, scale := t.(types.DecimalType_).Precision(), t.(types.DecimalType_).Scale() + maxWhole := prec - scale + maxScale := scale div := e.(*Div) finalScale := div.leftmostScale.Load() + div.divScale*int32(divPrecisionIncrement) - if finalScale > types.DecimalTypeMaxScale { - finalScale = types.DecimalTypeMaxScale - } else if uint8(finalScale) > maxFrac { - maxFrac = uint8(finalScale) + if uint8(finalScale) > maxScale { + maxScale = uint8(finalScale) + } + + if maxScale > types.DecimalTypeMaxScale { + maxScale = types.DecimalTypeMaxScale } - return types.MustCreateDecimalType(maxWhole+maxFrac, maxFrac) + prec = maxWhole + maxScale + if prec > types.DecimalTypeMaxPrecision { + prec = types.DecimalTypeMaxPrecision + } + + return types.MustCreateDecimalType(prec, maxScale) } // getFloatOrMaxDecimalType returns either Float64 or Decimal type with max precision and scale @@ -501,6 +547,12 @@ func setDivs(e sql.Expression, dScale int32) { setDivs(a.Right(), dScale) } + if tup, ok := e.(Tuple); ok { + for _, expr := range tup { + setDivs(expr, dScale) + } + } + return } diff --git a/sql/expression/div_test.go b/sql/expression/div_test.go index 7fb01ee5ab..669786b080 100644 --- a/sql/expression/div_test.go +++ b/sql/expression/div_test.go @@ -29,23 +29,23 @@ func TestDiv(t *testing.T) { var floatTestCases = []struct { name string left, right float64 - expected string + expected float64 null bool }{ - {"1 / 1", 1, 1, "1.0000", false}, - {"1 / 2", 1, 2, "0.5000", false}, - {"-1 / 1.0", -1, 1, "-1.0000", false}, - {"0 / 1234567890", 0, 12345677890, "0.0000", false}, - {"3.14159 / 3.0", 3.14159, 3.0, "1.047196667", false}, - {"1/0", 1, 0, "", true}, - {"-1/0", -1, 0, "", true}, - {"0/0", 0, 0, "", true}, + {"1 / 1", 1, 1, 1.0, false}, + {"1 / 2", 1, 2, 0.5, false}, + {"-1 / 1.0", -1, 1, -1.0, false}, + {"0 / 1234567890", 0, 12345677890, 0.0, false}, + {"3.14159 / 3.0", 3.14159, 3.0, 1.0471966666666666, false}, + {"1/0", 1, 0, 0.0, true}, + {"-1/0", -1, 0, 0.0, true}, + {"0/0", 0, 0, 0.0, true}, } for _, tt := range floatTestCases { t.Run(tt.name, func(t *testing.T) { + // The numbers are interpreted as Float64 without going through parser, so we lose precision here for 1.0 result, err := NewDiv( - // The numbers are interpreted as Float64 without going through parser, so we lose precision here for 1.0 NewLiteral(tt.left, types.Float64), NewLiteral(tt.right, types.Float64), ).Eval(sql.NewEmptyContext(), sql.NewRow()) @@ -53,9 +53,7 @@ func TestDiv(t *testing.T) { if tt.null { assert.Equal(t, nil, result) } else { - r, ok := result.(decimal.Decimal) - assert.True(t, ok) - assert.Equal(t, tt.expected, r.StringFixed(r.Exponent()*-1)) + assert.Equal(t, tt.expected, result) } }) } @@ -121,6 +119,7 @@ func TestDiv(t *testing.T) { // TestDivUsesFloatsInternally tests that division expression trees internally use floating point types when operating // on integers, but when returning the final result from the expression tree, it is returned as a Decimal. func TestDivUsesFloatsInternally(t *testing.T) { + t.Skip("maybe we don't want this") bottomDiv := NewDiv( NewGetField(0, types.Int32, "", false), NewGetField(1, types.Int64, "", false)) diff --git a/sql/expression/in.go b/sql/expression/in.go index cca8cb92c9..c337e031bb 100644 --- a/sql/expression/in.go +++ b/sql/expression/in.go @@ -54,6 +54,10 @@ func (in *InTuple) Right() sql.Expression { // NewInTuple creates an InTuple expression. func NewInTuple(left sql.Expression, right sql.Expression) *InTuple { + setArithmeticOps(left, -1) + setArithmeticOps(right, -1) + setDivs(left, -1) + setDivs(right, -1) return &InTuple{BinaryExpressionStub{left, right}} } @@ -61,12 +65,12 @@ func NewInTuple(left sql.Expression, right sql.Expression) *InTuple { func (in *InTuple) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { typ := in.Left().Type().Promote() leftElems := types.NumColumns(typ) - left, err := in.Left().Eval(ctx, row) + originalLeft, err := in.Left().Eval(ctx, row) if err != nil { return nil, err } - if left == nil { + if originalLeft == nil { return nil, nil } @@ -76,7 +80,7 @@ func (in *InTuple) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { // also if no match is found in the list and one of the expressions in the list is NULL. rightNull := false - left, _, err = typ.Convert(left) + left, _, err := typ.Convert(originalLeft) if err != nil { return nil, err } @@ -101,7 +105,8 @@ func (in *InTuple) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { } var cmp int - if types.IsDecimal(el.Type()) || types.IsFloat(el.Type()) { + elType := el.Type() + if types.IsDecimal(elType) || types.IsFloat(elType) { rtyp := el.Type().Promote() left, err := convertOrTruncate(ctx, left, rtyp) if err != nil { From 367930e7f63e693f32cbbf3327b8550da3641898 Mon Sep 17 00:00:00 2001 From: jycor Date: Fri, 9 Feb 2024 23:30:14 +0000 Subject: [PATCH 02/27] [ga-format-pr] Run ./format_repo.sh to fix formatting --- enginetest/memory_engine_test.go | 13 +++++-------- sql/expression/arithmetic.go | 2 +- sql/expression/div.go | 2 +- 3 files changed, 7 insertions(+), 10 deletions(-) diff --git a/enginetest/memory_engine_test.go b/enginetest/memory_engine_test.go index 5a9e1edae7..62e1983703 100644 --- a/enginetest/memory_engine_test.go +++ b/enginetest/memory_engine_test.go @@ -208,7 +208,7 @@ func TestSingleScript(t *testing.T) { Name: "delete me", SetUpScript: []string{ "CREATE TABLE tab0(col0 INTEGER, col1 INTEGER, col2 INTEGER);", - "INSERT INTO tab0 VALUES(97,1,99);", + "INSERT INTO tab0 VALUES(97,1,99);", }, Assertions: []queries.ScriptTestAssertion{ { @@ -256,17 +256,14 @@ func TestSingleScript(t *testing.T) { }, { - Query: "SELECT * FROM tab0 WHERE col2 IN ( 98 + 97 / 99 );", - Expected: []sql.Row{ - }, + Query: "SELECT * FROM tab0 WHERE col2 IN ( 98 + 97 / 99 );", + Expected: []sql.Row{}, }, { - Query: "SELECT ALL * FROM tab0 AS cor0 WHERE col2 IN ( 39 + + 89, col0 + + col1 + + ( - ( - col0 ) ) / col2, + ( col0 ) + - 99, + col1, + col2 * - + col2 * - 12 + col1 + - 66 );", - Expected: []sql.Row{ - }, + Query: "SELECT ALL * FROM tab0 AS cor0 WHERE col2 IN ( 39 + + 89, col0 + + col1 + + ( - ( - col0 ) ) / col2, + ( col0 ) + - 99, + col1, + col2 * - + col2 * - 12 + col1 + - 66 );", + Expected: []sql.Row{}, }, - { Query: `SELECT 1 IN (1 / 9 * 5);`, Expected: []sql.Row{{false}}, diff --git a/sql/expression/arithmetic.go b/sql/expression/arithmetic.go index 1f41836233..d2330933cb 100644 --- a/sql/expression/arithmetic.go +++ b/sql/expression/arithmetic.go @@ -189,7 +189,7 @@ func (a *Arithmetic) Type() sql.Type { } prec = prec + scale case sqlparser.MultStr: - scale = lScale+rScale + scale = lScale + rScale prec = prec + scale case sqlparser.DivStr: if lScale > rScale { diff --git a/sql/expression/div.go b/sql/expression/div.go index 3adb1e7191..00d4c5878c 100644 --- a/sql/expression/div.go +++ b/sql/expression/div.go @@ -329,7 +329,7 @@ func (d *Div) determineResultType(outermostResult bool) sql.Type { lPrec = lPrec + divPrecisionIncrement lScale = lScale + divPrecisionIncrement if d.ops == -1 { - lScale = (lScale / 9 + 1) * 9 + lScale = (lScale/9 + 1) * 9 if lScale > types.DecimalTypeMaxScale { lScale = types.DecimalTypeMaxScale } From 6e4b7a4186d97f464e9998116c27256007d20fa6 Mon Sep 17 00:00:00 2001 From: James Cor Date: Mon, 12 Feb 2024 10:29:39 -0800 Subject: [PATCH 03/27] guard max precision --- sql/expression/div.go | 4 ---- 1 file changed, 4 deletions(-) diff --git a/sql/expression/div.go b/sql/expression/div.go index 3adb1e7191..eaf84255fe 100644 --- a/sql/expression/div.go +++ b/sql/expression/div.go @@ -330,10 +330,6 @@ func (d *Div) determineResultType(outermostResult bool) sql.Type { lScale = lScale + divPrecisionIncrement if d.ops == -1 { lScale = (lScale / 9 + 1) * 9 - if lScale > types.DecimalTypeMaxScale { - lScale = types.DecimalTypeMaxScale - } - return types.MustCreateDecimalType(lPrec, lScale) } if lPrec > types.DecimalTypeMaxPrecision { From 562b30e8315e974f3841537444d34039a3b67676 Mon Sep 17 00:00:00 2001 From: James Cor Date: Mon, 12 Feb 2024 12:20:44 -0800 Subject: [PATCH 04/27] prevent scale > precision and wire fix --- sql/expression/div.go | 49 ++++++++++++++++++++++++------------------- 1 file changed, 28 insertions(+), 21 deletions(-) diff --git a/sql/expression/div.go b/sql/expression/div.go index e21d1d1949..72d66463a2 100644 --- a/sql/expression/div.go +++ b/sql/expression/div.go @@ -304,45 +304,52 @@ func (d *Div) determineResultType(outermostResult bool) sql.Type { return types.Float64 } - if !outermostResult && types.IsNumber(lTyp) && types.IsNumber(rTyp) { - // TODO: does this mean we're never using float optimization? - //return types.Float64 + if types.IsBinaryType(lTyp) || types.IsBinaryType(rTyp) { + return types.Float64 } - // numerical outermost results from here on - if types.IsFloat(lTyp) || types.IsFloat(rTyp) { return types.Float64 } - if types.IsInteger(lTyp) { - // TODO: determine precision? - if d.ops == -1 { - return types.MustCreateDecimalType(types.DecimalTypeMaxPrecision, 9) + // Decimal only results from here on + + if types.IsDatetimeType(lTyp) { + if dtType, ok := lTyp.(sql.DatetimeType); ok { + scale := uint8(dtType.Precision() + divPrecisionIncrement) + if scale > types.DecimalTypeMaxScale { + scale = types.DecimalTypeMaxScale + } + // TODO: determine actual precision + return types.MustCreateDecimalType(types.DecimalTypeMaxPrecision, scale) } - return types.MustCreateDecimalType(types.DecimalTypeMaxPrecision, divPrecisionIncrement) } if types.IsDecimal(lTyp) { // TODO: better verify this - lPrec, lScale := lTyp.(types.DecimalType_).Precision(), lTyp.(types.DecimalType_).Scale() - lPrec = lPrec + divPrecisionIncrement - lScale = lScale + divPrecisionIncrement + prec, scale := lTyp.(types.DecimalType_).Precision(), lTyp.(types.DecimalType_).Scale() + scale = scale + divPrecisionIncrement if d.ops == -1 { - lScale = (lScale/9 + 1) * 9 + scale = (scale/9 + 1) * 9 + prec = prec + scale + } else { + prec = prec + divPrecisionIncrement } - if lPrec > types.DecimalTypeMaxPrecision { - lPrec = types.DecimalTypeMaxPrecision + if prec > types.DecimalTypeMaxPrecision { + prec = types.DecimalTypeMaxPrecision } - if lScale > types.DecimalTypeMaxScale { - lScale = types.DecimalTypeMaxScale + if scale > types.DecimalTypeMaxScale { + scale = types.DecimalTypeMaxScale } - return types.MustCreateDecimalType(lPrec, lScale) + return types.MustCreateDecimalType(prec, scale) } - // TODO: missing cases - return lTyp + // All other types are treated as if they were integers + if d.ops == -1 { + return types.MustCreateDecimalType(types.DecimalTypeMaxPrecision, 9) + } + return types.MustCreateDecimalType(types.DecimalTypeMaxPrecision, divPrecisionIncrement) } // TODO: this is unused now, consider deleting From 6b9dcf7ad4043ea9f1f5b6b43d8f7df603028a58 Mon Sep 17 00:00:00 2001 From: James Cor Date: Mon, 12 Feb 2024 13:49:24 -0800 Subject: [PATCH 05/27] various type fixes --- enginetest/queries/procedure_queries.go | 8 +--- enginetest/queries/queries.go | 6 +++ enginetest/queries/tpch_plans.go | 20 ++++----- sql/expression/arithmetic.go | 54 ++++++++++++++++--------- sql/expression/div.go | 3 +- sql/expression/function/date.go | 4 +- sql/expression/function/str_to_date.go | 3 +- sql/expression/function/time.go | 4 ++ 8 files changed, 62 insertions(+), 40 deletions(-) diff --git a/enginetest/queries/procedure_queries.go b/enginetest/queries/procedure_queries.go index 1afb6f84ae..da6a430e1c 100644 --- a/enginetest/queries/procedure_queries.go +++ b/enginetest/queries/procedure_queries.go @@ -75,17 +75,13 @@ var ProcedureLogicTests = []ScriptTest{ { Query: "CALL testabc(2, 3)", Expected: []sql.Row{ - { - "6", - }, + {6.0}, }, }, { Query: "CALL testabc(9, 9.5)", Expected: []sql.Row{ - { - "85.5", - }, + {85.5}, }, }, }, diff --git a/enginetest/queries/queries.go b/enginetest/queries/queries.go index 7279afb458..7ff7a83e82 100644 --- a/enginetest/queries/queries.go +++ b/enginetest/queries/queries.go @@ -6228,6 +6228,7 @@ Select * from ( Expected: []sql.Row{{1}}, }, { + // TODO: this is invalid... Query: `SELECT DATETIME(NOW()) - NOW()`, Expected: []sql.Row{{int64(0)}}, }, @@ -6239,6 +6240,11 @@ Select * from ( Query: `SELECT STR_TO_DATE('01,5,2013 09:30:17','%d,%m,%Y %h:%i:%s') - (STR_TO_DATE('01,5,2013 09:30:17','%d,%m,%Y %h:%i:%s') - INTERVAL 1 SECOND)`, Expected: []sql.Row{{int64(1)}}, }, + // TODO: skip this test + //{ + // Query: `SELECT STR_TO_DATE('01,5,2013 09:30:17','%d,%m,%Y %h:%i:%s %f') - (STR_TO_DATE('01,5,2013 09:30:17','%d,%m,%Y %h:%i:%s') - INTERVAL 1 SECOND)`, + // Expected: []sql.Row{{int64(1)}}, + //}, { Query: `SELECT SUBSTR(SUBSTRING('0123456789ABCDEF', 1, 10), -4)`, Expected: []sql.Row{{"6789"}}, diff --git a/enginetest/queries/tpch_plans.go b/enginetest/queries/tpch_plans.go index e0b82c01ec..d71e0baebe 100644 --- a/enginetest/queries/tpch_plans.go +++ b/enginetest/queries/tpch_plans.go @@ -52,7 +52,7 @@ order by " └─ Filter\n" + " ├─ LessThanOrEqual\n" + " │ ├─ lineitem.l_shipdate:6!null\n" + - " │ └─ 1998-09-02 00:00:00 +0000 UTC (datetime(6))\n" + + " │ └─ 1998-09-02 00:00:00 +0000 UTC (datetime)\n" + " └─ ProcessTable\n" + " └─ Table\n" + " ├─ name: lineitem\n" + @@ -574,7 +574,7 @@ order by " │ │ └─ 1993-07-01 (longtext)\n" + " │ └─ LessThan\n" + " │ ├─ orders.o_orderdate:4!null\n" + - " │ └─ 1993-10-01 00:00:00 +0000 UTC (datetime(6))\n" + + " │ └─ 1993-10-01 00:00:00 +0000 UTC (datetime)\n" + " └─ IndexedTableAccess(orders)\n" + " ├─ index: [orders.O_ORDERKEY]\n" + " ├─ static: [{[NULL, ∞)}]\n" + @@ -702,7 +702,7 @@ order by " │ │ │ │ │ │ └─ 1994-01-01 (longtext)\n" + " │ │ │ │ │ └─ LessThan\n" + " │ │ │ │ │ ├─ orders.o_orderdate:2!null\n" + - " │ │ │ │ │ └─ 1995-01-01 00:00:00 +0000 UTC (datetime(6))\n" + + " │ │ │ │ │ └─ 1995-01-01 00:00:00 +0000 UTC (datetime)\n" + " │ │ │ │ └─ IndexedTableAccess(orders)\n" + " │ │ │ │ ├─ index: [orders.O_ORDERKEY]\n" + " │ │ │ │ ├─ static: [{[NULL, ∞)}]\n" + @@ -863,7 +863,7 @@ where " │ │ │ │ └─ 1994-01-01 (longtext)\n" + " │ │ │ └─ LessThan\n" + " │ │ │ ├─ lineitem.l_shipdate:3!null\n" + - " │ │ │ └─ 1995-01-01 00:00:00 +0000 UTC (datetime(6))\n" + + " │ │ │ └─ 1995-01-01 00:00:00 +0000 UTC (datetime)\n" + " │ │ └─ AND\n" + " │ │ ├─ GreaterThanOrEqual\n" + " │ │ │ ├─ lineitem.l_discount:2!null\n" + @@ -1738,7 +1738,7 @@ order by " │ │ │ │ └─ 1993-10-01 (longtext)\n" + " │ │ │ └─ LessThan\n" + " │ │ │ ├─ orders.o_orderdate:2!null\n" + - " │ │ │ └─ 1994-01-01 00:00:00 +0000 UTC (datetime(6))\n" + + " │ │ │ └─ 1994-01-01 00:00:00 +0000 UTC (datetime)\n" + " │ │ └─ IndexedTableAccess(orders)\n" + " │ │ ├─ index: [orders.O_ORDERKEY]\n" + " │ │ ├─ static: [{[NULL, ∞)}]\n" + @@ -2109,7 +2109,7 @@ order by " │ │ │ └─ 1994-01-01 (longtext)\n" + " │ │ └─ LessThan\n" + " │ │ ├─ lineitem.l_receiptdate:3!null\n" + - " │ │ └─ 1995-01-01 00:00:00 +0000 UTC (datetime(6))\n" + + " │ │ └─ 1995-01-01 00:00:00 +0000 UTC (datetime)\n" + " │ └─ IndexedTableAccess(lineitem)\n" + " │ ├─ index: [lineitem.L_ORDERKEY,lineitem.L_LINENUMBER]\n" + " │ ├─ static: [{[NULL, ∞), [NULL, ∞)}]\n" + @@ -2315,7 +2315,7 @@ where " │ │ │ └─ 1995-09-01 (longtext)\n" + " │ │ └─ LessThan\n" + " │ │ ├─ lineitem.l_shipdate:3!null\n" + - " │ │ └─ 1995-10-01 00:00:00 +0000 UTC (datetime(6))\n" + + " │ │ └─ 1995-10-01 00:00:00 +0000 UTC (datetime)\n" + " │ └─ ProcessTable\n" + " │ └─ Table\n" + " │ ├─ name: lineitem\n" + @@ -2430,7 +2430,7 @@ order by " │ │ │ └─ 1996-01-01 (longtext)\n" + " │ │ └─ LessThan\n" + " │ │ ├─ lineitem.l_shipdate:12!null\n" + - " │ │ └─ 1996-04-01 00:00:00 +0000 UTC (datetime(6))\n" + + " │ │ └─ 1996-04-01 00:00:00 +0000 UTC (datetime)\n" + " │ └─ Table\n" + " │ ├─ name: lineitem\n" + " │ ├─ columns: [l_suppkey l_extendedprice l_discount l_shipdate]\n" + @@ -2456,7 +2456,7 @@ order by " │ │ │ └─ 1996-01-01 (longtext)\n" + " │ │ └─ LessThan\n" + " │ │ ├─ lineitem.l_shipdate:3!null\n" + - " │ │ └─ 1996-04-01 00:00:00 +0000 UTC (datetime(6))\n" + + " │ │ └─ 1996-04-01 00:00:00 +0000 UTC (datetime)\n" + " │ └─ Table\n" + " │ ├─ name: lineitem\n" + " │ ├─ columns: [l_suppkey l_extendedprice l_discount l_shipdate]\n" + @@ -3311,7 +3311,7 @@ order by " │ │ │ │ │ └─ 1994-01-01 (longtext)\n" + " │ │ │ │ └─ LessThan\n" + " │ │ │ │ ├─ lineitem.l_shipdate:8!null\n" + - " │ │ │ │ └─ 1995-01-01 00:00:00 +0000 UTC (datetime(6))\n" + + " │ │ │ │ └─ 1995-01-01 00:00:00 +0000 UTC (datetime)\n" + " │ │ │ └─ Table\n" + " │ │ │ ├─ name: lineitem\n" + " │ │ │ ├─ columns: [l_partkey l_suppkey l_quantity l_shipdate]\n" + diff --git a/sql/expression/arithmetic.go b/sql/expression/arithmetic.go index d2330933cb..ef6100ab6a 100644 --- a/sql/expression/arithmetic.go +++ b/sql/expression/arithmetic.go @@ -134,31 +134,53 @@ func (a *Arithmetic) Type() sql.Type { // applies for + and - ops if isInterval(a.LeftChild) || isInterval(a.RightChild) { - // TODO: we might need to truncate precision here - return types.DatetimeMaxPrecision + // TODO: need to use the precision stored in datetimeType + // return MustCreateDatetimeType(sqltypes.Datetime, ...) + return types.Datetime } - if types.IsTime(lTyp) && types.IsTime(rTyp) { - return types.Int64 + if types.IsText(lTyp) || types.IsText(rTyp) { + return types.Float64 + } + + if types.IsJSON(lTyp) || types.IsJSON(rTyp) { + return types.Float64 } - if !types.IsNumber(lTyp) || !types.IsNumber(rTyp) { + if types.IsFloat(lTyp) || types.IsFloat(rTyp) { return types.Float64 } + // Datetimes are decimals, unless they have precision 0 + if types.IsDatetimeType(lTyp) { + if dtType, ok := lTyp.(sql.DatetimeType); ok { + scale := uint8(dtType.Precision()) + if scale == 0 { + lTyp = types.Int64 + } else { + lTyp = types.MustCreateDecimalType(types.DecimalTypeMaxPrecision, scale) + } + } + } + if types.IsDatetimeType(rTyp) { + if dtType, ok := rTyp.(sql.DatetimeType); ok { + scale := uint8(dtType.Precision()) + if scale == 0 { + rTyp = types.Int64 + } else { + rTyp = types.MustCreateDecimalType(types.DecimalTypeMaxPrecision, scale) + } + } + } + if types.IsUnsigned(lTyp) && types.IsUnsigned(rTyp) { return types.Uint64 - } else if types.IsSigned(lTyp) && types.IsSigned(rTyp) { - return types.Int64 } - // if one is uint and the other is int of any size, then use int64 if types.IsInteger(lTyp) && types.IsInteger(rTyp) { return types.Int64 } - // TODO: special cases for div, intdiv, and mod? - if types.IsDecimal(lTyp) && !types.IsDecimal(rTyp) { return lTyp } @@ -173,6 +195,7 @@ func (a *Arithmetic) Type() sql.Type { rPrec := rTyp.(types.DecimalType_).Precision() rScale := rTyp.(types.DecimalType_).Scale() + // TODO: determine real precision var prec, scale uint8 if lPrec > rPrec { prec = lPrec @@ -191,14 +214,6 @@ func (a *Arithmetic) Type() sql.Type { case sqlparser.MultStr: scale = lScale + rScale prec = prec + scale - case sqlparser.DivStr: - if lScale > rScale { - scale = lScale - } else { - scale = rScale - } - scale = scale + divPrecisionIncrement - prec = prec + divPrecisionIncrement } if prec > types.DecimalTypeMaxPrecision { @@ -211,7 +226,8 @@ func (a *Arithmetic) Type() sql.Type { return types.MustCreateDecimalType(prec, scale) } - return getFloatOrMaxDecimalType(a, false) + // When in doubt return float64 + return types.Float64 } // CollationCoercibility implements the interface sql.CollationCoercible. diff --git a/sql/expression/div.go b/sql/expression/div.go index 72d66463a2..0838015511 100644 --- a/sql/expression/div.go +++ b/sql/expression/div.go @@ -304,7 +304,7 @@ func (d *Div) determineResultType(outermostResult bool) sql.Type { return types.Float64 } - if types.IsBinaryType(lTyp) || types.IsBinaryType(rTyp) { + if types.IsJSON(lTyp) || types.IsJSON(rTyp) { return types.Float64 } @@ -326,7 +326,6 @@ func (d *Div) determineResultType(outermostResult bool) sql.Type { } if types.IsDecimal(lTyp) { - // TODO: better verify this prec, scale := lTyp.(types.DecimalType_).Precision(), lTyp.(types.DecimalType_).Scale() scale = scale + divPrecisionIncrement if d.ops == -1 { diff --git a/sql/expression/function/date.go b/sql/expression/function/date.go index ecf8788848..8498a21dad 100644 --- a/sql/expression/function/date.go +++ b/sql/expression/function/date.go @@ -280,7 +280,7 @@ func (t *TimestampConversion) String() string { } func (t *TimestampConversion) Type() sql.Type { - return types.TimestampMaxPrecision + return t.Date.Type() } // CollationCoercibility implements the interface sql.CollationCoercible. @@ -346,7 +346,7 @@ func (t *DatetimeConversion) String() string { } func (t *DatetimeConversion) Type() sql.Type { - return types.DatetimeMaxPrecision + return t.Date.Type() } // CollationCoercibility implements the interface sql.CollationCoercible. diff --git a/sql/expression/function/str_to_date.go b/sql/expression/function/str_to_date.go index bae5de3812..def6760c3e 100644 --- a/sql/expression/function/str_to_date.go +++ b/sql/expression/function/str_to_date.go @@ -47,7 +47,8 @@ func (s StrToDate) String() string { // Type returns the expression type. func (s StrToDate) Type() sql.Type { - return types.DatetimeMaxPrecision + // TODO: needs to take into account precision + return types.Datetime } // CollationCoercibility implements the interface sql.CollationCoercible. diff --git a/sql/expression/function/time.go b/sql/expression/function/time.go index f2d0e4bb3a..e2199396dd 100644 --- a/sql/expression/function/time.go +++ b/sql/expression/function/time.go @@ -948,6 +948,10 @@ func (n *Now) Description() string { // Type implements the sql.Expression interface. func (n *Now) Type() sql.Type { + // TODO: This should be types.NewDatetime(n.prec) + if n.prec == nil { + return types.Datetime + } return types.DatetimeMaxPrecision } From e7f396a6c4fc8348a9d0b343348aaad29f93d9fc Mon Sep 17 00:00:00 2001 From: James Cor Date: Mon, 12 Feb 2024 16:05:54 -0800 Subject: [PATCH 06/27] rewriting tests for add, sub, and mult --- sql/expression/arithmetic.go | 21 ++ sql/expression/arithmetic_test.go | 607 ++++++++++++++++++++++-------- sql/types/typecheck.go | 6 + 3 files changed, 487 insertions(+), 147 deletions(-) diff --git a/sql/expression/arithmetic.go b/sql/expression/arithmetic.go index ef6100ab6a..1377e5b0f7 100644 --- a/sql/expression/arithmetic.go +++ b/sql/expression/arithmetic.go @@ -151,6 +151,27 @@ func (a *Arithmetic) Type() sql.Type { return types.Float64 } + if types.IsYear(lTyp) && types.IsYear(rTyp) { + // MySQL just returns the largest int that fits + return types.Uint64 + } + + // Bit types are integers + if types.IsBit(lTyp) { + lTyp = types.Int64 + } + if types.IsBit(rTyp) { + rTyp = types.Int64 + } + + // Dates are Integers + if types.IsDateType(lTyp) { + lTyp = types.Int64 + } + if types.IsDateType(rTyp) { + rTyp = types.Int64 + } + // Datetimes are decimals, unless they have precision 0 if types.IsDatetimeType(lTyp) { if dtType, ok := lTyp.(sql.DatetimeType); ok { diff --git a/sql/expression/arithmetic_test.go b/sql/expression/arithmetic_test.go index df9e9b6936..8aa5f1ec03 100644 --- a/sql/expression/arithmetic_test.go +++ b/sql/expression/arithmetic_test.go @@ -15,155 +15,515 @@ package expression import ( + "fmt" "testing" "time" "github.com/shopspring/decimal" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "gopkg.in/src-d/go-errors.v1" "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/go-mysql-server/sql/types" _ "github.com/dolthub/go-mysql-server/sql/variables" + "github.com/dolthub/vitess/go/sqltypes" ) func TestPlus(t *testing.T) { var testCases = []struct { - name string - left, right float64 - expected string + name string + left sql.Expression + right sql.Expression + exp interface{} + skip bool }{ - {"1 + 1", 1, 1, "2"}, - {"-1 + 1", -1, 1, "0"}, - {"0 + 0", 0, 0, "0"}, - {"0.14159 + 3.0", 0.14159, 3.0, "3.14159"}, + { + left: NewLiteral(1, types.Uint32), + right: NewLiteral(1, types.Uint32), + exp: uint64(2), + }, + { + left: NewLiteral(1, types.Uint64), + right: NewLiteral(1, types.Uint64), + exp: uint64(2), + }, + { + left: NewLiteral(1, types.Int32), + right: NewLiteral(1, types.Int32), + exp: int64(2), + }, + { + left: NewLiteral(1, types.Int64), + right: NewLiteral(1, types.Int64), + exp: int64(2), + }, + { + left: NewLiteral(0, types.Int64), + right: NewLiteral(0, types.Int64), + exp: int64(0), + }, + { + left: NewLiteral(-1, types.Int64), + right: NewLiteral(1, types.Int64), + exp: int64(0), + }, + { + left: NewLiteral(1, types.Float32), + right: NewLiteral(1, types.Float32), + exp: float64(2), + }, + { + left: NewLiteral(1, types.Float64), + right: NewLiteral(1, types.Float64), + exp: float64(2), + }, + { + left: NewLiteral(0.1459, types.Float64), + right: NewLiteral(3.0, types.Float64), + exp: 3.1459, + }, + { + left: NewLiteral(decimal.New(1, 0), types.MustCreateDecimalType(10, 0)), + right: NewLiteral(decimal.New(1, 0), types.MustCreateDecimalType(10, 0)), + exp: "2", + }, + { + left: NewLiteral(decimal.New(1000, -3), types.MustCreateDecimalType(10, 3)), // 1.000 + right: NewLiteral(decimal.New(1, 0), types.MustCreateDecimalType(10, 0)), + exp: "2.000", + }, + { + left: NewLiteral(decimal.New(1000, -3), types.MustCreateDecimalType(10, 3)), // 1.000 + right: NewLiteral(decimal.New(100000, -5), types.MustCreateDecimalType(10, 5)), // 1.00000 + exp: "2.00000", + }, + { + left: NewLiteral(decimal.New(1459, -4), types.MustCreateDecimalType(10, 4)), // 0.1459 + right: NewLiteral(decimal.New(3, 0), types.MustCreateDecimalType(10, 0)), // 3 + exp: "3.1459", + }, + { + left: NewLiteral(2001, types.Year), + right: NewLiteral(2002, types.Year), + exp: uint64(4003), + }, + { + left: NewLiteral("2001-01-01", types.Date), + right: NewLiteral("2001-01-01", types.Date), + exp: int64(40020202), + }, + { + skip: true, // need to trim just the date portion + left: NewLiteral("2001-01-01 12:00:00", types.Date), + right: NewLiteral("2001-01-01 12:00:00", types.Date), + exp: int64(40020202), + }, + { + skip: true, // need to trim just the date portion + left: NewLiteral("2001-01-01 12:00:00.123456", types.Date), + right: NewLiteral("2001-01-01 12:00:00.123456", types.Date), + exp: int64(40020202), + }, + { + left: NewLiteral("2001-01-01 12:00:00", types.Datetime), + right: NewLiteral("2001-01-01 12:00:00", types.Datetime), + exp: int64(40020202240000), + }, + { + skip: true, // need to trim just the datetime portion according to precision + left: NewLiteral("2001-01-01 12:00:00.123456", types.Datetime), + right: NewLiteral("2001-01-01 12:00:00.123456", types.Datetime), + exp: int64(40020202240000), + }, + { + skip: true, // need to trim just the datetime portion according to precision and use as exponent + left: NewLiteral("2001-01-01 12:00:00.123456", types.MustCreateDatetimeType(sqltypes.Datetime, 3)), + right: NewLiteral("2001-01-01 12:00:00.123456", types.MustCreateDatetimeType(sqltypes.Datetime, 3)), + exp: "40020202240000.246", + }, + { + skip: true, // need to use precision as exponent + left: NewLiteral("2001-01-01 12:00:00.123456", types.DatetimeMaxPrecision), + right: NewLiteral("2001-01-01 12:00:00.123456", types.DatetimeMaxPrecision), + exp: "40020202240000.246912", + }, + { + left: NewLiteral("1", types.Text), + right: NewLiteral("1", types.Text), + exp: float64(2), + }, + { + left: NewLiteral("1", types.Text), + right: NewLiteral(1.0, types.Float64), + exp: float64(2), + }, + { + left: NewLiteral(1, types.MustCreateBitType(1)), + right: NewLiteral(0, types.MustCreateBitType(1)), + exp: int64(1), + }, + { + left: NewLiteral("2018-05-01", types.LongText), + right: NewInterval(NewLiteral(int64(1), types.Int64), "DAY"), + exp: time.Date(2018, time.May, 2, 0, 0, 0, 0, time.UTC), + }, + { + left: NewInterval(NewLiteral(int64(1), types.Int64), "DAY"), + right: NewLiteral("2018-05-01", types.LongText), + exp: time.Date(2018, time.May, 2, 0, 0, 0, 0, time.UTC), + }, } for _, tt := range testCases { - t.Run(tt.name, func(t *testing.T) { + name := fmt.Sprintf("%s(%v)+%s(%v)", tt.left.Type(), tt.left, tt.right.Type(), tt.right) + t.Run(name, func(t *testing.T) { require := require.New(t) - result, err := NewPlus( - NewLiteral(tt.left, types.Float64), - NewLiteral(tt.right, types.Float64), - ).Eval(sql.NewEmptyContext(), sql.NewRow()) + if tt.skip { + t.Skip() + } + f := NewPlus(tt.left, tt.right) + result, err := f.Eval(sql.NewEmptyContext(), nil) require.NoError(err) - r, ok := result.(decimal.Decimal) - assert.True(t, ok) - assert.Equal(t, tt.expected, r.StringFixed(r.Exponent()*-1)) + if dec, ok := result.(decimal.Decimal); ok { + result = dec.StringFixed(dec.Exponent()*-1) + } + assert.Equal(t, tt.exp, result) }) } - - require := require.New(t) - result, err := NewPlus(NewLiteral("2", types.LongText), NewLiteral(3, types.Float64)). - Eval(sql.NewEmptyContext(), sql.NewRow()) - require.NoError(err) - require.Equal(5.0, result) -} - -func TestPlusInterval(t *testing.T) { - require := require.New(t) - - expected := time.Date(2018, time.May, 2, 0, 0, 0, 0, time.UTC) - op := NewPlus( - NewLiteral("2018-05-01", types.LongText), - NewInterval(NewLiteral(int64(1), types.Int64), "DAY"), - ) - - result, err := op.Eval(sql.NewEmptyContext(), nil) - require.NoError(err) - require.Equal(expected, result) - - op = NewPlus( - NewInterval(NewLiteral(int64(1), types.Int64), "DAY"), - NewLiteral("2018-05-01", types.LongText), - ) - - result, err = op.Eval(sql.NewEmptyContext(), nil) - require.NoError(err) - require.Equal(expected, result) } func TestMinus(t *testing.T) { var testCases = []struct { - name string - left, right float64 - expected string + name string + left sql.Expression + right sql.Expression + exp interface{} + skip bool }{ - {"1 - 1", 1, 1, "0"}, - {"1 - -1", 1, -1, "2"}, - {"0 - 0", 0, 0, "0"}, - {"3.14159 - 3.0", 3.14159, 3.0, "0.14159"}, + { + left: NewLiteral(1, types.Uint32), + right: NewLiteral(1, types.Uint32), + exp: uint64(0), + }, + { + left: NewLiteral(1, types.Uint64), + right: NewLiteral(1, types.Uint64), + exp: uint64(0), + }, + { + left: NewLiteral(1, types.Int32), + right: NewLiteral(1, types.Int32), + exp: int64(0), + }, + { + left: NewLiteral(1, types.Int64), + right: NewLiteral(1, types.Int64), + exp: int64(0), + }, + { + left: NewLiteral(0, types.Int64), + right: NewLiteral(0, types.Int64), + exp: int64(0), + }, + { + left: NewLiteral(-1, types.Int64), + right: NewLiteral(1, types.Int64), + exp: int64(-2), + }, + { + left: NewLiteral(1, types.Float32), + right: NewLiteral(1, types.Float32), + exp: float64(0), + }, + { + left: NewLiteral(1, types.Float64), + right: NewLiteral(1, types.Float64), + exp: float64(0), + }, + { + left: NewLiteral(0.1459, types.Float64), + right: NewLiteral(3.0, types.Float64), + exp: -2.8541, + }, + { + left: NewLiteral(decimal.New(1, 0), types.MustCreateDecimalType(10, 0)), + right: NewLiteral(decimal.New(1, 0), types.MustCreateDecimalType(10, 0)), + exp: "0", + }, + { + left: NewLiteral(decimal.New(1000, -3), types.MustCreateDecimalType(10, 3)), // 1.000 + right: NewLiteral(decimal.New(1, 0), types.MustCreateDecimalType(10, 0)), + exp: "0.000", + }, + { + left: NewLiteral(decimal.New(1000, -3), types.MustCreateDecimalType(10, 3)), // 1.000 + right: NewLiteral(decimal.New(100000, -5), types.MustCreateDecimalType(10, 5)), // 1.00000 + exp: "0.00000", + }, + { + left: NewLiteral(decimal.New(1459, -4), types.MustCreateDecimalType(10, 4)), // 0.1459 + right: NewLiteral(decimal.New(3, 0), types.MustCreateDecimalType(10, 0)), // 3 + exp: "-2.8541", + }, + { + left: NewLiteral(2002, types.Year), + right: NewLiteral(2001, types.Year), + exp: uint64(1), + }, + { + left: NewLiteral("2001-01-01", types.Date), + right: NewLiteral("2001-01-01", types.Date), + exp: int64(0), + }, + { + skip: true, // need to trim just the date portion + left: NewLiteral("2001-01-01 12:00:00", types.Date), + right: NewLiteral("2001-01-01 12:00:00", types.Date), + exp: int64(0), + }, + { + skip: true, // need to trim just the date portion + left: NewLiteral("2001-01-01 12:00:00.123456", types.Date), + right: NewLiteral("2001-01-01 12:00:00.123456", types.Date), + exp: int64(0), + }, + { + left: NewLiteral("2001-01-01 12:00:00", types.Datetime), + right: NewLiteral("2001-01-01 12:00:00", types.Datetime), + exp: int64(0), + }, + { + skip: true, // need to trim just the datetime portion according to precision + left: NewLiteral("2001-01-01 12:00:00.123456", types.Datetime), + right: NewLiteral("2001-01-01 12:00:00.123456", types.Datetime), + exp: int64(0), + }, + { + skip: true, // need to trim just the datetime portion according to precision and use as exponent + left: NewLiteral("2001-01-01 12:00:00.123456", types.MustCreateDatetimeType(sqltypes.Datetime, 3)), + right: NewLiteral("2001-01-01 12:00:00.123456", types.MustCreateDatetimeType(sqltypes.Datetime, 3)), + exp: "0.000", + }, + { + skip: true, // need to use precision as exponent + left: NewLiteral("2001-01-01 12:00:00.123456", types.DatetimeMaxPrecision), + right: NewLiteral("2001-01-01 12:00:00.123456", types.DatetimeMaxPrecision), + exp: "0.000000", + }, + { + left: NewLiteral("1", types.Text), + right: NewLiteral("1", types.Text), + exp: float64(0), + }, + { + left: NewLiteral("1", types.Text), + right: NewLiteral(1.0, types.Float64), + exp: float64(0), + }, + { + left: NewLiteral(1, types.MustCreateBitType(1)), + right: NewLiteral(0, types.MustCreateBitType(1)), + exp: int64(1), + }, + { + left: NewLiteral("2018-05-01", types.LongText), + right: NewInterval(NewLiteral(int64(1), types.Int64), "DAY"), + exp: time.Date(2018, time.April, 30, 0, 0, 0, 0, time.UTC), + }, } for _, tt := range testCases { - t.Run(tt.name, func(t *testing.T) { + name := fmt.Sprintf("%s(%v)-%s(%v)", tt.left.Type(), tt.left, tt.right.Type(), tt.right) + t.Run(name, func(t *testing.T) { require := require.New(t) - result, err := NewMinus( - NewLiteral(tt.left, types.Float64), - NewLiteral(tt.right, types.Float64), - ).Eval(sql.NewEmptyContext(), sql.NewRow()) + if tt.skip { + t.Skip() + } + f := NewMinus(tt.left, tt.right) + result, err := f.Eval(sql.NewEmptyContext(), nil) require.NoError(err) - r, ok := result.(decimal.Decimal) - assert.True(t, ok) - assert.Equal(t, tt.expected, r.StringFixed(r.Exponent()*-1)) + if dec, ok := result.(decimal.Decimal); ok { + result = dec.StringFixed(dec.Exponent()*-1) + } + assert.Equal(t, tt.exp, result) }) } - - require := require.New(t) - result, err := NewMinus(NewLiteral("10", types.LongText), NewLiteral(10, types.Int64)). - Eval(sql.NewEmptyContext(), sql.NewRow()) - require.NoError(err) - require.Equal(0.0, result) -} - -func TestMinusInterval(t *testing.T) { - require := require.New(t) - - expected := time.Date(2018, time.May, 1, 0, 0, 0, 0, time.UTC) - op := NewMinus( - NewLiteral("2018-05-02", types.LongText), - NewInterval(NewLiteral(int64(1), types.Int64), "DAY"), - ) - - result, err := op.Eval(sql.NewEmptyContext(), nil) - require.NoError(err) - require.Equal(expected, result) } func TestMult(t *testing.T) { var testCases = []struct { - name string - left, right float64 - expected string + name string + left sql.Expression + right sql.Expression + exp interface{} + err *errors.Kind + skip bool }{ - {"1 * 1", 1, 1, "1"}, - {"-1 * 1", -1, 1, "-1"}, - {"0 * 0", 0, 0, "0"}, - {"3.14159 * 3.0", 3.14159, 3.0, "9.42477"}, + { + left: NewLiteral(1, types.Uint32), + right: NewLiteral(1, types.Uint32), + exp: uint64(1), + }, + { + left: NewLiteral(1, types.Uint64), + right: NewLiteral(1, types.Uint64), + exp: uint64(1), + }, + { + left: NewLiteral(1, types.Int32), + right: NewLiteral(1, types.Int32), + exp: int64(1), + }, + { + left: NewLiteral(1, types.Int64), + right: NewLiteral(1, types.Int64), + exp: int64(1), + }, + { + left: NewLiteral(0, types.Int64), + right: NewLiteral(0, types.Int64), + exp: int64(0), + }, + { + left: NewLiteral(-1, types.Int64), + right: NewLiteral(1, types.Int64), + exp: int64(-1), + }, + { + left: NewLiteral(1, types.Float32), + right: NewLiteral(1, types.Float32), + exp: float64(1), + }, + { + left: NewLiteral(1, types.Float64), + right: NewLiteral(1, types.Float64), + exp: float64(1), + }, + { + left: NewLiteral(0.1459, types.Float64), + right: NewLiteral(3.0, types.Float64), + exp: 0.4377, + }, + { + left: NewLiteral(3.1459, types.Float64), + right: NewLiteral(3.0, types.Float64), + exp: 9.4377, + }, + { + left: NewLiteral(decimal.New(1, 0), types.MustCreateDecimalType(10, 0)), + right: NewLiteral(decimal.New(1, 0), types.MustCreateDecimalType(10, 0)), + exp: "1", + }, + { + left: NewLiteral(decimal.New(1000, -3), types.MustCreateDecimalType(10, 3)), // 1.000 + right: NewLiteral(decimal.New(1, 0), types.MustCreateDecimalType(10, 0)), + exp: "1.000", + }, + { + left: NewLiteral(decimal.New(1000, -3), types.MustCreateDecimalType(10, 3)), // 1.000 + right: NewLiteral(decimal.New(100000, -5), types.MustCreateDecimalType(10, 5)), // 1.00000 + exp: "1.00000000", + }, + { + left: NewLiteral(decimal.New(1459, -4), types.MustCreateDecimalType(10, 4)), // 0.1459 + right: NewLiteral(decimal.New(3, 0), types.MustCreateDecimalType(10, 0)), // 3 + exp: "0.4377", + }, + { + left: NewLiteral(decimal.New(31459, -4), types.MustCreateDecimalType(10, 4)), // 3.1459 + right: NewLiteral(decimal.New(3, 0), types.MustCreateDecimalType(10, 0)), // 3 + exp: "9.4377", + }, + { + left: NewLiteral(2002, types.Year), + right: NewLiteral(2001, types.Year), + exp: uint64(4006002), + }, + { + left: NewLiteral("2001-01-01", types.Date), + right: NewLiteral("2001-01-01", types.Date), + exp: int64(400404142030201), + }, + { + skip: true, // need to trim just the date portion + left: NewLiteral("2001-01-01 12:00:00", types.Date), + right: NewLiteral("2001-01-01 12:00:00", types.Date), + exp: int64(400404142030201), + }, + { + skip: true, // need to trim just the date portion + left: NewLiteral("2001-01-01 12:00:00.123456", types.Date), + right: NewLiteral("2001-01-01 12:00:00.123456", types.Date), + exp: int64(400404142030201), + }, + { + // MySQL throws out of range + skip: true, + left: NewLiteral("2001-01-01 12:00:00", types.Datetime), + right: NewLiteral("2001-01-01 12:00:00", types.Datetime), + err: sql.ErrValueOutOfRange, + }, + { + skip: true, // need to trim just the datetime portion according to precision + left: NewLiteral("2001-01-01 12:00:00.123456", types.Datetime), + right: NewLiteral("2001-01-01 12:00:00.123456", types.Datetime), + err: sql.ErrValueOutOfRange, + }, + { + skip: true, // need to trim just the datetime portion according to precision and use as exponent + left: NewLiteral("2001-01-01 12:00:00.123456", types.MustCreateDatetimeType(sqltypes.Datetime, 3)), + right: NewLiteral("2001-01-01 12:00:00.123456", types.MustCreateDatetimeType(sqltypes.Datetime, 3)), + exp: "400404146832630176884875520.015129", + }, + { + skip: true, // need to use precision as exponent + left: NewLiteral("2001-01-01 12:00:00.123456", types.DatetimeMaxPrecision), + right: NewLiteral("2001-01-01 12:00:00.123456", types.DatetimeMaxPrecision), + exp: "400404146832630195134087741.455241383936", + }, + { + left: NewLiteral("10", types.Text), + right: NewLiteral("10", types.Text), + exp: float64(100), + }, + { + left: NewLiteral("10", types.Text), + right: NewLiteral(10.0, types.Float64), + exp: float64(100), + }, + { + left: NewLiteral(1, types.MustCreateBitType(1)), + right: NewLiteral(0, types.MustCreateBitType(1)), + exp: int64(0), + }, } for _, tt := range testCases { - t.Run(tt.name, func(t *testing.T) { + name := fmt.Sprintf("%s(%v)*%s(%v)", tt.left.Type(), tt.left, tt.right.Type(), tt.right) + t.Run(name, func(t *testing.T) { require := require.New(t) - result, err := NewMult( - NewLiteral(tt.left, types.Float64), - NewLiteral(tt.right, types.Float64), - ).Eval(sql.NewEmptyContext(), sql.NewRow()) + if tt.skip { + t.Skip() + } + f := NewMult(tt.left, tt.right) + result, err := f.Eval(sql.NewEmptyContext(), nil) + if tt.err != nil { + require.Error(err) + require.True(tt.err.Is(err), err.Error()) + return + } require.NoError(err) - r, ok := result.(decimal.Decimal) - assert.True(t, ok) - assert.Equal(t, tt.expected, r.StringFixed(r.Exponent()*-1)) + if dec, ok := result.(decimal.Decimal); ok { + result = dec.StringFixed(dec.Exponent()*-1) + } + assert.Equal(t, tt.exp, result) }) } - - require := require.New(t) - result, err := NewMult(NewLiteral("10", types.LongText), NewLiteral("10", types.LongText)). - Eval(sql.NewEmptyContext(), sql.NewRow()) - require.NoError(err) - require.Equal(100.0, result) } func TestMod(t *testing.T) { + // TODO: make this match the others var testCases = []struct { name string left, right int64 @@ -201,53 +561,6 @@ func TestMod(t *testing.T) { } } -func TestAllFloat64(t *testing.T) { - var testCases = []struct { - op string - value float64 - expected string - }{ - // The value here are given with decimal place to force the value type to float, but the interpreted values - // will not have 0 scale, so the mult is 3.0000 * 0 = 0.0000 instead of 3.0000 * 0.0 = 0.00000 - {"+", 1.0, "1"}, - {"-", -8.0, "9"}, - {"/", 3.0, "3.0000"}, - {"*", 4.0, "12.0000"}, - {"%", 11, "1.0000"}, - } - - // ((((0 + 1) - (-8)) / 3) * 4) % 11 == 1 - lval := NewLiteral(float64(0.0), types.Float64) - for _, tt := range testCases { - t.Run(tt.op, func(t *testing.T) { - require := require.New(t) - var result interface{} - var err error - if tt.op == "/" { - result, err = NewDiv(lval, - NewLiteral(tt.value, types.Float64), - ).Eval(sql.NewEmptyContext(), sql.NewRow()) - } else if tt.op == "%" { - result, err = NewMod(lval, - NewLiteral(tt.value, types.Float64), - ).Eval(sql.NewEmptyContext(), sql.NewRow()) - } else { - result, err = NewArithmetic(lval, - NewLiteral(tt.value, types.Float64), tt.op, - ).Eval(sql.NewEmptyContext(), sql.NewRow()) - } - require.NoError(err) - if r, ok := result.(decimal.Decimal); ok { - assert.Equal(t, tt.expected, r.StringFixed(r.Exponent()*-1)) - } else { - assert.Equal(t, tt.expected, result) - } - - lval = NewLiteral(result, types.Float64) - }) - } -} - func TestUnaryMinus(t *testing.T) { testCases := []struct { name string diff --git a/sql/types/typecheck.go b/sql/types/typecheck.go index 298d5d4d2d..6cef518a9a 100644 --- a/sql/types/typecheck.go +++ b/sql/types/typecheck.go @@ -192,3 +192,9 @@ func IsUnsigned(t sql.Type) bool { return t == Uint8 || t == Uint16 || t == Uint24 || t == Uint32 || t == Uint64 } + +// IsYear checks if t is a year type. +func IsYear(t sql.Type) bool { + _, ok := t.(YearType_) + return ok +} From 847f4f01a710c439658c32866a1ea6a0b922f199 Mon Sep 17 00:00:00 2001 From: James Cor Date: Mon, 12 Feb 2024 16:56:31 -0800 Subject: [PATCH 07/27] rewriting div tests --- sql/expression/div.go | 5 + sql/expression/div_test.go | 373 ++++++++++++++++++++++++++++--------- 2 files changed, 290 insertions(+), 88 deletions(-) diff --git a/sql/expression/div.go b/sql/expression/div.go index 0838015511..be16f3e072 100644 --- a/sql/expression/div.go +++ b/sql/expression/div.go @@ -312,6 +312,11 @@ func (d *Div) determineResultType(outermostResult bool) sql.Type { return types.Float64 } + // TODO: see if we can actually do this + //if !outermostResult { + // return types.Float64 + //} + // Decimal only results from here on if types.IsDatetimeType(lTyp) { diff --git a/sql/expression/div_test.go b/sql/expression/div_test.go index 669786b080..44c49ebb64 100644 --- a/sql/expression/div_test.go +++ b/sql/expression/div_test.go @@ -15,6 +15,9 @@ package expression import ( + "fmt" + "github.com/dolthub/vitess/go/sqltypes" +"gopkg.in/src-d/go-errors.v1" "testing" "github.com/shopspring/decimal" @@ -26,92 +29,290 @@ import ( ) func TestDiv(t *testing.T) { - var floatTestCases = []struct { - name string - left, right float64 - expected float64 - null bool + var testCases = []struct { + name string + left sql.Expression + right sql.Expression + exp interface{} + err *errors.Kind + skip bool }{ - {"1 / 1", 1, 1, 1.0, false}, - {"1 / 2", 1, 2, 0.5, false}, - {"-1 / 1.0", -1, 1, -1.0, false}, - {"0 / 1234567890", 0, 12345677890, 0.0, false}, - {"3.14159 / 3.0", 3.14159, 3.0, 1.0471966666666666, false}, - {"1/0", 1, 0, 0.0, true}, - {"-1/0", -1, 0, 0.0, true}, - {"0/0", 0, 0, 0.0, true}, - } + { + left: NewLiteral(1, types.Int64), + right: NewLiteral(0, types.Int64), + exp: nil, + }, - for _, tt := range floatTestCases { - t.Run(tt.name, func(t *testing.T) { - // The numbers are interpreted as Float64 without going through parser, so we lose precision here for 1.0 - result, err := NewDiv( - NewLiteral(tt.left, types.Float64), - NewLiteral(tt.right, types.Float64), - ).Eval(sql.NewEmptyContext(), sql.NewRow()) - require.NoError(t, err) - if tt.null { - assert.Equal(t, nil, result) - } else { - assert.Equal(t, tt.expected, result) - } - }) - } + // Unsigned Integers + { + left: NewLiteral(1, types.Uint32), + right: NewLiteral(1, types.Uint32), + exp: "1.0000", + }, + { + left: NewLiteral(1, types.Uint32), + right: NewLiteral(2, types.Uint32), + exp: "0.5000", + }, + { + left: NewLiteral(1, types.Uint64), + right: NewLiteral(1, types.Uint64), + exp: "1.0000", + }, + { + left: NewLiteral(1, types.Uint64), + right: NewLiteral(2, types.Uint64), + exp: "0.5000", + }, - var intTestCases = []struct { - name string - left, right int64 - expected string - null bool - }{ - {"1 / 1", 1, 1, "1.0000", false}, - {"-1 / 1", -1, 1, "-1.0000", false}, - {"0 / 1234567890", 0, 12345677890, "0.0000", false}, - {"1/0", 1, 0, "", true}, - {"0/0", 1, 0, "", true}, - } - for _, tt := range intTestCases { - t.Run(tt.name, func(t *testing.T) { - result, err := NewDiv( - NewLiteral(tt.left, types.Int64), - NewLiteral(tt.right, types.Int64), - ).Eval(sql.NewEmptyContext(), sql.NewRow()) - require.NoError(t, err) - if tt.null { - assert.Equal(t, nil, result) - } else { - r, ok := result.(decimal.Decimal) - assert.True(t, ok) - assert.Equal(t, tt.expected, r.StringFixed(r.Exponent()*-1)) - } - }) - } + // Signed Integers + { + left: NewLiteral(1, types.Int32), + right: NewLiteral(1, types.Int32), + exp: "1.0000", + }, + { + left: NewLiteral(1, types.Int32), + right: NewLiteral(2, types.Int32), + exp: "0.5000", + }, + { + left: NewLiteral(-1, types.Int32), + right: NewLiteral(2, types.Int32), + exp: "-0.5000", + }, + { + left: NewLiteral(1, types.Int32), + right: NewLiteral(-2, types.Int32), + exp: "-0.5000", + }, + { + left: NewLiteral(1, types.Int64), + right: NewLiteral(1, types.Int64), + exp: "1.0000", + }, + { + left: NewLiteral(1, types.Int64), + right: NewLiteral(2, types.Int64), + exp: "0.5000", + }, + { + left: NewLiteral(-1, types.Int64), + right: NewLiteral(2, types.Int64), + exp: "-0.5000", + }, + { + left: NewLiteral(1, types.Int64), + right: NewLiteral(-2, types.Int64), + exp: "-0.5000", + }, - var uintTestCases = []struct { - name string - left, right uint64 - expected string - null bool - }{ - {"1 / 1", 1, 1, "1.0000", false}, - {"0 / 1234567890", 0, 12345677890, "0.0000", false}, - {"1/0", 1, 0, "", true}, - {"0/0", 1, 0, "", true}, + // Unsigned and Signed Integers + { + left: NewLiteral(1, types.Uint32), + right: NewLiteral(-2, types.Int32), + exp: "-0.5000", + }, + { + left: NewLiteral(-1, types.Int64), + right: NewLiteral(2, types.Uint32), + exp: "-0.5000", + }, + { + left: NewLiteral(1, types.Int64), + right: NewLiteral(123456789, types.Int64), + exp: "0.0000", + }, + + // Repeating Decimals + { + left: NewLiteral(1, types.Int64), + right: NewLiteral(3, types.Int64), + exp: "0.3333", + }, + { + left: NewLiteral(1, types.Int64), + right: NewLiteral(9, types.Int64), + exp: "0.1111", + }, + { + left: NewLiteral(1, types.Int64), + right: NewLiteral(6, types.Int64), + exp: "0.1667", + }, + + // Floats + { + left: NewLiteral(1.0, types.Float32), + right: NewLiteral(3.0, types.Float32), + exp: 0.3333333333333333, + }, + { + left: NewLiteral(1.0, types.Float32), + right: NewLiteral(9.0, types.Float32), + exp: 0.1111111111111111, + }, + { + left: NewLiteral(1.0, types.Float64), + right: NewLiteral(3.0, types.Float64), + exp: 0.3333333333333333, + }, + { + left: NewLiteral(1.0, types.Float64), + right: NewLiteral(9.0, types.Float64), + exp: 0.1111111111111111, + }, + { + // MySQL treats float32 a little differently + skip: true, + left: NewLiteral(3.14159, types.Float32), + right: NewLiteral(3.0, types.Float32), + exp: 1.0471967061360676, + }, + { + left: NewLiteral(3.14159, types.Float64), + right: NewLiteral(3.0, types.Float64), + exp: 1.0471966666666666, + }, + + // Decimals + { + left: NewLiteral(decimal.New(1, 0), types.MustCreateDecimalType(10, 0)), + right: NewLiteral(decimal.New(3, 0), types.MustCreateDecimalType(10, 0)), + exp: "0.3333", + }, + { + left: NewLiteral(decimal.New(1000, -3), types.MustCreateDecimalType(10, 3)), + right: NewLiteral(decimal.New(3, 0), types.MustCreateDecimalType(10, 0)), + exp: "0.3333333", + }, + { + left: NewLiteral(decimal.New(1, 0), types.MustCreateDecimalType(10, 0)), + right: NewLiteral(decimal.New(3000, -3), types.MustCreateDecimalType(10, 3)), + exp: "0.3333", + }, + { + left: NewLiteral(decimal.New(314159, -5), types.MustCreateDecimalType(10, 5)), + right: NewLiteral(decimal.New(3, 0), types.MustCreateDecimalType(10, 0)), + exp: "1.047196667", + }, + { + left: NewLiteral(decimal.NewFromFloat(3.14159), types.MustCreateDecimalType(10, 5)), + right: NewLiteral(3, types.Int64), + exp: "1.047196667", + }, + + // Bit + { + left: NewLiteral(0, types.MustCreateBitType(1)), + right: NewLiteral(1, types.MustCreateBitType(1)), + exp: "0.0000", + }, + { + left: NewLiteral(1, types.MustCreateBitType(1)), + right: NewLiteral(1, types.MustCreateBitType(1)), + exp: "1.0000", + }, + + // Year + { + left: NewLiteral(2001, types.YearType_{}), + right: NewLiteral(2002, types.YearType_{}), + exp: "0.9995", + }, + + // Time + { + left: NewLiteral("2001-01-01", types.Date), + right: NewLiteral("2001-01-01", types.Date), + exp: "1.0000", + }, + { + left: NewLiteral("2001-01-01 12:00:00", types.Date), + right: NewLiteral("2001-01-01 12:00:00", types.Date), + exp: "1.0000", + }, + { + skip: true, // need to trim just the date portion + left: NewLiteral("2001-01-01 12:00:00.123456", types.Date), + right: NewLiteral("2001-01-01 12:00:00.123456", types.Date), + exp: "1.0000", + }, + { + left: NewLiteral("2001-01-01 12:00:00", types.Datetime), + right: NewLiteral("2001-01-01 12:00:00", types.Datetime), + exp: "1.0000", + }, + { + skip: true, // need to trim just the datetime portion according to precision and use as exponent + left: NewLiteral("2001-01-01 12:00:00.123456", types.Datetime), + right: NewLiteral("2001-01-01 12:00:00.123456", types.Datetime), + exp: "1.0000", + }, + { + skip: true, // need to trim just the datetime portion according to precision and use as exponent + left: NewLiteral("2001-01-01 12:00:00.123456", types.MustCreateDatetimeType(sqltypes.Datetime, 3)), + right: NewLiteral("2001-01-01 12:00:00.123456", types.MustCreateDatetimeType(sqltypes.Datetime, 3)), + exp: "1.0000000", + }, + { + left: NewLiteral("2001-01-01 12:00:00.123456", types.DatetimeMaxPrecision), + right: NewLiteral("2001-01-01 12:00:00.123456", types.DatetimeMaxPrecision), + exp: "1.0000000000", + }, + + // Text + { + left: NewLiteral("1", types.Text), + right: NewLiteral("3", types.Text), + exp: 0.3333333333333333, + }, + { + left: NewLiteral("1.000", types.Text), + right: NewLiteral("3", types.Text), + exp: 0.3333333333333333, + }, + { + left: NewLiteral("1", types.Text), + right: NewLiteral("3.000", types.Text), + exp: 0.3333333333333333, + }, + { + left: NewLiteral("3.14159", types.Text), + right: NewLiteral("3", types.Text), + exp: 1.0471966666666666, + }, + { + left: NewLiteral("1", types.Text), + right: NewLiteral(decimal.New(3, 0), types.MustCreateDecimalType(10, 0)), + exp: 0.3333333333333333, + }, + { + left: NewLiteral(decimal.New(1, 0), types.MustCreateDecimalType(10, 0)), + right: NewLiteral("3", types.Text), + exp: 0.3333333333333333, + }, } - for _, tt := range uintTestCases { - t.Run(tt.name, func(t *testing.T) { - result, err := NewDiv( - NewLiteral(tt.left, types.Uint64), - NewLiteral(tt.right, types.Uint64), - ).Eval(sql.NewEmptyContext(), sql.NewRow()) - require.NoError(t, err) - if tt.null { - assert.Equal(t, nil, result) - } else { - r, ok := result.(decimal.Decimal) - assert.True(t, ok) - assert.Equal(t, tt.expected, r.StringFixed(r.Exponent()*-1)) + + for _, tt := range testCases { + name := fmt.Sprintf("%s(%v)/%s(%v)", tt.left.Type(), tt.left, tt.right.Type(), tt.right) + t.Run(name, func(t *testing.T) { + require := require.New(t) + if tt.skip { + t.Skip() + } + f := NewDiv(tt.left, tt.right) + result, err := f.Eval(sql.NewEmptyContext(), nil) + if tt.err != nil { + require.Error(err) + require.True(tt.err.Is(err), err.Error()) + return + } + require.NoError(err) + if dec, ok := result.(decimal.Decimal); ok { + result = dec.StringFixed(dec.Exponent()*-1) } + assert.Equal(t, tt.exp, result) }) } } @@ -119,14 +320,10 @@ func TestDiv(t *testing.T) { // TestDivUsesFloatsInternally tests that division expression trees internally use floating point types when operating // on integers, but when returning the final result from the expression tree, it is returned as a Decimal. func TestDivUsesFloatsInternally(t *testing.T) { - t.Skip("maybe we don't want this") - bottomDiv := NewDiv( - NewGetField(0, types.Int32, "", false), - NewGetField(1, types.Int64, "", false)) - middleDiv := NewDiv(bottomDiv, - NewGetField(2, types.Int64, "", false)) - topDiv := NewDiv(middleDiv, - NewGetField(3, types.Int64, "", false)) + t.Skip("TODO: see if we can actually enable this") + bottomDiv := NewDiv(NewGetField(0, types.Int32, "", false), NewGetField(1, types.Int64, "", false)) + middleDiv := NewDiv(bottomDiv, NewGetField(2, types.Int64, "", false)) + topDiv := NewDiv(middleDiv, NewGetField(3, types.Int64, "", false)) result, err := topDiv.Eval(sql.NewEmptyContext(), sql.NewRow(250, 2, 5, 2)) require.NoError(t, err) From 8bbaa4465f633ead7701a8e03a3180344753acc4 Mon Sep 17 00:00:00 2001 From: James Cor Date: Mon, 12 Feb 2024 17:13:13 -0800 Subject: [PATCH 08/27] more tests --- enginetest/queries/queries.go | 52 +++++++++++++++++++++++++++- enginetest/queries/script_queries.go | 6 ++++ 2 files changed, 57 insertions(+), 1 deletion(-) diff --git a/enginetest/queries/queries.go b/enginetest/queries/queries.go index 7ff7a83e82..0d3fc3a223 100644 --- a/enginetest/queries/queries.go +++ b/enginetest/queries/queries.go @@ -782,7 +782,8 @@ var QueryTests = []QueryTest{ { Query: "select 1 as x from xy having AVG(x) > 0", Expected: []sql.Row{{1}}, - }, { + }, + { Query: "select 1 as x, AVG(x) from xy group by (y) having AVG(x) > 0", Expected: []sql.Row{{1, float64(1)}, {1, float64(2)}, {1, float64(3)}}, }, @@ -2550,6 +2551,31 @@ Select * from ( Query: "SELECT i + 1 FROM mytable;", Expected: []sql.Row{{int64(2)}, {int64(3)}, {int64(4)}}, }, + { + Query: "select 1 / 3 * 3;", + Expected: []sql.Row{ + {"1.0000"}, + }, + }, + { + Query: "select 1 / 3 * 3 = 0.999999999;", + Expected: []sql.Row{ + {true}, + }, + }, + { + Query: "select 1.00000 / 3 * 3 = 0.999999999;", + Expected: []sql.Row{ + {true}, + }, + }, + // TODO: fix this + //{ + // Query: "select 1.000000 / 3 * 3 = 0.999999999999999999;", + // Expected: []sql.Row{ + // {true}, + // }, + //}, { Query: "SELECT i div 2 FROM mytable order by 1;", Expected: []sql.Row{{int64(0)}, {int64(1)}, {int64(1)}}, @@ -2734,6 +2760,30 @@ Select * from ( Query: "SELECT 'HOMER' IN (1.0)", Expected: []sql.Row{{false}}, }, + { + Query: "select 1 / 3 * 3 in (0.999999999);", + Expected: []sql.Row{{true}}, + }, + { + Query: "SELECT 99 NOT IN ( 98 + 97 / 99 );", + Expected: []sql.Row{{true}}, + }, + { + Query: "SELECT 1 NOT IN ( 97 / 99 );", + Expected: []sql.Row{{true}}, + }, + { + Query: `SELECT 1 NOT IN (1 / 9 * 5);`, + Expected: []sql.Row{{true}}, + }, + { + Query: `SELECT 1 / 9 * 5 NOT IN (1);`, + Expected: []sql.Row{{true}}, + }, + { + Query: `SELECT 1 / 9 * 5 IN (1 / 9 * 5);`, + Expected: []sql.Row{{true}}, + }, { Query: `SELECT * FROM mytable WHERE i in (CAST(NULL AS SIGNED), 2, 3, 4)`, Expected: []sql.Row{{3, "third row"}, {2, "second row"}}, diff --git a/enginetest/queries/script_queries.go b/enginetest/queries/script_queries.go index b485e3aa8f..63569c152a 100644 --- a/enginetest/queries/script_queries.go +++ b/enginetest/queries/script_queries.go @@ -5401,6 +5401,12 @@ CREATE TABLE tab3 ( }, }, }, + //{ + // Name: "dividing has different rounding behavior", + // SetUpScript: []string{ + // + // }, + //}, } var SpatialScriptTests = []ScriptTest{ From b0a5aa5b80beb983382fdeead359ed722b92b53e Mon Sep 17 00:00:00 2001 From: jycor Date: Tue, 13 Feb 2024 01:15:25 +0000 Subject: [PATCH 09/27] [ga-format-pr] Run ./format_repo.sh to fix formatting --- enginetest/queries/queries.go | 6 +++--- sql/expression/arithmetic_test.go | 10 +++++----- sql/expression/div_test.go | 8 ++++---- 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/enginetest/queries/queries.go b/enginetest/queries/queries.go index 0d3fc3a223..c68b84083e 100644 --- a/enginetest/queries/queries.go +++ b/enginetest/queries/queries.go @@ -2761,15 +2761,15 @@ Select * from ( Expected: []sql.Row{{false}}, }, { - Query: "select 1 / 3 * 3 in (0.999999999);", + Query: "select 1 / 3 * 3 in (0.999999999);", Expected: []sql.Row{{true}}, }, { - Query: "SELECT 99 NOT IN ( 98 + 97 / 99 );", + Query: "SELECT 99 NOT IN ( 98 + 97 / 99 );", Expected: []sql.Row{{true}}, }, { - Query: "SELECT 1 NOT IN ( 97 / 99 );", + Query: "SELECT 1 NOT IN ( 97 / 99 );", Expected: []sql.Row{{true}}, }, { diff --git a/sql/expression/arithmetic_test.go b/sql/expression/arithmetic_test.go index 8aa5f1ec03..cdf5776a36 100644 --- a/sql/expression/arithmetic_test.go +++ b/sql/expression/arithmetic_test.go @@ -19,6 +19,7 @@ import ( "testing" "time" + "github.com/dolthub/vitess/go/sqltypes" "github.com/shopspring/decimal" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -27,7 +28,6 @@ import ( "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/go-mysql-server/sql/types" _ "github.com/dolthub/go-mysql-server/sql/variables" - "github.com/dolthub/vitess/go/sqltypes" ) func TestPlus(t *testing.T) { @@ -186,7 +186,7 @@ func TestPlus(t *testing.T) { result, err := f.Eval(sql.NewEmptyContext(), nil) require.NoError(err) if dec, ok := result.(decimal.Decimal); ok { - result = dec.StringFixed(dec.Exponent()*-1) + result = dec.StringFixed(dec.Exponent() * -1) } assert.Equal(t, tt.exp, result) }) @@ -344,7 +344,7 @@ func TestMinus(t *testing.T) { result, err := f.Eval(sql.NewEmptyContext(), nil) require.NoError(err) if dec, ok := result.(decimal.Decimal); ok { - result = dec.StringFixed(dec.Exponent()*-1) + result = dec.StringFixed(dec.Exponent() * -1) } assert.Equal(t, tt.exp, result) }) @@ -459,7 +459,7 @@ func TestMult(t *testing.T) { }, { // MySQL throws out of range - skip: true, + skip: true, left: NewLiteral("2001-01-01 12:00:00", types.Datetime), right: NewLiteral("2001-01-01 12:00:00", types.Datetime), err: sql.ErrValueOutOfRange, @@ -515,7 +515,7 @@ func TestMult(t *testing.T) { } require.NoError(err) if dec, ok := result.(decimal.Decimal); ok { - result = dec.StringFixed(dec.Exponent()*-1) + result = dec.StringFixed(dec.Exponent() * -1) } assert.Equal(t, tt.exp, result) }) diff --git a/sql/expression/div_test.go b/sql/expression/div_test.go index 44c49ebb64..6a806cd213 100644 --- a/sql/expression/div_test.go +++ b/sql/expression/div_test.go @@ -16,13 +16,13 @@ package expression import ( "fmt" - "github.com/dolthub/vitess/go/sqltypes" -"gopkg.in/src-d/go-errors.v1" "testing" + "github.com/dolthub/vitess/go/sqltypes" "github.com/shopspring/decimal" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "gopkg.in/src-d/go-errors.v1" "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/go-mysql-server/sql/types" @@ -164,7 +164,7 @@ func TestDiv(t *testing.T) { }, { // MySQL treats float32 a little differently - skip: true, + skip: true, left: NewLiteral(3.14159, types.Float32), right: NewLiteral(3.0, types.Float32), exp: 1.0471967061360676, @@ -310,7 +310,7 @@ func TestDiv(t *testing.T) { } require.NoError(err) if dec, ok := result.(decimal.Decimal); ok { - result = dec.StringFixed(dec.Exponent()*-1) + result = dec.StringFixed(dec.Exponent() * -1) } assert.Equal(t, tt.exp, result) }) From a98c9cb118d17cd6cf859f957d248dbc65239f09 Mon Sep 17 00:00:00 2001 From: James Cor Date: Tue, 13 Feb 2024 15:06:28 -0800 Subject: [PATCH 10/27] implement internal decimal scale --- enginetest/queries/queries.go | 175 +++++++++++++++++++++++++++ enginetest/queries/script_queries.go | 35 +++++- sql/expression/arithmetic.go | 4 - sql/expression/convert.go | 4 + sql/expression/div.go | 52 ++++---- 5 files changed, 230 insertions(+), 40 deletions(-) diff --git a/enginetest/queries/queries.go b/enginetest/queries/queries.go index c68b84083e..a3d943b581 100644 --- a/enginetest/queries/queries.go +++ b/enginetest/queries/queries.go @@ -2551,6 +2551,10 @@ Select * from ( Query: "SELECT i + 1 FROM mytable;", Expected: []sql.Row{{int64(2)}, {int64(3)}, {int64(4)}}, }, + { + Query: `select (1 / 3) * (1 / 3);`, + Expected: []sql.Row{{"0.11111111"}}, + }, { Query: "select 1 / 3 * 3;", Expected: []sql.Row{ @@ -4153,6 +4157,177 @@ Select * from ( Query: "select 1/2/3%4/5/6;", Expected: []sql.Row{{"0.0055555555555556"}}, }, + + // check that internal precision is preserved in comparisons + { + // 0 scale + 0 scale = 9 scale + Query: "select 1 / 3 = 0.333333333;", + Expected: []sql.Row{{true}}, + }, + { + // 0 scale + 1 scale = 9 scale + Query: "select 1 / 3.0 = 0.333333333;", + Expected: []sql.Row{{true}}, + }, + { + // 0 scale + 6 scale = 18 scale + Query: "select 1 / 3.000000 = 0.333333333333333333;", + Expected: []sql.Row{{true}}, + }, + { + // 0 scale + 15 scale = 27 scale + Query: "select 1 / 3.000000000000000 = 0.333333333333333333333333333;", + Expected: []sql.Row{{true}}, + }, + { + // 0 scale + 24 scale = 36 scale + Query: "select 1 / 3.000000000000000000000000 = 0.333333333333333333333333333333333333;", + Expected: []sql.Row{{true}}, + }, + + { + // 1 scale + 0 scale = 9 scale + Query: "select 1.0 / 3 = 0.333333333;", + Expected: []sql.Row{{true}}, + }, + { + // 1 scale + 1 scale = 18 scale + Query: "select 1.0 / 3.0 = 0.333333333333333333;", + Expected: []sql.Row{{true}}, + }, + { + // 1 scale + 10 scale = 27 scale + Query: "select 1.0 / 3.0000000000 = 0.333333333333333333333333333;", + Expected: []sql.Row{{true}}, + }, + { + // 1 scale + 19 scale = 36 scale + Query: "select 1.0 / 3.0000000000000000000 = 0.333333333333333333333333333333333333;", + Expected: []sql.Row{{true}}, + }, + + { + // 6 scale + 8 scale = 18 scale + Query: "select 1.000000 / 3.00000000 = 0.333333333333333333;", + Expected: []sql.Row{{true}}, + }, + { + // 6 scale + 9 scale = 27 scale + Query: "select 1.000000 / 3.000000000 = 0.333333333333333333333333333;", + Expected: []sql.Row{{true}}, + }, + { + // 6 scale + 17 scale = 27 scale + Query: "select 1.000000 / 3.00000000000000000 = 0.333333333333333333333333333;", + Expected: []sql.Row{{true}}, + }, + { + // 6 scale + 18 scale = 36 scale + Query: "select 1.000000 / 3.000000000000000000 = 0.333333333333333333333333333333333333;", + Expected: []sql.Row{{true}}, + }, + + { + // 7 scale + 7 scale = 18 scale + Query: "select 1.0000000 / 3.0000000 = 0.333333333333333333;", + Expected: []sql.Row{{true}}, + }, + { + // 7 scale + 8 scale = 27 scale + Query: "select 1.0000000 / 3.00000000 = 0.333333333333333333333333333;", + Expected: []sql.Row{{true}}, + }, + { + // 7 scale + 16 scale = 27 scale + Query: "select 1.0000000 / 3.0000000000000000 = 0.333333333333333333333333333;", + Expected: []sql.Row{{true}}, + }, + { + // 7 scale + 15 scale = 36 scale + Query: "select 1.0000000 / 3.00000000000000000 = 0.333333333333333333333333333333333333;", + Expected: []sql.Row{{true}}, + }, + + { + // 8 scale + 6 scale = 18 scale + Query: "select 1.00000000 / 3.000000 = 0.333333333333333333;", + Expected: []sql.Row{{true}}, + }, + { + // 8 scale + 7 scale = 27 scale + Query: "select 1.00000000 / 3.0000000 = 0.333333333333333333333333333;", + Expected: []sql.Row{{true}}, + }, + { + // 8 scale + 15 scale = 27 scale + Query: "select 1.00000000 / 3.000000000000000 = 0.333333333333333333333333333;", + Expected: []sql.Row{{true}}, + }, + { + // 8 scale + 14 scale = 36 scale + Query: "select 1.00000000 / 3.0000000000000000 = 0.333333333333333333333333333333333333;", + Expected: []sql.Row{{true}}, + }, + + { + // 9 scale + 5 scale = 18 scale + Query: "select 1.000000000 / 3.00000 = 0.333333333333333333;", + Expected: []sql.Row{{true}}, + }, + { + // 9 scale + 6 scale = 27 scale + Query: "select 1.000000000 / 3.000000 = 0.333333333333333333333333333;", + Expected: []sql.Row{{true}}, + }, + { + // 9 scale + 14 scale = 27 scale + Query: "select 1.000000000 / 3.00000000000000 = 0.333333333333333333333333333;", + Expected: []sql.Row{{true}}, + }, + { + // 9 scale + 13 scale = 36 scale + Query: "select 1.000000000 / 3.000000000000000 = 0.333333333333333333333333333333333333;", + Expected: []sql.Row{{true}}, + }, + + { + // 10 scale + 1 scale = 27 scale + Query: "select 1.0000000000 / 3.0 = 0.333333333333333333333333333;", + Expected: []sql.Row{{true}}, + }, + { + // 10 scale + 10 scale = 36 scale + Query: "select 1.0000000000 / 3.0000000000 = 0.333333333333333333333333333333333333;", + Expected: []sql.Row{{true}}, + }, + + // check that decimal internal precision is preserved in casts + { + // 0 scale + 0 scale = 9 scale + Query: "select cast(1 / 3 as decimal(65,30));", + Expected: []sql.Row{{"0.333333333000000000000000000000"}}, + }, + { + // 0 scale + 1 scale = 9 scale + Query: "select cast(1 / 3.0 as decimal(65,30));", + Expected: []sql.Row{{"0.333333333000000000000000000000"}}, + }, + { + // 0 scale + 6 scale = 18 scale + Query: "select cast(1 / 3.000000 as decimal(65,30));", + Expected: []sql.Row{{"0.333333333333333333000000000000"}}, + }, + { + // 0 scale + 15 scale = 27 scale + Query: "select cast(1 / 3.000000000000000 as decimal(65,30));", + Expected: []sql.Row{{"0.333333333333333333333333333000"}}, + }, + { + // 0 scale + 24 scale = 36 scale + Query: "select cast(1 / 3.000000000000000000000000 as decimal(65,30));", + Expected: []sql.Row{{"0.333333333333333333333333333333"}}, + }, + { Query: "select 0.05 % 0.024;", Expected: []sql.Row{{"0.002"}}, diff --git a/enginetest/queries/script_queries.go b/enginetest/queries/script_queries.go index 63569c152a..26b732d9dd 100644 --- a/enginetest/queries/script_queries.go +++ b/enginetest/queries/script_queries.go @@ -5401,12 +5401,35 @@ CREATE TABLE tab3 ( }, }, }, - //{ - // Name: "dividing has different rounding behavior", - // SetUpScript: []string{ - // - // }, - //}, + { + Name: "dividing has different rounding behavior", + SetUpScript: []string{ + "CREATE TABLE tab0(col0 INTEGER, col1 INTEGER, col2 INTEGER);", + "INSERT INTO tab0 VALUES(97, 1, 99);", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "SELECT col2 IN ( 98 + col0 / 99 ) from tab0;", + Expected: []sql.Row{ + {false}, + }, + }, + { + Query: "SELECT col2 IN ( 98 + 97 / 99 ) from tab0;", + Expected: []sql.Row{ + {false}, + }, + }, + { + Query: "SELECT * FROM tab0 WHERE col2 IN ( 98 + 97 / 99 );", + Expected: []sql.Row{}, + }, + { + Query: "SELECT ALL * FROM tab0 AS cor0 WHERE col2 IN ( 39 + + 89, col0 + + col1 + + ( - ( - col0 ) ) / col2, + ( col0 ) + - 99, + col1, + col2 * - + col2 * - 12 + col1 + - 66 );", + Expected: []sql.Row{}, + }, + }, + }, } var SpatialScriptTests = []ScriptTest{ diff --git a/sql/expression/arithmetic.go b/sql/expression/arithmetic.go index 1377e5b0f7..47d3382dc8 100644 --- a/sql/expression/arithmetic.go +++ b/sql/expression/arithmetic.go @@ -309,10 +309,6 @@ func (a *Arithmetic) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { return res.Round(finalScale), nil } } - // In comparisons, we need to truncate decimals to have scale of 9 - if a.ops == -1 { - result = res.Truncate(9) - } } return result, nil diff --git a/sql/expression/convert.go b/sql/expression/convert.go index a33114f509..0173fedb93 100644 --- a/sql/expression/convert.go +++ b/sql/expression/convert.go @@ -82,6 +82,8 @@ var _ sql.CollationCoercible = (*Convert)(nil) // |castToType| type. All optional parameters (i.e. typeLength, typeScale, and charset) are omitted and initialized // to their zero values. func NewConvert(expr sql.Expression, castToType string) *Convert { + setArithmeticOps(expr, -1) + setDivs(expr, -1) return &Convert{ UnaryExpression: UnaryExpression{Child: expr}, castToType: strings.ToLower(castToType), @@ -92,6 +94,8 @@ func NewConvert(expr sql.Expression, castToType string) *Convert { // |castToType| type, with |typeLength| specifying a length constraint of the converted type, and |typeScale| specifying // a scale constraint of the converted type. func NewConvertWithLengthAndScale(expr sql.Expression, castToType string, typeLength, typeScale int) *Convert { + setArithmeticOps(expr, -1) + setDivs(expr, -1) return &Convert{ UnaryExpression: UnaryExpression{Child: expr}, castToType: strings.ToLower(castToType), diff --git a/sql/expression/div.go b/sql/expression/div.go index be16f3e072..9a0af56e68 100644 --- a/sql/expression/div.go +++ b/sql/expression/div.go @@ -34,10 +34,10 @@ var ErrIntDivDataOutOfRange = errors.NewKind("BIGINT value is out of range (%s D // '4 scales' are added to scale of the number on the left side of division operator at every division operation. // The default value is 4, and it can be set using sysvar https://dev.mysql.com/doc/refman/8.0/en/server-system-variables.html#sysvar_div_precision_increment -const divPrecisionIncrement = 4 +const divPrecInc = 4 // '9 scales' are added for every non-integer divider(right side). -const divIntermediatePrecisionInc = 9 +const divIntPrecInc = 9 const ERDivisionByZero = 1365 @@ -252,32 +252,24 @@ func (d *Div) div(ctx *sql.Context, lval, rval interface{}) (interface{}, error) return nil, nil } - if d.curIntermediatePrecisionInc == 0 { - d.curIntermediatePrecisionInc = getPrecInc(d, 0) - // if the first dividend / the leftmost value is non int value, - // then curIntermediatePrecisionInc gets additional increment per every 9 scales - if d.curIntermediatePrecisionInc == 0 { - if !isIntOr1(l) { - d.curIntermediatePrecisionInc = int(math.Ceil(float64(l.Exponent()*-1) / float64(divIntermediatePrecisionInc))) - } + lScale, rScale := -1 * l.Exponent(), -1 * r.Exponent() + inc := int32(math.Ceil(float64(lScale + rScale + divPrecInc) / divIntPrecInc)) + if lScale != 0 && rScale != 0 { + lInc := int32(math.Ceil(float64(lScale) / divIntPrecInc)) + rInc := int32(math.Ceil(float64(rScale) / divIntPrecInc)) + inc2 := lInc + rInc + if inc2 > inc { + inc = inc2 } } - - // for every divider we increment the curIntermediatePrecisionInc per every 9 scales - // for 0 scaled number, we increment 1 - if r.Exponent() == 0 { - d.curIntermediatePrecisionInc += 1 - } else { - d.curIntermediatePrecisionInc += int(math.Ceil(float64(r.Exponent()*-1) / float64(divIntermediatePrecisionInc))) - } - - storedScale := d.leftmostScale.Load() + int32(d.curIntermediatePrecisionInc*divIntermediatePrecisionInc) - l = l.Truncate(storedScale) - r = r.Truncate(storedScale) + scale := inc * divIntPrecInc + l = l.Truncate(scale) + r = r.Truncate(scale) // give it buffer of 2 additional scale to avoid the result to be rounded - divRes := l.DivRound(r, storedScale+2) - return divRes.Truncate(storedScale), nil + res := l.DivRound(r, scale + 2) + res = res.Truncate(scale) + return res, nil } } @@ -321,7 +313,7 @@ func (d *Div) determineResultType(outermostResult bool) sql.Type { if types.IsDatetimeType(lTyp) { if dtType, ok := lTyp.(sql.DatetimeType); ok { - scale := uint8(dtType.Precision() + divPrecisionIncrement) + scale := uint8(dtType.Precision() + divPrecInc) if scale > types.DecimalTypeMaxScale { scale = types.DecimalTypeMaxScale } @@ -332,12 +324,12 @@ func (d *Div) determineResultType(outermostResult bool) sql.Type { if types.IsDecimal(lTyp) { prec, scale := lTyp.(types.DecimalType_).Precision(), lTyp.(types.DecimalType_).Scale() - scale = scale + divPrecisionIncrement + scale = scale + divPrecInc if d.ops == -1 { scale = (scale/9 + 1) * 9 prec = prec + scale } else { - prec = prec + divPrecisionIncrement + prec = prec + divPrecInc } if prec > types.DecimalTypeMaxPrecision { @@ -353,7 +345,7 @@ func (d *Div) determineResultType(outermostResult bool) sql.Type { if d.ops == -1 { return types.MustCreateDecimalType(types.DecimalTypeMaxPrecision, 9) } - return types.MustCreateDecimalType(types.DecimalTypeMaxPrecision, divPrecisionIncrement) + return types.MustCreateDecimalType(types.DecimalTypeMaxPrecision, divPrecInc) } // TODO: this is unused now, consider deleting @@ -383,7 +375,7 @@ func floatOrDecimalTypeForDiv(e sql.Expression, treatIntsAsFloats bool) sql.Type maxScale := scale div := e.(*Div) - finalScale := div.leftmostScale.Load() + div.divScale*int32(divPrecisionIncrement) + finalScale := div.leftmostScale.Load() + div.divScale*int32(divPrecInc) if uint8(finalScale) > maxScale { maxScale = uint8(finalScale) @@ -635,7 +627,7 @@ func getFinalScale(e sql.Expression) (int32, bool) { } if d, isDiv := e.(*Div); isDiv { - finalScale := d.leftmostScale.Load() + d.divScale*int32(divPrecisionIncrement) + finalScale := d.leftmostScale.Load() + d.divScale*int32(divPrecInc) if finalScale > types.DecimalTypeMaxScale { finalScale = types.DecimalTypeMaxScale } From c45c076f448d69412d8c8d2e2961e5aff5896a82 Mon Sep 17 00:00:00 2001 From: jycor Date: Tue, 13 Feb 2024 23:07:59 +0000 Subject: [PATCH 11/27] [ga-format-pr] Run ./format_repo.sh to fix formatting --- sql/expression/div.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/expression/div.go b/sql/expression/div.go index 9a0af56e68..6513a0ff69 100644 --- a/sql/expression/div.go +++ b/sql/expression/div.go @@ -252,8 +252,8 @@ func (d *Div) div(ctx *sql.Context, lval, rval interface{}) (interface{}, error) return nil, nil } - lScale, rScale := -1 * l.Exponent(), -1 * r.Exponent() - inc := int32(math.Ceil(float64(lScale + rScale + divPrecInc) / divIntPrecInc)) + lScale, rScale := -1*l.Exponent(), -1*r.Exponent() + inc := int32(math.Ceil(float64(lScale+rScale+divPrecInc) / divIntPrecInc)) if lScale != 0 && rScale != 0 { lInc := int32(math.Ceil(float64(lScale) / divIntPrecInc)) rInc := int32(math.Ceil(float64(rScale) / divIntPrecInc)) @@ -267,7 +267,7 @@ func (d *Div) div(ctx *sql.Context, lval, rval interface{}) (interface{}, error) r = r.Truncate(scale) // give it buffer of 2 additional scale to avoid the result to be rounded - res := l.DivRound(r, scale + 2) + res := l.DivRound(r, scale+2) res = res.Truncate(scale) return res, nil } From da9a7699bcf526b9a07ac2f4add5ba36d61fcd23 Mon Sep 17 00:00:00 2001 From: James Cor Date: Tue, 13 Feb 2024 15:09:01 -0800 Subject: [PATCH 12/27] reverting changes to memory_engine_test --- enginetest/memory_engine_test.go | 121 ++++++------------------------- 1 file changed, 21 insertions(+), 100 deletions(-) diff --git a/enginetest/memory_engine_test.go b/enginetest/memory_engine_test.go index 62e1983703..b2cb356a99 100644 --- a/enginetest/memory_engine_test.go +++ b/enginetest/memory_engine_test.go @@ -203,125 +203,46 @@ func newUpdateResult(matched, updated int) types.OkResult { // Convenience test for debugging a single query. Unskip and set to the desired query. func TestSingleScript(t *testing.T) { + t.Skip() var scripts = []queries.ScriptTest{ { - Name: "delete me", + Name: "physical columns added after virtual one", SetUpScript: []string{ - "CREATE TABLE tab0(col0 INTEGER, col1 INTEGER, col2 INTEGER);", - "INSERT INTO tab0 VALUES(97,1,99);", + "create table t (pk int primary key, col1 int as (pk + 1));", + "insert into t (pk) values (1), (3)", + "alter table t add index idx1 (col1, pk);", + "alter table t add index idx2 (col1);", + "alter table t add column col2 int;", + "alter table t add column col3 int;", + "insert into t (pk, col2, col3) values (2, 4, 5);", }, Assertions: []queries.ScriptTestAssertion{ { - Query: "select 1 / 1", + Query: "select * from t order by pk", Expected: []sql.Row{ - {"1.0000"}, + {1, 2, nil, nil}, + {2, 3, 4, 5}, + {3, 4, nil, nil}, }, }, - { - Query: "select 1 / 3 * 3;", + Query: "select * from t where col1 = 2", Expected: []sql.Row{ - {"1.0000"}, + {1, 2, nil, nil}, }, }, { - Query: "select 1 / 3 * 3 = 0.999999999;", + Query: "select * from t where col1 = 3 and pk = 2", Expected: []sql.Row{ - {true}, + {2, 3, 4, 5}, }, }, { - Query: "SELECT col2 IN ( 98 + col0 / 99 ) from tab0;", + Query: "select * from t where pk = 2", Expected: []sql.Row{ - {false}, + {2, 3, 4, 5}, }, }, - { - Query: "SELECT col2 IN ( 98 + 97 / 99 ) from tab0;", - Expected: []sql.Row{ - {false}, - }, - }, - { - Query: "SELECT 99 IN ( 98 + 97 / 99 );", - Expected: []sql.Row{ - {false}, - }, - }, - { - Query: "SELECT 1 IN ( 97 / 99 );", - Expected: []sql.Row{ - {false}, - }, - }, - - { - Query: "SELECT * FROM tab0 WHERE col2 IN ( 98 + 97 / 99 );", - Expected: []sql.Row{}, - }, - { - Query: "SELECT ALL * FROM tab0 AS cor0 WHERE col2 IN ( 39 + + 89, col0 + + col1 + + ( - ( - col0 ) ) / col2, + ( col0 ) + - 99, + col1, + col2 * - + col2 * - 12 + col1 + - 66 );", - Expected: []sql.Row{}, - }, - - { - Query: `SELECT 1 IN (1 / 9 * 5);`, - Expected: []sql.Row{{false}}, - }, - { - Query: `select 1 / 3 * 3 = 1;`, - Expected: []sql.Row{{false}}, - }, - { - Query: `select 1 / 3 * 3 = 0.999999999;`, - Expected: []sql.Row{{true}}, - }, - - { - Query: `select 1 / 3 * 3 in (1);`, - Expected: []sql.Row{{false}}, - }, - { - Query: `select 1 in (1 / 3 * 3);`, - Expected: []sql.Row{{false}}, - }, - { - Query: `SELECT 1 IN (1 / 9 * 5);`, - Expected: []sql.Row{{false}}, - }, - { - Query: `SELECT 1 / 9 * 5 IN (1);`, - Expected: []sql.Row{{false}}, - }, - { - Query: `SELECT 1 / 9 * 5 IN (1 / 9 * 5);`, - Expected: []sql.Row{{true}}, - }, - - { - Query: `SELECT 1 IN (1 / 99 * 50);`, - Expected: []sql.Row{{false}}, - }, - { - Query: `select 1 / 3 * 3 in (0.999999999);`, - Expected: []sql.Row{{true}}, - }, - { - Query: `SELECT 96 / 51 * 51 > 96;`, - Expected: []sql.Row{{false}}, - }, - { - Query: `SELECT 96 / 51 * 51 = 95.999999991;`, - Expected: []sql.Row{{true}}, - }, - { - Query: `select 64 / 77 * 77;`, - Expected: []sql.Row{{"64.0000"}}, - }, - { - Query: `select (1 / 3) * (1 / 3);`, - Expected: []sql.Row{{"0.11111111"}}, - }, }, }, } @@ -333,8 +254,8 @@ func TestSingleScript(t *testing.T) { if err != nil { panic(err) } - //engine.EngineAnalyzer().Debug = true - //engine.EngineAnalyzer().Verbose = true + engine.EngineAnalyzer().Debug = true + engine.EngineAnalyzer().Verbose = true enginetest.TestScriptWithEngine(t, engine, harness, test) } From a5950269eed885c198a4d6e5bb33aab6fba6ba9f Mon Sep 17 00:00:00 2001 From: James Cor Date: Tue, 13 Feb 2024 15:20:09 -0800 Subject: [PATCH 13/27] clean up --- enginetest/queries/queries.go | 46 +++++--------------------- sql/expression/arithmetic.go | 5 ++- sql/expression/function/str_to_date.go | 2 +- sql/expression/function/time.go | 2 +- 4 files changed, 13 insertions(+), 42 deletions(-) diff --git a/enginetest/queries/queries.go b/enginetest/queries/queries.go index a3d943b581..bdcefccf9d 100644 --- a/enginetest/queries/queries.go +++ b/enginetest/queries/queries.go @@ -2555,31 +2555,6 @@ Select * from ( Query: `select (1 / 3) * (1 / 3);`, Expected: []sql.Row{{"0.11111111"}}, }, - { - Query: "select 1 / 3 * 3;", - Expected: []sql.Row{ - {"1.0000"}, - }, - }, - { - Query: "select 1 / 3 * 3 = 0.999999999;", - Expected: []sql.Row{ - {true}, - }, - }, - { - Query: "select 1.00000 / 3 * 3 = 0.999999999;", - Expected: []sql.Row{ - {true}, - }, - }, - // TODO: fix this - //{ - // Query: "select 1.000000 / 3 * 3 = 0.999999999999999999;", - // Expected: []sql.Row{ - // {true}, - // }, - //}, { Query: "SELECT i div 2 FROM mytable order by 1;", Expected: []sql.Row{{int64(0)}, {int64(1)}, {int64(1)}}, @@ -2788,6 +2763,10 @@ Select * from ( Query: `SELECT 1 / 9 * 5 IN (1 / 9 * 5);`, Expected: []sql.Row{{true}}, }, + { + Query: "select 0 in (1/100000);", + Expected: []sql.Row{{false}}, + }, { Query: `SELECT * FROM mytable WHERE i in (CAST(NULL AS SIGNED), 2, 3, 4)`, Expected: []sql.Row{{3, "third row"}, {2, "second row"}}, @@ -6465,11 +6444,6 @@ Select * from ( Query: `SELECT STR_TO_DATE('01,5,2013 09:30:17','%d,%m,%Y %h:%i:%s') - (STR_TO_DATE('01,5,2013 09:30:17','%d,%m,%Y %h:%i:%s') - INTERVAL 1 SECOND)`, Expected: []sql.Row{{int64(1)}}, }, - // TODO: skip this test - //{ - // Query: `SELECT STR_TO_DATE('01,5,2013 09:30:17','%d,%m,%Y %h:%i:%s %f') - (STR_TO_DATE('01,5,2013 09:30:17','%d,%m,%Y %h:%i:%s') - INTERVAL 1 SECOND)`, - // Expected: []sql.Row{{int64(1)}}, - //}, { Query: `SELECT SUBSTR(SUBSTRING('0123456789ABCDEF', 1, 10), -4)`, Expected: []sql.Row{{"6789"}}, @@ -9246,13 +9220,6 @@ var KeylessQueries = []QueryTest{ // BrokenQueries are queries that are known to be broken in the engine. var BrokenQueries = []QueryTest{ - // https://github.com/dolthub/dolt/issues/7207 - { - Query: "select 0 in (1/100000);", - Expected: []sql.Row{ - {false}, - }, - }, // union and aggregation typing are tricky { Query: "with recursive t (n) as (select sum('1') from dual union all select (2.00) from dual) select sum(n) from t;", @@ -9357,6 +9324,11 @@ var BrokenQueries = []QueryTest{ Query: "SELECT STR_TO_DATE('2013 32 Tuesday', '%X %V %W')", // Tuesday of 32th week Expected: []sql.Row{{"2013-08-13"}}, }, + { + // TODO: need to properly handle datetime precision + Query: `SELECT STR_TO_DATE('01,5,2013 09:30:17','%d,%m,%Y %h:%i:%s %f') - (STR_TO_DATE('01,5,2013 09:30:17','%d,%m,%Y %h:%i:%s') - INTERVAL 1 SECOND)`, + Expected: []sql.Row{{int64(1)}}, + }, { // This panics // The non-recursive part of the UNION ALL returns too many rows, causing index out of bounds errors diff --git a/sql/expression/arithmetic.go b/sql/expression/arithmetic.go index 47d3382dc8..dc7ffce22d 100644 --- a/sql/expression/arithmetic.go +++ b/sql/expression/arithmetic.go @@ -134,7 +134,7 @@ func (a *Arithmetic) Type() sql.Type { // applies for + and - ops if isInterval(a.LeftChild) || isInterval(a.RightChild) { - // TODO: need to use the precision stored in datetimeType + // TODO: need to use the precision stored in datetimeType; something like // return MustCreateDatetimeType(sqltypes.Datetime, ...) return types.Datetime } @@ -172,7 +172,7 @@ func (a *Arithmetic) Type() sql.Type { rTyp = types.Int64 } - // Datetimes are decimals, unless they have precision 0 + // Datetime(0) is treated as Int64, otherwise as Decimal if types.IsDatetimeType(lTyp) { if dtType, ok := lTyp.(sql.DatetimeType); ok { scale := uint8(dtType.Precision()) @@ -216,7 +216,6 @@ func (a *Arithmetic) Type() sql.Type { rPrec := rTyp.(types.DecimalType_).Precision() rScale := rTyp.(types.DecimalType_).Scale() - // TODO: determine real precision var prec, scale uint8 if lPrec > rPrec { prec = lPrec diff --git a/sql/expression/function/str_to_date.go b/sql/expression/function/str_to_date.go index def6760c3e..6fc0bd68d0 100644 --- a/sql/expression/function/str_to_date.go +++ b/sql/expression/function/str_to_date.go @@ -47,7 +47,7 @@ func (s StrToDate) String() string { // Type returns the expression type. func (s StrToDate) Type() sql.Type { - // TODO: needs to take into account precision + // TODO: precision return types.Datetime } diff --git a/sql/expression/function/time.go b/sql/expression/function/time.go index e2199396dd..7bd1f900a2 100644 --- a/sql/expression/function/time.go +++ b/sql/expression/function/time.go @@ -948,7 +948,7 @@ func (n *Now) Description() string { // Type implements the sql.Expression interface. func (n *Now) Type() sql.Type { - // TODO: This should be types.NewDatetime(n.prec) + // TODO: precision if n.prec == nil { return types.Datetime } From 5bc7edbddd95887a9101a49011b27667dececf4e Mon Sep 17 00:00:00 2001 From: James Cor Date: Tue, 13 Feb 2024 15:22:44 -0800 Subject: [PATCH 14/27] fix test --- sql/expression/div_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/expression/div_test.go b/sql/expression/div_test.go index 6a806cd213..a76b047a44 100644 --- a/sql/expression/div_test.go +++ b/sql/expression/div_test.go @@ -194,12 +194,12 @@ func TestDiv(t *testing.T) { { left: NewLiteral(decimal.New(314159, -5), types.MustCreateDecimalType(10, 5)), right: NewLiteral(decimal.New(3, 0), types.MustCreateDecimalType(10, 0)), - exp: "1.047196667", + exp: "1.047196666", }, { left: NewLiteral(decimal.NewFromFloat(3.14159), types.MustCreateDecimalType(10, 5)), right: NewLiteral(3, types.Int64), - exp: "1.047196667", + exp: "1.047196666", }, // Bit From 65c3451ef7b7350ed3fc722e76e00cbeb4564698 Mon Sep 17 00:00:00 2001 From: James Cor Date: Wed, 14 Feb 2024 10:42:56 -0800 Subject: [PATCH 15/27] removing unused code --- sql/expression/arithmetic.go | 26 ---------- sql/expression/div.go | 98 ------------------------------------ 2 files changed, 124 deletions(-) diff --git a/sql/expression/arithmetic.go b/sql/expression/arithmetic.go index dc7ffce22d..8b8ca7f552 100644 --- a/sql/expression/arithmetic.go +++ b/sql/expression/arithmetic.go @@ -602,32 +602,6 @@ func minus(lval, rval interface{}) (interface{}, error) { return nil, errUnableToCast.New(lval, rval) } -// floatOrDecimalTypeForMult returns Float64 type if either left or right side is of type int or float. -// Otherwise, it returns decimal type of sum of left and right sides' precisions and scales. E.g. `1.40 * 1.0 = 1.400` -func floatOrDecimalTypeForMult(l, r sql.Expression) sql.Type { - lType := getFloatOrMaxDecimalType(l, false) - rType := getFloatOrMaxDecimalType(r, false) - - if lType == types.Float64 || rType == types.Float64 { - return types.Float64 - } - - lPrec := lType.(types.DecimalType_).Precision() - lScale := lType.(types.DecimalType_).Scale() - rPrec := rType.(types.DecimalType_).Precision() - rScale := rType.(types.DecimalType_).Scale() - - maxWhole := (lPrec - lScale) + (rPrec - rScale) - maxScale := lScale + rScale - if maxWhole > types.DecimalTypeMaxPrecision-types.DecimalTypeMaxScale { - maxWhole = types.DecimalTypeMaxPrecision - types.DecimalTypeMaxScale - } - if maxScale > types.DecimalTypeMaxScale { - maxScale = types.DecimalTypeMaxScale - } - return types.MustCreateDecimalType(maxWhole+maxScale, maxScale) -} - func mult(lval, rval interface{}) (interface{}, error) { switch l := lval.(type) { case uint8: diff --git a/sql/expression/div.go b/sql/expression/div.go index 6513a0ff69..fc1df0578c 100644 --- a/sql/expression/div.go +++ b/sql/expression/div.go @@ -348,51 +348,6 @@ func (d *Div) determineResultType(outermostResult bool) sql.Type { return types.MustCreateDecimalType(types.DecimalTypeMaxPrecision, divPrecInc) } -// TODO: this is unused now, consider deleting -// floatOrDecimalTypeForDiv returns either Float64 or Decimal type depending on column reference, -// left and right expression types and left and right evaluated types. -// If |treatIntsAsFloats| is true, then integers are treated as floats instead of Decimals. This -// is a performance optimization for division operations, since float division can be several orders -// of magnitude faster than division with Decimals. -// Otherwise, the return type is always decimal. The expression and evaluated types -// are used to determine appropriate Decimal type to return that will not result in -// precision loss. -func floatOrDecimalTypeForDiv(e sql.Expression, treatIntsAsFloats bool) sql.Type { - t := getFloatOrMaxDecimalType(e, treatIntsAsFloats) - - if t == types.Float64 { - return types.Float64 - } - - // if not float, it must be decimal type - if treatIntsAsFloats { - //return t - } - - // for Div expression, if it's the outermostResult, then add the additional scales for the final result - prec, scale := t.(types.DecimalType_).Precision(), t.(types.DecimalType_).Scale() - maxWhole := prec - scale - maxScale := scale - - div := e.(*Div) - finalScale := div.leftmostScale.Load() + div.divScale*int32(divPrecInc) - - if uint8(finalScale) > maxScale { - maxScale = uint8(finalScale) - } - - if maxScale > types.DecimalTypeMaxScale { - maxScale = types.DecimalTypeMaxScale - } - - prec = maxWhole + maxScale - if prec > types.DecimalTypeMaxPrecision { - prec = types.DecimalTypeMaxPrecision - } - - return types.MustCreateDecimalType(prec, maxScale) -} - // getFloatOrMaxDecimalType returns either Float64 or Decimal type with max precision and scale // depending on column reference, expression types and evaluated value types. Otherwise, the return // type is always max decimal type. |treatIntsAsFloats| is used for division operation optimization. @@ -700,59 +655,6 @@ func GetPrecisionAndScale(val interface{}) (uint8, uint8) { return GetDecimalPrecisionAndScale(str) } -// isIntOr1 checks whether the decimal number is equal to 1 -// or it is an integer value even though the value can be -// given as decimal. This function returns true if val is -// 1 or 1.000 or 2.00 or 13. These all are int numbers. -func isIntOr1(val decimal.Decimal) bool { - if val.Equal(decimal.NewFromInt(1)) { - return true - } - if val.Equal(decimal.NewFromInt(-1)) { - return true - } - if val.Equal(decimal.NewFromInt(val.IntPart())) { - return true - } - return false -} - -// getPrecInc returns the max curIntermediatePrecisionInc by searching the children -// of the expression given. This allows us to keep track of the appropriate value -// of curIntermediatePrecisionInc that is used to storing scale number for the decimal value. -func getPrecInc(e sql.Expression, cur int) int { - if e == nil { - return 0 - } - - if d, ok := e.(*Div); ok { - if d.curIntermediatePrecisionInc > cur { - return d.curIntermediatePrecisionInc - } - l := getPrecInc(d.LeftChild, cur) - if l > cur { - cur = l - } - r := getPrecInc(d.RightChild, cur) - if r > cur { - cur = r - } - return cur - } else if d, ok := e.(ArithmeticOp); ok { - l := getPrecInc(d.Left(), cur) - if l > cur { - cur = l - } - r := getPrecInc(d.Right(), cur) - if r > cur { - cur = r - } - return cur - } else { - return cur - } -} - var _ ArithmeticOp = (*IntDiv)(nil) var _ sql.CollationCoercible = (*IntDiv)(nil) From 44d928de0c484b1d8aec0186cd0090b1751d004c Mon Sep 17 00:00:00 2001 From: James Cor Date: Thu, 15 Feb 2024 13:00:02 -0800 Subject: [PATCH 16/27] improve getFinalScale to handle leftmost div --- sql/expression/arithmetic.go | 11 ++--- sql/expression/div.go | 82 +++++++++++++++++++++++------------- sql/expression/mod.go | 6 +-- 3 files changed, 59 insertions(+), 40 deletions(-) diff --git a/sql/expression/arithmetic.go b/sql/expression/arithmetic.go index 8b8ca7f552..bc4e44a75d 100644 --- a/sql/expression/arithmetic.go +++ b/sql/expression/arithmetic.go @@ -211,10 +211,10 @@ func (a *Arithmetic) Type() sql.Type { } if types.IsDecimal(lTyp) && types.IsDecimal(rTyp) { - lPrec := lTyp.(types.DecimalType_).Precision() - lScale := lTyp.(types.DecimalType_).Scale() - rPrec := rTyp.(types.DecimalType_).Precision() - rScale := rTyp.(types.DecimalType_).Scale() + lPrec := lTyp.(sql.DecimalType).Precision() + lScale := lTyp.(sql.DecimalType).Scale() + rPrec := rTyp.(sql.DecimalType).Precision() + rScale := rTyp.(sql.DecimalType).Scale() var prec, scale uint8 if lPrec > rPrec { @@ -303,8 +303,9 @@ func (a *Arithmetic) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { // Decimals must be rounded if res, ok := result.(decimal.Decimal); ok { if isOutermostArithmeticOp(a, a.ops) { - finalScale, hasDiv := getFinalScale(a) + finalScale, hasDiv := getFinalScale(ctx, row, a,0) if hasDiv { + // TODO: should always round regardless; we have bad Decimal defaults return res.Round(finalScale), nil } } diff --git a/sql/expression/div.go b/sql/expression/div.go index fc1df0578c..ea21d9abc6 100644 --- a/sql/expression/div.go +++ b/sql/expression/div.go @@ -49,7 +49,7 @@ type Div struct { BinaryExpressionStub ops int32 // divScale is number of continuous division operations; this value will be available of all layers - divScale int32 + divScale int32 // TODO: calling this divScale is confusing // leftmostScale is a length of scale of the leftmost value in continuous division operation // It is accessed concurrently read in the .Type() and written in the .Eval() methods. leftmostScale atomic.Int32 @@ -121,14 +121,6 @@ func (d *Div) WithChildren(children ...sql.Expression) (sql.Expression, error) { // Eval implements the Expression interface. func (d *Div) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { - // we need to get the scale of the leftmost value of all continuous division - // for the final result rounding precision. This only is able to happen in the - // outermost layer, which is where we use this value to round the final result. - // we do not round the value until it's the last division operation. - if isOutermostDiv(d, 0, d.divScale) { - d.leftmostScale.Store(getScaleOfLeftmostValue(ctx, row, d, 0, d.divScale)) - } - lval, rval, err := d.evalLeftRight(ctx, row) if err != nil { return nil, err @@ -145,19 +137,11 @@ func (d *Div) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { return nil, err } - // we do not round the value until it's the last division operation. - if isOutermostDiv(d, 0, d.divScale) { - // We prefer using floats internally for division operations, but if this expressions output type - // is a Decimal, make sure we convert the result and return it as a decimal. - if types.IsDecimal(d.Type()) { - result = convertValueToType(ctx, types.InternalDecimalType, result, false) - } - - if res, ok := result.(decimal.Decimal); ok { - if isOutermostArithmeticOp(d, d.ops) { - finalScale, _ := getFinalScale(d) - return res.Round(finalScale), nil - } + // Decimals must be rounded + if res, ok := result.(decimal.Decimal); ok { + if isOutermostArithmeticOp(d, d.ops) { + finalScale, _ := getFinalScale(ctx, row, d, 0) + return res.Round(finalScale), nil } } @@ -575,14 +559,37 @@ func isOutermostDiv(e sql.Expression, d, dScale int32) bool { } // getFinalScale returns the final scale of the result value. -// it traverses both the left and right nodes looking for Div nodes -func getFinalScale(e sql.Expression) (int32, bool) { +// it traverses both the left and right nodes looking for Div, Arithmetic, and Literal nodes +func getFinalScale(ctx *sql.Context, row sql.Row, e sql.Expression, d int32) (int32, bool) { if e == nil { return 0, false } - if d, isDiv := e.(*Div); isDiv { - finalScale := d.leftmostScale.Load() + d.divScale*int32(divPrecInc) + if div, isDiv := e.(*Div); isDiv { + // TODO: there's gotta be a better way of determining if this is the leftmost div... + finalScale := int32(divPrecInc) + d = d + 1 + if d == div.divScale { + // TODO: redundant call to Eval for LeftChild + lval, err := div.LeftChild.Eval(ctx, row) + if err != nil { + return 0, false + } + _, s := GetPrecisionAndScale(lval) + typ := div.LeftChild.Type() + if dt, dok := typ.(sql.DecimalType); dok { + ts := dt.Scale() + if ts > s { + s = ts + } + } + finalScale += int32(s) + } else { + // We only care about left scale for divs + leftScale, _ := getFinalScale(ctx, row, div.LeftChild, d) + finalScale += leftScale + } + if finalScale > types.DecimalTypeMaxScale { finalScale = types.DecimalTypeMaxScale } @@ -590,8 +597,8 @@ func getFinalScale(e sql.Expression) (int32, bool) { } if a, isArith := e.(*Arithmetic); isArith { - leftScale, leftHasDiv := getFinalScale(a.Left()) - rightScale, rightHasDiv := getFinalScale(a.Right()) + leftScale, leftHasDiv := getFinalScale(ctx, row, a.Left(), d) + rightScale, rightHasDiv := getFinalScale(ctx, row, a.Right(), d) var finalScale int32 switch a.Operator() { case sqlparser.PlusStr, sqlparser.MinusStr: @@ -609,6 +616,22 @@ func getFinalScale(e sql.Expression) (int32, bool) { return finalScale, leftHasDiv || rightHasDiv } + // TODO: this is just a guess of what mod should do with scale; test this + if m, isMod := e.(*Mod); isMod { + leftScale, leftHasDiv := getFinalScale(ctx, row, m.LeftChild, d) + rightScale, rightHasDiv := getFinalScale(ctx, row, m.RightChild, d) + finalScale := leftScale + if rightScale > finalScale { + finalScale = rightScale + } + if finalScale > types.DecimalTypeMaxScale { + finalScale = types.DecimalTypeMaxScale + } + return finalScale, leftHasDiv || rightHasDiv + } + + // TODO: likely need a case for IntDiv + s := uint8(0) if lit, isLit := e.(*Literal); isLit { _, s = GetPrecisionAndScale(lit.value) @@ -771,8 +794,7 @@ func (i *IntDiv) convertLeftRight(ctx *sql.Context, left interface{}, right inte } else if (lIsTimeType && rIsTimeType) || (types.IsSigned(lTyp) && types.IsSigned(rTyp)) { typ = types.Int64 } else { - // using max precision which is 65. - typ = types.MustCreateDecimalType(65, 0) + typ = types.MustCreateDecimalType(types.DecimalTypeMaxPrecision, 0) } if types.IsInteger(typ) || types.IsFloat(typ) { diff --git a/sql/expression/mod.go b/sql/expression/mod.go index 9fabb6e535..c49ee1313b 100644 --- a/sql/expression/mod.go +++ b/sql/expression/mod.go @@ -149,13 +149,9 @@ func (m *Mod) convertLeftRight(ctx *sql.Context, left interface{}, right interfa if types.IsFloat(typ) { left = convertValueToType(ctx, typ, left, lIsTimeType) - } else { - left = convertToDecimalValue(left, lIsTimeType) - } - - if types.IsFloat(typ) { right = convertValueToType(ctx, typ, right, rIsTimeType) } else { + left = convertToDecimalValue(left, lIsTimeType) right = convertToDecimalValue(right, rIsTimeType) } From c2735557119e89a6571006cd8dcd7948b951054e Mon Sep 17 00:00:00 2001 From: jycor Date: Thu, 15 Feb 2024 21:01:37 +0000 Subject: [PATCH 17/27] [ga-format-pr] Run ./format_repo.sh to fix formatting --- sql/expression/arithmetic.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/expression/arithmetic.go b/sql/expression/arithmetic.go index bc4e44a75d..2f172f7ac4 100644 --- a/sql/expression/arithmetic.go +++ b/sql/expression/arithmetic.go @@ -303,7 +303,7 @@ func (a *Arithmetic) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { // Decimals must be rounded if res, ok := result.(decimal.Decimal); ok { if isOutermostArithmeticOp(a, a.ops) { - finalScale, hasDiv := getFinalScale(ctx, row, a,0) + finalScale, hasDiv := getFinalScale(ctx, row, a, 0) if hasDiv { // TODO: should always round regardless; we have bad Decimal defaults return res.Round(finalScale), nil From bec19e9afe0280194c4faacb603a502d34ef92d9 Mon Sep 17 00:00:00 2001 From: James Cor Date: Thu, 15 Feb 2024 13:07:27 -0800 Subject: [PATCH 18/27] feedback --- enginetest/queries/queries.go | 2 +- sql/expression/comparison.go | 12 ++++++++---- sql/expression/div.go | 11 +++-------- sql/expression/in.go | 6 ++---- 4 files changed, 14 insertions(+), 17 deletions(-) diff --git a/enginetest/queries/queries.go b/enginetest/queries/queries.go index bdcefccf9d..2733d2fc0a 100644 --- a/enginetest/queries/queries.go +++ b/enginetest/queries/queries.go @@ -6432,7 +6432,7 @@ Select * from ( Expected: []sql.Row{{1}}, }, { - // TODO: this is invalid... + // TODO: Neither MySQL or MariaDB have a function called DATETIME; remove this function. Query: `SELECT DATETIME(NOW()) - NOW()`, Expected: []sql.Row{{int64(0)}}, }, diff --git a/sql/expression/comparison.go b/sql/expression/comparison.go index 176e30e444..40d794c031 100644 --- a/sql/expression/comparison.go +++ b/sql/expression/comparison.go @@ -67,12 +67,16 @@ type comparison struct { BinaryExpressionStub } +// disableRounding disables rounding for the given expression. +func disableRounding(expr sql.Expression) { + setArithmeticOps(expr, -1) + setDivs(expr, -1) +} + func newComparison(left, right sql.Expression) comparison { // TODO: somewhat hacky way to disable rounding for comparisons - setArithmeticOps(left, -1) - setArithmeticOps(right, -1) - setDivs(left, -1) - setDivs(right, -1) + disableRounding(left) + disableRounding(right) return comparison{BinaryExpressionStub{left, right}} } diff --git a/sql/expression/div.go b/sql/expression/div.go index ea21d9abc6..79a7540c50 100644 --- a/sql/expression/div.go +++ b/sql/expression/div.go @@ -288,11 +288,6 @@ func (d *Div) determineResultType(outermostResult bool) sql.Type { return types.Float64 } - // TODO: see if we can actually do this - //if !outermostResult { - // return types.Float64 - //} - // Decimal only results from here on if types.IsDatetimeType(lTyp) { @@ -307,10 +302,10 @@ func (d *Div) determineResultType(outermostResult bool) sql.Type { } if types.IsDecimal(lTyp) { - prec, scale := lTyp.(types.DecimalType_).Precision(), lTyp.(types.DecimalType_).Scale() + prec, scale := lTyp.(sql.DecimalType).Precision(), lTyp.(sql.DecimalType).Scale() scale = scale + divPrecInc if d.ops == -1 { - scale = (scale/9 + 1) * 9 + scale = (scale / divIntPrecInc + 1) * divIntPrecInc prec = prec + scale } else { prec = prec + divPrecInc @@ -327,7 +322,7 @@ func (d *Div) determineResultType(outermostResult bool) sql.Type { // All other types are treated as if they were integers if d.ops == -1 { - return types.MustCreateDecimalType(types.DecimalTypeMaxPrecision, 9) + return types.MustCreateDecimalType(types.DecimalTypeMaxPrecision, divIntPrecInc) } return types.MustCreateDecimalType(types.DecimalTypeMaxPrecision, divPrecInc) } diff --git a/sql/expression/in.go b/sql/expression/in.go index c337e031bb..d9c695fa42 100644 --- a/sql/expression/in.go +++ b/sql/expression/in.go @@ -54,10 +54,8 @@ func (in *InTuple) Right() sql.Expression { // NewInTuple creates an InTuple expression. func NewInTuple(left sql.Expression, right sql.Expression) *InTuple { - setArithmeticOps(left, -1) - setArithmeticOps(right, -1) - setDivs(left, -1) - setDivs(right, -1) + disableRounding(left) + disableRounding(right) return &InTuple{BinaryExpressionStub{left, right}} } From ae8cb67bd46714571e46da1ed104bdbbbf46e74c Mon Sep 17 00:00:00 2001 From: jycor Date: Thu, 15 Feb 2024 21:08:55 +0000 Subject: [PATCH 19/27] [ga-format-pr] Run ./format_repo.sh to fix formatting --- sql/expression/div.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/expression/div.go b/sql/expression/div.go index 79a7540c50..7d067b6faf 100644 --- a/sql/expression/div.go +++ b/sql/expression/div.go @@ -305,7 +305,7 @@ func (d *Div) determineResultType(outermostResult bool) sql.Type { prec, scale := lTyp.(sql.DecimalType).Precision(), lTyp.(sql.DecimalType).Scale() scale = scale + divPrecInc if d.ops == -1 { - scale = (scale / divIntPrecInc + 1) * divIntPrecInc + scale = (scale/divIntPrecInc + 1) * divIntPrecInc prec = prec + scale } else { prec = prec + divPrecInc From d53a2620085f00c68f910125f3a62e12c7114d27 Mon Sep 17 00:00:00 2001 From: James Cor Date: Thu, 15 Feb 2024 13:28:56 -0800 Subject: [PATCH 20/27] improving comment --- sql/expression/arithmetic.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/expression/arithmetic.go b/sql/expression/arithmetic.go index 2f172f7ac4..0390dcddf1 100644 --- a/sql/expression/arithmetic.go +++ b/sql/expression/arithmetic.go @@ -135,7 +135,7 @@ func (a *Arithmetic) Type() sql.Type { // applies for + and - ops if isInterval(a.LeftChild) || isInterval(a.RightChild) { // TODO: need to use the precision stored in datetimeType; something like - // return MustCreateDatetimeType(sqltypes.Datetime, ...) + // return types.MustCreateDatetimeType(sqltypes.Datetime, 0) return types.Datetime } From 5cc9d47f48020e1a04e47255124f3412e1e85ec6 Mon Sep 17 00:00:00 2001 From: James Cor Date: Thu, 15 Feb 2024 13:56:59 -0800 Subject: [PATCH 21/27] tidying up variable names --- sql/expression/arithmetic.go | 10 +-- sql/expression/comparison.go | 2 +- sql/expression/convert.go | 6 +- sql/expression/div.go | 160 ++++++++++++----------------------- 4 files changed, 62 insertions(+), 116 deletions(-) diff --git a/sql/expression/arithmetic.go b/sql/expression/arithmetic.go index 0390dcddf1..9488dbe517 100644 --- a/sql/expression/arithmetic.go +++ b/sql/expression/arithmetic.go @@ -397,20 +397,20 @@ func countArithmeticOps(e sql.Expression) int32 { // setArithmeticOps will set ops number with number counted by countArithmeticOps. This allows // us to keep track of whether the expression is the last arithmetic operation. -func setArithmeticOps(e sql.Expression, opScale int32) { +func setArithmeticOps(e sql.Expression, ops int32) { if e == nil { return } if a, ok := e.(ArithmeticOp); ok { - a.SetOpCount(opScale) - setArithmeticOps(a.Left(), opScale) - setArithmeticOps(a.Right(), opScale) + a.SetOpCount(ops) + setArithmeticOps(a.Left(), ops) + setArithmeticOps(a.Right(), ops) } if tup, ok := e.(Tuple); ok { for _, expr := range tup { - setArithmeticOps(expr, opScale) + setArithmeticOps(expr, ops) } } diff --git a/sql/expression/comparison.go b/sql/expression/comparison.go index 40d794c031..1aa894b93c 100644 --- a/sql/expression/comparison.go +++ b/sql/expression/comparison.go @@ -70,7 +70,7 @@ type comparison struct { // disableRounding disables rounding for the given expression. func disableRounding(expr sql.Expression) { setArithmeticOps(expr, -1) - setDivs(expr, -1) + setDivOps(expr, -1) } func newComparison(left, right sql.Expression) comparison { diff --git a/sql/expression/convert.go b/sql/expression/convert.go index 0173fedb93..c4139c3255 100644 --- a/sql/expression/convert.go +++ b/sql/expression/convert.go @@ -82,8 +82,7 @@ var _ sql.CollationCoercible = (*Convert)(nil) // |castToType| type. All optional parameters (i.e. typeLength, typeScale, and charset) are omitted and initialized // to their zero values. func NewConvert(expr sql.Expression, castToType string) *Convert { - setArithmeticOps(expr, -1) - setDivs(expr, -1) + disableRounding(expr) return &Convert{ UnaryExpression: UnaryExpression{Child: expr}, castToType: strings.ToLower(castToType), @@ -94,8 +93,7 @@ func NewConvert(expr sql.Expression, castToType string) *Convert { // |castToType| type, with |typeLength| specifying a length constraint of the converted type, and |typeScale| specifying // a scale constraint of the converted type. func NewConvertWithLengthAndScale(expr sql.Expression, castToType string, typeLength, typeScale int) *Convert { - setArithmeticOps(expr, -1) - setDivs(expr, -1) + disableRounding(expr) return &Convert{ UnaryExpression: UnaryExpression{Child: expr}, castToType: strings.ToLower(castToType), diff --git a/sql/expression/div.go b/sql/expression/div.go index 7d067b6faf..c985540768 100644 --- a/sql/expression/div.go +++ b/sql/expression/div.go @@ -19,8 +19,7 @@ import ( "math" "strconv" "strings" - "sync/atomic" - "time" + "time" "github.com/dolthub/vitess/go/vt/sqlparser" "github.com/shopspring/decimal" @@ -47,27 +46,16 @@ var _ sql.CollationCoercible = (*Div)(nil) // Div expression represents "/" arithmetic operation type Div struct { BinaryExpressionStub - ops int32 - // divScale is number of continuous division operations; this value will be available of all layers - divScale int32 // TODO: calling this divScale is confusing - // leftmostScale is a length of scale of the leftmost value in continuous division operation - // It is accessed concurrently read in the .Type() and written in the .Eval() methods. - leftmostScale atomic.Int32 - curIntermediatePrecisionInc int + ops int32 + divOps int32 } // NewDiv creates a new Div / sql.Expression. func NewDiv(left, right sql.Expression) *Div { - a := &Div{ - BinaryExpressionStub: BinaryExpressionStub{LeftChild: left, RightChild: right}, - curIntermediatePrecisionInc: 0, - } - a.leftmostScale.Store(0) - divs := countDivs(a) - setDivs(a, divs) - ops := countArithmeticOps(a) - setArithmeticOps(a, ops) - return a + d := &Div{BinaryExpressionStub: BinaryExpressionStub{LeftChild: left, RightChild: right}} + setDivOps(d, countDivOps(d)) + setArithmeticOps(d, countArithmeticOps(d)) + return d } func (d *Div) Operator() string { @@ -96,7 +84,7 @@ func (d *Div) IsNullable() bool { // However, if this is the outermost division expression in an expression tree, we must return the result as a // Decimal type in order to match MySQL's results exactly. func (d *Div) Type() sql.Type { - return d.determineResultType(isOutermostDiv(d, 0, d.divScale)) + return d.determineResultType(isOutermostDiv(d, 0, d.divOps)) } // internalType returns the internal result type for this division expression. For performance reasons, we prefer @@ -448,80 +436,41 @@ func convertToDecimalValue(val interface{}, isTimeType bool) interface{} { // 'div' 2 // / \ // 24 3 -func countDivs(e sql.Expression) int32 { +func countDivOps(e sql.Expression) int32 { if e == nil { return 0 } - if a, ok := e.(*Div); ok { - return countDivs(a.LeftChild) + 1 + return countDivOps(a.LeftChild) + 1 } - if a, ok := e.(ArithmeticOp); ok { - return countDivs(a.Left()) + return countDivOps(a.Left()) } - return 0 } // setDivs will set each node's DivScale to the number counted by countDivs. This allows us to // keep track of whether the current Div expression is the last Div operation, so the result is // rounded appropriately. -func setDivs(e sql.Expression, dScale int32) { +func setDivOps(e sql.Expression, divOps int32) { if e == nil { return } - if a, isArithmeticOp := e.(ArithmeticOp); isArithmeticOp { if d, ok := a.(*Div); ok { - d.divScale = dScale + d.divOps = divOps } - setDivs(a.Left(), dScale) - setDivs(a.Right(), dScale) + setDivOps(a.Left(), divOps) + setDivOps(a.Right(), divOps) } - if tup, ok := e.(Tuple); ok { for _, expr := range tup { - setDivs(expr, dScale) + setDivOps(expr, divOps) } } - return } -// getScaleOfLeftmostValue find the leftmost/first value of all continuous divisions. -// E.g. 24/50/3.2/2/1 will return 2 for len('50') of number '24.50'. -func getScaleOfLeftmostValue(ctx *sql.Context, row sql.Row, e sql.Expression, d, dScale int32) int32 { - if e == nil { - return 0 - } - - if a, ok := e.(*Div); ok { - d = d + 1 - if d == dScale { - lval, err := a.LeftChild.Eval(ctx, row) - if err != nil { - return 0 - } - _, s := GetPrecisionAndScale(lval) - // the leftmost value can be row value of decimal type column - // the evaluated value does not always match the scale of column type definition - typ := a.LeftChild.Type() - if dt, dok := typ.(sql.DecimalType); dok { - ts := dt.Scale() - if ts > s { - s = ts - } - } - return int32(s) - } else { - return getScaleOfLeftmostValue(ctx, row, a.LeftChild, d, dScale) - } - } - - return 0 -} - // isOutermostDiv returns whether the expression we're currently evaluating is // the last division operation of all continuous divisions. // E.g. the top 'div' (divided by 1) is the outermost/last division that is calculated: @@ -555,16 +504,16 @@ func isOutermostDiv(e sql.Expression, d, dScale int32) bool { // getFinalScale returns the final scale of the result value. // it traverses both the left and right nodes looking for Div, Arithmetic, and Literal nodes -func getFinalScale(ctx *sql.Context, row sql.Row, e sql.Expression, d int32) (int32, bool) { - if e == nil { +func getFinalScale(ctx *sql.Context, row sql.Row, expr sql.Expression, divOpCnt int32) (int32, bool) { + if expr == nil { return 0, false } - if div, isDiv := e.(*Div); isDiv { + if div, isDiv := expr.(*Div); isDiv { // TODO: there's gotta be a better way of determining if this is the leftmost div... - finalScale := int32(divPrecInc) - d = d + 1 - if d == div.divScale { + fScale := int32(divPrecInc) + divOpCnt = divOpCnt + 1 + if divOpCnt == div.divOps { // TODO: redundant call to Eval for LeftChild lval, err := div.LeftChild.Eval(ctx, row) if err != nil { @@ -578,68 +527,67 @@ func getFinalScale(ctx *sql.Context, row sql.Row, e sql.Expression, d int32) (in s = ts } } - finalScale += int32(s) + fScale += int32(s) } else { // We only care about left scale for divs - leftScale, _ := getFinalScale(ctx, row, div.LeftChild, d) - finalScale += leftScale + lScale, _ := getFinalScale(ctx, row, div.LeftChild, divOpCnt) + fScale += lScale } - if finalScale > types.DecimalTypeMaxScale { - finalScale = types.DecimalTypeMaxScale + if fScale > types.DecimalTypeMaxScale { + fScale = types.DecimalTypeMaxScale } - return finalScale, true + return fScale, true } - if a, isArith := e.(*Arithmetic); isArith { - leftScale, leftHasDiv := getFinalScale(ctx, row, a.Left(), d) - rightScale, rightHasDiv := getFinalScale(ctx, row, a.Right(), d) - var finalScale int32 + if a, isArith := expr.(*Arithmetic); isArith { + lScale, lHasDiv := getFinalScale(ctx, row, a.Left(), divOpCnt) + rScale, rHasDiv := getFinalScale(ctx, row, a.Right(), divOpCnt) + var fScale int32 switch a.Operator() { case sqlparser.PlusStr, sqlparser.MinusStr: - if leftScale > rightScale { - finalScale = leftScale + if lScale > rScale { + fScale = lScale } else { - finalScale = rightScale + fScale = rScale } case sqlparser.MultStr: - finalScale = leftScale + rightScale + fScale = lScale + rScale } - if finalScale > types.DecimalTypeMaxScale { - finalScale = types.DecimalTypeMaxScale + if fScale > types.DecimalTypeMaxScale { + fScale = types.DecimalTypeMaxScale } - return finalScale, leftHasDiv || rightHasDiv + return fScale, lHasDiv || rHasDiv } // TODO: this is just a guess of what mod should do with scale; test this - if m, isMod := e.(*Mod); isMod { - leftScale, leftHasDiv := getFinalScale(ctx, row, m.LeftChild, d) - rightScale, rightHasDiv := getFinalScale(ctx, row, m.RightChild, d) - finalScale := leftScale - if rightScale > finalScale { - finalScale = rightScale + if m, isMod := expr.(*Mod); isMod { + fScale, leftHasDiv := getFinalScale(ctx, row, m.LeftChild, divOpCnt) + rScale, rightHasDiv := getFinalScale(ctx, row, m.RightChild, divOpCnt) + if rScale > fScale { + fScale = rScale } - if finalScale > types.DecimalTypeMaxScale { - finalScale = types.DecimalTypeMaxScale + if fScale > types.DecimalTypeMaxScale { + fScale = types.DecimalTypeMaxScale } - return finalScale, leftHasDiv || rightHasDiv + return fScale, leftHasDiv || rightHasDiv } // TODO: likely need a case for IntDiv - s := uint8(0) - if lit, isLit := e.(*Literal); isLit { - _, s = GetPrecisionAndScale(lit.value) + var fScale uint8 + if lit, isLit := expr.(*Literal); isLit { + _, fScale = GetPrecisionAndScale(lit.value) } - typ := e.Type() + typ := expr.Type() if dt, dok := typ.(sql.DecimalType); dok { ts := dt.Scale() - if ts > s { - s = ts + if ts > fScale { + fScale = ts } } - return int32(s), false + return int32(fScale), false } // GetDecimalPrecisionAndScale returns precision and scale for given string formatted float/double number. From af48541e5637a25719505864a7e86ee7f8c2d652 Mon Sep 17 00:00:00 2001 From: jycor Date: Thu, 15 Feb 2024 21:58:33 +0000 Subject: [PATCH 22/27] [ga-format-pr] Run ./format_repo.sh to fix formatting --- sql/expression/div.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/expression/div.go b/sql/expression/div.go index c985540768..7606e040cc 100644 --- a/sql/expression/div.go +++ b/sql/expression/div.go @@ -19,7 +19,7 @@ import ( "math" "strconv" "strings" - "time" + "time" "github.com/dolthub/vitess/go/vt/sqlparser" "github.com/shopspring/decimal" @@ -52,7 +52,7 @@ type Div struct { // NewDiv creates a new Div / sql.Expression. func NewDiv(left, right sql.Expression) *Div { - d := &Div{BinaryExpressionStub: BinaryExpressionStub{LeftChild: left, RightChild: right}} + d := &Div{BinaryExpressionStub: BinaryExpressionStub{LeftChild: left, RightChild: right}} setDivOps(d, countDivOps(d)) setArithmeticOps(d, countArithmeticOps(d)) return d From a77c57a11bad6c0eb06dab576868b73255ef338b Mon Sep 17 00:00:00 2001 From: James Cor Date: Thu, 15 Feb 2024 14:00:17 -0800 Subject: [PATCH 23/27] small clean --- sql/expression/comparison.go | 1 - sql/expression/div.go | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/sql/expression/comparison.go b/sql/expression/comparison.go index 1aa894b93c..df6ac73983 100644 --- a/sql/expression/comparison.go +++ b/sql/expression/comparison.go @@ -74,7 +74,6 @@ func disableRounding(expr sql.Expression) { } func newComparison(left, right sql.Expression) comparison { - // TODO: somewhat hacky way to disable rounding for comparisons disableRounding(left) disableRounding(right) return comparison{BinaryExpressionStub{left, right}} diff --git a/sql/expression/div.go b/sql/expression/div.go index c985540768..6636a9a7c0 100644 --- a/sql/expression/div.go +++ b/sql/expression/div.go @@ -19,7 +19,7 @@ import ( "math" "strconv" "strings" - "time" + "time" "github.com/dolthub/vitess/go/vt/sqlparser" "github.com/shopspring/decimal" From b241aaa763385e0df602689bea0837138224cada Mon Sep 17 00:00:00 2001 From: James Cor Date: Thu, 15 Feb 2024 15:01:50 -0800 Subject: [PATCH 24/27] microbenchmarks --- sql/expression/div_test.go | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/sql/expression/div_test.go b/sql/expression/div_test.go index a76b047a44..d40cd7351f 100644 --- a/sql/expression/div_test.go +++ b/sql/expression/div_test.go @@ -373,3 +373,36 @@ func TestIntDiv(t *testing.T) { }) } } + + +func BenchmarkDiv(b *testing.B) { + require := require.New(b) + ctx := sql.NewEmptyContext() + l := NewLiteral(1, types.Int64) + r := NewLiteral(3, types.Int64) + d := NewDiv(l, r) + for i := 0; i < b.N; i++ { + res, err := d.Eval(ctx, nil) + require.NoError(err) + if dec, ok := res.(decimal.Decimal); ok { + res = dec.StringFixed(dec.Exponent() * -1) + } + require.Equal(res, "0.3333") + } +} + +func BenchmarkDivFloat(b *testing.B) { + require := require.New(b) + ctx := sql.NewEmptyContext() + l := NewLiteral(1.0, types.Float64) + r := NewLiteral(3.0, types.Float64) + d := NewDiv(l, r) + for i := 0; i < b.N; i++ { + res, err := d.Eval(ctx, nil) + require.NoError(err) + if dec, ok := res.(decimal.Decimal); ok { + res = dec.StringFixed(dec.Exponent() * -1) + } + require.Equal(res, 1.0/3.0) + } +} \ No newline at end of file From 69ec5f378379f6bbd4b516562e70f4dd3ecf960c Mon Sep 17 00:00:00 2001 From: James Cor Date: Thu, 15 Feb 2024 16:51:05 -0800 Subject: [PATCH 25/27] benchmarks --- sql/expression/div_test.go | 140 ++++++++++++++++++++++++++++++++----- 1 file changed, 123 insertions(+), 17 deletions(-) diff --git a/sql/expression/div_test.go b/sql/expression/div_test.go index d40cd7351f..7042916384 100644 --- a/sql/expression/div_test.go +++ b/sql/expression/div_test.go @@ -375,34 +375,140 @@ func TestIntDiv(t *testing.T) { } -func BenchmarkDiv(b *testing.B) { +// Results: +// BenchmarkDivInt-16 365416 3117 ns/op +func BenchmarkDivInt(b *testing.B) { require := require.New(b) ctx := sql.NewEmptyContext() - l := NewLiteral(1, types.Int64) - r := NewLiteral(3, types.Int64) - d := NewDiv(l, r) + div := NewDiv( + NewLiteral(1, types.Int64), + NewLiteral(3, types.Int64), + ) + var res interface{} + var err error for i := 0; i < b.N; i++ { - res, err := d.Eval(ctx, nil) + res, err = div.Eval(ctx, nil) require.NoError(err) - if dec, ok := res.(decimal.Decimal); ok { - res = dec.StringFixed(dec.Exponent() * -1) - } - require.Equal(res, "0.3333") + } + if dec, ok := res.(decimal.Decimal); ok { + res = dec.StringFixed(dec.Exponent() * -1) + } + exp := "0.3333" + if res != exp { + b.Logf("Expected %v, got %v", exp, res) } } +// Results: +// BenchmarkDivFloat-16 1521937 787.7 ns/op func BenchmarkDivFloat(b *testing.B) { require := require.New(b) ctx := sql.NewEmptyContext() - l := NewLiteral(1.0, types.Float64) - r := NewLiteral(3.0, types.Float64) - d := NewDiv(l, r) + div := NewDiv( + NewLiteral(1.0, types.Float64), + NewLiteral(3.0, types.Float64), + ) + var res interface{} + var err error for i := 0; i < b.N; i++ { - res, err := d.Eval(ctx, nil) + res, err = div.Eval(ctx, nil) require.NoError(err) - if dec, ok := res.(decimal.Decimal); ok { - res = dec.StringFixed(dec.Exponent() * -1) - } - require.Equal(res, 1.0/3.0) + } + exp := 1.0/3.0 + if res != exp { + b.Logf("Expected %v, got %v", exp, res) + } +} + +// Results: +// BenchmarkDivHighScaleDecimals-16 294921 3901 ns/op +func BenchmarkDivHighScaleDecimals(b *testing.B) { + require := require.New(b) + ctx := sql.NewEmptyContext() + div := NewDiv( + NewLiteral(decimal.NewFromFloat(0.123456789), types.MustCreateDecimalType(types.DecimalTypeMaxPrecision, types.DecimalTypeMaxScale)), + NewLiteral(decimal.NewFromFloat(0.987654321), types.MustCreateDecimalType(types.DecimalTypeMaxPrecision, types.DecimalTypeMaxScale)), + ) + var res interface{} + var err error + for i := 0; i < b.N; i++ { + res, err = div.Eval(ctx, nil) + require.NoError(err) + } + if dec, ok := res.(decimal.Decimal); ok { + res = dec.StringFixed(dec.Exponent() * -1) + } + exp := "0.124999998860937500014238281250" + if res != exp { + b.Logf("Expected %v, got %v", exp, res) + } +} + +// Results: +// BenchmarkDivManyInts-16 40711 29372 ns/op +func BenchmarkDivManyInts(b *testing.B) { + require := require.New(b) + var div sql.Expression = NewLiteral(1, types.Int64) + for i := 2; i < 10; i++ { + div = NewDiv(div, NewLiteral(int64(i), types.Int64)) + } + ctx := sql.NewEmptyContext() + var res interface{} + var err error + for i := 0; i < b.N; i++ { + res, err = div.Eval(ctx, nil) + require.NoError(err) + } + if dec, ok := res.(decimal.Decimal); ok { + res = dec.StringFixed(dec.Exponent() * -1) + } + exp := "0.000002755731922398589054232804" + if res != exp { + b.Logf("Expected %v, got %v", exp, res) + } +} + +// Results: +// BenchmarkManyFloats-16 174555 6666 ns/op +func BenchmarkManyFloats(b *testing.B) { + require := require.New(b) + ctx := sql.NewEmptyContext() + var div sql.Expression = NewLiteral(1.0, types.Float64) + for i := 2; i < 10; i++ { + div = NewDiv(div, NewLiteral(float64(i), types.Float64)) + } + var res interface{} + var err error + for i := 0; i < b.N; i++ { + res, err = div.Eval(ctx, nil) + require.NoError(err) + } + exp := 1.0/2.0/3.0/4.0/5.0/6.0/7.0/8.0/9.0 + if res != exp { + b.Logf("Expected %v, got %v", exp, res) + } +} + +// Results: +// BenchmarkDivManyDecimals-16 52053 23134 ns/op +func BenchmarkDivManyDecimals(b *testing.B) { + require := require.New(b) + var div sql.Expression = NewLiteral(decimal.NewFromInt(int64(1)), types.DecimalType_{}) + for i := 2; i < 10; i++ { + div = NewDiv(div, NewLiteral(decimal.NewFromInt(int64(i)), types.DecimalType_{})) + } + ctx := sql.NewEmptyContext() + var res interface{} + var err error + for i := 0; i < b.N; i++ { + res, err = div.Eval(ctx, nil) + require.NoError(err) + } + if dec, ok := res.(decimal.Decimal); ok { + res = dec.StringFixed(dec.Exponent() * -1) + } + exp := "0.000002755731922398589054232804" + if res != exp { + b.Logf("Expected %v, got %v", exp, res) } } \ No newline at end of file From f769a015421872189539ff81d3159da1fa0329ca Mon Sep 17 00:00:00 2001 From: James Cor Date: Thu, 15 Feb 2024 16:51:31 -0800 Subject: [PATCH 26/27] new line --- sql/expression/div_test.go | 1 - 1 file changed, 1 deletion(-) diff --git a/sql/expression/div_test.go b/sql/expression/div_test.go index 7042916384..ca9e82338d 100644 --- a/sql/expression/div_test.go +++ b/sql/expression/div_test.go @@ -374,7 +374,6 @@ func TestIntDiv(t *testing.T) { } } - // Results: // BenchmarkDivInt-16 365416 3117 ns/op func BenchmarkDivInt(b *testing.B) { From 7b5c2f1f7a115ebab07acf8ae6a08da1663e2af5 Mon Sep 17 00:00:00 2001 From: jycor Date: Fri, 16 Feb 2024 00:58:21 +0000 Subject: [PATCH 27/27] [ga-format-pr] Run ./format_repo.sh to fix formatting --- sql/expression/div_test.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/expression/div_test.go b/sql/expression/div_test.go index ca9e82338d..6ef9f95470 100644 --- a/sql/expression/div_test.go +++ b/sql/expression/div_test.go @@ -413,7 +413,7 @@ func BenchmarkDivFloat(b *testing.B) { res, err = div.Eval(ctx, nil) require.NoError(err) } - exp := 1.0/3.0 + exp := 1.0 / 3.0 if res != exp { b.Logf("Expected %v, got %v", exp, res) } @@ -482,7 +482,7 @@ func BenchmarkManyFloats(b *testing.B) { res, err = div.Eval(ctx, nil) require.NoError(err) } - exp := 1.0/2.0/3.0/4.0/5.0/6.0/7.0/8.0/9.0 + exp := 1.0 / 2.0 / 3.0 / 4.0 / 5.0 / 6.0 / 7.0 / 8.0 / 9.0 if res != exp { b.Logf("Expected %v, got %v", exp, res) } @@ -510,4 +510,4 @@ func BenchmarkDivManyDecimals(b *testing.B) { if res != exp { b.Logf("Expected %v, got %v", exp, res) } -} \ No newline at end of file +}