]> Cypherpunks repositories - gostls13.git/commitdiff
[release-branch.go1.23] internal/poll: keep copying after successful Sendfile return...
authorDamien Neil <dneil@google.com>
Wed, 23 Oct 2024 23:01:08 +0000 (16:01 -0700)
committerMichael Pratt <mpratt@google.com>
Wed, 30 Oct 2024 17:03:55 +0000 (17:03 +0000)
The BSD implementation of poll.SendFile incorrectly halted
copying after succesfully writing one full chunk of data.
Adjust the copy loop to match the Linux and Solaris
implementations.

In testing, empirically macOS appears to sometimes return
EAGAIN from sendfile after successfully copying a full
chunk. Add a check to all implementations to return nil
after successfully copying all data if the last sendfile
call returns EAGAIN.

For #70000
For #70020

Change-Id: I57ba649491fc078c7330310b23e1cfd85135c8ff
Reviewed-on: https://go-review.googlesource.com/c/go/+/622235
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Ian Lance Taylor <iant@google.com>
(cherry picked from commit bd388c0216bcb33d7325b0ad9722a3be8155a289)
Reviewed-on: https://go-review.googlesource.com/c/go/+/622696

src/internal/poll/sendfile_bsd.go
src/internal/poll/sendfile_linux.go
src/internal/poll/sendfile_solaris.go
src/os/copy_test.go [new file with mode: 0644]
src/os/readfrom_linux_test.go

index 669df94cc12e0d50d25b9b99a896252377018d38..0b0966815deedd1876a203ae0ad121c22fb3d809 100644 (file)
@@ -38,22 +38,25 @@ func SendFile(dstFD *FD, src int, pos, remain int64) (written int64, err error,
                        pos += int64(n)
                        written += int64(n)
                        remain -= int64(n)
+                       continue
+               } else if err != syscall.EAGAIN && err != syscall.EINTR {
+                       // This includes syscall.ENOSYS (no kernel
+                       // support) and syscall.EINVAL (fd types which
+                       // don't implement sendfile), and other errors.
+                       // We should end the loop when there is no error
+                       // returned from sendfile(2) or it is not a retryable error.
+                       break
                }
                if err == syscall.EINTR {
                        continue
                }
-               // This includes syscall.ENOSYS (no kernel
-               // support) and syscall.EINVAL (fd types which
-               // don't implement sendfile), and other errors.
-               // We should end the loop when there is no error
-               // returned from sendfile(2) or it is not a retryable error.
-               if err != syscall.EAGAIN {
-                       break
-               }
                if err = dstFD.pd.waitWrite(dstFD.isFile); err != nil {
                        break
                }
        }
+       if err == syscall.EAGAIN {
+               err = nil
+       }
        handled = written != 0 || (err != syscall.ENOSYS && err != syscall.EINVAL)
        return
 }
index d1c4d5c0d3d34de9c3cd8d3cb8ae7cccce7bee8d..1c4130d45da89c0d23b9075f50bf7070aedc6c9c 100644 (file)
@@ -50,6 +50,9 @@ func SendFile(dstFD *FD, src int, remain int64) (written int64, err error, handl
                        break
                }
        }
+       if err == syscall.EAGAIN {
+               err = nil
+       }
        handled = written != 0 || (err != syscall.ENOSYS && err != syscall.EINVAL)
        return
 }
index ec675833a225dc2ae5af469c3aa986a307a9a3ab..b7c3f81a1efdcda5ca7fa9a1f5a7e2f5e3f6e1bc 100644 (file)
@@ -61,6 +61,9 @@ func SendFile(dstFD *FD, src int, pos, remain int64) (written int64, err error,
                        break
                }
        }
+       if err == syscall.EAGAIN {
+               err = nil
+       }
        handled = written != 0 || (err != syscall.ENOSYS && err != syscall.EINVAL)
        return
 }
