From: Andy Pan Date: Tue, 28 Feb 2023 08:39:15 +0000 (+0800) Subject: net,os: arrange zero-copy of os.File and net.TCPConn to net.UnixConn X-Git-Tag: go1.22rc1~263 X-Git-Url: http://www.git.cypherpunks.su/?a=commitdiff_plain;h=f664031bc17629080332a1c7bede38d67fd32e47;p=gostls13.git net,os: arrange zero-copy of os.File and net.TCPConn to net.UnixConn Fixes #58808 goos: linux goarch: amd64 pkg: net cpu: DO-Premium-Intel │ old │ new │ │ sec/op │ sec/op vs base │ Splice/tcp-to-unix/1024-4 3.783µ ± 10% 3.201µ ± 7% -15.40% (p=0.001 n=10) Splice/tcp-to-unix/2048-4 3.967µ ± 13% 3.818µ ± 16% ~ (p=0.971 n=10) Splice/tcp-to-unix/4096-4 4.988µ ± 16% 4.590µ ± 11% ~ (p=0.089 n=10) Splice/tcp-to-unix/8192-4 6.981µ ± 13% 5.236µ ± 9% -25.00% (p=0.000 n=10) Splice/tcp-to-unix/16384-4 10.192µ ± 9% 7.350µ ± 7% -27.89% (p=0.000 n=10) Splice/tcp-to-unix/32768-4 19.65µ ± 13% 10.28µ ± 16% -47.69% (p=0.000 n=10) Splice/tcp-to-unix/65536-4 41.89µ ± 18% 15.70µ ± 13% -62.52% (p=0.000 n=10) Splice/tcp-to-unix/131072-4 90.05µ ± 11% 29.55µ ± 10% -67.18% (p=0.000 n=10) Splice/tcp-to-unix/262144-4 170.24µ ± 15% 52.66µ ± 4% -69.06% (p=0.000 n=10) Splice/tcp-to-unix/524288-4 326.4µ ± 13% 109.3µ ± 11% -66.52% (p=0.000 n=10) Splice/tcp-to-unix/1048576-4 651.4µ ± 9% 228.3µ ± 14% -64.95% (p=0.000 n=10) geomean 29.42µ 15.62µ -46.90% │ old │ new │ │ B/s │ B/s vs base │ Splice/tcp-to-unix/1024-4 258.2Mi ± 11% 305.2Mi ± 8% +18.21% (p=0.001 n=10) Splice/tcp-to-unix/2048-4 492.5Mi ± 15% 511.7Mi ± 13% ~ (p=0.971 n=10) Splice/tcp-to-unix/4096-4 783.5Mi ± 14% 851.2Mi ± 12% ~ (p=0.089 n=10) Splice/tcp-to-unix/8192-4 1.093Gi ± 11% 1.458Gi ± 8% +33.36% (p=0.000 n=10) Splice/tcp-to-unix/16384-4 1.497Gi ± 9% 2.076Gi ± 7% +38.67% (p=0.000 n=10) Splice/tcp-to-unix/32768-4 1.553Gi ± 11% 2.969Gi ± 14% +91.17% (p=0.000 n=10) Splice/tcp-to-unix/65536-4 1.458Gi ± 23% 3.888Gi ± 11% +166.69% (p=0.000 n=10) Splice/tcp-to-unix/131072-4 1.356Gi ± 10% 4.131Gi ± 9% +204.72% (p=0.000 n=10) Splice/tcp-to-unix/262144-4 1.434Gi ± 13% 4.637Gi ± 4% +223.32% (p=0.000 n=10) Splice/tcp-to-unix/524288-4 1.497Gi ± 15% 4.468Gi ± 10% +198.47% (p=0.000 n=10) Splice/tcp-to-unix/1048576-4 1.501Gi ± 10% 4.277Gi ± 16% +184.88% (p=0.000 n=10) geomean 1.038Gi 1.954Gi +88.28% │ old │ new │ │ B/op │ B/op vs base │ Splice/tcp-to-unix/1024-4 0.000 ± 0% 0.000 ± 0% ~ (p=1.000 n=10) ¹ Splice/tcp-to-unix/2048-4 0.000 ± 0% 0.000 ± 0% ~ (p=1.000 n=10) ¹ Splice/tcp-to-unix/4096-4 0.000 ± 0% 0.000 ± 0% ~ (p=1.000 n=10) ¹ Splice/tcp-to-unix/8192-4 0.000 ± 0% 0.000 ± 0% ~ (p=1.000 n=10) ¹ Splice/tcp-to-unix/16384-4 0.000 ± 0% 0.000 ± 0% ~ (p=1.000 n=10) ¹ Splice/tcp-to-unix/32768-4 0.000 ± 0% 0.000 ± 0% ~ (p=1.000 n=10) ¹ Splice/tcp-to-unix/65536-4 1.000 ± ? 0.000 ± 0% -100.00% (p=0.001 n=10) Splice/tcp-to-unix/131072-4 2.000 ± 0% 0.000 ± 0% -100.00% (p=0.000 n=10) Splice/tcp-to-unix/262144-4 4.000 ± 25% 0.000 ± 0% -100.00% (p=0.000 n=10) Splice/tcp-to-unix/524288-4 7.500 ± 33% 0.000 ± 0% -100.00% (p=0.000 n=10) Splice/tcp-to-unix/1048576-4 17.00 ± 12% 0.00 ± 0% -100.00% (p=0.000 n=10) geomean ² ? ² ³ ¹ all samples are equal ² summaries must be >0 to compute geomean ³ ratios must be >0 to compute geomean │ old │ new │ │ allocs/op │ allocs/op vs base │ Splice/tcp-to-unix/1024-4 0.000 ± 0% 0.000 ± 0% ~ (p=1.000 n=10) ¹ Splice/tcp-to-unix/2048-4 0.000 ± 0% 0.000 ± 0% ~ (p=1.000 n=10) ¹ Splice/tcp-to-unix/4096-4 0.000 ± 0% 0.000 ± 0% ~ (p=1.000 n=10) ¹ Splice/tcp-to-unix/8192-4 0.000 ± 0% 0.000 ± 0% ~ (p=1.000 n=10) ¹ Splice/tcp-to-unix/16384-4 0.000 ± 0% 0.000 ± 0% ~ (p=1.000 n=10) ¹ Splice/tcp-to-unix/32768-4 0.000 ± 0% 0.000 ± 0% ~ (p=1.000 n=10) ¹ Splice/tcp-to-unix/65536-4 0.000 ± 0% 0.000 ± 0% ~ (p=1.000 n=10) ¹ Splice/tcp-to-unix/131072-4 0.000 ± 0% 0.000 ± 0% ~ (p=1.000 n=10) ¹ Splice/tcp-to-unix/262144-4 0.000 ± 0% 0.000 ± 0% ~ (p=1.000 n=10) ¹ Splice/tcp-to-unix/524288-4 0.000 ± 0% 0.000 ± 0% ~ (p=1.000 n=10) ¹ Splice/tcp-to-unix/1048576-4 0.000 ± 0% 0.000 ± 0% ~ (p=1.000 n=10) ¹ geomean ² +0.00% ² ¹ all samples are equal ² summaries must be >0 to compute geomean Change-Id: I829061b009a0929a8ef1a15c183793c0b9104dde Reviewed-on: https://go-review.googlesource.com/c/go/+/472475 Reviewed-by: Damien Neil Reviewed-by: Bryan Mills LUCI-TryBot-Result: Go LUCI --- diff --git a/api/next/58808.txt b/api/next/58808.txt new file mode 100644 index 0000000000..f1105c3168 --- /dev/null +++ b/api/next/58808.txt @@ -0,0 +1,2 @@ +pkg net, method (*TCPConn) WriteTo(io.Writer) (int64, error) #58808 +pkg os, method (*File) WriteTo(io.Writer) (int64, error) #58808 diff --git a/src/internal/poll/fd.go b/src/internal/poll/fd.go index ef61d0cb3f..4e038d00dd 100644 --- a/src/internal/poll/fd.go +++ b/src/internal/poll/fd.go @@ -81,3 +81,14 @@ func consume(v *[][]byte, n int64) { // TestHookDidWritev is a hook for testing writev. var TestHookDidWritev = func(wrote int) {} + +// String is an internal string definition for methods/functions +// that is not intended for use outside the standard libraries. +// +// Other packages in std that import internal/poll and have some +// exported APIs (now we've got some in net.rawConn) which are only used +// internally and are not intended to be used outside the standard libraries, +// Therefore, we make those APIs use internal types like poll.FD or poll.String +// in their function signatures to disable the usability of these APIs from +// external codebase. +type String string diff --git a/src/net/http/transfer_test.go b/src/net/http/transfer_test.go index 3f9ebdea7b..b1a5a93103 100644 --- a/src/net/http/transfer_test.go +++ b/src/net/http/transfer_test.go @@ -264,6 +264,12 @@ func TestTransferWriterWriteBodyReaderTypes(t *testing.T) { actualReader = reflect.TypeOf(lr.R) } else { actualReader = reflect.TypeOf(mw.CalledReader) + // We have to handle this special case for genericWriteTo in os, + // this struct is introduced to support a zero-copy optimization, + // check out https://go.dev/issue/58808 for details. + if actualReader.Kind() == reflect.Struct && actualReader.PkgPath() == "os" && actualReader.Name() == "fileWithoutWriteTo" { + actualReader = actualReader.Field(1).Type + } } if tc.expectedReader != actualReader { diff --git a/src/net/net.go b/src/net/net.go index 396713ce4a..02c2ceda32 100644 --- a/src/net/net.go +++ b/src/net/net.go @@ -664,15 +664,53 @@ var errClosed = poll.ErrNetClosing // errors.Is(err, net.ErrClosed). var ErrClosed error = errClosed -type writerOnly struct { - io.Writer +// noReadFrom can be embedded alongside another type to +// hide the ReadFrom method of that other type. +type noReadFrom struct{} + +// ReadFrom hides another ReadFrom method. +// It should never be called. +func (noReadFrom) ReadFrom(io.Reader) (int64, error) { + panic("can't happen") +} + +// tcpConnWithoutReadFrom implements all the methods of *TCPConn other +// than ReadFrom. This is used to permit ReadFrom to call io.Copy +// without leading to a recursive call to ReadFrom. +type tcpConnWithoutReadFrom struct { + noReadFrom + *TCPConn } // Fallback implementation of io.ReaderFrom's ReadFrom, when sendfile isn't // applicable. -func genericReadFrom(w io.Writer, r io.Reader) (n int64, err error) { +func genericReadFrom(c *TCPConn, r io.Reader) (n int64, err error) { // Use wrapper to hide existing r.ReadFrom from io.Copy. - return io.Copy(writerOnly{w}, r) + return io.Copy(tcpConnWithoutReadFrom{TCPConn: c}, r) +} + +// noWriteTo can be embedded alongside another type to +// hide the WriteTo method of that other type. +type noWriteTo struct{} + +// WriteTo hides another WriteTo method. +// It should never be called. +func (noWriteTo) WriteTo(io.Writer) (int64, error) { + panic("can't happen") +} + +// tcpConnWithoutWriteTo implements all the methods of *TCPConn other +// than WriteTo. This is used to permit WriteTo to call io.Copy +// without leading to a recursive call to WriteTo. +type tcpConnWithoutWriteTo struct { + noWriteTo + *TCPConn +} + +// Fallback implementation of io.WriterTo's WriteTo, when zero-copy isn't applicable. +func genericWriteTo(c *TCPConn, w io.Writer) (n int64, err error) { + // Use wrapper to hide existing w.WriteTo from io.Copy. + return io.Copy(w, tcpConnWithoutWriteTo{TCPConn: c}) } // Limit the number of concurrent cgo-using goroutines, because diff --git a/src/net/rawconn.go b/src/net/rawconn.go index e49b9fb81b..7a69fe5c25 100644 --- a/src/net/rawconn.go +++ b/src/net/rawconn.go @@ -79,6 +79,17 @@ func newRawConn(fd *netFD) *rawConn { return &rawConn{fd: fd} } +// Network returns the network type of the underlying connection. +// +// Other packages in std that import internal/poll and are unable to +// import net (such as os) can use a type assertion to access this +// extension method so that they can distinguish different socket types. +// +// Network is not intended for use outside the standard library. +func (c *rawConn) Network() poll.String { + return poll.String(c.fd.net) +} + type rawListener struct { rawConn } diff --git a/src/net/sendfile_linux_test.go b/src/net/sendfile_linux_test.go index 0b5af36cdb..7a66d3645f 100644 --- a/src/net/sendfile_linux_test.go +++ b/src/net/sendfile_linux_test.go @@ -14,29 +14,36 @@ import ( ) func BenchmarkSendFile(b *testing.B) { + b.Run("file-to-tcp", func(b *testing.B) { benchmarkSendFile(b, "tcp") }) + b.Run("file-to-unix", func(b *testing.B) { benchmarkSendFile(b, "unix") }) +} + +func benchmarkSendFile(b *testing.B, proto string) { for i := 0; i <= 10; i++ { size := 1 << (i + 10) - bench := sendFileBench{chunkSize: size} + bench := sendFileBench{ + proto: proto, + chunkSize: size, + } b.Run(strconv.Itoa(size), bench.benchSendFile) } } type sendFileBench struct { + proto string chunkSize int } func (bench sendFileBench) benchSendFile(b *testing.B) { fileSize := b.N * bench.chunkSize f := createTempFile(b, fileSize) - fileName := f.Name() - defer os.Remove(fileName) - defer f.Close() - client, server := spliceTestSocketPair(b, "tcp") + client, server := spliceTestSocketPair(b, bench.proto) defer server.Close() cleanUp, err := startSpliceClient(client, "r", bench.chunkSize, fileSize) if err != nil { + client.Close() b.Fatal(err) } defer cleanUp() @@ -51,15 +58,18 @@ func (bench sendFileBench) benchSendFile(b *testing.B) { b.Fatalf("failed to copy data with sendfile, error: %v", err) } if sent != int64(fileSize) { - b.Fatalf("bytes sent mismatch\n\texpect: %d\n\tgot: %d", fileSize, sent) + b.Fatalf("bytes sent mismatch, got: %d, want: %d", sent, fileSize) } } func createTempFile(b *testing.B, size int) *os.File { - f, err := os.CreateTemp("", "linux-sendfile-test") + f, err := os.CreateTemp(b.TempDir(), "linux-sendfile-bench") if err != nil { b.Fatalf("failed to create temporary file: %v", err) } + b.Cleanup(func() { + f.Close() + }) data := make([]byte, size) if _, err := f.Write(data); err != nil { diff --git a/src/net/splice_linux.go b/src/net/splice_linux.go index ab2ab70b28..bdafcb59ab 100644 --- a/src/net/splice_linux.go +++ b/src/net/splice_linux.go @@ -9,12 +9,12 @@ import ( "io" ) -// splice transfers data from r to c using the splice system call to minimize -// copies from and to userspace. c must be a TCP connection. Currently, splice -// is only enabled if r is a TCP or a stream-oriented Unix connection. +// spliceFrom transfers data from r to c using the splice system call to minimize +// copies from and to userspace. c must be a TCP connection. +// Currently, spliceFrom is only enabled if r is a TCP or a stream-oriented Unix connection. // -// If splice returns handled == false, it has performed no work. -func splice(c *netFD, r io.Reader) (written int64, err error, handled bool) { +// If spliceFrom returns handled == false, it has performed no work. +func spliceFrom(c *netFD, r io.Reader) (written int64, err error, handled bool) { var remain int64 = 1<<63 - 1 // by default, copy until EOF lr, ok := r.(*io.LimitedReader) if ok { @@ -25,14 +25,17 @@ func splice(c *netFD, r io.Reader) (written int64, err error, handled bool) { } var s *netFD - if tc, ok := r.(*TCPConn); ok { - s = tc.fd - } else if uc, ok := r.(*UnixConn); ok { - if uc.fd.net != "unix" { + switch v := r.(type) { + case *TCPConn: + s = v.fd + case tcpConnWithoutWriteTo: + s = v.fd + case *UnixConn: + if v.fd.net != "unix" { return 0, nil, false } - s = uc.fd - } else { + s = v.fd + default: return 0, nil, false } @@ -42,3 +45,18 @@ func splice(c *netFD, r io.Reader) (written int64, err error, handled bool) { } return written, wrapSyscallError(sc, err), handled } + +// spliceTo transfers data from c to w using the splice system call to minimize +// copies from and to userspace. c must be a TCP connection. +// Currently, spliceTo is only enabled if w is a stream-oriented Unix connection. +// +// If spliceTo returns handled == false, it has performed no work. +func spliceTo(w io.Writer, c *netFD) (written int64, err error, handled bool) { + uc, ok := w.(*UnixConn) + if !ok || uc.fd.net != "unix" { + return + } + + written, handled, sc, err := poll.Splice(&uc.fd.pfd, &c.pfd, 1<<63-1) + return written, wrapSyscallError(sc, err), handled +} diff --git a/src/net/splice_stub.go b/src/net/splice_stub.go index 3cdadb11c5..239227ff88 100644 --- a/src/net/splice_stub.go +++ b/src/net/splice_stub.go @@ -8,6 +8,10 @@ package net import "io" -func splice(c *netFD, r io.Reader) (int64, error, bool) { +func spliceFrom(_ *netFD, _ io.Reader) (int64, error, bool) { + return 0, nil, false +} + +func spliceTo(_ io.Writer, _ *netFD) (int64, error, bool) { return 0, nil, false } diff --git a/src/net/splice_test.go b/src/net/splice_test.go index 75a8f274ff..227ddebff4 100644 --- a/src/net/splice_test.go +++ b/src/net/splice_test.go @@ -23,6 +23,7 @@ func TestSplice(t *testing.T) { t.Skip("skipping unix-to-tcp tests") } t.Run("unix-to-tcp", func(t *testing.T) { testSplice(t, "unix", "tcp") }) + t.Run("tcp-to-unix", func(t *testing.T) { testSplice(t, "tcp", "unix") }) t.Run("tcp-to-file", func(t *testing.T) { testSpliceToFile(t, "tcp", "file") }) t.Run("unix-to-file", func(t *testing.T) { testSpliceToFile(t, "unix", "file") }) t.Run("no-unixpacket", testSpliceNoUnixpacket) @@ -159,6 +160,13 @@ func (tc spliceTestCase) testFile(t *testing.T) { } func testSpliceReaderAtEOF(t *testing.T, upNet, downNet string) { + // UnixConn doesn't implement io.ReaderFrom, which will fail + // the following test in asserting a UnixConn to be an io.ReaderFrom, + // so skip this test. + if upNet == "unix" || downNet == "unix" { + t.Skip("skipping test on unix socket") + } + clientUp, serverUp := spliceTestSocketPair(t, upNet) defer clientUp.Close() clientDown, serverDown := spliceTestSocketPair(t, downNet) @@ -166,16 +174,16 @@ func testSpliceReaderAtEOF(t *testing.T, upNet, downNet string) { serverUp.Close() - // We'd like to call net.splice here and check the handled return + // We'd like to call net.spliceFrom here and check the handled return // value, but we disable splice on old Linux kernels. // - // In that case, poll.Splice and net.splice return a non-nil error + // In that case, poll.Splice and net.spliceFrom return a non-nil error // and handled == false. We'd ideally like to see handled == true // because the source reader is at EOF, but if we're running on an old - // kernel, and splice is disabled, we won't see EOF from net.splice, + // kernel, and splice is disabled, we won't see EOF from net.spliceFrom, // because we won't touch the reader at all. // - // Trying to untangle the errors from net.splice and match them + // Trying to untangle the errors from net.spliceFrom and match them // against the errors created by the poll package would be brittle, // so this is a higher level test. // @@ -268,7 +276,7 @@ func testSpliceNoUnixpacket(t *testing.T) { // // What we want is err == nil and handled == false, i.e. we never // called poll.Splice, because we know the unix socket's network. - _, err, handled := splice(serverDown.(*TCPConn).fd, serverUp) + _, err, handled := spliceFrom(serverDown.(*TCPConn).fd, serverUp) if err != nil || handled != false { t.Fatalf("got err = %v, handled = %t, want nil error, handled == false", err, handled) } @@ -289,7 +297,7 @@ func testSpliceNoUnixgram(t *testing.T) { defer clientDown.Close() defer serverDown.Close() // Analogous to testSpliceNoUnixpacket. - _, err, handled := splice(serverDown.(*TCPConn).fd, up) + _, err, handled := spliceFrom(serverDown.(*TCPConn).fd, up) if err != nil || handled != false { t.Fatalf("got err = %v, handled = %t, want nil error, handled == false", err, handled) } @@ -300,6 +308,7 @@ func BenchmarkSplice(b *testing.B) { b.Run("tcp-to-tcp", func(b *testing.B) { benchSplice(b, "tcp", "tcp") }) b.Run("unix-to-tcp", func(b *testing.B) { benchSplice(b, "unix", "tcp") }) + b.Run("tcp-to-unix", func(b *testing.B) { benchSplice(b, "tcp", "unix") }) } func benchSplice(b *testing.B, upNet, downNet string) { diff --git a/src/net/tcpsock.go b/src/net/tcpsock.go index 1528353cba..6257f2515b 100644 --- a/src/net/tcpsock.go +++ b/src/net/tcpsock.go @@ -134,6 +134,18 @@ func (c *TCPConn) ReadFrom(r io.Reader) (int64, error) { return n, err } +// WriteTo implements the io.WriterTo WriteTo method. +func (c *TCPConn) WriteTo(w io.Writer) (int64, error) { + if !c.ok() { + return 0, syscall.EINVAL + } + n, err := c.writeTo(w) + if err != nil && err != io.EOF { + err = &OpError{Op: "writeto", Net: c.fd.net, Source: c.fd.laddr, Addr: c.fd.raddr, Err: err} + } + return n, err +} + // CloseRead shuts down the reading side of the TCP connection. // Most callers should just use Close. func (c *TCPConn) CloseRead() error { diff --git a/src/net/tcpsock_plan9.go b/src/net/tcpsock_plan9.go index d55948f69e..463dedcf44 100644 --- a/src/net/tcpsock_plan9.go +++ b/src/net/tcpsock_plan9.go @@ -14,6 +14,10 @@ func (c *TCPConn) readFrom(r io.Reader) (int64, error) { return genericReadFrom(c, r) } +func (c *TCPConn) writeTo(w io.Writer) (int64, error) { + return genericWriteTo(c, w) +} + func (sd *sysDialer) dialTCP(ctx context.Context, laddr, raddr *TCPAddr) (*TCPConn, error) { if h := sd.testHookDialTCP; h != nil { return h(ctx, sd.network, laddr, raddr) diff --git a/src/net/tcpsock_posix.go b/src/net/tcpsock_posix.go index 83cee7c789..01b5ec9ed0 100644 --- a/src/net/tcpsock_posix.go +++ b/src/net/tcpsock_posix.go @@ -45,7 +45,7 @@ func (a *TCPAddr) toLocal(net string) sockaddr { } func (c *TCPConn) readFrom(r io.Reader) (int64, error) { - if n, err, handled := splice(c.fd, r); handled { + if n, err, handled := spliceFrom(c.fd, r); handled { return n, err } if n, err, handled := sendFile(c.fd, r); handled { @@ -54,6 +54,13 @@ func (c *TCPConn) readFrom(r io.Reader) (int64, error) { return genericReadFrom(c, r) } +func (c *TCPConn) writeTo(w io.Writer) (int64, error) { + if n, err, handled := spliceTo(w, c.fd); handled { + return n, err + } + return genericWriteTo(c, w) +} + func (sd *sysDialer) dialTCP(ctx context.Context, laddr, raddr *TCPAddr) (*TCPConn, error) { if h := sd.testHookDialTCP; h != nil { return h(ctx, sd.network, laddr, raddr) diff --git a/src/os/export_linux_test.go b/src/os/export_linux_test.go index 3fd5e61de7..942b48a17d 100644 --- a/src/os/export_linux_test.go +++ b/src/os/export_linux_test.go @@ -5,7 +5,8 @@ package os var ( - PollCopyFileRangeP = &pollCopyFileRange - PollSpliceFile = &pollSplice - GetPollFDForTest = getPollFD + PollCopyFileRangeP = &pollCopyFileRange + PollSpliceFile = &pollSplice + PollSendFile = &pollSendFile + GetPollFDAndNetwork = getPollFDAndNetwork ) diff --git a/src/os/file.go b/src/os/file.go index 82be00a834..37a30ccf04 100644 --- a/src/os/file.go +++ b/src/os/file.go @@ -157,20 +157,26 @@ func (f *File) ReadFrom(r io.Reader) (n int64, err error) { return n, f.wrapErr("write", e) } -func genericReadFrom(f *File, r io.Reader) (int64, error) { - return io.Copy(fileWithoutReadFrom{f}, r) +// noReadFrom can be embedded alongside another type to +// hide the ReadFrom method of that other type. +type noReadFrom struct{} + +// ReadFrom hides another ReadFrom method. +// It should never be called. +func (noReadFrom) ReadFrom(io.Reader) (int64, error) { + panic("can't happen") } // fileWithoutReadFrom implements all the methods of *File other // than ReadFrom. This is used to permit ReadFrom to call io.Copy // without leading to a recursive call to ReadFrom. type fileWithoutReadFrom struct { + noReadFrom *File } -// This ReadFrom method hides the *File ReadFrom method. -func (fileWithoutReadFrom) ReadFrom(fileWithoutReadFrom) { - panic("unreachable") +func genericReadFrom(f *File, r io.Reader) (int64, error) { + return io.Copy(fileWithoutReadFrom{File: f}, r) } // Write writes len(b) bytes from b to the File. @@ -229,6 +235,40 @@ func (f *File) WriteAt(b []byte, off int64) (n int, err error) { return } +// WriteTo implements io.WriterTo. +func (f *File) WriteTo(w io.Writer) (n int64, err error) { + if err := f.checkValid("read"); err != nil { + return 0, err + } + n, handled, e := f.writeTo(w) + if handled { + return n, f.wrapErr("read", e) + } + return genericWriteTo(f, w) // without wrapping +} + +// noWriteTo can be embedded alongside another type to +// hide the WriteTo method of that other type. +type noWriteTo struct{} + +// WriteTo hides another WriteTo method. +// It should never be called. +func (noWriteTo) WriteTo(io.Writer) (int64, error) { + panic("can't happen") +} + +// fileWithoutWriteTo implements all the methods of *File other +// than WriteTo. This is used to permit WriteTo to call io.Copy +// without leading to a recursive call to WriteTo. +type fileWithoutWriteTo struct { + noWriteTo + *File +} + +func genericWriteTo(f *File, w io.Writer) (int64, error) { + return io.Copy(w, fileWithoutWriteTo{File: f}) +} + // Seek sets the offset for the next Read or Write on file to offset, interpreted // according to whence: 0 means relative to the origin of the file, 1 means // relative to the current offset, and 2 means relative to the end. diff --git a/src/os/readfrom_linux_test.go b/src/os/readfrom_linux_test.go index 4f98be4b9b..93f78032e7 100644 --- a/src/os/readfrom_linux_test.go +++ b/src/os/readfrom_linux_test.go @@ -749,12 +749,12 @@ func TestProcCopy(t *testing.T) { } } -func TestGetPollFDFromReader(t *testing.T) { - t.Run("tcp", func(t *testing.T) { testGetPollFromReader(t, "tcp") }) - t.Run("unix", func(t *testing.T) { testGetPollFromReader(t, "unix") }) +func TestGetPollFDAndNetwork(t *testing.T) { + t.Run("tcp4", func(t *testing.T) { testGetPollFDAndNetwork(t, "tcp4") }) + t.Run("unix", func(t *testing.T) { testGetPollFDAndNetwork(t, "unix") }) } -func testGetPollFromReader(t *testing.T, proto string) { +func testGetPollFDAndNetwork(t *testing.T, proto string) { _, server := createSocketPair(t, proto) sc, ok := server.(syscall.Conn) if !ok { @@ -765,12 +765,15 @@ func testGetPollFromReader(t *testing.T, proto string) { t.Fatalf("server SyscallConn error: %v", err) } if err = rc.Control(func(fd uintptr) { - pfd := GetPollFDForTest(server) + pfd, network := GetPollFDAndNetwork(server) if pfd == nil { - t.Fatalf("GetPollFDForTest didn't return poll.FD") + t.Fatalf("GetPollFDAndNetwork didn't return poll.FD") + } + if string(network) != proto { + t.Fatalf("GetPollFDAndNetwork returned wrong network, got: %s, want: %s", network, proto) } if pfd.Sysfd != int(fd) { - t.Fatalf("GetPollFDForTest returned wrong poll.FD, got: %d, want: %d", pfd.Sysfd, int(fd)) + t.Fatalf("GetPollFDAndNetwork returned wrong poll.FD, got: %d, want: %d", pfd.Sysfd, int(fd)) } if !pfd.IsStream { t.Fatalf("expected IsStream to be true") diff --git a/src/os/writeto_linux_test.go b/src/os/writeto_linux_test.go new file mode 100644 index 0000000000..5ffab88a2a --- /dev/null +++ b/src/os/writeto_linux_test.go @@ -0,0 +1,171 @@ +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package os_test + +import ( + "bytes" + "internal/poll" + "io" + "math/rand" + "net" + . "os" + "strconv" + "syscall" + "testing" + "time" +) + +func TestSendFile(t *testing.T) { + sizes := []int{ + 1, + 42, + 1025, + syscall.Getpagesize() + 1, + 32769, + } + t.Run("sendfile-to-unix", func(t *testing.T) { + for _, size := range sizes { + t.Run(strconv.Itoa(size), func(t *testing.T) { + testSendFile(t, "unix", int64(size)) + }) + } + }) + t.Run("sendfile-to-tcp", func(t *testing.T) { + for _, size := range sizes { + t.Run(strconv.Itoa(size), func(t *testing.T) { + testSendFile(t, "tcp", int64(size)) + }) + } + }) +} + +func testSendFile(t *testing.T, proto string, size int64) { + dst, src, recv, data, hook := newSendFileTest(t, proto, size) + + // Now call WriteTo (through io.Copy), which will hopefully call poll.SendFile + n, err := io.Copy(dst, src) + if err != nil { + t.Fatalf("io.Copy error: %v", err) + } + + // We should have called poll.Splice with the right file descriptor arguments. + if n > 0 && !hook.called { + t.Fatal("expected to called poll.SendFile") + } + if hook.called && hook.srcfd != int(src.Fd()) { + t.Fatalf("wrong source file descriptor: got %d, want %d", hook.srcfd, src.Fd()) + } + sc, ok := dst.(syscall.Conn) + if !ok { + t.Fatalf("destination is not a syscall.Conn") + } + rc, err := sc.SyscallConn() + if err != nil { + t.Fatalf("destination SyscallConn error: %v", err) + } + if err = rc.Control(func(fd uintptr) { + if hook.called && hook.dstfd != int(fd) { + t.Fatalf("wrong destination file descriptor: got %d, want %d", hook.dstfd, int(fd)) + } + }); err != nil { + t.Fatalf("destination Conn Control error: %v", err) + } + + // Verify the data size and content. + dataSize := len(data) + dstData := make([]byte, dataSize) + m, err := io.ReadFull(recv, dstData) + if err != nil { + t.Fatalf("server Conn Read error: %v", err) + } + if n != int64(dataSize) { + t.Fatalf("data length mismatch for io.Copy, got %d, want %d", n, dataSize) + } + if m != dataSize { + t.Fatalf("data length mismatch for net.Conn.Read, got %d, want %d", m, dataSize) + } + if !bytes.Equal(dstData, data) { + t.Errorf("data mismatch, got %s, want %s", dstData, data) + } +} + +// newSendFileTest initializes a new test for sendfile. +// +// It creates source file and destination sockets, and populates the source file +// with random data of the specified size. It also hooks package os' call +// to poll.Sendfile and returns the hook so it can be inspected. +func newSendFileTest(t *testing.T, proto string, size int64) (net.Conn, *File, net.Conn, []byte, *sendFileHook) { + t.Helper() + + hook := hookSendFile(t) + + client, server := createSocketPair(t, proto) + tempFile, data := createTempFile(t, size) + + return client, tempFile, server, data, hook +} + +func hookSendFile(t *testing.T) *sendFileHook { + h := new(sendFileHook) + h.install() + t.Cleanup(h.uninstall) + return h +} + +type sendFileHook struct { + called bool + dstfd int + srcfd int + remain int64 + + written int64 + handled bool + err error + + original func(dst *poll.FD, src int, remain int64) (int64, error, bool) +} + +func (h *sendFileHook) install() { + h.original = *PollSendFile + *PollSendFile = func(dst *poll.FD, src int, remain int64) (int64, error, bool) { + h.called = true + h.dstfd = dst.Sysfd + h.srcfd = src + h.remain = remain + h.written, h.err, h.handled = h.original(dst, src, remain) + return h.written, h.err, h.handled + } +} + +func (h *sendFileHook) uninstall() { + *PollSendFile = h.original +} + +func createTempFile(t *testing.T, size int64) (*File, []byte) { + f, err := CreateTemp(t.TempDir(), "writeto-sendfile-to-socket") + if err != nil { + t.Fatalf("failed to create temporary file: %v", err) + } + t.Cleanup(func() { + f.Close() + }) + + randSeed := time.Now().Unix() + t.Logf("random data seed: %d\n", randSeed) + prng := rand.New(rand.NewSource(randSeed)) + data := make([]byte, size) + prng.Read(data) + if _, err := f.Write(data); err != nil { + t.Fatalf("failed to create and feed the file: %v", err) + } + if err := f.Sync(); err != nil { + t.Fatalf("failed to save the file: %v", err) + } + if _, err := f.Seek(0, io.SeekStart); err != nil { + t.Fatalf("failed to rewind the file: %v", err) + } + + return f, data +} diff --git a/src/os/readfrom_linux.go b/src/os/zero_copy_linux.go similarity index 70% rename from src/os/readfrom_linux.go rename to src/os/zero_copy_linux.go index 7e8024028e..7c45aefeee 100644 --- a/src/os/readfrom_linux.go +++ b/src/os/zero_copy_linux.go @@ -13,8 +13,33 @@ import ( var ( pollCopyFileRange = poll.CopyFileRange pollSplice = poll.Splice + pollSendFile = poll.SendFile ) +func (f *File) writeTo(w io.Writer) (written int64, handled bool, err error) { + pfd, network := getPollFDAndNetwork(w) + // TODO(panjf2000): same as File.spliceToFile. + if pfd == nil || !pfd.IsStream || !isUnixOrTCP(string(network)) { + return + } + + sc, err := f.SyscallConn() + if err != nil { + return + } + + rerr := sc.Read(func(fd uintptr) (done bool) { + written, err, handled = pollSendFile(pfd, int(fd), 1<<63-1) + return true + }) + + if err == nil { + err = rerr + } + + return written, handled, wrapSyscallError("sendfile", err) +} + func (f *File) readFrom(r io.Reader) (written int64, handled bool, err error) { // Neither copy_file_range(2) nor splice(2) supports destinations opened with // O_APPEND, so don't bother to try zero-copy with these system calls. @@ -41,7 +66,7 @@ func (f *File) spliceToFile(r io.Reader) (written int64, handled bool, err error return 0, true, nil } - pfd := getPollFD(r) + pfd, _ := getPollFDAndNetwork(r) // TODO(panjf2000): run some tests to see if we should unlock the non-streams for splice. // Streams benefit the most from the splice(2), non-streams are not even supported in old kernels // where splice(2) will just return EINVAL; newer kernels support non-streams like UDP, but I really @@ -63,25 +88,6 @@ func (f *File) spliceToFile(r io.Reader) (written int64, handled bool, err error return written, handled, wrapSyscallError(syscallName, err) } -// getPollFD tries to get the poll.FD from the given io.Reader by expecting -// the underlying type of r to be the implementation of syscall.Conn that contains -// a *net.rawConn. -func getPollFD(r io.Reader) *poll.FD { - sc, ok := r.(syscall.Conn) - if !ok { - return nil - } - rc, err := sc.SyscallConn() - if err != nil { - return nil - } - ipfd, ok := rc.(interface{ PollFD() *poll.FD }) - if !ok { - return nil - } - return ipfd.PollFD() -} - func (f *File) copyFileRange(r io.Reader) (written int64, handled bool, err error) { var ( remain int64 @@ -91,10 +97,16 @@ func (f *File) copyFileRange(r io.Reader) (written int64, handled bool, err erro return 0, true, nil } - src, ok := r.(*File) - if !ok { + var src *File + switch v := r.(type) { + case *File: + src = v + case fileWithoutWriteTo: + src = v.File + default: return 0, false, nil } + if src.checkValid("ReadFrom") != nil { // Avoid returning the error as we report handled as false, // leave further error handling as the responsibility of the caller. @@ -108,6 +120,28 @@ func (f *File) copyFileRange(r io.Reader) (written int64, handled bool, err erro return written, handled, wrapSyscallError("copy_file_range", err) } +// getPollFDAndNetwork tries to get the poll.FD and network type from the given interface +// by expecting the underlying type of i to be the implementation of syscall.Conn +// that contains a *net.rawConn. +func getPollFDAndNetwork(i any) (*poll.FD, poll.String) { + sc, ok := i.(syscall.Conn) + if !ok { + return nil, "" + } + rc, err := sc.SyscallConn() + if err != nil { + return nil, "" + } + irc, ok := rc.(interface { + PollFD() *poll.FD + Network() poll.String + }) + if !ok { + return nil, "" + } + return irc.PollFD(), irc.Network() +} + // tryLimitedReader tries to assert the io.Reader to io.LimitedReader, it returns the io.LimitedReader, // the underlying io.Reader and the remaining amount of bytes if the assertion succeeds, // otherwise it just returns the original io.Reader and the theoretical unlimited remaining amount of bytes. @@ -122,3 +156,12 @@ func tryLimitedReader(r io.Reader) (*io.LimitedReader, io.Reader, int64) { remain = lr.N return lr, lr.R, remain } + +func isUnixOrTCP(network string) bool { + switch network { + case "tcp", "tcp4", "tcp6", "unix": + return true + default: + return false + } +} diff --git a/src/os/readfrom_stub.go b/src/os/zero_copy_stub.go similarity index 74% rename from src/os/readfrom_stub.go rename to src/os/zero_copy_stub.go index 8b7d5fb8f9..9ec5808101 100644 --- a/src/os/readfrom_stub.go +++ b/src/os/zero_copy_stub.go @@ -8,6 +8,10 @@ package os import "io" +func (f *File) writeTo(w io.Writer) (written int64, handled bool, err error) { + return 0, false, nil +} + func (f *File) readFrom(r io.Reader) (n int64, handled bool, err error) { return 0, false, nil }