From: Mateusz Poliwczak Date: Fri, 10 Feb 2023 09:01:29 +0000 (+0000) Subject: net: move context cancellation logic of blocking calls to a common function X-Git-Tag: go1.21rc1~1583 X-Git-Url: http://www.git.cypherpunks.su/?a=commitdiff_plain;h=88ad44766f8db22a53ea5f3a946a57cdf26b818e;p=gostls13.git net: move context cancellation logic of blocking calls to a common function Follow up of CL 464375 which reverted the CL 463231, because of a data race. Change-Id: I1a52f23a68a6981b902fc59bda1437bd169ca22b GitHub-Last-Rev: 0157bd01807a731239f3f2940d440e798be33d83 GitHub-Pull-Request: golang/go#58383 Reviewed-on: https://go-review.googlesource.com/c/go/+/465836 Reviewed-by: Ian Lance Taylor TryBot-Result: Gopher Robot Reviewed-by: Damien Neil Run-TryBot: Ian Lance Taylor Auto-Submit: Ian Lance Taylor --- diff --git a/src/net/cgo_unix.go b/src/net/cgo_unix.go index b90b579ffc..e378a87ba3 100644 --- a/src/net/cgo_unix.go +++ b/src/net/cgo_unix.go @@ -30,19 +30,33 @@ func (eai addrinfoErrno) Error() string { return _C_gai_strerror(_C_int(eai)) func (eai addrinfoErrno) Temporary() bool { return eai == _C_EAI_AGAIN } func (eai addrinfoErrno) Timeout() bool { return false } -type portLookupResult struct { - port int - err error -} +// doBlockingWithCtx executes a blocking function in a separate goroutine when the provided +// context is cancellable. It is intended for use with calls that don't support context +// cancellation (cgo, syscalls). blocking func may still be running after this function finishes. +func doBlockingWithCtx[T any](ctx context.Context, blocking func() (T, error)) (T, error) { + if ctx.Done() == nil { + return blocking() + } -type ipLookupResult struct { - addrs []IPAddr - err error -} + type result struct { + res T + err error + } -type reverseLookupResult struct { - names []string - err error + res := make(chan result, 1) + go func() { + var r result + r.res, r.err = blocking() + res <- r + }() + + select { + case r := <-res: + return r.res, r.err + case <-ctx.Done(): + var zero T + return zero, mapErr(ctx.Err()) + } } func cgoLookupHost(ctx context.Context, name string) (hosts []string, err error, completed bool) { @@ -72,20 +86,11 @@ func cgoLookupPort(ctx context.Context, network, service string) (port int, err case '6': *_C_ai_family(&hints) = _C_AF_INET6 } - if ctx.Done() == nil { - port, err := cgoLookupServicePort(&hints, network, service) - return port, err, true - } - result := make(chan portLookupResult, 1) - go cgoPortLookup(result, &hints, network, service) - select { - case r := <-result: - return r.port, r.err, true - case <-ctx.Done(): - // Since there isn't a portable way to cancel the lookup, - // we just let it finish and write to the buffered channel. - return 0, mapErr(ctx.Err()), false - } + + port, err = doBlockingWithCtx(ctx, func() (int, error) { + return cgoLookupServicePort(&hints, network, service) + }) + return port, err, true } func cgoLookupServicePort(hints *_C_struct_addrinfo, network, service string) (port int, err error) { @@ -127,11 +132,6 @@ func cgoLookupServicePort(hints *_C_struct_addrinfo, network, service string) (p return 0, &DNSError{Err: "unknown port", Name: network + "/" + service} } -func cgoPortLookup(result chan<- portLookupResult, hints *_C_struct_addrinfo, network, service string) { - port, err := cgoLookupServicePort(hints, network, service) - result <- portLookupResult{port, err} -} - func cgoLookupHostIP(network, name string) (addrs []IPAddr, err error) { acquireThread() defer releaseThread() @@ -197,24 +197,11 @@ func cgoLookupHostIP(network, name string) (addrs []IPAddr, err error) { return addrs, nil } -func cgoIPLookup(result chan<- ipLookupResult, network, name string) { - addrs, err := cgoLookupHostIP(network, name) - result <- ipLookupResult{addrs, err} -} - func cgoLookupIP(ctx context.Context, network, name string) (addrs []IPAddr, err error, completed bool) { - if ctx.Done() == nil { - addrs, err = cgoLookupHostIP(network, name) - return addrs, err, true - } - result := make(chan ipLookupResult, 1) - go cgoIPLookup(result, network, name) - select { - case r := <-result: - return r.addrs, r.err, true - case <-ctx.Done(): - return nil, mapErr(ctx.Err()), true - } + addrs, err = doBlockingWithCtx(ctx, func() ([]IPAddr, error) { + return cgoLookupHostIP(network, name) + }) + return addrs, err, true } // These are roughly enough for the following: @@ -239,18 +226,11 @@ func cgoLookupPTR(ctx context.Context, addr string) (names []string, err error, if sa == nil { return nil, &DNSError{Err: "invalid address " + ip.String(), Name: addr}, true } - if ctx.Done() == nil { - names, err := cgoLookupAddrPTR(addr, sa, salen) - return names, err, true - } - result := make(chan reverseLookupResult, 1) - go cgoReverseLookup(result, addr, sa, salen) - select { - case r := <-result: - return r.names, r.err, true - case <-ctx.Done(): - return nil, mapErr(ctx.Err()), true - } + + names, err = doBlockingWithCtx(ctx, func() ([]string, error) { + return cgoLookupAddrPTR(addr, sa, salen) + }) + return names, err, true } func cgoLookupAddrPTR(addr string, sa *_C_struct_sockaddr, salen _C_socklen_t) (names []string, err error) { @@ -292,11 +272,6 @@ func cgoLookupAddrPTR(addr string, sa *_C_struct_sockaddr, salen _C_socklen_t) ( return []string{absDomainName(string(b))}, nil } -func cgoReverseLookup(result chan<- reverseLookupResult, addr string, sa *_C_struct_sockaddr, salen _C_socklen_t) { - names, err := cgoLookupAddrPTR(addr, sa, salen) - result <- reverseLookupResult{names, err} -} - func cgoSockaddr(ip IP, zone string) (*_C_struct_sockaddr, _C_socklen_t) { if ip4 := ip.To4(); ip4 != nil { return cgoSockaddrInet4(ip4), _C_socklen_t(syscall.SizeofSockaddrInet4) @@ -322,30 +297,9 @@ func cgoLookupCNAME(ctx context.Context, name string) (cname string, err error, // resSearch will make a call to the 'res_nsearch' routine in the C library // and parse the output as a slice of DNS resources. func resSearch(ctx context.Context, hostname string, rtype, class int) ([]dnsmessage.Resource, error) { - if ctx.Done() == nil { + return doBlockingWithCtx(ctx, func() ([]dnsmessage.Resource, error) { return cgoResSearch(hostname, rtype, class) - } - - type result struct { - res []dnsmessage.Resource - err error - } - - res := make(chan result, 1) - go func() { - r, err := cgoResSearch(hostname, rtype, class) - res <- result{ - res: r, - err: err, - } - }() - - select { - case res := <-res: - return res.res, res.err - case <-ctx.Done(): - return nil, mapErr(ctx.Err()) - } + }) } func cgoResSearch(hostname string, rtype, class int) ([]dnsmessage.Resource, error) {