]> Cypherpunks repositories - gostls13.git/commitdiff
crypto/rand: prevent Read argument from escaping to heap
authorFilippo Valsorda <filippo@golang.org>
Thu, 1 Aug 2024 17:59:07 +0000 (19:59 +0200)
committerFilippo Valsorda <filippo@golang.org>
Mon, 7 Oct 2024 15:33:40 +0000 (15:33 +0000)
Mateusz had this idea before me in CL 578516, but it got much easier
after the recent cleanup.

It's unfortunate we lose the test coverage of batched, but the package
is significantly simpler than when we introduced it, so it should be
easier to review that everything does what it's supposed to do.

Fixes #66779

Co-authored-by: Mateusz Poliwczak <mpoliwczak34@gmail.com>
Change-Id: Id35f1172e678fec184efb0efae3631afac8121d0
Reviewed-on: https://go-review.googlesource.com/c/go/+/602498
Reviewed-by: David Chase <drchase@google.com>
Reviewed-by: Roland Shoemaker <roland@golang.org>
Reviewed-by: Daniel McCarney <daniel@binaryparadox.net>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>

src/crypto/rand/rand.go
src/crypto/rand/rand_batched_test.go [deleted file]
src/crypto/rand/rand_getentropy.go
src/crypto/rand/rand_js.go
src/crypto/rand/rand_test.go
src/internal/syscall/unix/getentropy_openbsd.go

index 73e8a8bc3916ada6c35f6ee09cc760ba0b1601c3..20a2438e84312f10a06ea70b56a2e42d58f99d90 100644 (file)
@@ -70,28 +70,20 @@ func fatal(string)
 // If [Reader] is set to a non-default value, Read calls [io.ReadFull] on
 // [Reader] and crashes the program irrecoverably if an error is returned.
 func Read(b []byte) (n int, err error) {
-       _, err = io.ReadFull(Reader, b)
+       // We don't want b to escape to the heap, but escape analysis can't see
+       // through a potentially overridden Reader, so we special-case the default
+       // case which we can keep non-escaping, and in the general case we read into
+       // a heap buffer and copy from it.
+       if r, ok := Reader.(*reader); ok {
+               _, err = r.Read(b)
+       } else {
+               bb := make([]byte, len(b))
+               _, err = io.ReadFull(Reader, bb)
+               copy(b, bb)
+       }
        if err != nil {
                fatal("crypto/rand: failed to read random data (see https://go.dev/issue/66821): " + err.Error())
                panic("unreachable") // To be sure.
        }
        return len(b), nil
 }
-
-// batched returns a function that calls f to populate a []byte by chunking it
-// into subslices of, at most, readMax bytes.
-func batched(f func([]byte) error, readMax int) func([]byte) error {
-       return func(out []byte) error {
-               for len(out) > 0 {
-                       read := len(out)
-                       if read > readMax {
-                               read = readMax
-                       }
-                       if err := f(out[:read]); err != nil {
-                               return err
-                       }
-                       out = out[read:]
-               }
-               return nil
-       }
-}
diff --git a/src/crypto/rand/rand_batched_test.go b/src/crypto/rand/rand_batched_test.go
deleted file mode 100644 (file)
index 02f4893..0000000
+++ /dev/null
@@ -1,75 +0,0 @@
-// Copyright 2014 The Go Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-//go:build unix
-
-package rand
-
-import (
-       "bytes"
-       "errors"
-       prand "math/rand"
-       "testing"
-)
-
-func TestBatched(t *testing.T) {
-       fillBatched := batched(func(p []byte) error {
-               for i := range p {
-                       p[i] = byte(i)
-               }
-               return nil
-       }, 5)
-
-       p := make([]byte, 13)
-       if err := fillBatched(p); err != nil {
-               t.Fatalf("batched function returned error: %s", err)
-       }
-       expected := []byte{0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2}
-       if !bytes.Equal(expected, p) {
-               t.Errorf("incorrect batch result: got %x, want %x", p, expected)
-       }
-}
-
-func TestBatchedBuffering(t *testing.T) {
-       backingStore := make([]byte, 1<<23)
-       prand.Read(backingStore)
-       backingMarker := backingStore[:]
-       output := make([]byte, len(backingStore))
-       outputMarker := output[:]
-
-       fillBatched := batched(func(p []byte) error {
-               n := copy(p, backingMarker)
-               backingMarker = backingMarker[n:]
-               return nil
-       }, 731)
-
-       for len(outputMarker) > 0 {
-               max := 9200
-               if max > len(outputMarker) {
-                       max = len(outputMarker)
-               }
-               howMuch := prand.Intn(max + 1)
-               if err := fillBatched(outputMarker[:howMuch]); err != nil {
-                       t.Fatalf("batched function returned error: %s", err)
-               }
-               outputMarker = outputMarker[howMuch:]
-       }
-       if !bytes.Equal(backingStore, output) {
-               t.Error("incorrect batch result")
-       }
-}
-
-func TestBatchedError(t *testing.T) {
-       b := batched(func(p []byte) error { return errors.New("failure") }, 5)
-       if b(make([]byte, 13)) == nil {
-               t.Fatal("batched function should have returned an error")
-       }
-}
-
-func TestBatchedEmpty(t *testing.T) {
-       b := batched(func(p []byte) error { return errors.New("failure") }, 5)
-       if b(make([]byte, 0)) != nil {
-               t.Fatal("empty slice should always return successful")
-       }
-}
index 47320133e546cf76a0919a6f75ab7f2f8da7df83..b9e41433a2de4adf91443b0e9b987349292c763c 100644 (file)
@@ -8,5 +8,17 @@ package rand
 
 import "internal/syscall/unix"
 
