]> Cypherpunks repositories - gostls13.git/commitdiff
http: RoundTrippers shouldn't mutate Request
authorBrad Fitzpatrick <bradfitz@golang.org>
Fri, 14 Oct 2011 21:16:43 +0000 (14:16 -0700)
committerBrad Fitzpatrick <bradfitz@golang.org>
Fri, 14 Oct 2011 21:16:43 +0000 (14:16 -0700)
Fixes #2146

R=rsc
CC=golang-dev
https://golang.org/cl/5284041

src/pkg/http/client.go
src/pkg/http/request.go
src/pkg/http/transport.go
src/pkg/http/transport_test.go

index 8997a07923b4c3d73f22125fe5be200bdc0dc4f0..bce9014c4bfb18fa98262dc7f0944b7b20e5c2ea 100644 (file)
@@ -56,9 +56,10 @@ type RoundTripper interface {
        // higher-level protocol details such as redirects,
        // authentication, or cookies.
        //
-       // RoundTrip may modify the request. The request Headers field is
-       // guaranteed to be initialized.
-       RoundTrip(req *Request) (resp *Response, err os.Error)
+       // RoundTrip should not modify the request, except for
+       // consuming the Body.  The request's URL and Header fields
+       // are guaranteed to be initialized.
+       RoundTrip(*Request) (*Response, os.Error)
 }
 
 // Given a string of the form "host", "host:port", or "[ipv6::address]:port",
