]> Cypherpunks repositories - gostls13.git/commitdiff
rpc.
authorRob Pike <r@golang.org>
Mon, 13 Jul 2009 23:52:57 +0000 (16:52 -0700)
committerRob Pike <r@golang.org>
Mon, 13 Jul 2009 23:52:57 +0000 (16:52 -0700)
client library.
muxes on both ends.

R=rsc
DELTA=178  (132 added, 26 deleted, 20 changed)
OCL=31541
CL=31556

src/pkg/rpc/Makefile
src/pkg/rpc/client.go
src/pkg/rpc/server.go
src/pkg/rpc/server_test.go

index 4d8c15629e8fec6ac101390414d435dfc3bc7c64..0d0f109b8623c38cac462c4aba129e49c2792ffc 100644 (file)
@@ -33,17 +33,23 @@ coverage: packages
        $(AS) $*.s
 
 O1=\
-       client.$O\
        server.$O\
 
+O2=\
+       client.$O\
+
 
-phases: a1
+phases: a1 a2
 _obj$D/rpc.a: phases
 
 a1: $(O1)
-       $(AR) grc _obj$D/rpc.a client.$O server.$O
+       $(AR) grc _obj$D/rpc.a server.$O
        rm -f $(O1)
 
+a2: $(O2)
+       $(AR) grc _obj$D/rpc.a client.$O
+       rm -f $(O2)
+
 
 newpkg: clean
        mkdir -p _obj$D
@@ -51,6 +57,7 @@ newpkg: clean
 
 $(O1): newpkg
 $(O2): a1
+$(O3): a2
 
 nuke: clean
        rm -f $(GOROOT)/pkg/$(GOOS)_$(GOARCH)$D/rpc.a
index 7df3dc8d51f6524ccddad5df18ae1af5bde8288e..725add1a54a59e03119c5a7cbfea34a4a37c5a7b 100644 (file)
@@ -8,7 +8,106 @@ import (
        "gob";
        "io";
        "os";
-       "reflect";
+       "rpc";
        "sync";
 )
 
+// Call represents an active RPC
+type Call struct {
+       ServiceMethod   string; // The name of the service and method to call.
+       Args    interface{};    // The argument to the function (*struct).
+       Reply   interface{};    // The reply from the function (*struct).
+       Error   os.Error;       // After completion, the error status.
+       Done    chan *Call;     // Strobes when call is complete; value is the error status.
+       seq     uint64;
+}
+
+// Client represents an RPC Client.
+type Client struct {
+       sync.Mutex;     // protects pending, seq
+       sending sync.Mutex;
+       seq     uint64;
+       conn io.ReadWriteCloser;
+       enc     *gob.Encoder;
+       dec     *gob.Decoder;
+       pending map[uint64] *Call;
+}
+
+func (client *Client) send(c *Call) {
+       // Register this call.
+       client.Lock();
+       c.seq = client.seq;
+       client.seq++;
+       client.pending[c.seq] = c;
+       client.Unlock();
+
+       // Encode and send the request.
+       request := new(Request);
+       client.sending.Lock();
+       request.Seq = c.seq;
+       request.ServiceMethod = c.ServiceMethod;
+       client.enc.Encode(request);
+       err := client.enc.Encode(c.Args);
+       if err != nil {
+               panicln("client encode error:", err)
+       }
+       client.sending.Unlock();
+}
+
+func (client *Client) serve() {
+       for {
+               response := new(Response);
+               err := client.dec.Decode(response);
+               seq := response.Seq;
+               client.Lock();
+               c := client.pending[seq];
+               client.pending[seq] = c, false;
+               client.Unlock();
+               client.dec.Decode(c.Reply);
+               if err != nil {
+                       panicln("client decode error:", err)
+               }
+               c.Error = os.ErrorString(response.Error);
+               // We don't want to block here, it is the caller's responsibility to make
+               // sure the channel has enough buffer space. See comment in Start().
+               doNotBlock := c.Done <- c;
+       }
+}
+
+// NewClient returns a new Client to handle requests to the
+// set of services at the other end of the connection.
+func NewClient(conn io.ReadWriteCloser) *Client {
+       client := new(Client);
+       client.conn = conn;
+       client.enc = gob.NewEncoder(conn);
+       client.dec = gob.NewDecoder(conn);
+       client.pending = make(map[uint64] *Call);
+       go client.serve();
+       return client;
+}
+
+// Start invokes the function asynchronously.  It returns the Call structure representing
+// the invocation.
+func (client *Client) Start(serviceMethod string, args interface{}, reply interface{}, done chan *Call) *Call {
+       c := new(Call);
+       c.ServiceMethod = serviceMethod;
+       c.Args = args;
+       c.Reply = reply;
+       if done == nil {
+               done = make(chan *Call, 1);     // buffered.
+       } else {
+               // TODO(r): check cap > 0
+               // If caller passes done != nil, it must arrange that
+               // done has enough buffer for the number of simultaneous
+               // RPCs that will be using that channel.
+       }
+       c.Done = done;
+       client.send(c);
+       return c;
+}
+
+// Call invokes the named function, waits for it to complete, and returns its error status.
+func (client *Client) Call(serviceMethod string, args interface{}, reply interface{}) os.Error {
+       call := <-client.Start(serviceMethod, args, reply, nil).Done;
+       return call.Error;
+}
index b03685298d2ea0c8f4ad23c0cbfa45aee2487c87..3b7a5df707b58cd9546c3e59e56d015b09ec700f 100644 (file)
@@ -143,16 +143,9 @@ func _new(t *reflect.PtrType) *reflect.PtrValue {
        return v;
 }
 
