]> Cypherpunks repositories - gostls13.git/commitdiff
net, net/http: don't trace DNS dials
authorTom Bergan <tombergan@google.com>
Fri, 13 May 2016 05:03:46 +0000 (22:03 -0700)
committerBrad Fitzpatrick <bradfitz@golang.org>
Sat, 14 May 2016 00:14:25 +0000 (00:14 +0000)
This fixes change https://go-review.googlesource.com/#/c/23069/, which
assumes all DNS requests are UDP. This is not true -- DNS requests can
be TCP in some cases. See:
https://tip.golang.org/src/net/dnsclient_unix.go#L154
https://en.wikipedia.org/wiki/Domain_Name_System#Protocol_transport

Also, the test code added by the above change doesn't actually test
anything because the test uses a faked DNS resolver that doesn't
actually make any DNS queries. I fixed that by adding another test
that uses the system DNS resolver.

Updates #12580

Change-Id: I6c24c03ebab84d437d3ac610fd6eb5353753c490
Reviewed-on: https://go-review.googlesource.com/23101
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>

src/internal/nettrace/nettrace.go
src/net/dial.go
src/net/http/transport_test.go

index 0f85d727c605bd71c2452091f21cffa0664263fd..de3254df589f0e31844b147324e418d900a97ebd 100644 (file)
@@ -32,14 +32,14 @@ type Trace struct {
        // actually be for circular dependency reasons.
        DNSDone func(netIPs []interface{}, coalesced bool, err error)
 
-       // ConnectStart is called before a TCPAddr or UnixAddr
-       // Dial. In the case of DualStack (Happy Eyeballs) dialing,
-       // this may be called multiple times, from multiple
+       // ConnectStart is called before a Dial, excluding Dials made
+       // during DNS lookups. In the case of DualStack (Happy Eyeballs)
+       // dialing, this may be called multiple times, from multiple
        // goroutines.
        ConnectStart func(network, addr string)
 
-       // ConnectStart is called after a TCPAddr or UnixAddr Dial
-       // with the results. It may also be called multiple times,
-       // like ConnectStart.
+       // ConnectStart is called after a Dial with the results, excluding
+       // Dials made during DNS lookups. It may also be called multiple
+       // times, like ConnectStart.
        ConnectDone func(network, addr string, err error)
 }
index 5985421b0661a9bab78924b8758a29c3007b5bd5..16f67a2f33742743b1fe0234f9e4ab64ff173745 100644 (file)
@@ -317,7 +317,16 @@ func (d *Dialer) DialContext(ctx context.Context, network, address string) (Conn
                ctx = subCtx
        }
 
-       addrs, err := resolveAddrList(ctx, "dial", network, address, d.LocalAddr)
+       // Shadow the nettrace (if any) during resolve so Connect events don't fire for DNS lookups.
+       resolveCtx := ctx
+       if trace, _ := ctx.Value(nettrace.TraceKey{}).(*nettrace.Trace); trace != nil {
+               shadow := *trace
+               shadow.ConnectStart = nil
+               shadow.ConnectDone = nil
+               resolveCtx = context.WithValue(resolveCtx, nettrace.TraceKey{}, &shadow)
+       }
+
+       addrs, err := resolveAddrList(resolveCtx, "dial", network, address, d.LocalAddr)
        if err != nil {
                return nil, &OpError{Op: "dial", Net: network, Source: nil, Addr: nil, Err: err}
        }
@@ -472,21 +481,11 @@ func dialSerial(ctx context.Context, dp *dialParam, ras addrList) (Conn, error)
        return nil, firstErr
 }
 
-// traceDialType reports whether ra is an address type for which
-// nettrace.Trace should trace.
-func traceDialType(ra Addr) bool {
-       switch ra.(type) {
-       case *TCPAddr, *UnixAddr:
-               return true
-       }
-       return false
-}
-
 // dialSingle attempts to establish and returns a single connection to
 // the destination address.
 func dialSingle(ctx context.Context, dp *dialParam, ra Addr) (c Conn, err error) {
        trace, _ := ctx.Value(nettrace.TraceKey{}).(*nettrace.Trace)
-       if trace != nil && traceDialType(ra) {
+       if trace != nil {
                raStr := ra.String()
                if trace.ConnectStart != nil {
                        trace.ConnectStart(dp.network, raStr)
index ab26de2e9526a4a9378eef41c7e9069f5153fce2..328fd5727b9e6b66d54303d3f73c2ef5f26d9153 100644 (file)
@@ -3308,6 +3308,52 @@ func testTransportEventTrace(t *testing.T, noHooks bool) {
        }
 }
 
+func TestTransportEventTraceRealDNS(t *testing.T) {
+       defer afterTest(t)
+       tr := &Transport{}
+       defer tr.CloseIdleConnections()
+       c := &Client{Transport: tr}
+
+       var mu sync.Mutex
+       var buf bytes.Buffer
+       logf := func(format string, args ...interface{}) {
+               mu.Lock()
+               defer mu.Unlock()
+               fmt.Fprintf(&buf, format, args...)
+               buf.WriteByte('\n')
+       }
+
+       req, _ := NewRequest("GET", "http://dns-should-not-resolve.golang:80", nil)
+       trace := &httptrace.ClientTrace{
+               DNSStart:     func(e httptrace.DNSStartInfo) { logf("DNSStart: %+v", e) },
+               DNSDone:      func(e httptrace.DNSDoneInfo) { logf("DNSDone: %+v", e) },
+               ConnectStart: func(network, addr string) { logf("ConnectStart: %s %s", network, addr) },
+               ConnectDone:  func(network, addr string, err error) { logf("ConnectDone: %s %s %v", network, addr, err) },
+       }
+       req = req.WithContext(httptrace.WithClientTrace(context.Background(), trace))
+
+       resp, err := c.Do(req)
+       if err == nil {
+               resp.Body.Close()
+               t.Fatal("expected error during DNS lookup")
+       }
+
+       got := buf.String()
+       wantSub := func(sub string) {
+               if !strings.Contains(got, sub) {
+                       t.Errorf("expected substring %q in output.", sub)
+               }
+       }
+       wantSub("DNSStart: {Host:dns-should-not-resolve.golang}")
+       wantSub("DNSDone: {Addrs:[] Err:")
+       if strings.Contains(got, "ConnectStart") || strings.Contains(got, "ConnectDone") {
+               t.Errorf("should not see Connect events")
+       }
+       if t.Failed() {
+               t.Errorf("Output:\n%s", got)
+       }
+}
+
 func TestTransportMaxIdleConns(t *testing.T) {
        defer afterTest(t)
        ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {