]> Cypherpunks repositories - gostls13.git/commitdiff
[release-branch.go1.23] database/sql: avoid closing Rows while scan is in progress
authorDamien Neil <dneil@google.com>
Wed, 23 Jul 2025 21:26:54 +0000 (14:26 -0700)
committerGopher Robot <gobot@golang.org>
Wed, 6 Aug 2025 17:51:00 +0000 (10:51 -0700)
A database/sql/driver.Rows can return database-owned data
from Rows.Next. The driver.Rows documentation doesn't explicitly
document the lifetime guarantees for this data, but a reasonable
expectation is that the caller of Next should only access it
until the next call to Rows.Close or Rows.Next.

Avoid violating that constraint when a query is cancelled while
a call to database/sql.Rows.Scan (note the difference between
the two different Rows types!) is in progress. We previously
took care to avoid closing a driver.Rows while the user has
access to driver-owned memory via a RawData, but we could still
close a driver.Rows while a Scan call was in the process of
reading previously-returned driver-owned data.

Update the fake DB used in database/sql tests to invalidate
returned data to help catch other places we might be
incorrectly retaining it.

Updates #74831
Fixes #74832

Change-Id: Ice45b5fad51b679c38e3e1d21ef39156b56d6037
Reviewed-on: https://go-internal-review.googlesource.com/c/go/+/2540
Reviewed-by: Roland Shoemaker <bracewell@google.com>
Reviewed-by: Neal Patel <nealpatel@google.com>
Reviewed-on: https://go-internal-review.googlesource.com/c/go/+/2601
Reviewed-on: https://go-review.googlesource.com/c/go/+/693558
TryBot-Bypass: Dmitri Shuralyov <dmitshur@golang.org>
Reviewed-by: Mark Freeman <markfreeman@google.com>
Reviewed-by: Dmitri Shuralyov <dmitshur@google.com>
Auto-Submit: Dmitri Shuralyov <dmitshur@google.com>

src/database/sql/convert.go
src/database/sql/fakedb_test.go
src/database/sql/sql.go
src/database/sql/sql_test.go

index c261046b187e5271861d75cedb2ff872f76dc721..396833c2fc800983ef8c0a93db3a40811ef13755 100644 (file)
@@ -335,7 +335,6 @@ func convertAssignRows(dest, src any, rows *Rows) error {
                        if rows == nil {
                                return errors.New("invalid context to convert cursor rows, missing parent *Rows")
                        }
-                       rows.closemu.Lock()
                        *d = Rows{
                                dc:          rows.dc,
                                releaseConn: func(error) {},
@@ -351,7 +350,6 @@ func convertAssignRows(dest, src any, rows *Rows) error {
                                        parentCancel()
                                }
                        }
-                       rows.closemu.Unlock()
                        return nil
                }
        }
