From 4f6d8a59eae73f3ce4e67e8e42a2bc25e7216ec3 Mon Sep 17 00:00:00 2001 From: Russ Cox Date: Wed, 22 Nov 2017 15:25:59 -0500 Subject: [PATCH] net/rpc: wait for responses to be written before closing Codec If there are no more requests being made, wait to shut down the response-writing codec until the pending requests are all answered. Fixes #17239. Change-Id: Ie62c63ada536171df4e70b73c95f98f778069972 Reviewed-on: https://go-review.googlesource.com/79515 Run-TryBot: Russ Cox TryBot-Result: Gobot Gobot Reviewed-by: Rob Pike --- src/net/rpc/server.go | 14 +++++++--- src/net/rpc/server_test.go | 52 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 63 insertions(+), 3 deletions(-) diff --git a/src/net/rpc/server.go b/src/net/rpc/server.go index 29aae7ee7f..a021292603 100644 --- a/src/net/rpc/server.go +++ b/src/net/rpc/server.go @@ -372,7 +372,10 @@ func (m *methodType) NumCalls() (n uint) { return n } -func (s *service) call(server *Server, sending *sync.Mutex, mtype *methodType, req *Request, argv, replyv reflect.Value, codec ServerCodec) { +func (s *service) call(server *Server, sending *sync.Mutex, wg *sync.WaitGroup, mtype *methodType, req *Request, argv, replyv reflect.Value, codec ServerCodec) { + if wg != nil { + defer wg.Done() + } mtype.Lock() mtype.numCalls++ mtype.Unlock() @@ -456,6 +459,7 @@ func (server *Server) ServeConn(conn io.ReadWriteCloser) { // decode requests and encode responses. func (server *Server) ServeCodec(codec ServerCodec) { sending := new(sync.Mutex) + wg := new(sync.WaitGroup) for { service, mtype, req, argv, replyv, keepReading, err := server.readRequest(codec) if err != nil { @@ -472,8 +476,12 @@ func (server *Server) ServeCodec(codec ServerCodec) { } continue } - go service.call(server, sending, mtype, req, argv, replyv, codec) + wg.Add(1) + go service.call(server, sending, wg, mtype, req, argv, replyv, codec) } + // We've seen that there are no more requests. + // Wait for responses to be sent before closing codec. + wg.Wait() codec.Close() } @@ -493,7 +501,7 @@ func (server *Server) ServeRequest(codec ServerCodec) error { } return err } - service.call(server, sending, mtype, req, argv, replyv, codec) + service.call(server, sending, nil, mtype, req, argv, replyv, codec) return nil } diff --git a/src/net/rpc/server_test.go b/src/net/rpc/server_test.go index fb97f82a2f..e5d7fe0c8f 100644 --- a/src/net/rpc/server_test.go +++ b/src/net/rpc/server_test.go @@ -75,6 +75,11 @@ func (t *Arith) Error(args *Args, reply *Reply) error { panic("ERROR") } +func (t *Arith) SleepMilli(args *Args, reply *Reply) error { + time.Sleep(time.Duration(args.A) * time.Millisecond) + return nil +} + type hidden int func (t *hidden) Exported(args Args, reply *Reply) error { @@ -693,6 +698,53 @@ func TestAcceptExitAfterListenerClose(t *testing.T) { newServer.Accept(l) } +func TestShutdown(t *testing.T) { + var l net.Listener + l, _ = listenTCP() + ch := make(chan net.Conn, 1) + go func() { + defer l.Close() + c, err := l.Accept() + if err != nil { + t.Error(err) + } + ch <- c + }() + c, err := net.Dial("tcp", l.Addr().String()) + if err != nil { + t.Fatal(err) + } + c1 := <-ch + if c1 == nil { + t.Fatal(err) + } + + newServer := NewServer() + newServer.Register(new(Arith)) + go newServer.ServeConn(c1) + + args := &Args{7, 8} + reply := new(Reply) + client := NewClient(c) + err = client.Call("Arith.Add", args, reply) + if err != nil { + t.Fatal(err) + } + + // On an unloaded system 10ms is usually enough to fail 100% of the time + // with a broken server. On a loaded system, a broken server might incorrectly + // be reported as passing, but we're OK with that kind of flakiness. + // If the code is correct, this test will never fail, regardless of timeout. + args.A = 10 // 10 ms + done := make(chan *Call, 1) + call := client.Go("Arith.SleepMilli", args, reply, done) + c.(*net.TCPConn).CloseWrite() + <-done + if call.Error != nil { + t.Fatal(err) + } +} + func benchmarkEndToEnd(dial func() (*Client, error), b *testing.B) { once.Do(startServer) client, err := dial() -- 2.48.1