]> Cypherpunks repositories - gostls13.git/commitdiff
net: use windows GetAddrInfoW in LookupPort when possible
authorAlex Brainman <alex.brainman@gmail.com>
Fri, 18 Jan 2013 06:05:04 +0000 (17:05 +1100)
committerAlex Brainman <alex.brainman@gmail.com>
Fri, 18 Jan 2013 06:05:04 +0000 (17:05 +1100)
R=golang-dev, bradfitz
CC=golang-dev
https://golang.org/cl/7252045

src/pkg/net/fd_windows.go
src/pkg/net/lookup_windows.go
src/pkg/net/port_test.go

index ea6ef10ec1a950976108c8e89c6103bfaacbfdee..0bf361d4437b02582b2f68441c89fb6274c9e4cd 100644 (file)
@@ -37,6 +37,7 @@ func sysInit() {
        }
        canCancelIO = syscall.LoadCancelIoEx() == nil
        if syscall.LoadGetAddrInfo() == nil {
+               lookupPort = newLookupPort
                lookupIP = newLookupIP
        }
 }
index 390fe7f4405e6ea0fb7ce1a6273e9e7338fcf688..b433d0cbbdbca9278c9c921dd79c17055070ecf9 100644 (file)
@@ -17,6 +17,11 @@ var (
        serventLock  sync.Mutex
 )
 
+var (
+       lookupPort = oldLookupPort
+       lookupIP   = oldLookupIP
+)
+
 // lookupProtocol looks up IP protocol name and returns correspondent protocol number.
 func lookupProtocol(name string) (proto int, err error) {
        protoentLock.Lock()
@@ -40,8 +45,6 @@ func lookupHost(name string) (addrs []string, err error) {
        return
 }
 
-var lookupIP = oldLookupIP
-
 func oldLookupIP(name string) (addrs []IP, err error) {
        hostentLock.Lock()
        defer hostentLock.Unlock()
@@ -92,7 +95,7 @@ func newLookupIP(name string) (addrs []IP, err error) {
        return addrs, nil
 }
 
-func lookupPort(network, service string) (port int, err error) {
+func oldLookupPort(network, service string) (port int, err error) {
        switch network {
        case "tcp4", "tcp6":
                network = "tcp"
@@ -108,6 +111,40 @@ func lookupPort(network, service string) (port int, err error) {
        return int(syscall.Ntohs(s.Port)), nil
 }
 
+func newLookupPort(network, service string) (port int, err error) {
+       var stype int32
+       switch network {
+       case "tcp4", "tcp6":
+               stype = syscall.SOCK_STREAM
+       case "udp4", "udp6":
+               stype = syscall.SOCK_DGRAM
+       }
+       hints := syscall.AddrinfoW{
+               Family:   syscall.AF_UNSPEC,
+               Socktype: stype,
+               Protocol: syscall.IPPROTO_IP,
+       }
+       var result *syscall.AddrinfoW
+       e := syscall.GetAddrInfoW(nil, syscall.StringToUTF16Ptr(service), &hints, &result)
+       if e != nil {
+               return 0, os.NewSyscallError("GetAddrInfoW", e)
+       }
+       defer syscall.FreeAddrInfoW(result)
+       if result == nil {
+               return 0, os.NewSyscallError("LookupPort", syscall.EINVAL)
+       }
+       addr := unsafe.Pointer(result.Addr)
+       switch result.Family {
+       case syscall.AF_INET:
+               a := (*syscall.RawSockaddrInet4)(addr)
+               return int(syscall.Ntohs(a.Port)), nil
+       case syscall.AF_INET6:
+               a := (*syscall.RawSockaddrInet6)(addr)
+               return int(syscall.Ntohs(a.Port)), nil
+       }
+       return 0, os.NewSyscallError("LookupPort", syscall.EINVAL)
+}
+
 func lookupCNAME(name string) (cname string, err error) {
        var r *syscall.DNSRecord
        e := syscall.DnsQuery(name, syscall.DNS_TYPE_CNAME, 0, nil, &r, nil)
index 329b169f3492d3a5ab42e470def031824492f013..9e8968f359cdb60ffb30ab80d30da4e596e76362 100644 (file)
@@ -46,7 +46,7 @@ func TestLookupPort(t *testing.T) {
        for i := 0; i < len(porttests); i++ {
                tt := porttests[i]
                if port, err := LookupPort(tt.netw, tt.name); port != tt.port || (err == nil) != tt.ok {
-                       t.Errorf("LookupPort(%q, %q) = %v, %s; want %v",
+                       t.Errorf("LookupPort(%q, %q) = %v, %v; want %v",
                                tt.netw, tt.name, port, err, tt.port)
                }
        }