]> Cypherpunks repositories - gostls13.git/commitdiff
os: support zero-copy from TCP/Unix socket to file
authorAndy Pan <panjf2000@gmail.com>
Sat, 3 Sep 2022 02:41:03 +0000 (10:41 +0800)
committerGopher Robot <gobot@golang.org>
Mon, 27 Feb 2023 00:12:08 +0000 (00:12 +0000)
Go currently supports the cases of zero-copy: from TCP/Unix socket to TCP socket, from file to TCP socket, from file to file.

Now implementing the new support of zero-copy: from TCP/Unix socket to file.

goos: linux
goarch: amd64
pkg: net
cpu: DO-Premium-Intel
                                  │      old      │                 new                  │
                                  │    sec/op     │    sec/op     vs base                │
SpliceFile/tcp-to-file/1024-4        5.910µ ±  9%   4.116µ ± 13%  -30.35% (p=0.000 n=10)
SpliceFile/tcp-to-file/2048-4        6.150µ ± 10%   4.077µ ± 13%  -33.72% (p=0.002 n=10)
SpliceFile/tcp-to-file/4096-4        4.837µ ± 28%   4.447µ ± 23%        ~ (p=0.353 n=10)
SpliceFile/tcp-to-file/8192-4        9.309µ ±  7%   6.293µ ±  9%  -32.40% (p=0.000 n=10)
SpliceFile/tcp-to-file/16384-4       19.43µ ± 12%   12.48µ ±  9%  -35.76% (p=0.000 n=10)
SpliceFile/tcp-to-file/32768-4       42.73µ ± 10%   25.32µ ±  8%  -40.76% (p=0.000 n=10)
SpliceFile/tcp-to-file/65536-4       70.37µ ± 11%   48.60µ ±  4%  -30.93% (p=0.000 n=10)
SpliceFile/tcp-to-file/131072-4     141.91µ ±  6%   96.24µ ±  4%  -32.18% (p=0.000 n=10)
SpliceFile/tcp-to-file/262144-4      329.7µ ±  8%   246.7µ ± 13%  -25.19% (p=0.000 n=10)
SpliceFile/tcp-to-file/524288-4      653.5µ ±  7%   441.6µ ±  7%  -32.43% (p=0.000 n=10)
SpliceFile/tcp-to-file/1048576-4    1184.4µ ±  9%   851.8µ ± 14%  -28.09% (p=0.000 n=10)
SpliceFile/unix-to-file/1024-4       1.734µ ± 10%   1.524µ ± 25%  -12.06% (p=0.035 n=10)
SpliceFile/unix-to-file/2048-4       2.614µ ±  7%   2.231µ ±  8%  -14.65% (p=0.000 n=10)
SpliceFile/unix-to-file/4096-4       5.081µ ±  7%   3.947µ ± 11%  -22.33% (p=0.000 n=10)
SpliceFile/unix-to-file/8192-4       8.560µ ±  5%   8.531µ ± 17%        ~ (p=0.796 n=10)
SpliceFile/unix-to-file/16384-4      18.09µ ± 12%   12.92µ ± 25%  -28.59% (p=0.000 n=10)
SpliceFile/unix-to-file/32768-4      35.50µ ±  5%   24.50µ ±  6%  -31.00% (p=0.000 n=10)
SpliceFile/unix-to-file/65536-4      69.99µ ±  7%   51.22µ ± 23%  -26.82% (p=0.000 n=10)
SpliceFile/unix-to-file/131072-4     133.7µ ± 17%   119.7µ ±  6%  -10.43% (p=0.000 n=10)
SpliceFile/unix-to-file/262144-4     246.5µ ±  5%   207.3µ ± 19%  -15.90% (p=0.007 n=10)
SpliceFile/unix-to-file/524288-4     484.8µ ± 20%   382.9µ ± 10%  -21.02% (p=0.000 n=10)
SpliceFile/unix-to-file/1048576-4   1188.4µ ± 27%   781.8µ ± 11%  -34.21% (p=0.000 n=10)
geomean                              42.24µ         31.45µ        -25.53%

                                  │      old       │                  new                   │
                                  │      B/s       │      B/s        vs base                │
SpliceFile/tcp-to-file/1024-4        165.4Mi ± 10%    237.3Mi ± 11%  +43.47% (p=0.000 n=10)
SpliceFile/tcp-to-file/2048-4        317.6Mi ± 12%    479.7Mi ± 14%  +51.02% (p=0.002 n=10)
SpliceFile/tcp-to-file/4096-4        808.2Mi ± 22%    886.8Mi ± 19%        ~ (p=0.353 n=10)
SpliceFile/tcp-to-file/8192-4        839.3Mi ±  6%   1241.5Mi ±  8%  +47.91% (p=0.000 n=10)
SpliceFile/tcp-to-file/16384-4       804.7Mi ± 13%   1252.2Mi ± 10%  +55.61% (p=0.000 n=10)
SpliceFile/tcp-to-file/32768-4       731.3Mi ± 11%   1234.3Mi ±  7%  +68.78% (p=0.000 n=10)
SpliceFile/tcp-to-file/65536-4       888.7Mi ± 10%   1286.2Mi ±  4%  +44.73% (p=0.000 n=10)
SpliceFile/tcp-to-file/131072-4      880.9Mi ±  6%   1299.0Mi ±  4%  +47.47% (p=0.000 n=10)
SpliceFile/tcp-to-file/262144-4      758.2Mi ±  7%   1014.4Mi ± 15%  +33.78% (p=0.000 n=10)
SpliceFile/tcp-to-file/524288-4      765.3Mi ±  7%   1132.5Mi ±  7%  +47.99% (p=0.000 n=10)
SpliceFile/tcp-to-file/1048576-4     845.0Mi ±  8%   1174.0Mi ± 16%  +38.94% (p=0.000 n=10)
SpliceFile/unix-to-file/1024-4       564.2Mi ± 11%    640.5Mi ± 20%  +13.53% (p=0.035 n=10)
SpliceFile/unix-to-file/2048-4       747.4Mi ±  7%    875.7Mi ±  8%  +17.17% (p=0.000 n=10)
SpliceFile/unix-to-file/4096-4       768.8Mi ±  6%    989.8Mi ± 10%  +28.74% (p=0.000 n=10)
SpliceFile/unix-to-file/8192-4       912.9Mi ±  5%    915.8Mi ± 15%        ~ (p=0.796 n=10)
SpliceFile/unix-to-file/16384-4      863.6Mi ± 10%   1209.7Mi ± 20%  +40.06% (p=0.000 n=10)
SpliceFile/unix-to-file/32768-4      880.2Mi ±  6%   1275.7Mi ±  6%  +44.93% (p=0.000 n=10)
SpliceFile/unix-to-file/65536-4      893.0Mi ±  7%   1220.3Mi ± 19%  +36.66% (p=0.000 n=10)
SpliceFile/unix-to-file/131072-4     935.1Mi ± 14%   1043.9Mi ±  7%  +11.64% (p=0.000 n=10)
SpliceFile/unix-to-file/262144-4    1014.2Mi ±  6%   1205.9Mi ± 16%  +18.91% (p=0.007 n=10)
SpliceFile/unix-to-file/524288-4     1.007Gi ± 17%    1.275Gi ±  9%  +26.61% (p=0.000 n=10)
SpliceFile/unix-to-file/1048576-4    841.8Mi ± 21%   1279.0Mi ± 10%  +51.94% (p=0.000 n=10)
geomean                              740.1Mi          994.2Mi        +34.33%

                                  │      old       │                   new                   │
                                  │      B/op      │    B/op     vs base                     │
SpliceFile/tcp-to-file/1024-4       0.000 ±   0%     0.000 ± 0%         ~ (p=1.000 n=10) ¹
SpliceFile/tcp-to-file/2048-4       0.000 ±   0%     0.000 ± 0%         ~ (p=1.000 n=10) ¹
SpliceFile/tcp-to-file/4096-4       0.000 ±   0%     0.000 ± 0%         ~ (p=1.000 n=10) ¹
SpliceFile/tcp-to-file/8192-4       0.000 ±   0%     0.000 ± 0%         ~ (p=1.000 n=10) ¹
SpliceFile/tcp-to-file/16384-4      0.000 ±   0%     0.000 ± 0%         ~ (p=1.000 n=10) ¹
SpliceFile/tcp-to-file/32768-4      1.000 ±   0%     0.000 ± 0%  -100.00% (p=0.000 n=10)
SpliceFile/tcp-to-file/65536-4      1.000 ± 100%     0.000 ± 0%  -100.00% (p=0.000 n=10)
SpliceFile/tcp-to-file/131072-4     3.000 ±  33%     0.000 ± 0%  -100.00% (p=0.000 n=10)
SpliceFile/tcp-to-file/262144-4     8.500 ±  18%     0.000 ± 0%  -100.00% (p=0.000 n=10)
SpliceFile/tcp-to-file/524288-4     16.50 ±  21%      0.00 ± 0%  -100.00% (p=0.000 n=10)
SpliceFile/tcp-to-file/1048576-4    30.50 ±  15%      0.00 ± 0%  -100.00% (p=0.000 n=10)
SpliceFile/unix-to-file/1024-4      0.000 ±   0%     0.000 ± 0%         ~ (p=1.000 n=10) ¹
SpliceFile/unix-to-file/2048-4      0.000 ±   0%     0.000 ± 0%         ~ (p=1.000 n=10) ¹
SpliceFile/unix-to-file/4096-4      0.000 ±   0%     0.000 ± 0%         ~ (p=1.000 n=10) ¹
SpliceFile/unix-to-file/8192-4      0.000 ±   0%     0.000 ± 0%         ~ (p=1.000 n=10) ¹
SpliceFile/unix-to-file/16384-4     0.000 ±   0%     0.000 ± 0%         ~ (p=1.000 n=10) ¹
SpliceFile/unix-to-file/32768-4     0.000 ±   0%     0.000 ± 0%         ~ (p=1.000 n=10) ¹
SpliceFile/unix-to-file/65536-4     1.000 ± 100%     0.000 ± 0%  -100.00% (p=0.000 n=10)
SpliceFile/unix-to-file/131072-4    3.000 ±  33%     0.000 ± 0%  -100.00% (p=0.000 n=10)
SpliceFile/unix-to-file/262144-4    6.000 ±  17%     0.000 ± 0%  -100.00% (p=0.000 n=10)
SpliceFile/unix-to-file/524288-4    12.00 ±   8%      0.00 ± 0%  -100.00% (p=0.000 n=10)
SpliceFile/unix-to-file/1048576-4   33.50 ±  10%      0.00 ± 0%  -100.00% (p=0.000 n=10)
geomean                                          ²               ?                       ² ³
¹ all samples are equal
² summaries must be >0 to compute geomean
³ ratios must be >0 to compute geomean

                                  │     old      │                 new                 │
                                  │  allocs/op   │ allocs/op   vs base                 │
SpliceFile/tcp-to-file/1024-4       0.000 ± 0%     0.000 ± 0%       ~ (p=1.000 n=10) ¹
SpliceFile/tcp-to-file/2048-4       0.000 ± 0%     0.000 ± 0%       ~ (p=1.000 n=10) ¹
SpliceFile/tcp-to-file/4096-4       0.000 ± 0%     0.000 ± 0%       ~ (p=1.000 n=10) ¹
SpliceFile/tcp-to-file/8192-4       0.000 ± 0%     0.000 ± 0%       ~ (p=1.000 n=10) ¹
SpliceFile/tcp-to-file/16384-4      0.000 ± 0%     0.000 ± 0%       ~ (p=1.000 n=10) ¹
SpliceFile/tcp-to-file/32768-4      0.000 ± 0%     0.000 ± 0%       ~ (p=1.000 n=10) ¹
SpliceFile/tcp-to-file/65536-4      0.000 ± 0%     0.000 ± 0%       ~ (p=1.000 n=10) ¹
SpliceFile/tcp-to-file/131072-4     0.000 ± 0%     0.000 ± 0%       ~ (p=1.000 n=10) ¹
SpliceFile/tcp-to-file/262144-4     0.000 ± 0%     0.000 ± 0%       ~ (p=1.000 n=10) ¹
SpliceFile/tcp-to-file/524288-4     0.000 ± 0%     0.000 ± 0%       ~ (p=1.000 n=10) ¹
SpliceFile/tcp-to-file/1048576-4    0.000 ± 0%     0.000 ± 0%       ~ (p=1.000 n=10) ¹
SpliceFile/unix-to-file/1024-4      0.000 ± 0%     0.000 ± 0%       ~ (p=1.000 n=10) ¹
SpliceFile/unix-to-file/2048-4      0.000 ± 0%     0.000 ± 0%       ~ (p=1.000 n=10) ¹
SpliceFile/unix-to-file/4096-4      0.000 ± 0%     0.000 ± 0%       ~ (p=1.000 n=10) ¹
SpliceFile/unix-to-file/8192-4      0.000 ± 0%     0.000 ± 0%       ~ (p=1.000 n=10) ¹
SpliceFile/unix-to-file/16384-4     0.000 ± 0%     0.000 ± 0%       ~ (p=1.000 n=10) ¹
SpliceFile/unix-to-file/32768-4     0.000 ± 0%     0.000 ± 0%       ~ (p=1.000 n=10) ¹
SpliceFile/unix-to-file/65536-4     0.000 ± 0%     0.000 ± 0%       ~ (p=1.000 n=10) ¹
SpliceFile/unix-to-file/131072-4    0.000 ± 0%     0.000 ± 0%       ~ (p=1.000 n=10) ¹
SpliceFile/unix-to-file/262144-4    0.000 ± 0%     0.000 ± 0%       ~ (p=1.000 n=10) ¹
SpliceFile/unix-to-file/524288-4    0.000 ± 0%     0.000 ± 0%       ~ (p=1.000 n=10) ¹
SpliceFile/unix-to-file/1048576-4   0.000 ± 0%     0.000 ± 0%       ~ (p=1.000 n=10) ¹
geomean                                        ²               +0.00%                ²
¹ all samples are equal
² summaries must be >0 to compute geomean

Change-Id: Ie7f7d4d7b6b373d9ee7ce6da8f6a4cd157632486
Reviewed-on: https://go-review.googlesource.com/c/go/+/466015
TryBot-Result: Gopher Robot <gobot@golang.org>
Auto-Submit: Ian Lance Taylor <iant@google.com>
Run-TryBot: Andy Pan <panjf2000@gmail.com>
Reviewed-by: Bryan Mills <bcmills@google.com>
Reviewed-by: Ian Lance Taylor <iant@google.com>
Run-TryBot: Ian Lance Taylor <iant@google.com>

src/net/rawconn.go
src/net/splice_test.go
src/os/export_linux_test.go
src/os/readfrom_linux.go
src/os/readfrom_linux_test.go

index c7863545826c636de4ab5cdb41eb5d0a933f04dd..974320c25f8ed7e2854ea7c95c04e605256d632c 100644 (file)
@@ -5,6 +5,7 @@
 package net
 
 import (
+       "internal/poll"
        "runtime"
        "syscall"
 )
@@ -60,6 +61,20 @@ func (c *rawConn) Write(f func(uintptr) bool) error {
        return err
 }
 
+// PollFD returns the poll.FD of the underlying connection.
+//
+// Other packages in std that also import internal/poll (such as os)
+// can use a type assertion to access this extension method so that
+// they can pass the *poll.FD to functions like poll.Splice.
+//
+// PollFD is not intended for use outside the standard library.
+func (c *rawConn) PollFD() *poll.FD {
+       if !c.ok() {
+               return nil
+       }
+       return &c.fd.pfd
+}
+
 func newRawConn(fd *netFD) (*rawConn, error) {
        return &rawConn{fd: fd}, nil
 }
index fa14c95eb715335a96d7fccd117d6d9afde01b83..c74361d61b6230fe7f33542993be47c2e6f9ce76 100644 (file)
@@ -23,10 +23,21 @@ func TestSplice(t *testing.T) {
                t.Skip("skipping unix-to-tcp tests")
        }
        t.Run("unix-to-tcp", func(t *testing.T) { testSplice(t, "unix", "tcp") })
+       t.Run("tcp-to-file", func(t *testing.T) { testSpliceToFile(t, "tcp", "file") })
+       t.Run("unix-to-file", func(t *testing.T) { testSpliceToFile(t, "unix", "file") })
        t.Run("no-unixpacket", testSpliceNoUnixpacket)
        t.Run("no-unixgram", testSpliceNoUnixgram)
 }
 
+func testSpliceToFile(t *testing.T, upNet, downNet string) {
+       t.Run("simple", spliceTestCase{upNet, downNet, 128, 128, 0}.testFile)
+       t.Run("multipleWrite", spliceTestCase{upNet, downNet, 4096, 1 << 20, 0}.testFile)
+       t.Run("big", spliceTestCase{upNet, downNet, 5 << 20, 1 << 30, 0}.testFile)
+       t.Run("honorsLimitedReader", spliceTestCase{upNet, downNet, 4096, 1 << 20, 1 << 10}.testFile)
+       t.Run("updatesLimitedReaderN", spliceTestCase{upNet, downNet, 1024, 4096, 4096 + 100}.testFile)
+       t.Run("limitedReaderAtLimit", spliceTestCase{upNet, downNet, 32, 128, 128}.testFile)
+}
+
 func testSplice(t *testing.T, upNet, downNet string) {
        t.Run("simple", spliceTestCase{upNet, downNet, 128, 128, 0}.test)
        t.Run("multipleWrite", spliceTestCase{upNet, downNet, 4096, 1 << 20, 0}.test)
@@ -96,6 +107,57 @@ func (tc spliceTestCase) test(t *testing.T) {
        }
 }
 
+func (tc spliceTestCase) testFile(t *testing.T) {
+       f, err := os.CreateTemp(t.TempDir(), "linux-splice-to-file")
+       if err != nil {
+               t.Fatal(err)
+       }
+       defer f.Close()
+
+       client, server := spliceTestSocketPair(t, tc.upNet)
+       defer server.Close()
+
+       cleanup, err := startSpliceClient(client, "w", tc.chunkSize, tc.totalSize)
+       if err != nil {
+               client.Close()
+               t.Fatal("failed to start splice client:", err)
+       }
+       defer cleanup()
+
+       var (
+               r          io.Reader = server
+               actualSize           = tc.totalSize
+       )
+       if tc.limitReadSize > 0 {
+               if tc.limitReadSize < actualSize {
+                       actualSize = tc.limitReadSize
+               }
+
+               r = &io.LimitedReader{
+                       N: int64(tc.limitReadSize),
+                       R: r,
+               }
+       }
+
+       got, err := io.Copy(f, r)
+       if err != nil {
+               t.Fatalf("failed to ReadFrom with error: %v", err)
+       }
+       if want := int64(actualSize); got != want {
+               t.Errorf("got %d bytes, want %d", got, want)
+       }
+       if tc.limitReadSize > 0 {
+               wantN := 0
+               if tc.limitReadSize > actualSize {
+                       wantN = tc.limitReadSize - actualSize
+               }
+
+               if gotN := r.(*io.LimitedReader).N; gotN != int64(wantN) {
+                       t.Errorf("r.N = %d, want %d", gotN, wantN)
+               }
+       }
+}
+
 func testSpliceReaderAtEOF(t *testing.T, upNet, downNet string) {
        clientUp, serverUp := spliceTestSocketPair(t, upNet)
        defer clientUp.Close()
@@ -415,3 +477,56 @@ func init() {
                }
        }
 }
+
+func BenchmarkSpliceFile(b *testing.B) {
+       b.Run("tcp-to-file", func(b *testing.B) { benchmarkSpliceFile(b, "tcp") })
+       b.Run("unix-to-file", func(b *testing.B) { benchmarkSpliceFile(b, "unix") })
+}
+
+func benchmarkSpliceFile(b *testing.B, proto string) {
+       for i := 0; i <= 10; i++ {
+               size := 1 << (i + 10)
+               bench := spliceFileBench{
+                       proto:     proto,
+                       chunkSize: size,
+               }
+               b.Run(strconv.Itoa(size), bench.benchSpliceFile)
+       }
+}
+
+type spliceFileBench struct {
+       proto     string
+       chunkSize int
+}
+
+func (bench spliceFileBench) benchSpliceFile(b *testing.B) {
+       f, err := os.OpenFile(os.DevNull, os.O_WRONLY, 0)
+       if err != nil {
+               b.Fatal(err)
+       }
+       defer f.Close()
+
+       totalSize := b.N * bench.chunkSize
+
+       client, server := spliceTestSocketPair(b, bench.proto)
+       defer server.Close()
+
+       cleanup, err := startSpliceClient(client, "w", bench.chunkSize, totalSize)
+       if err != nil {
+               client.Close()
+               b.Fatalf("failed to start splice client: %v", err)
+       }
+       defer cleanup()
+
+       b.ReportAllocs()
+       b.SetBytes(int64(bench.chunkSize))
+       b.ResetTimer()
+
+       got, err := io.Copy(f, server)
+       if err != nil {
+               b.Fatalf("failed to ReadFrom with error: %v", err)
+       }
+       if want := int64(totalSize); got != want {
+               b.Errorf("bytes sent mismatch, got: %d, want: %d", got, want)
+       }
+}
index d947d05df077e5d1fd2f6e89914b7bbd333fe261..3fd5e61de78b79e45f405d90a5d36855cc364a41 100644 (file)
@@ -4,4 +4,8 @@
 
 package os
 
-var PollCopyFileRangeP = &pollCopyFileRange
+var (
+       PollCopyFileRangeP = &pollCopyFileRange
+       PollSpliceFile     = &pollSplice
+       GetPollFDForTest   = getPollFD
+)
index 63ea45cf6516435265b2b71d81cca5a29ac5f971..950a6553a4c18b4499c9ee120aa54bf137435afa 100644 (file)
@@ -7,25 +7,85 @@ package os
 import (
        "internal/poll"
        "io"
+       "syscall"
 )
 
-var pollCopyFileRange = poll.CopyFileRange
+var (
+       pollCopyFileRange = poll.CopyFileRange
+       pollSplice        = poll.Splice
+)
 
 func (f *File) readFrom(r io.Reader) (written int64, handled bool, err error) {
+       written, handled, err = f.copyFileRange(r)
+       if handled {
+               return
+       }
+       return f.spliceToFile(r)
+}
+
+func (f *File) spliceToFile(r io.Reader) (written int64, handled bool, err error) {
+       var (
+               remain int64
+               lr     *io.LimitedReader
+       )
+       if lr, r, remain = tryLimitedReader(r); remain <= 0 {
+               return 0, true, nil
+       }
+
+       pfd := getPollFD(r)
+       // TODO(panjf2000): run some tests to see if we should unlock the non-streams for splice.
+       // Streams benefit the most from the splice(2), non-streams are not even supported in old kernels
+       // where splice(2) will just return EINVAL; newer kernels support non-streams like UDP, but I really
+       // doubt that splice(2) could help non-streams, cuz they usually send small frames respectively
+       // and one splice call would result in one frame.
+       // splice(2) is suitable for large data but the generation of fragments defeats its edge here.
+       // Therefore, don't bother to try splice if the r is not a streaming descriptor.
+       if pfd == nil || !pfd.IsStream {
+               return
+       }
+
+       var syscallName string
+       written, handled, syscallName, err = pollSplice(&f.pfd, pfd, remain)
+
+       if lr != nil {
+               lr.N = remain - written
+       }
+
+       return written, handled, NewSyscallError(syscallName, err)
+}
+
+// getPollFD tries to get the poll.FD from the given io.Reader by expecting
+// the underlying type of r to be the implementation of syscall.Conn that contains
+// a *net.rawConn.
+func getPollFD(r io.Reader) *poll.FD {
+       sc, ok := r.(syscall.Conn)
+       if !ok {
+               return nil
+       }
+       rc, err := sc.SyscallConn()
+       if err != nil {
+               return nil
+       }
+       ipfd, ok := rc.(interface{ PollFD() *poll.FD })
+       if !ok {
+               return nil
+       }
+       return ipfd.PollFD()
+}
+
+func (f *File) copyFileRange(r io.Reader) (written int64, handled bool, err error) {
        // copy_file_range(2) does not support destinations opened with
        // O_APPEND, so don't even try.
        if f.appendMode {
                return 0, false, nil
        }
 
-       remain := int64(1 << 62)
-
-       lr, ok := r.(*io.LimitedReader)
-       if ok {
-               remain, r = lr.N, lr.R
-               if remain <= 0 {
-                       return 0, true, nil
-               }
+       var (
+               remain int64
+               lr     *io.LimitedReader
+       )
+       if lr, r, remain = tryLimitedReader(r); remain <= 0 {
+               return 0, true, nil
        }
 
        src, ok := r.(*File)
@@ -44,3 +104,18 @@ func (f *File) readFrom(r io.Reader) (written int64, handled bool, err error) {
        }
        return written, handled, NewSyscallError("copy_file_range", err)
 }
+
+// tryLimitedReader tries to assert the io.Reader to io.LimitedReader, it returns the io.LimitedReader,
+// the underlying io.Reader and the remaining amount of bytes if the assertion succeeds,
+// otherwise it just returns the original io.Reader and the theoretical unlimited remaining amount of bytes.
+func tryLimitedReader(r io.Reader) (*io.LimitedReader, io.Reader, int64) {
+       remain := int64(1 << 62)
+
+       lr, ok := r.(*io.LimitedReader)
+       if !ok {
+               return nil, r, remain
+       }
+
+       remain = lr.N
+       return lr, lr.R, remain
+}
index 20408a887c8d8294b41be2d3b542d9511d98358c..3909c2f02e29c25f4ed7469a750d30a331cf529d 100644 (file)
@@ -9,6 +9,7 @@ import (
        "internal/poll"
        "io"
        "math/rand"
+       "net"
        "os"
        . "os"
        "path/filepath"
@@ -17,6 +18,8 @@ import (
        "syscall"
        "testing"
        "time"
+
+       "golang.org/x/net/nettest"
 )
 
 func TestCopyFileRange(t *testing.T) {
@@ -253,6 +256,145 @@ func TestCopyFileRange(t *testing.T) {
        })
 }
 
+func TestSpliceFile(t *testing.T) {
+       sizes := []int{
+               1,
+               42,
+               1025,
+               syscall.Getpagesize() + 1,
+               32769,
+       }
+       t.Run("Basic-TCP", func(t *testing.T) {
+               for _, size := range sizes {
+                       t.Run(strconv.Itoa(size), func(t *testing.T) {
+                               testSpliceFile(t, "tcp", int64(size), -1)
+                       })
+               }
+       })
+       t.Run("Basic-Unix", func(t *testing.T) {
+               for _, size := range sizes {
+                       t.Run(strconv.Itoa(size), func(t *testing.T) {
+                               testSpliceFile(t, "unix", int64(size), -1)
+                       })
+               }
+       })
+       t.Run("Limited", func(t *testing.T) {
+               t.Run("OneLess-TCP", func(t *testing.T) {
+                       for _, size := range sizes {
+                               t.Run(strconv.Itoa(size), func(t *testing.T) {
+                                       testSpliceFile(t, "tcp", int64(size), int64(size)-1)
+                               })
+                       }
+               })
+               t.Run("OneLess-Unix", func(t *testing.T) {
+                       for _, size := range sizes {
+                               t.Run(strconv.Itoa(size), func(t *testing.T) {
+                                       testSpliceFile(t, "unix", int64(size), int64(size)-1)
+                               })
+                       }
+               })
+               t.Run("Half-TCP", func(t *testing.T) {
+                       for _, size := range sizes {
+                               t.Run(strconv.Itoa(size), func(t *testing.T) {
+                                       testSpliceFile(t, "tcp", int64(size), int64(size)/2)
+                               })
+                       }
+               })
+               t.Run("Half-Unix", func(t *testing.T) {
+                       for _, size := range sizes {
+                               t.Run(strconv.Itoa(size), func(t *testing.T) {
+                                       testSpliceFile(t, "unix", int64(size), int64(size)/2)
+                               })
+                       }
+               })
+               t.Run("More-TCP", func(t *testing.T) {
+                       for _, size := range sizes {
+                               t.Run(strconv.Itoa(size), func(t *testing.T) {
+                                       testSpliceFile(t, "tcp", int64(size), int64(size)+1)
+                               })
+                       }
+               })
+               t.Run("More-Unix", func(t *testing.T) {
+                       for _, size := range sizes {
+                               t.Run(strconv.Itoa(size), func(t *testing.T) {
+                                       testSpliceFile(t, "unix", int64(size), int64(size)+1)
+                               })
+                       }
+               })
+       })
+}
+
+func testSpliceFile(t *testing.T, proto string, size, limit int64) {
+       dst, src, data, hook, cleanup := newSpliceFileTest(t, proto, size)
+       defer cleanup()
+
+       // If we have a limit, wrap the reader.
+       var (
+               r  io.Reader
+               lr *io.LimitedReader
+       )
+       if limit >= 0 {
+               lr = &io.LimitedReader{N: limit, R: src}
+               r = lr
+               if limit < int64(len(data)) {
+                       data = data[:limit]
+               }
+       } else {
+               r = src
+       }
+       // Now call ReadFrom (through io.Copy), which will hopefully call poll.Splice
+       n, err := io.Copy(dst, r)
+       if err != nil {
+               t.Fatal(err)
+       }
+
+       // We should have called poll.Splice with the right file descriptor arguments.
+       if n > 0 && !hook.called {
+               t.Fatal("expected to called poll.Splice")
+       }
+       if hook.called && hook.dstfd != int(dst.Fd()) {
+               t.Fatalf("wrong destination file descriptor: got %d, want %d", hook.dstfd, dst.Fd())
+       }
+       sc, ok := src.(syscall.Conn)
+       if !ok {
+               t.Fatalf("server Conn is not a syscall.Conn")
+       }
+       rc, err := sc.SyscallConn()
+       if err != nil {
+               t.Fatalf("server Conn SyscallConn error: %v", err)
+       }
+       if err = rc.Control(func(fd uintptr) {
+               if hook.called && hook.srcfd != int(fd) {
+                       t.Fatalf("wrong source file descriptor: got %d, want %d", hook.srcfd, int(fd))
+               }
+       }); err != nil {
+               t.Fatalf("server Conn Control error: %v", err)
+       }
+
+       // Check that the offsets after the transfer make sense, that the size
+       // of the transfer was reported correctly, and that the destination
+       // file contains exactly the bytes we expect it to contain.
+       dstoff, err := dst.Seek(0, io.SeekCurrent)
+       if err != nil {
+               t.Fatal(err)
+       }
+       if dstoff != int64(len(data)) {
+               t.Errorf("dstoff = %d, want %d", dstoff, len(data))
+       }
+       if n != int64(len(data)) {
+               t.Errorf("short ReadFrom: wrote %d bytes, want %d", n, len(data))
+       }
+       mustSeekStart(t, dst)
+       mustContainData(t, dst, data)
+
+       // If we had a limit, check that it was updated.
+       if lr != nil {
+               if want := limit - n; lr.N != want {
+                       t.Fatalf("didn't update limit correctly: got %d, want %d", lr.N, want)
+               }
+       }
+}
+
 func testCopyFileRange(t *testing.T, size int64, limit int64) {
        dst, src, data, hook := newCopyFileRangeTest(t, size)
 
@@ -359,6 +501,40 @@ func newCopyFileRangeTest(t *testing.T, size int64) (dst, src *File, data []byte
        return dst, src, data, hook
 }
 
+// newSpliceFileTest initializes a new test for splice.
+//
+// It creates source sockets and destination file, and populates the source sockets
+// with random data of the specified size. It also hooks package os' call
+// to poll.Splice and returns the hook so it can be inspected.
+func newSpliceFileTest(t *testing.T, proto string, size int64) (*File, net.Conn, []byte, *spliceFileHook, func()) {
+       t.Helper()
+
+       hook := hookSpliceFile(t)
+
+       client, server := createSocketPair(t, proto)
+
+       dst, err := CreateTemp(t.TempDir(), "dst-splice-file-test")
+       if err != nil {
+               t.Fatal(err)
+       }
+       t.Cleanup(func() { dst.Close() })
+
+       randSeed := time.Now().Unix()
+       t.Logf("random data seed: %d\n", randSeed)
+       prng := rand.New(rand.NewSource(randSeed))
+       data := make([]byte, size)
+       prng.Read(data)
+
+       done := make(chan struct{})
+       go func() {
+               client.Write(data)
+               client.Close()
+               close(done)
+       }()
+
+       return dst, server, data, hook, func() { <-done }
+}
+
 // mustContainData ensures that the specified file contains exactly the
 // specified data.
 func mustContainData(t *testing.T, f *File, data []byte) {
@@ -418,6 +594,43 @@ func (h *copyFileRangeHook) uninstall() {
        *PollCopyFileRangeP = h.original
 }
 
+func hookSpliceFile(t *testing.T) *spliceFileHook {
+       h := new(spliceFileHook)
+       h.install()
+       t.Cleanup(h.uninstall)
+       return h
+}
+
+type spliceFileHook struct {
+       called bool
+       dstfd  int
+       srcfd  int
+       remain int64
+
+       written int64
+       handled bool
+       sc      string
+       err     error
+
+       original func(dst, src *poll.FD, remain int64) (int64, bool, string, error)
+}
+
+func (h *spliceFileHook) install() {
+       h.original = *PollSpliceFile
+       *PollSpliceFile = func(dst, src *poll.FD, remain int64) (int64, bool, string, error) {
+               h.called = true
+               h.dstfd = dst.Sysfd
+               h.srcfd = src.Sysfd
+               h.remain = remain
+               h.written, h.handled, h.sc, h.err = h.original(dst, src, remain)
+               return h.written, h.handled, h.sc, h.err
+       }
+}
+
+func (h *spliceFileHook) uninstall() {
+       *PollSpliceFile = h.original
+}
+
 // On some kernels copy_file_range fails on files in /proc.
 func TestProcCopy(t *testing.T) {
        t.Parallel()
@@ -451,3 +664,72 @@ func TestProcCopy(t *testing.T) {
                t.Errorf("copy of %q got %q want %q\n", cmdlineFile, copy, cmdline)
        }
 }
+
+func TestGetPollFDFromReader(t *testing.T) {
+       t.Run("tcp", func(t *testing.T) { testGetPollFromReader(t, "tcp") })
+       t.Run("unix", func(t *testing.T) { testGetPollFromReader(t, "unix") })
+}
+
+func testGetPollFromReader(t *testing.T, proto string) {
+       _, server := createSocketPair(t, proto)
+       sc, ok := server.(syscall.Conn)
+       if !ok {
+               t.Fatalf("server Conn is not a syscall.Conn")
+       }
+       rc, err := sc.SyscallConn()
+       if err != nil {
+               t.Fatalf("server SyscallConn error: %v", err)
+       }
+       if err = rc.Control(func(fd uintptr) {
+               pfd := os.GetPollFDForTest(server)
+               if pfd == nil {
+                       t.Fatalf("GetPollFDForTest didn't return poll.FD")
+               }
+               if pfd.Sysfd != int(fd) {
+                       t.Fatalf("GetPollFDForTest returned wrong poll.FD, got: %d, want: %d", pfd.Sysfd, int(fd))
+               }
+               if !pfd.IsStream {
+                       t.Fatalf("expected IsStream to be true")
+               }
+               if err = pfd.Init(proto, true); err == nil {
+                       t.Fatalf("Init should have failed with the initialized poll.FD and return EEXIST error")
+               }
+       }); err != nil {
+               t.Fatalf("server Control error: %v", err)
+       }
+}
+
+func createSocketPair(t *testing.T, proto string) (client, server net.Conn) {
+       t.Helper()
+
+       ln, err := nettest.NewLocalListener(proto)
+       if err != nil {
+               t.Fatalf("NewLocalListener error: %v", err)
+       }
+       t.Cleanup(func() {
+               if ln != nil {
+                       ln.Close()
+               }
+               if client != nil {
+                       client.Close()
+               }
+               if server != nil {
+                       server.Close()
+               }
+       })
+       ch := make(chan struct{})
+       go func() {
+               var err error
+               server, err = ln.Accept()
+               if err != nil {
+                       t.Errorf("Accept new connection error: %v", err)
+               }
+               ch <- struct{}{}
+       }()
+       client, err = net.Dial(proto, ln.Addr().String())
+       <-ch
+       if err != nil {
+               t.Fatalf("Dial new connection error: %v", err)
+       }
+       return client, server
+}