From: Rob Pike Date: Sat, 4 Sep 2010 13:41:54 +0000 (+1000) Subject: netchan: use acknowledgements on export send. X-Git-Tag: weekly.2010-09-06~3 X-Git-Url: http://www.git.cypherpunks.su/?a=commitdiff_plain;h=d54b921c9baf19105d074896da98d6dbf3c9e219;p=gostls13.git netchan: use acknowledgements on export send. Also add exporter.Drain() to wait for completion. This makes it possible for an Exporter to fire off a message and wait (by calling Drain) for the message to be received, even if a client has yet to call to retrieve it. Once this design is settled, I'll do the same for import send. Testing strategies welcome. I have some working stand-alone tests. R=rsc CC=golang-dev https://golang.org/cl/2137041 --- diff --git a/src/pkg/netchan/common.go b/src/pkg/netchan/common.go index 624397ef46..c5fd5698cf 100644 --- a/src/pkg/netchan/common.go +++ b/src/pkg/netchan/common.go @@ -10,6 +10,7 @@ import ( "os" "reflect" "sync" + "time" ) // The direction of a connection from the client's perspective. @@ -25,6 +26,7 @@ const ( payRequest = iota // request structure follows payError // error structure follows payData // user payload follows + payAck // acknowledgement; no payload ) // A header is sent as a prefix to every transmission. It will be followed by @@ -32,13 +34,14 @@ const ( type header struct { name string payloadType int + seqNum int64 } // Sent with a header once per channel from importer to exporter to report // that it wants to bind to a channel with the specified direction for count // messages. If count is zero, it means unlimited. type request struct { - count int + count int64 dir Dir } @@ -47,6 +50,27 @@ type error struct { error string } +// Used to unify management of acknowledgements for import and export. +type unackedCounter interface { + unackedCount() int64 + ack() int64 + seq() int64 +} + +// A channel and its direction. +type chanDir struct { + ch *reflect.ChanValue + dir Dir +} + +// clientSet contains the objects and methods needed for tracking +// clients of an exporter and draining outstanding messages. +type clientSet struct { + mu sync.Mutex // protects access to channel and client maps + chans map[string]*chanDir + clients map[unackedCounter]bool +} + // Mutex-protected encoder and decoder pair. type encDec struct { decLock sync.Mutex @@ -79,10 +103,78 @@ func (ed *encDec) encode(hdr *header, payloadType int, payload interface{}) os.E hdr.payloadType = payloadType err := ed.enc.Encode(hdr) if err == nil { - err = ed.enc.Encode(payload) - } else { + if payload != nil { + err = ed.enc.Encode(payload) + } + } + if err != nil { // TODO: tear down connection if there is an error? } ed.encLock.Unlock() return err } + +// See the comment for Exporter.Drain. +func (cs *clientSet) drain(timeout int64) os.Error { + startTime := time.Nanoseconds() + for { + pending := false + cs.mu.Lock() + // Any messages waiting for a client? + for _, chDir := range cs.chans { + if chDir.ch.Len() > 0 { + pending = true + } + } + // Any unacknowledged messages? + for client := range cs.clients { + n := client.unackedCount() + if n > 0 { // Check for > rather than != just to be safe. + pending = true + break + } + } + cs.mu.Unlock() + if !pending { + break + } + if timeout > 0 && time.Nanoseconds()-startTime >= timeout { + return os.ErrorString("timeout") + } + time.Sleep(100 * 1e6) // 100 milliseconds + } + return nil +} + +// See the comment for Exporter.Sync. +func (cs *clientSet) sync(timeout int64) os.Error { + startTime := time.Nanoseconds() + // seq remembers the clients and their seqNum at point of entry. + seq := make(map[unackedCounter]int64) + for client := range cs.clients { + seq[client] = client.seq() + } + for { + pending := false + cs.mu.Lock() + // Any unacknowledged messages? Look only at clients that existed + // when we started and are still in this client set. + for client := range seq { + if _, ok := cs.clients[client]; ok { + if client.ack() < seq[client] { + pending = true + break + } + } + } + cs.mu.Unlock() + if !pending { + break + } + if timeout > 0 && time.Nanoseconds()-startTime >= timeout { + return os.ErrorString("timeout") + } + time.Sleep(100 * 1e6) // 100 milliseconds + } + return nil +} diff --git a/src/pkg/netchan/export.go b/src/pkg/netchan/export.go index 3142eebf73..c42e35c56d 100644 --- a/src/pkg/netchan/export.go +++ b/src/pkg/netchan/export.go @@ -31,59 +31,47 @@ import ( // Export -// A channel and its associated information: a direction plus -// a handy marshaling place for its data. -type exportChan struct { - ch *reflect.ChanValue - dir Dir -} - // An Exporter allows a set of channels to be published on a single // network port. A single machine may have multiple Exporters // but they must use different ports. type Exporter struct { + *clientSet listener net.Listener - chanLock sync.Mutex // protects access to channel map - chans map[string]*exportChan } type expClient struct { *encDec - exp *Exporter + exp *Exporter + mu sync.Mutex // protects remaining fields + errored bool // client has been sent an error + seqNum int64 // sequences messages sent to client; has value of highest sent + ackNum int64 // highest sequence number acknowledged } func newClient(exp *Exporter, conn net.Conn) *expClient { client := new(expClient) client.exp = exp client.encDec = newEncDec(conn) + client.seqNum = 0 + client.ackNum = 0 return client } -// Wait for incoming connections, start a new runner for each -func (exp *Exporter) listen() { - for { - conn, err := exp.listener.Accept() - if err != nil { - log.Stderr("exporter.listen:", err) - break - } - client := newClient(exp, conn) - go client.run() - } -} - func (client *expClient) sendError(hdr *header, err string) { error := &error{err} log.Stderr("export:", error.error) client.encode(hdr, payError, error) // ignore any encode error, hope client gets it + client.mu.Lock() + client.errored = true + client.mu.Unlock() } -func (client *expClient) getChan(hdr *header, dir Dir) *exportChan { +func (client *expClient) getChan(hdr *header, dir Dir) *chanDir { exp := client.exp - exp.chanLock.Lock() + exp.mu.Lock() ech, ok := exp.chans[hdr.name] - exp.chanLock.Unlock() + exp.mu.Unlock() if !ok { client.sendError(hdr, "no such channel: "+hdr.name) return nil @@ -95,9 +83,10 @@ func (client *expClient) getChan(hdr *header, dir Dir) *exportChan { return ech } -// Manage sends and receives for a single client. For each (client Recv) request, -// this will launch a serveRecv goroutine to deliver the data for that channel, -// while (client Send) requests are handled as data arrives from the client. +// The function run manages sends and receives for a single client. For each +// (client Recv) request, this will launch a serveRecv goroutine to deliver +// the data for that channel, while (client Send) requests are handled as +// data arrives from the client. func (client *expClient) run() { hdr := new(header) hdrValue := reflect.NewValue(hdr) @@ -107,15 +96,13 @@ func (client *expClient) run() { for { if err := client.decode(hdrValue); err != nil { log.Stderr("error decoding client header:", err) - // TODO: tear down connection - return + break } switch hdr.payloadType { case payRequest: if err := client.decode(reqValue); err != nil { log.Stderr("error decoding client request:", err) - // TODO: tear down connection - return + break } switch req.dir { case Recv: @@ -132,13 +119,27 @@ func (client *expClient) run() { } case payData: client.serveSend(*hdr) + case payAck: + client.mu.Lock() + if client.ackNum != hdr.seqNum-1 { + // Since the sequence number is incremented and the message is sent + // in a single instance of locking client.mu, the messages are guaranteed + // to be sent in order. Therefore receipt of acknowledgement N means + // all messages <=N have been seen by the recipient. We check anyway. + log.Stderr("netchan export: sequence out of order:", client.ackNum, hdr.seqNum) + } + if client.ackNum < hdr.seqNum { // If there has been an error, don't back up the count. + client.ackNum = hdr.seqNum + } + client.mu.Unlock() } } + client.exp.delClient(client) } // Send all the data on a single channel to a client asking for a Recv. // The header is passed by value to avoid issues of overwriting. -func (client *expClient) serveRecv(hdr header, count int) { +func (client *expClient) serveRecv(hdr header, count int64) { ech := client.getChan(&hdr, Send) if ech == nil { return @@ -149,7 +150,16 @@ func (client *expClient) serveRecv(hdr header, count int) { client.sendError(&hdr, os.EOF.String()) break } - if err := client.encode(&hdr, payData, val.Interface()); err != nil { + // We hold the lock during transmission to guarantee messages are + // sent in sequence number order. Also, we increment first so the + // value of client.seqNum is the value of the highest used sequence + // number, not one beyond. + client.mu.Lock() + client.seqNum++ + hdr.seqNum = client.seqNum + err := client.encode(&hdr, payData, val.Interface()) + client.mu.Unlock() + if err != nil { log.Stderr("error encoding client response:", err) client.sendError(&hdr, err.String()) break @@ -180,6 +190,40 @@ func (client *expClient) serveSend(hdr header) { // TODO count } +func (client *expClient) unackedCount() int64 { + client.mu.Lock() + n := client.seqNum - client.ackNum + client.mu.Unlock() + return n +} + +func (client *expClient) seq() int64 { + client.mu.Lock() + n := client.seqNum + client.mu.Unlock() + return n +} + +func (client *expClient) ack() int64 { + client.mu.Lock() + n := client.seqNum + client.mu.Unlock() + return n +} + +// Wait for incoming connections, start a new runner for each +func (exp *Exporter) listen() { + for { + conn, err := exp.listener.Accept() + if err != nil { + log.Stderr("exporter.listen:", err) + break + } + client := exp.addClient(conn) + go client.run() + } +} + // NewExporter creates a new Exporter to export channels // on the network and local address defined as in net.Listen. func NewExporter(network, localaddr string) (*Exporter, os.Error) { @@ -189,12 +233,52 @@ func NewExporter(network, localaddr string) (*Exporter, os.Error) { } e := &Exporter{ listener: listener, - chans: make(map[string]*exportChan), + clientSet: &clientSet{ + chans: make(map[string]*chanDir), + clients: make(map[unackedCounter]bool), + }, } go e.listen() return e, nil } +// addClient creates a new expClient and records its existence +func (exp *Exporter) addClient(conn net.Conn) *expClient { + client := newClient(exp, conn) + exp.clients[client] = true + exp.mu.Unlock() + return client +} + +// delClient forgets the client existed +func (exp *Exporter) delClient(client *expClient) { + exp.mu.Lock() + exp.clients[client] = false, false + exp.mu.Unlock() +} + +// Drain waits until all messages sent from this exporter/importer, including +// those not yet sent to any client and possibly including those sent while +// Drain was executing, have been received by the importer. In short, it +// waits until all the exporter's messages have been received by a client. +// If the timeout (measured in nanoseconds) is positive and Drain takes +// longer than that to complete, an error is returned. +func (exp *Exporter) Drain(timeout int64) os.Error { + // This wrapper function is here so the method's comment will appear in godoc. + return exp.clientSet.drain(timeout) +} + +// Sync waits until all clients of the exporter have received the messages +// that were sent at the time Sync was invoked. Unlike Drain, it does not +// wait for messages sent while it is running or messages that have not been +// dispatched to any client. If the timeout (measured in nanoseconds) is +// positive and Sync takes longer than that to complete, an error is +// returned. +func (exp *Exporter) Sync(timeout int64) os.Error { + // This wrapper function is here so the method's comment will appear in godoc. + return exp.clientSet.sync(timeout) +} + // Addr returns the Exporter's local network address. func (exp *Exporter) Addr() net.Addr { return exp.listener.Addr() } @@ -230,12 +314,12 @@ func (exp *Exporter) Export(name string, chT interface{}, dir Dir) os.Error { if err != nil { return err } - exp.chanLock.Lock() - defer exp.chanLock.Unlock() + exp.mu.Lock() + defer exp.mu.Unlock() _, present := exp.chans[name] if present { return os.ErrorString("channel name already being exported:" + name) } - exp.chans[name] = &exportChan{ch, dir} + exp.chans[name] = &chanDir{ch, dir} return nil } diff --git a/src/pkg/netchan/import.go b/src/pkg/netchan/import.go index 1effbaef4a..6a065543b5 100644 --- a/src/pkg/netchan/import.go +++ b/src/pkg/netchan/import.go @@ -14,13 +14,6 @@ import ( // Import -// A channel and its associated information: a template value and direction, -// plus a handy marshaling place for its data. -type importChan struct { - ch *reflect.ChanValue - dir Dir -} - // An Importer allows a set of channels to be imported from a single // remote machine/network port. A machine may have multiple // importers, even from the same machine/network port. @@ -28,7 +21,7 @@ type Importer struct { *encDec conn net.Conn chanLock sync.Mutex // protects access to channel map - chans map[string]*importChan + chans map[string]*chanDir } // NewImporter creates a new Importer object to import channels @@ -43,7 +36,7 @@ func NewImporter(network, remoteaddr string) (*Importer, os.Error) { imp := new(Importer) imp.encDec = newEncDec(conn) imp.conn = conn - imp.chans = make(map[string]*importChan) + imp.chans = make(map[string]*chanDir) go imp.run() return imp, nil } @@ -67,6 +60,7 @@ func (imp *Importer) run() { // Loop on responses; requests are sent by ImportNValues() hdr := new(header) hdrValue := reflect.NewValue(hdr) + ackHdr := new(header) err := new(error) errValue := reflect.NewValue(err) for { @@ -103,6 +97,10 @@ func (imp *Importer) run() { log.Stderr("cannot happen: receive from non-Recv channel") return } + // Acknowledge receipt + ackHdr.name = hdr.name + ackHdr.seqNum = hdr.seqNum + imp.encode(ackHdr, payAck, nil) // Create a new value for each received item. value := reflect.MakeZero(ich.ch.Type().(*reflect.ChanType).Elem()) if e := imp.decode(value); e != nil { @@ -144,14 +142,10 @@ func (imp *Importer) ImportNValues(name string, chT interface{}, dir Dir, n int) if present { return os.ErrorString("channel name already being imported:" + name) } - imp.chans[name] = &importChan{ch, dir} + imp.chans[name] = &chanDir{ch, dir} // Tell the other side about this channel. - hdr := new(header) - hdr.name = name - hdr.payloadType = payRequest - req := new(request) - req.dir = dir - req.count = n + hdr := &header{name: name, payloadType: payRequest} + req := &request{count: int64(n), dir: dir} if err := imp.encode(hdr, payRequest, req); err != nil { log.Stderr("importer request encode:", err) return err diff --git a/src/pkg/netchan/netchan_test.go b/src/pkg/netchan/netchan_test.go index eb5a11ea44..1bd4c9d4f8 100644 --- a/src/pkg/netchan/netchan_test.go +++ b/src/pkg/netchan/netchan_test.go @@ -37,7 +37,7 @@ func exportReceive(exp *Exporter, t *testing.T) { } } -func importReceive(imp *Importer, t *testing.T) { +func importReceive(imp *Importer, t *testing.T, done chan bool) { ch := make(chan int) err := imp.ImportNValues("exportedSend", ch, Recv, count) if err != nil { @@ -55,6 +55,9 @@ func importReceive(imp *Importer, t *testing.T) { t.Errorf("importReceive: bad value: expected %d; got %+d", 23+i, v) } } + if done != nil { + done <- true + } } func importSend(imp *Importer, t *testing.T) { @@ -78,7 +81,7 @@ func TestExportSendImportReceive(t *testing.T) { t.Fatal("new importer:", err) } exportSend(exp, count, t) - importReceive(imp, t) + importReceive(imp, t, nil) } func TestExportReceiveImportSend(t *testing.T) { @@ -104,5 +107,39 @@ func TestClosingExportSendImportReceive(t *testing.T) { t.Fatal("new importer:", err) } exportSend(exp, closeCount, t) - importReceive(imp, t) + importReceive(imp, t, nil) +} + +// Not a great test but it does at least invoke Drain. +func TestExportDrain(t *testing.T) { + exp, err := NewExporter("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal("new exporter:", err) + } + imp, err := NewImporter("tcp", exp.Addr().String()) + if err != nil { + t.Fatal("new importer:", err) + } + done := make(chan bool) + go exportSend(exp, closeCount, t) + go importReceive(imp, t, done) + exp.Drain(0) + <-done +} + +// Not a great test but it does at least invoke Sync. +func TestExportSync(t *testing.T) { + exp, err := NewExporter("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal("new exporter:", err) + } + imp, err := NewImporter("tcp", exp.Addr().String()) + if err != nil { + t.Fatal("new importer:", err) + } + done := make(chan bool) + go importReceive(imp, t, done) + exportSend(exp, closeCount, t) + exp.Sync(0) + <-done }