if ciCtx, is := ci.(driver.ConnPrepareContext); is {
return ciCtx.PrepareContext(ctx, query)
}
+ if ctx.Done() == context.Background().Done() {
+ return ci.Prepare(query)
+ }
+
type R struct {
err error
panic interface{}
if execerCtx, is := execer.(driver.ExecerContext); is {
return execerCtx.ExecContext(ctx, query, dargs)
}
+ if ctx.Done() == context.Background().Done() {
+ return execer.Exec(query, dargs)
+ }
+
type R struct {
err error
panic interface{}
if queryerCtx, is := queryer.(driver.QueryerContext); is {
return queryerCtx.QueryContext(ctx, query, dargs)
}
+ if ctx.Done() == context.Background().Done() {
+ return queryer.Query(query, dargs)
+ }
+
type R struct {
err error
panic interface{}
if siCtx, is := si.(driver.StmtExecContext); is {
return siCtx.ExecContext(ctx, dargs)
}
+ if ctx.Done() == context.Background().Done() {
+ return si.Exec(dargs)
+ }
+
type R struct {
err error
panic interface{}
if siCtx, is := si.(driver.StmtQueryContext); is {
return siCtx.QueryContext(ctx, dargs)
}
+ if ctx.Done() == context.Background().Done() {
+ return si.Query(dargs)
+ }
+
type R struct {
err error
panic interface{}
if ciCtx, is := ci.(driver.ConnBeginContext); is {
return ciCtx.BeginContext(ctx)
}
+ if ctx.Done() == context.Background().Done() {
+ return ci.Begin()
+ }
+
// TODO(kardianos): check the transaction level in ctx. If set and non-default
// then return an error here as the BeginContext driver value is not supported.
// returned statement.
// The caller must call the statement's Close method
// when the statement is no longer needed.
-// Context is for the preparation of the statment, not for the execution of
+//
+// The provided context is for the preparation of the statment, not for the execution of
// the statement.
func (db *DB) PrepareContext(ctx context.Context, query string) (*Stmt, error) {
var stmt *Stmt
releaseConn: releaseConn,
rowsi: rowsi,
}
+ rows.initContextClose(ctx)
return rows, nil
}
}
rowsi: rowsi,
closeStmt: si,
}
+ rows.initContextClose(ctx)
return rows, nil
}
// be used once the transaction has been committed or rolled back.
//
// To use an existing prepared statement on this transaction, see Tx.Stmt.
-// Context will be used for the preparation of the context, not
+//
+// The provided context will be used for the preparation of the context, not
// for the execution of the returned statement. The returned statement
// will run in the transaction context.
func (tx *Tx) PrepareContext(ctx context.Context, query string) (*Stmt, error) {
rowsi: rowsi,
// releaseConn set below
}
+ rows.initContextClose(ctx)
s.db.addDep(s, rows)
rows.releaseConn = func(err error) {
releaseConn(err)
releaseConn func(error)
rowsi driver.Rows
- closed bool
+ // closed value is 1 when the Rows is closed.
+ // Use atomic operations on value when checking value.
+ closed int32
+ ctxClose chan struct{} // closed when Rows is closed, may be null.
lastcols []driver.Value
lasterr error // non-nil only if closed is true
closeStmt driver.Stmt // if non-nil, statement to Close on close
}
+func (rs *Rows) initContextClose(ctx context.Context) {
+ if ctx.Done() == context.Background().Done() {
+ return
+ }
+
+ rs.ctxClose = make(chan struct{})
+ go func() {
+ select {
+ case <-ctx.Done():
+ rs.Close()
+ case <-rs.ctxClose:
+ }
+ }()
+}
+
// Next prepares the next result row for reading with the Scan method. It
// returns true on success, or false if there is no next result row or an error
// happened while preparing it. Err should be consulted to distinguish between
//
// Every call to Scan, even the first one, must be preceded by a call to Next.
func (rs *Rows) Next() bool {
- if rs.closed {
+ if rs.isClosed() {
return false
}
if rs.lastcols == nil {
// Columns returns an error if the rows are closed, or if the rows
// are from QueryRow and there was a deferred error.
func (rs *Rows) Columns() ([]string, error) {
- if rs.closed {
+ if rs.isClosed() {
return nil, errors.New("sql: Rows are closed")
}
if rs.rowsi == nil {
// For scanning into *bool, the source may be true, false, 1, 0, or
// string inputs parseable by strconv.ParseBool.
func (rs *Rows) Scan(dest ...interface{}) error {
- if rs.closed {
+ if rs.isClosed() {
return errors.New("sql: Rows are closed")
}
if rs.lastcols == nil {
var rowsCloseHook func(*Rows, *error)
+func (rs *Rows) isClosed() bool {
+ return atomic.LoadInt32(&rs.closed) != 0
+}
+
// Close closes the Rows, preventing further enumeration. If Next returns
// false, the Rows are closed automatically and it will suffice to check the
// result of Err. Close is idempotent and does not affect the result of Err.
func (rs *Rows) Close() error {
- if rs.closed {
+ if !atomic.CompareAndSwapInt32(&rs.closed, 0, 1) {
return nil
}
- rs.closed = true
+ if rs.ctxClose != nil {
+ close(rs.ctxClose)
+ }
err := rs.rowsi.Close()
if fn := rowsCloseHook; fn != nil {
fn(rs, &err)
}
}
+func TestQueryContext(t *testing.T) {
+ db := newTestDB(t, "people")
+ defer closeDB(t, db)
+ prepares0 := numPrepares(t, db)
+
+ ctx, cancel := context.WithCancel(context.Background())
+
+ rows, err := db.QueryContext(ctx, "SELECT|people|age,name|")
+ if err != nil {
+ t.Fatalf("Query: %v", err)
+ }
+ type row struct {
+ age int
+ name string
+ }
+ got := []row{}
+ index := 0
+ for rows.Next() {
+ if index == 2 {
+ cancel()
+ time.Sleep(10 * time.Millisecond)
+ }
+ var r row
+ err = rows.Scan(&r.age, &r.name)
+ if err != nil {
+ if index == 2 {
+ break
+ }
+ t.Fatalf("Scan: %v", err)
+ }
+ if index == 2 && err == nil {
+ t.Fatal("expected an error on last scan")
+ }
+ got = append(got, r)
+ index++
+ }
+ err = rows.Err()
+ if err != nil {
+ t.Fatalf("Err: %v", err)
+ }
+ want := []row{
+ {age: 1, name: "Alice"},
+ {age: 2, name: "Bob"},
+ }
+ if !reflect.DeepEqual(got, want) {
+ t.Errorf("mismatch.\n got: %#v\nwant: %#v", got, want)
+ }
+
+ // And verify that the final rows.Next() call, which hit EOF,
+ // also closed the rows connection.
+ if n := db.numFreeConns(); n != 1 {
+ t.Fatalf("free conns after query hitting EOF = %d; want 1", n)
+ }
+ if prepares := numPrepares(t, db) - prepares0; prepares != 1 {
+ t.Errorf("executed %d Prepare statements; want 1", prepares)
+ }
+}
+
func TestByteOwnership(t *testing.T) {
db := newTestDB(t, "people")
defer closeDB(t, db)