]> Cypherpunks repositories - gostls13.git/commitdiff
internal/poll: keep copying after successful Sendfile return on BSD
authorDamien Neil <dneil@google.com>
Wed, 23 Oct 2024 23:01:08 +0000 (16:01 -0700)
committerDamien Neil <dneil@google.com>
Thu, 24 Oct 2024 01:18:06 +0000 (01:18 +0000)
The BSD implementation of poll.SendFile incorrectly halted
copying after succesfully writing one full chunk of data.
Adjust the copy loop to match the Linux and Solaris
implementations.

In testing, empirically macOS appears to sometimes return
EAGAIN from sendfile after successfully copying a full
chunk. Add a check to all implementations to return nil
after successfully copying all data if the last sendfile
call returns EAGAIN.

For #70000

Change-Id: I57ba649491fc078c7330310b23e1cfd85135c8ff
Reviewed-on: https://go-review.googlesource.com/c/go/+/622235
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Ian Lance Taylor <iant@google.com>
src/internal/poll/sendfile_bsd.go
src/internal/poll/sendfile_linux.go
src/internal/poll/sendfile_solaris.go
src/os/copy_test.go [new file with mode: 0644]
src/os/readfrom_linux_test.go

index f82af4407d0750a0d071c6c1e68a52dc720fbd16..bc665978c39e6e90f96e1c0cc993c417d94f643c 100644 (file)
@@ -41,22 +41,25 @@ func SendFile(dstFD *FD, src int, pos, remain int64) (written int64, err error,
                        pos += int64(n)
                        written += int64(n)
                        remain -= int64(n)
+                       continue
+               } else if err != syscall.EAGAIN && err != syscall.EINTR {
+                       // This includes syscall.ENOSYS (no kernel
+                       // support) and syscall.EINVAL (fd types which
+                       // don't implement sendfile), and other errors.
+                       // We should end the loop when there is no error
+                       // returned from sendfile(2) or it is not a retryable error.
+                       break
                }
                if err == syscall.EINTR {
                        continue
                }
-               // This includes syscall.ENOSYS (no kernel
-               // support) and syscall.EINVAL (fd types which
-               // don't implement sendfile), and other errors.
-               // We should end the loop when there is no error
-               // returned from sendfile(2) or it is not a retryable error.
-               if err != syscall.EAGAIN {
-                       break
-               }
                if err = dstFD.pd.waitWrite(dstFD.isFile); err != nil {
                        break
                }
        }
+       if err == syscall.EAGAIN {
+               err = nil
+       }
        handled = written != 0 || (err != syscall.ENOSYS && err != syscall.EINVAL)
        return
 }
index 1b618a72a2b4cd7b50de81a111949178ade99d84..7e800a3b7ea222ea71debbddeadbdfb095b39358 100644 (file)
@@ -53,6 +53,9 @@ func SendFile(dstFD *FD, src int, remain int64) (written int64, err error, handl
                        break
                }
        }
+       if err == syscall.EAGAIN {
+               err = nil
+       }
        handled = written != 0 || (err != syscall.ENOSYS && err != syscall.EINVAL)
        return
 }
index e6f2d908a1d365c8aaaaf4db7eb1d87541023982..113214ff391fc0e923e5b2f87badf966ef7a3aaa 100644 (file)
@@ -71,6 +71,9 @@ func SendFile(dstFD *FD, src int, pos, remain int64) (written int64, err error,
                        break
                }
        }
+       if err == syscall.EAGAIN {
+               err = nil
+       }
        handled = written != 0 || (err != syscall.ENOSYS && err != syscall.EINVAL)
        return
 }
