]> Cypherpunks repositories - gostls13.git/commitdiff
net/http: add Transport.ResponseHeaderTimeout
authorBrad Fitzpatrick <bradfitz@golang.org>
Wed, 27 Feb 2013 16:47:08 +0000 (08:47 -0800)
committerBrad Fitzpatrick <bradfitz@golang.org>
Wed, 27 Feb 2013 16:47:08 +0000 (08:47 -0800)
Update #3362

R=golang-dev, adg, rsc
CC=golang-dev
https://golang.org/cl/7369055

src/pkg/net/http/transport.go
src/pkg/net/http/transport_test.go

index a94579de6da2dc623f340bf110398ddc0ef2f917..984c39154e4b3e23cceec6f1b7a169e0b20adc6e 100644 (file)
@@ -73,6 +73,12 @@ type Transport struct {
        // (keep-alive) to keep per-host.  If zero,
        // DefaultMaxIdleConnsPerHost is used.
        MaxIdleConnsPerHost int
+
+       // ResponseHeaderTimeout, if non-zero, specifies the amount of
+       // time to wait for a server's response headers after fully
+       // writing the request (including its body, if any). This
+       // time does not include the time to read the response body.
+       ResponseHeaderTimeout time.Duration
 }
 
 // ProxyFromEnvironment returns the URL of the proxy to use for a
@@ -743,6 +749,7 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err
        var re responseAndError
        var pconnDeadCh = pc.closech
        var failTicker <-chan time.Time
+       var respHeaderTimer <-chan time.Time
 WaitResponse:
        for {
                select {
@@ -752,6 +759,9 @@ WaitResponse:
                                pc.close()
                                break WaitResponse
                        }
+                       if d := pc.t.ResponseHeaderTimeout; d > 0 {
+                               respHeaderTimer = time.After(d)
+                       }
                case <-pconnDeadCh:
                        // The persist connection is dead. This shouldn't
                        // usually happen (only with Connection: close responses
@@ -768,7 +778,11 @@ WaitResponse:
                        pconnDeadCh = nil                               // avoid spinning
                        failTicker = time.After(100 * time.Millisecond) // arbitrary time to wait for resc
                case <-failTicker:
-                       re = responseAndError{nil, errors.New("net/http: transport closed before response was received")}
+                       re = responseAndError{err: errors.New("net/http: transport closed before response was received")}
+                       break WaitResponse
+               case <-respHeaderTimer:
+                       pc.close()
+                       re = responseAndError{err: errors.New("net/http: timeout awaiting response headers")}
                        break WaitResponse
                case re = <-resc:
                        break WaitResponse
index a0ba91735c1e73cf64127fc5f82dc3d0d39031fc..248e1507a9d09270cddc805bdc9478c5e4188c08 100644 (file)
@@ -1113,6 +1113,54 @@ func TestIssue4191_InfiniteGetToPutTimeout(t *testing.T) {
        ts.Close()
 }
 
+func TestTransportResponseHeaderTimeout(t *testing.T) {
+       defer checkLeakedTransports(t)
+       if testing.Short() {
+               t.Skip("skipping timeout test in -short mode")
+       }
+       const debug = false
+       mux := NewServeMux()
+       mux.HandleFunc("/fast", func(w ResponseWriter, r *Request) {})
+       mux.HandleFunc("/slow", func(w ResponseWriter, r *Request) {
+               time.Sleep(2 * time.Second)
+       })
+       ts := httptest.NewServer(mux)
+       defer ts.Close()
+
+       tr := &Transport{
+               ResponseHeaderTimeout: 500 * time.Millisecond,
+       }
+       defer tr.CloseIdleConnections()
+       c := &Client{Transport: tr}
+
+       tests := []struct {
+               path    string
+               want    int
+               wantErr string
+       }{
+               {path: "/fast", want: 200},
+               {path: "/slow", wantErr: "timeout awaiting response headers"},
+               {path: "/fast", want: 200},
+       }
+       for i, tt := range tests {
+               res, err := c.Get(ts.URL + tt.path)
+               if err != nil {
+                       if strings.Contains(err.Error(), tt.wantErr) {
+                               continue
+                       }
+                       t.Errorf("%d. unexpected error: %v", i, err)
+                       continue
+               }
+               if tt.wantErr != "" {
+                       t.Errorf("%d. no error. expected error: %v", i, tt.wantErr)
+                       continue
+               }
+               if res.StatusCode != tt.want {
+                       t.Errorf("%d for path %q status = %d; want %d", i, tt.path, res.StatusCode, tt.want)
+               }
+       }
+}
+
 type fooProto struct{}
 
 func (fooProto) RoundTrip(req *Request) (*Response, error) {