]> Cypherpunks repositories - gostls13.git/commitdiff
net: prevent races during windows lookup calls
authorAlex Brainman <alex.brainman@gmail.com>
Mon, 4 Feb 2013 02:05:20 +0000 (13:05 +1100)
committerAlex Brainman <alex.brainman@gmail.com>
Mon, 4 Feb 2013 02:05:20 +0000 (13:05 +1100)
This only affects code (with exception of lookupProtocol)
that is only executed on older versions of Windows.

R=rsc, bradfitz
CC=golang-dev
https://golang.org/cl/7293043

src/pkg/net/lookup_windows.go

index b433d0cbbdbca9278c9c921dd79c17055070ecf9..3b29724f27a3ff72161184c22dcd2bccf886bb3d 100644 (file)
@@ -6,26 +6,17 @@ package net
 
 import (
        "os"
-       "sync"
+       "runtime"
        "syscall"
        "unsafe"
 )
 
-var (
-       protoentLock sync.Mutex
-       hostentLock  sync.Mutex
-       serventLock  sync.Mutex
-)
-
 var (
        lookupPort = oldLookupPort
        lookupIP   = oldLookupIP
 )
 
-// lookupProtocol looks up IP protocol name and returns correspondent protocol number.
-func lookupProtocol(name string) (proto int, err error) {
-       protoentLock.Lock()
-       defer protoentLock.Unlock()
+func getprotobyname(name string) (proto int, err error) {
        p, err := syscall.GetProtoByName(name)
        if err != nil {
                return 0, os.NewSyscallError("GetProtoByName", err)
@@ -33,6 +24,25 @@ func lookupProtocol(name string) (proto int, err error) {
        return int(p.Proto), nil
 }
 
+// lookupProtocol looks up IP protocol name and returns correspondent protocol number.
+func lookupProtocol(name string) (proto int, err error) {
+       // GetProtoByName return value is stored in thread local storage.
+       // Start new os thread before the call to prevent races.
+       type result struct {
+               proto int
+               err   error
+       }
+       ch := make(chan result)
+       go func() {
+               runtime.LockOSThread()
+               defer runtime.UnlockOSThread()
+               proto, err := getprotobyname(name)
+               ch <- result{proto: proto, err: err}
+       }()
+       r := <-ch
+       return r.proto, r.err
+}
+
 func lookupHost(name string) (addrs []string, err error) {
        ips, err := LookupIP(name)
        if err != nil {
@@ -45,9 +55,7 @@ func lookupHost(name string) (addrs []string, err error) {
        return
 }
 
-func oldLookupIP(name string) (addrs []IP, err error) {
-       hostentLock.Lock()
-       defer hostentLock.Unlock()
+func gethostbyname(name string) (addrs []IP, err error) {
        h, err := syscall.GetHostByName(name)
        if err != nil {
                return nil, os.NewSyscallError("GetHostByName", err)
@@ -66,6 +74,24 @@ func oldLookupIP(name string) (addrs []IP, err error) {
        return addrs, nil
 }
 
+func oldLookupIP(name string) (addrs []IP, err error) {
+       // GetHostByName return value is stored in thread local storage.
+       // Start new os thread before the call to prevent races.
+       type result struct {
+               addrs []IP
+               err   error
+       }
+       ch := make(chan result)
+       go func() {
+               runtime.LockOSThread()
+               defer runtime.UnlockOSThread()
+               addrs, err := gethostbyname(name)
+               ch <- result{addrs: addrs, err: err}
+       }()
+       r := <-ch
+       return r.addrs, r.err
+}
+
 func newLookupIP(name string) (addrs []IP, err error) {
        hints := syscall.AddrinfoW{
                Family:   syscall.AF_UNSPEC,
@@ -95,15 +121,13 @@ func newLookupIP(name string) (addrs []IP, err error) {
        return addrs, nil
 }
 
-func oldLookupPort(network, service string) (port int, err error) {
+func getservbyname(network, service string) (port int, err error) {
        switch network {
        case "tcp4", "tcp6":
                network = "tcp"
        case "udp4", "udp6":
                network = "udp"
        }
-       serventLock.Lock()
-       defer serventLock.Unlock()
        s, err := syscall.GetServByName(service, network)
        if err != nil {
                return 0, os.NewSyscallError("GetServByName", err)
@@ -111,6 +135,24 @@ func oldLookupPort(network, service string) (port int, err error) {
        return int(syscall.Ntohs(s.Port)), nil
 }
 
+func oldLookupPort(network, service string) (port int, err error) {
+       // GetServByName return value is stored in thread local storage.
+       // Start new os thread before the call to prevent races.
+       type result struct {
+               port int
+               err  error
+       }
+       ch := make(chan result)
+       go func() {
+               runtime.LockOSThread()
+               defer runtime.UnlockOSThread()
+               port, err := getservbyname(network, service)
+               ch <- result{port: port, err: err}
+       }()
+       r := <-ch
+       return r.port, r.err
+}
+
 func newLookupPort(network, service string) (port int, err error) {
        var stype int32
        switch network {