]> Cypherpunks repositories - gostls13.git/commitdiff
rpc: avoid infinite loop on input error
authorRuss Cox <rsc@golang.org>
Tue, 1 Nov 2011 04:29:41 +0000 (00:29 -0400)
committerRuss Cox <rsc@golang.org>
Tue, 1 Nov 2011 04:29:41 +0000 (00:29 -0400)
Fixes #1828.
Fixes #2179.

R=golang-dev, r
CC=golang-dev
https://golang.org/cl/5305084

src/pkg/rpc/jsonrpc/all_test.go
src/pkg/rpc/server.go
src/pkg/rpc/server_test.go

index c1a9e8ecbc5395f3e968ab5938b2396b23a7a26d..99253baf3cb7d571ec05d357a3f8b743e103d65b 100644 (file)
@@ -6,6 +6,7 @@ package jsonrpc
 
 import (
        "fmt"
+       "io"
        "json"
        "net"
        "os"
@@ -154,3 +155,67 @@ func TestClient(t *testing.T) {
                t.Error("Div: expected divide by zero error; got", err)
        }
 }
+
+func TestMalformedInput(t *testing.T) {
+       cli, srv := net.Pipe()
+       go cli.Write([]byte(`{id:1}`)) // invalid json
+       ServeConn(srv)                 // must return, not loop
+}
+
+func TestUnexpectedError(t *testing.T) {
+       cli, srv := myPipe()
+       go cli.PipeWriter.CloseWithError(os.NewError("unexpected error!")) // reader will get this error
+       ServeConn(srv)                                                     // must return, not loop
+}
+
+// Copied from package net.
+func myPipe() (*pipe, *pipe) {
+       r1, w1 := io.Pipe()
+       r2, w2 := io.Pipe()
+
+       return &pipe{r1, w2}, &pipe{r2, w1}
+}
+
+type pipe struct {
+       *io.PipeReader
+       *io.PipeWriter
+}
+
+type pipeAddr int
+
+func (pipeAddr) Network() string {
+       return "pipe"
+}
+
+func (pipeAddr) String() string {
+       return "pipe"
+}
+
+func (p *pipe) Close() os.Error {
+       err := p.PipeReader.Close()
+       err1 := p.PipeWriter.Close()
+       if err == nil {
+               err = err1
+       }
+       return err
+}
+
+func (p *pipe) LocalAddr() net.Addr {
+       return pipeAddr(0)
+}
+
+func (p *pipe) RemoteAddr() net.Addr {
+       return pipeAddr(0)
+}
+
+func (p *pipe) SetTimeout(nsec int64) os.Error {
+       return os.NewError("net.Pipe does not support timeouts")
+}
+
+func (p *pipe) SetReadTimeout(nsec int64) os.Error {
+       return os.NewError("net.Pipe does not support timeouts")
+}
+
+func (p *pipe) SetWriteTimeout(nsec int64) os.Error {
+       return os.NewError("net.Pipe does not support timeouts")
+}
index f03710061a4cc4943fc0ac2d0bdaea3218785a9f..142bf8a5294b78a3212d9c06a114e956577fe95f 100644 (file)
@@ -394,12 +394,12 @@ func (server *Server) ServeConn(conn io.ReadWriteCloser) {
 func (server *Server) ServeCodec(codec ServerCodec) {
        sending := new(sync.Mutex)
        for {
-               service, mtype, req, argv, replyv, err := server.readRequest(codec)
+               service, mtype, req, argv, replyv, keepReading, err := server.readRequest(codec)
                if err != nil {
                        if err != os.EOF {
                                log.Println("rpc:", err)
                        }
-                       if err == os.EOF || err == io.ErrUnexpectedEOF {
+                       if !keepReading {
                                break
                        }
                        // send a response if we actually managed to read a header.
@@ -418,9 +418,9 @@ func (server *Server) ServeCodec(codec ServerCodec) {
 // It does not close the codec upon completion.
 func (server *Server) ServeRequest(codec ServerCodec) os.Error {
        sending := new(sync.Mutex)
-       service, mtype, req, argv, replyv, err := server.readRequest(codec)
+       service, mtype, req, argv, replyv, keepReading, err := server.readRequest(codec)
        if err != nil {
-               if err == os.EOF || err == io.ErrUnexpectedEOF {
+               if !keepReading {
                        return err
                }
                // send a response if we actually managed to read a header.
@@ -474,10 +474,10 @@ func (server *Server) freeResponse(resp *Response) {
        server.respLock.Unlock()
 }
 
-func (server *Server) readRequest(codec ServerCodec) (service *service, mtype *methodType, req *Request, argv, replyv reflect.Value, err os.Error) {
-       service, mtype, req, err = server.readRequestHeader(codec)
+func (server *Server) readRequest(codec ServerCodec) (service *service, mtype *methodType, req *Request, argv, replyv reflect.Value, keepReading bool, err os.Error) {
+       service, mtype, req, keepReading, err = server.readRequestHeader(codec)
        if err != nil {
-               if err == os.EOF || err == io.ErrUnexpectedEOF {
+               if !keepReading {
                        return
                }
                // discard body
@@ -505,7 +505,7 @@ func (server *Server) readRequest(codec ServerCodec) (service *service, mtype *m
        return
 }
 
-func (server *Server) readRequestHeader(codec ServerCodec) (service *service, mtype *methodType, req *Request, err os.Error) {
+func (server *Server) readRequestHeader(codec ServerCodec) (service *service, mtype *methodType, req *Request, keepReading bool, err os.Error) {
        // Grab the request header.
        req = server.getRequest()
        err = codec.ReadRequestHeader(req)
@@ -518,6 +518,10 @@ func (server *Server) readRequestHeader(codec ServerCodec) (service *service, mt
                return
        }
 
+       // We read the header successfully.  If we see an error now,
+       // we can still recover and move on to the next request.
+       keepReading = true
+
        serviceMethod := strings.Split(req.ServiceMethod, ".")
        if len(serviceMethod) != 2 {
                err = os.NewError("rpc: service/method request ill-formed: " + req.ServiceMethod)
index 029741b28b51a85efa6e1ab4d594db5de852d0c4..3e9fe297d4502cf148d5f208455a6d48db2595f4 100644 (file)
@@ -311,8 +311,9 @@ func (codec *CodecEmulator) ReadRequestBody(argv interface{}) os.Error {
 func (codec *CodecEmulator) WriteResponse(resp *Response, reply interface{}) os.Error {
        if resp.Error != "" {
                codec.err = os.NewError(resp.Error)
+       } else {
+               *codec.reply = *(reply.(*Reply))
        }
-       *codec.reply = *(reply.(*Reply))
        return nil
 }