From: Mateusz Poliwczak Date: Wed, 21 Feb 2024 17:18:43 +0000 (+0000) Subject: net: support context cancellation in acquireThread X-Git-Tag: go1.23rc1~1168 X-Git-Url: http://www.git.cypherpunks.su/?a=commitdiff_plain;h=c07b9b00361b9a99fc64ffe36f897d24954f99cf;p=gostls13.git net: support context cancellation in acquireThread acquireThread is already waiting on a channel, so it can be easily wired up to support context cancellation. This change will make sure that contexts that are cancelled at the acquireThread stage (when the limit of threads is reached) do not queue unnecessarily and cause an unnecessary cgo call that will be soon aborted by the doBlockingWithCtx function. Updates #63978 Change-Id: I8ae4debd51995637567d8f51c6f1ed60f23d6c0c GitHub-Last-Rev: 4189b9faf07c073a2ca440becee07b6aa9c4e795 GitHub-Pull-Request: golang/go#63985 Reviewed-on: https://go-review.googlesource.com/c/go/+/539360 Auto-Submit: Ian Lance Taylor Reviewed-by: Ian Lance Taylor Reviewed-by: Bryan Mills Commit-Queue: Ian Lance Taylor LUCI-TryBot-Result: Go LUCI --- diff --git a/src/net/cgo_unix.go b/src/net/cgo_unix.go index 9879315019..82ec4441fc 100644 --- a/src/net/cgo_unix.go +++ b/src/net/cgo_unix.go @@ -40,8 +40,20 @@ func (eai addrinfoErrno) isAddrinfoErrno() {} // doBlockingWithCtx executes a blocking function in a separate goroutine when the provided // context is cancellable. It is intended for use with calls that don't support context // cancellation (cgo, syscalls). blocking func may still be running after this function finishes. -func doBlockingWithCtx[T any](ctx context.Context, blocking func() (T, error)) (T, error) { +// For the duration of the execution of the blocking function, the thread is 'acquired' using [acquireThread], +// blocking might not be executed when the context gets cancelled early. +func doBlockingWithCtx[T any](ctx context.Context, lookupName string, blocking func() (T, error)) (T, error) { + if err := acquireThread(ctx); err != nil { + var zero T + return zero, &DNSError{ + Name: lookupName, + Err: mapErr(err).Error(), + IsTimeout: err == context.DeadlineExceeded, + } + } + if ctx.Done() == nil { + defer releaseThread() return blocking() } @@ -52,6 +64,7 @@ func doBlockingWithCtx[T any](ctx context.Context, blocking func() (T, error)) ( res := make(chan result, 1) go func() { + defer releaseThread() var r result r.res, r.err = blocking() res <- r @@ -62,7 +75,11 @@ func doBlockingWithCtx[T any](ctx context.Context, blocking func() (T, error)) ( return r.res, r.err case <-ctx.Done(): var zero T - return zero, mapErr(ctx.Err()) + return zero, &DNSError{ + Name: lookupName, + Err: mapErr(ctx.Err()).Error(), + IsTimeout: ctx.Err() == context.DeadlineExceeded, + } } } @@ -97,7 +114,7 @@ func cgoLookupPort(ctx context.Context, network, service string) (port int, err *_C_ai_family(&hints) = _C_AF_INET6 } - return doBlockingWithCtx(ctx, func() (int, error) { + return doBlockingWithCtx(ctx, network+"/"+service, func() (int, error) { return cgoLookupServicePort(&hints, network, service) }) } @@ -146,9 +163,6 @@ func cgoLookupServicePort(hints *_C_struct_addrinfo, network, service string) (p } func cgoLookupHostIP(network, name string) (addrs []IPAddr, err error) { - acquireThread() - defer releaseThread() - var hints _C_struct_addrinfo *_C_ai_flags(&hints) = cgoAddrInfoFlags *_C_ai_socktype(&hints) = _C_SOCK_STREAM @@ -213,7 +227,7 @@ func cgoLookupHostIP(network, name string) (addrs []IPAddr, err error) { } func cgoLookupIP(ctx context.Context, network, name string) (addrs []IPAddr, err error) { - return doBlockingWithCtx(ctx, func() ([]IPAddr, error) { + return doBlockingWithCtx(ctx, name, func() ([]IPAddr, error) { return cgoLookupHostIP(network, name) }) } @@ -241,15 +255,12 @@ func cgoLookupPTR(ctx context.Context, addr string) (names []string, err error) return nil, &DNSError{Err: "invalid address " + ip.String(), Name: addr} } - return doBlockingWithCtx(ctx, func() ([]string, error) { + return doBlockingWithCtx(ctx, addr, func() ([]string, error) { return cgoLookupAddrPTR(addr, sa, salen) }) } func cgoLookupAddrPTR(addr string, sa *_C_struct_sockaddr, salen _C_socklen_t) (names []string, err error) { - acquireThread() - defer releaseThread() - var gerrno int var b []byte for l := nameinfoLen; l <= maxNameinfoLen; l *= 2 { @@ -310,15 +321,12 @@ func cgoLookupCNAME(ctx context.Context, name string) (cname string, err error, // resSearch will make a call to the 'res_nsearch' routine in the C library // and parse the output as a slice of DNS resources. func resSearch(ctx context.Context, hostname string, rtype, class int) ([]dnsmessage.Resource, error) { - return doBlockingWithCtx(ctx, func() ([]dnsmessage.Resource, error) { + return doBlockingWithCtx(ctx, hostname, func() ([]dnsmessage.Resource, error) { return cgoResSearch(hostname, rtype, class) }) } func cgoResSearch(hostname string, rtype, class int) ([]dnsmessage.Resource, error) { - acquireThread() - defer releaseThread() - resStateSize := unsafe.Sizeof(_C_struct___res_state{}) var state *_C_struct___res_state if resStateSize > 0 { diff --git a/src/net/lookup_windows.go b/src/net/lookup_windows.go index 3048f3269b..946622761c 100644 --- a/src/net/lookup_windows.go +++ b/src/net/lookup_windows.go @@ -54,7 +54,10 @@ func lookupProtocol(ctx context.Context, name string) (int, error) { } ch := make(chan result) // unbuffered go func() { - acquireThread() + if err := acquireThread(ctx); err != nil { + ch <- result{err: mapErr(err)} + return + } defer releaseThread() runtime.LockOSThread() defer runtime.UnlockOSThread() @@ -111,7 +114,13 @@ func (r *Resolver) lookupIP(ctx context.Context, network, name string) ([]IPAddr } getaddr := func() ([]IPAddr, error) { - acquireThread() + if err := acquireThread(ctx); err != nil { + return nil, &DNSError{ + Name: name, + Err: mapErr(err).Error(), + IsTimeout: ctx.Err() == context.DeadlineExceeded, + } + } defer releaseThread() hints := syscall.AddrinfoW{ Family: family, @@ -200,8 +209,14 @@ func (r *Resolver) lookupPort(ctx context.Context, network, service string) (int return lookupPortMap(network, service) } - // TODO(bradfitz): finish ctx plumbing. Nothing currently depends on this. - acquireThread() + // TODO(bradfitz): finish ctx plumbing + if err := acquireThread(ctx); err != nil { + return 0, &DNSError{ + Name: network + "/" + service, + Err: mapErr(err).Error(), + IsTimeout: ctx.Err() == context.DeadlineExceeded, + } + } defer releaseThread() var hints syscall.AddrinfoW @@ -263,8 +278,14 @@ func (r *Resolver) lookupCNAME(ctx context.Context, name string) (string, error) return r.goLookupCNAME(ctx, name, order, conf) } - // TODO(bradfitz): finish ctx plumbing. Nothing currently depends on this. - acquireThread() + // TODO(bradfitz): finish ctx plumbing + if err := acquireThread(ctx); err != nil { + return "", &DNSError{ + Name: name, + Err: mapErr(err).Error(), + IsTimeout: ctx.Err() == context.DeadlineExceeded, + } + } defer releaseThread() var rec *syscall.DNSRecord e := syscall.DnsQuery(name, syscall.DNS_TYPE_CNAME, 0, nil, &rec, nil) @@ -288,8 +309,14 @@ func (r *Resolver) lookupSRV(ctx context.Context, service, proto, name string) ( if systemConf().mustUseGoResolver(r) { return r.goLookupSRV(ctx, service, proto, name) } - // TODO(bradfitz): finish ctx plumbing. Nothing currently depends on this. - acquireThread() + // TODO(bradfitz): finish ctx plumbing + if err := acquireThread(ctx); err != nil { + return "", nil, &DNSError{ + Name: name, + Err: mapErr(err).Error(), + IsTimeout: ctx.Err() == context.DeadlineExceeded, + } + } defer releaseThread() var target string if service == "" && proto == "" { @@ -318,8 +345,14 @@ func (r *Resolver) lookupMX(ctx context.Context, name string) ([]*MX, error) { if systemConf().mustUseGoResolver(r) { return r.goLookupMX(ctx, name) } - // TODO(bradfitz): finish ctx plumbing. Nothing currently depends on this. - acquireThread() + // TODO(bradfitz): finish ctx plumbing. + if err := acquireThread(ctx); err != nil { + return nil, &DNSError{ + Name: name, + Err: mapErr(err).Error(), + IsTimeout: ctx.Err() == context.DeadlineExceeded, + } + } defer releaseThread() var rec *syscall.DNSRecord e := syscall.DnsQuery(name, syscall.DNS_TYPE_MX, 0, nil, &rec, nil) @@ -342,8 +375,14 @@ func (r *Resolver) lookupNS(ctx context.Context, name string) ([]*NS, error) { if systemConf().mustUseGoResolver(r) { return r.goLookupNS(ctx, name) } - // TODO(bradfitz): finish ctx plumbing. Nothing currently depends on this. - acquireThread() + // TODO(bradfitz): finish ctx plumbing. + if err := acquireThread(ctx); err != nil { + return nil, &DNSError{ + Name: name, + Err: mapErr(err).Error(), + IsTimeout: ctx.Err() == context.DeadlineExceeded, + } + } defer releaseThread() var rec *syscall.DNSRecord e := syscall.DnsQuery(name, syscall.DNS_TYPE_NS, 0, nil, &rec, nil) @@ -365,8 +404,14 @@ func (r *Resolver) lookupTXT(ctx context.Context, name string) ([]string, error) if systemConf().mustUseGoResolver(r) { return r.goLookupTXT(ctx, name) } - // TODO(bradfitz): finish ctx plumbing. Nothing currently depends on this. - acquireThread() + // TODO(bradfitz): finish ctx plumbing. + if err := acquireThread(ctx); err != nil { + return nil, &DNSError{ + Name: name, + Err: mapErr(err).Error(), + IsTimeout: ctx.Err() == context.DeadlineExceeded, + } + } defer releaseThread() var rec *syscall.DNSRecord e := syscall.DnsQuery(name, syscall.DNS_TYPE_TEXT, 0, nil, &rec, nil) @@ -393,8 +438,14 @@ func (r *Resolver) lookupAddr(ctx context.Context, addr string) ([]string, error return r.goLookupPTR(ctx, addr, order, conf) } - // TODO(bradfitz): finish ctx plumbing. Nothing currently depends on this. - acquireThread() + // TODO(bradfitz): finish ctx plumbing. + if err := acquireThread(ctx); err != nil { + return nil, &DNSError{ + Name: addr, + Err: mapErr(err).Error(), + IsTimeout: ctx.Err() == context.DeadlineExceeded, + } + } defer releaseThread() arpa, err := reverseaddr(addr) if err != nil { diff --git a/src/net/net.go b/src/net/net.go index 387f2bb14d..b5f7303db3 100644 --- a/src/net/net.go +++ b/src/net/net.go @@ -727,11 +727,16 @@ var threadLimit chan struct{} var threadOnce sync.Once -func acquireThread() { +func acquireThread(ctx context.Context) error { threadOnce.Do(func() { threadLimit = make(chan struct{}, concurrentThreadsLimit()) }) - threadLimit <- struct{}{} + select { + case threadLimit <- struct{}{}: + return nil + case <-ctx.Done(): + return ctx.Err() + } } func releaseThread() {