]> Cypherpunks repositories - gostls13.git/commitdiff
net/rpc: give hint to pass a pointer to Register
authorShenghou Ma <minux.ma@gmail.com>
Tue, 6 Nov 2012 21:03:16 +0000 (05:03 +0800)
committerShenghou Ma <minux.ma@gmail.com>
Tue, 6 Nov 2012 21:03:16 +0000 (05:03 +0800)
Fixes #4325.

R=golang-dev, r
CC=golang-dev
https://golang.org/cl/6819081

src/pkg/net/rpc/server.go
src/pkg/net/rpc/server_test.go

index ee1df823ebfad4b54561551ccaea9a7dae7429f0..8898b98abad488a9b3a713e549a87b232f8cd1f7 100644 (file)
@@ -261,8 +261,30 @@ func (server *Server) register(rcvr interface{}, name string, useName bool) erro
        s.method = make(map[string]*methodType)
 
        // Install the methods
-       for m := 0; m < s.typ.NumMethod(); m++ {
-               method := s.typ.Method(m)
+       s.method = suitableMethods(s.typ, true)
+
+       if len(s.method) == 0 {
+               str := ""
+               // To help the user, see if a pointer receiver would work.
+               method := suitableMethods(reflect.PtrTo(s.typ), false)
+               if len(method) != 0 {
+                       str = "rpc.Register: type " + sname + " has no exported methods of suitable type (hint: pass a pointer to value of that type)"
+               } else {
+                       str = "rpc.Register: type " + sname + " has no exported methods of suitable type"
+               }
+               log.Print(str)
+               return errors.New(str)
+       }
+       server.serviceMap[s.name] = s
+       return nil
+}
+
+// suitableMethods returns suitable Rpc methods of typ, it will report
+// error using log if reportErr is true.
+func suitableMethods(typ reflect.Type, reportErr bool) map[string]*methodType {
+       methods := make(map[string]*methodType)
+       for m := 0; m < typ.NumMethod(); m++ {
+               method := typ.Method(m)
                mtype := method.Type
                mname := method.Name
                // Method must be exported.
@@ -271,46 +293,51 @@ func (server *Server) register(rcvr interface{}, name string, useName bool) erro
                }
                // Method needs three ins: receiver, *args, *reply.
                if mtype.NumIn() != 3 {
-                       log.Println("method", mname, "has wrong number of ins:", mtype.NumIn())
+                       if reportErr {
+                               log.Println("method", mname, "has wrong number of ins:", mtype.NumIn())
+                       }
                        continue
                }
                // First arg need not be a pointer.
                argType := mtype.In(1)
                if !isExportedOrBuiltinType(argType) {
-                       log.Println(mname, "argument type not exported:", argType)
+                       if reportErr {
+                               log.Println(mname, "argument type not exported:", argType)
+                       }
                        continue
                }
                // Second arg must be a pointer.
                replyType := mtype.In(2)
                if replyType.Kind() != reflect.Ptr {
-                       log.Println("method", mname, "reply type not a pointer:", replyType)
+                       if reportErr {
+                               log.Println("method", mname, "reply type not a pointer:", replyType)
+                       }
                        continue
                }
                // Reply type must be exported.
                if !isExportedOrBuiltinType(replyType) {
-                       log.Println("method", mname, "reply type not exported:", replyType)
+                       if reportErr {
+                               log.Println("method", mname, "reply type not exported:", replyType)
+                       }
                        continue
                }
                // Method needs one out.
                if mtype.NumOut() != 1 {
-                       log.Println("method", mname, "has wrong number of outs:", mtype.NumOut())
+                       if reportErr {
+                               log.Println("method", mname, "has wrong number of outs:", mtype.NumOut())
+                       }
                        continue
                }
                // The return type of the method must be error.
                if returnType := mtype.Out(0); returnType != typeOfError {
-                       log.Println("method", mname, "returns", returnType.String(), "not error")
+                       if reportErr {
+                               log.Println("method", mname, "returns", returnType.String(), "not error")
+                       }
                        continue
                }
-               s.method[mname] = &methodType{method: method, ArgType: argType, ReplyType: replyType}
-       }
-
-       if len(s.method) == 0 {
-               s := "rpc Register: type " + sname + " has no exported methods of suitable type"
-               log.Print(s)
-               return errors.New(s)
+               methods[mname] = &methodType{method: method, ArgType: argType, ReplyType: replyType}
        }
-       server.serviceMap[s.name] = s
-       return nil
+       return methods
 }
 
 // A value sent as a placeholder for the server's response value when the server
index a718e8a9407d8bc362e5ac625519b48e31de8880..d9ebe71e5c310a6cecff9a6435681ae9de18d4a1 100644 (file)
@@ -349,6 +349,7 @@ func testServeRequest(t *testing.T, server *Server) {
 type ReplyNotPointer int
 type ArgNotPublic int
 type ReplyNotPublic int
+type NeedsPtrType int
 type local struct{}
 
 func (t *ReplyNotPointer) ReplyNotPointer(args *Args, reply Reply) error {
@@ -363,19 +364,29 @@ func (t *ReplyNotPublic) ReplyNotPublic(args *Args, reply *local) error {
        return nil
 }
 
+func (t *NeedsPtrType) NeedsPtrType(args *Args, reply *Reply) error {
+       return nil
+}
+
 // Check that registration handles lots of bad methods and a type with no suitable methods.
 func TestRegistrationError(t *testing.T) {
        err := Register(new(ReplyNotPointer))
        if err == nil {
-               t.Errorf("expected error registering ReplyNotPointer")
+               t.Error("expected error registering ReplyNotPointer")
        }
        err = Register(new(ArgNotPublic))
        if err == nil {
-               t.Errorf("expected error registering ArgNotPublic")
+               t.Error("expected error registering ArgNotPublic")
        }
        err = Register(new(ReplyNotPublic))
        if err == nil {
-               t.Errorf("expected error registering ReplyNotPublic")
+               t.Error("expected error registering ReplyNotPublic")
+       }
+       err = Register(NeedsPtrType(0))
+       if err == nil {
+               t.Error("expected error registering NeedsPtrType")
+       } else if !strings.Contains(err.Error(), "pointer") {
+               t.Error("expected hint when registering NeedsPtrType")
        }
 }