@@ -96,11 +97,15 @@ func send(req *Request, t RoundTripper) (resp *Response, err os.Error) {
        if t == nil {
                t = DefaultTransport
                if t == nil {
-                       err = os.NewError("no http.Client.Transport or http.DefaultTransport")
+                       err = os.NewError("http: no Client.Transport or DefaultTransport")
                        return
                }
        }
 
+       if req.URL == nil {
+               return nil, os.NewError("http: nil Request.URL")
+       }
+
        // Most the callers of send (Get, Post, et al) don't need
        // Headers, leaving it uninitialized.  We guarantee to the
        // Transport that this has been initialized, though.
index 4f555ff5753fcba7df42dc6c3f9a8d3622f90fa4..02317e0c41944317e9e0c8afbbe9547dab8d2410 100644 (file)
@@ -275,7 +275,7 @@ const defaultUserAgent = "Go http package"
 // hasn't been set to "identity", Write adds "Transfer-Encoding:
 // chunked" to the header. Body is closed after it is sent.
 func (req *Request) Write(w io.Writer) os.Error {
-       return req.write(w, false)
+       return req.write(w, false, nil)
 }
 
 // WriteProxy is like Write but writes the request in the form
@@ -285,7 +285,7 @@ func (req *Request) Write(w io.Writer) os.Error {
 // either case, WriteProxy also writes a Host header, using either
 // req.Host or req.URL.Host.
 func (req *Request) WriteProxy(w io.Writer) os.Error {
-       return req.write(w, true)
+       return req.write(w, true, nil)
 }
 
 func (req *Request) dumpWrite(w io.Writer) os.Error {
@@ -333,7 +333,8 @@ func (req *Request) dumpWrite(w io.Writer) os.Error {
        return nil
 }
 
-func (req *Request) write(w io.Writer, usingProxy bool) os.Error {
+// extraHeaders may be nil
+func (req *Request) write(w io.Writer, usingProxy bool, extraHeaders Header) os.Error {
        host := req.Host
        if host == "" {
                if req.URL == nil {
@@ -394,6 +395,13 @@ func (req *Request) write(w io.Writer, usingProxy bool) os.Error {
                return err
        }
 
+       if extraHeaders != nil {
+               err = extraHeaders.Write(bw)
+               if err != nil {
+                       return err
+               }
+       }
+
        io.WriteString(bw, "\r\n")
 
        // Write body and trailer
index d46d565677f2f72da60279546abc524c638aaf1e..b0aea970873073d344b61ff35c12d0fc8453cf63 100644 (file)
@@ -100,11 +100,28 @@ func ProxyURL(fixedURL *url.URL) func(*Request) (*url.URL, os.Error) {
        }
 }
 
+// transportRequest is a wrapper around a *Request that adds
+// optional extra headers to write.
+type transportRequest struct {
+       *Request        // original request, not to be mutated
+       extra    Header // extra headers to write, or nil
+}
+
+func (tr *transportRequest) extraHeaders() Header {
+       if tr.extra == nil {
+               tr.extra = make(Header)
+       }
+       return tr.extra
+}
+
 // RoundTrip implements the RoundTripper interface.
 func (t *Transport) RoundTrip(req *Request) (resp *Response, err os.Error) {
        if req.URL == nil {
                return nil, os.NewError("http: nil Request.URL")
        }
+       if req.Header == nil {
+               return nil, os.NewError("http: nil Request.Header")
+       }
        if req.URL.Scheme != "http" && req.URL.Scheme != "https" {
                t.lk.Lock()
                var rt RoundTripper
@@ -117,8 +134,8 @@ func (t *Transport) RoundTrip(req *Request) (resp *Response, err os.Error) {
                }
                return rt.RoundTrip(req)
        }
-
-       cm, err := t.connectMethodForRequest(req)
+       treq := &transportRequest{Request: req}
+       cm, err := t.connectMethodForRequest(treq)
        if err != nil {
                return nil, err
        }
@@ -132,7 +149,7 @@ func (t *Transport) RoundTrip(req *Request) (resp *Response, err os.Error) {
                return nil, err
        }
 
-       return pconn.roundTrip(req)
+       return pconn.roundTrip(treq)
 }
 
 // RegisterProtocol registers a new protocol with scheme.
@@ -185,14 +202,14 @@ func getenvEitherCase(k string) string {
        return os.Getenv(strings.ToLower(k))
 }
 
-func (t *Transport) connectMethodForRequest(req *Request) (*connectMethod, os.Error) {
+func (t *Transport) connectMethodForRequest(treq *transportRequest) (*connectMethod, os.Error) {
        cm := &connectMethod{
-               targetScheme: req.URL.Scheme,
-               targetAddr:   canonicalAddr(req.URL),
+               targetScheme: treq.URL.Scheme,
+               targetAddr:   canonicalAddr(treq.URL),
        }
        if t.Proxy != nil {
                var err os.Error
-               cm.proxyURL, err = t.Proxy(req)
+               cm.proxyURL, err = t.Proxy(treq.Request)
                if err != nil {
                        return nil, err
                }
@@ -295,19 +312,15 @@ func (t *Transport) getConn(cm *connectMethod) (*persistConn, os.Error) {
                conn:     conn,
                reqch:    make(chan requestAndChan, 50),
        }
-       newClientConnFunc := NewClientConn
 
        switch {
        case cm.proxyURL == nil:
                // Do nothing.
        case cm.targetScheme == "http":
-               newClientConnFunc = NewProxyClientConn
+               pconn.isProxy = true
                if pa != "" {
-                       pconn.mutateRequestFunc = func(req *Request) {
-                               if req.Header == nil {
-                                       req.Header = make(Header)
-                               }
-                               req.Header.Set("Proxy-Authorization", pa)
+                       pconn.mutateHeaderFunc = func(h Header) {
+                               h.Set("Proxy-Authorization", pa)
                        }
                }
        case cm.targetScheme == "https":
@@ -351,7 +364,7 @@ func (t *Transport) getConn(cm *connectMethod) (*persistConn, os.Error) {
        }
 
        pconn.br = bufio.NewReader(pconn.conn)
-       pconn.cc = newClientConnFunc(conn, pconn.br)
+       pconn.cc = NewClientConn(conn, pconn.br)
        go pconn.readLoop()
        return pconn, nil
 }
@@ -447,30 +460,21 @@ func (cm *connectMethod) tlsHost() string {
        return h
 }
 
-type readResult struct {
-       res *Response // either res or err will be set
-       err os.Error
-}
-
-type writeRequest struct {
-       // Set by client (in pc.roundTrip)
-       req   *Request
-       resch chan *readResult
-
-       // Set by writeLoop if an error writing headers.
-       writeErr os.Error
-}
-
 // persistConn wraps a connection, usually a persistent one
 // (but may be used for non-keep-alive requests as well)
 type persistConn struct {
-       t                 *Transport
-       cacheKey          string // its connectMethod.String()
-       conn              net.Conn
-       cc                *ClientConn
-       br                *bufio.Reader
-       reqch             chan requestAndChan // written by roundTrip(); read by readLoop()
-       mutateRequestFunc func(*Request)      // nil or func to modify each outbound request
+       t        *Transport
+       cacheKey string // its connectMethod.String()
+       conn     net.Conn
+       cc       *ClientConn
+       br       *bufio.Reader
+       reqch    chan requestAndChan // written by roundTrip(); read by readLoop()
+       isProxy  bool
+
+       // mutateHeaderFunc is an optional func to modify extra
+       // headers on each outbound request before it's written. (the
+       // original Request given to RoundTrip is not modified)
+       mutateHeaderFunc func(Header)
 
        lk                   sync.Mutex // guards numExpectedResponses and broken
        numExpectedResponses int
@@ -526,9 +530,6 @@ func (pc *persistConn) readLoop() {
                        if err != nil || resp.ContentLength == 0 {
                                return resp, err
                        }
-                       if rc.addedGzip {
-                               forReq.Header.Del("Accept-Encoding")
-                       }
                        if rc.addedGzip && resp.Header.Get("Content-Encoding") == "gzip" {
                                resp.Header.Del("Content-Encoding")
                                resp.Header.Del("Content-Length")
@@ -604,9 +605,9 @@ type requestAndChan struct {
        addedGzip bool
 }
 
-func (pc *persistConn) roundTrip(req *Request) (resp *Response, err os.Error) {
-       if pc.mutateRequestFunc != nil {
-               pc.mutateRequestFunc(req)
+func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err os.Error) {
+       if pc.mutateHeaderFunc != nil {
+               pc.mutateHeaderFunc(req.extraHeaders())
        }
 
        // Ask for a compressed version if the caller didn't set their
@@ -616,24 +617,28 @@ func (pc *persistConn) roundTrip(req *Request) (resp *Response, err os.Error) {
        requestedGzip := false
        if !pc.t.DisableCompression && req.Header.Get("Accept-Encoding") == "" {
                // Request gzip only, not deflate. Deflate is ambiguous and 
-               // as universally supported anyway.
+               // not as universally supported anyway.
                // See: http://www.gzip.org/zlib/zlib_faq.html#faq38
                requestedGzip = true
-               req.Header.Set("Accept-Encoding", "gzip")
+               req.extraHeaders().Set("Accept-Encoding", "gzip")
        }
 
        pc.lk.Lock()
        pc.numExpectedResponses++
        pc.lk.Unlock()
 
-       err = pc.cc.Write(req)
+       pc.cc.writeReq = func(r *Request, w io.Writer) os.Error {
+               return r.write(w, pc.isProxy, req.extra)
+       }
+
+       err = pc.cc.Write(req.Request)
        if err != nil {
                pc.close()
                return
        }
 
        ch := make(chan responseAndError, 1)
-       pc.reqch <- requestAndChan{req, ch, requestedGzip}
+       pc.reqch <- requestAndChan{req.Request, ch, requestedGzip}
        re := <-ch
        pc.lk.Lock()
        pc.numExpectedResponses--
@@ -648,7 +653,7 @@ func (pc *persistConn) close() {
        pc.broken = true
        pc.cc.Close()
        pc.conn.Close()
-       pc.mutateRequestFunc = nil
+       pc.mutateHeaderFunc = nil
 }
 
 var portMap = map[string]string{
index a5dfe5ee3caf931a6db7ff80dae1f72dc582548a..f3162b9ede43f47fb6bb4a45d17f78deb5bb9af4 100644 (file)
@@ -372,7 +372,8 @@ var roundTripTests = []struct {
        // Requests with other accept-encoding should pass through unmodified
        {"foo", "foo", false},
        // Requests with accept-encoding == gzip should be passed through
-       {"gzip", "gzip", true}}
+       {"gzip", "gzip", true},
+}
 
 // Test that the modification made to the Request by the RoundTripper is cleaned up
 func TestRoundTripGzip(t *testing.T) {
@@ -380,7 +381,8 @@ func TestRoundTripGzip(t *testing.T) {
        ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) {
                accept := req.Header.Get("Accept-Encoding")
                if expect := req.FormValue("expect_accept"); accept != expect {
-                       t.Errorf("Accept-Encoding = %q, want %q", accept, expect)
+                       t.Errorf("in handler, test %v: Accept-Encoding = %q, want %q",
+                               req.FormValue("testnum"), accept, expect)
                }
                if accept == "gzip" {
                        rw.Header().Set("Content-Encoding", "gzip")
@@ -396,8 +398,10 @@ func TestRoundTripGzip(t *testing.T) {
 
        for i, test := range roundTripTests {
                // Test basic request (no accept-encoding)
-               req, _ := NewRequest("GET", ts.URL+"?expect_accept="+test.expectAccept, nil)
-               req.Header.Set("Accept-Encoding", test.accept)
+               req, _ := NewRequest("GET", fmt.Sprintf("%s/?testnum=%d&expect_accept=%s", ts.URL, i, test.expectAccept), nil)
+               if test.accept != "" {
+                       req.Header.Set("Accept-Encoding", test.accept)
+               }
                res, err := DefaultTransport.RoundTrip(req)
                var body []byte
                if test.compressed {
@@ -409,16 +413,16 @@ func TestRoundTripGzip(t *testing.T) {
                }
                if err != nil {
                        t.Errorf("%d. Error: %q", i, err)
-               } else {
-                       if g, e := string(body), responseBody; g != e {
-                               t.Errorf("%d. body = %q; want %q", i, g, e)
-                       }
-                       if g, e := req.Header.Get("Accept-Encoding"), test.accept; g != e {
-                               t.Errorf("%d. Accept-Encoding = %q; want %q", i, g, e)
-                       }
-                       if g, e := res.Header.Get("Content-Encoding"), test.accept; g != e {
-                               t.Errorf("%d. Content-Encoding = %q; want %q", i, g, e)
-                       }
+                       continue
+               }
+               if g, e := string(body), responseBody; g != e {
+                       t.Errorf("%d. body = %q; want %q", i, g, e)
+               }
+               if g, e := req.Header.Get("Accept-Encoding"), test.accept; g != e {
+                       t.Errorf("%d. Accept-Encoding = %q; want %q (it was mutated, in violation of RoundTrip contract)", i, g, e)
+               }
+               if g, e := res.Header.Get("Content-Encoding"), test.accept; g != e {
+                       t.Errorf("%d. Content-Encoding = %q; want %q", i, g, e)
                }
        }