]> Cypherpunks repositories - gostls13.git/commitdiff
net/http: fix a few crashes with a ClientTrace with nil funcs
authorBrad Fitzpatrick <bradfitz@golang.org>
Sun, 1 May 2016 02:27:04 +0000 (21:27 -0500)
committerBrad Fitzpatrick <bradfitz@golang.org>
Sun, 1 May 2016 05:46:39 +0000 (05:46 +0000)
And add a test.

Updates #12580

Change-Id: Ia7eaba09b8e7fd0eddbcaefb948d01ab10af876e
Reviewed-on: https://go-review.googlesource.com/22659
Reviewed-by: Andrew Gerrand <adg@golang.org>
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>

src/net/http/transport.go
src/net/http/transport_test.go

index b4d56ab69947dff98ca50217e6a8b1d4a3b26d7c..755a807bed47487c9e5f1993b3bce75128f1180a 100644 (file)
@@ -787,11 +787,11 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (*persistC
        req := treq.Request
        trace := treq.trace
        ctx := req.Context()
-       if trace != nil {
+       if trace != nil && trace.GetConn != nil {
                trace.GetConn(cm.addr())
        }
        if pc, idleSince := t.getIdleConn(cm); pc != nil {
-               if trace != nil {
+               if trace != nil && trace.GotConn != nil {
                        trace.GotConn(pc.gotIdleConnTrace(idleSince))
                }
                // set request canceler to some non-nil function so we
@@ -834,7 +834,7 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (*persistC
        select {
        case v := <-dialc:
                // Our dial finished.
-               if trace != nil && v.pc != nil {
+               if trace != nil && trace.GotConn != nil && v.pc != nil {
                        trace.GotConn(httptrace.GotConnInfo{Conn: v.pc.conn})
                }
                return v.pc, v.err
index 67f0b74ba05a26addf99b9153cb612715fb5db1d..9f14c9649ad9789deb8b6212dbbf3271e6055c96 100644 (file)
@@ -3193,7 +3193,12 @@ func TestTransportResponseHeaderLength(t *testing.T) {
        }
 }
 
-func TestTransportEventTrace(t *testing.T) {
+func TestTransportEventTrace(t *testing.T) { testTransportEventTrace(t, false) }
+
+// test a non-nil httptrace.ClientTrace but with all hooks set to zero.
+func TestTransportEventTrace_NoHooks(t *testing.T) { testTransportEventTrace(t, true) }
+
+func testTransportEventTrace(t *testing.T, noHooks bool) {
        defer afterTest(t)
        const resBody = "some body"
        ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
@@ -3233,7 +3238,7 @@ func TestTransportEventTrace(t *testing.T) {
        })
 
        req, _ := NewRequest("POST", "http://dns-is-faked.golang:"+port, strings.NewReader("some body"))
-       req = req.WithContext(httptrace.WithClientTrace(ctx, &httptrace.ClientTrace{
+       trace := &httptrace.ClientTrace{
                GetConn:              func(hostPort string) { logf("Getting conn for %v ...", hostPort) },
                GotConn:              func(ci httptrace.GotConnInfo) { logf("got conn: %+v", ci) },
                GotFirstResponseByte: func() { logf("first response byte") },
@@ -3250,7 +3255,12 @@ func TestTransportEventTrace(t *testing.T) {
                Wait100Continue: func() { logf("Wait100Continue") },
                Got100Continue:  func() { logf("Got100Continue") },
                WroteRequest:    func(e httptrace.WroteRequestInfo) { logf("WroteRequest: %+v", e) },
-       }))
+       }
+       if noHooks {
+               // zero out all func pointers, trying to get some path to crash
+               *trace = httptrace.ClientTrace{}
+       }
+       req = req.WithContext(httptrace.WithClientTrace(ctx, trace))
 
        req.Header.Set("Expect", "100-continue")
        res, err := c.Do(req)
@@ -3266,6 +3276,13 @@ func TestTransportEventTrace(t *testing.T) {
        }
        res.Body.Close()
 
+       if noHooks {
+               // Done at this point. Just testing a full HTTP
+               // requests can happen with a trace pointing to a zero
+               // ClientTrace, full of nil func pointers.
+               return
+       }
+
        got := buf.String()
        wantSub := func(sub string) {
                if !strings.Contains(got, sub) {