]> Cypherpunks repositories - gostls13.git/commitdiff
net: verify if internal/poll.Splice has been called during io.Copy
authorAndy Pan <panjf2000@gmail.com>
Fri, 22 Dec 2023 13:49:46 +0000 (21:49 +0800)
committerGopher Robot <gobot@golang.org>
Wed, 24 Jan 2024 22:13:22 +0000 (22:13 +0000)
Change-Id: I29ef35b034cd4ec401ac502bf95dbd8c3ef2a2d4
Reviewed-on: https://go-review.googlesource.com/c/go/+/552535
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Auto-Submit: Matthew Dempsky <mdempsky@google.com>
Reviewed-by: Matthew Dempsky <mdempsky@google.com>
Reviewed-by: Damien Neil <dneil@google.com>
src/net/splice_linux.go
src/net/splice_linux_test.go [moved from src/net/splice_test.go with 80% similarity]

index bdafcb59ab84802ec6e703ed625ea130b15a42f6..9fc26b4c23e9d2b58ef0005b39faaef26fce60fd 100644 (file)
@@ -9,6 +9,8 @@ import (
        "io"
 )
 
+var pollSplice = poll.Splice
+
 // spliceFrom transfers data from r to c using the splice system call to minimize
 // copies from and to userspace. c must be a TCP connection.
 // Currently, spliceFrom is only enabled if r is a TCP or a stream-oriented Unix connection.
@@ -39,7 +41,7 @@ func spliceFrom(c *netFD, r io.Reader) (written int64, err error, handled bool)
                return 0, nil, false
        }
 
-       written, handled, sc, err := poll.Splice(&c.pfd, &s.pfd, remain)
+       written, handled, sc, err := pollSplice(&c.pfd, &s.pfd, remain)
        if lr != nil {
                lr.N -= written
        }
@@ -57,6 +59,6 @@ func spliceTo(w io.Writer, c *netFD) (written int64, err error, handled bool) {
                return
        }
 
-       written, handled, sc, err := poll.Splice(&uc.fd.pfd, &c.pfd, 1<<63-1)
+       written, handled, sc, err := pollSplice(&uc.fd.pfd, &c.pfd, 1<<63-1)
        return written, wrapSyscallError(sc, err), handled
 }
similarity index 80%
rename from src/net/splice_test.go
rename to src/net/splice_linux_test.go
index 227ddebff402cea3a8ece694d44b68e91cea02f0..7082ecdfbe82dcfe7ef39afaa4719a705527272c 100644 (file)
@@ -7,12 +7,14 @@
 package net
 
 import (
+       "internal/poll"
        "io"
        "log"
        "os"
        "os/exec"
        "strconv"
        "sync"
+       "syscall"
        "testing"
        "time"
 )
