]> Cypherpunks repositories - gostls13.git/commitdiff
net: add Unwrap to *DNSError
authorMateusz Poliwczak <mpoliwczak34@gmail.com>
Sat, 13 Apr 2024 07:01:44 +0000 (07:01 +0000)
committerGopher Robot <gobot@golang.org>
Sun, 14 Apr 2024 18:23:45 +0000 (18:23 +0000)
Fixes #63116

Change-Id: Iab8c415555ab85097be6d2d133b3349c5219a23b
GitHub-Last-Rev: 8a8177b9af5509ebbaa701b06c79126aae7510a8
GitHub-Pull-Request: golang/go#63348
Reviewed-on: https://go-review.googlesource.com/c/go/+/532217
Auto-Submit: Ian Lance Taylor <iant@google.com>
Reviewed-by: Damien Neil <dneil@google.com>
Reviewed-by: Ian Lance Taylor <iant@google.com>
Commit-Queue: Ian Lance Taylor <iant@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>

api/next/63116.txt [new file with mode: 0644]
doc/next/6-stdlib/99-minor/net/63116.md [new file with mode: 0644]
src/net/cgo_unix.go
src/net/dnsclient_unix.go
src/net/dnsclient_unix_test.go
src/net/lookup.go
src/net/lookup_plan9.go
src/net/lookup_test.go
src/net/lookup_windows.go
src/net/net.go

