]> Cypherpunks repositories - gostls13.git/commitdiff
database/sql: ensure Rows is closed when Tx closes
authorDaniel Theophanes <kardianos@gmail.com>
Mon, 5 Jun 2017 16:04:05 +0000 (09:04 -0700)
committerDaniel Theophanes <kardianos@gmail.com>
Mon, 5 Jun 2017 19:48:49 +0000 (19:48 +0000)
Close any Rows queried within a Tx when the Tx is closed. This prevents
the Tx from blocking on rollback if a Rows query has not been closed yet.

Fixes #20575

Change-Id: I4efe9c4150e951d8a0f1c40d9d5e325964fdd608
Reviewed-on: https://go-review.googlesource.com/44812
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
src/database/sql/sql.go
src/database/sql/sql_test.go

index b7433f2374677bbc32cf6c97aad5c704f67c2737..f7919f983cac86e0350e6af1c680a766c076d6ee 100644 (file)
@@ -1284,12 +1284,14 @@ func (db *DB) query(ctx context.Context, query string, args []interface{}, strat
                return nil, err
        }
 
-       return db.queryDC(ctx, dc, dc.releaseConn, query, args)
+       return db.queryDC(ctx, nil, dc, dc.releaseConn, query, args)
 }
 
 // queryDC executes a query on the given connection.
 // The connection gets released by the releaseConn function.
-func (db *DB) queryDC(ctx context.Context, dc *driverConn, releaseConn func(error), query string, args []interface{}) (*Rows, error) {
+// The ctx context is from a query method and the txctx context is from an
+// optional transaction context.
+func (db *DB) queryDC(ctx, txctx context.Context, dc *driverConn, releaseConn func(error), query string, args []interface{}) (*Rows, error) {
        if queryer, ok := dc.ci.(driver.Queryer); ok {
                dargs, err := driverArgs(dc.ci, nil, args)
                if err != nil {
@@ -1312,7 +1314,7 @@ func (db *DB) queryDC(ctx context.Context, dc *driverConn, releaseConn func(erro
                                releaseConn: releaseConn,
                                rowsi:       rowsi,
                        }
-                       rows.initContextClose(ctx)
+                       rows.initContextClose(ctx, txctx)
                        return rows, nil
                }
        }
@@ -1343,7 +1345,7 @@ func (db *DB) queryDC(ctx context.Context, dc *driverConn, releaseConn func(erro
                rowsi:       rowsi,
                closeStmt:   ds,
        }
-       rows.initContextClose(ctx)
+       rows.initContextClose(ctx, txctx)
        return rows, nil
 }
 
@@ -1532,7 +1534,7 @@ func (c *Conn) QueryContext(ctx context.Context, query string, args ...interface
        }
 
        c.closemu.RLock()
-       return c.db.queryDC(ctx, dc, c.closemuRUnlockCondReleaseConn, query, args)
+       return c.db.queryDC(ctx, nil, dc, c.closemuRUnlockCondReleaseConn, query, args)
 }
 
 // QueryRowContext executes a query that is expected to return at most one row.
@@ -1687,11 +1689,12 @@ var ErrTxDone = errors.New("sql: Transaction has already been committed or rolle
 // close returns the connection to the pool and
 // must only be called by Tx.rollback or Tx.Commit.
 func (tx *Tx) close(err error) {
+       tx.cancel()
+
        tx.closemu.Lock()
        defer tx.closemu.Unlock()
 
        tx.releaseConn(err)
-       tx.cancel()
        tx.dc = nil
        tx.txi = nil
 }
@@ -1957,7 +1960,7 @@ func (tx *Tx) QueryContext(ctx context.Context, query string, args ...interface{
                return nil, err
        }
 
-       return tx.db.queryDC(ctx, dc, tx.closemuRUnlockRelease, query, args)
+       return tx.db.queryDC(ctx, tx.ctx, dc, tx.closemuRUnlockRelease, query, args)
 }
 
 // Query executes a query that returns rows, typically a SELECT.
@@ -2207,7 +2210,11 @@ func (s *Stmt) QueryContext(ctx context.Context, args ...interface{}) (*Rows, er
                                releaseConn(err)
                                s.db.removeDep(s, rows)
                        }
-                       rows.initContextClose(ctx)
+                       var txctx context.Context
+                       if s.tx != nil {
+                               txctx = s.tx.ctx
+                       }
+                       rows.initContextClose(ctx, txctx)
                        return rows, nil
                }
 
@@ -2363,14 +2370,24 @@ type Rows struct {
        lastcols []driver.Value
 }
 
-func (rs *Rows) initContextClose(ctx context.Context) {
+func (rs *Rows) initContextClose(ctx, txctx context.Context) {
        ctx, rs.cancel = context.WithCancel(ctx)
-       go rs.awaitDone(ctx)
+       go rs.awaitDone(ctx, txctx)
 }
 
-// awaitDone blocks until the rows are closed or the context canceled.
-func (rs *Rows) awaitDone(ctx context.Context) {
-       <-ctx.Done()
+// awaitDone blocks until either ctx or txctx is canceled. The ctx is provided
+// from the query context and is canceled when the query Rows is closed.
+// If the query was issued in a transaction, the transaction's context
+// is also provided in txctx to ensure Rows is closed if the Tx is closed.
+func (rs *Rows) awaitDone(ctx, txctx context.Context) {
+       var txctxDone <-chan struct{}
+       if txctx != nil {
+               txctxDone = txctx.Done()
+       }
+       select {
+       case <-ctx.Done():
+       case <-txctxDone:
+       }
        rs.close(ctx.Err())
 }
 
index f5bacc4324111ba0db8245d1de5798b3a70c4b63..8a477edf1ad310e05865788fc62926dd59a89b01 100644 (file)
@@ -2467,6 +2467,32 @@ func TestManyErrBadConn(t *testing.T) {
        }
 }
 
+// TestIssue20575 ensures the Rows from query does not block
+// closing a transaction. Ensure Rows is closed while closing a trasaction.
+func TestIssue20575(t *testing.T) {
+       db := newTestDB(t, "people")
+       tx, err := db.Begin()
+       if err != nil {
+               t.Fatal(err)
+       }
+       ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
+       defer cancel()
+       _, err = tx.QueryContext(ctx, "SELECT|people|age,name|")
+       if err != nil {
+               t.Fatal(err)
+       }
+       // Do not close Rows from QueryContext.
+       err = tx.Rollback()
+       if err != nil {
+               t.Fatal(err)
+       }
+       select {
+       default:
+       case <-ctx.Done():
+               t.Fatal("timeout: failed to rollback query without closing rows:", ctx.Err())
+       }
+}
+
 // golang.org/issue/5718
 func TestErrBadConnReconnect(t *testing.T) {
        db := newTestDB(t, "foo")