]> Cypherpunks repositories - gostls13.git/commitdiff
[dev.fuzz] internal/fuzz: avoid marshaling input before calling fuzz function
authorJay Conrod <jayconrod@google.com>
Mon, 12 Jul 2021 22:39:43 +0000 (15:39 -0700)
committerJay Conrod <jayconrod@google.com>
Tue, 20 Jul 2021 00:06:06 +0000 (00:06 +0000)
Previously, before each call to the fuzz function, the worker process
marshalled the mutated input into shared memory. If the worker process
terminates unexpectedly, it's important that the coordinator can find
the crashing input in shared memory.

Profiling shows this marshalling is very expensive though. This change
takes another strategy. Instead of marshaling each mutated input, the
worker process no longer modifies the input in shared memory at
all. Instead, it saves its PRNG state in shared memory and increments
a counter before each fuzz function call. If the worker process
terminates, the coordinator can reconstruct the crashing value using
this information.

This change gives a ~10x increase in execs/s for a trivial fuzz
function with -parallel=1.

Change-Id: I18cf326c252727385dc53ea2518922b1f6ae36b6
Reviewed-on: https://go-review.googlesource.com/c/go/+/334149
Trust: Jay Conrod <jayconrod@google.com>
Trust: Katie Hockman <katie@golang.org>
Run-TryBot: Jay Conrod <jayconrod@google.com>
TryBot-Result: Go Bot <gobot@golang.org>
Reviewed-by: Katie Hockman <katie@golang.org>
src/cmd/go/testdata/script/test_fuzz_mutator_repeat.txt [new file with mode: 0644]
src/internal/fuzz/mem.go
src/internal/fuzz/mutators_byteslice_test.go
src/internal/fuzz/pcg.go
src/internal/fuzz/worker.go
src/internal/fuzz/worker_test.go

diff --git a/src/cmd/go/testdata/script/test_fuzz_mutator_repeat.txt b/src/cmd/go/testdata/script/test_fuzz_mutator_repeat.txt
new file mode 100644 (file)
index 0000000..0924ed3
--- /dev/null
@@ -0,0 +1,66 @@
+# TODO(jayconrod): support shared memory on more platforms.
+[!darwin] [!linux] [!windows] skip
+
+# Verify that the fuzzing engine records the actual crashing input, even when
+# a worker process terminates without communicating the crashing input back
+# to the coordinator.
+
+[short] skip
+
+# Start fuzzing. The worker crashes after ~100 iterations.
+# The fuzz function writes the crashing input to "want" before exiting.
+# The fuzzing engine reconstructs the crashing input and saves it to testdata.
+! exists want
+! go test -fuzz=. -parallel=1
+stdout 'fuzzing process terminated unexpectedly'
+stdout 'Crash written to testdata'
+
+# Run the fuzz target without fuzzing. The fuzz function is called with the
+# crashing input in testdata. The test passes if that input is identical to
+# the one saved in "want".
+exists want
+go test -want=want
+
+-- go.mod --
+module fuzz
+
+go 1.17
+-- fuzz_test.go --
+package fuzz
+
+import (
+       "bytes"
+       "flag"
+       "os"
+       "testing"
+)
+
+var wantFlag = flag.String("want", "", "file containing previous crashing input")
+
+func FuzzRepeat(f *testing.F) {
+       i := 0
+       f.Fuzz(func(t *testing.T, b []byte) {
+               i++
+               if i == 100 {
+                       f, err := os.OpenFile("want", os.O_WRONLY|os.O_CREATE|os.O_EXCL, 0666)
+                       if err != nil {
+                               // Couldn't create the file, probably because it already exists,
+                               // and we're minimizing now. Return without crashing.
+                               return
+                       }
+                       f.Write(b)
+                       f.Close()
+                       os.Exit(1) // crash without communicating
+               }
+
+               if *wantFlag != "" {
+                       want, err := os.ReadFile(*wantFlag)
+                       if err != nil {
+                               t.Fatal(err)
+                       }
+                       if !bytes.Equal(want, b) {
+                               t.Fatalf("inputs are not equal!\n got: %q\nwant:%q", b, want)
+                       }
+               }
+       })
+}
index a7792321eeadee6f8321fd631aa5965d7ed90532..ccd4da24550cf0aa9f84a480f3327211e18ed9f7 100644 (file)
@@ -43,6 +43,9 @@ type sharedMemHeader struct {
 
        // valueLen is the length of the value that was last fuzzed.
        valueLen int
+
+       // randState and randInc hold the state of a pseudo-random number generator.
+       randState, randInc uint64
 }
 
 // sharedMemSize returns the size needed for a shared memory buffer that can
