}
}
-// ModSqrt sets z to a square root of x mod p if such a square root exists, and
-// returns z. The modulus p must be an odd prime. If x is not a square mod p,
-// ModSqrt leaves z unchanged and returns nil. This function panics if p is
-// not an odd integer.
-func (z *Int) ModSqrt(x, p *Int) *Int {
- switch Jacobi(x, p) {
- case -1:
- return nil // x is not a square mod p
- case 0:
- return z.SetInt64(0) // sqrt(0) mod p = 0
- case 1:
- break
- }
- if x.neg || x.Cmp(p) >= 0 { // ensure 0 <= x < p
- x = new(Int).Mod(x, p)
- }
+// modSqrt3Mod4 uses the identity
+// (a^((p+1)/4))^2 mod p
+// == u^(p+1) mod p
+// == u^2 mod p
+// to calculate the square root of any quadratic residue mod p quickly for 3
+// mod 4 primes.
+func (z *Int) modSqrt3Mod4Prime(x, p *Int) *Int {
+ z.Set(p) // z = p
+ z.Add(z, intOne) // z = p + 1
+ z.Rsh(z, 2) // z = (p + 1) / 4
+ z.Exp(x, z, p) // z = x^z mod p
+ return z
+}
+// modSqrtTonelliShanks uses the Tonelli-Shanks algorithm to find the square
+// root of a quadratic residue modulo any prime.
+func (z *Int) modSqrtTonelliShanks(x, p *Int) *Int {
// Break p-1 into s*2^e such that s is odd.
var s Int
s.Sub(p, intOne)
}
}
+// ModSqrt sets z to a square root of x mod p if such a square root exists, and
+// returns z. The modulus p must be an odd prime. If x is not a square mod p,
+// ModSqrt leaves z unchanged and returns nil. This function panics if p is
+// not an odd integer.
+func (z *Int) ModSqrt(x, p *Int) *Int {
+ switch Jacobi(x, p) {
+ case -1:
+ return nil // x is not a square mod p
+ case 0:
+ return z.SetInt64(0) // sqrt(0) mod p = 0
+ case 1:
+ break
+ }
+ if x.neg || x.Cmp(p) >= 0 { // ensure 0 <= x < p
+ x = new(Int).Mod(x, p)
+ }
+
+ // Check whether p is 3 mod 4, and if so, use the faster algorithm.
+ if len(p.abs) > 0 && p.abs[0]%4 == 3 {
+ return z.modSqrt3Mod4Prime(x, p)
+ }
+ // Otherwise, use Tonelli-Shanks.
+ return z.modSqrtTonelliShanks(x, p)
+}
+
// Lsh sets z = x << n and returns z.
func (z *Int) Lsh(x *Int, n uint) *Int {
z.abs = z.abs.shl(x.abs, n)
}
}
+// tri generates the trinomial 2**(n*2) - 2**n - 1, which is always 3 mod 4 and
+// 7 mod 8, so that 2 is always a quadratic residue.
+func tri(n uint) *Int {
+ x := NewInt(1)
+ x.Lsh(x, n)
+ x2 := new(Int).Lsh(x, n)
+ x2.Sub(x2, x)
+ x2.Sub(x2, intOne)
+ return x2
+}
+
+func BenchmarkModSqrt225_Tonelli(b *testing.B) {
+ p := tri(225)
+ x := NewInt(2)
+ for i := 0; i < b.N; i++ {
+ x.SetUint64(2)
+ x.modSqrtTonelliShanks(x, p)
+ }
+}
+
+func BenchmarkModSqrt224_3Mod4(b *testing.B) {
+ p := tri(225)
+ x := new(Int).SetUint64(2)
+ for i := 0; i < b.N; i++ {
+ x.SetUint64(2)
+ x.modSqrt3Mod4Prime(x, p)
+ }
+}
+
+func BenchmarkModSqrt5430_Tonelli(b *testing.B) {
+ p := tri(5430)
+ x := new(Int).SetUint64(2)
+ for i := 0; i < b.N; i++ {
+ x.SetUint64(2)
+ x.modSqrtTonelliShanks(x, p)
+ }
+}
+
+func BenchmarkModSqrt5430_3Mod4(b *testing.B) {
+ p := tri(5430)
+ x := new(Int).SetUint64(2)
+ for i := 0; i < b.N; i++ {
+ x.SetUint64(2)
+ x.modSqrt3Mod4Prime(x, p)
+ }
+}
+
func TestBitwise(t *testing.T) {
x := new(Int)
y := new(Int)