]> Cypherpunks repositories - gostls13.git/commitdiff
netchan: use acknowledgements on export send.
authorRob Pike <r@golang.org>
Sat, 4 Sep 2010 13:41:54 +0000 (23:41 +1000)
committerRob Pike <r@golang.org>
Sat, 4 Sep 2010 13:41:54 +0000 (23:41 +1000)
Also add exporter.Drain() to wait for completion.
This makes it possible for an Exporter to fire off a message
and wait (by calling Drain) for the message to be received,
even if a client has yet to call to retrieve it.

Once this design is settled, I'll do the same for import send.

Testing strategies welcome.  I have some working stand-alone
tests.

R=rsc
CC=golang-dev
https://golang.org/cl/2137041

src/pkg/netchan/common.go
src/pkg/netchan/export.go
src/pkg/netchan/import.go
src/pkg/netchan/netchan_test.go

index 624397ef46db514e75a601749f7883669d9718c8..c5fd5698cf9043f32428e2f3fa4ebd75d5546a4a 100644 (file)
@@ -10,6 +10,7 @@ import (
        "os"
        "reflect"
        "sync"
+       "time"
 )
 
 // The direction of a connection from the client's perspective.
@@ -25,6 +26,7 @@ const (
        payRequest = iota // request structure follows
        payError          // error structure follows
        payData           // user payload follows
+       payAck            // acknowledgement; no payload
 )
 
 // A header is sent as a prefix to every transmission.  It will be followed by
