]> Cypherpunks repositories - gostls13.git/commitdiff
rpc: implement ServeRequest to synchronously serve a single request.
authorSugu Sougoumarane <ssougou@gmail.com>
Mon, 15 Aug 2011 22:06:22 +0000 (08:06 +1000)
committerRob Pike <r@golang.org>
Mon, 15 Aug 2011 22:06:22 +0000 (08:06 +1000)
This is useful for applications that want to micromanage the rpc service.
Moved part of ServeCodec into a new readRequest function.
Renamed existing readRequest to readRequestHeader, and reordered
its parameters to align with the new readRequest and service.call.

R=golang-dev, r, rsc, sougou
CC=golang-dev, msolomon
https://golang.org/cl/4889043

src/pkg/rpc/server.go
src/pkg/rpc/server_test.go

index 86767abea32b6568ffea28de28cbfe52106428db..ac3f793047db1bfbe08c4cfe483cd235fcd392c7 100644 (file)
@@ -394,7 +394,7 @@ func (server *Server) ServeConn(conn io.ReadWriteCloser) {
 func (server *Server) ServeCodec(codec ServerCodec) {
        sending := new(sync.Mutex)
        for {
-               req, service, mtype, err := server.readRequest(codec)
+               service, mtype, req, argv, replyv, err := server.readRequest(codec)
                if err != nil {
                        if err != os.EOF {
                                log.Println("rpc:", err)
@@ -402,9 +402,6 @@ func (server *Server) ServeCodec(codec ServerCodec) {
                        if err == os.EOF || err == io.ErrUnexpectedEOF {
                                break
                        }
-                       // discard body
-                       codec.ReadRequestBody(nil)
-
                        // send a response if we actually managed to read a header.
                        if req != nil {
                                server.sendResponse(sending, req, invalidRequest, codec, err.String())
@@ -412,35 +409,29 @@ func (server *Server) ServeCodec(codec ServerCodec) {
                        }
                        continue
                }
+               go service.call(server, sending, mtype, req, argv, replyv, codec)
+       }
+       codec.Close()
+}
 
-               // Decode the argument value.
-               var argv reflect.Value
-               argIsValue := false // if true, need to indirect before calling.
-               if mtype.ArgType.Kind() == reflect.Ptr {
-                       argv = reflect.New(mtype.ArgType.Elem())
-               } else {
-                       argv = reflect.New(mtype.ArgType)
-                       argIsValue = true
-               }
-               // argv guaranteed to be a pointer now.
-               replyv := reflect.New(mtype.ReplyType.Elem())
-               err = codec.ReadRequestBody(argv.Interface())
-               if err != nil {
-                       if err == os.EOF || err == io.ErrUnexpectedEOF {
-                               if err == io.ErrUnexpectedEOF {
-                                       log.Println("rpc:", err)
-                               }
-                               break
-                       }
-                       server.sendResponse(sending, req, replyv.Interface(), codec, err.String())
-                       continue
+// ServeRequest is like ServeCodec but synchronously serves a single request.
+// It does not close the codec upon completion.
+func (server *Server) ServeRequest(codec ServerCodec) os.Error {
+       sending := new(sync.Mutex)
+       service, mtype, req, argv, replyv, err := server.readRequest(codec)
+       if err != nil {
+               if err == os.EOF || err == io.ErrUnexpectedEOF {
+                       return err
                }
-               if argIsValue {
-                       argv = argv.Elem()
+               // send a response if we actually managed to read a header.
+               if req != nil {
+                       server.sendResponse(sending, req, invalidRequest, codec, err.String())
+                       server.freeRequest(req)
                }
-               go service.call(server, sending, mtype, req, argv, replyv, codec)
+               return err
        }
-       codec.Close()
+       service.call(server, sending, mtype, req, argv, replyv, codec)
+       return nil
 }
 
 func (server *Server) getRequest() *Request {
@@ -483,7 +474,38 @@ func (server *Server) freeResponse(resp *Response) {
        server.respLock.Unlock()
 }
 
-func (server *Server) readRequest(codec ServerCodec) (req *Request, service *service, mtype *methodType, err os.Error) {
+func (server *Server) readRequest(codec ServerCodec) (service *service, mtype *methodType, req *Request, argv, replyv reflect.Value, err os.Error) {
+       service, mtype, req, err = server.readRequestHeader(codec)
+       if err != nil {
+               if err == os.EOF || err == io.ErrUnexpectedEOF {
+                       return
+               }
+               // discard body
+               codec.ReadRequestBody(nil)
+               return
+       }
+
+       // Decode the argument value.
+       argIsValue := false // if true, need to indirect before calling.
+       if mtype.ArgType.Kind() == reflect.Ptr {
+               argv = reflect.New(mtype.ArgType.Elem())
+       } else {
+               argv = reflect.New(mtype.ArgType)
+               argIsValue = true
+       }
+       // argv guaranteed to be a pointer now.
+       if err = codec.ReadRequestBody(argv.Interface()); err != nil {
+               return
+       }
+       if argIsValue {
+               argv = argv.Elem()
+       }
+
+       replyv = reflect.New(mtype.ReplyType.Elem())
+       return
+}
+
+func (server *Server) readRequestHeader(codec ServerCodec) (service *service, mtype *methodType, req *Request, err os.Error) {
        // Grab the request header.
        req = server.getRequest()
        err = codec.ReadRequestHeader(req)
@@ -568,6 +590,12 @@ func ServeCodec(codec ServerCodec) {
        DefaultServer.ServeCodec(codec)
 }
 
+// ServeRequest is like ServeCodec but synchronously serves a single request.
+// It does not close the codec upon completion.
+func ServeRequest(codec ServerCodec) os.Error {
+       return DefaultServer.ServeRequest(codec)
+}
+
 // Accept accepts connections on the listener and serves requests
 // to DefaultServer for each incoming connection.  
 // Accept blocks; the caller typically invokes it in a go statement.
index 459dd59d6a0fa2f30bbaa06c8de1f58d6436e3ca..e7bbfbe97d29a594d9b6d074ba3b670774922542 100644 (file)
@@ -7,6 +7,7 @@ package rpc
 import (
        "fmt"
        "http/httptest"
+       "io"
        "log"
        "net"
        "os"
@@ -18,6 +19,7 @@ import (
 )
 
 var (
+       newServer                 *Server
        serverAddr, newServerAddr string
        httpServerAddr            string
        once, newOnce, httpOnce   sync.Once
@@ -93,15 +95,15 @@ func startServer() {
 }
 
 func startNewServer() {
-       s := NewServer()
-       s.Register(new(Arith))
+       newServer = NewServer()
+       newServer.Register(new(Arith))
 
        var l net.Listener
        l, newServerAddr = listenTCP()
        log.Println("NewServer test RPC server listening on", newServerAddr)
        go Accept(l)
 
-       s.HandleHTTP(newHttpPath, "/bar")
+       newServer.HandleHTTP(newHttpPath, "/bar")
        httpOnce.Do(startHttpServer)
 }
 
@@ -264,6 +266,85 @@ func testHTTPRPC(t *testing.T, path string) {
        }
 }
 
+// CodecEmulator provides a client-like api and a ServerCodec interface.
+// Can be used to test ServeRequest.
+type CodecEmulator struct {
+       server        *Server
+       serviceMethod string
+       args          *Args
+       reply         *Reply
+       err           os.Error
+}
+
+func (codec *CodecEmulator) Call(serviceMethod string, args *Args, reply *Reply) os.Error {
+       codec.serviceMethod = serviceMethod
+       codec.args = args
+       codec.reply = reply
+       codec.err = nil
+       var serverError os.Error
+       if codec.server == nil {
+               serverError = ServeRequest(codec)
+       } else {
+               serverError = codec.server.ServeRequest(codec)
+       }
+       if codec.err == nil && serverError != nil {
+               codec.err = serverError
+       }
+       return codec.err
+}
+
+func (codec *CodecEmulator) ReadRequestHeader(req *Request) os.Error {
+       req.ServiceMethod = codec.serviceMethod
+       req.Seq = 0
+       return nil
+}
+
+func (codec *CodecEmulator) ReadRequestBody(argv interface{}) os.Error {
+       if codec.args == nil {
+               return io.ErrUnexpectedEOF
+       }
+       *(argv.(*Args)) = *codec.args
+       return nil
+}
+
+func (codec *CodecEmulator) WriteResponse(resp *Response, reply interface{}) os.Error {
+       if resp.Error != "" {
+               codec.err = os.NewError(resp.Error)
+       }
+       *codec.reply = *(reply.(*Reply))
+       return nil
+}
+
+func (codec *CodecEmulator) Close() os.Error {
+       return nil
+}
+
+func TestServeRequest(t *testing.T) {
+       once.Do(startServer)
+       testServeRequest(t, nil)
+       newOnce.Do(startNewServer)
+       testServeRequest(t, newServer)
+}
+
+func testServeRequest(t *testing.T, server *Server) {
+       client := CodecEmulator{server: server}
+
+       args := &Args{7, 8}
+       reply := new(Reply)
+       err := client.Call("Arith.Add", args, reply)
+       if err != nil {
+               t.Errorf("Add: expected no error but got string %q", err.String())
+       }
+       if reply.C != args.A+args.B {
+               t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B)
+       }
+
+       err = client.Call("Arith.Add", nil, reply)
+       if err == nil {
+               t.Errorf("expected error calling Arith.Add with nil arg")
+       }
+}
+
 type ReplyNotPointer int
 type ArgNotPublic int
 type ReplyNotPublic int