]> Cypherpunks repositories - gostls13.git/commitdiff
net: use closesocket when closing socket os.File's on Windows
authorqmuntal <quimmuntal@gmail.com>
Tue, 13 May 2025 11:31:22 +0000 (13:31 +0200)
committerQuim Muntal <quimmuntal@gmail.com>
Thu, 15 May 2025 05:21:50 +0000 (22:21 -0700)
The WSASocket documentation states that the returned socket must be
closed by calling closesocket instead of CloseHandle. The different
File methods on the net package return an os.File that is not aware
that it should use closesocket. Ideally, os.NewFile should detect that
the passed handle is a socket and use the appropriate close function,
but there is no reliable way to detect that a handle is a socket on
Windows (see CL 671455).

To work around this, we add a hidden function to the os package that
can be used to return an os.File that uses closesocket. This approach
is the same as used on Unix, which also uses a hidden function for other
purposes.

While here, fix a potential issue with FileConn, which was using File.Fd
rather than File.SyscallConn to get the handle. This could result in the
File being closed and garbage collected before the syscall was made.

Fixes #73683.

Change-Id: I179405f34c63cbbd555d8119e0f77157c670eb3e
Reviewed-on: https://go-review.googlesource.com/c/go/+/672195
Reviewed-by: Damien Neil <dneil@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Michael Knyszek <mknyszek@google.com>
src/internal/poll/fd_windows.go
src/net/fd_windows.go
src/net/file_test.go
src/net/file_windows.go
src/os/file_windows.go

