]> Cypherpunks repositories - gostls13.git/commitdiff
improve rpc code. more robust. more tests.
authorRob Pike <r@golang.org>
Tue, 14 Jul 2009 20:23:14 +0000 (13:23 -0700)
committerRob Pike <r@golang.org>
Tue, 14 Jul 2009 20:23:14 +0000 (13:23 -0700)
R=rsc
DELTA=186  (133 added, 20 deleted, 33 changed)
OCL=31611
CL=31616

src/pkg/gob/decoder.go
src/pkg/rpc/client.go
src/pkg/rpc/server.go
src/pkg/rpc/server_test.go

index 8676533e62bdf224b91f9406c0e56aef3e9ba76c..ef5481e109a5077123bfda87ddf722e2bdd20bfd 100644 (file)
@@ -73,7 +73,7 @@ func (dec *Decoder) Decode(e interface{}) os.Error {
        // TODO(r): need to make the decoder work correctly if the wire type is compatible
        // but not equal to the local type (e.g, extra fields).
        if info.wire.name() != dec.seen[id].name() {
-               dec.state.err = os.ErrorString("gob decode: incorrect type for wire value");
+               dec.state.err = os.ErrorString("gob decode: incorrect type for wire value: want " + info.wire.name() + "; received " + dec.seen[id].name());
                return dec.state.err
        }
 
index 725add1a54a59e03119c5a7cbfea34a4a37c5a7b..196118834b31343c9fcbd6a787f3801b77a4b405 100644 (file)
@@ -7,6 +7,7 @@ package rpc
 import (
        "gob";
        "io";
+       "log";
        "os";
        "rpc";
        "sync";
@@ -25,6 +26,7 @@ type Call struct {
 // Client represents an RPC Client.
 type Client struct {
        sync.Mutex;     // protects pending, seq
+       closed  bool;
        sending sync.Mutex;
        seq     uint64;
        conn io.ReadWriteCloser;
@@ -49,29 +51,35 @@ func (client *Client) send(c *Call) {
        client.enc.Encode(request);
        err := client.enc.Encode(c.Args);
        if err != nil {
-               panicln("client encode error:", err)
+               panicln("rpc: client encode error:", err);
        }
        client.sending.Unlock();
 }
 
 func (client *Client) serve() {
-       for {
+       var err os.Error;
+       for err == nil {
                response := new(Response);
-               err := client.dec.Decode(response);
+               err = client.dec.Decode(response);
+               if err != nil {
+                       if err == os.EOF {
+                               break;
+                       }
+                       break;
+               }
                seq := response.Seq;
                client.Lock();
                c := client.pending[seq];
                client.pending[seq] = c, false;
                client.Unlock();
-               client.dec.Decode(c.Reply);
-               if err != nil {
-                       panicln("client decode error:", err)
-               }
+               err = client.dec.Decode(c.Reply);
                c.Error = os.ErrorString(response.Error);
-               // We don't want to block here, it is the caller's responsibility to make
-               // sure the channel has enough buffer space. See comment in Start().
+               // We don't want to block here.  It is the caller's responsibility to make
+               // sure the channel has enough buffer space. See comment in Go().
                doNotBlock := c.Done <- c;
        }
+       client.closed = true;
+       log.Stderr("client protocol error:", err);
 }
 
 // NewClient returns a new Client to handle requests to the
@@ -86,9 +94,9 @@ func NewClient(conn io.ReadWriteCloser) *Client {
        return client;
 }
 
-// Start invokes the function asynchronously.  It returns the Call structure representing
+// Go invokes the function asynchronously.  It returns the Call structure representing
 // the invocation.
-func (client *Client) Start(serviceMethod string, args interface{}, reply interface{}, done chan *Call) *Call {
+func (client *Client) Go(serviceMethod string, args interface{}, reply interface{}, done chan *Call) *Call {
        c := new(Call);
        c.ServiceMethod = serviceMethod;
        c.Args = args;
@@ -102,12 +110,20 @@ func (client *Client) Start(serviceMethod string, args interface{}, reply interf
                // RPCs that will be using that channel.
        }
        c.Done = done;
+       if client.closed {
+               c.Error = os.EOF;
+               doNotBlock := c.Done <- c;
+               return c;
+       }
        client.send(c);
        return c;
 }
 
 // Call invokes the named function, waits for it to complete, and returns its error status.
 func (client *Client) Call(serviceMethod string, args interface{}, reply interface{}) os.Error {
-       call := <-client.Start(serviceMethod, args, reply, nil).Done;
+       if client.closed {
+               return os.EOF
+       }
+       call := <-client.Go(serviceMethod, args, reply, nil).Done;
        return call.Error;
 }
index 3b7a5df707b58cd9546c3e59e56d015b09ec700f..304f1e2df2f1671761ca97ef1ba452ea6702b071 100644 (file)
@@ -8,6 +8,7 @@ import (
        "gob";
        "log";
        "io";
+       "net";
        "os";
        "reflect";
        "strings";
@@ -16,6 +17,8 @@ import (
        "utf8";
 )
 
+import "fmt" // TODO DELETE
+
 // Precompute the reflect type for os.Error.  Can't use os.Error directly
 // because Typeof takes an empty interface value.  This is annoying.
 var unusedError *os.Error;
@@ -137,31 +140,43 @@ func (server *Server) Add(rcvr interface{}) os.Error {
        return nil;
 }
 
+// A value to be sent as a placeholder for the response when we receive invalid request.
+type InvalidRequest struct {
+       marker int
+}
+var invalidRequest = InvalidRequest{1}
+
 func _new(t *reflect.PtrType) *reflect.PtrValue {
        v := reflect.MakeZero(t).(*reflect.PtrValue);
        v.PointTo(reflect.MakeZero(t.Elem()));
        return v;
 }
 
-func (s *service) call(sending *sync.Mutex, function *reflect.FuncValue, req *Request, argv, replyv reflect.Value, enc *gob.Encoder) {
-       // Invoke the method, providing a new value for the reply.
-       returnValues := function.Call([]reflect.Value{s.rcvr, argv, replyv});
-       // The return value for the method is an os.Error.
-       err := returnValues[0].Interface();
+func (s *service) sendResponse(sending *sync.Mutex, req *Request, reply interface{}, enc *gob.Encoder, errmsg string) {
        resp := new(Response);
-       if err != nil {
-               resp.Error = err.(os.Error).String();
-       }
        // Encode the response header
        sending.Lock();
        resp.ServiceMethod = req.ServiceMethod;
+       resp.Error = errmsg;
        resp.Seq = req.Seq;
        enc.Encode(resp);
        // Encode the reply value.
-       enc.Encode(replyv.Interface());
+       enc.Encode(reply);
        sending.Unlock();
 }
 
+func (s *service) call(sending *sync.Mutex, function *reflect.FuncValue, req *Request, argv, replyv reflect.Value, enc *gob.Encoder) {
+       // Invoke the method, providing a new value for the reply.
+       returnValues := function.Call([]reflect.Value{s.rcvr, argv, replyv});
+       // The return value for the method is an os.Error.
+       errInter := returnValues[0].Interface();
+       errmsg := "";
+       if errInter != nil {
+               errmsg = errInter.(os.Error).String();
+       }
+       s.sendResponse(sending, req, replyv.Interface(), enc, errmsg);
+}
+
 func (server *Server) serve(conn io.ReadWriteCloser) {
        dec := gob.NewDecoder(conn);
        enc := gob.NewEncoder(conn);
@@ -171,33 +186,56 @@ func (server *Server) serve(conn io.ReadWriteCloser) {
                req := new(Request);
                err := dec.Decode(req);
                if err != nil {
-                       panicln("can't handle decode error yet", err.String());
+                       log.Stderr("rpc: server cannot decode request:", err);
+                       break;
                }
                serviceMethod := strings.Split(req.ServiceMethod, ".", 0);
                if len(serviceMethod) != 2 {
-                       panicln("service/Method request ill-formed:", req.ServiceMethod);
+                       log.Stderr("rpc: service/Method request ill-formed:", req.ServiceMethod);
+                       break;
                }
                // Look up the request.
                service, ok := server.serviceMap[serviceMethod[0]];
                if !ok {
-                       panicln("can't find service", serviceMethod[0]);
+                       s := "rpc: can't find service " + req.ServiceMethod;
+                       service.sendResponse(sending, req, invalidRequest, enc, s);
+                       break;
                }
                mtype, ok := service.method[serviceMethod[1]];
                if !ok {
-                       panicln("can't find method", serviceMethod[1]);
+                       s := "rpc: can't find method " + req.ServiceMethod;
+                       service.sendResponse(sending, req, invalidRequest, enc, s);
+                       break;
                }
                method := mtype.method;
                // Decode the argument value.
                argv := _new(mtype.argType);
+               replyv := _new(mtype.replyType);
                err = dec.Decode(argv.Interface());
                if err != nil {
-                       panicln("can't handle payload decode error yet", err.String());
+                       log.Stderr("tearing down connection:", err);
+                       service.sendResponse(sending, req, replyv.Interface(), enc, err.String());
+                       break;
                }
-               go service.call(sending, method.Func, req, argv, _new(mtype.replyType), enc);
+               go service.call(sending, method.Func, req, argv, replyv, enc);
        }
+       conn.Close();
 }
 
-// Serve runs the server on the connection.
-func (server *Server) Serve(conn io.ReadWriteCloser) {
+// ServeConn runs the server on a single connection.  When the connection
+// completes, service terminates.
+func (server *Server) ServeConn(conn io.ReadWriteCloser) {
        go server.serve(conn)
 }
+
+// Accept accepts connections on the listener and serves requests
+// for each incoming connection.
+func (server *Server) Accept(lis net.Listener) {
+       for {
+               conn, addr, err := lis.Accept();
+               if err != nil {
+                       log.Exit("rpc.Serve: accept:", err.String());   // TODO(r): exit?
+               }
+               go server.ServeConn(conn);
+       }
+}
index 0a1ec64be4ed408429b352232b7db529ad0e8473..51d024d76b5dad00997bb17c0157583dbab9a25d 100644 (file)
@@ -11,8 +11,10 @@ import (
        "io";
        "log";
        "net";
+       "once";
        "os";
        "rpc";
+       "strings";
        "testing";
 )
 
@@ -53,15 +55,6 @@ func (t *Arith) Error(args *Args, reply *Reply) os.Error {
        panicln("ERROR");
 }
 
-func run(server *Server, l net.Listener) {
-       conn, addr, err := l.Accept();
-       if err != nil {
-               println("accept:", err.String());
-               os.Exit(1);
-       }
-       server.Serve(conn);
-}
-
 func startServer() {
        server := new(Server);
        server.Add(new(Arith));
@@ -72,14 +65,13 @@ func startServer() {
        }
        serverAddr = l.Addr();
        log.Stderr("Test RPC server listening on ", serverAddr);
-//     go http.Serve(l, nil);
-       go run(server, l);
+       go server.Accept(l);
 }
 
 func TestRPC(t *testing.T) {
        var i int;
 
-       startServer();
+       once.Do(startServer);
 
        conn, err := net.Dial("tcp", "", serverAddr);
        if err != nil {
@@ -106,9 +98,9 @@ func TestRPC(t *testing.T) {
        // Out of order.
        args = &Args{7,8};
        mulReply := new(Reply);
-       mulCall := client.Start("Arith.Mul", args, mulReply, nil);
+       mulCall := client.Go("Arith.Mul", args, mulReply, nil);
        addReply := new(Reply);
-       addCall := client.Start("Arith.Add", args, addReply, nil);
+       addCall := client.Go("Arith.Add", args, addReply, nil);
 
        <-addCall.Done;
        if addReply.C != args.A + args.B {
@@ -126,6 +118,73 @@ func TestRPC(t *testing.T) {
        err = client.Call("Arith.Div", args, reply);
        // expect an error: zero divide
        if err == nil {
-               t.Errorf("Div: expected error");
+               t.Error("Div: expected error");
+       } else if err.String() != "divide by zero" {
+               t.Error("Div: expected divide by zero error; got", err);
+       }
+}
+
+func TestCheckUnknownService(t *testing.T) {
+       var i int;
+
+       once.Do(startServer);
+
+       conn, err := net.Dial("tcp", "", serverAddr);
+       if err != nil {
+               t.Fatal("dialing:", err)
+       }
+
+       client := NewClient(conn);
+
+       args := &Args{7,8};
+       reply := new(Reply);
+       err = client.Call("Unknown.Add", args, reply);
+       if err == nil {
+               t.Error("expected error calling unknown service");
+       } else if strings.Index(err.String(), "service") < 0 {
+               t.Error("expected error about service; got", err);
+       }
+}
+
+func TestCheckUnknownMethod(t *testing.T) {
+       var i int;
+
+       once.Do(startServer);
+
+       conn, err := net.Dial("tcp", "", serverAddr);
+       if err != nil {
+               t.Fatal("dialing:", err)
+       }
+
+       client := NewClient(conn);
+
+       args := &Args{7,8};
+       reply := new(Reply);
+       err = client.Call("Arith.Unknown", args, reply);
+       if err == nil {
+               t.Error("expected error calling unknown service");
+       } else if strings.Index(err.String(), "method") < 0 {
+               t.Error("expected error about method; got", err);
+       }
+}
+
+func TestCheckBadType(t *testing.T) {
+       var i int;
+
+       once.Do(startServer);
+
+       conn, err := net.Dial("tcp", "", serverAddr);
+       if err != nil {
+               t.Fatal("dialing:", err)
+       }
+
+       client := NewClient(conn);
+
+       reply := new(Reply);
+       err = client.Call("Arith.Add", reply, reply);   // args, reply would be the correct thing to use
+       if err == nil {
+               t.Error("expected error calling Arith.Add with wrong arg type");
+       } else if strings.Index(err.String(), "type") < 0 {
+               t.Error("expected error about type; got", err);
        }
 }