index 4b8652cf03a8bb70af458c17d514ccbfc1bde12a..50a39a9a5b5fa6cf6caee4515a7e877c99127f14 100644 (file)
@@ -44,6 +44,14 @@ func (mr *mockRand) bool() bool {
        return b
 }
 
+func (mr *mockRand) save(*uint64, *uint64) {
+       panic("unimplemented")
+}
+
+func (mr *mockRand) restore(uint64, uint64) {
+       panic("unimplemented")
+}
+
 func TestByteSliceMutators(t *testing.T) {
        for _, tc := range []struct {
                name     string
index 0b799aab02ffe991b1560e04dd78b536bc088c5a..c9ea0afcf8c3286141ce54735a796a3fd7d5dfb5 100644 (file)
@@ -19,6 +19,9 @@ type mutatorRand interface {
        uint32n(uint32) uint32
        exp2() int
        bool() bool
+
+       save(randState, randInc *uint64)
+       restore(randState, randInc uint64)
 }
 
 // The functions in pcg implement a 32 bit PRNG with a 64 bit period: pcg xsh rr
@@ -74,6 +77,16 @@ func (r *pcgRand) step() {
        r.state += r.inc
 }
 
+func (r *pcgRand) save(randState, randInc *uint64) {
+       *randState = r.state
+       *randInc = r.inc
+}
+
+func (r *pcgRand) restore(randState, randInc uint64) {
+       r.state = randState
+       r.inc = randInc
+}
+
 // uint32 returns a pseudo-random uint32.
 func (r *pcgRand) uint32() uint32 {
        x := r.state
index 2acbf30ead1dd517b0597e8755530c25d514a463..e3029bcd66e0fb3f24802c47c473f843109e32d9 100644 (file)
@@ -5,6 +5,7 @@
 package fuzz
 
 import (
+       "bytes"
        "context"
        "crypto/sha256"
        "encoding/json"
@@ -156,7 +157,7 @@ func (w *worker) coordinate(ctx context.Context) error {
                                // to the client.
                                args.CoverageData = input.coverageData
                        }
-                       value, resp, err := w.client.fuzz(ctx, input.entry.Data, args)
+                       entry, resp, err := w.client.fuzz(ctx, input.entry, args)
                        if err != nil {
                                // Error communicating with worker.
                                w.stop()
@@ -194,26 +195,11 @@ func (w *worker) coordinate(ctx context.Context) error {
                                count:          resp.Count,
                                totalDuration:  resp.TotalDuration,
                                entryDuration:  resp.InterestingDuration,
+                               entry:          entry,
                        }
                        if resp.Err != "" {
-                               h := sha256.Sum256(value)
-                               name := fmt.Sprintf("%x", h[:4])
-                               result.entry = CorpusEntry{
-                                       Name:       name,
-                                       Parent:     input.entry.Name,
-                                       Data:       value,
-                                       Generation: input.entry.Generation + 1,
-                               }
                                result.crasherMsg = resp.Err
                        } else if resp.CoverageData != nil {
-                               h := sha256.Sum256(value)
-                               name := fmt.Sprintf("%x", h[:4])
-                               result.entry = CorpusEntry{
-                                       Name:       name,
-                                       Parent:     input.entry.Name,
-                                       Data:       value,
-                                       Generation: input.entry.Generation + 1,
-                               }
                                result.coverageData = resp.CoverageData
                        }
                        w.coordinator.resultC <- result
@@ -252,7 +238,7 @@ func (w *worker) minimize(ctx context.Context, input fuzzResult) (min fuzzResult
                Limit:   w.coordinator.opts.MinimizeLimit,
                Timeout: w.coordinator.opts.MinimizeTimeout,
        }
-       value, resp, err := w.client.minimize(ctx, input.entry.Data, args)
+       minEntry, resp, err := w.client.minimize(ctx, input.entry, args)
        if err != nil {
                // Error communicating with worker.
                w.stop()
@@ -274,7 +260,7 @@ func (w *worker) minimize(ctx context.Context, input fuzzResult) (min fuzzResult
        min.crasherMsg = resp.Err
        min.count = resp.Count
        min.totalDuration = resp.Duration
-       min.entry.Data = value
+       min.entry = minEntry
        return min, nil
 }
 
@@ -369,7 +355,9 @@ func (w *worker) start() (err error) {
        // called later by stop.
        w.cmd = cmd
        w.termC = make(chan struct{})
-       w.client = newWorkerClient(workerComm{fuzzIn: fuzzInW, fuzzOut: fuzzOutR, memMu: w.memMu})
+       comm := workerComm{fuzzIn: fuzzInW, fuzzOut: fuzzOutR, memMu: w.memMu}
+       m := newMutator()
+       w.client = newWorkerClient(comm, m)
 
        go func() {
                w.waitErr = w.cmd.Wait()
@@ -632,9 +620,17 @@ func (ws *workerServer) serve(ctx context.Context) error {
        }
 }
 
-// fuzz runs the test function on random variations of a given input value for
-// a given amount of time. fuzz returns early if it finds an input that crashes
-// the fuzz function or an input that expands coverage.
+// fuzz runs the test function on random variations of the input value in shared
+// memory for a limited duration or number of iterations.
+//
+// fuzz returns early if it finds an input that crashes the fuzz function (with
+// fuzzResponse.Err set) or an input that expands coverage (with
+// fuzzResponse.InterestingDuration set).
+//
+// fuzz does not modify the input in shared memory. Instead, it saves the
+// initial PRNG state in shared memory and increments a counter in shared
+// memory before each call to the test function. The caller may reconstruct
+// the crashing input with this information, since the PRNG is deterministic.
 func (ws *workerServer) fuzz(ctx context.Context, args fuzzArgs) (resp fuzzResponse) {
        if args.CoverageData != nil {
                ws.coverageData = args.CoverageData
@@ -648,6 +644,7 @@ func (ws *workerServer) fuzz(ctx context.Context, args fuzzArgs) (resp fuzzRespo
                defer cancel()
        }
        mem := <-ws.memMu
+       ws.m.r.save(&mem.header().randState, &mem.header().randInc)
        defer func() {
                resp.Count = mem.header().count
                ws.memMu <- mem
@@ -680,7 +677,6 @@ func (ws *workerServer) fuzz(ctx context.Context, args fuzzArgs) (resp fuzzRespo
                default:
                        mem.header().count++
                        ws.m.mutate(vals, cap(mem.valueRef()))
-                       writeToMem(vals, mem)
                        fStart := time.Now()
                        err := ws.fuzzFn(CorpusEntry{Values: vals})
                        fDur := time.Since(fStart)
@@ -879,10 +875,11 @@ func (ws *workerServer) ping(ctx context.Context, args pingArgs) pingResponse {
 type workerClient struct {
        workerComm
        mu sync.Mutex
+       m  *mutator
 }
 
-func newWorkerClient(comm workerComm) *workerClient {
-       return &workerClient{workerComm: comm}
+func newWorkerClient(comm workerComm, m *mutator) *workerClient {
+       return &workerClient{workerComm: comm, m: m}
 }
 
 // Close shuts down the connection to the RPC server (the worker process) by
@@ -919,55 +916,81 @@ var errSharedMemClosed = errors.New("internal error: shared memory was closed an
 
 // minimize tells the worker to call the minimize method. See
 // workerServer.minimize.
-func (wc *workerClient) minimize(ctx context.Context, valueIn []byte, args minimizeArgs) (valueOut []byte, resp minimizeResponse, err error) {
+func (wc *workerClient) minimize(ctx context.Context, entryIn CorpusEntry, args minimizeArgs) (entryOut CorpusEntry, resp minimizeResponse, err error) {
        wc.mu.Lock()
        defer wc.mu.Unlock()
 
        mem, ok := <-wc.memMu
        if !ok {
-               return nil, minimizeResponse{}, errSharedMemClosed
+               return CorpusEntry{}, minimizeResponse{}, errSharedMemClosed
        }
        mem.header().count = 0
-       mem.setValue(valueIn)
+       mem.setValue(entryIn.Data)
        wc.memMu <- mem
+       defer func() { wc.memMu <- mem }()
 
        c := call{Minimize: &args}
-       err = wc.callLocked(ctx, c, &resp)
+       callErr := wc.callLocked(ctx, c, &resp)
        mem, ok = <-wc.memMu
        if !ok {
-               return nil, minimizeResponse{}, errSharedMemClosed
+               return CorpusEntry{}, minimizeResponse{}, errSharedMemClosed
+       }
+       entryOut.Data = mem.valueCopy()
+       entryOut.Values, err = unmarshalCorpusFile(entryOut.Data)
+       if err != nil {
+               panic(fmt.Sprintf("workerClient.minimize unmarshaling minimized value: %v", err))
        }
-       valueOut = mem.valueCopy()
        resp.Count = mem.header().count
-       wc.memMu <- mem
 
-       return valueOut, resp, err
+       return entryOut, resp, callErr
 }
 
 // fuzz tells the worker to call the fuzz method. See workerServer.fuzz.
-func (wc *workerClient) fuzz(ctx context.Context, valueIn []byte, args fuzzArgs) (valueOut []byte, resp fuzzResponse, err error) {
+func (wc *workerClient) fuzz(ctx context.Context, entryIn CorpusEntry, args fuzzArgs) (entryOut CorpusEntry, resp fuzzResponse, err error) {
        wc.mu.Lock()
        defer wc.mu.Unlock()
 
        mem, ok := <-wc.memMu
        if !ok {
-               return nil, fuzzResponse{}, errSharedMemClosed
+               return CorpusEntry{}, fuzzResponse{}, errSharedMemClosed
        }
        mem.header().count = 0
-       mem.setValue(valueIn)
+       mem.setValue(entryIn.Data)
        wc.memMu <- mem
 
        c := call{Fuzz: &args}
-       err = wc.callLocked(ctx, c, &resp)
+       callErr := wc.callLocked(ctx, c, &resp)
        mem, ok = <-wc.memMu
        if !ok {
-               return nil, fuzzResponse{}, errSharedMemClosed
+               return CorpusEntry{}, fuzzResponse{}, errSharedMemClosed
        }
-       valueOut = mem.valueCopy()
+       defer func() { wc.memMu <- mem }()
        resp.Count = mem.header().count
-       wc.memMu <- mem
 
-       return valueOut, resp, err
+       if !bytes.Equal(entryIn.Data, mem.valueRef()) {
+               panic("workerServer.fuzz modified input")
+       }
+       valuesOut, err := unmarshalCorpusFile(entryIn.Data)
+       if err != nil {
+               panic(fmt.Sprintf("unmarshaling fuzz input value after call: %v", err))
+       }
+       wc.m.r.restore(mem.header().randState, mem.header().randInc)
+       for i := int64(0); i < mem.header().count; i++ {
+               wc.m.mutate(valuesOut, cap(mem.valueRef()))
+       }
+       dataOut := marshalCorpusFile(valuesOut...)
+
+       h := sha256.Sum256(dataOut)
+       name := fmt.Sprintf("%x", h[:4])
+       entryOut = CorpusEntry{
+               Name:       name,
+               Parent:     entryIn.Name,
+               Data:       dataOut,
+               Values:     valuesOut,
+               Generation: entryIn.Generation + 1,
+       }
+
+       return entryOut, resp, callErr
 }
 
 // ping tells the worker to call the ping method. See workerServer.ping.
index 6c75fc412c4e70a827fcaa074b9896ade618062d..2369b4ce3f48beb2c8e74f5830b7d5bf1d7a4be0 100644 (file)
@@ -79,13 +79,14 @@ func BenchmarkWorkerPing(b *testing.B) {
 func BenchmarkWorkerFuzz(b *testing.B) {
        b.SetParallelism(1)
        w := newWorkerForTest(b)
-       data := marshalCorpusFile([]byte(nil))
+       entry := CorpusEntry{Values: []interface{}{[]byte(nil)}}
+       entry.Data = marshalCorpusFile(entry.Values...)
        for i := int64(0); i < int64(b.N); {
                args := fuzzArgs{
                        Limit:   int64(b.N) - i,
                        Timeout: workerFuzzDuration,
                }
-               _, resp, err := w.client.fuzz(context.Background(), data, args)
+               _, resp, err := w.client.fuzz(context.Background(), entry, args)
                if err != nil {
                        b.Fatal(err)
                }