]> Cypherpunks repositories - gostls13.git/commitdiff
net: support context cancellation in acquireThread
authorMateusz Poliwczak <mpoliwczak34@gmail.com>
Wed, 21 Feb 2024 17:18:43 +0000 (17:18 +0000)
committerGopher Robot <gobot@golang.org>
Wed, 21 Feb 2024 18:47:20 +0000 (18:47 +0000)
acquireThread is already waiting on a channel, so
it can be easily wired up to support context cancellation.
This change will make sure that contexts that are
cancelled at the acquireThread stage (when the limit of
threads is reached) do not queue unnecessarily and cause
an unnecessary cgo call that will be soon aborted by
the doBlockingWithCtx function.

Updates #63978

Change-Id: I8ae4debd51995637567d8f51c6f1ed60f23d6c0c
GitHub-Last-Rev: 4189b9faf07c073a2ca440becee07b6aa9c4e795
GitHub-Pull-Request: golang/go#63985
Reviewed-on: https://go-review.googlesource.com/c/go/+/539360
Auto-Submit: Ian Lance Taylor <iant@golang.org>
Reviewed-by: Ian Lance Taylor <iant@google.com>
Reviewed-by: Bryan Mills <bcmills@google.com>
Commit-Queue: Ian Lance Taylor <iant@golang.org>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>

src/net/cgo_unix.go
src/net/lookup_windows.go
src/net/net.go

index 987931501948192c904f7a8543be46a83158b9a2..82ec4441fcbbc9ffd7f1897bba7ceaaa8b3bf22e 100644 (file)
@@ -40,8 +40,20 @@ func (eai addrinfoErrno) isAddrinfoErrno() {}
 // doBlockingWithCtx executes a blocking function in a separate goroutine when the provided
 // context is cancellable. It is intended for use with calls that don't support context
 // cancellation (cgo, syscalls). blocking func may still be running after this function finishes.
