package net
import (
- "runtime"
"time"
)
return
}
-const useDialTimeoutRace = runtime.GOOS == "windows" || runtime.GOOS == "plan9"
-
// DialTimeout acts like Dial but takes a timeout.
// The timeout includes name resolution, if required.
func DialTimeout(net, addr string, timeout time.Duration) (Conn, error) {
- if useDialTimeoutRace {
- // On windows and plan9, use the relatively inefficient
- // goroutine-racing implementation of DialTimeout that
- // doesn't push down deadlines to the pollster.
- // TODO: remove this once those are implemented.
- return dialTimeoutRace(net, addr, timeout)
- }
- deadline := time.Now().Add(timeout)
- _, addri, err := resolveNetAddr("dial", net, addr, deadline)
- if err != nil {
- return nil, err
- }
- return dialAddr(net, addr, addri, deadline)
+ return dialTimeout(net, addr, timeout)
}
// dialTimeoutRace is the old implementation of DialTimeout, still used
// on operating systems where the deadline hasn't been pushed down
// into the pollserver.
-// TODO: fix this on Windows and plan9.
+// TODO: fix this on plan9.
func dialTimeoutRace(net, addr string, timeout time.Duration) (Conn, error) {
t := time.NewTimer(timeout)
defer t.Stop()
--- /dev/null
+// Copyright 2013 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package net
+
+import (
+ "sync"
+ "syscall"
+ "testing"
+ "time"
+ "unsafe"
+)
+
+var handleCounter struct {
+ once sync.Once
+ proc *syscall.Proc
+}
+
+func numHandles(t *testing.T) int {
+
+ handleCounter.once.Do(func() {
+ d, err := syscall.LoadDLL("kernel32.dll")
+ if err != nil {
+ t.Fatalf("LoadDLL: %v\n", err)
+ }
+ handleCounter.proc, err = d.FindProc("GetProcessHandleCount")
+ if err != nil {
+ t.Fatalf("FindProc: %v\n", err)
+ }
+ })
+
+ cp, err := syscall.GetCurrentProcess()
+ if err != nil {
+ t.Fatalf("GetCurrentProcess: %v\n", err)
+ }
+ var n uint32
+ r, _, err := handleCounter.proc.Call(uintptr(cp), uintptr(unsafe.Pointer(&n)))
+ if r == 0 {
+ t.Fatalf("GetProcessHandleCount: %v\n", error(err))
+ }
+ return int(n)
+}
+
+func testDialTimeoutHandleLeak(t *testing.T) (before, after int) {
+ before = numHandles(t)
+ // See comment in TestDialTimeout about why we use this address.
+ c, err := DialTimeout("tcp", "127.0.71.111:49151", 200*time.Millisecond)
+ after = numHandles(t)
+ if err == nil {
+ c.Close()
+ t.Fatalf("unexpected: connected to %s", c.RemoteAddr())
+ }
+ terr, ok := err.(timeout)
+ if !ok {
+ t.Fatalf("got error %q; want error with timeout interface", err)
+ }
+ if !terr.Timeout() {
+ t.Fatalf("got error %q; not a timeout", err)
+ }
+ return
+}
+
+func TestDialTimeoutHandleLeak(t *testing.T) {
+ if !canUseConnectEx("tcp") {
+ t.Logf("skipping test; no ConnectEx found.")
+ return
+ }
+ testDialTimeoutHandleLeak(t) // ignore first call results
+ before, after := testDialTimeoutHandleLeak(t)
+ if before != after {
+ t.Fatalf("handle count is different before=%d and after=%d", before, after)
+ }
+}
func sysInit() {
}
+func dialTimeout(net, addr string, timeout time.Duration) (Conn, error) {
+ // On plan9, use the relatively inefficient
+ // goroutine-racing implementation.
+ return dialTimeoutRace(net, addr, timeout)
+}
+
func newFD(proto, name string, ctl *os.File, laddr, raddr Addr) *netFD {
return &netFD{proto, name, "/net/" + proto + "/" + name, ctl, nil, laddr, raddr}
}
return pollservers[k]
}
+func dialTimeout(net, addr string, timeout time.Duration) (Conn, error) {
+ deadline := time.Now().Add(timeout)
+ _, addri, err := resolveNetAddr("dial", net, addr, deadline)
+ if err != nil {
+ return nil, err
+ }
+ return dialAddr(net, addr, addri, deadline)
+}
+
func newFD(fd, family, sotype int, net string) (*netFD, error) {
if err := syscall.SetNonblock(fd, true); err != nil {
return nil, err
return syscall.Closesocket(s)
}
+func canUseConnectEx(net string) bool {
+ if net == "udp" || net == "udp4" || net == "udp6" {
+ // ConnectEx windows API does not support connectionless sockets.
+ return false
+ }
+ return syscall.LoadConnectEx() == nil
+}
+
+func dialTimeout(net, addr string, timeout time.Duration) (Conn, error) {
+ if !canUseConnectEx(net) {
+ // Use the relatively inefficient goroutine-racing
+ // implementation of DialTimeout.
+ return dialTimeoutRace(net, addr, timeout)
+ }
+ deadline := time.Now().Add(timeout)
+ _, addri, err := resolveNetAddr("dial", net, addr, deadline)
+ if err != nil {
+ return nil, err
+ }
+ return dialAddr(net, addr, addri, deadline)
+}
+
// Interface for all IO operations.
type anOpIface interface {
Op() *anOp
runtime.SetFinalizer(fd, (*netFD).closesocket)
}
+// Make new connection.
+
+type connectOp struct {
+ anOp
+ ra syscall.Sockaddr
+}
+
+func (o *connectOp) Submit() error {
+ return syscall.ConnectEx(o.fd.sysfd, o.ra, nil, 0, nil, &o.o)
+}
+
+func (o *connectOp) Name() string {
+ return "ConnectEx"
+}
+
func (fd *netFD) connect(ra syscall.Sockaddr) error {
- return syscall.Connect(fd.sysfd, ra)
+ if !canUseConnectEx(fd.net) {
+ return syscall.Connect(fd.sysfd, ra)
+ }
+ // ConnectEx windows API requires an unconnected, previously bound socket.
+ var la syscall.Sockaddr
+ switch ra.(type) {
+ case *syscall.SockaddrInet4:
+ la = &syscall.SockaddrInet4{}
+ case *syscall.SockaddrInet6:
+ la = &syscall.SockaddrInet6{}
+ default:
+ panic("unexpected type in connect")
+ }
+ if err := syscall.Bind(fd.sysfd, la); err != nil {
+ return err
+ }
+ // Call ConnectEx API.
+ var o connectOp
+ o.Init(fd, 'w')
+ o.ra = ra
+ _, err := iosrv.ExecIO(&o, fd.wdeadline.value())
+ if err != nil {
+ return err
+ }
+ // Refresh socket properties.
+ return syscall.Setsockopt(fd.sysfd, syscall.SOL_SOCKET, syscall.SO_UPDATE_CONNECT_CONTEXT, (*byte)(unsafe.Pointer(&fd.sysfd)), int32(unsafe.Sizeof(fd.sysfd)))
}
// Add a reference to this fd.
package syscall
import (
+ errorspkg "errors"
+ "sync"
"unicode/utf16"
"unsafe"
)
return procGetAddrInfoW.Find()
}
+var connectExFunc struct {
+ once sync.Once
+ addr uintptr
+ err error
+}
+
+func LoadConnectEx() error {
+ connectExFunc.once.Do(func() {
+ var s Handle
+ s, connectExFunc.err = Socket(AF_INET, SOCK_STREAM, IPPROTO_TCP)
+ if connectExFunc.err != nil {
+ return
+ }
+ defer CloseHandle(s)
+ var n uint32
+ connectExFunc.err = WSAIoctl(s,
+ SIO_GET_EXTENSION_FUNCTION_POINTER,
+ (*byte)(unsafe.Pointer(&WSAID_CONNECTEX)),
+ uint32(unsafe.Sizeof(WSAID_CONNECTEX)),
+ (*byte)(unsafe.Pointer(&connectExFunc.addr)),
+ uint32(unsafe.Sizeof(connectExFunc.addr)),
+ &n, nil, 0)
+ })
+ return connectExFunc.err
+}
+
+func connectEx(s Handle, name uintptr, namelen int32, sendBuf *byte, sendDataLen uint32, bytesSent *uint32, overlapped *Overlapped) (err error) {
+ r1, _, e1 := Syscall9(connectExFunc.addr, 7, uintptr(s), uintptr(name), uintptr(namelen), uintptr(unsafe.Pointer(sendBuf)), uintptr(sendDataLen), uintptr(unsafe.Pointer(bytesSent)), uintptr(unsafe.Pointer(overlapped)), 0, 0)
+ if r1 == 0 {
+ if e1 != 0 {
+ err = error(e1)
+ } else {
+ err = EINVAL
+ }
+ }
+ return
+}
+
+func ConnectEx(fd Handle, sa Sockaddr, sendBuf *byte, sendDataLen uint32, bytesSent *uint32, overlapped *Overlapped) error {
+ err := LoadConnectEx()
+ if err != nil {
+ return errorspkg.New("failed to find ConnectEx: " + err.Error())
+ }
+ ptr, n, err := sa.sockaddr()
+ if err != nil {
+ return err
+ }
+ return connectEx(fd, ptr, n, sendBuf, sendDataLen, bytesSent, overlapped)
+}
+
// Invented structures to support what package os expects.
type Rusage struct {
CreationTime Filetime
IPPROTO_TCP = 6
IPPROTO_UDP = 17
- SOL_SOCKET = 0xffff
- SO_REUSEADDR = 4
- SO_KEEPALIVE = 8
- SO_DONTROUTE = 16
- SO_BROADCAST = 32
- SO_LINGER = 128
- SO_RCVBUF = 0x1002
- SO_SNDBUF = 0x1001
- SO_UPDATE_ACCEPT_CONTEXT = 0x700b
+ SOL_SOCKET = 0xffff
+ SO_REUSEADDR = 4
+ SO_KEEPALIVE = 8
+ SO_DONTROUTE = 16
+ SO_BROADCAST = 32
+ SO_LINGER = 128
+ SO_RCVBUF = 0x1002
+ SO_SNDBUF = 0x1001
+ SO_UPDATE_ACCEPT_CONTEXT = 0x700b
+ SO_UPDATE_CONNECT_CONTEXT = 0x7010
+
+ IOC_OUT = 0x40000000
+ IOC_IN = 0x80000000
+ IOC_INOUT = IOC_IN | IOC_OUT
+ IOC_WS2 = 0x08000000
+ SIO_GET_EXTENSION_FUNCTION_POINTER = IOC_INOUT | IOC_WS2 | 6
// cf. http://support.microsoft.com/default.aspx?scid=kb;en-us;257460
AI_CANONNAME = 2
AI_NUMERICHOST = 4
)
+
+type GUID struct {
+ Data1 uint32
+ Data2 uint16
+ Data3 uint16
+ Data4 [8]byte
+}
+
+var WSAID_CONNECTEX = GUID{
+ 0x25a207b9,
+ 0xddf3,
+ 0x4660,
+ [8]byte{0x8e, 0xe9, 0x76, 0xe5, 0x8c, 0x74, 0x06, 0x3e},
+}