]> Cypherpunks repositories - gostls13.git/commitdiff
crypto/internal/fips140/nistec: make SetBytes constant time
authorFilippo Valsorda <filippo@golang.org>
Wed, 19 Feb 2025 21:41:59 +0000 (22:41 +0100)
committerGopher Robot <gobot@golang.org>
Fri, 21 Feb 2025 18:31:33 +0000 (10:31 -0800)
Similarly to CL 648035, SetBytes doesn't need to be constant time for
the uses we make of it in the standard library (ECDH and ECDSA public
keys), but it doesn't cost much to make it constant time for users of
the re-exported package, or even just to save the next person from
convincing themselves that it's ok for it not to be constant time.

Change-Id: I6a6a465622a0de08d9fc71db75c63185a82aa54a
Reviewed-on: https://go-review.googlesource.com/c/go/+/650579
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Auto-Submit: Filippo Valsorda <filippo@golang.org>
Reviewed-by: Michael Knyszek <mknyszek@google.com>
Reviewed-by: Roland Shoemaker <roland@golang.org>
src/crypto/ecdh/ecdh_test.go
src/crypto/internal/fips140/nistec/fiat/generate.go
src/crypto/internal/fips140/nistec/fiat/p224.go
src/crypto/internal/fips140/nistec/fiat/p256.go
src/crypto/internal/fips140/nistec/fiat/p384.go
src/crypto/internal/fips140/nistec/fiat/p521.go
src/crypto/internal/fips140/subtle/constant_time.go
src/crypto/internal/fips140/subtle/constant_time_test.go [new file with mode: 0644]

