]> Cypherpunks repositories - gostls13.git/commitdiff
net: use contexts for cgo-based DNS resolution
authorScott Bell <scott@sctsm.com>
Mon, 9 May 2016 01:17:59 +0000 (18:17 -0700)
committerBrad Fitzpatrick <bradfitz@golang.org>
Tue, 10 May 2016 15:55:48 +0000 (15:55 +0000)
Although calls to getaddrinfo can't be portably interrupted,
we still benefit from more granular resource management by
pushing the context downwards.

Fixes #15321

Change-Id: I5506195fc6493080410e3d46aaa3fe02018a24fe
Reviewed-on: https://go-review.googlesource.com/22961
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>

src/net/cgo_stub.go
src/net/cgo_unix.go
src/net/cgo_unix_test.go
src/net/lookup_unix.go
src/net/netgo_unix_test.go

index 52d1dfd3460e6c71966b06df67de64169ee14e5c..51259722aec8e6ce54d0d8e2ed46457afa02b3e4 100644 (file)
@@ -6,6 +6,8 @@
 
 package net
 
+import "context"
+
 func init() { netGo = true }
 
 type addrinfoErrno int
@@ -14,22 +16,22 @@ func (eai addrinfoErrno) Error() string   { return "<nil>" }
 func (eai addrinfoErrno) Temporary() bool { return false }
 func (eai addrinfoErrno) Timeout() bool   { return false }
 
