From: Daniel Theophanes Date: Thu, 17 Nov 2016 17:33:31 +0000 (-0800) Subject: database/sql: ensure all driver Stmt are closed once X-Git-Tag: go1.8beta1~114 X-Git-Url: http://www.git.cypherpunks.su/?a=commitdiff_plain;h=90b8a0ca2d0b565c7c7199ffcf77b15ea6b6db3a;p=gostls13.git database/sql: ensure all driver Stmt are closed once Previously driver.Stmt could could be closed multiple times in edge cases that drivers may not test for initially. Make their job easier by ensuring the driver is only closed a single time. Fixes #16019 Change-Id: I1e4777ef70697a849602e6ef9da73054a8feb4cd Reviewed-on: https://go-review.googlesource.com/33352 Reviewed-by: Brad Fitzpatrick Run-TryBot: Brad Fitzpatrick TryBot-Result: Gobot Gobot --- diff --git a/src/database/sql/sql.go b/src/database/sql/sql.go index a549e859a4..a02aa35b7b 100644 --- a/src/database/sql/sql.go +++ b/src/database/sql/sql.go @@ -330,7 +330,7 @@ type driverConn struct { ci driver.Conn closed bool finalClosed bool // ci.Close has been called - openStmt map[driver.Stmt]bool + openStmt map[*driverStmt]bool // guarded by db.mu inUse bool @@ -342,10 +342,10 @@ func (dc *driverConn) releaseConn(err error) { dc.db.putConn(dc, err) } -func (dc *driverConn) removeOpenStmt(si driver.Stmt) { +func (dc *driverConn) removeOpenStmt(ds *driverStmt) { dc.Lock() defer dc.Unlock() - delete(dc.openStmt, si) + delete(dc.openStmt, ds) } func (dc *driverConn) expired(timeout time.Duration) bool { @@ -355,28 +355,23 @@ func (dc *driverConn) expired(timeout time.Duration) bool { return dc.createdAt.Add(timeout).Before(nowFunc()) } -func (dc *driverConn) prepareLocked(ctx context.Context, query string) (driver.Stmt, error) { +func (dc *driverConn) prepareLocked(ctx context.Context, query string) (*driverStmt, error) { si, err := ctxDriverPrepare(ctx, dc.ci, query) - if err == nil { - // Track each driverConn's open statements, so we can close them - // before closing the conn. - // - // TODO(bradfitz): let drivers opt out of caring about - // stmt closes if the conn is about to close anyway? For now - // do the safe thing, in case stmts need to be closed. - // - // TODO(bradfitz): after Go 1.2, closing driver.Stmts - // should be moved to driverStmt, using unique - // *driverStmts everywhere (including from - // *Stmt.connStmt, instead of returning a - // driver.Stmt), using driverStmt as a pointer - // everywhere, and making it a finalCloser. - if dc.openStmt == nil { - dc.openStmt = make(map[driver.Stmt]bool) - } - dc.openStmt[si] = true + if err != nil { + return nil, err } - return si, err + + // Track each driverConn's open statements, so we can close them + // before closing the conn. + // + // Wrap all driver.Stmt is *driverStmt to ensure they are only closed once. + if dc.openStmt == nil { + dc.openStmt = make(map[*driverStmt]bool) + } + ds := &driverStmt{Locker: dc, si: si} + dc.openStmt[ds] = true + + return ds, nil } // the dc.db's Mutex is held. @@ -409,17 +404,24 @@ func (dc *driverConn) Close() error { func (dc *driverConn) finalClose() error { var err error + + // Each *driverStmt has a lock to the dc. Copy the list out of the dc + // before calling close on each stmt. + var openStmt []*driverStmt withLock(dc, func() { - defer func() { // In case si.Close panics. - dc.openStmt = nil - dc.finalClosed = true - err = dc.ci.Close() - dc.ci = nil - }() - - for si := range dc.openStmt { - si.Close() + openStmt = make([]*driverStmt, 0, len(dc.openStmt)) + for ds := range dc.openStmt { + openStmt = append(openStmt, ds) } + dc.openStmt = nil + }) + for _, ds := range openStmt { + ds.Close() + } + withLock(dc, func() { + dc.finalClosed = true + err = dc.ci.Close() + dc.ci = nil }) dc.db.mu.Lock() @@ -437,12 +439,21 @@ func (dc *driverConn) finalClose() error { type driverStmt struct { sync.Locker // the *driverConn si driver.Stmt + closed bool + closeErr error // return value of previous Close call } +// Close ensures dirver.Stmt is only closed once any always returns the same +// result. func (ds *driverStmt) Close() error { ds.Lock() defer ds.Unlock() - return ds.si.Close() + if ds.closed { + return ds.closeErr + } + ds.closed = true + ds.closeErr = ds.si.Close() + return ds.closeErr } // depSet is a finalCloser's outstanding dependencies @@ -933,21 +944,22 @@ func (db *DB) conn(ctx context.Context, strategy connReuseStrategy) (*driverConn // putConnHook is a hook for testing. var putConnHook func(*DB, *driverConn) -// noteUnusedDriverStatement notes that si is no longer used and should +// noteUnusedDriverStatement notes that ds is no longer used and should // be closed whenever possible (when c is next not in use), unless c is // already closed. -func (db *DB) noteUnusedDriverStatement(c *driverConn, si driver.Stmt) { +func (db *DB) noteUnusedDriverStatement(c *driverConn, ds *driverStmt) { db.mu.Lock() defer db.mu.Unlock() if c.inUse { c.onPut = append(c.onPut, func() { - si.Close() + ds.Close() }) } else { c.Lock() - defer c.Unlock() - if !c.finalClosed { - si.Close() + fc := c.finalClosed + c.Unlock() + if !fc { + ds.Close() } } } @@ -1084,9 +1096,9 @@ func (db *DB) prepare(ctx context.Context, query string, strategy connReuseStrat if err != nil { return nil, err } - var si driver.Stmt + var ds *driverStmt withLock(dc, func() { - si, err = dc.prepareLocked(ctx, query) + ds, err = dc.prepareLocked(ctx, query) }) if err != nil { db.putConn(dc, err) @@ -1095,7 +1107,7 @@ func (db *DB) prepare(ctx context.Context, query string, strategy connReuseStrat stmt := &Stmt{ db: db, query: query, - css: []connStmt{{dc, si}}, + css: []connStmt{{dc, ds}}, lastNumClosed: atomic.LoadUint64(&db.numClosed), } db.addDep(stmt, stmt) @@ -1160,8 +1172,9 @@ func (db *DB) exec(ctx context.Context, query string, args []interface{}, strate if err != nil { return nil, err } - defer withLock(dc, func() { si.Close() }) - return resultFromStatement(ctx, driverStmt{dc, si}, args...) + ds := &driverStmt{Locker: dc, si: si} + defer ds.Close() + return resultFromStatement(ctx, ds, args...) } // QueryContext executes a query that returns rows, typically a SELECT. @@ -1236,12 +1249,10 @@ func (db *DB) queryConn(ctx context.Context, dc *driverConn, releaseConn func(er return nil, err } - ds := driverStmt{dc, si} + ds := &driverStmt{Locker: dc, si: si} rowsi, err := rowsiFromStatement(ctx, ds, args...) if err != nil { - withLock(dc, func() { - si.Close() - }) + ds.Close() releaseConn(err) return nil, err } @@ -1252,7 +1263,7 @@ func (db *DB) queryConn(ctx context.Context, dc *driverConn, releaseConn func(er dc: dc, releaseConn: releaseConn, rowsi: rowsi, - closeStmt: si, + closeStmt: ds, } rows.initContextClose(ctx) return rows, nil @@ -1476,7 +1487,7 @@ func (tx *Tx) PrepareContext(ctx context.Context, query string) (*Stmt, error) { stmt := &Stmt{ db: tx.db, tx: tx, - txsi: &driverStmt{ + txds: &driverStmt{ Locker: dc, si: si, }, @@ -1530,7 +1541,7 @@ func (tx *Tx) StmtContext(ctx context.Context, stmt *Stmt) *Stmt { txs := &Stmt{ db: tx.db, tx: tx, - txsi: &driverStmt{ + txds: &driverStmt{ Locker: dc, si: si, }, @@ -1591,9 +1602,10 @@ func (tx *Tx) ExecContext(ctx context.Context, query string, args ...interface{} if err != nil { return nil, err } - defer withLock(dc, func() { si.Close() }) + ds := &driverStmt{Locker: dc, si: si} + defer ds.Close() - return resultFromStatement(ctx, driverStmt{dc, si}, args...) + return resultFromStatement(ctx, ds, args...) } // Exec executes a query that doesn't return rows. @@ -1635,7 +1647,7 @@ func (tx *Tx) QueryRow(query string, args ...interface{}) *Row { // connStmt is a prepared statement on a particular connection. type connStmt struct { dc *driverConn - si driver.Stmt + ds *driverStmt } // Stmt is a prepared statement. @@ -1650,7 +1662,7 @@ type Stmt struct { // If in a transaction, else both nil: tx *Tx - txsi *driverStmt + txds *driverStmt mu sync.Mutex // protects the rest of the fields closed bool @@ -1674,7 +1686,7 @@ func (s *Stmt) ExecContext(ctx context.Context, args ...interface{}) (Result, er var res Result for i := 0; i < maxBadConnRetries; i++ { - dc, releaseConn, si, err := s.connStmt(ctx) + _, releaseConn, ds, err := s.connStmt(ctx) if err != nil { if err == driver.ErrBadConn { continue @@ -1682,7 +1694,7 @@ func (s *Stmt) ExecContext(ctx context.Context, args ...interface{}) (Result, er return nil, err } - res, err = resultFromStatement(ctx, driverStmt{dc, si}, args...) + res, err = resultFromStatement(ctx, ds, args...) releaseConn(err) if err != driver.ErrBadConn { return res, err @@ -1697,13 +1709,13 @@ func (s *Stmt) Exec(args ...interface{}) (Result, error) { return s.ExecContext(context.Background(), args...) } -func driverNumInput(ds driverStmt) int { +func driverNumInput(ds *driverStmt) int { ds.Lock() defer ds.Unlock() // in case NumInput panics return ds.si.NumInput() } -func resultFromStatement(ctx context.Context, ds driverStmt, args ...interface{}) (Result, error) { +func resultFromStatement(ctx context.Context, ds *driverStmt, args ...interface{}) (Result, error) { want := driverNumInput(ds) // -1 means the driver doesn't know how to count the number of @@ -1713,7 +1725,7 @@ func resultFromStatement(ctx context.Context, ds driverStmt, args ...interface{} return nil, fmt.Errorf("sql: expected %d arguments, got %d", want, len(args)) } - dargs, err := driverArgs(&ds, args) + dargs, err := driverArgs(ds, args) if err != nil { return nil, err } @@ -1757,7 +1769,7 @@ func (s *Stmt) removeClosedStmtLocked() { // connStmt returns a free driver connection on which to execute the // statement, a function to call to release the connection, and a // statement bound to that connection. -func (s *Stmt) connStmt(ctx context.Context) (ci *driverConn, releaseConn func(error), si driver.Stmt, err error) { +func (s *Stmt) connStmt(ctx context.Context) (ci *driverConn, releaseConn func(error), ds *driverStmt, err error) { if err = s.stickyErr; err != nil { return } @@ -1777,7 +1789,7 @@ func (s *Stmt) connStmt(ctx context.Context) (ci *driverConn, releaseConn func(e return } releaseConn = func(error) {} - return ci, releaseConn, s.txsi.si, nil + return ci, releaseConn, s.txds, nil } s.removeClosedStmtLocked() @@ -1792,25 +1804,25 @@ func (s *Stmt) connStmt(ctx context.Context) (ci *driverConn, releaseConn func(e for _, v := range s.css { if v.dc == dc { s.mu.Unlock() - return dc, dc.releaseConn, v.si, nil + return dc, dc.releaseConn, v.ds, nil } } s.mu.Unlock() // No luck; we need to prepare the statement on this connection withLock(dc, func() { - si, err = dc.prepareLocked(ctx, s.query) + ds, err = dc.prepareLocked(ctx, s.query) }) if err != nil { s.db.putConn(dc, err) return nil, nil, nil, err } s.mu.Lock() - cs := connStmt{dc, si} + cs := connStmt{dc, ds} s.css = append(s.css, cs) s.mu.Unlock() - return dc, dc.releaseConn, si, nil + return dc, dc.releaseConn, ds, nil } // QueryContext executes a prepared query statement with the given arguments @@ -1821,7 +1833,7 @@ func (s *Stmt) QueryContext(ctx context.Context, args ...interface{}) (*Rows, er var rowsi driver.Rows for i := 0; i < maxBadConnRetries; i++ { - dc, releaseConn, si, err := s.connStmt(ctx) + dc, releaseConn, ds, err := s.connStmt(ctx) if err != nil { if err == driver.ErrBadConn { continue @@ -1829,7 +1841,7 @@ func (s *Stmt) QueryContext(ctx context.Context, args ...interface{}) (*Rows, er return nil, err } - rowsi, err = rowsiFromStatement(ctx, driverStmt{dc, si}, args...) + rowsi, err = rowsiFromStatement(ctx, ds, args...) if err == nil { // Note: ownership of ci passes to the *Rows, to be freed // with releaseConn. @@ -1861,7 +1873,7 @@ func (s *Stmt) Query(args ...interface{}) (*Rows, error) { return s.QueryContext(context.Background(), args...) } -func rowsiFromStatement(ctx context.Context, ds driverStmt, args ...interface{}) (driver.Rows, error) { +func rowsiFromStatement(ctx context.Context, ds *driverStmt, args ...interface{}) (driver.Rows, error) { var want int withLock(ds, func() { want = ds.si.NumInput() @@ -1874,7 +1886,7 @@ func rowsiFromStatement(ctx context.Context, ds driverStmt, args ...interface{}) return nil, fmt.Errorf("sql: statement expects %d inputs; got %d", want, len(args)) } - dargs, err := driverArgs(&ds, args) + dargs, err := driverArgs(ds, args) if err != nil { return nil, err } @@ -1937,12 +1949,11 @@ func (s *Stmt) Close() error { return nil } s.closed = true + s.mu.Unlock() if s.tx != nil { - defer s.mu.Unlock() - return s.txsi.Close() + return s.txds.Close() } - s.mu.Unlock() return s.db.removeDep(s, s) } @@ -1952,8 +1963,8 @@ func (s *Stmt) finalClose() error { defer s.mu.Unlock() if s.css != nil { for _, v := range s.css { - s.db.noteUnusedDriverStatement(v.dc, v.si) - v.dc.removeOpenStmt(v.si) + s.db.noteUnusedDriverStatement(v.dc, v.ds) + v.dc.removeOpenStmt(v.ds) } s.css = nil } @@ -1985,7 +1996,7 @@ type Rows struct { ctxClose chan struct{} // closed when Rows is closed, may be null. lastcols []driver.Value lasterr error // non-nil only if closed is true - closeStmt driver.Stmt // if non-nil, statement to Close on close + closeStmt *driverStmt // if non-nil, statement to Close on close } func (rs *Rows) initContextClose(ctx context.Context) { diff --git a/src/database/sql/sql_test.go b/src/database/sql/sql_test.go index ea86264ae6..c46aaf60f8 100644 --- a/src/database/sql/sql_test.go +++ b/src/database/sql/sql_test.go @@ -672,7 +672,7 @@ func TestStatementClose(t *testing.T) { msg string }{ {&Stmt{stickyErr: want}, "stickyErr not propagated"}, - {&Stmt{tx: &Tx{}, txsi: &driverStmt{&sync.Mutex{}, stubDriverStmt{want}}}, "driverStmt.Close() error not propagated"}, + {&Stmt{tx: &Tx{}, txds: &driverStmt{Locker: &sync.Mutex{}, si: stubDriverStmt{want}}}, "driverStmt.Close() error not propagated"}, } for _, test := range tests { if err := test.stmt.Close(); err != want {