]> Cypherpunks repositories - gostls13.git/commitdiff
net, internal/syscall/unix: add SocketConn, SocketPacketConn
authorMikio Hara <mikioh.mikioh@gmail.com>
Thu, 23 Apr 2015 14:57:00 +0000 (23:57 +0900)
committerMikio Hara <mikioh.mikioh@gmail.com>
Wed, 13 May 2015 01:04:23 +0000 (01:04 +0000)
FileConn and FilePacketConn APIs accept user-configured socket
descriptors to make them work together with runtime-integrated network
poller, but there's a limitation. The APIs reject protocol sockets that
are not supported by standard library. It's very hard for the net,
syscall packages to look after all platform, feature-specific sockets.

This change allows various platform, feature-specific socket descriptors
to use runtime-integrated network poller by using SocketConn,
SocketPacketConn APIs that bridge between the net, syscall packages and
platforms.

New exposed APIs:
pkg net, func SocketConn(*os.File, SocketAddr) (Conn, error)
pkg net, func SocketPacketConn(*os.File, SocketAddr) (PacketConn, error)
pkg net, type SocketAddr interface { Addr, Raw }
pkg net, type SocketAddr interface, Addr([]uint8) Addr
pkg net, type SocketAddr interface, Raw(Addr) []uint8

Fixes #10565.

Change-Id: Iec57499b3d84bb5cb0bcf3f664330c535eec11e3
Reviewed-on: https://go-review.googlesource.com/9275
Reviewed-by: Ian Lance Taylor <iant@golang.org>
14 files changed:
src/go/build/deps_test.go
src/internal/syscall/unix/socket.go [new file with mode: 0644]
src/internal/syscall/unix/socket_linux_386.go [new file with mode: 0644]
src/internal/syscall/unix/socket_linux_386.s [new file with mode: 0644]
src/internal/syscall/unix/socket_stub.go [new file with mode: 0644]
src/internal/syscall/unix/socket_unix.go [new file with mode: 0644]
src/net/fd_unix.go
src/net/file.go
src/net/file_bsd_test.go [new file with mode: 0644]
src/net/file_linux_test.go [new file with mode: 0644]
src/net/file_plan9.go
src/net/file_stub.go
src/net/file_unix.go
src/net/file_windows.go

