]> Cypherpunks repositories - gostls13.git/commitdiff
net/http, net/http/httptrace: new package for tracing HTTP client requests
authorBrad Fitzpatrick <bradfitz@golang.org>
Sat, 16 Apr 2016 18:57:06 +0000 (11:57 -0700)
committerBrad Fitzpatrick <bradfitz@golang.org>
Thu, 28 Apr 2016 20:56:38 +0000 (20:56 +0000)
Updates #12580

Change-Id: I9f9578148ef2b48dffede1007317032d39f6af55
Reviewed-on: https://go-review.googlesource.com/22191
Reviewed-by: Ian Lance Taylor <iant@golang.org>
Reviewed-by: Tom Bergan <tombergan@google.com>
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>

src/go/build/deps_test.go
src/internal/nettrace/nettrace.go [new file with mode: 0644]
src/net/dial.go
src/net/http/httptrace/trace.go [new file with mode: 0644]
src/net/http/httptrace/trace_test.go [new file with mode: 0644]
src/net/http/request.go
src/net/http/transport.go
src/net/http/transport_test.go
src/net/lookup.go

index a87de577b5b1346cb977dd187c7a44bb8a973bce..67057a960bc1e4c79d37706d166c561c77c5ab82 100644 (file)
@@ -282,6 +282,7 @@ var pkgDeps = map[string][]string{
        // do networking portably, it must have a small dependency set: just L0+basic os.
        "net": {"L0", "CGO",
                "context", "math/rand", "os", "sort", "syscall", "time",
+               "internal/nettrace",
                "internal/syscall/windows", "internal/singleflight", "internal/race"},
 
        // NET enables use of basic network-related packages.
@@ -363,8 +364,11 @@ var pkgDeps = map[string][]string{
                "mime/multipart", "runtime/debug",
                "net/http/internal",
                "golang.org/x/net/http2/hpack",
+               "internal/nettrace",
+               "net/http/httptrace",
        },
-       "net/http/internal": {"L4"},
+       "net/http/internal":  {"L4"},
+       "net/http/httptrace": {"context", "internal/nettrace", "net", "reflect", "time"},
 
        // HTTP-using packages.
        "expvar":             {"L4", "OS", "encoding/json", "net/http"},
diff --git a/src/internal/nettrace/nettrace.go b/src/internal/nettrace/nettrace.go
new file mode 100644 (file)
index 0000000..51a8b2c
--- /dev/null
@@ -0,0 +1,43 @@
+// Copyright 2016 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package nettrace contains internal hooks for tracing activity in
+// the net package. This package is purely internal for use by the
+// net/http/httptrace package and has no stable API exposed to end
+// users.
+package nettrace
+
+// TraceKey is a context.Context Value key. Its associated value should
+// be a *Trace struct.
+type TraceKey struct{}
+
+// LookupIPAltResolverKey is a context.Context Value key used by tests to
+// specify an alternate resolver func.
+// It is not exposed to outsider users. (But see issue 12503)
+// The value should be the same type as lookupIP:
+//     func lookupIP(ctx context.Context, host string) ([]IPAddr, error)
+type LookupIPAltResolverKey struct{}
+
+// Trace contains a set of hooks for tracing events within
+// the net package. Any specific hook may be nil.
+type Trace struct {
+       // DNSStart is called with the hostname of a DNS lookup
+       // before it begins.
+       DNSStart func(name string)
+
+       // DNSDone is called after a DNS lookup completes (or fails).
+       // The coalesced parameter is whether singleflight de-dupped
+       // the call. The addrs are of type net.IPAddr but can't
+       // actually be for circular dependency reasons.
+       DNSDone func(netIPs []interface{}, coalesced bool, err error)
+
+       // ConnectStart is called before a Dial. In the case of
+       // DualStack (Happy Eyeballs) dialing, this may be called
+       // multiple times, from multiple goroutines.
+       ConnectStart func(network, addr string)
+
+       // ConnectStart is called after a Dial with the results. It
+       // may also be called multiple times, like ConnectStart.
+       ConnectDone func(network, addr string, err error)
+}
index 05d7e98027a938316788bdd09ef286080160a49f..256ef3806147c426c18a4e4cf3756bd293fff221 100644 (file)
@@ -6,6 +6,7 @@ package net
 
 import (
        "context"
+       "internal/nettrace"
        "time"
 )
 
@@ -474,6 +475,16 @@ func dialSerial(ctx context.Context, dp *dialParam, ras addrList) (Conn, error)
 // dialSingle attempts to establish and returns a single connection to
 // the destination address.
 func dialSingle(ctx context.Context, dp *dialParam, ra Addr) (c Conn, err error) {
+       trace, _ := ctx.Value(nettrace.TraceKey{}).(*nettrace.Trace)
+       if trace != nil {
+               raStr := ra.String()
+               if trace.ConnectStart != nil {
+                       trace.ConnectStart(dp.network, raStr)
+               }
+               if trace.ConnectDone != nil {
+                       defer func() { trace.ConnectDone(dp.network, raStr, err) }()
+               }
+       }
        la := dp.LocalAddr
        switch ra := ra.(type) {
        case *TCPAddr:
diff --git a/src/net/http/httptrace/trace.go b/src/net/http/httptrace/trace.go
new file mode 100644 (file)
index 0000000..5d2c548
--- /dev/null
@@ -0,0 +1,225 @@
+// Copyright 2016 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.h
+
+// Package httptrace provides mechanisms to trace the events within
+// HTTP client requests.
+package httptrace
+
+import (
+       "context"
+       "internal/nettrace"
+       "net"
+       "reflect"
+       "time"
+)
+
+// unique type to prevent assignment.
+type clientEventContextKey struct{}
+
+// ContextClientTrace returns the ClientTrace associated with the
+// provided context. If none, it returns nil.
+func ContextClientTrace(ctx context.Context) *ClientTrace {
+       trace, _ := ctx.Value(clientEventContextKey{}).(*ClientTrace)
+       return trace
+}
+
+// WithClientTrace returns a new context based on the provided parent
+// ctx. HTTP client requests made with the returned context will use
+// the provided trace hooks, in addition to any previous hooks
+// registered with ctx. Any hooks defined in the provided trace will
+// be called first.
+func WithClientTrace(ctx context.Context, trace *ClientTrace) context.Context {
+       if trace == nil {
+               panic("nil trace")
+       }
+       old := ContextClientTrace(ctx)
+       trace.compose(old)
+
+       ctx = context.WithValue(ctx, clientEventContextKey{}, trace)
+       if trace.hasNetHooks() {
+               nt := &nettrace.Trace{
+                       ConnectStart: trace.ConnectStart,
+                       ConnectDone:  trace.ConnectDone,
+               }
+               if trace.DNSStart != nil {
+                       nt.DNSStart = func(name string) {
+                               trace.DNSStart(DNSStartInfo{Host: name})
+                       }
+               }
+               if trace.DNSDone != nil {
+                       nt.DNSDone = func(netIPs []interface{}, coalesced bool, err error) {
+                               addrs := make([]net.IPAddr, len(netIPs))
+                               for i, ip := range netIPs {
+                                       addrs[i] = ip.(net.IPAddr)
+                               }
+                               trace.DNSDone(DNSDoneInfo{
+                                       Addrs:     addrs,
+                                       Coalesced: coalesced,
+                                       Err:       err,
+                               })
+                       }
+               }
+               ctx = context.WithValue(ctx, nettrace.TraceKey{}, nt)
+       }
+       return ctx
+}
+
+// ClientTrace is a set of hooks to run at various stages of an HTTP
+// client request. Any particular hook may be nil. Functions may be
+// called concurrently from different goroutines, starting after the
+// call to Transport.RoundTrip and ending either when RoundTrip
+// returns an error, or when the Response.Body is closed.
+type ClientTrace struct {
+       // GetConn is called before a connection is created or
+       // retrieved from an idle pool. The hostPort is the
+       // "host:port" of the target or proxy. GetConn is called even
+       // if there's already an idle cached connection available.
+       GetConn func(hostPort string)
+
+       // GotConn is called after a successful connection is
+       // obtained. There is no hook for failure to obtain a
+       // connection; instead, use the error from
+       // Transport.RoundTrip.
+       GotConn func(GotConnInfo)
+
+       // PutIdleConn is called when the connection is returned to
+       // the idle pool. If err is nil, the connection was
+       // successfully returned to the idle pool. If err is non-nil,
+       // it describes why not. PutIdleConn is not called if
+       // connection reuse is disabled via Transport.DisableKeepAlives.
+       // PutIdleConn is called before the caller's Response.Body.Close
+       // call returns.
+       PutIdleConn func(err error)
+
+       // GotFirstResponseByte is called when the first byte of the response
+       // headers is available.
+       GotFirstResponseByte func()
+
+       // Got100Continue is called if the server replies with a "100
+       // Continue" response.
+       Got100Continue func()
+
+       // DNSStart is called when a DNS lookup begins.
+       DNSStart func(DNSStartInfo)
+
+       // DNSDone is called when a DNS lookup ends.
+       DNSDone func(DNSDoneInfo)
+
+       // ConnectStart is called when a new connection's Dial begins.
+       // If net.Dialer.DualStack (IPv6 "Happy Eyeballs") support is
+       // enabled, this may be called multiple times.
+       ConnectStart func(network, addr string)
+
+       // ConnectDone is called when a new connection's Dial
+       // completes. The provided err indicates whether the
+       // connection completedly successfully.
+       // If net.Dialer.DualStack ("Happy Eyeballs") support is
+       // enabled, this may be called multiple times.
+       ConnectDone func(network, addr string, err error)
+
+       // WroteHeaders is called after the Transport has written
+       // the request headers.
+       WroteHeaders func()
+
+       // Wait100Continue is called if the Request specified
+       // "Expected: 100-continue" and the Transport has written the
+       // request headers but is waiting for "100 Continue" from the
+       // server before writing the request body.
+       Wait100Continue func()
+
+       // WroteRequest is called with the result of writing the
+       // request and any body.
+       WroteRequest func(WroteRequestInfo)
+}
+
+// WroteRequestInfo contains information provided to the WroteRequest
+// hook.
+type WroteRequestInfo struct {
+       // Err is any error encountered while writing the Request.
+       Err error
+}
+
+// compose modifies t such that it respects the previously-registered hooks in old,
+// subject to the composition policy requested in t.Compose.
+func (t *ClientTrace) compose(old *ClientTrace) {
+       if old == nil {
+               return
+       }
+       tv := reflect.ValueOf(t).Elem()
+       ov := reflect.ValueOf(old).Elem()
+       structType := tv.Type()
+       for i := 0; i < structType.NumField(); i++ {
+               tf := tv.Field(i)
+               hookType := tf.Type()
+               if hookType.Kind() != reflect.Func {
+                       continue
+               }
+               of := ov.Field(i)
+               if of.IsNil() {
+                       continue
+               }
+               if tf.IsNil() {
+                       tf.Set(of)
+                       continue
+               }
+
+               // Make a copy of tf for tf to call. (Otherwise it
+               // creates a recursive call cycle and stack overflows)
+               tfCopy := reflect.ValueOf(tf.Interface())
+
+               // We need to call both tf and of in some order.
+               newFunc := reflect.MakeFunc(hookType, func(args []reflect.Value) []reflect.Value {
+                       tfCopy.Call(args)
+                       return of.Call(args)
+               })
+               tv.Field(i).Set(newFunc)
+       }
+}
+
+// DNSStartInfo contains information about a DNS request.
+type DNSStartInfo struct {
+       Host string
+}
+
+// DNSDoneInfo contains information about the results of a DNS lookup.
+type DNSDoneInfo struct {
+       // Addrs are the IPv4 and/or IPv6 addresses found in the DNS
+       // lookup. The contents of the slice should not be mutated.
+       Addrs []net.IPAddr
+
+       // Err is any error that occurred during the DNS lookup.
+       Err error
+
+       // Coalesced is whether the Addrs were shared with another
+       // caller who was doing the same DNS lookup concurrently.
+       Coalesced bool
+}
+
+func (t *ClientTrace) hasNetHooks() bool {
+       if t == nil {
+               return false
+       }
+       return t.DNSStart != nil || t.DNSDone != nil || t.ConnectStart != nil || t.ConnectDone != nil
+}
+
+// GotConnInfo is the argument to the ClientTrace.GotConn function and
+// contains information about the obtained connection.
+type GotConnInfo struct {
+       // Conn is the connection that was obtained. It is owned by
+       // the http.Transport and should not be read, written or
+       // closed by users of ClientTrace.
+       Conn net.Conn
+
+       // Reused is whether this connection has been previously
+       // used for another HTTP request.
+       Reused bool
+
+       // WasIdle is whether this connection was obtained from an
+       // idle pool.
+       WasIdle bool
+
+       // IdleTime reports how long the connection was previously
+       // idle, if WasIdle is true.
+       IdleTime time.Duration
+}
diff --git a/src/net/http/httptrace/trace_test.go b/src/net/http/httptrace/trace_test.go
new file mode 100644 (file)
index 0000000..ed6ddbb
--- /dev/null
@@ -0,0 +1,62 @@
+// Copyright 2016 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.h
+
+package httptrace
+
+import (
+       "bytes"
+       "testing"
+)
+
+func TestCompose(t *testing.T) {
+       var buf bytes.Buffer
+       var testNum int
+
+       connectStart := func(b byte) func(network, addr string) {
+               return func(network, addr string) {
+                       if addr != "addr" {
+                               t.Errorf(`%d. args for %Q case = %q, %q; want addr of "addr"`, testNum, b, network, addr)
+                       }
+                       buf.WriteByte(b)
+               }
+       }
+
+       tests := [...]struct {
+               trace, old *ClientTrace
+               want       string
+       }{
+               0: {
+                       want: "T",
+                       trace: &ClientTrace{
+                               ConnectStart: connectStart('T'),
+                       },
+               },
+               1: {
+                       want: "TO",
+                       trace: &ClientTrace{
+                               ConnectStart: connectStart('T'),
+                       },
+                       old: &ClientTrace{ConnectStart: connectStart('O')},
+               },
+               2: {
+                       want:  "O",
+                       trace: &ClientTrace{},
+                       old:   &ClientTrace{ConnectStart: connectStart('O')},
+               },
+       }
+       for i, tt := range tests {
+               testNum = i
+               buf.Reset()
+
+               tr := *tt.trace
+               tr.compose(tt.old)
+               if tr.ConnectStart != nil {
+                       tr.ConnectStart("net", "addr")
+               }
+               if got := buf.String(); got != tt.want {
+                       t.Errorf("%d. got = %q; want %q", i, got, tt.want)
+               }
+       }
+
+}
index a49ab369644f318bf18b1ee99988a610632f07f1..1bde114909c38eaff04cead872059f9b575d19a3 100644 (file)
@@ -18,6 +18,7 @@ import (
        "io/ioutil"
        "mime"
        "mime/multipart"
+       "net/http/httptrace"
        "net/textproto"
        "net/url"
        "strconv"
@@ -437,7 +438,16 @@ var errMissingHost = errors.New("http: Request.Write on Request with no Host or
 
 // extraHeaders may be nil
 // waitForContinue may be nil
-func (req *Request) write(w io.Writer, usingProxy bool, extraHeaders Header, waitForContinue func() bool) error {
+func (req *Request) write(w io.Writer, usingProxy bool, extraHeaders Header, waitForContinue func() bool) (err error) {
+       trace := httptrace.ContextClientTrace(req.Context())
+       if trace != nil && trace.WroteRequest != nil {
+               defer func() {
+                       trace.WroteRequest(httptrace.WroteRequestInfo{
+                               Err: err,
+                       })
+               }()
+       }
+
        // Find the target host. Prefer the Host: header, but if that
        // is not given, use the host from the request URL.
        //
@@ -474,7 +484,7 @@ func (req *Request) write(w io.Writer, usingProxy bool, extraHeaders Header, wai
                w = bw
        }
 
-       _, err := fmt.Fprintf(w, "%s %s HTTP/1.1\r\n", valueOrDefault(req.Method, "GET"), ruri)
+       _, err = fmt.Fprintf(w, "%s %s HTTP/1.1\r\n", valueOrDefault(req.Method, "GET"), ruri)
        if err != nil {
                return err
        }
@@ -525,6 +535,10 @@ func (req *Request) write(w io.Writer, usingProxy bool, extraHeaders Header, wai
                return err
        }
 
+       if trace != nil && trace.WroteHeaders != nil {
+               trace.WroteHeaders()
+       }
+
        // Flush and wait for 100-continue if expected.
        if waitForContinue != nil {
                if bw, ok := w.(*bufio.Writer); ok {
@@ -533,7 +547,9 @@ func (req *Request) write(w io.Writer, usingProxy bool, extraHeaders Header, wai
                                return err
                        }
                }
-
+               if trace != nil && trace.Wait100Continue != nil {
+                       trace.Wait100Continue()
+               }
                if !waitForContinue() {
                        req.closeBody()
                        return nil
index b8b75b3e293e3365ffbc856fc6066d0f50485658..f1e0560ab7ebeb8f8b3999cf19f4fd855aeed88e 100644 (file)
@@ -19,6 +19,7 @@ import (
        "io"
        "log"
        "net"
+       "net/http/httptrace"
        "net/url"
        "os"
        "strings"
@@ -168,7 +169,7 @@ type Transport struct {
        h2transport   *http2Transport // non-nil if http2 wired up
 
        // TODO: MaxIdleConns tunable for global max cached connections (Issue 15461)
-       // TODO: tunable on timeout on cached connections
+       // TODO: tunable on timeout on cached connections (and advertise with Keep-Alive header?)
        // TODO: tunable on max per-host TCP dials in flight (Issue 13957)
 }
 
@@ -280,8 +281,9 @@ func ProxyURL(fixedURL *url.URL) func(*Request) (*url.URL, error) {
 // transportRequest is a wrapper around a *Request that adds
 // optional extra headers to write.
 type transportRequest struct {
-       *Request        // original request, not to be mutated
-       extra    Header // extra headers to write, or nil
+       *Request                        // original request, not to be mutated
+       extra    Header                 // extra headers to write, or nil
+       trace    *httptrace.ClientTrace // optional
 }
 
 func (tr *transportRequest) extraHeaders() Header {
@@ -297,6 +299,9 @@ func (tr *transportRequest) extraHeaders() Header {
 // and redirects), see Get, Post, and the Client type.
 func (t *Transport) RoundTrip(req *Request) (*Response, error) {
        t.nextProtoOnce.Do(t.onceSetNextProtoDefaults)
+       ctx := req.Context()
+       trace := httptrace.ContextClientTrace(ctx)
+
        if req.URL == nil {
                req.closeBody()
                return nil, errors.New("http: nil Request.URL")
@@ -342,7 +347,7 @@ func (t *Transport) RoundTrip(req *Request) (*Response, error) {
 
        for {
                // treq gets modified by roundTrip, so we need to recreate for each retry.
-               treq := &transportRequest{Request: req}
+               treq := &transportRequest{Request: req, trace: trace}
                cm, err := t.connectMethodForRequest(treq)
                if err != nil {
                        req.closeBody()
@@ -353,7 +358,7 @@ func (t *Transport) RoundTrip(req *Request) (*Response, error) {
                // host (for http or https), the http proxy, or the http proxy
                // pre-CONNECTed to https server. In any case, we'll be ready
                // to send it requests.
-               pconn, err := t.getConn(req, cm)
+               pconn, err := t.getConn(treq, cm)
                if err != nil {
                        t.setReqCanceler(req, nil)
                        req.closeBody()
@@ -605,16 +610,19 @@ func (t *Transport) tryPutIdleConn(pconn *persistConn) error {
        if t.idleConn == nil {
                t.idleConn = make(map[connectMethodKey][]*persistConn)
        }
-       if len(t.idleConn[key]) >= max {
+       idles := t.idleConn[key]
+       if len(idles) >= max {
                return errTooManyIdle
        }
-       for _, exist := range t.idleConn[key] {
+       for _, exist := range idles {
                if exist == pconn {
                        log.Fatalf("dup idle pconn %p in freelist", pconn)
                }
        }
-       t.idleConn[key] = append(t.idleConn[key], pconn)
+
+       t.idleConn[key] = append(idles, pconn)
        t.idleCount++
+       pconn.idleAt = time.Now()
        return nil
 }
 
@@ -640,18 +648,14 @@ func (t *Transport) getIdleConnCh(cm connectMethod) chan *persistConn {
        return ch
 }
 
-func (t *Transport) getIdleConn(cm connectMethod) *persistConn {
+func (t *Transport) getIdleConn(cm connectMethod) (pconn *persistConn, idleSince time.Time) {
        key := cm.key()
        t.idleMu.Lock()
        defer t.idleMu.Unlock()
-       if t.idleConn == nil {
-               return nil
-       }
-       var pconn *persistConn
        for {
                pconns, ok := t.idleConn[key]
                if !ok {
-                       return nil
+                       return nil, time.Time{}
                }
                if len(pconns) == 1 {
                        pconn = pconns[0]
@@ -671,7 +675,7 @@ func (t *Transport) getIdleConn(cm connectMethod) *persistConn {
                        // carry on.
                        continue
                }
-               return pconn
+               return pconn, pconn.idleAt
        }
 }
 
@@ -736,6 +740,8 @@ func (t *Transport) replaceReqCanceler(r *Request, fn func()) bool {
        return true
 }
 
+var zeroDialer net.Dialer
+
 func (t *Transport) dial(ctx context.Context, network, addr string) (net.Conn, error) {
        if t.Dialer != nil {
                return t.Dialer.DialContext(ctx, network, addr)
@@ -747,16 +753,24 @@ func (t *Transport) dial(ctx context.Context, network, addr string) (net.Conn, e
                }
                return c, err
        }
-       return net.Dial(network, addr)
+       return zeroDialer.DialContext(ctx, network, addr)
 }
 
 // getConn dials and creates a new persistConn to the target as
 // specified in the connectMethod. This includes doing a proxy CONNECT
 // and/or setting up TLS.  If this doesn't return an error, the persistConn
 // is ready to write requests to.
-func (t *Transport) getConn(req *Request, cm connectMethod) (*persistConn, error) {
+func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (*persistConn, error) {
+       req := treq.Request
+       trace := treq.trace
        ctx := req.Context()
-       if pc := t.getIdleConn(cm); pc != nil {
+       if trace != nil {
+               trace.GetConn(cm.addr())
+       }
+       if pc, idleSince := t.getIdleConn(cm); pc != nil {
+               if trace != nil {
+                       trace.GotConn(pc.gotIdleConnTrace(idleSince))
+               }
                // set request canceler to some non-nil function so we
                // can detect whether it was cleared between now and when
                // we enter roundTrip
@@ -797,6 +811,9 @@ func (t *Transport) getConn(req *Request, cm connectMethod) (*persistConn, error
        select {
        case v := <-dialc:
                // Our dial finished.
+               if trace != nil && v.pc != nil {
+                       trace.GotConn(httptrace.GotConnInfo{Conn: v.pc.conn})
+               }
                return v.pc, v.err
        case pc := <-idleConnCh:
                // Another request finished first and its net.Conn
@@ -805,6 +822,9 @@ func (t *Transport) getConn(req *Request, cm connectMethod) (*persistConn, error
                // But our dial is still going, so give it away
                // when it finishes:
                handlePendingDial()
+               if trace != nil {
+                       trace.GotConn(httptrace.GotConnInfo{Conn: pc.conn, Reused: pc.isReused()})
+               }
                return pc, nil
        case <-req.Cancel:
                handlePendingDial()
@@ -1093,6 +1113,8 @@ type persistConn struct {
        // whether or not a connection can be reused. Issue 7569.
        writeErrCh chan error
 
+       idleAt time.Time // time it last become idle; guarded by Transport.idleMu
+
        mu                   sync.Mutex // guards following fields
        numExpectedResponses int
        closed               error // set non-nil when conn is closed, before closech is closed
@@ -1150,6 +1172,16 @@ func (pc *persistConn) isReused() bool {
        return r
 }
 
+func (pc *persistConn) gotIdleConnTrace(idleAt time.Time) (t httptrace.GotConnInfo) {
+       pc.mu.Lock()
+       defer pc.mu.Unlock()
+       t.Reused = pc.reused
+       t.Conn = pc.conn
+       t.WasIdle = true
+       t.IdleTime = time.Since(idleAt)
+       return
+}
+
 func (pc *persistConn) cancelRequest() {
        pc.mu.Lock()
        defer pc.mu.Unlock()
@@ -1164,11 +1196,17 @@ func (pc *persistConn) readLoop() {
                pc.t.removeIdleConn(pc)
        }()
 
-       tryPutIdleConn := func() bool {
+       tryPutIdleConn := func(trace *httptrace.ClientTrace) bool {
                if err := pc.t.tryPutIdleConn(pc); err != nil {
                        closeErr = err
+                       if trace != nil && trace.PutIdleConn != nil && err != errKeepAlivesDisabled {
+                               trace.PutIdleConn(err)
+                       }
                        return false
                }
+               if trace != nil && trace.PutIdleConn != nil {
+                       trace.PutIdleConn(nil)
+               }
                return true
        }
 
@@ -1200,10 +1238,11 @@ func (pc *persistConn) readLoop() {
                pc.mu.Unlock()
 
                rc := <-pc.reqch
+               trace := httptrace.ContextClientTrace(rc.req.Context())
 
                var resp *Response
                if err == nil {
-                       resp, err = pc.readResponse(rc)
+                       resp, err = pc.readResponse(rc, trace)
                }
 
                if err != nil {
@@ -1254,7 +1293,7 @@ func (pc *persistConn) readLoop() {
                        alive = alive &&
                                !pc.sawEOF &&
                                pc.wroteRequest() &&
-                               tryPutIdleConn()
+                               tryPutIdleConn(trace)
 
                        select {
                        case rc.ch <- responseAndError{res: resp}:
@@ -1313,7 +1352,7 @@ func (pc *persistConn) readLoop() {
                                bodyEOF &&
                                !pc.sawEOF &&
                                pc.wroteRequest() &&
-                               tryPutIdleConn()
+                               tryPutIdleConn(trace)
                        if bodyEOF {
                                eofc <- struct{}{}
                        }
@@ -1349,13 +1388,22 @@ func (pc *persistConn) readLoopPeekFailLocked(peekErr error) {
 
 // readResponse reads an HTTP response (or two, in the case of "Expect:
 // 100-continue") from the server. It returns the final non-100 one.
-func (pc *persistConn) readResponse(rc requestAndChan) (resp *Response, err error) {
+// trace is optional.
+func (pc *persistConn) readResponse(rc requestAndChan, trace *httptrace.ClientTrace) (resp *Response, err error) {
+       if trace != nil && trace.GotFirstResponseByte != nil {
+               if peek, err := pc.br.Peek(1); err == nil && len(peek) == 1 {
+                       trace.GotFirstResponseByte()
+               }
+       }
        resp, err = ReadResponse(pc.br, rc.req)
        if err != nil {
                return
        }
        if rc.continueCh != nil {
                if resp.StatusCode == 100 {
+                       if trace != nil && trace.Got100Continue != nil {
+                               trace.Got100Continue()
+                       }
                        rc.continueCh <- struct{}{}
                } else {
                        close(rc.continueCh)
index bf2aa2f0b6406753f4b7edd258c591ab83c894bb..3f6ab7b01b7af0c036bfeabeec7fd03a3d21c576 100644 (file)
@@ -18,6 +18,7 @@ import (
        "crypto/tls"
        "errors"
        "fmt"
+       "internal/nettrace"
        "internal/testenv"
        "io"
        "io/ioutil"
@@ -25,6 +26,7 @@ import (
        "net"
        . "net/http"
        "net/http/httptest"
+       "net/http/httptrace"
        "net/http/httputil"
        "net/http/internal"
        "net/url"
@@ -3191,6 +3193,104 @@ func TestTransportResponseHeaderLength(t *testing.T) {
        }
 }
 
+func TestTransportEventTrace(t *testing.T) {
+       defer afterTest(t)
+       const resBody = "some body"
+       ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+               if _, err := ioutil.ReadAll(r.Body); err != nil {
+                       t.Error(err)
+               }
+               io.WriteString(w, resBody)
+       }))
+       defer ts.Close()
+       tr := &Transport{
+               ExpectContinueTimeout: 1 * time.Second,
+       }
+       defer tr.CloseIdleConnections()
+       c := &Client{Transport: tr}
+
+       var mu sync.Mutex
+       var buf bytes.Buffer
+       logf := func(format string, args ...interface{}) {
+               mu.Lock()
+               defer mu.Unlock()
+               fmt.Fprintf(&buf, format, args...)
+               buf.WriteByte('\n')
+       }
+
+       ip, port, err := net.SplitHostPort(ts.Listener.Addr().String())
+       if err != nil {
+               t.Fatal(err)
+       }
+
+       // Install a fake DNS server.
+       ctx := context.WithValue(context.Background(), nettrace.LookupIPAltResolverKey{}, func(ctx context.Context, host string) ([]net.IPAddr, error) {
+               if host != "dns-is-faked.golang" {
+                       t.Errorf("unexpected DNS host lookup for %q", host)
+                       return nil, nil
+               }
+               if err != nil {
+                       t.Error(err)
+                       return nil, err
+               }
+               return []net.IPAddr{net.IPAddr{IP: net.ParseIP(ip)}}, nil
+       })
+
+       req, _ := NewRequest("POST", "http://dns-is-faked.golang:"+port, strings.NewReader("some body"))
+       req = req.WithContext(httptrace.WithClientTrace(ctx, &httptrace.ClientTrace{
+               GetConn:              func(hostPort string) { logf("Getting conn for %v ...", hostPort) },
+               GotConn:              func(ci httptrace.GotConnInfo) { logf("got conn: %+v", ci) },
+               GotFirstResponseByte: func() { logf("first response byte") },
+               PutIdleConn:          func(err error) { logf("PutIdleConn = %v", err) },
+               DNSStart:             func(e httptrace.DNSStartInfo) { logf("DNS start: %+v", e) },
+               DNSDone:              func(e httptrace.DNSDoneInfo) { logf("DNS done: %+v", e) },
+               ConnectStart:         func(network, addr string) { logf("ConnectStart: Connecting to %s %s ...", network, addr) },
+               ConnectDone: func(network, addr string, err error) {
+                       if err != nil {
+                               t.Errorf("ConnectDone: %v", err)
+                       }
+                       logf("ConnectDone: connected to %s %s = %v", network, addr, err)
+               },
+               Wait100Continue: func() { logf("Wait100Continue") },
+               Got100Continue:  func() { logf("Got100Continue") },
+               WroteRequest:    func(e httptrace.WroteRequestInfo) { logf("WroteRequest: %+v", e) },
+       }))
+
+       req.Header.Set("Expect", "100-continue")
+       res, err := c.Do(req)
+       if err != nil {
+               t.Fatal(err)
+       }
+       slurp, err := ioutil.ReadAll(res.Body)
+       if err != nil {
+               t.Fatal(err)
+       }
+       if string(slurp) != resBody || res.StatusCode != 200 {
+               t.Fatalf("Got %q, %v; want %q, 200 OK", slurp, res.Status, resBody)
+       }
+       res.Body.Close()
+
+       got := buf.String()
+       wantSub := func(sub string) {
+               if !strings.Contains(got, sub) {
+                       t.Errorf("expected substring %q in output.", sub)
+               }
+       }
+       wantSub("Getting conn for dns-is-faked.golang:" + port)
+       wantSub("DNS start: {Host:dns-is-faked.golang}")
+       wantSub("DNS done: {Addrs:[{IP:" + ip + " Zone:}] Err:<nil> Coalesced:false}")
+       wantSub("connected to tcp " + ts.Listener.Addr().String() + " = <nil>")
+       wantSub("Reused:false WasIdle:false IdleTime:0s")
+       wantSub("first response byte")
+       wantSub("PutIdleConn = <nil>")
+       wantSub("WroteRequest: {Err:<nil>}")
+       wantSub("Wait100Continue")
+       wantSub("Got100Continue")
+       if t.Failed() {
+               t.Errorf("Output:\n%s", got)
+       }
+}
+
 var errFakeRoundTrip = errors.New("fake roundtrip")
 
 type funcRoundTripper func()
index 5e60011165dda2f828da35fe3eb3117b45caac63..c169e9e902c2654e01ac6490aed6dc2224e3ded0 100644 (file)
@@ -6,6 +6,7 @@ package net
 
 import (
        "context"
+       "internal/nettrace"
        "internal/singleflight"
 )
 
@@ -85,16 +86,37 @@ func lookupIPReturn(addrsi interface{}, err error, shared bool) ([]IPAddr, error
        return addrs, nil
 }
 
+// ipAddrsEface returns an empty interface slice of addrs.
+func ipAddrsEface(addrs []IPAddr) []interface{} {
+       s := make([]interface{}, len(addrs))
+       for i, v := range addrs {
+               s[i] = v
+       }
+       return s
+}
+
 // lookupIPContext looks up a hostname with a context.
+//
+// TODO(bradfitz): rename this function. All the other
+// build-tag-specific lookupIP funcs also take a context now, so this
+// name is no longer great. Maybe make this lookupIPMerge and ditch
+// the other one, making its callers call this instead with a
+// context.Background().
 func lookupIPContext(ctx context.Context, host string) (addrs []IPAddr, err error) {
-       // TODO(bradfitz): when adding trace hooks later here, make
-       // sure the tracing is done outside of the singleflight
-       // merging. Both callers should see the DNS lookup delay, even
-       // if it's only being done once. The r.Shared bit can be
-       // included in the trace for callers who need it.
+       trace, _ := ctx.Value(nettrace.TraceKey{}).(*nettrace.Trace)
+       if trace != nil && trace.DNSStart != nil {
+               trace.DNSStart(host)
+       }
+       // The underlying resolver func is lookupIP by default but it
+       // can be overridden by tests. This is needed by net/http, so it
+       // uses a context key instead of unexported variables.
+       resolverFunc := lookupIP
+       if alt, _ := ctx.Value(nettrace.LookupIPAltResolverKey{}).(func(context.Context, string) ([]IPAddr, error)); alt != nil {
+               resolverFunc = alt
+       }
 
        ch := lookupGroup.DoChan(host, func() (interface{}, error) {
-               return testHookLookupIP(ctx, lookupIP, host)
+               return testHookLookupIP(ctx, resolverFunc, host)
        })
 
        select {
@@ -103,9 +125,17 @@ func lookupIPContext(ctx context.Context, host string) (addrs []IPAddr, err erro
                // future requests to start the DNS lookup again
                // rather than waiting for the current lookup to
                // complete. See issue 8602.
+               err := mapErr(ctx.Err())
                lookupGroup.Forget(host)
-               return nil, mapErr(ctx.Err())
+               if trace != nil && trace.DNSDone != nil {
+                       trace.DNSDone(nil, false, err)
+               }
+               return nil, err
        case r := <-ch:
+               if trace != nil && trace.DNSDone != nil {
+                       addrs, _ := r.Val.([]IPAddr)
+                       trace.DNSDone(ipAddrsEface(addrs), r.Shared, r.Err)
+               }
                return lookupIPReturn(r.Val, r.Err, r.Shared)
        }
 }