]> Cypherpunks repositories - gostls13.git/commitdiff
net/rpc: wait for responses to be written before closing Codec
authorRuss Cox <rsc@golang.org>
Wed, 22 Nov 2017 20:25:59 +0000 (15:25 -0500)
committerRuss Cox <rsc@golang.org>
Wed, 29 Nov 2017 16:26:57 +0000 (16:26 +0000)
If there are no more requests being made, wait to shut down
the response-writing codec until the pending requests are all
answered.

Fixes #17239.

Change-Id: Ie62c63ada536171df4e70b73c95f98f778069972
Reviewed-on: https://go-review.googlesource.com/79515
Run-TryBot: Russ Cox <rsc@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Rob Pike <r@golang.org>
src/net/rpc/server.go
src/net/rpc/server_test.go

index 29aae7ee7ff04b19815cfac7ebaf6b82266435d4..a0212926037cd8e7c7f9c0e072d92cafee7a2372 100644 (file)
@@ -372,7 +372,10 @@ func (m *methodType) NumCalls() (n uint) {
        return n
 }
 
-func (s *service) call(server *Server, sending *sync.Mutex, mtype *methodType, req *Request, argv, replyv reflect.Value, codec ServerCodec) {
+func (s *service) call(server *Server, sending *sync.Mutex, wg *sync.WaitGroup, mtype *methodType, req *Request, argv, replyv reflect.Value, codec ServerCodec) {
+       if wg != nil {
+               defer wg.Done()
+       }
        mtype.Lock()
        mtype.numCalls++
        mtype.Unlock()
@@ -456,6 +459,7 @@ func (server *Server) ServeConn(conn io.ReadWriteCloser) {
 // decode requests and encode responses.
 func (server *Server) ServeCodec(codec ServerCodec) {
        sending := new(sync.Mutex)
+       wg := new(sync.WaitGroup)
        for {
                service, mtype, req, argv, replyv, keepReading, err := server.readRequest(codec)
                if err != nil {
@@ -472,8 +476,12 @@ func (server *Server) ServeCodec(codec ServerCodec) {
                        }
                        continue
                }
-               go service.call(server, sending, mtype, req, argv, replyv, codec)
+               wg.Add(1)
+               go service.call(server, sending, wg, mtype, req, argv, replyv, codec)
        }
+       // We've seen that there are no more requests.
+       // Wait for responses to be sent before closing codec.
+       wg.Wait()
        codec.Close()
 }
 
@@ -493,7 +501,7 @@ func (server *Server) ServeRequest(codec ServerCodec) error {
                }
                return err
        }
-       service.call(server, sending, mtype, req, argv, replyv, codec)
+       service.call(server, sending, nil, mtype, req, argv, replyv, codec)
        return nil
 }
 
index fb97f82a2f73eb65ede361429e1f2b19736b25d3..e5d7fe0c8f55e79c30ae41aeaecd116670964b90 100644 (file)
@@ -75,6 +75,11 @@ func (t *Arith) Error(args *Args, reply *Reply) error {
        panic("ERROR")
 }
 
+func (t *Arith) SleepMilli(args *Args, reply *Reply) error {
+       time.Sleep(time.Duration(args.A) * time.Millisecond)
+       return nil
+}
+
 type hidden int
 
 func (t *hidden) Exported(args Args, reply *Reply) error {
@@ -693,6 +698,53 @@ func TestAcceptExitAfterListenerClose(t *testing.T) {
        newServer.Accept(l)
 }
 
+func TestShutdown(t *testing.T) {
+       var l net.Listener
+       l, _ = listenTCP()
+       ch := make(chan net.Conn, 1)
+       go func() {
+               defer l.Close()
+               c, err := l.Accept()
+               if err != nil {
+                       t.Error(err)
+               }
+               ch <- c
+       }()
+       c, err := net.Dial("tcp", l.Addr().String())
+       if err != nil {
+               t.Fatal(err)
+       }
+       c1 := <-ch
+       if c1 == nil {
+               t.Fatal(err)
+       }
+
+       newServer := NewServer()
+       newServer.Register(new(Arith))
+       go newServer.ServeConn(c1)
+
+       args := &Args{7, 8}
+       reply := new(Reply)
+       client := NewClient(c)
+       err = client.Call("Arith.Add", args, reply)
+       if err != nil {
+               t.Fatal(err)
+       }
+
+       // On an unloaded system 10ms is usually enough to fail 100% of the time
+       // with a broken server. On a loaded system, a broken server might incorrectly
+       // be reported as passing, but we're OK with that kind of flakiness.
+       // If the code is correct, this test will never fail, regardless of timeout.
+       args.A = 10 // 10 ms
+       done := make(chan *Call, 1)
+       call := client.Go("Arith.SleepMilli", args, reply, done)
+       c.(*net.TCPConn).CloseWrite()
+       <-done
+       if call.Error != nil {
+               t.Fatal(err)
+       }
+}
+
 func benchmarkEndToEnd(dial func() (*Client, error), b *testing.B) {
        once.Do(startServer)
        client, err := dial()