From ce90883143367a6067b2f501cca6847371e2692d Mon Sep 17 00:00:00 2001 From: Martin HS Date: Thu, 25 Jul 2024 10:11:27 +0200 Subject: [PATCH] divmod: fix aliasing error, add tests (#180) This change fixes a flaw in `DivMod` related to aliasing of input arguments. --- ternary_test.go | 33 ++++++++++++++++++++++++++++++++- uint256.go | 28 +++++++++++++++++----------- 2 files changed, 49 insertions(+), 12 deletions(-) diff --git a/ternary_test.go b/ternary_test.go index eb5271c..5822e4f 100644 --- a/ternary_test.go +++ b/ternary_test.go @@ -21,6 +21,8 @@ var ternaryOpFuncs = []struct { {"AddMod", (*Int).AddMod, bigAddMod}, {"MulMod", (*Int).MulMod, bigMulMod}, {"MulModWithReciprocal", (*Int).mulModWithReciprocalWrapper, bigMulMod}, + {"DivModZ", divModZ, bigDivModZ}, + {"DivModM", divModM, bigDivModM}, } func checkTernaryOperation(t *testing.T, opName string, op opThreeArgFunc, bigOp bigThreeArgFunc, x, y, z Int) { @@ -49,7 +51,10 @@ func checkTernaryOperation(t *testing.T, opName string, op opThreeArgFunc, bigOp t.Fatalf("%v\nsecond argument had been modified: %x", operation, f2) } if !f3.Eq(f3orig) { - t.Fatalf("%v\nthird argument had been modified: %x", operation, f3) + if opName != "DivModZ" && opName != "DivModM" { + // DivMod takes m as third argument, modifies it, and returns it. That is by design. + t.Fatalf("%v\nthird argument had been modified: %x", operation, f3) + } } // Check if reusing args as result works correctly. if have = op(f1, f1, f2orig, f3orig); have != f1 { @@ -117,3 +122,29 @@ func (z *Int) mulModWithReciprocalWrapper(x, y, mod *Int) *Int { mu := Reciprocal(mod) return z.MulModWithReciprocal(x, y, mod, &mu) } + +func divModZ(z, x, y, m *Int) *Int { + z2, _ := z.DivMod(x, y, m) + return z2 +} + +func bigDivModZ(result, x, y, mod *big.Int) *big.Int { + if y.Sign() == 0 { + return result.SetUint64(0) + } + z2, _ := result.DivMod(x, y, mod) + return z2 +} + +func divModM(z, x, y, m *Int) *Int { + _, m2 := z.DivMod(x, y, m) + return z.Set(m2) +} + +func bigDivModM(result, x, y, mod *big.Int) *big.Int { + if y.Sign() == 0 { + return result.SetUint64(0) + } + _, m2 := result.DivMod(x, y, mod) + return result.Set(m2) +} diff --git a/uint256.go b/uint256.go index 10dfcb9..9033594 100644 --- a/uint256.go +++ b/uint256.go @@ -369,9 +369,9 @@ func umul(x, y *Int, res *[8]uint64) { func (z *Int) Mul(x, y *Int) *Int { var ( carry0, carry1, carry2 uint64 - res1, res2 uint64 - x0, x1, x2, x3 = x[0], x[1], x[2], x[3] - y0, y1, y2, y3 = y[0], y[1], y[2], y[3] + res1, res2 uint64 + x0, x1, x2, x3 = x[0], x[1], x[2], x[3] + y0, y1, y2, y3 = y[0], y[1], y[2], y[3] ) carry0, z[0] = bits.Mul64(x0, y0) @@ -610,6 +610,11 @@ func (z *Int) Mod(x, y *Int) *Int { // DivMod sets z to the quotient x div y and m to the modulus x mod y and returns the pair (z, m) for y != 0. // If y == 0, both z and m are set to 0 (OBS: differs from the big.Int) func (z *Int) DivMod(x, y, m *Int) (*Int, *Int) { + if z == m { + // We return both z and m as results, if they are aliased, we have to + // un-alias them to be able to return separate results. + m = new(Int).Set(m) + } if y.IsZero() { return z.Clear(), m.Clear() } @@ -617,7 +622,8 @@ func (z *Int) DivMod(x, y, m *Int) (*Int, *Int) { return z.SetOne(), m.Clear() } if x.Lt(y) { - return z.Clear(), m.Set(x) + m.Set(x) + return z.Clear(), m } // At this point: @@ -1279,7 +1285,7 @@ func (z *Int) Sqrt(x *Int) *Int { return z.SetUint64(x0) } for { - z2 = (z1 + x0 / z1) >> 1 + z2 = (z1 + x0/z1) >> 1 if z2 >= z1 { return z.SetUint64(z1) } @@ -1291,18 +1297,18 @@ func (z *Int) Sqrt(x *Int) *Int { z2 := NewInt(0) // Start with value known to be too large and repeat "z = ⌊(z + ⌊x/z⌋)/2⌋" until it stops getting smaller. - z1.Lsh(z1, uint(x.BitLen() + 1) / 2) // must be ≥ √x + z1.Lsh(z1, uint(x.BitLen()+1)/2) // must be ≥ √x // We can do the first division outside the loop - z2.Rsh(x, uint(x.BitLen() + 1) / 2) // The first div is equal to a right shift + z2.Rsh(x, uint(x.BitLen()+1)/2) // The first div is equal to a right shift for { z2.Add(z2, z1) - + // z2 = z2.Rsh(z2, 1) -- the code below does a 1-bit rsh faster - z2[0] = (z2[0] >> 1) | z2[1] << 63 - z2[1] = (z2[1] >> 1) | z2[2] << 63 - z2[2] = (z2[2] >> 1) | z2[3] << 63 + z2[0] = (z2[0] >> 1) | z2[1]<<63 + z2[1] = (z2[1] >> 1) | z2[2]<<63 + z2[2] = (z2[2] >> 1) | z2[3]<<63 z2[3] >>= 1 if !z2.Lt(z1) {