From 71d84ee7b4e42ea9f35f409f169bebd44a360331 Mon Sep 17 00:00:00 2001 From: Mateusz Poliwczak Date: Sat, 13 Apr 2024 07:01:44 +0000 Subject: [PATCH] net: add Unwrap to *DNSError Fixes #63116 Change-Id: Iab8c415555ab85097be6d2d133b3349c5219a23b GitHub-Last-Rev: 8a8177b9af5509ebbaa701b06c79126aae7510a8 GitHub-Pull-Request: golang/go#63348 Reviewed-on: https://go-review.googlesource.com/c/go/+/532217 Auto-Submit: Ian Lance Taylor Reviewed-by: Damien Neil Reviewed-by: Ian Lance Taylor Commit-Queue: Ian Lance Taylor LUCI-TryBot-Result: Go LUCI --- api/next/63116.txt | 2 + doc/next/6-stdlib/99-minor/net/63116.md | 3 ++ src/net/cgo_unix.go | 30 +++++-------- src/net/dnsclient_unix.go | 60 ++++++++----------------- src/net/dnsclient_unix_test.go | 20 ++++----- src/net/lookup.go | 27 +++-------- src/net/lookup_plan9.go | 4 +- src/net/lookup_test.go | 28 ++++++++++++ src/net/lookup_windows.go | 51 +++++++-------------- src/net/net.go | 53 +++++++++++++++++++++- 10 files changed, 147 insertions(+), 131 deletions(-) create mode 100644 api/next/63116.txt create mode 100644 doc/next/6-stdlib/99-minor/net/63116.md diff --git a/api/next/63116.txt b/api/next/63116.txt new file mode 100644 index 0000000000..47214a9e05 --- /dev/null +++ b/api/next/63116.txt @@ -0,0 +1,2 @@ +pkg net, type DNSError struct, UnwrapErr error #63116 +pkg net, method (*DNSError) Unwrap() error #63116 diff --git a/doc/next/6-stdlib/99-minor/net/63116.md b/doc/next/6-stdlib/99-minor/net/63116.md new file mode 100644 index 0000000000..d847a5545e --- /dev/null +++ b/doc/next/6-stdlib/99-minor/net/63116.md @@ -0,0 +1,3 @@ +The [`DNSError`](/pkg/net#DNSError) type now wraps errors caused by timeouts +or cancelation. For example, `errors.Is(someDNSErr, context.DeadlineExceedeed)` +will now report whether a DNS error was caused by a timeout. diff --git a/src/net/cgo_unix.go b/src/net/cgo_unix.go index 1858e495d2..bc374c2c76 100644 --- a/src/net/cgo_unix.go +++ b/src/net/cgo_unix.go @@ -132,19 +132,17 @@ func cgoLookupServicePort(hints *_C_struct_addrinfo, network, service string) (p var res *_C_struct_addrinfo gerrno, err := _C_getaddrinfo(nil, (*_C_char)(unsafe.Pointer(&cservice[0])), hints, &res) if gerrno != 0 { - isTemporary := false switch gerrno { case _C_EAI_SYSTEM: if err == nil { // see golang.org/issue/6232 err = syscall.EMFILE } + return 0, newDNSError(err, network+"/"+service, "") case _C_EAI_SERVICE, _C_EAI_NONAME: // Darwin returns EAI_NONAME. - return 0, &DNSError{Err: "unknown port", Name: network + "/" + service, IsNotFound: true} + return 0, newDNSError(errUnknownPort, network+"/"+service, "") default: - err = addrinfoErrno(gerrno) - isTemporary = addrinfoErrno(gerrno).Temporary() + return 0, newDNSError(addrinfoErrno(gerrno), network+"/"+service, "") } - return 0, &DNSError{Err: err.Error(), Name: network + "/" + service, IsTemporary: isTemporary} } defer _C_freeaddrinfo(res) @@ -160,7 +158,7 @@ func cgoLookupServicePort(hints *_C_struct_addrinfo, network, service string) (p return int(p[0])<<8 | int(p[1]), nil } } - return 0, &DNSError{Err: "unknown port", Name: network + "/" + service, IsNotFound: true} + return 0, newDNSError(errUnknownPort, network+"/"+service, "") } func cgoLookupHostIP(network, name string) (addrs []IPAddr, err error) { @@ -182,8 +180,6 @@ func cgoLookupHostIP(network, name string) (addrs []IPAddr, err error) { var res *_C_struct_addrinfo gerrno, err := _C_getaddrinfo((*_C_char)(unsafe.Pointer(h)), nil, &hints, &res) if gerrno != 0 { - isErrorNoSuchHost := false - isTemporary := false switch gerrno { case _C_EAI_SYSTEM: if err == nil { @@ -196,15 +192,13 @@ func cgoLookupHostIP(network, name string) (addrs []IPAddr, err error) { // comes up again. golang.org/issue/6232. err = syscall.EMFILE } + return nil, newDNSError(err, name, "") case _C_EAI_NONAME, _C_EAI_NODATA: - err = errNoSuchHost - isErrorNoSuchHost = true + return nil, newDNSError(errNoSuchHost, name, "") default: - err = addrinfoErrno(gerrno) - isTemporary = addrinfoErrno(gerrno).Temporary() + return nil, newDNSError(addrinfoErrno(gerrno), name, "") } - return nil, &DNSError{Err: err.Error(), Name: name, IsNotFound: isErrorNoSuchHost, IsTemporary: isTemporary} } defer _C_freeaddrinfo(res) @@ -272,21 +266,17 @@ func cgoLookupAddrPTR(addr string, sa *_C_struct_sockaddr, salen _C_socklen_t) ( } } if gerrno != 0 { - isErrorNoSuchHost := false - isTemporary := false switch gerrno { case _C_EAI_SYSTEM: if err == nil { // see golang.org/issue/6232 err = syscall.EMFILE } + return nil, newDNSError(err, addr, "") case _C_EAI_NONAME: - err = errNoSuchHost - isErrorNoSuchHost = true + return nil, newDNSError(errNoSuchHost, addr, "") default: - err = addrinfoErrno(gerrno) - isTemporary = addrinfoErrno(gerrno).Temporary() + return nil, newDNSError(addrinfoErrno(gerrno), addr, "") } - return nil, &DNSError{Err: err.Error(), Name: addr, IsTemporary: isTemporary, IsNotFound: isErrorNoSuchHost} } if i := bytealg.IndexByte(b, 0); i != -1 { b = b[:i] diff --git a/src/net/dnsclient_unix.go b/src/net/dnsclient_unix.go index ad5c245dbf..8193189cc7 100644 --- a/src/net/dnsclient_unix.go +++ b/src/net/dnsclient_unix.go @@ -48,7 +48,7 @@ var ( // errServerTemporarilyMisbehaving is like errServerMisbehaving, except // that when it gets translated to a DNSError, the IsTemporary field // gets set to true. - errServerTemporarilyMisbehaving = errors.New("server misbehaving") + errServerTemporarilyMisbehaving = &temporaryError{"server misbehaving"} ) func newRequest(q dnsmessage.Question, ad bool) (id uint16, udpReq, tcpReq []byte, err error) { @@ -292,7 +292,7 @@ func (r *Resolver) tryOneName(ctx context.Context, cfg *dnsConfig, name string, n, err := dnsmessage.NewName(name) if err != nil { - return dnsmessage.Parser{}, "", errCannotMarshalDNSMessage + return dnsmessage.Parser{}, "", &DNSError{Err: errCannotMarshalDNSMessage.Error(), Name: name} } q := dnsmessage.Question{ Name: n, @@ -306,14 +306,7 @@ func (r *Resolver) tryOneName(ctx context.Context, cfg *dnsConfig, name string, p, h, err := r.exchange(ctx, server, q, cfg.timeout, cfg.useTCP, cfg.trustAD) if err != nil { - dnsErr := &DNSError{ - Err: err.Error(), - Name: name, - Server: server, - } - if nerr, ok := err.(Error); ok && nerr.Timeout() { - dnsErr.IsTimeout = true - } + dnsErr := newDNSError(err, name, server) // Set IsTemporary for socket-level errors. Note that this flag // may also be used to indicate a SERVFAIL response. if _, ok := err.(*OpError); ok { @@ -324,41 +317,26 @@ func (r *Resolver) tryOneName(ctx context.Context, cfg *dnsConfig, name string, } if err := checkHeader(&p, h); err != nil { - dnsErr := &DNSError{ - Err: err.Error(), - Name: name, - Server: server, - } - if err == errServerTemporarilyMisbehaving { - dnsErr.IsTemporary = true - } if err == errNoSuchHost { // The name does not exist, so trying // another server won't help. - - dnsErr.IsNotFound = true - return p, server, dnsErr + return p, server, newDNSError(errNoSuchHost, name, server) } - lastErr = dnsErr + lastErr = newDNSError(err, name, server) continue } - err = skipToAnswer(&p, qtype) - if err == nil { - return p, server, nil - } - lastErr = &DNSError{ - Err: err.Error(), - Name: name, - Server: server, + if err := skipToAnswer(&p, qtype); err != nil { + if err == errNoSuchHost { + // The name does not exist, so trying + // another server won't help. + return p, server, newDNSError(errNoSuchHost, name, server) + } + lastErr = newDNSError(err, name, server) + continue } - if err == errNoSuchHost { - // The name does not exist, so trying another - // server won't help. - lastErr.(*DNSError).IsNotFound = true - return p, server, lastErr - } + return p, server, nil } } return dnsmessage.Parser{}, "", lastErr @@ -458,7 +436,7 @@ func (r *Resolver) lookup(ctx context.Context, name string, qtype dnsmessage.Typ // Other lookups might allow broader name syntax // (for example Multicast DNS allows UTF-8; see RFC 6762). // For consistency with libc resolvers, report no such host. - return dnsmessage.Parser{}, "", &DNSError{Err: errNoSuchHost.Error(), Name: name, IsNotFound: true} + return dnsmessage.Parser{}, "", newDNSError(errNoSuchHost, name, "") } if conf == nil { @@ -586,7 +564,7 @@ func (r *Resolver) goLookupHostOrder(ctx context.Context, name string, order hos } if order == hostLookupFiles { - return nil, &DNSError{Err: errNoSuchHost.Error(), Name: name, IsNotFound: true} + return nil, newDNSError(errNoSuchHost, name, "") } } ips, _, err := r.goLookupIPCNAMEOrder(ctx, "ip", name, order, conf) @@ -636,13 +614,13 @@ func (r *Resolver) goLookupIPCNAMEOrder(ctx context.Context, network, name strin } if order == hostLookupFiles { - return nil, dnsmessage.Name{}, &DNSError{Err: errNoSuchHost.Error(), Name: name, IsNotFound: true} + return nil, dnsmessage.Name{}, newDNSError(errNoSuchHost, name, "") } } if !isDomainName(name) { // See comment in func lookup above about use of errNoSuchHost. - return nil, dnsmessage.Name{}, &DNSError{Err: errNoSuchHost.Error(), Name: name, IsNotFound: true} + return nil, dnsmessage.Name{}, newDNSError(errNoSuchHost, name, "") } type result struct { p dnsmessage.Parser @@ -847,7 +825,7 @@ func (r *Resolver) goLookupPTR(ctx context.Context, addr string, order hostLooku } if order == hostLookupFiles { - return nil, &DNSError{Err: errNoSuchHost.Error(), Name: addr, IsNotFound: true} + return nil, newDNSError(errNoSuchHost, addr, "") } } diff --git a/src/net/dnsclient_unix_test.go b/src/net/dnsclient_unix_test.go index 0fad9e94ba..a887485133 100644 --- a/src/net/dnsclient_unix_test.go +++ b/src/net/dnsclient_unix_test.go @@ -25,8 +25,6 @@ import ( "golang.org/x/net/dns/dnsmessage" ) -var goResolver = Resolver{PreferGo: true} - // Test address from 192.0.2.0/24 block, reserved by RFC 5737 for documentation. var TestAddr = [4]byte{0xc0, 0x00, 0x02, 0x01} @@ -1230,10 +1228,11 @@ func TestStrictErrorsLookupIP(t *testing.T) { } makeTimeout := func() error { return &DNSError{ - Err: os.ErrDeadlineExceeded.Error(), - Name: name, - Server: server, - IsTimeout: true, + Err: os.ErrDeadlineExceeded.Error(), + Name: name, + Server: server, + IsTimeout: true, + IsTemporary: true, } } makeNxDomain := func() error { @@ -1486,10 +1485,11 @@ func TestStrictErrorsLookupTXT(t *testing.T) { var wantRRs int if strict { wantErr = &DNSError{ - Err: os.ErrDeadlineExceeded.Error(), - Name: name, - Server: server, - IsTimeout: true, + Err: os.ErrDeadlineExceeded.Error(), + Name: name, + Server: server, + IsTimeout: true, + IsTemporary: true, } } else { wantRRs = 1 diff --git a/src/net/lookup.go b/src/net/lookup.go index 3ec2660786..b04dfa23b9 100644 --- a/src/net/lookup.go +++ b/src/net/lookup.go @@ -105,7 +105,7 @@ func lookupPortMapWithNetwork(network, errNetwork, service string) (port int, er if port, ok := m[string(lowerService[:n])]; ok && n == len(service) { return port, nil } - return 0, &DNSError{Err: "unknown port", Name: errNetwork + "/" + service, IsNotFound: true} + return 0, newDNSError(errUnknownPort, errNetwork+"/"+service, "") } return 0, &DNSError{Err: "unknown network", Name: errNetwork + "/" + service} } @@ -192,7 +192,7 @@ func LookupHost(host string) (addrs []string, err error) { func (r *Resolver) LookupHost(ctx context.Context, host string) (addrs []string, err error) { // Make sure that no matter what we do later, host=="" is rejected. if host == "" { - return nil, &DNSError{Err: errNoSuchHost.Error(), Name: host, IsNotFound: true} + return nil, newDNSError(errNoSuchHost, host, "") } if _, err := netip.ParseAddr(host); err == nil { return []string{host}, nil @@ -236,7 +236,7 @@ func (r *Resolver) LookupIP(ctx context.Context, network, host string) ([]IP, er } if host == "" { - return nil, &DNSError{Err: errNoSuchHost.Error(), Name: host, IsNotFound: true} + return nil, newDNSError(errNoSuchHost, host, "") } addrs, err := r.internetAddrList(ctx, afnet, host) if err != nil { @@ -304,7 +304,7 @@ func withUnexpiredValuesPreserved(lookupCtx context.Context) context.Context { func (r *Resolver) lookupIPAddr(ctx context.Context, network, host string) ([]IPAddr, error) { // Make sure that no matter what we do later, host=="" is rejected. if host == "" { - return nil, &DNSError{Err: errNoSuchHost.Error(), Name: host, IsNotFound: true} + return nil, newDNSError(errNoSuchHost, host, "") } if ip, err := netip.ParseAddr(host); err == nil { return []IPAddr{{IP: IP(ip.AsSlice()).To16(), Zone: ip.Zone()}}, nil @@ -354,12 +354,7 @@ func (r *Resolver) lookupIPAddr(ctx context.Context, network, host string) ([]IP } else { go dnsWaitGroupDone(ch, lookupGroupCancel) } - ctxErr := ctx.Err() - err := &DNSError{ - Err: mapErr(ctxErr).Error(), - Name: host, - IsTimeout: ctxErr == context.DeadlineExceeded, - } + err := newDNSError(mapErr(ctx.Err()), host, "") if trace != nil && trace.DNSDone != nil { trace.DNSDone(nil, false, err) } @@ -370,17 +365,7 @@ func (r *Resolver) lookupIPAddr(ctx context.Context, network, host string) ([]IP err := r.Err if err != nil { if _, ok := err.(*DNSError); !ok { - isTimeout := false - if err == context.DeadlineExceeded { - isTimeout = true - } else if terr, ok := err.(timeout); ok { - isTimeout = terr.Timeout() - } - err = &DNSError{ - Err: err.Error(), - Name: host, - IsTimeout: isTimeout, - } + err = newDNSError(mapErr(err), host, "") } } if trace != nil && trace.DNSDone != nil { diff --git a/src/net/lookup_plan9.go b/src/net/lookup_plan9.go index 8cfc4f6bb3..2532a0e967 100644 --- a/src/net/lookup_plan9.go +++ b/src/net/lookup_plan9.go @@ -170,9 +170,9 @@ func (*Resolver) lookupHost(ctx context.Context, host string) (addrs []string, e lines, err := queryCS(ctx, "net", host, "1") if err != nil { if stringsHasSuffix(err.Error(), "dns failure") { - return nil, &DNSError{Err: errNoSuchHost.Error(), Name: host, IsNotFound: true} + err = errNoSuchHost } - return nil, handlePlan9DNSError(err, host) + return nil, newDNSError(err, host, "") } loop: for _, line := range lines { diff --git a/src/net/lookup_test.go b/src/net/lookup_test.go index b32591a718..bd58498fbc 100644 --- a/src/net/lookup_test.go +++ b/src/net/lookup_test.go @@ -20,6 +20,8 @@ import ( "time" ) +var goResolver = Resolver{PreferGo: true} + func hasSuffixFold(s, suffix string) bool { return strings.HasSuffix(strings.ToLower(s), strings.ToLower(suffix)) } @@ -1630,3 +1632,29 @@ func TestLookupNoSuchHost(t *testing.T) { }) } } + +func TestDNSErrorUnwrap(t *testing.T) { + rDeadlineExcceeded := &Resolver{PreferGo: true, Dial: func(ctx context.Context, network, address string) (Conn, error) { + return nil, context.DeadlineExceeded + }} + rCancelled := &Resolver{PreferGo: true, Dial: func(ctx context.Context, network, address string) (Conn, error) { + return nil, context.Canceled + }} + + _, err := rDeadlineExcceeded.LookupHost(context.Background(), "test.go.dev") + if !errors.Is(err, context.DeadlineExceeded) { + t.Errorf("errors.Is(err, context.DeadlineExceeded) = false; want = true") + } + + _, err = rCancelled.LookupHost(context.Background(), "test.go.dev") + if !errors.Is(err, context.Canceled) { + t.Errorf("errors.Is(err, context.Canceled) = false; want = true") + } + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + _, err = goResolver.LookupHost(ctx, "text.go.dev") + if !errors.Is(err, context.Canceled) { + t.Errorf("errors.Is(err, context.Canceled) = false; want = true") + } +} diff --git a/src/net/lookup_windows.go b/src/net/lookup_windows.go index 946622761c..7d415bee4f 100644 --- a/src/net/lookup_windows.go +++ b/src/net/lookup_windows.go @@ -73,12 +73,7 @@ func lookupProtocol(ctx context.Context, name string) (int, error) { if proto, err := lookupProtocolMap(name); err == nil { return proto, nil } - - dnsError := &DNSError{Err: r.err.Error(), Name: name} - if r.err == errNoSuchHost { - dnsError.IsNotFound = true - } - r.err = dnsError + r.err = newDNSError(r.err, name, "") } return r.proto, r.err case <-ctx.Done(): @@ -130,7 +125,7 @@ func (r *Resolver) lookupIP(ctx context.Context, network, name string) ([]IPAddr var result *syscall.AddrinfoW name16p, err := syscall.UTF16PtrFromString(name) if err != nil { - return nil, &DNSError{Name: name, Err: err.Error()} + return nil, newDNSError(err, name, "") } dnsConf := getSystemDNSConfig() @@ -144,12 +139,7 @@ func (r *Resolver) lookupIP(ctx context.Context, network, name string) ([]IPAddr } } if e != nil { - err := winError("getaddrinfow", e) - dnsError := &DNSError{Err: err.Error(), Name: name} - if err == errNoSuchHost { - dnsError.IsNotFound = true - } - return nil, dnsError + return nil, newDNSError(winError("getaddrinfow", e), name, "") } defer syscall.FreeAddrInfoW(result) addrs := make([]IPAddr, 0, 5) @@ -164,7 +154,7 @@ func (r *Resolver) lookupIP(ctx context.Context, network, name string) ([]IPAddr zone := zoneCache.name(int((*syscall.RawSockaddrInet6)(addr).Scope_id)) addrs = append(addrs, IPAddr{IP: copyIP(a[:]), Zone: zone}) default: - return nil, &DNSError{Err: syscall.EWINDOWS.Error(), Name: name} + return nil, newDNSError(syscall.EWINDOWS, name, "") } } return addrs, nil @@ -196,11 +186,7 @@ func (r *Resolver) lookupIP(ctx context.Context, network, name string) ([]IPAddr // // For now we just let it finish and write to the // buffered channel. - return nil, &DNSError{ - Name: name, - Err: ctx.Err().Error(), - IsTimeout: ctx.Err() == context.DeadlineExceeded, - } + return nil, newDNSError(mapErr(ctx.Err()), name, "") } } @@ -252,14 +238,13 @@ func (r *Resolver) lookupPort(ctx context.Context, network, service string) (int // for _WSAHOST_NOT_FOUND here to match the cgo (unix) version // cgo_unix.go (cgoLookupServicePort). if e == _WSATYPE_NOT_FOUND || e == _WSAHOST_NOT_FOUND { - return 0, &DNSError{Err: "unknown port", Name: network + "/" + service, IsNotFound: true} + return 0, newDNSError(errUnknownPort, network+"/"+service, "") } - err := os.NewSyscallError("getaddrinfow", e) - return 0, &DNSError{Err: err.Error(), Name: network + "/" + service} + return 0, newDNSError(winError("getaddrinfow", e), network+"/"+service, "") } defer syscall.FreeAddrInfoW(result) if result == nil { - return 0, &DNSError{Err: syscall.EINVAL.Error(), Name: network + "/" + service} + return 0, newDNSError(syscall.EINVAL, network+"/"+service, "") } addr := unsafe.Pointer(result.Addr) switch result.Family { @@ -270,7 +255,7 @@ func (r *Resolver) lookupPort(ctx context.Context, network, service string) (int a := (*syscall.RawSockaddrInet6)(addr) return int(syscall.Ntohs(a.Port)), nil } - return 0, &DNSError{Err: syscall.EINVAL.Error(), Name: network + "/" + service} + return 0, newDNSError(syscall.EINVAL, network+"/"+service, "") } func (r *Resolver) lookupCNAME(ctx context.Context, name string) (string, error) { @@ -295,8 +280,7 @@ func (r *Resolver) lookupCNAME(ctx context.Context, name string) (string, error) return absDomainName(name), nil } if e != nil { - err := winError("dnsquery", e) - return "", &DNSError{Err: err.Error(), Name: name, IsNotFound: err == errNoSuchHost} + return "", newDNSError(winError("dnsquery", e), name, "") } defer syscall.DnsRecordListFree(rec, 1) @@ -327,8 +311,7 @@ func (r *Resolver) lookupSRV(ctx context.Context, service, proto, name string) ( var rec *syscall.DNSRecord e := syscall.DnsQuery(target, syscall.DNS_TYPE_SRV, 0, nil, &rec, nil) if e != nil { - err := winError("dnsquery", e) - return "", nil, &DNSError{Err: err.Error(), Name: name, IsNotFound: err == errNoSuchHost} + return "", nil, newDNSError(winError("dnsquery", e), name, "") } defer syscall.DnsRecordListFree(rec, 1) @@ -357,8 +340,7 @@ func (r *Resolver) lookupMX(ctx context.Context, name string) ([]*MX, error) { var rec *syscall.DNSRecord e := syscall.DnsQuery(name, syscall.DNS_TYPE_MX, 0, nil, &rec, nil) if e != nil { - err := winError("dnsquery", e) - return nil, &DNSError{Err: err.Error(), Name: name, IsNotFound: err == errNoSuchHost} + return nil, newDNSError(winError("dnsquery", e), name, "") } defer syscall.DnsRecordListFree(rec, 1) @@ -387,8 +369,7 @@ func (r *Resolver) lookupNS(ctx context.Context, name string) ([]*NS, error) { var rec *syscall.DNSRecord e := syscall.DnsQuery(name, syscall.DNS_TYPE_NS, 0, nil, &rec, nil) if e != nil { - err := winError("dnsquery", e) - return nil, &DNSError{Err: err.Error(), Name: name, IsNotFound: err == errNoSuchHost} + return nil, newDNSError(winError("dnsquery", e), name, "") } defer syscall.DnsRecordListFree(rec, 1) @@ -416,8 +397,7 @@ func (r *Resolver) lookupTXT(ctx context.Context, name string) ([]string, error) var rec *syscall.DNSRecord e := syscall.DnsQuery(name, syscall.DNS_TYPE_TEXT, 0, nil, &rec, nil) if e != nil { - err := winError("dnsquery", e) - return nil, &DNSError{Err: err.Error(), Name: name, IsNotFound: err == errNoSuchHost} + return nil, newDNSError(winError("dnsquery", e), name, "") } defer syscall.DnsRecordListFree(rec, 1) @@ -454,8 +434,7 @@ func (r *Resolver) lookupAddr(ctx context.Context, addr string) ([]string, error var rec *syscall.DNSRecord e := syscall.DnsQuery(arpa, syscall.DNS_TYPE_PTR, 0, nil, &rec, nil) if e != nil { - err := winError("dnsquery", e) - return nil, &DNSError{Err: err.Error(), Name: addr, IsNotFound: err == errNoSuchHost} + return nil, newDNSError(winError("dnsquery", e), addr, "") } defer syscall.DnsRecordListFree(rec, 1) diff --git a/src/net/net.go b/src/net/net.go index d0db65286b..deaeea4081 100644 --- a/src/net/net.go +++ b/src/net/net.go @@ -617,11 +617,27 @@ func (e *DNSConfigError) Temporary() bool { return false } // Various errors contained in DNSError. var ( - errNoSuchHost = errors.New("no such host") + errNoSuchHost = ¬FoundError{"no such host"} + errUnknownPort = ¬FoundError{"unknown port"} ) +// notFoundError is a special error understood by the newDNSError function, +// which causes a creation of a DNSError with IsNotFound field set to true. +type notFoundError struct{ s string } + +func (e *notFoundError) Error() string { return e.s } + +// temporaryError is an error type that implements the [Error] interface. +// It returns true from the Temporary method. +type temporaryError struct{ s string } + +func (e *temporaryError) Error() string { return e.s } +func (e *temporaryError) Temporary() bool { return true } +func (e *temporaryError) Timeout() bool { return false } + // DNSError represents a DNS lookup error. type DNSError struct { + UnwrapErr error // error returned by the [DNSError.Unwrap] method, might be nil Err string // description of the error Name string // name looked for Server string // server used @@ -634,6 +650,41 @@ type DNSError struct { IsNotFound bool } +// newDNSError creates a new *DNSError. +// Based on the err, it sets the UnwrapErr, IsTimeout, IsTemporary, IsNotFound fields. +func newDNSError(err error, name, server string) *DNSError { + var ( + isTimeout bool + isTemporary bool + unwrapErr error + ) + + if err, ok := err.(Error); ok { + isTimeout = err.Timeout() + isTemporary = err.Temporary() + } + + // At this time, the only errors we wrap are context errors, to allow + // users to check for canceled/timed out requests. + if errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled) { + unwrapErr = err + } + + _, isNotFound := err.(*notFoundError) + return &DNSError{ + UnwrapErr: unwrapErr, + Err: err.Error(), + Name: name, + Server: server, + IsTimeout: isTimeout, + IsTemporary: isTemporary, + IsNotFound: isNotFound, + } +} + +// Unwrap returns e.UnwrapErr. +func (e *DNSError) Unwrap() error { return e.UnwrapErr } + func (e *DNSError) Error() string { if e == nil { return "" -- 2.48.1