@@ -58,6 +60,8 @@ type spliceTestCase struct {
 }
 
 func (tc spliceTestCase) test(t *testing.T) {
+       hook := hookSplice(t)
+
        clientUp, serverUp := spliceTestSocketPair(t, tc.upNet)
        defer serverUp.Close()
        cleanup, err := startSpliceClient(clientUp, "w", tc.chunkSize, tc.totalSize)
@@ -72,6 +76,7 @@ func (tc spliceTestCase) test(t *testing.T) {
                t.Fatal(err)
        }
        defer cleanup()
+
        var (
                r    io.Reader = serverUp
                size           = tc.totalSize
@@ -88,10 +93,10 @@ func (tc spliceTestCase) test(t *testing.T) {
                defer serverUp.Close()
        }
        n, err := io.Copy(serverDown, r)
-       serverDown.Close()
        if err != nil {
                t.Fatal(err)
        }
+
        if want := int64(size); want != n {
                t.Errorf("want %d bytes spliced, got %d", want, n)
        }
@@ -106,9 +111,62 @@ func (tc spliceTestCase) test(t *testing.T) {
                        t.Errorf("r.N = %d, want %d", n, wantN)
                }
        }
+
+       // poll.Splice is expected to be called when the source is not
+       // a wrapper or the destination is TCPConn.
+       if tc.limitReadSize == 0 || tc.downNet == "tcp" {
+               // We should have called poll.Splice with the right file descriptor arguments.
+               if n > 0 && !hook.called {
+                       t.Fatal("expected poll.Splice to be called")
+               }
+
+               verifySpliceFds(t, serverDown, hook, "dst")
+               verifySpliceFds(t, serverUp, hook, "src")
+
+               // poll.Splice is expected to handle the data transmission successfully.
+               if !hook.handled || hook.written != int64(size) || hook.err != nil {
+                       t.Errorf("expected handled = true, written = %d, err = nil, but got handled = %t, written = %d, err = %v",
+                               size, hook.handled, hook.written, hook.err)
+               }
+       } else if hook.called {
+               // poll.Splice will certainly not be called when the source
+               // is a wrapper and the destination is not TCPConn.
+               t.Errorf("expected poll.Splice not be called")
+       }
+}
+
+func verifySpliceFds(t *testing.T, c Conn, hook *spliceHook, fdType string) {
+       t.Helper()
+
+       sc, ok := c.(syscall.Conn)
+       if !ok {
+               t.Fatalf("expected syscall.Conn")
+       }
+       rc, err := sc.SyscallConn()
+       if err != nil {
+               t.Fatalf("syscall.Conn.SyscallConn error: %v", err)
+       }
+       var hookFd int
+       switch fdType {
+       case "src":
+               hookFd = hook.srcfd
+       case "dst":
+               hookFd = hook.dstfd
+       default:
+               t.Fatalf("unknown fdType %q", fdType)
+       }
+       if err := rc.Control(func(fd uintptr) {
+               if hook.called && hookFd != int(fd) {
+                       t.Fatalf("wrong %s file descriptor: got %d, want %d", fdType, hook.dstfd, int(fd))
+               }
+       }); err != nil {
+               t.Fatalf("syscall.RawConn.Control error: %v", err)
+       }
 }
 
 func (tc spliceTestCase) testFile(t *testing.T) {
+       hook := hookSplice(t)
+
        f, err := os.OpenFile(os.DevNull, os.O_WRONLY, 0)
        if err != nil {
                t.Fatal(err)
@@ -144,6 +202,13 @@ func (tc spliceTestCase) testFile(t *testing.T) {
        if err != nil {
                t.Fatalf("failed to ReadFrom with error: %v", err)
        }
+
+       // We shouldn't have called poll.Splice in TCPConn.WriteTo,
+       // it's supposed to be called from File.ReadFrom.
+       if got > 0 && hook.called {
+               t.Error("expected not poll.Splice to be called")
+       }
+
        if want := int64(actualSize); got != want {
                t.Errorf("got %d bytes, want %d", got, want)
        }
@@ -163,14 +228,17 @@ func testSpliceReaderAtEOF(t *testing.T, upNet, downNet string) {
        // UnixConn doesn't implement io.ReaderFrom, which will fail
        // the following test in asserting a UnixConn to be an io.ReaderFrom,
        // so skip this test.
-       if upNet == "unix" || downNet == "unix" {
+       if downNet == "unix" {
                t.Skip("skipping test on unix socket")
        }
 
+       hook := hookSplice(t)
+
        clientUp, serverUp := spliceTestSocketPair(t, upNet)
        defer clientUp.Close()
        clientDown, serverDown := spliceTestSocketPair(t, downNet)
        defer clientDown.Close()
+       defer serverDown.Close()
 
        serverUp.Close()
 
@@ -194,17 +262,30 @@ func testSpliceReaderAtEOF(t *testing.T, upNet, downNet string) {
        go func() {
                serverDown.(io.ReaderFrom).ReadFrom(serverUp)
                io.WriteString(serverDown, msg)
-               serverDown.Close()
        }()
 
        buf := make([]byte, 3)
-       _, err := io.ReadFull(clientDown, buf)
+       n, err := io.ReadFull(clientDown, buf)
        if err != nil {
                t.Errorf("clientDown: %v", err)
        }
        if string(buf) != msg {
                t.Errorf("clientDown got %q, want %q", buf, msg)
        }
+
+       // We should have called poll.Splice with the right file descriptor arguments.
+       if n > 0 && !hook.called {
+               t.Fatal("expected poll.Splice to be called")
+       }
+
+       verifySpliceFds(t, serverDown, hook, "dst")
+
+       // poll.Splice is expected to handle the data transmission but fail
+       // when working with a closed endpoint, return an error.
+       if !hook.handled || hook.written > 0 || hook.err == nil {
+               t.Errorf("expected handled = true, written = 0, err != nil, but got handled = %t, written = %d, err = %v",
+                       hook.handled, hook.written, hook.err)
+       }
 }
 
 func testSpliceIssue25985(t *testing.T, upNet, downNet string) {
@@ -539,3 +620,42 @@ func (bench spliceFileBench) benchSpliceFile(b *testing.B) {
                b.Errorf("bytes sent mismatch, got: %d, want: %d", got, want)
        }
 }
+
+func hookSplice(t *testing.T) *spliceHook {
+       t.Helper()
+
+       h := new(spliceHook)
+       h.install()
+       t.Cleanup(h.uninstall)
+       return h
+}
+
+type spliceHook struct {
+       called bool
+       dstfd  int
+       srcfd  int
+       remain int64
+
+       written int64
+       handled bool
+       sc      string
+       err     error
+
+       original func(dst, src *poll.FD, remain int64) (int64, bool, string, error)
+}
+
+func (h *spliceHook) install() {
+       h.original = pollSplice
+       pollSplice = func(dst, src *poll.FD, remain int64) (int64, bool, string, error) {
+               h.called = true
+               h.dstfd = dst.Sysfd
+               h.srcfd = src.Sysfd
+               h.remain = remain
+               h.written, h.handled, h.sc, h.err = h.original(dst, src, remain)
+               return h.written, h.handled, h.sc, h.err
+       }
+}
+
+func (h *spliceHook) uninstall() {
+       pollSplice = h.original
+}