diff --git a/src/os/copy_test.go b/src/os/copy_test.go
new file mode 100644 (file)
index 0000000..82346ca
--- /dev/null
@@ -0,0 +1,154 @@
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package os_test
+
+import (
+       "bytes"
+       "errors"
+       "io"
+       "math/rand/v2"
+       "net"
+       "os"
+       "runtime"
+       "sync"
+       "testing"
+
+       "golang.org/x/net/nettest"
+)
+
+// Exercise sendfile/splice fast paths with a moderately large file.
+//
+// https://go.dev/issue/70000
+
+func TestLargeCopyViaNetwork(t *testing.T) {
+       const size = 10 * 1024 * 1024
+       dir := t.TempDir()
+
+       src, err := os.Create(dir + "/src")
+       if err != nil {
+               t.Fatal(err)
+       }
+       defer src.Close()
+       if _, err := io.CopyN(src, newRandReader(), size); err != nil {
+               t.Fatal(err)
+       }
+       if _, err := src.Seek(0, 0); err != nil {
+               t.Fatal(err)
+       }
+
+       dst, err := os.Create(dir + "/dst")
+       if err != nil {
+               t.Fatal(err)
+       }
+       defer dst.Close()
+
+       client, server := createSocketPair(t, "tcp")
+       var wg sync.WaitGroup
+       wg.Add(2)
+       go func() {
+               defer wg.Done()
+               if n, err := io.Copy(dst, server); n != size || err != nil {
+                       t.Errorf("copy to destination = %v, %v; want %v, nil", n, err, size)
+               }
+       }()
+       go func() {
+               defer wg.Done()
+               defer client.Close()
+               if n, err := io.Copy(client, src); n != size || err != nil {
+                       t.Errorf("copy from source = %v, %v; want %v, nil", n, err, size)
+               }
+       }()
+       wg.Wait()
+
+       if _, err := dst.Seek(0, 0); err != nil {
+               t.Fatal(err)
+       }
+       if err := compareReaders(dst, io.LimitReader(newRandReader(), size)); err != nil {
+               t.Fatal(err)
+       }
+}
+
+func compareReaders(a, b io.Reader) error {
+       bufa := make([]byte, 4096)
+       bufb := make([]byte, 4096)
+       for {
+               na, erra := io.ReadFull(a, bufa)
+               if erra != nil && erra != io.EOF {
+                       return erra
+               }
+               nb, errb := io.ReadFull(b, bufb)
+               if errb != nil && errb != io.EOF {
+                       return errb
+               }
+               if !bytes.Equal(bufa[:na], bufb[:nb]) {
+                       return errors.New("contents mismatch")
+               }
+               if erra == io.EOF && errb == io.EOF {
+                       break
+               }
+       }
+       return nil
+}
+
+type randReader struct {
+       rand *rand.Rand
+}
+
+func newRandReader() *randReader {
+       return &randReader{rand.New(rand.NewPCG(0, 0))}
+}
+
+func (r *randReader) Read(p []byte) (int, error) {
+       var v uint64
+       var n int
+       for i := range p {
+               if n == 0 {
+                       v = r.rand.Uint64()
+                       n = 8
+               }
+               p[i] = byte(v & 0xff)
+               v >>= 8
+               n--
+       }
+       return len(p), nil
+}
+
+func createSocketPair(t *testing.T, proto string) (client, server net.Conn) {
+       t.Helper()
+       if !nettest.TestableNetwork(proto) {
+               t.Skipf("%s does not support %q", runtime.GOOS, proto)
+       }
+
+       ln, err := nettest.NewLocalListener(proto)
+       if err != nil {
+               t.Fatalf("NewLocalListener error: %v", err)
+       }
+       t.Cleanup(func() {
+               if ln != nil {
+                       ln.Close()
+               }
+               if client != nil {
+                       client.Close()
+               }
+               if server != nil {
+                       server.Close()
+               }
+       })
+       ch := make(chan struct{})
+       go func() {
+               var err error
+               server, err = ln.Accept()
+               if err != nil {
+                       t.Errorf("Accept new connection error: %v", err)
+               }
+               ch <- struct{}{}
+       }()
+       client, err = net.Dial(proto, ln.Addr().String())
+       <-ch
+       if err != nil {
+               t.Fatalf("Dial new connection error: %v", err)
+       }
+       return client, server
+}
index 8dcb9cb217288263221afa03f877fe5503879e54..45867477dc26b26326ab8fffb1c2c955a30f1ae5 100644 (file)
@@ -14,15 +14,12 @@ import (
        "net"
        . "os"
        "path/filepath"
-       "runtime"
        "strconv"
        "strings"
        "sync"
        "syscall"
        "testing"
        "time"
-
-       "golang.org/x/net/nettest"
 )
 
 func TestCopyFileRange(t *testing.T) {
@@ -784,41 +781,3 @@ func testGetPollFDAndNetwork(t *testing.T, proto string) {
                t.Fatalf("server Control error: %v", err)
        }
 }
-
-func createSocketPair(t *testing.T, proto string) (client, server net.Conn) {
-       t.Helper()
-       if !nettest.TestableNetwork(proto) {
-               t.Skipf("%s does not support %q", runtime.GOOS, proto)
-       }
-
-       ln, err := nettest.NewLocalListener(proto)
-       if err != nil {
-               t.Fatalf("NewLocalListener error: %v", err)
-       }
-       t.Cleanup(func() {
-               if ln != nil {
-                       ln.Close()
-               }
-               if client != nil {
-                       client.Close()
-               }
-               if server != nil {
-                       server.Close()
-               }
-       })
-       ch := make(chan struct{})
-       go func() {
-               var err error
-               server, err = ln.Accept()
-               if err != nil {
-                       t.Errorf("Accept new connection error: %v", err)
-               }
-               ch <- struct{}{}
-       }()
-       client, err = net.Dial(proto, ln.Addr().String())
-       <-ch
-       if err != nil {
-               t.Fatalf("Dial new connection error: %v", err)
-       }
-       return client, server
-}