]> Cypherpunks repositories - gostls13.git/commitdiff
net: avoid nil pointer dereference when RemoteAddr.String method chain is called
authorMikio Hara <mikioh.mikioh@gmail.com>
Thu, 23 Aug 2012 11:54:00 +0000 (20:54 +0900)
committerMikio Hara <mikioh.mikioh@gmail.com>
Thu, 23 Aug 2012 11:54:00 +0000 (20:54 +0900)
Fixes #3721.

R=dave, rsc
CC=golang-dev
https://golang.org/cl/6395055

src/pkg/net/fd.go
src/pkg/net/file.go
src/pkg/net/ipraw_test.go
src/pkg/net/sock.go
src/pkg/net/udp_test.go

index 52527ec8f2f744fc41abb3390d52b2935abbec5d..e9927d9534af980b4c3d2b8f7ac9b09783b24df3 100644 (file)
@@ -612,11 +612,10 @@ func (fd *netFD) accept(toAddr func(syscall.Sockaddr) Addr) (netfd *netFD, err e
        syscall.ForkLock.RUnlock()
 
        if netfd, err = newFD(s, fd.family, fd.sotype, fd.net); err != nil {
-               syscall.Close(s)
+               closesocket(s)
                return nil, err
        }
-       lsa, _ := syscall.Getsockname(netfd.sysfd)
-       netfd.setAddr(toAddr(lsa), toAddr(rsa))
+       netfd.setAddr(localSockname(fd, toAddr), toAddr(rsa))
        return netfd, nil
 }
 