index e846c2cd52509c238c04fe151f74ddcb16408da9..acc2ab0c6e1eaf471f8aee67f5fe6eee8d1d0c68 100644 (file)
@@ -318,7 +318,7 @@ type FD struct {
        // message based socket connection.
        ZeroReadIsEOF bool
 
-       // Whether this is a file rather than a network socket.
+       // Whether the handle is owned by os.File.
        isFile bool
 
        // The kind of this file.
@@ -368,6 +368,7 @@ const (
        kindFile
        kindConsole
        kindPipe
+       kindFileNet
 )
 
 // Init initializes the FD. The Sysfd field should already be set.
@@ -388,6 +389,8 @@ func (fd *FD) Init(net string, pollable bool) error {
                fd.kind = kindConsole
        case "pipe":
                fd.kind = kindPipe
+       case "file+net":
+               fd.kind = kindFileNet
        default:
                // We don't actually care about the various network types.
                fd.kind = kindNet
@@ -453,7 +456,7 @@ func (fd *FD) destroy() error {
        fd.pd.close()
        var err error
        switch fd.kind {
-       case kindNet:
+       case kindNet, kindFileNet:
                // The net package uses the CloseFunc variable for testing.
                err = CloseFunc(fd.Sysfd)
        default:
@@ -494,7 +497,7 @@ func (fd *FD) Read(buf []byte) (int, error) {
                return 0, err
        }
        defer fd.readUnlock()
-       if fd.isFile {
+       if fd.kind == kindFile {
                fd.l.Lock()
                defer fd.l.Unlock()
        }
@@ -747,7 +750,7 @@ func (fd *FD) Write(buf []byte) (int, error) {
                return 0, err
        }
        defer fd.writeUnlock()
-       if fd.isFile {
+       if fd.kind == kindFile {
                fd.l.Lock()
                defer fd.l.Unlock()
        }
index a23be0501f72f6aeb5b65339301a5b96f2d46a5e..52985be8e65403df188fef357abed6640528a84a 100644 (file)
@@ -233,6 +233,9 @@ func (fd *netFD) accept() (*netFD, error) {
        return netfd, nil
 }
 
+// Defined in os package.
+func newWindowsFile(h syscall.Handle, name string) *os.File
+
 func (fd *netFD) dup() (*os.File, error) {
        // Disassociate the IOCP from the socket,
        // it is not safe to share a duplicated handle
@@ -251,5 +254,8 @@ func (fd *netFD) dup() (*os.File, error) {
        if err != nil {
                return nil, err
        }
-       return os.NewFile(uintptr(h), fd.name()), nil
+       // All WSASocket calls must be match with a syscall.Closesocket call,
+       // but os.NewFile calls syscall.CloseHandle instead. We need to use
+       // a hidden function so that the returned file is aware of this fact.
+       return newWindowsFile(h, fd.name()), nil
 }
index b5d007d6cfc4d574f8b25bbc0966acc6b00676f2..51e54ff506bd1f6a9e2f427abfd63e3f19bb6163 100644 (file)
@@ -34,89 +34,90 @@ func TestFileConn(t *testing.T) {
        }
 
        for _, tt := range fileConnTests {
-               if !testableNetwork(tt.network) {
-                       t.Logf("skipping %s test", tt.network)
-                       continue
-               }
+               t.Run(tt.network, func(t *testing.T) {
+                       if !testableNetwork(tt.network) {
+                               t.Skipf("skipping %s test", tt.network)
+                       }
 
-               var network, address string
-               switch tt.network {
-               case "udp":
-                       c := newLocalPacketListener(t, tt.network)
-                       defer c.Close()
-                       network = c.LocalAddr().Network()
-                       address = c.LocalAddr().String()
-               default:
-                       handler := func(ls *localServer, ln Listener) {
-                               c, err := ln.Accept()
-                               if err != nil {
-                                       return
-                               }
+                       var network, address string
+                       switch tt.network {
+                       case "udp":
+                               c := newLocalPacketListener(t, tt.network)
                                defer c.Close()
-                               var b [1]byte
-                               c.Read(b[:])
+                               network = c.LocalAddr().Network()
+                               address = c.LocalAddr().String()
+                       default:
+                               handler := func(ls *localServer, ln Listener) {
+                                       c, err := ln.Accept()
+                                       if err != nil {
+                                               return
+                                       }
+                                       defer c.Close()
+                                       var b [1]byte
+                                       c.Read(b[:])
+                               }
+                               ls := newLocalServer(t, tt.network)
+                               defer ls.teardown()
+                               if err := ls.buildup(handler); err != nil {
+                                       t.Fatal(err)
+                               }
+                               network = ls.Listener.Addr().Network()
+                               address = ls.Listener.Addr().String()
                        }
-                       ls := newLocalServer(t, tt.network)
-                       defer ls.teardown()
-                       if err := ls.buildup(handler); err != nil {
+
+                       c1, err := Dial(network, address)
+                       if err != nil {
+                               if perr := parseDialError(err); perr != nil {
+                                       t.Error(perr)
+                               }
                                t.Fatal(err)
                        }
-                       network = ls.Listener.Addr().Network()
-                       address = ls.Listener.Addr().String()
-               }
+                       addr := c1.LocalAddr()
 
-               c1, err := Dial(network, address)
-               if err != nil {
-                       if perr := parseDialError(err); perr != nil {
-                               t.Error(perr)
+                       var f *os.File
+                       switch c1 := c1.(type) {
+                       case *TCPConn:
+                               f, err = c1.File()
+                       case *UDPConn:
+                               f, err = c1.File()
+                       case *UnixConn:
+                               f, err = c1.File()
                        }
-                       t.Fatal(err)
-               }
-               addr := c1.LocalAddr()
-
-               var f *os.File
-               switch c1 := c1.(type) {
-               case *TCPConn:
-                       f, err = c1.File()
-               case *UDPConn:
-                       f, err = c1.File()
-               case *UnixConn:
-                       f, err = c1.File()
-               }
-               if err := c1.Close(); err != nil {
-                       if perr := parseCloseError(err, false); perr != nil {
-                               t.Error(perr)
+                       if err := c1.Close(); err != nil {
+                               if perr := parseCloseError(err, false); perr != nil {
+                                       t.Error(perr)
+                               }
+                               t.Error(err)
                        }
-                       t.Error(err)
-               }
-               if err != nil {
-                       if perr := parseCommonError(err); perr != nil {
-                               t.Error(perr)
+                       if err != nil {
+                               if perr := parseCommonError(err); perr != nil {
+                                       t.Error(perr)
+                               }
+                               t.Fatal(err)
                        }
-                       t.Fatal(err)
-               }
 
-               c2, err := FileConn(f)
-               if err := f.Close(); err != nil {
-                       t.Error(err)
-               }
-               if err != nil {
-                       if perr := parseCommonError(err); perr != nil {
-                               t.Error(perr)
+                       c2, err := FileConn(f)
+                       if err := f.Close(); err != nil {
+                               t.Error(err)
                        }
-                       t.Fatal(err)
-               }
-               defer c2.Close()
+                       if err != nil {
+                               if perr := parseCommonError(err); perr != nil {
+                                       t.Error(perr)
+                               }
+                               t.Fatal(err)
+                       }
+                       defer c2.Close()
 
-               if _, err := c2.Write([]byte("FILECONN TEST")); err != nil {
-                       if perr := parseWriteError(err); perr != nil {
-                               t.Error(perr)
+                       if _, err := c2.Write([]byte("FILECONN TEST")); err != nil {
+                               if perr := parseWriteError(err); perr != nil {
+                                       t.Error(perr)
+                               }
+                               t.Fatal(err)
                        }
-                       t.Fatal(err)
-               }
-               if !reflect.DeepEqual(c2.LocalAddr(), addr) {
-                       t.Fatalf("got %#v; want %#v", c2.LocalAddr(), addr)
-               }
+                       if !reflect.DeepEqual(c2.LocalAddr(), addr) {
+                               t.Fatalf("got %#v; want %#v", c2.LocalAddr(), addr)
+                       }
+               })
        }
 }
 
@@ -135,81 +136,82 @@ func TestFileListener(t *testing.T) {
        }
 
        for _, tt := range fileListenerTests {
-               if !testableNetwork(tt.network) {
-                       t.Logf("skipping %s test", tt.network)
-                       continue
-               }
+               t.Run(tt.network, func(t *testing.T) {
+                       if !testableNetwork(tt.network) {
+                               t.Skipf("skipping %s test", tt.network)
+                       }
 
-               ln1 := newLocalListener(t, tt.network)
-               switch tt.network {
-               case "unix", "unixpacket":
-                       defer os.Remove(ln1.Addr().String())
-               }
-               addr := ln1.Addr()
+                       ln1 := newLocalListener(t, tt.network)
+                       switch tt.network {
+                       case "unix", "unixpacket":
+                               defer os.Remove(ln1.Addr().String())
+                       }
+                       addr := ln1.Addr()
 
-               var (
-                       f   *os.File
-                       err error
-               )
-               switch ln1 := ln1.(type) {
-               case *TCPListener:
-                       f, err = ln1.File()
-               case *UnixListener:
-                       f, err = ln1.File()
-               }
-               switch tt.network {
-               case "unix", "unixpacket":
-                       defer ln1.Close() // UnixListener.Close calls syscall.Unlink internally
-               default:
-                       if err := ln1.Close(); err != nil {
-                               t.Error(err)
+                       var (
+                               f   *os.File
+                               err error
+                       )
+                       switch ln1 := ln1.(type) {
+                       case *TCPListener:
+                               f, err = ln1.File()
+                       case *UnixListener:
+                               f, err = ln1.File()
                        }
-               }
-               if err != nil {
-                       if perr := parseCommonError(err); perr != nil {
-                               t.Error(perr)
+                       switch tt.network {
+                       case "unix", "unixpacket":
+                               defer ln1.Close() // UnixListener.Close calls syscall.Unlink internally
+                       default:
+                               if err := ln1.Close(); err != nil {
+                                       t.Error(err)
+                               }
+                       }
+                       if err != nil {
+                               if perr := parseCommonError(err); perr != nil {
+                                       t.Error(perr)
+                               }
+                               t.Fatal(err)
                        }
-                       t.Fatal(err)
-               }
 
-               ln2, err := FileListener(f)
-               if err := f.Close(); err != nil {
-                       t.Error(err)
-               }
-               if err != nil {
-                       if perr := parseCommonError(err); perr != nil {
-                               t.Error(perr)
+                       ln2, err := FileListener(f)
+                       if err := f.Close(); err != nil {
+                               t.Error(err)
                        }
-                       t.Fatal(err)
-               }
-               defer ln2.Close()
+                       if err != nil {
+                               if perr := parseCommonError(err); perr != nil {
+                                       t.Error(perr)
+                               }
+                               t.Fatal(err)
+                       }
+                       defer ln2.Close()
 
-               var wg sync.WaitGroup
-               wg.Add(1)
-               go func() {
-                       defer wg.Done()
-                       c, err := Dial(ln2.Addr().Network(), ln2.Addr().String())
+                       var wg sync.WaitGroup
+                       wg.Add(1)
+                       go func() {
+                               defer wg.Done()
+                               c, err := Dial(ln2.Addr().Network(), ln2.Addr().String())
+                               if err != nil {
+                                       if perr := parseDialError(err); perr != nil {
+                                               t.Error(perr)
+                                       }
+                                       t.Error(err)
+                                       return
+                               }
+                               c.Close()
+                       }()
+                       c, err := ln2.Accept()
                        if err != nil {
-                               if perr := parseDialError(err); perr != nil {
+                               if perr := parseAcceptError(err); perr != nil {
                                        t.Error(perr)
                                }
-                               t.Error(err)
-                               return
+                               t.Fatal(err)
                        }
                        c.Close()
-               }()
-               c, err := ln2.Accept()
-               if err != nil {
-                       if perr := parseAcceptError(err); perr != nil {
-                               t.Error(perr)
+                       wg.Wait()
+                       if !reflect.DeepEqual(ln2.Addr(), addr) {
+                               t.Fatalf("got %#v; want %#v", ln2.Addr(), addr)
                        }
-                       t.Fatal(err)
-               }
-               c.Close()
-               wg.Wait()
-               if !reflect.DeepEqual(ln2.Addr(), addr) {
-                       t.Fatalf("got %#v; want %#v", ln2.Addr(), addr)
-               }
+               })
        }
 }
 
@@ -227,62 +229,63 @@ func TestFilePacketConn(t *testing.T) {
        }
 
        for _, tt := range filePacketConnTests {
-               if !testableNetwork(tt.network) {
-                       t.Logf("skipping %s test", tt.network)
-                       continue
-               }
+               t.Run(tt.network, func(t *testing.T) {
+                       if !testableNetwork(tt.network) {
+                               t.Skipf("skipping %s test", tt.network)
+                       }
 
-               c1 := newLocalPacketListener(t, tt.network)
-               switch tt.network {
-               case "unixgram":
-                       defer os.Remove(c1.LocalAddr().String())
-               }
-               addr := c1.LocalAddr()
+                       c1 := newLocalPacketListener(t, tt.network)
+                       switch tt.network {
+                       case "unixgram":
+                               defer os.Remove(c1.LocalAddr().String())
+                       }
+                       addr := c1.LocalAddr()
 
-               var (
-                       f   *os.File
-                       err error
-               )
-               switch c1 := c1.(type) {
-               case *UDPConn:
-                       f, err = c1.File()
-               case *UnixConn:
-                       f, err = c1.File()
-               }
-               if err := c1.Close(); err != nil {
-                       if perr := parseCloseError(err, false); perr != nil {
-                               t.Error(perr)
+                       var (
+                               f   *os.File
+                               err error
+                       )
+                       switch c1 := c1.(type) {
+                       case *UDPConn:
+                               f, err = c1.File()
+                       case *UnixConn:
+                               f, err = c1.File()
                        }
-                       t.Error(err)
-               }
-               if err != nil {
-                       if perr := parseCommonError(err); perr != nil {
-                               t.Error(perr)
+                       if err := c1.Close(); err != nil {
+                               if perr := parseCloseError(err, false); perr != nil {
+                                       t.Error(perr)
+                               }
+                               t.Error(err)
+                       }
+                       if err != nil {
+                               if perr := parseCommonError(err); perr != nil {
+                                       t.Error(perr)
+                               }
+                               t.Fatal(err)
                        }
-                       t.Fatal(err)
-               }
 
-               c2, err := FilePacketConn(f)
-               if err := f.Close(); err != nil {
-                       t.Error(err)
-               }
-               if err != nil {
-                       if perr := parseCommonError(err); perr != nil {
-                               t.Error(perr)
+                       c2, err := FilePacketConn(f)
+                       if err := f.Close(); err != nil {
+                               t.Error(err)
                        }
-                       t.Fatal(err)
-               }
-               defer c2.Close()
+                       if err != nil {
+                               if perr := parseCommonError(err); perr != nil {
+                                       t.Error(perr)
+                               }
+                               t.Fatal(err)
+                       }
+                       defer c2.Close()
 
-               if _, err := c2.WriteTo([]byte("FILEPACKETCONN TEST"), addr); err != nil {
-                       if perr := parseWriteError(err); perr != nil {
-                               t.Error(perr)
+                       if _, err := c2.WriteTo([]byte("FILEPACKETCONN TEST"), addr); err != nil {
+                               if perr := parseWriteError(err); perr != nil {
+                                       t.Error(perr)
+                               }
+                               t.Fatal(err)
                        }
-                       t.Fatal(err)
-               }
-               if !reflect.DeepEqual(c2.LocalAddr(), addr) {
-                       t.Fatalf("got %#v; want %#v", c2.LocalAddr(), addr)
-               }
+                       if !reflect.DeepEqual(c2.LocalAddr(), addr) {
+                               t.Fatalf("got %#v; want %#v", c2.LocalAddr(), addr)
+                       }
+               })
        }
 }
 
index bd7e2bf4806bf8388ee9c1867d072712a8083035..b4eb00e56448a578d509702f71aaa09edc5b3f6a 100644 (file)
@@ -22,9 +22,26 @@ func dupSocket(h syscall.Handle) (syscall.Handle, error) {
 }
 
 func dupFileSocket(f *os.File) (syscall.Handle, error) {
-       // The resulting handle should not be associated to an IOCP, else the IO operations
-       // will block an OS thread, and that's not what net package users expect.
-       h, err := dupSocket(syscall.Handle(f.Fd()))
+       // Call Fd to disassociate the IOCP from the handle,
+       // it is not safe to share a duplicated handle
+       // that is associated with IOCP.
+       // Don't use the returned fd, as it might be closed
+       // if f happens to be the last reference to the file.
+       f.Fd()
+
+       sc, err := f.SyscallConn()
+       if err != nil {
+               return 0, err
+       }
+
+       var h syscall.Handle
+       var syserr error
+       err = sc.Control(func(fd uintptr) {
+               h, syserr = dupSocket(syscall.Handle(fd))
+       })
+       if err != nil {
+               err = syserr
+       }
        if err != nil {
                return 0, err
        }
index c97307371c88ad1c59d3cb32c26d339bffd9fa6a..ee6735fe443137d57f7d32f40be6177bac6cf2df 100644 (file)
@@ -92,6 +92,18 @@ func newFileFromNewFile(fd uintptr, name string) *File {
        return newFile(h, name, "file", nonBlocking)
 }
 
+// net_newWindowsFile is a hidden entry point called by net.conn.File.
+// This is used so that the File.pfd.close method calls [syscall.Closesocket]
+// instead of [syscall.CloseHandle].
+//
+//go:linkname net_newWindowsFile net.newWindowsFile
+func net_newWindowsFile(h syscall.Handle, name string) *File {
+       if h == syscall.InvalidHandle {
+               panic("invalid FD")
+       }
+       return newFile(h, name, "file+net", true)
+}
+
 func epipecheck(file *File, e error) {
 }