]> Cypherpunks repositories - gostls13.git/commitdiff
database/sql: ensure a Stmt from a Conn executes on the same driver.Conn
authorDaniel Theophanes <kardianos@gmail.com>
Mon, 12 Jun 2017 17:46:15 +0000 (10:46 -0700)
committerBrad Fitzpatrick <bradfitz@golang.org>
Tue, 13 Jun 2017 19:16:54 +0000 (19:16 +0000)
Ensure a Stmt prepared on a Conn executes on the same driver.Conn.
This also removes another instance of duplicated prepare logic
as a side effect.

Fixes #20647

Change-Id: Ia00a19e4dd15e19e4d754105babdff5dc127728f
Reviewed-on: https://go-review.googlesource.com/45391
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>

src/database/sql/sql.go
src/database/sql/sql_test.go

index 9bc6414f2bfd13bd7e96e340e2f452e1e4041fe1..59bbf59c30e53230572e2d760958b65eb603e491 100644 (file)
@@ -393,11 +393,19 @@ func (dc *driverConn) expired(timeout time.Duration) bool {
        return dc.createdAt.Add(timeout).Before(nowFunc())
 }
 
-func (dc *driverConn) prepareLocked(ctx context.Context, query string) (*driverStmt, error) {
+// prepareLocked prepares the query on dc. When cg == nil the dc must keep track of
+// the prepared statements in a pool.
+func (dc *driverConn) prepareLocked(ctx context.Context, cg stmtConnGrabber, query string) (*driverStmt, error) {
        si, err := ctxDriverPrepare(ctx, dc.ci, query)
        if err != nil {
                return nil, err
        }
+       ds := &driverStmt{Locker: dc, si: si}
+
+       // No need to manage open statements if there is a single connection grabber.
+       if cg != nil {
+               return ds, nil
+       }
 
        // Track each driverConn's open statements, so we can close them
        // before closing the conn.
@@ -406,9 +414,7 @@ func (dc *driverConn) prepareLocked(ctx context.Context, query string) (*driverS
        if dc.openStmt == nil {
                dc.openStmt = make(map[*driverStmt]bool)
        }
-       ds := &driverStmt{Locker: dc, si: si}
        dc.openStmt[ds] = true
-
        return ds, nil
 }
 
@@ -1165,28 +1171,39 @@ func (db *DB) prepare(ctx context.Context, query string, strategy connReuseStrat
        if err != nil {
                return nil, err
        }
-       return db.prepareDC(ctx, dc, dc.releaseConn, query)
+       return db.prepareDC(ctx, dc, dc.releaseConn, nil, query)
 }
 
-func (db *DB) prepareDC(ctx context.Context, dc *driverConn, release func(error), query string) (*Stmt, error) {
+// prepareDC prepares a query on the driverConn and calls release before
+// returning. When cg == nil it implies that a connection pool is used, and
+// when cg != nil only a single driver connection is used.
+func (db *DB) prepareDC(ctx context.Context, dc *driverConn, release func(error), cg stmtConnGrabber, query string) (*Stmt, error) {
        var ds *driverStmt
        var err error
        defer func() {
                release(err)
        }()
        withLock(dc, func() {
-               ds, err = dc.prepareLocked(ctx, query)
+               ds, err = dc.prepareLocked(ctx, cg, query)
        })
        if err != nil {
                return nil, err
        }
        stmt := &Stmt{
-               db:            db,
-               query:         query,
-               css:           []connStmt{{dc, ds}},
-               lastNumClosed: atomic.LoadUint64(&db.numClosed),
+               db:    db,
+               query: query,
+               cg:    cg,
+               cgds:  ds,
+       }
+
+       // When cg == nil this statement will need to keep track of various
+       // connections they are prepared on and record the stmt dependency on
+       // the DB.
+       if cg == nil {
+               stmt.css = []connStmt{{dc, ds}}
+               stmt.lastNumClosed = atomic.LoadUint64(&db.numClosed)
+               db.addDep(stmt, stmt)
        }
-       db.addDep(stmt, stmt)
        return stmt, nil
 }
 
@@ -1474,6 +1491,8 @@ func (db *DB) Conn(ctx context.Context) (*Conn, error) {
        return conn, nil
 }
 
+type releaseConn func(error)
+
 // Conn represents a single database session rather a pool of database
 // sessions. Prefer running queries from DB unless there is a specific
 // need for a continuous single database session.
@@ -1501,46 +1520,41 @@ type Conn struct {
        done int32
 }
 
-func (c *Conn) grabConn() (*driverConn, error) {
+func (c *Conn) grabConn(context.Context) (*driverConn, releaseConn, error) {
        if atomic.LoadInt32(&c.done) != 0 {
-               return nil, ErrConnDone
+               return nil, nil, ErrConnDone
        }
-       return c.dc, nil
+       c.closemu.RLock()
+       return c.dc, c.closemuRUnlockCondReleaseConn, nil
 }
 
 // PingContext verifies the connection to the database is still alive.
 func (c *Conn) PingContext(ctx context.Context) error {
-       dc, err := c.grabConn()
+       dc, release, err := c.grabConn(ctx)
        if err != nil {
                return err
        }
-
-       c.closemu.RLock()
-       return c.db.pingDC(ctx, dc, c.closemuRUnlockCondReleaseConn)
+       return c.db.pingDC(ctx, dc, release)
 }
 
 // ExecContext executes a query without returning any rows.
 // The args are for any placeholder parameters in the query.
 func (c *Conn) ExecContext(ctx context.Context, query string, args ...interface{}) (Result, error) {
-       dc, err := c.grabConn()
+       dc, release, err := c.grabConn(ctx)
        if err != nil {
                return nil, err
        }
-
-       c.closemu.RLock()
-       return c.db.execDC(ctx, dc, c.closemuRUnlockCondReleaseConn, query, args)
+       return c.db.execDC(ctx, dc, release, query, args)
 }
 
 // QueryContext executes a query that returns rows, typically a SELECT.
 // The args are for any placeholder parameters in the query.
 func (c *Conn) QueryContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) {
-       dc, err := c.grabConn()
+       dc, release, err := c.grabConn(ctx)
        if err != nil {
                return nil, err
        }
-
-       c.closemu.RLock()
-       return c.db.queryDC(ctx, nil, dc, c.closemuRUnlockCondReleaseConn, query, args)
+       return c.db.queryDC(ctx, nil, dc, release, query, args)
 }
 
 // QueryRowContext executes a query that is expected to return at most one row.
@@ -1563,13 +1577,11 @@ func (c *Conn) QueryRowContext(ctx context.Context, query string, args ...interf
 // The provided context is used for the preparation of the statement, not for the
 // execution of the statement.
 func (c *Conn) PrepareContext(ctx context.Context, query string) (*Stmt, error) {
-       dc, err := c.grabConn()
+       dc, release, err := c.grabConn(ctx)
        if err != nil {
                return nil, err
        }
-
-       c.closemu.RLock()
-       return c.db.prepareDC(ctx, dc, c.closemuRUnlockCondReleaseConn, query)
+       return c.db.prepareDC(ctx, dc, release, c, query)
 }
 
 // BeginTx starts a transaction.
@@ -1583,13 +1595,11 @@ func (c *Conn) PrepareContext(ctx context.Context, query string) (*Stmt, error)
 // If a non-default isolation level is used that the driver doesn't support,
 // an error will be returned.
 func (c *Conn) BeginTx(ctx context.Context, opts *TxOptions) (*Tx, error) {
-       dc, err := c.grabConn()
+       dc, release, err := c.grabConn(ctx)
        if err != nil {
                return nil, err
        }
-
-       c.closemu.RLock()
-       return c.db.beginDC(ctx, dc, c.closemuRUnlockCondReleaseConn, opts)
+       return c.db.beginDC(ctx, dc, release, opts)
 }
 
 // closemuRUnlockCondReleaseConn read unlocks closemu
@@ -1601,6 +1611,10 @@ func (c *Conn) closemuRUnlockCondReleaseConn(err error) {
        }
 }
 
+func (c *Conn) txCtx() context.Context {
+       return nil
+}
+
 func (c *Conn) close(err error) error {
        if !atomic.CompareAndSwapInt32(&c.done, 0, 1) {
                return ErrConnDone
@@ -1712,19 +1726,28 @@ func (tx *Tx) close(err error) {
 // a successful call to (*Tx).grabConn. For tests.
 var hookTxGrabConn func()
 
-func (tx *Tx) grabConn(ctx context.Context) (*driverConn, error) {
+func (tx *Tx) grabConn(ctx context.Context) (*driverConn, releaseConn, error) {
        select {
        default:
        case <-ctx.Done():
-               return nil, ctx.Err()
+               return nil, nil, ctx.Err()
        }
+
+       // closeme.RLock must come before the check for isDone to prevent the Tx from
+       // closing while a query is executing.
+       tx.closemu.RLock()
        if tx.isDone() {
-               return nil, ErrTxDone
+               tx.closemu.RUnlock()
+               return nil, nil, ErrTxDone
        }
        if hookTxGrabConn != nil { // test hook
                hookTxGrabConn()
        }
-       return tx.dc, nil
+       return tx.dc, tx.closemuRUnlockRelease, nil
+}
+
+func (tx *Tx) txCtx() context.Context {
+       return tx.ctx
 }
 
 // closemuRUnlockRelease is used as a func(error) method value in
@@ -1801,31 +1824,15 @@ func (tx *Tx) Rollback() error {
 // for the execution of the returned statement. The returned statement
 // will run in the transaction context.
 func (tx *Tx) PrepareContext(ctx context.Context, query string) (*Stmt, error) {
-       tx.closemu.RLock()
-       defer tx.closemu.RUnlock()
-
-       dc, err := tx.grabConn(ctx)
+       dc, release, err := tx.grabConn(ctx)
        if err != nil {
                return nil, err
        }
 
-       var si driver.Stmt
-       withLock(dc, func() {
-               si, err = ctxDriverPrepare(ctx, dc.ci, query)
-       })
+       stmt, err := tx.db.prepareDC(ctx, dc, release, tx, query)
        if err != nil {
                return nil, err
        }
-
-       stmt := &Stmt{
-               db: tx.db,
-               tx: tx,
-               txds: &driverStmt{
-                       Locker: dc,
-                       si:     si,
-               },
-               query: query,
-       }
        tx.stmts.Lock()
        tx.stmts.v = append(tx.stmts.v, stmt)
        tx.stmts.Unlock()
@@ -1855,20 +1862,19 @@ func (tx *Tx) Prepare(query string) (*Stmt, error) {
 // The returned statement operates within the transaction and will be closed
 // when the transaction has been committed or rolled back.
 func (tx *Tx) StmtContext(ctx context.Context, stmt *Stmt) *Stmt {
-       tx.closemu.RLock()
-       defer tx.closemu.RUnlock()
+       dc, release, err := tx.grabConn(ctx)
+       if err != nil {
+               return &Stmt{stickyErr: err}
+       }
+       defer release(nil)
 
        if tx.db != stmt.db {
                return &Stmt{stickyErr: errors.New("sql: Tx.Stmt: statement from different database used")}
        }
-       dc, err := tx.grabConn(ctx)
-       if err != nil {
-               return &Stmt{stickyErr: err}
-       }
        var si driver.Stmt
        var parentStmt *Stmt
        stmt.mu.Lock()
-       if stmt.closed || stmt.tx != nil {
+       if stmt.closed || stmt.cg != nil {
                // If the statement has been closed or already belongs to a
                // transaction, we can't reuse it in this connection.
                // Since tx.StmtContext should never need to be called with a
@@ -1907,8 +1913,8 @@ func (tx *Tx) StmtContext(ctx context.Context, stmt *Stmt) *Stmt {
 
        txs := &Stmt{
                db: tx.db,
-               tx: tx,
-               txds: &driverStmt{
+               cg: tx,
+               cgds: &driverStmt{
                        Locker: dc,
                        si:     si,
                },
@@ -1943,14 +1949,11 @@ func (tx *Tx) Stmt(stmt *Stmt) *Stmt {
 // ExecContext executes a query that doesn't return rows.
 // For example: an INSERT and UPDATE.
 func (tx *Tx) ExecContext(ctx context.Context, query string, args ...interface{}) (Result, error) {
-       tx.closemu.RLock()
-
-       dc, err := tx.grabConn(ctx)
+       dc, release, err := tx.grabConn(ctx)
        if err != nil {
-               tx.closemu.RUnlock()
                return nil, err
        }
-       return tx.db.execDC(ctx, dc, tx.closemuRUnlockRelease, query, args)
+       return tx.db.execDC(ctx, dc, release, query, args)
 }
 
 // Exec executes a query that doesn't return rows.
@@ -1961,15 +1964,12 @@ func (tx *Tx) Exec(query string, args ...interface{}) (Result, error) {
 
 // QueryContext executes a query that returns rows, typically a SELECT.
 func (tx *Tx) QueryContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) {
-       tx.closemu.RLock()
-
-       dc, err := tx.grabConn(ctx)
+       dc, release, err := tx.grabConn(ctx)
        if err != nil {
-               tx.closemu.RUnlock()
                return nil, err
        }
 
-       return tx.db.queryDC(ctx, tx.ctx, dc, tx.closemuRUnlockRelease, query, args)
+       return tx.db.queryDC(ctx, tx.ctx, dc, release, query, args)
 }
 
 // Query executes a query that returns rows, typically a SELECT.
@@ -2004,6 +2004,24 @@ type connStmt struct {
        ds *driverStmt
 }
 
+// stmtConnGrabber represents a Tx or Conn that will return the underlying
+// driverConn and release function.
+type stmtConnGrabber interface {
+       // grabConn returns the driverConn and the associated release function
+       // that must be called when the operation completes.
+       grabConn(context.Context) (*driverConn, releaseConn, error)
+
+       // txCtx returns the transaction context if available.
+       // The returned context should be selected on along with
+       // any query context when awaiting a cancel.
+       txCtx() context.Context
+}
+
+var (
+       _ stmtConnGrabber = &Tx{}
+       _ stmtConnGrabber = &Conn{}
+)
+
 // Stmt is a prepared statement.
 // A Stmt is safe for concurrent use by multiple goroutines.
 type Stmt struct {
@@ -2014,9 +2032,13 @@ type Stmt struct {
 
        closemu sync.RWMutex // held exclusively during close, for read otherwise.
 
-       // If in a transaction, else both nil:
-       tx   *Tx
-       txds *driverStmt
+       // If Stmt is prepared on a Tx or Conn then cg is present and will
+       // only ever grab a connection from cg.
+       // If cg is nil then the Stmt must grab an arbitrary connection
+       // from db and determine if it must prepare the stmt again by
+       // inspecting css.
+       cg   stmtConnGrabber
+       cgds *driverStmt
 
        // parentStmt is set when a transaction-specific statement
        // is requested from an identical statement prepared on the same
@@ -2031,8 +2053,8 @@ type Stmt struct {
 
        // css is a list of underlying driver statement interfaces
        // that are valid on particular connections. This is only
-       // used if tx == nil and one is found that has idle
-       // connections. If tx != nil, txds is always used.
+       // used if cg == nil and one is found that has idle
+       // connections. If cg != nil, cgds is always used.
        css []connStmt
 
        // lastNumClosed is copied from db.numClosed when Stmt is created
@@ -2120,7 +2142,7 @@ func (s *Stmt) removeClosedStmtLocked() {
 // connStmt returns a free driver connection on which to execute the
 // statement, a function to call to release the connection, and a
 // statement bound to that connection.
-func (s *Stmt) connStmt(ctx context.Context, strategy connReuseStrategy) (ci *driverConn, releaseConn func(error), ds *driverStmt, err error) {
+func (s *Stmt) connStmt(ctx context.Context, strategy connReuseStrategy) (dc *driverConn, releaseConn func(error), ds *driverStmt, err error) {
        if err = s.stickyErr; err != nil {
                return
        }
@@ -2131,22 +2153,21 @@ func (s *Stmt) connStmt(ctx context.Context, strategy connReuseStrategy) (ci *dr
                return
        }
 
-       // In a transaction, we always use the connection that the
-       // transaction was created on.
-       if s.tx != nil {
+       // In a transaction or connection, we always use the connection that the
+       // the stmt was created on.
+       if s.cg != nil {
                s.mu.Unlock()
-               ci, err = s.tx.grabConn(ctx) // blocks, waiting for the connection.
+               dc, releaseConn, err = s.cg.grabConn(ctx) // blocks, waiting for the connection.
                if err != nil {
                        return
                }
-               releaseConn = func(error) {}
-               return ci, releaseConn, s.txds, nil
+               return dc, releaseConn, s.cgds, nil
        }
 
        s.removeClosedStmtLocked()
        s.mu.Unlock()
 
-       dc, err := s.db.conn(ctx, strategy)
+       dc, err = s.db.conn(ctx, strategy)
        if err != nil {
                return nil, nil, nil, err
        }
@@ -2175,7 +2196,7 @@ func (s *Stmt) connStmt(ctx context.Context, strategy connReuseStrategy) (ci *dr
 // prepareOnConnLocked prepares the query in Stmt s on dc and adds it to the list of
 // open connStmt on the statement. It assumes the caller is holding the lock on dc.
 func (s *Stmt) prepareOnConnLocked(ctx context.Context, dc *driverConn) (*driverStmt, error) {
-       si, err := dc.prepareLocked(ctx, s.query)
+       si, err := dc.prepareLocked(ctx, s.cg, s.query)
        if err != nil {
                return nil, err
        }
@@ -2226,8 +2247,8 @@ func (s *Stmt) QueryContext(ctx context.Context, args ...interface{}) (*Rows, er
                                s.db.removeDep(s, rows)
                        }
                        var txctx context.Context
-                       if s.tx != nil {
-                               txctx = s.tx.ctx
+                       if s.cg != nil {
+                               txctx = s.cg.txCtx()
                        }
                        rows.initContextClose(ctx, txctx)
                        return rows, nil
@@ -2323,9 +2344,12 @@ func (s *Stmt) Close() error {
                return nil
        }
        s.closed = true
+       txds := s.cgds
+       s.cgds = nil
+
        s.mu.Unlock()
 
-       if s.tx == nil {
+       if s.cg == nil {
                return s.db.removeDep(s, s)
        }
 
@@ -2334,7 +2358,7 @@ func (s *Stmt) Close() error {
                // in the css array of the parentStmt.
                return s.db.removeDep(s.parentStmt, s)
        }
-       return s.txds.Close()
+       return txds.Close()
 }
 
 func (s *Stmt) finalClose() error {
index 7895aa04047dd5d1c150e0b54d5491607dcd6c92..c935eb4348025756269926628187719ed7d8754d 100644 (file)
@@ -877,7 +877,7 @@ func TestStatementClose(t *testing.T) {
                msg  string
        }{
                {&Stmt{stickyErr: want}, "stickyErr not propagated"},
-               {&Stmt{tx: &Tx{}, txds: &driverStmt{Locker: &sync.Mutex{}, si: stubDriverStmt{want}}}, "driverStmt.Close() error not propagated"},
+               {&Stmt{cg: &Tx{}, cgds: &driverStmt{Locker: &sync.Mutex{}, si: stubDriverStmt{want}}}, "driverStmt.Close() error not propagated"},
        }
        for _, test := range tests {
                if err := test.stmt.Close(); err != want {
@@ -3231,6 +3231,42 @@ func TestIssue18719(t *testing.T) {
        cancel()
 }
 
+func TestIssue20647(t *testing.T) {
+       db := newTestDB(t, "people")
+       defer closeDB(t, db)
+
+       ctx, cancel := context.WithCancel(context.Background())
+       defer cancel()
+
+       conn, err := db.Conn(ctx)
+       if err != nil {
+               t.Fatal(err)
+       }
+       defer conn.Close()
+
+       stmt, err := conn.PrepareContext(ctx, "SELECT|people|name|")
+       if err != nil {
+               t.Fatal(err)
+       }
+       defer stmt.Close()
+
+       rows1, err := stmt.QueryContext(ctx)
+       if err != nil {
+               t.Fatal("rows1", err)
+       }
+       defer rows1.Close()
+
+       rows2, err := stmt.QueryContext(ctx)
+       if err != nil {
+               t.Fatal("rows2", err)
+       }
+       defer rows2.Close()
+
+       if rows1.dc != rows2.dc {
+               t.Fatal("stmt prepared on Conn does not use same connection")
+       }
+}
+
 func TestConcurrency(t *testing.T) {
        list := []struct {
                name string