]> Cypherpunks repositories - gostls13.git/commitdiff
netchan: allow client to send as well as receive.
authorRob Pike <r@golang.org>
Tue, 13 Apr 2010 20:49:42 +0000 (13:49 -0700)
committerRob Pike <r@golang.org>
Tue, 13 Apr 2010 20:49:42 +0000 (13:49 -0700)
much rewriting and improving, but it's still experimental.

R=rsc
CC=golang-dev
https://golang.org/cl/875045

src/pkg/netchan/common.go
src/pkg/netchan/export.go
src/pkg/netchan/import.go
src/pkg/netchan/netchan_test.go

index a82bd07c16b0f8eadfd02908cd7a3bb878552f20..0fe9c96bb8e72ed794cc6359dd7e685132db3849 100644 (file)
@@ -1,4 +1,4 @@
-// Copyright 2009 The Go Authors. All rights reserved.
+// Copyright 2010 The Go Authors. All rights reserved.
 // Use of this source code is governed by a BSD-style
 // license that can be found in the LICENSE file.
 
@@ -6,12 +6,12 @@ package netchan
 
 import (
        "gob"
-       "log"
        "net"
        "os"
        "sync"
 )
 
+// The direction of a connection from the client's perspective.
 type Dir int
 
 const (
@@ -19,8 +19,34 @@ const (
        Send
 )
 
-// Mutex-protected encoder and decoder pair
+// Payload types
+const (
+       payRequest = iota // request structure follows
+       payError          // error structure follows
+       payData           // user payload follows
+)
+
+// A header is sent as a prefix to every transmission.  It will be followed by
+// a request structure, an error structure, or an arbitrary user payload structure.
+type header struct {
+       name        string
+       payloadType int
+}
+
+// Sent with a header once per channel from importer to exporter to report
+// that it wants to bind to a channel with the specified direction for count
+// messages.  If count is zero, it means unlimited.
+type request struct {
+       count int
+       dir   Dir
+}
 
+// Sent with a header to report an error.
+type error struct {
+       error string
+}
+
+// Mutex-protected encoder and decoder pair.
 type encDec struct {
        decLock sync.Mutex
        dec     *gob.Decoder
@@ -35,29 +61,27 @@ func newEncDec(conn net.Conn) *encDec {
        }
 }
 
+// Decode an item from the connection.
 func (ed *encDec) decode(e interface{}) os.Error {
        ed.decLock.Lock()
-       defer ed.decLock.Unlock()
        err := ed.dec.Decode(e)
        if err != nil {
-               log.Stderr("exporter decode:", err)
-               // TODO: tear down connection
-               return err
+               // TODO: tear down connection?
        }
-       return nil
+       ed.decLock.Unlock()
+       return err
 }
 
-func (ed *encDec) encode(e0, e1 interface{}) os.Error {
+// Encode a header and payload onto the connection.
+func (ed *encDec) encode(hdr *header, payloadType int, payload interface{}) os.Error {
        ed.encLock.Lock()
-       defer ed.encLock.Unlock()
-       err := ed.enc.Encode(e0)
-       if err == nil && e1 != nil {
-               err = ed.enc.Encode(e1)
-       }
-       if err != nil {
-               log.Stderr("exporter encode:", err)
-               // TODO: tear down connection?
-               return err
+       hdr.payloadType = payloadType
+       err := ed.enc.Encode(hdr)
+       if err == nil {
+               err = ed.enc.Encode(payload)
+       } else {
+               // TODO: tear down connection if there is an error?
        }
-       return nil
+       ed.encLock.Unlock()
+       return err
 }
index 626630b4a04c7b6cdc888abf331fdf5718f571cb..89deb20ae24589259ac85f5b16a360c5e9afccdb 100644 (file)
@@ -1,4 +1,4 @@
-// Copyright 2009 The Go Authors. All rights reserved.
+// Copyright 2010 The Go Authors. All rights reserved.
 // Use of this source code is governed by a BSD-style
 // license that can be found in the LICENSE file.
 
        Networked channels are not synchronized; they always behave
        as if there is a buffer of at least one element between the
        two machines.
-
-       TODO: at the moment, the exporting machine must send and
-       the importing machine must receive.  This restriction will
-       be lifted soon.
 */
 package netchan
 
@@ -34,10 +30,12 @@ import (
 
 // Export
 
-// A channel and its associated information: a direction
+// A channel and its associated information: a direction plus
+// a handy marshaling place for its data.
 type exportChan struct {
        ch  *reflect.ChanValue
        dir Dir
+       ptr *reflect.PtrValue // a pointer value we can point at each new received item
 }
 
 // An Exporter allows a set of channels to be published on a single
@@ -62,21 +60,6 @@ func newClient(exp *Exporter, conn net.Conn) *expClient {
 
 }
 
-// TODO: ASSUMES EXPORT MEANS SEND
-
-// Sent once per channel from importer to exporter to report that it's listening to a channel
-type request struct {
-       name  string
-       dir   Dir
-       count int
-}
-
-// Reply to request, sent from exporter to importer on each send.
-type response struct {
-       name  string
-       error string
-}
-
 // Wait for incoming connections, start a new runner for each
 func (exp *Exporter) listen() {
        for {
@@ -85,70 +68,112 @@ func (exp *Exporter) listen() {
                        log.Stderr("exporter.listen:", err)
                        break
                }
-               log.Stderr("accepted call from", conn.RemoteAddr())
                client := newClient(exp, conn)
                go client.run()
        }
 }
 
-// Send a single client all its data.  For each request, this will launch
-// a serveRecv goroutine to deliver the data for that channel.
+func (client *expClient) sendError(hdr *header, err string) {
+       error := &error{err}
+       log.Stderr("export:", error.error)
+       client.encode(hdr, payError, error) // ignore any encode error, hope client gets it
+}
+
+func (client *expClient) getChan(hdr *header, dir Dir) *exportChan {
+       exp := client.exp
+       exp.chanLock.Lock()
+       ech, ok := exp.chans[hdr.name]
+       exp.chanLock.Unlock()
+       if !ok {
+               client.sendError(hdr, "no such channel: "+hdr.name)
+               return nil
+       }
+       if ech.dir != dir {
+               client.sendError(hdr, "wrong direction for channel: "+hdr.name)
+               return nil
+       }
+       return ech
+}
+
+// Manage sends and receives for a single client.  For each (client Recv) request,
+// this will launch a serveRecv goroutine to deliver the data for that channel,
+// while (client Send) requests are handled as data arrives from the client.
 func (client *expClient) run() {
+       hdr := new(header)
        req := new(request)
+       error := new(error)
        for {
-               if err := client.decode(req); err != nil {
-                       log.Stderr("error decoding client request:", err)
+               if err := client.decode(hdr); err != nil {
+                       log.Stderr("error decoding client header:", err)
                        // TODO: tear down connection
-                       break
+                       return
                }
-               log.Stderrf("export request: %+v", req)
-               if req.dir == Recv {
-                       go client.serveRecv(req)
-               } else {
-                       log.Stderr("export request: can't handle channel direction", req.dir)
-                       resp := new(response)
-                       resp.name = req.name
-                       resp.error = "export request: can't handle channel direction"
-                       client.encode(resp, nil)
-                       break
+               switch hdr.payloadType {
+               case payRequest:
+                       if err := client.decode(req); err != nil {
+                               log.Stderr("error decoding client request:", err)
+                               // TODO: tear down connection
+                               return
+                       }
+                       switch req.dir {
+                       case Recv:
+                               go client.serveRecv(*hdr, req.count)
+                       case Send:
+                               // Request to send is clear as a matter of protocol
+                               // but not actually used by the implementation.
+                               // The actual sends will have payload type payData.
+                               // TODO: manage the count?
+                       default:
+                               error.error = "export request: can't handle channel direction"
+                               log.Stderr(error.error, req.dir)
+                               client.encode(hdr, payError, error)
+                       }
+               case payData:
+                       client.serveSend(*hdr)
                }
        }
 }
 
-// Send all the data on a single channel to a client asking for a Recv
-func (client *expClient) serveRecv(req *request) {
-       exp := client.exp
-       resp := new(response)
-       resp.name = req.name
-       var ok bool
-       exp.chanLock.Lock()
-       ech, ok := exp.chans[req.name]
-       exp.chanLock.Unlock()
-       if !ok {
-               resp.error = "no such channel: " + req.name
-               log.Stderr("export:", resp.error)
-               client.encode(resp, nil) // ignore any encode error, hope client gets it
+// Send all the data on a single channel to a client asking for a Recv.
+// The header is passed by value to avoid issues of overwriting.
+func (client *expClient) serveRecv(hdr header, count int) {
+       ech := client.getChan(&hdr, Send)
+       if ech == nil {
                return
        }
        for {
-               if ech.dir != Send {
-                       log.Stderr("TODO: recv export unimplemented")
-                       break
-               }
                val := ech.ch.Recv()
-               if err := client.encode(resp, val.Interface()); err != nil {
+               if err := client.encode(&hdr, payData, val.Interface()); err != nil {
                        log.Stderr("error encoding client response:", err)
+                       client.sendError(&hdr, err.String())
                        break
                }
-               if req.count > 0 {
-                       req.count--
-                       if req.count == 0 {
+               if count > 0 {
+                       if count--; count == 0 {
                                break
                        }
                }
        }
 }
 
+// Receive and deliver locally one item from a client asking for a Send
+// The header is passed by value to avoid issues of overwriting.
+func (client *expClient) serveSend(hdr header) {
+       ech := client.getChan(&hdr, Recv)
+       if ech == nil {
+               return
+       }
+       // Create a new value for each received item.
+       val := reflect.MakeZero(ech.ptr.Type().(*reflect.PtrType).Elem())
+       ech.ptr.PointTo(val)
+       if err := client.decode(ech.ptr.Interface()); err != nil {
+               log.Stderr("exporter value decode:", err)
+               return
+       }
+       ech.ch.Send(val)
+       // TODO count
+}
+
 // NewExporter creates a new Exporter to export channels
 // on the network and local address defined as in net.Listen.
 func NewExporter(network, localaddr string) (*Exporter, os.Error) {
@@ -195,7 +220,8 @@ func checkChan(chT interface{}, dir Dir) (*reflect.ChanValue, os.Error) {
 // Despite the literal signature, the effective signature is
 //     Export(name string, chT chan T, dir Dir)
 // where T must be a struct, pointer to struct, etc.
-func (exp *Exporter) Export(name string, chT interface{}, dir Dir) os.Error {
+// TODO: fix gob interface so we can eliminate the need for pT, and for structs.
+func (exp *Exporter) Export(name string, chT interface{}, dir Dir, pT interface{}) os.Error {
        ch, err := checkChan(chT, dir)
        if err != nil {
                return err
@@ -206,6 +232,7 @@ func (exp *Exporter) Export(name string, chT interface{}, dir Dir) os.Error {
        if present {
                return os.ErrorString("channel name already being exported:" + name)
        }
-       exp.chans[name] = &exportChan{ch, dir}
+       ptr := reflect.MakeZero(reflect.Typeof(pT)).(*reflect.PtrValue)
+       exp.chans[name] = &exportChan{ch, dir, ptr}
        return nil
 }
index 263ee4404b913938331902c97c295f866c83ac90..bde36f6152ac0107d16e0fc07e8cf75c684ab1ae 100644 (file)
@@ -1,4 +1,4 @@
-// Copyright 2009 The Go Authors. All rights reserved.
+// Copyright 2010 The Go Authors. All rights reserved.
 // Use of this source code is governed by a BSD-style
 // license that can be found in the LICENSE file.
 
@@ -14,12 +14,12 @@ import (
 
 // Import
 
-// A channel and its associated information: a template value, direction and a count
+// A channel and its associated information: a template value and direction,
+// plus a handy marshaling place for its data.
 type importChan struct {
-       ch    *reflect.ChanValue
-       dir   Dir
-       ptr   *reflect.PtrValue // a pointer value we can point at each new item
-       count int
+       ch  *reflect.ChanValue
+       dir Dir
+       ptr *reflect.PtrValue // a pointer value we can point at each new received item
 }
 
 // An Importer allows a set of channels to be imported from a single
@@ -32,8 +32,6 @@ type Importer struct {
        chans    map[string]*importChan
 }
 
-// TODO: ASSUMES IMPORT MEANS RECEIVE
-
 // NewImporter creates a new Importer object to import channels
 // from an Exporter at the network and remote address as defined in net.Dial.
 // The Exporter must be available and serving when the Importer is
@@ -54,37 +52,49 @@ func NewImporter(network, remoteaddr string) (*Importer, os.Error) {
 // Handle the data from a single imported data stream, which will
 // have the form
 //     (response, data)*
-// The response identifies by name which channel is receiving data.
-// TODO: allow an importer to send.
+// The response identifies by name which channel is transmitting data.
 func (imp *Importer) run() {
        // Loop on responses; requests are sent by ImportNValues()
-       resp := new(response)
+       hdr := new(header)
+       err := new(error)
        for {
-               if err := imp.decode(resp); err != nil {
-                       log.Stderr("importer response decode:", err)
-                       break
+               if e := imp.decode(hdr); e != nil {
+                       log.Stderr("importer header:", e)
+                       return
                }
-               if resp.error != "" {
-                       log.Stderr("importer response error:", resp.error)
-                       // TODO: tear down connection
-                       break
+               switch hdr.payloadType {
+               case payData:
+                       // done lower in loop
+               case payError:
+                       if e := imp.decode(err); e != nil {
+                               log.Stderr("importer error:", e)
+                               return
+                       }
+                       if err.error != "" {
+                               log.Stderr("importer response error:", err.error)
+                               // TODO: tear down connection
+                               return
+                       }
+               default:
+                       log.Stderr("unexpected payload type:", hdr.payloadType)
+                       return
                }
                imp.chanLock.Lock()
-               ich, ok := imp.chans[resp.name]
+               ich, ok := imp.chans[hdr.name]
                imp.chanLock.Unlock()
                if !ok {
-                       log.Stderr("unknown name in request:", resp.name)
-                       break
+                       log.Stderr("unknown name in request:", hdr.name)
+                       return
                }
                if ich.dir != Recv {
-                       log.Stderr("TODO: import send unimplemented")
-                       break
+                       log.Stderr("cannot happen: receive from non-Recv channel")
+                       return
                }
                // Create a new value for each received item.
                val := reflect.MakeZero(ich.ptr.Type().(*reflect.PtrType).Elem())
                ich.ptr.PointTo(val)
-               if err := imp.decode(ich.ptr.Interface()); err != nil {
-                       log.Stderr("importer value decode:", err)
+               if e := imp.decode(ich.ptr.Interface()); e != nil {
+                       log.Stderr("importer value decode:", e)
                        return
                }
                ich.ch.Send(val)
@@ -103,7 +113,7 @@ func (imp *Importer) Import(name string, chT interface{}, dir Dir, pT interface{
 // the remote site's channel is provided in the call and may be of arbitrary
 // channel type.
 // Despite the literal signature, the effective signature is
-//     ImportNValues(name string, chT chan T, dir Dir, pT T)
+//     ImportNValues(name string, chT chan T, dir Dir, pT T, n int) os.Error
 // where T must be a struct, pointer to struct, etc.  pT may be more indirect
 // than the value type of the channel (e.g.  chan T, pT *T) but it must be a
 // pointer.
@@ -114,7 +124,7 @@ func (imp *Importer) Import(name string, chT interface{}, dir Dir, pT interface{
 //     err := imp.ImportNValues("name", ch, Recv, new(myType), 1)
 //     if err != nil { log.Exit(err) }
 //     fmt.Printf("%+v\n", <-ch)
-// (TODO: Can we eliminate the need for pT?)
+// TODO: fix gob interface so we can eliminate the need for pT, and for structs.
 func (imp *Importer) ImportNValues(name string, chT interface{}, dir Dir, pT interface{}, n int) os.Error {
        ch, err := checkChan(chT, dir)
        if err != nil {
@@ -135,15 +145,28 @@ func (imp *Importer) ImportNValues(name string, chT interface{}, dir Dir, pT int
                return os.ErrorString("channel name already being imported:" + name)
        }
        ptr := reflect.MakeZero(reflect.Typeof(pT)).(*reflect.PtrValue)
-       imp.chans[name] = &importChan{ch, dir, ptr, n}
+       imp.chans[name] = &importChan{ch, dir, ptr}
        // Tell the other side about this channel.
+       hdr := new(header)
+       hdr.name = name
+       hdr.payloadType = payRequest
        req := new(request)
-       req.name = name
        req.dir = dir
        req.count = n
-       if err := imp.encode(req, nil); err != nil {
+       if err := imp.encode(hdr, payRequest, req); err != nil {
                log.Stderr("importer request encode:", err)
                return err
        }
+       if dir == Send {
+               go func() {
+                       for i := 0; n == 0 || i < n; i++ {
+                               val := ch.Recv()
+                               if err := imp.encode(hdr, payData, val.Interface()); err != nil {
+                                       log.Stderr("error encoding client response:", err)
+                                       return
+                               }
+                       }
+               }()
+       }
        return nil
 }
index b7fd100cf5eb74a9c87c6e793a3c37fc74bf2221..cdf709406107b4e40ab6e50ecf3556857f92fb28 100644 (file)
@@ -1,4 +1,4 @@
-// Copyright 2009 The Go Authors. All rights reserved.
+// Copyright 2010 The Go Authors. All rights reserved.
 // Use of this source code is governed by a BSD-style
 // license that can be found in the LICENSE file.
 
@@ -14,37 +14,82 @@ type value struct {
        s string
 }
 
+const count = 10
+
 func exportSend(exp *Exporter, t *testing.T) {
-       c := make(chan value)
-       err := exp.Export("name", c, Send)
+       ch := make(chan value)
+       err := exp.Export("exportedSend", ch, Send, new(value))
+       if err != nil {
+               t.Fatal("exportSend:", err)
+       }
+       for i := 0; i < count; i++ {
+               ch <- value{23 + i, "hello"}
+       }
+}
+
+func exportReceive(exp *Exporter, t *testing.T) {
+       ch := make(chan value)
+       err := exp.Export("exportedRecv", ch, Recv, new(value))
        if err != nil {
-               t.Fatal("export:", err)
+               t.Fatal("exportReceive:", err)
+       }
+       for i := 0; i < count; i++ {
+               v := <-ch
+               fmt.Printf("%v\n", v)
+               if v.i != 45+i || v.s != "hello" {
+                       t.Errorf("export Receive: bad value: expected 4%d, hello; got %+v", 45+i, v)
+               }
        }
-       c <- value{23, "hello"}
 }
 
 func importReceive(imp *Importer, t *testing.T) {
        ch := make(chan value)
-       err := imp.ImportNValues("name", ch, Recv, new(value), 1)
+       err := imp.ImportNValues("exportedSend", ch, Recv, new(value), count)
        if err != nil {
-               t.Fatal("import:", err)
+               t.Fatal("importReceive:", err)
        }
-       v := <-ch
-       fmt.Printf("%v\n", v)
-       if v.i != 23 || v.s != "hello" {
-               t.Errorf("bad value: expected 23, hello; got %+v\n", v)
+       for i := 0; i < count; i++ {
+               v := <-ch
+               fmt.Printf("%v\n", v)
+               if v.i != 23+i || v.s != "hello" {
+                       t.Errorf("importReceive: bad value: expected %d, hello; got %+v", 23+i, v)
+               }
        }
 }
 
-func TestBabyStep(t *testing.T) {
+func importSend(imp *Importer, t *testing.T) {
+       ch := make(chan value)
+       err := imp.ImportNValues("exportedRecv", ch, Send, new(value), count)
+       if err != nil {
+               t.Fatal("importSend:", err)
+       }
+       for i := 0; i < count; i++ {
+               ch <- value{45 + i, "hello"}
+       }
+}
+
+func TestExportSendImportReceive(t *testing.T) {
        exp, err := NewExporter("tcp", ":0")
        if err != nil {
                t.Fatal("new exporter:", err)
        }
-       go exportSend(exp, t)
        imp, err := NewImporter("tcp", exp.Addr().String())
        if err != nil {
                t.Fatal("new importer:", err)
        }
+       go exportSend(exp, t)
        importReceive(imp, t)
 }
+
+func TestExportReceiveImportSend(t *testing.T) {
+       exp, err := NewExporter("tcp", ":0")
+       if err != nil {
+               t.Fatal("new exporter:", err)
+       }
+       imp, err := NewImporter("tcp", exp.Addr().String())
+       if err != nil {
+               t.Fatal("new importer:", err)
+       }
+       go importSend(imp, t)
+       exportReceive(exp, t)
+}