package net
import (
- "bytes"
- "fmt"
"io"
"io/ioutil"
+ "log"
+ "os"
+ "os/exec"
+ "strconv"
"sync"
"testing"
)
func TestSplice(t *testing.T) {
- t.Run("simple", testSpliceSimple)
- t.Run("multipleWrite", testSpliceMultipleWrite)
- t.Run("big", testSpliceBig)
- t.Run("honorsLimitedReader", testSpliceHonorsLimitedReader)
- t.Run("readerAtEOF", testSpliceReaderAtEOF)
- t.Run("issue25985", testSpliceIssue25985)
+ t.Run("tcp-to-tcp", func(t *testing.T) { testSplice(t, "tcp", "tcp") })
+ t.Run("unix-to-tcp", func(t *testing.T) { testSplice(t, "unix", "tcp") })
}
-func testSpliceSimple(t *testing.T) {
- srv, err := newSpliceTestServer()
- if err != nil {
- t.Fatal(err)
- }
- defer srv.Close()
- copyDone := srv.Copy()
- msg := []byte("splice test")
- if _, err := srv.Write(msg); err != nil {
- t.Fatal(err)
- }
- got := make([]byte, len(msg))
- if _, err := io.ReadFull(srv, got); err != nil {
- t.Fatal(err)
- }
- if !bytes.Equal(got, msg) {
- t.Errorf("got %q, wrote %q", got, msg)
- }
- srv.CloseWrite()
- srv.CloseRead()
- if err := <-copyDone; err != nil {
- t.Errorf("splice: %v", err)
- }
+func testSplice(t *testing.T, upNet, downNet string) {
+ t.Run("simple", spliceTestCase{upNet, downNet, 128, 128, 0}.test)
+ t.Run("multipleWrite", spliceTestCase{upNet, downNet, 4096, 1 << 20, 0}.test)
+ t.Run("big", spliceTestCase{upNet, downNet, 5 << 20, 1 << 30, 0}.test)
+ t.Run("honorsLimitedReader", spliceTestCase{upNet, downNet, 4096, 1 << 20, 1 << 10}.test)
+ t.Run("updatesLimitedReaderN", spliceTestCase{upNet, downNet, 1024, 4096, 4096 + 100}.test)
+ t.Run("limitedReaderAtLimit", spliceTestCase{upNet, downNet, 32, 128, 128}.test)
+ t.Run("readerAtEOF", func(t *testing.T) { testSpliceReaderAtEOF(t, upNet, downNet) })
+ t.Run("issue25985", func(t *testing.T) { testSpliceIssue25985(t, upNet, downNet) })
}
-func testSpliceMultipleWrite(t *testing.T) {
- srv, err := newSpliceTestServer()
- if err != nil {
- t.Fatal(err)
- }
- defer srv.Close()
- copyDone := srv.Copy()
- msg1 := []byte("splice test part 1 ")
- msg2 := []byte(" splice test part 2")
- if _, err := srv.Write(msg1); err != nil {
- t.Fatalf("Write: %v", err)
- }
- if _, err := srv.Write(msg2); err != nil {
- t.Fatal(err)
- }
- got := make([]byte, len(msg1)+len(msg2))
- if _, err := io.ReadFull(srv, got); err != nil {
- t.Fatal(err)
- }
- want := append(msg1, msg2...)
- if !bytes.Equal(got, want) {
- t.Errorf("got %q, wrote %q", got, want)
- }
- srv.CloseWrite()
- srv.CloseRead()
- if err := <-copyDone; err != nil {
- t.Errorf("splice: %v", err)
- }
-}
+type spliceTestCase struct {
+ upNet, downNet string
-func testSpliceBig(t *testing.T) {
- // The maximum amount of data that internal/poll.Splice will use in a
- // splice(2) call is 4 << 20. Use a bigger size here so that we test an
- // amount that doesn't fit in a single call.
- size := 5 << 20
- srv, err := newSpliceTestServer()
- if err != nil {
- t.Fatal(err)
- }
- defer srv.Close()
- big := make([]byte, size)
- copyDone := srv.Copy()
- type readResult struct {
- b []byte
- err error
- }
- readDone := make(chan readResult)
- go func() {
- got := make([]byte, len(big))
- _, err := io.ReadFull(srv, got)
- readDone <- readResult{got, err}
- }()
- if _, err := srv.Write(big); err != nil {
- t.Fatal(err)
- }
- res := <-readDone
- if res.err != nil {
- t.Fatal(res.err)
- }
- got := res.b
- if !bytes.Equal(got, big) {
- t.Errorf("input and output differ")
- }
- srv.CloseWrite()
- srv.CloseRead()
- if err := <-copyDone; err != nil {
- t.Errorf("splice: %v", err)
- }
-}
-
-func testSpliceHonorsLimitedReader(t *testing.T) {
- t.Run("stopsAfterN", testSpliceStopsAfterN)
- t.Run("updatesN", testSpliceUpdatesN)
- t.Run("readerAtLimit", testSpliceReaderAtLimit)
+ chunkSize, totalSize int
+ limitReadSize int
}
-func testSpliceStopsAfterN(t *testing.T) {
- clientUp, serverUp, err := spliceTestSocketPair("tcp")
+func (tc spliceTestCase) test(t *testing.T) {
+ clientUp, serverUp, err := spliceTestSocketPair(tc.upNet)
if err != nil {
t.Fatal(err)
}
- defer clientUp.Close()
defer serverUp.Close()
- clientDown, serverDown, err := spliceTestSocketPair("tcp")
+ cleanup, err := startSpliceClient(clientUp, "w", tc.chunkSize, tc.totalSize)
if err != nil {
t.Fatal(err)
}
- defer clientDown.Close()
- defer serverDown.Close()
- count := 128
- copyDone := make(chan error)
- lr := &io.LimitedReader{
- N: int64(count),
- R: serverUp,
- }
- go func() {
- _, err := io.Copy(serverDown, lr)
- serverDown.Close()
- copyDone <- err
- }()
- msg := make([]byte, 2*count)
- if _, err := clientUp.Write(msg); err != nil {
- t.Fatal(err)
- }
- clientUp.Close()
- var buf bytes.Buffer
- if _, err := io.Copy(&buf, clientDown); err != nil {
- t.Fatal(err)
- }
- if buf.Len() != count {
- t.Errorf("splice transferred %d bytes, want to stop after %d", buf.Len(), count)
- }
- clientDown.Close()
- if err := <-copyDone; err != nil {
- t.Errorf("splice: %v", err)
- }
-}
-
-func testSpliceUpdatesN(t *testing.T) {
- clientUp, serverUp, err := spliceTestSocketPair("tcp")
+ defer cleanup()
+ clientDown, serverDown, err := spliceTestSocketPair(tc.downNet)
if err != nil {
t.Fatal(err)
}
- defer clientUp.Close()
- defer serverUp.Close()
- clientDown, serverDown, err := spliceTestSocketPair("tcp")
- if err != nil {
- t.Fatal(err)
- }
- defer clientDown.Close()
defer serverDown.Close()
- count := 128
- copyDone := make(chan error)
- lr := &io.LimitedReader{
- N: int64(100 + count),
- R: serverUp,
- }
- go func() {
- _, err := io.Copy(serverDown, lr)
- copyDone <- err
- }()
- msg := make([]byte, count)
- if _, err := clientUp.Write(msg); err != nil {
- t.Fatal(err)
- }
- clientUp.Close()
- got := make([]byte, count)
- if _, err := io.ReadFull(clientDown, got); err != nil {
+ cleanup, err = startSpliceClient(clientDown, "r", tc.chunkSize, tc.totalSize)
+ if err != nil {
t.Fatal(err)
}
- clientDown.Close()
- if err := <-copyDone; err != nil {
- t.Errorf("splice: %v", err)
- }
- wantN := int64(100)
- if lr.N != wantN {
- t.Errorf("lr.N = %d, want %d", lr.N, wantN)
- }
-}
+ defer cleanup()
+ var (
+ r io.Reader = serverUp
+ size = tc.totalSize
+ )
+ if tc.limitReadSize > 0 {
+ if tc.limitReadSize < size {
+ size = tc.limitReadSize
+ }
-func testSpliceReaderAtLimit(t *testing.T) {
- clientUp, serverUp, err := spliceTestSocketPair("tcp")
- if err != nil {
- t.Fatal(err)
+ r = &io.LimitedReader{
+ N: int64(tc.limitReadSize),
+ R: serverUp,
+ }
+ defer serverUp.Close()
}
- defer clientUp.Close()
- defer serverUp.Close()
- clientDown, serverDown, err := spliceTestSocketPair("tcp")
+ n, err := io.Copy(serverDown, r)
+ serverDown.Close()
if err != nil {
t.Fatal(err)
}
- defer clientDown.Close()
- defer serverDown.Close()
-
- lr := &io.LimitedReader{
- N: 0,
- R: serverUp,
+ if want := int64(size); want != n {
+ t.Errorf("want %d bytes spliced, got %d", want, n)
}
- _, err, handled := splice(serverDown.(*TCPConn).fd, lr)
- if !handled {
- t.Errorf("exhausted LimitedReader: got err = %v, handled = %t, want handled = true", err, handled)
+
+ if tc.limitReadSize > 0 {
+ wantN := 0
+ if tc.limitReadSize > size {
+ wantN = tc.limitReadSize - size
+ }
+
+ if n := r.(*io.LimitedReader).N; n != int64(wantN) {
+ t.Errorf("r.N = %d, want %d", n, wantN)
+ }
}
}
-func testSpliceReaderAtEOF(t *testing.T) {
- clientUp, serverUp, err := spliceTestSocketPair("tcp")
+func testSpliceReaderAtEOF(t *testing.T, upNet, downNet string) {
+ clientUp, serverUp, err := spliceTestSocketPair(upNet)
if err != nil {
t.Fatal(err)
}
defer clientUp.Close()
- clientDown, serverDown, err := spliceTestSocketPair("tcp")
+ clientDown, serverDown, err := spliceTestSocketPair(downNet)
if err != nil {
t.Fatal(err)
}
// get a goodbye signal. Test for the goodbye signal.
msg := "bye"
go func() {
- serverDown.(*TCPConn).ReadFrom(serverUp)
+ serverDown.(io.ReaderFrom).ReadFrom(serverUp)
io.WriteString(serverDown, msg)
serverDown.Close()
}()
}
}
-func testSpliceIssue25985(t *testing.T) {
- front, err := newLocalListener("tcp")
+func testSpliceIssue25985(t *testing.T, upNet, downNet string) {
+ front, err := newLocalListener(upNet)
if err != nil {
t.Fatal(err)
}
defer front.Close()
- back, err := newLocalListener("tcp")
+ back, err := newLocalListener(downNet)
if err != nil {
t.Fatal(err)
}
if err != nil {
return
}
- dst, err := Dial("tcp", back.Addr().String())
+ dst, err := Dial(downNet, back.Addr().String())
if err != nil {
return
}
go proxy()
- toFront, err := Dial("tcp", front.Addr().String())
+ toFront, err := Dial(upNet, front.Addr().String())
if err != nil {
t.Fatal(err)
}
wg.Wait()
}
-func BenchmarkTCPReadFrom(b *testing.B) {
+func BenchmarkSplice(b *testing.B) {
testHookUninstaller.Do(uninstallTestHooks)
- var chunkSizes []int
- for i := uint(10); i <= 20; i++ {
- chunkSizes = append(chunkSizes, 1<<i)
- }
- // To benchmark the genericReadFrom code path, set this to false.
- useSplice := true
- for _, chunkSize := range chunkSizes {
- b.Run(fmt.Sprint(chunkSize), func(b *testing.B) {
- benchmarkSplice(b, chunkSize, useSplice)
- })
- }
+ b.Run("tcp-to-tcp", func(b *testing.B) { benchSplice(b, "tcp", "tcp") })
+ b.Run("unix-to-tcp", func(b *testing.B) { benchSplice(b, "unix", "tcp") })
}
-func benchmarkSplice(b *testing.B, chunkSize int, useSplice bool) {
- srv, err := newSpliceTestServer()
- if err != nil {
- b.Fatal(err)
- }
- defer srv.Close()
- var copyDone <-chan error
- if useSplice {
- copyDone = srv.Copy()
- } else {
- copyDone = srv.CopyNoSplice()
- }
- chunk := make([]byte, chunkSize)
- discardDone := make(chan struct{})
- go func() {
- for {
- buf := make([]byte, chunkSize)
- _, err := srv.Read(buf)
- if err != nil {
- break
- }
+func benchSplice(b *testing.B, upNet, downNet string) {
+ for i := 0; i <= 10; i++ {
+ chunkSize := 1 << uint(i+10)
+ tc := spliceTestCase{
+ upNet: upNet,
+ downNet: downNet,
+ chunkSize: chunkSize,
}
- discardDone <- struct{}{}
- }()
- b.SetBytes(int64(chunkSize))
- b.ResetTimer()
- for i := 0; i < b.N; i++ {
- srv.Write(chunk)
+
+ b.Run(strconv.Itoa(chunkSize), tc.bench)
}
- srv.CloseWrite()
- <-copyDone
- srv.CloseRead()
- <-discardDone
}
-type spliceTestServer struct {
- clientUp io.WriteCloser
- clientDown io.ReadCloser
- serverUp io.ReadCloser
- serverDown io.WriteCloser
-}
+func (tc spliceTestCase) bench(b *testing.B) {
+ // To benchmark the genericReadFrom code path, set this to false.
+ useSplice := true
-func newSpliceTestServer() (*spliceTestServer, error) {
- // For now, both networks are hard-coded to TCP.
- // If splice is enabled for non-tcp upstream connections,
- // newSpliceTestServer will need to take a network parameter.
- clientUp, serverUp, err := spliceTestSocketPair("tcp")
+ clientUp, serverUp, err := spliceTestSocketPair(tc.upNet)
if err != nil {
- return nil, err
+ b.Fatal(err)
}
- clientDown, serverDown, err := spliceTestSocketPair("tcp")
+ defer serverUp.Close()
+
+ cleanup, err := startSpliceClient(clientUp, "w", tc.chunkSize, tc.chunkSize*b.N)
if err != nil {
- clientUp.Close()
- serverUp.Close()
- return nil, err
+ b.Fatal(err)
}
- return &spliceTestServer{clientUp, clientDown, serverUp, serverDown}, nil
-}
-
-// Read reads from the downstream connection.
-func (srv *spliceTestServer) Read(b []byte) (int, error) {
- return srv.clientDown.Read(b)
-}
-
-// Write writes to the upstream connection.
-func (srv *spliceTestServer) Write(b []byte) (int, error) {
- return srv.clientUp.Write(b)
-}
+ defer cleanup()
-// Close closes the server.
-func (srv *spliceTestServer) Close() error {
- err := srv.closeUp()
- err1 := srv.closeDown()
- if err == nil {
- return err1
+ clientDown, serverDown, err := spliceTestSocketPair(tc.downNet)
+ if err != nil {
+ b.Fatal(err)
}
- return err
-}
-
-// CloseWrite closes the client side of the upstream connection.
-func (srv *spliceTestServer) CloseWrite() error {
- return srv.clientUp.Close()
-}
-
-// CloseRead closes the client side of the downstream connection.
-func (srv *spliceTestServer) CloseRead() error {
- return srv.clientDown.Close()
-}
-
-// Copy copies from the server side of the upstream connection
-// to the server side of the downstream connection, in a separate
-// goroutine. Copy is done when the first send on the returned
-// channel succeeds.
-func (srv *spliceTestServer) Copy() <-chan error {
- ch := make(chan error)
- go func() {
- _, err := io.Copy(srv.serverDown, srv.serverUp)
- ch <- err
- close(ch)
- }()
- return ch
-}
+ defer serverDown.Close()
-// CopyNoSplice is like Copy, but ensures that the splice code path
-// is not reached.
-func (srv *spliceTestServer) CopyNoSplice() <-chan error {
- type onlyReader struct {
- io.Reader
+ cleanup, err = startSpliceClient(clientDown, "r", tc.chunkSize, tc.chunkSize*b.N)
+ if err != nil {
+ b.Fatal(err)
}
- ch := make(chan error)
- go func() {
- _, err := io.Copy(srv.serverDown, onlyReader{srv.serverUp})
- ch <- err
- close(ch)
- }()
- return ch
-}
+ defer cleanup()
-func (srv *spliceTestServer) closeUp() error {
- var err, err1 error
- if srv.serverUp != nil {
- err = srv.serverUp.Close()
- }
- if srv.clientUp != nil {
- err1 = srv.clientUp.Close()
- }
- if err == nil {
- return err1
- }
- return err
-}
+ b.SetBytes(int64(tc.chunkSize))
+ b.ResetTimer()
-func (srv *spliceTestServer) closeDown() error {
- var err, err1 error
- if srv.serverDown != nil {
- err = srv.serverDown.Close()
- }
- if srv.clientDown != nil {
- err1 = srv.clientDown.Close()
- }
- if err == nil {
- return err1
+ if useSplice {
+ _, err := io.Copy(serverDown, serverUp)
+ if err != nil {
+ b.Fatal(err)
+ }
+ } else {
+ type onlyReader struct {
+ io.Reader
+ }
+ _, err := io.Copy(serverDown, onlyReader{serverUp})
+ if err != nil {
+ b.Fatal(err)
+ }
}
- return err
}
func spliceTestSocketPair(net string) (client, server Conn, err error) {
}
return client, server, nil
}
+
+func startSpliceClient(conn Conn, op string, chunkSize, totalSize int) (func(), error) {
+ f, err := conn.(interface{ File() (*os.File, error) }).File()
+ if err != nil {
+ return nil, err
+ }
+
+ cmd := exec.Command(os.Args[0], os.Args[1:]...)
+ cmd.Env = []string{
+ "GO_NET_TEST_SPLICE=1",
+ "GO_NET_TEST_SPLICE_OP=" + op,
+ "GO_NET_TEST_SPLICE_CHUNK_SIZE=" + strconv.Itoa(chunkSize),
+ "GO_NET_TEST_SPLICE_TOTAL_SIZE=" + strconv.Itoa(totalSize),
+ }
+ cmd.ExtraFiles = append(cmd.ExtraFiles, f)
+ cmd.Stdout = os.Stdout
+ cmd.Stderr = os.Stderr
+
+ if err := cmd.Start(); err != nil {
+ return nil, err
+ }
+
+ donec := make(chan struct{})
+ go func() {
+ cmd.Wait()
+ conn.Close()
+ f.Close()
+ close(donec)
+ }()
+
+ return func() { <-donec }, nil
+}
+
+func init() {
+ if os.Getenv("GO_NET_TEST_SPLICE") == "" {
+ return
+ }
+ defer os.Exit(0)
+
+ f := os.NewFile(uintptr(3), "splice-test-conn")
+ defer f.Close()
+
+ conn, err := FileConn(f)
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ var chunkSize int
+ if chunkSize, err = strconv.Atoi(os.Getenv("GO_NET_TEST_SPLICE_CHUNK_SIZE")); err != nil {
+ log.Fatal(err)
+ }
+ buf := make([]byte, chunkSize)
+
+ var totalSize int
+ if totalSize, err = strconv.Atoi(os.Getenv("GO_NET_TEST_SPLICE_TOTAL_SIZE")); err != nil {
+ log.Fatal(err)
+ }
+
+ var fn func([]byte) (int, error)
+ switch op := os.Getenv("GO_NET_TEST_SPLICE_OP"); op {
+ case "r":
+ fn = conn.Read
+ case "w":
+ defer conn.Close()
+
+ fn = conn.Write
+ default:
+ log.Fatalf("unknown op %q", op)
+ }
+
+ var n int
+ for count := 0; count < totalSize; count += n {
+ if count+chunkSize > totalSize {
+ buf = buf[:totalSize-count]
+ }
+
+ var err error
+ if n, err = fn(buf); err != nil {
+ return
+ }
+ }
+}