if ciCtx, is := ci.(driver.ConnPrepareContext); is {
return ciCtx.PrepareContext(ctx, query)
}
- if ctx.Done() == context.Background().Done() {
- return ci.Prepare(query)
- }
-
- type R struct {
- err error
- panic interface{}
- si driver.Stmt
- }
-
- rc := make(chan R, 1)
- go func() {
- r := R{}
- defer func() {
- if v := recover(); v != nil {
- r.panic = v
- }
- rc <- r
- }()
- r.si, r.err = ci.Prepare(query)
- }()
- select {
- case <-ctx.Done():
- go func() {
- <-rc
- close(rc)
- }()
- return nil, ctx.Err()
- case r := <-rc:
- if r.panic != nil {
- panic(r.panic)
+ si, err := ci.Prepare(query)
+ if err == nil {
+ select {
+ default:
+ case <-ctx.Done():
+ si.Close()
+ return nil, ctx.Err()
}
- return r.si, r.err
}
+ return si, err
}
func ctxDriverExec(ctx context.Context, execer driver.Execer, query string, nvdargs []driver.NamedValue) (driver.Result, error) {
if err != nil {
return nil, err
}
- if ctx.Done() == context.Background().Done() {
- return execer.Exec(query, dargs)
- }
-
- type R struct {
- err error
- panic interface{}
- resi driver.Result
- }
- rc := make(chan R, 1)
- go func() {
- r := R{}
- defer func() {
- if v := recover(); v != nil {
- r.panic = v
- }
- rc <- r
- }()
- r.resi, r.err = execer.Exec(query, dargs)
- }()
- select {
- case <-ctx.Done():
- go func() {
- <-rc
- close(rc)
- }()
- return nil, ctx.Err()
- case r := <-rc:
- if r.panic != nil {
- panic(r.panic)
+ resi, err := execer.Exec(query, dargs)
+ if err == nil {
+ select {
+ default:
+ case <-ctx.Done():
+ return resi, ctx.Err()
}
- return r.resi, r.err
}
+ return resi, err
}
func ctxDriverQuery(ctx context.Context, queryer driver.Queryer, query string, nvdargs []driver.NamedValue) (driver.Rows, error) {
if queryerCtx, is := queryer.(driver.QueryerContext); is {
- return queryerCtx.QueryContext(ctx, query, nvdargs)
+ ret, err := queryerCtx.QueryContext(ctx, query, nvdargs)
+ return ret, err
}
dargs, err := namedValueToValue(nvdargs)
if err != nil {
return nil, err
}
- if ctx.Done() == context.Background().Done() {
- return queryer.Query(query, dargs)
- }
- type R struct {
- err error
- panic interface{}
- rowsi driver.Rows
- }
-
- rc := make(chan R, 1)
- go func() {
- r := R{}
- defer func() {
- if v := recover(); v != nil {
- r.panic = v
- }
- rc <- r
- }()
- r.rowsi, r.err = queryer.Query(query, dargs)
- }()
- select {
- case <-ctx.Done():
- go func() {
- <-rc
- close(rc)
- }()
- return nil, ctx.Err()
- case r := <-rc:
- if r.panic != nil {
- panic(r.panic)
+ rowsi, err := queryer.Query(query, dargs)
+ if err == nil {
+ select {
+ default:
+ case <-ctx.Done():
+ rowsi.Close()
+ return nil, ctx.Err()
}
- return r.rowsi, r.err
}
+ return rowsi, err
}
func ctxDriverStmtExec(ctx context.Context, si driver.Stmt, nvdargs []driver.NamedValue) (driver.Result, error) {
if err != nil {
return nil, err
}
- if ctx.Done() == context.Background().Done() {
- return si.Exec(dargs)
- }
-
- type R struct {
- err error
- panic interface{}
- resi driver.Result
- }
- rc := make(chan R, 1)
- go func() {
- r := R{}
- defer func() {
- if v := recover(); v != nil {
- r.panic = v
- }
- rc <- r
- }()
- r.resi, r.err = si.Exec(dargs)
- }()
- select {
- case <-ctx.Done():
- go func() {
- <-rc
- close(rc)
- }()
- return nil, ctx.Err()
- case r := <-rc:
- if r.panic != nil {
- panic(r.panic)
+ resi, err := si.Exec(dargs)
+ if err == nil {
+ select {
+ default:
+ case <-ctx.Done():
+ return resi, ctx.Err()
}
- return r.resi, r.err
}
+ return resi, err
}
func ctxDriverStmtQuery(ctx context.Context, si driver.Stmt, nvdargs []driver.NamedValue) (driver.Rows, error) {
if err != nil {
return nil, err
}
- if ctx.Done() == context.Background().Done() {
- return si.Query(dargs)
- }
- type R struct {
- err error
- panic interface{}
- rowsi driver.Rows
- }
-
- rc := make(chan R, 1)
- go func() {
- r := R{}
- defer func() {
- if v := recover(); v != nil {
- r.panic = v
- }
- rc <- r
- }()
- r.rowsi, r.err = si.Query(dargs)
- }()
- select {
- case <-ctx.Done():
- go func() {
- <-rc
- close(rc)
- }()
- return nil, ctx.Err()
- case r := <-rc:
- if r.panic != nil {
- panic(r.panic)
+ rowsi, err := si.Query(dargs)
+ if err == nil {
+ select {
+ default:
+ case <-ctx.Done():
+ rowsi.Close()
+ return nil, ctx.Err()
}
- return r.rowsi, r.err
}
+ return rowsi, err
}
var errLevelNotSupported = errors.New("sql: selected isolation level is not supported")
return nil, errors.New("sql: driver does not support read-only transactions")
}
- type R struct {
- err error
- panic interface{}
- txi driver.Tx
- }
- rc := make(chan R, 1)
- go func() {
- r := R{}
- defer func() {
- if v := recover(); v != nil {
- r.panic = v
- }
- rc <- r
- }()
- r.txi, r.err = ci.Begin()
- }()
- select {
- case <-ctx.Done():
- go func() {
- <-rc
- close(rc)
- }()
- return nil, ctx.Err()
- case r := <-rc:
- if r.panic != nil {
- panic(r.panic)
+ txi, err := ci.Begin()
+ if err == nil {
+ select {
+ default:
+ case <-ctx.Done():
+ txi.Rollback()
+ return nil, ctx.Err()
}
- return r.txi, r.err
}
+ return txi, err
}
func namedValueToValue(named []driver.NamedValue) ([]driver.Value, error) {
// statement.
//
// Exec may return ErrSkip.
+//
+// Deprecated: Drivers should implement ExecerContext instead (or additionally).
type Execer interface {
Exec(query string, args []Value) (Result, error)
}
-// ExecerContext is like execer, but must honor the context timeout and return
-// when the context is cancelled.
+// ExecerContext is an optional interface that may be implemented by a Conn.
+//
+// If a Conn does not implement ExecerContext, the sql package's DB.Exec will
+// first prepare a query, execute the statement, and then close the
+// statement.
+//
+// ExecerContext may return ErrSkip.
+//
+// ExecerContext must honor the context timeout and return when the context is canceled.
type ExecerContext interface {
ExecContext(ctx context.Context, query string, args []NamedValue) (Result, error)
}
// statement.
//
// Query may return ErrSkip.
+//
+// Deprecated: Drivers should implement QueryerContext instead (or additionally).
type Queryer interface {
Query(query string, args []Value) (Rows, error)
}
-// QueryerContext is like Queryer, but most honor the context timeout and return
-// when the context is cancelled.
+// QueryerContext is an optional interface that may be implemented by a Conn.
+//
+// If a Conn does not implement QueryerContext, the sql package's DB.Query will
+// first prepare a query, execute the statement, and then close the
+// statement.
+//
+// QueryerContext may return ErrSkip.
+//
+// QueryerContext must honor the context timeout and return when the context is canceled.
type QueryerContext interface {
QueryContext(ctx context.Context, query string, args []NamedValue) (Rows, error)
}
Close() error
// Begin starts and returns a new transaction.
+ //
+ // Deprecated: Drivers should implement ConnBeginContext instead (or additionally).
Begin() (Tx, error)
}
// ConnBeginContext enhances the Conn interface with context.
type ConnBeginContext interface {
// BeginContext starts and returns a new transaction.
- // The provided context should be used to roll the transaction back
- // if it is cancelled.
+ // If the context is canceled by the user the sql package will
+ // call Tx.Rollback before discarding and closing the connection.
//
// This must call IsolationFromContext to determine if there is a set
// isolation level. If the driver does not support setting the isolation
// Exec executes a query that doesn't return rows, such
// as an INSERT or UPDATE.
+ //
+ // Deprecated: Drivers should implement StmtExecContext instead (or additionally).
Exec(args []Value) (Result, error)
// Query executes a query that may return rows, such as a
// SELECT.
+ //
+ // Deprecated: Drivers should implement StmtQueryContext instead (or additionally).
Query(args []Value) (Rows, error)
}
// StmtExecContext enhances the Stmt interface by providing Exec with context.
type StmtExecContext interface {
- // ExecContext must honor the context timeout and return when it is cancelled.
+ // ExecContext executes a query that doesn't return rows, such
+ // as an INSERT or UPDATE.
+ //
+ // ExecContext must honor the context timeout and return when it is canceled.
ExecContext(ctx context.Context, args []NamedValue) (Result, error)
}
// StmtQueryContext enhances the Stmt interface by providing Query with context.
type StmtQueryContext interface {
- // QueryContext must honor the context timeout and return when it is cancelled.
+ // QueryContext executes a query that may return rows, such as a
+ // SELECT.
+ //
+ // QueryContext must honor the context timeout and return when it is canceled.
QueryContext(ctx context.Context, args []NamedValue) (Rows, error)
}
// Any of these can be preceded by PANIC|<method>|, to cause the
// named method on fakeStmt to panic.
//
+// Any of these can be proceeded by WAIT|<duration>|, to cause the
+// named method on fakeStmt to sleep for the specified duration.
+//
// Multiple of these can be combined when separated with a semicolon.
//
// When opening a fakeDriver's database, it starts empty with no
cmd string
table string
panic string
+ wait time.Duration
next *fakeStmt // used for returning multiple results.
if firstStmt == nil {
firstStmt = stmt
}
- if len(parts) >= 3 && parts[0] == "PANIC" {
- stmt.panic = parts[1]
- parts = parts[2:]
+ if len(parts) >= 3 {
+ switch parts[0] {
+ case "PANIC":
+ stmt.panic = parts[1]
+ parts = parts[2:]
+ case "WAIT":
+ wait, err := time.ParseDuration(parts[1])
+ if err != nil {
+ return nil, errf("expected section after WAIT to be a duration, got %q %v", parts[1], err)
+ }
+ parts = parts[2:]
+ stmt.wait = wait
+ }
}
cmd := parts[0]
stmt.cmd = cmd
parts = parts[1:]
+ if stmt.wait > 0 {
+ time.Sleep(stmt.wait)
+ }
+
c.incrStat(&c.stmtsMade)
var err error
switch cmd {
return nil, err
}
+ if s.wait > 0 {
+ time.Sleep(s.wait)
+ }
+
+ select {
+ default:
+ case <-ctx.Done():
+ return nil, ctx.Err()
+ }
+
db := s.c.db
switch s.cmd {
case "WIPE":
return nil, errDBClosed
}
// Check if the context is expired.
- if err := ctx.Err(); err != nil {
+ select {
+ default:
+ case <-ctx.Done():
db.mu.Unlock()
- return nil, err
+ return nil, ctx.Err()
}
lifetime := db.maxLifetime
// BeginContext starts a transaction.
//
+// The provided context is used until the transaction is committed or rolled back.
+// If the context is canceled, the sql package will roll back
+// the transaction. Tx.Commit will return an error if the context provided to
+// BeginContext is canceled.
+//
// An isolation level may be set by setting the value in the context
// before calling this. If a non-default isolation level is used
// that the driver doesn't support an error will be returned. Different drivers
dc: dc,
txi: txi,
cancel: cancel,
+ ctx: ctx,
}
- go func() {
+ go func(tx *Tx) {
select {
- case <-ctx.Done():
- if !tx.done {
- tx.Rollback()
+ case <-tx.ctx.Done():
+ if !tx.isDone() {
+ // Discard and close the connection used to ensure the transaction
+ // is closed and the resources are released.
+ tx.rollback(true)
}
}
- }()
+ }(tx)
return tx, nil
}
dc *driverConn
txi driver.Tx
- // done transitions from false to true exactly once, on Commit
+ // done transitions from 0 to 1 exactly once, on Commit
// or Rollback. once done, all operations fail with
// ErrTxDone.
- done bool
+ // Use atomic operations on value when checking value.
+ done int32
// All Stmts prepared for this transaction. These will be closed after the
// transaction has been committed or rolled back.
// cancel is called after done transitions from false to true.
cancel func()
+
+ // ctx lives for the life of the transaction.
+ ctx context.Context
+}
+
+func (tx *Tx) isDone() bool {
+ return atomic.LoadInt32(&tx.done) != 0
}
// ErrTxDone is returned by any operation that is performed on a transaction
var ErrTxDone = errors.New("sql: Transaction has already been committed or rolled back")
func (tx *Tx) close(err error) {
- if tx.done {
+ if !atomic.CompareAndSwapInt32(&tx.done, 0, 1) {
panic("double close") // internal error
}
- tx.done = true
tx.db.putConn(tx.dc, err)
tx.cancel()
tx.dc = nil
}
func (tx *Tx) grabConn(ctx context.Context) (*driverConn, error) {
- if tx.done {
+ if tx.isDone() {
return nil, ErrTxDone
}
return tx.dc, nil
// Commit commits the transaction.
func (tx *Tx) Commit() error {
- if tx.done {
+ select {
+ default:
+ case <-tx.ctx.Done():
+ return tx.ctx.Err()
+ }
+ if tx.isDone() {
return ErrTxDone
}
var err error
return err
}
-// Rollback aborts the transaction.
-func (tx *Tx) Rollback() error {
- if tx.done {
+// rollback aborts the transaction and optionally forces the pool to discard
+// the connection.
+func (tx *Tx) rollback(discardConn bool) error {
+ if tx.isDone() {
return ErrTxDone
}
var err error
if err != driver.ErrBadConn {
tx.closePrepared()
}
+ if discardConn {
+ err = driver.ErrBadConn
+ }
tx.close(err)
return err
}
+// Rollback aborts the transaction.
+func (tx *Tx) Rollback() error {
+ return tx.rollback(false)
+}
+
// Prepare creates a prepared statement for use within a transaction.
//
// The returned statement operates within the transaction and will be closed
var si driver.Stmt
withLock(dc, func() {
- si, err = dc.ci.Prepare(query)
+ si, err = ctxDriverPrepare(ctx, dc.ci, query)
})
if err != nil {
return nil, err
}
var si driver.Stmt
withLock(dc, func() {
- si, err = dc.ci.Prepare(stmt.query)
+ si, err = ctxDriverPrepare(ctx, dc.ci, stmt.query)
})
txs := &Stmt{
db: tx.db,
if err != nil {
t.Fatalf("error closing DB: %v", err)
}
- db.mu.Lock()
- count := db.numOpen
- db.mu.Unlock()
- if count != 0 {
+ if count := db.numOpenConns(); count != 0 {
t.Fatalf("%d connections still open after closing DB", count)
}
}
return len(db.freeConn)
}
+func (db *DB) numOpenConns() int {
+ db.mu.Lock()
+ defer db.mu.Unlock()
+ return db.numOpen
+}
+
// clearAllConns closes all connections in db.
func (db *DB) clearAllConns(t *testing.T) {
db.SetMaxIdleConns(0)
}
}
+func waitCondition(waitFor, checkEvery time.Duration, fn func() bool) bool {
+ deadline := time.Now().Add(waitFor)
+ for time.Now().Before(deadline) {
+ if fn() {
+ return true
+ }
+ time.Sleep(checkEvery)
+ }
+ return false
+}
+
+func TestQueryContextWait(t *testing.T) {
+ db := newTestDB(t, "people")
+ defer closeDB(t, db)
+ prepares0 := numPrepares(t, db)
+
+ ctx, _ := context.WithTimeout(context.Background(), time.Millisecond*15)
+
+ // This will trigger the *fakeConn.Prepare method which will take time
+ // performing the query. The ctxDriverPrepare func will check the context
+ // after this and close the rows and return an error.
+ _, err := db.QueryContext(ctx, "WAIT|30ms|SELECT|people|age,name|")
+ if err != context.DeadlineExceeded {
+ t.Fatalf("expected QueryContext to error with context deadline exceeded but returned %v", err)
+ }
+
+ // Verify closed rows connection after error condition.
+ if n := db.numFreeConns(); n != 1 {
+ t.Fatalf("free conns after query hitting EOF = %d; want 1", n)
+ }
+ if prepares := numPrepares(t, db) - prepares0; prepares != 1 {
+ t.Errorf("executed %d Prepare statements; want 1", prepares)
+ }
+}
+
+func TestTxContextWait(t *testing.T) {
+ db := newTestDB(t, "people")
+ defer closeDB(t, db)
+
+ ctx, _ := context.WithTimeout(context.Background(), time.Millisecond*15)
+
+ tx, err := db.BeginContext(ctx)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // This will trigger the *fakeConn.Prepare method which will take time
+ // performing the query. The ctxDriverPrepare func will check the context
+ // after this and close the rows and return an error.
+ _, err = tx.QueryContext(ctx, "WAIT|30ms|SELECT|people|age,name|")
+ if err != context.DeadlineExceeded {
+ t.Fatalf("expected QueryContext to error with context deadline exceeded but returned %v", err)
+ }
+
+ var numFree int
+ if !waitCondition(5*time.Second, 5*time.Millisecond, func() bool {
+ numFree = db.numFreeConns()
+ return numFree == 0
+ }) {
+ t.Fatalf("free conns after hitting EOF = %d; want 0", numFree)
+ }
+
+ // Ensure the dropped connection allows more connections to be made.
+ // Checked on DB Close.
+ waitCondition(5*time.Second, 5*time.Millisecond, func() bool {
+ return db.numOpenConns() == 0
+ })
+}
+
func TestMultiResultSetQuery(t *testing.T) {
db := newTestDB(t, "people")
defer closeDB(t, db)