@@ -32,13 +34,14 @@ const (
 type header struct {
        name        string
        payloadType int
+       seqNum      int64
 }
 
 // Sent with a header once per channel from importer to exporter to report
 // that it wants to bind to a channel with the specified direction for count
 // messages.  If count is zero, it means unlimited.
 type request struct {
-       count int
+       count int64
        dir   Dir
 }
 
@@ -47,6 +50,27 @@ type error struct {
        error string
 }
 
+// Used to unify management of acknowledgements for import and export.
+type unackedCounter interface {
+       unackedCount() int64
+       ack() int64
+       seq() int64
+}
+
+// A channel and its direction.
+type chanDir struct {
+       ch  *reflect.ChanValue
+       dir Dir
+}
+
+// clientSet contains the objects and methods needed for tracking
+// clients of an exporter and draining outstanding messages.
+type clientSet struct {
+       mu      sync.Mutex // protects access to channel and client maps
+       chans   map[string]*chanDir
+       clients map[unackedCounter]bool
+}
+
 // Mutex-protected encoder and decoder pair.
 type encDec struct {
        decLock sync.Mutex
@@ -79,10 +103,78 @@ func (ed *encDec) encode(hdr *header, payloadType int, payload interface{}) os.E
        hdr.payloadType = payloadType
        err := ed.enc.Encode(hdr)
        if err == nil {
-               err = ed.enc.Encode(payload)
-       } else {
+               if payload != nil {
+                       err = ed.enc.Encode(payload)
+               }
+       }
+       if err != nil {
                // TODO: tear down connection if there is an error?
        }
        ed.encLock.Unlock()
        return err
 }
+
+// See the comment for Exporter.Drain.
+func (cs *clientSet) drain(timeout int64) os.Error {
+       startTime := time.Nanoseconds()
+       for {
+               pending := false
+               cs.mu.Lock()
+               // Any messages waiting for a client?
+               for _, chDir := range cs.chans {
+                       if chDir.ch.Len() > 0 {
+                               pending = true
+                       }
+               }
+               // Any unacknowledged messages?
+               for client := range cs.clients {
+                       n := client.unackedCount()
+                       if n > 0 { // Check for > rather than != just to be safe.
+                               pending = true
+                               break
+                       }
+               }
+               cs.mu.Unlock()
+               if !pending {
+                       break
+               }
+               if timeout > 0 && time.Nanoseconds()-startTime >= timeout {
+                       return os.ErrorString("timeout")
+               }
+               time.Sleep(100 * 1e6) // 100 milliseconds
+       }
+       return nil
+}
+
+// See the comment for Exporter.Sync.
+func (cs *clientSet) sync(timeout int64) os.Error {
+       startTime := time.Nanoseconds()
+       // seq remembers the clients and their seqNum at point of entry.
+       seq := make(map[unackedCounter]int64)
+       for client := range cs.clients {
+               seq[client] = client.seq()
+       }
+       for {
+               pending := false
+               cs.mu.Lock()
+               // Any unacknowledged messages?  Look only at clients that existed
+               // when we started and are still in this client set.
+               for client := range seq {
+                       if _, ok := cs.clients[client]; ok {
+                               if client.ack() < seq[client] {
+                                       pending = true
+                                       break
+                               }
+                       }
+               }
+               cs.mu.Unlock()
+               if !pending {
+                       break
+               }
+               if timeout > 0 && time.Nanoseconds()-startTime >= timeout {
+                       return os.ErrorString("timeout")
+               }
+               time.Sleep(100 * 1e6) // 100 milliseconds
+       }
+       return nil
+}
index 3142eebf736d5b959322b45b4b0dc061d796f283..c42e35c56d988dfc580d7390a85d53bffbec0f31 100644 (file)
@@ -31,59 +31,47 @@ import (
 
 // Export
 
-// A channel and its associated information: a direction plus
-// a handy marshaling place for its data.
-type exportChan struct {
-       ch  *reflect.ChanValue
-       dir Dir
-}
-
 // An Exporter allows a set of channels to be published on a single
 // network port.  A single machine may have multiple Exporters
 // but they must use different ports.
 type Exporter struct {
+       *clientSet
        listener net.Listener
-       chanLock sync.Mutex // protects access to channel map
-       chans    map[string]*exportChan
 }
 
 type expClient struct {
        *encDec
-       exp *Exporter
+       exp     *Exporter
+       mu      sync.Mutex // protects remaining fields
+       errored bool       // client has been sent an error
+       seqNum  int64      // sequences messages sent to client; has value of highest sent
+       ackNum  int64      // highest sequence number acknowledged
 }
 
 func newClient(exp *Exporter, conn net.Conn) *expClient {
        client := new(expClient)
        client.exp = exp
        client.encDec = newEncDec(conn)
+       client.seqNum = 0
+       client.ackNum = 0
        return client
 
 }
 
-// Wait for incoming connections, start a new runner for each
-func (exp *Exporter) listen() {
-       for {
-               conn, err := exp.listener.Accept()
-               if err != nil {
-                       log.Stderr("exporter.listen:", err)
-                       break
-               }
-               client := newClient(exp, conn)
-               go client.run()
-       }
-}
-
 func (client *expClient) sendError(hdr *header, err string) {
        error := &error{err}
        log.Stderr("export:", error.error)
        client.encode(hdr, payError, error) // ignore any encode error, hope client gets it
+       client.mu.Lock()
+       client.errored = true
+       client.mu.Unlock()
 }
 
-func (client *expClient) getChan(hdr *header, dir Dir) *exportChan {
+func (client *expClient) getChan(hdr *header, dir Dir) *chanDir {
        exp := client.exp
-       exp.chanLock.Lock()
+       exp.mu.Lock()
        ech, ok := exp.chans[hdr.name]
-       exp.chanLock.Unlock()
+       exp.mu.Unlock()
        if !ok {
                client.sendError(hdr, "no such channel: "+hdr.name)
                return nil
@@ -95,9 +83,10 @@ func (client *expClient) getChan(hdr *header, dir Dir) *exportChan {
        return ech
 }
 
-// Manage sends and receives for a single client.  For each (client Recv) request,
-// this will launch a serveRecv goroutine to deliver the data for that channel,
-// while (client Send) requests are handled as data arrives from the client.
+// The function run manages sends and receives for a single client.  For each
+// (client Recv) request, this will launch a serveRecv goroutine to deliver
+// the data for that channel, while (client Send) requests are handled as
+// data arrives from the client.
 func (client *expClient) run() {
        hdr := new(header)
        hdrValue := reflect.NewValue(hdr)
@@ -107,15 +96,13 @@ func (client *expClient) run() {
        for {
                if err := client.decode(hdrValue); err != nil {
                        log.Stderr("error decoding client header:", err)
-                       // TODO: tear down connection
-                       return
+                       break
                }
                switch hdr.payloadType {
                case payRequest:
                        if err := client.decode(reqValue); err != nil {
                                log.Stderr("error decoding client request:", err)
-                               // TODO: tear down connection
-                               return
+                               break
                        }
                        switch req.dir {
                        case Recv:
@@ -132,13 +119,27 @@ func (client *expClient) run() {
                        }
                case payData:
                        client.serveSend(*hdr)
+               case payAck:
+                       client.mu.Lock()
+                       if client.ackNum != hdr.seqNum-1 {
+                               // Since the sequence number is incremented and the message is sent
+                               // in a single instance of locking client.mu, the messages are guaranteed
+                               // to be sent in order.  Therefore receipt of acknowledgement N means
+                               // all messages <=N have been seen by the recipient.  We check anyway.
+                               log.Stderr("netchan export: sequence out of order:", client.ackNum, hdr.seqNum)
+                       }
+                       if client.ackNum < hdr.seqNum { // If there has been an error, don't back up the count. 
+                               client.ackNum = hdr.seqNum
+                       }
+                       client.mu.Unlock()
                }
        }
+       client.exp.delClient(client)
 }
 
 // Send all the data on a single channel to a client asking for a Recv.
 // The header is passed by value to avoid issues of overwriting.
-func (client *expClient) serveRecv(hdr header, count int) {
+func (client *expClient) serveRecv(hdr header, count int64) {
        ech := client.getChan(&hdr, Send)
        if ech == nil {
                return
@@ -149,7 +150,16 @@ func (client *expClient) serveRecv(hdr header, count int) {
                        client.sendError(&hdr, os.EOF.String())
                        break
                }
-               if err := client.encode(&hdr, payData, val.Interface()); err != nil {
+               // We hold the lock during transmission to guarantee messages are
+               // sent in sequence number order.  Also, we increment first so the
+               // value of client.seqNum is the value of the highest used sequence
+               // number, not one beyond.
+               client.mu.Lock()
+               client.seqNum++
+               hdr.seqNum = client.seqNum
+               err := client.encode(&hdr, payData, val.Interface())
+               client.mu.Unlock()
+               if err != nil {
                        log.Stderr("error encoding client response:", err)
                        client.sendError(&hdr, err.String())
                        break
@@ -180,6 +190,40 @@ func (client *expClient) serveSend(hdr header) {
        // TODO count
 }
 
+func (client *expClient) unackedCount() int64 {
+       client.mu.Lock()
+       n := client.seqNum - client.ackNum
+       client.mu.Unlock()
+       return n
+}
+
+func (client *expClient) seq() int64 {
+       client.mu.Lock()
+       n := client.seqNum
+       client.mu.Unlock()
+       return n
+}
+
+func (client *expClient) ack() int64 {
+       client.mu.Lock()
+       n := client.seqNum
+       client.mu.Unlock()
+       return n
+}
+
+// Wait for incoming connections, start a new runner for each
+func (exp *Exporter) listen() {
+       for {
+               conn, err := exp.listener.Accept()
+               if err != nil {
+                       log.Stderr("exporter.listen:", err)
+                       break
+               }
+               client := exp.addClient(conn)
+               go client.run()
+       }
+}
+
 // NewExporter creates a new Exporter to export channels
 // on the network and local address defined as in net.Listen.
 func NewExporter(network, localaddr string) (*Exporter, os.Error) {
@@ -189,12 +233,52 @@ func NewExporter(network, localaddr string) (*Exporter, os.Error) {
        }
        e := &Exporter{
                listener: listener,
-               chans:    make(map[string]*exportChan),
+               clientSet: &clientSet{
+                       chans:   make(map[string]*chanDir),
+                       clients: make(map[unackedCounter]bool),
+               },
        }
        go e.listen()
        return e, nil
 }
 
+// addClient creates a new expClient and records its existence
+func (exp *Exporter) addClient(conn net.Conn) *expClient {
+       client := newClient(exp, conn)
+       exp.clients[client] = true
+       exp.mu.Unlock()
+       return client
+}
+
+// delClient forgets the client existed
+func (exp *Exporter) delClient(client *expClient) {
+       exp.mu.Lock()
+       exp.clients[client] = false, false
+       exp.mu.Unlock()
+}
+
+// Drain waits until all messages sent from this exporter/importer, including
+// those not yet sent to any client and possibly including those sent while
+// Drain was executing, have been received by the importer.  In short, it
+// waits until all the exporter's messages have been received by a client.
+// If the timeout (measured in nanoseconds) is positive and Drain takes
+// longer than that to complete, an error is returned.
+func (exp *Exporter) Drain(timeout int64) os.Error {
+       // This wrapper function is here so the method's comment will appear in godoc.
+       return exp.clientSet.drain(timeout)
+}
+
+// Sync waits until all clients of the exporter have received the messages
+// that were sent at the time Sync was invoked.  Unlike Drain, it does not
+// wait for messages sent while it is running or messages that have not been
+// dispatched to any client.  If the timeout (measured in nanoseconds) is
+// positive and Sync takes longer than that to complete, an error is
+// returned.
+func (exp *Exporter) Sync(timeout int64) os.Error {
+       // This wrapper function is here so the method's comment will appear in godoc.
+       return exp.clientSet.sync(timeout)
+}
+
 // Addr returns the Exporter's local network address.
 func (exp *Exporter) Addr() net.Addr { return exp.listener.Addr() }
 
@@ -230,12 +314,12 @@ func (exp *Exporter) Export(name string, chT interface{}, dir Dir) os.Error {
        if err != nil {
                return err
        }
-       exp.chanLock.Lock()
-       defer exp.chanLock.Unlock()
+       exp.mu.Lock()
+       defer exp.mu.Unlock()
        _, present := exp.chans[name]
        if present {
                return os.ErrorString("channel name already being exported:" + name)
        }
-       exp.chans[name] = &exportChan{ch, dir}
+       exp.chans[name] = &chanDir{ch, dir}
        return nil
 }
index 1effbaef4a06db47c8fdd49e09c939f9b12a0dcd..6a065543b5e0619ae1bffb9bd53ecdfd30358cba 100644 (file)
@@ -14,13 +14,6 @@ import (
 
 // Import
 
-// A channel and its associated information: a template value and direction,
-// plus a handy marshaling place for its data.
-type importChan struct {
-       ch  *reflect.ChanValue
-       dir Dir
-}
-
 // An Importer allows a set of channels to be imported from a single
 // remote machine/network port.  A machine may have multiple
 // importers, even from the same machine/network port.
@@ -28,7 +21,7 @@ type Importer struct {
        *encDec
        conn     net.Conn
        chanLock sync.Mutex // protects access to channel map
-       chans    map[string]*importChan
+       chans    map[string]*chanDir
 }
 
 // NewImporter creates a new Importer object to import channels
@@ -43,7 +36,7 @@ func NewImporter(network, remoteaddr string) (*Importer, os.Error) {
        imp := new(Importer)
        imp.encDec = newEncDec(conn)
        imp.conn = conn
-       imp.chans = make(map[string]*importChan)
+       imp.chans = make(map[string]*chanDir)
        go imp.run()
        return imp, nil
 }
@@ -67,6 +60,7 @@ func (imp *Importer) run() {
        // Loop on responses; requests are sent by ImportNValues()
        hdr := new(header)
        hdrValue := reflect.NewValue(hdr)
+       ackHdr := new(header)
        err := new(error)
        errValue := reflect.NewValue(err)
        for {
@@ -103,6 +97,10 @@ func (imp *Importer) run() {
                        log.Stderr("cannot happen: receive from non-Recv channel")
                        return
                }
+               // Acknowledge receipt
+               ackHdr.name = hdr.name
+               ackHdr.seqNum = hdr.seqNum
+               imp.encode(ackHdr, payAck, nil)
                // Create a new value for each received item.
                value := reflect.MakeZero(ich.ch.Type().(*reflect.ChanType).Elem())
                if e := imp.decode(value); e != nil {
@@ -144,14 +142,10 @@ func (imp *Importer) ImportNValues(name string, chT interface{}, dir Dir, n int)
        if present {
                return os.ErrorString("channel name already being imported:" + name)
        }
-       imp.chans[name] = &importChan{ch, dir}
+       imp.chans[name] = &chanDir{ch, dir}
        // Tell the other side about this channel.
-       hdr := new(header)
-       hdr.name = name
-       hdr.payloadType = payRequest
-       req := new(request)
-       req.dir = dir
-       req.count = n
+       hdr := &header{name: name, payloadType: payRequest}
+       req := &request{count: int64(n), dir: dir}
        if err := imp.encode(hdr, payRequest, req); err != nil {
                log.Stderr("importer request encode:", err)
                return err
index eb5a11ea449e0a4cf8ae1f1b25527cbebcc35c9b..1bd4c9d4f84e0fad431a691e422f25145fd1511f 100644 (file)
@@ -37,7 +37,7 @@ func exportReceive(exp *Exporter, t *testing.T) {
        }
 }
 
-func importReceive(imp *Importer, t *testing.T) {
+func importReceive(imp *Importer, t *testing.T, done chan bool) {
        ch := make(chan int)
        err := imp.ImportNValues("exportedSend", ch, Recv, count)
        if err != nil {
@@ -55,6 +55,9 @@ func importReceive(imp *Importer, t *testing.T) {
                        t.Errorf("importReceive: bad value: expected %d; got %+d", 23+i, v)
                }
        }
+       if done != nil {
+               done <- true
+       }
 }
 
 func importSend(imp *Importer, t *testing.T) {
@@ -78,7 +81,7 @@ func TestExportSendImportReceive(t *testing.T) {
                t.Fatal("new importer:", err)
        }
        exportSend(exp, count, t)
-       importReceive(imp, t)
+       importReceive(imp, t, nil)
 }
 
 func TestExportReceiveImportSend(t *testing.T) {
@@ -104,5 +107,39 @@ func TestClosingExportSendImportReceive(t *testing.T) {
                t.Fatal("new importer:", err)
        }
        exportSend(exp, closeCount, t)
-       importReceive(imp, t)
+       importReceive(imp, t, nil)
+}
+
+// Not a great test but it does at least invoke Drain.
+func TestExportDrain(t *testing.T) {
+       exp, err := NewExporter("tcp", "127.0.0.1:0")
+       if err != nil {
+               t.Fatal("new exporter:", err)
+       }
+       imp, err := NewImporter("tcp", exp.Addr().String())
+       if err != nil {
+               t.Fatal("new importer:", err)
+       }
+       done := make(chan bool)
+       go exportSend(exp, closeCount, t)
+       go importReceive(imp, t, done)
+       exp.Drain(0)
+       <-done
+}
+
+// Not a great test but it does at least invoke Sync.
+func TestExportSync(t *testing.T) {
+       exp, err := NewExporter("tcp", "127.0.0.1:0")
+       if err != nil {
+               t.Fatal("new exporter:", err)
+       }
+       imp, err := NewImporter("tcp", exp.Addr().String())
+       if err != nil {
+               t.Fatal("new importer:", err)
+       }
+       done := make(chan bool)
+       go importReceive(imp, t, done)
+       exportSend(exp, closeCount, t)
+       exp.Sync(0)
+       <-done
 }