From: Daniel Theophanes Date: Wed, 28 Sep 2016 19:51:39 +0000 (-0700) Subject: database/sql: close Rows when context is cancelled X-Git-Tag: go1.8beta1~1105 X-Git-Url: http://www.git.cypherpunks.su/?a=commitdiff_plain;h=d2df8498f366669acbae24f38e3683b3acdab102;p=gostls13.git database/sql: close Rows when context is cancelled To prevent leaking connections, close any open Rows when the context is cancelled. Also enforce context cancel while reading rows off of the wire. Change-Id: I62237ecdb7d250d6734f6ce3d2b0bcb16dc6fda7 Reviewed-on: https://go-review.googlesource.com/29957 Reviewed-by: Brad Fitzpatrick --- diff --git a/src/database/sql/ctxutil.go b/src/database/sql/ctxutil.go index 65e1652657..e1d4c03c9a 100644 --- a/src/database/sql/ctxutil.go +++ b/src/database/sql/ctxutil.go @@ -14,6 +14,10 @@ func ctxDriverPrepare(ctx context.Context, ci driver.Conn, query string) (driver if ciCtx, is := ci.(driver.ConnPrepareContext); is { return ciCtx.PrepareContext(ctx, query) } + if ctx.Done() == context.Background().Done() { + return ci.Prepare(query) + } + type R struct { err error panic interface{} @@ -50,6 +54,10 @@ func ctxDriverExec(ctx context.Context, execer driver.Execer, query string, darg if execerCtx, is := execer.(driver.ExecerContext); is { return execerCtx.ExecContext(ctx, query, dargs) } + if ctx.Done() == context.Background().Done() { + return execer.Exec(query, dargs) + } + type R struct { err error panic interface{} @@ -86,6 +94,10 @@ func ctxDriverQuery(ctx context.Context, queryer driver.Queryer, query string, d if queryerCtx, is := queryer.(driver.QueryerContext); is { return queryerCtx.QueryContext(ctx, query, dargs) } + if ctx.Done() == context.Background().Done() { + return queryer.Query(query, dargs) + } + type R struct { err error panic interface{} @@ -122,6 +134,10 @@ func ctxDriverStmtExec(ctx context.Context, si driver.Stmt, dargs []driver.Value if siCtx, is := si.(driver.StmtExecContext); is { return siCtx.ExecContext(ctx, dargs) } + if ctx.Done() == context.Background().Done() { + return si.Exec(dargs) + } + type R struct { err error panic interface{} @@ -158,6 +174,10 @@ func ctxDriverStmtQuery(ctx context.Context, si driver.Stmt, dargs []driver.Valu if siCtx, is := si.(driver.StmtQueryContext); is { return siCtx.QueryContext(ctx, dargs) } + if ctx.Done() == context.Background().Done() { + return si.Query(dargs) + } + type R struct { err error panic interface{} @@ -196,6 +216,10 @@ func ctxDriverBegin(ctx context.Context, ci driver.Conn) (driver.Tx, error) { if ciCtx, is := ci.(driver.ConnBeginContext); is { return ciCtx.BeginContext(ctx) } + if ctx.Done() == context.Background().Done() { + return ci.Begin() + } + // TODO(kardianos): check the transaction level in ctx. If set and non-default // then return an error here as the BeginContext driver value is not supported. diff --git a/src/database/sql/sql.go b/src/database/sql/sql.go index 4c44e2b6f4..f56c71a638 100644 --- a/src/database/sql/sql.go +++ b/src/database/sql/sql.go @@ -974,7 +974,8 @@ const maxBadConnRetries = 2 // returned statement. // The caller must call the statement's Close method // when the statement is no longer needed. -// Context is for the preparation of the statment, not for the execution of +// +// The provided context is for the preparation of the statment, not for the execution of // the statement. func (db *DB) PrepareContext(ctx context.Context, query string) (*Stmt, error) { var stmt *Stmt @@ -1148,6 +1149,7 @@ func (db *DB) queryConn(ctx context.Context, dc *driverConn, releaseConn func(er releaseConn: releaseConn, rowsi: rowsi, } + rows.initContextClose(ctx) return rows, nil } } @@ -1180,6 +1182,7 @@ func (db *DB) queryConn(ctx context.Context, dc *driverConn, releaseConn func(er rowsi: rowsi, closeStmt: si, } + rows.initContextClose(ctx) return rows, nil } @@ -1364,7 +1367,8 @@ func (tx *Tx) Rollback() error { // be used once the transaction has been committed or rolled back. // // To use an existing prepared statement on this transaction, see Tx.Stmt. -// Context will be used for the preparation of the context, not +// +// The provided context will be used for the preparation of the context, not // for the execution of the returned statement. The returned statement // will run in the transaction context. func (tx *Tx) PrepareContext(ctx context.Context, query string) (*Stmt, error) { @@ -1759,6 +1763,7 @@ func (s *Stmt) QueryContext(ctx context.Context, args ...interface{}) (*Rows, er rowsi: rowsi, // releaseConn set below } + rows.initContextClose(ctx) s.db.addDep(s, rows) rows.releaseConn = func(err error) { releaseConn(err) @@ -1899,12 +1904,30 @@ type Rows struct { releaseConn func(error) rowsi driver.Rows - closed bool + // closed value is 1 when the Rows is closed. + // Use atomic operations on value when checking value. + closed int32 + ctxClose chan struct{} // closed when Rows is closed, may be null. lastcols []driver.Value lasterr error // non-nil only if closed is true closeStmt driver.Stmt // if non-nil, statement to Close on close } +func (rs *Rows) initContextClose(ctx context.Context) { + if ctx.Done() == context.Background().Done() { + return + } + + rs.ctxClose = make(chan struct{}) + go func() { + select { + case <-ctx.Done(): + rs.Close() + case <-rs.ctxClose: + } + }() +} + // Next prepares the next result row for reading with the Scan method. It // returns true on success, or false if there is no next result row or an error // happened while preparing it. Err should be consulted to distinguish between @@ -1912,7 +1935,7 @@ type Rows struct { // // Every call to Scan, even the first one, must be preceded by a call to Next. func (rs *Rows) Next() bool { - if rs.closed { + if rs.isClosed() { return false } if rs.lastcols == nil { @@ -1939,7 +1962,7 @@ func (rs *Rows) Err() error { // Columns returns an error if the rows are closed, or if the rows // are from QueryRow and there was a deferred error. func (rs *Rows) Columns() ([]string, error) { - if rs.closed { + if rs.isClosed() { return nil, errors.New("sql: Rows are closed") } if rs.rowsi == nil { @@ -2000,7 +2023,7 @@ func (rs *Rows) Columns() ([]string, error) { // For scanning into *bool, the source may be true, false, 1, 0, or // string inputs parseable by strconv.ParseBool. func (rs *Rows) Scan(dest ...interface{}) error { - if rs.closed { + if rs.isClosed() { return errors.New("sql: Rows are closed") } if rs.lastcols == nil { @@ -2020,14 +2043,20 @@ func (rs *Rows) Scan(dest ...interface{}) error { var rowsCloseHook func(*Rows, *error) +func (rs *Rows) isClosed() bool { + return atomic.LoadInt32(&rs.closed) != 0 +} + // Close closes the Rows, preventing further enumeration. If Next returns // false, the Rows are closed automatically and it will suffice to check the // result of Err. Close is idempotent and does not affect the result of Err. func (rs *Rows) Close() error { - if rs.closed { + if !atomic.CompareAndSwapInt32(&rs.closed, 0, 1) { return nil } - rs.closed = true + if rs.ctxClose != nil { + close(rs.ctxClose) + } err := rs.rowsi.Close() if fn := rowsCloseHook; fn != nil { fn(rs, &err) diff --git a/src/database/sql/sql_test.go b/src/database/sql/sql_test.go index 9fcb2e38c1..ca14af79e7 100644 --- a/src/database/sql/sql_test.go +++ b/src/database/sql/sql_test.go @@ -261,6 +261,64 @@ func TestQuery(t *testing.T) { } } +func TestQueryContext(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + prepares0 := numPrepares(t, db) + + ctx, cancel := context.WithCancel(context.Background()) + + rows, err := db.QueryContext(ctx, "SELECT|people|age,name|") + if err != nil { + t.Fatalf("Query: %v", err) + } + type row struct { + age int + name string + } + got := []row{} + index := 0 + for rows.Next() { + if index == 2 { + cancel() + time.Sleep(10 * time.Millisecond) + } + var r row + err = rows.Scan(&r.age, &r.name) + if err != nil { + if index == 2 { + break + } + t.Fatalf("Scan: %v", err) + } + if index == 2 && err == nil { + t.Fatal("expected an error on last scan") + } + got = append(got, r) + index++ + } + err = rows.Err() + if err != nil { + t.Fatalf("Err: %v", err) + } + want := []row{ + {age: 1, name: "Alice"}, + {age: 2, name: "Bob"}, + } + if !reflect.DeepEqual(got, want) { + t.Errorf("mismatch.\n got: %#v\nwant: %#v", got, want) + } + + // And verify that the final rows.Next() call, which hit EOF, + // also closed the rows connection. + if n := db.numFreeConns(); n != 1 { + t.Fatalf("free conns after query hitting EOF = %d; want 1", n) + } + if prepares := numPrepares(t, db) - prepares0; prepares != 1 { + t.Errorf("executed %d Prepare statements; want 1", prepares) + } +} + func TestByteOwnership(t *testing.T) { db := newTestDB(t, "people") defer closeDB(t, db)