for {
select {
case err := <-writeErrCh:
+ if isSyscallWriteError(err) {
+ // Issue 11745. If we failed to write the request
+ // body, it's possible the server just heard enough
+ // and already wrote to us. Prioritize the server's
+ // response over returning a body write error.
+ select {
+ case re = <-resc:
+ pc.close()
+ break WaitResponse
+ case <-time.After(50 * time.Millisecond):
+ // Fall through.
+ }
+ }
if err != nil {
re = responseAndError{nil, err}
pc.close()
func (fakeLocker) Lock() {}
func (fakeLocker) Unlock() {}
+
+func isSyscallWriteError(err error) bool {
+ switch e := err.(type) {
+ case *url.Error:
+ return isSyscallWriteError(e.Err)
+ case *net.OpError:
+ return e.Op == "write" && isSyscallWriteError(e.Err)
+ case *os.SyscallError:
+ return e.Syscall == "write"
+ default:
+ return false
+ }
+}
"io/ioutil"
"log"
"net"
- "net/http"
. "net/http"
"net/http/httptest"
"net/url"
addrSeen[r.RemoteAddr]++
if r.URL.Path == "/chunked/" {
w.WriteHeader(200)
- w.(http.Flusher).Flush()
+ w.(Flusher).Flush()
} else {
w.Header().Set("Content-Type", strconv.Itoa(len(msg)))
w.WriteHeader(200)
wantLen := []int{len(msg), -1}[pi]
addrSeen = make(map[string]int)
for i := 0; i < 3; i++ {
- res, err := http.Get(ts.URL + path)
+ res, err := Get(ts.URL + path)
if err != nil {
t.Errorf("Get %s: %v", path, err)
continue
// then closes it.
func TestTransportClosesRequestBody(t *testing.T) {
defer afterTest(t)
- ts := httptest.NewServer(http.HandlerFunc(func(w ResponseWriter, r *Request) {
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
io.Copy(ioutil.Discard, r.Body)
}))
defer ts.Close()
// Test for issue 8755
// Ensure that if a proxy returns an error, it is exposed by RoundTrip
func TestRoundTripReturnsProxyError(t *testing.T) {
- badProxy := func(*http.Request) (*url.URL, error) {
+ badProxy := func(*Request) (*url.URL, error) {
return nil, errors.New("errorMessage")
}
tr := &Transport{Proxy: badProxy}
- req, _ := http.NewRequest("GET", "http://example.com", nil)
+ req, _ := NewRequest("GET", "http://example.com", nil)
_, err := tr.RoundTrip(req)
}
}
-func wantBody(res *http.Response, err error, want string) error {
+// Issue 11745.
+func TestTransportPrefersResponseOverWriteError(t *testing.T) {
+ if testing.Short() {
+ t.Skip("skipping in short mode")
+ }
+ defer afterTest(t)
+ const contentLengthLimit = 1024 * 1024 // 1MB
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ if r.ContentLength >= contentLengthLimit {
+ w.WriteHeader(StatusBadRequest)
+ r.Body.Close()
+ return
+ }
+ w.WriteHeader(StatusOK)
+ }))
+ defer ts.Close()
+
+ fail := 0
+ count := 100
+ bigBody := strings.Repeat("a", contentLengthLimit*2)
+ for i := 0; i < count; i++ {
+ req, err := NewRequest("PUT", ts.URL, strings.NewReader(bigBody))
+ if err != nil {
+ t.Fatal(err)
+ }
+ tr := new(Transport)
+ defer tr.CloseIdleConnections()
+ client := &Client{Transport: tr}
+ resp, err := client.Do(req)
+ if err != nil {
+ fail++
+ t.Logf("%d = %#v", i, err)
+ if ue, ok := err.(*url.Error); ok {
+ t.Logf("urlErr = %#v", ue.Err)
+ if ne, ok := ue.Err.(*net.OpError); ok {
+ t.Logf("netOpError = %#v", ne.Err)
+ }
+ }
+ } else {
+ resp.Body.Close()
+ if resp.StatusCode != 400 {
+ t.Errorf("Expected status code 400, got %v", resp.Status)
+ }
+ }
+ }
+ if fail > 0 {
+ t.Errorf("Failed %v out of %v\n", fail, count)
+ }
+}
+
+func wantBody(res *Response, err error, want string) error {
if err != nil {
return err
}