]> Cypherpunks repositories - gostls13.git/commitdiff
net: fix panic when calling net.Listen or net.Dial on wasip1
authorAchille Roussel <achille.roussel@gmail.com>
Sat, 10 Jun 2023 18:31:43 +0000 (11:31 -0700)
committerGopher Robot <gobot@golang.org>
Wed, 14 Jun 2023 01:36:27 +0000 (01:36 +0000)
Address a panic that was caused by net.Dial/net.Listen entering the fake
network stack and assuming that the addresses would be of type *TCPAddr,
where in fact they could have been *UDPAddr or *UnixAddr as well.

The fix consist in implementing the fake network facility for udp and
unix addresses, preventing the assumed type assertion to TCPAddr from
triggering a panic. New tests are added to verify that using the fake
network from the exported functions of the net package satisfies the
minimal requirement of being able to create a listener and establish a
connection for all the supported network types.

Fixes #60012
Fixes #60739

Change-Id: I2688f1a0a7c6c9894ad3d137a5d311192c77a9b2
Reviewed-on: https://go-review.googlesource.com/c/go/+/502315
Reviewed-by: Dmitri Shuralyov <dmitshur@google.com>
TryBot-Result: Gopher Robot <gobot@golang.org>
Run-TryBot: Dmitri Shuralyov <dmitshur@golang.org>
Reviewed-by: Johan Brandhorst-Satzkorn <johan.brandhorst@gmail.com>
Reviewed-by: Ian Lance Taylor <iant@google.com>
Auto-Submit: Dmitri Shuralyov <dmitshur@golang.org>

src/net/net_fake.go
src/net/net_fake_test.go [new file with mode: 0644]

index a816213f8de0d86fe2d4891bb67935ead1687d2d..68d36966ca2ae46f7ce8e33f278f9f0c864ec5b9 100644 (file)
@@ -20,7 +20,7 @@ import (
 )
 
 var listenersMu sync.Mutex
-var listeners = make(map[string]*netFD)
+var listeners = make(map[fakeNetAddr]*netFD)
 
 var portCounterMu sync.Mutex
 var portCounter = 0
@@ -32,13 +32,16 @@ func nextPort() int {
        return portCounter
 }
 
+type fakeNetAddr struct {
+       network string
+       address string
+}
+
 type fakeNetFD struct {
-       listener bool
-       laddr    Addr
+       listener fakeNetAddr
        r        *bufferedPipe
        w        *bufferedPipe
        incoming chan *netFD
-
        closedMu sync.Mutex
        closed   bool
 }
@@ -51,32 +54,110 @@ func socket(ctx context.Context, net string, family, sotype, proto int, ipv6only
                return fakelistener(fd, laddr)
        }
        fd2 := &netFD{family: family, sotype: sotype, net: net}