index 3dfcd447b52bca1234c2407a7ce62c017d46015f..003e6c62986f31a071acf8a11f10b7ccf92893a5 100644 (file)
@@ -5,6 +5,7 @@
 package sql
 
 import (
+       "bytes"
        "context"
        "database/sql/driver"
        "errors"
@@ -15,7 +16,6 @@ import (
        "strconv"
        "strings"
        "sync"
-       "sync/atomic"
        "testing"
        "time"
 )
@@ -91,8 +91,6 @@ func (cc *fakeDriverCtx) OpenConnector(name string) (driver.Connector, error) {
 type fakeDB struct {
        name string
 
-       useRawBytes atomic.Bool
-
        mu       sync.Mutex
        tables   map[string]*table
        badConn  bool
@@ -684,8 +682,6 @@ func (c *fakeConn) PrepareContext(ctx context.Context, query string) (driver.Stm
                switch cmd {
                case "WIPE":
                        // Nothing
-               case "USE_RAWBYTES":
-                       c.db.useRawBytes.Store(true)
                case "SELECT":
                        stmt, err = c.prepareSelect(stmt, parts)
                case "CREATE":
@@ -789,9 +785,6 @@ func (s *fakeStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (d
        case "WIPE":
                db.wipe()
                return driver.ResultNoRows, nil
-       case "USE_RAWBYTES":
-               s.c.db.useRawBytes.Store(true)
-               return driver.ResultNoRows, nil
        case "CREATE":
                if err := db.createTable(s.table, s.colName, s.colType); err != nil {
                        return nil, err
@@ -1076,10 +1069,9 @@ type rowsCursor struct {
        errPos int
        err    error
 
-       // a clone of slices to give out to clients, indexed by the
-       // original slice's first byte address.  we clone them
-       // just so we're able to corrupt them on close.
-       bytesClone map[*byte][]byte
+       // Data returned to clients.
+       // We clone and stash it here so it can be invalidated by Close and Next.
+       driverOwnedMemory [][]byte
 
        // Every operation writes to line to enable the race detector
        // check for data races.
@@ -1096,9 +1088,19 @@ func (rc *rowsCursor) touchMem() {
        rc.line++
 }
 
+func (rc *rowsCursor) invalidateDriverOwnedMemory() {
+       for _, buf := range rc.driverOwnedMemory {
+               for i := range buf {
+                       buf[i] = 'x'
+               }
+       }
+       rc.driverOwnedMemory = nil
+}
+
 func (rc *rowsCursor) Close() error {
        rc.touchMem()
        rc.parentMem.touchMem()
+       rc.invalidateDriverOwnedMemory()
        rc.closed = true
        return rc.closeErr
 }
@@ -1129,6 +1131,8 @@ func (rc *rowsCursor) Next(dest []driver.Value) error {
        if rc.posRow >= len(rc.rows[rc.posSet]) {
                return io.EOF // per interface spec
        }
+       // Corrupt any previously returned bytes.
+       rc.invalidateDriverOwnedMemory()
        for i, v := range rc.rows[rc.posSet][rc.posRow].cols {
                // TODO(bradfitz): convert to subset types? naah, I
                // think the subset types should only be input to
@@ -1136,20 +1140,13 @@ func (rc *rowsCursor) Next(dest []driver.Value) error {
                // a wider range of types coming out of drivers. all
                // for ease of drivers, and to prevent drivers from
                // messing up conversions or doing them differently.
-               dest[i] = v
-
-               if bs, ok := v.([]byte); ok && !rc.db.useRawBytes.Load() {
-                       if rc.bytesClone == nil {
-                               rc.bytesClone = make(map[*byte][]byte)
-                       }
-                       clone, ok := rc.bytesClone[&bs[0]]
-                       if !ok {
-                               clone = make([]byte, len(bs))
-                               copy(clone, bs)
-                               rc.bytesClone[&bs[0]] = clone
-                       }
-                       dest[i] = clone
+               if bs, ok := v.([]byte); ok {
+                       // Clone []bytes and stash for later invalidation.
+                       bs = bytes.Clone(bs)
+                       rc.driverOwnedMemory = append(rc.driverOwnedMemory, bs)
+                       v = bs
                }
+               dest[i] = v
        }
        return nil
 }
index c247a9b506bfabb38b03ed1f160116370fde0572..3346ad48f1b33e18c844e38f1da3492bc684458c 100644 (file)
@@ -3360,38 +3360,36 @@ func (rs *Rows) Scan(dest ...any) error {
                // without calling Next.
                return fmt.Errorf("sql: Scan called without calling Next (closemuScanHold)")
        }
+
        rs.closemu.RLock()
+       rs.raw = rs.raw[:0]
+       err := rs.scanLocked(dest...)
+       if err == nil && scanArgsContainRawBytes(dest) {
+               rs.closemuScanHold = true
+       } else {
+               rs.closemu.RUnlock()
+       }
+       return err
+}
 
+func (rs *Rows) scanLocked(dest ...any) error {
        if rs.lasterr != nil && rs.lasterr != io.EOF {
-               rs.closemu.RUnlock()
                return rs.lasterr
        }
        if rs.closed {
-               err := rs.lasterrOrErrLocked(errRowsClosed)
-               rs.closemu.RUnlock()
-               return err
-       }
-
-       if scanArgsContainRawBytes(dest) {
-               rs.closemuScanHold = true
-               rs.raw = rs.raw[:0]
-       } else {
-               rs.closemu.RUnlock()
+               return rs.lasterrOrErrLocked(errRowsClosed)
        }
 
        if rs.lastcols == nil {
-               rs.closemuRUnlockIfHeldByScan()
                return errors.New("sql: Scan called without calling Next")
        }
        if len(dest) != len(rs.lastcols) {
-               rs.closemuRUnlockIfHeldByScan()
                return fmt.Errorf("sql: expected %d destination arguments in Scan, not %d", len(rs.lastcols), len(dest))
        }
 
        for i, sv := range rs.lastcols {
                err := convertAssignRows(dest[i], sv, rs)
                if err != nil {
-                       rs.closemuRUnlockIfHeldByScan()
                        return fmt.Errorf(`sql: Scan error on column index %d, name %q: %w`, i, rs.rowsi.Columns()[i], err)
                }
        }
index 110a2bae5bd2473ec708923f651f1265d4744cc5..236bbbb1746d51cebc58cd100003c8edab7541d7 100644 (file)
@@ -5,6 +5,7 @@
 package sql
 
 import (
+       "bytes"
        "context"
        "database/sql/driver"
        "errors"
@@ -4446,10 +4447,6 @@ func testContextCancelDuringRawBytesScan(t *testing.T, mode string) {
        db := newTestDB(t, "people")
        defer closeDB(t, db)
 
-       if _, err := db.Exec("USE_RAWBYTES"); err != nil {
-               t.Fatal(err)
-       }
-
        // cancel used to call close asynchronously.
        // This test checks that it waits so as not to interfere with RawBytes.
        ctx, cancel := context.WithCancel(context.Background())
@@ -4541,6 +4538,61 @@ func TestContextCancelBetweenNextAndErr(t *testing.T) {
        }
 }
 
+type testScanner struct {
+       scanf func(src any) error
+}
+
+func (ts testScanner) Scan(src any) error { return ts.scanf(src) }
+
+func TestContextCancelDuringScan(t *testing.T) {
+       db := newTestDB(t, "people")
+       defer closeDB(t, db)
+
+       ctx, cancel := context.WithCancel(context.Background())
+       defer cancel()
+
+       scanStart := make(chan any)
+       scanEnd := make(chan error)
+       scanner := &testScanner{
+               scanf: func(src any) error {
+                       scanStart <- src
+                       return <-scanEnd
+               },
+       }
+
+       // Start a query, and pause it mid-scan.
+       want := []byte("Alice")
+       r, err := db.QueryContext(ctx, "SELECT|people|name|name=?", string(want))
+       if err != nil {
+               t.Fatal(err)
+       }
+       if !r.Next() {
+               t.Fatalf("r.Next() = false, want true")
+       }
+       go func() {
+               r.Scan(scanner)
+       }()
+       got := <-scanStart
+       defer close(scanEnd)
+       gotBytes, ok := got.([]byte)
+       if !ok {
+               t.Fatalf("r.Scan returned %T, want []byte", got)
+       }
+       if !bytes.Equal(gotBytes, want) {
+               t.Fatalf("before cancel: r.Scan returned %q, want %q", gotBytes, want)
+       }
+
+       // Cancel the query.
+       // Sleep to give it a chance to finish canceling.
+       cancel()
+       time.Sleep(10 * time.Millisecond)
+
+       // Cancelling the query should not have changed the result.
+       if !bytes.Equal(gotBytes, want) {
+               t.Fatalf("after cancel: r.Scan result is now %q, want %q", gotBytes, want)
+       }
+}
+
 func TestNilErrorAfterClose(t *testing.T) {
        db := newTestDB(t, "people")
        defer closeDB(t, db)
@@ -4574,10 +4626,6 @@ func TestRawBytesReuse(t *testing.T) {
        db := newTestDB(t, "people")
        defer closeDB(t, db)
 
-       if _, err := db.Exec("USE_RAWBYTES"); err != nil {
-               t.Fatal(err)
-       }
-
        var raw RawBytes
 
        // The RawBytes in this query aliases driver-owned memory.