]> Cypherpunks repositories - gostls13.git/commitdiff
net: fix handling of Conns created by Resolver.Dial
authorIan Gudger <igudger@google.com>
Tue, 24 Jul 2018 23:19:19 +0000 (16:19 -0700)
committerBrad Fitzpatrick <bradfitz@golang.org>
Tue, 24 Jul 2018 23:54:08 +0000 (23:54 +0000)
The DNS client in net is documented to treat Conns returned by
Resolver.Dial which implement PacketConn as UDP and those which don't as
TCP regardless of what was requested. golang.org/cl/37879 changed the
DNS client to assume that the Conn returned by Resolver.Dial was the
requested type which broke compatibility.

Fixes #26573
Updates #16218

Change-Id: Idf4f073a4cc3b1db36a3804898df206907f9c43c
Reviewed-on: https://go-review.googlesource.com/125735
Run-TryBot: Ian Gudger <igudger@google.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
src/net/dnsclient_unix.go
src/net/dnsclient_unix_test.go

index fe00fe19fe366ef7d10ce6f9d4be285e265d858e..2fee3346e92dcffea35a604cddfbd41644d0fbef 100644 (file)
@@ -137,10 +137,10 @@ func (r *Resolver) exchange(ctx context.Context, server string, q dnsmessage.Que
                }
                var p dnsmessage.Parser
                var h dnsmessage.Header
