]> Cypherpunks repositories - gostls13.git/commitdiff
[dev.fuzz] internal/fuzz: support minimization of strings, integers, and floats
authorRoland Shoemaker <roland@golang.org>
Tue, 18 May 2021 03:14:15 +0000 (20:14 -0700)
committerRoland Shoemaker <roland@golang.org>
Thu, 27 May 2021 02:15:49 +0000 (02:15 +0000)
Adds support for minimizing strings using the same logic as byte slices
as well as minimizing both signed and unsigned integers and floats using
extremely basic logic. A more complex approach is probably warranted in
the future, but for now this should be _good enough_.

Change-Id: Ibc6c3d6ae82685998f571aa2c1ecea2f85c2708b
Reviewed-on: https://go-review.googlesource.com/c/go/+/320669
Trust: Roland Shoemaker <roland@golang.org>
Trust: Katie Hockman <katie@golang.org>
Run-TryBot: Roland Shoemaker <roland@golang.org>
TryBot-Result: Go Bot <gobot@golang.org>
Reviewed-by: Katie Hockman <katie@golang.org>
Reviewed-by: Jay Conrod <jayconrod@google.com>
src/internal/fuzz/minimize.go [new file with mode: 0644]
src/internal/fuzz/minimize_test.go [new file with mode: 0644]
src/internal/fuzz/worker.go

