ci driver.Conn
closed bool
finalClosed bool // ci.Close has been called
- openStmt map[driver.Stmt]bool
+ openStmt map[*driverStmt]bool
// guarded by db.mu
inUse bool
dc.db.putConn(dc, err)
}
-func (dc *driverConn) removeOpenStmt(si driver.Stmt) {
+func (dc *driverConn) removeOpenStmt(ds *driverStmt) {
dc.Lock()
defer dc.Unlock()
- delete(dc.openStmt, si)
+ delete(dc.openStmt, ds)
}
func (dc *driverConn) expired(timeout time.Duration) bool {
return dc.createdAt.Add(timeout).Before(nowFunc())
}
-func (dc *driverConn) prepareLocked(ctx context.Context, query string) (driver.Stmt, error) {
+func (dc *driverConn) prepareLocked(ctx context.Context, query string) (*driverStmt, error) {
si, err := ctxDriverPrepare(ctx, dc.ci, query)
- if err == nil {
- // Track each driverConn's open statements, so we can close them
- // before closing the conn.
- //
- // TODO(bradfitz): let drivers opt out of caring about
- // stmt closes if the conn is about to close anyway? For now
- // do the safe thing, in case stmts need to be closed.
- //
- // TODO(bradfitz): after Go 1.2, closing driver.Stmts
- // should be moved to driverStmt, using unique
- // *driverStmts everywhere (including from
- // *Stmt.connStmt, instead of returning a
- // driver.Stmt), using driverStmt as a pointer
- // everywhere, and making it a finalCloser.
- if dc.openStmt == nil {
- dc.openStmt = make(map[driver.Stmt]bool)
- }
- dc.openStmt[si] = true
+ if err != nil {
+ return nil, err
}
- return si, err
+
+ // Track each driverConn's open statements, so we can close them
+ // before closing the conn.
+ //
+ // Wrap all driver.Stmt is *driverStmt to ensure they are only closed once.
+ if dc.openStmt == nil {
+ dc.openStmt = make(map[*driverStmt]bool)
+ }
+ ds := &driverStmt{Locker: dc, si: si}
+ dc.openStmt[ds] = true
+
+ return ds, nil
}
// the dc.db's Mutex is held.
func (dc *driverConn) finalClose() error {
var err error
+
+ // Each *driverStmt has a lock to the dc. Copy the list out of the dc
+ // before calling close on each stmt.
+ var openStmt []*driverStmt
withLock(dc, func() {
- defer func() { // In case si.Close panics.
- dc.openStmt = nil
- dc.finalClosed = true
- err = dc.ci.Close()
- dc.ci = nil
- }()
-
- for si := range dc.openStmt {
- si.Close()
+ openStmt = make([]*driverStmt, 0, len(dc.openStmt))
+ for ds := range dc.openStmt {
+ openStmt = append(openStmt, ds)
}
+ dc.openStmt = nil
+ })
+ for _, ds := range openStmt {
+ ds.Close()
+ }
+ withLock(dc, func() {
+ dc.finalClosed = true
+ err = dc.ci.Close()
+ dc.ci = nil
})
dc.db.mu.Lock()
type driverStmt struct {
sync.Locker // the *driverConn
si driver.Stmt
+ closed bool
+ closeErr error // return value of previous Close call
}
+// Close ensures dirver.Stmt is only closed once any always returns the same
+// result.
func (ds *driverStmt) Close() error {
ds.Lock()
defer ds.Unlock()
- return ds.si.Close()
+ if ds.closed {
+ return ds.closeErr
+ }
+ ds.closed = true
+ ds.closeErr = ds.si.Close()
+ return ds.closeErr
}
// depSet is a finalCloser's outstanding dependencies
// putConnHook is a hook for testing.
var putConnHook func(*DB, *driverConn)
-// noteUnusedDriverStatement notes that si is no longer used and should
+// noteUnusedDriverStatement notes that ds is no longer used and should
// be closed whenever possible (when c is next not in use), unless c is
// already closed.
-func (db *DB) noteUnusedDriverStatement(c *driverConn, si driver.Stmt) {
+func (db *DB) noteUnusedDriverStatement(c *driverConn, ds *driverStmt) {
db.mu.Lock()
defer db.mu.Unlock()
if c.inUse {
c.onPut = append(c.onPut, func() {
- si.Close()
+ ds.Close()
})
} else {
c.Lock()
- defer c.Unlock()
- if !c.finalClosed {
- si.Close()
+ fc := c.finalClosed
+ c.Unlock()
+ if !fc {
+ ds.Close()
}
}
}
if err != nil {
return nil, err
}
- var si driver.Stmt
+ var ds *driverStmt
withLock(dc, func() {
- si, err = dc.prepareLocked(ctx, query)
+ ds, err = dc.prepareLocked(ctx, query)
})
if err != nil {
db.putConn(dc, err)
stmt := &Stmt{
db: db,
query: query,
- css: []connStmt{{dc, si}},
+ css: []connStmt{{dc, ds}},
lastNumClosed: atomic.LoadUint64(&db.numClosed),
}
db.addDep(stmt, stmt)
if err != nil {
return nil, err
}
- defer withLock(dc, func() { si.Close() })
- return resultFromStatement(ctx, driverStmt{dc, si}, args...)
+ ds := &driverStmt{Locker: dc, si: si}
+ defer ds.Close()
+ return resultFromStatement(ctx, ds, args...)
}
// QueryContext executes a query that returns rows, typically a SELECT.
return nil, err
}
- ds := driverStmt{dc, si}
+ ds := &driverStmt{Locker: dc, si: si}
rowsi, err := rowsiFromStatement(ctx, ds, args...)
if err != nil {
- withLock(dc, func() {
- si.Close()
- })
+ ds.Close()
releaseConn(err)
return nil, err
}
dc: dc,
releaseConn: releaseConn,
rowsi: rowsi,
- closeStmt: si,
+ closeStmt: ds,
}
rows.initContextClose(ctx)
return rows, nil
stmt := &Stmt{
db: tx.db,
tx: tx,
- txsi: &driverStmt{
+ txds: &driverStmt{
Locker: dc,
si: si,
},
txs := &Stmt{
db: tx.db,
tx: tx,
- txsi: &driverStmt{
+ txds: &driverStmt{
Locker: dc,
si: si,
},
if err != nil {
return nil, err
}
- defer withLock(dc, func() { si.Close() })
+ ds := &driverStmt{Locker: dc, si: si}
+ defer ds.Close()
- return resultFromStatement(ctx, driverStmt{dc, si}, args...)
+ return resultFromStatement(ctx, ds, args...)
}
// Exec executes a query that doesn't return rows.
// connStmt is a prepared statement on a particular connection.
type connStmt struct {
dc *driverConn
- si driver.Stmt
+ ds *driverStmt
}
// Stmt is a prepared statement.
// If in a transaction, else both nil:
tx *Tx
- txsi *driverStmt
+ txds *driverStmt
mu sync.Mutex // protects the rest of the fields
closed bool
var res Result
for i := 0; i < maxBadConnRetries; i++ {
- dc, releaseConn, si, err := s.connStmt(ctx)
+ _, releaseConn, ds, err := s.connStmt(ctx)
if err != nil {
if err == driver.ErrBadConn {
continue
return nil, err
}
- res, err = resultFromStatement(ctx, driverStmt{dc, si}, args...)
+ res, err = resultFromStatement(ctx, ds, args...)
releaseConn(err)
if err != driver.ErrBadConn {
return res, err
return s.ExecContext(context.Background(), args...)
}
-func driverNumInput(ds driverStmt) int {
+func driverNumInput(ds *driverStmt) int {
ds.Lock()
defer ds.Unlock() // in case NumInput panics
return ds.si.NumInput()
}
-func resultFromStatement(ctx context.Context, ds driverStmt, args ...interface{}) (Result, error) {
+func resultFromStatement(ctx context.Context, ds *driverStmt, args ...interface{}) (Result, error) {
want := driverNumInput(ds)
// -1 means the driver doesn't know how to count the number of
return nil, fmt.Errorf("sql: expected %d arguments, got %d", want, len(args))
}
- dargs, err := driverArgs(&ds, args)
+ dargs, err := driverArgs(ds, args)
if err != nil {
return nil, err
}
// connStmt returns a free driver connection on which to execute the
// statement, a function to call to release the connection, and a
// statement bound to that connection.
-func (s *Stmt) connStmt(ctx context.Context) (ci *driverConn, releaseConn func(error), si driver.Stmt, err error) {
+func (s *Stmt) connStmt(ctx context.Context) (ci *driverConn, releaseConn func(error), ds *driverStmt, err error) {
if err = s.stickyErr; err != nil {
return
}
return
}
releaseConn = func(error) {}
- return ci, releaseConn, s.txsi.si, nil
+ return ci, releaseConn, s.txds, nil
}
s.removeClosedStmtLocked()
for _, v := range s.css {
if v.dc == dc {
s.mu.Unlock()
- return dc, dc.releaseConn, v.si, nil
+ return dc, dc.releaseConn, v.ds, nil
}
}
s.mu.Unlock()
// No luck; we need to prepare the statement on this connection
withLock(dc, func() {
- si, err = dc.prepareLocked(ctx, s.query)
+ ds, err = dc.prepareLocked(ctx, s.query)
})
if err != nil {
s.db.putConn(dc, err)
return nil, nil, nil, err
}
s.mu.Lock()
- cs := connStmt{dc, si}
+ cs := connStmt{dc, ds}
s.css = append(s.css, cs)
s.mu.Unlock()
- return dc, dc.releaseConn, si, nil
+ return dc, dc.releaseConn, ds, nil
}
// QueryContext executes a prepared query statement with the given arguments
var rowsi driver.Rows
for i := 0; i < maxBadConnRetries; i++ {
- dc, releaseConn, si, err := s.connStmt(ctx)
+ dc, releaseConn, ds, err := s.connStmt(ctx)
if err != nil {
if err == driver.ErrBadConn {
continue
return nil, err
}
- rowsi, err = rowsiFromStatement(ctx, driverStmt{dc, si}, args...)
+ rowsi, err = rowsiFromStatement(ctx, ds, args...)
if err == nil {
// Note: ownership of ci passes to the *Rows, to be freed
// with releaseConn.
return s.QueryContext(context.Background(), args...)
}
-func rowsiFromStatement(ctx context.Context, ds driverStmt, args ...interface{}) (driver.Rows, error) {
+func rowsiFromStatement(ctx context.Context, ds *driverStmt, args ...interface{}) (driver.Rows, error) {
var want int
withLock(ds, func() {
want = ds.si.NumInput()
return nil, fmt.Errorf("sql: statement expects %d inputs; got %d", want, len(args))
}
- dargs, err := driverArgs(&ds, args)
+ dargs, err := driverArgs(ds, args)
if err != nil {
return nil, err
}
return nil
}
s.closed = true
+ s.mu.Unlock()
if s.tx != nil {
- defer s.mu.Unlock()
- return s.txsi.Close()
+ return s.txds.Close()
}
- s.mu.Unlock()
return s.db.removeDep(s, s)
}
defer s.mu.Unlock()
if s.css != nil {
for _, v := range s.css {
- s.db.noteUnusedDriverStatement(v.dc, v.si)
- v.dc.removeOpenStmt(v.si)
+ s.db.noteUnusedDriverStatement(v.dc, v.ds)
+ v.dc.removeOpenStmt(v.ds)
}
s.css = nil
}
ctxClose chan struct{} // closed when Rows is closed, may be null.
lastcols []driver.Value
lasterr error // non-nil only if closed is true
- closeStmt driver.Stmt // if non-nil, statement to Close on close
+ closeStmt *driverStmt // if non-nil, statement to Close on close
}
func (rs *Rows) initContextClose(ctx context.Context) {