]> Cypherpunks repositories - gostls13.git/commitdiff
runtime: tidy Windows callback test
authorAustin Clements <austin@google.com>
Sat, 17 Oct 2020 02:22:20 +0000 (22:22 -0400)
committerAustin Clements <austin@google.com>
Mon, 26 Oct 2020 14:50:37 +0000 (14:50 +0000)
This simplifies the systematic test of Windows callbacks with
different signatures and prepares it for expanded coverage of function
signatures.

It now returns a result from the Go function and threads it back
through C. This simplifies things, but also previously the code could
have succeeded by simply not calling the callbacks at all (though
other tests would have caught that).

It bundles together the C function description and the Go function
it's intended to call. Now the test source generation and the test
running both loop over a single slice of test functions.

Since the C function and Go function are now bundled, it generates the
C function by reflectively inspecting the signature of the Go
function. For the moment, we keep the same test suite, which is
entirely functions with "uintptr" arguments, but we'll expand this
shortly.

It now use sub-tests. This way tests automatically get useful
diagnostic labels in failures and the tests don't have to catch panics
on their own.

It eliminates the DLL function argument. I honestly couldn't figure
out what the point of this was, and it added what appeared to be an
unnecessary loop level to the tests.

Change-Id: I120dfd4785057cc2c392bd2c821302f276bd128e
Reviewed-on: https://go-review.googlesource.com/c/go/+/263270
Trust: Austin Clements <austin@google.com>
Trust: Alex Brainman <alex.brainman@gmail.com>
Run-TryBot: Austin Clements <austin@google.com>
TryBot-Result: Go Bot <gobot@golang.org>
Reviewed-by: Alex Brainman <alex.brainman@gmail.com>
src/runtime/syscall_windows_test.go

index 2e74546e38333639a73157b2c5c1a969eccd0893..cb942beb3e8950ce2ca3e4e7192686811f71b116 100644 (file)
@@ -9,11 +9,13 @@ import (
        "fmt"
        "internal/syscall/windows/sysdll"
        "internal/testenv"
+       "io"
        "io/ioutil"
        "math"
        "os"
        "os/exec"
        "path/filepath"
+       "reflect"
        "runtime"
        "strconv"
        "strings"
@@ -285,99 +287,85 @@ func TestCallbackInAnotherThread(t *testing.T) {
        }
 }
 
