]> Cypherpunks repositories - gostls13.git/commitdiff
handle errors better:
authorRob Pike <r@golang.org>
Wed, 15 Jul 2009 18:47:29 +0000 (11:47 -0700)
committerRob Pike <r@golang.org>
Wed, 15 Jul 2009 18:47:29 +0000 (11:47 -0700)
1) terminate outstanding calls on the client when we see EOF from server
2) allow data to drain on server before closing the connection

R=rsc
DELTA=41  (23 added, 4 deleted, 14 changed)
OCL=31687
CL=31689

src/pkg/rpc/client.go
src/pkg/rpc/server.go

index a18f9f15aeec4e9bbdc4b39170728d418c1ab665..122763f4e8bdfbd0675af86911063154c6616004 100644 (file)
@@ -30,7 +30,7 @@ type Call struct {
 // Client represents an RPC Client.
 type Client struct {
        sync.Mutex;     // protects pending, seq
-       closed  bool;
+       shutdown        os.Error;       // non-nil if the client is shut down
        sending sync.Mutex;
        seq     uint64;
        conn io.ReadWriteCloser;
@@ -42,6 +42,12 @@ type Client struct {
 func (client *Client) send(c *Call) {
        // Register this call.
        client.Lock();
+       if client.shutdown != nil {
+               client.Unlock();
+               c.Error = client.shutdown;
+               doNotBlock := c.Done <- c;
+               return;
+       }
        c.seq = client.seq;
        client.seq++;
        client.pending[c.seq] = c;
@@ -66,10 +72,7 @@ func (client *Client) serve() {
                response := new(Response);
                err = client.dec.Decode(response);
                if err != nil {
-                       if err == os.EOF {
-                               break;
-                       }
-                       break;
+                       break
                }
                seq := response.Seq;
                client.Lock();
@@ -82,7 +85,14 @@ func (client *Client) serve() {
                // sure the channel has enough buffer space. See comment in Go().
                doNotBlock := c.Done <- c;
        }
-       client.closed = true;
+       // Terminate pending calls.
+       client.Lock();
+       client.shutdown = err;
+       for seq, call := range client.pending {
+               call.Error = err;
+               doNotBlock := call.Done <- call;
+       }
+       client.Unlock();
        log.Stderr("client protocol error:", err);
 }
 
@@ -144,8 +154,8 @@ func (client *Client) Go(serviceMethod string, args interface{}, reply interface
                // RPCs that will be using that channel.
        }
        c.Done = done;
-       if client.closed {
-               c.Error = os.EOF;
+       if client.shutdown != nil {
+               c.Error = client.shutdown;
                doNotBlock := c.Done <- c;
                return c;
        }
@@ -155,8 +165,8 @@ func (client *Client) Go(serviceMethod string, args interface{}, reply interface
 
 // Call invokes the named function, waits for it to complete, and returns its error status.
 func (client *Client) Call(serviceMethod string, args interface{}, reply interface{}) os.Error {
-       if client.closed {
-               return os.EOF
+       if client.shutdown != nil {
+               return client.shutdown
        }
        call := <-client.Go(serviceMethod, args, reply, nil).Done;
        return call.Error;
index 142f00acf4339fba7ae48ab9e14edae1c937a513..79feccc699691dfd79911743bd019cc14118dbe1 100644 (file)
@@ -14,6 +14,7 @@ import (
        "reflect";
        "strings";
        "sync";
+       "time"; // See TODO in serve()
        "unicode";
        "utf8";
 )
@@ -148,13 +149,13 @@ func _new(t *reflect.PtrType) *reflect.PtrValue {
        return v;
 }
 
-func (s *service) sendResponse(sending *sync.Mutex, req *Request, reply interface{}, enc *gob.Encoder, errmsg string) {
+func sendResponse(sending *sync.Mutex, req *Request, reply interface{}, enc *gob.Encoder, errmsg string) {
        resp := new(Response);
        // Encode the response header
-       sending.Lock();
        resp.ServiceMethod = req.ServiceMethod;
        resp.Error = errmsg;
        resp.Seq = req.Seq;
+       sending.Lock();
        enc.Encode(resp);
        // Encode the reply value.
        enc.Encode(reply);
@@ -170,7 +171,7 @@ func (s *service) call(sending *sync.Mutex, function *reflect.FuncValue, req *Re
        if errInter != nil {
                errmsg = errInter.(os.Error).String();
        }
-       s.sendResponse(sending, req, replyv.Interface(), enc, errmsg);
+       sendResponse(sending, req, replyv.Interface(), enc, errmsg);
 }
 
 func (server *serverType) serve(conn io.ReadWriteCloser) {
@@ -182,25 +183,27 @@ func (server *serverType) serve(conn io.ReadWriteCloser) {
                req := new(Request);
                err := dec.Decode(req);
                if err != nil {
-                       log.Stderr("rpc: server cannot decode request:", err);
+                       s := "rpc: server cannot decode request: " + err.String();
+                       sendResponse(sending, req, invalidRequest, enc, s);
                        break;
                }
                serviceMethod := strings.Split(req.ServiceMethod, ".", 0);
                if len(serviceMethod) != 2 {
-                       log.Stderr("rpc: service/Method request ill-formed:", req.ServiceMethod);
+                       s := "rpc: service/method request ill:formed: " + req.ServiceMethod;
+                       sendResponse(sending, req, invalidRequest, enc, s);
                        break;
                }
                // Look up the request.
                service, ok := server.serviceMap[serviceMethod[0]];
                if !ok {
                        s := "rpc: can't find service " + req.ServiceMethod;
-                       service.sendResponse(sending, req, invalidRequest, enc, s);
+                       sendResponse(sending, req, invalidRequest, enc, s);
                        break;
                }
                mtype, ok := service.method[serviceMethod[1]];
                if !ok {
                        s := "rpc: can't find method " + req.ServiceMethod;
-                       service.sendResponse(sending, req, invalidRequest, enc, s);
+                       sendResponse(sending, req, invalidRequest, enc, s);
                        break;
                }
                method := mtype.method;
@@ -210,11 +213,17 @@ func (server *serverType) serve(conn io.ReadWriteCloser) {
                err = dec.Decode(argv.Interface());
                if err != nil {
                        log.Stderr("tearing down connection:", err);
-                       service.sendResponse(sending, req, replyv.Interface(), enc, err.String());
+                       sendResponse(sending, req, replyv.Interface(), enc, err.String());
                        break;
                }
                go service.call(sending, method.Func, req, argv, replyv, enc);
        }
+       // TODO(r):  Gobs cannot handle unexpected data yet.  Once they can, we can
+       // ignore it and the connection can persist.  For now, though, bad data
+       // ruins the connection and we must shut down.  The sleep is necessary to
+       // guarantee all the data gets out before we close the connection, so the
+       // client can see the error description.
+       time.Sleep(2e9);
        conn.Close();
 }