From: Daniel Theophanes Date: Mon, 12 Jun 2017 17:46:15 +0000 (-0700) Subject: database/sql: ensure a Stmt from a Conn executes on the same driver.Conn X-Git-Tag: go1.9beta1~33 X-Git-Url: http://www.git.cypherpunks.su/?a=commitdiff_plain;h=cd24a8a5509376e4d5c492256a0e1e120cab63e7;p=gostls13.git database/sql: ensure a Stmt from a Conn executes on the same driver.Conn Ensure a Stmt prepared on a Conn executes on the same driver.Conn. This also removes another instance of duplicated prepare logic as a side effect. Fixes #20647 Change-Id: Ia00a19e4dd15e19e4d754105babdff5dc127728f Reviewed-on: https://go-review.googlesource.com/45391 Reviewed-by: Brad Fitzpatrick Run-TryBot: Brad Fitzpatrick TryBot-Result: Gobot Gobot --- diff --git a/src/database/sql/sql.go b/src/database/sql/sql.go index 9bc6414f2b..59bbf59c30 100644 --- a/src/database/sql/sql.go +++ b/src/database/sql/sql.go @@ -393,11 +393,19 @@ func (dc *driverConn) expired(timeout time.Duration) bool { return dc.createdAt.Add(timeout).Before(nowFunc()) } -func (dc *driverConn) prepareLocked(ctx context.Context, query string) (*driverStmt, error) { +// prepareLocked prepares the query on dc. When cg == nil the dc must keep track of +// the prepared statements in a pool. +func (dc *driverConn) prepareLocked(ctx context.Context, cg stmtConnGrabber, query string) (*driverStmt, error) { si, err := ctxDriverPrepare(ctx, dc.ci, query) if err != nil { return nil, err } + ds := &driverStmt{Locker: dc, si: si} + + // No need to manage open statements if there is a single connection grabber. + if cg != nil { + return ds, nil + } // Track each driverConn's open statements, so we can close them // before closing the conn. @@ -406,9 +414,7 @@ func (dc *driverConn) prepareLocked(ctx context.Context, query string) (*driverS if dc.openStmt == nil { dc.openStmt = make(map[*driverStmt]bool) } - ds := &driverStmt{Locker: dc, si: si} dc.openStmt[ds] = true - return ds, nil } @@ -1165,28 +1171,39 @@ func (db *DB) prepare(ctx context.Context, query string, strategy connReuseStrat if err != nil { return nil, err } - return db.prepareDC(ctx, dc, dc.releaseConn, query) + return db.prepareDC(ctx, dc, dc.releaseConn, nil, query) } -func (db *DB) prepareDC(ctx context.Context, dc *driverConn, release func(error), query string) (*Stmt, error) { +// prepareDC prepares a query on the driverConn and calls release before +// returning. When cg == nil it implies that a connection pool is used, and +// when cg != nil only a single driver connection is used. +func (db *DB) prepareDC(ctx context.Context, dc *driverConn, release func(error), cg stmtConnGrabber, query string) (*Stmt, error) { var ds *driverStmt var err error defer func() { release(err) }() withLock(dc, func() { - ds, err = dc.prepareLocked(ctx, query) + ds, err = dc.prepareLocked(ctx, cg, query) }) if err != nil { return nil, err } stmt := &Stmt{ - db: db, - query: query, - css: []connStmt{{dc, ds}}, - lastNumClosed: atomic.LoadUint64(&db.numClosed), + db: db, + query: query, + cg: cg, + cgds: ds, + } + + // When cg == nil this statement will need to keep track of various + // connections they are prepared on and record the stmt dependency on + // the DB. + if cg == nil { + stmt.css = []connStmt{{dc, ds}} + stmt.lastNumClosed = atomic.LoadUint64(&db.numClosed) + db.addDep(stmt, stmt) } - db.addDep(stmt, stmt) return stmt, nil } @@ -1474,6 +1491,8 @@ func (db *DB) Conn(ctx context.Context) (*Conn, error) { return conn, nil } +type releaseConn func(error) + // Conn represents a single database session rather a pool of database // sessions. Prefer running queries from DB unless there is a specific // need for a continuous single database session. @@ -1501,46 +1520,41 @@ type Conn struct { done int32 } -func (c *Conn) grabConn() (*driverConn, error) { +func (c *Conn) grabConn(context.Context) (*driverConn, releaseConn, error) { if atomic.LoadInt32(&c.done) != 0 { - return nil, ErrConnDone + return nil, nil, ErrConnDone } - return c.dc, nil + c.closemu.RLock() + return c.dc, c.closemuRUnlockCondReleaseConn, nil } // PingContext verifies the connection to the database is still alive. func (c *Conn) PingContext(ctx context.Context) error { - dc, err := c.grabConn() + dc, release, err := c.grabConn(ctx) if err != nil { return err } - - c.closemu.RLock() - return c.db.pingDC(ctx, dc, c.closemuRUnlockCondReleaseConn) + return c.db.pingDC(ctx, dc, release) } // ExecContext executes a query without returning any rows. // The args are for any placeholder parameters in the query. func (c *Conn) ExecContext(ctx context.Context, query string, args ...interface{}) (Result, error) { - dc, err := c.grabConn() + dc, release, err := c.grabConn(ctx) if err != nil { return nil, err } - - c.closemu.RLock() - return c.db.execDC(ctx, dc, c.closemuRUnlockCondReleaseConn, query, args) + return c.db.execDC(ctx, dc, release, query, args) } // QueryContext executes a query that returns rows, typically a SELECT. // The args are for any placeholder parameters in the query. func (c *Conn) QueryContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) { - dc, err := c.grabConn() + dc, release, err := c.grabConn(ctx) if err != nil { return nil, err } - - c.closemu.RLock() - return c.db.queryDC(ctx, nil, dc, c.closemuRUnlockCondReleaseConn, query, args) + return c.db.queryDC(ctx, nil, dc, release, query, args) } // QueryRowContext executes a query that is expected to return at most one row. @@ -1563,13 +1577,11 @@ func (c *Conn) QueryRowContext(ctx context.Context, query string, args ...interf // The provided context is used for the preparation of the statement, not for the // execution of the statement. func (c *Conn) PrepareContext(ctx context.Context, query string) (*Stmt, error) { - dc, err := c.grabConn() + dc, release, err := c.grabConn(ctx) if err != nil { return nil, err } - - c.closemu.RLock() - return c.db.prepareDC(ctx, dc, c.closemuRUnlockCondReleaseConn, query) + return c.db.prepareDC(ctx, dc, release, c, query) } // BeginTx starts a transaction. @@ -1583,13 +1595,11 @@ func (c *Conn) PrepareContext(ctx context.Context, query string) (*Stmt, error) // If a non-default isolation level is used that the driver doesn't support, // an error will be returned. func (c *Conn) BeginTx(ctx context.Context, opts *TxOptions) (*Tx, error) { - dc, err := c.grabConn() + dc, release, err := c.grabConn(ctx) if err != nil { return nil, err } - - c.closemu.RLock() - return c.db.beginDC(ctx, dc, c.closemuRUnlockCondReleaseConn, opts) + return c.db.beginDC(ctx, dc, release, opts) } // closemuRUnlockCondReleaseConn read unlocks closemu @@ -1601,6 +1611,10 @@ func (c *Conn) closemuRUnlockCondReleaseConn(err error) { } } +func (c *Conn) txCtx() context.Context { + return nil +} + func (c *Conn) close(err error) error { if !atomic.CompareAndSwapInt32(&c.done, 0, 1) { return ErrConnDone @@ -1712,19 +1726,28 @@ func (tx *Tx) close(err error) { // a successful call to (*Tx).grabConn. For tests. var hookTxGrabConn func() -func (tx *Tx) grabConn(ctx context.Context) (*driverConn, error) { +func (tx *Tx) grabConn(ctx context.Context) (*driverConn, releaseConn, error) { select { default: case <-ctx.Done(): - return nil, ctx.Err() + return nil, nil, ctx.Err() } + + // closeme.RLock must come before the check for isDone to prevent the Tx from + // closing while a query is executing. + tx.closemu.RLock() if tx.isDone() { - return nil, ErrTxDone + tx.closemu.RUnlock() + return nil, nil, ErrTxDone } if hookTxGrabConn != nil { // test hook hookTxGrabConn() } - return tx.dc, nil + return tx.dc, tx.closemuRUnlockRelease, nil +} + +func (tx *Tx) txCtx() context.Context { + return tx.ctx } // closemuRUnlockRelease is used as a func(error) method value in @@ -1801,31 +1824,15 @@ func (tx *Tx) Rollback() error { // for the execution of the returned statement. The returned statement // will run in the transaction context. func (tx *Tx) PrepareContext(ctx context.Context, query string) (*Stmt, error) { - tx.closemu.RLock() - defer tx.closemu.RUnlock() - - dc, err := tx.grabConn(ctx) + dc, release, err := tx.grabConn(ctx) if err != nil { return nil, err } - var si driver.Stmt - withLock(dc, func() { - si, err = ctxDriverPrepare(ctx, dc.ci, query) - }) + stmt, err := tx.db.prepareDC(ctx, dc, release, tx, query) if err != nil { return nil, err } - - stmt := &Stmt{ - db: tx.db, - tx: tx, - txds: &driverStmt{ - Locker: dc, - si: si, - }, - query: query, - } tx.stmts.Lock() tx.stmts.v = append(tx.stmts.v, stmt) tx.stmts.Unlock() @@ -1855,20 +1862,19 @@ func (tx *Tx) Prepare(query string) (*Stmt, error) { // The returned statement operates within the transaction and will be closed // when the transaction has been committed or rolled back. func (tx *Tx) StmtContext(ctx context.Context, stmt *Stmt) *Stmt { - tx.closemu.RLock() - defer tx.closemu.RUnlock() + dc, release, err := tx.grabConn(ctx) + if err != nil { + return &Stmt{stickyErr: err} + } + defer release(nil) if tx.db != stmt.db { return &Stmt{stickyErr: errors.New("sql: Tx.Stmt: statement from different database used")} } - dc, err := tx.grabConn(ctx) - if err != nil { - return &Stmt{stickyErr: err} - } var si driver.Stmt var parentStmt *Stmt stmt.mu.Lock() - if stmt.closed || stmt.tx != nil { + if stmt.closed || stmt.cg != nil { // If the statement has been closed or already belongs to a // transaction, we can't reuse it in this connection. // Since tx.StmtContext should never need to be called with a @@ -1907,8 +1913,8 @@ func (tx *Tx) StmtContext(ctx context.Context, stmt *Stmt) *Stmt { txs := &Stmt{ db: tx.db, - tx: tx, - txds: &driverStmt{ + cg: tx, + cgds: &driverStmt{ Locker: dc, si: si, }, @@ -1943,14 +1949,11 @@ func (tx *Tx) Stmt(stmt *Stmt) *Stmt { // ExecContext executes a query that doesn't return rows. // For example: an INSERT and UPDATE. func (tx *Tx) ExecContext(ctx context.Context, query string, args ...interface{}) (Result, error) { - tx.closemu.RLock() - - dc, err := tx.grabConn(ctx) + dc, release, err := tx.grabConn(ctx) if err != nil { - tx.closemu.RUnlock() return nil, err } - return tx.db.execDC(ctx, dc, tx.closemuRUnlockRelease, query, args) + return tx.db.execDC(ctx, dc, release, query, args) } // Exec executes a query that doesn't return rows. @@ -1961,15 +1964,12 @@ func (tx *Tx) Exec(query string, args ...interface{}) (Result, error) { // QueryContext executes a query that returns rows, typically a SELECT. func (tx *Tx) QueryContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) { - tx.closemu.RLock() - - dc, err := tx.grabConn(ctx) + dc, release, err := tx.grabConn(ctx) if err != nil { - tx.closemu.RUnlock() return nil, err } - return tx.db.queryDC(ctx, tx.ctx, dc, tx.closemuRUnlockRelease, query, args) + return tx.db.queryDC(ctx, tx.ctx, dc, release, query, args) } // Query executes a query that returns rows, typically a SELECT. @@ -2004,6 +2004,24 @@ type connStmt struct { ds *driverStmt } +// stmtConnGrabber represents a Tx or Conn that will return the underlying +// driverConn and release function. +type stmtConnGrabber interface { + // grabConn returns the driverConn and the associated release function + // that must be called when the operation completes. + grabConn(context.Context) (*driverConn, releaseConn, error) + + // txCtx returns the transaction context if available. + // The returned context should be selected on along with + // any query context when awaiting a cancel. + txCtx() context.Context +} + +var ( + _ stmtConnGrabber = &Tx{} + _ stmtConnGrabber = &Conn{} +) + // Stmt is a prepared statement. // A Stmt is safe for concurrent use by multiple goroutines. type Stmt struct { @@ -2014,9 +2032,13 @@ type Stmt struct { closemu sync.RWMutex // held exclusively during close, for read otherwise. - // If in a transaction, else both nil: - tx *Tx - txds *driverStmt + // If Stmt is prepared on a Tx or Conn then cg is present and will + // only ever grab a connection from cg. + // If cg is nil then the Stmt must grab an arbitrary connection + // from db and determine if it must prepare the stmt again by + // inspecting css. + cg stmtConnGrabber + cgds *driverStmt // parentStmt is set when a transaction-specific statement // is requested from an identical statement prepared on the same @@ -2031,8 +2053,8 @@ type Stmt struct { // css is a list of underlying driver statement interfaces // that are valid on particular connections. This is only - // used if tx == nil and one is found that has idle - // connections. If tx != nil, txds is always used. + // used if cg == nil and one is found that has idle + // connections. If cg != nil, cgds is always used. css []connStmt // lastNumClosed is copied from db.numClosed when Stmt is created @@ -2120,7 +2142,7 @@ func (s *Stmt) removeClosedStmtLocked() { // connStmt returns a free driver connection on which to execute the // statement, a function to call to release the connection, and a // statement bound to that connection. -func (s *Stmt) connStmt(ctx context.Context, strategy connReuseStrategy) (ci *driverConn, releaseConn func(error), ds *driverStmt, err error) { +func (s *Stmt) connStmt(ctx context.Context, strategy connReuseStrategy) (dc *driverConn, releaseConn func(error), ds *driverStmt, err error) { if err = s.stickyErr; err != nil { return } @@ -2131,22 +2153,21 @@ func (s *Stmt) connStmt(ctx context.Context, strategy connReuseStrategy) (ci *dr return } - // In a transaction, we always use the connection that the - // transaction was created on. - if s.tx != nil { + // In a transaction or connection, we always use the connection that the + // the stmt was created on. + if s.cg != nil { s.mu.Unlock() - ci, err = s.tx.grabConn(ctx) // blocks, waiting for the connection. + dc, releaseConn, err = s.cg.grabConn(ctx) // blocks, waiting for the connection. if err != nil { return } - releaseConn = func(error) {} - return ci, releaseConn, s.txds, nil + return dc, releaseConn, s.cgds, nil } s.removeClosedStmtLocked() s.mu.Unlock() - dc, err := s.db.conn(ctx, strategy) + dc, err = s.db.conn(ctx, strategy) if err != nil { return nil, nil, nil, err } @@ -2175,7 +2196,7 @@ func (s *Stmt) connStmt(ctx context.Context, strategy connReuseStrategy) (ci *dr // prepareOnConnLocked prepares the query in Stmt s on dc and adds it to the list of // open connStmt on the statement. It assumes the caller is holding the lock on dc. func (s *Stmt) prepareOnConnLocked(ctx context.Context, dc *driverConn) (*driverStmt, error) { - si, err := dc.prepareLocked(ctx, s.query) + si, err := dc.prepareLocked(ctx, s.cg, s.query) if err != nil { return nil, err } @@ -2226,8 +2247,8 @@ func (s *Stmt) QueryContext(ctx context.Context, args ...interface{}) (*Rows, er s.db.removeDep(s, rows) } var txctx context.Context - if s.tx != nil { - txctx = s.tx.ctx + if s.cg != nil { + txctx = s.cg.txCtx() } rows.initContextClose(ctx, txctx) return rows, nil @@ -2323,9 +2344,12 @@ func (s *Stmt) Close() error { return nil } s.closed = true + txds := s.cgds + s.cgds = nil + s.mu.Unlock() - if s.tx == nil { + if s.cg == nil { return s.db.removeDep(s, s) } @@ -2334,7 +2358,7 @@ func (s *Stmt) Close() error { // in the css array of the parentStmt. return s.db.removeDep(s.parentStmt, s) } - return s.txds.Close() + return txds.Close() } func (s *Stmt) finalClose() error { diff --git a/src/database/sql/sql_test.go b/src/database/sql/sql_test.go index 7895aa0404..c935eb4348 100644 --- a/src/database/sql/sql_test.go +++ b/src/database/sql/sql_test.go @@ -877,7 +877,7 @@ func TestStatementClose(t *testing.T) { msg string }{ {&Stmt{stickyErr: want}, "stickyErr not propagated"}, - {&Stmt{tx: &Tx{}, txds: &driverStmt{Locker: &sync.Mutex{}, si: stubDriverStmt{want}}}, "driverStmt.Close() error not propagated"}, + {&Stmt{cg: &Tx{}, cgds: &driverStmt{Locker: &sync.Mutex{}, si: stubDriverStmt{want}}}, "driverStmt.Close() error not propagated"}, } for _, test := range tests { if err := test.stmt.Close(); err != want { @@ -3231,6 +3231,42 @@ func TestIssue18719(t *testing.T) { cancel() } +func TestIssue20647(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + conn, err := db.Conn(ctx) + if err != nil { + t.Fatal(err) + } + defer conn.Close() + + stmt, err := conn.PrepareContext(ctx, "SELECT|people|name|") + if err != nil { + t.Fatal(err) + } + defer stmt.Close() + + rows1, err := stmt.QueryContext(ctx) + if err != nil { + t.Fatal("rows1", err) + } + defer rows1.Close() + + rows2, err := stmt.QueryContext(ctx) + if err != nil { + t.Fatal("rows2", err) + } + defer rows2.Close() + + if rows1.dc != rows2.dc { + t.Fatal("stmt prepared on Conn does not use same connection") + } +} + func TestConcurrency(t *testing.T) { list := []struct { name string