]> Cypherpunks repositories - gostls13.git/commitdiff
database/sql: add context methods
authorDaniel Theophanes <kardianos@gmail.com>
Mon, 19 Sep 2016 18:19:32 +0000 (11:19 -0700)
committerBrad Fitzpatrick <bradfitz@golang.org>
Tue, 27 Sep 2016 19:41:09 +0000 (19:41 +0000)
Add context methods to sql and sql/driver methods. If
the driver doesn't implement context methods the connection
pool will still handle timeouts when a query fails to return
in time or when a connection is not available from the pool
in time.

There will be a follow-up CL that will add support for
context values that specify transaction levels and modes
that a driver can use.

Fixes #15123

Change-Id: Ia99f3957aa3f177b23044dd99d4ec217491a30a7
Reviewed-on: https://go-review.googlesource.com/29381
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>

src/database/sql/ctxutil.go [new file with mode: 0644]
src/database/sql/driver/driver.go
src/database/sql/sql.go
src/database/sql/sql_test.go
src/go/build/deps_test.go

diff --git a/src/database/sql/ctxutil.go b/src/database/sql/ctxutil.go
new file mode 100644 (file)
index 0000000..65e1652
--- /dev/null
@@ -0,0 +1,231 @@
+// Copyright 2016 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package sql
+
+import (
+       "context"
+       "database/sql/driver"
+       "errors"
+)
+
+func ctxDriverPrepare(ctx context.Context, ci driver.Conn, query string) (driver.Stmt, error) {
+       if ciCtx, is := ci.(driver.ConnPrepareContext); is {
+               return ciCtx.PrepareContext(ctx, query)
+       }
+       type R struct {
+               err   error
+               panic interface{}
+               si    driver.Stmt
+       }
+
+       rc := make(chan R, 1)
+       go func() {
+               r := R{}
+               defer func() {
+                       if v := recover(); v != nil {
+                               r.panic = v
+                       }
+                       rc <- r
+               }()
+               r.si, r.err = ci.Prepare(query)
+       }()
+       select {
+       case <-ctx.Done():
+               go func() {
+                       <-rc
+                       close(rc)
+               }()
+               return nil, ctx.Err()
+       case r := <-rc:
+               if r.panic != nil {
+                       panic(r.panic)
+               }
+               return r.si, r.err
+       }
+}
+
+func ctxDriverExec(ctx context.Context, execer driver.Execer, query string, dargs []driver.Value) (driver.Result, error) {
+       if execerCtx, is := execer.(driver.ExecerContext); is {
+               return execerCtx.ExecContext(ctx, query, dargs)
+       }
+       type R struct {
+               err   error
+               panic interface{}
+               resi  driver.Result
+       }
+
+       rc := make(chan R, 1)
+       go func() {
+               r := R{}
+               defer func() {
+                       if v := recover(); v != nil {
+                               r.panic = v
+                       }
+                       rc <- r
+               }()
+               r.resi, r.err = execer.Exec(query, dargs)
+       }()
+       select {
+       case <-ctx.Done():
+               go func() {
+                       <-rc
+                       close(rc)
+               }()
+               return nil, ctx.Err()
+       case r := <-rc:
+               if r.panic != nil {
+                       panic(r.panic)
+               }
+               return r.resi, r.err
+       }
+}
+
+func ctxDriverQuery(ctx context.Context, queryer driver.Queryer, query string, dargs []driver.Value) (driver.Rows, error) {
+       if queryerCtx, is := queryer.(driver.QueryerContext); is {
+               return queryerCtx.QueryContext(ctx, query, dargs)
+       }
+       type R struct {
+               err   error
+               panic interface{}
+               rowsi driver.Rows
+       }
+
+       rc := make(chan R, 1)
+       go func() {
+               r := R{}
+               defer func() {
+                       if v := recover(); v != nil {
+                               r.panic = v
+                       }
+                       rc <- r
+               }()
+               r.rowsi, r.err = queryer.Query(query, dargs)
+       }()
+       select {
+       case <-ctx.Done():
+               go func() {
+                       <-rc
+                       close(rc)
+               }()
+               return nil, ctx.Err()
+       case r := <-rc:
+               if r.panic != nil {
+                       panic(r.panic)
+               }
+               return r.rowsi, r.err
+       }
+}
+
+func ctxDriverStmtExec(ctx context.Context, si driver.Stmt, dargs []driver.Value) (driver.Result, error) {
+       if siCtx, is := si.(driver.StmtExecContext); is {
+               return siCtx.ExecContext(ctx, dargs)
+       }
+       type R struct {
+               err   error
+               panic interface{}
+               resi  driver.Result
+       }
+
+       rc := make(chan R, 1)
+       go func() {
+               r := R{}
+               defer func() {
+                       if v := recover(); v != nil {
+                               r.panic = v
+                       }
+                       rc <- r
+               }()
+               r.resi, r.err = si.Exec(dargs)
+       }()
+       select {
+       case <-ctx.Done():
+               go func() {
+                       <-rc
+                       close(rc)
+               }()
+               return nil, ctx.Err()
+       case r := <-rc:
+               if r.panic != nil {
+                       panic(r.panic)
+               }
+               return r.resi, r.err
+       }
+}
+
+func ctxDriverStmtQuery(ctx context.Context, si driver.Stmt, dargs []driver.Value) (driver.Rows, error) {
+       if siCtx, is := si.(driver.StmtQueryContext); is {
+               return siCtx.QueryContext(ctx, dargs)
+       }
+       type R struct {
+               err   error
+               panic interface{}
+               rowsi driver.Rows
+       }
+
+       rc := make(chan R, 1)
+       go func() {
+               r := R{}
+               defer func() {
+                       if v := recover(); v != nil {
+                               r.panic = v
+                       }
+                       rc <- r
+               }()
+               r.rowsi, r.err = si.Query(dargs)
+       }()
+       select {
+       case <-ctx.Done():
+               go func() {
+                       <-rc
+                       close(rc)
+               }()
+               return nil, ctx.Err()
+       case r := <-rc:
+               if r.panic != nil {
+                       panic(r.panic)
+               }
+               return r.rowsi, r.err
+       }
+}
+
+var errLevelNotSupported = errors.New("sql: selected isolation level is not supported")
+
+func ctxDriverBegin(ctx context.Context, ci driver.Conn) (driver.Tx, error) {
+       if ciCtx, is := ci.(driver.ConnBeginContext); is {
+               return ciCtx.BeginContext(ctx)
+       }
+       // TODO(kardianos): check the transaction level in ctx. If set and non-default
+       // then return an error here as the BeginContext driver value is not supported.
+
+       type R struct {
+               err   error
+               panic interface{}
+               txi   driver.Tx
+       }
+       rc := make(chan R, 1)
+       go func() {
+               r := R{}
+               defer func() {
+                       if v := recover(); v != nil {
+                               r.panic = v
+                       }
+                       rc <- r
+               }()
+               r.txi, r.err = ci.Begin()
+       }()
+       select {
+       case <-ctx.Done():
+               go func() {
+                       <-rc
+                       close(rc)
+               }()
+               return nil, ctx.Err()
+       case r := <-rc:
+               if r.panic != nil {
+                       panic(r.panic)
+               }
+               return r.txi, r.err
+       }
+}
index 4dba85a6d343148912d00f12456117a37c8d38ea..ccc283d3738959ef41ced17db7937d68e1191963 100644 (file)
@@ -8,7 +8,10 @@
 // Most code should use package sql.
 package driver
 
