From 00e7fdc07792abdd1cf323c62c5c9bede368ecc0 Mon Sep 17 00:00:00 2001 From: Mateusz Poliwczak Date: Tue, 31 Jan 2023 08:34:13 +0000 Subject: [PATCH] net: move context cancellation logic of blocking calls to a common function. Change-Id: I5f7219a111436e3d6a4685df9461f5a8f8bcb000 GitHub-Last-Rev: e420129bade2681d2b6ce92087ed94444f424810 GitHub-Pull-Request: golang/go#58108 Reviewed-on: https://go-review.googlesource.com/c/go/+/463231 Reviewed-by: Michael Knyszek Run-TryBot: Ian Lance Taylor Auto-Submit: Ian Lance Taylor TryBot-Result: Gopher Robot Reviewed-by: Ian Lance Taylor --- src/net/cgo_unix.go | 128 ++++++++++++++------------------------------ 1 file changed, 41 insertions(+), 87 deletions(-) diff --git a/src/net/cgo_unix.go b/src/net/cgo_unix.go index 38bf20cbb1..0cb71c7d38 100644 --- a/src/net/cgo_unix.go +++ b/src/net/cgo_unix.go @@ -29,20 +29,33 @@ func (eai addrinfoErrno) Error() string { return _C_gai_strerror(_C_int(eai)) func (eai addrinfoErrno) Temporary() bool { return eai == _C_EAI_AGAIN } func (eai addrinfoErrno) Timeout() bool { return false } -type portLookupResult struct { - port int - err error -} +// doBlockingWithCtx executes a blocking function in a separate goroutine when the provided +// context is cancellable. It is intended for use with calls that don't support context +// cancellation (cgo, syscalls). blocking func may still be running after this function finishes. +func doBlockingWithCtx[T any](ctx context.Context, blocking func() (T, error)) (T, error) { + if ctx.Done() == nil { + return blocking() + } -type ipLookupResult struct { - addrs []IPAddr - cname string - err error -} + type result struct { + res T + err error + } -type reverseLookupResult struct { - names []string - err error + res := make(chan result, 1) + go func() { + var r result + r.res, r.err = blocking() + res <- r + }() + + select { + case r := <-res: + return r.res, r.err + case <-ctx.Done(): + var zero T + return zero, mapErr(ctx.Err()) + } } func cgoLookupHost(ctx context.Context, name string) (hosts []string, err error, completed bool) { @@ -72,20 +85,11 @@ func cgoLookupPort(ctx context.Context, network, service string) (port int, err case '6': *_C_ai_family(&hints) = _C_AF_INET6 } - if ctx.Done() == nil { - port, err := cgoLookupServicePort(&hints, network, service) - return port, err, true - } - result := make(chan portLookupResult, 1) - go cgoPortLookup(result, &hints, network, service) - select { - case r := <-result: - return r.port, r.err, true - case <-ctx.Done(): - // Since there isn't a portable way to cancel the lookup, - // we just let it finish and write to the buffered channel. - return 0, mapErr(ctx.Err()), false - } + + port, err = doBlockingWithCtx(ctx, func() (int, error) { + return cgoLookupServicePort(&hints, network, service) + }) + return port, err, true } func cgoLookupServicePort(hints *_C_struct_addrinfo, network, service string) (port int, err error) { @@ -127,11 +131,6 @@ func cgoLookupServicePort(hints *_C_struct_addrinfo, network, service string) (p return 0, &DNSError{Err: "unknown port", Name: network + "/" + service} } -func cgoPortLookup(result chan<- portLookupResult, hints *_C_struct_addrinfo, network, service string) { - port, err := cgoLookupServicePort(hints, network, service) - result <- portLookupResult{port, err} -} - func cgoLookupIPCNAME(network, name string) (addrs []IPAddr, cname string, err error) { acquireThread() defer releaseThread() @@ -206,24 +205,12 @@ func cgoLookupIPCNAME(network, name string) (addrs []IPAddr, cname string, err e return addrs, cname, nil } -func cgoIPLookup(result chan<- ipLookupResult, network, name string) { - addrs, cname, err := cgoLookupIPCNAME(network, name) - result <- ipLookupResult{addrs, cname, err} -} - func cgoLookupIP(ctx context.Context, network, name string) (addrs []IPAddr, err error, completed bool) { - if ctx.Done() == nil { + addrs, err = doBlockingWithCtx(ctx, func() ([]IPAddr, error) { addrs, _, err = cgoLookupIPCNAME(network, name) - return addrs, err, true - } - result := make(chan ipLookupResult, 1) - go cgoIPLookup(result, network, name) - select { - case r := <-result: - return r.addrs, r.err, true - case <-ctx.Done(): - return nil, mapErr(ctx.Err()), true - } + return addrs, err + }) + return addrs, err, true } // These are roughly enough for the following: @@ -252,18 +239,11 @@ func cgoLookupPTR(ctx context.Context, addr string) (names []string, err error, if sa == nil { return nil, &DNSError{Err: "invalid address " + ip.String(), Name: addr}, true } - if ctx.Done() == nil { - names, err := cgoLookupAddrPTR(addr, sa, salen) - return names, err, true - } - result := make(chan reverseLookupResult, 1) - go cgoReverseLookup(result, addr, sa, salen) - select { - case r := <-result: - return r.names, r.err, true - case <-ctx.Done(): - return nil, mapErr(ctx.Err()), true - } + + names, err = doBlockingWithCtx(ctx, func() ([]string, error) { + return cgoLookupAddrPTR(addr, sa, salen) + }) + return names, err, true } func cgoLookupAddrPTR(addr string, sa *_C_struct_sockaddr, salen _C_socklen_t) (names []string, err error) { @@ -301,11 +281,6 @@ func cgoLookupAddrPTR(addr string, sa *_C_struct_sockaddr, salen _C_socklen_t) ( return []string{absDomainName(string(b))}, nil } -func cgoReverseLookup(result chan<- reverseLookupResult, addr string, sa *_C_struct_sockaddr, salen _C_socklen_t) { - names, err := cgoLookupAddrPTR(addr, sa, salen) - result <- reverseLookupResult{names, err} -} - func cgoSockaddr(ip IP, zone string) (*_C_struct_sockaddr, _C_socklen_t) { if ip4 := ip.To4(); ip4 != nil { return cgoSockaddrInet4(ip4), _C_socklen_t(syscall.SizeofSockaddrInet4) @@ -331,30 +306,9 @@ func cgoLookupCNAME(ctx context.Context, name string) (cname string, err error, // resSearch will make a call to the 'res_nsearch' routine in the C library // and parse the output as a slice of DNS resources. func resSearch(ctx context.Context, hostname string, rtype, class int) ([]dnsmessage.Resource, error) { - if ctx.Done() == nil { + return doBlockingWithCtx(ctx, func() ([]dnsmessage.Resource, error) { return cgoResSearch(hostname, rtype, class) - } - - type result struct { - res []dnsmessage.Resource - err error - } - - res := make(chan result, 1) - go func() { - r, err := cgoResSearch(hostname, rtype, class) - res <- result{ - res: r, - err: err, - } - }() - - select { - case res := <-res: - return res.res, res.err - case <-ctx.Done(): - return nil, mapErr(ctx.Err()) - } + }) } func cgoResSearch(hostname string, rtype, class int) ([]dnsmessage.Resource, error) { -- 2.50.0