]> Cypherpunks repositories - gostls13.git/commitdiff
net,os: arrange zero-copy of os.File and net.TCPConn to net.UnixConn
authorAndy Pan <panjf2000@gmail.com>
Tue, 28 Feb 2023 08:39:15 +0000 (16:39 +0800)
committerDamien Neil <dneil@google.com>
Fri, 17 Nov 2023 23:16:28 +0000 (23:16 +0000)
Fixes #58808

goos: linux
goarch: amd64
pkg: net
cpu: DO-Premium-Intel
                             │      old      │                 new                  │
                             │    sec/op     │    sec/op     vs base                │
Splice/tcp-to-unix/1024-4       3.783µ ± 10%   3.201µ ±  7%  -15.40% (p=0.001 n=10)
Splice/tcp-to-unix/2048-4       3.967µ ± 13%   3.818µ ± 16%        ~ (p=0.971 n=10)
Splice/tcp-to-unix/4096-4       4.988µ ± 16%   4.590µ ± 11%        ~ (p=0.089 n=10)
Splice/tcp-to-unix/8192-4       6.981µ ± 13%   5.236µ ±  9%  -25.00% (p=0.000 n=10)
Splice/tcp-to-unix/16384-4     10.192µ ±  9%   7.350µ ±  7%  -27.89% (p=0.000 n=10)
Splice/tcp-to-unix/32768-4      19.65µ ± 13%   10.28µ ± 16%  -47.69% (p=0.000 n=10)
Splice/tcp-to-unix/65536-4      41.89µ ± 18%   15.70µ ± 13%  -62.52% (p=0.000 n=10)
Splice/tcp-to-unix/131072-4     90.05µ ± 11%   29.55µ ± 10%  -67.18% (p=0.000 n=10)
Splice/tcp-to-unix/262144-4    170.24µ ± 15%   52.66µ ±  4%  -69.06% (p=0.000 n=10)
Splice/tcp-to-unix/524288-4     326.4µ ± 13%   109.3µ ± 11%  -66.52% (p=0.000 n=10)
Splice/tcp-to-unix/1048576-4    651.4µ ±  9%   228.3µ ± 14%  -64.95% (p=0.000 n=10)
geomean                         29.42µ         15.62µ        -46.90%

                             │      old      │                  new                   │
                             │      B/s      │      B/s       vs base                 │
Splice/tcp-to-unix/1024-4      258.2Mi ± 11%   305.2Mi ±  8%   +18.21% (p=0.001 n=10)
Splice/tcp-to-unix/2048-4      492.5Mi ± 15%   511.7Mi ± 13%         ~ (p=0.971 n=10)
Splice/tcp-to-unix/4096-4      783.5Mi ± 14%   851.2Mi ± 12%         ~ (p=0.089 n=10)
Splice/tcp-to-unix/8192-4      1.093Gi ± 11%   1.458Gi ±  8%   +33.36% (p=0.000 n=10)
Splice/tcp-to-unix/16384-4     1.497Gi ±  9%   2.076Gi ±  7%   +38.67% (p=0.000 n=10)
Splice/tcp-to-unix/32768-4     1.553Gi ± 11%   2.969Gi ± 14%   +91.17% (p=0.000 n=10)
Splice/tcp-to-unix/65536-4     1.458Gi ± 23%   3.888Gi ± 11%  +166.69% (p=0.000 n=10)
Splice/tcp-to-unix/131072-4    1.356Gi ± 10%   4.131Gi ±  9%  +204.72% (p=0.000 n=10)
Splice/tcp-to-unix/262144-4    1.434Gi ± 13%   4.637Gi ±  4%  +223.32% (p=0.000 n=10)
Splice/tcp-to-unix/524288-4    1.497Gi ± 15%   4.468Gi ± 10%  +198.47% (p=0.000 n=10)
Splice/tcp-to-unix/1048576-4   1.501Gi ± 10%   4.277Gi ± 16%  +184.88% (p=0.000 n=10)
geomean                        1.038Gi         1.954Gi         +88.28%

                             │      old      │                   new                   │
                             │     B/op      │    B/op     vs base                     │
Splice/tcp-to-unix/1024-4      0.000 ±  0%     0.000 ± 0%         ~ (p=1.000 n=10) ¹
Splice/tcp-to-unix/2048-4      0.000 ±  0%     0.000 ± 0%         ~ (p=1.000 n=10) ¹
Splice/tcp-to-unix/4096-4      0.000 ±  0%     0.000 ± 0%         ~ (p=1.000 n=10) ¹
Splice/tcp-to-unix/8192-4      0.000 ±  0%     0.000 ± 0%         ~ (p=1.000 n=10) ¹
Splice/tcp-to-unix/16384-4     0.000 ±  0%     0.000 ± 0%         ~ (p=1.000 n=10) ¹
Splice/tcp-to-unix/32768-4     0.000 ±  0%     0.000 ± 0%         ~ (p=1.000 n=10) ¹
Splice/tcp-to-unix/65536-4     1.000 ±   ?     0.000 ± 0%  -100.00% (p=0.001 n=10)
Splice/tcp-to-unix/131072-4    2.000 ±  0%     0.000 ± 0%  -100.00% (p=0.000 n=10)
Splice/tcp-to-unix/262144-4    4.000 ± 25%     0.000 ± 0%  -100.00% (p=0.000 n=10)
Splice/tcp-to-unix/524288-4    7.500 ± 33%     0.000 ± 0%  -100.00% (p=0.000 n=10)
Splice/tcp-to-unix/1048576-4   17.00 ± 12%      0.00 ± 0%  -100.00% (p=0.000 n=10)
geomean                                    ²               ?                       ² ³
¹ all samples are equal
² summaries must be >0 to compute geomean
³ ratios must be >0 to compute geomean

                             │     old      │                 new                 │
                             │  allocs/op   │ allocs/op   vs base                 │
