]> Cypherpunks repositories - gostls13.git/commitdiff
database/sql: Close per-tx prepared statements when the associated tx ends
authorMarko Tiikkaja <marko@joh.to>
Mon, 22 Sep 2014 13:19:27 +0000 (09:19 -0400)
committerBrad Fitzpatrick <bradfitz@golang.org>
Mon, 22 Sep 2014 13:19:27 +0000 (09:19 -0400)
LGTM=bradfitz
R=golang-codereviews, bradfitz, mattn.jp
CC=golang-codereviews
https://golang.org/cl/131650043

src/database/sql/sql.go
src/database/sql/sql_test.go

index 90f813d8230ae06bf5e6a6f92cf19b0d43442381..731b7a7f797e749362b4f9f323c7e86e715a897d 100644 (file)
@@ -1043,6 +1043,13 @@ type Tx struct {
        // or Rollback. once done, all operations fail with
        // ErrTxDone.
        done bool
+
+       // All Stmts prepared for this transaction.  These will be closed after the
+       // transaction has been committed or rolled back.
+       stmts struct {
+               sync.Mutex
+               v []*Stmt
+       }
 }
 
 var ErrTxDone = errors.New("sql: Transaction has already been committed or rolled back")
@@ -1064,6 +1071,15 @@ func (tx *Tx) grabConn() (*driverConn, error) {
        return tx.dc, nil
 }
 
+// Closes all Stmts prepared for this transaction.
+func (tx *Tx) closePrepared() {
+       tx.stmts.Lock()
+       for _, stmt := range tx.stmts.v {
+               stmt.Close()
+       }
+       tx.stmts.Unlock()
+}
+
 // Commit commits the transaction.
 func (tx *Tx) Commit() error {
        if tx.done {
@@ -1071,8 +1087,12 @@ func (tx *Tx) Commit() error {
        }
        defer tx.close()
        tx.dc.Lock()
-       defer tx.dc.Unlock()
-       return tx.txi.Commit()
+       err := tx.txi.Commit()
+       tx.dc.Unlock()
+       if err != driver.ErrBadConn {
+               tx.closePrepared()
+       }
+       return err
 }
 
 // Rollback aborts the transaction.
@@ -1082,8 +1102,12 @@ func (tx *Tx) Rollback() error {
        }
        defer tx.close()
        tx.dc.Lock()
-       defer tx.dc.Unlock()
-       return tx.txi.Rollback()
+       err := tx.txi.Rollback()
+       tx.dc.Unlock()
+       if err != driver.ErrBadConn {
+               tx.closePrepared()
+       }
+       return err
 }
 
 // Prepare creates a prepared statement for use within a transaction.
@@ -1127,6 +1151,9 @@ func (tx *Tx) Prepare(query string) (*Stmt, error) {
                },
                query: query,
        }
+       tx.stmts.Lock()
+       tx.stmts.v = append(tx.stmts.v, stmt)
+       tx.stmts.Unlock()
        return stmt, nil
 }
 
@@ -1155,7 +1182,7 @@ func (tx *Tx) Stmt(stmt *Stmt) *Stmt {
        dc.Lock()
        si, err := dc.ci.Prepare(stmt.query)
        dc.Unlock()
-       return &Stmt{
+       txs := &Stmt{
                db: tx.db,
                tx: tx,
                txsi: &driverStmt{
@@ -1165,6 +1192,10 @@ func (tx *Tx) Stmt(stmt *Stmt) *Stmt {
                query:     stmt.query,
                stickyErr: err,
        }
+       tx.stmts.Lock()
+       tx.stmts.v = append(tx.stmts.v, txs)
+       tx.stmts.Unlock()
+       return txs
 }
 
 // Exec executes a query that doesn't return rows.
index 12e5a6fd6fc04b305a47fe0888091d2c7f64bd91..34efdf254c65309f9ea0c48a6921ff2a9add9f97 100644 (file)
@@ -441,6 +441,33 @@ func TestExec(t *testing.T) {
        }
 }
 
+func TestTxPrepare(t *testing.T) {
+       db := newTestDB(t, "")
+       defer closeDB(t, db)
+       exec(t, db, "CREATE|t1|name=string,age=int32,dead=bool")
+       tx, err := db.Begin()
+       if err != nil {
+               t.Fatalf("Begin = %v", err)
+       }
+       stmt, err := tx.Prepare("INSERT|t1|name=?,age=?")
+       if err != nil {
+               t.Fatalf("Stmt, err = %v, %v", stmt, err)
+       }
+       defer stmt.Close()
+       _, err = stmt.Exec("Bobby", 7)
+       if err != nil {
+               t.Fatalf("Exec = %v", err)
+       }
+       err = tx.Commit()
+       if err != nil {
+               t.Fatalf("Commit = %v", err)
+       }
+       // Commit() should have closed the statement
+       if !stmt.closed {
+               t.Fatal("Stmt not closed after Commit")
+       }
+}
+
 func TestTxStmt(t *testing.T) {
        db := newTestDB(t, "")
        defer closeDB(t, db)
@@ -464,6 +491,10 @@ func TestTxStmt(t *testing.T) {
        if err != nil {
                t.Fatalf("Commit = %v", err)
        }
+       // Commit() should have closed the statement
+       if !txs.closed {
+               t.Fatal("Stmt not closed after Commit")
+       }
 }
 
 // Issue: http://golang.org/issue/2784