]> Cypherpunks repositories - gostls13.git/commitdiff
net: use windows ConnectEx to dial (when possible)
authorAlex Brainman <alex.brainman@gmail.com>
Fri, 11 Jan 2013 01:42:09 +0000 (12:42 +1100)
committerAlex Brainman <alex.brainman@gmail.com>
Fri, 11 Jan 2013 01:42:09 +0000 (12:42 +1100)
Update #2631.
Update #3097.

R=golang-dev, bradfitz
CC=golang-dev
https://golang.org/cl/7061060

src/pkg/net/dial.go
src/pkg/net/dial_windows_test.go [new file with mode: 0644]
src/pkg/net/fd_plan9.go
src/pkg/net/fd_unix.go
src/pkg/net/fd_windows.go
src/pkg/syscall/syscall_windows.go
src/pkg/syscall/ztypes_windows.go

index c1eb983cc0f882a642accb41bbb38fccd1b8fe4d..354028a157ab4154b61cc4a672359719d525bd2a 100644 (file)
@@ -5,7 +5,6 @@
 package net
 
 import (
-       "runtime"
        "time"
 )
 
@@ -113,30 +112,16 @@ func dialAddr(net, addr string, addri Addr, deadline time.Time) (c Conn, err err
        return
 }
 
-const useDialTimeoutRace = runtime.GOOS == "windows" || runtime.GOOS == "plan9"
-
 // DialTimeout acts like Dial but takes a timeout.
 // The timeout includes name resolution, if required.
 func DialTimeout(net, addr string, timeout time.Duration) (Conn, error) {
-       if useDialTimeoutRace {
-               // On windows and plan9, use the relatively inefficient
-               // goroutine-racing implementation of DialTimeout that
-               // doesn't push down deadlines to the pollster.
-               // TODO: remove this once those are implemented.
-               return dialTimeoutRace(net, addr, timeout)
-       }
-       deadline := time.Now().Add(timeout)
-       _, addri, err := resolveNetAddr("dial", net, addr, deadline)
-       if err != nil {
-               return nil, err
-       }
-       return dialAddr(net, addr, addri, deadline)
+       return dialTimeout(net, addr, timeout)
 }
 
 // dialTimeoutRace is the old implementation of DialTimeout, still used
 // on operating systems where the deadline hasn't been pushed down
 // into the pollserver.
