]> Cypherpunks repositories - gostls13.git/commitdiff
http: export Transport, add keep-alive support
authorBrad Fitzpatrick <bradfitz@golang.org>
Wed, 23 Mar 2011 17:38:18 +0000 (10:38 -0700)
committerBrad Fitzpatrick <bradfitz@golang.org>
Wed, 23 Mar 2011 17:38:18 +0000 (10:38 -0700)
This patch adds a connection cache and keep-alive
support to Transport, which is used by the
HTTP client.

It's also structured such that it's easy to add
HTTP pipelining in the future.

R=rsc, petar-m, bradfitzwork, r
CC=golang-dev
https://golang.org/cl/4272045

src/pkg/http/client.go
src/pkg/http/export_test.go [new file with mode: 0644]
src/pkg/http/httptest/server.go
src/pkg/http/persist.go
src/pkg/http/proxy_test.go
src/pkg/http/serve_test.go
src/pkg/http/transport.go
src/pkg/http/transport_test.go

index c43e58332b9bcd904f846b410bc3f16ff72ded43..daba3a89b0c21effce0c6f9e4f85c3d97afb6968 100644 (file)
@@ -57,40 +57,6 @@ type readClose struct {
        io.Closer
 }
 
-// matchNoProxy returns true if requests to addr should not use a proxy,
-// according to the NO_PROXY or no_proxy environment variable.
-func matchNoProxy(addr string) bool {
-       if len(addr) == 0 {
-               return false
-       }
-       no_proxy := os.Getenv("NO_PROXY")
-       if len(no_proxy) == 0 {
-               no_proxy = os.Getenv("no_proxy")
-       }
-       if no_proxy == "*" {
-               return true
-       }
-
-       addr = strings.ToLower(strings.TrimSpace(addr))
-       if hasPort(addr) {
-               addr = addr[:strings.LastIndex(addr, ":")]
-       }
-
-       for _, p := range strings.Split(no_proxy, ",", -1) {
-               p = strings.ToLower(strings.TrimSpace(p))
-               if len(p) == 0 {
-                       continue
-               }
-               if hasPort(p) {
-                       p = p[:strings.LastIndex(p, ":")]
-               }
-               if addr == p || (p[0] == '.' && (strings.HasSuffix(addr, p) || addr == p[1:])) {
-                       return true
-               }
-       }
-       return false
-}
-
 // Do sends an HTTP request and returns an HTTP response, following
 // policy (e.g. redirects, cookies, auth) as configured on the client.
 //
diff --git a/src/pkg/http/export_test.go b/src/pkg/http/export_test.go
new file mode 100644 (file)
index 0000000..a76b707
--- /dev/null
@@ -0,0 +1,34 @@
+// Copyright 2011 The Go Authors.  All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Bridge package to expose http internals to tests in the http_test
+// package.
+
+package http
+
+func (t *Transport) IdleConnKeysForTesting() (keys []string) {
+       keys = make([]string, 0)
+       t.lk.Lock()
+       defer t.lk.Unlock()
+       if t.idleConn == nil {
+               return
+       }
+       for key, _ := range t.idleConn {
+               keys = append(keys, key)
+       }
+       return
+}
+
+func (t *Transport) IdleConnCountForTesting(cacheKey string) int {
+       t.lk.Lock()
+       defer t.lk.Unlock()
+       if t.idleConn == nil {
+               return 0
+       }
+       conns, ok := t.idleConn[cacheKey]
+       if !ok {
+               return 0
+       }
+       return len(conns)
+}
index 86c9eb4353562ca5e7465dc94a65706429c27a92..6e825a890d19382122f768161fa0f685839ec9ae 100644 (file)
@@ -9,6 +9,7 @@ package httptest
 import (
        "fmt"
        "http"
+       "os"
        "net"
 )
 
@@ -19,6 +20,21 @@ type Server struct {
        Listener net.Listener
 }
 
+// historyListener keeps track of all connections that it's ever
+// accepted.
+type historyListener struct {
+       net.Listener
+       history []net.Conn
+}
+
+func (hs *historyListener) Accept() (c net.Conn, err os.Error) {
+       c, err = hs.Listener.Accept()
+       if err == nil {
+               hs.history = append(hs.history, c)
+       }
+       return
+}
+
 // NewServer starts and returns a new Server.
 // The caller should call Close when finished, to shut it down.
 func NewServer(handler http.Handler) *Server {
@@ -29,10 +45,10 @@ func NewServer(handler http.Handler) *Server {
                        panic(fmt.Sprintf("httptest: failed to listen on a port: %v", err))
                }
        }
-       ts.Listener = l
+       ts.Listener = &historyListener{l, make([]net.Conn, 0)}
        ts.URL = "http://" + l.Addr().String()
        server := &http.Server{Handler: handler}
