]> Cypherpunks repositories - gostls13.git/commitdiff
netchan: added drain method to importer.
authorDavid Jakob Fritz <david.jakob.fritz@gmail.com>
Mon, 6 Jun 2011 06:55:32 +0000 (06:55 +0000)
committerRob Pike <r@golang.org>
Mon, 6 Jun 2011 06:55:32 +0000 (06:55 +0000)
Fixes #1868.

R=golang-dev, r, rsc
CC=golang-dev
https://golang.org/cl/4550093

src/pkg/netchan/import.go
src/pkg/netchan/netchan_test.go

index 0a700ca2b9906a9f01eda51075dc81f7b17a1538..7d96228c407d2358981aa298cd4c4e104c806660 100644 (file)
@@ -11,6 +11,7 @@ import (
        "os"
        "reflect"
        "sync"
+       "time"
 )
 
 // Import
@@ -31,6 +32,9 @@ type Importer struct {
        chans    map[int]*netChan
        errors   chan os.Error
        maxId    int
+       mu       sync.Mutex // protects remaining fields
+       unacked  int64      // number of unacknowledged sends.
+       seqLock  sync.Mutex // guarantees messages are in sequence, only locked under mu
 }
 
 // NewImporter creates a new Importer object to import a set of channels
@@ -42,6 +46,7 @@ func NewImporter(conn io.ReadWriter) *Importer {
        imp.chans = make(map[int]*netChan)
        imp.names = make(map[string]*netChan)
        imp.errors = make(chan os.Error, 10)
+       imp.unacked = 0
        go imp.run()
        return imp
 }
@@ -80,8 +85,10 @@ func (imp *Importer) run() {
        for {
                *hdr = header{}
                if e := imp.decode(hdrValue); e != nil {
-                       impLog("header:", e)
-                       imp.shutdown()
+                       if e != os.EOF {
+                               impLog("header:", e)
+                               imp.shutdown()
+                       }
                        return
                }
                switch hdr.PayloadType {
@@ -114,6 +121,9 @@ func (imp *Importer) run() {
                        nch := imp.getChan(hdr.Id, true)
                        if nch != nil {
                                nch.acked()
+                               imp.mu.Lock()
+                               imp.unacked--
+                               imp.mu.Unlock()
                        }
                        continue
                default:
@@ -220,10 +230,17 @@ func (imp *Importer) ImportNValues(name string, chT interface{}, dir Dir, size,
                                        }
                                        return
                                }
+                               // We hold the lock during transmission to guarantee messages are
+                               // sent in order.
+                               imp.mu.Lock()
+                               imp.unacked++
+                               imp.seqLock.Lock()
+                               imp.mu.Unlock()
                                if err = imp.encode(hdr, payData, val.Interface()); err != nil {
                                        impLog("error encoding client send:", err)
                                        return
                                }
+                               imp.seqLock.Unlock()
                        }
                }()
        }
@@ -244,3 +261,27 @@ func (imp *Importer) Hangup(name string) os.Error {
        nc.close()
        return nil
 }
+
+func (imp *Importer) unackedCount() int64 {
+       imp.mu.Lock()
+       n := imp.unacked
+       imp.mu.Unlock()
+       return n
+}
+
+// Drain waits until all messages sent from this exporter/importer, including
+// those not yet sent to any server and possibly including those sent while
+// Drain was executing, have been received by the exporter.  In short, it
+// waits until all the importer's messages have been received.
+// If the timeout (measured in nanoseconds) is positive and Drain takes
+// longer than that to complete, an error is returned.
+func (imp *Importer) Drain(timeout int64) os.Error {
+       startTime := time.Nanoseconds()
+       for imp.unackedCount() > 0 {
+               if timeout > 0 && time.Nanoseconds()-startTime >= timeout {
+                       return os.ErrorString("timeout")
+               }
+               time.Sleep(100 * 1e6)
+       }
+       return nil
+}
index fd4d8f780d9328953f42302f8ee676c18bb61c28..8c0f9a6e4b7ff3df15dfaef844157b0e79ff2052 100644 (file)
@@ -178,6 +178,16 @@ func TestExportDrain(t *testing.T) {
        <-done
 }
 
+// Not a great test but it does at least invoke Drain.
+func TestImportDrain(t *testing.T) {
+       exp, imp := pair(t)
+       expDone := make(chan bool)
+       go exportReceive(exp, t, expDone)
+       <-expDone
+       importSend(imp, closeCount, t, nil)
+       imp.Drain(0)
+}
+
 // Not a great test but it does at least invoke Sync.
 func TestExportSync(t *testing.T) {
        exp, imp := pair(t)