]> Cypherpunks repositories - gostls13.git/commitdiff
big: add Int methods to act on numbered bits.
authorRoger Peppe <rogpeppe@gmail.com>
Tue, 17 May 2011 20:38:21 +0000 (13:38 -0700)
committerRobert Griesemer <gri@golang.org>
Tue, 17 May 2011 20:38:21 +0000 (13:38 -0700)
Speeds up setting individual bits by ~75%, useful
when using big.Int as a bit set.

R=gri, rsc
CC=golang-dev
https://golang.org/cl/4538053

src/pkg/big/int.go
src/pkg/big/int_test.go
src/pkg/big/nat.go

index f1ea7b1c2ec2cf8c524dc444f541b4f276928ec7..dfb7dcdb63fa86a0f55a05e6f4b66d2f2f17603e 100755 (executable)
@@ -560,6 +560,42 @@ func (z *Int) Rsh(x *Int, n uint) *Int {
 }
 
 
+// Bit returns the value of the i'th bit of z. That is, it
+// returns (z>>i)&1. The bit index i must be >= 0.
+func (z *Int) Bit(i int) uint {
+       if i < 0 {
+               panic("negative bit index")
+       }
+       if z.neg {
+               t := nat{}.sub(z.abs, natOne)
+               return t.bit(uint(i)) ^ 1
+       }
+
+       return z.abs.bit(uint(i))
+}
+
+
+// SetBit sets the i'th bit of z to bit and returns z.
+// That is, if bit is 1 SetBit sets z = x | (1 << i);
+// if bit is 0 it sets z = x &^ (1 << i). If bit is not 0 or 1,
+// SetBit will panic.
+func (z *Int) SetBit(x *Int, i int, b uint) *Int {
+       if i < 0 {
+               panic("negative bit index")
+       }
+       if x.neg {
+               t := z.abs.sub(x.abs, natOne)
+               t = t.setBit(t, uint(i), b^1)
+               z.abs = t.add(t, natOne)
+               z.neg = len(z.abs) > 0
+               return z
+       }
+       z.abs = z.abs.setBit(x.abs, uint(i), b)
+       z.neg = false
+       return z
+}
+
+
 // And sets z = x & y and returns z.
 func (z *Int) And(x, y *Int) *Int {
        if x.neg == y.neg {
index 9c19dd5da6ded695207392ff9d64830f092cd678..82fd7564e82feb800ffe2eb39a2561772fd6c521 100755 (executable)
@@ -984,6 +984,142 @@ func testBitFunSelf(t *testing.T, msg string, f bitFun, x, y *Int, exp string) {
        }
 }
 
+func altBit(x *Int, i int) uint {
+       z := new(Int).Rsh(x, uint(i))
+       z = z.And(z, NewInt(1))
+       if z.Cmp(new(Int)) != 0 {
+               return 1
+       }
+       return 0
+}
+
+func altSetBit(z *Int, x *Int, i int, b uint) *Int {
+       one := NewInt(1)
+       m := one.Lsh(one, uint(i))
+       switch b {
+       case 1:
+               return z.Or(x, m)
+       case 0:
+               return z.AndNot(x, m)
+       }
+       panic("set bit is not 0 or 1")
+}
+
+func testBitset(t *testing.T, x *Int) {
+       n := x.BitLen()
+       z := new(Int).Set(x)
+       z1 := new(Int).Set(x)
+       for i := 0; i < n+10; i++ {
+               old := z.Bit(i)
+               old1 := altBit(z1, i)
+               if old != old1 {
+                       t.Errorf("bitset: inconsistent value for Bit(%s, %d), got %v want %v", z1, i, old, old1)
+               }
+               z := new(Int).SetBit(z, i, 1)
+               z1 := altSetBit(new(Int), z1, i, 1)
+               if z.Bit(i) == 0 {
+                       t.Errorf("bitset: bit %d of %s got 0 want 1", i, x)
+               }
+               if z.Cmp(z1) != 0 {
+                       t.Errorf("bitset: inconsistent value after SetBit 1, got %s want %s", z, z1)
+               }
+               z.SetBit(z, i, 0)
+               altSetBit(z1, z1, i, 0)
+               if z.Bit(i) != 0 {
+                       t.Errorf("bitset: bit %d of %s got 1 want 0", i, x)
+               }
+               if z.Cmp(z1) != 0 {
+                       t.Errorf("bitset: inconsistent value after SetBit 0, got %s want %s", z, z1)
+               }
+               altSetBit(z1, z1, i, old)
+               z.SetBit(z, i, old)
+               if z.Cmp(z1) != 0 {
+                       t.Errorf("bitset: inconsistent value after SetBit old, got %s want %s", z, z1)
+               }
+       }
+       if z.Cmp(x) != 0 {
+               t.Errorf("bitset: got %s want %s", z, x)
+       }
+}
+
+var bitsetTests = []struct {
+       x string
+       i int
+       b uint
+}{
+       {"0", 0, 0},
+       {"0", 200, 0},
+       {"1", 0, 1},
+       {"1", 1, 0},
+       {"-1", 0, 1},
+       {"-1", 200, 1},
+       {"0x2000000000000000000000000000", 108, 0},
+       {"0x2000000000000000000000000000", 109, 1},
+       {"0x2000000000000000000000000000", 110, 0},
+       {"-0x2000000000000000000000000001", 108, 1},
+       {"-0x2000000000000000000000000001", 109, 0},
+       {"-0x2000000000000000000000000001", 110, 1},
+}
+
+func TestBitSet(t *testing.T) {
+       for _, test := range bitwiseTests {
+               x := new(Int)
+               x.SetString(test.x, 0)
+               testBitset(t, x)
+               x = new(Int)
+               x.SetString(test.y, 0)
+               testBitset(t, x)
+       }
+       for i, test := range bitsetTests {
+               x := new(Int)
+               x.SetString(test.x, 0)
+               b := x.Bit(test.i)
+               if b != test.b {
+
+                       t.Errorf("#%d want %v got %v", i, test.b, b)
+               }
+       }
+}
+
+func BenchmarkBitset(b *testing.B) {
+       z := new(Int)
+       z.SetBit(z, 512, 1)
+       b.ResetTimer()
+       b.StartTimer()
+       for i := b.N - 1; i >= 0; i-- {
+               z.SetBit(z, i&512, 1)
+       }
+}
+
+func BenchmarkBitsetNeg(b *testing.B) {
+       z := NewInt(-1)
+       z.SetBit(z, 512, 0)
+       b.ResetTimer()
+       b.StartTimer()
+       for i := b.N - 1; i >= 0; i-- {
+               z.SetBit(z, i&512, 0)
+       }
+}
+
+func BenchmarkBitsetOrig(b *testing.B) {
+       z := new(Int)
+       altSetBit(z, z, 512, 1)
+       b.ResetTimer()
+       b.StartTimer()
+       for i := b.N - 1; i >= 0; i-- {
+               altSetBit(z, z, i&512, 1)
+       }
+}
+
+func BenchmarkBitsetNegOrig(b *testing.B) {
+       z := NewInt(-1)
+       altSetBit(z, z, 512, 0)
+       b.ResetTimer()
+       b.StartTimer()
+       for i := b.N - 1; i >= 0; i-- {
+               altSetBit(z, z, i&512, 0)
+       }
+}
 
 func TestBitwise(t *testing.T) {
        x := new(Int)
index 4848d427b39d48d474d3fb1a3f2a320a99d62954..2fdae9f06f62d6178dcd212e967063457454b1b6 100755 (executable)
@@ -773,6 +773,43 @@ func (z nat) shr(x nat, s uint) nat {
 }
 
 
+func (z nat) setBit(x nat, i uint, b uint) nat {
+       j := int(i / _W)
+       m := Word(1) << (i % _W)
+       n := len(x)
+       switch b {
+       case 0:
+               z = z.make(n)
+               copy(z, x)
+               if j >= n {
+                       // no need to grow
+                       return z
+               }
+               z[j] &^= m
+               return z.norm()
+       case 1:
+               if j >= n {
+                       n = j + 1
+               }
+               z = z.make(n)
+               copy(z, x)
+               z[j] |= m
+               // no need to normalize
+               return z
+       }
+       panic("set bit is not 0 or 1")
+}
+
+
+func (z nat) bit(i uint) uint {
+       j := int(i / _W)
+       if j >= len(z) {
+               return 0
+       }
+       return uint(z[j] >> (i % _W) & 1)
+}
+
+
 func (z nat) and(x, y nat) nat {
        m := len(x)
        n := len(y)