From 0ab78df9ea602d6bc9cf45dbd610c3d6f534cb58 Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Sat, 30 Apr 2016 21:27:04 -0500 Subject: [PATCH] net/http: fix a few crashes with a ClientTrace with nil funcs And add a test. Updates #12580 Change-Id: Ia7eaba09b8e7fd0eddbcaefb948d01ab10af876e Reviewed-on: https://go-review.googlesource.com/22659 Reviewed-by: Andrew Gerrand Run-TryBot: Brad Fitzpatrick TryBot-Result: Gobot Gobot --- src/net/http/transport.go | 6 +++--- src/net/http/transport_test.go | 23 ++++++++++++++++++++--- 2 files changed, 23 insertions(+), 6 deletions(-) diff --git a/src/net/http/transport.go b/src/net/http/transport.go index b4d56ab699..755a807bed 100644 --- a/src/net/http/transport.go +++ b/src/net/http/transport.go @@ -787,11 +787,11 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (*persistC req := treq.Request trace := treq.trace ctx := req.Context() - if trace != nil { + if trace != nil && trace.GetConn != nil { trace.GetConn(cm.addr()) } if pc, idleSince := t.getIdleConn(cm); pc != nil { - if trace != nil { + if trace != nil && trace.GotConn != nil { trace.GotConn(pc.gotIdleConnTrace(idleSince)) } // set request canceler to some non-nil function so we @@ -834,7 +834,7 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (*persistC select { case v := <-dialc: // Our dial finished. - if trace != nil && v.pc != nil { + if trace != nil && trace.GotConn != nil && v.pc != nil { trace.GotConn(httptrace.GotConnInfo{Conn: v.pc.conn}) } return v.pc, v.err diff --git a/src/net/http/transport_test.go b/src/net/http/transport_test.go index 67f0b74ba0..9f14c9649a 100644 --- a/src/net/http/transport_test.go +++ b/src/net/http/transport_test.go @@ -3193,7 +3193,12 @@ func TestTransportResponseHeaderLength(t *testing.T) { } } -func TestTransportEventTrace(t *testing.T) { +func TestTransportEventTrace(t *testing.T) { testTransportEventTrace(t, false) } + +// test a non-nil httptrace.ClientTrace but with all hooks set to zero. +func TestTransportEventTrace_NoHooks(t *testing.T) { testTransportEventTrace(t, true) } + +func testTransportEventTrace(t *testing.T, noHooks bool) { defer afterTest(t) const resBody = "some body" ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { @@ -3233,7 +3238,7 @@ func TestTransportEventTrace(t *testing.T) { }) req, _ := NewRequest("POST", "http://dns-is-faked.golang:"+port, strings.NewReader("some body")) - req = req.WithContext(httptrace.WithClientTrace(ctx, &httptrace.ClientTrace{ + trace := &httptrace.ClientTrace{ GetConn: func(hostPort string) { logf("Getting conn for %v ...", hostPort) }, GotConn: func(ci httptrace.GotConnInfo) { logf("got conn: %+v", ci) }, GotFirstResponseByte: func() { logf("first response byte") }, @@ -3250,7 +3255,12 @@ func TestTransportEventTrace(t *testing.T) { Wait100Continue: func() { logf("Wait100Continue") }, Got100Continue: func() { logf("Got100Continue") }, WroteRequest: func(e httptrace.WroteRequestInfo) { logf("WroteRequest: %+v", e) }, - })) + } + if noHooks { + // zero out all func pointers, trying to get some path to crash + *trace = httptrace.ClientTrace{} + } + req = req.WithContext(httptrace.WithClientTrace(ctx, trace)) req.Header.Set("Expect", "100-continue") res, err := c.Do(req) @@ -3266,6 +3276,13 @@ func TestTransportEventTrace(t *testing.T) { } res.Body.Close() + if noHooks { + // Done at this point. Just testing a full HTTP + // requests can happen with a trace pointing to a zero + // ClientTrace, full of nil func pointers. + return + } + got := buf.String() wantSub := func(sub string) { if !strings.Contains(got, sub) { -- 2.48.1