-func cgoLookupHost(name string) (addrs []string, err error, completed bool) {
+func cgoLookupHost(ctx context.Context, name string) (addrs []string, err error, completed bool) {
        return nil, nil, false
 }
 
-func cgoLookupPort(network, service string) (port int, err error, completed bool) {
+func cgoLookupPort(ctx context.Context, network, service string) (port int, err error, completed bool) {
        return 0, nil, false
 }
 
-func cgoLookupIP(name string) (addrs []IPAddr, err error, completed bool) {
+func cgoLookupIP(ctx context.Context, name string) (addrs []IPAddr, err error, completed bool) {
        return nil, nil, false
 }
 
-func cgoLookupCNAME(name string) (cname string, err error, completed bool) {
+func cgoLookupCNAME(ctx context.Context, name string) (cname string, err error, completed bool) {
        return "", nil, false
 }
 
-func cgoLookupPTR(addr string) (ptrs []string, err error, completed bool) {
+func cgoLookupPTR(ctx context.Context, addr string) (ptrs []string, err error, completed bool) {
        return nil, nil, false
 }
index 59c40c8d8a7f0cdd4caa8f43508f30b774311dbe..5a1eed843751a45ab72de2365337132f3729f4ba 100644 (file)
@@ -19,6 +19,7 @@ package net
 import "C"
 
 import (
+       "context"
        "syscall"
        "unsafe"
 )
@@ -32,18 +33,31 @@ func (eai addrinfoErrno) Error() string   { return C.GoString(C.gai_strerror(C.i
 func (eai addrinfoErrno) Temporary() bool { return eai == C.EAI_AGAIN }
 func (eai addrinfoErrno) Timeout() bool   { return false }
 
-func cgoLookupHost(name string) (hosts []string, err error, completed bool) {
-       addrs, err, completed := cgoLookupIP(name)
+type portLookupResult struct {
+       port int
+       err  error
+}
+
+type ipLookupResult struct {
+       addrs []IPAddr
+       cname string
+       err   error
+}
+
+type reverseLookupResult struct {
+       names []string
+       err   error
+}
+
+func cgoLookupHost(ctx context.Context, name string) (hosts []string, err error, completed bool) {
+       addrs, err, completed := cgoLookupIP(ctx, name)
        for _, addr := range addrs {
                hosts = append(hosts, addr.String())
        }
        return
 }
 
-func cgoLookupPort(network, service string) (port int, err error, completed bool) {
-       acquireThread()
-       defer releaseThread()
-
+func cgoLookupPort(ctx context.Context, network, service string) (port int, err error, completed bool) {
        var hints C.struct_addrinfo
        switch network {
        case "": // no hints
@@ -64,11 +78,27 @@ func cgoLookupPort(network, service string) (port int, err error, completed bool
                        hints.ai_family = C.AF_INET6
                }
        }
+       if ctx.Done() == nil {
+               port, err := cgoLookupServicePort(&hints, network, service)
+               return port, err, true
+       }
+       result := make(chan portLookupResult, 1)
+       go cgoPortLookup(result, &hints, network, service)
+       select {
+       case r := <-result:
+               return r.port, r.err, true
+       case <-ctx.Done():
+               // Since there isn't a portable way to cancel the lookup,
+               // we just let it finish and write to the buffered channel.
+               return 0, mapErr(ctx.Err()), false
+       }
+}
 
+func cgoLookupServicePort(hints *C.struct_addrinfo, network, service string) (port int, err error) {
        s := C.CString(service)
        var res *C.struct_addrinfo
        defer C.free(unsafe.Pointer(s))
-       gerrno, err := C.getaddrinfo(nil, s, &hints, &res)
+       gerrno, err := C.getaddrinfo(nil, s, hints, &res)
        if gerrno != 0 {
                switch gerrno {
                case C.EAI_SYSTEM:
@@ -78,7 +108,7 @@ func cgoLookupPort(network, service string) (port int, err error, completed bool
                default:
                        err = addrinfoErrno(gerrno)
                }
-               return 0, &DNSError{Err: err.Error(), Name: network + "/" + service}, true
+               return 0, &DNSError{Err: err.Error(), Name: network + "/" + service}
        }
        defer C.freeaddrinfo(res)
 
@@ -87,17 +117,22 @@ func cgoLookupPort(network, service string) (port int, err error, completed bool
                case C.AF_INET:
                        sa := (*syscall.RawSockaddrInet4)(unsafe.Pointer(r.ai_addr))
                        p := (*[2]byte)(unsafe.Pointer(&sa.Port))
-                       return int(p[0])<<8 | int(p[1]), nil, true
+                       return int(p[0])<<8 | int(p[1]), nil
                case C.AF_INET6:
                        sa := (*syscall.RawSockaddrInet6)(unsafe.Pointer(r.ai_addr))
                        p := (*[2]byte)(unsafe.Pointer(&sa.Port))
-                       return int(p[0])<<8 | int(p[1]), nil, true
+                       return int(p[0])<<8 | int(p[1]), nil
                }
        }
-       return 0, &DNSError{Err: "unknown port", Name: network + "/" + service}, true
+       return 0, &DNSError{Err: "unknown port", Name: network + "/" + service}
 }
 
-func cgoLookupIPCNAME(name string) (addrs []IPAddr, cname string, err error, completed bool) {
+func cgoPortLookup(result chan<- portLookupResult, hints *C.struct_addrinfo, network, service string) {
+       port, err := cgoLookupServicePort(hints, network, service)
+       result <- portLookupResult{port, err}
+}
+
+func cgoLookupIPCNAME(name string) (addrs []IPAddr, cname string, err error) {
        acquireThread()
        defer releaseThread()
 
@@ -127,7 +162,7 @@ func cgoLookupIPCNAME(name string) (addrs []IPAddr, cname string, err error, com
                default:
                        err = addrinfoErrno(gerrno)
                }
-               return nil, "", &DNSError{Err: err.Error(), Name: name}, true
+               return nil, "", &DNSError{Err: err.Error(), Name: name}
        }
        defer C.freeaddrinfo(res)
 
@@ -156,17 +191,42 @@ func cgoLookupIPCNAME(name string) (addrs []IPAddr, cname string, err error, com
                        addrs = append(addrs, addr)
                }
        }
-       return addrs, cname, nil, true
+       return addrs, cname, nil
 }
 
-func cgoLookupIP(name string) (addrs []IPAddr, err error, completed bool) {
-       addrs, _, err, completed = cgoLookupIPCNAME(name)
-       return
+func cgoIPLookup(result chan<- ipLookupResult, name string) {
+       addrs, cname, err := cgoLookupIPCNAME(name)
+       result <- ipLookupResult{addrs, cname, err}
 }
 
-func cgoLookupCNAME(name string) (cname string, err error, completed bool) {
-       _, cname, err, completed = cgoLookupIPCNAME(name)
-       return
+func cgoLookupIP(ctx context.Context, name string) (addrs []IPAddr, err error, completed bool) {
+       if ctx.Done() == nil {
+               addrs, _, err = cgoLookupIPCNAME(name)
+               return addrs, err, true
+       }
+       result := make(chan ipLookupResult, 1)
+       go cgoIPLookup(result, name)
+       select {
+       case r := <-result:
+               return r.addrs, r.err, true
+       case <-ctx.Done():
+               return nil, mapErr(ctx.Err()), false
+       }
+}
+
+func cgoLookupCNAME(ctx context.Context, name string) (cname string, err error, completed bool) {
+       if ctx.Done() == nil {
+               _, cname, err = cgoLookupIPCNAME(name)
+               return cname, err, true
+       }
+       result := make(chan ipLookupResult, 1)
+       go cgoIPLookup(result, name)
+       select {
+       case r := <-result:
+               return r.cname, r.err, true
+       case <-ctx.Done():
+               return "", mapErr(ctx.Err()), false
+       }
 }
 
 // These are roughly enough for the following:
@@ -182,10 +242,7 @@ const (
        maxNameinfoLen = 4096
 )
 
-func cgoLookupPTR(addr string) ([]string, error, bool) {
-       acquireThread()
-       defer releaseThread()
-
+func cgoLookupPTR(ctx context.Context, addr string) (names []string, err error, completed bool) {
        var zone string
        ip := parseIPv4(addr)
        if ip == nil {
@@ -198,9 +255,26 @@ func cgoLookupPTR(addr string) ([]string, error, bool) {
        if sa == nil {
                return nil, &DNSError{Err: "invalid address " + ip.String(), Name: addr}, true
        }
-       var err error
-       var b []byte
+       if ctx.Done() == nil {
+               names, err := cgoLookupAddrPTR(addr, sa, salen)
+               return names, err, true
+       }
+       result := make(chan reverseLookupResult, 1)
+       go cgoReverseLookup(result, addr, sa, salen)
+       select {
+       case r := <-result:
+               return r.names, r.err, true
+       case <-ctx.Done():
+               return nil, mapErr(ctx.Err()), false
+       }
+}
+
+func cgoLookupAddrPTR(addr string, sa *C.struct_sockaddr, salen C.socklen_t) (names []string, err error) {
+       acquireThread()
+       defer releaseThread()
+
        var gerrno int
+       var b []byte
        for l := nameinfoLen; l <= maxNameinfoLen; l *= 2 {
                b = make([]byte, l)
                gerrno, err = cgoNameinfoPTR(b, sa, salen)
@@ -217,16 +291,20 @@ func cgoLookupPTR(addr string) ([]string, error, bool) {
                default:
                        err = addrinfoErrno(gerrno)
                }
-               return nil, &DNSError{Err: err.Error(), Name: addr}, true
+               return nil, &DNSError{Err: err.Error(), Name: addr}
        }
-
        for i := 0; i < len(b); i++ {
                if b[i] == 0 {
                        b = b[:i]
                        break
                }
        }
-       return []string{absDomainName(b)}, nil, true
+       return []string{absDomainName(b)}, nil
+}
+
+func cgoReverseLookup(result chan<- reverseLookupResult, addr string, sa *C.struct_sockaddr, salen C.socklen_t) {
+       names, err := cgoLookupAddrPTR(addr, sa, salen)
+       result <- reverseLookupResult{names, err}
 }
 
 func cgoSockaddr(ip IP, zone string) (*C.struct_sockaddr, C.socklen_t) {
index 5dc7b1a62d466312a9249eca9ad748a6ee07e852..e861c7aa1f90d8a1372a08d858d778d16c36fff3 100644 (file)
@@ -13,15 +13,70 @@ import (
 )
 
 func TestCgoLookupIP(t *testing.T) {
-       host := "localhost"
-       _, err, ok := cgoLookupIP(host)
+       ctx := context.Background()
+       _, err, ok := cgoLookupIP(ctx, "localhost")
        if !ok {
                t.Errorf("cgoLookupIP must not be a placeholder")
        }
        if err != nil {
                t.Error(err)
        }
-       if _, err := goLookupIP(context.Background(), host); err != nil {
+}
+
+func TestCgoLookupIPWithCancel(t *testing.T) {
+       ctx, cancel := context.WithCancel(context.Background())
+       defer cancel()
+       _, err, ok := cgoLookupIP(ctx, "localhost")
+       if !ok {
+               t.Errorf("cgoLookupIP must not be a placeholder")
+       }
+       if err != nil {
+               t.Error(err)
+       }
+}
+
+func TestCgoLookupPort(t *testing.T) {
+       ctx := context.Background()
+       _, err, ok := cgoLookupPort(ctx, "tcp", "smtp")
+       if !ok {
+               t.Errorf("cgoLookupPort must not be a placeholder")
+       }
+       if err != nil {
+               t.Error(err)
+       }
+}
+
+func TestCgoLookupPortWithCancel(t *testing.T) {
+       ctx, cancel := context.WithCancel(context.Background())
+       defer cancel()
+       _, err, ok := cgoLookupPort(ctx, "tcp", "smtp")
+       if !ok {
+               t.Errorf("cgoLookupPort must not be a placeholder")
+       }
+       if err != nil {
+               t.Error(err)
+       }
+}
+
+func TestCgoLookupPTR(t *testing.T) {
+       ctx := context.Background()
+       _, err, ok := cgoLookupPTR(ctx, "127.0.0.1")
+       if !ok {
+               t.Errorf("cgoLookupPTR must not be a placeholder")
+       }
+       if err != nil {
+               t.Error(err)
+       }
+}
+
+func TestCgoLookupPTRWithCancel(t *testing.T) {
+       ctx, cancel := context.WithCancel(context.Background())
+       defer cancel()
+       _, err, ok := cgoLookupPTR(ctx, "127.0.0.1")
+       if !ok {
+               t.Errorf("cgoLookupPTR must not be a placeholder")
+       }
+       if err != nil {
                t.Error(err)
        }
 }
index 5461fe8a411e8d60e6eb8a71c43382356285fe90..15397e8105a92e84ecebe35489edabf9b9ceaf01 100644 (file)
@@ -55,7 +55,7 @@ func lookupProtocol(_ context.Context, name string) (int, error) {
 func lookupHost(ctx context.Context, host string) (addrs []string, err error) {
        order := systemConf().hostLookupOrder(host)
        if order == hostLookupCgo {
-               if addrs, err, ok := cgoLookupHost(host); ok {
+               if addrs, err, ok := cgoLookupHost(ctx, host); ok {
                        return addrs, err
                }
                // cgo not available (or netgo); fall back to Go's DNS resolver
@@ -67,8 +67,7 @@ func lookupHost(ctx context.Context, host string) (addrs []string, err error) {
 func lookupIP(ctx context.Context, host string) (addrs []IPAddr, err error) {
        order := systemConf().hostLookupOrder(host)
        if order == hostLookupCgo {
-               // TODO(bradfitz): push down ctx, or at least its deadline to start
-               if addrs, err, ok := cgoLookupIP(host); ok {
+               if addrs, err, ok := cgoLookupIP(ctx, host); ok {
                        return addrs, err
                }
                // cgo not available (or netgo); fall back to Go's DNS resolver
@@ -84,7 +83,7 @@ func lookupPort(ctx context.Context, network, service string) (int, error) {
        // files might be on a remote filesystem, though. This should
        // probably race goroutines if ctx != context.Background().
        if systemConf().canUseCgo() {
-               if port, err, ok := cgoLookupPort(network, service); ok {
+               if port, err, ok := cgoLookupPort(ctx, network, service); ok {
                        return port, err
                }
        }
@@ -93,8 +92,7 @@ func lookupPort(ctx context.Context, network, service string) (int, error) {
 
 func lookupCNAME(ctx context.Context, name string) (string, error) {
        if systemConf().canUseCgo() {
-               // TODO: use ctx. issue 15321. Or race goroutines.
-               if cname, err, ok := cgoLookupCNAME(name); ok {
+               if cname, err, ok := cgoLookupCNAME(ctx, name); ok {
                        return cname, err
                }
        }
@@ -161,7 +159,7 @@ func lookupTXT(ctx context.Context, name string) ([]string, error) {
 
 func lookupAddr(ctx context.Context, addr string) ([]string, error) {
        if systemConf().canUseCgo() {
-               if ptrs, err, ok := cgoLookupPTR(addr); ok {
+               if ptrs, err, ok := cgoLookupPTR(ctx, addr); ok {
                        return ptrs, err
                }
        }
index 0a118874c25aef7912c82054b1b6f45149bf013e..5f1eb19e1240bdc9907810482a3102ddb18e707e 100644 (file)
@@ -14,14 +14,15 @@ import (
 
 func TestGoLookupIP(t *testing.T) {
        host := "localhost"
-       _, err, ok := cgoLookupIP(host)
+       ctx := context.Background()
+       _, err, ok := cgoLookupIP(ctx, host)
        if ok {
                t.Errorf("cgoLookupIP must be a placeholder")
        }
        if err != nil {
                t.Error(err)
        }
-       if _, err := goLookupIP(context.Background(), host); err != nil {
+       if _, err := goLookupIP(ctx, host); err != nil {
                t.Error(err)
        }
 }