]> Cypherpunks repositories - gostls13.git/commitdiff
net/rpc: fix data race on Call.Error
authorDmitriy Vyukov <dvyukov@google.com>
Fri, 27 Jan 2012 07:27:05 +0000 (11:27 +0400)
committerDmitriy Vyukov <dvyukov@google.com>
Fri, 27 Jan 2012 07:27:05 +0000 (11:27 +0400)
+eliminates a possibility of sending a call to Done several times.
+fixes memory leak in case of temporal Write errors.
+fixes data race on Client.shutdown.
+fixes data race on Client.closing.
+fixes comments.
Fixes #2780.

R=r, rsc
CC=golang-dev, mpimenov
https://golang.org/cl/5571063

src/pkg/net/rpc/client.go

index abc1e59cd510858c66b3c70ff33c7a9a2af09957..69c440769500b4235ea6937d0cf721a53bd9228d 100644 (file)
@@ -31,8 +31,7 @@ type Call struct {
        Args          interface{} // The argument to the function (*struct).
        Reply         interface{} // The reply from the function (*struct).
        Error         error       // After completion, the error status.
-       Done          chan *Call  // Strobes when call is complete; value is the error status.
-       seq           uint64
+       Done          chan *Call  // Strobes when call is complete.
 }
 
 // Client represents an RPC Client.
@@ -65,28 +64,33 @@ type ClientCodec interface {
        Close() error
 }
 
-func (client *Client) send(c *Call) {
+func (client *Client) send(call *Call) {
+       client.sending.Lock()
+       defer client.sending.Unlock()
+
        // Register this call.
        client.mutex.Lock()
        if client.shutdown {
-               c.Error = ErrShutdown
+               call.Error = ErrShutdown
                client.mutex.Unlock()
-               c.done()
+               call.done()
                return
        }
-       c.seq = client.seq
+       seq := client.seq
        client.seq++
-       client.pending[c.seq] = c
+       client.pending[seq] = call
        client.mutex.Unlock()
 
        // Encode and send the request.
-       client.sending.Lock()
-       defer client.sending.Unlock()
-       client.request.Seq = c.seq
-       client.request.ServiceMethod = c.ServiceMethod
-       if err := client.codec.WriteRequest(&client.request, c.Args); err != nil {
-               c.Error = err
-               c.done()
+       client.request.Seq = seq
+       client.request.ServiceMethod = call.ServiceMethod
+       err := client.codec.WriteRequest(&client.request, call.Args)
+       if err != nil {
+               client.mutex.Lock()
+               delete(client.pending, seq)
+               client.mutex.Unlock()
+               call.Error = err
+               call.done()
        }
 }
 
@@ -104,36 +108,39 @@ func (client *Client) input() {
                }
                seq := response.Seq
                client.mutex.Lock()
-               c := client.pending[seq]
+               call := client.pending[seq]
                delete(client.pending, seq)
                client.mutex.Unlock()
 
                if response.Error == "" {
-                       err = client.codec.ReadResponseBody(c.Reply)
+                       err = client.codec.ReadResponseBody(call.Reply)
                        if err != nil {
-                               c.Error = errors.New("reading body " + err.Error())
+                               call.Error = errors.New("reading body " + err.Error())
                        }
                } else {
                        // We've got an error response. Give this to the request;
                        // any subsequent requests will get the ReadResponseBody
                        // error if there is one.
-                       c.Error = ServerError(response.Error)
+                       call.Error = ServerError(response.Error)
                        err = client.codec.ReadResponseBody(nil)
                        if err != nil {
                                err = errors.New("reading error body: " + err.Error())
                        }
                }
-               c.done()
+               call.done()
        }
        // Terminate pending calls.
+       client.sending.Lock()
        client.mutex.Lock()
        client.shutdown = true
+       closing := client.closing
        for _, call := range client.pending {
                call.Error = err
                call.done()
        }
        client.mutex.Unlock()
-       if err != io.EOF || !client.closing {
+       client.sending.Unlock()
+       if err != io.EOF || !closing {
                log.Println("rpc: client protocol error:", err)
        }
 }
@@ -269,20 +276,12 @@ func (client *Client) Go(serviceMethod string, args interface{}, reply interface
                }
        }
        call.Done = done
-       if client.shutdown {
-               call.Error = ErrShutdown
-               call.done()
-               return call
-       }
        client.send(call)
        return call
 }
 
 // Call invokes the named function, waits for it to complete, and returns its error status.
 func (client *Client) Call(serviceMethod string, args interface{}, reply interface{}) error {
-       if client.shutdown {
-               return ErrShutdown
-       }
        call := <-client.Go(serviceMethod, args, reply, make(chan *Call, 1)).Done
        return call.Error
 }