]> Cypherpunks repositories - gostls13.git/commitdiff
sql: add Tx.Stmt to use an existing prepared stmt in a transaction
authorBrad Fitzpatrick <bradfitz@golang.org>
Mon, 28 Nov 2011 16:00:32 +0000 (11:00 -0500)
committerBrad Fitzpatrick <bradfitz@golang.org>
Mon, 28 Nov 2011 16:00:32 +0000 (11:00 -0500)
R=rsc
CC=golang-dev
https://golang.org/cl/5433059

src/pkg/exp/sql/sql.go
src/pkg/exp/sql/sql_test.go

index c055fdd68c63648194d7165790e23f33efaf8449..f17d12eaa13e27f69cd5eec288688617b02dd3ef 100644 (file)
@@ -344,25 +344,26 @@ func (tx *Tx) Rollback() error {
        return tx.txi.Rollback()
 }
 
-// Prepare creates a prepared statement.
+// Prepare creates a prepared statement for use within a transaction.
 //
-// The statement is only valid within the scope of this transaction.
+// The returned statement operates within the transaction and can no longer
+// be used once the transaction has been committed or rolled back.
+//
+// To use an existing prepared statement on this transaction, see Tx.Stmt.
 func (tx *Tx) Prepare(query string) (*Stmt, error) {
-       // TODO(bradfitz): the restriction that the returned statement
-       // is only valid for this Transaction is lame and negates a
-       // lot of the benefit of prepared statements.  We could be
-       // more efficient here and either provide a method to take an
-       // existing Stmt (created on perhaps a different Conn), and
-       // re-create it on this Conn if necessary. Or, better: keep a
-       // map in DB of query string to Stmts, and have Stmt.Execute
-       // do the right thing and re-prepare if the Conn in use
-       // doesn't have that prepared statement.  But we'll want to
-       // avoid caching the statement in the case where we only call
-       // conn.Prepare implicitly (such as in db.Exec or tx.Exec),
-       // but the caller package can't be holding a reference to the
-       // returned statement.  Perhaps just looking at the reference
-       // count (by noting Stmt.Close) would be enough. We might also
-       // want a finalizer on Stmt to drop the reference count.
+       // TODO(bradfitz): We could be more efficient here and either
+       // provide a method to take an existing Stmt (created on
+       // perhaps a different Conn), and re-create it on this Conn if
+       // necessary. Or, better: keep a map in DB of query string to
+       // Stmts, and have Stmt.Execute do the right thing and
+       // re-prepare if the Conn in use doesn't have that prepared
+       // statement.  But we'll want to avoid caching the statement
+       // in the case where we only call conn.Prepare implicitly
+       // (such as in db.Exec or tx.Exec), but the caller package
+       // can't be holding a reference to the returned statement.
+       // Perhaps just looking at the reference count (by noting
+       // Stmt.Close) would be enough. We might also want a finalizer
+       // on Stmt to drop the reference count.
        ci, err := tx.grabConn()
        if err != nil {
                return nil, err
@@ -383,6 +384,39 @@ func (tx *Tx) Prepare(query string) (*Stmt, error) {
        return stmt, nil
 }
 
+// Stmt returns a transaction-specific prepared statement from
+// an existing statement.
+//
+// Example:
+//  updateMoney, err := db.Prepare("UPDATE balance SET money=money+? WHERE id=?")
+//  ...
+//  tx, err := db.Begin()
+//  ...
+//  res, err := tx.Stmt(updateMoney).Exec(123.45, 98293203)
+func (tx *Tx) Stmt(stmt *Stmt) *Stmt {
+       // TODO(bradfitz): optimize this. Currently this re-prepares
+       // each time.  This is fine for now to illustrate the API but
+       // we should really cache already-prepared statements
+       // per-Conn. See also the big comment in Tx.Prepare.
+
+       if tx.db != stmt.db {
+               return &Stmt{stickyErr: errors.New("sql: Tx.Stmt: statement from different database used")}
+       }
+       ci, err := tx.grabConn()
+       if err != nil {
+               return &Stmt{stickyErr: err}
+       }
+       defer tx.releaseConn()
+       si, err := ci.Prepare(stmt.query)
+       return &Stmt{
+               db:        tx.db,
+               tx:        tx,
+               txsi:      si,
+               query:     stmt.query,
+               stickyErr: err,
+       }
+}
+
 // Exec executes a query that doesn't return rows.
 // For example: an INSERT and UPDATE.
 func (tx *Tx) Exec(query string, args ...interface{}) (Result, error) {
@@ -448,8 +482,9 @@ type connStmt struct {
 // Stmt is a prepared statement. Stmt is safe for concurrent use by multiple goroutines.
 type Stmt struct {
        // Immutable:
-       db    *DB    // where we came from
-       query string // that created the Sttm
+       db        *DB    // where we came from
+       query     string // that created the Stmt
+       stickyErr error  // if non-nil, this error is returned for all operations
 
        // If in a transaction, else both nil:
        tx   *Tx
@@ -513,6 +548,9 @@ func (s *Stmt) Exec(args ...interface{}) (Result, error) {
 // statement, a function to call to release the connection, and a
 // statement bound to that connection.
 func (s *Stmt) connStmt() (ci driver.Conn, releaseConn func(), si driver.Stmt, err error) {
+       if s.stickyErr != nil {
+               return nil, nil, nil, s.stickyErr
+       }
        s.mu.Lock()
        if s.closed {
                s.mu.Unlock()
@@ -621,6 +659,9 @@ func (s *Stmt) QueryRow(args ...interface{}) *Row {
 
 // Close closes the statement.
 func (s *Stmt) Close() error {
+       if s.stickyErr != nil {
+               return s.stickyErr
+       }
        s.mu.Lock()
        defer s.mu.Unlock()
        if s.closed {
index 5b8bcc914227625e43d0f1af0df27a591a466d78..4f8318d26ef7972884804f2a7a6dde463e7bae0b 100644 (file)
@@ -166,7 +166,7 @@ func TestBogusPreboundParameters(t *testing.T) {
        }
 }
 
-func TestDb(t *testing.T) {
+func TestExec(t *testing.T) {
        db := newTestDB(t, "foo")
        defer closeDB(t, db)
        exec(t, db, "CREATE|t1|name=string,age=int32,dead=bool")
@@ -206,3 +206,25 @@ func TestDb(t *testing.T) {
                }
        }
 }
+
+func TestTxStmt(t *testing.T) {
+       db := newTestDB(t, "")
+       defer closeDB(t, db)
+       exec(t, db, "CREATE|t1|name=string,age=int32,dead=bool")
+       stmt, err := db.Prepare("INSERT|t1|name=?,age=?")
+       if err != nil {
+               t.Fatalf("Stmt, err = %v, %v", stmt, err)
+       }
+       tx, err := db.Begin()
+       if err != nil {
+               t.Fatalf("Begin = %v", err)
+       }
+       _, err = tx.Stmt(stmt).Exec("Bobby", 7)
+       if err != nil {
+               t.Fatalf("Exec = %v", err)
+       }
+       err = tx.Commit()
+       if err != nil {
+               t.Fatalf("Commit = %v", err)
+       }
+}