]> Cypherpunks repositories - gostls13.git/commitdiff
runtime: store syscall parameters in m not on stack
authorAlex Brainman <alex.brainman@gmail.com>
Fri, 22 May 2015 00:58:57 +0000 (10:58 +1000)
committerAlex Brainman <alex.brainman@gmail.com>
Mon, 29 Jun 2015 02:45:45 +0000 (02:45 +0000)
Stack can move during callback, so libcall struct cannot be stored on stack.
asmstdcall updates return values and errno in libcall struct parameter, but
these could be at different location when callback returns.
Store these in m, so they are not affected by GC.

Fixes #10406

Change-Id: Id01c9d2b4b44530494e6d9e9e1c875261ce477cd
Reviewed-on: https://go-review.googlesource.com/10370
Reviewed-by: Russ Cox <rsc@golang.org>
src/runtime/cgocall.go
src/runtime/runtime2.go
src/runtime/syscall_windows.go
src/runtime/syscall_windows_test.go

index 17e01251e227201b2a83d8ade2ce51e1b33ba038..f09a66a07d0e3deab1aef43c0f47a8225fc314a0 100644 (file)
@@ -163,6 +163,10 @@ func cgocallbackg() {
                exit(2)
        }
 
+       // Save current syscall parameters, so m.syscall can be
+       // used again if callback decide to make syscall.
+       syscall := gp.m.syscall
+
        // entersyscall saves the caller's SP to allow the GC to trace the Go
        // stack. However, since we're returning to an earlier stack frame and
        // need to pair with the entersyscall() call made by cgocall, we must
