return true
}
-func (t *Transport) dial(network, addr string) (c net.Conn, err error) {
+func (t *Transport) dial(network, addr string) (net.Conn, error) {
if t.Dial != nil {
- return t.Dial(network, addr)
+ c, err := t.Dial(network, addr)
+ if c == nil && err == nil {
+ err = errors.New("net/http: Transport.Dial hook returned (nil, nil)")
+ }
+ return c, err
}
return net.Dial(network, addr)
}
return pc, nil
case <-req.Cancel:
handlePendingDial()
- return nil, errors.New("net/http: request canceled while waiting for connection")
+ return nil, errRequestCanceledConn
case <-cancelc:
handlePendingDial()
- return nil, errors.New("net/http: request canceled while waiting for connection")
+ return nil, errRequestCanceledConn
}
}
if err != nil {
return nil, err
}
+ if pconn.conn == nil {
+ return nil, errors.New("net/http: Transport.DialTLS returned (nil, nil)")
+ }
if tc, ok := pconn.conn.(*tls.Conn); ok {
cs := tc.ConnectionState()
pconn.tlsState = &cs
var errTimeout error = &httpError{err: "net/http: timeout awaiting response headers", timeout: true}
var errClosed error = &httpError{err: "net/http: server closed connection before response was received"}
var errRequestCanceled = errors.New("net/http: request canceled")
+var errRequestCanceledConn = errors.New("net/http: request canceled while waiting for connection") // TODO: unify?
func nop() {}
}
pc.broken = true
if pc.closed == nil {
- pc.conn.Close()
pc.closed = err
- close(pc.closech)
+ if pc.alt != nil {
+ // Do nothing; can only get here via getConn's
+ // handlePendingDial's putOrCloseIdleConn when
+ // it turns out the abandoned connection in
+ // flight ended up negotiating an alternate
+ // protocol. We don't use the connection
+ // freelist for http2. That's done by the
+ // alternate protocol's RoundTripper.
+ } else {
+ pc.conn.Close()
+ close(pc.closech)
+ }
}
pc.mutateHeaderFunc = nil
}
. "net/http"
"net/http/httptest"
"net/http/httputil"
+ "net/http/internal"
"net/url"
"os"
"reflect"
}
}
+// Issue 13839
+func TestNoCrashReturningTransportAltConn(t *testing.T) {
+ cert, err := tls.X509KeyPair(internal.LocalhostCert, internal.LocalhostKey)
+ if err != nil {
+ t.Fatal(err)
+ }
+ ln := newLocalListener(t)
+ defer ln.Close()
+
+ handledPendingDial := make(chan bool, 1)
+ SetPendingDialHooks(nil, func() { handledPendingDial <- true })
+ defer SetPendingDialHooks(nil, nil)
+
+ testDone := make(chan struct{})
+ defer close(testDone)
+ go func() {
+ tln := tls.NewListener(ln, &tls.Config{
+ NextProtos: []string{"foo"},
+ Certificates: []tls.Certificate{cert},
+ })
+ sc, err := tln.Accept()
+ if err != nil {
+ t.Error(err)
+ return
+ }
+ if err := sc.(*tls.Conn).Handshake(); err != nil {
+ t.Error(err)
+ return
+ }
+ <-testDone
+ sc.Close()
+ }()
+
+ addr := ln.Addr().String()
+
+ req, _ := NewRequest("GET", "https://fake.tld/", nil)
+ cancel := make(chan struct{})
+ req.Cancel = cancel
+
+ doReturned := make(chan bool, 1)
+ madeRoundTripper := make(chan bool, 1)
+
+ tr := &Transport{
+ DisableKeepAlives: true,
+ TLSNextProto: map[string]func(string, *tls.Conn) RoundTripper{
+ "foo": func(authority string, c *tls.Conn) RoundTripper {
+ madeRoundTripper <- true
+ return funcRoundTripper(func() {
+ t.Error("foo RoundTripper should not be called")
+ })
+ },
+ },
+ Dial: func(_, _ string) (net.Conn, error) {
+ panic("shouldn't be called")
+ },
+ DialTLS: func(_, _ string) (net.Conn, error) {
+ tc, err := tls.Dial("tcp", addr, &tls.Config{
+ InsecureSkipVerify: true,
+ NextProtos: []string{"foo"},
+ })
+ if err != nil {
+ return nil, err
+ }
+ if err := tc.Handshake(); err != nil {
+ return nil, err
+ }
+ close(cancel)
+ <-doReturned
+ return tc, nil
+ },
+ }
+ c := &Client{Transport: tr}
+
+ _, err = c.Do(req)
+ if ue, ok := err.(*url.Error); !ok || ue.Err != ExportErrRequestCanceledConn {
+ t.Fatalf("Do error = %v; want url.Error with errRequestCanceledConn", err)
+ }
+
+ doReturned <- true
+ <-madeRoundTripper
+ <-handledPendingDial
+}
+
+var errFakeRoundTrip = errors.New("fake roundtrip")
+
+type funcRoundTripper func()
+
+func (fn funcRoundTripper) RoundTrip(*Request) (*Response, error) {
+ fn()
+ return nil, errFakeRoundTrip
+}
+
func wantBody(res *Response, err error, want string) error {
if err != nil {
return err