index 75d2480775669f0b5e2a9b65345e4cb6cef31133..8a3eb87061ad8c90b448c74c6afeaab015e0aea4 100644 (file)
@@ -308,6 +308,8 @@ var invalidPublicKeys = map[ecdh.Curve][]string{
                // Points not on the curve.
                "046b17d1f2e12c4247f8bce6e563a440f277037d812deb33a0f4a13945d898c2964fe342e2fe1a7f9b8ee7eb4a7c0f9e162bce33576b315ececbb6406837bf51f6",
                "0400000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
+               // Non-canonical encoding.
+               "04ffffffff00000001000000000000000000000001000000000000000000000004ba6dbc4555a7e7fa016ec431667e8521ee35afc49b265c3accbea3f7cdb70433",
        },
        ecdh.P384(): {
                // Bad lengths.
@@ -322,6 +324,8 @@ var invalidPublicKeys = map[ecdh.Curve][]string{
                // Points not on the curve.
                "04aa87ca22be8b05378eb1c71ef320ad746e1d3b628ba79b9859f741e082542a385502f25dbf55296c3a545e3872760ab73617de4a96262c6f5d9e98bf9292dc29f8f41dbd289a147ce9da3113b5f0b8c00a60b1ce1d7e819d7a431d7c90ea0e60",
                "04000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
+               // Non-canonical encoding.
+               "04fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffeffffffff000000000000000100000001732152442fb6ee5c3e6ce1d920c059bc623563814d79042b903ce60f1d4487fccd450a86da03f3e6ed525d02017bfdb3",
        },
        ecdh.P521(): {
                // Bad lengths.
@@ -336,6 +340,8 @@ var invalidPublicKeys = map[ecdh.Curve][]string{
                // Points not on the curve.
                "0400c6858e06b70404e9cd9e3ecb662395b4429c648139053fb521f828af606b4d3dbaa14b5e77efe75928fe1dc127a2ffa8de3348b3c1856a429bf97e7e31c2e5bd66011839296a789a3bc0045c8a5fb42c7d1bd998f54449579b446817afbd17273e662c97ee72995ef42640c550b9013fad0761353c7086a272c24088be94769fd16651",
                "04000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000",
+               // Non-canonical encoding.
+               "0402000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000100d9254fdf800496acb33790b103c5ee9fac12832fe546c632225b0f7fce3da4574b1a879b623d722fa8fc34d5fc2a8731aad691a9a8bb8b554c95a051d6aa505acf",
        },
        ecdh.X25519(): {},
 }
index b8c5a1389c56731f8896f082162657866f843a67..5dda9434d41ecab89f2773cf05eedb37b9f6597d 100644 (file)
@@ -223,13 +223,8 @@ func (e *{{ .Element }}) SetBytes(v []byte) (*{{ .Element }}, error) {
        // the encoding of -1 mod p, so p - 1, the highest canonical encoding.
        var minusOneEncoding = new({{ .Element }}).Sub(
                new({{ .Element }}), new({{ .Element }}).One()).Bytes()
-       for i := range v {
-               if v[i] < minusOneEncoding[i] {
-                       break
-               }
-               if v[i] > minusOneEncoding[i] {
-                       return nil, errors.New("invalid {{ .Element }} encoding")
-               }
+       if subtle.ConstantTimeLessOrEqBytes(v, minusOneEncoding) == 0 {
+               return nil, errors.New("invalid {{ .Element }} encoding")
        }
 
        var in [{{ .Prefix }}ElementLen]byte
index cdce9f7018f9e809f71ce9e58f683edcc8279657..335fa42cdadc399a2bb73248d94b443c3df12eb8 100644 (file)
@@ -78,13 +78,8 @@ func (e *P224Element) SetBytes(v []byte) (*P224Element, error) {
        // the encoding of -1 mod p, so p - 1, the highest canonical encoding.
        var minusOneEncoding = new(P224Element).Sub(
                new(P224Element), new(P224Element).One()).Bytes()
-       for i := range v {
-               if v[i] < minusOneEncoding[i] {
-                       break
-               }
-               if v[i] > minusOneEncoding[i] {
-                       return nil, errors.New("invalid P224Element encoding")
-               }
+       if subtle.ConstantTimeLessOrEqBytes(v, minusOneEncoding) == 0 {
+               return nil, errors.New("invalid P224Element encoding")
        }
 
        var in [p224ElementLen]byte
index fb7284977ac90286a4d1f9f42c7269356e11b557..2301656b591dfaad36b30002c84e186450e260b4 100644 (file)
@@ -78,13 +78,8 @@ func (e *P256Element) SetBytes(v []byte) (*P256Element, error) {
        // the encoding of -1 mod p, so p - 1, the highest canonical encoding.
        var minusOneEncoding = new(P256Element).Sub(
                new(P256Element), new(P256Element).One()).Bytes()
-       for i := range v {
-               if v[i] < minusOneEncoding[i] {
-                       break
-               }
-               if v[i] > minusOneEncoding[i] {
-                       return nil, errors.New("invalid P256Element encoding")
-               }
+       if subtle.ConstantTimeLessOrEqBytes(v, minusOneEncoding) == 0 {
+               return nil, errors.New("invalid P256Element encoding")
        }
 
        var in [p256ElementLen]byte
index 505b7e9a2d94911a212d83776ca035b1f8560cde..f514ab2d60319bf7c95d37f65c772b5cbc4e0e50 100644 (file)
@@ -78,13 +78,8 @@ func (e *P384Element) SetBytes(v []byte) (*P384Element, error) {
        // the encoding of -1 mod p, so p - 1, the highest canonical encoding.
        var minusOneEncoding = new(P384Element).Sub(
                new(P384Element), new(P384Element).One()).Bytes()
-       for i := range v {
-               if v[i] < minusOneEncoding[i] {
-                       break
-               }
-               if v[i] > minusOneEncoding[i] {
-                       return nil, errors.New("invalid P384Element encoding")
-               }
+       if subtle.ConstantTimeLessOrEqBytes(v, minusOneEncoding) == 0 {
+               return nil, errors.New("invalid P384Element encoding")
        }
 
        var in [p384ElementLen]byte
index 48141900ff67231ba7d6530b4e6c5eb2fed0af9b..d4d576503d4bc5dfad4d395847ace8b2a760b1d9 100644 (file)
@@ -78,13 +78,8 @@ func (e *P521Element) SetBytes(v []byte) (*P521Element, error) {
        // the encoding of -1 mod p, so p - 1, the highest canonical encoding.
        var minusOneEncoding = new(P521Element).Sub(
                new(P521Element), new(P521Element).One()).Bytes()
-       for i := range v {
-               if v[i] < minusOneEncoding[i] {
-                       break
-               }
-               if v[i] > minusOneEncoding[i] {
-                       return nil, errors.New("invalid P521Element encoding")
-               }
+       if subtle.ConstantTimeLessOrEqBytes(v, minusOneEncoding) == 0 {
+               return nil, errors.New("invalid P521Element encoding")
        }
 
        var in [p521ElementLen]byte
index 9fd3923e76137c39e0ac5182863ea119b301eb9b..fa7a002d3fa45642d5042ae86fa6f3c9c5e30b4b 100644 (file)
@@ -4,6 +4,11 @@
 
 package subtle
 
+import (
+       "crypto/internal/fips140deps/byteorder"
+       "math/bits"
+)
+
 // ConstantTimeCompare returns 1 if the two slices, x and y, have equal contents
 // and 0 otherwise. The time taken is a function of the length of the slices and
 // is independent of the contents. If the lengths of x and y do not match it
@@ -22,6 +27,37 @@ func ConstantTimeCompare(x, y []byte) int {
        return ConstantTimeByteEq(v, 0)
 }
 
+// ConstantTimeLessOrEqBytes returns 1 if x <= y and 0 otherwise. The comparison
+// is lexigraphical, or big-endian. The time taken is a function of the length of
+// the slices and is independent of the contents. If the lengths of x and y do not
+// match it returns 0 immediately.
+func ConstantTimeLessOrEqBytes(x, y []byte) int {
+       if len(x) != len(y) {
+               return 0
+       }
+
+       // Do a constant time subtraction chain y - x.
+       // If there is no borrow at the end, then x <= y.
+       var b uint64
+       for len(x) > 8 {
+               x0 := byteorder.BEUint64(x[len(x)-8:])
+               y0 := byteorder.BEUint64(y[len(y)-8:])
+               _, b = bits.Sub64(y0, x0, b)
+               x = x[:len(x)-8]
+               y = y[:len(y)-8]
+       }
+       if len(x) > 0 {
+               xb := make([]byte, 8)
+               yb := make([]byte, 8)
+               copy(xb[8-len(x):], x)
+               copy(yb[8-len(y):], y)
+               x0 := byteorder.BEUint64(xb)
+               y0 := byteorder.BEUint64(yb)
+               _, b = bits.Sub64(y0, x0, b)
+       }
+       return int(b ^ 1)
+}
+
 // ConstantTimeSelect returns x if v == 1 and y if v == 0.
 // Its behavior is undefined if v takes any other value.
 func ConstantTimeSelect(v, x, y int) int { return ^(v-1)&x | (v-1)&y }
diff --git a/src/crypto/internal/fips140/subtle/constant_time_test.go b/src/crypto/internal/fips140/subtle/constant_time_test.go
new file mode 100644 (file)
index 0000000..bcd548c
--- /dev/null
@@ -0,0 +1,104 @@
+// Copyright 2025 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package subtle
+
+import (
+       "bytes"
+       "crypto/internal/fips140deps/byteorder"
+       "math/rand/v2"
+       "testing"
+       "time"
+)
+
+func TestConstantTimeLessOrEqBytes(t *testing.T) {
+       seed := make([]byte, 32)
+       byteorder.BEPutUint64(seed, uint64(time.Now().UnixNano()))
+       r := rand.NewChaCha8([32]byte(seed))
+       for l := range 20 {
+               a := make([]byte, l)
+               b := make([]byte, l)
+               empty := make([]byte, l)
+               r.Read(a)
+               r.Read(b)
+               exp := 0
+               if bytes.Compare(a, b) <= 0 {
+                       exp = 1
+               }
+               if got := ConstantTimeLessOrEqBytes(a, b); got != exp {
+                       t.Errorf("ConstantTimeLessOrEqBytes(%x, %x) = %d, want %d", a, b, got, exp)
+               }
+               exp = 0
+               if bytes.Compare(b, a) <= 0 {
+                       exp = 1
+               }
+               if got := ConstantTimeLessOrEqBytes(b, a); got != exp {
+                       t.Errorf("ConstantTimeLessOrEqBytes(%x, %x) = %d, want %d", b, a, got, exp)
+               }
+               if got := ConstantTimeLessOrEqBytes(empty, a); got != 1 {
+                       t.Errorf("ConstantTimeLessOrEqBytes(%x, %x) = %d, want 1", empty, a, got)
+               }
+               if got := ConstantTimeLessOrEqBytes(empty, b); got != 1 {
+                       t.Errorf("ConstantTimeLessOrEqBytes(%x, %x) = %d, want 1", empty, b, got)
+               }
+               if got := ConstantTimeLessOrEqBytes(a, a); got != 1 {
+                       t.Errorf("ConstantTimeLessOrEqBytes(%x, %x) = %d, want 1", a, a, got)
+               }
+               if got := ConstantTimeLessOrEqBytes(b, b); got != 1 {
+                       t.Errorf("ConstantTimeLessOrEqBytes(%x, %x) = %d, want 1", b, b, got)
+               }
+               if got := ConstantTimeLessOrEqBytes(empty, empty); got != 1 {
+                       t.Errorf("ConstantTimeLessOrEqBytes(%x, %x) = %d, want 1", empty, empty, got)
+               }
+               if l == 0 {
+                       continue
+               }
+               max := make([]byte, l)
+               for i := range max {
+                       max[i] = 0xff
+               }
+               if got := ConstantTimeLessOrEqBytes(a, max); got != 1 {
+                       t.Errorf("ConstantTimeLessOrEqBytes(%x, %x) = %d, want 1", a, max, got)
+               }
+               if got := ConstantTimeLessOrEqBytes(b, max); got != 1 {
+                       t.Errorf("ConstantTimeLessOrEqBytes(%x, %x) = %d, want 1", b, max, got)
+               }
+               if got := ConstantTimeLessOrEqBytes(empty, max); got != 1 {
+                       t.Errorf("ConstantTimeLessOrEqBytes(%x, %x) = %d, want 1", empty, max, got)
+               }
+               if got := ConstantTimeLessOrEqBytes(max, max); got != 1 {
+                       t.Errorf("ConstantTimeLessOrEqBytes(%x, %x) = %d, want 1", max, max, got)
+               }
+               aPlusOne := make([]byte, l)
+               copy(aPlusOne, a)
+               for i := l - 1; i >= 0; i-- {
+                       if aPlusOne[i] == 0xff {
+                               aPlusOne[i] = 0
+                               continue
+                       }
+                       aPlusOne[i]++
+                       if got := ConstantTimeLessOrEqBytes(a, aPlusOne); got != 1 {
+                               t.Errorf("ConstantTimeLessOrEqBytes(%x, %x) = %d, want 1", a, aPlusOne, got)
+                       }
+                       if got := ConstantTimeLessOrEqBytes(aPlusOne, a); got != 0 {
+                               t.Errorf("ConstantTimeLessOrEqBytes(%x, %x) = %d, want 0", aPlusOne, a, got)
+                       }
+                       break
+               }
+               shorter := make([]byte, l-1)
+               copy(shorter, a)
+               if got := ConstantTimeLessOrEqBytes(a, shorter); got != 0 {
+                       t.Errorf("ConstantTimeLessOrEqBytes(%x, %x) = %d, want 0", a, shorter, got)
+               }
+               if got := ConstantTimeLessOrEqBytes(shorter, a); got != 0 {
+                       t.Errorf("ConstantTimeLessOrEqBytes(%x, %x) = %d, want 0", shorter, a, got)
+               }
+               if got := ConstantTimeLessOrEqBytes(b, shorter); got != 0 {
+                       t.Errorf("ConstantTimeLessOrEqBytes(%x, %x) = %d, want 0", b, shorter, got)
+               }
+               if got := ConstantTimeLessOrEqBytes(shorter, b); got != 0 {
+                       t.Errorf("ConstantTimeLessOrEqBytes(%x, %x) = %d, want 0", shorter, b, got)
+               }
+       }
+}