]> Cypherpunks repositories - gostls13.git/commitdiff
net: use WSASocket instead of socket call
authorAlex Brainman <alex.brainman@gmail.com>
Mon, 23 Oct 2017 05:23:43 +0000 (16:23 +1100)
committerAlex Brainman <alex.brainman@gmail.com>
Tue, 7 Nov 2017 22:04:15 +0000 (22:04 +0000)
WSASocket (unlike socket call) allows to create sockets that
will not be inherited by child process. So call WSASocket to
save on using syscall.ForkLock and calling syscall.CloseOnExec.

Some very old versions of Windows do not have that functionality.
Call socket, if WSASocket failed, to support these.

Change-Id: I2dab9fa00d1a8609dd6feae1c9cc31d4e55b8cb5
Reviewed-on: https://go-review.googlesource.com/72590
Reviewed-by: Ian Lance Taylor <iant@golang.org>
src/go/build/deps_test.go
src/internal/syscall/windows/syscall_windows.go
src/internal/syscall/windows/zsyscall_windows.go
src/net/hook_windows.go
src/net/internal/socktest/sys_windows.go
src/net/main_windows_test.go
src/net/sock_windows.go

index 16ac51ef07c4dc4ea3dbaa6811219f5d6349b73c..0048469ef4c85c591bc53148483d6652f318d2de 100644 (file)
@@ -266,7 +266,7 @@ var pkgDeps = map[string][]string{
        "math/big":                 {"L4"},
        "mime":                     {"L4", "OS", "syscall", "internal/syscall/windows/registry"},
        "mime/quotedprintable":     {"L4"},
-       "net/internal/socktest":    {"L4", "OS", "syscall"},
+       "net/internal/socktest":    {"L4", "OS", "syscall", "internal/syscall/windows"},
        "net/url":                  {"L4"},
        "plugin":                   {"L0", "OS", "CGO"},
        "runtime/pprof/internal/profile": {"L4", "OS", "compress/gzip", "regexp"},
index af87416f5effe37900a104eddbf14eac1f088b39..3c14691e1dd1adf8d734dd5f15cdc433e7242d5d 100644 (file)
@@ -112,6 +112,13 @@ const (
 //sys  MoveFileEx(from *uint16, to *uint16, flags uint32) (err error) = MoveFileExW
 //sys  GetModuleFileName(module syscall.Handle, fn *uint16, len uint32) (n uint32, err error) = kernel32.GetModuleFileNameW
 
+const (
+       WSA_FLAG_OVERLAPPED        = 0x01
+       WSA_FLAG_NO_HANDLE_INHERIT = 0x80
+)
+
+//sys  WSASocket(af int32, typ int32, protocol int32, protinfo *syscall.WSAProtocolInfo, group uint32, flags uint32) (handle syscall.Handle, err error) [failretval==syscall.InvalidHandle] = ws2_32.WSASocketW
+
 const (
        ComputerNameNetBIOS                   = 0
        ComputerNameDnsHostname               = 1
index ba16456b677d042dffe7c0ef237161b8d4aabe5d..d745fe11a5a1eb955a297477c904718cf97e23cc 100644 (file)
@@ -38,6 +38,7 @@ func errnoErr(e syscall.Errno) error {
 var (
        modiphlpapi = syscall.NewLazyDLL(sysdll.Add("iphlpapi.dll"))
        modkernel32 = syscall.NewLazyDLL(sysdll.Add("kernel32.dll"))
+       modws2_32   = syscall.NewLazyDLL(sysdll.Add("ws2_32.dll"))
        modnetapi32 = syscall.NewLazyDLL(sysdll.Add("netapi32.dll"))
        modadvapi32 = syscall.NewLazyDLL(sysdll.Add("advapi32.dll"))
        modpsapi    = syscall.NewLazyDLL(sysdll.Add("psapi.dll"))
@@ -46,6 +47,7 @@ var (
        procGetComputerNameExW        = modkernel32.NewProc("GetComputerNameExW")
        procMoveFileExW               = modkernel32.NewProc("MoveFileExW")
        procGetModuleFileNameW        = modkernel32.NewProc("GetModuleFileNameW")
+       procWSASocketW                = modws2_32.NewProc("WSASocketW")
        procGetACP                    = modkernel32.NewProc("GetACP")
        procGetConsoleCP              = modkernel32.NewProc("GetConsoleCP")
        procMultiByteToWideChar       = modkernel32.NewProc("MultiByteToWideChar")
@@ -108,6 +110,19 @@ func GetModuleFileName(module syscall.Handle, fn *uint16, len uint32) (n uint32,
        return
 }
 
+func WSASocket(af int32, typ int32, protocol int32, protinfo *syscall.WSAProtocolInfo, group uint32, flags uint32) (handle syscall.Handle, err error) {
+       r0, _, e1 := syscall.Syscall6(procWSASocketW.Addr(), 6, uintptr(af), uintptr(typ), uintptr(protocol), uintptr(unsafe.Pointer(protinfo)), uintptr(group), uintptr(flags))
+       handle = syscall.Handle(r0)
+       if handle == syscall.InvalidHandle {
+               if e1 != 0 {
+                       err = errnoErr(e1)
+               } else {
+                       err = syscall.EINVAL
+               }
+       }
+       return
+}
+
 func GetACP() (acp uint32) {
        r0, _, _ := syscall.Syscall(procGetACP.Addr(), 0, 0, 0, 0)
        acp = uint32(r0)
index 4e64dcef517e6a781bc67bca875dbace2bde4acb..ab8656cbbf343bdc6d720a7ce24d7a5f41172b7c 100644 (file)
@@ -5,6 +5,7 @@
 package net
 
 import (
+       "internal/syscall/windows"
        "syscall"
        "time"
 )
@@ -13,7 +14,8 @@ var (
        testHookDialChannel = func() { time.Sleep(time.Millisecond) } // see golang.org/issue/5349
 
        // Placeholders for socket system calls.
-       socketFunc  func(int, int, int) (syscall.Handle, error)  = syscall.Socket
-       connectFunc func(syscall.Handle, syscall.Sockaddr) error = syscall.Connect
-       listenFunc  func(syscall.Handle, int) error              = syscall.Listen
+       socketFunc    func(int, int, int) (syscall.Handle, error)                                                 = syscall.Socket
+       wsaSocketFunc func(int32, int32, int32, *syscall.WSAProtocolInfo, uint32, uint32) (syscall.Handle, error) = windows.WSASocket
+       connectFunc   func(syscall.Handle, syscall.Sockaddr) error                                                = syscall.Connect
+       listenFunc    func(syscall.Handle, int) error                                                             = syscall.Listen
 )
index 2e3d2bc7fcecb4e98926ff748fd8f2a8b5bf210e..8c1c862f33c9bd6093dec58331f93c38b4dc4fcd 100644 (file)
@@ -4,7 +4,10 @@
 
 package socktest
 
-import "syscall"
+import (
+       "internal/syscall/windows"
+       "syscall"
+)
 
 // Socket wraps syscall.Socket.
 func (sw *Switch) Socket(family, sotype, proto int) (s syscall.Handle, err error) {
@@ -38,6 +41,38 @@ func (sw *Switch) Socket(family, sotype, proto int) (s syscall.Handle, err error
        return s, nil
 }
 
+// WSASocket wraps syscall.WSASocket.
+func (sw *Switch) WSASocket(family, sotype, proto int32, protinfo *syscall.WSAProtocolInfo, group uint32, flags uint32) (s syscall.Handle, err error) {
+       sw.once.Do(sw.init)
+
+       so := &Status{Cookie: cookie(int(family), int(sotype), int(proto))}
+       sw.fmu.RLock()
+       f, _ := sw.fltab[FilterSocket]
+       sw.fmu.RUnlock()
+
+       af, err := f.apply(so)
+       if err != nil {
+               return syscall.InvalidHandle, err
+       }
+       s, so.Err = windows.WSASocket(family, sotype, proto, protinfo, group, flags)
+       if err = af.apply(so); err != nil {
+               if so.Err == nil {
+                       syscall.Closesocket(s)
+               }
+               return syscall.InvalidHandle, err
+       }
+
+       sw.smu.Lock()
+       defer sw.smu.Unlock()
+       if so.Err != nil {
+               sw.stats.getLocked(so.Cookie).OpenFailed++
+               return syscall.InvalidHandle, so.Err
+       }
+       nso := sw.addLocked(s, int(family), int(sotype), int(proto))
+       sw.stats.getLocked(nso.Cookie).Opened++
+       return s, nil
+}
+
 // Closesocket wraps syscall.Closesocket.
 func (sw *Switch) Closesocket(s syscall.Handle) (err error) {
        so := sw.sockso(s)
index f38a3a0d668179388751af798cd1d7b7486f6a8d..07f21b72eb1fcc9a57c352312721ce1c5b0a2190 100644 (file)
@@ -9,6 +9,7 @@ import "internal/poll"
 var (
        // Placeholders for saving original socket system calls.
        origSocket      = socketFunc
+       origWSASocket   = wsaSocketFunc
        origClosesocket = poll.CloseFunc
        origConnect     = connectFunc
        origConnectEx   = poll.ConnectExFunc
@@ -18,6 +19,7 @@ var (
 
 func installTestHooks() {
        socketFunc = sw.Socket
+       wsaSocketFunc = sw.WSASocket
        poll.CloseFunc = sw.Closesocket
        connectFunc = sw.Connect
        poll.ConnectExFunc = sw.ConnectEx
@@ -27,6 +29,7 @@ func installTestHooks() {
 
 func uninstallTestHooks() {
        socketFunc = origSocket
+       wsaSocketFunc = origWSASocket
        poll.CloseFunc = origClosesocket
        connectFunc = origConnect
        poll.ConnectExFunc = origConnectEx
index 89a3ca425857c2150893edaad598d94b7d554061..fa11c7af2e727c21facca8c6daa2d7b3af9ba672 100644 (file)
@@ -5,6 +5,7 @@
 package net
 
 import (
+       "internal/syscall/windows"
        "os"
        "syscall"
 )
@@ -16,9 +17,19 @@ func maxListenerBacklog() int {
 }
 
 func sysSocket(family, sotype, proto int) (syscall.Handle, error) {
+       s, err := wsaSocketFunc(int32(family), int32(sotype), int32(proto),
+               nil, 0, windows.WSA_FLAG_OVERLAPPED|windows.WSA_FLAG_NO_HANDLE_INHERIT)
+       if err == nil {
+               return s, nil
+       }
+       // WSA_FLAG_NO_HANDLE_INHERIT flag is not supported on some
+       // old versions of Windows, see
+       // https://msdn.microsoft.com/en-us/library/windows/desktop/ms742212(v=vs.85).aspx
+       // for details. Just use syscall.Socket, if windows.WSASocket failed.
+
        // See ../syscall/exec_unix.go for description of ForkLock.
        syscall.ForkLock.RLock()
-       s, err := socketFunc(family, sotype, proto)
+       s, err = socketFunc(family, sotype, proto)
        if err == nil {
                syscall.CloseOnExec(s)
        }