-// TODO: fix this on Windows and plan9.
+// TODO: fix this on plan9.
 func dialTimeoutRace(net, addr string, timeout time.Duration) (Conn, error) {
        t := time.NewTimer(timeout)
        defer t.Stop()
diff --git a/src/pkg/net/dial_windows_test.go b/src/pkg/net/dial_windows_test.go
new file mode 100644 (file)
index 0000000..8fc9b2f
--- /dev/null
@@ -0,0 +1,74 @@
+// Copyright 2013 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package net
+
+import (
+       "sync"
+       "syscall"
+       "testing"
+       "time"
+       "unsafe"
+)
+
+var handleCounter struct {
+       once sync.Once
+       proc *syscall.Proc
+}
+
+func numHandles(t *testing.T) int {
+
+       handleCounter.once.Do(func() {
+               d, err := syscall.LoadDLL("kernel32.dll")
+               if err != nil {
+                       t.Fatalf("LoadDLL: %v\n", err)
+               }
+               handleCounter.proc, err = d.FindProc("GetProcessHandleCount")
+               if err != nil {
+                       t.Fatalf("FindProc: %v\n", err)
+               }
+       })
+
+       cp, err := syscall.GetCurrentProcess()
+       if err != nil {
+               t.Fatalf("GetCurrentProcess: %v\n", err)
+       }
+       var n uint32
+       r, _, err := handleCounter.proc.Call(uintptr(cp), uintptr(unsafe.Pointer(&n)))
+       if r == 0 {
+               t.Fatalf("GetProcessHandleCount: %v\n", error(err))
+       }
+       return int(n)
+}
+
+func testDialTimeoutHandleLeak(t *testing.T) (before, after int) {
+       before = numHandles(t)
+       // See comment in TestDialTimeout about why we use this address.
+       c, err := DialTimeout("tcp", "127.0.71.111:49151", 200*time.Millisecond)
+       after = numHandles(t)
+       if err == nil {
+               c.Close()
+               t.Fatalf("unexpected: connected to %s", c.RemoteAddr())
+       }
+       terr, ok := err.(timeout)
+       if !ok {
+               t.Fatalf("got error %q; want error with timeout interface", err)
+       }
+       if !terr.Timeout() {
+               t.Fatalf("got error %q; not a timeout", err)
+       }
+       return
+}
+
+func TestDialTimeoutHandleLeak(t *testing.T) {
+       if !canUseConnectEx("tcp") {
+               t.Logf("skipping test; no ConnectEx found.")
+               return
+       }
+       testDialTimeoutHandleLeak(t) // ignore first call results
+       before, after := testDialTimeoutHandleLeak(t)
+       if before != after {
+               t.Fatalf("handle count is different before=%d and after=%d", before, after)
+       }
+}
index 6d7ab388ae783c02bfd792742a71801a47c8ad84..3462792816e5f4d2e9952b4ae69075051825658e 100644 (file)
@@ -23,6 +23,12 @@ var canCancelIO = true // used for testing current package
 func sysInit() {
 }
 
+func dialTimeout(net, addr string, timeout time.Duration) (Conn, error) {
+       // On plan9, use the relatively inefficient
+       // goroutine-racing implementation.
+       return dialTimeoutRace(net, addr, timeout)
+}
+
 func newFD(proto, name string, ctl *os.File, laddr, raddr Addr) *netFD {
        return &netFD{proto, name, "/net/" + proto + "/" + name, ctl, nil, laddr, raddr}
 }
index 6d8af0ab2e26b6b7abb58462d4733e24b73bb74a..cfe6df2130814cb32018d9a1453746a6e941c10c 100644 (file)
@@ -288,6 +288,15 @@ func server(fd int) *pollServer {
        return pollservers[k]
 }
 
+func dialTimeout(net, addr string, timeout time.Duration) (Conn, error) {
+       deadline := time.Now().Add(timeout)
+       _, addri, err := resolveNetAddr("dial", net, addr, deadline)
+       if err != nil {
+               return nil, err
+       }
+       return dialAddr(net, addr, addri, deadline)
+}
+
 func newFD(fd, family, sotype int, net string) (*netFD, error) {
        if err := syscall.SetNonblock(fd, true); err != nil {
                return nil, err
index 18712191fee18b5c03bde470ae374c39283b53d1..ea6ef10ec1a950976108c8e89c6103bfaacbfdee 100644 (file)
@@ -45,6 +45,28 @@ func closesocket(s syscall.Handle) error {
        return syscall.Closesocket(s)
 }
 
+func canUseConnectEx(net string) bool {
+       if net == "udp" || net == "udp4" || net == "udp6" {
+               // ConnectEx windows API does not support connectionless sockets.
+               return false
+       }
+       return syscall.LoadConnectEx() == nil
+}
+
+func dialTimeout(net, addr string, timeout time.Duration) (Conn, error) {
+       if !canUseConnectEx(net) {
+               // Use the relatively inefficient goroutine-racing
+               // implementation of DialTimeout.
+               return dialTimeoutRace(net, addr, timeout)
+       }
+       deadline := time.Now().Add(timeout)
+       _, addri, err := resolveNetAddr("dial", net, addr, deadline)
+       if err != nil {
+               return nil, err
+       }
+       return dialAddr(net, addr, addri, deadline)
+}
+
 // Interface for all IO operations.
 type anOpIface interface {
        Op() *anOp
@@ -321,8 +343,48 @@ func (fd *netFD) setAddr(laddr, raddr Addr) {
        runtime.SetFinalizer(fd, (*netFD).closesocket)
 }
 
+// Make new connection.
+
+type connectOp struct {
+       anOp
+       ra syscall.Sockaddr
+}
+
+func (o *connectOp) Submit() error {
+       return syscall.ConnectEx(o.fd.sysfd, o.ra, nil, 0, nil, &o.o)
+}
+
+func (o *connectOp) Name() string {
+       return "ConnectEx"
+}
+
 func (fd *netFD) connect(ra syscall.Sockaddr) error {
-       return syscall.Connect(fd.sysfd, ra)
+       if !canUseConnectEx(fd.net) {
+               return syscall.Connect(fd.sysfd, ra)
+       }
+       // ConnectEx windows API requires an unconnected, previously bound socket.
+       var la syscall.Sockaddr
+       switch ra.(type) {
+       case *syscall.SockaddrInet4:
+               la = &syscall.SockaddrInet4{}
+       case *syscall.SockaddrInet6:
+               la = &syscall.SockaddrInet6{}
+       default:
+               panic("unexpected type in connect")
+       }
+       if err := syscall.Bind(fd.sysfd, la); err != nil {
+               return err
+       }
+       // Call ConnectEx API.
+       var o connectOp
+       o.Init(fd, 'w')
+       o.ra = ra
+       _, err := iosrv.ExecIO(&o, fd.wdeadline.value())
+       if err != nil {
+               return err
+       }
+       // Refresh socket properties.
+       return syscall.Setsockopt(fd.sysfd, syscall.SOL_SOCKET, syscall.SO_UPDATE_CONNECT_CONTEXT, (*byte)(unsafe.Pointer(&fd.sysfd)), int32(unsafe.Sizeof(fd.sysfd)))
 }
 
 // Add a reference to this fd.
index 5acb65dee144ca55e340ad8f9e6962c0f6fed712..e745fbe510f43f77b3025378001a8a7b60d10df1 100644 (file)
@@ -7,6 +7,8 @@
 package syscall
 
 import (
+       errorspkg "errors"
+       "sync"
        "unicode/utf16"
        "unsafe"
 )
@@ -712,6 +714,56 @@ func LoadGetAddrInfo() error {
        return procGetAddrInfoW.Find()
 }
 
+var connectExFunc struct {
+       once sync.Once
+       addr uintptr
+       err  error
+}
+
+func LoadConnectEx() error {
+       connectExFunc.once.Do(func() {
+               var s Handle
+               s, connectExFunc.err = Socket(AF_INET, SOCK_STREAM, IPPROTO_TCP)
+               if connectExFunc.err != nil {
+                       return
+               }
+               defer CloseHandle(s)
+               var n uint32
+               connectExFunc.err = WSAIoctl(s,
+                       SIO_GET_EXTENSION_FUNCTION_POINTER,
+                       (*byte)(unsafe.Pointer(&WSAID_CONNECTEX)),
+                       uint32(unsafe.Sizeof(WSAID_CONNECTEX)),
+                       (*byte)(unsafe.Pointer(&connectExFunc.addr)),
+                       uint32(unsafe.Sizeof(connectExFunc.addr)),
+                       &n, nil, 0)
+       })
+       return connectExFunc.err
+}
+
+func connectEx(s Handle, name uintptr, namelen int32, sendBuf *byte, sendDataLen uint32, bytesSent *uint32, overlapped *Overlapped) (err error) {
+       r1, _, e1 := Syscall9(connectExFunc.addr, 7, uintptr(s), uintptr(name), uintptr(namelen), uintptr(unsafe.Pointer(sendBuf)), uintptr(sendDataLen), uintptr(unsafe.Pointer(bytesSent)), uintptr(unsafe.Pointer(overlapped)), 0, 0)
+       if r1 == 0 {
+               if e1 != 0 {
+                       err = error(e1)
+               } else {
+                       err = EINVAL
+               }
+       }
+       return
+}
+
+func ConnectEx(fd Handle, sa Sockaddr, sendBuf *byte, sendDataLen uint32, bytesSent *uint32, overlapped *Overlapped) error {
+       err := LoadConnectEx()
+       if err != nil {
+               return errorspkg.New("failed to find ConnectEx: " + err.Error())
+       }
+       ptr, n, err := sa.sockaddr()
+       if err != nil {
+               return err
+       }
+       return connectEx(fd, ptr, n, sendBuf, sendDataLen, bytesSent, overlapped)
+}
+
 // Invented structures to support what package os expects.
 type Rusage struct {
        CreationTime Filetime
index 1f7308796fbb25295b9a01e6402f984c325a00b5..a2006f803d6237ad3f549f82bcf8886093b89501 100644 (file)
@@ -496,15 +496,22 @@ const (
        IPPROTO_TCP  = 6
        IPPROTO_UDP  = 17
 
-       SOL_SOCKET               = 0xffff
-       SO_REUSEADDR             = 4
-       SO_KEEPALIVE             = 8
-       SO_DONTROUTE             = 16
-       SO_BROADCAST             = 32
-       SO_LINGER                = 128
-       SO_RCVBUF                = 0x1002
-       SO_SNDBUF                = 0x1001
-       SO_UPDATE_ACCEPT_CONTEXT = 0x700b
+       SOL_SOCKET                = 0xffff
+       SO_REUSEADDR              = 4
+       SO_KEEPALIVE              = 8
+       SO_DONTROUTE              = 16
+       SO_BROADCAST              = 32
+       SO_LINGER                 = 128
+       SO_RCVBUF                 = 0x1002
+       SO_SNDBUF                 = 0x1001
+       SO_UPDATE_ACCEPT_CONTEXT  = 0x700b
+       SO_UPDATE_CONNECT_CONTEXT = 0x7010
+
+       IOC_OUT                            = 0x40000000
+       IOC_IN                             = 0x80000000
+       IOC_INOUT                          = IOC_IN | IOC_OUT
+       IOC_WS2                            = 0x08000000
+       SIO_GET_EXTENSION_FUNCTION_POINTER = IOC_INOUT | IOC_WS2 | 6
 
        // cf. http://support.microsoft.com/default.aspx?scid=kb;en-us;257460
 
@@ -941,3 +948,17 @@ const (
        AI_CANONNAME   = 2
        AI_NUMERICHOST = 4
 )
+
+type GUID struct {
+       Data1 uint32
+       Data2 uint16
+       Data3 uint16
+       Data4 [8]byte
+}
+
+var WSAID_CONNECTEX = GUID{
+       0x25a207b9,
+       0xddf3,
+       0x4660,
+       [8]byte{0x8e, 0xe9, 0x76, 0xe5, 0x8c, 0x74, 0x06, 0x3e},
+}