index 5a28c34adf012f76ff5917333d9daa1eef55d545..8e985aa05b962871954a0a06a0844fd40820b414 100644 (file)
@@ -244,7 +244,7 @@ var pkgDeps = map[string][]string{
        // Basic networking.
        // Because net must be used by any package that wants to
        // do networking portably, it must have a small dependency set: just L1+basic os.
-       "net": {"L1", "CGO", "os", "syscall", "time", "internal/syscall/windows", "internal/singleflight"},
+       "net": {"L1", "CGO", "os", "syscall", "time", "internal/syscall/unix", "internal/syscall/windows", "internal/singleflight"},
 
        // NET enables use of basic network-related packages.
        "NET": {
diff --git a/src/internal/syscall/unix/socket.go b/src/internal/syscall/unix/socket.go
new file mode 100644 (file)
index 0000000..d7a9b9c
--- /dev/null
@@ -0,0 +1,39 @@
+// Copyright 2015 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build darwin dragonfly freebsd linux nacl netbsd openbsd solaris
+
+package unix
+
+// Getsockname copies the binary encoding of the current address for s
+// into addr.
+func Getsockname(s int, addr []byte) error {
+       return getsockname(s, addr)
+}
+
+// Getpeername copies the binary encoding of the peer address for s
+// into addr.
+func Getpeername(s int, addr []byte) error {
+       return getpeername(s, addr)
+}
+
+var emptyPayload uintptr
+
+// Recvfrom receives a message from s, copying the message into b.
+// The socket address addr must be large enough for storing the source
+// address of the message.
+// Flags must be operation control flags or 0.
+// It retunrs the number of bytes copied into b.
+func Recvfrom(s int, b []byte, flags int, addr []byte) (int, error) {
+       return recvfrom(s, b, flags, addr)
+}
+
+// Sendto sends a message to the socket address addr, copying the
+// message from b.
+// The socket address addr must be suitable for s.
+// Flags must be operation control flags or 0.
+// It retunrs the number of bytes copied from b.
+func Sendto(s int, b []byte, flags int, addr []byte) (int, error) {
+       return sendto(s, b, flags, addr)
+}
diff --git a/src/internal/syscall/unix/socket_linux_386.go b/src/internal/syscall/unix/socket_linux_386.go
new file mode 100644 (file)
index 0000000..47105e0
--- /dev/null
@@ -0,0 +1,67 @@
+// Copyright 2015 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package unix
+
+import (
+       "syscall"
+       "unsafe"
+)
+
+const (
+       sysGETSOCKNAME = 0x6
+       sysGETPEERNAME = 0x7
+       sysSENDTO      = 0xb
+       sysRECVFROM    = 0xc
+)
+
+func socketcall(call int, a0, a1, a2, a3, a4, a5 uintptr) (int, syscall.Errno)
+func rawsocketcall(call int, a0, a1, a2, a3, a4, a5 uintptr) (int, syscall.Errno)
+
+func getsockname(s int, addr []byte) error {
+       l := uint32(len(addr))
+       _, errno := rawsocketcall(sysGETSOCKNAME, uintptr(s), uintptr(unsafe.Pointer(&addr[0])), uintptr(unsafe.Pointer(&l)), 0, 0, 0)
+       if errno != 0 {
+               return error(errno)
+       }
+       return nil
+}
+
+func getpeername(s int, addr []byte) error {
+       l := uint32(len(addr))
+       _, errno := rawsocketcall(sysGETPEERNAME, uintptr(s), uintptr(unsafe.Pointer(&addr[0])), uintptr(unsafe.Pointer(&l)), 0, 0, 0)
+       if errno != 0 {
+               return error(errno)
+       }
+       return nil
+}
+
+func recvfrom(s int, b []byte, flags int, from []byte) (int, error) {
+       var p unsafe.Pointer
+       if len(b) > 0 {
+               p = unsafe.Pointer(&b[0])
+       } else {
+               p = unsafe.Pointer(&emptyPayload)
+       }
+       l := uint32(len(from))
+       n, errno := socketcall(sysRECVFROM, uintptr(s), uintptr(p), uintptr(len(b)), uintptr(flags), uintptr(unsafe.Pointer(&from[0])), uintptr(unsafe.Pointer(&l)))
+       if errno != 0 {
+               return int(n), error(errno)
+       }
+       return int(n), nil
+}
+
+func sendto(s int, b []byte, flags int, to []byte) (int, error) {
+       var p unsafe.Pointer
+       if len(b) > 0 {
+               p = unsafe.Pointer(&b[0])
+       } else {
+               p = unsafe.Pointer(&emptyPayload)
+       }
+       n, errno := socketcall(sysSENDTO, uintptr(s), uintptr(p), uintptr(len(b)), uintptr(flags), uintptr(unsafe.Pointer(&to[0])), uintptr(len(to)))
+       if errno != 0 {
+               return int(n), error(errno)
+       }
+       return int(n), nil
+}
diff --git a/src/internal/syscall/unix/socket_linux_386.s b/src/internal/syscall/unix/socket_linux_386.s
new file mode 100644 (file)
index 0000000..48e2094
--- /dev/null
@@ -0,0 +1,11 @@
+// Copyright 2015 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+#include "textflag.h"
+
+TEXT   ·socketcall(SB),NOSPLIT,$0-36
+       JMP     syscall·socketcall(SB)
+
+TEXT   ·rawsocketcall(SB),NOSPLIT,$0-36
+       JMP     syscall·socketcall(SB)
diff --git a/src/internal/syscall/unix/socket_stub.go b/src/internal/syscall/unix/socket_stub.go
new file mode 100644 (file)
index 0000000..1c89ed1
--- /dev/null
@@ -0,0 +1,25 @@
+// Copyright 2015 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build nacl solaris
+
+package unix
+
+import "syscall"
+
+func getsockname(s int, addr []byte) error {
+       return syscall.EOPNOTSUPP
+}
+
+func getpeername(s int, addr []byte) error {
+       return syscall.EOPNOTSUPP
+}
+
+func recvfrom(s int, b []byte, flags int, from []byte) (int, error) {
+       return 0, syscall.EOPNOTSUPP
+}
+
+func sendto(s int, b []byte, flags int, to []byte) (int, error) {
+       return 0, syscall.EOPNOTSUPP
+}
diff --git a/src/internal/syscall/unix/socket_unix.go b/src/internal/syscall/unix/socket_unix.go
new file mode 100644 (file)
index 0000000..a769bb3
--- /dev/null
@@ -0,0 +1,59 @@
+// Copyright 2015 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build darwin dragonfly freebsd linux,!386 netbsd openbsd
+
+package unix
+
+import (
+       "syscall"
+       "unsafe"
+)
+
+func getsockname(s int, addr []byte) error {
+       l := uint32(len(addr))
+       _, _, errno := syscall.RawSyscall(syscall.SYS_GETSOCKNAME, uintptr(s), uintptr(unsafe.Pointer(&addr[0])), uintptr(unsafe.Pointer(&l)))
+       if errno != 0 {
+               return error(errno)
+       }
+       return nil
+}
+
+func getpeername(s int, addr []byte) error {
+       l := uint32(len(addr))
+       _, _, errno := syscall.RawSyscall(syscall.SYS_GETPEERNAME, uintptr(s), uintptr(unsafe.Pointer(&addr[0])), uintptr(unsafe.Pointer(&l)))
+       if errno != 0 {
+               return error(errno)
+       }
+       return nil
+}
+
+func recvfrom(s int, b []byte, flags int, from []byte) (int, error) {
+       var p unsafe.Pointer
+       if len(b) > 0 {
+               p = unsafe.Pointer(&b[0])
+       } else {
+               p = unsafe.Pointer(&emptyPayload)
+       }
+       l := uint32(len(from))
+       n, _, errno := syscall.Syscall6(syscall.SYS_RECVFROM, uintptr(s), uintptr(p), uintptr(len(b)), uintptr(flags), uintptr(unsafe.Pointer(&from[0])), uintptr(unsafe.Pointer(&l)))
+       if errno != 0 {
+               return int(n), error(errno)
+       }
+       return int(n), nil
+}
+
+func sendto(s int, b []byte, flags int, to []byte) (int, error) {
+       var p unsafe.Pointer
+       if len(b) > 0 {
+               p = unsafe.Pointer(&b[0])
+       } else {
+               p = unsafe.Pointer(&emptyPayload)
+       }
+       n, _, errno := syscall.Syscall6(syscall.SYS_SENDTO, uintptr(s), uintptr(p), uintptr(len(b)), uintptr(flags), uintptr(unsafe.Pointer(&to[0])), uintptr(len(to)))
+       if errno != 0 {
+               return int(n), error(errno)
+       }
+       return int(n), nil
+}
index f2d7b348bf0950bf43b44f42eb81b9d3324fdbe1..827045b13d3ed31aaf264c8f38f9b8ee99f12741 100644 (file)
@@ -7,6 +7,7 @@
 package net
 
 import (
+       "internal/syscall/unix"
        "io"
        "os"
        "runtime"
@@ -270,6 +271,33 @@ func (fd *netFD) readFrom(p []byte) (n int, sa syscall.Sockaddr, err error) {
        return
 }
 
+func (fd *netFD) recvFrom(b []byte, flags int, from []byte) (n int, err error) {
+       if err := fd.readLock(); err != nil {
+               return 0, err
+       }
+       defer fd.readUnlock()
+       if err := fd.pd.PrepareRead(); err != nil {
+               return 0, err
+       }
+       for {
+               n, err = unix.Recvfrom(fd.sysfd, b, flags, from)
+               if err != nil {
+                       n = 0
+                       if err == syscall.EAGAIN {
+                               if err = fd.pd.WaitRead(); err == nil {
+                                       continue
+                               }
+                       }
+               }
+               err = fd.eofError(n, err)
+               break
+       }
+       if _, ok := err.(syscall.Errno); ok {
+               err = os.NewSyscallError("recvfrom", err)
+       }
+       return
+}
+
 func (fd *netFD) readMsg(p []byte, oob []byte) (n, oobn, flags int, sa syscall.Sockaddr, err error) {
        if err := fd.readLock(); err != nil {
                return 0, 0, 0, nil, err
@@ -359,6 +387,29 @@ func (fd *netFD) writeTo(p []byte, sa syscall.Sockaddr) (n int, err error) {
        return
 }
 
+func (fd *netFD) sendTo(b []byte, flags int, to []byte) (n int, err error) {
+       if err := fd.writeLock(); err != nil {
+               return 0, err
+       }
+       defer fd.writeUnlock()
+       if err := fd.pd.PrepareWrite(); err != nil {
+               return 0, err
+       }
+       for {
+               n, err = unix.Sendto(fd.sysfd, b, flags, to)
+               if err == syscall.EAGAIN {
+                       if err = fd.pd.WaitWrite(); err == nil {
+                               continue
+                       }
+               }
+               break
+       }
+       if _, ok := err.(syscall.Errno); ok {
+               err = os.NewSyscallError("sendto", err)
+       }
+       return
+}
+
 func (fd *netFD) writeMsg(p []byte, oob []byte, sa syscall.Sockaddr) (n int, oobn int, err error) {
        if err := fd.writeLock(); err != nil {
                return 0, 0, err
index 1aad477400c2d0494709e87eceb23b68fe2d78ca..1d0686c9fc6228ea3dd6e044deec6ede9f14d0f2 100644 (file)
@@ -46,3 +46,47 @@ func FilePacketConn(f *os.File) (c PacketConn, err error) {
        }
        return
 }
+
+// A SocketAddr is used with SocketConn or SocketPacketConn to
+// implement a user-configured socket address.
+// The net package does not provide any implementations of SocketAddr;
+// the caller of SocketConn or SocketPacketConn is expected to provide
+// one.
+type SocketAddr interface {
+       // Addr takes a platform-specific socket address and returns
+       // a net.Addr. The result may be nil when the syscall package,
+       // system call or underlying protocol does not support the
+       // socket address.
+       Addr([]byte) Addr
+
+       // Raw takes a net.Addr and returns a platform-specific socket
+       // address. The result may be nil when the syscall package,
+       // system call or underlying protocol does not support the
+       // socket address.
+       Raw(Addr) []byte
+}
+
+// SocketConn returns a copy of the network connection corresponding
+// to the open file f and user-defined socket address sa.
+// It is the caller's responsibility to close f when finished.
+// Closing c does not affect f, and closing f does not affect c.
+func SocketConn(f *os.File, sa SocketAddr) (c Conn, err error) {
+       c, err = socketConn(f, sa)
+       if err != nil {
+               err = &OpError{Op: "file", Net: "file+net", Source: nil, Addr: fileAddr(f.Name()), Err: err}
+       }
+       return
+}
+
+// SocketPacketConn returns a copy of the packet network connection
+// corresponding to the open file f and user-defined socket address
+// sa.
+// It is the caller's responsibility to close f when finished.
+// Closing c does not affect f, and closing f does not affect c.
+func SocketPacketConn(f *os.File, sa SocketAddr) (c PacketConn, err error) {
+       c, err = socketPacketConn(f, sa)
+       if err != nil {
+               err = &OpError{Op: "file", Net: "file+net", Source: nil, Addr: fileAddr(f.Name()), Err: err}
+       }
+       return
+}
diff --git a/src/net/file_bsd_test.go b/src/net/file_bsd_test.go
new file mode 100644 (file)
index 0000000..6e6cf12
--- /dev/null
@@ -0,0 +1,94 @@
+// Copyright 2015 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build darwin dragonfly freebsd netbsd openbsd
+
+package net
+
+import (
+       "os"
+       "runtime"
+       "strings"
+       "sync"
+       "syscall"
+       "testing"
+       "unsafe"
+)
+
+type routeAddr struct{}
+
+func (a *routeAddr) Network() string { return "route" }
+func (a *routeAddr) String() string  { return "<nil>" }
+
+func (a *routeAddr) Addr(rsa []byte) Addr { return &routeAddr{} }
+func (a *routeAddr) Raw(addr Addr) []byte { return nil }
+
+func TestSocketConn(t *testing.T) {
+       var freebsd32o64 bool
+       if runtime.GOOS == "freebsd" && runtime.GOARCH == "386" {
+               archs, _ := syscall.Sysctl("kern.supported_archs")
+               for _, s := range strings.Split(archs, " ") {
+                       if strings.TrimSpace(s) == "amd64" {
+                               freebsd32o64 = true
+                               break
+                       }
+               }
+       }
+
+       s, err := syscall.Socket(syscall.AF_ROUTE, syscall.SOCK_RAW, syscall.AF_UNSPEC)
+       if err != nil {
+               t.Fatal(err)
+       }
+       f := os.NewFile(uintptr(s), "route")
+       c, err := SocketConn(f, &routeAddr{})
+       f.Close()
+       if err != nil {
+               t.Fatal(err)
+       }
+       defer c.Close()
+
+       const N = 3
+       for i := 0; i < N; i++ {
+               go func(i int) {
+                       l := syscall.SizeofRtMsghdr + syscall.SizeofSockaddrInet4
+                       if freebsd32o64 {
+                               l += syscall.SizeofRtMetrics // see syscall/route_freebsd_32bit.go
+                       }
+                       b := make([]byte, l)
+                       h := (*syscall.RtMsghdr)(unsafe.Pointer(&b[0]))
+                       h.Msglen = uint16(len(b))
+                       h.Version = syscall.RTM_VERSION
+                       h.Type = syscall.RTM_GET
+                       h.Addrs = syscall.RTA_DST
+                       h.Pid = int32(os.Getpid())
+                       h.Seq = int32(i)
+                       p := (*syscall.RawSockaddrInet4)(unsafe.Pointer(&b[syscall.SizeofRtMsghdr]))
+                       p.Len = syscall.SizeofSockaddrInet4
+                       p.Family = syscall.AF_INET
+                       p.Addr = [4]byte{127, 0, 0, 1}
+                       if _, err := c.Write(b); err != nil {
+                               t.Error(err)
+                               return
+                       }
+               }(i + 1)
+       }
+       var wg sync.WaitGroup
+       wg.Add(N)
+       for i := 0; i < N; i++ {
+               go func() {
+                       defer wg.Done()
+                       b := make([]byte, os.Getpagesize())
+                       n, err := c.Read(b)
+                       if err != nil {
+                               t.Error(err)
+                               return
+                       }
+                       if _, err := syscall.ParseRoutingMessage(b[:n]); err != nil {
+                               t.Error(err)
+                               return
+                       }
+               }()
+       }
+       wg.Wait()
+}
diff --git a/src/net/file_linux_test.go b/src/net/file_linux_test.go
new file mode 100644 (file)
index 0000000..58f74d2
--- /dev/null
@@ -0,0 +1,97 @@
+// Copyright 2015 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package net
+
+import (
+       "fmt"
+       "os"
+       "sync"
+       "syscall"
+       "testing"
+       "unsafe"
+)
+
+type netlinkAddr struct {
+       PID    uint32
+       Groups uint32
+}
+
+func (a *netlinkAddr) Network() string { return "netlink" }
+func (a *netlinkAddr) String() string  { return fmt.Sprintf("%x:%x", a.PID, a.Groups) }
+
+func (a *netlinkAddr) Addr(rsa []byte) Addr {
+       if len(rsa) < syscall.SizeofSockaddrNetlink {
+               return nil
+       }
+       var addr netlinkAddr
+       b := (*[unsafe.Sizeof(addr)]byte)(unsafe.Pointer(&addr))
+       copy(b[0:4], rsa[4:8])
+       copy(b[4:8], rsa[8:12])
+       return &addr
+}
+
+func (a *netlinkAddr) Raw(addr Addr) []byte {
+       if addr, ok := addr.(*netlinkAddr); ok {
+               rsa := &syscall.RawSockaddrNetlink{Family: syscall.AF_NETLINK, Pid: addr.PID, Groups: addr.Groups}
+               return (*[unsafe.Sizeof(*rsa)]byte)(unsafe.Pointer(rsa))[:]
+       }
+       return nil
+}
+
+func TestSocketPacketConn(t *testing.T) {
+       s, err := syscall.Socket(syscall.AF_NETLINK, syscall.SOCK_RAW, syscall.NETLINK_ROUTE)
+       if err != nil {
+               t.Fatal(err)
+       }
+       lsa := syscall.SockaddrNetlink{Family: syscall.AF_NETLINK}
+       if err := syscall.Bind(s, &lsa); err != nil {
+               syscall.Close(s)
+               t.Fatal(err)
+       }
+       f := os.NewFile(uintptr(s), "netlink")
+       c, err := SocketPacketConn(f, &netlinkAddr{})
+       f.Close()
+       if err != nil {
+               t.Fatal(err)
+       }
+       defer c.Close()
+
+       const N = 3
+       dst := &netlinkAddr{PID: 0}
+       for i := 0; i < N; i++ {
+               go func() {
+                       l := syscall.NLMSG_HDRLEN + syscall.SizeofRtGenmsg
+                       b := make([]byte, l)
+                       *(*uint32)(unsafe.Pointer(&b[0:4][0])) = uint32(l)
+                       *(*uint16)(unsafe.Pointer(&b[4:6][0])) = uint16(syscall.RTM_GETLINK)
+                       *(*uint16)(unsafe.Pointer(&b[6:8][0])) = uint16(syscall.NLM_F_DUMP | syscall.NLM_F_REQUEST)
+                       *(*uint32)(unsafe.Pointer(&b[8:12][0])) = uint32(1)
+                       *(*uint32)(unsafe.Pointer(&b[12:16][0])) = uint32(0)
+                       b[16] = byte(syscall.AF_UNSPEC)
+                       if _, err := c.WriteTo(b, dst); err != nil {
+                               t.Error(err)
+                               return
+                       }
+               }()
+       }
+       var wg sync.WaitGroup
+       wg.Add(N)
+       for i := 0; i < N; i++ {
+               go func() {
+                       defer wg.Done()
+                       b := make([]byte, os.Getpagesize())
+                       n, _, err := c.ReadFrom(b)
+                       if err != nil {
+                               t.Error(err)
+                               return
+                       }
+                       if _, err := syscall.ParseNetlinkMessage(b[:n]); err != nil {
+                               t.Error(err)
+                               return
+                       }
+               }()
+       }
+       wg.Wait()
+}
index 892775a024f45667bd96f4b6206a3a6df80686a8..efe416f690f00ea54120c9f905758469193cf8c6 100644 (file)
@@ -135,3 +135,11 @@ func fileListener(f *os.File) (Listener, error) {
 func filePacketConn(f *os.File) (PacketConn, error) {
        return nil, syscall.EPLAN9
 }
+
+func socketConn(f *os.File, sa SocketAddr) (Conn, error) {
+       return nil, syscall.EPLAN9
+}
+
+func socketPacketConn(f *os.File, sa SocketAddr) (PacketConn, error) {
+       return nil, syscall.EPLAN9
+}
index 0f7460c757974d7cb05a2aaaa0ba18552c3e16d3..41ca78b437c976f246e9bdc8348c669f407c06d8 100644 (file)
@@ -11,6 +11,8 @@ import (
        "syscall"
 )
 
-func fileConn(f *os.File) (Conn, error)             { return nil, syscall.ENOPROTOOPT }
-func fileListener(f *os.File) (Listener, error)     { return nil, syscall.ENOPROTOOPT }
-func filePacketConn(f *os.File) (PacketConn, error) { return nil, syscall.ENOPROTOOPT }
+func fileConn(f *os.File) (Conn, error)                              { return nil, syscall.ENOPROTOOPT }
+func fileListener(f *os.File) (Listener, error)                      { return nil, syscall.ENOPROTOOPT }
+func filePacketConn(f *os.File) (PacketConn, error)                  { return nil, syscall.ENOPROTOOPT }
+func socketConn(f *os.File, sa SocketAddr) (Conn, error)             { return nil, syscall.ENOPROTOOPT }
+func socketPacketConn(f *os.File, sa SocketAddr) (PacketConn, error) { return nil, syscall.ENOPROTOOPT }
index 147ca1ed9575c8bffae183350c9fcb829b395055..df884d16038d27203c4278382a8f445038f7531c 100644 (file)
@@ -7,76 +7,81 @@
 package net
 
 import (
+       "internal/syscall/unix"
        "os"
        "syscall"
 )
 
-func newFileFD(f *os.File) (*netFD, error) {
-       fd, err := dupCloseOnExec(int(f.Fd()))
+func dupSocket(f *os.File) (int, error) {
+       s, err := dupCloseOnExec(int(f.Fd()))
        if err != nil {
-               return nil, err
+               return -1, err
        }
-
-       if err = syscall.SetNonblock(fd, true); err != nil {
-               closeFunc(fd)
-               return nil, os.NewSyscallError("setnonblock", err)
+       if err := syscall.SetNonblock(s, true); err != nil {
+               closeFunc(s)
+               return -1, os.NewSyscallError("setnonblock", err)
        }
+       return s, nil
+}
 
-       sotype, err := syscall.GetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_TYPE)
+func newFileFD(f *os.File, sa SocketAddr) (*netFD, error) {
+       s, err := dupSocket(f)
        if err != nil {
-               closeFunc(fd)
-               return nil, os.NewSyscallError("getsockopt", err)
+               return nil, err
        }
-
-       family := syscall.AF_UNSPEC
-       toAddr := sockaddrToTCP
-       lsa, _ := syscall.Getsockname(fd)
-       switch lsa.(type) {
-       case *syscall.SockaddrInet4:
-               family = syscall.AF_INET
-               if sotype == syscall.SOCK_DGRAM {
-                       toAddr = sockaddrToUDP
-               } else if sotype == syscall.SOCK_RAW {
-                       toAddr = sockaddrToIP
+       var laddr, raddr Addr
+       var fd *netFD
+       if sa != nil {
+               lsa := make([]byte, syscall.SizeofSockaddrAny)
+               if err := unix.Getsockname(s, lsa); err != nil {
+                       lsa = nil
+               }
+               rsa := make([]byte, syscall.SizeofSockaddrAny)
+               if err := unix.Getpeername(s, rsa); err != nil {
+                       rsa = nil
                }
-       case *syscall.SockaddrInet6:
-               family = syscall.AF_INET6
-               if sotype == syscall.SOCK_DGRAM {
-                       toAddr = sockaddrToUDP
-               } else if sotype == syscall.SOCK_RAW {
-                       toAddr = sockaddrToIP
+               laddr = sa.Addr(lsa)
+               raddr = sa.Addr(rsa)
+               fd, err = newFD(s, -1, -1, laddr.Network())
+       } else {
+               family := syscall.AF_UNSPEC
+               sotype, err := syscall.GetsockoptInt(s, syscall.SOL_SOCKET, syscall.SO_TYPE)
+               if err != nil {
+                       closeFunc(s)
+                       return nil, os.NewSyscallError("getsockopt", err)
                }
-       case *syscall.SockaddrUnix:
-               family = syscall.AF_UNIX
-               toAddr = sockaddrToUnix
-               if sotype == syscall.SOCK_DGRAM {
-                       toAddr = sockaddrToUnixgram
-               } else if sotype == syscall.SOCK_SEQPACKET {
-                       toAddr = sockaddrToUnixpacket
+               lsa, _ := syscall.Getsockname(s)
+               rsa, _ := syscall.Getpeername(s)
+               switch lsa.(type) {
+               case *syscall.SockaddrInet4:
+                       family = syscall.AF_INET
+               case *syscall.SockaddrInet6:
+                       family = syscall.AF_INET6
+               case *syscall.SockaddrUnix:
+                       family = syscall.AF_UNIX
+               default:
+                       closeFunc(s)
+                       return nil, syscall.EPROTONOSUPPORT
                }
-       default:
-               closeFunc(fd)
-               return nil, syscall.EPROTONOSUPPORT
+               fd, err = newFD(s, family, sotype, "")
+               laddr = fd.addrFunc()(lsa)
+               raddr = fd.addrFunc()(rsa)
+               fd.net = laddr.Network()
        }
-       laddr := toAddr(lsa)
-       rsa, _ := syscall.Getpeername(fd)
-       raddr := toAddr(rsa)
-
-       netfd, err := newFD(fd, family, sotype, laddr.Network())
        if err != nil {
-               closeFunc(fd)
+               closeFunc(s)
                return nil, err
        }
-       if err := netfd.init(); err != nil {
-               netfd.Close()
+       if err := fd.init(); err != nil {
+               fd.Close()
                return nil, err
        }
-       netfd.setAddr(laddr, raddr)
-       return netfd, nil
+       fd.setAddr(laddr, raddr)
+       return fd, nil
 }
 
 func fileConn(f *os.File) (Conn, error) {
-       fd, err := newFileFD(f)
+       fd, err := newFileFD(f, nil)
        if err != nil {
                return nil, err
        }
@@ -95,7 +100,7 @@ func fileConn(f *os.File) (Conn, error) {
 }
 
 func fileListener(f *os.File) (Listener, error) {
-       fd, err := newFileFD(f)
+       fd, err := newFileFD(f, nil)
        if err != nil {
                return nil, err
        }
@@ -110,7 +115,7 @@ func fileListener(f *os.File) (Listener, error) {
 }
 
 func filePacketConn(f *os.File) (PacketConn, error) {
-       fd, err := newFileFD(f)
+       fd, err := newFileFD(f, nil)
        if err != nil {
                return nil, err
        }
@@ -125,3 +130,55 @@ func filePacketConn(f *os.File) (PacketConn, error) {
        fd.Close()
        return nil, syscall.EINVAL
 }
+
+func socketConn(f *os.File, sa SocketAddr) (Conn, error) {
+       fd, err := newFileFD(f, sa)
+       if err != nil {
+               return nil, err
+       }
+       return &socketFile{conn: conn{fd}, SocketAddr: sa}, nil
+}
+
+func socketPacketConn(f *os.File, sa SocketAddr) (PacketConn, error) {
+       fd, err := newFileFD(f, sa)
+       if err != nil {
+               return nil, err
+       }
+       return &socketFile{conn: conn{fd}, SocketAddr: sa}, nil
+}
+
+var (
+       _ Conn       = &socketFile{}
+       _ PacketConn = &socketFile{}
+)
+
+// A socketFile is a placeholder that holds a user-specified socket
+// descriptor and a profile of socket address encoding.
+// It implements both Conn and PacketConn interfaces.
+type socketFile struct {
+       conn
+       SocketAddr
+}
+
+func (c *socketFile) ReadFrom(b []byte) (int, Addr, error) {
+       if !c.ok() {
+               return 0, nil, syscall.EINVAL
+       }
+       from := make([]byte, syscall.SizeofSockaddrAny)
+       n, err := c.fd.recvFrom(b, 0, from)
+       if err != nil {
+               return n, nil, &OpError{Op: "read", Net: c.fd.net, Source: c.fd.laddr, Addr: c.fd.raddr, Err: err}
+       }
+       return n, c.SocketAddr.Addr(from), nil
+}
+
+func (c *socketFile) WriteTo(b []byte, addr Addr) (int, error) {
+       if !c.ok() {
+               return 0, syscall.EINVAL
+       }
+       n, err := c.fd.sendTo(b, 0, c.SocketAddr.Raw(addr))
+       if err != nil {
+               return n, &OpError{Op: "write", Net: c.fd.net, Source: c.fd.laddr, Addr: c.fd.raddr, Err: err}
+       }
+       return n, nil
+}
index 241fa17617c2bf20073d275c479e01ed288dde58..1ed72d5bd433ba73acaeeadece308b5a7afd4a93 100644 (file)
@@ -23,3 +23,13 @@ func filePacketConn(f *os.File) (PacketConn, error) {
        // TODO: Implement this
        return nil, syscall.EWINDOWS
 }
+
+func socketConn(f *os.File, sa SocketAddr) (Conn, error) {
+       // TODO: Implement this
+       return nil, syscall.EWINDOWS
+}
+
+func socketPacketConn(f *os.File, sa SocketAddr) (PacketConn, error) {
+       // TODO: Implement this
+       return nil, syscall.EWINDOWS
+}