package sql
import (
+ "context"
"database/sql/driver"
"errors"
"fmt"
return dc.createdAt.Add(timeout).Before(nowFunc())
}
-func (dc *driverConn) prepareLocked(query string) (driver.Stmt, error) {
- si, err := dc.ci.Prepare(query)
+func (dc *driverConn) prepareLocked(ctx context.Context, query string) (driver.Stmt, error) {
+ si, err := ctxDriverPrepare(ctx, dc.ci, query)
if err == nil {
// Track each driverConn's open statements, so we can close them
// before closing the conn.
return db, nil
}
-// Ping verifies a connection to the database is still alive,
+// PingContext verifies a connection to the database is still alive,
// establishing a connection if necessary.
-func (db *DB) Ping() error {
+func (db *DB) PingContext(ctx context.Context) error {
// TODO(bradfitz): give drivers an optional hook to implement
// this in a more efficient or more reliable way, if they
// have one.
- dc, err := db.conn(cachedOrNewConn)
+ dc, err := db.conn(ctx, cachedOrNewConn)
if err != nil {
return err
}
return nil
}
+// Ping verifies a connection to the database is still alive,
+// establishing a connection if necessary.
+func (db *DB) Ping() error {
+ return db.PingContext(context.Background())
+}
+
// Close closes the database, releasing any open resources.
//
// It is rare to Close a DB, as the DB handle is meant to be
var errDBClosed = errors.New("sql: database is closed")
// conn returns a newly-opened or cached *driverConn.
-func (db *DB) conn(strategy connReuseStrategy) (*driverConn, error) {
+func (db *DB) conn(ctx context.Context, strategy connReuseStrategy) (*driverConn, error) {
db.mu.Lock()
if db.closed {
db.mu.Unlock()
return nil, errDBClosed
}
+ // Check if the context is expired.
+ if err := ctx.Err(); err != nil {
+ return nil, err
+ }
lifetime := db.maxLifetime
// Prefer a free connection, if possible.
req := make(chan connRequest, 1)
db.connRequests = append(db.connRequests, req)
db.mu.Unlock()
- ret, ok := <-req
- if !ok {
- return nil, errDBClosed
- }
- if ret.err == nil && ret.conn.expired(lifetime) {
- ret.conn.Close()
- return nil, driver.ErrBadConn
+
+ // Timeout the connection request with the context.
+ select {
+ case <-ctx.Done():
+ return nil, ctx.Err()
+ case ret, ok := <-req:
+ if !ok {
+ return nil, errDBClosed
+ }
+ if ret.err == nil && ret.conn.expired(lifetime) {
+ ret.conn.Close()
+ return nil, driver.ErrBadConn
+ }
+ return ret.conn, ret.err
}
- return ret.conn, ret.err
}
db.numOpen++ // optimistically
// connection to be opened.
const maxBadConnRetries = 2
-// Prepare creates a prepared statement for later queries or executions.
+// PrepareContext creates a prepared statement for later queries or executions.
// Multiple queries or executions may be run concurrently from the
// returned statement.
// The caller must call the statement's Close method
// when the statement is no longer needed.
-func (db *DB) Prepare(query string) (*Stmt, error) {
+// Context is for the preparation of the statment, not for the execution of
+// the statement.
+func (db *DB) PrepareContext(ctx context.Context, query string) (*Stmt, error) {
var stmt *Stmt
var err error
for i := 0; i < maxBadConnRetries; i++ {
- stmt, err = db.prepare(query, cachedOrNewConn)
+ stmt, err = db.prepare(ctx, query, cachedOrNewConn)
if err != driver.ErrBadConn {
break
}
}
if err == driver.ErrBadConn {
- return db.prepare(query, alwaysNewConn)
+ return db.prepare(ctx, query, alwaysNewConn)
}
return stmt, err
}
-func (db *DB) prepare(query string, strategy connReuseStrategy) (*Stmt, error) {
+// Prepare creates a prepared statement for later queries or executions.
+// Multiple queries or executions may be run concurrently from the
+// returned statement.
+// The caller must call the statement's Close method
+// when the statement is no longer needed.
+func (db *DB) Prepare(query string) (*Stmt, error) {
+ return db.PrepareContext(context.Background(), query)
+}
+
+func (db *DB) prepare(ctx context.Context, query string, strategy connReuseStrategy) (*Stmt, error) {
// TODO: check if db.driver supports an optional
// driver.Preparer interface and call that instead, if so,
// otherwise we make a prepared statement that's bound
// to a connection, and to execute this prepared statement
// we either need to use this connection (if it's free), else
// get a new connection + re-prepare + execute on that one.
- dc, err := db.conn(strategy)
+ dc, err := db.conn(ctx, strategy)
if err != nil {
return nil, err
}
var si driver.Stmt
withLock(dc, func() {
- si, err = dc.prepareLocked(query)
+ si, err = dc.prepareLocked(ctx, query)
})
if err != nil {
db.putConn(dc, err)
return stmt, nil
}
-// Exec executes a query without returning any rows.
+// ExecContext executes a query without returning any rows.
// The args are for any placeholder parameters in the query.
-func (db *DB) Exec(query string, args ...interface{}) (Result, error) {
+func (db *DB) ExecContext(ctx context.Context, query string, args ...interface{}) (Result, error) {
var res Result
var err error
for i := 0; i < maxBadConnRetries; i++ {
- res, err = db.exec(query, args, cachedOrNewConn)
+ res, err = db.exec(ctx, query, args, cachedOrNewConn)
if err != driver.ErrBadConn {
break
}
}
if err == driver.ErrBadConn {
- return db.exec(query, args, alwaysNewConn)
+ return db.exec(ctx, query, args, alwaysNewConn)
}
return res, err
}
-func (db *DB) exec(query string, args []interface{}, strategy connReuseStrategy) (res Result, err error) {
- dc, err := db.conn(strategy)
+// Exec executes a query without returning any rows.
+// The args are for any placeholder parameters in the query.
+func (db *DB) Exec(query string, args ...interface{}) (Result, error) {
+ return db.ExecContext(context.Background(), query, args...)
+}
+
+func (db *DB) exec(ctx context.Context, query string, args []interface{}, strategy connReuseStrategy) (res Result, err error) {
+ dc, err := db.conn(ctx, strategy)
if err != nil {
return nil, err
}
}
var resi driver.Result
withLock(dc, func() {
- resi, err = execer.Exec(query, dargs)
+ resi, err = ctxDriverExec(ctx, execer, query, dargs)
})
if err != driver.ErrSkip {
if err != nil {
var si driver.Stmt
withLock(dc, func() {
- si, err = dc.ci.Prepare(query)
+ si, err = ctxDriverPrepare(ctx, dc.ci, query)
})
if err != nil {
return nil, err
}
defer withLock(dc, func() { si.Close() })
- return resultFromStatement(driverStmt{dc, si}, args...)
+ return resultFromStatement(ctx, driverStmt{dc, si}, args...)
}
-// Query executes a query that returns rows, typically a SELECT.
+// QueryContext executes a query that returns rows, typically a SELECT.
// The args are for any placeholder parameters in the query.
-func (db *DB) Query(query string, args ...interface{}) (*Rows, error) {
+func (db *DB) QueryContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) {
var rows *Rows
var err error
for i := 0; i < maxBadConnRetries; i++ {
- rows, err = db.query(query, args, cachedOrNewConn)
+ rows, err = db.query(ctx, query, args, cachedOrNewConn)
if err != driver.ErrBadConn {
break
}
}
if err == driver.ErrBadConn {
- return db.query(query, args, alwaysNewConn)
+ return db.query(ctx, query, args, alwaysNewConn)
}
return rows, err
}
-func (db *DB) query(query string, args []interface{}, strategy connReuseStrategy) (*Rows, error) {
- ci, err := db.conn(strategy)
+// Query executes a query that returns rows, typically a SELECT.
+// The args are for any placeholder parameters in the query.
+func (db *DB) Query(query string, args ...interface{}) (*Rows, error) {
+ return db.QueryContext(context.Background(), query, args...)
+}
+
+func (db *DB) query(ctx context.Context, query string, args []interface{}, strategy connReuseStrategy) (*Rows, error) {
+ ci, err := db.conn(ctx, strategy)
if err != nil {
return nil, err
}
- return db.queryConn(ci, ci.releaseConn, query, args)
+ return db.queryConn(ctx, ci, ci.releaseConn, query, args)
}
// queryConn executes a query on the given connection.
// The connection gets released by the releaseConn function.
-func (db *DB) queryConn(dc *driverConn, releaseConn func(error), query string, args []interface{}) (*Rows, error) {
+func (db *DB) queryConn(ctx context.Context, dc *driverConn, releaseConn func(error), query string, args []interface{}) (*Rows, error) {
if queryer, ok := dc.ci.(driver.Queryer); ok {
dargs, err := driverArgs(nil, args)
if err != nil {
}
var rowsi driver.Rows
withLock(dc, func() {
- rowsi, err = queryer.Query(query, dargs)
+ rowsi, err = ctxDriverQuery(ctx, queryer, query, dargs)
})
if err != driver.ErrSkip {
if err != nil {
var si driver.Stmt
var err error
withLock(dc, func() {
- si, err = dc.ci.Prepare(query)
+ si, err = ctxDriverPrepare(ctx, dc.ci, query)
})
if err != nil {
releaseConn(err)
}
ds := driverStmt{dc, si}
- rowsi, err := rowsiFromStatement(ds, args...)
+ rowsi, err := rowsiFromStatement(ctx, ds, args...)
if err != nil {
withLock(dc, func() {
si.Close()
return rows, nil
}
+// QueryRowContext executes a query that is expected to return at most one row.
+// QueryRowContext always returns a non-nil value. Errors are deferred until
+// Row's Scan method is called.
+func (db *DB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *Row {
+ rows, err := db.QueryContext(ctx, query, args...)
+ return &Row{rows: rows, err: err}
+}
+
// QueryRow executes a query that is expected to return at most one row.
// QueryRow always returns a non-nil value. Errors are deferred until
// Row's Scan method is called.
func (db *DB) QueryRow(query string, args ...interface{}) *Row {
- rows, err := db.Query(query, args...)
- return &Row{rows: rows, err: err}
+ return db.QueryRowContext(context.Background(), query, args...)
}
-// Begin starts a transaction. The isolation level is dependent on
-// the driver.
-func (db *DB) Begin() (*Tx, error) {
+// BeginContext starts a transaction. If a non-default isolation level is used
+// that the driver doesn't support an error will be returned. Different drivers
+// may have slightly different meanings for the same isolation level.
+func (db *DB) BeginContext(ctx context.Context) (*Tx, error) {
var tx *Tx
var err error
for i := 0; i < maxBadConnRetries; i++ {
- tx, err = db.begin(cachedOrNewConn)
+ tx, err = db.begin(ctx, cachedOrNewConn)
if err != driver.ErrBadConn {
break
}
}
if err == driver.ErrBadConn {
- return db.begin(alwaysNewConn)
+ return db.begin(ctx, alwaysNewConn)
}
return tx, err
}
-func (db *DB) begin(strategy connReuseStrategy) (tx *Tx, err error) {
- dc, err := db.conn(strategy)
+// Begin starts a transaction. The default isolation level is dependent on
+// the driver.
+func (db *DB) Begin() (*Tx, error) {
+ return db.BeginContext(context.Background())
+}
+
+func (db *DB) begin(ctx context.Context, strategy connReuseStrategy) (tx *Tx, err error) {
+ dc, err := db.conn(ctx, strategy)
if err != nil {
return nil, err
}
var txi driver.Tx
withLock(dc, func() {
- txi, err = dc.ci.Begin()
+ txi, err = ctxDriverBegin(ctx, dc.ci)
})
if err != nil {
db.putConn(dc, err)
return nil, err
}
- return &Tx{
- db: db,
- dc: dc,
- txi: txi,
- }, nil
+
+ // Schedule the transaction to rollback when the context is cancelled.
+ // The cancel function in Tx will be called after done is set to true.
+ ctx, cancel := context.WithCancel(ctx)
+ tx = &Tx{
+ db: db,
+ dc: dc,
+ txi: txi,
+ cancel: cancel,
+ }
+ go func() {
+ select {
+ case <-ctx.Done():
+ if !tx.done {
+ tx.Rollback()
+ }
+ }
+ }()
+ return tx, nil
}
// Driver returns the database's underlying driver.
sync.Mutex
v []*Stmt
}
+
+ // cancel is called after done transitions from false to true.
+ cancel func()
}
// ErrTxDone is returned by any operation that is performed on a transaction
}
tx.done = true
tx.db.putConn(tx.dc, err)
+ tx.cancel()
tx.dc = nil
tx.txi = nil
}
-func (tx *Tx) grabConn() (*driverConn, error) {
+func (tx *Tx) grabConn(ctx context.Context) (*driverConn, error) {
if tx.done {
return nil, ErrTxDone
}
// be used once the transaction has been committed or rolled back.
//
// To use an existing prepared statement on this transaction, see Tx.Stmt.
-func (tx *Tx) Prepare(query string) (*Stmt, error) {
+// Context will be used for the preparation of the context, not
+// for the execution of the returned statement. The returned statement
+// will run in the transaction context.
+func (tx *Tx) PrepareContext(ctx context.Context, query string) (*Stmt, error) {
// TODO(bradfitz): We could be more efficient here and either
// provide a method to take an existing Stmt (created on
// perhaps a different Conn), and re-create it on this Conn if
// Perhaps just looking at the reference count (by noting
// Stmt.Close) would be enough. We might also want a finalizer
// on Stmt to drop the reference count.
- dc, err := tx.grabConn()
+ dc, err := tx.grabConn(ctx)
if err != nil {
return nil, err
}
return stmt, nil
}
-// Stmt returns a transaction-specific prepared statement from
+// Prepare creates a prepared statement for use within a transaction.
+//
+// The returned statement operates within the transaction and can no longer
+// be used once the transaction has been committed or rolled back.
+//
+// To use an existing prepared statement on this transaction, see Tx.Stmt.
+func (tx *Tx) Prepare(query string) (*Stmt, error) {
+ return tx.PrepareContext(context.Background(), query)
+}
+
+// StmtContext returns a transaction-specific prepared statement from
// an existing statement.
//
// Example:
// ...
// tx, err := db.Begin()
// ...
-// res, err := tx.Stmt(updateMoney).Exec(123.45, 98293203)
+// res, err := tx.StmtContext(ctx, updateMoney).Exec(123.45, 98293203)
//
// The returned statement operates within the transaction and can no longer
// be used once the transaction has been committed or rolled back.
-func (tx *Tx) Stmt(stmt *Stmt) *Stmt {
+func (tx *Tx) StmtContext(ctx context.Context, stmt *Stmt) *Stmt {
// TODO(bradfitz): optimize this. Currently this re-prepares
// each time. This is fine for now to illustrate the API but
// we should really cache already-prepared statements
if tx.db != stmt.db {
return &Stmt{stickyErr: errors.New("sql: Tx.Stmt: statement from different database used")}
}
- dc, err := tx.grabConn()
+ dc, err := tx.grabConn(ctx)
if err != nil {
return &Stmt{stickyErr: err}
}
return txs
}
-// Exec executes a query that doesn't return rows.
+// Stmt returns a transaction-specific prepared statement from
+// an existing statement.
+//
+// Example:
+// updateMoney, err := db.Prepare("UPDATE balance SET money=money+? WHERE id=?")
+// ...
+// tx, err := db.Begin()
+// ...
+// res, err := tx.Stmt(updateMoney).Exec(123.45, 98293203)
+//
+// The returned statement operates within the transaction and can no longer
+// be used once the transaction has been committed or rolled back.
+func (tx *Tx) Stmt(stmt *Stmt) *Stmt {
+ return tx.StmtContext(context.Background(), stmt)
+}
+
+// ExecContext executes a query that doesn't return rows.
// For example: an INSERT and UPDATE.
-func (tx *Tx) Exec(query string, args ...interface{}) (Result, error) {
- dc, err := tx.grabConn()
+func (tx *Tx) ExecContext(ctx context.Context, query string, args ...interface{}) (Result, error) {
+ dc, err := tx.grabConn(ctx)
if err != nil {
return nil, err
}
}
defer withLock(dc, func() { si.Close() })
- return resultFromStatement(driverStmt{dc, si}, args...)
+ return resultFromStatement(ctx, driverStmt{dc, si}, args...)
}
-// Query executes a query that returns rows, typically a SELECT.
-func (tx *Tx) Query(query string, args ...interface{}) (*Rows, error) {
- dc, err := tx.grabConn()
+// Exec executes a query that doesn't return rows.
+// For example: an INSERT and UPDATE.
+func (tx *Tx) Exec(query string, args ...interface{}) (Result, error) {
+ return tx.ExecContext(context.Background(), query, args...)
+}
+
+// QueryContext executes a query that returns rows, typically a SELECT.
+func (tx *Tx) QueryContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) {
+ dc, err := tx.grabConn(ctx)
if err != nil {
return nil, err
}
releaseConn := func(error) {}
- return tx.db.queryConn(dc, releaseConn, query, args)
+ return tx.db.queryConn(ctx, dc, releaseConn, query, args)
+}
+
+// Query executes a query that returns rows, typically a SELECT.
+func (tx *Tx) Query(query string, args ...interface{}) (*Rows, error) {
+ return tx.QueryContext(context.Background(), query, args...)
+}
+
+// QueryRowContext executes a query that is expected to return at most one row.
+// QueryRowContext always returns a non-nil value. Errors are deferred until
+// Row's Scan method is called.
+func (tx *Tx) QueryRowContext(ctx context.Context, query string, args ...interface{}) *Row {
+ rows, err := tx.QueryContext(ctx, query, args...)
+ return &Row{rows: rows, err: err}
}
// QueryRow executes a query that is expected to return at most one row.
// QueryRow always returns a non-nil value. Errors are deferred until
// Row's Scan method is called.
func (tx *Tx) QueryRow(query string, args ...interface{}) *Row {
- rows, err := tx.Query(query, args...)
- return &Row{rows: rows, err: err}
+ return tx.QueryRowContext(context.Background(), query, args...)
}
// connStmt is a prepared statement on a particular connection.
lastNumClosed uint64
}
-// Exec executes a prepared statement with the given arguments and
+// ExecContext executes a prepared statement with the given arguments and
// returns a Result summarizing the effect of the statement.
-func (s *Stmt) Exec(args ...interface{}) (Result, error) {
+func (s *Stmt) ExecContext(ctx context.Context, args ...interface{}) (Result, error) {
s.closemu.RLock()
defer s.closemu.RUnlock()
var res Result
for i := 0; i < maxBadConnRetries; i++ {
- dc, releaseConn, si, err := s.connStmt()
+ dc, releaseConn, si, err := s.connStmt(ctx)
if err != nil {
if err == driver.ErrBadConn {
continue
return nil, err
}
- res, err = resultFromStatement(driverStmt{dc, si}, args...)
+ res, err = resultFromStatement(ctx, driverStmt{dc, si}, args...)
releaseConn(err)
if err != driver.ErrBadConn {
return res, err
return nil, driver.ErrBadConn
}
+// Exec executes a prepared statement with the given arguments and
+// returns a Result summarizing the effect of the statement.
+func (s *Stmt) Exec(args ...interface{}) (Result, error) {
+ return s.ExecContext(context.Background(), args...)
+}
+
func driverNumInput(ds driverStmt) int {
ds.Lock()
defer ds.Unlock() // in case NumInput panics
return ds.si.NumInput()
}
-func resultFromStatement(ds driverStmt, args ...interface{}) (Result, error) {
+func resultFromStatement(ctx context.Context, ds driverStmt, args ...interface{}) (Result, error) {
want := driverNumInput(ds)
// -1 means the driver doesn't know how to count the number of
ds.Lock()
defer ds.Unlock()
- resi, err := ds.si.Exec(dargs)
+
+ resi, err := ctxDriverStmtExec(ctx, ds.si, dargs)
if err != nil {
return nil, err
}
// connStmt returns a free driver connection on which to execute the
// statement, a function to call to release the connection, and a
// statement bound to that connection.
-func (s *Stmt) connStmt() (ci *driverConn, releaseConn func(error), si driver.Stmt, err error) {
+func (s *Stmt) connStmt(ctx context.Context) (ci *driverConn, releaseConn func(error), si driver.Stmt, err error) {
if err = s.stickyErr; err != nil {
return
}
// transaction was created on.
if s.tx != nil {
s.mu.Unlock()
- ci, err = s.tx.grabConn() // blocks, waiting for the connection.
+ ci, err = s.tx.grabConn(ctx) // blocks, waiting for the connection.
if err != nil {
return
}
s.removeClosedStmtLocked()
s.mu.Unlock()
- // TODO(bradfitz): or always wait for one? make configurable later?
- dc, err := s.db.conn(cachedOrNewConn)
+ dc, err := s.db.conn(ctx, cachedOrNewConn)
if err != nil {
return nil, nil, nil, err
}
// No luck; we need to prepare the statement on this connection
withLock(dc, func() {
- si, err = dc.prepareLocked(s.query)
+ si, err = dc.prepareLocked(ctx, s.query)
})
if err != nil {
s.db.putConn(dc, err)
return dc, dc.releaseConn, si, nil
}
-// Query executes a prepared query statement with the given arguments
+// QueryContext executes a prepared query statement with the given arguments
// and returns the query results as a *Rows.
-func (s *Stmt) Query(args ...interface{}) (*Rows, error) {
+func (s *Stmt) QueryContext(ctx context.Context, args ...interface{}) (*Rows, error) {
s.closemu.RLock()
defer s.closemu.RUnlock()
var rowsi driver.Rows
for i := 0; i < maxBadConnRetries; i++ {
- dc, releaseConn, si, err := s.connStmt()
+ dc, releaseConn, si, err := s.connStmt(ctx)
if err != nil {
if err == driver.ErrBadConn {
continue
return nil, err
}
- rowsi, err = rowsiFromStatement(driverStmt{dc, si}, args...)
+ rowsi, err = rowsiFromStatement(ctx, driverStmt{dc, si}, args...)
if err == nil {
// Note: ownership of ci passes to the *Rows, to be freed
// with releaseConn.
return nil, driver.ErrBadConn
}
-func rowsiFromStatement(ds driverStmt, args ...interface{}) (driver.Rows, error) {
+// Query executes a prepared query statement with the given arguments
+// and returns the query results as a *Rows.
+func (s *Stmt) Query(args ...interface{}) (*Rows, error) {
+ return s.QueryContext(context.Background(), args...)
+}
+
+func rowsiFromStatement(ctx context.Context, ds driverStmt, args ...interface{}) (driver.Rows, error) {
var want int
withLock(ds, func() {
want = ds.si.NumInput()
ds.Lock()
defer ds.Unlock()
- rowsi, err := ds.si.Query(dargs)
+
+ rowsi, err := ctxDriverStmtQuery(ctx, ds.si, dargs)
if err != nil {
return nil, err
}
return rowsi, nil
}
-// QueryRow executes a prepared query statement with the given arguments.
+// QueryRowContext executes a prepared query statement with the given arguments.
// If an error occurs during the execution of the statement, that error will
// be returned by a call to Scan on the returned *Row, which is always non-nil.
// If the query selects no rows, the *Row's Scan will return ErrNoRows.
// Example usage:
//
// var name string
-// err := nameByUseridStmt.QueryRow(id).Scan(&name)
-func (s *Stmt) QueryRow(args ...interface{}) *Row {
- rows, err := s.Query(args...)
+// err := nameByUseridStmt.QueryRowContext(ctx, id).Scan(&name)
+func (s *Stmt) QueryRowContext(ctx context.Context, args ...interface{}) *Row {
+ rows, err := s.QueryContext(ctx, args...)
if err != nil {
return &Row{err: err}
}
return &Row{rows: rows}
}
+// QueryRow executes a prepared query statement with the given arguments.
+// If an error occurs during the execution of the statement, that error will
+// be returned by a call to Scan on the returned *Row, which is always non-nil.
+// If the query selects no rows, the *Row's Scan will return ErrNoRows.
+// Otherwise, the *Row's Scan scans the first selected row and discards
+// the rest.
+//
+// Example usage:
+//
+// var name string
+// err := nameByUseridStmt.QueryRow(id).Scan(&name)
+func (s *Stmt) QueryRow(args ...interface{}) *Row {
+ return s.QueryRowContext(context.Background(), args...)
+}
+
// Close closes the statement.
func (s *Stmt) Close() error {
s.closemu.Lock()