-       go server.Serve(l)
+       go server.Serve(ts.Listener)
        return ts
 }
 
@@ -40,3 +56,15 @@ func NewServer(handler http.Handler) *Server {
 func (s *Server) Close() {
        s.Listener.Close()
 }
+
+// CloseClientConnections closes any currently open HTTP connections
+// to the test Server.
+func (s *Server) CloseClientConnections() {
+       hl, ok := s.Listener.(*historyListener)
+       if !ok {
+               return
+       }
+       for _, conn := range hl.history {
+               conn.Close()
+       }
+}
index a8285c894a46483e72facbe568f59156856a4662..b93c5fe4855c8c5d207c2a0b9374ddb5005e68ba 100644 (file)
@@ -213,6 +213,7 @@ type ClientConn struct {
 
        pipe     textproto.Pipeline
        writeReq func(*Request, io.Writer) os.Error
+       readRes  func(buf *bufio.Reader, method string) (*Response, os.Error)
 }
 
 // NewClientConn returns a new ClientConn reading and writing c.  If r is not
@@ -226,6 +227,7 @@ func NewClientConn(c net.Conn, r *bufio.Reader) *ClientConn {
                r:        r,
                pipereq:  make(map[*Request]uint),
                writeReq: (*Request).Write,
+               readRes:  ReadResponse,
        }
 }
 
@@ -363,7 +365,7 @@ func (cc *ClientConn) Read(req *Request) (resp *Response, err os.Error) {
                }
        }
 
