]> Cypherpunks repositories - gostls13.git/commitdiff
net/http: don't ignore errors in Request.Write
authorBrad Fitzpatrick <bradfitz@golang.org>
Thu, 24 Jul 2014 01:38:13 +0000 (18:38 -0700)
committerBrad Fitzpatrick <bradfitz@golang.org>
Thu, 24 Jul 2014 01:38:13 +0000 (18:38 -0700)
LGTM=josharian, adg
R=golang-codereviews, josharian, adg
CC=golang-codereviews
https://golang.org/cl/119110043

src/pkg/net/http/request.go
src/pkg/net/http/requestwrite_test.go

index 80bff9c0ec6c9de6fb41e9a929ca858c17734f22..131cb6d67eef5bc1a1b744f019ed39360c1e89cb 100644 (file)
@@ -390,10 +390,16 @@ func (req *Request) write(w io.Writer, usingProxy bool, extraHeaders Header) err
                w = bw
        }
 
-       fmt.Fprintf(w, "%s %s HTTP/1.1\r\n", valueOrDefault(req.Method, "GET"), ruri)
+       _, err := fmt.Fprintf(w, "%s %s HTTP/1.1\r\n", valueOrDefault(req.Method, "GET"), ruri)
+       if err != nil {
+               return err
+       }
 
        // Header lines
-       fmt.Fprintf(w, "Host: %s\r\n", host)
+       _, err = fmt.Fprintf(w, "Host: %s\r\n", host)
+       if err != nil {
+               return err
+       }
 
        // Use the defaultUserAgent unless the Header contains one, which
        // may be blank to not send the header.
@@ -404,7 +410,10 @@ func (req *Request) write(w io.Writer, usingProxy bool, extraHeaders Header) err
                }
        }
        if userAgent != "" {
-               fmt.Fprintf(w, "User-Agent: %s\r\n", userAgent)
+               _, err = fmt.Fprintf(w, "User-Agent: %s\r\n", userAgent)
+               if err != nil {
+                       return err
+               }
        }
 
        // Process Body,ContentLength,Close,Trailer
@@ -429,7 +438,10 @@ func (req *Request) write(w io.Writer, usingProxy bool, extraHeaders Header) err
                }
        }
 
-       io.WriteString(w, "\r\n")
+       _, err = io.WriteString(w, "\r\n")
+       if err != nil {
+               return err
+       }
 
        // Write body and trailer
        err = tw.WriteBody(w)
index dc0e204cac98c7c2d86686ff76c438604513e8dc..997010c2b2b41b0ab70578b0d04c747c66f6304e 100644 (file)
@@ -563,3 +563,61 @@ func mustParseURL(s string) *url.URL {
        }
        return u
 }
+
+type writerFunc func([]byte) (int, error)
+
+func (f writerFunc) Write(p []byte) (int, error) { return f(p) }
+
+// TestRequestWriteError tests the Write err != nil checks in (*Request).write.
+func TestRequestWriteError(t *testing.T) {
+       failAfter, writeCount := 0, 0
+       errFail := errors.New("fake write failure")
+
+       // w is the buffered io.Writer to write the request to.  It
+       // fails exactly once on its Nth Write call, as controlled by
+       // failAfter. It also tracks the number of calls in
+       // writeCount.
+       w := struct {
+               io.ByteWriter // to avoid being wrapped by a bufio.Writer
+               io.Writer
+       }{
+               nil,
+               writerFunc(func(p []byte) (n int, err error) {
+                       writeCount++
+                       if failAfter == 0 {
+                               err = errFail
+                       }
+                       failAfter--
+                       return len(p), err
+               }),
+       }
+
+       req, _ := NewRequest("GET", "http://example.com/", nil)
+       const writeCalls = 4 // number of Write calls in current implementation
+       sawGood := false
+       for n := 0; n <= writeCalls+2; n++ {
+               failAfter = n
+               writeCount = 0
+               err := req.Write(w)
+               var wantErr error
+               if n < writeCalls {
+                       wantErr = errFail
+               }
+               if err != wantErr {
+                       t.Errorf("for fail-after %d Writes, err = %v; want %v", n, err, wantErr)
+                       continue
+               }
+               if err == nil {
+                       sawGood = true
+                       if writeCount != writeCalls {
+                               t.Fatalf("writeCalls constant is outdated in test")
+                       }
+               }
+               if writeCount > writeCalls || writeCount > n+1 {
+                       t.Errorf("for fail-after %d, saw unexpectedly high (%d) write calls", n, writeCount)
+               }
+       }
+       if !sawGood {
+               t.Fatalf("writeCalls constant is outdated in test")
+       }
+}