]> Cypherpunks repositories - gostls13.git/commitdiff
database/sql: ensure all driver Stmt are closed once
authorDaniel Theophanes <kardianos@gmail.com>
Thu, 17 Nov 2016 17:33:31 +0000 (09:33 -0800)
committerBrad Fitzpatrick <bradfitz@golang.org>
Thu, 17 Nov 2016 18:13:41 +0000 (18:13 +0000)
Previously  driver.Stmt could could be closed multiple times in
edge cases that drivers may not test for initially. Make their
job easier by ensuring the driver is only closed a single time.

Fixes #16019

Change-Id: I1e4777ef70697a849602e6ef9da73054a8feb4cd
Reviewed-on: https://go-review.googlesource.com/33352
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>

src/database/sql/sql.go
src/database/sql/sql_test.go

index a549e859a4792769caed3d9cc784b0835fb463ea..a02aa35b7be8508ec22a7043f557b552f43dbbe7 100644 (file)
@@ -330,7 +330,7 @@ type driverConn struct {
        ci          driver.Conn
        closed      bool
        finalClosed bool // ci.Close has been called
-       openStmt    map[driver.Stmt]bool
+       openStmt    map[*driverStmt]bool
 
        // guarded by db.mu
        inUse      bool
@@ -342,10 +342,10 @@ func (dc *driverConn) releaseConn(err error) {
        dc.db.putConn(dc, err)
 }
 
-func (dc *driverConn) removeOpenStmt(si driver.Stmt) {
+func (dc *driverConn) removeOpenStmt(ds *driverStmt) {
        dc.Lock()
        defer dc.Unlock()
-       delete(dc.openStmt, si)
+       delete(dc.openStmt, ds)
 }
 
 func (dc *driverConn) expired(timeout time.Duration) bool {
@@ -355,28 +355,23 @@ func (dc *driverConn) expired(timeout time.Duration) bool {
        return dc.createdAt.Add(timeout).Before(nowFunc())
 }
 
-func (dc *driverConn) prepareLocked(ctx context.Context, query string) (driver.Stmt, error) {
+func (dc *driverConn) prepareLocked(ctx context.Context, query string) (*driverStmt, error) {
        si, err := ctxDriverPrepare(ctx, dc.ci, query)
-       if err == nil {
-               // Track each driverConn's open statements, so we can close them
-               // before closing the conn.
-               //
-               // TODO(bradfitz): let drivers opt out of caring about
-               // stmt closes if the conn is about to close anyway? For now
-               // do the safe thing, in case stmts need to be closed.
-               //
-               // TODO(bradfitz): after Go 1.2, closing driver.Stmts
-               // should be moved to driverStmt, using unique
-               // *driverStmts everywhere (including from
-               // *Stmt.connStmt, instead of returning a
-               // driver.Stmt), using driverStmt as a pointer
-               // everywhere, and making it a finalCloser.
-               if dc.openStmt == nil {
-                       dc.openStmt = make(map[driver.Stmt]bool)
-               }
-               dc.openStmt[si] = true
+       if err != nil {
+               return nil, err
        }
-       return si, err
+
+       // Track each driverConn's open statements, so we can close them
+       // before closing the conn.
+       //
+       // Wrap all driver.Stmt is *driverStmt to ensure they are only closed once.
+       if dc.openStmt == nil {
+               dc.openStmt = make(map[*driverStmt]bool)
+       }
+       ds := &driverStmt{Locker: dc, si: si}
+       dc.openStmt[ds] = true
+
+       return ds, nil
 }
 
 // the dc.db's Mutex is held.
@@ -409,17 +404,24 @@ func (dc *driverConn) Close() error {
 
 func (dc *driverConn) finalClose() error {
        var err error
+
+       // Each *driverStmt has a lock to the dc. Copy the list out of the dc
+       // before calling close on each stmt.
+       var openStmt []*driverStmt
        withLock(dc, func() {
-               defer func() { // In case si.Close panics.
-                       dc.openStmt = nil
-                       dc.finalClosed = true
-                       err = dc.ci.Close()
-                       dc.ci = nil
-               }()
-
-               for si := range dc.openStmt {
-                       si.Close()
+               openStmt = make([]*driverStmt, 0, len(dc.openStmt))
+               for ds := range dc.openStmt {
+                       openStmt = append(openStmt, ds)
                }
+               dc.openStmt = nil
+       })
+       for _, ds := range openStmt {
+               ds.Close()
+       }
+       withLock(dc, func() {
+               dc.finalClosed = true
+               err = dc.ci.Close()
+               dc.ci = nil
        })
 
        dc.db.mu.Lock()
@@ -437,12 +439,21 @@ func (dc *driverConn) finalClose() error {
 type driverStmt struct {
        sync.Locker // the *driverConn
        si          driver.Stmt
+       closed      bool
+       closeErr    error // return value of previous Close call
 }
 
+// Close ensures dirver.Stmt is only closed once any always returns the same
+// result.
 func (ds *driverStmt) Close() error {
        ds.Lock()
        defer ds.Unlock()
-       return ds.si.Close()
+       if ds.closed {
+               return ds.closeErr
+       }
+       ds.closed = true
+       ds.closeErr = ds.si.Close()
+       return ds.closeErr
 }
 
 // depSet is a finalCloser's outstanding dependencies
@@ -933,21 +944,22 @@ func (db *DB) conn(ctx context.Context, strategy connReuseStrategy) (*driverConn
 // putConnHook is a hook for testing.
 var putConnHook func(*DB, *driverConn)
 
-// noteUnusedDriverStatement notes that si is no longer used and should
+// noteUnusedDriverStatement notes that ds is no longer used and should
 // be closed whenever possible (when c is next not in use), unless c is
 // already closed.
-func (db *DB) noteUnusedDriverStatement(c *driverConn, si driver.Stmt) {
+func (db *DB) noteUnusedDriverStatement(c *driverConn, ds *driverStmt) {
        db.mu.Lock()
        defer db.mu.Unlock()
        if c.inUse {
                c.onPut = append(c.onPut, func() {
-                       si.Close()
+                       ds.Close()
                })
        } else {
                c.Lock()
-               defer c.Unlock()
-               if !c.finalClosed {
-                       si.Close()
+               fc := c.finalClosed
+               c.Unlock()
+               if !fc {
+                       ds.Close()
                }
        }
 }
@@ -1084,9 +1096,9 @@ func (db *DB) prepare(ctx context.Context, query string, strategy connReuseStrat
        if err != nil {
                return nil, err
        }
-       var si driver.Stmt
+       var ds *driverStmt
        withLock(dc, func() {
-               si, err = dc.prepareLocked(ctx, query)
+               ds, err = dc.prepareLocked(ctx, query)
        })
        if err != nil {
                db.putConn(dc, err)
@@ -1095,7 +1107,7 @@ func (db *DB) prepare(ctx context.Context, query string, strategy connReuseStrat
        stmt := &Stmt{
                db:            db,
                query:         query,
-               css:           []connStmt{{dc, si}},
+               css:           []connStmt{{dc, ds}},
                lastNumClosed: atomic.LoadUint64(&db.numClosed),
        }
        db.addDep(stmt, stmt)
@@ -1160,8 +1172,9 @@ func (db *DB) exec(ctx context.Context, query string, args []interface{}, strate
        if err != nil {
                return nil, err
        }
-       defer withLock(dc, func() { si.Close() })
-       return resultFromStatement(ctx, driverStmt{dc, si}, args...)
+       ds := &driverStmt{Locker: dc, si: si}
+       defer ds.Close()
+       return resultFromStatement(ctx, ds, args...)
 }
 
 // QueryContext executes a query that returns rows, typically a SELECT.
@@ -1236,12 +1249,10 @@ func (db *DB) queryConn(ctx context.Context, dc *driverConn, releaseConn func(er
                return nil, err
        }
 
-       ds := driverStmt{dc, si}
+       ds := &driverStmt{Locker: dc, si: si}
        rowsi, err := rowsiFromStatement(ctx, ds, args...)
        if err != nil {
-               withLock(dc, func() {
-                       si.Close()
-               })
+               ds.Close()
                releaseConn(err)
                return nil, err
        }
@@ -1252,7 +1263,7 @@ func (db *DB) queryConn(ctx context.Context, dc *driverConn, releaseConn func(er
                dc:          dc,
                releaseConn: releaseConn,
                rowsi:       rowsi,
-               closeStmt:   si,
+               closeStmt:   ds,
        }
        rows.initContextClose(ctx)
        return rows, nil
@@ -1476,7 +1487,7 @@ func (tx *Tx) PrepareContext(ctx context.Context, query string) (*Stmt, error) {
        stmt := &Stmt{
                db: tx.db,
                tx: tx,
-               txsi: &driverStmt{
+               txds: &driverStmt{
                        Locker: dc,
                        si:     si,
                },
@@ -1530,7 +1541,7 @@ func (tx *Tx) StmtContext(ctx context.Context, stmt *Stmt) *Stmt {
        txs := &Stmt{
                db: tx.db,
                tx: tx,
-               txsi: &driverStmt{
+               txds: &driverStmt{
                        Locker: dc,
                        si:     si,
                },
@@ -1591,9 +1602,10 @@ func (tx *Tx) ExecContext(ctx context.Context, query string, args ...interface{}
        if err != nil {
                return nil, err
        }
-       defer withLock(dc, func() { si.Close() })
+       ds := &driverStmt{Locker: dc, si: si}
+       defer ds.Close()
 
-       return resultFromStatement(ctx, driverStmt{dc, si}, args...)
+       return resultFromStatement(ctx, ds, args...)
 }
 
 // Exec executes a query that doesn't return rows.
@@ -1635,7 +1647,7 @@ func (tx *Tx) QueryRow(query string, args ...interface{}) *Row {
 // connStmt is a prepared statement on a particular connection.
 type connStmt struct {
        dc *driverConn
-       si driver.Stmt
+       ds *driverStmt
 }
 
 // Stmt is a prepared statement.
@@ -1650,7 +1662,7 @@ type Stmt struct {
 
        // If in a transaction, else both nil:
        tx   *Tx
-       txsi *driverStmt
+       txds *driverStmt
 
        mu     sync.Mutex // protects the rest of the fields
        closed bool
@@ -1674,7 +1686,7 @@ func (s *Stmt) ExecContext(ctx context.Context, args ...interface{}) (Result, er
 
        var res Result
        for i := 0; i < maxBadConnRetries; i++ {
-               dc, releaseConn, si, err := s.connStmt(ctx)
+               _, releaseConn, ds, err := s.connStmt(ctx)
                if err != nil {
                        if err == driver.ErrBadConn {
                                continue
@@ -1682,7 +1694,7 @@ func (s *Stmt) ExecContext(ctx context.Context, args ...interface{}) (Result, er
                        return nil, err
                }
 
-               res, err = resultFromStatement(ctx, driverStmt{dc, si}, args...)
+               res, err = resultFromStatement(ctx, ds, args...)
                releaseConn(err)
                if err != driver.ErrBadConn {
                        return res, err
@@ -1697,13 +1709,13 @@ func (s *Stmt) Exec(args ...interface{}) (Result, error) {
        return s.ExecContext(context.Background(), args...)
 }
 
-func driverNumInput(ds driverStmt) int {
+func driverNumInput(ds *driverStmt) int {
        ds.Lock()
        defer ds.Unlock() // in case NumInput panics
        return ds.si.NumInput()
 }
 
-func resultFromStatement(ctx context.Context, ds driverStmt, args ...interface{}) (Result, error) {
+func resultFromStatement(ctx context.Context, ds *driverStmt, args ...interface{}) (Result, error) {
        want := driverNumInput(ds)
 
        // -1 means the driver doesn't know how to count the number of
@@ -1713,7 +1725,7 @@ func resultFromStatement(ctx context.Context, ds driverStmt, args ...interface{}
                return nil, fmt.Errorf("sql: expected %d arguments, got %d", want, len(args))
        }
 
-       dargs, err := driverArgs(&ds, args)
+       dargs, err := driverArgs(ds, args)
        if err != nil {
                return nil, err
        }
@@ -1757,7 +1769,7 @@ func (s *Stmt) removeClosedStmtLocked() {
 // connStmt returns a free driver connection on which to execute the
 // statement, a function to call to release the connection, and a
 // statement bound to that connection.
-func (s *Stmt) connStmt(ctx context.Context) (ci *driverConn, releaseConn func(error), si driver.Stmt, err error) {
+func (s *Stmt) connStmt(ctx context.Context) (ci *driverConn, releaseConn func(error), ds *driverStmt, err error) {
        if err = s.stickyErr; err != nil {
                return
        }
@@ -1777,7 +1789,7 @@ func (s *Stmt) connStmt(ctx context.Context) (ci *driverConn, releaseConn func(e
                        return
                }
                releaseConn = func(error) {}
-               return ci, releaseConn, s.txsi.si, nil
+               return ci, releaseConn, s.txds, nil
        }
 
        s.removeClosedStmtLocked()
@@ -1792,25 +1804,25 @@ func (s *Stmt) connStmt(ctx context.Context) (ci *driverConn, releaseConn func(e
        for _, v := range s.css {
                if v.dc == dc {
                        s.mu.Unlock()
-                       return dc, dc.releaseConn, v.si, nil
+                       return dc, dc.releaseConn, v.ds, nil
                }
        }
        s.mu.Unlock()
 
        // No luck; we need to prepare the statement on this connection
        withLock(dc, func() {
-               si, err = dc.prepareLocked(ctx, s.query)
+               ds, err = dc.prepareLocked(ctx, s.query)
        })
        if err != nil {
                s.db.putConn(dc, err)
                return nil, nil, nil, err
        }
        s.mu.Lock()
-       cs := connStmt{dc, si}
+       cs := connStmt{dc, ds}
        s.css = append(s.css, cs)
        s.mu.Unlock()
 
-       return dc, dc.releaseConn, si, nil
+       return dc, dc.releaseConn, ds, nil
 }
 
 // QueryContext executes a prepared query statement with the given arguments
@@ -1821,7 +1833,7 @@ func (s *Stmt) QueryContext(ctx context.Context, args ...interface{}) (*Rows, er
 
        var rowsi driver.Rows
        for i := 0; i < maxBadConnRetries; i++ {
-               dc, releaseConn, si, err := s.connStmt(ctx)
+               dc, releaseConn, ds, err := s.connStmt(ctx)
                if err != nil {
                        if err == driver.ErrBadConn {
                                continue
@@ -1829,7 +1841,7 @@ func (s *Stmt) QueryContext(ctx context.Context, args ...interface{}) (*Rows, er
                        return nil, err
                }
 
-               rowsi, err = rowsiFromStatement(ctx, driverStmt{dc, si}, args...)
+               rowsi, err = rowsiFromStatement(ctx, ds, args...)
                if err == nil {
                        // Note: ownership of ci passes to the *Rows, to be freed
                        // with releaseConn.
@@ -1861,7 +1873,7 @@ func (s *Stmt) Query(args ...interface{}) (*Rows, error) {
        return s.QueryContext(context.Background(), args...)
 }
 
-func rowsiFromStatement(ctx context.Context, ds driverStmt, args ...interface{}) (driver.Rows, error) {
+func rowsiFromStatement(ctx context.Context, ds *driverStmt, args ...interface{}) (driver.Rows, error) {
        var want int
        withLock(ds, func() {
                want = ds.si.NumInput()
@@ -1874,7 +1886,7 @@ func rowsiFromStatement(ctx context.Context, ds driverStmt, args ...interface{})
                return nil, fmt.Errorf("sql: statement expects %d inputs; got %d", want, len(args))
        }
 
-       dargs, err := driverArgs(&ds, args)
+       dargs, err := driverArgs(ds, args)
        if err != nil {
                return nil, err
        }
@@ -1937,12 +1949,11 @@ func (s *Stmt) Close() error {
                return nil
        }
        s.closed = true
+       s.mu.Unlock()
 
        if s.tx != nil {
-               defer s.mu.Unlock()
-               return s.txsi.Close()
+               return s.txds.Close()
        }
-       s.mu.Unlock()
 
        return s.db.removeDep(s, s)
 }
@@ -1952,8 +1963,8 @@ func (s *Stmt) finalClose() error {
        defer s.mu.Unlock()
        if s.css != nil {
                for _, v := range s.css {
-                       s.db.noteUnusedDriverStatement(v.dc, v.si)
-                       v.dc.removeOpenStmt(v.si)
+                       s.db.noteUnusedDriverStatement(v.dc, v.ds)
+                       v.dc.removeOpenStmt(v.ds)
                }
                s.css = nil
        }
@@ -1985,7 +1996,7 @@ type Rows struct {
        ctxClose  chan struct{} // closed when Rows is closed, may be null.
        lastcols  []driver.Value
        lasterr   error       // non-nil only if closed is true
-       closeStmt driver.Stmt // if non-nil, statement to Close on close
+       closeStmt *driverStmt // if non-nil, statement to Close on close
 }
 
 func (rs *Rows) initContextClose(ctx context.Context) {
index ea86264ae6884261cabe297487747fee3ac5a567..c46aaf60f85c94d90bddad93bfe9b04e7a61b05a 100644 (file)
@@ -672,7 +672,7 @@ func TestStatementClose(t *testing.T) {
                msg  string
        }{
                {&Stmt{stickyErr: want}, "stickyErr not propagated"},
-               {&Stmt{tx: &Tx{}, txsi: &driverStmt{&sync.Mutex{}, stubDriverStmt{want}}}, "driverStmt.Close() error not propagated"},
+               {&Stmt{tx: &Tx{}, txds: &driverStmt{Locker: &sync.Mutex{}, si: stubDriverStmt{want}}}, "driverStmt.Close() error not propagated"},
        }
        for _, test := range tests {
                if err := test.stmt.Close(); err != want {