]> Cypherpunks repositories - gostls13.git/commitdiff
net: implement ip protocol name to number resolver for windows
authorAlex Brainman <alex.brainman@gmail.com>
Tue, 11 Oct 2011 23:29:22 +0000 (10:29 +1100)
committerAlex Brainman <alex.brainman@gmail.com>
Tue, 11 Oct 2011 23:29:22 +0000 (10:29 +1100)
Fixes #2215.
Fixes #2216.

R=golang-dev, dave, rsc
CC=golang-dev
https://golang.org/cl/5248055

src/pkg/net/ipraw_test.go
src/pkg/net/iprawsock_posix.go
src/pkg/net/lookup_unix.go
src/pkg/net/lookup_windows.go
src/pkg/syscall/syscall_windows.go
src/pkg/syscall/zsyscall_windows_386.go
src/pkg/syscall/zsyscall_windows_amd64.go
src/pkg/syscall/ztypes_windows.go

index 6894ce656dd5b7cfd7c19e9c8c6205653bd1ce0d..7f8c7b841ea1c92d90d7eb57e0ab0587222d2b60 100644 (file)
@@ -11,6 +11,7 @@ import (
        "bytes"
        "flag"
        "os"
+       "runtime"
        "testing"
 )
 
@@ -64,7 +65,7 @@ var dsthost = flag.String("dsthost", "127.0.0.1", "Destination for the ICMP ECHO
 
 // test (raw) IP socket using ICMP
 func TestICMP(t *testing.T) {
-       if os.Getuid() != 0 {
+       if runtime.GOOS != "windows" && os.Getuid() != 0 {
                t.Logf("test disabled; must be root")
                return
        }
index f9e497f173e426a24d20532a8199e5890a4d4139..dafbdab780404be7f8453154961b21bb6658436a 100644 (file)
@@ -10,12 +10,9 @@ package net
 
 import (
        "os"
-       "sync"
        "syscall"
 )
 
-var onceReadProtocols sync.Once
-
 func sockaddrToIP(sa syscall.Sockaddr) Addr {
        switch sa := sa.(type) {
        case *syscall.SockaddrInet4:
@@ -209,33 +206,7 @@ func (c *IPConn) WriteTo(b []byte, addr Addr) (n int, err os.Error) {
        return c.WriteToIP(b, a)
 }
 
-var protocols map[string]int
-
-func readProtocols() {
-       protocols = make(map[string]int)
-       if file, err := open("/etc/protocols"); err == nil {
-               for line, ok := file.readLine(); ok; line, ok = file.readLine() {
-                       // tcp    6   TCP    # transmission control protocol
-                       if i := byteIndex(line, '#'); i >= 0 {
-                               line = line[0:i]
-                       }
-                       f := getFields(line)
-                       if len(f) < 2 {
-                               continue
-                       }
-                       if proto, _, ok := dtoi(f[1], 0); ok {
-                               protocols[f[0]] = proto
-                               for _, alias := range f[2:] {
-                                       protocols[alias] = proto
-                               }
-                       }
-               }
-               file.close()
-       }
-}
-
 func splitNetProto(netProto string) (net string, proto int, err os.Error) {
-       onceReadProtocols.Do(readProtocols)
        i := last(netProto, ':')
        if i < 0 { // no colon
                return "", 0, os.NewError("no IP protocol specified")
@@ -244,13 +215,12 @@ func splitNetProto(netProto string) (net string, proto int, err os.Error) {
        protostr := netProto[i+1:]
        proto, i, ok := dtoi(protostr, 0)
        if !ok || i != len(protostr) {
-               // lookup by name
-               proto, ok = protocols[protostr]
-               if ok {
-                       return
+               proto, err = lookupProtocol(protostr)
+               if err != nil {
+                       return "", 0, err
                }
        }
-       return
+       return net, proto, nil
 }
 
 // DialIP connects to the remote address raddr on the network net,
index 7368b751ee002924261e5e9038242662300de5c4..387bb5976c9e08a9a39a25d4c777c92ecd74b443 100644 (file)
@@ -8,8 +8,50 @@ package net
 
 import (
        "os"
+       "sync"
 )
 
+var (
+       protocols         map[string]int
+       onceReadProtocols sync.Once
+)
+
+// readProtocols loads contents of /etc/protocols into protocols map
+// for quick access.
+func readProtocols() {
+       protocols = make(map[string]int)
+       if file, err := open("/etc/protocols"); err == nil {
+               for line, ok := file.readLine(); ok; line, ok = file.readLine() {
+                       // tcp    6   TCP    # transmission control protocol
+                       if i := byteIndex(line, '#'); i >= 0 {
+                               line = line[0:i]
+                       }
+                       f := getFields(line)
+                       if len(f) < 2 {
+                               continue
+                       }
+                       if proto, _, ok := dtoi(f[1], 0); ok {
+                               protocols[f[0]] = proto
+                               for _, alias := range f[2:] {
+                                       protocols[alias] = proto
+                               }
+                       }
+               }
+               file.close()
+       }
+}
+
+// lookupProtocol looks up IP protocol name in /etc/protocols and
+// returns correspondent protocol number.
+func lookupProtocol(name string) (proto int, err os.Error) {
+       onceReadProtocols.Do(readProtocols)
+       proto, found := protocols[name]
+       if !found {
+               return 0, os.NewError("unknown IP protocol specified: " + name)
+       }
+       return
+}
+
 // LookupHost looks up the given host using the local resolver.
 // It returns an array of that host's addresses.
 func LookupHost(host string) (addrs []string, err os.Error) {
index b33c7f949e0ed2e0e2c7dc942f22be209e82af8f..e138698241c8ac0b77d1d4f3baf3edc171311037 100644 (file)
@@ -11,8 +11,22 @@ import (
        "sync"
 )
 
-var hostentLock sync.Mutex
-var serventLock sync.Mutex
+var (
+       protoentLock sync.Mutex
+       hostentLock  sync.Mutex
+       serventLock  sync.Mutex
+)
+
+// lookupProtocol looks up IP protocol name and returns correspondent protocol number.
+func lookupProtocol(name string) (proto int, err os.Error) {
+       protoentLock.Lock()
+       defer protoentLock.Unlock()
+       p, e := syscall.GetProtoByName(name)
+       if e != 0 {
+               return 0, os.NewSyscallError("GetProtoByName", e)
+       }
+       return int(p.Proto), nil
+}
 
 func LookupHost(name string) (addrs []string, err os.Error) {
        ips, err := LookupIP(name)
index e7bae326d8e31c3293979adbaf5c9d2ee6122278..c482b8073c1d79eba8b92197df8825a34903dd7b 100644 (file)
@@ -502,6 +502,7 @@ func Chmod(path string, mode uint32) (errno int) {
 //sys  GetHostByName(name string) (h *Hostent, errno int) [failretval==nil] = ws2_32.gethostbyname
 //sys  GetServByName(name string, proto string) (s *Servent, errno int) [failretval==nil] = ws2_32.getservbyname
 //sys  Ntohs(netshort uint16) (u uint16) = ws2_32.ntohs
+//sys  GetProtoByName(name string) (p *Protoent, errno int) [failretval==nil] = ws2_32.getprotobyname
 //sys  DnsQuery(name string, qtype uint16, options uint32, extra *byte, qrs **DNSRecord, pr *byte) (status uint32) = dnsapi.DnsQuery_W
 //sys  DnsRecordListFree(rl *DNSRecord, freetype uint32) = dnsapi.DnsRecordListFree
 //sys  GetIfEntry(pIfRow *MibIfRow) (errcode int) = iphlpapi.GetIfEntry
index 845004aa3e0d0e844a1e426efe9475bc7f72a1d0..7a666403e81f3c04ed497f39a7c54bc3714b9471 100644 (file)
@@ -101,6 +101,7 @@ var (
        procgethostbyname              = modws2_32.NewProc("gethostbyname")
        procgetservbyname              = modws2_32.NewProc("getservbyname")
        procntohs                      = modws2_32.NewProc("ntohs")
+       procgetprotobyname             = modws2_32.NewProc("getprotobyname")
        procDnsQuery_W                 = moddnsapi.NewProc("DnsQuery_W")
        procDnsRecordListFree          = moddnsapi.NewProc("DnsRecordListFree")
        procGetIfEntry                 = modiphlpapi.NewProc("GetIfEntry")
@@ -1314,6 +1315,21 @@ func Ntohs(netshort uint16) (u uint16) {
        return
 }
 
+func GetProtoByName(name string) (p *Protoent, errno int) {
+       r0, _, e1 := Syscall(procgetprotobyname.Addr(), 1, uintptr(unsafe.Pointer(StringBytePtr(name))), 0, 0)
+       p = (*Protoent)(unsafe.Pointer(r0))
+       if p == nil {
+               if e1 != 0 {
+                       errno = int(e1)
+               } else {
+                       errno = EINVAL
+               }
+       } else {
+               errno = 0
+       }
+       return
+}
+
 func DnsQuery(name string, qtype uint16, options uint32, extra *byte, qrs **DNSRecord, pr *byte) (status uint32) {
        r0, _, _ := Syscall6(procDnsQuery_W.Addr(), 6, uintptr(unsafe.Pointer(StringToUTF16Ptr(name))), uintptr(qtype), uintptr(options), uintptr(unsafe.Pointer(extra)), uintptr(unsafe.Pointer(qrs)), uintptr(unsafe.Pointer(pr)))
        status = uint32(r0)
index 0904085b9c3b4962802de4184296abf1093f3e9b..f6488ce9d81825e9915a311c513f8aea622faa35 100644 (file)
@@ -101,6 +101,7 @@ var (
        procgethostbyname              = modws2_32.NewProc("gethostbyname")
        procgetservbyname              = modws2_32.NewProc("getservbyname")
        procntohs                      = modws2_32.NewProc("ntohs")
+       procgetprotobyname             = modws2_32.NewProc("getprotobyname")
        procDnsQuery_W                 = moddnsapi.NewProc("DnsQuery_W")
        procDnsRecordListFree          = moddnsapi.NewProc("DnsRecordListFree")
        procGetIfEntry                 = modiphlpapi.NewProc("GetIfEntry")
@@ -1314,6 +1315,21 @@ func Ntohs(netshort uint16) (u uint16) {
        return
 }
 
+func GetProtoByName(name string) (p *Protoent, errno int) {
+       r0, _, e1 := Syscall(procgetprotobyname.Addr(), 1, uintptr(unsafe.Pointer(StringBytePtr(name))), 0, 0)
+       p = (*Protoent)(unsafe.Pointer(r0))
+       if p == nil {
+               if e1 != 0 {
+                       errno = int(e1)
+               } else {
+                       errno = EINVAL
+               }
+       } else {
+               errno = 0
+       }
+       return
+}
+
 func DnsQuery(name string, qtype uint16, options uint32, extra *byte, qrs **DNSRecord, pr *byte) (status uint32) {
        r0, _, _ := Syscall6(procDnsQuery_W.Addr(), 6, uintptr(unsafe.Pointer(StringToUTF16Ptr(name))), uintptr(qtype), uintptr(options), uintptr(unsafe.Pointer(extra)), uintptr(unsafe.Pointer(qrs)), uintptr(unsafe.Pointer(pr)))
        status = uint32(r0)
index a3ef1ba43b90f54bdf7681b06cc07de3394dbbbc..451cbf03d1becb55e660d71340ff962907e7d656 100644 (file)
@@ -411,6 +411,12 @@ type Hostent struct {
        AddrList **byte
 }
 
+type Protoent struct {
+       Name    *byte
+       Aliases **byte
+       Proto   uint16
+}
+
 const (
        DNS_TYPE_A       = 0x0001
        DNS_TYPE_NS      = 0x0002