Splice/tcp-to-unix/1024-4      0.000 ± 0%     0.000 ± 0%       ~ (p=1.000 n=10) ¹
Splice/tcp-to-unix/2048-4      0.000 ± 0%     0.000 ± 0%       ~ (p=1.000 n=10) ¹
Splice/tcp-to-unix/4096-4      0.000 ± 0%     0.000 ± 0%       ~ (p=1.000 n=10) ¹
Splice/tcp-to-unix/8192-4      0.000 ± 0%     0.000 ± 0%       ~ (p=1.000 n=10) ¹
Splice/tcp-to-unix/16384-4     0.000 ± 0%     0.000 ± 0%       ~ (p=1.000 n=10) ¹
Splice/tcp-to-unix/32768-4     0.000 ± 0%     0.000 ± 0%       ~ (p=1.000 n=10) ¹
Splice/tcp-to-unix/65536-4     0.000 ± 0%     0.000 ± 0%       ~ (p=1.000 n=10) ¹
Splice/tcp-to-unix/131072-4    0.000 ± 0%     0.000 ± 0%       ~ (p=1.000 n=10) ¹
Splice/tcp-to-unix/262144-4    0.000 ± 0%     0.000 ± 0%       ~ (p=1.000 n=10) ¹
Splice/tcp-to-unix/524288-4    0.000 ± 0%     0.000 ± 0%       ~ (p=1.000 n=10) ¹
Splice/tcp-to-unix/1048576-4   0.000 ± 0%     0.000 ± 0%       ~ (p=1.000 n=10) ¹
geomean                                   ²               +0.00%                ²
¹ all samples are equal
² summaries must be >0 to compute geomean

Change-Id: I829061b009a0929a8ef1a15c183793c0b9104dde
Reviewed-on: https://go-review.googlesource.com/c/go/+/472475
Reviewed-by: Damien Neil <dneil@google.com>
Reviewed-by: Bryan Mills <bcmills@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>

18 files changed:
api/next/58808.txt [new file with mode: 0644]
src/internal/poll/fd.go
src/net/http/transfer_test.go
src/net/net.go
src/net/rawconn.go
src/net/sendfile_linux_test.go
src/net/splice_linux.go
src/net/splice_stub.go
src/net/splice_test.go
src/net/tcpsock.go
src/net/tcpsock_plan9.go
src/net/tcpsock_posix.go
src/os/export_linux_test.go
src/os/file.go
src/os/readfrom_linux_test.go
src/os/writeto_linux_test.go [new file with mode: 0644]
src/os/zero_copy_linux.go [moved from src/os/readfrom_linux.go with 70% similarity]
src/os/zero_copy_stub.go [moved from src/os/readfrom_stub.go with 74% similarity]

