--- /dev/null
+// Copyright 2018 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package poll
+
+import "syscall"
+
+const (
+ // spliceNonblock makes calls to splice(2) non-blocking.
+ spliceNonblock = 0x2
+
+ // maxSpliceSize is the maximum amount of data Splice asks
+ // the kernel to move in a single call to splice(2).
+ maxSpliceSize = 4 << 20
+)
+
+// Splice transfers at most remain bytes of data from src to dst, using the
+// splice system call to minimize copies of data from and to userspace.
+//
+// Splice creates a temporary pipe, to serve as a buffer for the data transfer.
+// src and dst must both be stream-oriented sockets.
+//
+// If err != nil, sc is the system call which caused the error.
+func Splice(dst, src *FD, remain int64) (written int64, handled bool, sc string, err error) {
+ prfd, pwfd, sc, err := newTempPipe()
+ if err != nil {
+ return 0, false, sc, err
+ }
+ defer destroyTempPipe(prfd, pwfd)
+ // From here on, the operation should be considered handled,
+ // even if Splice doesn't transfer any data.
+ if err := src.readLock(); err != nil {
+ return 0, true, "splice", err
+ }
+ defer src.readUnlock()
+ if err := dst.writeLock(); err != nil {
+ return 0, true, "splice", err
+ }
+ defer dst.writeUnlock()
+ if err := src.pd.prepareRead(src.isFile); err != nil {
+ return 0, true, "splice", err
+ }
+ if err := dst.pd.prepareWrite(dst.isFile); err != nil {
+ return 0, true, "splice", err
+ }
+ var inPipe, n int
+ for err == nil && remain > 0 {
+ max := maxSpliceSize
+ if int64(max) > remain {
+ max = int(remain)
+ }
+ inPipe, err = spliceDrain(pwfd, src, max)
+ // spliceDrain should never return EAGAIN, so if err != nil,
+ // Splice cannot continue. If inPipe == 0 && err == nil,
+ // src is at EOF, and the transfer is complete.
+ if err != nil || (inPipe == 0 && err == nil) {
+ break
+ }
+ n, err = splicePump(dst, prfd, inPipe)
+ if n > 0 {
+ written += int64(n)
+ remain -= int64(n)
+ }
+ }
+ if err != nil {
+ return written, true, "splice", err
+ }
+ return written, true, "", nil
+}
+
+// spliceDrain moves data from a socket to a pipe.
+//
+// Invariant: when entering spliceDrain, the pipe is empty. It is either in its
+// initial state, or splicePump has emptied it previously.
+//
+// Given this, spliceDrain can reasonably assume that the pipe is ready for
+// writing, so if splice returns EAGAIN, it must be because the socket is not
+// ready for reading.
+//
+// If spliceDrain returns (0, nil), src is at EOF.
+func spliceDrain(pipefd int, sock *FD, max int) (int, error) {
+ for {
+ n, err := splice(pipefd, sock.Sysfd, max, spliceNonblock)
+ if err != syscall.EAGAIN {
+ return n, err
+ }
+ if err := sock.pd.waitRead(sock.isFile); err != nil {
+ return n, err
+ }
+ }
+}
+
+// splicePump moves all the buffered data from a pipe to a socket.
+//
+// Invariant: when entering splicePump, there are exactly inPipe
+// bytes of data in the pipe, from a previous call to spliceDrain.
+//
+// By analogy to the condition from spliceDrain, splicePump
+// only needs to poll the socket for readiness, if splice returns
+// EAGAIN.
+//
+// If splicePump cannot move all the data in a single call to
+// splice(2), it loops over the buffered data until it has written
+// all of it to the socket. This behavior is similar to the Write
+// step of an io.Copy in userspace.
+func splicePump(sock *FD, pipefd int, inPipe int) (int, error) {
+ written := 0
+ for inPipe > 0 {
+ n, err := splice(sock.Sysfd, pipefd, inPipe, spliceNonblock)
+ // Here, the condition n == 0 && err == nil should never be
+ // observed, since Splice controls the write side of the pipe.
+ if n > 0 {
+ inPipe -= n
+ written += n
+ continue
+ }
+ if err != syscall.EAGAIN {
+ return written, err
+ }
+ if err := sock.pd.waitWrite(sock.isFile); err != nil {
+ return written, err
+ }
+ }
+ return written, nil
+}
+
+// splice wraps the splice system call. Since the current implementation
+// only uses splice on sockets and pipes, the offset arguments are unused.
+// splice returns int instead of int64, because callers never ask it to
+// move more data in a single call than can fit in an int32.
+func splice(out int, in int, max int, flags int) (int, error) {
+ n, err := syscall.Splice(in, nil, out, nil, max, flags)
+ return int(n), err
+}
+
+// newTempPipe sets up a temporary pipe for a splice operation.
+func newTempPipe() (prfd, pwfd int, sc string, err error) {
+ var fds [2]int
+ const flags = syscall.O_CLOEXEC | syscall.O_NONBLOCK
+ if err := syscall.Pipe2(fds[:], flags); err != nil {
+ // pipe2 was added in 2.6.27 and our minimum requirement
+ // is 2.6.23, so it might not be implemented.
+ if err == syscall.ENOSYS {
+ return newTempPipeFallback(fds[:])
+ }
+ return -1, -1, "pipe2", err
+ }
+ return fds[0], fds[1], "", nil
+}
+
+// newTempPipeFallback is a fallback for newTempPipe, for systems
+// which do not support pipe2.
+func newTempPipeFallback(fds []int) (prfd, pwfd int, sc string, err error) {
+ syscall.ForkLock.RLock()
+ defer syscall.ForkLock.RUnlock()
+ if err := syscall.Pipe(fds); err != nil {
+ return -1, -1, "pipe", err
+ }
+ prfd, pwfd = fds[0], fds[1]
+ syscall.CloseOnExec(prfd)
+ syscall.CloseOnExec(pwfd)
+ if err := syscall.SetNonblock(prfd, true); err != nil {
+ CloseFunc(prfd)
+ CloseFunc(pwfd)
+ return -1, -1, "setnonblock", err
+ }
+ if err := syscall.SetNonblock(pwfd, true); err != nil {
+ CloseFunc(prfd)
+ CloseFunc(pwfd)
+ return -1, -1, "setnonblock", err
+ }
+ return prfd, pwfd, "", nil
+}
+
+// destroyTempPipe destroys a temporary pipe.
+func destroyTempPipe(prfd, pwfd int) error {
+ err := CloseFunc(prfd)
+ err1 := CloseFunc(pwfd)
+ if err == nil {
+ return err1
+ }
+ return err
+}
--- /dev/null
+// Copyright 2018 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build linux
+
+package net
+
+import (
+ "bytes"
+ "fmt"
+ "io"
+ "testing"
+)
+
+func TestSplice(t *testing.T) {
+ t.Run("simple", testSpliceSimple)
+ t.Run("multipleWrite", testSpliceMultipleWrite)
+ t.Run("big", testSpliceBig)
+ t.Run("honorsLimitedReader", testSpliceHonorsLimitedReader)
+ t.Run("readerAtEOF", testSpliceReaderAtEOF)
+}
+
+func testSpliceSimple(t *testing.T) {
+ srv, err := newSpliceTestServer()
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer srv.Close()
+ copyDone := srv.Copy()
+ msg := []byte("splice test")
+ if _, err := srv.Write(msg); err != nil {
+ t.Fatal(err)
+ }
+ got := make([]byte, len(msg))
+ if _, err := io.ReadFull(srv, got); err != nil {
+ t.Fatal(err)
+ }
+ if !bytes.Equal(got, msg) {
+ t.Errorf("got %q, wrote %q", got, msg)
+ }
+ srv.CloseWrite()
+ srv.CloseRead()
+ if err := <-copyDone; err != nil {
+ t.Errorf("splice: %v", err)
+ }
+}
+
+func testSpliceMultipleWrite(t *testing.T) {
+ srv, err := newSpliceTestServer()
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer srv.Close()
+ copyDone := srv.Copy()
+ msg1 := []byte("splice test part 1 ")
+ msg2 := []byte(" splice test part 2")
+ if _, err := srv.Write(msg1); err != nil {
+ t.Fatalf("Write: %v", err)
+ }
+ if _, err := srv.Write(msg2); err != nil {
+ t.Fatal(err)
+ }
+ got := make([]byte, len(msg1)+len(msg2))
+ if _, err := io.ReadFull(srv, got); err != nil {
+ t.Fatal(err)
+ }
+ want := append(msg1, msg2...)
+ if !bytes.Equal(got, want) {
+ t.Errorf("got %q, wrote %q", got, want)
+ }
+ srv.CloseWrite()
+ srv.CloseRead()
+ if err := <-copyDone; err != nil {
+ t.Errorf("splice: %v", err)
+ }
+}
+
+func testSpliceBig(t *testing.T) {
+ size := 1<<31 - 1
+ if testing.Short() {
+ size = 1 << 25
+ }
+ srv, err := newSpliceTestServer()
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer srv.Close()
+ big := make([]byte, size)
+ copyDone := srv.Copy()
+ type readResult struct {
+ b []byte
+ err error
+ }
+ readDone := make(chan readResult)
+ go func() {
+ got := make([]byte, len(big))
+ _, err := io.ReadFull(srv, got)
+ readDone <- readResult{got, err}
+ }()
+ if _, err := srv.Write(big); err != nil {
+ t.Fatal(err)
+ }
+ res := <-readDone
+ if res.err != nil {
+ t.Fatal(res.err)
+ }
+ got := res.b
+ if !bytes.Equal(got, big) {
+ t.Errorf("input and output differ")
+ }
+ srv.CloseWrite()
+ srv.CloseRead()
+ if err := <-copyDone; err != nil {
+ t.Errorf("splice: %v", err)
+ }
+}
+
+func testSpliceHonorsLimitedReader(t *testing.T) {
+ t.Run("stopsAfterN", testSpliceStopsAfterN)
+ t.Run("updatesN", testSpliceUpdatesN)
+}
+
+func testSpliceStopsAfterN(t *testing.T) {
+ clientUp, serverUp, err := spliceTestSocketPair("tcp")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer clientUp.Close()
+ defer serverUp.Close()
+ clientDown, serverDown, err := spliceTestSocketPair("tcp")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer clientDown.Close()
+ defer serverDown.Close()
+ count := 128
+ copyDone := make(chan error)
+ lr := &io.LimitedReader{
+ N: int64(count),
+ R: serverUp,
+ }
+ go func() {
+ _, err := io.Copy(serverDown, lr)
+ serverDown.Close()
+ copyDone <- err
+ }()
+ msg := make([]byte, 2*count)
+ if _, err := clientUp.Write(msg); err != nil {
+ t.Fatal(err)
+ }
+ clientUp.Close()
+ var buf bytes.Buffer
+ if _, err := io.Copy(&buf, clientDown); err != nil {
+ t.Fatal(err)
+ }
+ if buf.Len() != count {
+ t.Errorf("splice transferred %d bytes, want to stop after %d", buf.Len(), count)
+ }
+ clientDown.Close()
+ if err := <-copyDone; err != nil {
+ t.Errorf("splice: %v", err)
+ }
+}
+
+func testSpliceUpdatesN(t *testing.T) {
+ clientUp, serverUp, err := spliceTestSocketPair("tcp")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer clientUp.Close()
+ defer serverUp.Close()
+ clientDown, serverDown, err := spliceTestSocketPair("tcp")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer clientDown.Close()
+ defer serverDown.Close()
+ count := 128
+ copyDone := make(chan error)
+ lr := &io.LimitedReader{
+ N: int64(100 + count),
+ R: serverUp,
+ }
+ go func() {
+ _, err := io.Copy(serverDown, lr)
+ copyDone <- err
+ }()
+ msg := make([]byte, count)
+ if _, err := clientUp.Write(msg); err != nil {
+ t.Fatal(err)
+ }
+ clientUp.Close()
+ got := make([]byte, count)
+ if _, err := io.ReadFull(clientDown, got); err != nil {
+ t.Fatal(err)
+ }
+ clientDown.Close()
+ if err := <-copyDone; err != nil {
+ t.Errorf("splice: %v", err)
+ }
+ wantN := int64(100)
+ if lr.N != wantN {
+ t.Errorf("lr.N = %d, want %d", lr.N, wantN)
+ }
+}
+
+func testSpliceReaderAtEOF(t *testing.T) {
+ clientUp, serverUp, err := spliceTestSocketPair("tcp")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer clientUp.Close()
+ defer serverUp.Close()
+ clientDown, serverDown, err := spliceTestSocketPair("tcp")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer clientDown.Close()
+ defer serverDown.Close()
+
+ serverUp.Close()
+ _, err, handled := splice(serverDown.(*TCPConn).fd, serverUp)
+ if !handled {
+ t.Errorf("closed connection: got err = %v, handled = %t, want handled = true", err, handled)
+ }
+ lr := &io.LimitedReader{
+ N: 0,
+ R: serverUp,
+ }
+ _, err, handled = splice(serverDown.(*TCPConn).fd, lr)
+ if !handled {
+ t.Errorf("exhausted LimitedReader: got err = %v, handled = %t, want handled = true", err, handled)
+ }
+}
+
+func BenchmarkTCPReadFrom(b *testing.B) {
+ testHookUninstaller.Do(uninstallTestHooks)
+
+ var chunkSizes []int
+ for i := uint(10); i <= 20; i++ {
+ chunkSizes = append(chunkSizes, 1<<i)
+ }
+ // To benchmark the genericReadFrom code path, set this to false.
+ useSplice := true
+ for _, chunkSize := range chunkSizes {
+ b.Run(fmt.Sprint(chunkSize), func(b *testing.B) {
+ benchmarkSplice(b, chunkSize, useSplice)
+ })
+ }
+}
+
+func benchmarkSplice(b *testing.B, chunkSize int, useSplice bool) {
+ srv, err := newSpliceTestServer()
+ if err != nil {
+ b.Fatal(err)
+ }
+ defer srv.Close()
+ var copyDone <-chan error
+ if useSplice {
+ copyDone = srv.Copy()
+ } else {
+ copyDone = srv.CopyNoSplice()
+ }
+ chunk := make([]byte, chunkSize)
+ discardDone := make(chan struct{})
+ go func() {
+ for {
+ buf := make([]byte, chunkSize)
+ _, err := srv.Read(buf)
+ if err != nil {
+ break
+ }
+ }
+ discardDone <- struct{}{}
+ }()
+ b.SetBytes(int64(chunkSize))
+ b.ResetTimer()
+ for i := 0; i < b.N; i++ {
+ srv.Write(chunk)
+ }
+ srv.CloseWrite()
+ <-copyDone
+ srv.CloseRead()
+ <-discardDone
+}
+
+type spliceTestServer struct {
+ clientUp io.WriteCloser
+ clientDown io.ReadCloser
+ serverUp io.ReadCloser
+ serverDown io.WriteCloser
+}
+
+func newSpliceTestServer() (*spliceTestServer, error) {
+ // For now, both networks are hard-coded to TCP.
+ // If splice is enabled for non-tcp upstream connections,
+ // newSpliceTestServer will need to take a network parameter.
+ clientUp, serverUp, err := spliceTestSocketPair("tcp")
+ if err != nil {
+ return nil, err
+ }
+ clientDown, serverDown, err := spliceTestSocketPair("tcp")
+ if err != nil {
+ clientUp.Close()
+ serverUp.Close()
+ return nil, err
+ }
+ return &spliceTestServer{clientUp, clientDown, serverUp, serverDown}, nil
+}
+
+// Read reads from the downstream connection.
+func (srv *spliceTestServer) Read(b []byte) (int, error) {
+ return srv.clientDown.Read(b)
+}
+
+// Write writes to the upstream connection.
+func (srv *spliceTestServer) Write(b []byte) (int, error) {
+ return srv.clientUp.Write(b)
+}
+
+// Close closes the server.
+func (srv *spliceTestServer) Close() error {
+ err := srv.closeUp()
+ err1 := srv.closeDown()
+ if err == nil {
+ return err1
+ }
+ return err
+}
+
+// CloseWrite closes the client side of the upstream connection.
+func (srv *spliceTestServer) CloseWrite() error {
+ return srv.clientUp.Close()
+}
+
+// CloseRead closes the client side of the downstream connection.
+func (srv *spliceTestServer) CloseRead() error {
+ return srv.clientDown.Close()
+}
+
+// Copy copies from the server side of the upstream connection
+// to the server side of the downstream connection, in a separate
+// goroutine. Copy is done when the first send on the returned
+// channel succeeds.
+func (srv *spliceTestServer) Copy() <-chan error {
+ ch := make(chan error)
+ go func() {
+ _, err := io.Copy(srv.serverDown, srv.serverUp)
+ ch <- err
+ close(ch)
+ }()
+ return ch
+}
+
+// CopyNoSplice is like Copy, but ensures that the splice code path
+// is not reached.
+func (srv *spliceTestServer) CopyNoSplice() <-chan error {
+ type onlyReader struct {
+ io.Reader
+ }
+ ch := make(chan error)
+ go func() {
+ _, err := io.Copy(srv.serverDown, onlyReader{srv.serverUp})
+ ch <- err
+ close(ch)
+ }()
+ return ch
+}
+
+func (srv *spliceTestServer) closeUp() error {
+ var err, err1 error
+ if srv.serverUp != nil {
+ err = srv.serverUp.Close()
+ }
+ if srv.clientUp != nil {
+ err1 = srv.clientUp.Close()
+ }
+ if err == nil {
+ return err1
+ }
+ return err
+}
+
+func (srv *spliceTestServer) closeDown() error {
+ var err, err1 error
+ if srv.serverDown != nil {
+ err = srv.serverDown.Close()
+ }
+ if srv.clientDown != nil {
+ err1 = srv.clientDown.Close()
+ }
+ if err == nil {
+ return err1
+ }
+ return err
+}
+
+func spliceTestSocketPair(net string) (client, server Conn, err error) {
+ ln, err := newLocalListener(net)
+ if err != nil {
+ return nil, nil, err
+ }
+ defer ln.Close()
+ var cerr, serr error
+ acceptDone := make(chan struct{})
+ go func() {
+ server, serr = ln.Accept()
+ acceptDone <- struct{}{}
+ }()
+ client, cerr = Dial(ln.Addr().Network(), ln.Addr().String())
+ <-acceptDone
+ if cerr != nil {
+ if server != nil {
+ server.Close()
+ }
+ return nil, nil, cerr
+ }
+ if serr != nil {
+ if client != nil {
+ client.Close()
+ }
+ return nil, nil, serr
+ }
+ return client, server, nil
+}