return 1, io.EOF
}
-// transferBodyReader is an io.Reader that reads from tw.Body
-// and records any non-EOF error in tw.bodyReadError.
-// It is exactly 1 pointer wide to avoid allocations into interfaces.
-type transferBodyReader struct{ tw *transferWriter }
-
-func (br transferBodyReader) Read(p []byte) (n int, err error) {
- n, err = br.tw.Body.Read(p)
- if err != nil && err != io.EOF {
- br.tw.bodyReadError = err
- }
- return
-}
-
// transferWriter inspects the fields of a user-supplied Request or Response,
// sanitizes them without changing the user object and provides methods for
// writing the respective header, body and trailer in wire format.
var err error
var ncopy int64
- // Write body
+ // Write body. We "unwrap" the body first if it was wrapped in a
+ // nopCloser. This is to ensure that we can take advantage of
+ // OS-level optimizations in the event that the body is an
+ // *os.File.
if t.Body != nil {
- var body = transferBodyReader{t}
+ var body = t.unwrapBody()
if chunked(t.TransferEncoding) {
if bw, ok := w.(*bufio.Writer); ok && !t.IsResponse {
w = &internal.FlushAfterChunkWriter{Writer: bw}
}
cw := internal.NewChunkedWriter(w)
- _, err = io.Copy(cw, body)
+ _, err = t.doBodyCopy(cw, body)
if err == nil {
err = cw.Close()
}
if t.Method == "CONNECT" {
dst = bufioFlushWriter{dst}
}
- ncopy, err = io.Copy(dst, body)
+ ncopy, err = t.doBodyCopy(dst, body)
} else {
- ncopy, err = io.Copy(w, io.LimitReader(body, t.ContentLength))
+ ncopy, err = t.doBodyCopy(w, io.LimitReader(body, t.ContentLength))
if err != nil {
return err
}
var nextra int64
- nextra, err = io.Copy(ioutil.Discard, body)
+ nextra, err = t.doBodyCopy(ioutil.Discard, body)
ncopy += nextra
}
if err != nil {
return err
}
+// doBodyCopy wraps a copy operation, with any resulting error also
+// being saved in bodyReadError.
+//
+// This function is only intended for use in writeBody.
+func (t *transferWriter) doBodyCopy(dst io.Writer, src io.Reader) (n int64, err error) {
+ n, err = io.Copy(dst, src)
+ if err != nil && err != io.EOF {
+ t.bodyReadError = err
+ }
+ return
+}
+
+// unwrapBodyReader unwraps the body's inner reader if it's a
+// nopCloser. This is to ensure that body writes sourced from local
+// files (*os.File types) are properly optimized.
+//
+// This function is only intended for use in writeBody.
+func (t *transferWriter) unwrapBody() io.Reader {
+ if reflect.TypeOf(t.Body) == nopCloserType {
+ return reflect.ValueOf(t.Body).Field(0).Interface().(io.Reader)
+ }
+
+ return t.Body
+}
+
type transferReader struct {
// Input
Header Header
import (
"bufio"
"bytes"
+ "crypto/rand"
+ "fmt"
"io"
"io/ioutil"
+ "os"
+ "reflect"
"strings"
"testing"
)
}
}
}
+
+type mockTransferWriter struct {
+ CalledReader io.Reader
+ WriteCalled bool
+}
+
+var _ io.ReaderFrom = (*mockTransferWriter)(nil)
+
+func (w *mockTransferWriter) ReadFrom(r io.Reader) (int64, error) {
+ w.CalledReader = r
+ return io.Copy(ioutil.Discard, r)
+}
+
+func (w *mockTransferWriter) Write(p []byte) (int, error) {
+ w.WriteCalled = true
+ return ioutil.Discard.Write(p)
+}
+
+func TestTransferWriterWriteBodyReaderTypes(t *testing.T) {
+ fileType := reflect.TypeOf(&os.File{})
+ bufferType := reflect.TypeOf(&bytes.Buffer{})
+
+ nBytes := int64(1 << 10)
+ newFileFunc := func() (r io.Reader, done func(), err error) {
+ f, err := ioutil.TempFile("", "net-http-newfilefunc")
+ if err != nil {
+ return nil, nil, err
+ }
+
+ // Write some bytes to the file to enable reading.
+ if _, err := io.CopyN(f, rand.Reader, nBytes); err != nil {
+ return nil, nil, fmt.Errorf("failed to write data to file: %v", err)
+ }
+ if _, err := f.Seek(0, 0); err != nil {
+ return nil, nil, fmt.Errorf("failed to seek to front: %v", err)
+ }
+
+ done = func() {
+ f.Close()
+ os.Remove(f.Name())
+ }
+
+ return f, done, nil
+ }
+
+ newBufferFunc := func() (io.Reader, func(), error) {
+ return bytes.NewBuffer(make([]byte, nBytes)), func() {}, nil
+ }
+
+ cases := []struct {
+ name string
+ bodyFunc func() (io.Reader, func(), error)
+ method string
+ contentLength int64
+ transferEncoding []string
+ limitedReader bool
+ expectedReader reflect.Type
+ expectedWrite bool
+ }{
+ {
+ name: "file, non-chunked, size set",
+ bodyFunc: newFileFunc,
+ method: "PUT",
+ contentLength: nBytes,
+ limitedReader: true,
+ expectedReader: fileType,
+ },
+ {
+ name: "file, non-chunked, size set, nopCloser wrapped",
+ method: "PUT",
+ bodyFunc: func() (io.Reader, func(), error) {
+ r, cleanup, err := newFileFunc()
+ return ioutil.NopCloser(r), cleanup, err
+ },
+ contentLength: nBytes,
+ limitedReader: true,
+ expectedReader: fileType,
+ },
+ {
+ name: "file, non-chunked, negative size",
+ method: "PUT",
+ bodyFunc: newFileFunc,
+ contentLength: -1,
+ expectedReader: fileType,
+ },
+ {
+ name: "file, non-chunked, CONNECT, negative size",
+ method: "CONNECT",
+ bodyFunc: newFileFunc,
+ contentLength: -1,
+ expectedReader: fileType,
+ },
+ {
+ name: "file, chunked",
+ method: "PUT",
+ bodyFunc: newFileFunc,
+ transferEncoding: []string{"chunked"},
+ expectedWrite: true,
+ },
+ {
+ name: "buffer, non-chunked, size set",
+ bodyFunc: newBufferFunc,
+ method: "PUT",
+ contentLength: nBytes,
+ limitedReader: true,
+ expectedReader: bufferType,
+ },
+ {
+ name: "buffer, non-chunked, size set, nopCloser wrapped",
+ method: "PUT",
+ bodyFunc: func() (io.Reader, func(), error) {
+ r, cleanup, err := newBufferFunc()
+ return ioutil.NopCloser(r), cleanup, err
+ },
+ contentLength: nBytes,
+ limitedReader: true,
+ expectedReader: bufferType,
+ },
+ {
+ name: "buffer, non-chunked, negative size",
+ method: "PUT",
+ bodyFunc: newBufferFunc,
+ contentLength: -1,
+ expectedWrite: true,
+ },
+ {
+ name: "buffer, non-chunked, CONNECT, negative size",
+ method: "CONNECT",
+ bodyFunc: newBufferFunc,
+ contentLength: -1,
+ expectedWrite: true,
+ },
+ {
+ name: "buffer, chunked",
+ method: "PUT",
+ bodyFunc: newBufferFunc,
+ transferEncoding: []string{"chunked"},
+ expectedWrite: true,
+ },
+ }
+
+ for _, tc := range cases {
+ t.Run(tc.name, func(t *testing.T) {
+ body, cleanup, err := tc.bodyFunc()
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer cleanup()
+
+ mw := &mockTransferWriter{}
+ tw := &transferWriter{
+ Body: body,
+ ContentLength: tc.contentLength,
+ TransferEncoding: tc.transferEncoding,
+ }
+
+ if err := tw.writeBody(mw); err != nil {
+ t.Fatal(err)
+ }
+
+ if tc.expectedReader != nil {
+ if mw.CalledReader == nil {
+ t.Fatal("did not call ReadFrom")
+ }
+
+ var actualReader reflect.Type
+ lr, ok := mw.CalledReader.(*io.LimitedReader)
+ if ok && tc.limitedReader {
+ actualReader = reflect.TypeOf(lr.R)
+ } else {
+ actualReader = reflect.TypeOf(mw.CalledReader)
+ }
+
+ if tc.expectedReader != actualReader {
+ t.Fatalf("got reader %T want %T", actualReader, tc.expectedReader)
+ }
+ }
+
+ if tc.expectedWrite && !mw.WriteCalled {
+ t.Fatal("did not invoke Write")
+ }
+ })
+ }
+}