-       resp, err = ReadResponse(r, req.Method)
+       resp, err = cc.readRes(r, req.Method)
        cc.lk.Lock()
        defer cc.lk.Unlock()
        if err != nil {
index 0f2ca458fed2c53d03781a95f462875da9384f28..7050ef5ed063d96fc15fe957c72a2020580668a7 100644 (file)
@@ -12,31 +12,33 @@ import (
 // TODO(mattn):
 //     test ProxyAuth
 
-var MatchNoProxyTests = []struct {
+var UseProxyTests = []struct {
        host  string
        match bool
 }{
-       {"localhost", true},        // match completely
-       {"barbaz.net", true},       // match as .barbaz.net
-       {"foobar.com:443", true},   // have a port but match 
-       {"foofoobar.com", false},   // not match as a part of foobar.com
-       {"baz.com", false},         // not match as a part of barbaz.com
-       {"localhost.net", false},   // not match as suffix of address
-       {"local.localhost", false}, // not match as prefix as address
-       {"barbarbaz.net", false},   // not match because NO_PROXY have a '.'
-       {"www.foobar.com", false},  // not match because NO_PROXY is not .foobar.com
+       {"localhost", false},      // match completely
+       {"barbaz.net", false},     // match as .barbaz.net
+       {"foobar.com:443", false}, // have a port but match 
+       {"foofoobar.com", true},   // not match as a part of foobar.com
+       {"baz.com", true},         // not match as a part of barbaz.com
+       {"localhost.net", true},   // not match as suffix of address
+       {"local.localhost", true}, // not match as prefix as address
+       {"barbarbaz.net", true},   // not match because NO_PROXY have a '.'
+       {"www.foobar.com", true},  // not match because NO_PROXY is not .foobar.com
 }
 
-func TestMatchNoProxy(t *testing.T) {
+func TestUseProxy(t *testing.T) {
        oldenv := os.Getenv("NO_PROXY")
        no_proxy := "foobar.com, .barbaz.net   , localhost"
        os.Setenv("NO_PROXY", no_proxy)
        defer os.Setenv("NO_PROXY", oldenv)
 
-       for _, test := range MatchNoProxyTests {
-               if matchNoProxy(test.host) != test.match {
+       tr := &Transport{}
+
+       for _, test := range UseProxyTests {
+               if tr.useProxy(test.host) != test.match {
                        if test.match {
-                               t.Errorf("matchNoProxy(%v) = %v, want %v", test.host, !test.match, test.match)
+                               t.Errorf("useProxy(%v) = %v, want %v", test.host, !test.match, test.match)
                        } else {
                                t.Errorf("not expected: '%s' shouldn't match as '%s'", test.host, no_proxy)
                        }
index 6b881a2491da94ae58596d28ae136bdf74377dc5..b5487358cdff7f98e47ab3f9e7a7b86b404f544c 100644 (file)
@@ -250,7 +250,9 @@ func TestServerTimeouts(t *testing.T) {
        url := fmt.Sprintf("http://localhost:%d/", addr.Port)
 
        // Hit the HTTP server successfully.
-       r, _, err := Get(url)
+       tr := &Transport{DisableKeepAlives: true} // they interfere with this test
+       c := &Client{Transport: tr}
+       r, _, err := c.Get(url)
        if err != nil {
                t.Fatalf("http Get #1: %v", err)
        }
@@ -335,6 +337,7 @@ func TestIdentityResponse(t *testing.T) {
                        t.Errorf("for %s expected len(res.TransferEncoding) of %d; got %d (%v)",
                                url, expected, tl, res.TransferEncoding)
                }
+               res.Body.Close()
        }
 
        // Verify that ErrContentLength is returned
@@ -343,7 +346,6 @@ func TestIdentityResponse(t *testing.T) {
        if err != nil {
                t.Fatalf("error with Get of %s: %v", url, err)
        }
-
        // Verify that the connection is closed when the declared Content-Length
        // is larger than what the handler wrote.
        conn, err := net.Dial("tcp", "", ts.Listener.Addr().String())
index cea1a3b2401890155980ac24edca79602e103e1f..8a73ead31f9f855c4598232474333a87963fb527 100644 (file)
@@ -9,6 +9,8 @@ import (
        "crypto/tls"
        "encoding/base64"
        "fmt"
+       "io"
+       "log"
        "net"
        "os"
        "strings"
@@ -20,22 +22,24 @@ import (
 // each call to Do and uses HTTP proxies as directed by the
 // $HTTP_PROXY and $NO_PROXY (or $http_proxy and $no_proxy)
 // environment variables.
-var DefaultTransport RoundTripper = &transport{}
-
-// transport implements Tranport for the default case, using TCP
-// connections to either the host or a proxy, serving http or https
-// schemes.  In the future this may become public and support options
-// on keep-alive connection duration, pipelining controls, etc.  For
-// now this is simply a port of the old Go code client code to the
-// Transport interface.
-type transport struct {
-       // TODO: keep-alives, pipelining, etc using a map from
-       // scheme/host to a connection.  Something like:
-       l        sync.Mutex
-       hostConn map[string]*ClientConn
-}
-
-func (ct *transport) RoundTrip(req *Request) (resp *Response, err os.Error) {
+var DefaultTransport RoundTripper = &Transport{}
+
+// Transport is an implementation of RoundTripper that supports http,
+// https, and http proxies (for either http or https with CONNECT).
+// Transport can also cache connections for future re-use.
+type Transport struct {
+       lk       sync.Mutex
+       idleConn map[string][]*persistConn
+
+       // TODO: tunables on max cached connections (total, per-server), duration
+       // TODO: optional pipelining
+
+       IgnoreEnvironment bool // don't look at environment variables for proxy configuration
+       DisableKeepAlives bool
+}
+
+// RoundTrip implements the RoundTripper interface.
+func (t *Transport) RoundTrip(req *Request) (resp *Response, err os.Error) {
        if req.URL == nil {
                if req.URL, err = ParseURL(req.RawURL); err != nil {
                        return
@@ -45,26 +49,71 @@ func (ct *transport) RoundTrip(req *Request) (resp *Response, err os.Error) {
                return nil, &badStringError{"unsupported protocol scheme", req.URL.Scheme}
        }
 
-       addr := req.URL.Host
-       if !hasPort(addr) {
-               addr += ":" + req.URL.Scheme
+       cm, err := t.connectMethodForRequest(req)
+       if err != nil {
+               return nil, err
+       }
+
+       // Get the cached or newly-created connection to either the
+       // host (for http or https), the http proxy, or the http proxy
+       // pre-CONNECTed to https server.  In any case, we'll be ready
+       // to send it requests.
+       pconn, err := t.getConn(cm)
+       if err != nil {
+               return nil, err
        }
 
-       var proxyURL *URL
-       proxyAuth := ""
-       proxy := ""
-       if !matchNoProxy(addr) {
-               proxy = os.Getenv("HTTP_PROXY")
-               if proxy == "" {
-                       proxy = os.Getenv("http_proxy")
+       return pconn.roundTrip(req)
+}
+
+// CloseIdleConnections closes any connections which were previously
+// connected from previous requests but are now sitting idle in
+// a "keep-alive" state. It does not interrupt any connections currently
+// in use.
+func (t *Transport) CloseIdleConnections() {
+       t.lk.Lock()
+       defer t.lk.Unlock()
+       if t.idleConn == nil {
+               return
+       }
+       for _, conns := range t.idleConn {
+               for _, pconn := range conns {
+                       pconn.close()
                }
        }
+       t.idleConn = nil
+}
 
-       var write = (*Request).Write
+//
+// Private implementation past this point.
+//
 
-       if proxy != "" {
-               write = (*Request).WriteProxy
-               proxyURL, err = ParseRequestURL(proxy)
+func (t *Transport) getenvEitherCase(k string) string {
+       if t.IgnoreEnvironment {
+               return ""
+       }
+       if v := t.getenv(strings.ToUpper(k)); v != "" {
+               return v
+       }
+       return t.getenv(strings.ToLower(k))
+}
+
+func (t *Transport) getenv(k string) string {
+       if t.IgnoreEnvironment {
+               return ""
+       }
+       return os.Getenv(k)
+}
+
+func (t *Transport) connectMethodForRequest(req *Request) (*connectMethod, os.Error) {
+       cm := &connectMethod{
+               targetScheme: req.URL.Scheme,
+               targetAddr:   canonicalAddr(req.URL),
+       }
+
+       proxy := t.getenvEitherCase("HTTP_PROXY")
+       if proxy != "" && t.useProxy(cm.targetAddr) {
+               proxyURL, err := ParseRequestURL(proxy)
                if err != nil {
                        return nil, os.ErrorString("invalid proxy address")
                }
@@ -74,83 +123,405 @@ func (ct *transport) RoundTrip(req *Request) (resp *Response, err os.Error) {
                                return nil, os.ErrorString("invalid proxy address")
                        }
                }
-               addr = proxyURL.Host
-               proxyInfo := proxyURL.RawUserinfo
-               if proxyInfo != "" {
-                       enc := base64.URLEncoding
-                       encoded := make([]byte, enc.EncodedLen(len(proxyInfo)))
-                       enc.Encode(encoded, []byte(proxyInfo))
-                       proxyAuth = "Basic " + string(encoded)
+               cm.proxyURL = proxyURL
+       }
+       return cm, nil
+}
+
+// proxyAuth returns the Proxy-Authorization header to set
+// on requests, if applicable.
+func (cm *connectMethod) proxyAuth() string {
+       if cm.proxyURL == nil {
+               return ""
+       }
+       proxyInfo := cm.proxyURL.RawUserinfo
+       if proxyInfo != "" {
+               enc := base64.URLEncoding
+               encoded := make([]byte, enc.EncodedLen(len(proxyInfo)))
+               enc.Encode(encoded, []byte(proxyInfo))
+               return "Basic " + string(encoded)
+       }
+       return ""
+}
+
+func (t *Transport) putIdleConn(pconn *persistConn) {
+       t.lk.Lock()
+       defer t.lk.Unlock()
+       if t.DisableKeepAlives {
+               pconn.close()
+               return
+       }
+       if pconn.isBroken() {
+               return
+       }
+       key := pconn.cacheKey
+       t.idleConn[key] = append(t.idleConn[key], pconn)
+}
+
+func (t *Transport) getIdleConn(cm *connectMethod) (pconn *persistConn) {
+       t.lk.Lock()
+       defer t.lk.Unlock()
+       if t.idleConn == nil {
+               t.idleConn = make(map[string][]*persistConn)
+       }
+       key := cm.String()
+       for {
+               pconns, ok := t.idleConn[key]
+               if !ok {
+                       return nil
+               }
+               if len(pconns) == 1 {
+                       pconn = pconns[0]
+                       t.idleConn[key] = nil, false
+               } else {
+                       // 2 or more cached connections; pop last
+                       // TODO: queue?
+                       pconn = pconns[len(pconns)-1]
+                       t.idleConn[key] = pconns[0 : len(pconns)-1]
+               }
+               if !pconn.isBroken() {
+                       return
                }
        }
+       return
+}
+
+// getConn dials and creates a new persistConn to the target as
+// specified in the connectMethod.  This includes doing a proxy CONNECT
+// and/or setting up TLS.  If this doesn't return an error, the persistConn
+// is ready to write requests to.
+func (t *Transport) getConn(cm *connectMethod) (*persistConn, os.Error) {
+       if pc := t.getIdleConn(cm); pc != nil {
+               return pc, nil
+       }
 
-       // Connect to server or proxy
-       conn, err := net.Dial("tcp", "", addr)
+       conn, err := net.Dial("tcp", "", cm.addr())
        if err != nil {
                return nil, err
        }
 
-       if req.URL.Scheme == "http" {
-               // Include proxy http header if needed.
-               if proxyAuth != "" {
-                       req.Header.Set("Proxy-Authorization", proxyAuth)
-               }
-       } else { // https
-               if proxyURL != nil {
-                       // Ask proxy for direct connection to server.
-                       // addr defaults above to ":https" but we need to use numbers
-                       addr = req.URL.Host
-                       if !hasPort(addr) {
-                               addr += ":443"
-                       }
-                       fmt.Fprintf(conn, "CONNECT %s HTTP/1.1\r\n", addr)
-                       fmt.Fprintf(conn, "Host: %s\r\n", addr)
-                       if proxyAuth != "" {
-                               fmt.Fprintf(conn, "Proxy-Authorization: %s\r\n", proxyAuth)
-                       }
-                       fmt.Fprintf(conn, "\r\n")
+       pa := cm.proxyAuth()
 
-                       // Read response.
-                       // Okay to use and discard buffered reader here, because
-                       // TLS server will not speak until spoken to.
-                       br := bufio.NewReader(conn)
-                       resp, err := ReadResponse(br, "CONNECT")
-                       if err != nil {
-                               return nil, err
-                       }
-                       if resp.StatusCode != 200 {
-                               f := strings.Split(resp.Status, " ", 2)
-                               return nil, os.ErrorString(f[1])
+       pconn := &persistConn{
+               t:        t,
+               cacheKey: cm.String(),
+               conn:     conn,
+               reqch:    make(chan requestAndChan, 50),
+       }
+       newClientConnFunc := NewClientConn
+
+       switch {
+       case cm.proxyURL == nil:
+               // Do nothing.
+       case cm.targetScheme == "http":
+               newClientConnFunc = NewProxyClientConn
+               if pa != "" {
+                       pconn.mutateRequestFunc = func(req *Request) {
+                               if req.Header == nil {
+                                       req.Header = make(Header)
+                               }
+                               req.Header.Set("Proxy-Authorization", pa)
                        }
                }
+       case cm.targetScheme == "https":
+               fmt.Fprintf(conn, "CONNECT %s HTTP/1.1\r\n", cm.targetAddr)
+               fmt.Fprintf(conn, "Host: %s\r\n", cm.targetAddr)
+               if pa != "" {
+                       fmt.Fprintf(conn, "Proxy-Authorization: %s\r\n", pa)
+               }
+               fmt.Fprintf(conn, "\r\n")
 
+               // Read response.
+               // Okay to use and discard buffered reader here, because
+               // TLS server will not speak until spoken to.
+               br := bufio.NewReader(conn)
+               resp, err := ReadResponse(br, "CONNECT")
+               if err != nil {
+                       conn.Close()
+                       return nil, err
+               }
+               if resp.StatusCode != 200 {
+                       f := strings.Split(resp.Status, " ", 2)
+                       conn.Close()
+                       return nil, os.ErrorString(f[1])
+               }
+       }
+
+       if cm.targetScheme == "https" {
                // Initiate TLS and check remote host name against certificate.
                conn = tls.Client(conn, nil)
                if err = conn.(*tls.Conn).Handshake(); err != nil {
                        return nil, err
                }
-               h := req.URL.Host
-               if hasPort(h) {
-                       h = h[:strings.LastIndex(h, ":")]
-               }
-               if err = conn.(*tls.Conn).VerifyHostname(h); err != nil {
+               if err = conn.(*tls.Conn).VerifyHostname(cm.tlsHost()); err != nil {
                        return nil, err
                }
+               pconn.conn = conn
        }
 
-       err = write(req, conn)
-       if err != nil {
-               conn.Close()
-               return nil, err
+       pconn.br = bufio.NewReader(pconn.conn)
+       pconn.cc = newClientConnFunc(conn, pconn.br)
+       pconn.cc.readRes = readResponseWithEOFSignal
+       go pconn.readLoop()
+       return pconn, nil
+}
+
+// useProxy returns true if requests to addr should use a proxy,
+// according to the NO_PROXY or no_proxy environment variable.
+func (t *Transport) useProxy(addr string) bool {
+       if len(addr) == 0 {
+               return true
+       }
+       no_proxy := t.getenvEitherCase("NO_PROXY")
+       if no_proxy == "*" {
+               return false
+       }
+
+       addr = strings.ToLower(strings.TrimSpace(addr))
+       if hasPort(addr) {
+               addr = addr[:strings.LastIndex(addr, ":")]
        }
 
-       reader := bufio.NewReader(conn)
-       resp, err = ReadResponse(reader, req.Method)
+       for _, p := range strings.Split(no_proxy, ",", -1) {
+               p = strings.ToLower(strings.TrimSpace(p))
+               if len(p) == 0 {
+                       continue
+               }
+               if hasPort(p) {
+                       p = p[:strings.LastIndex(p, ":")]
+               }
+               if addr == p || (p[0] == '.' && (strings.HasSuffix(addr, p) || addr == p[1:])) {
+                       return false
+               }
+       }
+       return true
+}
+
+// connectMethod is the map key (in its String form) for keeping persistent
+// TCP connections alive for subsequent HTTP requests.
+//
+// A connect method may be of the following types:
+//
+// Cache key form                Description
+// -----------------             -------------------------
+// ||http|foo.com                http directly to server, no proxy
+// ||https|foo.com               https directly to server, no proxy
+// http://proxy.com|https|foo.com  http to proxy, then CONNECT to foo.com
+// http://proxy.com|http           http to proxy, http to anywhere after that
+//
+// Note: no support to https to the proxy yet.
+//
+type connectMethod struct {
+       proxyURL     *URL   // "" for no proxy, else full proxy URL
+       targetScheme string // "http" or "https"
+       targetAddr   string // Not used if proxy + http targetScheme (4th example in table)
+}
+
+func (ck *connectMethod) String() string {
+       proxyStr := ""
+       if ck.proxyURL != nil {
+               proxyStr = ck.proxyURL.String()
+       }
+       return strings.Join([]string{proxyStr, ck.targetScheme, ck.targetAddr}, "|")
+}
+
+// addr returns the first hop "host:port" to which we need to TCP connect.
+func (cm *connectMethod) addr() string {
+       if cm.proxyURL != nil {
+               return canonicalAddr(cm.proxyURL)
+       }
+       return cm.targetAddr
+}
+
+// tlsHost returns the host name to match against the peer's
+// TLS certificate.
+func (cm *connectMethod) tlsHost() string {
+       h := cm.targetAddr
+       if hasPort(h) {
+               h = h[:strings.LastIndex(h, ":")]
+       }
+       return h
+}
+
+type readResult struct {
+       res *Response // either res or err will be set
+       err os.Error
+}
+
+type writeRequest struct {
+       // Set by client (in pc.roundTrip)
+       req   *Request
+       resch chan *readResult
+
+       // Set by writeLoop if an error writing headers.
+       writeErr os.Error
+}
+
+// persistConn wraps a connection, usually a persistent one
+// (but may be used for non-keep-alive requests as well)
+type persistConn struct {
+       t                 *Transport
+       cacheKey          string // its connectMethod.String()
+       conn              net.Conn
+       cc                *ClientConn
+       br                *bufio.Reader
+       reqch             chan requestAndChan // written by roundTrip(); read by readLoop()
+       mutateRequestFunc func(*Request)      // nil or func to modify each outbound request
+
+       lk                   sync.Mutex // guards numExpectedResponses and broken
+       numExpectedResponses int
+       broken               bool // an error has happened on this connection; marked broken so it's not reused.
+}
+
+func (pc *persistConn) isBroken() bool {
+       pc.lk.Lock()
+       defer pc.lk.Unlock()
+       return pc.broken
+}
+
+func (pc *persistConn) expectingResponse() bool {
+       pc.lk.Lock()
+       defer pc.lk.Unlock()
+       return pc.numExpectedResponses > 0
+}
+
+func (pc *persistConn) readLoop() {
+       alive := true
+       for alive {
+               pb, err := pc.br.Peek(1)
+               if err != nil {
+                       if (err == os.EOF || err == os.EINVAL) && !pc.expectingResponse() {
+                               // Remote side closed on us.  (We probably hit their
+                               // max idle timeout)
+                               pc.close()
+                               return
+                       }
+               }
+               if !pc.expectingResponse() {
+                       log.Printf("Unsolicited response received on idle HTTP channel starting with %q; err=%v",
+                               string(pb), err)
+                       pc.close()
+                       return
+               }
+
+               rc := <-pc.reqch
+               resp, err := pc.cc.Read(rc.req)
+               if err == nil && !rc.req.Close {
+                       pc.t.putIdleConn(pc)
+               }
+               if err == ErrPersistEOF {
+                       // Succeeded, but we can't send any more
+                       // persistent connections on this again.  We
+                       // hide this error to upstream callers.
+                       alive = false
+                       err = nil
+               } else if err != nil {
+                       alive = false
+               }
+               rc.ch <- responseAndError{resp, err}
+
+               // Wait for the just-returned response body to be fully consumed
+               // before we race and peek on the underlying bufio reader.
+               if alive {
+                       <-resp.Body.(*bodyEOFSignal).ch
+               }
+       }
+}
+
+type responseAndError struct {
+       res *Response
+       err os.Error
+}
+
+type requestAndChan struct {
+       req *Request
+       ch  chan responseAndError
+}
+
+func (pc *persistConn) roundTrip(req *Request) (resp *Response, err os.Error) {
+       if pc.mutateRequestFunc != nil {
+               pc.mutateRequestFunc(req)
+       }
+
+       pc.lk.Lock()
+       pc.numExpectedResponses++
+       pc.lk.Unlock()
+
+       err = pc.cc.Write(req)
        if err != nil {
-               conn.Close()
-               return nil, err
+               pc.close()
+               return
+       }
+
+       ch := make(chan responseAndError, 1)
+       pc.reqch <- requestAndChan{req, ch}
+       re := <-ch
+       pc.lk.Lock()
+       pc.numExpectedResponses--
+       pc.lk.Unlock()
+       return re.res, re.err
+}
+
+func (pc *persistConn) close() {
+       pc.lk.Lock()
+       defer pc.lk.Unlock()
+       pc.broken = true
+       pc.cc.Close()
+       pc.conn.Close()
+       pc.mutateRequestFunc = nil
+}
+
+var portMap = map[string]string{
+       "http":  "80",
+       "https": "443",
+}
+
+// canonicalAddr returns url.Host but always with a ":port" suffix
+func canonicalAddr(url *URL) string {
+       addr := url.Host
+       if !hasPort(addr) {
+               return addr + ":" + portMap[url.Scheme]
        }
+       return addr
+}
 
-       resp.Body = readClose{resp.Body, conn}
+func responseIsKeepAlive(res *Response) bool {
+       // TODO: implement.  for now just always shutting down the connection.
+       return false
+}
+
+// readResponseWithEOFSignal is a wrapper around ReadResponse that replaces
+// the response body with a bodyEOFSignal-wrapped version.
+func readResponseWithEOFSignal(r *bufio.Reader, requestMethod string) (resp *Response, err os.Error) {
+       resp, err = ReadResponse(r, requestMethod)
+       if err == nil {
+               resp.Body = &bodyEOFSignal{resp.Body, make(chan bool, 1), false}
+       }
+       return
+}
+
+// bodyEOFSignal wraps a ReadCloser but sends on ch once once
+// the wrapped ReadCloser is fully consumed (including on Close)
+type bodyEOFSignal struct {
+       body io.ReadCloser
+       ch   chan bool
+       done bool
+}
+
+func (es *bodyEOFSignal) Read(p []byte) (n int, err os.Error) {
+       n, err = es.body.Read(p)
+       if err == os.EOF && !es.done {
+               es.ch <- true
+               es.done = true
+       }
+       return
+}
+
+func (es *bodyEOFSignal) Close() (err os.Error) {
+       err = es.body.Close()
+       if err == nil && !es.done {
+               es.ch <- true
+               es.done = true
+       }
        return
 }
index 2bdca7b99b760654744a22dd322c866de64126c6..5c3e1cdb582a92d0a2b11c2719e710646e1f7ba1 100644 (file)
@@ -11,9 +11,204 @@ import (
        . "http"
        "http/httptest"
        "io/ioutil"
+       "os"
        "testing"
+       "time"
 )
 
+// TODO: test 5 pipelined requests with responses: 1) OK, 2) OK, Connection: Close
+//       and then verify that the final 2 responses get errors back.
+
+// hostPortHandler writes back the client's "host:port".
+var hostPortHandler = HandlerFunc(func(w ResponseWriter, r *Request) {
+       if r.FormValue("close") == "true" {
+               w.Header().Set("Connection", "close")
+       }
+       fmt.Fprintf(w, "%s", r.RemoteAddr)
+})
+
+// Two subsequent requests and verify their response is the same.
+// The response from the server is our own IP:port
+func TestTransportKeepAlives(t *testing.T) {
+       ts := httptest.NewServer(hostPortHandler)
+       defer ts.Close()
+
+       for _, disableKeepAlive := range []bool{false, true} {
+               tr := &Transport{DisableKeepAlives: disableKeepAlive}
+               c := &Client{Transport: tr}
+
+               fetch := func(n int) string {
+                       res, _, err := c.Get(ts.URL)
+                       if err != nil {
+                               t.Fatalf("error in disableKeepAlive=%v, req #%d, GET: %v", disableKeepAlive, n, err)
+                       }
+                       body, err := ioutil.ReadAll(res.Body)
+                       if err != nil {
+                               t.Fatalf("error in disableKeepAlive=%v, req #%d, ReadAll: %v", disableKeepAlive, n, err)
+                       }
+                       return string(body)
+               }
+
+               body1 := fetch(1)
+               body2 := fetch(2)
+
+               bodiesDiffer := body1 != body2
+               if bodiesDiffer != disableKeepAlive {
+                       t.Errorf("error in disableKeepAlive=%v. unexpected bodiesDiffer=%v; body1=%q; body2=%q",
+                               disableKeepAlive, bodiesDiffer, body1, body2)
+               }
+       }
+}
+
+func TestTransportConnectionCloseOnResponse(t *testing.T) {
+       ts := httptest.NewServer(hostPortHandler)
+       defer ts.Close()
+
+       for _, connectionClose := range []bool{false, true} {
+               tr := &Transport{}
+               c := &Client{Transport: tr}
+
+               fetch := func(n int) string {
+                       req := new(Request)
+                       var err os.Error
+                       req.URL, err = ParseURL(ts.URL + fmt.Sprintf("?close=%v", connectionClose))
+                       if err != nil {
+                               t.Fatalf("URL parse error: %v", err)
+                       }
+                       req.Method = "GET"
+                       req.Proto = "HTTP/1.1"
+                       req.ProtoMajor = 1
+                       req.ProtoMinor = 1
+
+                       res, err := c.Do(req)
+                       if err != nil {
+                               t.Fatalf("error in connectionClose=%v, req #%d, Do: %v", connectionClose, n, err)
+                       }
+                       body, err := ioutil.ReadAll(res.Body)
+                       if err != nil {
+                               t.Fatalf("error in connectionClose=%v, req #%d, ReadAll: %v", connectionClose, n, err)
+                       }
+                       return string(body)
+               }
+
+               body1 := fetch(1)
+               body2 := fetch(2)
+               bodiesDiffer := body1 != body2
+               if bodiesDiffer != connectionClose {
+                       t.Errorf("error in connectionClose=%v. unexpected bodiesDiffer=%v; body1=%q; body2=%q",
+                               connectionClose, bodiesDiffer, body1, body2)
+               }
+       }
+}
+
+func TestTransportConnectionCloseOnRequest(t *testing.T) {
+       ts := httptest.NewServer(hostPortHandler)
+       defer ts.Close()
+
+       for _, connectionClose := range []bool{false, true} {
+               tr := &Transport{}
+               c := &Client{Transport: tr}
+
+               fetch := func(n int) string {
+                       req := new(Request)
+                       var err os.Error
+                       req.URL, err = ParseURL(ts.URL)
+                       if err != nil {
+                               t.Fatalf("URL parse error: %v", err)
+                       }
+                       req.Method = "GET"
+                       req.Proto = "HTTP/1.1"
+                       req.ProtoMajor = 1
+                       req.ProtoMinor = 1
+                       req.Close = connectionClose
+
+                       res, err := c.Do(req)
+                       if err != nil {
+                               t.Fatalf("error in connectionClose=%v, req #%d, Do: %v", connectionClose, n, err)
+                       }
+                       body, err := ioutil.ReadAll(res.Body)
+                       if err != nil {
+                               t.Fatalf("error in connectionClose=%v, req #%d, ReadAll: %v", connectionClose, n, err)
+                       }
+                       return string(body)
+               }
+
+               body1 := fetch(1)
+               body2 := fetch(2)
+               bodiesDiffer := body1 != body2
+               if bodiesDiffer != connectionClose {
+                       t.Errorf("error in connectionClose=%v. unexpected bodiesDiffer=%v; body1=%q; body2=%q",
+                               connectionClose, bodiesDiffer, body1, body2)
+               }
+       }
+}
+
+func TestTransportIdleCacheKeys(t *testing.T) {
+       ts := httptest.NewServer(hostPortHandler)
+       defer ts.Close()
+
+       tr := &Transport{DisableKeepAlives: false}
+       c := &Client{Transport: tr}
+
+       if e, g := 0, len(tr.IdleConnKeysForTesting()); e != g {
+               t.Errorf("After CloseIdleConnections expected %d idle conn cache keys; got %d", e, g)
+       }
+
+       if _, _, err := c.Get(ts.URL); err != nil {
+               t.Error(err)
+       }
+
+       keys := tr.IdleConnKeysForTesting()
+       if e, g := 1, len(keys); e != g {
+               t.Fatalf("After Get expected %d idle conn cache keys; got %d", e, g)
+       }
+
+       if e := "|http|" + ts.Listener.Addr().String(); keys[0] != e {
+               t.Logf("Expected idle cache key %q; got %q", e, keys[0])
+       }
+
+       tr.CloseIdleConnections()
+       if e, g := 0, len(tr.IdleConnKeysForTesting()); e != g {
+               t.Errorf("After CloseIdleConnections expected %d idle conn cache keys; got %d", e, g)
+       }
+}
+
+func TestTransportServerClosingUnexpectedly(t *testing.T) {
+       ts := httptest.NewServer(hostPortHandler)
+       defer ts.Close()
+
+       tr := &Transport{}
+       c := &Client{Transport: tr}
+
+       fetch := func(n int) string {
+               res, _, err := c.Get(ts.URL)
+               if err != nil {
+                       t.Fatalf("error in req #%d, GET: %v", n, err)
+               }
+               body, err := ioutil.ReadAll(res.Body)
+               if err != nil {
+                       t.Fatalf("error in req #%d, ReadAll: %v", n, err)
+               }
+               res.Body.Close()
+               return string(body)
+       }
+
+       body1 := fetch(1)
+       body2 := fetch(2)
+
+       ts.CloseClientConnections() // surprise!
+       time.Sleep(25e6)            // idle for a bit (test is inherently racey, but expectedly)
+
+       body3 := fetch(3)
+
+       if body1 != body2 {
+               t.Errorf("expected body1 and body2 to be equal")
+       }
+       if body2 == body3 {
+               t.Errorf("expected body2 and body3 to be different")
+       }
+}
+
 func TestTransportNilURL(t *testing.T) {
        ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
                fmt.Fprintf(w, "Hi")
@@ -28,9 +223,8 @@ func TestTransportNilURL(t *testing.T) {
        req.ProtoMajor = 1
        req.ProtoMinor = 1
 
-       // TODO(bradfitz): test &transport{} and not DefaultTransport
-       // once Transport is exported.
-       res, err := DefaultTransport.RoundTrip(req)
+       tr := &Transport{}
+       res, err := tr.RoundTrip(req)
        if err != nil {
                t.Fatalf("unexpected RoundTrip error: %v", err)
        }