// higher-level protocol details such as redirects,
// authentication, or cookies.
//
- // RoundTrip may modify the request. The request Headers field is
- // guaranteed to be initialized.
- RoundTrip(req *Request) (resp *Response, err os.Error)
+ // RoundTrip should not modify the request, except for
+ // consuming the Body. The request's URL and Header fields
+ // are guaranteed to be initialized.
+ RoundTrip(*Request) (*Response, os.Error)
}
// Given a string of the form "host", "host:port", or "[ipv6::address]:port",
if t == nil {
t = DefaultTransport
if t == nil {
- err = os.NewError("no http.Client.Transport or http.DefaultTransport")
+ err = os.NewError("http: no Client.Transport or DefaultTransport")
return
}
}
+ if req.URL == nil {
+ return nil, os.NewError("http: nil Request.URL")
+ }
+
// Most the callers of send (Get, Post, et al) don't need
// Headers, leaving it uninitialized. We guarantee to the
// Transport that this has been initialized, though.
// hasn't been set to "identity", Write adds "Transfer-Encoding:
// chunked" to the header. Body is closed after it is sent.
func (req *Request) Write(w io.Writer) os.Error {
- return req.write(w, false)
+ return req.write(w, false, nil)
}
// WriteProxy is like Write but writes the request in the form
// either case, WriteProxy also writes a Host header, using either
// req.Host or req.URL.Host.
func (req *Request) WriteProxy(w io.Writer) os.Error {
- return req.write(w, true)
+ return req.write(w, true, nil)
}
func (req *Request) dumpWrite(w io.Writer) os.Error {
return nil
}
-func (req *Request) write(w io.Writer, usingProxy bool) os.Error {
+// extraHeaders may be nil
+func (req *Request) write(w io.Writer, usingProxy bool, extraHeaders Header) os.Error {
host := req.Host
if host == "" {
if req.URL == nil {
return err
}
+ if extraHeaders != nil {
+ err = extraHeaders.Write(bw)
+ if err != nil {
+ return err
+ }
+ }
+
io.WriteString(bw, "\r\n")
// Write body and trailer
}
}
+// transportRequest is a wrapper around a *Request that adds
+// optional extra headers to write.
+type transportRequest struct {
+ *Request // original request, not to be mutated
+ extra Header // extra headers to write, or nil
+}
+
+func (tr *transportRequest) extraHeaders() Header {
+ if tr.extra == nil {
+ tr.extra = make(Header)
+ }
+ return tr.extra
+}
+
// RoundTrip implements the RoundTripper interface.
func (t *Transport) RoundTrip(req *Request) (resp *Response, err os.Error) {
if req.URL == nil {
return nil, os.NewError("http: nil Request.URL")
}
+ if req.Header == nil {
+ return nil, os.NewError("http: nil Request.Header")
+ }
if req.URL.Scheme != "http" && req.URL.Scheme != "https" {
t.lk.Lock()
var rt RoundTripper
}
return rt.RoundTrip(req)
}
-
- cm, err := t.connectMethodForRequest(req)
+ treq := &transportRequest{Request: req}
+ cm, err := t.connectMethodForRequest(treq)
if err != nil {
return nil, err
}
return nil, err
}
- return pconn.roundTrip(req)
+ return pconn.roundTrip(treq)
}
// RegisterProtocol registers a new protocol with scheme.
return os.Getenv(strings.ToLower(k))
}
-func (t *Transport) connectMethodForRequest(req *Request) (*connectMethod, os.Error) {
+func (t *Transport) connectMethodForRequest(treq *transportRequest) (*connectMethod, os.Error) {
cm := &connectMethod{
- targetScheme: req.URL.Scheme,
- targetAddr: canonicalAddr(req.URL),
+ targetScheme: treq.URL.Scheme,
+ targetAddr: canonicalAddr(treq.URL),
}
if t.Proxy != nil {
var err os.Error
- cm.proxyURL, err = t.Proxy(req)
+ cm.proxyURL, err = t.Proxy(treq.Request)
if err != nil {
return nil, err
}
conn: conn,
reqch: make(chan requestAndChan, 50),
}
- newClientConnFunc := NewClientConn
switch {
case cm.proxyURL == nil:
// Do nothing.
case cm.targetScheme == "http":
- newClientConnFunc = NewProxyClientConn
+ pconn.isProxy = true
if pa != "" {
- pconn.mutateRequestFunc = func(req *Request) {
- if req.Header == nil {
- req.Header = make(Header)
- }
- req.Header.Set("Proxy-Authorization", pa)
+ pconn.mutateHeaderFunc = func(h Header) {
+ h.Set("Proxy-Authorization", pa)
}
}
case cm.targetScheme == "https":
}
pconn.br = bufio.NewReader(pconn.conn)
- pconn.cc = newClientConnFunc(conn, pconn.br)
+ pconn.cc = NewClientConn(conn, pconn.br)
go pconn.readLoop()
return pconn, nil
}
return h
}
-type readResult struct {
- res *Response // either res or err will be set
- err os.Error
-}
-
-type writeRequest struct {
- // Set by client (in pc.roundTrip)
- req *Request
- resch chan *readResult
-
- // Set by writeLoop if an error writing headers.
- writeErr os.Error
-}
-
// persistConn wraps a connection, usually a persistent one
// (but may be used for non-keep-alive requests as well)
type persistConn struct {
- t *Transport
- cacheKey string // its connectMethod.String()
- conn net.Conn
- cc *ClientConn
- br *bufio.Reader
- reqch chan requestAndChan // written by roundTrip(); read by readLoop()
- mutateRequestFunc func(*Request) // nil or func to modify each outbound request
+ t *Transport
+ cacheKey string // its connectMethod.String()
+ conn net.Conn
+ cc *ClientConn
+ br *bufio.Reader
+ reqch chan requestAndChan // written by roundTrip(); read by readLoop()
+ isProxy bool
+
+ // mutateHeaderFunc is an optional func to modify extra
+ // headers on each outbound request before it's written. (the
+ // original Request given to RoundTrip is not modified)
+ mutateHeaderFunc func(Header)
lk sync.Mutex // guards numExpectedResponses and broken
numExpectedResponses int
if err != nil || resp.ContentLength == 0 {
return resp, err
}
- if rc.addedGzip {
- forReq.Header.Del("Accept-Encoding")
- }
if rc.addedGzip && resp.Header.Get("Content-Encoding") == "gzip" {
resp.Header.Del("Content-Encoding")
resp.Header.Del("Content-Length")
addedGzip bool
}
-func (pc *persistConn) roundTrip(req *Request) (resp *Response, err os.Error) {
- if pc.mutateRequestFunc != nil {
- pc.mutateRequestFunc(req)
+func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err os.Error) {
+ if pc.mutateHeaderFunc != nil {
+ pc.mutateHeaderFunc(req.extraHeaders())
}
// Ask for a compressed version if the caller didn't set their
requestedGzip := false
if !pc.t.DisableCompression && req.Header.Get("Accept-Encoding") == "" {
// Request gzip only, not deflate. Deflate is ambiguous and
- // as universally supported anyway.
+ // not as universally supported anyway.
// See: http://www.gzip.org/zlib/zlib_faq.html#faq38
requestedGzip = true
- req.Header.Set("Accept-Encoding", "gzip")
+ req.extraHeaders().Set("Accept-Encoding", "gzip")
}
pc.lk.Lock()
pc.numExpectedResponses++
pc.lk.Unlock()
- err = pc.cc.Write(req)
+ pc.cc.writeReq = func(r *Request, w io.Writer) os.Error {
+ return r.write(w, pc.isProxy, req.extra)
+ }
+
+ err = pc.cc.Write(req.Request)
if err != nil {
pc.close()
return
}
ch := make(chan responseAndError, 1)
- pc.reqch <- requestAndChan{req, ch, requestedGzip}
+ pc.reqch <- requestAndChan{req.Request, ch, requestedGzip}
re := <-ch
pc.lk.Lock()
pc.numExpectedResponses--
pc.broken = true
pc.cc.Close()
pc.conn.Close()
- pc.mutateRequestFunc = nil
+ pc.mutateHeaderFunc = nil
}
var portMap = map[string]string{
// Requests with other accept-encoding should pass through unmodified
{"foo", "foo", false},
// Requests with accept-encoding == gzip should be passed through
- {"gzip", "gzip", true}}
+ {"gzip", "gzip", true},
+}
// Test that the modification made to the Request by the RoundTripper is cleaned up
func TestRoundTripGzip(t *testing.T) {
ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) {
accept := req.Header.Get("Accept-Encoding")
if expect := req.FormValue("expect_accept"); accept != expect {
- t.Errorf("Accept-Encoding = %q, want %q", accept, expect)
+ t.Errorf("in handler, test %v: Accept-Encoding = %q, want %q",
+ req.FormValue("testnum"), accept, expect)
}
if accept == "gzip" {
rw.Header().Set("Content-Encoding", "gzip")
for i, test := range roundTripTests {
// Test basic request (no accept-encoding)
- req, _ := NewRequest("GET", ts.URL+"?expect_accept="+test.expectAccept, nil)
- req.Header.Set("Accept-Encoding", test.accept)
+ req, _ := NewRequest("GET", fmt.Sprintf("%s/?testnum=%d&expect_accept=%s", ts.URL, i, test.expectAccept), nil)
+ if test.accept != "" {
+ req.Header.Set("Accept-Encoding", test.accept)
+ }
res, err := DefaultTransport.RoundTrip(req)
var body []byte
if test.compressed {
}
if err != nil {
t.Errorf("%d. Error: %q", i, err)
- } else {
- if g, e := string(body), responseBody; g != e {
- t.Errorf("%d. body = %q; want %q", i, g, e)
- }
- if g, e := req.Header.Get("Accept-Encoding"), test.accept; g != e {
- t.Errorf("%d. Accept-Encoding = %q; want %q", i, g, e)
- }
- if g, e := res.Header.Get("Content-Encoding"), test.accept; g != e {
- t.Errorf("%d. Content-Encoding = %q; want %q", i, g, e)
- }
+ continue
+ }
+ if g, e := string(body), responseBody; g != e {
+ t.Errorf("%d. body = %q; want %q", i, g, e)
+ }
+ if g, e := req.Header.Get("Accept-Encoding"), test.accept; g != e {
+ t.Errorf("%d. Accept-Encoding = %q; want %q (it was mutated, in violation of RoundTrip contract)", i, g, e)
+ }
+ if g, e := res.Header.Get("Content-Encoding"), test.accept; g != e {
+ t.Errorf("%d. Content-Encoding = %q; want %q", i, g, e)
}
}