diff --git a/src/os/copy_test.go b/src/os/copy_test.go
new file mode 100644 (file)
index 0000000..82346ca
--- /dev/null
@@ -0,0 +1,154 @@
+// Copyright 2024 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package os_test
+
+import (
+       "bytes"
+       "errors"
+       "io"
+       "math/rand/v2"
+       "net"
+       "os"
+       "runtime"
+       "sync"
+       "testing"
+
+       "golang.org/x/net/nettest"
+)
+
+// Exercise sendfile/splice fast paths with a moderately large file.
+//
+// https://go.dev/issue/70000
+
+func TestLargeCopyViaNetwork(t *testing.T) {
+       const size = 10 * 1024 * 1024
+       dir := t.TempDir()
+
+       src, err := os.Create(dir + "/src")
+       if err != nil {
+               t.Fatal(err)
+       }
+       defer src.Close()
+       if _, err := io.CopyN(src, newRandReader(), size); err != nil {
+               t.Fatal(err)
+       }
+       if _, err := src.Seek(0, 0); err != nil {
+               t.Fatal(err)
+       }
+
+       dst, err := os.Create(dir + "/dst")
+       if err != nil {
+               t.Fatal(err)
+       }
+       defer dst.Close()
+
+       client, server := createSocketPair(t, "tcp")
+       var wg sync.WaitGroup
+       wg.Add(2)
+       go func() {
+               defer wg.Done()
+               if n, err := io.Copy(dst, server); n != size || err != nil {
+                       t.Errorf("copy to destination = %v, %v; want %v, nil", n, err, size)
+               }
+       }()
+       go func() {
+               defer wg.Done()
+               defer client.Close()
+               if n, err := io.Copy(client, src); n != size || err != nil {
+                       t.Errorf("copy from source = %v, %v; want %v, nil", n, err, size)
+               }
+       }()
+       wg.Wait()
+
+       if _, err := dst.Seek(0, 0); err != nil {
+               t.Fatal(err)
+       }
+       if err := compareReaders(dst, io.LimitReader(newRandReader(), size)); err != nil {
+               t.Fatal(err)
+       }
+}
+
+func compareReaders(a, b io.Reader) error {
+       bufa := make([]byte, 4096)
+       bufb := make([]byte, 4096)
+       for {
+               na, erra := io.ReadFull(a, bufa)
+               if erra != nil && erra != io.EOF {
+                       return erra
+               }
+               nb, errb := io.ReadFull(b, bufb)
+               if errb != nil && errb != io.EOF {
+                       return errb
+               }
+               if !bytes.Equal(bufa[:na], bufb[:nb]) {
+                       return errors.New("contents mismatch")
+               }
+               if erra == io.EOF && errb == io.EOF {
+                       break
+               }
+       }
+       return nil
+}
+
+type randReader struct {
+       rand *rand.Rand
+}
+
+func newRandReader() *randReader {
+       return &randReader{rand.New(rand.NewPCG(0, 0))}
+}
+
+func (r *randReader) Read(p []byte) (int, error) {
+       var v uint64
+       var n int
+       for i := range p {
+               if n == 0 {
+                       v = r.rand.Uint64()
+                       n = 8
+               }
+               p[i] = byte(v & 0xff)
+               v >>= 8
+               n--
+       }
+       return len(p), nil
+}
+
+func createSocketPair(t *testing.T, proto string) (client, server net.Conn) {
+       t.Helper()
+       if !nettest.TestableNetwork(proto) {
+               t.Skipf("%s does not support %q", runtime.GOOS, proto)
+       }
+
+       ln, err := nettest.NewLocalListener(proto)
+       if err != nil {
+               t.Fatalf("NewLocalListener error: %v", err)
+       }
+       t.Cleanup(func() {
+               if ln != nil {
+                       ln.Close()
+               }
+               if client != nil {
+                       client.Close()
+               }
+               if server != nil {
+                       server.Close()
+               }
+       })
+       ch := make(chan struct{})
+       go func() {
+               var err error
+               server, err = ln.Accept()
+               if err != nil {
+                       t.Errorf("Accept new connection error: %v", err)
+               }
+               ch <- struct{}{}
+       }()
+       client, err = net.Dial(proto, ln.Addr().String())
+       <-ch
+       if err != nil {
+               t.Fatalf("Dial new connection error: %v", err)
+       }
+       return client, server
+}
index 2dcc7f0cf8d8c53d859a5b7dfb69850643f861c0..cc0322882b1305013ce8797f8e9f17f39fc1cc87 100644 (file)
@@ -14,14 +14,11 @@ import (
        "net"
        . "os"
        "path/filepath"
-       "runtime"
        "strconv"
        "sync"
        "syscall"
        "testing"
        "time"
-
-       "golang.org/x/net/nettest"
 )
 
 func TestSpliceFile(t *testing.T) {
@@ -479,41 +476,3 @@ func testGetPollFDAndNetwork(t *testing.T, proto string) {
                t.Fatalf("server Control error: %v", err)
        }
 }
-
-func createSocketPair(t *testing.T, proto string) (client, server net.Conn) {
-       t.Helper()
-       if !nettest.TestableNetwork(proto) {
-               t.Skipf("%s does not support %q", runtime.GOOS, proto)
-       }
-
-       ln, err := nettest.NewLocalListener(proto)
-       if err != nil {
-               t.Fatalf("NewLocalListener error: %v", err)
-       }
-       t.Cleanup(func() {
-               if ln != nil {
-                       ln.Close()
-               }
-               if client != nil {
-                       client.Close()
-               }
-               if server != nil {
-                       server.Close()
-               }
-       })
-       ch := make(chan struct{})
-       go func() {
-               var err error
-               server, err = ln.Accept()
-               if err != nil {
-                       t.Errorf("Accept new connection error: %v", err)
-               }
-               ch <- struct{}{}
-       }()
-       client, err = net.Dial(proto, ln.Addr().String())
-       <-ch
-       if err != nil {
-               t.Fatalf("Dial new connection error: %v", err)
-       }
-       return client, server
-}