]> Cypherpunks repositories - gostls13.git/commitdiff
rpc: abstract client and server encodings
authorRuss Cox <rsc@golang.org>
Tue, 27 Apr 2010 20:51:25 +0000 (13:51 -0700)
committerRuss Cox <rsc@golang.org>
Tue, 27 Apr 2010 20:51:25 +0000 (13:51 -0700)
R=r
CC=golang-dev, rog
https://golang.org/cl/811046

src/pkg/rpc/client.go
src/pkg/rpc/server.go

index 6b2ddd6f0ab5c8c92628fa53deaca767e0d983f5..d742d099fb3deb3414a72649945ece9a8550ccfe 100644 (file)
@@ -33,13 +33,25 @@ type Client struct {
        shutdown os.Error   // non-nil if the client is shut down
        sending  sync.Mutex
        seq      uint64
-       conn     io.ReadWriteCloser
-       enc      *gob.Encoder
-       dec      *gob.Decoder
+       codec    ClientCodec
        pending  map[uint64]*Call
        closing  bool
 }
 
+// A ClientCodec implements writing of RPC requests and
+// reading of RPC responses for the client side of an RPC session.
+// The client calls WriteRequest to write a request to the connection
+// and calls ReadResponseHeader and ReadResponseBody in pairs
+// to read responses.  The client calls Close when finished with the
+// connection.
+type ClientCodec interface {
+       WriteRequest(*Request, interface{}) os.Error
+       ReadResponseHeader(*Response) os.Error
+       ReadResponseBody(interface{}) os.Error
+
+       Close() os.Error
+}
+
 func (client *Client) send(c *Call) {
        // Register this call.
        client.mutex.Lock()
@@ -59,9 +71,7 @@ func (client *Client) send(c *Call) {
        client.sending.Lock()
        request.Seq = c.seq
        request.ServiceMethod = c.ServiceMethod
-       client.enc.Encode(request)
-       err := client.enc.Encode(c.Args)
-       if err != nil {
+       if err := client.codec.WriteRequest(request, c.Args); err != nil {
                panic("rpc: client encode error: " + err.String())
        }
        client.sending.Unlock()
@@ -71,7 +81,7 @@ func (client *Client) input() {
        var err os.Error
        for err == nil {
                response := new(Response)
-               err = client.dec.Decode(response)
+               err = client.codec.ReadResponseHeader(response)
                if err != nil {
                        if err == os.EOF && !client.closing {
                                err = io.ErrUnexpectedEOF
@@ -83,7 +93,7 @@ func (client *Client) input() {
                c := client.pending[seq]
                client.pending[seq] = c, false
                client.mutex.Unlock()
-               err = client.dec.Decode(c.Reply)
+               err = client.codec.ReadResponseBody(c.Reply)
                // Empty strings should turn into nil os.Errors
                if response.Error != "" {
                        c.Error = os.ErrorString(response.Error)
@@ -110,17 +120,49 @@ func (client *Client) input() {
 // NewClient returns a new Client to handle requests to the
 // set of services at the other end of the connection.
 func NewClient(conn io.ReadWriteCloser) *Client {
-       client := new(Client)
-       client.conn = conn
-       client.enc = gob.NewEncoder(conn)
-       client.dec = gob.NewDecoder(conn)
-       client.pending = make(map[uint64]*Call)
+       return NewClientWithCodec(&gobClientCodec{conn, gob.NewDecoder(conn), gob.NewEncoder(conn)})
+}
+
+// NewClientWithCodec is like NewClient but uses the specified
+// codec to encode requests and decode responses.
+func NewClientWithCodec(codec ClientCodec) *Client {
+       client := &Client{
+               codec:   codec,
+               pending: make(map[uint64]*Call),
+       }
        go client.input()
        return client
 }
 
+type gobClientCodec struct {
+       rwc io.ReadWriteCloser
+       dec *gob.Decoder
+       enc *gob.Encoder
+}
+
+func (c *gobClientCodec) WriteRequest(r *Request, body interface{}) os.Error {
+       if err := c.enc.Encode(r); err != nil {
+               return err
+       }
+       return c.enc.Encode(body)
+}
+
+func (c *gobClientCodec) ReadResponseHeader(r *Response) os.Error {
+       return c.dec.Decode(r)
+}
+
+func (c *gobClientCodec) ReadResponseBody(body interface{}) os.Error {
+       return c.dec.Decode(body)
+}
+
+func (c *gobClientCodec) Close() os.Error {
+       return c.rwc.Close()
+}
+
+
 // DialHTTP connects to an HTTP RPC server at the specified network address.
 func DialHTTP(network, address string) (*Client, os.Error) {
+       var err os.Error
        conn, err := net.Dial(network, "", address)
        if err != nil {
                return nil, err
@@ -156,7 +198,7 @@ func (client *Client) Close() os.Error {
        client.mutex.Lock()
        client.closing = true
        client.mutex.Unlock()
-       return client.conn.Close()
+       return client.codec.Close()
 }
 
 // Go invokes the function asynchronously.  It returns the Call structure representing
index 413f9a59acf1363dc3f3e6fd66d719ad093c9b4c..4c957597bc7448991d5bd4badaa3bb500f3bf2fb 100644 (file)
@@ -272,7 +272,7 @@ func _new(t *reflect.PtrType) *reflect.PtrValue {
        return v
 }
 
-func sendResponse(sending *sync.Mutex, req *Request, reply interface{}, enc *gob.Encoder, errmsg string) {
+func sendResponse(sending *sync.Mutex, req *Request, reply interface{}, codec ServerCodec, errmsg string) {
        resp := new(Response)
        // Encode the response header
        resp.ServiceMethod = req.ServiceMethod
@@ -281,13 +281,14 @@ func sendResponse(sending *sync.Mutex, req *Request, reply interface{}, enc *gob
        }
        resp.Seq = req.Seq
        sending.Lock()
-       enc.Encode(resp)
-       // Encode the reply value.
-       enc.Encode(reply)
+       err := codec.WriteResponse(resp, reply)
+       if err != nil {
+               log.Stderr("rpc: writing response: ", err)
+       }
        sending.Unlock()
 }
 
-func (s *service) call(sending *sync.Mutex, mtype *methodType, req *Request, argv, replyv reflect.Value, enc *gob.Encoder) {
+func (s *service) call(sending *sync.Mutex, mtype *methodType, req *Request, argv, replyv reflect.Value, codec ServerCodec) {
        mtype.Lock()
        mtype.numCalls++
        mtype.Unlock()
@@ -300,17 +301,40 @@ func (s *service) call(sending *sync.Mutex, mtype *methodType, req *Request, arg
        if errInter != nil {
                errmsg = errInter.(os.Error).String()
        }
-       sendResponse(sending, req, replyv.Interface(), enc, errmsg)
+       sendResponse(sending, req, replyv.Interface(), codec, errmsg)
+}
+
+type gobServerCodec struct {
+       rwc io.ReadWriteCloser
+       dec *gob.Decoder
+       enc *gob.Encoder
+}
+
+func (c *gobServerCodec) ReadRequestHeader(r *Request) os.Error {
+       return c.dec.Decode(r)
+}
+
+func (c *gobServerCodec) ReadRequestBody(body interface{}) os.Error {
+       return c.dec.Decode(body)
+}
+
+func (c *gobServerCodec) WriteResponse(r *Response, body interface{}) os.Error {
+       if err := c.enc.Encode(r); err != nil {
+               return err
+       }
+       return c.enc.Encode(body)
 }
 
-func (server *serverType) input(conn io.ReadWriteCloser) {
-       dec := gob.NewDecoder(conn)
-       enc := gob.NewEncoder(conn)
+func (c *gobServerCodec) Close() os.Error {
+       return c.rwc.Close()
+}
+
+func (server *serverType) input(codec ServerCodec) {
        sending := new(sync.Mutex)
        for {
                // Grab the request header.
                req := new(Request)
-               err := dec.Decode(req)
+               err := codec.ReadRequestHeader(req)
                if err != nil {
                        if err == os.EOF || err == io.ErrUnexpectedEOF {
                                if err == io.ErrUnexpectedEOF {
@@ -319,13 +343,13 @@ func (server *serverType) input(conn io.ReadWriteCloser) {
                                break
                        }
                        s := "rpc: server cannot decode request: " + err.String()
-                       sendResponse(sending, req, invalidRequest, enc, s)
-                       continue
+                       sendResponse(sending, req, invalidRequest, codec, s)
+                       break
                }
                serviceMethod := strings.Split(req.ServiceMethod, ".", 0)
                if len(serviceMethod) != 2 {
-                       s := "rpc: service/method request ill:formed: " + req.ServiceMethod
-                       sendResponse(sending, req, invalidRequest, enc, s)
+                       s := "rpc: service/method request ill-formed: " + req.ServiceMethod
+                       sendResponse(sending, req, invalidRequest, codec, s)
                        continue
                }
                // Look up the request.
@@ -334,27 +358,27 @@ func (server *serverType) input(conn io.ReadWriteCloser) {
                server.Unlock()
                if !ok {
                        s := "rpc: can't find service " + req.ServiceMethod
-                       sendResponse(sending, req, invalidRequest, enc, s)
+                       sendResponse(sending, req, invalidRequest, codec, s)
                        continue
                }
                mtype, ok := service.method[serviceMethod[1]]
                if !ok {
                        s := "rpc: can't find method " + req.ServiceMethod
-                       sendResponse(sending, req, invalidRequest, enc, s)
+                       sendResponse(sending, req, invalidRequest, codec, s)
                        continue
                }
                // Decode the argument value.
                argv := _new(mtype.argType)
                replyv := _new(mtype.replyType)
-               err = dec.Decode(argv.Interface())
+               err = codec.ReadRequestBody(argv.Interface())
                if err != nil {
                        log.Stderr("rpc: tearing down", serviceMethod[0], "connection:", err)
-                       sendResponse(sending, req, replyv.Interface(), enc, err.String())
-                       continue
+                       sendResponse(sending, req, replyv.Interface(), codec, err.String())
+                       break
                }
-               go service.call(sending, mtype, req, argv, replyv, enc)
+               go service.call(sending, mtype, req, argv, replyv, codec)
        }
-       conn.Close()
+       codec.Close()
 }
 
 func (server *serverType) accept(lis net.Listener) {
@@ -363,7 +387,7 @@ func (server *serverType) accept(lis net.Listener) {
                if err != nil {
                        log.Exit("rpc.Serve: accept:", err.String()) // TODO(r): exit?
                }
-               go server.input(conn)
+               go ServeConn(conn)
        }
 }
 
@@ -376,10 +400,34 @@ func (server *serverType) accept(lis net.Listener) {
 // suitable methods.
 func Register(rcvr interface{}) os.Error { return server.register(rcvr) }
 
-// ServeConn runs the server on a single connection.  When the connection
-// completes, service terminates.  ServeConn blocks; the caller typically
-// invokes it in a go statement.
-func ServeConn(conn io.ReadWriteCloser) { server.input(conn) }
+// A ServerCodec implements reading of RPC requests and writing of
+// RPC responses for the server side of an RPC session.
+// The server calls ReadRequestHeader and ReadRequestBody in pairs
+// to read requests from the connection, and it calls WriteResponse to
+// write a response back.  The server calls Close when finished with the
+// connection.
+type ServerCodec interface {
+       ReadRequestHeader(*Request) os.Error
+       ReadRequestBody(interface{}) os.Error
+       WriteResponse(*Response, interface{}) os.Error
+
+       Close() os.Error
+}
+
+// ServeConn runs the server on a single connection.
+// ServeConn blocks, serving the connection until the client hangs up.
+// The caller typically invokes ServeConn in a go statement.
+// ServeConn uses the gob wire format (see package gob) on the
+// connection.  To use an alternate codec, use ServeCodec.
+func ServeConn(conn io.ReadWriteCloser) {
+       ServeCodec(&gobServerCodec{conn, gob.NewDecoder(conn), gob.NewEncoder(conn)})
+}
+
+// ServeCodec is like ServeConn but uses the specified codec to
+// decode requests and encode responses.
+func ServeCodec(codec ServerCodec) {
+       server.input(codec)
+}
 
 // Accept accepts connections on the listener and serves requests
 // for each incoming connection.  Accept blocks; the caller typically
@@ -404,7 +452,7 @@ func serveHTTP(c *http.Conn, req *http.Request) {
                return
        }
        io.WriteString(conn, "HTTP/1.0 "+connected+"\n\n")
-       server.input(conn)
+       ServeConn(conn)
 }
 
 // HandleHTTP registers an HTTP handler for RPC messages.