]> Cypherpunks repositories - gostls13.git/commitdiff
net/http: fix Transport crash when abandoning dial which upgrades protos
authorBrad Fitzpatrick <bradfitz@golang.org>
Wed, 13 Jan 2016 16:30:00 +0000 (16:30 +0000)
committerBrad Fitzpatrick <bradfitz@golang.org>
Wed, 13 Jan 2016 17:52:50 +0000 (17:52 +0000)
When the Transport was creating an bound HTTP connection (protocol
unknown initially) and then ends up deciding it doesn't need it, a
goroutine sits around to clean up whatever the result was. That
goroutine made the false assumption that the result was always an
HTTP/1 connection or an error. It may also be an alternate protocol
in which case the *persistConn.conn net.Conn field is nil, and the
alt field is non-nil.

Fixes #13839

Change-Id: Ia4972e5eb1ad53fa00410b3466d4129c753e0871
Reviewed-on: https://go-review.googlesource.com/18573
Reviewed-by: Russ Cox <rsc@golang.org>
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>

src/net/http/export_test.go
src/net/http/transport.go
src/net/http/transport_test.go

index 514d02b2a3a4fae5974802b80e8c1ce9d41233cf..52bccbdce31c5a3139f985140c58e7ebe3733ea1 100644 (file)
@@ -21,6 +21,7 @@ var (
        ExportServerNewConn           = (*Server).newConn
        ExportCloseWriteAndWait       = (*conn).closeWriteAndWait
        ExportErrRequestCanceled      = errRequestCanceled
+       ExportErrRequestCanceledConn  = errRequestCanceledConn
        ExportServeFile               = serveFile
        ExportHttp2ConfigureTransport = http2ConfigureTransport
        ExportHttp2ConfigureServer    = http2ConfigureServer
index 6c083917668514779d447114c8cab93f4a3446bc..9378b8385e9d4678f1316a09cdf565a55b5e2bf0 100644 (file)
@@ -618,9 +618,13 @@ func (t *Transport) replaceReqCanceler(r *Request, fn func()) bool {
        return true
 }
 
-func (t *Transport) dial(network, addr string) (c net.Conn, err error) {
+func (t *Transport) dial(network, addr string) (net.Conn, error) {
        if t.Dial != nil {
-               return t.Dial(network, addr)
+               c, err := t.Dial(network, addr)
+               if c == nil && err == nil {
+                       err = errors.New("net/http: Transport.Dial hook returned (nil, nil)")
+               }
+               return c, err
        }
        return net.Dial(network, addr)
 }
@@ -682,10 +686,10 @@ func (t *Transport) getConn(req *Request, cm connectMethod) (*persistConn, error
                return pc, nil
        case <-req.Cancel:
                handlePendingDial()
-               return nil, errors.New("net/http: request canceled while waiting for connection")
+               return nil, errRequestCanceledConn
        case <-cancelc:
                handlePendingDial()
-               return nil, errors.New("net/http: request canceled while waiting for connection")
+               return nil, errRequestCanceledConn
        }
 }
 
@@ -705,6 +709,9 @@ func (t *Transport) dialConn(cm connectMethod) (*persistConn, error) {
                if err != nil {
                        return nil, err
                }
+               if pconn.conn == nil {
+                       return nil, errors.New("net/http: Transport.DialTLS returned (nil, nil)")
+               }
                if tc, ok := pconn.conn.(*tls.Conn); ok {
                        cs := tc.ConnectionState()
                        pconn.tlsState = &cs
@@ -1326,6 +1333,7 @@ func (e *httpError) Temporary() bool { return true }
 var errTimeout error = &httpError{err: "net/http: timeout awaiting response headers", timeout: true}
 var errClosed error = &httpError{err: "net/http: server closed connection before response was received"}
 var errRequestCanceled = errors.New("net/http: request canceled")
+var errRequestCanceledConn = errors.New("net/http: request canceled while waiting for connection") // TODO: unify?
 
 func nop() {}
 
@@ -1502,9 +1510,19 @@ func (pc *persistConn) closeLocked(err error) {
        }
        pc.broken = true
        if pc.closed == nil {
-               pc.conn.Close()
                pc.closed = err
-               close(pc.closech)
+               if pc.alt != nil {
+                       // Do nothing; can only get here via getConn's
+                       // handlePendingDial's putOrCloseIdleConn when
+                       // it turns out the abandoned connection in
+                       // flight ended up negotiating an alternate
+                       // protocol.  We don't use the connection
+                       // freelist for http2. That's done by the
+                       // alternate protocol's RoundTripper.
+               } else {
+                       pc.conn.Close()
+                       close(pc.closech)
+               }
        }
        pc.mutateHeaderFunc = nil
 }
index 46e330315b78de898d8e52533c7c49b9c384ec87..3b2a5f978e263557feca946c0f1af4eed838cf86 100644 (file)
@@ -24,6 +24,7 @@ import (
        . "net/http"
        "net/http/httptest"
        "net/http/httputil"
+       "net/http/internal"
        "net/url"
        "os"
        "reflect"
@@ -2939,6 +2940,98 @@ func TestTransportReuseConnEmptyResponseBody(t *testing.T) {
        }
 }
 
+// Issue 13839
+func TestNoCrashReturningTransportAltConn(t *testing.T) {
+       cert, err := tls.X509KeyPair(internal.LocalhostCert, internal.LocalhostKey)
+       if err != nil {
+               t.Fatal(err)
+       }
+       ln := newLocalListener(t)
+       defer ln.Close()
+
+       handledPendingDial := make(chan bool, 1)
+       SetPendingDialHooks(nil, func() { handledPendingDial <- true })
+       defer SetPendingDialHooks(nil, nil)
+
+       testDone := make(chan struct{})
+       defer close(testDone)
+       go func() {
+               tln := tls.NewListener(ln, &tls.Config{
+                       NextProtos:   []string{"foo"},
+                       Certificates: []tls.Certificate{cert},
+               })
+               sc, err := tln.Accept()
+               if err != nil {
+                       t.Error(err)
+                       return
+               }
+               if err := sc.(*tls.Conn).Handshake(); err != nil {
+                       t.Error(err)
+                       return
+               }
+               <-testDone
+               sc.Close()
+       }()
+
+       addr := ln.Addr().String()
+
+       req, _ := NewRequest("GET", "https://fake.tld/", nil)
+       cancel := make(chan struct{})
+       req.Cancel = cancel
+
+       doReturned := make(chan bool, 1)
+       madeRoundTripper := make(chan bool, 1)
+
+       tr := &Transport{
+               DisableKeepAlives: true,
+               TLSNextProto: map[string]func(string, *tls.Conn) RoundTripper{
+                       "foo": func(authority string, c *tls.Conn) RoundTripper {
+                               madeRoundTripper <- true
+                               return funcRoundTripper(func() {
+                                       t.Error("foo RoundTripper should not be called")
+                               })
+                       },
+               },
+               Dial: func(_, _ string) (net.Conn, error) {
+                       panic("shouldn't be called")
+               },
+               DialTLS: func(_, _ string) (net.Conn, error) {
+                       tc, err := tls.Dial("tcp", addr, &tls.Config{
+                               InsecureSkipVerify: true,
+                               NextProtos:         []string{"foo"},
+                       })
+                       if err != nil {
+                               return nil, err
+                       }
+                       if err := tc.Handshake(); err != nil {
+                               return nil, err
+                       }
+                       close(cancel)
+                       <-doReturned
+                       return tc, nil
+               },
+       }
+       c := &Client{Transport: tr}
+
+       _, err = c.Do(req)
+       if ue, ok := err.(*url.Error); !ok || ue.Err != ExportErrRequestCanceledConn {
+               t.Fatalf("Do error = %v; want url.Error with errRequestCanceledConn", err)
+       }
+
+       doReturned <- true
+       <-madeRoundTripper
+       <-handledPendingDial
+}
+
+var errFakeRoundTrip = errors.New("fake roundtrip")
+
+type funcRoundTripper func()
+
+func (fn funcRoundTripper) RoundTrip(*Request) (*Response, error) {
+       fn()
+       return nil, errFakeRoundTrip
+}
+
 func wantBody(res *Response, err error, want string) error {
        if err != nil {
                return err