]> Cypherpunks repositories - gostls13.git/commitdiff
internal/poll: don't skip empty writes on Windows
authorqmuntal <quimmuntal@gmail.com>
Fri, 28 Mar 2025 09:39:49 +0000 (10:39 +0100)
committerQuim Muntal <quimmuntal@gmail.com>
Fri, 28 Mar 2025 16:44:35 +0000 (09:44 -0700)
Empty writes might be important for some protocols. Let Windows decide
what do with them rather than skipping them on our side. This is inline
with the behavior of other platforms.

While here, refactor the Read/Write/Pwrite methods to reduce one
indentation level and make the code easier to read.

Fixes #73084.

Change-Id: Ic5393358e237d53b8be6097cd7359ac0ff205309
Reviewed-on: https://go-review.googlesource.com/c/go/+/661435
Reviewed-by: Dmitri Shuralyov <dmitshur@google.com>
Reviewed-by: Damien Neil <dneil@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>

src/internal/poll/fd_windows.go
src/internal/poll/fd_windows_test.go
src/internal/syscall/windows/syscall_windows.go

index 14b1febbf46712180f18abd578e37d657604c360..81c8293911d4034b89f29e4a3c5b104d8bf7606f 100644 (file)
@@ -434,6 +434,10 @@ func (fd *FD) Read(buf []byte) (int, error) {
                return 0, err
        }
        defer fd.readUnlock()