-import "errors"
+import (
+       "context"
+       "errors"
+)
 
 // Value is a value that drivers must be able to handle.
 // It is either nil or an instance of one of these types:
@@ -65,6 +68,12 @@ type Execer interface {
        Exec(query string, args []Value) (Result, error)
 }
 
+// ExecerContext is like execer, but must honor the context timeout and return
+// when the context is cancelled.
+type ExecerContext interface {
+       ExecContext(ctx context.Context, query string, args []Value) (Result, error)
+}
+
 // Queryer is an optional interface that may be implemented by a Conn.
 //
 // If a Conn does not implement Queryer, the sql package's DB.Query will
@@ -76,6 +85,12 @@ type Queryer interface {
        Query(query string, args []Value) (Rows, error)
 }
 
+// QueryerContext is like Queryer, but most honor the context timeout and return
+// when the context is cancelled.
+type QueryerContext interface {
+       QueryContext(ctx context.Context, query string, args []Value) (Rows, error)
+}
+
 // Conn is a connection to a database. It is not used concurrently
 // by multiple goroutines.
 //
@@ -98,6 +113,23 @@ type Conn interface {
        Begin() (Tx, error)
 }
 
+// ConnPrepareContext enhances the Conn interface with context.
+type ConnPrepareContext interface {
+       // PrepareContext returns a prepared statement, bound to this connection.
+       // context is for the preparation of the statement,
+       // it must not store the context within the statement itself.
+       PrepareContext(ctx context.Context, query string) (Stmt, error)
+}
+
+// ConnBeginContext enhances the Conn interface with context.
+type ConnBeginContext interface {
+       // BeginContext starts and returns a new transaction.
+       // the provided context should be used to roll the transaction back
+       // if it is cancelled. If there is an isolation level in context
+       // that is not supported by the driver an error must be returned.
+       BeginContext(ctx context.Context) (Tx, error)
+}
+
 // Result is the result of a query execution.
 type Result interface {
        // LastInsertId returns the database's auto-generated ID
@@ -139,6 +171,18 @@ type Stmt interface {
        Query(args []Value) (Rows, error)
 }
 
+// StmtExecContext enhances the Stmt interface by providing Exec with context.
+type StmtExecContext interface {
+       // ExecContext must honor the context timeout and return when it is cancelled.
+       ExecContext(ctx context.Context, args []Value) (Result, error)
+}
+
+// StmtQueryContext enhances the Stmt interface by providing Query with context.
+type StmtQueryContext interface {
+       // QueryContext must honor the context timeout and return when it is cancelled.
+       QueryContext(ctx context.Context, args []Value) (Rows, error)
+}
+
 // ColumnConverter may be optionally implemented by Stmt if the
 // statement is aware of its own columns' types and can convert from
 // any type to a driver Value.
index 1e09a313ac3b3e73b031672669ba9d3ce2bd3eaf..4c44e2b6f466b982b83570eaf43813df499687ba 100644 (file)
@@ -13,6 +13,7 @@
 package sql
 
 import (
+       "context"
        "database/sql/driver"
        "errors"
        "fmt"
@@ -297,8 +298,8 @@ func (dc *driverConn) expired(timeout time.Duration) bool {
        return dc.createdAt.Add(timeout).Before(nowFunc())
 }
 
-func (dc *driverConn) prepareLocked(query string) (driver.Stmt, error) {
-       si, err := dc.ci.Prepare(query)
+func (dc *driverConn) prepareLocked(ctx context.Context, query string) (driver.Stmt, error) {
+       si, err := ctxDriverPrepare(ctx, dc.ci, query)
        if err == nil {
                // Track each driverConn's open statements, so we can close them
                // before closing the conn.
@@ -494,13 +495,13 @@ func Open(driverName, dataSourceName string) (*DB, error) {
        return db, nil
 }
 
-// Ping verifies a connection to the database is still alive,
+// PingContext verifies a connection to the database is still alive,
 // establishing a connection if necessary.
-func (db *DB) Ping() error {
+func (db *DB) PingContext(ctx context.Context) error {
        // TODO(bradfitz): give drivers an optional hook to implement
        // this in a more efficient or more reliable way, if they
        // have one.
-       dc, err := db.conn(cachedOrNewConn)
+       dc, err := db.conn(ctx, cachedOrNewConn)
        if err != nil {
                return err
        }
@@ -508,6 +509,12 @@ func (db *DB) Ping() error {
        return nil
 }
 
+// Ping verifies a connection to the database is still alive,
+// establishing a connection if necessary.
+func (db *DB) Ping() error {
+       return db.PingContext(context.Background())
+}
+
 // Close closes the database, releasing any open resources.
 //
 // It is rare to Close a DB, as the DB handle is meant to be
@@ -777,12 +784,16 @@ type connRequest struct {
 var errDBClosed = errors.New("sql: database is closed")
 
 // conn returns a newly-opened or cached *driverConn.
-func (db *DB) conn(strategy connReuseStrategy) (*driverConn, error) {
+func (db *DB) conn(ctx context.Context, strategy connReuseStrategy) (*driverConn, error) {
        db.mu.Lock()
        if db.closed {
                db.mu.Unlock()
                return nil, errDBClosed
        }
+       // Check if the context is expired.
+       if err := ctx.Err(); err != nil {
+               return nil, err
+       }
        lifetime := db.maxLifetime
 
        // Prefer a free connection, if possible.
@@ -808,15 +819,21 @@ func (db *DB) conn(strategy connReuseStrategy) (*driverConn, error) {
                req := make(chan connRequest, 1)
                db.connRequests = append(db.connRequests, req)
                db.mu.Unlock()
-               ret, ok := <-req
-               if !ok {
-                       return nil, errDBClosed
-               }
-               if ret.err == nil && ret.conn.expired(lifetime) {
-                       ret.conn.Close()
-                       return nil, driver.ErrBadConn
+
+               // Timeout the connection request with the context.
+               select {
+               case <-ctx.Done():
+                       return nil, ctx.Err()
+               case ret, ok := <-req:
+                       if !ok {
+                               return nil, errDBClosed
+                       }
+                       if ret.err == nil && ret.conn.expired(lifetime) {
+                               ret.conn.Close()
+                               return nil, driver.ErrBadConn
+                       }
+                       return ret.conn, ret.err
                }
-               return ret.conn, ret.err
        }
 
        db.numOpen++ // optimistically
@@ -952,40 +969,51 @@ func (db *DB) putConnDBLocked(dc *driverConn, err error) bool {
 // connection to be opened.
 const maxBadConnRetries = 2
 
-// Prepare creates a prepared statement for later queries or executions.
+// PrepareContext creates a prepared statement for later queries or executions.
 // Multiple queries or executions may be run concurrently from the
 // returned statement.
 // The caller must call the statement's Close method
 // when the statement is no longer needed.
-func (db *DB) Prepare(query string) (*Stmt, error) {
+// Context is for the preparation of the statment, not for the execution of
+// the statement.
+func (db *DB) PrepareContext(ctx context.Context, query string) (*Stmt, error) {
        var stmt *Stmt
        var err error
        for i := 0; i < maxBadConnRetries; i++ {
-               stmt, err = db.prepare(query, cachedOrNewConn)
+               stmt, err = db.prepare(ctx, query, cachedOrNewConn)
                if err != driver.ErrBadConn {
                        break
                }
        }
        if err == driver.ErrBadConn {
-               return db.prepare(query, alwaysNewConn)
+               return db.prepare(ctx, query, alwaysNewConn)
        }
        return stmt, err
 }
 
-func (db *DB) prepare(query string, strategy connReuseStrategy) (*Stmt, error) {
+// Prepare creates a prepared statement for later queries or executions.
+// Multiple queries or executions may be run concurrently from the
+// returned statement.
+// The caller must call the statement's Close method
+// when the statement is no longer needed.
+func (db *DB) Prepare(query string) (*Stmt, error) {
+       return db.PrepareContext(context.Background(), query)
+}
+
+func (db *DB) prepare(ctx context.Context, query string, strategy connReuseStrategy) (*Stmt, error) {
        // TODO: check if db.driver supports an optional
        // driver.Preparer interface and call that instead, if so,
        // otherwise we make a prepared statement that's bound
        // to a connection, and to execute this prepared statement
        // we either need to use this connection (if it's free), else
        // get a new connection + re-prepare + execute on that one.
-       dc, err := db.conn(strategy)
+       dc, err := db.conn(ctx, strategy)
        if err != nil {
                return nil, err
        }
        var si driver.Stmt
        withLock(dc, func() {
-               si, err = dc.prepareLocked(query)
+               si, err = dc.prepareLocked(ctx, query)
        })
        if err != nil {
                db.putConn(dc, err)
@@ -1002,25 +1030,31 @@ func (db *DB) prepare(query string, strategy connReuseStrategy) (*Stmt, error) {
        return stmt, nil
 }
 
-// Exec executes a query without returning any rows.
+// ExecContext executes a query without returning any rows.
 // The args are for any placeholder parameters in the query.
-func (db *DB) Exec(query string, args ...interface{}) (Result, error) {
+func (db *DB) ExecContext(ctx context.Context, query string, args ...interface{}) (Result, error) {
        var res Result
        var err error
        for i := 0; i < maxBadConnRetries; i++ {
-               res, err = db.exec(query, args, cachedOrNewConn)
+               res, err = db.exec(ctx, query, args, cachedOrNewConn)
                if err != driver.ErrBadConn {
                        break
                }
        }
        if err == driver.ErrBadConn {
-               return db.exec(query, args, alwaysNewConn)
+               return db.exec(ctx, query, args, alwaysNewConn)
        }
        return res, err
 }
 
-func (db *DB) exec(query string, args []interface{}, strategy connReuseStrategy) (res Result, err error) {
-       dc, err := db.conn(strategy)
+// Exec executes a query without returning any rows.
+// The args are for any placeholder parameters in the query.
+func (db *DB) Exec(query string, args ...interface{}) (Result, error) {
+       return db.ExecContext(context.Background(), query, args...)
+}
+
+func (db *DB) exec(ctx context.Context, query string, args []interface{}, strategy connReuseStrategy) (res Result, err error) {
+       dc, err := db.conn(ctx, strategy)
        if err != nil {
                return nil, err
        }
@@ -1036,7 +1070,7 @@ func (db *DB) exec(query string, args []interface{}, strategy connReuseStrategy)
                }
                var resi driver.Result
                withLock(dc, func() {
-                       resi, err = execer.Exec(query, dargs)
+                       resi, err = ctxDriverExec(ctx, execer, query, dargs)
                })
                if err != driver.ErrSkip {
                        if err != nil {
@@ -1048,44 +1082,50 @@ func (db *DB) exec(query string, args []interface{}, strategy connReuseStrategy)
 
        var si driver.Stmt
        withLock(dc, func() {
-               si, err = dc.ci.Prepare(query)
+               si, err = ctxDriverPrepare(ctx, dc.ci, query)
        })
        if err != nil {
                return nil, err
        }
        defer withLock(dc, func() { si.Close() })
-       return resultFromStatement(driverStmt{dc, si}, args...)
+       return resultFromStatement(ctx, driverStmt{dc, si}, args...)
 }
 
-// Query executes a query that returns rows, typically a SELECT.
+// QueryContext executes a query that returns rows, typically a SELECT.
 // The args are for any placeholder parameters in the query.
-func (db *DB) Query(query string, args ...interface{}) (*Rows, error) {
+func (db *DB) QueryContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) {
        var rows *Rows
        var err error
        for i := 0; i < maxBadConnRetries; i++ {
-               rows, err = db.query(query, args, cachedOrNewConn)
+               rows, err = db.query(ctx, query, args, cachedOrNewConn)
                if err != driver.ErrBadConn {
                        break
                }
        }
        if err == driver.ErrBadConn {
-               return db.query(query, args, alwaysNewConn)
+               return db.query(ctx, query, args, alwaysNewConn)
        }
        return rows, err
 }
 
-func (db *DB) query(query string, args []interface{}, strategy connReuseStrategy) (*Rows, error) {
-       ci, err := db.conn(strategy)
+// Query executes a query that returns rows, typically a SELECT.
+// The args are for any placeholder parameters in the query.
+func (db *DB) Query(query string, args ...interface{}) (*Rows, error) {
+       return db.QueryContext(context.Background(), query, args...)
+}
+
+func (db *DB) query(ctx context.Context, query string, args []interface{}, strategy connReuseStrategy) (*Rows, error) {
+       ci, err := db.conn(ctx, strategy)
        if err != nil {
                return nil, err
        }
 
-       return db.queryConn(ci, ci.releaseConn, query, args)
+       return db.queryConn(ctx, ci, ci.releaseConn, query, args)
 }
 
 // queryConn executes a query on the given connection.
 // The connection gets released by the releaseConn function.
-func (db *DB) queryConn(dc *driverConn, releaseConn func(error), query string, args []interface{}) (*Rows, error) {
+func (db *DB) queryConn(ctx context.Context, dc *driverConn, releaseConn func(error), query string, args []interface{}) (*Rows, error) {
        if queryer, ok := dc.ci.(driver.Queryer); ok {
                dargs, err := driverArgs(nil, args)
                if err != nil {
@@ -1094,7 +1134,7 @@ func (db *DB) queryConn(dc *driverConn, releaseConn func(error), query string, a
                }
                var rowsi driver.Rows
                withLock(dc, func() {
-                       rowsi, err = queryer.Query(query, dargs)
+                       rowsi, err = ctxDriverQuery(ctx, queryer, query, dargs)
                })
                if err != driver.ErrSkip {
                        if err != nil {
@@ -1115,7 +1155,7 @@ func (db *DB) queryConn(dc *driverConn, releaseConn func(error), query string, a
        var si driver.Stmt
        var err error
        withLock(dc, func() {
-               si, err = dc.ci.Prepare(query)
+               si, err = ctxDriverPrepare(ctx, dc.ci, query)
        })
        if err != nil {
                releaseConn(err)
@@ -1123,7 +1163,7 @@ func (db *DB) queryConn(dc *driverConn, releaseConn func(error), query string, a
        }
 
        ds := driverStmt{dc, si}
-       rowsi, err := rowsiFromStatement(ds, args...)
+       rowsi, err := rowsiFromStatement(ctx, ds, args...)
        if err != nil {
                withLock(dc, func() {
                        si.Close()
@@ -1143,49 +1183,77 @@ func (db *DB) queryConn(dc *driverConn, releaseConn func(error), query string, a
        return rows, nil
 }
 
+// QueryRowContext executes a query that is expected to return at most one row.
+// QueryRowContext always returns a non-nil value. Errors are deferred until
+// Row's Scan method is called.
+func (db *DB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *Row {
+       rows, err := db.QueryContext(ctx, query, args...)
+       return &Row{rows: rows, err: err}
+}
+
 // QueryRow executes a query that is expected to return at most one row.
 // QueryRow always returns a non-nil value. Errors are deferred until
 // Row's Scan method is called.
 func (db *DB) QueryRow(query string, args ...interface{}) *Row {
-       rows, err := db.Query(query, args...)
-       return &Row{rows: rows, err: err}
+       return db.QueryRowContext(context.Background(), query, args...)
 }
 
-// Begin starts a transaction. The isolation level is dependent on
-// the driver.
-func (db *DB) Begin() (*Tx, error) {
+// BeginContext starts a transaction. If a non-default isolation level is used
+// that the driver doesn't support an error will be returned. Different drivers
+// may have slightly different meanings for the same isolation level.
+func (db *DB) BeginContext(ctx context.Context) (*Tx, error) {
        var tx *Tx
        var err error
        for i := 0; i < maxBadConnRetries; i++ {
-               tx, err = db.begin(cachedOrNewConn)
+               tx, err = db.begin(ctx, cachedOrNewConn)
                if err != driver.ErrBadConn {
                        break
                }
        }
        if err == driver.ErrBadConn {
-               return db.begin(alwaysNewConn)
+               return db.begin(ctx, alwaysNewConn)
        }
        return tx, err
 }
 
-func (db *DB) begin(strategy connReuseStrategy) (tx *Tx, err error) {
-       dc, err := db.conn(strategy)
+// Begin starts a transaction. The default isolation level is dependent on
+// the driver.
+func (db *DB) Begin() (*Tx, error) {
+       return db.BeginContext(context.Background())
+}
+
+func (db *DB) begin(ctx context.Context, strategy connReuseStrategy) (tx *Tx, err error) {
+       dc, err := db.conn(ctx, strategy)
        if err != nil {
                return nil, err
        }
        var txi driver.Tx
        withLock(dc, func() {
-               txi, err = dc.ci.Begin()
+               txi, err = ctxDriverBegin(ctx, dc.ci)
        })
        if err != nil {
                db.putConn(dc, err)
                return nil, err
        }
-       return &Tx{
-               db:  db,
-               dc:  dc,
-               txi: txi,
-       }, nil
+
+       // Schedule the transaction to rollback when the context is cancelled.
+       // The cancel function in Tx will be called after done is set to true.
+       ctx, cancel := context.WithCancel(ctx)
+       tx = &Tx{
+               db:     db,
+               dc:     dc,
+               txi:    txi,
+               cancel: cancel,
+       }
+       go func() {
+               select {
+               case <-ctx.Done():
+                       if !tx.done {
+                               tx.Rollback()
+                       }
+               }
+       }()
+       return tx, nil
 }
 
 // Driver returns the database's underlying driver.
@@ -1222,6 +1290,9 @@ type Tx struct {
                sync.Mutex
                v []*Stmt
        }
+
+       // cancel is called after done transitions from false to true.
+       cancel func()
 }
 
 // ErrTxDone is returned by any operation that is performed on a transaction
@@ -1234,11 +1305,12 @@ func (tx *Tx) close(err error) {
        }
        tx.done = true
        tx.db.putConn(tx.dc, err)
+       tx.cancel()
        tx.dc = nil
        tx.txi = nil
 }
 
-func (tx *Tx) grabConn() (*driverConn, error) {
+func (tx *Tx) grabConn(ctx context.Context) (*driverConn, error) {
        if tx.done {
                return nil, ErrTxDone
        }
@@ -1292,7 +1364,10 @@ func (tx *Tx) Rollback() error {
 // be used once the transaction has been committed or rolled back.
 //
 // To use an existing prepared statement on this transaction, see Tx.Stmt.
-func (tx *Tx) Prepare(query string) (*Stmt, error) {
+// Context will be used for the preparation of the context, not
+// for the execution of the returned statement. The returned statement
+// will run in the transaction context.
+func (tx *Tx) PrepareContext(ctx context.Context, query string) (*Stmt, error) {
        // TODO(bradfitz): We could be more efficient here and either
        // provide a method to take an existing Stmt (created on
        // perhaps a different Conn), and re-create it on this Conn if
@@ -1306,7 +1381,7 @@ func (tx *Tx) Prepare(query string) (*Stmt, error) {
        // Perhaps just looking at the reference count (by noting
        // Stmt.Close) would be enough. We might also want a finalizer
        // on Stmt to drop the reference count.
-       dc, err := tx.grabConn()
+       dc, err := tx.grabConn(ctx)
        if err != nil {
                return nil, err
        }
@@ -1334,7 +1409,17 @@ func (tx *Tx) Prepare(query string) (*Stmt, error) {
        return stmt, nil
 }
 
-// Stmt returns a transaction-specific prepared statement from
+// Prepare creates a prepared statement for use within a transaction.
+//
+// The returned statement operates within the transaction and can no longer
+// be used once the transaction has been committed or rolled back.
+//
+// To use an existing prepared statement on this transaction, see Tx.Stmt.
+func (tx *Tx) Prepare(query string) (*Stmt, error) {
+       return tx.PrepareContext(context.Background(), query)
+}
+
+// StmtContext returns a transaction-specific prepared statement from
 // an existing statement.
 //
 // Example:
@@ -1342,11 +1427,11 @@ func (tx *Tx) Prepare(query string) (*Stmt, error) {
 //  ...
 //  tx, err := db.Begin()
 //  ...
-//  res, err := tx.Stmt(updateMoney).Exec(123.45, 98293203)
+//  res, err := tx.StmtContext(ctx, updateMoney).Exec(123.45, 98293203)
 //
 // The returned statement operates within the transaction and can no longer
 // be used once the transaction has been committed or rolled back.
-func (tx *Tx) Stmt(stmt *Stmt) *Stmt {
+func (tx *Tx) StmtContext(ctx context.Context, stmt *Stmt) *Stmt {
        // TODO(bradfitz): optimize this. Currently this re-prepares
        // each time. This is fine for now to illustrate the API but
        // we should really cache already-prepared statements
@@ -1355,7 +1440,7 @@ func (tx *Tx) Stmt(stmt *Stmt) *Stmt {
        if tx.db != stmt.db {
                return &Stmt{stickyErr: errors.New("sql: Tx.Stmt: statement from different database used")}
        }
-       dc, err := tx.grabConn()
+       dc, err := tx.grabConn(ctx)
        if err != nil {
                return &Stmt{stickyErr: err}
        }
@@ -1379,10 +1464,26 @@ func (tx *Tx) Stmt(stmt *Stmt) *Stmt {
        return txs
 }
 
-// Exec executes a query that doesn't return rows.
+// Stmt returns a transaction-specific prepared statement from
+// an existing statement.
+//
+// Example:
+//  updateMoney, err := db.Prepare("UPDATE balance SET money=money+? WHERE id=?")
+//  ...
+//  tx, err := db.Begin()
+//  ...
+//  res, err := tx.Stmt(updateMoney).Exec(123.45, 98293203)
+//
+// The returned statement operates within the transaction and can no longer
+// be used once the transaction has been committed or rolled back.
+func (tx *Tx) Stmt(stmt *Stmt) *Stmt {
+       return tx.StmtContext(context.Background(), stmt)
+}
+
+// ExecContext executes a query that doesn't return rows.
 // For example: an INSERT and UPDATE.
-func (tx *Tx) Exec(query string, args ...interface{}) (Result, error) {
-       dc, err := tx.grabConn()
+func (tx *Tx) ExecContext(ctx context.Context, query string, args ...interface{}) (Result, error) {
+       dc, err := tx.grabConn(ctx)
        if err != nil {
                return nil, err
        }
@@ -1413,25 +1514,43 @@ func (tx *Tx) Exec(query string, args ...interface{}) (Result, error) {
        }
        defer withLock(dc, func() { si.Close() })
 
-       return resultFromStatement(driverStmt{dc, si}, args...)
+       return resultFromStatement(ctx, driverStmt{dc, si}, args...)
 }
 
-// Query executes a query that returns rows, typically a SELECT.
-func (tx *Tx) Query(query string, args ...interface{}) (*Rows, error) {
-       dc, err := tx.grabConn()
+// Exec executes a query that doesn't return rows.
+// For example: an INSERT and UPDATE.
+func (tx *Tx) Exec(query string, args ...interface{}) (Result, error) {
+       return tx.ExecContext(context.Background(), query, args...)
+}
+
+// QueryContext executes a query that returns rows, typically a SELECT.
+func (tx *Tx) QueryContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) {
+       dc, err := tx.grabConn(ctx)
        if err != nil {
                return nil, err
        }
        releaseConn := func(error) {}
-       return tx.db.queryConn(dc, releaseConn, query, args)
+       return tx.db.queryConn(ctx, dc, releaseConn, query, args)
+}
+
+// Query executes a query that returns rows, typically a SELECT.
+func (tx *Tx) Query(query string, args ...interface{}) (*Rows, error) {
+       return tx.QueryContext(context.Background(), query, args...)
+}
+
+// QueryRowContext executes a query that is expected to return at most one row.
+// QueryRowContext always returns a non-nil value. Errors are deferred until
+// Row's Scan method is called.
+func (tx *Tx) QueryRowContext(ctx context.Context, query string, args ...interface{}) *Row {
+       rows, err := tx.QueryContext(ctx, query, args...)
+       return &Row{rows: rows, err: err}
 }
 
 // QueryRow executes a query that is expected to return at most one row.
 // QueryRow always returns a non-nil value. Errors are deferred until
 // Row's Scan method is called.
 func (tx *Tx) QueryRow(query string, args ...interface{}) *Row {
-       rows, err := tx.Query(query, args...)
-       return &Row{rows: rows, err: err}
+       return tx.QueryRowContext(context.Background(), query, args...)
 }
 
 // connStmt is a prepared statement on a particular connection.
@@ -1468,15 +1587,15 @@ type Stmt struct {
        lastNumClosed uint64
 }
 
-// Exec executes a prepared statement with the given arguments and
+// ExecContext executes a prepared statement with the given arguments and
 // returns a Result summarizing the effect of the statement.
-func (s *Stmt) Exec(args ...interface{}) (Result, error) {
+func (s *Stmt) ExecContext(ctx context.Context, args ...interface{}) (Result, error) {
        s.closemu.RLock()
        defer s.closemu.RUnlock()
 
        var res Result
        for i := 0; i < maxBadConnRetries; i++ {
-               dc, releaseConn, si, err := s.connStmt()
+               dc, releaseConn, si, err := s.connStmt(ctx)
                if err != nil {
                        if err == driver.ErrBadConn {
                                continue
@@ -1484,7 +1603,7 @@ func (s *Stmt) Exec(args ...interface{}) (Result, error) {
                        return nil, err
                }
 
-               res, err = resultFromStatement(driverStmt{dc, si}, args...)
+               res, err = resultFromStatement(ctx, driverStmt{dc, si}, args...)
                releaseConn(err)
                if err != driver.ErrBadConn {
                        return res, err
@@ -1493,13 +1612,19 @@ func (s *Stmt) Exec(args ...interface{}) (Result, error) {
        return nil, driver.ErrBadConn
 }
 
+// Exec executes a prepared statement with the given arguments and
+// returns a Result summarizing the effect of the statement.
+func (s *Stmt) Exec(args ...interface{}) (Result, error) {
+       return s.ExecContext(context.Background(), args...)
+}
+
 func driverNumInput(ds driverStmt) int {
        ds.Lock()
        defer ds.Unlock() // in case NumInput panics
        return ds.si.NumInput()
 }
 
-func resultFromStatement(ds driverStmt, args ...interface{}) (Result, error) {
+func resultFromStatement(ctx context.Context, ds driverStmt, args ...interface{}) (Result, error) {
        want := driverNumInput(ds)
 
        // -1 means the driver doesn't know how to count the number of
@@ -1516,7 +1641,8 @@ func resultFromStatement(ds driverStmt, args ...interface{}) (Result, error) {
 
        ds.Lock()
        defer ds.Unlock()
-       resi, err := ds.si.Exec(dargs)
+
+       resi, err := ctxDriverStmtExec(ctx, ds.si, dargs)
        if err != nil {
                return nil, err
        }
@@ -1552,7 +1678,7 @@ func (s *Stmt) removeClosedStmtLocked() {
 // connStmt returns a free driver connection on which to execute the
 // statement, a function to call to release the connection, and a
 // statement bound to that connection.
-func (s *Stmt) connStmt() (ci *driverConn, releaseConn func(error), si driver.Stmt, err error) {
+func (s *Stmt) connStmt(ctx context.Context) (ci *driverConn, releaseConn func(error), si driver.Stmt, err error) {
        if err = s.stickyErr; err != nil {
                return
        }
@@ -1567,7 +1693,7 @@ func (s *Stmt) connStmt() (ci *driverConn, releaseConn func(error), si driver.St
        // transaction was created on.
        if s.tx != nil {
                s.mu.Unlock()
-               ci, err = s.tx.grabConn() // blocks, waiting for the connection.
+               ci, err = s.tx.grabConn(ctx) // blocks, waiting for the connection.
                if err != nil {
                        return
                }
@@ -1578,8 +1704,7 @@ func (s *Stmt) connStmt() (ci *driverConn, releaseConn func(error), si driver.St
        s.removeClosedStmtLocked()
        s.mu.Unlock()
 
-       // TODO(bradfitz): or always wait for one? make configurable later?
-       dc, err := s.db.conn(cachedOrNewConn)
+       dc, err := s.db.conn(ctx, cachedOrNewConn)
        if err != nil {
                return nil, nil, nil, err
        }
@@ -1595,7 +1720,7 @@ func (s *Stmt) connStmt() (ci *driverConn, releaseConn func(error), si driver.St
 
        // No luck; we need to prepare the statement on this connection
        withLock(dc, func() {
-               si, err = dc.prepareLocked(s.query)
+               si, err = dc.prepareLocked(ctx, s.query)
        })
        if err != nil {
                s.db.putConn(dc, err)
@@ -1609,15 +1734,15 @@ func (s *Stmt) connStmt() (ci *driverConn, releaseConn func(error), si driver.St
        return dc, dc.releaseConn, si, nil
 }
 
-// Query executes a prepared query statement with the given arguments
+// QueryContext executes a prepared query statement with the given arguments
 // and returns the query results as a *Rows.
-func (s *Stmt) Query(args ...interface{}) (*Rows, error) {
+func (s *Stmt) QueryContext(ctx context.Context, args ...interface{}) (*Rows, error) {
        s.closemu.RLock()
        defer s.closemu.RUnlock()
 
        var rowsi driver.Rows
        for i := 0; i < maxBadConnRetries; i++ {
-               dc, releaseConn, si, err := s.connStmt()
+               dc, releaseConn, si, err := s.connStmt(ctx)
                if err != nil {
                        if err == driver.ErrBadConn {
                                continue
@@ -1625,7 +1750,7 @@ func (s *Stmt) Query(args ...interface{}) (*Rows, error) {
                        return nil, err
                }
 
-               rowsi, err = rowsiFromStatement(driverStmt{dc, si}, args...)
+               rowsi, err = rowsiFromStatement(ctx, driverStmt{dc, si}, args...)
                if err == nil {
                        // Note: ownership of ci passes to the *Rows, to be freed
                        // with releaseConn.
@@ -1650,7 +1775,13 @@ func (s *Stmt) Query(args ...interface{}) (*Rows, error) {
        return nil, driver.ErrBadConn
 }
 
-func rowsiFromStatement(ds driverStmt, args ...interface{}) (driver.Rows, error) {
+// Query executes a prepared query statement with the given arguments
+// and returns the query results as a *Rows.
+func (s *Stmt) Query(args ...interface{}) (*Rows, error) {
+       return s.QueryContext(context.Background(), args...)
+}
+
+func rowsiFromStatement(ctx context.Context, ds driverStmt, args ...interface{}) (driver.Rows, error) {
        var want int
        withLock(ds, func() {
                want = ds.si.NumInput()
@@ -1670,14 +1801,15 @@ func rowsiFromStatement(ds driverStmt, args ...interface{}) (driver.Rows, error)
 
        ds.Lock()
        defer ds.Unlock()
-       rowsi, err := ds.si.Query(dargs)
+
+       rowsi, err := ctxDriverStmtQuery(ctx, ds.si, dargs)
        if err != nil {
                return nil, err
        }
        return rowsi, nil
 }
 
-// QueryRow executes a prepared query statement with the given arguments.
+// QueryRowContext executes a prepared query statement with the given arguments.
 // If an error occurs during the execution of the statement, that error will
 // be returned by a call to Scan on the returned *Row, which is always non-nil.
 // If the query selects no rows, the *Row's Scan will return ErrNoRows.
@@ -1687,15 +1819,30 @@ func rowsiFromStatement(ds driverStmt, args ...interface{}) (driver.Rows, error)
 // Example usage:
 //
 //  var name string
-//  err := nameByUseridStmt.QueryRow(id).Scan(&name)
-func (s *Stmt) QueryRow(args ...interface{}) *Row {
-       rows, err := s.Query(args...)
+//  err := nameByUseridStmt.QueryRowContext(ctx, id).Scan(&name)
+func (s *Stmt) QueryRowContext(ctx context.Context, args ...interface{}) *Row {
+       rows, err := s.QueryContext(ctx, args...)
        if err != nil {
                return &Row{err: err}
        }
        return &Row{rows: rows}
 }
 
+// QueryRow executes a prepared query statement with the given arguments.
+// If an error occurs during the execution of the statement, that error will
+// be returned by a call to Scan on the returned *Row, which is always non-nil.
+// If the query selects no rows, the *Row's Scan will return ErrNoRows.
+// Otherwise, the *Row's Scan scans the first selected row and discards
+// the rest.
+//
+// Example usage:
+//
+//  var name string
+//  err := nameByUseridStmt.QueryRow(id).Scan(&name)
+func (s *Stmt) QueryRow(args ...interface{}) *Row {
+       return s.QueryRowContext(context.Background(), args...)
+}
+
 // Close closes the statement.
 func (s *Stmt) Close() error {
        s.closemu.Lock()
index 41afd00e92e07d9d9114f22c46bb9a8d6fbbfe64..9fcb2e38c1031b0e69d4dac556cb8eebaca9910f 100644 (file)
@@ -5,6 +5,7 @@
 package sql
 
 import (
+       "context"
        "database/sql/driver"
        "errors"
        "fmt"
@@ -1159,17 +1160,19 @@ func TestMaxOpenConnsOnBusy(t *testing.T) {
 
        db.SetMaxOpenConns(3)
 
-       conn0, err := db.conn(cachedOrNewConn)
+       ctx := context.Background()
+
+       conn0, err := db.conn(ctx, cachedOrNewConn)
        if err != nil {
                t.Fatalf("db open conn fail: %v", err)
        }
 
-       conn1, err := db.conn(cachedOrNewConn)
+       conn1, err := db.conn(ctx, cachedOrNewConn)
        if err != nil {
                t.Fatalf("db open conn fail: %v", err)
        }
 
-       conn2, err := db.conn(cachedOrNewConn)
+       conn2, err := db.conn(ctx, cachedOrNewConn)
        if err != nil {
                t.Fatalf("db open conn fail: %v", err)
        }
index d8eb2ee7261853aff6a44fbfbce5f6c43be36c36..c7e11498ddb7246cdf99393207d77ba461c30fea 100644 (file)
@@ -228,8 +228,8 @@ var pkgDeps = map[string][]string{
        "compress/lzw":             {"L4"},
        "compress/zlib":            {"L4", "compress/flate"},
        "context":                  {"errors", "fmt", "reflect", "sync", "time"},
-       "database/sql":             {"L4", "container/list", "database/sql/driver"},
-       "database/sql/driver":      {"L4", "time"},
+       "database/sql":             {"L4", "container/list", "context", "database/sql/driver"},
+       "database/sql/driver":      {"L4", "context", "time"},
        "debug/dwarf":              {"L4"},
        "debug/elf":                {"L4", "OS", "debug/dwarf", "compress/zlib"},
        "debug/gosym":              {"L4"},