]> Cypherpunks repositories - gostls13.git/commitdiff
netchan: allow use of arbitrary connections.
authorRoger Peppe <rogpeppe@gmail.com>
Wed, 16 Feb 2011 16:14:41 +0000 (08:14 -0800)
committerRob Pike <r@golang.org>
Wed, 16 Feb 2011 16:14:41 +0000 (08:14 -0800)
R=r, r2, rsc
CC=golang-dev
https://golang.org/cl/4119055

src/pkg/netchan/common.go
src/pkg/netchan/export.go
src/pkg/netchan/import.go
src/pkg/netchan/netchan_test.go

index 6c085484e5eb4abcf5a71478891260341ec09d5c..dd06050ee50f0f49eed56ade78e4f8f0d126358c 100644 (file)
@@ -6,7 +6,7 @@ package netchan
 
 import (
        "gob"
-       "net"
+       "io"
        "os"
        "reflect"
        "sync"
@@ -93,7 +93,7 @@ type encDec struct {
        enc     *gob.Encoder
 }
 
-func newEncDec(conn net.Conn) *encDec {
+func newEncDec(conn io.ReadWriter) *encDec {
        return &encDec{
                dec: gob.NewDecoder(conn),
                enc: gob.NewEncoder(conn),
@@ -199,9 +199,10 @@ func (cs *clientSet) sync(timeout int64) os.Error {
 // are delivered into the local channel.
 type netChan struct {
        *chanDir
-       name string
-       id   int
-       size int // buffer size of channel.
+       name   string
+       id     int
+       size   int // buffer size of channel.
+       closed bool
 
        // sender-specific state
        ackCh chan bool // buffered with space for all the acks we need
@@ -227,6 +228,9 @@ func newNetChan(name string, id int, ch *chanDir, ed *encDec, size int, count in
 
 // Close the channel.
 func (nch *netChan) close() {
+       if nch.closed {
+               return
+       }
        if nch.dir == Recv {
                if nch.sendCh != nil {
                        // If the sender goroutine is active, close the channel to it.
@@ -239,6 +243,7 @@ func (nch *netChan) close() {
                nch.ch.Close()
                close(nch.ackCh)
        }
+       nch.closed = true
 }
 
 // Send message from remote side to local receiver.
index 675e252d5c32b768eb28d208b9ac0dff26eff173..55eba0e2e0f61fb5cf5f7780a3000cbc74c1a6da 100644 (file)
@@ -23,6 +23,7 @@ package netchan
 
 import (
        "log"
+       "io"
        "net"
        "os"
        "reflect"
@@ -43,7 +44,6 @@ func expLog(args ...interface{}) {
 // but they must use different ports.
 type Exporter struct {
        *clientSet
-       listener net.Listener
 }
 
 type expClient struct {
@@ -57,7 +57,7 @@ type expClient struct {
        seqLock sync.Mutex       // guarantees messages are in sequence, only locked under mu
 }
 
-func newClient(exp *Exporter, conn net.Conn) *expClient {
+func newClient(exp *Exporter, conn io.ReadWriter) *expClient {
        client := new(expClient)
        client.exp = exp
        client.encDec = newEncDec(conn)
@@ -260,39 +260,50 @@ func (client *expClient) ack() int64 {
        return n
 }
 
-// Wait for incoming connections, start a new runner for each
-func (exp *Exporter) listen() {
+// Serve waits for incoming connections on the listener
+// and serves the Exporter's channels on each.
+// It blocks until the listener is closed.
+func (exp *Exporter) Serve(listener net.Listener) {
        for {
-               conn, err := exp.listener.Accept()
+               conn, err := listener.Accept()
                if err != nil {
                        expLog("listen:", err)
                        break
                }
-               client := exp.addClient(conn)
-               go client.run()
+               go exp.ServeConn(conn)
        }
 }
 
-// NewExporter creates a new Exporter to export channels
-// on the network and local address defined as in net.Listen.
-func NewExporter(network, localaddr string) (*Exporter, os.Error) {
-       listener, err := net.Listen(network, localaddr)
-       if err != nil {
-               return nil, err
-       }
+// ServeConn exports the Exporter's channels on conn.
+// It blocks until the connection is terminated.
+func (exp *Exporter) ServeConn(conn io.ReadWriter) {
+       exp.addClient(conn).run()
+}
+
+// NewExporter creates a new Exporter that exports a set of channels.
+func NewExporter() *Exporter {
        e := &Exporter{
-               listener: listener,
                clientSet: &clientSet{
                        names:   make(map[string]*chanDir),
                        clients: make(map[unackedCounter]bool),
                },
        }
-       go e.listen()
-       return e, nil
+       return e
+}
+
+// ListenAndServe exports the exporter's channels through the
+// given network and local address defined as in net.Listen.
+func (exp *Exporter) ListenAndServe(network, localaddr string) os.Error {
+       listener, err := net.Listen(network, localaddr)
+       if err != nil {
+               return err
+       }
+       go exp.Serve(listener)
+       return nil
 }
 
 // addClient creates a new expClient and records its existence
-func (exp *Exporter) addClient(conn net.Conn) *expClient {
+func (exp *Exporter) addClient(conn io.ReadWriter) *expClient {
        client := newClient(exp, conn)
        exp.mu.Lock()
        exp.clients[client] = true
@@ -329,9 +340,6 @@ func (exp *Exporter) Sync(timeout int64) os.Error {
        return exp.clientSet.sync(timeout)
 }
 
-// Addr returns the Exporter's local network address.
-func (exp *Exporter) Addr() net.Addr { return exp.listener.Addr() }
-
 func checkChan(chT interface{}, dir Dir) (*reflect.ChanValue, os.Error) {
        chanType, ok := reflect.Typeof(chT).(*reflect.ChanType)
        if !ok {
index d220d9a662a370be87199b7558c34da5dee48e71..30edcd8123bd1b2f159255aad5963fb6a11b9c63 100644 (file)
@@ -5,6 +5,7 @@
 package netchan
 
 import (
+       "io"
        "log"
        "net"
        "os"
@@ -25,7 +26,6 @@ func impLog(args ...interface{}) {
 // importers, even from the same machine/network port.
 type Importer struct {
        *encDec
-       conn     net.Conn
        chanLock sync.Mutex // protects access to channel map
        names    map[string]*netChan
        chans    map[int]*netChan
@@ -33,23 +33,26 @@ type Importer struct {
        maxId    int
 }
 
-// NewImporter creates a new Importer object to import channels
-// from an Exporter at the network and remote address as defined in net.Dial.
-// The Exporter must be available and serving when the Importer is
-// created.
-func NewImporter(network, remoteaddr string) (*Importer, os.Error) {
-       conn, err := net.Dial(network, "", remoteaddr)
-       if err != nil {
-               return nil, err
-       }
+// NewImporter creates a new Importer object to import a set of channels
+// from the given connection. The Exporter must be available and serving when
+// the Importer is created.
+func NewImporter(conn io.ReadWriter) *Importer {
        imp := new(Importer)
        imp.encDec = newEncDec(conn)
-       imp.conn = conn
        imp.chans = make(map[int]*netChan)
        imp.names = make(map[string]*netChan)
        imp.errors = make(chan os.Error, 10)
        go imp.run()
-       return imp, nil
+       return imp
+}
+
+// Import imports a set of channels from the given network and address.
+func Import(network, remoteaddr string) (*Importer, os.Error) {
+       conn, err := net.Dial(network, "", remoteaddr)
+       if err != nil {
+               return nil, err
+       }
+       return NewImporter(conn), nil
 }
 
 // shutdown closes all channels for which we are receiving data from the remote side.
@@ -231,15 +234,13 @@ func (imp *Importer) ImportNValues(name string, chT interface{}, dir Dir, size,
 // the channel.  Messages in flight for the channel may be dropped.
 func (imp *Importer) Hangup(name string) os.Error {
        imp.chanLock.Lock()
-       nc, ok := imp.names[name]
-       if ok {
-               imp.names[name] = nil, false
-               imp.chans[nc.id] = nil, false
-       }
-       imp.chanLock.Unlock()
-       if !ok {
+       defer imp.chanLock.Unlock()
+       nc := imp.names[name]
+       if nc == nil {
                return os.ErrorString("netchan import: hangup: no such channel: " + name)
        }
+       imp.names[name] = nil, false
+       imp.chans[nc.id] = nil, false
        nc.close()
        return nil
 }
index 4076aefebfc09934409bd7395b0572ea3fe9a001..1c84a9d14dee485b3b2804ecf8b1d4d0ee9758e2 100644 (file)
@@ -5,6 +5,7 @@
 package netchan
 
 import (
+       "net"
        "strings"
        "testing"
        "time"
@@ -94,27 +95,13 @@ func importReceive(imp *Importer, t *testing.T, done chan bool) {
 }
 
 func TestExportSendImportReceive(t *testing.T) {
-       exp, err := NewExporter("tcp", "127.0.0.1:0")
-       if err != nil {
-               t.Fatal("new exporter:", err)
-       }
-       imp, err := NewImporter("tcp", exp.Addr().String())
-       if err != nil {
-               t.Fatal("new importer:", err)
-       }
+       exp, imp := pair(t)
        exportSend(exp, count, t, nil)
        importReceive(imp, t, nil)
 }
 
 func TestExportReceiveImportSend(t *testing.T) {
-       exp, err := NewExporter("tcp", "127.0.0.1:0")
-       if err != nil {
-               t.Fatal("new exporter:", err)
-       }
-       imp, err := NewImporter("tcp", exp.Addr().String())
-       if err != nil {
-               t.Fatal("new importer:", err)
-       }
+       exp, imp := pair(t)
        expDone := make(chan bool)
        done := make(chan bool)
        go func() {
@@ -127,27 +114,13 @@ func TestExportReceiveImportSend(t *testing.T) {
 }
 
 func TestClosingExportSendImportReceive(t *testing.T) {
-       exp, err := NewExporter("tcp", "127.0.0.1:0")
-       if err != nil {
-               t.Fatal("new exporter:", err)
-       }
-       imp, err := NewImporter("tcp", exp.Addr().String())
-       if err != nil {
-               t.Fatal("new importer:", err)
-       }
+       exp, imp := pair(t)
        exportSend(exp, closeCount, t, nil)
        importReceive(imp, t, nil)
 }
 
 func TestClosingImportSendExportReceive(t *testing.T) {
-       exp, err := NewExporter("tcp", "127.0.0.1:0")
-       if err != nil {
-               t.Fatal("new exporter:", err)
-       }
-       imp, err := NewImporter("tcp", exp.Addr().String())
-       if err != nil {
-               t.Fatal("new importer:", err)
-       }
+       exp, imp := pair(t)
        expDone := make(chan bool)
        done := make(chan bool)
        go func() {
@@ -160,17 +133,10 @@ func TestClosingImportSendExportReceive(t *testing.T) {
 }
 
 func TestErrorForIllegalChannel(t *testing.T) {
-       exp, err := NewExporter("tcp", "127.0.0.1:0")
-       if err != nil {
-               t.Fatal("new exporter:", err)
-       }
-       imp, err := NewImporter("tcp", exp.Addr().String())
-       if err != nil {
-               t.Fatal("new importer:", err)
-       }
+       exp, imp := pair(t)
        // Now export a channel.
        ch := make(chan int, 1)
-       err = exp.Export("aChannel", ch, Send)
+       err := exp.Export("aChannel", ch, Send)
        if err != nil {
                t.Fatal("export:", err)
        }
@@ -200,14 +166,7 @@ func TestErrorForIllegalChannel(t *testing.T) {
 
 // Not a great test but it does at least invoke Drain.
 func TestExportDrain(t *testing.T) {
-       exp, err := NewExporter("tcp", "127.0.0.1:0")
-       if err != nil {
-               t.Fatal("new exporter:", err)
-       }
-       imp, err := NewImporter("tcp", exp.Addr().String())
-       if err != nil {
-               t.Fatal("new importer:", err)
-       }
+       exp, imp := pair(t)
        done := make(chan bool)
        go func() {
                exportSend(exp, closeCount, t, nil)
@@ -221,14 +180,7 @@ func TestExportDrain(t *testing.T) {
 
 // Not a great test but it does at least invoke Sync.
 func TestExportSync(t *testing.T) {
-       exp, err := NewExporter("tcp", "127.0.0.1:0")
-       if err != nil {
-               t.Fatal("new exporter:", err)
-       }
-       imp, err := NewImporter("tcp", exp.Addr().String())
-       if err != nil {
-               t.Fatal("new importer:", err)
-       }
+       exp, imp := pair(t)
        done := make(chan bool)
        exportSend(exp, closeCount, t, nil)
        go importReceive(imp, t, done)
@@ -239,16 +191,9 @@ func TestExportSync(t *testing.T) {
 // Test hanging up the send side of an export.
 // TODO: test hanging up the receive side of an export.
 func TestExportHangup(t *testing.T) {
-       exp, err := NewExporter("tcp", "127.0.0.1:0")
-       if err != nil {
-               t.Fatal("new exporter:", err)
-       }
-       imp, err := NewImporter("tcp", exp.Addr().String())
-       if err != nil {
-               t.Fatal("new importer:", err)
-       }
+       exp, imp := pair(t)
        ech := make(chan int)
-       err = exp.Export("exportedSend", ech, Send)
+       err := exp.Export("exportedSend", ech, Send)
        if err != nil {
                t.Fatal("export:", err)
        }
@@ -276,16 +221,9 @@ func TestExportHangup(t *testing.T) {
 // Test hanging up the send side of an import.
 // TODO: test hanging up the receive side of an import.
 func TestImportHangup(t *testing.T) {
-       exp, err := NewExporter("tcp", "127.0.0.1:0")
-       if err != nil {
-               t.Fatal("new exporter:", err)
-       }
-       imp, err := NewImporter("tcp", exp.Addr().String())
-       if err != nil {
-               t.Fatal("new importer:", err)
-       }
+       exp, imp := pair(t)
        ech := make(chan int)
-       err = exp.Export("exportedRecv", ech, Recv)
+       err := exp.Export("exportedRecv", ech, Recv)
        if err != nil {
                t.Fatal("export:", err)
        }
@@ -343,14 +281,7 @@ func exportLoopback(exp *Exporter, t *testing.T) {
 // This test checks that channel operations can proceed
 // even when other concurrent operations are blocked.
 func TestIndependentSends(t *testing.T) {
-       exp, err := NewExporter("tcp", "127.0.0.1:0")
-       if err != nil {
-               t.Fatal("new exporter:", err)
-       }
-       imp, err := NewImporter("tcp", exp.Addr().String())
-       if err != nil {
-               t.Fatal("new importer:", err)
-       }
+       exp, imp := pair(t)
 
        exportLoopback(exp, t)
 
@@ -377,23 +308,8 @@ type value struct {
 }
 
 func TestCrossConnect(t *testing.T) {
-       e1, err := NewExporter("tcp", "127.0.0.1:0")
-       if err != nil {
-               t.Fatal("new exporter:", err)
-       }
-       i1, err := NewImporter("tcp", e1.Addr().String())
-       if err != nil {
-               t.Fatal("new importer:", err)
-       }
-
-       e2, err := NewExporter("tcp", "127.0.0.1:0")
-       if err != nil {
-               t.Fatal("new exporter:", err)
-       }
-       i2, err := NewImporter("tcp", e2.Addr().String())
-       if err != nil {
-               t.Fatal("new importer:", err)
-       }
+       e1, i1 := pair(t)
+       e2, i2 := pair(t)
 
        crossExport(e1, e2, t)
        crossImport(i1, i2, t)
@@ -452,20 +368,13 @@ const flowCount = 100
 
 // test flow control from exporter to importer.
 func TestExportFlowControl(t *testing.T) {
-       exp, err := NewExporter("tcp", "127.0.0.1:0")
-       if err != nil {
-               t.Fatal("new exporter:", err)
-       }
-       imp, err := NewImporter("tcp", exp.Addr().String())
-       if err != nil {
-               t.Fatal("new importer:", err)
-       }
+       exp, imp := pair(t)
 
        sendDone := make(chan bool, 1)
        exportSend(exp, flowCount, t, sendDone)
 
        ch := make(chan int)
-       err = imp.ImportNValues("exportedSend", ch, Recv, 20, -1)
+       err := imp.ImportNValues("exportedSend", ch, Recv, 20, -1)
        if err != nil {
                t.Fatal("importReceive:", err)
        }
@@ -475,17 +384,10 @@ func TestExportFlowControl(t *testing.T) {
 
 // test flow control from importer to exporter.
 func TestImportFlowControl(t *testing.T) {
-       exp, err := NewExporter("tcp", "127.0.0.1:0")
-       if err != nil {
-               t.Fatal("new exporter:", err)
-       }
-       imp, err := NewImporter("tcp", exp.Addr().String())
-       if err != nil {
-               t.Fatal("new importer:", err)
-       }
+       exp, imp := pair(t)
 
        ch := make(chan int)
-       err = exp.Export("exportedRecv", ch, Recv)
+       err := exp.Export("exportedRecv", ch, Recv)
        if err != nil {
                t.Fatal("importReceive:", err)
        }
@@ -513,3 +415,11 @@ func testFlow(sendDone chan bool, ch <-chan int, N int, t *testing.T) {
                t.Fatalf("expected %d values; got %d", N, n)
        }
 }
+
+func pair(t *testing.T) (*Exporter, *Importer) {
+       c0, c1 := net.Pipe()
+       exp := NewExporter()
+       go exp.ServeConn(c0)
+       imp := NewImporter(c1)
+       return exp, imp
+}