From e13df02e5fbf2c0cd8811b826a8c8567efa882dd Mon Sep 17 00:00:00 2001 From: Daniel Theophanes Date: Mon, 19 Sep 2016 11:19:32 -0700 Subject: [PATCH] database/sql: add context methods Add context methods to sql and sql/driver methods. If the driver doesn't implement context methods the connection pool will still handle timeouts when a query fails to return in time or when a connection is not available from the pool in time. There will be a follow-up CL that will add support for context values that specify transaction levels and modes that a driver can use. Fixes #15123 Change-Id: Ia99f3957aa3f177b23044dd99d4ec217491a30a7 Reviewed-on: https://go-review.googlesource.com/29381 Reviewed-by: Brad Fitzpatrick Run-TryBot: Brad Fitzpatrick TryBot-Result: Gobot Gobot --- src/database/sql/ctxutil.go | 231 ++++++++++++++++++++ src/database/sql/driver/driver.go | 46 +++- src/database/sql/sql.go | 335 +++++++++++++++++++++--------- src/database/sql/sql_test.go | 9 +- src/go/build/deps_test.go | 4 +- 5 files changed, 525 insertions(+), 100 deletions(-) create mode 100644 src/database/sql/ctxutil.go diff --git a/src/database/sql/ctxutil.go b/src/database/sql/ctxutil.go new file mode 100644 index 0000000000..65e1652657 --- /dev/null +++ b/src/database/sql/ctxutil.go @@ -0,0 +1,231 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package sql + +import ( + "context" + "database/sql/driver" + "errors" +) + +func ctxDriverPrepare(ctx context.Context, ci driver.Conn, query string) (driver.Stmt, error) { + if ciCtx, is := ci.(driver.ConnPrepareContext); is { + return ciCtx.PrepareContext(ctx, query) + } + type R struct { + err error + panic interface{} + si driver.Stmt + } + + rc := make(chan R, 1) + go func() { + r := R{} + defer func() { + if v := recover(); v != nil { + r.panic = v + } + rc <- r + }() + r.si, r.err = ci.Prepare(query) + }() + select { + case <-ctx.Done(): + go func() { + <-rc + close(rc) + }() + return nil, ctx.Err() + case r := <-rc: + if r.panic != nil { + panic(r.panic) + } + return r.si, r.err + } +} + +func ctxDriverExec(ctx context.Context, execer driver.Execer, query string, dargs []driver.Value) (driver.Result, error) { + if execerCtx, is := execer.(driver.ExecerContext); is { + return execerCtx.ExecContext(ctx, query, dargs) + } + type R struct { + err error + panic interface{} + resi driver.Result + } + + rc := make(chan R, 1) + go func() { + r := R{} + defer func() { + if v := recover(); v != nil { + r.panic = v + } + rc <- r + }() + r.resi, r.err = execer.Exec(query, dargs) + }() + select { + case <-ctx.Done(): + go func() { + <-rc + close(rc) + }() + return nil, ctx.Err() + case r := <-rc: + if r.panic != nil { + panic(r.panic) + } + return r.resi, r.err + } +} + +func ctxDriverQuery(ctx context.Context, queryer driver.Queryer, query string, dargs []driver.Value) (driver.Rows, error) { + if queryerCtx, is := queryer.(driver.QueryerContext); is { + return queryerCtx.QueryContext(ctx, query, dargs) + } + type R struct { + err error + panic interface{} + rowsi driver.Rows + } + + rc := make(chan R, 1) + go func() { + r := R{} + defer func() { + if v := recover(); v != nil { + r.panic = v + } + rc <- r + }() + r.rowsi, r.err = queryer.Query(query, dargs) + }() + select { + case <-ctx.Done(): + go func() { + <-rc + close(rc) + }() + return nil, ctx.Err() + case r := <-rc: + if r.panic != nil { + panic(r.panic) + } + return r.rowsi, r.err + } +} + +func ctxDriverStmtExec(ctx context.Context, si driver.Stmt, dargs []driver.Value) (driver.Result, error) { + if siCtx, is := si.(driver.StmtExecContext); is { + return siCtx.ExecContext(ctx, dargs) + } + type R struct { + err error + panic interface{} + resi driver.Result + } + + rc := make(chan R, 1) + go func() { + r := R{} + defer func() { + if v := recover(); v != nil { + r.panic = v + } + rc <- r + }() + r.resi, r.err = si.Exec(dargs) + }() + select { + case <-ctx.Done(): + go func() { + <-rc + close(rc) + }() + return nil, ctx.Err() + case r := <-rc: + if r.panic != nil { + panic(r.panic) + } + return r.resi, r.err + } +} + +func ctxDriverStmtQuery(ctx context.Context, si driver.Stmt, dargs []driver.Value) (driver.Rows, error) { + if siCtx, is := si.(driver.StmtQueryContext); is { + return siCtx.QueryContext(ctx, dargs) + } + type R struct { + err error + panic interface{} + rowsi driver.Rows + } + + rc := make(chan R, 1) + go func() { + r := R{} + defer func() { + if v := recover(); v != nil { + r.panic = v + } + rc <- r + }() + r.rowsi, r.err = si.Query(dargs) + }() + select { + case <-ctx.Done(): + go func() { + <-rc + close(rc) + }() + return nil, ctx.Err() + case r := <-rc: + if r.panic != nil { + panic(r.panic) + } + return r.rowsi, r.err + } +} + +var errLevelNotSupported = errors.New("sql: selected isolation level is not supported") + +func ctxDriverBegin(ctx context.Context, ci driver.Conn) (driver.Tx, error) { + if ciCtx, is := ci.(driver.ConnBeginContext); is { + return ciCtx.BeginContext(ctx) + } + // TODO(kardianos): check the transaction level in ctx. If set and non-default + // then return an error here as the BeginContext driver value is not supported. + + type R struct { + err error + panic interface{} + txi driver.Tx + } + rc := make(chan R, 1) + go func() { + r := R{} + defer func() { + if v := recover(); v != nil { + r.panic = v + } + rc <- r + }() + r.txi, r.err = ci.Begin() + }() + select { + case <-ctx.Done(): + go func() { + <-rc + close(rc) + }() + return nil, ctx.Err() + case r := <-rc: + if r.panic != nil { + panic(r.panic) + } + return r.txi, r.err + } +} diff --git a/src/database/sql/driver/driver.go b/src/database/sql/driver/driver.go index 4dba85a6d3..ccc283d373 100644 --- a/src/database/sql/driver/driver.go +++ b/src/database/sql/driver/driver.go @@ -8,7 +8,10 @@ // Most code should use package sql. package driver -import "errors" +import ( + "context" + "errors" +) // Value is a value that drivers must be able to handle. // It is either nil or an instance of one of these types: @@ -65,6 +68,12 @@ type Execer interface { Exec(query string, args []Value) (Result, error) } +// ExecerContext is like execer, but must honor the context timeout and return +// when the context is cancelled. +type ExecerContext interface { + ExecContext(ctx context.Context, query string, args []Value) (Result, error) +} + // Queryer is an optional interface that may be implemented by a Conn. // // If a Conn does not implement Queryer, the sql package's DB.Query will @@ -76,6 +85,12 @@ type Queryer interface { Query(query string, args []Value) (Rows, error) } +// QueryerContext is like Queryer, but most honor the context timeout and return +// when the context is cancelled. +type QueryerContext interface { + QueryContext(ctx context.Context, query string, args []Value) (Rows, error) +} + // Conn is a connection to a database. It is not used concurrently // by multiple goroutines. // @@ -98,6 +113,23 @@ type Conn interface { Begin() (Tx, error) } +// ConnPrepareContext enhances the Conn interface with context. +type ConnPrepareContext interface { + // PrepareContext returns a prepared statement, bound to this connection. + // context is for the preparation of the statement, + // it must not store the context within the statement itself. + PrepareContext(ctx context.Context, query string) (Stmt, error) +} + +// ConnBeginContext enhances the Conn interface with context. +type ConnBeginContext interface { + // BeginContext starts and returns a new transaction. + // the provided context should be used to roll the transaction back + // if it is cancelled. If there is an isolation level in context + // that is not supported by the driver an error must be returned. + BeginContext(ctx context.Context) (Tx, error) +} + // Result is the result of a query execution. type Result interface { // LastInsertId returns the database's auto-generated ID @@ -139,6 +171,18 @@ type Stmt interface { Query(args []Value) (Rows, error) } +// StmtExecContext enhances the Stmt interface by providing Exec with context. +type StmtExecContext interface { + // ExecContext must honor the context timeout and return when it is cancelled. + ExecContext(ctx context.Context, args []Value) (Result, error) +} + +// StmtQueryContext enhances the Stmt interface by providing Query with context. +type StmtQueryContext interface { + // QueryContext must honor the context timeout and return when it is cancelled. + QueryContext(ctx context.Context, args []Value) (Rows, error) +} + // ColumnConverter may be optionally implemented by Stmt if the // statement is aware of its own columns' types and can convert from // any type to a driver Value. diff --git a/src/database/sql/sql.go b/src/database/sql/sql.go index 1e09a313ac..4c44e2b6f4 100644 --- a/src/database/sql/sql.go +++ b/src/database/sql/sql.go @@ -13,6 +13,7 @@ package sql import ( + "context" "database/sql/driver" "errors" "fmt" @@ -297,8 +298,8 @@ func (dc *driverConn) expired(timeout time.Duration) bool { return dc.createdAt.Add(timeout).Before(nowFunc()) } -func (dc *driverConn) prepareLocked(query string) (driver.Stmt, error) { - si, err := dc.ci.Prepare(query) +func (dc *driverConn) prepareLocked(ctx context.Context, query string) (driver.Stmt, error) { + si, err := ctxDriverPrepare(ctx, dc.ci, query) if err == nil { // Track each driverConn's open statements, so we can close them // before closing the conn. @@ -494,13 +495,13 @@ func Open(driverName, dataSourceName string) (*DB, error) { return db, nil } -// Ping verifies a connection to the database is still alive, +// PingContext verifies a connection to the database is still alive, // establishing a connection if necessary. -func (db *DB) Ping() error { +func (db *DB) PingContext(ctx context.Context) error { // TODO(bradfitz): give drivers an optional hook to implement // this in a more efficient or more reliable way, if they // have one. - dc, err := db.conn(cachedOrNewConn) + dc, err := db.conn(ctx, cachedOrNewConn) if err != nil { return err } @@ -508,6 +509,12 @@ func (db *DB) Ping() error { return nil } +// Ping verifies a connection to the database is still alive, +// establishing a connection if necessary. +func (db *DB) Ping() error { + return db.PingContext(context.Background()) +} + // Close closes the database, releasing any open resources. // // It is rare to Close a DB, as the DB handle is meant to be @@ -777,12 +784,16 @@ type connRequest struct { var errDBClosed = errors.New("sql: database is closed") // conn returns a newly-opened or cached *driverConn. -func (db *DB) conn(strategy connReuseStrategy) (*driverConn, error) { +func (db *DB) conn(ctx context.Context, strategy connReuseStrategy) (*driverConn, error) { db.mu.Lock() if db.closed { db.mu.Unlock() return nil, errDBClosed } + // Check if the context is expired. + if err := ctx.Err(); err != nil { + return nil, err + } lifetime := db.maxLifetime // Prefer a free connection, if possible. @@ -808,15 +819,21 @@ func (db *DB) conn(strategy connReuseStrategy) (*driverConn, error) { req := make(chan connRequest, 1) db.connRequests = append(db.connRequests, req) db.mu.Unlock() - ret, ok := <-req - if !ok { - return nil, errDBClosed - } - if ret.err == nil && ret.conn.expired(lifetime) { - ret.conn.Close() - return nil, driver.ErrBadConn + + // Timeout the connection request with the context. + select { + case <-ctx.Done(): + return nil, ctx.Err() + case ret, ok := <-req: + if !ok { + return nil, errDBClosed + } + if ret.err == nil && ret.conn.expired(lifetime) { + ret.conn.Close() + return nil, driver.ErrBadConn + } + return ret.conn, ret.err } - return ret.conn, ret.err } db.numOpen++ // optimistically @@ -952,40 +969,51 @@ func (db *DB) putConnDBLocked(dc *driverConn, err error) bool { // connection to be opened. const maxBadConnRetries = 2 -// Prepare creates a prepared statement for later queries or executions. +// PrepareContext creates a prepared statement for later queries or executions. // Multiple queries or executions may be run concurrently from the // returned statement. // The caller must call the statement's Close method // when the statement is no longer needed. -func (db *DB) Prepare(query string) (*Stmt, error) { +// Context is for the preparation of the statment, not for the execution of +// the statement. +func (db *DB) PrepareContext(ctx context.Context, query string) (*Stmt, error) { var stmt *Stmt var err error for i := 0; i < maxBadConnRetries; i++ { - stmt, err = db.prepare(query, cachedOrNewConn) + stmt, err = db.prepare(ctx, query, cachedOrNewConn) if err != driver.ErrBadConn { break } } if err == driver.ErrBadConn { - return db.prepare(query, alwaysNewConn) + return db.prepare(ctx, query, alwaysNewConn) } return stmt, err } -func (db *DB) prepare(query string, strategy connReuseStrategy) (*Stmt, error) { +// Prepare creates a prepared statement for later queries or executions. +// Multiple queries or executions may be run concurrently from the +// returned statement. +// The caller must call the statement's Close method +// when the statement is no longer needed. +func (db *DB) Prepare(query string) (*Stmt, error) { + return db.PrepareContext(context.Background(), query) +} + +func (db *DB) prepare(ctx context.Context, query string, strategy connReuseStrategy) (*Stmt, error) { // TODO: check if db.driver supports an optional // driver.Preparer interface and call that instead, if so, // otherwise we make a prepared statement that's bound // to a connection, and to execute this prepared statement // we either need to use this connection (if it's free), else // get a new connection + re-prepare + execute on that one. - dc, err := db.conn(strategy) + dc, err := db.conn(ctx, strategy) if err != nil { return nil, err } var si driver.Stmt withLock(dc, func() { - si, err = dc.prepareLocked(query) + si, err = dc.prepareLocked(ctx, query) }) if err != nil { db.putConn(dc, err) @@ -1002,25 +1030,31 @@ func (db *DB) prepare(query string, strategy connReuseStrategy) (*Stmt, error) { return stmt, nil } -// Exec executes a query without returning any rows. +// ExecContext executes a query without returning any rows. // The args are for any placeholder parameters in the query. -func (db *DB) Exec(query string, args ...interface{}) (Result, error) { +func (db *DB) ExecContext(ctx context.Context, query string, args ...interface{}) (Result, error) { var res Result var err error for i := 0; i < maxBadConnRetries; i++ { - res, err = db.exec(query, args, cachedOrNewConn) + res, err = db.exec(ctx, query, args, cachedOrNewConn) if err != driver.ErrBadConn { break } } if err == driver.ErrBadConn { - return db.exec(query, args, alwaysNewConn) + return db.exec(ctx, query, args, alwaysNewConn) } return res, err } -func (db *DB) exec(query string, args []interface{}, strategy connReuseStrategy) (res Result, err error) { - dc, err := db.conn(strategy) +// Exec executes a query without returning any rows. +// The args are for any placeholder parameters in the query. +func (db *DB) Exec(query string, args ...interface{}) (Result, error) { + return db.ExecContext(context.Background(), query, args...) +} + +func (db *DB) exec(ctx context.Context, query string, args []interface{}, strategy connReuseStrategy) (res Result, err error) { + dc, err := db.conn(ctx, strategy) if err != nil { return nil, err } @@ -1036,7 +1070,7 @@ func (db *DB) exec(query string, args []interface{}, strategy connReuseStrategy) } var resi driver.Result withLock(dc, func() { - resi, err = execer.Exec(query, dargs) + resi, err = ctxDriverExec(ctx, execer, query, dargs) }) if err != driver.ErrSkip { if err != nil { @@ -1048,44 +1082,50 @@ func (db *DB) exec(query string, args []interface{}, strategy connReuseStrategy) var si driver.Stmt withLock(dc, func() { - si, err = dc.ci.Prepare(query) + si, err = ctxDriverPrepare(ctx, dc.ci, query) }) if err != nil { return nil, err } defer withLock(dc, func() { si.Close() }) - return resultFromStatement(driverStmt{dc, si}, args...) + return resultFromStatement(ctx, driverStmt{dc, si}, args...) } -// Query executes a query that returns rows, typically a SELECT. +// QueryContext executes a query that returns rows, typically a SELECT. // The args are for any placeholder parameters in the query. -func (db *DB) Query(query string, args ...interface{}) (*Rows, error) { +func (db *DB) QueryContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) { var rows *Rows var err error for i := 0; i < maxBadConnRetries; i++ { - rows, err = db.query(query, args, cachedOrNewConn) + rows, err = db.query(ctx, query, args, cachedOrNewConn) if err != driver.ErrBadConn { break } } if err == driver.ErrBadConn { - return db.query(query, args, alwaysNewConn) + return db.query(ctx, query, args, alwaysNewConn) } return rows, err } -func (db *DB) query(query string, args []interface{}, strategy connReuseStrategy) (*Rows, error) { - ci, err := db.conn(strategy) +// Query executes a query that returns rows, typically a SELECT. +// The args are for any placeholder parameters in the query. +func (db *DB) Query(query string, args ...interface{}) (*Rows, error) { + return db.QueryContext(context.Background(), query, args...) +} + +func (db *DB) query(ctx context.Context, query string, args []interface{}, strategy connReuseStrategy) (*Rows, error) { + ci, err := db.conn(ctx, strategy) if err != nil { return nil, err } - return db.queryConn(ci, ci.releaseConn, query, args) + return db.queryConn(ctx, ci, ci.releaseConn, query, args) } // queryConn executes a query on the given connection. // The connection gets released by the releaseConn function. -func (db *DB) queryConn(dc *driverConn, releaseConn func(error), query string, args []interface{}) (*Rows, error) { +func (db *DB) queryConn(ctx context.Context, dc *driverConn, releaseConn func(error), query string, args []interface{}) (*Rows, error) { if queryer, ok := dc.ci.(driver.Queryer); ok { dargs, err := driverArgs(nil, args) if err != nil { @@ -1094,7 +1134,7 @@ func (db *DB) queryConn(dc *driverConn, releaseConn func(error), query string, a } var rowsi driver.Rows withLock(dc, func() { - rowsi, err = queryer.Query(query, dargs) + rowsi, err = ctxDriverQuery(ctx, queryer, query, dargs) }) if err != driver.ErrSkip { if err != nil { @@ -1115,7 +1155,7 @@ func (db *DB) queryConn(dc *driverConn, releaseConn func(error), query string, a var si driver.Stmt var err error withLock(dc, func() { - si, err = dc.ci.Prepare(query) + si, err = ctxDriverPrepare(ctx, dc.ci, query) }) if err != nil { releaseConn(err) @@ -1123,7 +1163,7 @@ func (db *DB) queryConn(dc *driverConn, releaseConn func(error), query string, a } ds := driverStmt{dc, si} - rowsi, err := rowsiFromStatement(ds, args...) + rowsi, err := rowsiFromStatement(ctx, ds, args...) if err != nil { withLock(dc, func() { si.Close() @@ -1143,49 +1183,77 @@ func (db *DB) queryConn(dc *driverConn, releaseConn func(error), query string, a return rows, nil } +// QueryRowContext executes a query that is expected to return at most one row. +// QueryRowContext always returns a non-nil value. Errors are deferred until +// Row's Scan method is called. +func (db *DB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *Row { + rows, err := db.QueryContext(ctx, query, args...) + return &Row{rows: rows, err: err} +} + // QueryRow executes a query that is expected to return at most one row. // QueryRow always returns a non-nil value. Errors are deferred until // Row's Scan method is called. func (db *DB) QueryRow(query string, args ...interface{}) *Row { - rows, err := db.Query(query, args...) - return &Row{rows: rows, err: err} + return db.QueryRowContext(context.Background(), query, args...) } -// Begin starts a transaction. The isolation level is dependent on -// the driver. -func (db *DB) Begin() (*Tx, error) { +// BeginContext starts a transaction. If a non-default isolation level is used +// that the driver doesn't support an error will be returned. Different drivers +// may have slightly different meanings for the same isolation level. +func (db *DB) BeginContext(ctx context.Context) (*Tx, error) { var tx *Tx var err error for i := 0; i < maxBadConnRetries; i++ { - tx, err = db.begin(cachedOrNewConn) + tx, err = db.begin(ctx, cachedOrNewConn) if err != driver.ErrBadConn { break } } if err == driver.ErrBadConn { - return db.begin(alwaysNewConn) + return db.begin(ctx, alwaysNewConn) } return tx, err } -func (db *DB) begin(strategy connReuseStrategy) (tx *Tx, err error) { - dc, err := db.conn(strategy) +// Begin starts a transaction. The default isolation level is dependent on +// the driver. +func (db *DB) Begin() (*Tx, error) { + return db.BeginContext(context.Background()) +} + +func (db *DB) begin(ctx context.Context, strategy connReuseStrategy) (tx *Tx, err error) { + dc, err := db.conn(ctx, strategy) if err != nil { return nil, err } var txi driver.Tx withLock(dc, func() { - txi, err = dc.ci.Begin() + txi, err = ctxDriverBegin(ctx, dc.ci) }) if err != nil { db.putConn(dc, err) return nil, err } - return &Tx{ - db: db, - dc: dc, - txi: txi, - }, nil + + // Schedule the transaction to rollback when the context is cancelled. + // The cancel function in Tx will be called after done is set to true. + ctx, cancel := context.WithCancel(ctx) + tx = &Tx{ + db: db, + dc: dc, + txi: txi, + cancel: cancel, + } + go func() { + select { + case <-ctx.Done(): + if !tx.done { + tx.Rollback() + } + } + }() + return tx, nil } // Driver returns the database's underlying driver. @@ -1222,6 +1290,9 @@ type Tx struct { sync.Mutex v []*Stmt } + + // cancel is called after done transitions from false to true. + cancel func() } // ErrTxDone is returned by any operation that is performed on a transaction @@ -1234,11 +1305,12 @@ func (tx *Tx) close(err error) { } tx.done = true tx.db.putConn(tx.dc, err) + tx.cancel() tx.dc = nil tx.txi = nil } -func (tx *Tx) grabConn() (*driverConn, error) { +func (tx *Tx) grabConn(ctx context.Context) (*driverConn, error) { if tx.done { return nil, ErrTxDone } @@ -1292,7 +1364,10 @@ func (tx *Tx) Rollback() error { // be used once the transaction has been committed or rolled back. // // To use an existing prepared statement on this transaction, see Tx.Stmt. -func (tx *Tx) Prepare(query string) (*Stmt, error) { +// Context will be used for the preparation of the context, not +// for the execution of the returned statement. The returned statement +// will run in the transaction context. +func (tx *Tx) PrepareContext(ctx context.Context, query string) (*Stmt, error) { // TODO(bradfitz): We could be more efficient here and either // provide a method to take an existing Stmt (created on // perhaps a different Conn), and re-create it on this Conn if @@ -1306,7 +1381,7 @@ func (tx *Tx) Prepare(query string) (*Stmt, error) { // Perhaps just looking at the reference count (by noting // Stmt.Close) would be enough. We might also want a finalizer // on Stmt to drop the reference count. - dc, err := tx.grabConn() + dc, err := tx.grabConn(ctx) if err != nil { return nil, err } @@ -1334,7 +1409,17 @@ func (tx *Tx) Prepare(query string) (*Stmt, error) { return stmt, nil } -// Stmt returns a transaction-specific prepared statement from +// Prepare creates a prepared statement for use within a transaction. +// +// The returned statement operates within the transaction and can no longer +// be used once the transaction has been committed or rolled back. +// +// To use an existing prepared statement on this transaction, see Tx.Stmt. +func (tx *Tx) Prepare(query string) (*Stmt, error) { + return tx.PrepareContext(context.Background(), query) +} + +// StmtContext returns a transaction-specific prepared statement from // an existing statement. // // Example: @@ -1342,11 +1427,11 @@ func (tx *Tx) Prepare(query string) (*Stmt, error) { // ... // tx, err := db.Begin() // ... -// res, err := tx.Stmt(updateMoney).Exec(123.45, 98293203) +// res, err := tx.StmtContext(ctx, updateMoney).Exec(123.45, 98293203) // // The returned statement operates within the transaction and can no longer // be used once the transaction has been committed or rolled back. -func (tx *Tx) Stmt(stmt *Stmt) *Stmt { +func (tx *Tx) StmtContext(ctx context.Context, stmt *Stmt) *Stmt { // TODO(bradfitz): optimize this. Currently this re-prepares // each time. This is fine for now to illustrate the API but // we should really cache already-prepared statements @@ -1355,7 +1440,7 @@ func (tx *Tx) Stmt(stmt *Stmt) *Stmt { if tx.db != stmt.db { return &Stmt{stickyErr: errors.New("sql: Tx.Stmt: statement from different database used")} } - dc, err := tx.grabConn() + dc, err := tx.grabConn(ctx) if err != nil { return &Stmt{stickyErr: err} } @@ -1379,10 +1464,26 @@ func (tx *Tx) Stmt(stmt *Stmt) *Stmt { return txs } -// Exec executes a query that doesn't return rows. +// Stmt returns a transaction-specific prepared statement from +// an existing statement. +// +// Example: +// updateMoney, err := db.Prepare("UPDATE balance SET money=money+? WHERE id=?") +// ... +// tx, err := db.Begin() +// ... +// res, err := tx.Stmt(updateMoney).Exec(123.45, 98293203) +// +// The returned statement operates within the transaction and can no longer +// be used once the transaction has been committed or rolled back. +func (tx *Tx) Stmt(stmt *Stmt) *Stmt { + return tx.StmtContext(context.Background(), stmt) +} + +// ExecContext executes a query that doesn't return rows. // For example: an INSERT and UPDATE. -func (tx *Tx) Exec(query string, args ...interface{}) (Result, error) { - dc, err := tx.grabConn() +func (tx *Tx) ExecContext(ctx context.Context, query string, args ...interface{}) (Result, error) { + dc, err := tx.grabConn(ctx) if err != nil { return nil, err } @@ -1413,25 +1514,43 @@ func (tx *Tx) Exec(query string, args ...interface{}) (Result, error) { } defer withLock(dc, func() { si.Close() }) - return resultFromStatement(driverStmt{dc, si}, args...) + return resultFromStatement(ctx, driverStmt{dc, si}, args...) } -// Query executes a query that returns rows, typically a SELECT. -func (tx *Tx) Query(query string, args ...interface{}) (*Rows, error) { - dc, err := tx.grabConn() +// Exec executes a query that doesn't return rows. +// For example: an INSERT and UPDATE. +func (tx *Tx) Exec(query string, args ...interface{}) (Result, error) { + return tx.ExecContext(context.Background(), query, args...) +} + +// QueryContext executes a query that returns rows, typically a SELECT. +func (tx *Tx) QueryContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) { + dc, err := tx.grabConn(ctx) if err != nil { return nil, err } releaseConn := func(error) {} - return tx.db.queryConn(dc, releaseConn, query, args) + return tx.db.queryConn(ctx, dc, releaseConn, query, args) +} + +// Query executes a query that returns rows, typically a SELECT. +func (tx *Tx) Query(query string, args ...interface{}) (*Rows, error) { + return tx.QueryContext(context.Background(), query, args...) +} + +// QueryRowContext executes a query that is expected to return at most one row. +// QueryRowContext always returns a non-nil value. Errors are deferred until +// Row's Scan method is called. +func (tx *Tx) QueryRowContext(ctx context.Context, query string, args ...interface{}) *Row { + rows, err := tx.QueryContext(ctx, query, args...) + return &Row{rows: rows, err: err} } // QueryRow executes a query that is expected to return at most one row. // QueryRow always returns a non-nil value. Errors are deferred until // Row's Scan method is called. func (tx *Tx) QueryRow(query string, args ...interface{}) *Row { - rows, err := tx.Query(query, args...) - return &Row{rows: rows, err: err} + return tx.QueryRowContext(context.Background(), query, args...) } // connStmt is a prepared statement on a particular connection. @@ -1468,15 +1587,15 @@ type Stmt struct { lastNumClosed uint64 } -// Exec executes a prepared statement with the given arguments and +// ExecContext executes a prepared statement with the given arguments and // returns a Result summarizing the effect of the statement. -func (s *Stmt) Exec(args ...interface{}) (Result, error) { +func (s *Stmt) ExecContext(ctx context.Context, args ...interface{}) (Result, error) { s.closemu.RLock() defer s.closemu.RUnlock() var res Result for i := 0; i < maxBadConnRetries; i++ { - dc, releaseConn, si, err := s.connStmt() + dc, releaseConn, si, err := s.connStmt(ctx) if err != nil { if err == driver.ErrBadConn { continue @@ -1484,7 +1603,7 @@ func (s *Stmt) Exec(args ...interface{}) (Result, error) { return nil, err } - res, err = resultFromStatement(driverStmt{dc, si}, args...) + res, err = resultFromStatement(ctx, driverStmt{dc, si}, args...) releaseConn(err) if err != driver.ErrBadConn { return res, err @@ -1493,13 +1612,19 @@ func (s *Stmt) Exec(args ...interface{}) (Result, error) { return nil, driver.ErrBadConn } +// Exec executes a prepared statement with the given arguments and +// returns a Result summarizing the effect of the statement. +func (s *Stmt) Exec(args ...interface{}) (Result, error) { + return s.ExecContext(context.Background(), args...) +} + func driverNumInput(ds driverStmt) int { ds.Lock() defer ds.Unlock() // in case NumInput panics return ds.si.NumInput() } -func resultFromStatement(ds driverStmt, args ...interface{}) (Result, error) { +func resultFromStatement(ctx context.Context, ds driverStmt, args ...interface{}) (Result, error) { want := driverNumInput(ds) // -1 means the driver doesn't know how to count the number of @@ -1516,7 +1641,8 @@ func resultFromStatement(ds driverStmt, args ...interface{}) (Result, error) { ds.Lock() defer ds.Unlock() - resi, err := ds.si.Exec(dargs) + + resi, err := ctxDriverStmtExec(ctx, ds.si, dargs) if err != nil { return nil, err } @@ -1552,7 +1678,7 @@ func (s *Stmt) removeClosedStmtLocked() { // connStmt returns a free driver connection on which to execute the // statement, a function to call to release the connection, and a // statement bound to that connection. -func (s *Stmt) connStmt() (ci *driverConn, releaseConn func(error), si driver.Stmt, err error) { +func (s *Stmt) connStmt(ctx context.Context) (ci *driverConn, releaseConn func(error), si driver.Stmt, err error) { if err = s.stickyErr; err != nil { return } @@ -1567,7 +1693,7 @@ func (s *Stmt) connStmt() (ci *driverConn, releaseConn func(error), si driver.St // transaction was created on. if s.tx != nil { s.mu.Unlock() - ci, err = s.tx.grabConn() // blocks, waiting for the connection. + ci, err = s.tx.grabConn(ctx) // blocks, waiting for the connection. if err != nil { return } @@ -1578,8 +1704,7 @@ func (s *Stmt) connStmt() (ci *driverConn, releaseConn func(error), si driver.St s.removeClosedStmtLocked() s.mu.Unlock() - // TODO(bradfitz): or always wait for one? make configurable later? - dc, err := s.db.conn(cachedOrNewConn) + dc, err := s.db.conn(ctx, cachedOrNewConn) if err != nil { return nil, nil, nil, err } @@ -1595,7 +1720,7 @@ func (s *Stmt) connStmt() (ci *driverConn, releaseConn func(error), si driver.St // No luck; we need to prepare the statement on this connection withLock(dc, func() { - si, err = dc.prepareLocked(s.query) + si, err = dc.prepareLocked(ctx, s.query) }) if err != nil { s.db.putConn(dc, err) @@ -1609,15 +1734,15 @@ func (s *Stmt) connStmt() (ci *driverConn, releaseConn func(error), si driver.St return dc, dc.releaseConn, si, nil } -// Query executes a prepared query statement with the given arguments +// QueryContext executes a prepared query statement with the given arguments // and returns the query results as a *Rows. -func (s *Stmt) Query(args ...interface{}) (*Rows, error) { +func (s *Stmt) QueryContext(ctx context.Context, args ...interface{}) (*Rows, error) { s.closemu.RLock() defer s.closemu.RUnlock() var rowsi driver.Rows for i := 0; i < maxBadConnRetries; i++ { - dc, releaseConn, si, err := s.connStmt() + dc, releaseConn, si, err := s.connStmt(ctx) if err != nil { if err == driver.ErrBadConn { continue @@ -1625,7 +1750,7 @@ func (s *Stmt) Query(args ...interface{}) (*Rows, error) { return nil, err } - rowsi, err = rowsiFromStatement(driverStmt{dc, si}, args...) + rowsi, err = rowsiFromStatement(ctx, driverStmt{dc, si}, args...) if err == nil { // Note: ownership of ci passes to the *Rows, to be freed // with releaseConn. @@ -1650,7 +1775,13 @@ func (s *Stmt) Query(args ...interface{}) (*Rows, error) { return nil, driver.ErrBadConn } -func rowsiFromStatement(ds driverStmt, args ...interface{}) (driver.Rows, error) { +// Query executes a prepared query statement with the given arguments +// and returns the query results as a *Rows. +func (s *Stmt) Query(args ...interface{}) (*Rows, error) { + return s.QueryContext(context.Background(), args...) +} + +func rowsiFromStatement(ctx context.Context, ds driverStmt, args ...interface{}) (driver.Rows, error) { var want int withLock(ds, func() { want = ds.si.NumInput() @@ -1670,14 +1801,15 @@ func rowsiFromStatement(ds driverStmt, args ...interface{}) (driver.Rows, error) ds.Lock() defer ds.Unlock() - rowsi, err := ds.si.Query(dargs) + + rowsi, err := ctxDriverStmtQuery(ctx, ds.si, dargs) if err != nil { return nil, err } return rowsi, nil } -// QueryRow executes a prepared query statement with the given arguments. +// QueryRowContext executes a prepared query statement with the given arguments. // If an error occurs during the execution of the statement, that error will // be returned by a call to Scan on the returned *Row, which is always non-nil. // If the query selects no rows, the *Row's Scan will return ErrNoRows. @@ -1687,15 +1819,30 @@ func rowsiFromStatement(ds driverStmt, args ...interface{}) (driver.Rows, error) // Example usage: // // var name string -// err := nameByUseridStmt.QueryRow(id).Scan(&name) -func (s *Stmt) QueryRow(args ...interface{}) *Row { - rows, err := s.Query(args...) +// err := nameByUseridStmt.QueryRowContext(ctx, id).Scan(&name) +func (s *Stmt) QueryRowContext(ctx context.Context, args ...interface{}) *Row { + rows, err := s.QueryContext(ctx, args...) if err != nil { return &Row{err: err} } return &Row{rows: rows} } +// QueryRow executes a prepared query statement with the given arguments. +// If an error occurs during the execution of the statement, that error will +// be returned by a call to Scan on the returned *Row, which is always non-nil. +// If the query selects no rows, the *Row's Scan will return ErrNoRows. +// Otherwise, the *Row's Scan scans the first selected row and discards +// the rest. +// +// Example usage: +// +// var name string +// err := nameByUseridStmt.QueryRow(id).Scan(&name) +func (s *Stmt) QueryRow(args ...interface{}) *Row { + return s.QueryRowContext(context.Background(), args...) +} + // Close closes the statement. func (s *Stmt) Close() error { s.closemu.Lock() diff --git a/src/database/sql/sql_test.go b/src/database/sql/sql_test.go index 41afd00e92..9fcb2e38c1 100644 --- a/src/database/sql/sql_test.go +++ b/src/database/sql/sql_test.go @@ -5,6 +5,7 @@ package sql import ( + "context" "database/sql/driver" "errors" "fmt" @@ -1159,17 +1160,19 @@ func TestMaxOpenConnsOnBusy(t *testing.T) { db.SetMaxOpenConns(3) - conn0, err := db.conn(cachedOrNewConn) + ctx := context.Background() + + conn0, err := db.conn(ctx, cachedOrNewConn) if err != nil { t.Fatalf("db open conn fail: %v", err) } - conn1, err := db.conn(cachedOrNewConn) + conn1, err := db.conn(ctx, cachedOrNewConn) if err != nil { t.Fatalf("db open conn fail: %v", err) } - conn2, err := db.conn(cachedOrNewConn) + conn2, err := db.conn(ctx, cachedOrNewConn) if err != nil { t.Fatalf("db open conn fail: %v", err) } diff --git a/src/go/build/deps_test.go b/src/go/build/deps_test.go index d8eb2ee726..c7e11498dd 100644 --- a/src/go/build/deps_test.go +++ b/src/go/build/deps_test.go @@ -228,8 +228,8 @@ var pkgDeps = map[string][]string{ "compress/lzw": {"L4"}, "compress/zlib": {"L4", "compress/flate"}, "context": {"errors", "fmt", "reflect", "sync", "time"}, - "database/sql": {"L4", "container/list", "database/sql/driver"}, - "database/sql/driver": {"L4", "time"}, + "database/sql": {"L4", "container/list", "context", "database/sql/driver"}, + "database/sql/driver": {"L4", "context", "time"}, "debug/dwarf": {"L4"}, "debug/elf": {"L4", "OS", "debug/dwarf", "compress/zlib"}, "debug/gosym": {"L4"}, -- 2.48.1