-               if network == "tcp" {
-                       p, h, err = dnsStreamRoundTrip(c, id, q, tcpReq)
-               } else {
+               if _, ok := c.(PacketConn); ok {
                        p, h, err = dnsPacketRoundTrip(c, id, q, udpReq)
+               } else {
+                       p, h, err = dnsStreamRoundTrip(c, id, q, tcpReq)
                }
                c.Close()
                if err != nil {
index a95b2fe645175e1475243fcc17f2075ed1ad1e4c..bb014b903ab9335a1d205a94a7813ff93b9cca3d 100644 (file)
@@ -113,7 +113,7 @@ var specialDomainNameTests = []struct {
 }
 
 func TestSpecialDomainName(t *testing.T) {
-       fake := fakeDNSServer{func(_, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
+       fake := fakeDNSServer{rh: func(_, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
                r := dnsmessage.Message{
                        Header: dnsmessage.Header{
                                ID:       q.ID,
@@ -189,7 +189,7 @@ func TestAvoidDNSName(t *testing.T) {
        }
 }
 
-var fakeDNSServerSuccessful = fakeDNSServer{func(_, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
+var fakeDNSServerSuccessful = fakeDNSServer{rh: func(_, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
        r := dnsmessage.Message{
                Header: dnsmessage.Header{
                        ID:       q.ID,
@@ -473,7 +473,7 @@ var goLookupIPWithResolverConfigTests = []struct {
 
 func TestGoLookupIPWithResolverConfig(t *testing.T) {
        defer dnsWaitGroup.Wait()
-       fake := fakeDNSServer{func(n, s string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
+       fake := fakeDNSServer{rh: func(n, s string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
                switch s {
                case "[2001:4860:4860::8888]:53", "8.8.8.8:53":
                        break
@@ -571,7 +571,7 @@ func TestGoLookupIPWithResolverConfig(t *testing.T) {
 func TestGoLookupIPOrderFallbackToFile(t *testing.T) {
        defer dnsWaitGroup.Wait()
 
-       fake := fakeDNSServer{func(n, s string, q dnsmessage.Message, tm time.Time) (dnsmessage.Message, error) {
+       fake := fakeDNSServer{rh: func(n, s string, q dnsmessage.Message, tm time.Time) (dnsmessage.Message, error) {
                r := dnsmessage.Message{
                        Header: dnsmessage.Header{
                                ID:       q.ID,
@@ -641,7 +641,7 @@ func TestErrorForOriginalNameWhenSearching(t *testing.T) {
                t.Fatal(err)
        }
 
-       fake := fakeDNSServer{func(_, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
+       fake := fakeDNSServer{rh: func(_, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
                r := dnsmessage.Message{
                        Header: dnsmessage.Header{
                                ID:       q.ID,
@@ -696,7 +696,7 @@ func TestIgnoreLameReferrals(t *testing.T) {
                t.Fatal(err)
        }
 
-       fake := fakeDNSServer{func(_, s string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
+       fake := fakeDNSServer{rh: func(_, s string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
                t.Log(s, q)
                r := dnsmessage.Message{
                        Header: dnsmessage.Header{
@@ -788,12 +788,15 @@ func BenchmarkGoLookupIPWithBrokenNameServer(b *testing.B) {
 }
 
 type fakeDNSServer struct {
-       rh func(n, s string, q dnsmessage.Message, t time.Time) (dnsmessage.Message, error)
+       rh        func(n, s string, q dnsmessage.Message, t time.Time) (dnsmessage.Message, error)
+       alwaysTCP bool
 }
 
 func (server *fakeDNSServer) DialContext(_ context.Context, n, s string) (Conn, error) {
-       tcp := n == "tcp" || n == "tcp4" || n == "tcp6"
-       return &fakeDNSConn{tcp: tcp, server: server, n: n, s: s}, nil
+       if server.alwaysTCP || n == "tcp" || n == "tcp4" || n == "tcp6" {
+               return &fakeDNSConn{tcp: true, server: server, n: n, s: s}, nil
+       }
+       return &fakeDNSPacketConn{fakeDNSConn: fakeDNSConn{tcp: false, server: server, n: n, s: s}}, nil
 }
 
 type fakeDNSConn struct {
@@ -846,10 +849,6 @@ func (f *fakeDNSConn) Read(b []byte) (int, error) {
        return len(bb), nil
 }
 
-func (f *fakeDNSConn) ReadFrom(b []byte) (int, Addr, error) {
-       return 0, nil, nil
-}
-
 func (f *fakeDNSConn) Write(b []byte) (int, error) {
        if f.tcp && len(b) >= 2 {
                b = b[2:]
@@ -860,15 +859,24 @@ func (f *fakeDNSConn) Write(b []byte) (int, error) {
        return len(b), nil
 }
 
-func (f *fakeDNSConn) WriteTo(b []byte, addr Addr) (int, error) {
-       return 0, nil
-}
-
 func (f *fakeDNSConn) SetDeadline(t time.Time) error {
        f.t = t
        return nil
 }
 
+type fakeDNSPacketConn struct {
+       PacketConn
+       fakeDNSConn
+}
+
+func (f *fakeDNSPacketConn) SetDeadline(t time.Time) error {
+       return f.fakeDNSConn.SetDeadline(t)
+}
+
+func (f *fakeDNSPacketConn) Close() error {
+       return f.fakeDNSConn.Close()
+}
+
 // UDP round-tripper algorithm should ignore invalid DNS responses (issue 13281).
 func TestIgnoreDNSForgeries(t *testing.T) {
        c, s := Pipe()
@@ -973,7 +981,7 @@ func TestRetryTimeout(t *testing.T) {
 
        var deadline0 time.Time
 
-       fake := fakeDNSServer{func(_, s string, q dnsmessage.Message, deadline time.Time) (dnsmessage.Message, error) {
+       fake := fakeDNSServer{rh: func(_, s string, q dnsmessage.Message, deadline time.Time) (dnsmessage.Message, error) {
                t.Log(s, q, deadline)
 
                if deadline.IsZero() {
@@ -1034,7 +1042,7 @@ func testRotate(t *testing.T, rotate bool, nameservers, wantServers []string) {
        }
 
        var usedServers []string
-       fake := fakeDNSServer{func(_, s string, q dnsmessage.Message, deadline time.Time) (dnsmessage.Message, error) {
+       fake := fakeDNSServer{rh: func(_, s string, q dnsmessage.Message, deadline time.Time) (dnsmessage.Message, error) {
                usedServers = append(usedServers, s)
                return mockTXTResponse(q), nil
        }}
@@ -1218,7 +1226,7 @@ func TestStrictErrorsLookupIP(t *testing.T) {
        }
 
        for i, tt := range cases {
-               fake := fakeDNSServer{func(_, s string, q dnsmessage.Message, deadline time.Time) (dnsmessage.Message, error) {
+               fake := fakeDNSServer{rh: func(_, s string, q dnsmessage.Message, deadline time.Time) (dnsmessage.Message, error) {
                        t.Log(s, q)
 
                        switch tt.resolveWhich(q.Questions[0]) {
@@ -1356,7 +1364,7 @@ func TestStrictErrorsLookupTXT(t *testing.T) {
        const searchY = "test.y.golang.org."
        const txt = "Hello World"
 
-       fake := fakeDNSServer{func(_, s string, q dnsmessage.Message, deadline time.Time) (dnsmessage.Message, error) {
+       fake := fakeDNSServer{rh: func(_, s string, q dnsmessage.Message, deadline time.Time) (dnsmessage.Message, error) {
                t.Log(s, q)
 
                switch q.Questions[0].Name.String() {
@@ -1402,7 +1410,7 @@ func TestStrictErrorsLookupTXT(t *testing.T) {
 func TestDNSGoroutineRace(t *testing.T) {
        defer dnsWaitGroup.Wait()
 
-       fake := fakeDNSServer{func(n, s string, q dnsmessage.Message, t time.Time) (dnsmessage.Message, error) {
+       fake := fakeDNSServer{rh: func(n, s string, q dnsmessage.Message, t time.Time) (dnsmessage.Message, error) {
                time.Sleep(10 * time.Microsecond)
                return dnsmessage.Message{}, poll.ErrTimeout
        }}
@@ -1502,3 +1510,28 @@ func TestIssue12778(t *testing.T) {
                t.Fatalf("Err = %#v; wanted %q", de.Err, errNoSuchHost.Error())
        }
 }
+
+// Issue 26573: verify that Conns that don't implement PacketConn are treated
+// as streams even when udp was requested.
+func TestDNSDialTCP(t *testing.T) {
+       fake := fakeDNSServer{
+               rh: func(n, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
+                       r := dnsmessage.Message{
+                               Header: dnsmessage.Header{
+                                       ID:       q.Header.ID,
+                                       Response: true,
+                                       RCode:    dnsmessage.RCodeSuccess,
+                               },
+                               Questions: q.Questions,
+                       }
+                       return r, nil
+               },
+               alwaysTCP: true,
+       }
+       r := Resolver{PreferGo: true, Dial: fake.DialContext}
+       ctx := context.Background()
+       _, _, err := r.exchange(ctx, "0.0.0.0", mustQuestion("com.", dnsmessage.TypeALL, dnsmessage.ClassINET), time.Second)
+       if err != nil {
+               t.Fatal("exhange failed:", err)
+       }
+}