]> Cypherpunks repositories - gostls13.git/commitdiff
net/http: on Transport body write error, wait briefly for a response
authorBrad Fitzpatrick <bradfitz@golang.org>
Thu, 23 Jul 2015 21:05:54 +0000 (14:05 -0700)
committerRuss Cox <rsc@golang.org>
Mon, 27 Jul 2015 16:24:01 +0000 (16:24 +0000)
From https://github.com/golang/go/issues/11745#issuecomment-123555313 :

The http.RoundTripper interface says you get either a *Response, or an
error.

But in the case of a client writing a large request and the server
replying prematurely (e.g. 403 Forbidden) and closing the connection
without reading the request body, what does the client want? The 403
response, or the error that the body couldn't be copied?

This CL implements the aforementioned comment's option c), making the
Transport give an N millisecond advantage to responses over body write
errors.

Updates #11745

Change-Id: I4485a782505d54de6189f6856a7a1f33ce4d5e5e
Reviewed-on: https://go-review.googlesource.com/12590
Reviewed-by: Russ Cox <rsc@golang.org>
src/net/http/transport.go
src/net/http/transport_test.go

index 41e02fc5802befdc020bf459519967f2d5aafbb9..6f181efc1a128e0a10c10482e506aad29f5b8e2a 100644 (file)
@@ -1164,6 +1164,19 @@ WaitResponse:
        for {
                select {
                case err := <-writeErrCh:
+                       if isSyscallWriteError(err) {
+                               // Issue 11745. If we failed to write the request
+                               // body, it's possible the server just heard enough
+                               // and already wrote to us. Prioritize the server's
+                               // response over returning a body write error.
+                               select {
+                               case re = <-resc:
+                                       pc.close()
+                                       break WaitResponse
+                               case <-time.After(50 * time.Millisecond):
+                                       // Fall through.
+                               }
+                       }
                        if err != nil {
                                re = responseAndError{nil, err}
                                pc.close()
@@ -1366,3 +1379,16 @@ type fakeLocker struct{}
 
 func (fakeLocker) Lock()   {}
 func (fakeLocker) Unlock() {}
+
+func isSyscallWriteError(err error) bool {
+       switch e := err.(type) {
+       case *url.Error:
+               return isSyscallWriteError(e.Err)
+       case *net.OpError:
+               return e.Op == "write" && isSyscallWriteError(e.Err)
+       case *os.SyscallError:
+               return e.Syscall == "write"
+       default:
+               return false
+       }
+}
index d399552e4718d1e9999e7b7d86f103f14d9ff638..cae254b4daf00c1f71a8c27eac2cad47f444a8b4 100644 (file)
@@ -18,7 +18,6 @@ import (
        "io/ioutil"
        "log"
        "net"
-       "net/http"
        . "net/http"
        "net/http/httptest"
        "net/url"
@@ -320,7 +319,7 @@ func TestTransportReadToEndReusesConn(t *testing.T) {
                addrSeen[r.RemoteAddr]++
                if r.URL.Path == "/chunked/" {
                        w.WriteHeader(200)
-                       w.(http.Flusher).Flush()
+                       w.(Flusher).Flush()
                } else {
                        w.Header().Set("Content-Type", strconv.Itoa(len(msg)))
                        w.WriteHeader(200)
@@ -335,7 +334,7 @@ func TestTransportReadToEndReusesConn(t *testing.T) {
                wantLen := []int{len(msg), -1}[pi]
                addrSeen = make(map[string]int)
                for i := 0; i < 3; i++ {
-                       res, err := http.Get(ts.URL + path)
+                       res, err := Get(ts.URL + path)
                        if err != nil {
                                t.Errorf("Get %s: %v", path, err)
                                continue
@@ -1976,7 +1975,7 @@ func TestIdleConnChannelLeak(t *testing.T) {
 // then closes it.
 func TestTransportClosesRequestBody(t *testing.T) {
        defer afterTest(t)
-       ts := httptest.NewServer(http.HandlerFunc(func(w ResponseWriter, r *Request) {
+       ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
                io.Copy(ioutil.Discard, r.Body)
        }))
        defer ts.Close()
@@ -2346,13 +2345,13 @@ func TestTransportDialTLS(t *testing.T) {
 // Test for issue 8755
 // Ensure that if a proxy returns an error, it is exposed by RoundTrip
 func TestRoundTripReturnsProxyError(t *testing.T) {
-       badProxy := func(*http.Request) (*url.URL, error) {
+       badProxy := func(*Request) (*url.URL, error) {
                return nil, errors.New("errorMessage")
        }
 
        tr := &Transport{Proxy: badProxy}
 
-       req, _ := http.NewRequest("GET", "http://example.com", nil)
+       req, _ := NewRequest("GET", "http://example.com", nil)
 
        _, err := tr.RoundTrip(req)
 
@@ -2644,7 +2643,57 @@ func TestTransportFlushesBodyChunks(t *testing.T) {
        }
 }
 
-func wantBody(res *http.Response, err error, want string) error {
+// Issue 11745.
+func TestTransportPrefersResponseOverWriteError(t *testing.T) {
+       if testing.Short() {
+               t.Skip("skipping in short mode")
+       }
+       defer afterTest(t)
+       const contentLengthLimit = 1024 * 1024 // 1MB
+       ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+               if r.ContentLength >= contentLengthLimit {
+                       w.WriteHeader(StatusBadRequest)
+                       r.Body.Close()
+                       return
+               }
+               w.WriteHeader(StatusOK)
+       }))
+       defer ts.Close()
+
+       fail := 0
+       count := 100
+       bigBody := strings.Repeat("a", contentLengthLimit*2)
+       for i := 0; i < count; i++ {
+               req, err := NewRequest("PUT", ts.URL, strings.NewReader(bigBody))
+               if err != nil {
+                       t.Fatal(err)
+               }
+               tr := new(Transport)
+               defer tr.CloseIdleConnections()
+               client := &Client{Transport: tr}
+               resp, err := client.Do(req)
+               if err != nil {
+                       fail++
+                       t.Logf("%d = %#v", i, err)
+                       if ue, ok := err.(*url.Error); ok {
+                               t.Logf("urlErr = %#v", ue.Err)
+                               if ne, ok := ue.Err.(*net.OpError); ok {
+                                       t.Logf("netOpError = %#v", ne.Err)
+                               }
+                       }
+               } else {
+                       resp.Body.Close()
+                       if resp.StatusCode != 400 {
+                               t.Errorf("Expected status code 400, got %v", resp.Status)
+                       }
+               }
+       }
+       if fail > 0 {
+               t.Errorf("Failed %v out of %v\n", fail, count)
+       }
+}
+
+func wantBody(res *Response, err error, want string) error {
        if err != nil {
                return err
        }