diff --git a/api/next/63116.txt b/api/next/63116.txt
new file mode 100644 (file)
index 0000000..47214a9
--- /dev/null
@@ -0,0 +1,2 @@
+pkg net, type DNSError struct, UnwrapErr error #63116
+pkg net, method (*DNSError) Unwrap() error #63116
diff --git a/doc/next/6-stdlib/99-minor/net/63116.md b/doc/next/6-stdlib/99-minor/net/63116.md
new file mode 100644 (file)
index 0000000..d847a55
--- /dev/null
@@ -0,0 +1,3 @@
+The [`DNSError`](/pkg/net#DNSError) type now wraps errors caused by timeouts
+or cancelation. For example, `errors.Is(someDNSErr, context.DeadlineExceedeed)`
+will now report whether a DNS error was caused by a timeout.
index 1858e495d22238c8c2a45b2bac8f0d315098a729..bc374c2c7688b4e2c143b7a7ea2e706a69bd9e30 100644 (file)
@@ -132,19 +132,17 @@ func cgoLookupServicePort(hints *_C_struct_addrinfo, network, service string) (p
        var res *_C_struct_addrinfo
        gerrno, err := _C_getaddrinfo(nil, (*_C_char)(unsafe.Pointer(&cservice[0])), hints, &res)
        if gerrno != 0 {
-               isTemporary := false
                switch gerrno {
                case _C_EAI_SYSTEM:
                        if err == nil { // see golang.org/issue/6232
                                err = syscall.EMFILE
                        }
+                       return 0, newDNSError(err, network+"/"+service, "")
                case _C_EAI_SERVICE, _C_EAI_NONAME: // Darwin returns EAI_NONAME.
-                       return 0, &DNSError{Err: "unknown port", Name: network + "/" + service, IsNotFound: true}
+                       return 0, newDNSError(errUnknownPort, network+"/"+service, "")
                default:
-                       err = addrinfoErrno(gerrno)
-                       isTemporary = addrinfoErrno(gerrno).Temporary()
+                       return 0, newDNSError(addrinfoErrno(gerrno), network+"/"+service, "")
                }
-               return 0, &DNSError{Err: err.Error(), Name: network + "/" + service, IsTemporary: isTemporary}
        }
        defer _C_freeaddrinfo(res)
 
@@ -160,7 +158,7 @@ func cgoLookupServicePort(hints *_C_struct_addrinfo, network, service string) (p
                        return int(p[0])<<8 | int(p[1]), nil
                }
        }
-       return 0, &DNSError{Err: "unknown port", Name: network + "/" + service, IsNotFound: true}
+       return 0, newDNSError(errUnknownPort, network+"/"+service, "")
 }
 
 func cgoLookupHostIP(network, name string) (addrs []IPAddr, err error) {
@@ -182,8 +180,6 @@ func cgoLookupHostIP(network, name string) (addrs []IPAddr, err error) {
        var res *_C_struct_addrinfo
        gerrno, err := _C_getaddrinfo((*_C_char)(unsafe.Pointer(h)), nil, &hints, &res)
        if gerrno != 0 {
-               isErrorNoSuchHost := false
-               isTemporary := false
                switch gerrno {
                case _C_EAI_SYSTEM:
                        if err == nil {
@@ -196,15 +192,13 @@ func cgoLookupHostIP(network, name string) (addrs []IPAddr, err error) {
                                // comes up again. golang.org/issue/6232.
                                err = syscall.EMFILE
                        }
+                       return nil, newDNSError(err, name, "")
                case _C_EAI_NONAME, _C_EAI_NODATA:
-                       err = errNoSuchHost
-                       isErrorNoSuchHost = true
+                       return nil, newDNSError(errNoSuchHost, name, "")
                default:
-                       err = addrinfoErrno(gerrno)
-                       isTemporary = addrinfoErrno(gerrno).Temporary()
+                       return nil, newDNSError(addrinfoErrno(gerrno), name, "")
                }
 
-               return nil, &DNSError{Err: err.Error(), Name: name, IsNotFound: isErrorNoSuchHost, IsTemporary: isTemporary}
        }
        defer _C_freeaddrinfo(res)
 
@@ -272,21 +266,17 @@ func cgoLookupAddrPTR(addr string, sa *_C_struct_sockaddr, salen _C_socklen_t) (
                }
        }
        if gerrno != 0 {
-               isErrorNoSuchHost := false
-               isTemporary := false
                switch gerrno {
                case _C_EAI_SYSTEM:
                        if err == nil { // see golang.org/issue/6232
                                err = syscall.EMFILE
                        }
+                       return nil, newDNSError(err, addr, "")
                case _C_EAI_NONAME:
-                       err = errNoSuchHost
-                       isErrorNoSuchHost = true
+                       return nil, newDNSError(errNoSuchHost, addr, "")
                default:
-                       err = addrinfoErrno(gerrno)
-                       isTemporary = addrinfoErrno(gerrno).Temporary()
+                       return nil, newDNSError(addrinfoErrno(gerrno), addr, "")
                }
-               return nil, &DNSError{Err: err.Error(), Name: addr, IsTemporary: isTemporary, IsNotFound: isErrorNoSuchHost}
        }
        if i := bytealg.IndexByte(b, 0); i != -1 {
                b = b[:i]
index ad5c245dbf74b3167f723381bd87b9d6f9f5bd66..8193189cc76f9dfffa56712b8cc4ab7dd3e87a51 100644 (file)
@@ -48,7 +48,7 @@ var (
        // errServerTemporarilyMisbehaving is like errServerMisbehaving, except
        // that when it gets translated to a DNSError, the IsTemporary field
        // gets set to true.
-       errServerTemporarilyMisbehaving = errors.New("server misbehaving")
+       errServerTemporarilyMisbehaving = &temporaryError{"server misbehaving"}
 )
 
 func newRequest(q dnsmessage.Question, ad bool) (id uint16, udpReq, tcpReq []byte, err error) {
@@ -292,7 +292,7 @@ func (r *Resolver) tryOneName(ctx context.Context, cfg *dnsConfig, name string,
 
        n, err := dnsmessage.NewName(name)
        if err != nil {
-               return dnsmessage.Parser{}, "", errCannotMarshalDNSMessage
+               return dnsmessage.Parser{}, "", &DNSError{Err: errCannotMarshalDNSMessage.Error(), Name: name}
        }
        q := dnsmessage.Question{
                Name:  n,
@@ -306,14 +306,7 @@ func (r *Resolver) tryOneName(ctx context.Context, cfg *dnsConfig, name string,
 
                        p, h, err := r.exchange(ctx, server, q, cfg.timeout, cfg.useTCP, cfg.trustAD)
                        if err != nil {
-                               dnsErr := &DNSError{
-                                       Err:    err.Error(),
-                                       Name:   name,
-                                       Server: server,
-                               }
-                               if nerr, ok := err.(Error); ok && nerr.Timeout() {
-                                       dnsErr.IsTimeout = true
-                               }
+                               dnsErr := newDNSError(err, name, server)
                                // Set IsTemporary for socket-level errors. Note that this flag
                                // may also be used to indicate a SERVFAIL response.
                                if _, ok := err.(*OpError); ok {
@@ -324,41 +317,26 @@ func (r *Resolver) tryOneName(ctx context.Context, cfg *dnsConfig, name string,
                        }
 
                        if err := checkHeader(&p, h); err != nil {
-                               dnsErr := &DNSError{
-                                       Err:    err.Error(),
-                                       Name:   name,
-                                       Server: server,
-                               }
-                               if err == errServerTemporarilyMisbehaving {
-                                       dnsErr.IsTemporary = true
-                               }
                                if err == errNoSuchHost {
                                        // The name does not exist, so trying
                                        // another server won't help.
-
-                                       dnsErr.IsNotFound = true
-                                       return p, server, dnsErr
+                                       return p, server, newDNSError(errNoSuchHost, name, server)
                                }
-                               lastErr = dnsErr
+                               lastErr = newDNSError(err, name, server)
                                continue
                        }
 
-                       err = skipToAnswer(&p, qtype)
-                       if err == nil {
-                               return p, server, nil
-                       }
-                       lastErr = &DNSError{
-                               Err:    err.Error(),
-                               Name:   name,
-                               Server: server,
+                       if err := skipToAnswer(&p, qtype); err != nil {
+                               if err == errNoSuchHost {
+                                       // The name does not exist, so trying
+                                       // another server won't help.
+                                       return p, server, newDNSError(errNoSuchHost, name, server)
+                               }
+                               lastErr = newDNSError(err, name, server)
+                               continue
                        }
-                       if err == errNoSuchHost {
-                               // The name does not exist, so trying another
-                               // server won't help.
 
-                               lastErr.(*DNSError).IsNotFound = true
-                               return p, server, lastErr
-                       }
+                       return p, server, nil
                }
        }
        return dnsmessage.Parser{}, "", lastErr
@@ -458,7 +436,7 @@ func (r *Resolver) lookup(ctx context.Context, name string, qtype dnsmessage.Typ
                // Other lookups might allow broader name syntax
                // (for example Multicast DNS allows UTF-8; see RFC 6762).
                // For consistency with libc resolvers, report no such host.
-               return dnsmessage.Parser{}, "", &DNSError{Err: errNoSuchHost.Error(), Name: name, IsNotFound: true}
+               return dnsmessage.Parser{}, "", newDNSError(errNoSuchHost, name, "")
        }
 
        if conf == nil {
@@ -586,7 +564,7 @@ func (r *Resolver) goLookupHostOrder(ctx context.Context, name string, order hos
                }
 
                if order == hostLookupFiles {
-                       return nil, &DNSError{Err: errNoSuchHost.Error(), Name: name, IsNotFound: true}
+                       return nil, newDNSError(errNoSuchHost, name, "")
                }
        }
        ips, _, err := r.goLookupIPCNAMEOrder(ctx, "ip", name, order, conf)
@@ -636,13 +614,13 @@ func (r *Resolver) goLookupIPCNAMEOrder(ctx context.Context, network, name strin
                }
 
                if order == hostLookupFiles {
-                       return nil, dnsmessage.Name{}, &DNSError{Err: errNoSuchHost.Error(), Name: name, IsNotFound: true}
+                       return nil, dnsmessage.Name{}, newDNSError(errNoSuchHost, name, "")
                }
        }
 
        if !isDomainName(name) {
                // See comment in func lookup above about use of errNoSuchHost.
-               return nil, dnsmessage.Name{}, &DNSError{Err: errNoSuchHost.Error(), Name: name, IsNotFound: true}
+               return nil, dnsmessage.Name{}, newDNSError(errNoSuchHost, name, "")
        }
        type result struct {
                p      dnsmessage.Parser
@@ -847,7 +825,7 @@ func (r *Resolver) goLookupPTR(ctx context.Context, addr string, order hostLooku
                }
 
                if order == hostLookupFiles {
-                       return nil, &DNSError{Err: errNoSuchHost.Error(), Name: addr, IsNotFound: true}
+                       return nil, newDNSError(errNoSuchHost, addr, "")
                }
        }
 
index 0fad9e94ba19dfb250b7f464dc9459c7b2c4b17c..a8874851335d83d888a87d40c5ee6d43d7bc6b7e 100644 (file)
@@ -25,8 +25,6 @@ import (
        "golang.org/x/net/dns/dnsmessage"
 )
 
-var goResolver = Resolver{PreferGo: true}
-
 // Test address from 192.0.2.0/24 block, reserved by RFC 5737 for documentation.
 var TestAddr = [4]byte{0xc0, 0x00, 0x02, 0x01}
 
@@ -1230,10 +1228,11 @@ func TestStrictErrorsLookupIP(t *testing.T) {
        }
        makeTimeout := func() error {
                return &DNSError{
-                       Err:       os.ErrDeadlineExceeded.Error(),
-                       Name:      name,
-                       Server:    server,
-                       IsTimeout: true,
+                       Err:         os.ErrDeadlineExceeded.Error(),
+                       Name:        name,
+                       Server:      server,
+                       IsTimeout:   true,
+                       IsTemporary: true,
                }
        }
        makeNxDomain := func() error {
@@ -1486,10 +1485,11 @@ func TestStrictErrorsLookupTXT(t *testing.T) {
                var wantRRs int
                if strict {
                        wantErr = &DNSError{
-                               Err:       os.ErrDeadlineExceeded.Error(),
-                               Name:      name,
-                               Server:    server,
-                               IsTimeout: true,
+                               Err:         os.ErrDeadlineExceeded.Error(),
+                               Name:        name,
+                               Server:      server,
+                               IsTimeout:   true,
+                               IsTemporary: true,
                        }
                } else {
                        wantRRs = 1
index 3ec2660786094ef3360b108e8a510fd194b95c06..b04dfa23b9877b0b4afb5967fb764bc903b719fb 100644 (file)
@@ -105,7 +105,7 @@ func lookupPortMapWithNetwork(network, errNetwork, service string) (port int, er
                if port, ok := m[string(lowerService[:n])]; ok && n == len(service) {
                        return port, nil
                }
-               return 0, &DNSError{Err: "unknown port", Name: errNetwork + "/" + service, IsNotFound: true}
+               return 0, newDNSError(errUnknownPort, errNetwork+"/"+service, "")
        }
        return 0, &DNSError{Err: "unknown network", Name: errNetwork + "/" + service}
 }
@@ -192,7 +192,7 @@ func LookupHost(host string) (addrs []string, err error) {
 func (r *Resolver) LookupHost(ctx context.Context, host string) (addrs []string, err error) {
        // Make sure that no matter what we do later, host=="" is rejected.
        if host == "" {
-               return nil, &DNSError{Err: errNoSuchHost.Error(), Name: host, IsNotFound: true}
+               return nil, newDNSError(errNoSuchHost, host, "")
        }
        if _, err := netip.ParseAddr(host); err == nil {
                return []string{host}, nil
@@ -236,7 +236,7 @@ func (r *Resolver) LookupIP(ctx context.Context, network, host string) ([]IP, er
        }
 
        if host == "" {
-               return nil, &DNSError{Err: errNoSuchHost.Error(), Name: host, IsNotFound: true}
+               return nil, newDNSError(errNoSuchHost, host, "")
        }
        addrs, err := r.internetAddrList(ctx, afnet, host)
        if err != nil {
@@ -304,7 +304,7 @@ func withUnexpiredValuesPreserved(lookupCtx context.Context) context.Context {
 func (r *Resolver) lookupIPAddr(ctx context.Context, network, host string) ([]IPAddr, error) {
        // Make sure that no matter what we do later, host=="" is rejected.
        if host == "" {
-               return nil, &DNSError{Err: errNoSuchHost.Error(), Name: host, IsNotFound: true}
+               return nil, newDNSError(errNoSuchHost, host, "")
        }
        if ip, err := netip.ParseAddr(host); err == nil {
                return []IPAddr{{IP: IP(ip.AsSlice()).To16(), Zone: ip.Zone()}}, nil
@@ -354,12 +354,7 @@ func (r *Resolver) lookupIPAddr(ctx context.Context, network, host string) ([]IP
                } else {
                        go dnsWaitGroupDone(ch, lookupGroupCancel)
                }
-               ctxErr := ctx.Err()
-               err := &DNSError{
-                       Err:       mapErr(ctxErr).Error(),
-                       Name:      host,
-                       IsTimeout: ctxErr == context.DeadlineExceeded,
-               }
+               err := newDNSError(mapErr(ctx.Err()), host, "")
                if trace != nil && trace.DNSDone != nil {
                        trace.DNSDone(nil, false, err)
                }
@@ -370,17 +365,7 @@ func (r *Resolver) lookupIPAddr(ctx context.Context, network, host string) ([]IP
                err := r.Err
                if err != nil {
                        if _, ok := err.(*DNSError); !ok {
-                               isTimeout := false
-                               if err == context.DeadlineExceeded {
-                                       isTimeout = true
-                               } else if terr, ok := err.(timeout); ok {
-                                       isTimeout = terr.Timeout()
-                               }
-                               err = &DNSError{
-                                       Err:       err.Error(),
-                                       Name:      host,
-                                       IsTimeout: isTimeout,
-                               }
+                               err = newDNSError(mapErr(err), host, "")
                        }
                }
                if trace != nil && trace.DNSDone != nil {
index 8cfc4f6bb3b02d3d99d073ed850be6b84483666f..2532a0e967b25838ebcff5560fffb83bc43559f8 100644 (file)
@@ -170,9 +170,9 @@ func (*Resolver) lookupHost(ctx context.Context, host string) (addrs []string, e
        lines, err := queryCS(ctx, "net", host, "1")
        if err != nil {
                if stringsHasSuffix(err.Error(), "dns failure") {
-                       return nil, &DNSError{Err: errNoSuchHost.Error(), Name: host, IsNotFound: true}
+                       err = errNoSuchHost
                }
-               return nil, handlePlan9DNSError(err, host)
+               return nil, newDNSError(err, host, "")
        }
 loop:
        for _, line := range lines {
index b32591a718479d46a71c14c930329489bd03cc35..bd58498fbc7224124761552c93b758baa2091060 100644 (file)
@@ -20,6 +20,8 @@ import (
        "time"
 )
 
+var goResolver = Resolver{PreferGo: true}
+
 func hasSuffixFold(s, suffix string) bool {
        return strings.HasSuffix(strings.ToLower(s), strings.ToLower(suffix))
 }
@@ -1630,3 +1632,29 @@ func TestLookupNoSuchHost(t *testing.T) {
                })
        }
 }
+
+func TestDNSErrorUnwrap(t *testing.T) {
+       rDeadlineExcceeded := &Resolver{PreferGo: true, Dial: func(ctx context.Context, network, address string) (Conn, error) {
+               return nil, context.DeadlineExceeded
+       }}
+       rCancelled := &Resolver{PreferGo: true, Dial: func(ctx context.Context, network, address string) (Conn, error) {
+               return nil, context.Canceled
+       }}
+
+       _, err := rDeadlineExcceeded.LookupHost(context.Background(), "test.go.dev")
+       if !errors.Is(err, context.DeadlineExceeded) {
+               t.Errorf("errors.Is(err, context.DeadlineExceeded) = false; want = true")
+       }
+
+       _, err = rCancelled.LookupHost(context.Background(), "test.go.dev")
+       if !errors.Is(err, context.Canceled) {
+               t.Errorf("errors.Is(err, context.Canceled) = false; want = true")
+       }
+
+       ctx, cancel := context.WithCancel(context.Background())
+       cancel()
+       _, err = goResolver.LookupHost(ctx, "text.go.dev")
+       if !errors.Is(err, context.Canceled) {
+               t.Errorf("errors.Is(err, context.Canceled) = false; want = true")
+       }
+}
index 946622761cf202d50539de8464bea63e4483b09b..7d415bee4f07b66f1a66e0d16858d1c0097a7a30 100644 (file)
@@ -73,12 +73,7 @@ func lookupProtocol(ctx context.Context, name string) (int, error) {
                        if proto, err := lookupProtocolMap(name); err == nil {
                                return proto, nil
                        }
-
-                       dnsError := &DNSError{Err: r.err.Error(), Name: name}
-                       if r.err == errNoSuchHost {
-                               dnsError.IsNotFound = true
-                       }
-                       r.err = dnsError
+                       r.err = newDNSError(r.err, name, "")
                }
                return r.proto, r.err
        case <-ctx.Done():
@@ -130,7 +125,7 @@ func (r *Resolver) lookupIP(ctx context.Context, network, name string) ([]IPAddr
                var result *syscall.AddrinfoW
                name16p, err := syscall.UTF16PtrFromString(name)
                if err != nil {
-                       return nil, &DNSError{Name: name, Err: err.Error()}
+                       return nil, newDNSError(err, name, "")
                }
 
                dnsConf := getSystemDNSConfig()
@@ -144,12 +139,7 @@ func (r *Resolver) lookupIP(ctx context.Context, network, name string) ([]IPAddr
                        }
                }
                if e != nil {
-                       err := winError("getaddrinfow", e)
-                       dnsError := &DNSError{Err: err.Error(), Name: name}
-                       if err == errNoSuchHost {
-                               dnsError.IsNotFound = true
-                       }
-                       return nil, dnsError
+                       return nil, newDNSError(winError("getaddrinfow", e), name, "")
                }
                defer syscall.FreeAddrInfoW(result)
                addrs := make([]IPAddr, 0, 5)
@@ -164,7 +154,7 @@ func (r *Resolver) lookupIP(ctx context.Context, network, name string) ([]IPAddr
                                zone := zoneCache.name(int((*syscall.RawSockaddrInet6)(addr).Scope_id))
                                addrs = append(addrs, IPAddr{IP: copyIP(a[:]), Zone: zone})
                        default:
-                               return nil, &DNSError{Err: syscall.EWINDOWS.Error(), Name: name}
+                               return nil, newDNSError(syscall.EWINDOWS, name, "")
                        }
                }
                return addrs, nil
@@ -196,11 +186,7 @@ func (r *Resolver) lookupIP(ctx context.Context, network, name string) ([]IPAddr
                //
                // For now we just let it finish and write to the
                // buffered channel.
-               return nil, &DNSError{
-                       Name:      name,
-                       Err:       ctx.Err().Error(),
-                       IsTimeout: ctx.Err() == context.DeadlineExceeded,
-               }
+               return nil, newDNSError(mapErr(ctx.Err()), name, "")
        }
 }
 
@@ -252,14 +238,13 @@ func (r *Resolver) lookupPort(ctx context.Context, network, service string) (int
                // for _WSAHOST_NOT_FOUND here to match the cgo (unix) version
                // cgo_unix.go (cgoLookupServicePort).
                if e == _WSATYPE_NOT_FOUND || e == _WSAHOST_NOT_FOUND {
-                       return 0, &DNSError{Err: "unknown port", Name: network + "/" + service, IsNotFound: true}
+                       return 0, newDNSError(errUnknownPort, network+"/"+service, "")
                }
-               err := os.NewSyscallError("getaddrinfow", e)
-               return 0, &DNSError{Err: err.Error(), Name: network + "/" + service}
+               return 0, newDNSError(winError("getaddrinfow", e), network+"/"+service, "")
        }
        defer syscall.FreeAddrInfoW(result)
        if result == nil {
-               return 0, &DNSError{Err: syscall.EINVAL.Error(), Name: network + "/" + service}
+               return 0, newDNSError(syscall.EINVAL, network+"/"+service, "")
        }
        addr := unsafe.Pointer(result.Addr)
        switch result.Family {
@@ -270,7 +255,7 @@ func (r *Resolver) lookupPort(ctx context.Context, network, service string) (int
                a := (*syscall.RawSockaddrInet6)(addr)
                return int(syscall.Ntohs(a.Port)), nil
        }
-       return 0, &DNSError{Err: syscall.EINVAL.Error(), Name: network + "/" + service}
+       return 0, newDNSError(syscall.EINVAL, network+"/"+service, "")
 }
 
 func (r *Resolver) lookupCNAME(ctx context.Context, name string) (string, error) {
@@ -295,8 +280,7 @@ func (r *Resolver) lookupCNAME(ctx context.Context, name string) (string, error)
                return absDomainName(name), nil
        }
        if e != nil {
-               err := winError("dnsquery", e)
-               return "", &DNSError{Err: err.Error(), Name: name, IsNotFound: err == errNoSuchHost}
+               return "", newDNSError(winError("dnsquery", e), name, "")
        }
        defer syscall.DnsRecordListFree(rec, 1)
 
@@ -327,8 +311,7 @@ func (r *Resolver) lookupSRV(ctx context.Context, service, proto, name string) (
        var rec *syscall.DNSRecord
        e := syscall.DnsQuery(target, syscall.DNS_TYPE_SRV, 0, nil, &rec, nil)
        if e != nil {
-               err := winError("dnsquery", e)
-               return "", nil, &DNSError{Err: err.Error(), Name: name, IsNotFound: err == errNoSuchHost}
+               return "", nil, newDNSError(winError("dnsquery", e), name, "")
        }
        defer syscall.DnsRecordListFree(rec, 1)
 
@@ -357,8 +340,7 @@ func (r *Resolver) lookupMX(ctx context.Context, name string) ([]*MX, error) {
        var rec *syscall.DNSRecord
        e := syscall.DnsQuery(name, syscall.DNS_TYPE_MX, 0, nil, &rec, nil)
        if e != nil {
-               err := winError("dnsquery", e)
-               return nil, &DNSError{Err: err.Error(), Name: name, IsNotFound: err == errNoSuchHost}
+               return nil, newDNSError(winError("dnsquery", e), name, "")
        }
        defer syscall.DnsRecordListFree(rec, 1)
 
@@ -387,8 +369,7 @@ func (r *Resolver) lookupNS(ctx context.Context, name string) ([]*NS, error) {
        var rec *syscall.DNSRecord
        e := syscall.DnsQuery(name, syscall.DNS_TYPE_NS, 0, nil, &rec, nil)
        if e != nil {
-               err := winError("dnsquery", e)
-               return nil, &DNSError{Err: err.Error(), Name: name, IsNotFound: err == errNoSuchHost}
+               return nil, newDNSError(winError("dnsquery", e), name, "")
        }
        defer syscall.DnsRecordListFree(rec, 1)
 
@@ -416,8 +397,7 @@ func (r *Resolver) lookupTXT(ctx context.Context, name string) ([]string, error)
        var rec *syscall.DNSRecord
        e := syscall.DnsQuery(name, syscall.DNS_TYPE_TEXT, 0, nil, &rec, nil)
        if e != nil {
-               err := winError("dnsquery", e)
-               return nil, &DNSError{Err: err.Error(), Name: name, IsNotFound: err == errNoSuchHost}
+               return nil, newDNSError(winError("dnsquery", e), name, "")
        }
        defer syscall.DnsRecordListFree(rec, 1)
 
@@ -454,8 +434,7 @@ func (r *Resolver) lookupAddr(ctx context.Context, addr string) ([]string, error
        var rec *syscall.DNSRecord
        e := syscall.DnsQuery(arpa, syscall.DNS_TYPE_PTR, 0, nil, &rec, nil)
        if e != nil {
-               err := winError("dnsquery", e)
-               return nil, &DNSError{Err: err.Error(), Name: addr, IsNotFound: err == errNoSuchHost}
+               return nil, newDNSError(winError("dnsquery", e), addr, "")
        }
        defer syscall.DnsRecordListFree(rec, 1)
 
index d0db65286b180c5f141ab1108351c416331821a8..deaeea40818a3f7fdb26b9a525b90ab0c185b4d9 100644 (file)
@@ -617,11 +617,27 @@ func (e *DNSConfigError) Temporary() bool { return false }
 
 // Various errors contained in DNSError.
 var (
-       errNoSuchHost = errors.New("no such host")
+       errNoSuchHost  = &notFoundError{"no such host"}
+       errUnknownPort = &notFoundError{"unknown port"}
 )
 
+// notFoundError is a special error understood by the newDNSError function,
+// which causes a creation of a DNSError with IsNotFound field set to true.
+type notFoundError struct{ s string }
+
+func (e *notFoundError) Error() string { return e.s }
+
+// temporaryError is an error type that implements the [Error] interface.
+// It returns true from the Temporary method.
+type temporaryError struct{ s string }
+
+func (e *temporaryError) Error() string   { return e.s }
+func (e *temporaryError) Temporary() bool { return true }
+func (e *temporaryError) Timeout() bool   { return false }
+
 // DNSError represents a DNS lookup error.
 type DNSError struct {
+       UnwrapErr   error  // error returned by the [DNSError.Unwrap] method, might be nil
        Err         string // description of the error
        Name        string // name looked for
        Server      string // server used
@@ -634,6 +650,41 @@ type DNSError struct {
        IsNotFound bool
 }
 
+// newDNSError creates a new *DNSError.
+// Based on the err, it sets the UnwrapErr, IsTimeout, IsTemporary, IsNotFound fields.
+func newDNSError(err error, name, server string) *DNSError {
+       var (
+               isTimeout   bool
+               isTemporary bool
+               unwrapErr   error
+       )
+
+       if err, ok := err.(Error); ok {
+               isTimeout = err.Timeout()
+               isTemporary = err.Temporary()
+       }
+
+       // At this time, the only errors we wrap are context errors, to allow
+       // users to check for canceled/timed out requests.
+       if errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled) {
+               unwrapErr = err
+       }
+
+       _, isNotFound := err.(*notFoundError)
+       return &DNSError{
+               UnwrapErr:   unwrapErr,
+               Err:         err.Error(),
+               Name:        name,
+               Server:      server,
+               IsTimeout:   isTimeout,
+               IsTemporary: isTemporary,
+               IsNotFound:  isNotFound,
+       }
+}
+
+// Unwrap returns e.UnwrapErr.
+func (e *DNSError) Unwrap() error { return e.UnwrapErr }
+
 func (e *DNSError) Error() string {
        if e == nil {
                return "<nil>"