-       return fakeconn(fd, fd2, raddr)
+       return fakeconn(fd, fd2, laddr, raddr)
+}
+
+func fakeIPAndPort(ip IP, port int) (IP, int) {
+       if ip == nil {
+               ip = IPv4(127, 0, 0, 1)
+       }
+       if port == 0 {
+               port = nextPort()
+       }
+       return ip, port
+}
+
+func fakeTCPAddr(addr *TCPAddr) *TCPAddr {
+       var ip IP
+       var port int
+       var zone string
+       if addr != nil {
+               ip, port, zone = addr.IP, addr.Port, addr.Zone
+       }
+       ip, port = fakeIPAndPort(ip, port)
+       return &TCPAddr{IP: ip, Port: port, Zone: zone}
+}
+
+func fakeUDPAddr(addr *UDPAddr) *UDPAddr {
+       var ip IP
+       var port int
+       var zone string
+       if addr != nil {
+               ip, port, zone = addr.IP, addr.Port, addr.Zone
+       }
+       ip, port = fakeIPAndPort(ip, port)
+       return &UDPAddr{IP: ip, Port: port, Zone: zone}
+}
+
+func fakeUnixAddr(sotype int, addr *UnixAddr) *UnixAddr {
+       var net, name string
+       if addr != nil {
+               name = addr.Name
+       }
+       switch sotype {
+       case syscall.SOCK_DGRAM:
+               net = "unixgram"
+       case syscall.SOCK_SEQPACKET:
+               net = "unixpacket"
+       default:
+               net = "unix"
+       }
+       return &UnixAddr{Net: net, Name: name}
 }
 
 func fakelistener(fd *netFD, laddr sockaddr) (*netFD, error) {
-       l := laddr.(*TCPAddr)
-       fd.laddr = &TCPAddr{
-               IP:   l.IP,
-               Port: nextPort(),
-               Zone: l.Zone,
+       switch l := laddr.(type) {
+       case *TCPAddr:
+               laddr = fakeTCPAddr(l)
+       case *UDPAddr:
+               laddr = fakeUDPAddr(l)
+       case *UnixAddr:
+               if l.Name == "" {
+                       return nil, syscall.ENOENT
+               }
+               laddr = fakeUnixAddr(fd.sotype, l)
+       default:
+               return nil, syscall.EOPNOTSUPP
        }
+
+       listener := fakeNetAddr{
+               network: laddr.Network(),
+               address: laddr.String(),
+       }
+
        fd.fakeNetFD = &fakeNetFD{
-               listener: true,
-               laddr:    fd.laddr,
+               listener: listener,
                incoming: make(chan *netFD, 1024),
        }
+
+       fd.laddr = laddr
        listenersMu.Lock()
-       listeners[fd.laddr.(*TCPAddr).String()] = fd
-       listenersMu.Unlock()
+       defer listenersMu.Unlock()
+       if _, exists := listeners[listener]; exists {
+               return nil, syscall.EADDRINUSE
+       }
+       listeners[listener] = fd
        return fd, nil
 }
 
-func fakeconn(fd *netFD, fd2 *netFD, raddr sockaddr) (*netFD, error) {
-       fd.laddr = &TCPAddr{
-               IP:   IPv4(127, 0, 0, 1),
-               Port: nextPort(),
+func fakeconn(fd *netFD, fd2 *netFD, laddr, raddr sockaddr) (*netFD, error) {
+       switch r := raddr.(type) {
+       case *TCPAddr:
+               r = fakeTCPAddr(r)
+               raddr = r
+               laddr = fakeTCPAddr(laddr.(*TCPAddr))
+       case *UDPAddr:
+               r = fakeUDPAddr(r)
+               raddr = r
+               laddr = fakeUDPAddr(laddr.(*UDPAddr))
+       case *UnixAddr:
+               r = fakeUnixAddr(fd.sotype, r)
+               raddr = r
+               laddr = &UnixAddr{Net: r.Net, Name: r.Name}
+       default:
+               return nil, syscall.EAFNOSUPPORT
        }
+       fd.laddr = laddr
        fd.raddr = raddr
 
        fd.fakeNetFD = &fakeNetFD{
@@ -90,15 +171,18 @@ func fakeconn(fd *netFD, fd2 *netFD, raddr sockaddr) (*netFD, error) {
 
        fd2.laddr = fd.raddr
        fd2.raddr = fd.laddr
+
+       listener := fakeNetAddr{
+               network: fd.raddr.Network(),
+               address: fd.raddr.String(),
+       }
        listenersMu.Lock()
-       l, ok := listeners[fd.raddr.(*TCPAddr).String()]
+       defer listenersMu.Unlock()
+       l, ok := listeners[listener]
        if !ok {
-               listenersMu.Unlock()
                return nil, syscall.ECONNREFUSED
        }
        l.incoming <- fd2
-       listenersMu.Unlock()
-
        return fd, nil
 }
 
@@ -119,11 +203,11 @@ func (fd *fakeNetFD) Close() error {
        fd.closed = true
        fd.closedMu.Unlock()
 
-       if fd.listener {
+       if fd.listener != (fakeNetAddr{}) {
                listenersMu.Lock()
-               delete(listeners, fd.laddr.String())
+               delete(listeners, fd.listener)
                close(fd.incoming)
-               fd.listener = false
+               fd.listener = fakeNetAddr{}
                listenersMu.Unlock()
                return nil
        }
diff --git a/src/net/net_fake_test.go b/src/net/net_fake_test.go
new file mode 100644 (file)
index 0000000..783304d
--- /dev/null
@@ -0,0 +1,203 @@
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+//go:build js || wasip1
+
+package net
+
+// GOOS=js and GOOS=wasip1 do not have typical socket networking capabilities
+// found on other platforms. To help run test suites of the stdlib packages,
+// an in-memory "fake network" facility is implemented.
+//
+// The tests in this files are intended to validate the behavior of the fake
+// network stack on these platforms.
+
+import "testing"
+
+func TestFakeConn(t *testing.T) {
+       tests := []struct {
+               name   string
+               listen func() (Listener, error)
+               dial   func(Addr) (Conn, error)
+               addr   func(*testing.T, Addr)
+       }{
+               {
+                       name: "Listener:tcp",
+                       listen: func() (Listener, error) {
+                               return Listen("tcp", ":0")
+                       },
+                       dial: func(addr Addr) (Conn, error) {
+                               return Dial(addr.Network(), addr.String())
+                       },
+                       addr: testFakeTCPAddr,
+               },
+
+               {
+                       name: "ListenTCP:tcp",
+                       listen: func() (Listener, error) {
+                               // Creating a listening TCP connection with a nil address must
+                               // select an IP address on localhost with a random port.
+                               // This test verifies that the fake network facility does that.
+                               return ListenTCP("tcp", nil)
+                       },
+                       dial: func(addr Addr) (Conn, error) {
+                               // Connecting a listening TCP connection will select a local
+                               // address on the local network and connects to the destination
+                               // address.
+                               return DialTCP("tcp", nil, addr.(*TCPAddr))
+                       },
+                       addr: testFakeTCPAddr,
+               },
+
+               {
+                       name: "ListenUnix:unix",
+                       listen: func() (Listener, error) {
+                               return ListenUnix("unix", &UnixAddr{Name: "test"})
+                       },
+                       dial: func(addr Addr) (Conn, error) {
+                               return DialUnix("unix", nil, addr.(*UnixAddr))
+                       },
+                       addr: testFakeUnixAddr("unix", "test"),
+               },
+
+               {
+                       name: "ListenUnix:unixpacket",
+                       listen: func() (Listener, error) {
+                               return ListenUnix("unixpacket", &UnixAddr{Name: "test"})
+                       },
+                       dial: func(addr Addr) (Conn, error) {
+                               return DialUnix("unixpacket", nil, addr.(*UnixAddr))
+                       },
+                       addr: testFakeUnixAddr("unixpacket", "test"),
+               },
+       }
+
+       for _, test := range tests {
+               t.Run(test.name, func(t *testing.T) {
+                       l, err := test.listen()
+                       if err != nil {
+                               t.Fatal(err)
+                       }
+                       defer l.Close()
+                       test.addr(t, l.Addr())
+
+                       c, err := test.dial(l.Addr())
+                       if err != nil {
+                               t.Fatal(err)
+                       }
+                       defer c.Close()
+                       test.addr(t, c.LocalAddr())
+                       test.addr(t, c.RemoteAddr())
+               })
+       }
+}
+
+func TestFakePacketConn(t *testing.T) {
+       tests := []struct {
+               name   string
+               listen func() (PacketConn, error)
+               dial   func(Addr) (Conn, error)
+               addr   func(*testing.T, Addr)
+       }{
+               {
+                       name: "ListenPacket:udp",
+                       listen: func() (PacketConn, error) {
+                               return ListenPacket("udp", ":0")
+                       },
+                       dial: func(addr Addr) (Conn, error) {
+                               return Dial(addr.Network(), addr.String())
+                       },
+                       addr: testFakeUDPAddr,
+               },
+
+               {
+                       name: "ListenUDP:udp",
+                       listen: func() (PacketConn, error) {
+                               // Creating a listening UDP connection with a nil address must
+                               // select an IP address on localhost with a random port.
+                               // This test verifies that the fake network facility does that.
+                               return ListenUDP("udp", nil)
+                       },
+                       dial: func(addr Addr) (Conn, error) {
+                               // Connecting a listening UDP connection will select a local
+                               // address on the local network and connects to the destination
+                               // address.
+                               return DialUDP("udp", nil, addr.(*UDPAddr))
+                       },
+                       addr: testFakeUDPAddr,
+               },
+
+               {
+                       name: "ListenUnixgram:unixgram",
+                       listen: func() (PacketConn, error) {
+                               return ListenUnixgram("unixgram", &UnixAddr{Name: "test"})
+                       },
+                       dial: func(addr Addr) (Conn, error) {
+                               return DialUnix("unixgram", nil, addr.(*UnixAddr))
+                       },
+                       addr: testFakeUnixAddr("unixgram", "test"),
+               },
+       }
+
+       for _, test := range tests {
+               t.Run(test.name, func(t *testing.T) {
+                       l, err := test.listen()
+                       if err != nil {
+                               t.Fatal(err)
+                       }
+                       defer l.Close()
+                       test.addr(t, l.LocalAddr())
+
+                       c, err := test.dial(l.LocalAddr())
+                       if err != nil {
+                               t.Fatal(err)
+                       }
+                       defer c.Close()
+                       test.addr(t, c.LocalAddr())
+                       test.addr(t, c.RemoteAddr())
+               })
+       }
+}
+
+func testFakeTCPAddr(t *testing.T, addr Addr) {
+       t.Helper()
+       if a, ok := addr.(*TCPAddr); !ok {
+               t.Errorf("Addr is not *TCPAddr: %T", addr)
+       } else {
+               testFakeNetAddr(t, a.IP, a.Port)
+       }
+}
+
+func testFakeUDPAddr(t *testing.T, addr Addr) {
+       t.Helper()
+       if a, ok := addr.(*UDPAddr); !ok {
+               t.Errorf("Addr is not *UDPAddr: %T", addr)
+       } else {
+               testFakeNetAddr(t, a.IP, a.Port)
+       }
+}
+
+func testFakeNetAddr(t *testing.T, ip IP, port int) {
+       t.Helper()
+       if port == 0 {
+               t.Error("network address is missing port")
+       } else if len(ip) == 0 {
+               t.Error("network address is missing IP")
+       } else if !ip.Equal(IPv4(127, 0, 0, 1)) {
+               t.Errorf("network address has wrong IP: %s", ip)
+       }
+}
+
+func testFakeUnixAddr(net, name string) func(*testing.T, Addr) {
+       return func(t *testing.T, addr Addr) {
+               t.Helper()
+               if a, ok := addr.(*UnixAddr); !ok {
+                       t.Errorf("Addr is not *UnixAddr: %T", addr)
+               } else if a.Net != net {
+                       t.Errorf("unix address has wrong net: want=%q got=%q", net, a.Net)
+               } else if a.Name != name {
+                       t.Errorf("unix address has wrong name: want=%q got=%q", name, a.Name)
+               }
+       }
+}