if rows == nil {
return errors.New("invalid context to convert cursor rows, missing parent *Rows")
}
- rows.closemu.Lock()
*d = Rows{
dc: rows.dc,
releaseConn: func(error) {},
parentCancel()
}
}
- rows.closemu.Unlock()
return nil
}
}
package sql
import (
+ "bytes"
"context"
"database/sql/driver"
"errors"
"strconv"
"strings"
"sync"
- "sync/atomic"
"testing"
"time"
)
type fakeDB struct {
name string
- useRawBytes atomic.Bool
-
mu sync.Mutex
tables map[string]*table
badConn bool
switch cmd {
case "WIPE":
// Nothing
- case "USE_RAWBYTES":
- c.db.useRawBytes.Store(true)
case "SELECT":
stmt, err = c.prepareSelect(stmt, parts)
case "CREATE":
case "WIPE":
db.wipe()
return driver.ResultNoRows, nil
- case "USE_RAWBYTES":
- s.c.db.useRawBytes.Store(true)
- return driver.ResultNoRows, nil
case "CREATE":
if err := db.createTable(s.table, s.colName, s.colType); err != nil {
return nil, err
errPos int
err error
- // a clone of slices to give out to clients, indexed by the
- // original slice's first byte address. we clone them
- // just so we're able to corrupt them on close.
- bytesClone map[*byte][]byte
+ // Data returned to clients.
+ // We clone and stash it here so it can be invalidated by Close and Next.
+ driverOwnedMemory [][]byte
// Every operation writes to line to enable the race detector
// check for data races.
rc.line++
}
+func (rc *rowsCursor) invalidateDriverOwnedMemory() {
+ for _, buf := range rc.driverOwnedMemory {
+ for i := range buf {
+ buf[i] = 'x'
+ }
+ }
+ rc.driverOwnedMemory = nil
+}
+
func (rc *rowsCursor) Close() error {
rc.touchMem()
rc.parentMem.touchMem()
+ rc.invalidateDriverOwnedMemory()
rc.closed = true
return rc.closeErr
}
if rc.posRow >= len(rc.rows[rc.posSet]) {
return io.EOF // per interface spec
}
+ // Corrupt any previously returned bytes.
+ rc.invalidateDriverOwnedMemory()
for i, v := range rc.rows[rc.posSet][rc.posRow].cols {
// TODO(bradfitz): convert to subset types? naah, I
// think the subset types should only be input to
// a wider range of types coming out of drivers. all
// for ease of drivers, and to prevent drivers from
// messing up conversions or doing them differently.
- dest[i] = v
-
- if bs, ok := v.([]byte); ok && !rc.db.useRawBytes.Load() {
- if rc.bytesClone == nil {
- rc.bytesClone = make(map[*byte][]byte)
- }
- clone, ok := rc.bytesClone[&bs[0]]
- if !ok {
- clone = make([]byte, len(bs))
- copy(clone, bs)
- rc.bytesClone[&bs[0]] = clone
- }
- dest[i] = clone
+ if bs, ok := v.([]byte); ok {
+ // Clone []bytes and stash for later invalidation.
+ bs = bytes.Clone(bs)
+ rc.driverOwnedMemory = append(rc.driverOwnedMemory, bs)
+ v = bs
}
+ dest[i] = v
}
return nil
}
// without calling Next.
return fmt.Errorf("sql: Scan called without calling Next (closemuScanHold)")
}
+
rs.closemu.RLock()
+ rs.raw = rs.raw[:0]
+ err := rs.scanLocked(dest...)
+ if err == nil && scanArgsContainRawBytes(dest) {
+ rs.closemuScanHold = true
+ } else {
+ rs.closemu.RUnlock()
+ }
+ return err
+}
+func (rs *Rows) scanLocked(dest ...any) error {
if rs.lasterr != nil && rs.lasterr != io.EOF {
- rs.closemu.RUnlock()
return rs.lasterr
}
if rs.closed {
- err := rs.lasterrOrErrLocked(errRowsClosed)
- rs.closemu.RUnlock()
- return err
- }
-
- if scanArgsContainRawBytes(dest) {
- rs.closemuScanHold = true
- rs.raw = rs.raw[:0]
- } else {
- rs.closemu.RUnlock()
+ return rs.lasterrOrErrLocked(errRowsClosed)
}
if rs.lastcols == nil {
- rs.closemuRUnlockIfHeldByScan()
return errors.New("sql: Scan called without calling Next")
}
if len(dest) != len(rs.lastcols) {
- rs.closemuRUnlockIfHeldByScan()
return fmt.Errorf("sql: expected %d destination arguments in Scan, not %d", len(rs.lastcols), len(dest))
}
for i, sv := range rs.lastcols {
err := convertAssignRows(dest[i], sv, rs)
if err != nil {
- rs.closemuRUnlockIfHeldByScan()
return fmt.Errorf(`sql: Scan error on column index %d, name %q: %w`, i, rs.rowsi.Columns()[i], err)
}
}
package sql
import (
+ "bytes"
"context"
"database/sql/driver"
"errors"
db := newTestDB(t, "people")
defer closeDB(t, db)
- if _, err := db.Exec("USE_RAWBYTES"); err != nil {
- t.Fatal(err)
- }
-
// cancel used to call close asynchronously.
// This test checks that it waits so as not to interfere with RawBytes.
ctx, cancel := context.WithCancel(context.Background())
}
}
+type testScanner struct {
+ scanf func(src any) error
+}
+
+func (ts testScanner) Scan(src any) error { return ts.scanf(src) }
+
+func TestContextCancelDuringScan(t *testing.T) {
+ db := newTestDB(t, "people")
+ defer closeDB(t, db)
+
+ ctx, cancel := context.WithCancel(context.Background())
+ defer cancel()
+
+ scanStart := make(chan any)
+ scanEnd := make(chan error)
+ scanner := &testScanner{
+ scanf: func(src any) error {
+ scanStart <- src
+ return <-scanEnd
+ },
+ }
+
+ // Start a query, and pause it mid-scan.
+ want := []byte("Alice")
+ r, err := db.QueryContext(ctx, "SELECT|people|name|name=?", string(want))
+ if err != nil {
+ t.Fatal(err)
+ }
+ if !r.Next() {
+ t.Fatalf("r.Next() = false, want true")
+ }
+ go func() {
+ r.Scan(scanner)
+ }()
+ got := <-scanStart
+ defer close(scanEnd)
+ gotBytes, ok := got.([]byte)
+ if !ok {
+ t.Fatalf("r.Scan returned %T, want []byte", got)
+ }
+ if !bytes.Equal(gotBytes, want) {
+ t.Fatalf("before cancel: r.Scan returned %q, want %q", gotBytes, want)
+ }
+
+ // Cancel the query.
+ // Sleep to give it a chance to finish canceling.
+ cancel()
+ time.Sleep(10 * time.Millisecond)
+
+ // Cancelling the query should not have changed the result.
+ if !bytes.Equal(gotBytes, want) {
+ t.Fatalf("after cancel: r.Scan result is now %q, want %q", gotBytes, want)
+ }
+}
+
func TestNilErrorAfterClose(t *testing.T) {
db := newTestDB(t, "people")
defer closeDB(t, db)
db := newTestDB(t, "people")
defer closeDB(t, db)
- if _, err := db.Exec("USE_RAWBYTES"); err != nil {
- t.Fatal(err)
- }
-
var raw RawBytes
// The RawBytes in this query aliases driver-owned memory.