]> Cypherpunks repositories - gostls13.git/commitdiff
net: add helpers for server testing
authorMikio Hara <mikioh.mikioh@gmail.com>
Sun, 19 Apr 2015 14:42:11 +0000 (23:42 +0900)
committerMikio Hara <mikioh.mikioh@gmail.com>
Tue, 21 Apr 2015 01:00:11 +0000 (01:00 +0000)
Also moves a few server test helpers into mockserver_test.go.

Change-Id: I5a95c9bc6f0c4683751bcca77e26a8586a377466
Reviewed-on: https://go-review.googlesource.com/9106
Reviewed-by: Ian Lance Taylor <iant@golang.org>
src/net/conn_test.go
src/net/error_test.go
src/net/mockserver_test.go
src/net/protoconn_test.go
src/net/tcp_test.go

index c5d11b99ee0cb629c4febd700b4ddab0747b84e3..cb055d880bc02f015877ffd07f30cc63d6a2dbae 100644 (file)
@@ -18,8 +18,7 @@ import (
 const someTimeout = 10 * time.Second
 
 func TestConnAndListener(t *testing.T) {
-       handler := func(ls *localServer, ln Listener) { transponder(t, ln) }
-       for _, network := range []string{"tcp", "unix", "unixpacket"} {
+       for i, network := range []string{"tcp", "unix", "unixpacket"} {
                if !testableNetwork(network) {
                        t.Logf("skipping %s test", network)
                        continue
@@ -30,6 +29,8 @@ func TestConnAndListener(t *testing.T) {
                        t.Fatalf("Listen failed: %v", err)
                }
                defer ls.teardown()
+               ch := make(chan error, 1)
+               handler := func(ls *localServer, ln Listener) { transponder(ln, ch) }
                if err := ls.buildup(handler); err != nil {
                        t.Fatal(err)
                }
@@ -56,40 +57,9 @@ func TestConnAndListener(t *testing.T) {
                if _, err := c.Read(rb); err != nil {
                        t.Fatalf("Conn.Read failed: %v", err)
                }
-       }
-}
-
-func transponder(t *testing.T, ln Listener) {
-       switch ln := ln.(type) {
-       case *TCPListener:
-               ln.SetDeadline(time.Now().Add(someTimeout))
-       case *UnixListener:
-               ln.SetDeadline(time.Now().Add(someTimeout))
-       }
-       c, err := ln.Accept()
-       if err != nil {
-               t.Errorf("Listener.Accept failed: %v", err)
-               return
-       }
-       defer c.Close()
 
-       network := ln.Addr().Network()
-       if c.LocalAddr().Network() != network || c.LocalAddr().Network() != network {
-               t.Errorf("got %v->%v; expected %v->%v", c.LocalAddr().Network(), c.RemoteAddr().Network(), network, network)
-               return
-       }
-       c.SetDeadline(time.Now().Add(someTimeout))
-       c.SetReadDeadline(time.Now().Add(someTimeout))
-       c.SetWriteDeadline(time.Now().Add(someTimeout))
-
-       b := make([]byte, 128)
-       n, err := c.Read(b)
-       if err != nil {
-               t.Errorf("Conn.Read failed: %v", err)
-               return
-       }
-       if _, err := c.Write(b[:n]); err != nil {
-               t.Errorf("Conn.Write failed: %v", err)
-               return
+               for err := range ch {
+                       t.Errorf("#%d: %v", i, err)
+               }
        }
 }
index ebb395d8f9b3d80e541cd077885ea1f648badf06..7c12cba76294a3d52acc55a7acb09e524a89cdbd 100644 (file)
@@ -317,6 +317,11 @@ second:
                return nil
        }
        switch err := nestedErr.(type) {
+       case *AddrError, addrinfoErrno, *DNSError, InvalidAddrError, *ParseError, *timeoutError, UnknownNetworkError:
+               return nil
+       case *DNSConfigError:
+               nestedErr = err.Err
+               goto third
        case *os.SyscallError:
                nestedErr = err.Err
                goto third
index 3a21452f714342ab4332105a2ecbdec44f464cc9..07ffd63386305a8317d08cab7424c72ca9fd86f2 100644 (file)
@@ -6,10 +6,26 @@ package net
 
 import (
        "fmt"
+       "io/ioutil"
        "os"
        "sync"
+       "time"
 )
 
+// testUnixAddr uses ioutil.TempFile to get a name that is unique.
+// It also uses /tmp directory in case it is prohibited to create UNIX
+// sockets in TMPDIR.
+func testUnixAddr() string {
+       f, err := ioutil.TempFile("", "nettest")
+       if err != nil {
+               panic(err)
+       }
+       addr := f.Name()
+       f.Close()
+       os.Remove(addr)
+       return addr
+}
+
 func newLocalListener(network string) (Listener, error) {
        switch network {
        case "tcp", "tcp4", "tcp6":
@@ -155,3 +171,212 @@ func newDualStackServer(lns []streamListener) (*dualStackServer, error) {
        }
        return dss, nil
 }
+
+func transponder(ln Listener, ch chan<- error) {
+       defer close(ch)
+
+       switch ln := ln.(type) {
+       case *TCPListener:
+               ln.SetDeadline(time.Now().Add(someTimeout))
+       case *UnixListener:
+               ln.SetDeadline(time.Now().Add(someTimeout))
+       }
+       c, err := ln.Accept()
+       if err != nil {
+               if perr := parseAcceptError(err); perr != nil {
+                       ch <- perr
+               }
+               ch <- err
+               return
+       }
+       defer c.Close()
+
+       network := ln.Addr().Network()
+       if c.LocalAddr().Network() != network || c.LocalAddr().Network() != network {
+               ch <- fmt.Errorf("got %v->%v; expected %v->%v", c.LocalAddr().Network(), c.RemoteAddr().Network(), network, network)
+               return
+       }
+       c.SetDeadline(time.Now().Add(someTimeout))
+       c.SetReadDeadline(time.Now().Add(someTimeout))
+       c.SetWriteDeadline(time.Now().Add(someTimeout))
+
+       b := make([]byte, 256)
+       n, err := c.Read(b)
+       if err != nil {
+               if perr := parseReadError(err); perr != nil {
+                       ch <- perr
+               }
+               ch <- err
+               return
+       }
+       if _, err := c.Write(b[:n]); err != nil {
+               if perr := parseWriteError(err); perr != nil {
+                       ch <- perr
+               }
+               ch <- err
+               return
+       }
+}
+
+func transceiver(c Conn, wb []byte, ch chan<- error) {
+       defer close(ch)
+
+       c.SetDeadline(time.Now().Add(someTimeout))
+       c.SetReadDeadline(time.Now().Add(someTimeout))
+       c.SetWriteDeadline(time.Now().Add(someTimeout))
+
+       n, err := c.Write(wb)
+       if err != nil {
+               if perr := parseWriteError(err); perr != nil {
+                       ch <- perr
+               }
+               ch <- err
+               return
+       }
+       if n != len(wb) {
+               ch <- fmt.Errorf("wrote %d; want %d", n, len(wb))
+       }
+       rb := make([]byte, len(wb))
+       n, err = c.Read(rb)
+       if err != nil {
+               if perr := parseReadError(err); perr != nil {
+                       ch <- perr
+               }
+               ch <- err
+               return
+       }
+       if n != len(wb) {
+               ch <- fmt.Errorf("read %d; want %d", n, len(wb))
+       }
+}
+
+func newLocalPacketListener(network string) (PacketConn, error) {
+       switch network {
+       case "udp", "udp4", "udp6":
+               if supportsIPv4 {
+                       return ListenPacket("udp4", "127.0.0.1:0")
+               }
+               if supportsIPv6 {
+                       return ListenPacket("udp6", "[::1]:0")
+               }
+       case "unixgram":
+               return ListenPacket(network, testUnixAddr())
+       }
+       return nil, fmt.Errorf("%s is not supported", network)
+}
+
+type localPacketServer struct {
+       pcmu sync.RWMutex
+       PacketConn
+       done chan bool // signal that indicates server stopped
+}
+
+func (ls *localPacketServer) buildup(handler func(*localPacketServer, PacketConn)) error {
+       go func() {
+               handler(ls, ls.PacketConn)
+               close(ls.done)
+       }()
+       return nil
+}
+
+func (ls *localPacketServer) teardown() error {
+       ls.pcmu.Lock()
+       if ls.PacketConn != nil {
+               network := ls.PacketConn.LocalAddr().Network()
+               address := ls.PacketConn.LocalAddr().String()
+               ls.PacketConn.Close()
+               <-ls.done
+               ls.PacketConn = nil
+               switch network {
+               case "unixgram":
+                       os.Remove(address)
+               }
+       }
+       ls.pcmu.Unlock()
+       return nil
+}
+
+func newLocalPacketServer(network string) (*localPacketServer, error) {
+       c, err := newLocalPacketListener(network)
+       if err != nil {
+               return nil, err
+       }
+       return &localPacketServer{PacketConn: c, done: make(chan bool)}, nil
+}
+
+type packetListener struct {
+       PacketConn
+}
+
+func (pl *packetListener) newLocalServer() (*localPacketServer, error) {
+       return &localPacketServer{PacketConn: pl.PacketConn, done: make(chan bool)}, nil
+}
+
+func packetTransponder(c PacketConn, ch chan<- error) {
+       defer close(ch)
+
+       c.SetDeadline(time.Now().Add(someTimeout))
+       c.SetReadDeadline(time.Now().Add(someTimeout))
+       c.SetWriteDeadline(time.Now().Add(someTimeout))
+
+       b := make([]byte, 256)
+       n, peer, err := c.ReadFrom(b)
+       if err != nil {
+               if perr := parseReadError(err); perr != nil {
+                       ch <- perr
+               }
+               ch <- err
+               return
+       }
+       if peer == nil { // for connected-mode sockets
+               switch c.LocalAddr().Network() {
+               case "udp":
+                       peer, err = ResolveUDPAddr("udp", string(b[:n]))
+               case "unixgram":
+                       peer, err = ResolveUnixAddr("unixgram", string(b[:n]))
+               }
+               if err != nil {
+                       ch <- err
+                       return
+               }
+       }
+       if _, err := c.WriteTo(b[:n], peer); err != nil {
+               if perr := parseWriteError(err); perr != nil {
+                       ch <- perr
+               }
+               ch <- err
+               return
+       }
+}
+
+func packetTransceiver(c PacketConn, wb []byte, dst Addr, ch chan<- error) {
+       defer close(ch)
+
+       c.SetDeadline(time.Now().Add(someTimeout))
+       c.SetReadDeadline(time.Now().Add(someTimeout))
+       c.SetWriteDeadline(time.Now().Add(someTimeout))
+
+       n, err := c.WriteTo(wb, dst)
+       if err != nil {
+               if perr := parseWriteError(err); perr != nil {
+                       ch <- perr
+               }
+               ch <- err
+               return
+       }
+       if n != len(wb) {
+               ch <- fmt.Errorf("wrote %d; want %d", n, len(wb))
+       }
+       rb := make([]byte, len(wb))
+       n, _, err = c.ReadFrom(rb)
+       if err != nil {
+               if perr := parseReadError(err); perr != nil {
+                       ch <- perr
+               }
+               ch <- err
+               return
+       }
+       if n != len(wb) {
+               ch <- fmt.Errorf("read %d; want %d", n, len(wb))
+       }
+}
index aad8686720142eeaaf0084b774be34dc1420cd7e..b04c4e58e0492aa0de81274c5ca3f389b86ae471 100644 (file)
@@ -8,7 +8,6 @@
 package net
 
 import (
-       "io/ioutil"
        "os"
        "runtime"
        "testing"
@@ -21,20 +20,6 @@ import (
 //     golang.org/x/net/ipv6
 //     golang.org/x/net/icmp
 
-// testUnixAddr uses ioutil.TempFile to get a name that is unique. It
-// also uses /tmp directory in case it is prohibited to create UNIX
-// sockets in TMPDIR.
-func testUnixAddr() string {
-       f, err := ioutil.TempFile("", "nettest")
-       if err != nil {
-               panic(err)
-       }
-       addr := f.Name()
-       f.Close()
-       os.Remove(addr)
-       return addr
-}
-
 func TestTCPListenerSpecificMethods(t *testing.T) {
        switch runtime.GOOS {
        case "plan9":
@@ -84,7 +69,8 @@ func TestTCPConnSpecificMethods(t *testing.T) {
        if err != nil {
                t.Fatalf("ListenTCP failed: %v", err)
        }
-       handler := func(ls *localServer, ln Listener) { transponder(t, ls.Listener) }
+       ch := make(chan error, 1)
+       handler := func(ls *localServer, ln Listener) { transponder(ls.Listener, ch) }
        ls, err := (&streamListener{Listener: ln}).newLocalServer()
        if err != nil {
                t.Fatal(err)
@@ -120,6 +106,10 @@ func TestTCPConnSpecificMethods(t *testing.T) {
        if _, err := c.Read(rb); err != nil {
                t.Fatalf("TCPConn.Read failed: %v", err)
        }
+
+       for err := range ch {
+               t.Error(err)
+       }
 }
 
 func TestUDPConnSpecificMethods(t *testing.T) {
index 9b2c8b3cd3945c30f7352b59421adbaae2c0d800..cb58ab571dbe4668e0e09cab4e518706a7484f21 100644 (file)
@@ -398,8 +398,7 @@ func TestIPv6LinkLocalUnicastTCP(t *testing.T) {
                        {"tcp6", "[ip6-localhost%" + ifi.Name + "]:0", true},
                }...)
        }
-       handler := func(ls *localServer, ln Listener) { transponder(t, ln) }
-       for _, tt := range tests {
+       for i, tt := range tests {
                ln, err := Listen(tt.net, tt.addr)
                if err != nil {
                        // It might return "LookupHost returned no
@@ -412,6 +411,8 @@ func TestIPv6LinkLocalUnicastTCP(t *testing.T) {
                        t.Fatal(err)
                }
                defer ls.teardown()
+               ch := make(chan error, 1)
+               handler := func(ls *localServer, ln Listener) { transponder(ln, ch) }
                if err := ls.buildup(handler); err != nil {
                        t.Fatal(err)
                }
@@ -438,6 +439,10 @@ func TestIPv6LinkLocalUnicastTCP(t *testing.T) {
                if _, err := c.Read(b); err != nil {
                        t.Fatalf("Conn.Read failed: %v", err)
                }
+
+               for err := range ch {
+                       t.Errorf("#%d: %v", i, err)
+               }
        }
 }