From: Andy Pan Date: Fri, 22 Dec 2023 13:49:46 +0000 (+0800) Subject: net: verify if internal/poll.Splice has been called during io.Copy X-Git-Tag: go1.23rc1~1390 X-Git-Url: http://www.git.cypherpunks.su/?a=commitdiff_plain;h=6674f2225e3a409f3f03b8b6ba31c1f3ddb0b35b;p=gostls13.git net: verify if internal/poll.Splice has been called during io.Copy Change-Id: I29ef35b034cd4ec401ac502bf95dbd8c3ef2a2d4 Reviewed-on: https://go-review.googlesource.com/c/go/+/552535 LUCI-TryBot-Result: Go LUCI Auto-Submit: Matthew Dempsky Reviewed-by: Matthew Dempsky Reviewed-by: Damien Neil --- diff --git a/src/net/splice_linux.go b/src/net/splice_linux.go index bdafcb59ab..9fc26b4c23 100644 --- a/src/net/splice_linux.go +++ b/src/net/splice_linux.go @@ -9,6 +9,8 @@ import ( "io" ) +var pollSplice = poll.Splice + // spliceFrom transfers data from r to c using the splice system call to minimize // copies from and to userspace. c must be a TCP connection. // Currently, spliceFrom is only enabled if r is a TCP or a stream-oriented Unix connection. @@ -39,7 +41,7 @@ func spliceFrom(c *netFD, r io.Reader) (written int64, err error, handled bool) return 0, nil, false } - written, handled, sc, err := poll.Splice(&c.pfd, &s.pfd, remain) + written, handled, sc, err := pollSplice(&c.pfd, &s.pfd, remain) if lr != nil { lr.N -= written } @@ -57,6 +59,6 @@ func spliceTo(w io.Writer, c *netFD) (written int64, err error, handled bool) { return } - written, handled, sc, err := poll.Splice(&uc.fd.pfd, &c.pfd, 1<<63-1) + written, handled, sc, err := pollSplice(&uc.fd.pfd, &c.pfd, 1<<63-1) return written, wrapSyscallError(sc, err), handled } diff --git a/src/net/splice_test.go b/src/net/splice_linux_test.go similarity index 80% rename from src/net/splice_test.go rename to src/net/splice_linux_test.go index 227ddebff4..7082ecdfbe 100644 --- a/src/net/splice_test.go +++ b/src/net/splice_linux_test.go @@ -7,12 +7,14 @@ package net import ( + "internal/poll" "io" "log" "os" "os/exec" "strconv" "sync" + "syscall" "testing" "time" ) @@ -58,6 +60,8 @@ type spliceTestCase struct { } func (tc spliceTestCase) test(t *testing.T) { + hook := hookSplice(t) + clientUp, serverUp := spliceTestSocketPair(t, tc.upNet) defer serverUp.Close() cleanup, err := startSpliceClient(clientUp, "w", tc.chunkSize, tc.totalSize) @@ -72,6 +76,7 @@ func (tc spliceTestCase) test(t *testing.T) { t.Fatal(err) } defer cleanup() + var ( r io.Reader = serverUp size = tc.totalSize @@ -88,10 +93,10 @@ func (tc spliceTestCase) test(t *testing.T) { defer serverUp.Close() } n, err := io.Copy(serverDown, r) - serverDown.Close() if err != nil { t.Fatal(err) } + if want := int64(size); want != n { t.Errorf("want %d bytes spliced, got %d", want, n) } @@ -106,9 +111,62 @@ func (tc spliceTestCase) test(t *testing.T) { t.Errorf("r.N = %d, want %d", n, wantN) } } + + // poll.Splice is expected to be called when the source is not + // a wrapper or the destination is TCPConn. + if tc.limitReadSize == 0 || tc.downNet == "tcp" { + // We should have called poll.Splice with the right file descriptor arguments. + if n > 0 && !hook.called { + t.Fatal("expected poll.Splice to be called") + } + + verifySpliceFds(t, serverDown, hook, "dst") + verifySpliceFds(t, serverUp, hook, "src") + + // poll.Splice is expected to handle the data transmission successfully. + if !hook.handled || hook.written != int64(size) || hook.err != nil { + t.Errorf("expected handled = true, written = %d, err = nil, but got handled = %t, written = %d, err = %v", + size, hook.handled, hook.written, hook.err) + } + } else if hook.called { + // poll.Splice will certainly not be called when the source + // is a wrapper and the destination is not TCPConn. + t.Errorf("expected poll.Splice not be called") + } +} + +func verifySpliceFds(t *testing.T, c Conn, hook *spliceHook, fdType string) { + t.Helper() + + sc, ok := c.(syscall.Conn) + if !ok { + t.Fatalf("expected syscall.Conn") + } + rc, err := sc.SyscallConn() + if err != nil { + t.Fatalf("syscall.Conn.SyscallConn error: %v", err) + } + var hookFd int + switch fdType { + case "src": + hookFd = hook.srcfd + case "dst": + hookFd = hook.dstfd + default: + t.Fatalf("unknown fdType %q", fdType) + } + if err := rc.Control(func(fd uintptr) { + if hook.called && hookFd != int(fd) { + t.Fatalf("wrong %s file descriptor: got %d, want %d", fdType, hook.dstfd, int(fd)) + } + }); err != nil { + t.Fatalf("syscall.RawConn.Control error: %v", err) + } } func (tc spliceTestCase) testFile(t *testing.T) { + hook := hookSplice(t) + f, err := os.OpenFile(os.DevNull, os.O_WRONLY, 0) if err != nil { t.Fatal(err) @@ -144,6 +202,13 @@ func (tc spliceTestCase) testFile(t *testing.T) { if err != nil { t.Fatalf("failed to ReadFrom with error: %v", err) } + + // We shouldn't have called poll.Splice in TCPConn.WriteTo, + // it's supposed to be called from File.ReadFrom. + if got > 0 && hook.called { + t.Error("expected not poll.Splice to be called") + } + if want := int64(actualSize); got != want { t.Errorf("got %d bytes, want %d", got, want) } @@ -163,14 +228,17 @@ func testSpliceReaderAtEOF(t *testing.T, upNet, downNet string) { // UnixConn doesn't implement io.ReaderFrom, which will fail // the following test in asserting a UnixConn to be an io.ReaderFrom, // so skip this test. - if upNet == "unix" || downNet == "unix" { + if downNet == "unix" { t.Skip("skipping test on unix socket") } + hook := hookSplice(t) + clientUp, serverUp := spliceTestSocketPair(t, upNet) defer clientUp.Close() clientDown, serverDown := spliceTestSocketPair(t, downNet) defer clientDown.Close() + defer serverDown.Close() serverUp.Close() @@ -194,17 +262,30 @@ func testSpliceReaderAtEOF(t *testing.T, upNet, downNet string) { go func() { serverDown.(io.ReaderFrom).ReadFrom(serverUp) io.WriteString(serverDown, msg) - serverDown.Close() }() buf := make([]byte, 3) - _, err := io.ReadFull(clientDown, buf) + n, err := io.ReadFull(clientDown, buf) if err != nil { t.Errorf("clientDown: %v", err) } if string(buf) != msg { t.Errorf("clientDown got %q, want %q", buf, msg) } + + // We should have called poll.Splice with the right file descriptor arguments. + if n > 0 && !hook.called { + t.Fatal("expected poll.Splice to be called") + } + + verifySpliceFds(t, serverDown, hook, "dst") + + // poll.Splice is expected to handle the data transmission but fail + // when working with a closed endpoint, return an error. + if !hook.handled || hook.written > 0 || hook.err == nil { + t.Errorf("expected handled = true, written = 0, err != nil, but got handled = %t, written = %d, err = %v", + hook.handled, hook.written, hook.err) + } } func testSpliceIssue25985(t *testing.T, upNet, downNet string) { @@ -539,3 +620,42 @@ func (bench spliceFileBench) benchSpliceFile(b *testing.B) { b.Errorf("bytes sent mismatch, got: %d, want: %d", got, want) } } + +func hookSplice(t *testing.T) *spliceHook { + t.Helper() + + h := new(spliceHook) + h.install() + t.Cleanup(h.uninstall) + return h +} + +type spliceHook struct { + called bool + dstfd int + srcfd int + remain int64 + + written int64 + handled bool + sc string + err error + + original func(dst, src *poll.FD, remain int64) (int64, bool, string, error) +} + +func (h *spliceHook) install() { + h.original = pollSplice + pollSplice = func(dst, src *poll.FD, remain int64) (int64, bool, string, error) { + h.called = true + h.dstfd = dst.Sysfd + h.srcfd = src.Sysfd + h.remain = remain + h.written, h.handled, h.sc, h.err = h.original(dst, src, remain) + return h.written, h.handled, h.sc, h.err + } +} + +func (h *spliceHook) uninstall() { + pollSplice = h.original +}