-func doBlockingWithCtx[T any](ctx context.Context, blocking func() (T, error)) (T, error) {
+// For the duration of the execution of the blocking function, the thread is 'acquired' using [acquireThread],
+// blocking might not be executed when the context gets cancelled early.
+func doBlockingWithCtx[T any](ctx context.Context, lookupName string, blocking func() (T, error)) (T, error) {
+       if err := acquireThread(ctx); err != nil {
+               var zero T
+               return zero, &DNSError{
+                       Name:      lookupName,
+                       Err:       mapErr(err).Error(),
+                       IsTimeout: err == context.DeadlineExceeded,
+               }
+       }
+
        if ctx.Done() == nil {
+               defer releaseThread()
                return blocking()
        }
 
@@ -52,6 +64,7 @@ func doBlockingWithCtx[T any](ctx context.Context, blocking func() (T, error)) (
 
        res := make(chan result, 1)
        go func() {
+               defer releaseThread()
                var r result
                r.res, r.err = blocking()
                res <- r
@@ -62,7 +75,11 @@ func doBlockingWithCtx[T any](ctx context.Context, blocking func() (T, error)) (
                return r.res, r.err
        case <-ctx.Done():
                var zero T
-               return zero, mapErr(ctx.Err())
+               return zero, &DNSError{
+                       Name:      lookupName,
+                       Err:       mapErr(ctx.Err()).Error(),
+                       IsTimeout: ctx.Err() == context.DeadlineExceeded,
+               }
        }
 }
 
@@ -97,7 +114,7 @@ func cgoLookupPort(ctx context.Context, network, service string) (port int, err
                *_C_ai_family(&hints) = _C_AF_INET6
        }
 
-       return doBlockingWithCtx(ctx, func() (int, error) {
+       return doBlockingWithCtx(ctx, network+"/"+service, func() (int, error) {
                return cgoLookupServicePort(&hints, network, service)
        })
 }
@@ -146,9 +163,6 @@ func cgoLookupServicePort(hints *_C_struct_addrinfo, network, service string) (p
 }
 
 func cgoLookupHostIP(network, name string) (addrs []IPAddr, err error) {
-       acquireThread()
-       defer releaseThread()
-
        var hints _C_struct_addrinfo
        *_C_ai_flags(&hints) = cgoAddrInfoFlags
        *_C_ai_socktype(&hints) = _C_SOCK_STREAM
@@ -213,7 +227,7 @@ func cgoLookupHostIP(network, name string) (addrs []IPAddr, err error) {
 }
 
 func cgoLookupIP(ctx context.Context, network, name string) (addrs []IPAddr, err error) {
-       return doBlockingWithCtx(ctx, func() ([]IPAddr, error) {
+       return doBlockingWithCtx(ctx, name, func() ([]IPAddr, error) {
                return cgoLookupHostIP(network, name)
        })
 }
@@ -241,15 +255,12 @@ func cgoLookupPTR(ctx context.Context, addr string) (names []string, err error)
                return nil, &DNSError{Err: "invalid address " + ip.String(), Name: addr}
        }
 
-       return doBlockingWithCtx(ctx, func() ([]string, error) {
+       return doBlockingWithCtx(ctx, addr, func() ([]string, error) {
                return cgoLookupAddrPTR(addr, sa, salen)
        })
 }
 
 func cgoLookupAddrPTR(addr string, sa *_C_struct_sockaddr, salen _C_socklen_t) (names []string, err error) {
-       acquireThread()
-       defer releaseThread()
-
        var gerrno int
        var b []byte
        for l := nameinfoLen; l <= maxNameinfoLen; l *= 2 {
@@ -310,15 +321,12 @@ func cgoLookupCNAME(ctx context.Context, name string) (cname string, err error,
 // resSearch will make a call to the 'res_nsearch' routine in the C library
 // and parse the output as a slice of DNS resources.
 func resSearch(ctx context.Context, hostname string, rtype, class int) ([]dnsmessage.Resource, error) {
-       return doBlockingWithCtx(ctx, func() ([]dnsmessage.Resource, error) {
+       return doBlockingWithCtx(ctx, hostname, func() ([]dnsmessage.Resource, error) {
                return cgoResSearch(hostname, rtype, class)
        })
 }
 
 func cgoResSearch(hostname string, rtype, class int) ([]dnsmessage.Resource, error) {
-       acquireThread()
-       defer releaseThread()
-
        resStateSize := unsafe.Sizeof(_C_struct___res_state{})
        var state *_C_struct___res_state
        if resStateSize > 0 {
index 3048f3269b003e05c34cb049411f3db72ff746c8..946622761cf202d50539de8464bea63e4483b09b 100644 (file)
@@ -54,7 +54,10 @@ func lookupProtocol(ctx context.Context, name string) (int, error) {
        }
        ch := make(chan result) // unbuffered
        go func() {
-               acquireThread()
+               if err := acquireThread(ctx); err != nil {
+                       ch <- result{err: mapErr(err)}
+                       return
+               }
                defer releaseThread()
                runtime.LockOSThread()
                defer runtime.UnlockOSThread()
@@ -111,7 +114,13 @@ func (r *Resolver) lookupIP(ctx context.Context, network, name string) ([]IPAddr
        }
 
        getaddr := func() ([]IPAddr, error) {
-               acquireThread()
+               if err := acquireThread(ctx); err != nil {
+                       return nil, &DNSError{
+                               Name:      name,
+                               Err:       mapErr(err).Error(),
+                               IsTimeout: ctx.Err() == context.DeadlineExceeded,
+                       }
+               }
                defer releaseThread()
                hints := syscall.AddrinfoW{
                        Family:   family,
@@ -200,8 +209,14 @@ func (r *Resolver) lookupPort(ctx context.Context, network, service string) (int
                return lookupPortMap(network, service)
        }
 
-       // TODO(bradfitz): finish ctx plumbing. Nothing currently depends on this.
-       acquireThread()
+       // TODO(bradfitz): finish ctx plumbing
+       if err := acquireThread(ctx); err != nil {
+               return 0, &DNSError{
+                       Name:      network + "/" + service,
+                       Err:       mapErr(err).Error(),
+                       IsTimeout: ctx.Err() == context.DeadlineExceeded,
+               }
+       }
        defer releaseThread()
 
        var hints syscall.AddrinfoW
@@ -263,8 +278,14 @@ func (r *Resolver) lookupCNAME(ctx context.Context, name string) (string, error)
                return r.goLookupCNAME(ctx, name, order, conf)
        }
 
-       // TODO(bradfitz): finish ctx plumbing. Nothing currently depends on this.
-       acquireThread()
+       // TODO(bradfitz): finish ctx plumbing
+       if err := acquireThread(ctx); err != nil {
+               return "", &DNSError{
+                       Name:      name,
+                       Err:       mapErr(err).Error(),
+                       IsTimeout: ctx.Err() == context.DeadlineExceeded,
+               }
+       }
        defer releaseThread()
        var rec *syscall.DNSRecord
        e := syscall.DnsQuery(name, syscall.DNS_TYPE_CNAME, 0, nil, &rec, nil)
@@ -288,8 +309,14 @@ func (r *Resolver) lookupSRV(ctx context.Context, service, proto, name string) (
        if systemConf().mustUseGoResolver(r) {
                return r.goLookupSRV(ctx, service, proto, name)
        }
-       // TODO(bradfitz): finish ctx plumbing. Nothing currently depends on this.
-       acquireThread()
+       // TODO(bradfitz): finish ctx plumbing
+       if err := acquireThread(ctx); err != nil {
+               return "", nil, &DNSError{
+                       Name:      name,
+                       Err:       mapErr(err).Error(),
+                       IsTimeout: ctx.Err() == context.DeadlineExceeded,
+               }
+       }
        defer releaseThread()
        var target string
        if service == "" && proto == "" {
@@ -318,8 +345,14 @@ func (r *Resolver) lookupMX(ctx context.Context, name string) ([]*MX, error) {
        if systemConf().mustUseGoResolver(r) {
                return r.goLookupMX(ctx, name)
        }
-       // TODO(bradfitz): finish ctx plumbing. Nothing currently depends on this.
-       acquireThread()
+       // TODO(bradfitz): finish ctx plumbing.
+       if err := acquireThread(ctx); err != nil {
+               return nil, &DNSError{
+                       Name:      name,
+                       Err:       mapErr(err).Error(),
+                       IsTimeout: ctx.Err() == context.DeadlineExceeded,
+               }
+       }
        defer releaseThread()
        var rec *syscall.DNSRecord
        e := syscall.DnsQuery(name, syscall.DNS_TYPE_MX, 0, nil, &rec, nil)
@@ -342,8 +375,14 @@ func (r *Resolver) lookupNS(ctx context.Context, name string) ([]*NS, error) {
        if systemConf().mustUseGoResolver(r) {
                return r.goLookupNS(ctx, name)
        }
-       // TODO(bradfitz): finish ctx plumbing. Nothing currently depends on this.
-       acquireThread()
+       // TODO(bradfitz): finish ctx plumbing.
+       if err := acquireThread(ctx); err != nil {
+               return nil, &DNSError{
+                       Name:      name,
+                       Err:       mapErr(err).Error(),
+                       IsTimeout: ctx.Err() == context.DeadlineExceeded,
+               }
+       }
        defer releaseThread()
        var rec *syscall.DNSRecord
        e := syscall.DnsQuery(name, syscall.DNS_TYPE_NS, 0, nil, &rec, nil)
@@ -365,8 +404,14 @@ func (r *Resolver) lookupTXT(ctx context.Context, name string) ([]string, error)
        if systemConf().mustUseGoResolver(r) {
                return r.goLookupTXT(ctx, name)
        }
-       // TODO(bradfitz): finish ctx plumbing. Nothing currently depends on this.
-       acquireThread()
+       // TODO(bradfitz): finish ctx plumbing.
+       if err := acquireThread(ctx); err != nil {
+               return nil, &DNSError{
+                       Name:      name,
+                       Err:       mapErr(err).Error(),
+                       IsTimeout: ctx.Err() == context.DeadlineExceeded,
+               }
+       }
        defer releaseThread()
        var rec *syscall.DNSRecord
        e := syscall.DnsQuery(name, syscall.DNS_TYPE_TEXT, 0, nil, &rec, nil)
@@ -393,8 +438,14 @@ func (r *Resolver) lookupAddr(ctx context.Context, addr string) ([]string, error
                return r.goLookupPTR(ctx, addr, order, conf)
        }
 
-       // TODO(bradfitz): finish ctx plumbing. Nothing currently depends on this.
-       acquireThread()
+       // TODO(bradfitz): finish ctx plumbing.
+       if err := acquireThread(ctx); err != nil {
+               return nil, &DNSError{
+                       Name:      addr,
+                       Err:       mapErr(err).Error(),
+                       IsTimeout: ctx.Err() == context.DeadlineExceeded,
+               }
+       }
        defer releaseThread()
        arpa, err := reverseaddr(addr)
        if err != nil {
index 387f2bb14dfe7dc2dc9adcb7edfa5a4c21b458f0..b5f7303db34dae489310b29a01cb3b892b07cf25 100644 (file)
@@ -727,11 +727,16 @@ var threadLimit chan struct{}
 
 var threadOnce sync.Once
 
-func acquireThread() {
+func acquireThread(ctx context.Context) error {
        threadOnce.Do(func() {
                threadLimit = make(chan struct{}, concurrentThreadsLimit())
        })
-       threadLimit <- struct{}{}
+       select {
+       case threadLimit <- struct{}{}:
+               return nil
+       case <-ctx.Done():
+               return ctx.Err()
+       }
 }
 
 func releaseThread() {