From 917c33fe8672116b04848cf11545296789cafd3b Mon Sep 17 00:00:00 2001 From: Mikio Hara Date: Tue, 20 Feb 2018 10:21:36 +0900 Subject: [PATCH] net: improve test coverage for syscall.RawConn An application using syscall.RawConn in a particular way must take account of the operating system or platform-dependent behavior. This change consolidates duplicate code and improves the test coverage for applications that use socket options. Change-Id: Ie42340ac5373875cf1fd9123df0e99a1e7ac280f Reviewed-on: https://go-review.googlesource.com/95335 Reviewed-by: Brad Fitzpatrick --- src/net/rawconn_stub_test.go | 24 +++++ src/net/rawconn_test.go | 142 +++++++++++++++++++++++++++ src/net/rawconn_unix_test.go | 169 +++++++++++--------------------- src/net/rawconn_windows_test.go | 108 ++++++++------------ 4 files changed, 266 insertions(+), 177 deletions(-) create mode 100644 src/net/rawconn_stub_test.go create mode 100644 src/net/rawconn_test.go diff --git a/src/net/rawconn_stub_test.go b/src/net/rawconn_stub_test.go new file mode 100644 index 0000000000..391b4d188e --- /dev/null +++ b/src/net/rawconn_stub_test.go @@ -0,0 +1,24 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build nacl plan9 + +package net + +import ( + "errors" + "syscall" +) + +func readRawConn(c syscall.RawConn, b []byte) (int, error) { + return 0, errors.New("not supported") +} + +func writeRawConn(c syscall.RawConn, b []byte) error { + return errors.New("not supported") +} + +func controlRawConn(c syscall.RawConn, addr Addr) error { + return errors.New("not supported") +} diff --git a/src/net/rawconn_test.go b/src/net/rawconn_test.go new file mode 100644 index 0000000000..287282f117 --- /dev/null +++ b/src/net/rawconn_test.go @@ -0,0 +1,142 @@ +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package net + +import ( + "bytes" + "runtime" + "testing" +) + +func TestRawConnReadWrite(t *testing.T) { + switch runtime.GOOS { + case "nacl", "plan9", "windows": + t.Skipf("not supported on %s", runtime.GOOS) + } + + t.Run("TCP", func(t *testing.T) { + handler := func(ls *localServer, ln Listener) { + c, err := ln.Accept() + if err != nil { + t.Error(err) + return + } + defer c.Close() + + cc, err := ln.(*TCPListener).SyscallConn() + if err != nil { + t.Fatal(err) + } + called := false + op := func(uintptr) bool { + called = true + return true + } + err = cc.Write(op) + if err == nil { + t.Error("Write should return an error") + } + if called { + t.Error("Write shouldn't call op") + } + called = false + err = cc.Read(op) + if err == nil { + t.Error("Read should return an error") + } + if called { + t.Error("Read shouldn't call op") + } + + var b [32]byte + n, err := c.Read(b[:]) + if err != nil { + t.Error(err) + return + } + if _, err := c.Write(b[:n]); err != nil { + t.Error(err) + return + } + } + ls, err := newLocalServer("tcp") + if err != nil { + t.Fatal(err) + } + defer ls.teardown() + if err := ls.buildup(handler); err != nil { + t.Fatal(err) + } + + c, err := Dial(ls.Listener.Addr().Network(), ls.Listener.Addr().String()) + if err != nil { + t.Fatal(err) + } + defer c.Close() + + cc, err := c.(*TCPConn).SyscallConn() + if err != nil { + t.Fatal(err) + } + data := []byte("HELLO-R-U-THERE") + if err := writeRawConn(cc, data); err != nil { + t.Fatal(err) + } + var b [32]byte + n, err := readRawConn(cc, b[:]) + if err != nil { + t.Fatal(err) + } + if bytes.Compare(b[:n], data) != 0 { + t.Fatalf("got %q; want %q", b[:n], data) + } + }) +} + +func TestRawConnControl(t *testing.T) { + switch runtime.GOOS { + case "nacl", "plan9": + t.Skipf("not supported on %s", runtime.GOOS) + } + + t.Run("TCP", func(t *testing.T) { + ln, err := newLocalListener("tcp") + if err != nil { + t.Fatal(err) + } + defer ln.Close() + + cc1, err := ln.(*TCPListener).SyscallConn() + if err != nil { + t.Fatal(err) + } + if err := controlRawConn(cc1, ln.Addr()); err != nil { + t.Fatal(err) + } + + c, err := Dial(ln.Addr().Network(), ln.Addr().String()) + if err != nil { + t.Fatal(err) + } + defer c.Close() + + cc2, err := c.(*TCPConn).SyscallConn() + if err != nil { + t.Fatal(err) + } + if err := controlRawConn(cc2, c.LocalAddr()); err != nil { + t.Fatal(err) + } + + ln.Close() + if err := controlRawConn(cc1, ln.Addr()); err == nil { + t.Fatal("Control after Close should fail") + } + c.Close() + if err := controlRawConn(cc2, c.LocalAddr()); err == nil { + t.Fatal("Control after Close should fail") + } + }) +} diff --git a/src/net/rawconn_unix_test.go b/src/net/rawconn_unix_test.go index 913ad86595..2fe4d2c6ba 100644 --- a/src/net/rawconn_unix_test.go +++ b/src/net/rawconn_unix_test.go @@ -6,139 +6,86 @@ package net -import ( - "bytes" - "syscall" - "testing" -) - -func TestRawConn(t *testing.T) { - handler := func(ls *localServer, ln Listener) { - c, err := ln.Accept() - if err != nil { - t.Error(err) - return - } - defer c.Close() - var b [32]byte - n, err := c.Read(b[:]) - if err != nil { - t.Error(err) - return - } - if _, err := c.Write(b[:n]); err != nil { - t.Error(err) - return - } - } - ls, err := newLocalServer("tcp") - if err != nil { - t.Fatal(err) - } - defer ls.teardown() - if err := ls.buildup(handler); err != nil { - t.Fatal(err) - } - - c, err := Dial(ls.Listener.Addr().Network(), ls.Listener.Addr().String()) - if err != nil { - t.Fatal(err) - } - defer c.Close() - cc, err := c.(*TCPConn).SyscallConn() - if err != nil { - t.Fatal(err) - } +import "syscall" +func readRawConn(c syscall.RawConn, b []byte) (int, error) { var operr error - data := []byte("HELLO-R-U-THERE") - err = cc.Write(func(s uintptr) bool { - _, operr = syscall.Write(int(s), data) + var n int + err := c.Read(func(s uintptr) bool { + n, operr = syscall.Read(int(s), b) if operr == syscall.EAGAIN { return false } return true }) - if err != nil || operr != nil { - t.Fatal(err, operr) + if err != nil { + return n, err } + if operr != nil { + return n, operr + } + return n, nil +} - var nr int - var b [32]byte - err = cc.Read(func(s uintptr) bool { - nr, operr = syscall.Read(int(s), b[:]) +func writeRawConn(c syscall.RawConn, b []byte) error { + var operr error + err := c.Write(func(s uintptr) bool { + _, operr = syscall.Write(int(s), b) if operr == syscall.EAGAIN { return false } return true }) - if err != nil || operr != nil { - t.Fatal(err, operr) - } - if bytes.Compare(b[:nr], data) != 0 { - t.Fatalf("got %#v; want %#v", b[:nr], data) - } - - fn := func(s uintptr) { - operr = syscall.SetsockoptInt(int(s), syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1) - } - err = cc.Control(fn) - if err != nil || operr != nil { - t.Fatal(err, operr) - } - c.Close() - err = cc.Control(fn) - if err == nil { - t.Fatal("should fail") - } -} - -func TestRawConnListener(t *testing.T) { - ln, err := newLocalListener("tcp") if err != nil { - t.Fatal(err) - } - defer ln.Close() - - cc, err := ln.(*TCPListener).SyscallConn() - if err != nil { - t.Fatal(err) - } - - called := false - op := func(uintptr) bool { - called = true - return true - } - - err = cc.Write(op) - if err == nil { - t.Error("Write should return an error") - } - if called { - t.Error("Write shouldn't call op") + return err } - - called = false - err = cc.Read(op) - if err == nil { - t.Error("Read should return an error") - } - if called { - t.Error("Read shouldn't call op") + if operr != nil { + return operr } + return nil +} +func controlRawConn(c syscall.RawConn, addr Addr) error { var operr error fn := func(s uintptr) { _, operr = syscall.GetsockoptInt(int(s), syscall.SOL_SOCKET, syscall.SO_REUSEADDR) + if operr != nil { + return + } + switch addr := addr.(type) { + case *TCPAddr: + // There's no guarantee that IP-level socket + // options work well with dual stack sockets. + // A simple solution would be to take a look + // at the bound address to the raw connection + // and to classify the address family of the + // underlying socket by the bound address: + // + // - When IP.To16() != nil and IP.To4() == nil, + // we can assume that the raw connection + // consists of an IPv6 socket using only + // IPv6 addresses. + // + // - When IP.To16() == nil and IP.To4() != nil, + // the raw connection consists of an IPv4 + // socket using only IPv4 addresses. + // + // - Otherwise, the raw connection is a dual + // stack socket, an IPv6 socket using IPv6 + // addresses including IPv4-mapped or + // IPv4-embedded IPv6 addresses. + if addr.IP.To16() != nil && addr.IP.To4() == nil { + operr = syscall.SetsockoptInt(int(s), syscall.IPPROTO_IPV6, syscall.IPV6_UNICAST_HOPS, 1) + } else if addr.IP.To16() == nil && addr.IP.To4() != nil { + operr = syscall.SetsockoptInt(int(s), syscall.IPPROTO_IP, syscall.IP_TTL, 1) + } + } } - err = cc.Control(fn) - if err != nil || operr != nil { - t.Fatal(err, operr) + if err := c.Control(fn); err != nil { + return err } - ln.Close() - err = cc.Control(fn) - if err == nil { - t.Fatal("Control after Close should fail") + if operr != nil { + return operr } + return nil } diff --git a/src/net/rawconn_windows_test.go b/src/net/rawconn_windows_test.go index 2ee12c3596..1b6777bb17 100644 --- a/src/net/rawconn_windows_test.go +++ b/src/net/rawconn_windows_test.go @@ -6,84 +6,60 @@ package net import ( "syscall" - "testing" "unsafe" ) -func TestRawConn(t *testing.T) { - c, err := newLocalPacketListener("udp") - if err != nil { - t.Fatal(err) - } - defer c.Close() - cc, err := c.(*UDPConn).SyscallConn() - if err != nil { - t.Fatal(err) - } - - var operr error - fn := func(s uintptr) { - operr = syscall.SetsockoptInt(syscall.Handle(s), syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1) - } - err = cc.Control(fn) - if err != nil || operr != nil { - t.Fatal(err, operr) - } - c.Close() - err = cc.Control(fn) - if err == nil { - t.Fatal("should fail") - } +func readRawConn(c syscall.RawConn, b []byte) (int, error) { + return 0, syscall.EWINDOWS } -func TestRawConnListener(t *testing.T) { - ln, err := newLocalListener("tcp") - if err != nil { - t.Fatal(err) - } - defer ln.Close() - - cc, err := ln.(*TCPListener).SyscallConn() - if err != nil { - t.Fatal(err) - } - - called := false - op := func(uintptr) bool { - called = true - return true - } - - err = cc.Write(op) - if err == nil { - t.Error("Write should return an error") - } - if called { - t.Error("Write shouldn't call op") - } - - called = false - err = cc.Read(op) - if err == nil { - t.Error("Read should return an error") - } - if called { - t.Error("Read shouldn't call op") - } +func writeRawConn(c syscall.RawConn, b []byte) error { + return syscall.EWINDOWS +} +func controlRawConn(c syscall.RawConn, addr Addr) error { var operr error fn := func(s uintptr) { var v, l int32 l = int32(unsafe.Sizeof(v)) operr = syscall.Getsockopt(syscall.Handle(s), syscall.SOL_SOCKET, syscall.SO_REUSEADDR, (*byte)(unsafe.Pointer(&v)), &l) + if operr != nil { + return + } + switch addr := addr.(type) { + case *TCPAddr: + // There's no guarantee that IP-level socket + // options work well with dual stack sockets. + // A simple solution would be to take a look + // at the bound address to the raw connection + // and to classify the address family of the + // underlying socket by the bound address: + // + // - When IP.To16() != nil and IP.To4() == nil, + // we can assume that the raw connection + // consists of an IPv6 socket using only + // IPv6 addresses. + // + // - When IP.To16() == nil and IP.To4() != nil, + // the raw connection consists of an IPv4 + // socket using only IPv4 addresses. + // + // - Otherwise, the raw connection is a dual + // stack socket, an IPv6 socket using IPv6 + // addresses including IPv4-mapped or + // IPv4-embedded IPv6 addresses. + if addr.IP.To16() != nil && addr.IP.To4() == nil { + operr = syscall.SetsockoptInt(syscall.Handle(s), syscall.IPPROTO_IPV6, syscall.IPV6_UNICAST_HOPS, 1) + } else if addr.IP.To16() == nil && addr.IP.To4() != nil { + operr = syscall.SetsockoptInt(syscall.Handle(s), syscall.IPPROTO_IP, syscall.IP_TTL, 1) + } + } } - err = cc.Control(fn) - if err != nil || operr != nil { - t.Fatal(err, operr) + if err := c.Control(fn); err != nil { + return err } - ln.Close() - err = cc.Control(fn) - if err == nil { - t.Fatal("Control after Close should fail") + if operr != nil { + return operr } + return nil } -- 2.50.0