-// Blocks until the decoder is ready for the next message.
-// TODO(r): blocks longer than that. make this async.
-func (s *service) call(req *Request, mt *methodType, dec *gob.Decoder, enc *gob.Encoder) {
-       method := mt.method;
-       // Decode the argument value.
-       argv := _new(mt.argType);
-       dec.Decode(argv.Interface());
+func (s *service) call(sending *sync.Mutex, function *reflect.FuncValue, req *Request, argv, replyv reflect.Value, enc *gob.Encoder) {
        // Invoke the method, providing a new value for the reply.
-       replyv := _new(mt.replyType);
-       returnValues := method.Func.Call([]reflect.Value{s.rcvr, argv, replyv});
+       returnValues := function.Call([]reflect.Value{s.rcvr, argv, replyv});
        // The return value for the method is an os.Error.
        err := returnValues[0].Interface();
        resp := new(Response);
@@ -160,22 +153,25 @@ func (s *service) call(req *Request, mt *methodType, dec *gob.Decoder, enc *gob.
                resp.Error = err.(os.Error).String();
        }
        // Encode the response header
+       sending.Lock();
        resp.ServiceMethod = req.ServiceMethod;
        resp.Seq = req.Seq;
        enc.Encode(resp);
        // Encode the reply value.
        enc.Encode(replyv.Interface());
+       sending.Unlock();
 }
 
 func (server *Server) serve(conn io.ReadWriteCloser) {
        dec := gob.NewDecoder(conn);
        enc := gob.NewEncoder(conn);
+       sending := new(sync.Mutex);
        for {
                // Grab the request header.
                req := new(Request);
                err := dec.Decode(req);
                if err != nil {
-                       panicln("can't handle decode error yet", err);
+                       panicln("can't handle decode error yet", err.String());
                }
                serviceMethod := strings.Split(req.ServiceMethod, ".", 0);
                if len(serviceMethod) != 2 {
@@ -186,11 +182,18 @@ func (server *Server) serve(conn io.ReadWriteCloser) {
                if !ok {
                        panicln("can't find service", serviceMethod[0]);
                }
-               method, ok := service.method[serviceMethod[1]];
+               mtype, ok := service.method[serviceMethod[1]];
                if !ok {
                        panicln("can't find method", serviceMethod[1]);
                }
-               service.call(req, method, dec, enc);
+               method := mtype.method;
+               // Decode the argument value.
+               argv := _new(mtype.argType);
+               err = dec.Decode(argv.Interface());
+               if err != nil {
+                       panicln("can't handle payload decode error yet", err.String());
+               }
+               go service.call(sending, method.Func, req, argv, _new(mtype.replyType), enc);
        }
 }
 
index 01d991f1481f71c41c598057e40330eab2a6e083..0a1ec64be4ed408429b352232b7db529ad0e8473 100644 (file)
@@ -86,49 +86,46 @@ func TestRPC(t *testing.T) {
                t.Fatal("dialing:", err)
        }
 
-       enc := gob.NewEncoder(conn);
-       dec := gob.NewDecoder(conn);
-       req := new(rpc.Request);
-       req.ServiceMethod = "Arith.Add";
-       req.Seq = 1;
-       enc.Encode(req);
+       client := NewClient(conn);
+
+       // Synchronous calls
        args := &Args{7,8};
-       enc.Encode(args);
-       response := new(rpc.Response);
-       dec.Decode(response);
        reply := new(Reply);
-       dec.Decode(reply);
-       fmt.Printf("%d\n", reply.C);
+       err = client.Call("Arith.Add", args, reply);
        if reply.C != args.A + args.B {
-               t.Errorf("Add: expected %d got %d", reply.C != args.A + args.B);
+               t.Errorf("Add: expected %d got %d", reply.C, args.A + args.B);
        }
 
-       req.ServiceMethod = "Arith.Mul";
-       req.Seq++;
-       enc.Encode(req);
        args = &Args{7,8};
-       enc.Encode(args);
-       response = new(rpc.Response);
-       dec.Decode(response);
        reply = new(Reply);
-       dec.Decode(reply);
-       fmt.Printf("%d\n", reply.C);
+       err = client.Call("Arith.Mul", args, reply);
        if reply.C != args.A * args.B {
-               t.Errorf("Mul: expected %d got %d", reply.C != args.A * args.B);
+               t.Errorf("Mul: expected %d got %d", reply.C, args.A * args.B);
+       }
+
+       // Out of order.
+       args = &Args{7,8};
+       mulReply := new(Reply);
+       mulCall := client.Start("Arith.Mul", args, mulReply, nil);
+       addReply := new(Reply);
+       addCall := client.Start("Arith.Add", args, addReply, nil);
+
+       <-addCall.Done;
+       if addReply.C != args.A + args.B {
+               t.Errorf("Add: expected %d got %d", addReply.C, args.A + args.B);
        }
 
-       req.ServiceMethod = "Arith.Div";
-       req.Seq++;
-       enc.Encode(req);
+       <-mulCall.Done;
+       if mulReply.C != args.A * args.B {
+               t.Errorf("Mul: expected %d got %d", mulReply.C, args.A * args.B);
+       }
+
+       // Error test
        args = &Args{7,0};
-       enc.Encode(args);
-       response = new(rpc.Response);
-       dec.Decode(response);
        reply = new(Reply);
-       dec.Decode(reply);
+       err = client.Call("Arith.Div", args, reply);
        // expect an error: zero divide
-       if response.Error == "" {
+       if err == nil {
                t.Errorf("Div: expected error");
        }
 }
-