]> Cypherpunks repositories - gostls13.git/commitdiff
rpc: make more tolerant of errors.
authorRoger Peppe <rogpeppe@gmail.com>
Wed, 9 Feb 2011 18:57:59 +0000 (10:57 -0800)
committerRob Pike <r@golang.org>
Wed, 9 Feb 2011 18:57:59 +0000 (10:57 -0800)
Add Error type to enable clients to distinguish
between local and remote errors.
Also return "connection shut down error" after
the first error return rather than returning the
same error each time.

R=r
CC=golang-dev
https://golang.org/cl/4080058

src/pkg/rpc/client.go
src/pkg/rpc/server.go
src/pkg/rpc/server_test.go

index 6f028c10d968eb40f8b135d4bb60121b7aceda95..cb21cf907a177adf5715fda1a85b0aa98f165cf2 100644 (file)
@@ -15,6 +15,16 @@ import (
        "sync"
 )
 
+// ServerError represents an error that has been returned from
+// the remote side of the RPC connection.
+type ServerError string
+
+func (e ServerError) String() string {
+       return string(e)
+}
+
+const ErrShutdown = os.ErrorString("connection is shut down")
+
 // Call represents an active RPC.
 type Call struct {
        ServiceMethod string      // The name of the service and method to call.
@@ -30,12 +40,12 @@ type Call struct {
 // with a single Client.
 type Client struct {
        mutex    sync.Mutex // protects pending, seq
-       shutdown os.Error   // non-nil if the client is shut down
        sending  sync.Mutex
        seq      uint64
        codec    ClientCodec
        pending  map[uint64]*Call
        closing  bool
+       shutdown bool
 }
 
 // A ClientCodec implements writing of RPC requests and
@@ -55,8 +65,8 @@ type ClientCodec interface {
 func (client *Client) send(c *Call) {
        // Register this call.
        client.mutex.Lock()
-       if client.shutdown != nil {
-               c.Error = client.shutdown
+       if client.shutdown {
+               c.Error = ErrShutdown
                client.mutex.Unlock()
                c.done()
                return
@@ -79,6 +89,7 @@ func (client *Client) send(c *Call) {
 
 func (client *Client) input() {
        var err os.Error
+       var marker struct{}
        for err == nil {
                response := new(Response)
                err = client.codec.ReadResponseHeader(response)
@@ -93,20 +104,27 @@ func (client *Client) input() {
                c := client.pending[seq]
                client.pending[seq] = c, false
                client.mutex.Unlock()
-               err = client.codec.ReadResponseBody(c.Reply)
-               if response.Error != "" {
-                       c.Error = os.ErrorString(response.Error)
-               } else if err != nil {
-                       c.Error = err
+
+               if response.Error == "" {
+                       err = client.codec.ReadResponseBody(c.Reply)
+                       if err != nil {
+                               c.Error = os.ErrorString("reading body " + err.String())
+                       }
                } else {
-                       // Empty strings should turn into nil os.Errors
-                       c.Error = nil
+                       // We've got an error response. Give this to the request;
+                       // any subsequent requests will get the ReadResponseBody
+                       // error if there is one.
+                       c.Error = ServerError(response.Error)
+                       err = client.codec.ReadResponseBody(&marker)
+                       if err != nil {
+                               err = os.ErrorString("reading error body: " + err.String())
+                       }
                }
                c.done()
        }
        // Terminate pending calls.
        client.mutex.Lock()
-       client.shutdown = err
+       client.shutdown = true
        for _, call := range client.pending {
                call.Error = err
                call.done()
@@ -209,10 +227,11 @@ func Dial(network, address string) (*Client, os.Error) {
 }
 
 func (client *Client) Close() os.Error {
-       if client.shutdown != nil || client.closing {
-               return os.ErrorString("rpc: already closed")
-       }
        client.mutex.Lock()
+       if client.shutdown || client.closing {
+               client.mutex.Unlock()
+               return ErrShutdown
+       }
        client.closing = true
        client.mutex.Unlock()
        return client.codec.Close()
@@ -239,8 +258,8 @@ func (client *Client) Go(serviceMethod string, args interface{}, reply interface
                }
        }
        c.Done = done
-       if client.shutdown != nil {
-               c.Error = client.shutdown
+       if client.shutdown {
+               c.Error = ErrShutdown
                c.done()
                return c
        }
@@ -250,8 +269,8 @@ func (client *Client) Go(serviceMethod string, args interface{}, reply interface
 
 // Call invokes the named function, waits for it to complete, and returns its error status.
 func (client *Client) Call(serviceMethod string, args interface{}, reply interface{}) os.Error {
-       if client.shutdown != nil {
-               return client.shutdown
+       if client.shutdown {
+               return ErrShutdown
        }
        call := <-client.Go(serviceMethod, args, reply, nil).Done
        return call.Error
index 91e9cd5c8d88bf33816c8956fbb23b04a0d0e5e0..4b622d4e5b80d049a7087b120c4594e8adaa0f0b 100644 (file)
@@ -299,7 +299,7 @@ func (server *Server) register(rcvr interface{}, name string, useName bool) os.E
 
 // A value sent as a placeholder for the response when the server receives an invalid request.
 type InvalidRequest struct {
-       marker int
+       Marker int
 }
 
 var invalidRequest = InvalidRequest{1}
@@ -316,6 +316,7 @@ func sendResponse(sending *sync.Mutex, req *Request, reply interface{}, codec Se
        resp.ServiceMethod = req.ServiceMethod
        if errmsg != "" {
                resp.Error = errmsg
+               reply = invalidRequest
        }
        resp.Seq = req.Seq
        sending.Lock()
@@ -389,54 +390,74 @@ func (server *Server) ServeConn(conn io.ReadWriteCloser) {
 func (server *Server) ServeCodec(codec ServerCodec) {
        sending := new(sync.Mutex)
        for {
-               // Grab the request header.
-               req := new(Request)
-               err := codec.ReadRequestHeader(req)
+               req, service, mtype, err := server.readRequest(codec)
                if err != nil {
+                       if err != os.EOF {
+                               log.Println("rpc:", err)
+                       }
                        if err == os.EOF || err == io.ErrUnexpectedEOF {
-                               if err == io.ErrUnexpectedEOF {
-                                       log.Println("rpc:", err)
-                               }
                                break
                        }
-                       s := "rpc: server cannot decode request: " + err.String()
-                       sendResponse(sending, req, invalidRequest, codec, s)
-                       break
-               }
-               serviceMethod := strings.Split(req.ServiceMethod, ".", -1)
-               if len(serviceMethod) != 2 {
-                       s := "rpc: service/method request ill-formed: " + req.ServiceMethod
-                       sendResponse(sending, req, invalidRequest, codec, s)
-                       continue
-               }
-               // Look up the request.
-               server.Lock()
-               service, ok := server.serviceMap[serviceMethod[0]]
-               server.Unlock()
-               if !ok {
-                       s := "rpc: can't find service " + req.ServiceMethod
-                       sendResponse(sending, req, invalidRequest, codec, s)
-                       continue
-               }
-               mtype, ok := service.method[serviceMethod[1]]
-               if !ok {
-                       s := "rpc: can't find method " + req.ServiceMethod
-                       sendResponse(sending, req, invalidRequest, codec, s)
+                       // discard body
+                       codec.ReadRequestBody(new(interface{}))
+
+                       // send a response if we actually managed to read a header.
+                       if req != nil {
+                               sendResponse(sending, req, invalidRequest, codec, err.String())
+                       }
                        continue
                }
+
                // Decode the argument value.
                argv := _new(mtype.ArgType)
                replyv := _new(mtype.ReplyType)
                err = codec.ReadRequestBody(argv.Interface())
                if err != nil {
-                       log.Println("rpc: tearing down", serviceMethod[0], "connection:", err)
+                       if err == os.EOF || err == io.ErrUnexpectedEOF {
+                               if err == io.ErrUnexpectedEOF {
+                                       log.Println("rpc:", err)
+                               }
+                               break
+                       }
                        sendResponse(sending, req, replyv.Interface(), codec, err.String())
-                       break
+                       continue
                }
                go service.call(sending, mtype, req, argv, replyv, codec)
        }
        codec.Close()
 }
+func (server *Server) readRequest(codec ServerCodec) (req *Request, service *service, mtype *methodType, err os.Error) {
+       // Grab the request header.
+       req = new(Request)
+       err = codec.ReadRequestHeader(req)
+       if err != nil {
+               req = nil
+               if err == os.EOF || err == io.ErrUnexpectedEOF {
+                       return
+               }
+               err = os.ErrorString("rpc: server cannot decode request: " + err.String())
+               return
+       }
+
+       serviceMethod := strings.Split(req.ServiceMethod, ".", -1)
+       if len(serviceMethod) != 2 {
+               err = os.ErrorString("rpc: service/method request ill-formed: " + req.ServiceMethod)
+               return
+       }
+       // Look up the request.
+       server.Lock()
+       service = server.serviceMap[serviceMethod[0]]
+       server.Unlock()
+       if service == nil {
+               err = os.ErrorString("rpc: can't find service " + req.ServiceMethod)
+               return
+       }
+       mtype = service.method[serviceMethod[1]]
+       if mtype == nil {
+               err = os.ErrorString("rpc: can't find method " + req.ServiceMethod)
+       }
+       return
+}
 
 // Accept accepts connections on the listener and serves requests
 // for each incoming connection.  Accept blocks; the caller typically
index 1f080faa5b6e927bceabbba417e852ac5ec201b2..05aaebceb487bf132a17deeb56efe01afe8f572b 100644 (file)
@@ -134,14 +134,25 @@ func testRPC(t *testing.T, addr string) {
                t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B)
        }
 
-       args = &Args{7, 8}
+       // Nonexistent method
+       args = &Args{7, 0}
        reply = new(Reply)
-       err = client.Call("Arith.Mul", args, reply)
-       if err != nil {
-               t.Errorf("Mul: expected no error but got string %q", err.String())
+       err = client.Call("Arith.BadOperation", args, reply)
+       // expect an error
+       if err == nil {
+               t.Error("BadOperation: expected error")
+       } else if !strings.HasPrefix(err.String(), "rpc: can't find method ") {
+               t.Errorf("BadOperation: expected can't find method error; got %q", err)
        }
-       if reply.C != args.A*args.B {
-               t.Errorf("Mul: expected %d got %d", reply.C, args.A*args.B)
+
+       // Unknown service
+       args = &Args{7, 8}
+       reply = new(Reply)
+       err = client.Call("Arith.Unknown", args, reply)
+       if err == nil {
+               t.Error("expected error calling unknown service")
+       } else if strings.Index(err.String(), "method") < 0 {
+               t.Error("expected error about method; got", err)
        }
 
        // Out of order.
@@ -178,6 +189,15 @@ func testRPC(t *testing.T, addr string) {
                t.Error("Div: expected divide by zero error; got", err)
        }
 
+       // Bad type.
+       reply = new(Reply)
+       err = client.Call("Arith.Add", reply, reply) // args, reply would be the correct thing to use
+       if err == nil {
+               t.Error("expected error calling Arith.Add with wrong arg type")
+       } else if strings.Index(err.String(), "type") < 0 {
+               t.Error("expected error about type; got", err)
+       }
+
        // Non-struct argument
        const Val = 12345
        str := fmt.Sprint(Val)
@@ -200,9 +220,19 @@ func testRPC(t *testing.T, addr string) {
        if str != expect {
                t.Errorf("String: expected %s got %s", expect, str)
        }
+
+       args = &Args{7, 8}
+       reply = new(Reply)
+       err = client.Call("Arith.Mul", args, reply)
+       if err != nil {
+               t.Errorf("Mul: expected no error but got string %q", err.String())
+       }
+       if reply.C != args.A*args.B {
+               t.Errorf("Mul: expected %d got %d", reply.C, args.A*args.B)
+       }
 }
 
-func TestHTTPRPC(t *testing.T) {
+func TestHTTP(t *testing.T) {
        once.Do(startServer)
        testHTTPRPC(t, "")
        newOnce.Do(startNewServer)
@@ -233,65 +263,6 @@ func testHTTPRPC(t *testing.T, path string) {
        }
 }
 
-func TestCheckUnknownService(t *testing.T) {
-       once.Do(startServer)
-
-       conn, err := net.Dial("tcp", "", serverAddr)
-       if err != nil {
-               t.Fatal("dialing:", err)
-       }
-
-       client := NewClient(conn)
-
-       args := &Args{7, 8}
-       reply := new(Reply)
-       err = client.Call("Unknown.Add", args, reply)
-       if err == nil {
-               t.Error("expected error calling unknown service")
-       } else if strings.Index(err.String(), "service") < 0 {
-               t.Error("expected error about service; got", err)
-       }
-}
-
-func TestCheckUnknownMethod(t *testing.T) {
-       once.Do(startServer)
-
-       conn, err := net.Dial("tcp", "", serverAddr)
-       if err != nil {
-               t.Fatal("dialing:", err)
-       }
-
-       client := NewClient(conn)
-
-       args := &Args{7, 8}
-       reply := new(Reply)
-       err = client.Call("Arith.Unknown", args, reply)
-       if err == nil {
-               t.Error("expected error calling unknown service")
-       } else if strings.Index(err.String(), "method") < 0 {
-               t.Error("expected error about method; got", err)
-       }
-}
-
-func TestCheckBadType(t *testing.T) {
-       once.Do(startServer)
-
-       conn, err := net.Dial("tcp", "", serverAddr)
-       if err != nil {
-               t.Fatal("dialing:", err)
-       }
-
-       client := NewClient(conn)
-
-       reply := new(Reply)
-       err = client.Call("Arith.Add", reply, reply) // args, reply would be the correct thing to use
-       if err == nil {
-               t.Error("expected error calling Arith.Add with wrong arg type")
-       } else if strings.Index(err.String(), "type") < 0 {
-               t.Error("expected error about type; got", err)
-       }
-}
-
 type ArgNotPointer int
 type ReplyNotPointer int
 type ArgNotPublic int