From: Alex Brainman Date: Mon, 4 Feb 2013 02:05:20 +0000 (+1100) Subject: net: prevent races during windows lookup calls X-Git-Tag: go1.1rc2~1160 X-Git-Url: http://www.git.cypherpunks.su/?a=commitdiff_plain;h=bc9999337b0c54a8035c4f9e6ea13b1fe3f34706;p=gostls13.git net: prevent races during windows lookup calls This only affects code (with exception of lookupProtocol) that is only executed on older versions of Windows. R=rsc, bradfitz CC=golang-dev https://golang.org/cl/7293043 --- diff --git a/src/pkg/net/lookup_windows.go b/src/pkg/net/lookup_windows.go index b433d0cbbd..3b29724f27 100644 --- a/src/pkg/net/lookup_windows.go +++ b/src/pkg/net/lookup_windows.go @@ -6,26 +6,17 @@ package net import ( "os" - "sync" + "runtime" "syscall" "unsafe" ) -var ( - protoentLock sync.Mutex - hostentLock sync.Mutex - serventLock sync.Mutex -) - var ( lookupPort = oldLookupPort lookupIP = oldLookupIP ) -// lookupProtocol looks up IP protocol name and returns correspondent protocol number. -func lookupProtocol(name string) (proto int, err error) { - protoentLock.Lock() - defer protoentLock.Unlock() +func getprotobyname(name string) (proto int, err error) { p, err := syscall.GetProtoByName(name) if err != nil { return 0, os.NewSyscallError("GetProtoByName", err) @@ -33,6 +24,25 @@ func lookupProtocol(name string) (proto int, err error) { return int(p.Proto), nil } +// lookupProtocol looks up IP protocol name and returns correspondent protocol number. +func lookupProtocol(name string) (proto int, err error) { + // GetProtoByName return value is stored in thread local storage. + // Start new os thread before the call to prevent races. + type result struct { + proto int + err error + } + ch := make(chan result) + go func() { + runtime.LockOSThread() + defer runtime.UnlockOSThread() + proto, err := getprotobyname(name) + ch <- result{proto: proto, err: err} + }() + r := <-ch + return r.proto, r.err +} + func lookupHost(name string) (addrs []string, err error) { ips, err := LookupIP(name) if err != nil { @@ -45,9 +55,7 @@ func lookupHost(name string) (addrs []string, err error) { return } -func oldLookupIP(name string) (addrs []IP, err error) { - hostentLock.Lock() - defer hostentLock.Unlock() +func gethostbyname(name string) (addrs []IP, err error) { h, err := syscall.GetHostByName(name) if err != nil { return nil, os.NewSyscallError("GetHostByName", err) @@ -66,6 +74,24 @@ func oldLookupIP(name string) (addrs []IP, err error) { return addrs, nil } +func oldLookupIP(name string) (addrs []IP, err error) { + // GetHostByName return value is stored in thread local storage. + // Start new os thread before the call to prevent races. + type result struct { + addrs []IP + err error + } + ch := make(chan result) + go func() { + runtime.LockOSThread() + defer runtime.UnlockOSThread() + addrs, err := gethostbyname(name) + ch <- result{addrs: addrs, err: err} + }() + r := <-ch + return r.addrs, r.err +} + func newLookupIP(name string) (addrs []IP, err error) { hints := syscall.AddrinfoW{ Family: syscall.AF_UNSPEC, @@ -95,15 +121,13 @@ func newLookupIP(name string) (addrs []IP, err error) { return addrs, nil } -func oldLookupPort(network, service string) (port int, err error) { +func getservbyname(network, service string) (port int, err error) { switch network { case "tcp4", "tcp6": network = "tcp" case "udp4", "udp6": network = "udp" } - serventLock.Lock() - defer serventLock.Unlock() s, err := syscall.GetServByName(service, network) if err != nil { return 0, os.NewSyscallError("GetServByName", err) @@ -111,6 +135,24 @@ func oldLookupPort(network, service string) (port int, err error) { return int(syscall.Ntohs(s.Port)), nil } +func oldLookupPort(network, service string) (port int, err error) { + // GetServByName return value is stored in thread local storage. + // Start new os thread before the call to prevent races. + type result struct { + port int + err error + } + ch := make(chan result) + go func() { + runtime.LockOSThread() + defer runtime.UnlockOSThread() + port, err := getservbyname(network, service) + ch <- result{port: port, err: err} + }() + r := <-ch + return r.port, r.err +} + func newLookupPort(network, service string) (port int, err error) { var stype int32 switch network {