]> Cypherpunks repositories - gostls13.git/commitdiff
http: fixed race condition in persist.go
authorPetar Maymounkov <petarm@gmail.com>
Sat, 5 Mar 2011 19:44:05 +0000 (14:44 -0500)
committerRuss Cox <rsc@golang.org>
Sat, 5 Mar 2011 19:44:05 +0000 (14:44 -0500)
R=rsc, bradfitzgo, bradfitzwork
CC=golang-dev
https://golang.org/cl/4266042

src/pkg/http/persist.go

index 000a4200e59a8a4d7e8730ffe2e93948585346b1..53efd7c8c6a9979f56aac428d58ec19083c9be86 100644 (file)
@@ -25,15 +25,15 @@ var (
 // i.e. requests can be read out of sync (but in the same order) while the
 // respective responses are sent.
 type ServerConn struct {
+       lk              sync.Mutex // read-write protects the following fields
        c               net.Conn
        r               *bufio.Reader
-       clsd            bool     // indicates a graceful close
        re, we          os.Error // read/write errors
        lastbody        io.ReadCloser
        nread, nwritten int
-       pipe            textproto.Pipeline
        pipereq         map[*Request]uint
-       lk              sync.Mutex // protected read/write to re,we
+
+       pipe textproto.Pipeline
 }
 
 // NewServerConn returns a new ServerConn reading and writing c.  If r is not
@@ -90,15 +90,21 @@ func (sc *ServerConn) Read() (req *Request, err os.Error) {
                defer sc.lk.Unlock()
                return nil, sc.re
        }
+       if sc.r == nil { // connection closed by user in the meantime
+               defer sc.lk.Unlock()
+               return nil, os.EBADF
+       }
+       r := sc.r
+       lastbody := sc.lastbody
+       sc.lastbody = nil
        sc.lk.Unlock()
 
        // Make sure body is fully consumed, even if user does not call body.Close
-       if sc.lastbody != nil {
+       if lastbody != nil {
                // body.Close is assumed to be idempotent and multiple calls to
                // it should return the error that its first invokation
                // returned.
-               err = sc.lastbody.Close()
-               sc.lastbody = nil
+               err = lastbody.Close()
                if err != nil {
                        sc.lk.Lock()
                        defer sc.lk.Unlock()
@@ -107,10 +113,10 @@ func (sc *ServerConn) Read() (req *Request, err os.Error) {
                }
        }
 
-       req, err = ReadRequest(sc.r)
+       req, err = ReadRequest(r)
+       sc.lk.Lock()
+       defer sc.lk.Unlock()
        if err != nil {
-               sc.lk.Lock()
-               defer sc.lk.Unlock()
                if err == io.ErrUnexpectedEOF {
                        // A close from the opposing client is treated as a
                        // graceful close, even if there was some unparse-able
@@ -119,18 +125,16 @@ func (sc *ServerConn) Read() (req *Request, err os.Error) {
                        return nil, sc.re
                } else {
                        sc.re = err
-                       return
+                       return req, err
                }
        }
        sc.lastbody = req.Body
        sc.nread++
        if req.Close {
-               sc.lk.Lock()
-               defer sc.lk.Unlock()
                sc.re = ErrPersistEOF
                return req, sc.re
        }
-       return
+       return req, err
 }
 
 // Pending returns the number of unanswered requests
@@ -165,24 +169,27 @@ func (sc *ServerConn) Write(req *Request, resp *Response) os.Error {
                defer sc.lk.Unlock()
                return sc.we
        }
-       sc.lk.Unlock()
+       if sc.c == nil { // connection closed by user in the meantime
+               defer sc.lk.Unlock()
+               return os.EBADF
+       }
+       c := sc.c
        if sc.nread <= sc.nwritten {
+               defer sc.lk.Unlock()
                return os.NewError("persist server pipe count")
        }
-
        if resp.Close {
                // After signaling a keep-alive close, any pipelined unread
                // requests will be lost. It is up to the user to drain them
                // before signaling.
-               sc.lk.Lock()
                sc.re = ErrPersistEOF
-               sc.lk.Unlock()
        }
+       sc.lk.Unlock()
 
-       err := resp.Write(sc.c)
+       err := resp.Write(c)
+       sc.lk.Lock()
+       defer sc.lk.Unlock()
        if err != nil {
-               sc.lk.Lock()
-               defer sc.lk.Unlock()
                sc.we = err
                return err
        }
@@ -196,14 +203,15 @@ func (sc *ServerConn) Write(req *Request, resp *Response) os.Error {
 // responsible for closing the underlying connection. One must call Close to
 // regain control of that connection and deal with it as desired.
 type ClientConn struct {
+       lk              sync.Mutex // read-write protects the following fields
        c               net.Conn
        r               *bufio.Reader
        re, we          os.Error // read/write errors
        lastbody        io.ReadCloser
        nread, nwritten int
-       pipe            textproto.Pipeline
        pipereq         map[*Request]uint
-       lk              sync.Mutex // protects read/write to re,we,pipereq,etc.
+
+       pipe textproto.Pipeline
 }
 
 // NewClientConn returns a new ClientConn reading and writing c.  If r is not
@@ -221,11 +229,11 @@ func NewClientConn(c net.Conn, r *bufio.Reader) *ClientConn {
 // logic. The user should not call Close while Read or Write is in progress.
 func (cc *ClientConn) Close() (c net.Conn, r *bufio.Reader) {
        cc.lk.Lock()
+       defer cc.lk.Unlock()
        c = cc.c
        r = cc.r
        cc.c = nil
        cc.r = nil
-       cc.lk.Unlock()
        return
 }
 
@@ -261,20 +269,22 @@ func (cc *ClientConn) Write(req *Request) (err os.Error) {
                defer cc.lk.Unlock()
                return cc.we
        }
-       cc.lk.Unlock()
-
+       if cc.c == nil { // connection closed by user in the meantime
+               defer cc.lk.Unlock()
+               return os.EBADF
+       }
+       c := cc.c
        if req.Close {
                // We write the EOF to the write-side error, because there
                // still might be some pipelined reads
-               cc.lk.Lock()
                cc.we = ErrPersistEOF
-               cc.lk.Unlock()
        }
+       cc.lk.Unlock()
 
-       err = req.Write(cc.c)
+       err = req.Write(c)
+       cc.lk.Lock()
+       defer cc.lk.Unlock()
        if err != nil {
-               cc.lk.Lock()
-               defer cc.lk.Unlock()
                cc.we = err
                return err
        }
@@ -316,15 +326,21 @@ func (cc *ClientConn) Read(req *Request) (resp *Response, err os.Error) {
                defer cc.lk.Unlock()
                return nil, cc.re
        }
+       if cc.r == nil { // connection closed by user in the meantime
+               defer cc.lk.Unlock()
+               return nil, os.EBADF
+       }
+       r := cc.r
+       lastbody := cc.lastbody
+       cc.lastbody = nil
        cc.lk.Unlock()
 
        // Make sure body is fully consumed, even if user does not call body.Close
-       if cc.lastbody != nil {
+       if lastbody != nil {
                // body.Close is assumed to be idempotent and multiple calls to
                // it should return the error that its first invokation
                // returned.
-               err = cc.lastbody.Close()
-               cc.lastbody = nil
+               err = lastbody.Close()
                if err != nil {
                        cc.lk.Lock()
                        defer cc.lk.Unlock()
@@ -333,24 +349,22 @@ func (cc *ClientConn) Read(req *Request) (resp *Response, err os.Error) {
                }
        }
 
-       resp, err = ReadResponse(cc.r, req.Method)
+       resp, err = ReadResponse(r, req.Method)
+       cc.lk.Lock()
+       defer cc.lk.Unlock()
        if err != nil {
-               cc.lk.Lock()
-               defer cc.lk.Unlock()
                cc.re = err
-               return
+               return resp, err
        }
        cc.lastbody = resp.Body
 
        cc.nread++
 
        if resp.Close {
-               cc.lk.Lock()
-               defer cc.lk.Unlock()
                cc.re = ErrPersistEOF // don't send any more requests
                return resp, cc.re
        }
-       return
+       return resp, err
 }
 
 // Do is convenience method that writes a request and reads a response.