From 2e79e8e54920c005af29447a85d7b241460c34cb Mon Sep 17 00:00:00 2001 From: Russ Cox Date: Tue, 1 Nov 2011 00:29:41 -0400 Subject: [PATCH] rpc: avoid infinite loop on input error Fixes #1828. Fixes #2179. R=golang-dev, r CC=golang-dev https://golang.org/cl/5305084 --- src/pkg/rpc/jsonrpc/all_test.go | 65 +++++++++++++++++++++++++++++++++ src/pkg/rpc/server.go | 20 ++++++---- src/pkg/rpc/server_test.go | 3 +- 3 files changed, 79 insertions(+), 9 deletions(-) diff --git a/src/pkg/rpc/jsonrpc/all_test.go b/src/pkg/rpc/jsonrpc/all_test.go index c1a9e8ecbc..99253baf3c 100644 --- a/src/pkg/rpc/jsonrpc/all_test.go +++ b/src/pkg/rpc/jsonrpc/all_test.go @@ -6,6 +6,7 @@ package jsonrpc import ( "fmt" + "io" "json" "net" "os" @@ -154,3 +155,67 @@ func TestClient(t *testing.T) { t.Error("Div: expected divide by zero error; got", err) } } + +func TestMalformedInput(t *testing.T) { + cli, srv := net.Pipe() + go cli.Write([]byte(`{id:1}`)) // invalid json + ServeConn(srv) // must return, not loop +} + +func TestUnexpectedError(t *testing.T) { + cli, srv := myPipe() + go cli.PipeWriter.CloseWithError(os.NewError("unexpected error!")) // reader will get this error + ServeConn(srv) // must return, not loop +} + +// Copied from package net. +func myPipe() (*pipe, *pipe) { + r1, w1 := io.Pipe() + r2, w2 := io.Pipe() + + return &pipe{r1, w2}, &pipe{r2, w1} +} + +type pipe struct { + *io.PipeReader + *io.PipeWriter +} + +type pipeAddr int + +func (pipeAddr) Network() string { + return "pipe" +} + +func (pipeAddr) String() string { + return "pipe" +} + +func (p *pipe) Close() os.Error { + err := p.PipeReader.Close() + err1 := p.PipeWriter.Close() + if err == nil { + err = err1 + } + return err +} + +func (p *pipe) LocalAddr() net.Addr { + return pipeAddr(0) +} + +func (p *pipe) RemoteAddr() net.Addr { + return pipeAddr(0) +} + +func (p *pipe) SetTimeout(nsec int64) os.Error { + return os.NewError("net.Pipe does not support timeouts") +} + +func (p *pipe) SetReadTimeout(nsec int64) os.Error { + return os.NewError("net.Pipe does not support timeouts") +} + +func (p *pipe) SetWriteTimeout(nsec int64) os.Error { + return os.NewError("net.Pipe does not support timeouts") +} diff --git a/src/pkg/rpc/server.go b/src/pkg/rpc/server.go index f03710061a..142bf8a529 100644 --- a/src/pkg/rpc/server.go +++ b/src/pkg/rpc/server.go @@ -394,12 +394,12 @@ func (server *Server) ServeConn(conn io.ReadWriteCloser) { func (server *Server) ServeCodec(codec ServerCodec) { sending := new(sync.Mutex) for { - service, mtype, req, argv, replyv, err := server.readRequest(codec) + service, mtype, req, argv, replyv, keepReading, err := server.readRequest(codec) if err != nil { if err != os.EOF { log.Println("rpc:", err) } - if err == os.EOF || err == io.ErrUnexpectedEOF { + if !keepReading { break } // send a response if we actually managed to read a header. @@ -418,9 +418,9 @@ func (server *Server) ServeCodec(codec ServerCodec) { // It does not close the codec upon completion. func (server *Server) ServeRequest(codec ServerCodec) os.Error { sending := new(sync.Mutex) - service, mtype, req, argv, replyv, err := server.readRequest(codec) + service, mtype, req, argv, replyv, keepReading, err := server.readRequest(codec) if err != nil { - if err == os.EOF || err == io.ErrUnexpectedEOF { + if !keepReading { return err } // send a response if we actually managed to read a header. @@ -474,10 +474,10 @@ func (server *Server) freeResponse(resp *Response) { server.respLock.Unlock() } -func (server *Server) readRequest(codec ServerCodec) (service *service, mtype *methodType, req *Request, argv, replyv reflect.Value, err os.Error) { - service, mtype, req, err = server.readRequestHeader(codec) +func (server *Server) readRequest(codec ServerCodec) (service *service, mtype *methodType, req *Request, argv, replyv reflect.Value, keepReading bool, err os.Error) { + service, mtype, req, keepReading, err = server.readRequestHeader(codec) if err != nil { - if err == os.EOF || err == io.ErrUnexpectedEOF { + if !keepReading { return } // discard body @@ -505,7 +505,7 @@ func (server *Server) readRequest(codec ServerCodec) (service *service, mtype *m return } -func (server *Server) readRequestHeader(codec ServerCodec) (service *service, mtype *methodType, req *Request, err os.Error) { +func (server *Server) readRequestHeader(codec ServerCodec) (service *service, mtype *methodType, req *Request, keepReading bool, err os.Error) { // Grab the request header. req = server.getRequest() err = codec.ReadRequestHeader(req) @@ -518,6 +518,10 @@ func (server *Server) readRequestHeader(codec ServerCodec) (service *service, mt return } + // We read the header successfully. If we see an error now, + // we can still recover and move on to the next request. + keepReading = true + serviceMethod := strings.Split(req.ServiceMethod, ".") if len(serviceMethod) != 2 { err = os.NewError("rpc: service/method request ill-formed: " + req.ServiceMethod) diff --git a/src/pkg/rpc/server_test.go b/src/pkg/rpc/server_test.go index 029741b28b..3e9fe297d4 100644 --- a/src/pkg/rpc/server_test.go +++ b/src/pkg/rpc/server_test.go @@ -311,8 +311,9 @@ func (codec *CodecEmulator) ReadRequestBody(argv interface{}) os.Error { func (codec *CodecEmulator) WriteResponse(resp *Response, reply interface{}) os.Error { if resp.Error != "" { codec.err = os.NewError(resp.Error) + } else { + *codec.reply = *(reply.(*Reply)) } - *codec.reply = *(reply.(*Reply)) return nil } -- 2.48.1