tx.closemu.RLock()
defer tx.closemu.RUnlock()
- // TODO(bradfitz): We could be more efficient here and either
- // provide a method to take an existing Stmt (created on
- // perhaps a different Conn), and re-create it on this Conn if
- // necessary. Or, better: keep a map in DB of query string to
- // Stmts, and have Stmt.Execute do the right thing and
- // re-prepare if the Conn in use doesn't have that prepared
- // statement. But we'll want to avoid caching the statement
- // in the case where we only call conn.Prepare implicitly
- // (such as in db.Exec or tx.Exec), but the caller package
- // can't be holding a reference to the returned statement.
- // Perhaps just looking at the reference count (by noting
- // Stmt.Close) would be enough. We might also want a finalizer
- // on Stmt to drop the reference count.
dc, err := tx.grabConn(ctx)
if err != nil {
return nil, err
tx.closemu.RLock()
defer tx.closemu.RUnlock()
- // TODO(bradfitz): optimize this. Currently this re-prepares
- // each time. This is fine for now to illustrate the API but
- // we should really cache already-prepared statements
- // per-Conn. See also the big comment in Tx.Prepare.
-
if tx.db != stmt.db {
return &Stmt{stickyErr: errors.New("sql: Tx.Stmt: statement from different database used")}
}
return &Stmt{stickyErr: err}
}
var si driver.Stmt
- withLock(dc, func() {
- si, err = ctxDriverPrepare(ctx, dc.ci, stmt.query)
- })
+ var parentStmt *Stmt
+ stmt.mu.Lock()
+ if stmt.closed || stmt.tx != nil {
+ // If the statement has been closed or already belongs to a
+ // transaction, we can't reuse it in this connection.
+ // Since tx.StmtContext should never need to be called with a
+ // Stmt already belonging to tx, we ignore this edge case and
+ // re-prepare the statement in this case. No need to add
+ // code-complexity for this.
+ stmt.mu.Unlock()
+ withLock(dc, func() {
+ si, err = ctxDriverPrepare(ctx, dc.ci, stmt.query)
+ })
+ if err != nil {
+ return &Stmt{stickyErr: err}
+ }
+ } else {
+ stmt.removeClosedStmtLocked()
+ // See if the statement has already been prepared on this connection,
+ // and reuse it if possible.
+ for _, v := range stmt.css {
+ if v.dc == dc {
+ si = v.ds.si
+ break
+ }
+ }
+
+ stmt.mu.Unlock()
+
+ if si == nil {
+ cs, err := stmt.prepareOnConnLocked(ctx, dc)
+ if err != nil {
+ return &Stmt{stickyErr: err}
+ }
+ si = cs.si
+ }
+ parentStmt = stmt
+ }
+
txs := &Stmt{
db: tx.db,
tx: tx,
Locker: dc,
si: si,
},
- query: stmt.query,
- stickyErr: err,
+ parentStmt: parentStmt,
+ query: stmt.query,
+ }
+ if parentStmt != nil {
+ tx.db.addDep(parentStmt, txs)
}
tx.stmts.Lock()
tx.stmts.v = append(tx.stmts.v, txs)
tx *Tx
txds *driverStmt
+ // parentStmt is set when a transaction-specific statement
+ // is requested from an identical statement prepared on the same
+ // conn. parentStmt is used to track the dependency of this statement
+ // on its originating ("parent") statement so that parentStmt may
+ // be closed by the user without them having to know whether or not
+ // any transactions are still using it.
+ parentStmt *Stmt
+
mu sync.Mutex // protects the rest of the fields
closed bool
// css is a list of underlying driver statement interfaces
// that are valid on particular connections. This is only
// used if tx == nil and one is found that has idle
- // connections. If tx != nil, txsi is always used.
+ // connections. If tx != nil, txds is always used.
css []connStmt
// lastNumClosed is copied from db.numClosed when Stmt is created
// No luck; we need to prepare the statement on this connection
withLock(dc, func() {
- ds, err = dc.prepareLocked(ctx, s.query)
+ ds, err = s.prepareOnConnLocked(ctx, dc)
})
if err != nil {
s.db.putConn(dc, err)
return nil, nil, nil, err
}
+
+ return dc, dc.releaseConn, ds, nil
+}
+
+// prepareOnConnLocked prepares the query in Stmt s on dc and adds it to the list of
+// open connStmt on the statement. It assumes the caller is holding the lock on dc.
+func (s *Stmt) prepareOnConnLocked(ctx context.Context, dc *driverConn) (*driverStmt, error) {
+ si, err := dc.prepareLocked(ctx, s.query)
+ if err != nil {
+ return nil, err
+ }
+ cs := connStmt{dc, si}
s.mu.Lock()
- cs := connStmt{dc, ds}
s.css = append(s.css, cs)
s.mu.Unlock()
-
- return dc, dc.releaseConn, ds, nil
+ return cs.ds, nil
}
// QueryContext executes a prepared query statement with the given arguments
s.closed = true
s.mu.Unlock()
- if s.tx != nil {
- return s.txds.Close()
+ if s.tx == nil {
+ return s.db.removeDep(s, s)
}
- return s.db.removeDep(s, s)
+ if s.parentStmt != nil {
+ // If parentStmt is set, we must not close s.txds since it's stored
+ // in the css array of the parentStmt.
+ return s.db.removeDep(s.parentStmt, s)
+ }
+ return s.txds.Close()
}
func (s *Stmt) finalClose() error {
}
}
+func TestTxStmtPreparedOnce(t *testing.T) {
+ db := newTestDB(t, "")
+ defer closeDB(t, db)
+ exec(t, db, "CREATE|t1|name=string,age=int32")
+
+ prepares0 := numPrepares(t, db)
+
+ // db.Prepare increments numPrepares.
+ stmt, err := db.Prepare("INSERT|t1|name=?,age=?")
+ if err != nil {
+ t.Fatalf("Stmt, err = %v, %v", stmt, err)
+ }
+ defer stmt.Close()
+
+ tx, err := db.Begin()
+ if err != nil {
+ t.Fatalf("Begin = %v", err)
+ }
+
+ txs1 := tx.Stmt(stmt)
+ txs2 := tx.Stmt(stmt)
+
+ _, err = txs1.Exec("Go", 7)
+ if err != nil {
+ t.Fatalf("Exec = %v", err)
+ }
+ txs1.Close()
+
+ _, err = txs2.Exec("Gopher", 8)
+ if err != nil {
+ t.Fatalf("Exec = %v", err)
+ }
+ txs2.Close()
+
+ err = tx.Commit()
+ if err != nil {
+ t.Fatalf("Commit = %v", err)
+ }
+
+ if prepares := numPrepares(t, db) - prepares0; prepares != 1 {
+ t.Errorf("executed %d Prepare statements; want 1", prepares)
+ }
+}
+
+func TestTxStmtClosedRePrepares(t *testing.T) {
+ db := newTestDB(t, "")
+ defer closeDB(t, db)
+ exec(t, db, "CREATE|t1|name=string,age=int32")
+
+ prepares0 := numPrepares(t, db)
+
+ // db.Prepare increments numPrepares.
+ stmt, err := db.Prepare("INSERT|t1|name=?,age=?")
+ if err != nil {
+ t.Fatalf("Stmt, err = %v, %v", stmt, err)
+ }
+ tx, err := db.Begin()
+ if err != nil {
+ t.Fatalf("Begin = %v", err)
+ }
+ err = stmt.Close()
+ if err != nil {
+ t.Fatalf("stmt.Close() = %v", err)
+ }
+ // tx.Stmt increments numPrepares because stmt is closed.
+ txs := tx.Stmt(stmt)
+ if txs.stickyErr != nil {
+ t.Fatal(txs.stickyErr)
+ }
+ if txs.parentStmt != nil {
+ t.Fatal("expected nil parentStmt")
+ }
+ _, err = txs.Exec(`Eric`, 82)
+ if err != nil {
+ t.Fatalf("txs.Exec = %v", err)
+ }
+
+ err = txs.Close()
+ if err != nil {
+ t.Fatalf("txs.Close = %v", err)
+ }
+
+ tx.Rollback()
+
+ if prepares := numPrepares(t, db) - prepares0; prepares != 2 {
+ t.Errorf("executed %d Prepare statements; want 2", prepares)
+ }
+}
+
+func TestParentStmtOutlivesTxStmt(t *testing.T) {
+ db := newTestDB(t, "")
+ defer closeDB(t, db)
+ exec(t, db, "CREATE|t1|name=string,age=int32")
+
+ // Make sure everything happens on the same connection.
+ db.SetMaxOpenConns(1)
+
+ prepares0 := numPrepares(t, db)
+
+ // db.Prepare increments numPrepares.
+ stmt, err := db.Prepare("INSERT|t1|name=?,age=?")
+ if err != nil {
+ t.Fatalf("Stmt, err = %v, %v", stmt, err)
+ }
+ defer stmt.Close()
+ tx, err := db.Begin()
+ if err != nil {
+ t.Fatalf("Begin = %v", err)
+ }
+ txs := tx.Stmt(stmt)
+ if len(stmt.css) != 1 {
+ t.Fatalf("len(stmt.css) = %v; want 1", len(stmt.css))
+ }
+ err = txs.Close()
+ if err != nil {
+ t.Fatalf("txs.Close() = %v", err)
+ }
+ err = tx.Rollback()
+ if err != nil {
+ t.Fatalf("tx.Rollback() = %v", err)
+ }
+ // txs must not be valid.
+ _, err = txs.Exec("Suzan", 30)
+ if err == nil {
+ t.Fatalf("txs.Exec(), expected err")
+ }
+ // Stmt must still be valid.
+ _, err = stmt.Exec("Janina", 25)
+ if err != nil {
+ t.Fatalf("stmt.Exec() = %v", err)
+ }
+
+ if prepares := numPrepares(t, db) - prepares0; prepares != 1 {
+ t.Errorf("executed %d Prepare statements; want 1", prepares)
+ }
+}
+
+// Test that tx.Stmt called with a statment already
+// associated with tx as argument re-prepares the same
+// statement again.
+func TestTxStmtFromTxStmtRePrepares(t *testing.T) {
+ db := newTestDB(t, "")
+ defer closeDB(t, db)
+ exec(t, db, "CREATE|t1|name=string,age=int32")
+ prepares0 := numPrepares(t, db)
+ // db.Prepare increments numPrepares.
+ stmt, err := db.Prepare("INSERT|t1|name=?,age=?")
+ if err != nil {
+ t.Fatalf("Stmt, err = %v, %v", stmt, err)
+ }
+ defer stmt.Close()
+
+ tx, err := db.Begin()
+ if err != nil {
+ t.Fatalf("Begin = %v", err)
+ }
+ txs1 := tx.Stmt(stmt)
+
+ // tx.Stmt(txs1) increments numPrepares because txs1 already
+ // belongs to a transaction (albeit the same transaction).
+ txs2 := tx.Stmt(txs1)
+ if txs2.stickyErr != nil {
+ t.Fatal(txs2.stickyErr)
+ }
+ if txs2.parentStmt != nil {
+ t.Fatal("expected nil parentStmt")
+ }
+ _, err = txs2.Exec(`Eric`, 82)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ err = txs1.Close()
+ if err != nil {
+ t.Fatalf("txs1.Close = %v", err)
+ }
+ err = txs2.Close()
+ if err != nil {
+ t.Fatalf("txs1.Close = %v", err)
+ }
+ err = tx.Rollback()
+ if err != nil {
+ t.Fatalf("tx.Rollback = %v", err)
+ }
+
+ if prepares := numPrepares(t, db) - prepares0; prepares != 2 {
+ t.Errorf("executed %d Prepare statements; want 2", prepares)
+ }
+}
+
// Issue: https://golang.org/issue/2784
// This test didn't fail before because we got lucky with the fakedb driver.
// It was failing, and now not, in github.com/bradfitz/go-sql-test