-// getentropy(2) returns a maximum of 256 bytes per call.
-var read = batched(unix.GetEntropy, 256)
+func read(b []byte) error {
+       for len(b) > 0 {
+               size := len(b)
+               if size > 256 {
+                       size = 256
+               }
+               // getentropy(2) returns a maximum of 256 bytes per call.
+               if err := unix.GetEntropy(b[:size]); err != nil {
+                       return err
+               }
+               b = b[size:]
+       }
+       return nil
+}
index 3345e4874a5eebbb9f3faac352225de41230a450..82cc75fb4e7cde79288d3e7d24c3444387cd1f08 100644 (file)
@@ -24,3 +24,21 @@ func getRandom(b []byte) error {
        js.CopyBytesToGo(b, a)
        return nil
 }
+
+// batched returns a function that calls f to populate a []byte by chunking it
+// into subslices of, at most, readMax bytes.
+func batched(f func([]byte) error, readMax int) func([]byte) error {
+       return func(out []byte) error {
+               for len(out) > 0 {
+                       read := len(out)
+                       if read > readMax {
+                               read = readMax
+                       }
+                       if err := f(out[:read]); err != nil {
+                               return err
+                       }
+                       out = out[read:]
+               }
+               return nil
+       }
+}
index 35a7d59338c2dd25d7d01426b4d5aa9987cb3b13..6d949ea9ac333edbcdf26312106be2ca51a630f2 100644 (file)
@@ -7,8 +7,10 @@ package rand_test
 import (
        "bytes"
        "compress/flate"
+       "crypto/internal/boring"
        . "crypto/rand"
        "io"
+       "runtime"
        "sync"
        "testing"
 )
@@ -121,6 +123,30 @@ func TestConcurrentRead(t *testing.T) {
        wg.Wait()
 }
 
+var sink byte
+
+func TestAllocations(t *testing.T) {
+       if boring.Enabled {
+               // Might be fixable with https://go.dev/issue/56378.
+               t.Skip("boringcrypto allocates")
+       }
+       if runtime.GOOS == "aix" {
+               t.Skip("/dev/urandom read path allocates")
+       }
+       if runtime.GOOS == "js" {
+               t.Skip("syscall/js allocates")
+       }
+
+       n := int(testing.AllocsPerRun(10, func() {
+               buf := make([]byte, 32)
+               Read(buf)
+               sink ^= buf[0]
+       }))
+       if n > 0 {
+               t.Errorf("allocs = %d, want 0", n)
+       }
+}
+
 func BenchmarkRead(b *testing.B) {
        b.Run("4", func(b *testing.B) {
                benchmarkRead(b, 4)
index ad0914da903b98efca1abc7ccfbdef69be39e34c..7516ac7ce71b08914af442bef571852a42a88ee2 100644 (file)
@@ -14,4 +14,5 @@ func GetEntropy(p []byte) error {
 }
 
 //go:linkname getentropy syscall.getentropy
+//go:noescape
 func getentropy(p []byte) error