]> Cypherpunks repositories - gostls13.git/commitdiff
net: improve test coverage for syscall.RawConn
authorMikio Hara <mikioh.mikioh@gmail.com>
Tue, 20 Feb 2018 01:21:36 +0000 (10:21 +0900)
committerMikio Hara <mikioh.mikioh@gmail.com>
Sat, 31 Mar 2018 23:41:59 +0000 (23:41 +0000)
An application using syscall.RawConn in a particular way must take
account of the operating system or platform-dependent behavior.
This change consolidates duplicate code and improves the test coverage
for applications that use socket options.

Change-Id: Ie42340ac5373875cf1fd9123df0e99a1e7ac280f
Reviewed-on: https://go-review.googlesource.com/95335
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
src/net/rawconn_stub_test.go [new file with mode: 0644]
src/net/rawconn_test.go [new file with mode: 0644]
src/net/rawconn_unix_test.go
src/net/rawconn_windows_test.go

diff --git a/src/net/rawconn_stub_test.go b/src/net/rawconn_stub_test.go
new file mode 100644 (file)
index 0000000..391b4d1
--- /dev/null
@@ -0,0 +1,24 @@
+// Copyright 2018 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build nacl plan9
+
+package net
+
+import (
+       "errors"
+       "syscall"
+)
+
+func readRawConn(c syscall.RawConn, b []byte) (int, error) {
+       return 0, errors.New("not supported")
+}
+
+func writeRawConn(c syscall.RawConn, b []byte) error {
+       return errors.New("not supported")
+}
+
+func controlRawConn(c syscall.RawConn, addr Addr) error {
+       return errors.New("not supported")
+}
diff --git a/src/net/rawconn_test.go b/src/net/rawconn_test.go
new file mode 100644 (file)
index 0000000..287282f
--- /dev/null
@@ -0,0 +1,142 @@
+// Copyright 2018 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package net
+
+import (
+       "bytes"
+       "runtime"
+       "testing"
+)
+
+func TestRawConnReadWrite(t *testing.T) {
+       switch runtime.GOOS {
+       case "nacl", "plan9", "windows":
+               t.Skipf("not supported on %s", runtime.GOOS)
+       }
+
+       t.Run("TCP", func(t *testing.T) {
+               handler := func(ls *localServer, ln Listener) {
+                       c, err := ln.Accept()
+                       if err != nil {
+                               t.Error(err)
+                               return
+                       }
+                       defer c.Close()
+
+                       cc, err := ln.(*TCPListener).SyscallConn()
+                       if err != nil {
+                               t.Fatal(err)
+                       }
+                       called := false
+                       op := func(uintptr) bool {
+                               called = true
+                               return true
+                       }
+                       err = cc.Write(op)
+                       if err == nil {
+                               t.Error("Write should return an error")
+                       }
+                       if called {
+                               t.Error("Write shouldn't call op")
+                       }
+                       called = false
+                       err = cc.Read(op)
+                       if err == nil {
+                               t.Error("Read should return an error")
+                       }
+                       if called {
+                               t.Error("Read shouldn't call op")
+                       }
+
+                       var b [32]byte
+                       n, err := c.Read(b[:])
+                       if err != nil {
+                               t.Error(err)
+                               return
+                       }
+                       if _, err := c.Write(b[:n]); err != nil {
+                               t.Error(err)
+                               return
+                       }
+               }
+               ls, err := newLocalServer("tcp")
+               if err != nil {
+                       t.Fatal(err)
+               }
+               defer ls.teardown()
+               if err := ls.buildup(handler); err != nil {
+                       t.Fatal(err)
+               }
+
+               c, err := Dial(ls.Listener.Addr().Network(), ls.Listener.Addr().String())
+               if err != nil {
+                       t.Fatal(err)
+               }
+               defer c.Close()
+
+               cc, err := c.(*TCPConn).SyscallConn()
+               if err != nil {
+                       t.Fatal(err)
+               }
+               data := []byte("HELLO-R-U-THERE")
+               if err := writeRawConn(cc, data); err != nil {
+                       t.Fatal(err)
+               }
+               var b [32]byte
+               n, err := readRawConn(cc, b[:])
+               if err != nil {
+                       t.Fatal(err)
+               }
+               if bytes.Compare(b[:n], data) != 0 {
+                       t.Fatalf("got %q; want %q", b[:n], data)
+               }
+       })
+}
+
+func TestRawConnControl(t *testing.T) {
+       switch runtime.GOOS {
+       case "nacl", "plan9":
+               t.Skipf("not supported on %s", runtime.GOOS)
+       }
+
+       t.Run("TCP", func(t *testing.T) {
+               ln, err := newLocalListener("tcp")
+               if err != nil {
+                       t.Fatal(err)
+               }
+               defer ln.Close()
+
+               cc1, err := ln.(*TCPListener).SyscallConn()
+               if err != nil {
+                       t.Fatal(err)
+               }
+               if err := controlRawConn(cc1, ln.Addr()); err != nil {
+                       t.Fatal(err)
+               }
+
+               c, err := Dial(ln.Addr().Network(), ln.Addr().String())
+               if err != nil {
+                       t.Fatal(err)
+               }
+               defer c.Close()
+
+               cc2, err := c.(*TCPConn).SyscallConn()
+               if err != nil {
+                       t.Fatal(err)
+               }
+               if err := controlRawConn(cc2, c.LocalAddr()); err != nil {
+                       t.Fatal(err)
+               }
+
+               ln.Close()
+               if err := controlRawConn(cc1, ln.Addr()); err == nil {
+                       t.Fatal("Control after Close should fail")
+               }
+               c.Close()
+               if err := controlRawConn(cc2, c.LocalAddr()); err == nil {
+                       t.Fatal("Control after Close should fail")
+               }
+       })
+}
index 913ad8659510f531a9f2d91381361a63b6b45f4c..2fe4d2c6bace69d78bce014ee42e7c71f3a489f4 100644 (file)
 
 package net
 