-type cbDLLFunc int // int determines number of callback parameters
-
-func (f cbDLLFunc) stdcallName() string {
-       return fmt.Sprintf("stdcall%d", f)
+type cbFunc struct {
+       goFunc interface{}
 }
 
-func (f cbDLLFunc) cdeclName() string {
-       return fmt.Sprintf("cdecl%d", f)
+func (f cbFunc) cName(cdecl bool) string {
+       name := "stdcall"
+       if cdecl {
+               name = "cdecl"
+       }
+       t := reflect.TypeOf(f.goFunc)
+       for i := 0; i < t.NumIn(); i++ {
+               name += "_" + t.In(i).Name()
+       }
+       return name
 }
 
-func (f cbDLLFunc) buildOne(stdcall bool) string {
-       var funcname, attr string
-       if stdcall {
-               funcname = f.stdcallName()
-               attr = "__stdcall"
-       } else {
-               funcname = f.cdeclName()
+func (f cbFunc) cSrc(w io.Writer, cdecl bool) {
+       // Construct a C function that takes a callback with
+       // f.goFunc's signature, and calls it with integers 1..N.
+       funcname := f.cName(cdecl)
+       attr := "__stdcall"
+       if cdecl {
                attr = "__cdecl"
        }
        typename := "t" + funcname
-       p := make([]string, f)
-       for i := range p {
-               p[i] = "uintptr_t"
-       }
-       params := strings.Join(p, ",")
-       for i := range p {
-               p[i] = fmt.Sprintf("%d", i+1)
-       }
-       args := strings.Join(p, ",")
-       return fmt.Sprintf(`
-typedef void %s (*%s)(%s);
-void %s(%s f, uintptr_t n) {
-       uintptr_t i;
-       for(i=0;i<n;i++){
-               f(%s);
-       }
+       t := reflect.TypeOf(f.goFunc)
+       cTypes := make([]string, t.NumIn())
+       cArgs := make([]string, t.NumIn())
+       for i := range cTypes {
+               // We included stdint.h, so this works for all sized
+               // integer types.
+               cTypes[i] = t.In(i).Name() + "_t"
+               cArgs[i] = fmt.Sprintf("%d", i+1)
+       }
+       fmt.Fprintf(w, `
+typedef uintptr_t %s (*%s)(%s);
+uintptr_t %s(%s f) {
+       return f(%s);
 }
-       `, attr, typename, params, funcname, typename, args)
+       `, attr, typename, strings.Join(cTypes, ","), funcname, typename, strings.Join(cArgs, ","))
 }
 
-func (f cbDLLFunc) build() string {
-       return "#include <stdint.h>\n\n" + f.buildOne(false) + f.buildOne(true)
+func (f cbFunc) testOne(t *testing.T, dll *syscall.DLL, cdecl bool, cb uintptr) {
+       r1, _, _ := dll.MustFindProc(f.cName(cdecl)).Call(cb)
+
+       want := 0
+       for i := 0; i < reflect.TypeOf(f.goFunc).NumIn(); i++ {
+               want += i + 1
+       }
+       if int(r1) != want {
+               t.Errorf("wanted result %d; got %d", want, r1)
+       }
 }
 
-var cbFuncs = [...]interface{}{
-       2: func(i1, i2 uintptr) uintptr {
-               if i1+i2 != 3 {
-                       panic("bad input")
-               }
-               return 0
-       },
-       3: func(i1, i2, i3 uintptr) uintptr {
-               if i1+i2+i3 != 6 {
-                       panic("bad input")
-               }
-               return 0
-       },
-       4: func(i1, i2, i3, i4 uintptr) uintptr {
-               if i1+i2+i3+i4 != 10 {
-                       panic("bad input")
-               }
-               return 0
-       },
-       5: func(i1, i2, i3, i4, i5 uintptr) uintptr {
-               if i1+i2+i3+i4+i5 != 15 {
-                       panic("bad input")
-               }
-               return 0
-       },
-       6: func(i1, i2, i3, i4, i5, i6 uintptr) uintptr {
-               if i1+i2+i3+i4+i5+i6 != 21 {
-                       panic("bad input")
-               }
-               return 0
-       },
-       7: func(i1, i2, i3, i4, i5, i6, i7 uintptr) uintptr {
-               if i1+i2+i3+i4+i5+i6+i7 != 28 {
-                       panic("bad input")
-               }
-               return 0
-       },
-       8: func(i1, i2, i3, i4, i5, i6, i7, i8 uintptr) uintptr {
-               if i1+i2+i3+i4+i5+i6+i7+i8 != 36 {
-                       panic("bad input")
-               }
-               return 0
-       },
-       9: func(i1, i2, i3, i4, i5, i6, i7, i8, i9 uintptr) uintptr {
-               if i1+i2+i3+i4+i5+i6+i7+i8+i9 != 45 {
-                       panic("bad input")
-               }
-               return 0
-       },
+var cbFuncs = []cbFunc{
+       {func(i1, i2 uintptr) uintptr {
+               return i1 + i2
+       }},
+       {func(i1, i2, i3 uintptr) uintptr {
+               return i1 + i2 + i3
+       }},
+       {func(i1, i2, i3, i4 uintptr) uintptr {
+               return i1 + i2 + i3 + i4
+       }},
+       {func(i1, i2, i3, i4, i5 uintptr) uintptr {
+               return i1 + i2 + i3 + i4 + i5
+       }},
+       {func(i1, i2, i3, i4, i5, i6 uintptr) uintptr {
+               return i1 + i2 + i3 + i4 + i5 + i6
+       }},
+       {func(i1, i2, i3, i4, i5, i6, i7 uintptr) uintptr {
+               return i1 + i2 + i3 + i4 + i5 + i6 + i7
+       }},
+       {func(i1, i2, i3, i4, i5, i6, i7, i8 uintptr) uintptr {
+               return i1 + i2 + i3 + i4 + i5 + i6 + i7 + i8
+       }},
+       {func(i1, i2, i3, i4, i5, i6, i7, i8, i9 uintptr) uintptr {
+               return i1 + i2 + i3 + i4 + i5 + i6 + i7 + i8 + i9
+       }},
 }
 
 type cbDLL struct {
@@ -385,21 +373,23 @@ type cbDLL struct {
        buildArgs func(out, src string) []string
 }
 
-func (d *cbDLL) buildSrc(t *testing.T, path string) {
+func (d *cbDLL) makeSrc(t *testing.T, path string) {
        f, err := os.Create(path)
        if err != nil {
                t.Fatalf("failed to create source file: %v", err)
        }
        defer f.Close()
 
-       for i := 2; i < 10; i++ {
-               fmt.Fprint(f, cbDLLFunc(i).build())
+       fmt.Fprintf(f, "#include <stdint.h>\n\n")
+       for _, cbf := range cbFuncs {
+               cbf.cSrc(f, false)
+               cbf.cSrc(f, true)
        }
 }
 
 func (d *cbDLL) build(t *testing.T, dir string) string {
        srcname := d.name + ".c"
-       d.buildSrc(t, filepath.Join(dir, srcname))
+       d.makeSrc(t, filepath.Join(dir, srcname))
        outname := d.name + ".dll"
        args := d.buildArgs(outname, srcname)
        cmd := exec.Command(args[0], args[1:]...)
@@ -426,51 +416,6 @@ var cbDLLs = []cbDLL{
        },
 }
 
-type cbTest struct {
-       n     int     // number of callback parameters
-       param uintptr // dll function parameter
-}
-
-func (test *cbTest) run(t *testing.T, dllpath string) {
-       dll := syscall.MustLoadDLL(dllpath)
-       defer dll.Release()
-       cb := cbFuncs[test.n]
-       stdcall := syscall.NewCallback(cb)
-       f := cbDLLFunc(test.n)
-       test.runOne(t, dll, f.stdcallName(), stdcall)
-       cdecl := syscall.NewCallbackCDecl(cb)
-       test.runOne(t, dll, f.cdeclName(), cdecl)
-}
-
-func (test *cbTest) runOne(t *testing.T, dll *syscall.DLL, proc string, cb uintptr) {
-       defer func() {
-               if r := recover(); r != nil {
-                       t.Errorf("dll call %v(..., %d) failed: %v", proc, test.param, r)
-               }
-       }()
-       dll.MustFindProc(proc).Call(cb, test.param)
-}
-
-var cbTests = []cbTest{
-       {2, 1},
-       {2, 10000},
-       {3, 3},
-       {4, 5},
-       {4, 6},
-       {5, 2},
-       {6, 7},
-       {6, 8},
-       {7, 6},
-       {8, 1},
-       {9, 8},
-       {9, 10000},
-       {3, 4},
-       {5, 3},
-       {7, 7},
-       {8, 2},
-       {9, 9},
-}
-
 func TestStdcallAndCDeclCallbacks(t *testing.T) {
        if _, err := exec.LookPath("gcc"); err != nil {
                t.Skip("skipping test: gcc is missing")
@@ -482,10 +427,21 @@ func TestStdcallAndCDeclCallbacks(t *testing.T) {
        defer os.RemoveAll(tmp)
 
        for _, dll := range cbDLLs {
-               dllPath := dll.build(t, tmp)
-               for _, test := range cbTests {
-                       test.run(t, dllPath)
-               }
+               t.Run(dll.name, func(t *testing.T) {
+                       dllPath := dll.build(t, tmp)
+                       dll := syscall.MustLoadDLL(dllPath)
+                       defer dll.Release()
+                       for _, cbf := range cbFuncs {
+                               t.Run(cbf.cName(false), func(t *testing.T) {
+                                       stdcall := syscall.NewCallback(cbf.goFunc)
+                                       cbf.testOne(t, dll, false, stdcall)
+                               })
+                               t.Run(cbf.cName(true), func(t *testing.T) {
+                                       cdecl := syscall.NewCallbackCDecl(cbf.goFunc)
+                                       cbf.testOne(t, dll, true, cdecl)
+                               })
+                       }
+               })
        }
 }