index 1abf24f2d6502a9c808c3cfd6498a7fe76e13c85..11c8f77a8254130e6bed3d3d3bf4808d36348539 100644 (file)
@@ -25,8 +25,8 @@ func newFileFD(f *os.File) (*netFD, error) {
 
        family := syscall.AF_UNSPEC
        toAddr := sockaddrToTCP
-       sa, _ := syscall.Getsockname(fd)
-       switch sa.(type) {
+       lsa, _ := syscall.Getsockname(fd)
+       switch lsa.(type) {
        default:
                closesocket(fd)
                return nil, syscall.EINVAL
@@ -53,16 +53,14 @@ func newFileFD(f *os.File) (*netFD, error) {
                        toAddr = sockaddrToUnixpacket
                }
        }
-       laddr := toAddr(sa)
-       sa, _ = syscall.Getpeername(fd)
-       raddr := toAddr(sa)
+       laddr := toAddr(lsa)
 
        netfd, err := newFD(fd, family, sotype, laddr.Network())
        if err != nil {
                closesocket(fd)
                return nil, err
        }
-       netfd.setAddr(laddr, raddr)
+       netfd.setAddr(laddr, remoteSockname(netfd, toAddr))
        return netfd, nil
 }
 
@@ -80,10 +78,10 @@ func FileConn(f *os.File) (c Conn, err error) {
                return newTCPConn(fd), nil
        case *UDPAddr:
                return newUDPConn(fd), nil
-       case *UnixAddr:
-               return newUnixConn(fd), nil
        case *IPAddr:
                return newIPConn(fd), nil
+       case *UnixAddr:
+               return newUnixConn(fd), nil
        }
        fd.Close()
        return nil, syscall.EINVAL
index 0a28827e330d38917c93af755fb2001a575b81a7..5b5b68377f7dbf3a6e976ef14a1138197421c883 100644 (file)
@@ -14,6 +14,55 @@ import (
        "time"
 )
 
+var ipConnAddrStringTests = []struct {
+       net   string
+       laddr string
+       raddr string
+       ipv6  bool
+}{
+       {"ip:icmp", "127.0.0.1", "", false},
+       {"ip:icmp", "::1", "", true},
+       {"ip:icmp", "", "127.0.0.1", false},
+       {"ip:icmp", "", "::1", true},
+}
+
+func TestIPConnAddrString(t *testing.T) {
+       if os.Getuid() != 0 {
+               t.Logf("skipping test; must be root")
+               return
+       }
+
+       for i, tt := range ipConnAddrStringTests {
+               if tt.ipv6 && !supportsIPv6 {
+                       continue
+               }
+               var (
+                       err  error
+                       c    *IPConn
+                       mode string
+               )
+               if tt.raddr == "" {
+                       mode = "listen"
+                       la, _ := ResolveIPAddr(tt.net, tt.laddr)
+                       c, err = ListenIP(tt.net, la)
+                       if err != nil {
+                               t.Fatalf("ListenIP(%q, %q) failed: %v", tt.net, la.String(), err)
+                       }
+               } else {
+                       mode = "dial"
+                       la, _ := ResolveIPAddr(tt.net, tt.laddr)
+                       ra, _ := ResolveIPAddr(tt.net, tt.raddr)
+                       c, err = DialIP(tt.net, la, ra)
+                       if err != nil {
+                               t.Fatalf("DialIP(%q, %q) failed: %v", tt.net, ra.String(), err)
+                       }
+               }
+               t.Logf("%s-%v: LocalAddr: %q, %q", mode, i, c.LocalAddr(), c.LocalAddr().String())
+               t.Logf("%s-%v: RemoteAddr: %q, %q", mode, i, c.RemoteAddr(), c.RemoteAddr().String())
+               c.Close()
+       }
+}
+
 var icmpTests = []struct {
        net   string
        laddr string
@@ -26,7 +75,7 @@ var icmpTests = []struct {
 
 func TestICMP(t *testing.T) {
        if os.Getuid() != 0 {
-               t.Logf("test disabled; must be root")
+               t.Logf("skipping test; must be root")
                return
        }
 
index 3ae16054e4726a142b8ddd56f0477a171d127655..bc9606048a88e87af81254a8926174bbc25de0dc 100644 (file)
@@ -16,7 +16,7 @@ import (
 var listenerBacklog = maxListenerBacklog()
 
 // Generic socket creation.
-func socket(net string, f, t, p int, ipv6only bool, la, ra syscall.Sockaddr, toAddr func(syscall.Sockaddr) Addr) (fd *netFD, err error) {
+func socket(net string, f, t, p int, ipv6only bool, ulsa, ursa syscall.Sockaddr, toAddr func(syscall.Sockaddr) Addr) (fd *netFD, err error) {
        // See ../syscall/exec.go for description of ForkLock.
        syscall.ForkLock.RLock()
        s, err := syscall.Socket(f, t, p)
@@ -27,21 +27,18 @@ func socket(net string, f, t, p int, ipv6only bool, la, ra syscall.Sockaddr, toA
        syscall.CloseOnExec(s)
        syscall.ForkLock.RUnlock()
 
-       err = setDefaultSockopts(s, f, t, ipv6only)
-       if err != nil {
+       if err = setDefaultSockopts(s, f, t, ipv6only); err != nil {
                closesocket(s)
                return nil, err
        }
 
-       var bla syscall.Sockaddr
-       if la != nil {
-               bla, err = listenerSockaddr(s, f, la, toAddr)
-               if err != nil {
+       var blsa syscall.Sockaddr
+       if ulsa != nil {
+               if blsa, err = listenerSockaddr(s, f, ulsa, toAddr); err != nil {
                        closesocket(s)
                        return nil, err
                }
-               err = syscall.Bind(s, bla)
-               if err != nil {
+               if err = syscall.Bind(s, blsa); err != nil {
                        closesocket(s)
                        return nil, err
                }
@@ -52,8 +49,8 @@ func socket(net string, f, t, p int, ipv6only bool, la, ra syscall.Sockaddr, toA
                return nil, err
        }
 
-       if ra != nil {
-               if err = fd.connect(ra); err != nil {
+       if ursa != nil {
+               if err = fd.connect(ursa); err != nil {
                        closesocket(s)
                        fd.Close()
                        return nil, err
@@ -61,17 +58,13 @@ func socket(net string, f, t, p int, ipv6only bool, la, ra syscall.Sockaddr, toA
                fd.isConnected = true
        }
 
-       sa, _ := syscall.Getsockname(s)
        var laddr Addr
-       if la != nil && bla != la {
-               laddr = toAddr(la)
+       if ulsa != nil && blsa != ulsa {
+               laddr = toAddr(ulsa)
        } else {
-               laddr = toAddr(sa)
+               laddr = localSockname(fd, toAddr)
        }
-       sa, _ = syscall.Getpeername(s)
-       raddr := toAddr(sa)
-
-       fd.setAddr(laddr, raddr)
+       fd.setAddr(laddr, remoteSockname(fd, toAddr))
        return fd, nil
 }
 
@@ -85,3 +78,39 @@ func genericReadFrom(w io.Writer, r io.Reader) (n int64, err error) {
        // Use wrapper to hide existing r.ReadFrom from io.Copy.
        return io.Copy(writerOnly{w}, r)
 }
+
+func localSockname(fd *netFD, toAddr func(syscall.Sockaddr) Addr) Addr {
+       sa, _ := syscall.Getsockname(fd.sysfd)
+       if sa == nil {
+               return nullProtocolAddr(fd.family, fd.sotype)
+       }
+       return toAddr(sa)
+}
+
+func remoteSockname(fd *netFD, toAddr func(syscall.Sockaddr) Addr) Addr {
+       sa, _ := syscall.Getpeername(fd.sysfd)
+       if sa == nil {
+               return nullProtocolAddr(fd.family, fd.sotype)
+       }
+       return toAddr(sa)
+}
+
+func nullProtocolAddr(f, t int) Addr {
+       switch f {
+       case syscall.AF_INET, syscall.AF_INET6:
+               switch t {
+               case syscall.SOCK_STREAM:
+                       return (*TCPAddr)(nil)
+               case syscall.SOCK_DGRAM:
+                       return (*UDPAddr)(nil)
+               case syscall.SOCK_RAW:
+                       return (*IPAddr)(nil)
+               }
+       case syscall.AF_UNIX:
+               switch t {
+               case syscall.SOCK_STREAM, syscall.SOCK_DGRAM, syscall.SOCK_SEQPACKET:
+                       return (*UnixAddr)(nil)
+               }
+       }
+       panic("unreachable")
+}
index f80d3b5a9cfa77e2e50549406e13a42e24b60155..90365c05f0a9ea51afc94ccc205f54e66398627f 100644 (file)
@@ -9,6 +9,33 @@ import (
        "testing"
 )
 
+var udpConnAddrStringTests = []struct {
+       net   string
+       laddr string
+       raddr string
+       ipv6  bool
+}{
+       {"udp", "127.0.0.1:0", "", false},
+       {"udp", "[::1]:0", "", true},
+}
+
+func TestUDPConnAddrString(t *testing.T) {
+       for i, tt := range udpConnAddrStringTests {
+               if tt.ipv6 && !supportsIPv6 {
+                       continue
+               }
+               mode := "listen"
+               la, _ := ResolveUDPAddr(tt.net, tt.laddr)
+               c, err := ListenUDP(tt.net, la)
+               if err != nil {
+                       t.Fatalf("ListenUDP(%q, %q) failed: %v", tt.net, la.String(), err)
+               }
+               t.Logf("%s-%v: LocalAddr: %q, %q", mode, i, c.LocalAddr(), c.LocalAddr().String())
+               t.Logf("%s-%v: RemoteAddr: %q, %q", mode, i, c.RemoteAddr(), c.RemoteAddr().String())
+               c.Close()
+       }
+}
+
 func TestWriteToUDP(t *testing.T) {
        switch runtime.GOOS {
        case "plan9":