+       if fd.isFile {
+               fd.l.Lock()
+               defer fd.l.Unlock()
+       }
 
        if len(buf) > maxRW {
                buf = buf[:maxRW]
@@ -441,36 +445,29 @@ func (fd *FD) Read(buf []byte) (int, error) {
 
        var n int
        var err error
-       if fd.isFile {
-               fd.l.Lock()
-               defer fd.l.Unlock()
-               switch fd.kind {
-               case kindConsole:
-                       n, err = fd.readConsole(buf)
-               default:
-                       o := &fd.rop
-                       o.InitBuf(buf)
-                       n, err = execIO(o, func(o *operation) error {
-                               return syscall.ReadFile(o.fd.Sysfd, unsafe.Slice(o.buf.Buf, o.buf.Len), &o.qty, o.overlapped())
-                       })
-                       fd.addOffset(n)
-                       if fd.kind == kindPipe && err != nil {
-                               switch err {
-                               case syscall.ERROR_BROKEN_PIPE:
-                                       // Returned by pipes when the other end is closed.
-                                       err = nil
-                               case syscall.ERROR_OPERATION_ABORTED:
-                                       // Close uses CancelIoEx to interrupt concurrent I/O for pipes.
-                                       // If the fd is a pipe and the Read was interrupted by CancelIoEx,
-                                       // we assume it is interrupted by Close.
-                                       err = ErrFileClosing
-                               }
+       switch fd.kind {
+       case kindConsole:
+               n, err = fd.readConsole(buf)
+       case kindFile, kindPipe:
+               o := &fd.rop
+               o.InitBuf(buf)
+               n, err = execIO(o, func(o *operation) error {
+                       return syscall.ReadFile(o.fd.Sysfd, unsafe.Slice(o.buf.Buf, o.buf.Len), &o.qty, o.overlapped())
+               })
+               fd.addOffset(n)
+               if fd.kind == kindPipe && err != nil {
+                       switch err {
+                       case syscall.ERROR_BROKEN_PIPE:
+                               // Returned by pipes when the other end is closed.
+                               err = nil
+                       case syscall.ERROR_OPERATION_ABORTED:
+                               // Close uses CancelIoEx to interrupt concurrent I/O for pipes.
+                               // If the fd is a pipe and the Read was interrupted by CancelIoEx,
+                               // we assume it is interrupted by Close.
+                               err = ErrFileClosing
                        }
                }
-               if err != nil {
-                       n = 0
-               }
-       } else {
+       case kindNet:
                o := &fd.rop
                o.InitBuf(buf)
                n, err = execIO(o, func(o *operation) error {
@@ -701,36 +698,32 @@ func (fd *FD) Write(buf []byte) (int, error) {
                defer fd.l.Unlock()
        }
 
-       ntotal := 0
-       for len(buf) > 0 {
-               b := buf
-               if len(b) > maxRW {
-                       b = b[:maxRW]
+       var ntotal int
+       for {
+               max := len(buf)
+               if max-ntotal > maxRW {
+                       max = ntotal + maxRW
                }
+               b := buf[ntotal:max]
                var n int
                var err error
-               if fd.isFile {
-                       switch fd.kind {
-                       case kindConsole:
-                               n, err = fd.writeConsole(b)
-                       default:
-                               o := &fd.wop
-                               o.InitBuf(b)
-                               n, err = execIO(o, func(o *operation) error {
-                                       return syscall.WriteFile(o.fd.Sysfd, unsafe.Slice(o.buf.Buf, o.buf.Len), &o.qty, o.overlapped())
-                               })
-                               fd.addOffset(n)
-                               if fd.kind == kindPipe && err == syscall.ERROR_OPERATION_ABORTED {
-                                       // Close uses CancelIoEx to interrupt concurrent I/O for pipes.
-                                       // If the fd is a pipe and the Write was interrupted by CancelIoEx,
-                                       // we assume it is interrupted by Close.
-                                       err = ErrFileClosing
-                               }
-                       }
-                       if err != nil {
-                               n = 0
+               switch fd.kind {
+               case kindConsole:
+                       n, err = fd.writeConsole(b)
+               case kindPipe, kindFile:
+                       o := &fd.wop
+                       o.InitBuf(b)
+                       n, err = execIO(o, func(o *operation) error {
+                               return syscall.WriteFile(o.fd.Sysfd, unsafe.Slice(o.buf.Buf, o.buf.Len), &o.qty, o.overlapped())
+                       })
+                       fd.addOffset(n)
+                       if fd.kind == kindPipe && err == syscall.ERROR_OPERATION_ABORTED {
+                               // Close uses CancelIoEx to interrupt concurrent I/O for pipes.
+                               // If the fd is a pipe and the Write was interrupted by CancelIoEx,
+                               // we assume it is interrupted by Close.
+                               err = ErrFileClosing
                        }
-               } else {
+               case kindNet:
                        if race.Enabled {
                                race.ReleaseMerge(unsafe.Pointer(&ioSync))
                        }
@@ -741,12 +734,13 @@ func (fd *FD) Write(buf []byte) (int, error) {
                        })
                }
                ntotal += n
-               if err != nil {
+               if ntotal == len(buf) || err != nil {
                        return ntotal, err
                }
-               buf = buf[n:]
+               if n == 0 {
+                       return ntotal, io.ErrUnexpectedEOF
+               }
        }
-       return ntotal, nil
 }
 
 // writeConsole writes len(b) bytes to the console File.
@@ -814,26 +808,29 @@ func (fd *FD) Pwrite(buf []byte, off int64) (int, error) {
        defer syscall.Seek(fd.Sysfd, curoffset, io.SeekStart)
        defer fd.setOffset(curoffset)
 
-       ntotal := 0
-       for len(buf) > 0 {
-               b := buf
-               if len(b) > maxRW {
-                       b = b[:maxRW]
+       var ntotal int
+       for {
+               max := len(buf)
+               if max-ntotal > maxRW {
+                       max = ntotal + maxRW
                }
+               b := buf[ntotal:max]
                o := &fd.wop
                o.InitBuf(b)
-               fd.setOffset(off)
+               fd.setOffset(off + int64(ntotal))
                n, err := execIO(o, func(o *operation) error {
                        return syscall.WriteFile(o.fd.Sysfd, unsafe.Slice(o.buf.Buf, o.buf.Len), &o.qty, &o.o)
                })
-               ntotal += int(n)
-               if err != nil {
+               if n > 0 {
+                       ntotal += n
+               }
+               if ntotal == len(buf) || err != nil {
                        return ntotal, err
                }
-               buf = buf[n:]
-               off += int64(n)
+               if n == 0 {
+                       return ntotal, io.ErrUnexpectedEOF
+               }
        }
-       return ntotal, nil
 }
 
 // Writev emulates the Unix writev system call.
index f2bc9b2a210df93b84fa6fcc9fdee045d89c7dea..f5fa4a26e351620cb47b4e2a6fa144bd1835db01 100644 (file)
@@ -239,7 +239,7 @@ var currentProces = sync.OnceValue(func() string {
 
 var pipeCounter atomic.Uint64
 
-func newPipe(t testing.TB, overlapped bool) (string, *poll.FD) {
+func newPipe(t testing.TB, overlapped, message bool) (string, *poll.FD) {
        name := `\\.\pipe\go-internal-poll-test-` + currentProces() + `-` + strconv.FormatUint(pipeCounter.Add(1), 10)
        wname, err := syscall.UTF16PtrFromString(name)
        if err != nil {
@@ -250,7 +250,11 @@ func newPipe(t testing.TB, overlapped bool) (string, *poll.FD) {
        if overlapped {
                flags |= syscall.FILE_FLAG_OVERLAPPED
        }
-       h, err := windows.CreateNamedPipe(wname, uint32(flags), windows.PIPE_TYPE_BYTE, 1, 4096, 4096, 0, nil)
+       typ := windows.PIPE_TYPE_BYTE
+       if message {
+               typ = windows.PIPE_TYPE_MESSAGE
+       }
+       h, err := windows.CreateNamedPipe(wname, uint32(flags), uint32(typ), 1, 4096, 4096, 0, nil)
        if err != nil {
                t.Fatal(err)
        }
@@ -358,22 +362,22 @@ func TestFile(t *testing.T) {
 
 func TestPipe(t *testing.T) {
        t.Run("overlapped", func(t *testing.T) {
-               name, pipe := newPipe(t, true)
+               name, pipe := newPipe(t, true, false)
                file := newFile(t, name, true)
                testReadWrite(t, pipe, file)
        })
        t.Run("overlapped-write", func(t *testing.T) {
-               name, pipe := newPipe(t, true)
+               name, pipe := newPipe(t, true, false)
                file := newFile(t, name, false)
                testReadWrite(t, file, pipe)
        })
        t.Run("overlapped-read", func(t *testing.T) {
-               name, pipe := newPipe(t, false)
+               name, pipe := newPipe(t, false, false)
                file := newFile(t, name, true)
                testReadWrite(t, file, pipe)
        })
        t.Run("sync", func(t *testing.T) {
-               name, pipe := newPipe(t, false)
+               name, pipe := newPipe(t, false, false)
                file := newFile(t, name, false)
                testReadWrite(t, file, pipe)
        })
@@ -397,6 +401,28 @@ func TestPipe(t *testing.T) {
        })
 }
 
+func TestPipeWriteEOF(t *testing.T) {
+       name, pipe := newPipe(t, false, true)
+       file := newFile(t, name, false)
+       read := make(chan struct{}, 1)
+       go func() {
+               _, err := pipe.Write(nil)
+               read <- struct{}{}
+               if err != nil {
+                       t.Error(err)
+               }
+       }()
+       <-read
+       var buf [10]byte
+       n, err := file.Read(buf[:])
+       if err != io.EOF {
+               t.Errorf("expected EOF, got %v", err)
+       }
+       if n != 0 {
+               t.Errorf("expected 0 bytes, got %d", n)
+       }
+}
+
 func BenchmarkReadOverlapped(b *testing.B) {
        benchmarkRead(b, true)
 }
index af542c80036bde9227be2e1c387a6449404c97b6..3a197f1c261625aa4000fc82ecf617ec96272f6c 100644 (file)
@@ -509,7 +509,8 @@ const (
        PIPE_ACCESS_OUTBOUND = 0x00000002
        PIPE_ACCESS_DUPLEX   = 0x00000003
 
-       PIPE_TYPE_BYTE = 0x00000000
+       PIPE_TYPE_BYTE    = 0x00000000
+       PIPE_TYPE_MESSAGE = 0x00000004
 )
 
 //sys  GetOverlappedResult(handle syscall.Handle, overlapped *syscall.Overlapped, done *uint32, wait bool) (err error)