]> Cypherpunks repositories - gostls13.git/commitdiff
rpc: allow the argument (first arg of method) to be a value rather than a pointer.
authorRob Pike <r@golang.org>
Tue, 26 Apr 2011 22:07:25 +0000 (15:07 -0700)
committerRob Pike <r@golang.org>
Tue, 26 Apr 2011 22:07:25 +0000 (15:07 -0700)
Can make the API nicer in some cases.

R=rsc, rog, r2
CC=golang-dev
https://golang.org/cl/4428064

src/pkg/rpc/server.go
src/pkg/rpc/server_test.go

index 086457963a2bff09dba1cf2db138151f220ca62f..acadeec37f0a7b95f0974bdfe07f794531f35fae 100644 (file)
        Only methods that satisfy these criteria will be made available for remote access;
        other methods will be ignored:
 
-               - the method receiver and name are exported, that is, begin with an upper case letter.
-               - the method has two arguments, both pointers to exported types.
+               - the method name is exported, that is, begins with an upper case letter.
+               - the method receiver is exported or local (defined in the package
+                 registering the service).
+               - the method has two arguments, both exported or local types.
+               - the method's second argument is a pointer.
                - the method has return type os.Error.
 
        The method's first argument represents the arguments provided by the caller; the
@@ -193,6 +196,14 @@ func isExported(name string) bool {
        return unicode.IsUpper(rune)
 }
 
+// Is this type exported or local to this package?
+func isExportedOrLocalType(t reflect.Type) bool {
+       for t.Kind() == reflect.Ptr {
+               t = t.Elem()
+       }
+       return t.PkgPath() == "" || isExported(t.Name())
+}
+
 // Register publishes in the server the set of methods of the
 // receiver value that satisfy the following conditions:
 //     - exported method
@@ -252,23 +263,20 @@ func (server *Server) register(rcvr interface{}, name string, useName bool) os.E
                        log.Println("method", mname, "has wrong number of ins:", mtype.NumIn())
                        continue
                }
+               // First arg need not be a pointer.
                argType := mtype.In(1)
-               ok := argType.Kind() == reflect.Ptr
-               if !ok {
-                       log.Println(mname, "arg type not a pointer:", mtype.In(1))
+               if !isExportedOrLocalType(argType) {
+                       log.Println(mname, "argument type not exported or local:", argType)
                        continue
                }
+               // Second arg must be a pointer.
                replyType := mtype.In(2)
                if replyType.Kind() != reflect.Ptr {
-                       log.Println(mname, "reply type not a pointer:", mtype.In(2))
+                       log.Println("method", mname, "reply type not a pointer:", replyType)
                        continue
                }
-               if argType.Elem().PkgPath() != "" && !isExported(argType.Elem().Name()) {
-                       log.Println(mname, "argument type not exported:", argType)
-                       continue
-               }
-               if replyType.Elem().PkgPath() != "" && !isExported(replyType.Elem().Name()) {
-                       log.Println(mname, "reply type not exported:", replyType)
+               if !isExportedOrLocalType(replyType) {
+                       log.Println("method", mname, "reply type not exported or local:", replyType)
                        continue
                }
                // Method needs one out: os.Error.
@@ -405,7 +413,15 @@ func (server *Server) ServeCodec(codec ServerCodec) {
                }
 
                // Decode the argument value.
-               argv := reflect.New(mtype.ArgType.Elem())
+               var argv reflect.Value
+               argIsValue := false // if true, need to indirect before calling.
+               if mtype.ArgType.Kind() == reflect.Ptr {
+                       argv = reflect.New(mtype.ArgType.Elem())
+               } else {
+                       argv = reflect.New(mtype.ArgType)
+                       argIsValue = true
+               }
+               // argv guaranteed to be a pointer now.
                replyv := reflect.New(mtype.ReplyType.Elem())
                err = codec.ReadRequestBody(argv.Interface())
                if err != nil {
@@ -418,6 +434,9 @@ func (server *Server) ServeCodec(codec ServerCodec) {
                        server.sendResponse(sending, req, replyv.Interface(), codec, err.String())
                        continue
                }
+               if argIsValue {
+                       argv = argv.Elem()
+               }
                go service.call(server, sending, mtype, req, argv, replyv, codec)
        }
        codec.Close()
index d4041ae70ce8ac406c6fff5eb5c0e04c3aa039cf..eb7b673d661de5be59de8c1dee74ab72270fd9a5 100644 (file)
@@ -38,7 +38,9 @@ type Reply struct {
 
 type Arith int
 
-func (t *Arith) Add(args *Args, reply *Reply) os.Error {
+// Some of Arith's methods have value args, some have pointer args. That's deliberate.
+
+func (t *Arith) Add(args Args, reply *Reply) os.Error {
        reply.C = args.A + args.B
        return nil
 }
@@ -48,7 +50,7 @@ func (t *Arith) Mul(args *Args, reply *Reply) os.Error {
        return nil
 }
 
-func (t *Arith) Div(args *Args, reply *Reply) os.Error {
+func (t *Arith) Div(args Args, reply *Reply) os.Error {
        if args.B == 0 {
                return os.ErrorString("divide by zero")
        }
@@ -61,8 +63,8 @@ func (t *Arith) String(args *Args, reply *string) os.Error {
        return nil
 }
 
-func (t *Arith) Scan(args *string, reply *Reply) (err os.Error) {
-       _, err = fmt.Sscan(*args, &reply.C)
+func (t *Arith) Scan(args string, reply *Reply) (err os.Error) {
+       _, err = fmt.Sscan(args, &reply.C)
        return
 }
 
@@ -262,16 +264,11 @@ func testHTTPRPC(t *testing.T, path string) {
        }
 }
 
-type ArgNotPointer int
 type ReplyNotPointer int
 type ArgNotPublic int
 type ReplyNotPublic int
 type local struct{}
 
-func (t *ArgNotPointer) ArgNotPointer(args Args, reply *Reply) os.Error {
-       return nil
-}
-
 func (t *ReplyNotPointer) ReplyNotPointer(args *Args, reply Reply) os.Error {
        return nil
 }
@@ -286,11 +283,7 @@ func (t *ReplyNotPublic) ReplyNotPublic(args *Args, reply *local) os.Error {
 
 // Check that registration handles lots of bad methods and a type with no suitable methods.
 func TestRegistrationError(t *testing.T) {
-       err := Register(new(ArgNotPointer))
-       if err == nil {
-               t.Errorf("expected error registering ArgNotPointer")
-       }
-       err = Register(new(ReplyNotPointer))
+       err := Register(new(ReplyNotPointer))
        if err == nil {
                t.Errorf("expected error registering ReplyNotPointer")
        }