package net
import (
+ "internal/poll"
"io"
"log"
"os"
"os/exec"
"strconv"
"sync"
+ "syscall"
"testing"
"time"
)
}
func (tc spliceTestCase) test(t *testing.T) {
+ hook := hookSplice(t)
+
clientUp, serverUp := spliceTestSocketPair(t, tc.upNet)
defer serverUp.Close()
cleanup, err := startSpliceClient(clientUp, "w", tc.chunkSize, tc.totalSize)
t.Fatal(err)
}
defer cleanup()
+
var (
r io.Reader = serverUp
size = tc.totalSize
defer serverUp.Close()
}
n, err := io.Copy(serverDown, r)
- serverDown.Close()
if err != nil {
t.Fatal(err)
}
+
if want := int64(size); want != n {
t.Errorf("want %d bytes spliced, got %d", want, n)
}
t.Errorf("r.N = %d, want %d", n, wantN)
}
}
+
+ // poll.Splice is expected to be called when the source is not
+ // a wrapper or the destination is TCPConn.
+ if tc.limitReadSize == 0 || tc.downNet == "tcp" {
+ // We should have called poll.Splice with the right file descriptor arguments.
+ if n > 0 && !hook.called {
+ t.Fatal("expected poll.Splice to be called")
+ }
+
+ verifySpliceFds(t, serverDown, hook, "dst")
+ verifySpliceFds(t, serverUp, hook, "src")
+
+ // poll.Splice is expected to handle the data transmission successfully.
+ if !hook.handled || hook.written != int64(size) || hook.err != nil {
+ t.Errorf("expected handled = true, written = %d, err = nil, but got handled = %t, written = %d, err = %v",
+ size, hook.handled, hook.written, hook.err)
+ }
+ } else if hook.called {
+ // poll.Splice will certainly not be called when the source
+ // is a wrapper and the destination is not TCPConn.
+ t.Errorf("expected poll.Splice not be called")
+ }
+}
+
+func verifySpliceFds(t *testing.T, c Conn, hook *spliceHook, fdType string) {
+ t.Helper()
+
+ sc, ok := c.(syscall.Conn)
+ if !ok {
+ t.Fatalf("expected syscall.Conn")
+ }
+ rc, err := sc.SyscallConn()
+ if err != nil {
+ t.Fatalf("syscall.Conn.SyscallConn error: %v", err)
+ }
+ var hookFd int
+ switch fdType {
+ case "src":
+ hookFd = hook.srcfd
+ case "dst":
+ hookFd = hook.dstfd
+ default:
+ t.Fatalf("unknown fdType %q", fdType)
+ }
+ if err := rc.Control(func(fd uintptr) {
+ if hook.called && hookFd != int(fd) {
+ t.Fatalf("wrong %s file descriptor: got %d, want %d", fdType, hook.dstfd, int(fd))
+ }
+ }); err != nil {
+ t.Fatalf("syscall.RawConn.Control error: %v", err)
+ }
}
func (tc spliceTestCase) testFile(t *testing.T) {
+ hook := hookSplice(t)
+
f, err := os.OpenFile(os.DevNull, os.O_WRONLY, 0)
if err != nil {
t.Fatal(err)
if err != nil {
t.Fatalf("failed to ReadFrom with error: %v", err)
}
+
+ // We shouldn't have called poll.Splice in TCPConn.WriteTo,
+ // it's supposed to be called from File.ReadFrom.
+ if got > 0 && hook.called {
+ t.Error("expected not poll.Splice to be called")
+ }
+
if want := int64(actualSize); got != want {
t.Errorf("got %d bytes, want %d", got, want)
}
// UnixConn doesn't implement io.ReaderFrom, which will fail
// the following test in asserting a UnixConn to be an io.ReaderFrom,
// so skip this test.
- if upNet == "unix" || downNet == "unix" {
+ if downNet == "unix" {
t.Skip("skipping test on unix socket")
}
+ hook := hookSplice(t)
+
clientUp, serverUp := spliceTestSocketPair(t, upNet)
defer clientUp.Close()
clientDown, serverDown := spliceTestSocketPair(t, downNet)
defer clientDown.Close()
+ defer serverDown.Close()
serverUp.Close()
go func() {
serverDown.(io.ReaderFrom).ReadFrom(serverUp)
io.WriteString(serverDown, msg)
- serverDown.Close()
}()
buf := make([]byte, 3)
- _, err := io.ReadFull(clientDown, buf)
+ n, err := io.ReadFull(clientDown, buf)
if err != nil {
t.Errorf("clientDown: %v", err)
}
if string(buf) != msg {
t.Errorf("clientDown got %q, want %q", buf, msg)
}
+
+ // We should have called poll.Splice with the right file descriptor arguments.
+ if n > 0 && !hook.called {
+ t.Fatal("expected poll.Splice to be called")
+ }
+
+ verifySpliceFds(t, serverDown, hook, "dst")
+
+ // poll.Splice is expected to handle the data transmission but fail
+ // when working with a closed endpoint, return an error.
+ if !hook.handled || hook.written > 0 || hook.err == nil {
+ t.Errorf("expected handled = true, written = 0, err != nil, but got handled = %t, written = %d, err = %v",
+ hook.handled, hook.written, hook.err)
+ }
}
func testSpliceIssue25985(t *testing.T, upNet, downNet string) {
b.Errorf("bytes sent mismatch, got: %d, want: %d", got, want)
}
}
+
+func hookSplice(t *testing.T) *spliceHook {
+ t.Helper()
+
+ h := new(spliceHook)
+ h.install()
+ t.Cleanup(h.uninstall)
+ return h
+}
+
+type spliceHook struct {
+ called bool
+ dstfd int
+ srcfd int
+ remain int64
+
+ written int64
+ handled bool
+ sc string
+ err error
+
+ original func(dst, src *poll.FD, remain int64) (int64, bool, string, error)
+}
+
+func (h *spliceHook) install() {
+ h.original = pollSplice
+ pollSplice = func(dst, src *poll.FD, remain int64) (int64, bool, string, error) {
+ h.called = true
+ h.dstfd = dst.Sysfd
+ h.srcfd = src.Sysfd
+ h.remain = remain
+ h.written, h.handled, h.sc, h.err = h.original(dst, src, remain)
+ return h.written, h.handled, h.sc, h.err
+ }
+}
+
+func (h *spliceHook) uninstall() {
+ pollSplice = h.original
+}