]> Cypherpunks repositories - gostls13.git/commitdiff
netchan: handle closing of channels.
authorRob Pike <r@golang.org>
Mon, 20 Sep 2010 00:14:39 +0000 (10:14 +1000)
committerRob Pike <r@golang.org>
Mon, 20 Sep 2010 00:14:39 +0000 (10:14 +1000)
This also silences some misleading logging.
Also improve logging.

R=rsc
CC=golang-dev
https://golang.org/cl/2245041

src/pkg/netchan/common.go
src/pkg/netchan/export.go
src/pkg/netchan/import.go
src/pkg/netchan/netchan_test.go

index 3f99868490bcf284c2bfad49df7e76146d7c4444..87981ca860387dfdb9ca5d4b5e54a8f60b434dd0 100644 (file)
@@ -37,6 +37,7 @@ const (
        payError          // error structure follows
        payData           // user payload follows
        payAck            // acknowledgement; no payload
+       payClosed         // channel is now closed
 )
 
 // A header is sent as a prefix to every transmission.  It will be followed by
index a58797e630c1717348e02839a3aaea48f3440f8f..d7dceead9900d46fc04140d5010c286cfd63e10a 100644 (file)
@@ -31,6 +31,12 @@ import (
 
 // Export
 
+// expLog is a logging convenience function.  The first argument must be a string.
+func expLog(args ...interface{}) {
+       args[0] = "netchan export: " + args[0].(string)
+       log.Stderr(args)
+}
+
 // An Exporter allows a set of channels to be published on a single
 // network port.  A single machine may have multiple Exporters
 // but they must use different ports.
@@ -60,7 +66,7 @@ func newClient(exp *Exporter, conn net.Conn) *expClient {
 
 func (client *expClient) sendError(hdr *header, err string) {
        error := &error{err}
-       log.Stderr("export:", error.error)
+       expLog("sending error to client", error.error)
        client.encode(hdr, payError, error) // ignore any encode error, hope client gets it
        client.mu.Lock()
        client.errored = true
@@ -96,13 +102,13 @@ func (client *expClient) run() {
        for {
                *hdr = header{}
                if err := client.decode(hdrValue); err != nil {
-                       log.Stderr("error decoding client header:", err)
+                       expLog("error decoding client header:", err)
                        break
                }
                switch hdr.payloadType {
                case payRequest:
                        if err := client.decode(reqValue); err != nil {
-                               log.Stderr("error decoding client request:", err)
+                               expLog("error decoding client request:", err)
                                break
                        }
                        switch req.dir {
@@ -114,12 +120,14 @@ func (client *expClient) run() {
                                // The actual sends will have payload type payData.
                                // TODO: manage the count?
                        default:
-                               error.error = "export request: can't handle channel direction"
-                               log.Stderr(error.error, req.dir)
+                               error.error = "request: can't handle channel direction"
+                               expLog(error.error, req.dir)
                                client.encode(hdr, payError, error)
                        }
                case payData:
                        client.serveSend(*hdr)
+               case payClosed:
+                       client.serveClosed(*hdr)
                case payAck:
                        client.mu.Lock()
                        if client.ackNum != hdr.seqNum-1 {
@@ -127,12 +135,14 @@ func (client *expClient) run() {
                                // in a single instance of locking client.mu, the messages are guaranteed
                                // to be sent in order.  Therefore receipt of acknowledgement N means
                                // all messages <=N have been seen by the recipient.  We check anyway.
-                               log.Stderr("netchan export: sequence out of order:", client.ackNum, hdr.seqNum)
+                               expLog("sequence out of order:", client.ackNum, hdr.seqNum)
                        }
                        if client.ackNum < hdr.seqNum { // If there has been an error, don't back up the count. 
                                client.ackNum = hdr.seqNum
                        }
                        client.mu.Unlock()
+               default:
+                       log.Exit("netchan export: unknown payload type", hdr.payloadType)
                }
        }
        client.exp.delClient(client)
@@ -148,7 +158,9 @@ func (client *expClient) serveRecv(hdr header, count int64) {
        for {
                val := ech.ch.Recv()
                if ech.ch.Closed() {
-                       client.sendError(&hdr, os.EOF.String())
+                       if err := client.encode(&hdr, payClosed, nil); err != nil {
+                               expLog("error encoding server closed message:", err)
+                       }
                        break
                }
                // We hold the lock during transmission to guarantee messages are
@@ -161,7 +173,7 @@ func (client *expClient) serveRecv(hdr header, count int64) {
                err := client.encode(&hdr, payData, val.Interface())
                client.mu.Unlock()
                if err != nil {
-                       log.Stderr("error encoding client response:", err)
+                       expLog("error encoding client response:", err)
                        client.sendError(&hdr, err.String())
                        break
                }
@@ -184,11 +196,20 @@ func (client *expClient) serveSend(hdr header) {
        // Create a new value for each received item.
        val := reflect.MakeZero(ech.ch.Type().(*reflect.ChanType).Elem())
        if err := client.decode(val); err != nil {
-               log.Stderr("exporter value decode:", err)
+               expLog("value decode:", err)
                return
        }
        ech.ch.Send(val)
-       // TODO count
+}
+
+// Report that client has closed the channel that is sending to us.
+// The header is passed by value to avoid issues of overwriting.
+func (client *expClient) serveClosed(hdr header) {
+       ech := client.getChan(&hdr, Recv)
+       if ech == nil {
+               return
+       }
+       ech.ch.Close()
 }
 
 func (client *expClient) unackedCount() int64 {
@@ -217,7 +238,7 @@ func (exp *Exporter) listen() {
        for {
                conn, err := exp.listener.Accept()
                if err != nil {
-                       log.Stderr("exporter.listen:", err)
+                       expLog("listen:", err)
                        break
                }
                client := exp.addClient(conn)
index 028a25f7f800bef682d2a8807df0ceddf0b05052..e6bf4cbb3283fd731c386911703eb83ba49e516c 100644 (file)
@@ -14,6 +14,12 @@ import (
 
 // Import
 
+// impLog is a logging convenience function.  The first argument must be a string.
+func impLog(args ...interface{}) {
+       args[0] = "netchan import: " + args[0].(string)
+       log.Stderr(args)
+}
+
 // An Importer allows a set of channels to be imported from a single
 // remote machine/network port.  A machine may have multiple
 // importers, even from the same machine/network port.
@@ -66,7 +72,7 @@ func (imp *Importer) run() {
        for {
                *hdr = header{}
                if e := imp.decode(hdrValue); e != nil {
-                       log.Stderr("importer header:", e)
+                       impLog("header:", e)
                        imp.shutdown()
                        return
                }
@@ -75,27 +81,30 @@ func (imp *Importer) run() {
                        // done lower in loop
                case payError:
                        if e := imp.decode(errValue); e != nil {
-                               log.Stderr("importer error:", e)
+                               impLog("error:", e)
                                return
                        }
                        if err.error != "" {
-                               log.Stderr("importer response error:", err.error)
+                               impLog("response error:", err.error)
                                imp.shutdown()
                                return
                        }
+               case payClosed:
+                       ich := imp.getChan(hdr.name)
+                       if ich != nil {
+                               ich.ch.Close()
+                       }
+                       continue
                default:
-                       log.Stderr("unexpected payload type:", hdr.payloadType)
+                       impLog("unexpected payload type:", hdr.payloadType)
                        return
                }
-               imp.chanLock.Lock()
-               ich, ok := imp.chans[hdr.name]
-               imp.chanLock.Unlock()
-               if !ok {
-                       log.Stderr("unknown name in request:", hdr.name)
-                       return
+               ich := imp.getChan(hdr.name)
+               if ich == nil {
+                       continue
                }
                if ich.dir != Recv {
-                       log.Stderr("cannot happen: receive from non-Recv channel")
+                       impLog("cannot happen: receive from non-Recv channel")
                        return
                }
                // Acknowledge receipt
@@ -105,13 +114,24 @@ func (imp *Importer) run() {
                // Create a new value for each received item.
                value := reflect.MakeZero(ich.ch.Type().(*reflect.ChanType).Elem())
                if e := imp.decode(value); e != nil {
-                       log.Stderr("importer value decode:", e)
+                       impLog("importer value decode:", e)
                        return
                }
                ich.ch.Send(value)
        }
 }
 
+func (imp *Importer) getChan(name string) *chanDir {
+       imp.chanLock.Lock()
+       ich := imp.chans[name]
+       imp.chanLock.Unlock()
+       if ich == nil {
+               impLog("unknown name in netchan request:", name)
+               return nil
+       }
+       return ich
+}
+
 // Import imports a channel of the given type and specified direction.
 // It is equivalent to ImportNValues with a count of -1, meaning unbounded.
 func (imp *Importer) Import(name string, chT interface{}, dir Dir) os.Error {
@@ -145,18 +165,24 @@ func (imp *Importer) ImportNValues(name string, chT interface{}, dir Dir, n int)
        }
        imp.chans[name] = &chanDir{ch, dir}
        // Tell the other side about this channel.
-       hdr := &header{name: name, payloadType: payRequest}
+       hdr := &header{name: name}
        req := &request{count: int64(n), dir: dir}
-       if err := imp.encode(hdr, payRequest, req); err != nil {
-               log.Stderr("importer request encode:", err)
+       if err = imp.encode(hdr, payRequest, req); err != nil {
+               impLog("request encode:", err)
                return err
        }
        if dir == Send {
                go func() {
                        for i := 0; n == -1 || i < n; i++ {
                                val := ch.Recv()
-                               if err := imp.encode(hdr, payData, val.Interface()); err != nil {
-                                       log.Stderr("error encoding client response:", err)
+                               if ch.Closed() {
+                                       if err = imp.encode(hdr, payClosed, nil); err != nil {
+                                               impLog("error encoding client closed message:", err)
+                                       }
+                                       return
+                               }
+                               if err = imp.encode(hdr, payData, val.Interface()); err != nil {
+                                       impLog("error encoding client send:", err)
                                        return
                                }
                        }
index 1626c367d3bfb5458378405aef60a23b43015919..42cb3d1ec1edc7ffce94b7433540d97e2132c38a 100644 (file)
@@ -9,6 +9,8 @@ import "testing"
 const count = 10     // number of items in most tests
 const closeCount = 5 // number of items when sender closes early
 
+const base = 23
+
 func exportSend(exp *Exporter, n int, t *testing.T) {
        ch := make(chan int)
        err := exp.Export("exportedSend", ch, Send)
@@ -17,7 +19,7 @@ func exportSend(exp *Exporter, n int, t *testing.T) {
        }
        go func() {
                for i := 0; i < n; i++ {
-                       ch <- 23+i
+                       ch <- base+i
                }
                close(ch)
        }()
@@ -31,12 +33,32 @@ func exportReceive(exp *Exporter, t *testing.T) {
        }
        for i := 0; i < count; i++ {
                v := <-ch
-               if v != 45+i {
-                       t.Errorf("export Receive: bad value: expected 4%d; got %d", 45+i, v)
+               if closed(ch) {
+                       if i != closeCount {
+                               t.Errorf("exportReceive expected close at %d; got one at %d\n", closeCount, i)
+                       }
+                       break
+               }
+               if v != base+i {
+                       t.Errorf("export Receive: bad value: expected %d+%d=%d; got %d", base, i, base+i, v)
                }
        }
 }
 
+func importSend(imp *Importer, n int, t *testing.T) {
+       ch := make(chan int)
+       err := imp.ImportNValues("exportedRecv", ch, Send, count)
+       if err != nil {
+               t.Fatal("importSend:", err)
+       }
+       go func() {
+               for i := 0; i < n; i++ {
+                       ch <- base+i
+               }
+               close(ch)
+       }()
+}
+
 func importReceive(imp *Importer, t *testing.T, done chan bool) {
        ch := make(chan int)
        err := imp.ImportNValues("exportedSend", ch, Recv, count)
@@ -47,12 +69,12 @@ func importReceive(imp *Importer, t *testing.T, done chan bool) {
                v := <-ch
                if closed(ch) {
                        if i != closeCount {
-                               t.Errorf("expected close at %d; got one at %d\n", closeCount, i)
+                               t.Errorf("importReceive expected close at %d; got one at %d\n", closeCount, i)
                        }
                        break
                }
                if v != 23+i {
-                       t.Errorf("importReceive: bad value: expected %d; got %+d", 23+i, v)
+                       t.Errorf("importReceive: bad value: expected %%d+%d=%d; got %+d", base, i, base+i, v)
                }
        }
        if done != nil {
@@ -60,17 +82,6 @@ func importReceive(imp *Importer, t *testing.T, done chan bool) {
        }
 }
 
-func importSend(imp *Importer, t *testing.T) {
-       ch := make(chan int)
-       err := imp.ImportNValues("exportedRecv", ch, Send, count)
-       if err != nil {
-               t.Fatal("importSend:", err)
-       }
-       for i := 0; i < count; i++ {
-               ch <- 45+i
-       }
-}
-
 func TestExportSendImportReceive(t *testing.T) {
        exp, err := NewExporter("tcp", "127.0.0.1:0")
        if err != nil {
@@ -93,7 +104,7 @@ func TestExportReceiveImportSend(t *testing.T) {
        if err != nil {
                t.Fatal("new importer:", err)
        }
-       go importSend(imp, t)
+       importSend(imp, count, t)
        exportReceive(exp, t)
 }
 
@@ -110,6 +121,19 @@ func TestClosingExportSendImportReceive(t *testing.T) {
        importReceive(imp, t, nil)
 }
 
+func TestClosingImportSendExportReceive(t *testing.T) {
+       exp, err := NewExporter("tcp", "127.0.0.1:0")
+       if err != nil {
+               t.Fatal("new exporter:", err)
+       }
+       imp, err := NewImporter("tcp", exp.Addr().String())
+       if err != nil {
+               t.Fatal("new importer:", err)
+       }
+       importSend(imp, closeCount, t)
+       exportReceive(exp, t)
+}
+
 // Not a great test but it does at least invoke Drain.
 func TestExportDrain(t *testing.T) {
        exp, err := NewExporter("tcp", "127.0.0.1:0")