]> Cypherpunks repositories - gostls13.git/commitdiff
database/sql: fix race when canceling queries immediately
authorDaniel Theophanes <kardianos@gmail.com>
Sat, 21 Jan 2017 01:12:50 +0000 (17:12 -0800)
committerBrad Fitzpatrick <bradfitz@golang.org>
Thu, 26 Jan 2017 06:25:37 +0000 (06:25 +0000)
Previously the following could happen, though in practice it would
be rare.

Goroutine 1:
(*Tx).QueryContext begins a query, passing in userContext

Goroutine 2:
(*Tx).awaitDone starts to wait on the context derived from the passed in context

Goroutine 1:
(*Tx).grabConn returns a valid (*driverConn)
The (*driverConn) passes to (*DB).queryConn

Goroutine 3:
userContext is canceled

Goroutine 2:
(*Tx).awaitDone unblocks and calls (*Tx).rollback
(*driverConn).finalClose obtains dc.Mutex
(*driverConn).finalClose sets dc.ci = nil

Goroutine 1:
(*DB).queryConn obtains dc.Mutex in withLock
ctxDriverPrepare accepts dc.ci which is now nil
ctxCriverPrepare panics on the nil ci

The fix for this is to guard the Tx methods with a RWLock
holding it exclusivly when closing the Tx and holding a read lock
when executing a query.

Fixes #18719

Change-Id: I37aa02c37083c9793dabd28f7f934a1c5cbc05ea
Reviewed-on: https://go-review.googlesource.com/35550
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
src/database/sql/sql.go
src/database/sql/sql_test.go

index 0fa7c34a13e3bb7ee209b9b01a33befe2f414e80..feb91223a9e9b99f48fa65e041afd6f7d309e791 100644 (file)
@@ -1357,16 +1357,7 @@ func (db *DB) begin(ctx context.Context, opts *TxOptions, strategy connReuseStra
                cancel: cancel,
                ctx:    ctx,
        }
-       go func(tx *Tx) {
-               select {
-               case <-tx.ctx.Done():
-                       if !tx.isDone() {
-                               // Discard and close the connection used to ensure the transaction
-                               // is closed and the resources are released.
-                               tx.rollback(true)
-                       }
-               }
-       }(tx)
+       go tx.awaitDone()
        return tx, nil
 }
 
