w = bw
}
- fmt.Fprintf(w, "%s %s HTTP/1.1\r\n", valueOrDefault(req.Method, "GET"), ruri)
+ _, err := fmt.Fprintf(w, "%s %s HTTP/1.1\r\n", valueOrDefault(req.Method, "GET"), ruri)
+ if err != nil {
+ return err
+ }
// Header lines
- fmt.Fprintf(w, "Host: %s\r\n", host)
+ _, err = fmt.Fprintf(w, "Host: %s\r\n", host)
+ if err != nil {
+ return err
+ }
// Use the defaultUserAgent unless the Header contains one, which
// may be blank to not send the header.
}
}
if userAgent != "" {
- fmt.Fprintf(w, "User-Agent: %s\r\n", userAgent)
+ _, err = fmt.Fprintf(w, "User-Agent: %s\r\n", userAgent)
+ if err != nil {
+ return err
+ }
}
// Process Body,ContentLength,Close,Trailer
}
}
- io.WriteString(w, "\r\n")
+ _, err = io.WriteString(w, "\r\n")
+ if err != nil {
+ return err
+ }
// Write body and trailer
err = tw.WriteBody(w)
}
return u
}
+
+type writerFunc func([]byte) (int, error)
+
+func (f writerFunc) Write(p []byte) (int, error) { return f(p) }
+
+// TestRequestWriteError tests the Write err != nil checks in (*Request).write.
+func TestRequestWriteError(t *testing.T) {
+ failAfter, writeCount := 0, 0
+ errFail := errors.New("fake write failure")
+
+ // w is the buffered io.Writer to write the request to. It
+ // fails exactly once on its Nth Write call, as controlled by
+ // failAfter. It also tracks the number of calls in
+ // writeCount.
+ w := struct {
+ io.ByteWriter // to avoid being wrapped by a bufio.Writer
+ io.Writer
+ }{
+ nil,
+ writerFunc(func(p []byte) (n int, err error) {
+ writeCount++
+ if failAfter == 0 {
+ err = errFail
+ }
+ failAfter--
+ return len(p), err
+ }),
+ }
+
+ req, _ := NewRequest("GET", "http://example.com/", nil)
+ const writeCalls = 4 // number of Write calls in current implementation
+ sawGood := false
+ for n := 0; n <= writeCalls+2; n++ {
+ failAfter = n
+ writeCount = 0
+ err := req.Write(w)
+ var wantErr error
+ if n < writeCalls {
+ wantErr = errFail
+ }
+ if err != wantErr {
+ t.Errorf("for fail-after %d Writes, err = %v; want %v", n, err, wantErr)
+ continue
+ }
+ if err == nil {
+ sawGood = true
+ if writeCount != writeCalls {
+ t.Fatalf("writeCalls constant is outdated in test")
+ }
+ }
+ if writeCount > writeCalls || writeCount > n+1 {
+ t.Errorf("for fail-after %d, saw unexpectedly high (%d) write calls", n, writeCount)
+ }
+ }
+ if !sawGood {
+ t.Fatalf("writeCalls constant is outdated in test")
+ }
+}