"sync"
)
+// ServerError represents an error that has been returned from
+// the remote side of the RPC connection.
+type ServerError string
+
+func (e ServerError) String() string {
+ return string(e)
+}
+
+const ErrShutdown = os.ErrorString("connection is shut down")
+
// Call represents an active RPC.
type Call struct {
ServiceMethod string // The name of the service and method to call.
// with a single Client.
type Client struct {
mutex sync.Mutex // protects pending, seq
- shutdown os.Error // non-nil if the client is shut down
sending sync.Mutex
seq uint64
codec ClientCodec
pending map[uint64]*Call
closing bool
+ shutdown bool
}
// A ClientCodec implements writing of RPC requests and
func (client *Client) send(c *Call) {
// Register this call.
client.mutex.Lock()
- if client.shutdown != nil {
- c.Error = client.shutdown
+ if client.shutdown {
+ c.Error = ErrShutdown
client.mutex.Unlock()
c.done()
return
func (client *Client) input() {
var err os.Error
+ var marker struct{}
for err == nil {
response := new(Response)
err = client.codec.ReadResponseHeader(response)
c := client.pending[seq]
client.pending[seq] = c, false
client.mutex.Unlock()
- err = client.codec.ReadResponseBody(c.Reply)
- if response.Error != "" {
- c.Error = os.ErrorString(response.Error)
- } else if err != nil {
- c.Error = err
+
+ if response.Error == "" {
+ err = client.codec.ReadResponseBody(c.Reply)
+ if err != nil {
+ c.Error = os.ErrorString("reading body " + err.String())
+ }
} else {
- // Empty strings should turn into nil os.Errors
- c.Error = nil
+ // We've got an error response. Give this to the request;
+ // any subsequent requests will get the ReadResponseBody
+ // error if there is one.
+ c.Error = ServerError(response.Error)
+ err = client.codec.ReadResponseBody(&marker)
+ if err != nil {
+ err = os.ErrorString("reading error body: " + err.String())
+ }
}
c.done()
}
// Terminate pending calls.
client.mutex.Lock()
- client.shutdown = err
+ client.shutdown = true
for _, call := range client.pending {
call.Error = err
call.done()
}
func (client *Client) Close() os.Error {
- if client.shutdown != nil || client.closing {
- return os.ErrorString("rpc: already closed")
- }
client.mutex.Lock()
+ if client.shutdown || client.closing {
+ client.mutex.Unlock()
+ return ErrShutdown
+ }
client.closing = true
client.mutex.Unlock()
return client.codec.Close()
}
}
c.Done = done
- if client.shutdown != nil {
- c.Error = client.shutdown
+ if client.shutdown {
+ c.Error = ErrShutdown
c.done()
return c
}
// Call invokes the named function, waits for it to complete, and returns its error status.
func (client *Client) Call(serviceMethod string, args interface{}, reply interface{}) os.Error {
- if client.shutdown != nil {
- return client.shutdown
+ if client.shutdown {
+ return ErrShutdown
}
call := <-client.Go(serviceMethod, args, reply, nil).Done
return call.Error
// A value sent as a placeholder for the response when the server receives an invalid request.
type InvalidRequest struct {
- marker int
+ Marker int
}
var invalidRequest = InvalidRequest{1}
resp.ServiceMethod = req.ServiceMethod
if errmsg != "" {
resp.Error = errmsg
+ reply = invalidRequest
}
resp.Seq = req.Seq
sending.Lock()
func (server *Server) ServeCodec(codec ServerCodec) {
sending := new(sync.Mutex)
for {
- // Grab the request header.
- req := new(Request)
- err := codec.ReadRequestHeader(req)
+ req, service, mtype, err := server.readRequest(codec)
if err != nil {
+ if err != os.EOF {
+ log.Println("rpc:", err)
+ }
if err == os.EOF || err == io.ErrUnexpectedEOF {
- if err == io.ErrUnexpectedEOF {
- log.Println("rpc:", err)
- }
break
}
- s := "rpc: server cannot decode request: " + err.String()
- sendResponse(sending, req, invalidRequest, codec, s)
- break
- }
- serviceMethod := strings.Split(req.ServiceMethod, ".", -1)
- if len(serviceMethod) != 2 {
- s := "rpc: service/method request ill-formed: " + req.ServiceMethod
- sendResponse(sending, req, invalidRequest, codec, s)
- continue
- }
- // Look up the request.
- server.Lock()
- service, ok := server.serviceMap[serviceMethod[0]]
- server.Unlock()
- if !ok {
- s := "rpc: can't find service " + req.ServiceMethod
- sendResponse(sending, req, invalidRequest, codec, s)
- continue
- }
- mtype, ok := service.method[serviceMethod[1]]
- if !ok {
- s := "rpc: can't find method " + req.ServiceMethod
- sendResponse(sending, req, invalidRequest, codec, s)
+ // discard body
+ codec.ReadRequestBody(new(interface{}))
+
+ // send a response if we actually managed to read a header.
+ if req != nil {
+ sendResponse(sending, req, invalidRequest, codec, err.String())
+ }
continue
}
+
// Decode the argument value.
argv := _new(mtype.ArgType)
replyv := _new(mtype.ReplyType)
err = codec.ReadRequestBody(argv.Interface())
if err != nil {
- log.Println("rpc: tearing down", serviceMethod[0], "connection:", err)
+ if err == os.EOF || err == io.ErrUnexpectedEOF {
+ if err == io.ErrUnexpectedEOF {
+ log.Println("rpc:", err)
+ }
+ break
+ }
sendResponse(sending, req, replyv.Interface(), codec, err.String())
- break
+ continue
}
go service.call(sending, mtype, req, argv, replyv, codec)
}
codec.Close()
}
+func (server *Server) readRequest(codec ServerCodec) (req *Request, service *service, mtype *methodType, err os.Error) {
+ // Grab the request header.
+ req = new(Request)
+ err = codec.ReadRequestHeader(req)
+ if err != nil {
+ req = nil
+ if err == os.EOF || err == io.ErrUnexpectedEOF {
+ return
+ }
+ err = os.ErrorString("rpc: server cannot decode request: " + err.String())
+ return
+ }
+
+ serviceMethod := strings.Split(req.ServiceMethod, ".", -1)
+ if len(serviceMethod) != 2 {
+ err = os.ErrorString("rpc: service/method request ill-formed: " + req.ServiceMethod)
+ return
+ }
+ // Look up the request.
+ server.Lock()
+ service = server.serviceMap[serviceMethod[0]]
+ server.Unlock()
+ if service == nil {
+ err = os.ErrorString("rpc: can't find service " + req.ServiceMethod)
+ return
+ }
+ mtype = service.method[serviceMethod[1]]
+ if mtype == nil {
+ err = os.ErrorString("rpc: can't find method " + req.ServiceMethod)
+ }
+ return
+}
// Accept accepts connections on the listener and serves requests
// for each incoming connection. Accept blocks; the caller typically
t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B)
}
- args = &Args{7, 8}
+ // Nonexistent method
+ args = &Args{7, 0}
reply = new(Reply)
- err = client.Call("Arith.Mul", args, reply)
- if err != nil {
- t.Errorf("Mul: expected no error but got string %q", err.String())
+ err = client.Call("Arith.BadOperation", args, reply)
+ // expect an error
+ if err == nil {
+ t.Error("BadOperation: expected error")
+ } else if !strings.HasPrefix(err.String(), "rpc: can't find method ") {
+ t.Errorf("BadOperation: expected can't find method error; got %q", err)
}
- if reply.C != args.A*args.B {
- t.Errorf("Mul: expected %d got %d", reply.C, args.A*args.B)
+
+ // Unknown service
+ args = &Args{7, 8}
+ reply = new(Reply)
+ err = client.Call("Arith.Unknown", args, reply)
+ if err == nil {
+ t.Error("expected error calling unknown service")
+ } else if strings.Index(err.String(), "method") < 0 {
+ t.Error("expected error about method; got", err)
}
// Out of order.
t.Error("Div: expected divide by zero error; got", err)
}
+ // Bad type.
+ reply = new(Reply)
+ err = client.Call("Arith.Add", reply, reply) // args, reply would be the correct thing to use
+ if err == nil {
+ t.Error("expected error calling Arith.Add with wrong arg type")
+ } else if strings.Index(err.String(), "type") < 0 {
+ t.Error("expected error about type; got", err)
+ }
+
// Non-struct argument
const Val = 12345
str := fmt.Sprint(Val)
if str != expect {
t.Errorf("String: expected %s got %s", expect, str)
}
+
+ args = &Args{7, 8}
+ reply = new(Reply)
+ err = client.Call("Arith.Mul", args, reply)
+ if err != nil {
+ t.Errorf("Mul: expected no error but got string %q", err.String())
+ }
+ if reply.C != args.A*args.B {
+ t.Errorf("Mul: expected %d got %d", reply.C, args.A*args.B)
+ }
}
-func TestHTTPRPC(t *testing.T) {
+func TestHTTP(t *testing.T) {
once.Do(startServer)
testHTTPRPC(t, "")
newOnce.Do(startNewServer)
}
}
-func TestCheckUnknownService(t *testing.T) {
- once.Do(startServer)
-
- conn, err := net.Dial("tcp", "", serverAddr)
- if err != nil {
- t.Fatal("dialing:", err)
- }
-
- client := NewClient(conn)
-
- args := &Args{7, 8}
- reply := new(Reply)
- err = client.Call("Unknown.Add", args, reply)
- if err == nil {
- t.Error("expected error calling unknown service")
- } else if strings.Index(err.String(), "service") < 0 {
- t.Error("expected error about service; got", err)
- }
-}
-
-func TestCheckUnknownMethod(t *testing.T) {
- once.Do(startServer)
-
- conn, err := net.Dial("tcp", "", serverAddr)
- if err != nil {
- t.Fatal("dialing:", err)
- }
-
- client := NewClient(conn)
-
- args := &Args{7, 8}
- reply := new(Reply)
- err = client.Call("Arith.Unknown", args, reply)
- if err == nil {
- t.Error("expected error calling unknown service")
- } else if strings.Index(err.String(), "method") < 0 {
- t.Error("expected error about method; got", err)
- }
-}
-
-func TestCheckBadType(t *testing.T) {
- once.Do(startServer)
-
- conn, err := net.Dial("tcp", "", serverAddr)
- if err != nil {
- t.Fatal("dialing:", err)
- }
-
- client := NewClient(conn)
-
- reply := new(Reply)
- err = client.Call("Arith.Add", reply, reply) // args, reply would be the correct thing to use
- if err == nil {
- t.Error("expected error calling Arith.Add with wrong arg type")
- } else if strings.Index(err.String(), "type") < 0 {
- t.Error("expected error about type; got", err)
- }
-}
-
type ArgNotPointer int
type ReplyNotPointer int
type ArgNotPublic int