]> Cypherpunks repositories - gostls13.git/commitdiff
netchan: improve closing and shutdown. there's still more to do.
authorRob Pike <r@golang.org>
Sat, 29 May 2010 05:32:29 +0000 (22:32 -0700)
committerRob Pike <r@golang.org>
Sat, 29 May 2010 05:32:29 +0000 (22:32 -0700)
Fixes #805.

R=rsc
CC=golang-dev
https://golang.org/cl/1400041

src/pkg/netchan/export.go
src/pkg/netchan/import.go
src/pkg/netchan/netchan_test.go

index 89deb20ae24589259ac85f5b16a360c5e9afccdb..ea1d63fb9e3d923f05a3e2c045e37f69a883e769 100644 (file)
        use the channels in the usual way.
 
        Networked channels are not synchronized; they always behave
-       as if there is a buffer of at least one element between the
-       two machines.
+       as if they are buffered channels of at least one element.
 */
 package netchan
 
+// BUG: can't use range clause to receive when using ImportNValues with N non-zero.
+
 import (
        "log"
        "net"
@@ -143,6 +144,10 @@ func (client *expClient) serveRecv(hdr header, count int) {
        }
        for {
                val := ech.ch.Recv()
+               if ech.ch.Closed() {
+                       client.sendError(&hdr, os.EOF.String())
+                       break
+               }
                if err := client.encode(&hdr, payData, val.Interface()); err != nil {
                        log.Stderr("error encoding client response:", err)
                        client.sendError(&hdr, err.String())
index bde36f6152ac0107d16e0fc07e8cf75c684ab1ae..454e265b217cfbdae559d3037baae09baed406fb 100644 (file)
@@ -49,6 +49,17 @@ func NewImporter(network, remoteaddr string) (*Importer, os.Error) {
        return imp, nil
 }
 
+// shutdown closes all channels for which we are receiving data from the remote side.
+func (imp *Importer) shutdown() {
+       imp.chanLock.Lock()
+       for _, ich := range imp.chans {
+               if ich.dir == Recv {
+                       ich.ch.Close()
+               }
+       }
+       imp.chanLock.Unlock()
+}
+
 // Handle the data from a single imported data stream, which will
 // have the form
 //     (response, data)*
@@ -60,6 +71,7 @@ func (imp *Importer) run() {
        for {
                if e := imp.decode(hdr); e != nil {
                        log.Stderr("importer header:", e)
+                       imp.shutdown()
                        return
                }
                switch hdr.payloadType {
@@ -72,7 +84,7 @@ func (imp *Importer) run() {
                        }
                        if err.error != "" {
                                log.Stderr("importer response error:", err.error)
-                               // TODO: tear down connection
+                               imp.shutdown()
                                return
                        }
                default:
index bce37c866963e32cedef81dce6bd3ff4e2804a02..1981a00c9ecc48404faf0d0af5867e6630ed5610 100644 (file)
@@ -11,17 +11,19 @@ type value struct {
        s string
 }
 
-const count = 10
+const count = 10     // number of items in most tests
+const closeCount = 5 // number of items when sender closes early
 
-func exportSend(exp *Exporter, t *testing.T) {
+func exportSend(exp *Exporter, n int, t *testing.T) {
        ch := make(chan value)
        err := exp.Export("exportedSend", ch, Send, new(value))
        if err != nil {
                t.Fatal("exportSend:", err)
        }
-       for i := 0; i < count; i++ {
+       for i := 0; i < n; i++ {
                ch <- value{23 + i, "hello"}
        }
+       close(ch)
 }
 
 func exportReceive(exp *Exporter, t *testing.T) {
@@ -46,6 +48,12 @@ func importReceive(imp *Importer, t *testing.T) {
        }
        for i := 0; i < count; i++ {
                v := <-ch
+               if closed(ch) {
+                       if i != closeCount {
+                               t.Errorf("expected close at %d; got one at %d\n", count/2, i)
+                       }
+                       break
+               }
                if v.i != 23+i || v.s != "hello" {
                        t.Errorf("importReceive: bad value: expected %d, hello; got %+v", 23+i, v)
                }
@@ -72,7 +80,7 @@ func TestExportSendImportReceive(t *testing.T) {
        if err != nil {
                t.Fatal("new importer:", err)
        }
-       go exportSend(exp, t)
+       go exportSend(exp, count, t)
        importReceive(imp, t)
 }
 
@@ -88,3 +96,16 @@ func TestExportReceiveImportSend(t *testing.T) {
        go importSend(imp, t)
        exportReceive(exp, t)
 }
+
+func TestClosingExportSendImportReceive(t *testing.T) {
+       exp, err := NewExporter("tcp", ":0")
+       if err != nil {
+               t.Fatal("new exporter:", err)
+       }
+       imp, err := NewImporter("tcp", exp.Addr().String())
+       if err != nil {
+               t.Fatal("new importer:", err)
+       }
+       go exportSend(exp, closeCount, t)
+       importReceive(imp, t)
+}