]> Cypherpunks repositories - gostls13.git/commitdiff
database/sql: close Rows when context is cancelled
authorDaniel Theophanes <kardianos@gmail.com>
Wed, 28 Sep 2016 19:51:39 +0000 (12:51 -0700)
committerBrad Fitzpatrick <bradfitz@golang.org>
Thu, 29 Sep 2016 22:26:42 +0000 (22:26 +0000)
To prevent leaking connections, close any open Rows when the
context is cancelled. Also enforce context cancel while reading
rows off of the wire.

Change-Id: I62237ecdb7d250d6734f6ce3d2b0bcb16dc6fda7
Reviewed-on: https://go-review.googlesource.com/29957
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
src/database/sql/ctxutil.go
src/database/sql/sql.go
src/database/sql/sql_test.go

index 65e1652657e09a19c344d986b9a399ed4449ead9..e1d4c03c9a3f09e69a5c594d7c6385851d3aecff 100644 (file)
@@ -14,6 +14,10 @@ func ctxDriverPrepare(ctx context.Context, ci driver.Conn, query string) (driver
        if ciCtx, is := ci.(driver.ConnPrepareContext); is {
                return ciCtx.PrepareContext(ctx, query)
        }
+       if ctx.Done() == context.Background().Done() {
+               return ci.Prepare(query)
+       }
+
        type R struct {
                err   error
                panic interface{}
@@ -50,6 +54,10 @@ func ctxDriverExec(ctx context.Context, execer driver.Execer, query string, darg
        if execerCtx, is := execer.(driver.ExecerContext); is {
                return execerCtx.ExecContext(ctx, query, dargs)
        }
+       if ctx.Done() == context.Background().Done() {
+               return execer.Exec(query, dargs)
+       }
+
        type R struct {
                err   error
                panic interface{}
@@ -86,6 +94,10 @@ func ctxDriverQuery(ctx context.Context, queryer driver.Queryer, query string, d
        if queryerCtx, is := queryer.(driver.QueryerContext); is {
                return queryerCtx.QueryContext(ctx, query, dargs)
        }
+       if ctx.Done() == context.Background().Done() {
+               return queryer.Query(query, dargs)
+       }
+
        type R struct {
                err   error
                panic interface{}
@@ -122,6 +134,10 @@ func ctxDriverStmtExec(ctx context.Context, si driver.Stmt, dargs []driver.Value
        if siCtx, is := si.(driver.StmtExecContext); is {
                return siCtx.ExecContext(ctx, dargs)
        }
+       if ctx.Done() == context.Background().Done() {
+               return si.Exec(dargs)
+       }
+
        type R struct {
                err   error
                panic interface{}
@@ -158,6 +174,10 @@ func ctxDriverStmtQuery(ctx context.Context, si driver.Stmt, dargs []driver.Valu
        if siCtx, is := si.(driver.StmtQueryContext); is {
                return siCtx.QueryContext(ctx, dargs)
        }
+       if ctx.Done() == context.Background().Done() {
+               return si.Query(dargs)
+       }
+
        type R struct {
                err   error
                panic interface{}
@@ -196,6 +216,10 @@ func ctxDriverBegin(ctx context.Context, ci driver.Conn) (driver.Tx, error) {
        if ciCtx, is := ci.(driver.ConnBeginContext); is {
                return ciCtx.BeginContext(ctx)
        }
+       if ctx.Done() == context.Background().Done() {
+               return ci.Begin()
+       }
+
        // TODO(kardianos): check the transaction level in ctx. If set and non-default
        // then return an error here as the BeginContext driver value is not supported.
 
index 4c44e2b6f466b982b83570eaf43813df499687ba..f56c71a638bfac10b277bec0a5b71cba0ad81f14 100644 (file)
@@ -974,7 +974,8 @@ const maxBadConnRetries = 2
 // returned statement.
 // The caller must call the statement's Close method
 // when the statement is no longer needed.
-// Context is for the preparation of the statment, not for the execution of
+//
+// The provided context is for the preparation of the statment, not for the execution of
 // the statement.
 func (db *DB) PrepareContext(ctx context.Context, query string) (*Stmt, error) {
        var stmt *Stmt
@@ -1148,6 +1149,7 @@ func (db *DB) queryConn(ctx context.Context, dc *driverConn, releaseConn func(er
                                releaseConn: releaseConn,
                                rowsi:       rowsi,
                        }
+                       rows.initContextClose(ctx)
                        return rows, nil
                }
        }
@@ -1180,6 +1182,7 @@ func (db *DB) queryConn(ctx context.Context, dc *driverConn, releaseConn func(er
                rowsi:       rowsi,
                closeStmt:   si,
        }
+       rows.initContextClose(ctx)
        return rows, nil
 }
 
@@ -1364,7 +1367,8 @@ func (tx *Tx) Rollback() error {
 // be used once the transaction has been committed or rolled back.
 //
 // To use an existing prepared statement on this transaction, see Tx.Stmt.
-// Context will be used for the preparation of the context, not
+//
+// The provided context will be used for the preparation of the context, not
 // for the execution of the returned statement. The returned statement
 // will run in the transaction context.
 func (tx *Tx) PrepareContext(ctx context.Context, query string) (*Stmt, error) {
@@ -1759,6 +1763,7 @@ func (s *Stmt) QueryContext(ctx context.Context, args ...interface{}) (*Rows, er
                                rowsi: rowsi,
                                // releaseConn set below
                        }
+                       rows.initContextClose(ctx)
                        s.db.addDep(s, rows)
                        rows.releaseConn = func(err error) {
                                releaseConn(err)
@@ -1899,12 +1904,30 @@ type Rows struct {
        releaseConn func(error)
        rowsi       driver.Rows
 
-       closed    bool
+       // closed value is 1 when the Rows is closed.
+       // Use atomic operations on value when checking value.
+       closed    int32
+       ctxClose  chan struct{} // closed when Rows is closed, may be null.
        lastcols  []driver.Value
        lasterr   error       // non-nil only if closed is true
        closeStmt driver.Stmt // if non-nil, statement to Close on close
 }
 
+func (rs *Rows) initContextClose(ctx context.Context) {
+       if ctx.Done() == context.Background().Done() {
+               return
+       }
+
+       rs.ctxClose = make(chan struct{})
+       go func() {
+               select {
+               case <-ctx.Done():
+                       rs.Close()
+               case <-rs.ctxClose:
+               }
+       }()
+}
+
 // Next prepares the next result row for reading with the Scan method. It
 // returns true on success, or false if there is no next result row or an error
 // happened while preparing it. Err should be consulted to distinguish between
@@ -1912,7 +1935,7 @@ type Rows struct {
 //
 // Every call to Scan, even the first one, must be preceded by a call to Next.
 func (rs *Rows) Next() bool {
-       if rs.closed {
+       if rs.isClosed() {
                return false
        }
        if rs.lastcols == nil {
@@ -1939,7 +1962,7 @@ func (rs *Rows) Err() error {
 // Columns returns an error if the rows are closed, or if the rows
 // are from QueryRow and there was a deferred error.
 func (rs *Rows) Columns() ([]string, error) {
-       if rs.closed {
+       if rs.isClosed() {
                return nil, errors.New("sql: Rows are closed")
        }
        if rs.rowsi == nil {
@@ -2000,7 +2023,7 @@ func (rs *Rows) Columns() ([]string, error) {
 // For scanning into *bool, the source may be true, false, 1, 0, or
 // string inputs parseable by strconv.ParseBool.
 func (rs *Rows) Scan(dest ...interface{}) error {
-       if rs.closed {
+       if rs.isClosed() {
                return errors.New("sql: Rows are closed")
        }
        if rs.lastcols == nil {
@@ -2020,14 +2043,20 @@ func (rs *Rows) Scan(dest ...interface{}) error {
 
 var rowsCloseHook func(*Rows, *error)
 
+func (rs *Rows) isClosed() bool {
+       return atomic.LoadInt32(&rs.closed) != 0
+}
+
 // Close closes the Rows, preventing further enumeration. If Next returns
 // false, the Rows are closed automatically and it will suffice to check the
 // result of Err. Close is idempotent and does not affect the result of Err.
 func (rs *Rows) Close() error {
-       if rs.closed {
+       if !atomic.CompareAndSwapInt32(&rs.closed, 0, 1) {
                return nil
        }
-       rs.closed = true
+       if rs.ctxClose != nil {
+               close(rs.ctxClose)
+       }
        err := rs.rowsi.Close()
        if fn := rowsCloseHook; fn != nil {
                fn(rs, &err)
index 9fcb2e38c1031b0e69d4dac556cb8eebaca9910f..ca14af79e7cb85b90fdcbadbdd74c6a668556d1c 100644 (file)
@@ -261,6 +261,64 @@ func TestQuery(t *testing.T) {
        }
 }
 
+func TestQueryContext(t *testing.T) {
+       db := newTestDB(t, "people")
+       defer closeDB(t, db)
+       prepares0 := numPrepares(t, db)
+
+       ctx, cancel := context.WithCancel(context.Background())
+
+       rows, err := db.QueryContext(ctx, "SELECT|people|age,name|")
+       if err != nil {
+               t.Fatalf("Query: %v", err)
+       }
+       type row struct {
+               age  int
+               name string
+       }
+       got := []row{}
+       index := 0
+       for rows.Next() {
+               if index == 2 {
+                       cancel()
+                       time.Sleep(10 * time.Millisecond)
+               }
+               var r row
+               err = rows.Scan(&r.age, &r.name)
+               if err != nil {
+                       if index == 2 {
+                               break
+                       }
+                       t.Fatalf("Scan: %v", err)
+               }
+               if index == 2 && err == nil {
+                       t.Fatal("expected an error on last scan")
+               }
+               got = append(got, r)
+               index++
+       }
+       err = rows.Err()
+       if err != nil {
+               t.Fatalf("Err: %v", err)
+       }
+       want := []row{
+               {age: 1, name: "Alice"},
+               {age: 2, name: "Bob"},
+       }
+       if !reflect.DeepEqual(got, want) {
+               t.Errorf("mismatch.\n got: %#v\nwant: %#v", got, want)
+       }
+
+       // And verify that the final rows.Next() call, which hit EOF,
+       // also closed the rows connection.
+       if n := db.numFreeConns(); n != 1 {
+               t.Fatalf("free conns after query hitting EOF = %d; want 1", n)
+       }
+       if prepares := numPrepares(t, db) - prepares0; prepares != 1 {
+               t.Errorf("executed %d Prepare statements; want 1", prepares)
+       }
+}
+
 func TestByteOwnership(t *testing.T) {
        db := newTestDB(t, "people")
        defer closeDB(t, db)