]> Cypherpunks repositories - gostls13.git/commitdiff
net/http/httptrace: add ClientTrace.TLSHandshakeStart & TLSHandshakeDone
authorEdward Muller <edwardam@interlix.com>
Wed, 5 Oct 2016 04:24:58 +0000 (21:24 -0700)
committerBrad Fitzpatrick <bradfitz@golang.org>
Wed, 19 Oct 2016 19:12:05 +0000 (19:12 +0000)
Fixes #16965

Change-Id: I3638fe280a5b1063ff589e6e1ff8a97c74b77c66
Reviewed-on: https://go-review.googlesource.com/30359
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>

src/go/build/deps_test.go
src/net/http/httptrace/trace.go
src/net/http/transport.go
src/net/http/transport_test.go

index 6da1e68fde878ca786ca2ebe72be44b900662cef..cbdcca4ac8a8565a6706531ff815287104f1e3fa 100644 (file)
@@ -396,7 +396,7 @@ var pkgDeps = map[string][]string{
                "runtime/debug",
        },
        "net/http/internal":  {"L4"},
-       "net/http/httptrace": {"context", "internal/nettrace", "net", "reflect", "time"},
+       "net/http/httptrace": {"context", "crypto/tls", "internal/nettrace", "net", "reflect", "time"},
 
        // HTTP-using packages.
        "expvar":             {"L4", "OS", "encoding/json", "net/http"},