@@ -1388,6 +1379,11 @@ func (db *DB) Driver() driver.Driver {
 type Tx struct {
        db *DB
 
+       // closemu prevents the transaction from closing while there
+       // is an active query. It is held for read during queries
+       // and exclusively during close.
+       closemu sync.RWMutex
+
        // dc is owned exclusively until Commit or Rollback, at which point
        // it's returned with putConn.
        dc  *driverConn
@@ -1413,6 +1409,20 @@ type Tx struct {
        ctx context.Context
 }
 
+// awaitDone blocks until the context in Tx is canceled and rolls back
+// the transaction if it's not already done.
+func (tx *Tx) awaitDone() {
+       // Wait for either the transaction to be committed or rolled
+       // back, or for the associated context to be closed.
+       <-tx.ctx.Done()
+
+       // Discard and close the connection used to ensure the
+       // transaction is closed and the resources are released.  This
+       // rollback does nothing if the transaction has already been
+       // committed or rolled back.
+       tx.rollback(true)
+}
+
 func (tx *Tx) isDone() bool {
        return atomic.LoadInt32(&tx.done) != 0
 }
@@ -1424,16 +1434,31 @@ var ErrTxDone = errors.New("sql: Transaction has already been committed or rolle
 // close returns the connection to the pool and
 // must only be called by Tx.rollback or Tx.Commit.
 func (tx *Tx) close(err error) {
+       tx.closemu.Lock()
+       defer tx.closemu.Unlock()
+
        tx.db.putConn(tx.dc, err)
        tx.cancel()
        tx.dc = nil
        tx.txi = nil
 }
 
+// hookTxGrabConn specifies an optional hook to be called on
+// a successful call to (*Tx).grabConn. For tests.
+var hookTxGrabConn func()
+
 func (tx *Tx) grabConn(ctx context.Context) (*driverConn, error) {
+       select {
+       default:
+       case <-ctx.Done():
+               return nil, ctx.Err()
+       }
        if tx.isDone() {
                return nil, ErrTxDone
        }
+       if hookTxGrabConn != nil { // test hook
+               hookTxGrabConn()
+       }
        return tx.dc, nil
 }
 
@@ -1503,6 +1528,9 @@ func (tx *Tx) Rollback() error {
 // for the execution of the returned statement. The returned statement
 // will run in the transaction context.
 func (tx *Tx) PrepareContext(ctx context.Context, query string) (*Stmt, error) {
+       tx.closemu.RLock()
+       defer tx.closemu.RUnlock()
+
        // TODO(bradfitz): We could be more efficient here and either
        // provide a method to take an existing Stmt (created on
        // perhaps a different Conn), and re-create it on this Conn if
@@ -1567,6 +1595,9 @@ func (tx *Tx) Prepare(query string) (*Stmt, error) {
 // The returned statement operates within the transaction and will be closed
 // when the transaction has been committed or rolled back.
 func (tx *Tx) StmtContext(ctx context.Context, stmt *Stmt) *Stmt {
+       tx.closemu.RLock()
+       defer tx.closemu.RUnlock()
+
        // TODO(bradfitz): optimize this. Currently this re-prepares
        // each time. This is fine for now to illustrate the API but
        // we should really cache already-prepared statements
@@ -1618,6 +1649,9 @@ func (tx *Tx) Stmt(stmt *Stmt) *Stmt {
 // ExecContext executes a query that doesn't return rows.
 // For example: an INSERT and UPDATE.
 func (tx *Tx) ExecContext(ctx context.Context, query string, args ...interface{}) (Result, error) {
+       tx.closemu.RLock()
+       defer tx.closemu.RUnlock()
+
        dc, err := tx.grabConn(ctx)
        if err != nil {
                return nil, err
@@ -1661,6 +1695,9 @@ func (tx *Tx) Exec(query string, args ...interface{}) (Result, error) {
 
 // QueryContext executes a query that returns rows, typically a SELECT.
 func (tx *Tx) QueryContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) {
+       tx.closemu.RLock()
+       defer tx.closemu.RUnlock()
+
        dc, err := tx.grabConn(ctx)
        if err != nil {
                return nil, err
@@ -2038,25 +2075,21 @@ type Rows struct {
        // closed value is 1 when the Rows is closed.
        // Use atomic operations on value when checking value.
        closed    int32
-       ctxClose  chan struct{} // closed when Rows is closed, may be null.
+       cancel    func() // called when Rows is closed, may be nil.
        lastcols  []driver.Value
        lasterr   error       // non-nil only if closed is true
        closeStmt *driverStmt // if non-nil, statement to Close on close
 }
 
 func (rs *Rows) initContextClose(ctx context.Context) {
-       if ctx.Done() == context.Background().Done() {
-               return
-       }
+       ctx, rs.cancel = context.WithCancel(ctx)
+       go rs.awaitDone(ctx)
+}
 
-       rs.ctxClose = make(chan struct{})
-       go func() {
-               select {
-               case <-ctx.Done():
-                       rs.Close()
-               case <-rs.ctxClose:
-               }
-       }()
+// awaitDone blocks until the rows are closed or the context canceled.
+func (rs *Rows) awaitDone(ctx context.Context) {
+       <-ctx.Done()
+       rs.Close()
 }
 
 // Next prepares the next result row for reading with the Scan method. It
@@ -2314,7 +2347,9 @@ func (rs *Rows) Scan(dest ...interface{}) error {
        return nil
 }
 
-var rowsCloseHook func(*Rows, *error)
+// rowsCloseHook returns a function so tests may install the
+// hook throug a test only mutex.
+var rowsCloseHook = func() func(*Rows, *error) { return nil }
 
 func (rs *Rows) isClosed() bool {
        return atomic.LoadInt32(&rs.closed) != 0
@@ -2328,13 +2363,15 @@ func (rs *Rows) Close() error {
        if !atomic.CompareAndSwapInt32(&rs.closed, 0, 1) {
                return nil
        }
-       if rs.ctxClose != nil {
-               close(rs.ctxClose)
-       }
+
        err := rs.rowsi.Close()
-       if fn := rowsCloseHook; fn != nil {
+       if fn := rowsCloseHook(); fn != nil {
                fn(rs, &err)
        }
+       if rs.cancel != nil {
+               rs.cancel()
+       }
+
        if rs.closeStmt != nil {
                rs.closeStmt.Close()
        }
index 3f8e03ce135426b241b089e9623a1b408b874ac2..898df3b455b22cac915bdb19510cc9ba22f83cbe 100644 (file)
@@ -14,6 +14,7 @@ import (
        "runtime"
        "strings"
        "sync"
+       "sync/atomic"
        "testing"
        "time"
 )
@@ -1135,6 +1136,24 @@ func TestQueryRowClosingStmt(t *testing.T) {
        }
 }
 
+var atomicRowsCloseHook atomic.Value // of func(*Rows, *error)
+
+func init() {
+       rowsCloseHook = func() func(*Rows, *error) {
+               fn, _ := atomicRowsCloseHook.Load().(func(*Rows, *error))
+               return fn
+       }
+}
+
+func setRowsCloseHook(fn func(*Rows, *error)) {
+       if fn == nil {
+               // Can't change an atomic.Value back to nil, so set it to this
+               // no-op func instead.
+               fn = func(*Rows, *error) {}
+       }
+       atomicRowsCloseHook.Store(fn)
+}
+
 // Test issue 6651
 func TestIssue6651(t *testing.T) {
        db := newTestDB(t, "people")
@@ -1147,6 +1166,7 @@ func TestIssue6651(t *testing.T) {
                return fmt.Errorf(want)
        }
        defer func() { rowsCursorNextHook = nil }()
+
        err := db.QueryRow("SELECT|people|name|").Scan(&v)
        if err == nil || err.Error() != want {
                t.Errorf("error = %q; want %q", err, want)
@@ -1154,10 +1174,10 @@ func TestIssue6651(t *testing.T) {
        rowsCursorNextHook = nil
 
        want = "error in rows.Close"
-       rowsCloseHook = func(rows *Rows, err *error) {
+       setRowsCloseHook(func(rows *Rows, err *error) {
                *err = fmt.Errorf(want)
-       }
-       defer func() { rowsCloseHook = nil }()
+       })
+       defer setRowsCloseHook(nil)
        err = db.QueryRow("SELECT|people|name|").Scan(&v)
        if err == nil || err.Error() != want {
                t.Errorf("error = %q; want %q", err, want)
@@ -1830,7 +1850,9 @@ func TestStmtCloseDeps(t *testing.T) {
                db.dumpDeps(t)
        }
 
-       if len(stmt.css) > nquery {
+       if !waitCondition(5*time.Second, 5*time.Millisecond, func() bool {
+               return len(stmt.css) <= nquery
+       }) {
                t.Errorf("len(stmt.css) = %d; want <= %d", len(stmt.css), nquery)
        }
 
@@ -2576,10 +2598,10 @@ func TestIssue6081(t *testing.T) {
        if err != nil {
                t.Fatal(err)
        }
-       rowsCloseHook = func(rows *Rows, err *error) {
+       setRowsCloseHook(func(rows *Rows, err *error) {
                *err = driver.ErrBadConn
-       }
-       defer func() { rowsCloseHook = nil }()
+       })
+       defer setRowsCloseHook(nil)
        for i := 0; i < 10; i++ {
                rows, err := stmt.Query()
                if err != nil {
@@ -2642,7 +2664,10 @@ func TestIssue18429(t *testing.T) {
                        if err != nil {
                                return
                        }
-                       rows, err := tx.QueryContext(ctx, "WAIT|"+qwait+"|SELECT|people|name|")
+                       // This is expected to give a cancel error many, but not all the time.
+                       // Test failure will happen with a panic or other race condition being
+                       // reported.
+                       rows, _ := tx.QueryContext(ctx, "WAIT|"+qwait+"|SELECT|people|name|")
                        if rows != nil {
                                rows.Close()
                        }
@@ -2655,6 +2680,56 @@ func TestIssue18429(t *testing.T) {
        time.Sleep(milliWait * 3 * time.Millisecond)
 }
 
+// TestIssue18719 closes the context right before use. The sql.driverConn
+// will nil out the ci on close in a lock, but if another process uses it right after
+// it will panic with on the nil ref.
+//
+// See https://golang.org/cl/35550 .
+func TestIssue18719(t *testing.T) {
+       db := newTestDB(t, "people")
+       defer closeDB(t, db)
+
+       ctx, cancel := context.WithCancel(context.Background())
+       defer cancel()
+
+       tx, err := db.BeginTx(ctx, nil)
+       if err != nil {
+               t.Fatal(err)
+       }
+
+       hookTxGrabConn = func() {
+               cancel()
+
+               // Wait for the context to cancel and tx to rollback.
+               for tx.isDone() == false {
+                       time.Sleep(time.Millisecond * 3)
+               }
+       }
+       defer func() { hookTxGrabConn = nil }()
+
+       // This call will grab the connection and cancel the context
+       // after it has done so. Code after must deal with the canceled state.
+       rows, err := tx.QueryContext(ctx, "SELECT|people|name|")
+       if err != nil {
+               rows.Close()
+               t.Fatalf("expected error %v but got %v", nil, err)
+       }
+
+       // Rows may be ignored because it will be closed when the context is canceled.
+
+       // Do not explicitly rollback. The rollback will happen from the
+       // canceled context.
+
+       // Wait for connections to return to pool.
+       var numOpen int
+       if !waitCondition(5*time.Second, 5*time.Millisecond, func() bool {
+               numOpen = db.numOpenConns()
+               return numOpen == 0
+       }) {
+               t.Fatalf("open conns after hitting EOF = %d; want 0", numOpen)
+       }
+}
+
 func TestConcurrency(t *testing.T) {
        doConcurrentTest(t, new(concurrentDBQueryTest))
        doConcurrentTest(t, new(concurrentDBExecTest))