]> Cypherpunks repositories - gostls13.git/commitdiff
net: add shutdown: TCPConn.CloseWrite and CloseRead
authorBrad Fitzpatrick <bradfitz@golang.org>
Wed, 28 Sep 2011 15:12:38 +0000 (08:12 -0700)
committerBrad Fitzpatrick <bradfitz@golang.org>
Wed, 28 Sep 2011 15:12:38 +0000 (08:12 -0700)
R=golang-dev, rsc, iant
CC=golang-dev
https://golang.org/cl/5136052

src/pkg/net/fd.go
src/pkg/net/fd_windows.go
src/pkg/net/net_test.go
src/pkg/net/tcpsock_posix.go

index 9084e88755e995a471a1bd9aef8afaa7109708cd..a0c56f78ee9b761c2e2ad2ac44bf8e5c0b5ccc23 100644 (file)
@@ -358,6 +358,22 @@ func (fd *netFD) Close() os.Error {
        return nil
 }
 
+func (fd *netFD) CloseRead() os.Error {
+       if fd == nil || fd.sysfile == nil {
+               return os.EINVAL
+       }
+       syscall.Shutdown(fd.sysfd, syscall.SHUT_RD)
+       return nil
+}
+
+func (fd *netFD) CloseWrite() os.Error {
+       if fd == nil || fd.sysfile == nil {
+               return os.EINVAL
+       }
+       syscall.Shutdown(fd.sysfd, syscall.SHUT_WR)
+       return nil
+}
+
 func (fd *netFD) Read(p []byte) (n int, err os.Error) {
        if fd == nil {
                return 0, os.EINVAL
index b025bddea0b58bf3c7f5ca523f09e6b241808ccd..8155d04aae0d1337ed4f3f64908a994e1da3e5dd 100644 (file)
@@ -312,6 +312,22 @@ func (fd *netFD) Close() os.Error {
        return nil
 }
 
+func (fd *netFD) CloseRead() os.Error {
+       if fd == nil || fd.sysfd == syscall.InvalidHandle {
+               return os.EINVAL
+       }
+       syscall.Shutdown(fd.sysfd, syscall.SHUT_RD)
+       return nil
+}
+
+func (fd *netFD) CloseWrite() os.Error {
+       if fd == nil || fd.sysfd == syscall.InvalidHandle {
+               return os.EINVAL
+       }
+       syscall.Shutdown(fd.sysfd, syscall.SHUT_WR)
+       return nil
+}
+
 // Read from network.
 
 type readOp struct {
index 698a845277552febf3f5be6df494679e4c7fb817..e4d7a253e207291b57e3a73edd143e598d9ad74a 100644 (file)
@@ -6,6 +6,7 @@ package net
 
 import (
        "flag"
+       "os"
        "regexp"
        "testing"
 )
@@ -119,3 +120,46 @@ func TestReverseAddress(t *testing.T) {
                }
        }
 }
+
+func TestShutdown(t *testing.T) {
+       l, err := Listen("tcp", "127.0.0.1:0")
+       if err != nil {
+               if l, err = Listen("tcp6", "[::1]:0"); err != nil {
+                       t.Fatalf("ListenTCP on :0: %v", err)
+               }
+       }
+
+       go func() {
+               c, err := l.Accept()
+               if err != nil {
+                       t.Fatalf("Accept: %v", err)
+               }
+               var buf [10]byte
+               n, err := c.Read(buf[:])
+               if n != 0 || err != os.EOF {
+                       t.Fatalf("server Read = %d, %v; want 0, os.EOF", n, err)
+               }
+               c.Write([]byte("response"))
+               c.Close()
+       }()
+
+       c, err := Dial("tcp", l.Addr().String())
+       if err != nil {
+               t.Fatalf("Dial: %v", err)
+       }
+       defer c.Close()
+
+       err = c.(*TCPConn).CloseWrite()
+       if err != nil {
+               t.Fatalf("CloseWrite: %v", err)
+       }
+       var buf [10]byte
+       n, err := c.Read(buf[:])
+       if err != nil {
+               t.Fatalf("client Read: %d, %v", n, err)
+       }
+       got := string(buf[:n])
+       if got != "response" {
+               t.Errorf("read = %q, want \"response\"", got)
+       }
+}
index 35d536c319acf0d5caab9d047c6696929c8ad899..740a63d30385737bc29b13c0e71c559c71d34abe 100644 (file)
@@ -100,6 +100,24 @@ func (c *TCPConn) Close() os.Error {
        return err
 }
 
+// CloseRead shuts down the reading side of the TCP connection.
+// Most callers should just use Close.
+func (c *TCPConn) CloseRead() os.Error {
+       if !c.ok() {
+               return os.EINVAL
+       }
+       return c.fd.CloseRead()
+}
+
+// CloseWrite shuts down the writing side of the TCP connection.
+// Most callers should just use Close.
+func (c *TCPConn) CloseWrite() os.Error {
+       if !c.ok() {
+               return os.EINVAL
+       }
+       return c.fd.CloseWrite()
+}
+
 // LocalAddr returns the local network address, a *TCPAddr.
 func (c *TCPConn) LocalAddr() Addr {
        if !c.ok() {