diff --git a/src/internal/fuzz/minimize.go b/src/internal/fuzz/minimize.go
new file mode 100644 (file)
index 0000000..c5533bd
--- /dev/null
@@ -0,0 +1,104 @@
+// Copyright 2021 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package fuzz
+
+import (
+       "context"
+       "math"
+)
+
+func minimizeBytes(ctx context.Context, v []byte, stillCrashes func(interface{}) bool, shouldStop func() bool) {
+       // First, try to cut the tail.
+       for n := 1024; n != 0; n /= 2 {
+               for len(v) > n {
+                       if shouldStop() {
+                               return
+                       }
+                       candidate := v[:len(v)-n]
+                       if !stillCrashes(candidate) {
+                               break
+                       }
+                       // Set v to the new value to continue iterating.
+                       v = candidate
+               }
+       }
+
+       // Then, try to remove each individual byte.
+       tmp := make([]byte, len(v))
+       for i := 0; i < len(v)-1; i++ {
+               if shouldStop() {
+                       return
+               }
+               candidate := tmp[:len(v)-1]
+               copy(candidate[:i], v[:i])
+               copy(candidate[i:], v[i+1:])
+               if !stillCrashes(candidate) {
+                       continue
+               }
+               // Update v to delete the value at index i.
+               copy(v[i:], v[i+1:])
+               v = v[:len(candidate)]
+               // v[i] is now different, so decrement i to redo this iteration
+               // of the loop with the new value.
+               i--
+       }
+
+       // Then, try to remove each possible subset of bytes.
+       for i := 0; i < len(v)-1; i++ {
+               copy(tmp, v[:i])
+               for j := len(v); j > i+1; j-- {
+                       if shouldStop() {
+                               return
+                       }
+                       candidate := tmp[:len(v)-j+i]
+                       copy(candidate[i:], v[j:])
+                       if !stillCrashes(candidate) {
+                               continue
+                       }
+                       // Update v and reset the loop with the new length.
+                       copy(v[i:], v[j:])
+                       v = v[:len(candidate)]
+                       j = len(v)
+               }
+       }
+
+       return
+}
+
+func minimizeInteger(ctx context.Context, v uint, stillCrashes func(interface{}) bool, shouldStop func() bool) {
+       // TODO(rolandshoemaker): another approach could be either unsetting/setting all bits
+       // (depending on signed-ness), or rotating bits? When operating on cast signed integers
+       // this would probably be more complex though.
+       for ; v != 0; v /= 10 {
+               if shouldStop() {
+                       return
+               }
+               // We ignore the return value here because there is no point
+               // advancing the loop, since there is nothing after this check,
+               // and we don't return early because a smaller value could
+               // re-trigger the crash.
+               stillCrashes(v)
+       }
+       return
+}
+
+func minimizeFloat(ctx context.Context, v float64, stillCrashes func(interface{}) bool, shouldStop func() bool) {
+       if math.IsNaN(v) {
+               return
+       }
+       minimized := float64(0)
+       for div := 10.0; minimized < v; div *= 10 {
+               if shouldStop() {
+                       return
+               }
+               minimized = float64(int(v*div)) / div
+               if !stillCrashes(minimized) {
+                       // Since we are searching from least precision -> highest precision we
+                       // can return early since we've already found the smallest value
+                       return
+               }
+       }
+       return
+}
diff --git a/src/internal/fuzz/minimize_test.go b/src/internal/fuzz/minimize_test.go
new file mode 100644 (file)
index 0000000..500ff43
--- /dev/null
@@ -0,0 +1,219 @@
+// Copyright 2021 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build darwin || linux || windows
+// +build darwin linux windows
+
+package fuzz
+
+import (
+       "context"
+       "fmt"
+       "reflect"
+       "testing"
+)
+
+func TestMinimizeInput(t *testing.T) {
+       type testcase struct {
+               fn       func(CorpusEntry) error
+               input    []interface{}
+               expected []interface{}
+       }
+       cases := []testcase{
+               {
+                       fn: func(e CorpusEntry) error {
+                               b := e.Values[0].([]byte)
+                               ones := 0
+                               for _, v := range b {
+                                       if v == 1 {
+                                               ones++
+                                       }
+                               }
+                               if ones == 3 {
+                                       return fmt.Errorf("bad %v", e.Values[0])
+                               }
+                               return nil
+                       },
+                       input:    []interface{}{[]byte{0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}},
+                       expected: []interface{}{[]byte{1, 1, 1}},
+               },
+               {
+                       fn: func(e CorpusEntry) error {
+                               b := e.Values[0].(string)
+                               ones := 0
+                               for _, v := range b {
+                                       if v == '1' {
+                                               ones++
+                                       }
+                               }
+                               if ones == 3 {
+                                       return fmt.Errorf("bad %v", e.Values[0])
+                               }
+                               return nil
+                       },
+                       input:    []interface{}{"001010001000000000000000000"},
+                       expected: []interface{}{"111"},
+               },
+               {
+                       fn: func(e CorpusEntry) error {
+                               i := e.Values[0].(int)
+                               if i > 100 {
+                                       return fmt.Errorf("bad %v", e.Values[0])
+                               }
+                               return nil
+                       },
+                       input:    []interface{}{123456},
+                       expected: []interface{}{123},
+               },
+               {
+                       fn: func(e CorpusEntry) error {
+                               i := e.Values[0].(int8)
+                               if i > 10 {
+                                       return fmt.Errorf("bad %v", e.Values[0])
+                               }
+                               return nil
+                       },
+                       input:    []interface{}{int8(1<<7 - 1)},
+                       expected: []interface{}{int8(12)},
+               },
+               {
+                       fn: func(e CorpusEntry) error {
+                               i := e.Values[0].(int16)
+                               if i > 10 {
+                                       return fmt.Errorf("bad %v", e.Values[0])
+                               }
+                               return nil
+                       },
+                       input:    []interface{}{int16(1<<15 - 1)},
+                       expected: []interface{}{int16(32)},
+               },
+               {
+                       fn: func(e CorpusEntry) error {
+                               i := e.Values[0].(int32)
+                               if i > 10 {
+                                       return fmt.Errorf("bad %v", e.Values[0])
+                               }
+                               return nil
+                       },
+                       input:    []interface{}{int32(1<<31 - 1)},
+                       expected: []interface{}{int32(21)},
+               },
+               {
+                       fn: func(e CorpusEntry) error {
+                               i := e.Values[0].(uint)
+                               if i > 10 {
+                                       return fmt.Errorf("bad %v", e.Values[0])
+                               }
+                               return nil
+                       },
+                       input:    []interface{}{uint(123456)},
+                       expected: []interface{}{uint(12)},
+               },
+               {
+                       fn: func(e CorpusEntry) error {
+                               i := e.Values[0].(uint8)
+                               if i > 10 {
+                                       return fmt.Errorf("bad %v", e.Values[0])
+                               }
+                               return nil
+                       },
+                       input:    []interface{}{uint8(1<<8 - 1)},
+                       expected: []interface{}{uint8(25)},
+               },
+               {
+                       fn: func(e CorpusEntry) error {
+                               i := e.Values[0].(uint16)
+                               if i > 10 {
+                                       return fmt.Errorf("bad %v", e.Values[0])
+                               }
+                               return nil
+                       },
+                       input:    []interface{}{uint16(1<<16 - 1)},
+                       expected: []interface{}{uint16(65)},
+               },
+               {
+                       fn: func(e CorpusEntry) error {
+                               i := e.Values[0].(uint32)
+                               if i > 10 {
+                                       return fmt.Errorf("bad %v", e.Values[0])
+                               }
+                               return nil
+                       },
+                       input:    []interface{}{uint32(1<<32 - 1)},
+                       expected: []interface{}{uint32(42)},
+               },
+               {
+                       fn: func(e CorpusEntry) error {
+                               if i := e.Values[0].(float32); i == 1.23 {
+                                       return nil
+                               }
+                               return fmt.Errorf("bad %v", e.Values[0])
+                       },
+                       input:    []interface{}{float32(1.23456789)},
+                       expected: []interface{}{float32(1.2)},
+               },
+               {
+                       fn: func(e CorpusEntry) error {
+                               if i := e.Values[0].(float64); i == 1.23 {
+                                       return nil
+                               }
+                               return fmt.Errorf("bad %v", e.Values[0])
+                       },
+                       input:    []interface{}{float64(1.23456789)},
+                       expected: []interface{}{float64(1.2)},
+               },
+       }
+
+       // If we are on a 64 bit platform add int64 and uint64 tests
+       if v := int64(1<<63 - 1); int64(int(v)) == v {
+               cases = append(cases, testcase{
+                       fn: func(e CorpusEntry) error {
+                               i := e.Values[0].(int64)
+                               if i > 10 {
+                                       return fmt.Errorf("bad %v", e.Values[0])
+                               }
+                               return nil
+                       },
+                       input:    []interface{}{int64(1<<63 - 1)},
+                       expected: []interface{}{int64(92)},
+               }, testcase{
+                       fn: func(e CorpusEntry) error {
+                               i := e.Values[0].(uint64)
+                               if i > 10 {
+                                       return fmt.Errorf("bad %v", e.Values[0])
+                               }
+                               return nil
+                       },
+                       input:    []interface{}{uint64(1<<64 - 1)},
+                       expected: []interface{}{uint64(18)},
+               })
+       }
+
+       sm, err := sharedMemTempFile(workerSharedMemSize)
+       if err != nil {
+               t.Fatalf("failed to create temporary shared memory file: %s", err)
+       }
+       defer sm.Close()
+
+       for _, tc := range cases {
+               ws := &workerServer{
+                       fuzzFn: tc.fn,
+               }
+               count := int64(0)
+               err = ws.minimizeInput(context.Background(), tc.input, sm, &count, 0)
+               if err == nil {
+                       t.Error("minimizeInput didn't fail")
+               }
+               if expected := fmt.Sprintf("bad %v", tc.input[0]); err.Error() != expected {
+                       t.Errorf("unexpected error: got %s, want %s", err, expected)
+               }
+               vals, err := unmarshalCorpusFile(sm.valueCopy())
+               if err != nil {
+                       t.Fatalf("failed to unmarshal values from shared memory file: %s", err)
+               }
+               if !reflect.DeepEqual(vals, tc.expected) {
+                       t.Errorf("unexpected results: got %v, want %v", vals, tc.expected)
+               }
+       }
+}
index 875f3ac5bab5a7ea733c25aba30f2d0933e52251..33727a543858a170fe94a95ae1940cbd61a60fe5 100644 (file)
@@ -714,7 +714,7 @@ func (ws *workerServer) minimize(ctx context.Context, args minimizeArgs) (resp m
 // mem just in case an unrecoverable error occurs. It uses the context to
 // determine how long to run, stopping once closed. It returns the last error it
 // found.
-func (ws *workerServer) minimizeInput(ctx context.Context, vals []interface{}, mem *sharedMem, count *int64, limit int64) (retErr error) {
+func (ws *workerServer) minimizeInput(ctx context.Context, vals []interface{}, mem *sharedMem, count *int64, limit int64) error {
        // Make sure the last crashing value is written to mem.
        defer writeToMem(vals, mem)
 
@@ -725,87 +725,116 @@ func (ws *workerServer) minimizeInput(ctx context.Context, vals []interface{}, m
                return nil
        }
 
-       // tryMinimized will run the fuzz function for the values in vals at the
-       // time the function is called. If err is nil, then the minimization was
-       // unsuccessful, since we expect an error to still occur.
-       tryMinimized := func(i int, prevVal interface{}) error {
-               writeToMem(vals, mem) // write to mem in case a non-recoverable crash occurs
+       var valI int
+       var retErr error
+       tryMinimized := func(candidate interface{}) bool {
+               prev := vals[valI]
+               // Set vals[valI] to the candidate after it has been
+               // properly cast. We know that candidate must be of
+               // the same type as prev, so use that as a reference.
+               switch c := candidate.(type) {
+               case float64:
+                       switch prev.(type) {
+                       case float32:
+                               vals[valI] = float32(c)
+                       case float64:
+                               vals[valI] = c
+                       default:
+                               panic("impossible")
+                       }
+               case uint:
+                       switch prev.(type) {
+                       case uint:
+                               vals[valI] = c
+                       case uint8:
+                               vals[valI] = uint8(c)
+                       case uint16:
+                               vals[valI] = uint16(c)
+                       case uint32:
+                               vals[valI] = uint32(c)
+                       case uint64:
+                               vals[valI] = uint64(c)
+                       case int:
+                               vals[valI] = int(c)
+                       case int8:
+                               vals[valI] = int8(c)
+                       case int16:
+                               vals[valI] = int16(c)
+                       case int32:
+                               vals[valI] = int32(c)
+                       case int64:
+                               vals[valI] = int64(c)
+                       default:
+                               panic("impossible")
+                       }
+               case []byte:
+                       switch prev.(type) {
+                       case []byte:
+                               vals[valI] = c
+                       case string:
+                               vals[valI] = string(c)
+                       default:
+                               panic("impossible")
+                       }
+               default:
+                       panic("impossible")
+               }
+               writeToMem(vals, mem)
                err := ws.fuzzFn(CorpusEntry{Values: vals})
-               if err == nil {
-                       // The fuzz function succeeded, so return the value at index i back
-                       // to the previously failing input.
-                       vals[i] = prevVal
-               } else {
-                       // The fuzz function failed, so save the most recent error.
+               if err != nil {
                        retErr = err
+                       return true
                }
                *count++
-               return err
+               vals[valI] = prev
+               return false
        }
-       for valI := range vals {
+
+       for valI = range vals {
+               if shouldStop() {
+                       return retErr
+               }
                switch v := vals[valI].(type) {
-               case bool, byte, rune:
+               case bool:
                        continue // can't minimize
-               case string, int, int8, int16, int64, uint, uint16, uint32, uint64, float32, float64:
-                       // TODO(jayconrod,katiehockman): support minimizing other types
-               case []byte:
-                       // First, try to cut the tail.
-                       for n := 1024; n != 0; n /= 2 {
-                               for len(v) > n {
-                                       if shouldStop() {
-                                               return retErr
-                                       }
-                                       vals[valI] = v[:len(v)-n]
-                                       if tryMinimized(valI, v) == nil {
-                                               break
-                                       }
-                                       // Set v to the new value to continue iterating.
-                                       v = v[:len(v)-n]
-                               }
+               case float32:
+                       minimizeFloat(ctx, float64(v), tryMinimized, shouldStop)
+               case float64:
+                       minimizeFloat(ctx, v, tryMinimized, shouldStop)
+               case uint:
+                       minimizeInteger(ctx, v, tryMinimized, shouldStop)
+               case uint8:
+                       minimizeInteger(ctx, uint(v), tryMinimized, shouldStop)
+               case uint16:
+                       minimizeInteger(ctx, uint(v), tryMinimized, shouldStop)
+               case uint32:
+                       minimizeInteger(ctx, uint(v), tryMinimized, shouldStop)
+               case uint64:
+                       if uint64(uint(v)) != v {
+                               // Skip minimizing a uint64 on 32 bit platforms, since we'll truncate the
+                               // value when casting
+                               continue
                        }
-
-                       // Then, try to remove each individual byte.
-                       tmp := make([]byte, len(v))
-                       for i := 0; i < len(v)-1; i++ {
-                               if shouldStop() {
-                                       return retErr
-                               }
-                               candidate := tmp[:len(v)-1]
-                               copy(candidate[:i], v[:i])
-                               copy(candidate[i:], v[i+1:])
-                               vals[valI] = candidate
-                               if tryMinimized(valI, v) == nil {
-                                       continue
-                               }
-                               // Update v to delete the value at index i.
-                               copy(v[i:], v[i+1:])
-                               v = v[:len(candidate)]
-                               // v[i] is now different, so decrement i to redo this iteration
-                               // of the loop with the new value.
-                               i--
+                       minimizeInteger(ctx, uint(v), tryMinimized, shouldStop)
+               case int:
+                       minimizeInteger(ctx, uint(v), tryMinimized, shouldStop)
+               case int8:
+                       minimizeInteger(ctx, uint(v), tryMinimized, shouldStop)
+               case int16:
+                       minimizeInteger(ctx, uint(v), tryMinimized, shouldStop)
+               case int32:
+                       minimizeInteger(ctx, uint(v), tryMinimized, shouldStop)
+               case int64:
+                       if int64(int(v)) != v {
+                               // Skip minimizing a int64 on 32 bit platforms, since we'll truncate the
+                               // value when casting
+                               continue
                        }
-
-                       // Then, try to remove each possible subset of bytes.
-                       for i := 0; i < len(v)-1; i++ {
-                               copy(tmp, v[:i])
-                               for j := len(v); j > i+1; j-- {
-                                       if shouldStop() {
-                                               return retErr
-                                       }
-                                       candidate := tmp[:len(v)-j+i]
-                                       copy(candidate[i:], v[j:])
-                                       vals[valI] = candidate
-                                       if tryMinimized(valI, v) == nil {
-                                               continue
-                                       }
-                                       // Update v and reset the loop with the new length.
-                                       copy(v[i:], v[j:])
-                                       v = v[:len(candidate)]
-                                       j = len(v)
-                               }
-                       }
-                       // TODO(jayconrod,katiehockman): consider adding canonicalization
-                       // which replaces each individual byte with '0'
+                       minimizeInteger(ctx, uint(v), tryMinimized, shouldStop)
+               case string:
+                       minimizeBytes(ctx, []byte(v), tryMinimized, shouldStop)
+               case []byte:
+                       minimizeBytes(ctx, v, tryMinimized, shouldStop)
                default:
                        panic("unreachable")
                }