]> Cypherpunks repositories - gostls13.git/commitdiff
net: add missing Close tests
authorMikio Hara <mikioh.mikioh@gmail.com>
Wed, 29 Apr 2015 08:16:21 +0000 (17:16 +0900)
committerMikio Hara <mikioh.mikioh@gmail.com>
Wed, 29 Apr 2015 23:01:45 +0000 (23:01 +0000)
This change adds missing CloseRead test and Close tests on Conn,
Listener and PacketConn with various networks.

Change-Id: Iadf99eaf526a323f853d203edc7c8d0577f67972
Reviewed-on: https://go-review.googlesource.com/9469
Reviewed-by: Ian Lance Taylor <iant@golang.org>
src/net/net_test.go

index df9373996fa0873318757940d5ccc7d38d14911a..3907ce4aa5664096a8844d141e380d0fdb7b81c1 100644 (file)
@@ -6,229 +6,251 @@ package net
 
 import (
        "io"
-       "io/ioutil"
        "os"
        "runtime"
        "testing"
-       "time"
 )
 
-func TestShutdown(t *testing.T) {
-       if runtime.GOOS == "plan9" {
-               t.Skipf("skipping test on %q", runtime.GOOS)
+func TestCloseRead(t *testing.T) {
+       switch runtime.GOOS {
+       case "nacl", "plan9":
+               t.Skipf("not supported on %s", runtime.GOOS)
        }
-       ln, err := Listen("tcp", "127.0.0.1:0")
-       if err != nil {
-               if ln, err = Listen("tcp6", "[::1]:0"); err != nil {
-                       t.Fatalf("ListenTCP on :0: %v", err)
+
+       for _, network := range []string{"tcp", "unix", "unixpacket"} {
+               if !testableNetwork(network) {
+                       t.Logf("skipping %s test", network)
+                       continue
                }
-       }
 
-       go func() {
-               defer ln.Close()
-               c, err := ln.Accept()
+               ln, err := newLocalListener(network)
                if err != nil {
-                       t.Errorf("Accept: %v", err)
-                       return
+                       t.Fatal(err)
                }
-               var buf [10]byte
-               n, err := c.Read(buf[:])
-               if perr := parseReadError(err); perr != nil {
-                       t.Error(perr)
+               switch network {
+               case "unix", "unixpacket":
+                       defer os.Remove(ln.Addr().String())
                }
-               if n != 0 || err != io.EOF {
-                       t.Errorf("server Read = %d, %v; want 0, io.EOF", n, err)
-                       return
-               }
-               c.Write([]byte("response"))
-               c.Close()
-       }()
+               defer ln.Close()
 
-       c, err := Dial("tcp", ln.Addr().String())
-       if err != nil {
-               t.Fatalf("Dial: %v", err)
-       }
-       defer c.Close()
+               c, err := Dial(ln.Addr().Network(), ln.Addr().String())
+               if err != nil {
+                       t.Fatal(err)
+               }
+               switch network {
+               case "unix", "unixpacket":
+                       defer os.Remove(c.LocalAddr().String())
+               }
+               defer c.Close()
 
-       err = c.(*TCPConn).CloseWrite()
-       if err != nil {
-               t.Fatalf("CloseWrite: %v", err)
-       }
-       var buf [10]byte
-       n, err := c.Read(buf[:])
-       if err != nil {
-               t.Fatalf("client Read: %d, %v", n, err)
-       }
-       got := string(buf[:n])
-       if got != "response" {
-               t.Errorf("read = %q, want \"response\"", got)
+               switch c := c.(type) {
+               case *TCPConn:
+                       err = c.CloseRead()
+               case *UnixConn:
+                       err = c.CloseRead()
+               }
+               if err != nil {
+                       if perr := parseCloseError(err); perr != nil {
+                               t.Error(perr)
+                       }
+                       t.Fatal(err)
+               }
+               var b [1]byte
+               n, err := c.Read(b[:])
+               if n != 0 || err == nil {
+                       t.Fatalf("got (%d, %v); want (0, error)", n, err)
+               }
        }
 }
 
-func TestShutdownUnix(t *testing.T) {
-       if !testableNetwork("unix") {
-               t.Skip("unix test")
-       }
-
-       f, err := ioutil.TempFile("", "go_net_unixtest")
-       if err != nil {
-               t.Fatalf("TempFile: %s", err)
-       }
-       f.Close()
-       tmpname := f.Name()
-       os.Remove(tmpname)
-       ln, err := Listen("unix", tmpname)
-       if err != nil {
-               t.Fatalf("ListenUnix on %s: %s", tmpname, err)
+func TestCloseWrite(t *testing.T) {
+       switch runtime.GOOS {
+       case "nacl", "plan9":
+               t.Skipf("not supported on %s", runtime.GOOS)
        }
-       defer func() {
-               ln.Close()
-               os.Remove(tmpname)
-       }()
 
-       go func() {
+       handler := func(ls *localServer, ln Listener) {
                c, err := ln.Accept()
                if err != nil {
-                       t.Errorf("Accept: %v", err)
+                       t.Error(err)
                        return
                }
-               var buf [10]byte
-               n, err := c.Read(buf[:])
-               if perr := parseReadError(err); perr != nil {
-                       t.Error(perr)
-               }
+               defer c.Close()
+
+               var b [1]byte
+               n, err := c.Read(b[:])
                if n != 0 || err != io.EOF {
-                       t.Errorf("server Read = %d, %v; want 0, io.EOF", n, err)
+                       t.Errorf("got (%d, %v); want (0, io.EOF)", n, err)
+                       return
+               }
+               switch c := c.(type) {
+               case *TCPConn:
+                       err = c.CloseWrite()
+               case *UnixConn:
+                       err = c.CloseWrite()
+               }
+               if err != nil {
+                       if perr := parseCloseError(err); perr != nil {
+                               t.Error(perr)
+                       }
+                       t.Error(err)
+                       return
+               }
+               n, err = c.Write(b[:])
+               if err == nil {
+                       t.Errorf("got (%d, %v); want (any, error)", n, err)
                        return
                }
-               c.Write([]byte("response"))
-               c.Close()
-       }()
-
-       c, err := Dial("unix", tmpname)
-       if err != nil {
-               t.Fatalf("Dial: %v", err)
        }
-       defer c.Close()
 
-       err = c.(*UnixConn).CloseWrite()
-       if err != nil {
-               t.Fatalf("CloseWrite: %v", err)
-       }
-       var buf [10]byte
-       n, err := c.Read(buf[:])
-       if err != nil {
-               t.Fatalf("client Read: %d, %v", n, err)
-       }
-       got := string(buf[:n])
-       if got != "response" {
-               t.Errorf("read = %q, want \"response\"", got)
-       }
-}
+       for _, network := range []string{"tcp", "unix", "unixpacket"} {
+               if !testableNetwork(network) {
+                       t.Logf("skipping %s test", network)
+                       continue
+               }
 
-func TestTCPListenClose(t *testing.T) {
-       ln, err := Listen("tcp", "127.0.0.1:0")
-       if err != nil {
-               t.Fatalf("Listen failed: %v", err)
-       }
+               ls, err := newLocalServer(network)
+               if err != nil {
+                       t.Fatal(err)
+               }
+               defer ls.teardown()
+               if err := ls.buildup(handler); err != nil {
+                       t.Fatal(err)
+               }
 
-       done := make(chan bool, 1)
-       go func() {
-               time.Sleep(100 * time.Millisecond)
-               ln.Close()
-       }()
-       go func() {
-               c, err := ln.Accept()
+               c, err := Dial(ls.Listener.Addr().Network(), ls.Listener.Addr().String())
+               if err != nil {
+                       t.Fatal(err)
+               }
+               switch network {
+               case "unix", "unixpacket":
+                       defer os.Remove(c.LocalAddr().String())
+               }
+               defer c.Close()
+
+               switch c := c.(type) {
+               case *TCPConn:
+                       err = c.CloseWrite()
+               case *UnixConn:
+                       err = c.CloseWrite()
+               }
+               if err != nil {
+                       if perr := parseCloseError(err); perr != nil {
+                               t.Error(perr)
+                       }
+                       t.Fatal(err)
+               }
+               var b [1]byte
+               n, err := c.Read(b[:])
+               if n != 0 || err != io.EOF {
+                       t.Fatalf("got (%d, %v); want (0, io.EOF)", n, err)
+               }
+               n, err = c.Write(b[:])
                if err == nil {
-                       c.Close()
-                       t.Error("Accept succeeded")
-               } else {
-                       t.Logf("Accept timeout error: %s (any error is fine)", err)
-               }
-               done <- true
-       }()
-       select {
-       case <-done:
-       case <-time.After(2 * time.Second):
-               t.Fatal("timeout waiting for TCP close")
+                       t.Fatalf("got (%d, %v); want (any, error)", n, err)
+               }
        }
 }
 
-func TestUDPListenClose(t *testing.T) {
-       switch runtime.GOOS {
-       case "plan9":
-               t.Skipf("skipping test on %q", runtime.GOOS)
-       }
-       ln, err := ListenPacket("udp", "127.0.0.1:0")
-       if err != nil {
-               t.Fatalf("Listen failed: %v", err)
-       }
+func TestConnClose(t *testing.T) {
+       for _, network := range []string{"tcp", "unix", "unixpacket"} {
+               if !testableNetwork(network) {
+                       t.Logf("skipping %s test", network)
+                       continue
+               }
 
-       buf := make([]byte, 1000)
-       done := make(chan bool, 1)
-       go func() {
-               time.Sleep(100 * time.Millisecond)
-               ln.Close()
-       }()
-       go func() {
-               _, _, err = ln.ReadFrom(buf)
-               if perr := parseReadError(err); perr != nil {
-                       t.Error(perr)
+               ln, err := newLocalListener(network)
+               if err != nil {
+                       t.Fatal(err)
                }
-               if err == nil {
-                       t.Error("ReadFrom succeeded")
-               } else {
-                       t.Logf("ReadFrom timeout error: %s (any error is fine)", err)
-               }
-               done <- true
-       }()
-       select {
-       case <-done:
-       case <-time.After(2 * time.Second):
-               t.Fatal("timeout waiting for UDP close")
-       }
-}
+               switch network {
+               case "unix", "unixpacket":
+                       defer os.Remove(ln.Addr().String())
+               }
+               defer ln.Close()
 
-func TestTCPClose(t *testing.T) {
-       switch runtime.GOOS {
-       case "plan9":
-               t.Skipf("skipping test on %q", runtime.GOOS)
-       }
-       l, err := Listen("tcp", "127.0.0.1:0")
-       if err != nil {
-               t.Fatal(err)
-       }
-       defer l.Close()
+               c, err := Dial(ln.Addr().Network(), ln.Addr().String())
+               if err != nil {
+                       t.Fatal(err)
+               }
+               switch network {
+               case "unix", "unixpacket":
+                       defer os.Remove(c.LocalAddr().String())
+               }
+               defer c.Close()
 
-       read := func(r io.Reader) error {
-               var m [1]byte
-               _, err := r.Read(m[:])
-               return err
+               if err := c.Close(); err != nil {
+                       if perr := parseCloseError(err); perr != nil {
+                               t.Error(perr)
+                       }
+                       t.Fatal(err)
+               }
+               var b [1]byte
+               n, err := c.Read(b[:])
+               if n != 0 || err == nil {
+                       t.Fatalf("got (%d, %v); want (0, error)", n, err)
+               }
        }
+}
 
-       go func() {
-               c, err := Dial("tcp", l.Addr().String())
+func TestListenerClose(t *testing.T) {
+       for _, network := range []string{"tcp", "unix", "unixpacket"} {
+               if !testableNetwork(network) {
+                       t.Logf("skipping %s test", network)
+                       continue
+               }
+
+               ln, err := newLocalListener(network)
                if err != nil {
-                       t.Errorf("Dial: %v", err)
-                       return
+                       t.Fatal(err)
                }
+               switch network {
+               case "unix", "unixpacket":
+                       defer os.Remove(ln.Addr().String())
+               }
+               defer ln.Close()
 
-               go read(c)
+               if err := ln.Close(); err != nil {
+                       if perr := parseCloseError(err); perr != nil {
+                               t.Error(perr)
+                       }
+                       t.Fatal(err)
+               }
+               c, err := ln.Accept()
+               if err == nil {
+                       c.Close()
+                       t.Fatal("should fail")
+               }
+       }
+}
 
-               time.Sleep(10 * time.Millisecond)
-               c.Close()
-       }()
+func TestPacketConnClose(t *testing.T) {
+       for _, network := range []string{"udp", "unixgram"} {
+               if !testableNetwork(network) {
+                       t.Logf("skipping %s test", network)
+                       continue
+               }
 
-       c, err := l.Accept()
-       if err != nil {
-               t.Fatal(err)
-       }
-       defer c.Close()
+               c, err := newLocalPacketListener(network)
+               if err != nil {
+                       t.Fatal(err)
+               }
+               switch network {
+               case "unixgram":
+                       defer os.Remove(c.LocalAddr().String())
+               }
+               defer c.Close()
 
-       for err == nil {
-               err = read(c)
-       }
-       if err != nil && err != io.EOF {
-               t.Fatal(err)
+               if err := c.Close(); err != nil {
+                       if perr := parseCloseError(err); perr != nil {
+                               t.Error(perr)
+                       }
+                       t.Fatal(err)
+               }
+               var b [1]byte
+               n, _, err := c.ReadFrom(b[:])
+               if n != 0 || err == nil {
+                       t.Fatalf("got (%d, %v); want (0, error)", n, err)
+               }
        }
 }