]> Cypherpunks repositories - gostls13.git/commitdiff
math/bits: add extended precision Add, Sub, Mul, Div
authorBrian Kessler <brian.m.kessler@gmail.com>
Tue, 14 Aug 2018 22:39:13 +0000 (16:39 -0600)
committerRobert Griesemer <gri@golang.org>
Tue, 11 Sep 2018 17:54:20 +0000 (17:54 +0000)
Port math/big pure go versions of add-with-carry, subtract-with-borrow,
full-width multiply, and full-width divide.

Updates #24813

Change-Id: Ifae5d2f6ee4237137c9dcba931f69c91b80a4b1c
Reviewed-on: https://go-review.googlesource.com/123157
Reviewed-by: Robert Griesemer <gri@golang.org>
Run-TryBot: Robert Griesemer <gri@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>

src/math/bits/bits.go
src/math/bits/bits_test.go

index 989baacc13fc17848fd63d86f5096eed5196deb3..58cf52d2a74582d427267ccadeb6288d453e1848 100644 (file)
@@ -328,3 +328,197 @@ func Len64(x uint64) (n int) {
        }
        return n + int(len8tab[x])
 }
+
+// --- Add with carry ---
+
+// Add returns the sum with carry of x, y and carry: sum = x + y + carry.
+// The carry input must be 0 or 1; otherwise the behavior is undefined.
+// The carryOut output is guaranteed to be 0 or 1.
+func Add(x, y, carry uint) (sum, carryOut uint) {
+       yc := y + carry
+       sum = x + yc
+       if sum < x || yc < y {
+               carryOut = 1
+       }
+       return
+}
+
+// Add32 returns the sum with carry of x, y and carry: sum = x + y + carry.
+// The carry input must be 0 or 1; otherwise the behavior is undefined.
+// The carryOut output is guaranteed to be 0 or 1.
+func Add32(x, y, carry uint32) (sum, carryOut uint32) {
+       yc := y + carry
+       sum = x + yc
+       if sum < x || yc < y {
+               carryOut = 1
+       }
+       return
+}
+
+// Add64 returns the sum with carry of x, y and carry: sum = x + y + carry.
+// The carry input must be 0 or 1; otherwise the behavior is undefined.
+// The carryOut output is guaranteed to be 0 or 1.
+func Add64(x, y, carry uint64) (sum, carryOut uint64) {
+       yc := y + carry
+       sum = x + yc
+       if sum < x || yc < y {
+               carryOut = 1
+       }
+       return
+}
+
+// --- Subtract with borrow ---
+
+// Sub returns the difference of x, y and borrow: diff = x - y - borrow.
+// The borrow input must be 0 or 1; otherwise the behavior is undefined.
+// The borrowOut output is guaranteed to be 0 or 1.
+func Sub(x, y, borrow uint) (diff, borrowOut uint) {
+       yb := y + borrow
+       diff = x - yb
+       if diff > x || yb < y {
+               borrowOut = 1
+       }
+       return
+}
+
+// Sub32 returns the difference of x, y and borrow, diff = x - y - borrow.
+// The borrow input must be 0 or 1; otherwise the behavior is undefined.
+// The borrowOut output is guaranteed to be 0 or 1.
+func Sub32(x, y, borrow uint32) (diff, borrowOut uint32) {
+       yb := y + borrow
+       diff = x - yb
+       if diff > x || yb < y {
+               borrowOut = 1
+       }
+       return
+}
+
+// Sub64 returns the difference of x, y and borrow: diff = x - y - borrow.
+// The borrow input must be 0 or 1; otherwise the behavior is undefined.
+// The borrowOut output is guaranteed to be 0 or 1.
+func Sub64(x, y, borrow uint64) (diff, borrowOut uint64) {
+       yb := y + borrow
+       diff = x - yb
+       if diff > x || yb < y {
+               borrowOut = 1
+       }
+       return
+}
+
+// --- Full-width multiply ---
+
+// Mul returns the full-width product of x and y: (hi, lo) = x * y
+// with the product bits' upper half returned in hi and the lower
+// half returned in lo.
+func Mul(x, y uint) (hi, lo uint) {
+       if UintSize == 32 {
+               h, l := Mul32(uint32(x), uint32(y))
+               return uint(h), uint(l)
+       }
+       h, l := Mul64(uint64(x), uint64(y))
+       return uint(h), uint(l)
+}
+
+// Mul32 returns the 64-bit product of x and y: (hi, lo) = x * y
+// with the product bits' upper half returned in hi and the lower
+// half returned in lo.
+func Mul32(x, y uint32) (hi, lo uint32) {
+       tmp := uint64(x) * uint64(y)
+       hi, lo = uint32(tmp>>32), uint32(tmp)
+       return
+}
+
+// Mul64 returns the 128-bit product of x and y: (hi, lo) = x * y
+// with the product bits' upper half returned in hi and the lower
+// half returned in lo.
+func Mul64(x, y uint64) (hi, lo uint64) {
+       const mask32 = 1<<32 - 1
+       x0 := x & mask32
+       x1 := x >> 32
+       y0 := y & mask32
+       y1 := y >> 32
+       w0 := x0 * y0
+       t := x1*y0 + w0>>32
+       w1 := t & mask32
+       w2 := t >> 32
+       w1 += x0 * y1
+       hi = x1*y1 + w2 + w1>>32
+       lo = x * y
+       return
+}
+
+// --- Full-width divide ---
+
+// Div returns the quotient and remainder of (hi, lo) divided by y:
+// quo = (hi, lo)/y, rem = (hi, lo)%y with the dividend bits' upper
+// half in parameter hi and the lower half in parameter lo.
+// hi must be < y otherwise the behavior is undefined (the quotient
+// won't fit into quo).
+func Div(hi, lo, y uint) (quo, rem uint) {
+       if UintSize == 32 {
+               q, r := Div32(uint32(hi), uint32(lo), uint32(y))
+               return uint(q), uint(r)
+       }
+       q, r := Div64(uint64(hi), uint64(lo), uint64(y))
+       return uint(q), uint(r)
+}
+
+// Div32 returns the quotient and remainder of (hi, lo) divided by y:
+// quo = (hi, lo)/y, rem = (hi, lo)%y with the dividend bits' upper
+// half in parameter hi and the lower half in parameter lo.
+// hi must be < y otherwise the behavior is undefined (the quotient
+// won't fit into quo).
+func Div32(hi, lo, y uint32) (quo, rem uint32) {
+       z := uint64(hi)<<32 | uint64(lo)
+       quo, rem = uint32(z/uint64(y)), uint32(z%uint64(y))
+       return
+}
+
+// Div64 returns the quotient and remainder of (hi, lo) divided by y:
+// quo = (hi, lo)/y, rem = (hi, lo)%y with the dividend bits' upper
+// half in parameter hi and the lower half in parameter lo.
+// hi must be < y otherwise the behavior is undefined (the quotient
+// won't fit into quo).
+func Div64(hi, lo, y uint64) (quo, rem uint64) {
+       const (
+               two32  = 1 << 32
+               mask32 = two32 - 1
+       )
+       if hi >= y {
+               return 1<<64 - 1, 1<<64 - 1
+       }
+
+       s := uint(LeadingZeros64(y))
+       y <<= s
+
+       yn1 := y >> 32
+       yn0 := y & mask32
+       un32 := hi<<s | lo>>(64-s)
+       un10 := lo << s
+       un1 := un10 >> 32
+       un0 := un10 & mask32
+       q1 := un32 / yn1
+       rhat := un32 - q1*yn1
+
+       for q1 >= two32 || q1*yn0 > two32*rhat+un1 {
+               q1--
+               rhat += yn1
+               if rhat >= two32 {
+                       break
+               }
+       }
+
+       un21 := un32*two32 + un1 - q1*y
+       q0 := un21 / yn1
+       rhat = un21 - q0*yn1
+
+       for q0 >= two32 || q0*yn0 > two32*rhat+un0 {
+               q0--
+               rhat += yn1
+               if rhat >= two32 {
+                       break
+               }
+       }
+
+       return q1*two32 + q0, (un21*two32 + un0 - q0*y) >> s
+}
index 5c34f6dbf7f841591d057b5f20bdade6708c4007..bd6b618f3557d115ae78ff769c2d53177bb53e95 100644 (file)
@@ -705,6 +705,272 @@ func TestLen(t *testing.T) {
        }
 }
 
