From: Alex Brainman Date: Tue, 11 Oct 2011 23:29:22 +0000 (+1100) Subject: net: implement ip protocol name to number resolver for windows X-Git-Tag: weekly.2011-10-18~123 X-Git-Url: http://www.git.cypherpunks.su/?a=commitdiff_plain;h=059c68bf0cccea85bea19c44c15d7fec9e9cbd21;p=gostls13.git net: implement ip protocol name to number resolver for windows Fixes #2215. Fixes #2216. R=golang-dev, dave, rsc CC=golang-dev https://golang.org/cl/5248055 --- diff --git a/src/pkg/net/ipraw_test.go b/src/pkg/net/ipraw_test.go index 6894ce656d..7f8c7b841e 100644 --- a/src/pkg/net/ipraw_test.go +++ b/src/pkg/net/ipraw_test.go @@ -11,6 +11,7 @@ import ( "bytes" "flag" "os" + "runtime" "testing" ) @@ -64,7 +65,7 @@ var dsthost = flag.String("dsthost", "127.0.0.1", "Destination for the ICMP ECHO // test (raw) IP socket using ICMP func TestICMP(t *testing.T) { - if os.Getuid() != 0 { + if runtime.GOOS != "windows" && os.Getuid() != 0 { t.Logf("test disabled; must be root") return } diff --git a/src/pkg/net/iprawsock_posix.go b/src/pkg/net/iprawsock_posix.go index f9e497f173..dafbdab780 100644 --- a/src/pkg/net/iprawsock_posix.go +++ b/src/pkg/net/iprawsock_posix.go @@ -10,12 +10,9 @@ package net import ( "os" - "sync" "syscall" ) -var onceReadProtocols sync.Once - func sockaddrToIP(sa syscall.Sockaddr) Addr { switch sa := sa.(type) { case *syscall.SockaddrInet4: @@ -209,33 +206,7 @@ func (c *IPConn) WriteTo(b []byte, addr Addr) (n int, err os.Error) { return c.WriteToIP(b, a) } -var protocols map[string]int - -func readProtocols() { - protocols = make(map[string]int) - if file, err := open("/etc/protocols"); err == nil { - for line, ok := file.readLine(); ok; line, ok = file.readLine() { - // tcp 6 TCP # transmission control protocol - if i := byteIndex(line, '#'); i >= 0 { - line = line[0:i] - } - f := getFields(line) - if len(f) < 2 { - continue - } - if proto, _, ok := dtoi(f[1], 0); ok { - protocols[f[0]] = proto - for _, alias := range f[2:] { - protocols[alias] = proto - } - } - } - file.close() - } -} - func splitNetProto(netProto string) (net string, proto int, err os.Error) { - onceReadProtocols.Do(readProtocols) i := last(netProto, ':') if i < 0 { // no colon return "", 0, os.NewError("no IP protocol specified") @@ -244,13 +215,12 @@ func splitNetProto(netProto string) (net string, proto int, err os.Error) { protostr := netProto[i+1:] proto, i, ok := dtoi(protostr, 0) if !ok || i != len(protostr) { - // lookup by name - proto, ok = protocols[protostr] - if ok { - return + proto, err = lookupProtocol(protostr) + if err != nil { + return "", 0, err } } - return + return net, proto, nil } // DialIP connects to the remote address raddr on the network net, diff --git a/src/pkg/net/lookup_unix.go b/src/pkg/net/lookup_unix.go index 7368b751ee..387bb5976c 100644 --- a/src/pkg/net/lookup_unix.go +++ b/src/pkg/net/lookup_unix.go @@ -8,8 +8,50 @@ package net import ( "os" + "sync" ) +var ( + protocols map[string]int + onceReadProtocols sync.Once +) + +// readProtocols loads contents of /etc/protocols into protocols map +// for quick access. +func readProtocols() { + protocols = make(map[string]int) + if file, err := open("/etc/protocols"); err == nil { + for line, ok := file.readLine(); ok; line, ok = file.readLine() { + // tcp 6 TCP # transmission control protocol + if i := byteIndex(line, '#'); i >= 0 { + line = line[0:i] + } + f := getFields(line) + if len(f) < 2 { + continue + } + if proto, _, ok := dtoi(f[1], 0); ok { + protocols[f[0]] = proto + for _, alias := range f[2:] { + protocols[alias] = proto + } + } + } + file.close() + } +} + +// lookupProtocol looks up IP protocol name in /etc/protocols and +// returns correspondent protocol number. +func lookupProtocol(name string) (proto int, err os.Error) { + onceReadProtocols.Do(readProtocols) + proto, found := protocols[name] + if !found { + return 0, os.NewError("unknown IP protocol specified: " + name) + } + return +} + // LookupHost looks up the given host using the local resolver. // It returns an array of that host's addresses. func LookupHost(host string) (addrs []string, err os.Error) { diff --git a/src/pkg/net/lookup_windows.go b/src/pkg/net/lookup_windows.go index b33c7f949e..e138698241 100644 --- a/src/pkg/net/lookup_windows.go +++ b/src/pkg/net/lookup_windows.go @@ -11,8 +11,22 @@ import ( "sync" ) -var hostentLock sync.Mutex -var serventLock sync.Mutex +var ( + protoentLock sync.Mutex + hostentLock sync.Mutex + serventLock sync.Mutex +) + +// lookupProtocol looks up IP protocol name and returns correspondent protocol number. +func lookupProtocol(name string) (proto int, err os.Error) { + protoentLock.Lock() + defer protoentLock.Unlock() + p, e := syscall.GetProtoByName(name) + if e != 0 { + return 0, os.NewSyscallError("GetProtoByName", e) + } + return int(p.Proto), nil +} func LookupHost(name string) (addrs []string, err os.Error) { ips, err := LookupIP(name) diff --git a/src/pkg/syscall/syscall_windows.go b/src/pkg/syscall/syscall_windows.go index e7bae326d8..c482b8073c 100644 --- a/src/pkg/syscall/syscall_windows.go +++ b/src/pkg/syscall/syscall_windows.go @@ -502,6 +502,7 @@ func Chmod(path string, mode uint32) (errno int) { //sys GetHostByName(name string) (h *Hostent, errno int) [failretval==nil] = ws2_32.gethostbyname //sys GetServByName(name string, proto string) (s *Servent, errno int) [failretval==nil] = ws2_32.getservbyname //sys Ntohs(netshort uint16) (u uint16) = ws2_32.ntohs +//sys GetProtoByName(name string) (p *Protoent, errno int) [failretval==nil] = ws2_32.getprotobyname //sys DnsQuery(name string, qtype uint16, options uint32, extra *byte, qrs **DNSRecord, pr *byte) (status uint32) = dnsapi.DnsQuery_W //sys DnsRecordListFree(rl *DNSRecord, freetype uint32) = dnsapi.DnsRecordListFree //sys GetIfEntry(pIfRow *MibIfRow) (errcode int) = iphlpapi.GetIfEntry diff --git a/src/pkg/syscall/zsyscall_windows_386.go b/src/pkg/syscall/zsyscall_windows_386.go index 845004aa3e..7a666403e8 100644 --- a/src/pkg/syscall/zsyscall_windows_386.go +++ b/src/pkg/syscall/zsyscall_windows_386.go @@ -101,6 +101,7 @@ var ( procgethostbyname = modws2_32.NewProc("gethostbyname") procgetservbyname = modws2_32.NewProc("getservbyname") procntohs = modws2_32.NewProc("ntohs") + procgetprotobyname = modws2_32.NewProc("getprotobyname") procDnsQuery_W = moddnsapi.NewProc("DnsQuery_W") procDnsRecordListFree = moddnsapi.NewProc("DnsRecordListFree") procGetIfEntry = modiphlpapi.NewProc("GetIfEntry") @@ -1314,6 +1315,21 @@ func Ntohs(netshort uint16) (u uint16) { return } +func GetProtoByName(name string) (p *Protoent, errno int) { + r0, _, e1 := Syscall(procgetprotobyname.Addr(), 1, uintptr(unsafe.Pointer(StringBytePtr(name))), 0, 0) + p = (*Protoent)(unsafe.Pointer(r0)) + if p == nil { + if e1 != 0 { + errno = int(e1) + } else { + errno = EINVAL + } + } else { + errno = 0 + } + return +} + func DnsQuery(name string, qtype uint16, options uint32, extra *byte, qrs **DNSRecord, pr *byte) (status uint32) { r0, _, _ := Syscall6(procDnsQuery_W.Addr(), 6, uintptr(unsafe.Pointer(StringToUTF16Ptr(name))), uintptr(qtype), uintptr(options), uintptr(unsafe.Pointer(extra)), uintptr(unsafe.Pointer(qrs)), uintptr(unsafe.Pointer(pr))) status = uint32(r0) diff --git a/src/pkg/syscall/zsyscall_windows_amd64.go b/src/pkg/syscall/zsyscall_windows_amd64.go index 0904085b9c..f6488ce9d8 100644 --- a/src/pkg/syscall/zsyscall_windows_amd64.go +++ b/src/pkg/syscall/zsyscall_windows_amd64.go @@ -101,6 +101,7 @@ var ( procgethostbyname = modws2_32.NewProc("gethostbyname") procgetservbyname = modws2_32.NewProc("getservbyname") procntohs = modws2_32.NewProc("ntohs") + procgetprotobyname = modws2_32.NewProc("getprotobyname") procDnsQuery_W = moddnsapi.NewProc("DnsQuery_W") procDnsRecordListFree = moddnsapi.NewProc("DnsRecordListFree") procGetIfEntry = modiphlpapi.NewProc("GetIfEntry") @@ -1314,6 +1315,21 @@ func Ntohs(netshort uint16) (u uint16) { return } +func GetProtoByName(name string) (p *Protoent, errno int) { + r0, _, e1 := Syscall(procgetprotobyname.Addr(), 1, uintptr(unsafe.Pointer(StringBytePtr(name))), 0, 0) + p = (*Protoent)(unsafe.Pointer(r0)) + if p == nil { + if e1 != 0 { + errno = int(e1) + } else { + errno = EINVAL + } + } else { + errno = 0 + } + return +} + func DnsQuery(name string, qtype uint16, options uint32, extra *byte, qrs **DNSRecord, pr *byte) (status uint32) { r0, _, _ := Syscall6(procDnsQuery_W.Addr(), 6, uintptr(unsafe.Pointer(StringToUTF16Ptr(name))), uintptr(qtype), uintptr(options), uintptr(unsafe.Pointer(extra)), uintptr(unsafe.Pointer(qrs)), uintptr(unsafe.Pointer(pr))) status = uint32(r0) diff --git a/src/pkg/syscall/ztypes_windows.go b/src/pkg/syscall/ztypes_windows.go index a3ef1ba43b..451cbf03d1 100644 --- a/src/pkg/syscall/ztypes_windows.go +++ b/src/pkg/syscall/ztypes_windows.go @@ -411,6 +411,12 @@ type Hostent struct { AddrList **byte } +type Protoent struct { + Name *byte + Aliases **byte + Proto uint16 +} + const ( DNS_TYPE_A = 0x0001 DNS_TYPE_NS = 0x0002