diff --git a/api/next/58808.txt b/api/next/58808.txt
new file mode 100644 (file)
index 0000000..f1105c3
--- /dev/null
@@ -0,0 +1,2 @@
+pkg net, method (*TCPConn) WriteTo(io.Writer) (int64, error) #58808
+pkg os, method (*File) WriteTo(io.Writer) (int64, error) #58808
index ef61d0cb3ffaa01ecb8b4dd5bf1599cbb67381d2..4e038d00ddab8b306924850849c04cf7e2f665bf 100644 (file)
@@ -81,3 +81,14 @@ func consume(v *[][]byte, n int64) {
 
 // TestHookDidWritev is a hook for testing writev.
 var TestHookDidWritev = func(wrote int) {}
+
+// String is an internal string definition for methods/functions
+// that is not intended for use outside the standard libraries.
+//
+// Other packages in std that import internal/poll and have some
+// exported APIs (now we've got some in net.rawConn) which are only used
+// internally and are not intended to be used outside the standard libraries,
+// Therefore, we make those APIs use internal types like poll.FD or poll.String
+// in their function signatures to disable the usability of these APIs from
+// external codebase.
+type String string
index 3f9ebdea7bb05fdfb6a0468e847d8e23c0fee6b4..b1a5a931035c4b141f61b58d6543e1c5c8481759 100644 (file)
@@ -264,6 +264,12 @@ func TestTransferWriterWriteBodyReaderTypes(t *testing.T) {
                                        actualReader = reflect.TypeOf(lr.R)
                                } else {
                                        actualReader = reflect.TypeOf(mw.CalledReader)
+                                       // We have to handle this special case for genericWriteTo in os,
+                                       // this struct is introduced to support a zero-copy optimization,
+                                       // check out https://go.dev/issue/58808 for details.
+                                       if actualReader.Kind() == reflect.Struct && actualReader.PkgPath() == "os" && actualReader.Name() == "fileWithoutWriteTo" {
+                                               actualReader = actualReader.Field(1).Type
+                                       }
                                }
 
                                if tc.expectedReader != actualReader {
index 396713ce4ab675a8ec2fb4a563552bdc24a2e6da..02c2ceda3267916a7499599f9a8765cb6c9591f7 100644 (file)
@@ -664,15 +664,53 @@ var errClosed = poll.ErrNetClosing
 // errors.Is(err, net.ErrClosed).
 var ErrClosed error = errClosed
 
-type writerOnly struct {
-       io.Writer
+// noReadFrom can be embedded alongside another type to
+// hide the ReadFrom method of that other type.
+type noReadFrom struct{}
+
+// ReadFrom hides another ReadFrom method.
+// It should never be called.
+func (noReadFrom) ReadFrom(io.Reader) (int64, error) {
+       panic("can't happen")
+}
+
+// tcpConnWithoutReadFrom implements all the methods of *TCPConn other
+// than ReadFrom. This is used to permit ReadFrom to call io.Copy
+// without leading to a recursive call to ReadFrom.
+type tcpConnWithoutReadFrom struct {
+       noReadFrom
+       *TCPConn
 }
 
 // Fallback implementation of io.ReaderFrom's ReadFrom, when sendfile isn't
 // applicable.
-func genericReadFrom(w io.Writer, r io.Reader) (n int64, err error) {
+func genericReadFrom(c *TCPConn, r io.Reader) (n int64, err error) {
        // Use wrapper to hide existing r.ReadFrom from io.Copy.
-       return io.Copy(writerOnly{w}, r)
+       return io.Copy(tcpConnWithoutReadFrom{TCPConn: c}, r)
+}
+
+// noWriteTo can be embedded alongside another type to
+// hide the WriteTo method of that other type.
+type noWriteTo struct{}
+
+// WriteTo hides another WriteTo method.
+// It should never be called.
+func (noWriteTo) WriteTo(io.Writer) (int64, error) {
+       panic("can't happen")
+}
+
+// tcpConnWithoutWriteTo implements all the methods of *TCPConn other
+// than WriteTo. This is used to permit WriteTo to call io.Copy
+// without leading to a recursive call to WriteTo.
+type tcpConnWithoutWriteTo struct {
+       noWriteTo
+       *TCPConn
+}
+
+// Fallback implementation of io.WriterTo's WriteTo, when zero-copy isn't applicable.
+func genericWriteTo(c *TCPConn, w io.Writer) (n int64, err error) {
+       // Use wrapper to hide existing w.WriteTo from io.Copy.
+       return io.Copy(w, tcpConnWithoutWriteTo{TCPConn: c})
 }
 
 // Limit the number of concurrent cgo-using goroutines, because
index e49b9fb81b536cc74c502dc41445caa0e19d4aaf..7a69fe5c25bad9ed03ef0adb0b687e751cb26bad 100644 (file)
@@ -79,6 +79,17 @@ func newRawConn(fd *netFD) *rawConn {
        return &rawConn{fd: fd}
 }
 
+// Network returns the network type of the underlying connection.
+//
+// Other packages in std that import internal/poll and are unable to
+// import net (such as os) can use a type assertion to access this
+// extension method so that they can distinguish different socket types.
+//
+// Network is not intended for use outside the standard library.
+func (c *rawConn) Network() poll.String {
+       return poll.String(c.fd.net)
+}
+
 type rawListener struct {
        rawConn
 }
index 0b5af36cdb7933dd3fa16c3507a5f0ccc5e70cd6..7a66d3645f2fa0c63973a95b90417c6b027c069e 100644 (file)
@@ -14,29 +14,36 @@ import (
 )
 
 func BenchmarkSendFile(b *testing.B) {
+       b.Run("file-to-tcp", func(b *testing.B) { benchmarkSendFile(b, "tcp") })
+       b.Run("file-to-unix", func(b *testing.B) { benchmarkSendFile(b, "unix") })
+}
+
+func benchmarkSendFile(b *testing.B, proto string) {
        for i := 0; i <= 10; i++ {
                size := 1 << (i + 10)
-               bench := sendFileBench{chunkSize: size}
+               bench := sendFileBench{
+                       proto:     proto,
+                       chunkSize: size,
+               }
                b.Run(strconv.Itoa(size), bench.benchSendFile)
        }
 }
 
 type sendFileBench struct {
+       proto     string
        chunkSize int
 }
 
 func (bench sendFileBench) benchSendFile(b *testing.B) {
        fileSize := b.N * bench.chunkSize
        f := createTempFile(b, fileSize)
-       fileName := f.Name()
-       defer os.Remove(fileName)
-       defer f.Close()
 
-       client, server := spliceTestSocketPair(b, "tcp")
+       client, server := spliceTestSocketPair(b, bench.proto)
        defer server.Close()
 
        cleanUp, err := startSpliceClient(client, "r", bench.chunkSize, fileSize)
        if err != nil {
+               client.Close()
                b.Fatal(err)
        }
        defer cleanUp()
@@ -51,15 +58,18 @@ func (bench sendFileBench) benchSendFile(b *testing.B) {
                b.Fatalf("failed to copy data with sendfile, error: %v", err)
        }
        if sent != int64(fileSize) {
-               b.Fatalf("bytes sent mismatch\n\texpect: %d\n\tgot: %d", fileSize, sent)
+               b.Fatalf("bytes sent mismatch, got: %d, want: %d", sent, fileSize)
        }
 }
 
 func createTempFile(b *testing.B, size int) *os.File {
-       f, err := os.CreateTemp("", "linux-sendfile-test")
+       f, err := os.CreateTemp(b.TempDir(), "linux-sendfile-bench")
        if err != nil {
                b.Fatalf("failed to create temporary file: %v", err)
        }
+       b.Cleanup(func() {
+               f.Close()
+       })
 
        data := make([]byte, size)
        if _, err := f.Write(data); err != nil {
index ab2ab70b28db8d3df0bc95fadc940b6dbc596c39..bdafcb59ab84802ec6e703ed625ea130b15a42f6 100644 (file)
@@ -9,12 +9,12 @@ import (
        "io"
 )
 
-// splice transfers data from r to c using the splice system call to minimize
-// copies from and to userspace. c must be a TCP connection. Currently, splice
-// is only enabled if r is a TCP or a stream-oriented Unix connection.
+// spliceFrom transfers data from r to c using the splice system call to minimize
+// copies from and to userspace. c must be a TCP connection.
+// Currently, spliceFrom is only enabled if r is a TCP or a stream-oriented Unix connection.
 //
-// If splice returns handled == false, it has performed no work.
-func splice(c *netFD, r io.Reader) (written int64, err error, handled bool) {
+// If spliceFrom returns handled == false, it has performed no work.
+func spliceFrom(c *netFD, r io.Reader) (written int64, err error, handled bool) {
        var remain int64 = 1<<63 - 1 // by default, copy until EOF
        lr, ok := r.(*io.LimitedReader)
        if ok {
@@ -25,14 +25,17 @@ func splice(c *netFD, r io.Reader) (written int64, err error, handled bool) {
        }
 
        var s *netFD
-       if tc, ok := r.(*TCPConn); ok {
-               s = tc.fd
-       } else if uc, ok := r.(*UnixConn); ok {
-               if uc.fd.net != "unix" {
+       switch v := r.(type) {
+       case *TCPConn:
+               s = v.fd
+       case tcpConnWithoutWriteTo:
+               s = v.fd
+       case *UnixConn:
+               if v.fd.net != "unix" {
                        return 0, nil, false
                }
-               s = uc.fd
-       } else {
+               s = v.fd
+       default:
                return 0, nil, false
        }
 
@@ -42,3 +45,18 @@ func splice(c *netFD, r io.Reader) (written int64, err error, handled bool) {
        }
        return written, wrapSyscallError(sc, err), handled
 }
+
+// spliceTo transfers data from c to w using the splice system call to minimize
+// copies from and to userspace. c must be a TCP connection.
+// Currently, spliceTo is only enabled if w is a stream-oriented Unix connection.
+//
+// If spliceTo returns handled == false, it has performed no work.
+func spliceTo(w io.Writer, c *netFD) (written int64, err error, handled bool) {
+       uc, ok := w.(*UnixConn)
+       if !ok || uc.fd.net != "unix" {
+               return
+       }
+
+       written, handled, sc, err := poll.Splice(&uc.fd.pfd, &c.pfd, 1<<63-1)
+       return written, wrapSyscallError(sc, err), handled
+}
index 3cdadb11c5615f6fa628f87facfc61351da3b569..239227ff88397db25b2f8033947b8948d62539be 100644 (file)
@@ -8,6 +8,10 @@ package net
 
 import "io"
 
-func splice(c *netFD, r io.Reader) (int64, error, bool) {
+func spliceFrom(_ *netFD, _ io.Reader) (int64, error, bool) {
+       return 0, nil, false
+}
+
+func spliceTo(_ io.Writer, _ *netFD) (int64, error, bool) {
        return 0, nil, false
 }
index 75a8f274ff05db8fde7fea40930c7fc7b316359a..227ddebff402cea3a8ece694d44b68e91cea02f0 100644 (file)
@@ -23,6 +23,7 @@ func TestSplice(t *testing.T) {
                t.Skip("skipping unix-to-tcp tests")
        }
        t.Run("unix-to-tcp", func(t *testing.T) { testSplice(t, "unix", "tcp") })
+       t.Run("tcp-to-unix", func(t *testing.T) { testSplice(t, "tcp", "unix") })
        t.Run("tcp-to-file", func(t *testing.T) { testSpliceToFile(t, "tcp", "file") })
        t.Run("unix-to-file", func(t *testing.T) { testSpliceToFile(t, "unix", "file") })
        t.Run("no-unixpacket", testSpliceNoUnixpacket)
@@ -159,6 +160,13 @@ func (tc spliceTestCase) testFile(t *testing.T) {
 }
 
 func testSpliceReaderAtEOF(t *testing.T, upNet, downNet string) {
+       // UnixConn doesn't implement io.ReaderFrom, which will fail
+       // the following test in asserting a UnixConn to be an io.ReaderFrom,
+       // so skip this test.
+       if upNet == "unix" || downNet == "unix" {
+               t.Skip("skipping test on unix socket")
+       }
+
        clientUp, serverUp := spliceTestSocketPair(t, upNet)
        defer clientUp.Close()
        clientDown, serverDown := spliceTestSocketPair(t, downNet)
@@ -166,16 +174,16 @@ func testSpliceReaderAtEOF(t *testing.T, upNet, downNet string) {
 
        serverUp.Close()
 
-       // We'd like to call net.splice here and check the handled return
+       // We'd like to call net.spliceFrom here and check the handled return
        // value, but we disable splice on old Linux kernels.
        //
-       // In that case, poll.Splice and net.splice return a non-nil error
+       // In that case, poll.Splice and net.spliceFrom return a non-nil error
        // and handled == false. We'd ideally like to see handled == true
        // because the source reader is at EOF, but if we're running on an old
-       // kernel, and splice is disabled, we won't see EOF from net.splice,
+       // kernel, and splice is disabled, we won't see EOF from net.spliceFrom,
        // because we won't touch the reader at all.
        //
-       // Trying to untangle the errors from net.splice and match them
+       // Trying to untangle the errors from net.spliceFrom and match them
        // against the errors created by the poll package would be brittle,
        // so this is a higher level test.
        //
@@ -268,7 +276,7 @@ func testSpliceNoUnixpacket(t *testing.T) {
        //
        // What we want is err == nil and handled == false, i.e. we never
        // called poll.Splice, because we know the unix socket's network.
-       _, err, handled := splice(serverDown.(*TCPConn).fd, serverUp)
+       _, err, handled := spliceFrom(serverDown.(*TCPConn).fd, serverUp)
        if err != nil || handled != false {
                t.Fatalf("got err = %v, handled = %t, want nil error, handled == false", err, handled)
        }
@@ -289,7 +297,7 @@ func testSpliceNoUnixgram(t *testing.T) {
        defer clientDown.Close()
        defer serverDown.Close()
        // Analogous to testSpliceNoUnixpacket.
-       _, err, handled := splice(serverDown.(*TCPConn).fd, up)
+       _, err, handled := spliceFrom(serverDown.(*TCPConn).fd, up)
        if err != nil || handled != false {
                t.Fatalf("got err = %v, handled = %t, want nil error, handled == false", err, handled)
        }
@@ -300,6 +308,7 @@ func BenchmarkSplice(b *testing.B) {
 
        b.Run("tcp-to-tcp", func(b *testing.B) { benchSplice(b, "tcp", "tcp") })
        b.Run("unix-to-tcp", func(b *testing.B) { benchSplice(b, "unix", "tcp") })
+       b.Run("tcp-to-unix", func(b *testing.B) { benchSplice(b, "tcp", "unix") })
 }
 
 func benchSplice(b *testing.B, upNet, downNet string) {
index 1528353cba0b145f62a788737e192eeee23ed9c1..6257f2515b206f5f8b285525e4d00b5fb2c56873 100644 (file)
@@ -134,6 +134,18 @@ func (c *TCPConn) ReadFrom(r io.Reader) (int64, error) {
        return n, err
 }
 
+// WriteTo implements the io.WriterTo WriteTo method.
+func (c *TCPConn) WriteTo(w io.Writer) (int64, error) {
+       if !c.ok() {
+               return 0, syscall.EINVAL
+       }
+       n, err := c.writeTo(w)
+       if err != nil && err != io.EOF {
+               err = &OpError{Op: "writeto", Net: c.fd.net, Source: c.fd.laddr, Addr: c.fd.raddr, Err: err}
+       }
+       return n, err
+}
+
 // CloseRead shuts down the reading side of the TCP connection.
 // Most callers should just use Close.
 func (c *TCPConn) CloseRead() error {
index d55948f69e4fe9f67c187a4cd8b481b9a8a2740c..463dedcf44cdedf424edbf8c88c6ff3fbbed21ba 100644 (file)
@@ -14,6 +14,10 @@ func (c *TCPConn) readFrom(r io.Reader) (int64, error) {
        return genericReadFrom(c, r)
 }
 
+func (c *TCPConn) writeTo(w io.Writer) (int64, error) {
+       return genericWriteTo(c, w)
+}
+
 func (sd *sysDialer) dialTCP(ctx context.Context, laddr, raddr *TCPAddr) (*TCPConn, error) {
        if h := sd.testHookDialTCP; h != nil {
                return h(ctx, sd.network, laddr, raddr)
index 83cee7c78940f354243a449a5b2863b5e5846b70..01b5ec9ed0564243952ea36a444962ec250e0e9e 100644 (file)
@@ -45,7 +45,7 @@ func (a *TCPAddr) toLocal(net string) sockaddr {
 }
 
 func (c *TCPConn) readFrom(r io.Reader) (int64, error) {
-       if n, err, handled := splice(c.fd, r); handled {
+       if n, err, handled := spliceFrom(c.fd, r); handled {
                return n, err
        }
        if n, err, handled := sendFile(c.fd, r); handled {
@@ -54,6 +54,13 @@ func (c *TCPConn) readFrom(r io.Reader) (int64, error) {
        return genericReadFrom(c, r)
 }
 
+func (c *TCPConn) writeTo(w io.Writer) (int64, error) {
+       if n, err, handled := spliceTo(w, c.fd); handled {
+               return n, err
+       }
+       return genericWriteTo(c, w)
+}
+
 func (sd *sysDialer) dialTCP(ctx context.Context, laddr, raddr *TCPAddr) (*TCPConn, error) {
        if h := sd.testHookDialTCP; h != nil {
                return h(ctx, sd.network, laddr, raddr)
index 3fd5e61de78b79e45f405d90a5d36855cc364a41..942b48a17d802d371fc34e7e6dd0f97ce8988092 100644 (file)
@@ -5,7 +5,8 @@
 package os
 
 var (
-       PollCopyFileRangeP = &pollCopyFileRange
-       PollSpliceFile     = &pollSplice
-       GetPollFDForTest   = getPollFD
+       PollCopyFileRangeP  = &pollCopyFileRange
+       PollSpliceFile      = &pollSplice
+       PollSendFile        = &pollSendFile
+       GetPollFDAndNetwork = getPollFDAndNetwork
 )
index 82be00a834a6c7871599e90afdfc54b4181b85cc..37a30ccf041919e6493181a9dcb83bc708eb2abf 100644 (file)
@@ -157,20 +157,26 @@ func (f *File) ReadFrom(r io.Reader) (n int64, err error) {
        return n, f.wrapErr("write", e)
 }
 
-func genericReadFrom(f *File, r io.Reader) (int64, error) {
-       return io.Copy(fileWithoutReadFrom{f}, r)
+// noReadFrom can be embedded alongside another type to
+// hide the ReadFrom method of that other type.
+type noReadFrom struct{}
+
+// ReadFrom hides another ReadFrom method.
+// It should never be called.
+func (noReadFrom) ReadFrom(io.Reader) (int64, error) {
+       panic("can't happen")
 }
 
 // fileWithoutReadFrom implements all the methods of *File other
 // than ReadFrom. This is used to permit ReadFrom to call io.Copy
 // without leading to a recursive call to ReadFrom.
 type fileWithoutReadFrom struct {
+       noReadFrom
        *File
 }
 
-// This ReadFrom method hides the *File ReadFrom method.
-func (fileWithoutReadFrom) ReadFrom(fileWithoutReadFrom) {
-       panic("unreachable")
+func genericReadFrom(f *File, r io.Reader) (int64, error) {
+       return io.Copy(fileWithoutReadFrom{File: f}, r)
 }
 
 // Write writes len(b) bytes from b to the File.
@@ -229,6 +235,40 @@ func (f *File) WriteAt(b []byte, off int64) (n int, err error) {
        return
 }
 
+// WriteTo implements io.WriterTo.
+func (f *File) WriteTo(w io.Writer) (n int64, err error) {
+       if err := f.checkValid("read"); err != nil {
+               return 0, err
+       }
+       n, handled, e := f.writeTo(w)
+       if handled {
+               return n, f.wrapErr("read", e)
+       }
+       return genericWriteTo(f, w) // without wrapping
+}
+
+// noWriteTo can be embedded alongside another type to
+// hide the WriteTo method of that other type.
+type noWriteTo struct{}
+
+// WriteTo hides another WriteTo method.
+// It should never be called.
+func (noWriteTo) WriteTo(io.Writer) (int64, error) {
+       panic("can't happen")
+}
+
+// fileWithoutWriteTo implements all the methods of *File other
+// than WriteTo. This is used to permit WriteTo to call io.Copy
+// without leading to a recursive call to WriteTo.
+type fileWithoutWriteTo struct {
+       noWriteTo
+       *File
+}
+
+func genericWriteTo(f *File, w io.Writer) (int64, error) {
+       return io.Copy(w, fileWithoutWriteTo{File: f})
+}
+
 // Seek sets the offset for the next Read or Write on file to offset, interpreted
 // according to whence: 0 means relative to the origin of the file, 1 means
 // relative to the current offset, and 2 means relative to the end.
index 4f98be4b9b1786efc449fd804f24d3af6f04fbd8..93f78032e737146813f1cb24ff19aff6b2605a48 100644 (file)
@@ -749,12 +749,12 @@ func TestProcCopy(t *testing.T) {
        }
 }
 
-func TestGetPollFDFromReader(t *testing.T) {
-       t.Run("tcp", func(t *testing.T) { testGetPollFromReader(t, "tcp") })
-       t.Run("unix", func(t *testing.T) { testGetPollFromReader(t, "unix") })
+func TestGetPollFDAndNetwork(t *testing.T) {
+       t.Run("tcp4", func(t *testing.T) { testGetPollFDAndNetwork(t, "tcp4") })
+       t.Run("unix", func(t *testing.T) { testGetPollFDAndNetwork(t, "unix") })
 }
 
-func testGetPollFromReader(t *testing.T, proto string) {
+func testGetPollFDAndNetwork(t *testing.T, proto string) {
        _, server := createSocketPair(t, proto)
        sc, ok := server.(syscall.Conn)
        if !ok {
@@ -765,12 +765,15 @@ func testGetPollFromReader(t *testing.T, proto string) {
                t.Fatalf("server SyscallConn error: %v", err)
        }
        if err = rc.Control(func(fd uintptr) {
-               pfd := GetPollFDForTest(server)
+               pfd, network := GetPollFDAndNetwork(server)
                if pfd == nil {
-                       t.Fatalf("GetPollFDForTest didn't return poll.FD")
+                       t.Fatalf("GetPollFDAndNetwork didn't return poll.FD")
+               }
+               if string(network) != proto {
+                       t.Fatalf("GetPollFDAndNetwork returned wrong network, got: %s, want: %s", network, proto)
                }
                if pfd.Sysfd != int(fd) {
-                       t.Fatalf("GetPollFDForTest returned wrong poll.FD, got: %d, want: %d", pfd.Sysfd, int(fd))
+                       t.Fatalf("GetPollFDAndNetwork returned wrong poll.FD, got: %d, want: %d", pfd.Sysfd, int(fd))
                }
                if !pfd.IsStream {
                        t.Fatalf("expected IsStream to be true")
diff --git a/src/os/writeto_linux_test.go b/src/os/writeto_linux_test.go
new file mode 100644 (file)
index 0000000..5ffab88
--- /dev/null
@@ -0,0 +1,171 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package os_test
+
+import (
+       "bytes"
+       "internal/poll"
+       "io"
+       "math/rand"
+       "net"
+       . "os"
+       "strconv"
+       "syscall"
+       "testing"
+       "time"
+)
+
+func TestSendFile(t *testing.T) {
+       sizes := []int{
+               1,
+               42,
+               1025,
+               syscall.Getpagesize() + 1,
+               32769,
+       }
+       t.Run("sendfile-to-unix", func(t *testing.T) {
+               for _, size := range sizes {
+                       t.Run(strconv.Itoa(size), func(t *testing.T) {
+                               testSendFile(t, "unix", int64(size))
+                       })
+               }
+       })
+       t.Run("sendfile-to-tcp", func(t *testing.T) {
+               for _, size := range sizes {
+                       t.Run(strconv.Itoa(size), func(t *testing.T) {
+                               testSendFile(t, "tcp", int64(size))
+                       })
+               }
+       })
+}
+
+func testSendFile(t *testing.T, proto string, size int64) {
+       dst, src, recv, data, hook := newSendFileTest(t, proto, size)
+
+       // Now call WriteTo (through io.Copy), which will hopefully call poll.SendFile
+       n, err := io.Copy(dst, src)
+       if err != nil {
+               t.Fatalf("io.Copy error: %v", err)
+       }
+
+       // We should have called poll.Splice with the right file descriptor arguments.
+       if n > 0 && !hook.called {
+               t.Fatal("expected to called poll.SendFile")
+       }
+       if hook.called && hook.srcfd != int(src.Fd()) {
+               t.Fatalf("wrong source file descriptor: got %d, want %d", hook.srcfd, src.Fd())
+       }
+       sc, ok := dst.(syscall.Conn)
+       if !ok {
+               t.Fatalf("destination is not a syscall.Conn")
+       }
+       rc, err := sc.SyscallConn()
+       if err != nil {
+               t.Fatalf("destination SyscallConn error: %v", err)
+       }
+       if err = rc.Control(func(fd uintptr) {
+               if hook.called && hook.dstfd != int(fd) {
+                       t.Fatalf("wrong destination file descriptor: got %d, want %d", hook.dstfd, int(fd))
+               }
+       }); err != nil {
+               t.Fatalf("destination Conn Control error: %v", err)
+       }
+
+       // Verify the data size and content.
+       dataSize := len(data)
+       dstData := make([]byte, dataSize)
+       m, err := io.ReadFull(recv, dstData)
+       if err != nil {
+               t.Fatalf("server Conn Read error: %v", err)
+       }
+       if n != int64(dataSize) {
+               t.Fatalf("data length mismatch for io.Copy, got %d, want %d", n, dataSize)
+       }
+       if m != dataSize {
+               t.Fatalf("data length mismatch for net.Conn.Read, got %d, want %d", m, dataSize)
+       }
+       if !bytes.Equal(dstData, data) {
+               t.Errorf("data mismatch, got %s, want %s", dstData, data)
+       }
+}
+
+// newSendFileTest initializes a new test for sendfile.
+//
+// It creates source file and destination sockets, and populates the source file
+// with random data of the specified size. It also hooks package os' call
+// to poll.Sendfile and returns the hook so it can be inspected.
+func newSendFileTest(t *testing.T, proto string, size int64) (net.Conn, *File, net.Conn, []byte, *sendFileHook) {
+       t.Helper()
+
+       hook := hookSendFile(t)
+
+       client, server := createSocketPair(t, proto)
+       tempFile, data := createTempFile(t, size)
+
+       return client, tempFile, server, data, hook
+}
+
+func hookSendFile(t *testing.T) *sendFileHook {
+       h := new(sendFileHook)
+       h.install()
+       t.Cleanup(h.uninstall)
+       return h
+}
+
+type sendFileHook struct {
+       called bool
+       dstfd  int
+       srcfd  int
+       remain int64
+
+       written int64
+       handled bool
+       err     error
+
+       original func(dst *poll.FD, src int, remain int64) (int64, error, bool)
+}
+
+func (h *sendFileHook) install() {
+       h.original = *PollSendFile
+       *PollSendFile = func(dst *poll.FD, src int, remain int64) (int64, error, bool) {
+               h.called = true
+               h.dstfd = dst.Sysfd
+               h.srcfd = src
+               h.remain = remain
+               h.written, h.err, h.handled = h.original(dst, src, remain)
+               return h.written, h.err, h.handled
+       }
+}
+
+func (h *sendFileHook) uninstall() {
+       *PollSendFile = h.original
+}
+
+func createTempFile(t *testing.T, size int64) (*File, []byte) {
+       f, err := CreateTemp(t.TempDir(), "writeto-sendfile-to-socket")
+       if err != nil {
+               t.Fatalf("failed to create temporary file: %v", err)
+       }
+       t.Cleanup(func() {
+               f.Close()
+       })
+
+       randSeed := time.Now().Unix()
+       t.Logf("random data seed: %d\n", randSeed)
+       prng := rand.New(rand.NewSource(randSeed))
+       data := make([]byte, size)
+       prng.Read(data)
+       if _, err := f.Write(data); err != nil {
+               t.Fatalf("failed to create and feed the file: %v", err)
+       }
+       if err := f.Sync(); err != nil {
+               t.Fatalf("failed to save the file: %v", err)
+       }
+       if _, err := f.Seek(0, io.SeekStart); err != nil {
+               t.Fatalf("failed to rewind the file: %v", err)
+       }
+
+       return f, data
+}
similarity index 70%
rename from src/os/readfrom_linux.go
rename to src/os/zero_copy_linux.go
index 7e8024028e98e852d7787845e8a94732f88e8cd4..7c45aefeee8621b4585f39998558ab9316da19d7 100644 (file)
@@ -13,8 +13,33 @@ import (
 var (
        pollCopyFileRange = poll.CopyFileRange
        pollSplice        = poll.Splice
+       pollSendFile      = poll.SendFile
 )
 
+func (f *File) writeTo(w io.Writer) (written int64, handled bool, err error) {
+       pfd, network := getPollFDAndNetwork(w)
+       // TODO(panjf2000): same as File.spliceToFile.
+       if pfd == nil || !pfd.IsStream || !isUnixOrTCP(string(network)) {
+               return
+       }
+
+       sc, err := f.SyscallConn()
+       if err != nil {
+               return
+       }
+
+       rerr := sc.Read(func(fd uintptr) (done bool) {
+               written, err, handled = pollSendFile(pfd, int(fd), 1<<63-1)
+               return true
+       })
+
+       if err == nil {
+               err = rerr
+       }
+
+       return written, handled, wrapSyscallError("sendfile", err)
+}
+
 func (f *File) readFrom(r io.Reader) (written int64, handled bool, err error) {
        // Neither copy_file_range(2) nor splice(2) supports destinations opened with
        // O_APPEND, so don't bother to try zero-copy with these system calls.
@@ -41,7 +66,7 @@ func (f *File) spliceToFile(r io.Reader) (written int64, handled bool, err error
                return 0, true, nil
        }
 
-       pfd := getPollFD(r)
+       pfd, _ := getPollFDAndNetwork(r)
        // TODO(panjf2000): run some tests to see if we should unlock the non-streams for splice.
        // Streams benefit the most from the splice(2), non-streams are not even supported in old kernels
        // where splice(2) will just return EINVAL; newer kernels support non-streams like UDP, but I really
@@ -63,25 +88,6 @@ func (f *File) spliceToFile(r io.Reader) (written int64, handled bool, err error
        return written, handled, wrapSyscallError(syscallName, err)
 }
 
-// getPollFD tries to get the poll.FD from the given io.Reader by expecting
-// the underlying type of r to be the implementation of syscall.Conn that contains
-// a *net.rawConn.
-func getPollFD(r io.Reader) *poll.FD {
-       sc, ok := r.(syscall.Conn)
-       if !ok {
-               return nil
-       }
-       rc, err := sc.SyscallConn()
-       if err != nil {
-               return nil
-       }
-       ipfd, ok := rc.(interface{ PollFD() *poll.FD })
-       if !ok {
-               return nil
-       }
-       return ipfd.PollFD()
-}
-
 func (f *File) copyFileRange(r io.Reader) (written int64, handled bool, err error) {
        var (
                remain int64
@@ -91,10 +97,16 @@ func (f *File) copyFileRange(r io.Reader) (written int64, handled bool, err erro
                return 0, true, nil
        }
 
-       src, ok := r.(*File)
-       if !ok {
+       var src *File
+       switch v := r.(type) {
+       case *File:
+               src = v
+       case fileWithoutWriteTo:
+               src = v.File
+       default:
                return 0, false, nil
        }
+
        if src.checkValid("ReadFrom") != nil {
                // Avoid returning the error as we report handled as false,
                // leave further error handling as the responsibility of the caller.
@@ -108,6 +120,28 @@ func (f *File) copyFileRange(r io.Reader) (written int64, handled bool, err erro
        return written, handled, wrapSyscallError("copy_file_range", err)
 }
 
+// getPollFDAndNetwork tries to get the poll.FD and network type from the given interface
+// by expecting the underlying type of i to be the implementation of syscall.Conn
+// that contains a *net.rawConn.
+func getPollFDAndNetwork(i any) (*poll.FD, poll.String) {
+       sc, ok := i.(syscall.Conn)
+       if !ok {
+               return nil, ""
+       }
+       rc, err := sc.SyscallConn()
+       if err != nil {
+               return nil, ""
+       }
+       irc, ok := rc.(interface {
+               PollFD() *poll.FD
+               Network() poll.String
+       })
+       if !ok {
+               return nil, ""
+       }
+       return irc.PollFD(), irc.Network()
+}
+
 // tryLimitedReader tries to assert the io.Reader to io.LimitedReader, it returns the io.LimitedReader,
 // the underlying io.Reader and the remaining amount of bytes if the assertion succeeds,
 // otherwise it just returns the original io.Reader and the theoretical unlimited remaining amount of bytes.
@@ -122,3 +156,12 @@ func tryLimitedReader(r io.Reader) (*io.LimitedReader, io.Reader, int64) {
        remain = lr.N
        return lr, lr.R, remain
 }
+
+func isUnixOrTCP(network string) bool {
+       switch network {
+       case "tcp", "tcp4", "tcp6", "unix":
+               return true
+       default:
+               return false
+       }
+}
similarity index 74%
rename from src/os/readfrom_stub.go
rename to src/os/zero_copy_stub.go
index 8b7d5fb8f9e35c88dbce07825c73e97614c1a64d..9ec5808101889d7f90879fd41e3216349acef8f7 100644 (file)
@@ -8,6 +8,10 @@ package os
 
 import "io"
 
+func (f *File) writeTo(w io.Writer) (written int64, handled bool, err error) {
+       return 0, false, nil
+}
+
 func (f *File) readFrom(r io.Reader) (n int64, handled bool, err error) {
        return 0, false, nil
 }