@@ -173,6 +177,8 @@ func cgocallbackg() {
        cgocallbackg1()
        // going back to cgo call
        reentersyscall(savedpc, uintptr(savedsp))
+
+       gp.m.syscall = syscall
 }
 
 func cgocallbackg1() {
index d2dfa71edd6ed0b8d47fe8c1407d0e8dc15139e4..64b2d03a929d60c5e771d4c48c29d982babc9fb0 100644 (file)
@@ -332,6 +332,7 @@ type m struct {
        libcallpc uintptr // for cpu profiler
        libcallsp uintptr
        libcallg  guintptr
+       syscall   libcall // stores syscall parameters on windows
        //#endif
        //#ifdef GOOS_solaris
        perrno *int32 // pointer to tls errno
index d2e44d761a130aa23825d5c9fb6cf4fec5de7b10..8e069cdb15f82a014f969eeda8120d9e4b5d743d 100644 (file)
@@ -91,11 +91,11 @@ func compileCallback(fn eface, cleanstack bool) (code uintptr) {
 //go:linkname syscall_loadlibrary syscall.loadlibrary
 //go:nosplit
 func syscall_loadlibrary(filename *uint16) (handle, err uintptr) {
-       var c libcall
+       c := &getg().m.syscall
        c.fn = getLoadLibrary()
        c.n = 1
-       c.args = uintptr(unsafe.Pointer(&filename))
-       cgocall(asmstdcallAddr, unsafe.Pointer(&c))
+       c.args = uintptr(noescape(unsafe.Pointer(&filename)))
+       cgocall(asmstdcallAddr, unsafe.Pointer(c))
        handle = c.r1
        if handle == 0 {
                err = c.err
@@ -106,11 +106,11 @@ func syscall_loadlibrary(filename *uint16) (handle, err uintptr) {
 //go:linkname syscall_getprocaddress syscall.getprocaddress
 //go:nosplit
 func syscall_getprocaddress(handle uintptr, procname *byte) (outhandle, err uintptr) {
-       var c libcall
+       c := &getg().m.syscall
        c.fn = getGetProcAddress()
        c.n = 2
-       c.args = uintptr(unsafe.Pointer(&handle))
-       cgocall(asmstdcallAddr, unsafe.Pointer(&c))
+       c.args = uintptr(noescape(unsafe.Pointer(&handle)))
+       cgocall(asmstdcallAddr, unsafe.Pointer(c))
        outhandle = c.r1
        if outhandle == 0 {
                err = c.err
@@ -121,54 +121,54 @@ func syscall_getprocaddress(handle uintptr, procname *byte) (outhandle, err uint
 //go:linkname syscall_Syscall syscall.Syscall
 //go:nosplit
 func syscall_Syscall(fn, nargs, a1, a2, a3 uintptr) (r1, r2, err uintptr) {
-       var c libcall
+       c := &getg().m.syscall
        c.fn = fn
        c.n = nargs
-       c.args = uintptr(unsafe.Pointer(&a1))
-       cgocall(asmstdcallAddr, unsafe.Pointer(&c))
+       c.args = uintptr(noescape(unsafe.Pointer(&a1)))
+       cgocall(asmstdcallAddr, unsafe.Pointer(c))
        return c.r1, c.r2, c.err
 }
 
 //go:linkname syscall_Syscall6 syscall.Syscall6
 //go:nosplit
 func syscall_Syscall6(fn, nargs, a1, a2, a3, a4, a5, a6 uintptr) (r1, r2, err uintptr) {
-       var c libcall
+       c := &getg().m.syscall
        c.fn = fn
        c.n = nargs
-       c.args = uintptr(unsafe.Pointer(&a1))
-       cgocall(asmstdcallAddr, unsafe.Pointer(&c))
+       c.args = uintptr(noescape(unsafe.Pointer(&a1)))
+       cgocall(asmstdcallAddr, unsafe.Pointer(c))
        return c.r1, c.r2, c.err
 }
 
 //go:linkname syscall_Syscall9 syscall.Syscall9
 //go:nosplit
 func syscall_Syscall9(fn, nargs, a1, a2, a3, a4, a5, a6, a7, a8, a9 uintptr) (r1, r2, err uintptr) {
-       var c libcall
+       c := &getg().m.syscall
        c.fn = fn
        c.n = nargs
-       c.args = uintptr(unsafe.Pointer(&a1))
-       cgocall(asmstdcallAddr, unsafe.Pointer(&c))
+       c.args = uintptr(noescape(unsafe.Pointer(&a1)))
+       cgocall(asmstdcallAddr, unsafe.Pointer(c))
        return c.r1, c.r2, c.err
 }
 
 //go:linkname syscall_Syscall12 syscall.Syscall12
 //go:nosplit
 func syscall_Syscall12(fn, nargs, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12 uintptr) (r1, r2, err uintptr) {
-       var c libcall
+       c := &getg().m.syscall
        c.fn = fn
        c.n = nargs
-       c.args = uintptr(unsafe.Pointer(&a1))
-       cgocall(asmstdcallAddr, unsafe.Pointer(&c))
+       c.args = uintptr(noescape(unsafe.Pointer(&a1)))
+       cgocall(asmstdcallAddr, unsafe.Pointer(c))
        return c.r1, c.r2, c.err
 }
 
 //go:linkname syscall_Syscall15 syscall.Syscall15
 //go:nosplit
 func syscall_Syscall15(fn, nargs, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15 uintptr) (r1, r2, err uintptr) {
-       var c libcall
+       c := &getg().m.syscall
        c.fn = fn
        c.n = nargs
-       c.args = uintptr(unsafe.Pointer(&a1))
-       cgocall(asmstdcallAddr, unsafe.Pointer(&c))
+       c.args = uintptr(noescape(unsafe.Pointer(&a1)))
+       cgocall(asmstdcallAddr, unsafe.Pointer(c))
        return c.r1, c.r2, c.err
 }
index 720f70bdfccd11de8d04e3c057dca237371911c3..cb9dfcde9d89a837bee858f84306dd2601fd18c8 100644 (file)
@@ -554,3 +554,86 @@ func TestWERDialogue(t *testing.T) {
        // Child process should not open WER dialogue, but return immediately instead.
        cmd.CombinedOutput()
 }
+
+var used byte
+
+func use(buf []byte) {
+       for _, c := range buf {
+               used += c
+       }
+}
+
+func forceStackCopy() (r int) {
+       var f func(int) int
+       f = func(i int) int {
+               var buf [256]byte
+               use(buf[:])
+               if i == 0 {
+                       return 0
+               }
+               return i + f(i-1)
+       }
+       r = f(128)
+       return
+}
+
+func TestReturnAfterStackGrowInCallback(t *testing.T) {
+
+       const src = `
+#include <stdint.h>
+#include <windows.h>
+
+typedef uintptr_t __stdcall (*callback)(uintptr_t);
+
+uintptr_t cfunc(callback f, uintptr_t n) {
+   uintptr_t r;
+   r = f(n);
+   SetLastError(333);
+   return r;
+}
+`
+       tmpdir, err := ioutil.TempDir("", "TestReturnAfterStackGrowInCallback")
+       if err != nil {
+               t.Fatal("TempDir failed: ", err)
+       }
+       defer os.RemoveAll(tmpdir)
+
+       srcname := "mydll.c"
+       err = ioutil.WriteFile(filepath.Join(tmpdir, srcname), []byte(src), 0)
+       if err != nil {
+               t.Fatal(err)
+       }
+       outname := "mydll.dll"
+       cmd := exec.Command("gcc", "-shared", "-s", "-Werror", "-o", outname, srcname)
+       cmd.Dir = tmpdir
+       out, err := cmd.CombinedOutput()
+       if err != nil {
+               t.Fatalf("failed to build dll: %v - %v", err, string(out))
+       }
+       dllpath := filepath.Join(tmpdir, outname)
+
+       dll := syscall.MustLoadDLL(dllpath)
+       defer dll.Release()
+
+       proc := dll.MustFindProc("cfunc")
+
+       cb := syscall.NewCallback(func(n uintptr) uintptr {
+               forceStackCopy()
+               return n
+       })
+
+       // Use a new goroutine so that we get a small stack.
+       type result struct {
+               r   uintptr
+               err syscall.Errno
+       }
+       c := make(chan result)
+       go func() {
+               r, _, err := proc.Call(cb, 100)
+               c <- result{r, err.(syscall.Errno)}
+       }()
+       want := result{r: 100, err: 333}
+       if got := <-c; got != want {
+               t.Errorf("got %d want %d", got, want)
+       }
+}