+const (
+       _M   = 1<<UintSize - 1
+       _M32 = 1<<32 - 1
+       _M64 = 1<<64 - 1
+)
+
+func TestAddSubUint(t *testing.T) {
+       test := func(msg string, f func(x, y, c uint) (z, cout uint), x, y, c, z, cout uint) {
+               z1, cout1 := f(x, y, c)
+               if z1 != z || cout1 != cout {
+                       t.Errorf("%s: got z:cout = %#x:%#x; want %#x:%#x", msg, z1, cout1, z, cout)
+               }
+       }
+       for _, a := range []struct{ x, y, c, z, cout uint }{
+               {0, 0, 0, 0, 0},
+               {0, 1, 0, 1, 0},
+               {0, 0, 1, 1, 0},
+               {0, 1, 1, 2, 0},
+               {12345, 67890, 0, 80235, 0},
+               {12345, 67890, 1, 80236, 0},
+               {_M, 1, 0, 0, 1},
+               {_M, 0, 1, 0, 1},
+               {_M, 1, 1, 1, 1},
+               {_M, _M, 0, _M - 1, 1},
+               {_M, _M, 1, _M, 1},
+       } {
+               test("Add", Add, a.x, a.y, a.c, a.z, a.cout)
+               test("Add symmetric", Add, a.y, a.x, a.c, a.z, a.cout)
+               test("Sub", Sub, a.z, a.x, a.c, a.y, a.cout)
+               test("Sub symmetric", Sub, a.z, a.y, a.c, a.x, a.cout)
+       }
+}
+
+func TestAddSubUint32(t *testing.T) {
+       test := func(msg string, f func(x, y, c uint32) (z, cout uint32), x, y, c, z, cout uint32) {
+               z1, cout1 := f(x, y, c)
+               if z1 != z || cout1 != cout {
+                       t.Errorf("%s: got z:cout = %#x:%#x; want %#x:%#x", msg, z1, cout1, z, cout)
+               }
+       }
+       for _, a := range []struct{ x, y, c, z, cout uint32 }{
+               {0, 0, 0, 0, 0},
+               {0, 1, 0, 1, 0},
+               {0, 0, 1, 1, 0},
+               {0, 1, 1, 2, 0},
+               {12345, 67890, 0, 80235, 0},
+               {12345, 67890, 1, 80236, 0},
+               {_M32, 1, 0, 0, 1},
+               {_M32, 0, 1, 0, 1},
+               {_M32, 1, 1, 1, 1},
+               {_M32, _M32, 0, _M32 - 1, 1},
+               {_M32, _M32, 1, _M32, 1},
+       } {
+               test("Add32", Add32, a.x, a.y, a.c, a.z, a.cout)
+               test("Add32 symmetric", Add32, a.y, a.x, a.c, a.z, a.cout)
+               test("Sub32", Sub32, a.z, a.x, a.c, a.y, a.cout)
+               test("Sub32 symmetric", Sub32, a.z, a.y, a.c, a.x, a.cout)
+       }
+}
+
+func TestAddSubUint64(t *testing.T) {
+       test := func(msg string, f func(x, y, c uint64) (z, cout uint64), x, y, c, z, cout uint64) {
+               z1, cout1 := f(x, y, c)
+               if z1 != z || cout1 != cout {
+                       t.Errorf("%s: got z:cout = %#x:%#x; want %#x:%#x", msg, z1, cout1, z, cout)
+               }
+       }
+       for _, a := range []struct{ x, y, c, z, cout uint64 }{
+               {0, 0, 0, 0, 0},
+               {0, 1, 0, 1, 0},
+               {0, 0, 1, 1, 0},
+               {0, 1, 1, 2, 0},
+               {12345, 67890, 0, 80235, 0},
+               {12345, 67890, 1, 80236, 0},
+               {_M64, 1, 0, 0, 1},
+               {_M64, 0, 1, 0, 1},
+               {_M64, 1, 1, 1, 1},
+               {_M64, _M64, 0, _M64 - 1, 1},
+               {_M64, _M64, 1, _M64, 1},
+       } {
+               test("Add64", Add64, a.x, a.y, a.c, a.z, a.cout)
+               test("Add64 symmetric", Add64, a.y, a.x, a.c, a.z, a.cout)
+               test("Sub64", Sub64, a.z, a.x, a.c, a.y, a.cout)
+               test("Sub64 symmetric", Sub64, a.z, a.y, a.c, a.x, a.cout)
+       }
+}
+
+func TestMulDiv(t *testing.T) {
+       testMul := func(msg string, f func(x, y uint) (hi, lo uint), x, y, hi, lo uint) {
+               hi1, lo1 := f(x, y)
+               if hi1 != hi || lo1 != lo {
+                       t.Errorf("%s: got hi:lo = %#x:%#x; want %#x:%#x", msg, hi1, lo1, hi, lo)
+               }
+       }
+       testDiv := func(msg string, f func(hi, lo, y uint) (q, r uint), hi, lo, y, q, r uint) {
+               q1, r1 := f(hi, lo, y)
+               if q1 != q || r1 != r {
+                       t.Errorf("%s: got q:r = %#x:%#x; want %#x:%#x", msg, q1, r1, q, r)
+               }
+       }
+       for _, a := range []struct {
+               x, y      uint
+               hi, lo, r uint
+       }{
+               {1 << (UintSize - 1), 2, 1, 0, 1},
+               {_M, _M, _M - 1, 1, 42},
+       } {
+               testMul("Mul", Mul, a.x, a.y, a.hi, a.lo)
+               testMul("Mul symmetric", Mul, a.y, a.x, a.hi, a.lo)
+               testDiv("Div", Div, a.hi, a.lo+a.r, a.y, a.x, a.r)
+               testDiv("Div symmetric", Div, a.hi, a.lo+a.r, a.x, a.y, a.r)
+       }
+}
+
+func TestMulDiv32(t *testing.T) {
+       testMul := func(msg string, f func(x, y uint32) (hi, lo uint32), x, y, hi, lo uint32) {
+               hi1, lo1 := f(x, y)
+               if hi1 != hi || lo1 != lo {
+                       t.Errorf("%s: got hi:lo = %#x:%#x; want %#x:%#x", msg, hi1, lo1, hi, lo)
+               }
+       }
+       testDiv := func(msg string, f func(hi, lo, y uint32) (q, r uint32), hi, lo, y, q, r uint32) {
+               q1, r1 := f(hi, lo, y)
+               if q1 != q || r1 != r {
+                       t.Errorf("%s: got q:r = %#x:%#x; want %#x:%#x", msg, q1, r1, q, r)
+               }
+       }
+       for _, a := range []struct {
+               x, y      uint32
+               hi, lo, r uint32
+       }{
+               {1 << 31, 2, 1, 0, 1},
+               {0xc47dfa8c, 50911, 0x98a4, 0x998587f4, 13},
+               {_M32, _M32, _M32 - 1, 1, 42},
+       } {
+               testMul("Mul32", Mul32, a.x, a.y, a.hi, a.lo)
+               testMul("Mul32 symmetric", Mul32, a.y, a.x, a.hi, a.lo)
+               testDiv("Div32", Div32, a.hi, a.lo+a.r, a.y, a.x, a.r)
+               testDiv("Div32 symmetric", Div32, a.hi, a.lo+a.r, a.x, a.y, a.r)
+       }
+}
+
+func TestMulDiv64(t *testing.T) {
+       testMul := func(msg string, f func(x, y uint64) (hi, lo uint64), x, y, hi, lo uint64) {
+               hi1, lo1 := f(x, y)
+               if hi1 != hi || lo1 != lo {
+                       t.Errorf("%s: got hi:lo = %#x:%#x; want %#x:%#x", msg, hi1, lo1, hi, lo)
+               }
+       }
+       testDiv := func(msg string, f func(hi, lo, y uint64) (q, r uint64), hi, lo, y, q, r uint64) {
+               q1, r1 := f(hi, lo, y)
+               if q1 != q || r1 != r {
+                       t.Errorf("%s: got q:r = %#x:%#x; want %#x:%#x", msg, q1, r1, q, r)
+               }
+       }
+       for _, a := range []struct {
+               x, y      uint64
+               hi, lo, r uint64
+       }{
+               {1 << 63, 2, 1, 0, 1},
+               {0x3626229738a3b9, 0xd8988a9f1cc4a61, 0x2dd0712657fe8, 0x9dd6a3364c358319, 13},
+               {_M64, _M64, _M64 - 1, 1, 42},
+       } {
+               testMul("Mul64", Mul64, a.x, a.y, a.hi, a.lo)
+               testMul("Mul64 symmetric", Mul64, a.y, a.x, a.hi, a.lo)
+               testDiv("Div64", Div64, a.hi, a.lo+a.r, a.y, a.x, a.r)
+               testDiv("Div64 symmetric", Div64, a.hi, a.lo+a.r, a.x, a.y, a.r)
+       }
+}
+
+func BenchmarkAdd(b *testing.B) {
+       var z, c uint
+       for i := 0; i < b.N; i++ {
+               z, c = Add(uint(Input), uint(i), c)
+       }
+       Output = int(z + c)
+}
+
+func BenchmarkAdd32(b *testing.B) {
+       var z, c uint32
+       for i := 0; i < b.N; i++ {
+               z, c = Add32(uint32(Input), uint32(i), c)
+       }
+       Output = int(z + c)
+}
+
+func BenchmarkAdd64(b *testing.B) {
+       var z, c uint64
+       for i := 0; i < b.N; i++ {
+               z, c = Add64(uint64(Input), uint64(i), c)
+       }
+       Output = int(z + c)
+}
+
+func BenchmarkSub(b *testing.B) {
+       var z, c uint
+       for i := 0; i < b.N; i++ {
+               z, c = Sub(uint(Input), uint(i), c)
+       }
+       Output = int(z + c)
+}
+
+func BenchmarkSub32(b *testing.B) {
+       var z, c uint32
+       for i := 0; i < b.N; i++ {
+               z, c = Sub32(uint32(Input), uint32(i), c)
+       }
+       Output = int(z + c)
+}
+
+func BenchmarkSub64(b *testing.B) {
+       var z, c uint64
+       for i := 0; i < b.N; i++ {
+               z, c = Add64(uint64(Input), uint64(i), c)
+       }
+       Output = int(z + c)
+}
+
+func BenchmarkMul(b *testing.B) {
+       var hi, lo uint
+       for i := 0; i < b.N; i++ {
+               hi, lo = Mul(uint(Input), uint(i))
+       }
+       Output = int(hi + lo)
+}
+
+func BenchmarkMul32(b *testing.B) {
+       var hi, lo uint32
+       for i := 0; i < b.N; i++ {
+               hi, lo = Mul32(uint32(Input), uint32(i))
+       }
+       Output = int(hi + lo)
+}
+
+func BenchmarkMul64(b *testing.B) {
+       var hi, lo uint64
+       for i := 0; i < b.N; i++ {
+               hi, lo = Mul64(uint64(Input), uint64(i))
+       }
+       Output = int(hi + lo)
+}
+
+func BenchmarkDiv(b *testing.B) {
+       var q, r uint
+       for i := 0; i < b.N; i++ {
+               q, r = Div(1, uint(i), uint(Input))
+       }
+       Output = int(q + r)
+}
+
+func BenchmarkDiv32(b *testing.B) {
+       var q, r uint32
+       for i := 0; i < b.N; i++ {
+               q, r = Div32(1, uint32(i), uint32(Input))
+       }
+       Output = int(q + r)
+}
+
+func BenchmarkDiv64(b *testing.B) {
+       var q, r uint64
+       for i := 0; i < b.N; i++ {
+               q, r = Div64(1, uint64(i), uint64(Input))
+       }
+       Output = int(q + r)
+}
+
 // ----------------------------------------------------------------------------
 // Testing support