// (keep-alive) to keep per-host. If zero,
// DefaultMaxIdleConnsPerHost is used.
MaxIdleConnsPerHost int
+
+ // ResponseHeaderTimeout, if non-zero, specifies the amount of
+ // time to wait for a server's response headers after fully
+ // writing the request (including its body, if any). This
+ // time does not include the time to read the response body.
+ ResponseHeaderTimeout time.Duration
}
// ProxyFromEnvironment returns the URL of the proxy to use for a
var re responseAndError
var pconnDeadCh = pc.closech
var failTicker <-chan time.Time
+ var respHeaderTimer <-chan time.Time
WaitResponse:
for {
select {
pc.close()
break WaitResponse
}
+ if d := pc.t.ResponseHeaderTimeout; d > 0 {
+ respHeaderTimer = time.After(d)
+ }
case <-pconnDeadCh:
// The persist connection is dead. This shouldn't
// usually happen (only with Connection: close responses
pconnDeadCh = nil // avoid spinning
failTicker = time.After(100 * time.Millisecond) // arbitrary time to wait for resc
case <-failTicker:
- re = responseAndError{nil, errors.New("net/http: transport closed before response was received")}
+ re = responseAndError{err: errors.New("net/http: transport closed before response was received")}
+ break WaitResponse
+ case <-respHeaderTimer:
+ pc.close()
+ re = responseAndError{err: errors.New("net/http: timeout awaiting response headers")}
break WaitResponse
case re = <-resc:
break WaitResponse
ts.Close()
}
+func TestTransportResponseHeaderTimeout(t *testing.T) {
+ defer checkLeakedTransports(t)
+ if testing.Short() {
+ t.Skip("skipping timeout test in -short mode")
+ }
+ const debug = false
+ mux := NewServeMux()
+ mux.HandleFunc("/fast", func(w ResponseWriter, r *Request) {})
+ mux.HandleFunc("/slow", func(w ResponseWriter, r *Request) {
+ time.Sleep(2 * time.Second)
+ })
+ ts := httptest.NewServer(mux)
+ defer ts.Close()
+
+ tr := &Transport{
+ ResponseHeaderTimeout: 500 * time.Millisecond,
+ }
+ defer tr.CloseIdleConnections()
+ c := &Client{Transport: tr}
+
+ tests := []struct {
+ path string
+ want int
+ wantErr string
+ }{
+ {path: "/fast", want: 200},
+ {path: "/slow", wantErr: "timeout awaiting response headers"},
+ {path: "/fast", want: 200},
+ }
+ for i, tt := range tests {
+ res, err := c.Get(ts.URL + tt.path)
+ if err != nil {
+ if strings.Contains(err.Error(), tt.wantErr) {
+ continue
+ }
+ t.Errorf("%d. unexpected error: %v", i, err)
+ continue
+ }
+ if tt.wantErr != "" {
+ t.Errorf("%d. no error. expected error: %v", i, tt.wantErr)
+ continue
+ }
+ if res.StatusCode != tt.want {
+ t.Errorf("%d for path %q status = %d; want %d", i, tt.path, res.StatusCode, tt.want)
+ }
+ }
+}
+
type fooProto struct{}
func (fooProto) RoundTrip(req *Request) (*Response, error) {