index 8c29c4aa6f9dc0bac80e0b7ee6f314447ff93a22..5b042c097fea48e340ec5f18fa6ed19308f1516c 100644 (file)
@@ -8,6 +8,7 @@ package httptrace
 
 import (
        "context"
+       "crypto/tls"
        "internal/nettrace"
        "net"
        "reflect"
@@ -119,6 +120,16 @@ type ClientTrace struct {
        // enabled, this may be called multiple times.
        ConnectDone func(network, addr string, err error)
 
+       // TLSHandshakeStart is called when the TLS handshake is started. When
+       // connecting to a HTTPS site via a HTTP proxy, the handshake happens after
+       // the CONNECT request is processed by the proxy.
+       TLSHandshakeStart func()
+
+       // TLSHandshakeDone is called after the TLS handshake with either the
+       // successful handshake's connection state, or a non-nil error on handshake
+       // failure.
+       TLSHandshakeDone func(tls.ConnectionState, error)
+
        // WroteHeaders is called after the Transport has written
        // the request headers.
        WroteHeaders func()
index 5594c948cd8854139231e8a442f0757b6a11540b..429f667c1429218dfa681f72424eb76d700597a2 100644 (file)
@@ -955,6 +955,7 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (*persistCon
                writeErrCh:    make(chan error, 1),
                writeLoopDone: make(chan struct{}),
        }
+       trace := httptrace.ContextClientTrace(ctx)
        tlsDial := t.DialTLS != nil && cm.targetScheme == "https" && cm.proxyURL == nil
        if tlsDial {
                var err error
@@ -968,11 +969,20 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (*persistCon
                if tc, ok := pconn.conn.(*tls.Conn); ok {
                        // Handshake here, in case DialTLS didn't. TLSNextProto below
                        // depends on it for knowing the connection state.
+                       if trace != nil && trace.TLSHandshakeStart != nil {
+                               trace.TLSHandshakeStart()
+                       }
                        if err := tc.Handshake(); err != nil {
                                go pconn.conn.Close()
+                               if trace != nil && trace.TLSHandshakeDone != nil {
+                                       trace.TLSHandshakeDone(tls.ConnectionState{}, err)
+                               }
                                return nil, err
                        }
                        cs := tc.ConnectionState()
+                       if trace != nil && trace.TLSHandshakeDone != nil {
+                               trace.TLSHandshakeDone(cs, nil)
+                       }
                        pconn.tlsState = &cs
                }
        } else {
@@ -1042,6 +1052,9 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (*persistCon
                        })
                }
                go func() {
+                       if trace != nil && trace.TLSHandshakeStart != nil {
+                               trace.TLSHandshakeStart()
+                       }
                        err := tlsConn.Handshake()
                        if timer != nil {
                                timer.Stop()
@@ -1050,6 +1063,9 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (*persistCon
                }()
                if err := <-errc; err != nil {
                        plainConn.Close()
+                       if trace != nil && trace.TLSHandshakeDone != nil {
+                               trace.TLSHandshakeDone(tls.ConnectionState{}, err)
+                       }
                        return nil, err
                }
                if !cfg.InsecureSkipVerify {
@@ -1059,6 +1075,9 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (*persistCon
                        }
                }
                cs := tlsConn.ConnectionState()
+               if trace != nil && trace.TLSHandshakeDone != nil {
+                       trace.TLSHandshakeDone(cs, nil)
+               }
                pconn.tlsState = &cs
                pconn.conn = tlsConn
        }
index cef2acc4568dd3fc1104460abd6403c0e0389d3a..147b468e788ca6d8a3f94bc46b2342ba6691f881 100644 (file)
@@ -3288,6 +3288,12 @@ func testTransportEventTrace(t *testing.T, h2 bool, noHooks bool) {
                        close(gotWroteReqEvent)
                },
        }
+       if h2 {
+               trace.TLSHandshakeStart = func() { logf("tls handshake start") }
+               trace.TLSHandshakeDone = func(s tls.ConnectionState, err error) {
+                       logf("tls handshake done. ConnectionState = %v \n err = %v", s, err)
+               }
+       }
        if noHooks {
                // zero out all func pointers, trying to get some path to crash
                *trace = httptrace.ClientTrace{}
@@ -3339,7 +3345,10 @@ func testTransportEventTrace(t *testing.T, h2 bool, noHooks bool) {
        wantOnceOrMore("connected to tcp " + addrStr + " = <nil>")
        wantOnce("Reused:false WasIdle:false IdleTime:0s")
        wantOnce("first response byte")
-       if !h2 {
+       if h2 {
+               wantOnce("tls handshake start")
+               wantOnce("tls handshake done")
+       } else {
                wantOnce("PutIdleConn = <nil>")
        }
        wantOnce("Wait100Continue")
@@ -3411,6 +3420,55 @@ func TestTransportEventTraceRealDNS(t *testing.T) {
        }
 }
 
+// Test the httptrace.TLSHandshake{Start,Done} hooks with a https http1
+// connections. The http2 test is done in TestTransportEventTrace_h2
+func TestTLSHandshakeTrace(t *testing.T) {
+       defer afterTest(t)
+       s := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {}))
+       defer s.Close()
+
+       var mu sync.Mutex
+       var start, done bool
+       trace := &httptrace.ClientTrace{
+               TLSHandshakeStart: func() {
+                       mu.Lock()
+                       defer mu.Unlock()
+                       start = true
+               },
+               TLSHandshakeDone: func(s tls.ConnectionState, err error) {
+                       mu.Lock()
+                       defer mu.Unlock()
+                       done = true
+                       if err != nil {
+                               t.Fatal("Expected error to be nil but was:", err)
+                       }
+               },
+       }
+
+       tr := &Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}}
+       defer tr.CloseIdleConnections()
+       c := &Client{Transport: tr}
+       req, err := NewRequest("GET", s.URL, nil)
+       if err != nil {
+               t.Fatal("Unable to construct test request:", err)
+       }
+       req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace))
+
+       r, err := c.Do(req)
+       if err != nil {
+               t.Fatal("Unexpected error making request:", err)
+       }
+       r.Body.Close()
+       mu.Lock()
+       defer mu.Unlock()
+       if !start {
+               t.Fatal("Expected TLSHandshakeStart to be called, but wasn't")
+       }
+       if !done {
+               t.Fatal("Expected TLSHandshakeDone to be called, but wasnt't")
+       }
+}
+
 func TestTransportMaxIdleConns(t *testing.T) {
        defer afterTest(t)
        ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {