]> Cypherpunks repositories - gostls13.git/commitdiff
os: employ sendfile(2) for file-to-file copying on Linux when needed
authorAndy Pan <i@andypan.me>
Tue, 6 Aug 2024 08:36:45 +0000 (16:36 +0800)
committerGopher Robot <gobot@golang.org>
Wed, 7 Aug 2024 17:28:44 +0000 (17:28 +0000)
Go utilizes copy_file_range(2) for file-to-file copying only on kernel 5.3+,
but even on 5.3+ this system call can still go wrong for some reason (check
out the comment inside poll.CopyFileRange).

Before Linux 2.6.33, out_fd must refer to a socket, but since Linux 2.6.33
it can be any file. Thus, we can employ sendfile(2) for copy between files
when copy_file_range(2) fails to handle the copy, that way we can still
benefit from the zero-copy technique on kernel <5.3 and wherever
copy_file_range(2) is available but broken.

Change-Id: I3922218c95ad34ee649ccdf3ccfbd1ce692bebcc
Reviewed-on: https://go-review.googlesource.com/c/go/+/603295
Reviewed-by: David Chase <drchase@google.com>
Auto-Submit: Ian Lance Taylor <iant@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Ian Lance Taylor <iant@google.com>
src/os/readfrom_linux_test.go
src/os/writeto_linux_test.go
src/os/zero_copy_linux.go

index 8dcb9cb217288263221afa03f877fe5503879e54..3822b2e329d1fd353d304d0ae1956dbd1ee82ab8 100644 (file)
@@ -25,7 +25,7 @@ import (
        "golang.org/x/net/nettest"
 )
 
-func TestCopyFileRange(t *testing.T) {
+func TestCopyFileRangeAndSendFile(t *testing.T) {
        sizes := []int{
                1,
                42,
@@ -37,6 +37,7 @@ func TestCopyFileRange(t *testing.T) {
                for _, size := range sizes {
                        t.Run(strconv.Itoa(size), func(t *testing.T) {
                                testCopyFileRange(t, int64(size), -1)
+                               testSendfileOverCopyFileRange(t, int64(size), -1)
                        })
                }
        })
@@ -45,6 +46,7 @@ func TestCopyFileRange(t *testing.T) {
                        for _, size := range sizes {
                                t.Run(strconv.Itoa(size), func(t *testing.T) {
                                        testCopyFileRange(t, int64(size), int64(size)-1)
+                                       testSendfileOverCopyFileRange(t, int64(size), int64(size)-1)
                                })
                        }
                })
@@ -52,6 +54,7 @@ func TestCopyFileRange(t *testing.T) {
                        for _, size := range sizes {
                                t.Run(strconv.Itoa(size), func(t *testing.T) {
                                        testCopyFileRange(t, int64(size), int64(size)/2)
+                                       testSendfileOverCopyFileRange(t, int64(size), int64(size)/2)
                                })
                        }
                })
@@ -59,173 +62,207 @@ func TestCopyFileRange(t *testing.T) {
                        for _, size := range sizes {
                                t.Run(strconv.Itoa(size), func(t *testing.T) {
                                        testCopyFileRange(t, int64(size), int64(size)+7)
+                                       testSendfileOverCopyFileRange(t, int64(size), int64(size)+7)
                                })
                        }
                })
        })
        t.Run("DoesntTryInAppendMode", func(t *testing.T) {
-               dst, src, data, hook := newCopyFileRangeTest(t, 42)
+               for _, newTest := range []func(*testing.T, int64) (*File, *File, []byte, *copyFileHook, string){
+                       newCopyFileRangeTest, newSendfileOverCopyFileRangeTest} {
+                       dst, src, data, hook, testName := newTest(t, 42)
 
-               dst2, err := OpenFile(dst.Name(), O_RDWR|O_APPEND, 0755)
-               if err != nil {
-                       t.Fatal(err)
-               }
-               defer dst2.Close()
+                       dst2, err := OpenFile(dst.Name(), O_RDWR|O_APPEND, 0755)
+                       if err != nil {
+                               t.Fatalf("%s: %v", testName, err)
+                       }
+                       defer dst2.Close()
 
-               if _, err := io.Copy(dst2, src); err != nil {
-                       t.Fatal(err)
-               }
-               if hook.called {
-                       t.Fatal("called poll.CopyFileRange for destination in O_APPEND mode")
+                       if _, err := io.Copy(dst2, src); err != nil {
+                               t.Fatalf("%s: %v", testName, err)
+                       }
+                       if hook.called {
+                               t.Fatalf("%s: hook shouldn't be called with destination in O_APPEND mode", testName)
+                       }
+                       mustSeekStart(t, dst2)
+                       mustContainData(t, dst2, data) // through traditional means
                }
-               mustSeekStart(t, dst2)
-               mustContainData(t, dst2, data) // through traditional means
        })
        t.Run("CopyFileItself", func(t *testing.T) {
-               hook := hookCopyFileRange(t)
+               for _, hookFunc := range []func(*testing.T) (*copyFileHook, string){hookCopyFileRange, hookSendFileOverCopyFileRange} {
+                       hook, testName := hookFunc(t)
 
-               f, err := CreateTemp("", "file-readfrom-itself-test")
-               if err != nil {
-                       t.Fatalf("failed to create tmp file: %v", err)
-               }
-               t.Cleanup(func() {
-                       f.Close()
-                       Remove(f.Name())
-               })
+                       f, err := CreateTemp("", "file-readfrom-itself-test")
+                       if err != nil {
+                               t.Fatalf("%s: failed to create tmp file: %v", testName, err)
+                       }
+                       t.Cleanup(func() {
+                               f.Close()
+                               Remove(f.Name())
+                       })
 
-               data := []byte("hello world!")
-               if _, err := f.Write(data); err != nil {
-                       t.Fatalf("failed to create and feed the file: %v", err)
-               }
+                       data := []byte("hello world!")
+                       if _, err := f.Write(data); err != nil {
+                               t.Fatalf("%s: failed to create and feed the file: %v", testName, err)
+                       }
 
-               if err := f.Sync(); err != nil {
-                       t.Fatalf("failed to save the file: %v", err)
-               }
+                       if err := f.Sync(); err != nil {
+                               t.Fatalf("%s: failed to save the file: %v", testName, err)
+                       }
 
-               // Rewind it.
-               if _, err := f.Seek(0, io.SeekStart); err != nil {
-                       t.Fatalf("failed to rewind the file: %v", err)
-               }
+                       // Rewind it.
+                       if _, err := f.Seek(0, io.SeekStart); err != nil {
+                               t.Fatalf("%s: failed to rewind the file: %v", testName, err)
+                       }
 
-               // Read data from the file itself.
-               if _, err := io.Copy(f, f); err != nil {
-                       t.Fatalf("failed to read from the file: %v", err)
-               }
+                       // Read data from the file itself.
+                       if _, err := io.Copy(f, f); err != nil {
+                               t.Fatalf("%s: failed to read from the file: %v", testName, err)
+                       }
 
-               if !hook.called || hook.written != 0 || hook.handled || hook.err != nil {
-                       t.Fatalf("poll.CopyFileRange should be called and return the EINVAL error, but got hook.called=%t, hook.err=%v", hook.called, hook.err)
-               }
+                       if hook.written != 0 || hook.handled || hook.err != nil {
+                               t.Fatalf("%s: File.readFrom is expected not to use any zero-copy techniques when copying itself."+
+                                       "got hook.written=%d, hook.handled=%t, hook.err=%v; expected hook.written=0, hook.handled=false, hook.err=nil",
+                                       testName, hook.written, hook.handled, hook.err)
+                       }
 
-               // Rewind it.
-               if _, err := f.Seek(0, io.SeekStart); err != nil {
-                       t.Fatalf("failed to rewind the file: %v", err)
-               }
+                       switch testName {
+                       case "hookCopyFileRange":
+                               // For copy_file_range(2), it fails and returns EINVAL when the source and target
+                               // refer to the same file and their ranges overlap. The hook should be called to
+                               // get the returned error and fall back to generic copy.
+                               if !hook.called {
+                                       t.Fatalf("%s: should have called the hook", testName)
+                               }
+                       case "hookSendFileOverCopyFileRange":
+                               // For sendfile(2), it allows the source and target to refer to the same file and overlap.
+                               // The hook should not be called and just fall back to generic copy directly.
+                               if hook.called {
+                                       t.Fatalf("%s: shouldn't have called the hook", testName)
+                               }
+                       default:
+                               t.Fatalf("%s: unexpected test", testName)
+                       }
 
-               data2, err := io.ReadAll(f)
-               if err != nil {
-                       t.Fatalf("failed to read from the file: %v", err)
-               }
+                       // Rewind it.
+                       if _, err := f.Seek(0, io.SeekStart); err != nil {
+                               t.Fatalf("%s: failed to rewind the file: %v", testName, err)
+                       }
+
+                       data2, err := io.ReadAll(f)
+                       if err != nil {
+                               t.Fatalf("%s: failed to read from the file: %v", testName, err)
+                       }
 
-               // It should wind up a double of the original data.
-               if strings.Repeat(string(data), 2) != string(data2) {
-                       t.Fatalf("data mismatch: %s != %s", string(data), string(data2))
+                       // It should wind up a double of the original data.
+                       if s := strings.Repeat(string(data), 2); s != string(data2) {
+                               t.Fatalf("%s: file contained %s, expected %s", testName, data2, s)
+                       }
                }
        })
        t.Run("NotRegular", func(t *testing.T) {
                t.Run("BothPipes", func(t *testing.T) {
-                       hook := hookCopyFileRange(t)
+                       for _, hookFunc := range []func(*testing.T) (*copyFileHook, string){hookCopyFileRange, hookSendFileOverCopyFileRange} {
+                               hook, testName := hookFunc(t)
 
-                       pr1, pw1, err := Pipe()
-                       if err != nil {
-                               t.Fatal(err)
-                       }
-                       defer pr1.Close()
-                       defer pw1.Close()
+                               pr1, pw1, err := Pipe()
+                               if err != nil {
+                                       t.Fatalf("%s: %v", testName, err)
+                               }
+                               defer pr1.Close()
+                               defer pw1.Close()
 
-                       pr2, pw2, err := Pipe()
-                       if err != nil {
-                               t.Fatal(err)
-                       }
-                       defer pr2.Close()
-                       defer pw2.Close()
-
-                       // The pipe is empty, and PIPE_BUF is large enough
-                       // for this, by (POSIX) definition, so there is no
-                       // need for an additional goroutine.
-                       data := []byte("hello")
-                       if _, err := pw1.Write(data); err != nil {
-                               t.Fatal(err)
-                       }
-                       pw1.Close()
+                               pr2, pw2, err := Pipe()
+                               if err != nil {
+                                       t.Fatalf("%s: %v", testName, err)
+                               }
+                               defer pr2.Close()
+                               defer pw2.Close()
+
+                               // The pipe is empty, and PIPE_BUF is large enough
+                               // for this, by (POSIX) definition, so there is no
+                               // need for an additional goroutine.
+                               data := []byte("hello")
+                               if _, err := pw1.Write(data); err != nil {
+                                       t.Fatalf("%s: %v", testName, err)
+                               }
+                               pw1.Close()
 
-                       n, err := io.Copy(pw2, pr1)
-                       if err != nil {
-                               t.Fatal(err)
-                       }
-                       if n != int64(len(data)) {
-                               t.Fatalf("transferred %d, want %d", n, len(data))
-                       }
-                       if !hook.called {
-                               t.Fatalf("should have called poll.CopyFileRange")
+                               n, err := io.Copy(pw2, pr1)
+                               if err != nil {
+                                       t.Fatalf("%s: %v", testName, err)
+                               }
+                               if n != int64(len(data)) {
+                                       t.Fatalf("%s: transferred %d, want %d", testName, n, len(data))
+                               }
+                               if !hook.called {
+                                       t.Fatalf("%s: should have called the hook", testName)
+                               }
+                               pw2.Close()
+                               mustContainData(t, pr2, data)
                        }
-                       pw2.Close()
-                       mustContainData(t, pr2, data)
                })
                t.Run("DstPipe", func(t *testing.T) {
-                       dst, src, data, hook := newCopyFileRangeTest(t, 255)
-                       dst.Close()
-
-                       pr, pw, err := Pipe()
-                       if err != nil {
-                               t.Fatal(err)
-                       }
-                       defer pr.Close()
-                       defer pw.Close()
+                       for _, newTest := range []func(*testing.T, int64) (*File, *File, []byte, *copyFileHook, string){
+                               newCopyFileRangeTest, newSendfileOverCopyFileRangeTest} {
+                               dst, src, data, hook, testName := newTest(t, 255)
+                               dst.Close()
+
+                               pr, pw, err := Pipe()
+                               if err != nil {
+                                       t.Fatalf("%s: %v", testName, err)
+                               }
+                               defer pr.Close()
+                               defer pw.Close()
 
-                       n, err := io.Copy(pw, src)
-                       if err != nil {
-                               t.Fatal(err)
-                       }
-                       if n != int64(len(data)) {
-                               t.Fatalf("transferred %d, want %d", n, len(data))
-                       }
-                       if !hook.called {
-                               t.Fatalf("should have called poll.CopyFileRange")
+                               n, err := io.Copy(pw, src)
+                               if err != nil {
+                                       t.Fatalf("%s: %v", testName, err)
+                               }
+                               if n != int64(len(data)) {
+                                       t.Fatalf("%s: transferred %d, want %d", testName, n, len(data))
+                               }
+                               if !hook.called {
+                                       t.Fatalf("%s: should have called the hook", testName)
+                               }
+                               pw.Close()
+                               mustContainData(t, pr, data)
                        }
-                       pw.Close()
-                       mustContainData(t, pr, data)
                })
                t.Run("SrcPipe", func(t *testing.T) {
-                       dst, src, data, hook := newCopyFileRangeTest(t, 255)
-                       src.Close()
-
-                       pr, pw, err := Pipe()
-                       if err != nil {
-                               t.Fatal(err)
-                       }
-                       defer pr.Close()
-                       defer pw.Close()
-
-                       // The pipe is empty, and PIPE_BUF is large enough
-                       // for this, by (POSIX) definition, so there is no
-                       // need for an additional goroutine.
-                       if _, err := pw.Write(data); err != nil {
-                               t.Fatal(err)
-                       }
-                       pw.Close()
+                       for _, newTest := range []func(*testing.T, int64) (*File, *File, []byte, *copyFileHook, string){
+                               newCopyFileRangeTest, newSendfileOverCopyFileRangeTest} {
+                               dst, src, data, hook, testName := newTest(t, 255)
+                               src.Close()
+
+                               pr, pw, err := Pipe()
+                               if err != nil {
+                                       t.Fatalf("%s: %v", testName, err)
+                               }
+                               defer pr.Close()
+                               defer pw.Close()
+
+                               // The pipe is empty, and PIPE_BUF is large enough
+                               // for this, by (POSIX) definition, so there is no
+                               // need for an additional goroutine.
+                               if _, err := pw.Write(data); err != nil {
+                                       t.Fatalf("%s: %v", testName, err)
+                               }
+                               pw.Close()
 
-                       n, err := io.Copy(dst, pr)
-                       if err != nil {
-                               t.Fatal(err)
-                       }
-                       if n != int64(len(data)) {
-                               t.Fatalf("transferred %d, want %d", n, len(data))
-                       }
-                       if !hook.called {
-                               t.Fatalf("should have called poll.CopyFileRange")
+                               n, err := io.Copy(dst, pr)
+                               if err != nil {
+                                       t.Fatalf("%s: %v", testName, err)
+                               }
+                               if n != int64(len(data)) {
+                                       t.Fatalf("%s: transferred %d, want %d", testName, n, len(data))
+                               }
+                               if !hook.called {
+                                       t.Fatalf("%s: should have called the hook", testName)
+                               }
+                               mustSeekStart(t, dst)
+                               mustContainData(t, dst, data)
                        }
-                       mustSeekStart(t, dst)
-                       mustContainData(t, dst, data)
                })
        })
        t.Run("Nil", func(t *testing.T) {
@@ -480,8 +517,16 @@ func testSpliceToTTY(t *testing.T, proto string, size int64) {
 }
 
 func testCopyFileRange(t *testing.T, size int64, limit int64) {
-       dst, src, data, hook := newCopyFileRangeTest(t, size)
+       dst, src, data, hook, name := newCopyFileRangeTest(t, size)
+       testCopyFile(t, dst, src, data, hook, limit, name)
+}
+
+func testSendfileOverCopyFileRange(t *testing.T, size int64, limit int64) {
+       dst, src, data, hook, name := newSendfileOverCopyFileRangeTest(t, size)
+       testCopyFile(t, dst, src, data, hook, limit, name)
+}
 
+func testCopyFile(t *testing.T, dst, src *File, data []byte, hook *copyFileHook, limit int64, testName string) {
        // If we have a limit, wrap the reader.
        var (
                realsrc io.Reader
@@ -498,22 +543,22 @@ func testCopyFileRange(t *testing.T, size int64, limit int64) {
        }
 
        // Now call ReadFrom (through io.Copy), which will hopefully call
-       // poll.CopyFileRange.
+       // poll.CopyFileRange or poll.SendFile.
        n, err := io.Copy(dst, realsrc)
        if err != nil {
-               t.Fatal(err)
+               t.Fatalf("%s: %v", testName, err)
        }
 
-       // If we didn't have a limit, we should have called poll.CopyFileRange
-       // with the right file descriptor arguments.
-       if limit > 0 && !hook.called {
-               t.Fatal("never called poll.CopyFileRange")
+       // If we didn't have a limit or had a positive limit, we should have called
+       // poll.CopyFileRange or poll.SendFile with the right file descriptor arguments.
+       if limit != 0 && !hook.called {
+               t.Fatalf("%s: never called the hook", testName)
        }
        if hook.called && hook.dstfd != int(dst.Fd()) {
-               t.Fatalf("wrong destination file descriptor: got %d, want %d", hook.dstfd, dst.Fd())
+               t.Fatalf("%s: wrong destination file descriptor: got %d, want %d", testName, hook.dstfd, dst.Fd())
        }
        if hook.called && hook.srcfd != int(src.Fd()) {
-               t.Fatalf("wrong source file descriptor: got %d, want %d", hook.srcfd, src.Fd())
+               t.Fatalf("%s: wrong source file descriptor: got %d, want %d", testName, hook.srcfd, src.Fd())
        }
 
        // Check that the offsets after the transfer make sense, that the size
@@ -521,20 +566,20 @@ func testCopyFileRange(t *testing.T, size int64, limit int64) {
        // file contains exactly the bytes we expect it to contain.
        dstoff, err := dst.Seek(0, io.SeekCurrent)
        if err != nil {
-               t.Fatal(err)
+               t.Fatalf("%s: %v", testName, err)
        }
        srcoff, err := src.Seek(0, io.SeekCurrent)
        if err != nil {
-               t.Fatal(err)
+               t.Fatalf("%s: %v", testName, err)
        }
        if dstoff != srcoff {
-               t.Errorf("offsets differ: dstoff = %d, srcoff = %d", dstoff, srcoff)
+               t.Errorf("%s: offsets differ: dstoff = %d, srcoff = %d", testName, dstoff, srcoff)
        }
        if dstoff != int64(len(data)) {
-               t.Errorf("dstoff = %d, want %d", dstoff, len(data))
+               t.Errorf("%s: dstoff = %d, want %d", testName, dstoff, len(data))
        }
        if n != int64(len(data)) {
-               t.Errorf("short ReadFrom: wrote %d bytes, want %d", n, len(data))
+               t.Errorf("%s: short ReadFrom: wrote %d bytes, want %d", testName, n, len(data))
        }
        mustSeekStart(t, dst)
        mustContainData(t, dst, data)
@@ -542,47 +587,53 @@ func testCopyFileRange(t *testing.T, size int64, limit int64) {
        // If we had a limit, check that it was updated.
        if lr != nil {
                if want := limit - n; lr.N != want {
-                       t.Fatalf("didn't update limit correctly: got %d, want %d", lr.N, want)
+                       t.Fatalf("%s: didn't update limit correctly: got %d, want %d", testName, lr.N, want)
                }
        }
 }
 
 // newCopyFileRangeTest initializes a new test for copy_file_range.
 //
-// It creates source and destination files, and populates the source file
-// with random data of the specified size. It also hooks package os' call
-// to poll.CopyFileRange and returns the hook so it can be inspected.
-func newCopyFileRangeTest(t *testing.T, size int64) (dst, src *File, data []byte, hook *copyFileRangeHook) {
+// It hooks package os' call to poll.CopyFileRange and returns the hook,
+// so it can be inspected.
+func newCopyFileRangeTest(t *testing.T, size int64) (dst, src *File, data []byte, hook *copyFileHook, name string) {
        t.Helper()
+       name = "newCopyFileRangeTest"
 
-       hook = hookCopyFileRange(t)
-       tmp := t.TempDir()
+       dst, src, data = newCopyFileTest(t, size)
+       hook, _ = hookCopyFileRange(t)
 
-       src, err := Create(filepath.Join(tmp, "src"))
-       if err != nil {
-               t.Fatal(err)
-       }
-       t.Cleanup(func() { src.Close() })
+       return
+}
+
+// newSendFileTest initializes a new test for sendfile over copy_file_range.
+// It hooks package os' call to poll.SendFile and returns the hook,
+// so it can be inspected.
+func newSendfileOverCopyFileRangeTest(t *testing.T, size int64) (dst, src *File, data []byte, hook *copyFileHook, name string) {
+       t.Helper()
+
+       name = "newSendfileOverCopyFileRangeTest"
 
-       dst, err = Create(filepath.Join(tmp, "dst"))
+       dst, src, data = newCopyFileTest(t, size)
+       hook, _ = hookSendFileOverCopyFileRange(t)
+
+       return
+}
+
+// newCopyFileTest initializes a new test for copying data between files.
+// It creates source and destination files, and populates the source file
+// with random data of the specified size, then rewind it, so it can be
+// consumed by copy_file_range(2) or sendfile(2).
+func newCopyFileTest(t *testing.T, size int64) (dst, src *File, data []byte) {
+       src, data = createTempFile(t, "test-copy_file_range-sendfile-src", size)
+
+       dst, err := CreateTemp(t.TempDir(), "test-copy_file_range-sendfile-dst")
        if err != nil {
                t.Fatal(err)
        }
        t.Cleanup(func() { dst.Close() })
 
-       // Populate the source file with data, then rewind it, so it can be
-       // consumed by copy_file_range(2).
-       prng := rand.New(rand.NewSource(time.Now().Unix()))
-       data = make([]byte, size)
-       prng.Read(data)
-       if _, err := src.Write(data); err != nil {
-               t.Fatal(err)
-       }
-       if _, err := src.Seek(0, io.SeekStart); err != nil {
-               t.Fatal(err)
-       }
-
-       return dst, src, data, hook
+       return
 }
 
 // newSpliceFileTest initializes a new test for splice.
@@ -642,40 +693,58 @@ func mustSeekStart(t *testing.T, f *File) {
        }
 }
 
-func hookCopyFileRange(t *testing.T) *copyFileRangeHook {
-       h := new(copyFileRangeHook)
-       h.install()
-       t.Cleanup(h.uninstall)
-       return h
+func hookCopyFileRange(t *testing.T) (hook *copyFileHook, name string) {
+       name = "hookCopyFileRange"
+
+       hook = new(copyFileHook)
+       orig := *PollCopyFileRangeP
+       t.Cleanup(func() {
+               *PollCopyFileRangeP = orig
+       })
+       *PollCopyFileRangeP = func(dst, src *poll.FD, remain int64) (int64, bool, error) {
+               hook.called = true
+               hook.dstfd = dst.Sysfd
+               hook.srcfd = src.Sysfd
+               hook.written, hook.handled, hook.err = orig(dst, src, remain)
+               return hook.written, hook.handled, hook.err
+       }
+       return
 }
 
-type copyFileRangeHook struct {
+func hookSendFileOverCopyFileRange(t *testing.T) (hook *copyFileHook, name string) {
+       name = "hookSendFileOverCopyFileRange"
+
+       // Disable poll.CopyFileRange to force the fallback to poll.SendFile.
+       originalCopyFileRange := *PollCopyFileRangeP
+       *PollCopyFileRangeP = func(dst, src *poll.FD, remain int64) (written int64, handled bool, err error) {
+               return 0, false, nil
+       }
+
+       hook = new(copyFileHook)
+       orig := poll.TestHookDidSendFile
+       t.Cleanup(func() {
+               *PollCopyFileRangeP = originalCopyFileRange
+               poll.TestHookDidSendFile = orig
+       })
+       poll.TestHookDidSendFile = func(dstFD *poll.FD, src int, written int64, err error, handled bool) {
+               hook.called = true
+               hook.dstfd = dstFD.Sysfd
+               hook.srcfd = src
+               hook.written = written
+               hook.err = err
+               hook.handled = handled
+       }
+       return
+}
+
+type copyFileHook struct {
        called bool
        dstfd  int
        srcfd  int
-       remain int64
 
        written int64
        handled bool
        err     error
-
-       original func(dst, src *poll.FD, remain int64) (int64, bool, error)
-}
-
-func (h *copyFileRangeHook) install() {
-       h.original = *PollCopyFileRangeP
-       *PollCopyFileRangeP = func(dst, src *poll.FD, remain int64) (int64, bool, error) {
-               h.called = true
-               h.dstfd = dst.Sysfd
-               h.srcfd = src.Sysfd
-               h.remain = remain
-               h.written, h.handled, h.err = h.original(dst, src, remain)
-               return h.written, h.handled, h.err
-       }
-}
-
-func (h *copyFileRangeHook) uninstall() {
-       *PollCopyFileRangeP = h.original
 }
 
 func hookSpliceFile(t *testing.T) *spliceFileHook {
index e3900631baa1ee35bf3484867b87e7785c00361f..a6f8980d10614cc045d6876b92befe710c77e31c 100644 (file)
@@ -102,7 +102,7 @@ func newSendFileTest(t *testing.T, proto string, size int64) (net.Conn, *File, n
        hook := hookSendFile(t)
 
        client, server := createSocketPair(t, proto)
-       tempFile, data := createTempFile(t, size)
+       tempFile, data := createTempFile(t, "writeto-sendfile-to-socket", size)
 
        return client, tempFile, server, data, hook
 }
@@ -134,8 +134,8 @@ type sendFileHook struct {
        err     error
 }
 
-func createTempFile(t *testing.T, size int64) (*File, []byte) {
-       f, err := CreateTemp(t.TempDir(), "writeto-sendfile-to-socket")
+func createTempFile(t *testing.T, name string, size int64) (*File, []byte) {
+       f, err := CreateTemp(t.TempDir(), name)
        if err != nil {
                t.Fatalf("failed to create temporary file: %v", err)
        }
index 0afc19e125a24ef850d7f2e9a943ce7b14032afa..4492c56bf59f4dcad1762e2167e8e54359decc67 100644 (file)
@@ -49,16 +49,17 @@ func (f *File) writeTo(w io.Writer) (written int64, handled bool, err error) {
 }
 
 func (f *File) readFrom(r io.Reader) (written int64, handled bool, err error) {
-       // Neither copy_file_range(2) nor splice(2) supports destinations opened with
+       // Neither copy_file_range(2)/sendfile(2) nor splice(2) supports destinations opened with
        // O_APPEND, so don't bother to try zero-copy with these system calls.
        //
        // Visit https://man7.org/linux/man-pages/man2/copy_file_range.2.html#ERRORS and
+       // https://man7.org/linux/man-pages/man2/sendfile.2.html#ERRORS and
        // https://man7.org/linux/man-pages/man2/splice.2.html#ERRORS for details.
        if f.appendMode {
                return 0, false, nil
        }
 
-       written, handled, err = f.copyFileRange(r)
+       written, handled, err = f.copyFile(r)
        if handled {
                return
        }
@@ -95,7 +96,7 @@ func (f *File) spliceToFile(r io.Reader) (written int64, handled bool, err error
        return written, handled, wrapSyscallError("splice", err)
 }
 
-func (f *File) copyFileRange(r io.Reader) (written int64, handled bool, err error) {
+func (f *File) copyFile(r io.Reader) (written int64, handled bool, err error) {
        var (
                remain int64
                lr     *io.LimitedReader
@@ -124,7 +125,44 @@ func (f *File) copyFileRange(r io.Reader) (written int64, handled bool, err erro
        if lr != nil {
                lr.N -= written
        }
-       return written, handled, wrapSyscallError("copy_file_range", err)
+
+       if handled {
+               return written, handled, wrapSyscallError("copy_file_range", err)
+       }
+
+       // If fd_in and fd_out refer to the same file and the source and target ranges overlap,
+       // copy_file_range(2) just returns EINVAL error. poll.CopyFileRange will ignore that
+       // error and act like it didn't call copy_file_range(2). Then the caller will fall back
+       // to generic copy, which results in doubling the content in the file.
+       // By contrast, sendfile(2) allows this kind of overlapping and works like a memmove,
+       // in this case the file content will remain the same after copying, which is not what we want.
+       // Thus, we just bail out here and leave it to generic copy when it's a file copying itself.
+       if f.pfd.Sysfd == src.pfd.Sysfd {
+               return 0, false, nil
+       }
+
+       sc, err := src.SyscallConn()
+       if err != nil {
+               return
+       }
+
+       // We can employ sendfile(2) when copy_file_range(2) fails to handle the copy.
+       // sendfile(2) enabled file-to-file copying since Linux 2.6.33 and Go requires
+       // Linux 3.17 or later, so we're good to go.
+       // Check out https://man7.org/linux/man-pages/man2/sendfile.2.html#DESCRIPTION for more details.
+       rerr := sc.Read(func(fd uintptr) bool {
+               written, err, handled = poll.SendFile(&f.pfd, int(fd), remain)
+               return true
+       })
+       if lr != nil {
+               lr.N -= written
+       }
+
+       if err == nil {
+               err = rerr
+       }
+
+       return written, handled, wrapSyscallError("sendfile", err)
 }
 
 // getPollFDAndNetwork tries to get the poll.FD and network type from the given interface