-import (
-       "bytes"
-       "syscall"
-       "testing"
-)
-
-func TestRawConn(t *testing.T) {
-       handler := func(ls *localServer, ln Listener) {
-               c, err := ln.Accept()
-               if err != nil {
-                       t.Error(err)
-                       return
-               }
-               defer c.Close()
-               var b [32]byte
-               n, err := c.Read(b[:])
-               if err != nil {
-                       t.Error(err)
-                       return
-               }
-               if _, err := c.Write(b[:n]); err != nil {
-                       t.Error(err)
-                       return
-               }
-       }
-       ls, err := newLocalServer("tcp")
-       if err != nil {
-               t.Fatal(err)
-       }
-       defer ls.teardown()
-       if err := ls.buildup(handler); err != nil {
-               t.Fatal(err)
-       }
-
-       c, err := Dial(ls.Listener.Addr().Network(), ls.Listener.Addr().String())
-       if err != nil {
-               t.Fatal(err)
-       }
-       defer c.Close()
-       cc, err := c.(*TCPConn).SyscallConn()
-       if err != nil {
-               t.Fatal(err)
-       }
+import "syscall"
 
+func readRawConn(c syscall.RawConn, b []byte) (int, error) {
        var operr error
-       data := []byte("HELLO-R-U-THERE")
-       err = cc.Write(func(s uintptr) bool {
-               _, operr = syscall.Write(int(s), data)
+       var n int
+       err := c.Read(func(s uintptr) bool {
+               n, operr = syscall.Read(int(s), b)
                if operr == syscall.EAGAIN {
                        return false
                }
                return true
        })
-       if err != nil || operr != nil {
-               t.Fatal(err, operr)
+       if err != nil {
+               return n, err
        }
+       if operr != nil {
+               return n, operr
+       }
+       return n, nil
+}
 
-       var nr int
-       var b [32]byte
-       err = cc.Read(func(s uintptr) bool {
-               nr, operr = syscall.Read(int(s), b[:])
+func writeRawConn(c syscall.RawConn, b []byte) error {
+       var operr error
+       err := c.Write(func(s uintptr) bool {
+               _, operr = syscall.Write(int(s), b)
                if operr == syscall.EAGAIN {
                        return false
                }
                return true
        })
-       if err != nil || operr != nil {
-               t.Fatal(err, operr)
-       }
-       if bytes.Compare(b[:nr], data) != 0 {
-               t.Fatalf("got %#v; want %#v", b[:nr], data)
-       }
-
-       fn := func(s uintptr) {
-               operr = syscall.SetsockoptInt(int(s), syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1)
-       }
-       err = cc.Control(fn)
-       if err != nil || operr != nil {
-               t.Fatal(err, operr)
-       }
-       c.Close()
-       err = cc.Control(fn)
-       if err == nil {
-               t.Fatal("should fail")
-       }
-}
-
-func TestRawConnListener(t *testing.T) {
-       ln, err := newLocalListener("tcp")
        if err != nil {
-               t.Fatal(err)
-       }
-       defer ln.Close()
-
-       cc, err := ln.(*TCPListener).SyscallConn()
-       if err != nil {
-               t.Fatal(err)
-       }
-
-       called := false
-       op := func(uintptr) bool {
-               called = true
-               return true
-       }
-
-       err = cc.Write(op)
-       if err == nil {
-               t.Error("Write should return an error")
-       }
-       if called {
-               t.Error("Write shouldn't call op")
+               return err
        }
-
-       called = false
-       err = cc.Read(op)
-       if err == nil {
-               t.Error("Read should return an error")
-       }
-       if called {
-               t.Error("Read shouldn't call op")
+       if operr != nil {
+               return operr
        }
+       return nil
+}
 
+func controlRawConn(c syscall.RawConn, addr Addr) error {
        var operr error
        fn := func(s uintptr) {
                _, operr = syscall.GetsockoptInt(int(s), syscall.SOL_SOCKET, syscall.SO_REUSEADDR)
+               if operr != nil {
+                       return
+               }
+               switch addr := addr.(type) {
+               case *TCPAddr:
+                       // There's no guarantee that IP-level socket
+                       // options work well with dual stack sockets.
+                       // A simple solution would be to take a look
+                       // at the bound address to the raw connection
+                       // and to classify the address family of the
+                       // underlying socket by the bound address:
+                       //
+                       // - When IP.To16() != nil and IP.To4() == nil,
+                       //   we can assume that the raw connection
+                       //   consists of an IPv6 socket using only
+                       //   IPv6 addresses.
+                       //
+                       // - When IP.To16() == nil and IP.To4() != nil,
+                       //   the raw connection consists of an IPv4
+                       //   socket using only IPv4 addresses.
+                       //
+                       // - Otherwise, the raw connection is a dual
+                       //   stack socket, an IPv6 socket using IPv6
+                       //   addresses including IPv4-mapped or
+                       //   IPv4-embedded IPv6 addresses.
+                       if addr.IP.To16() != nil && addr.IP.To4() == nil {
+                               operr = syscall.SetsockoptInt(int(s), syscall.IPPROTO_IPV6, syscall.IPV6_UNICAST_HOPS, 1)
+                       } else if addr.IP.To16() == nil && addr.IP.To4() != nil {
+                               operr = syscall.SetsockoptInt(int(s), syscall.IPPROTO_IP, syscall.IP_TTL, 1)
+                       }
+               }
        }
-       err = cc.Control(fn)
-       if err != nil || operr != nil {
-               t.Fatal(err, operr)
+       if err := c.Control(fn); err != nil {
+               return err
        }
-       ln.Close()
-       err = cc.Control(fn)
-       if err == nil {
-               t.Fatal("Control after Close should fail")
+       if operr != nil {
+               return operr
        }
+       return nil
 }
index 2ee12c35963d8f62cfcd6224b5f4e1e5fda546db..1b6777bb172012b6f0d05ae350f9e509480457ee 100644 (file)
@@ -6,84 +6,60 @@ package net
 
 import (
        "syscall"
-       "testing"
        "unsafe"
 )
 
-func TestRawConn(t *testing.T) {
-       c, err := newLocalPacketListener("udp")
-       if err != nil {
-               t.Fatal(err)
-       }
-       defer c.Close()
-       cc, err := c.(*UDPConn).SyscallConn()
-       if err != nil {
-               t.Fatal(err)
-       }
-
-       var operr error
-       fn := func(s uintptr) {
-               operr = syscall.SetsockoptInt(syscall.Handle(s), syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1)
-       }
-       err = cc.Control(fn)
-       if err != nil || operr != nil {
-               t.Fatal(err, operr)
-       }
-       c.Close()
-       err = cc.Control(fn)
-       if err == nil {
-               t.Fatal("should fail")
-       }
+func readRawConn(c syscall.RawConn, b []byte) (int, error) {
+       return 0, syscall.EWINDOWS
 }
 
-func TestRawConnListener(t *testing.T) {
-       ln, err := newLocalListener("tcp")
-       if err != nil {
-               t.Fatal(err)
-       }
-       defer ln.Close()
-
-       cc, err := ln.(*TCPListener).SyscallConn()
-       if err != nil {
-               t.Fatal(err)
-       }
-
-       called := false
-       op := func(uintptr) bool {
-               called = true
-               return true
-       }
-
-       err = cc.Write(op)
-       if err == nil {
-               t.Error("Write should return an error")
-       }
-       if called {
-               t.Error("Write shouldn't call op")
-       }
-
-       called = false
-       err = cc.Read(op)
-       if err == nil {
-               t.Error("Read should return an error")
-       }
-       if called {
-               t.Error("Read shouldn't call op")
-       }
+func writeRawConn(c syscall.RawConn, b []byte) error {
+       return syscall.EWINDOWS
+}
 
+func controlRawConn(c syscall.RawConn, addr Addr) error {
        var operr error
        fn := func(s uintptr) {
                var v, l int32
                l = int32(unsafe.Sizeof(v))
                operr = syscall.Getsockopt(syscall.Handle(s), syscall.SOL_SOCKET, syscall.SO_REUSEADDR, (*byte)(unsafe.Pointer(&v)), &l)
+               if operr != nil {
+                       return
+               }
+               switch addr := addr.(type) {
+               case *TCPAddr:
+                       // There's no guarantee that IP-level socket
+                       // options work well with dual stack sockets.
+                       // A simple solution would be to take a look
+                       // at the bound address to the raw connection
+                       // and to classify the address family of the
+                       // underlying socket by the bound address:
+                       //
+                       // - When IP.To16() != nil and IP.To4() == nil,
+                       //   we can assume that the raw connection
+                       //   consists of an IPv6 socket using only
+                       //   IPv6 addresses.
+                       //
+                       // - When IP.To16() == nil and IP.To4() != nil,
+                       //   the raw connection consists of an IPv4
+                       //   socket using only IPv4 addresses.
+                       //
+                       // - Otherwise, the raw connection is a dual
+                       //   stack socket, an IPv6 socket using IPv6
+                       //   addresses including IPv4-mapped or
+                       //   IPv4-embedded IPv6 addresses.
+                       if addr.IP.To16() != nil && addr.IP.To4() == nil {
+                               operr = syscall.SetsockoptInt(syscall.Handle(s), syscall.IPPROTO_IPV6, syscall.IPV6_UNICAST_HOPS, 1)
+                       } else if addr.IP.To16() == nil && addr.IP.To4() != nil {
+                               operr = syscall.SetsockoptInt(syscall.Handle(s), syscall.IPPROTO_IP, syscall.IP_TTL, 1)
+                       }
+               }
        }
-       err = cc.Control(fn)
-       if err != nil || operr != nil {
-               t.Fatal(err, operr)
+       if err := c.Control(fn); err != nil {
+               return err
        }
-       ln.Close()
-       err = cc.Control(fn)
-       if err == nil {
-               t.Fatal("Control after Close should fail")
+       if operr != nil {
+               return operr
        }
+       return nil
 }