]> Cypherpunks repositories - gostls13.git/commitdiff
database/sql: allow drivers to override Scan behavior
authorJack Christensen <jack@jackchristensen.com>
Sat, 31 May 2025 15:27:15 +0000 (15:27 +0000)
committerSean Liao <sean@liao.dev>
Mon, 11 Aug 2025 21:27:36 +0000 (14:27 -0700)
Implementing RowsColumnScanner allows the driver
to completely control how values are scanned.

Fixes #67546

Change-Id: Id8e7c3a973479c9665e4476fe2d29e1255aee687
GitHub-Last-Rev: ed0cacaec4a4feead56b09c0d6eee86ed58fe1ee
GitHub-Pull-Request: golang/go#67648
Reviewed-on: https://go-review.googlesource.com/c/go/+/588435
Reviewed-by: David Chase <drchase@google.com>
Reviewed-by: Dmitri Shuralyov <dmitshur@google.com>
Reviewed-by: Sean Liao <sean@liao.dev>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>

api/next/67546.txt [new file with mode: 0644]
doc/next/6-stdlib/99-minor/database/sql/driver/67546.md [new file with mode: 0644]
src/database/sql/driver/driver.go
src/database/sql/sql.go
src/database/sql/sql_test.go

diff --git a/api/next/67546.txt b/api/next/67546.txt
new file mode 100644 (file)
index 0000000..0b5b4b9
--- /dev/null
@@ -0,0 +1,5 @@
+pkg database/sql/driver, type RowsColumnScanner interface { Close, Columns, Next, ScanColumn } #67546
+pkg database/sql/driver, type RowsColumnScanner interface, Close() error #67546
+pkg database/sql/driver, type RowsColumnScanner interface, Columns() []string #67546
+pkg database/sql/driver, type RowsColumnScanner interface, Next([]Value) error #67546
+pkg database/sql/driver, type RowsColumnScanner interface, ScanColumn(interface{}, int) error #67546
diff --git a/doc/next/6-stdlib/99-minor/database/sql/driver/67546.md b/doc/next/6-stdlib/99-minor/database/sql/driver/67546.md
new file mode 100644 (file)
index 0000000..8cb9089
--- /dev/null
@@ -0,0 +1 @@
+A database driver may implement [RowsColumnScanner] to entirely override `Scan` behavior.
index d0892e80fc28d5e54275751e2861306f926698bd..487870be63209ebb711c97983e5a9c55db2b66ab 100644 (file)
@@ -515,6 +515,18 @@ type RowsColumnTypePrecisionScale interface {
        ColumnTypePrecisionScale(index int) (precision, scale int64, ok bool)
 }
 
+// RowsColumnScanner may be implemented by [Rows]. It allows the driver to completely
+// take responsibility for how values are scanned and replace the normal [database/sql].
+// scanning path. This allows drivers to directly support types that do not implement
+// [database/sql.Scanner].
+type RowsColumnScanner interface {
+       Rows
+
+       // ScanColumn copies the column in the current row into the value pointed at by
+       // dest. It returns [ErrSkip] to fall back to the normal [database/sql] scanning path.
+       ScanColumn(dest any, index int) error
+}
+
 // Tx is a transaction.
 type Tx interface {
        Commit() error
index 4be450ca87687728797bafdb10e976adc5a71f58..85b9ffc37d9445c469cad271f7fa76686cad6657 100644 (file)
@@ -3396,7 +3396,16 @@ func (rs *Rows) scanLocked(dest ...any) error {
        }
 
        for i, sv := range rs.lastcols {
-               err := convertAssignRows(dest[i], sv, rs)
+               err := driver.ErrSkip
+
+               if rcs, ok := rs.rowsi.(driver.RowsColumnScanner); ok {
+                       err = rcs.ScanColumn(dest[i], i)
+               }
+
+               if err == driver.ErrSkip {
+                       err = convertAssignRows(dest[i], sv, rs)
+               }
+
                if err != nil {
                        return fmt.Errorf(`sql: Scan error on column index %d, name %q: %w`, i, rs.rowsi.Columns()[i], err)
                }
index 4750edd4717561d6795e06d7af861b62aacdd091..f706610b87e85b791fa44fdee9ec1bcd32ee4161 100644 (file)
@@ -4201,6 +4201,102 @@ func TestNamedValueCheckerSkip(t *testing.T) {
        }
 }
 
+type rcsDriver struct {
+       fakeDriver
+}
+
+func (d *rcsDriver) Open(dsn string) (driver.Conn, error) {
+       c, err := d.fakeDriver.Open(dsn)
+       fc := c.(*fakeConn)
+       fc.db.allowAny = true
+       return &rcsConn{fc}, err
+}
+
+type rcsConn struct {
+       *fakeConn
+}
+
+func (c *rcsConn) PrepareContext(ctx context.Context, q string) (driver.Stmt, error) {
+       stmt, err := c.fakeConn.PrepareContext(ctx, q)
+       if err != nil {
+               return stmt, err
+       }
+       return &rcsStmt{stmt.(*fakeStmt)}, nil
+}
+
+type rcsStmt struct {
+       *fakeStmt
+}
+
+func (s *rcsStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
+       rows, err := s.fakeStmt.QueryContext(ctx, args)
+       if err != nil {
+               return rows, err
+       }
+       return &rcsRows{rows.(*rowsCursor)}, nil
+}
+
+type rcsRows struct {
+       *rowsCursor
+}
+
+func (r *rcsRows) ScanColumn(dest any, index int) error {
+       switch d := dest.(type) {
+       case *int64:
+               *d = 42
+               return nil
+       }
+
+       return driver.ErrSkip
+}
+
+func TestRowsColumnScanner(t *testing.T) {
+       Register("RowsColumnScanner", &rcsDriver{})
+       db, err := Open("RowsColumnScanner", "")
+       if err != nil {
+               t.Fatal(err)
+       }
+       defer db.Close()
+
+       ctx, cancel := context.WithCancel(context.Background())
+       defer cancel()
+
+       _, err = db.ExecContext(ctx, "CREATE|t|str=string,n=int64")
+       if err != nil {
+               t.Fatal("exec create", err)
+       }
+
+       _, err = db.ExecContext(ctx, "INSERT|t|str=?,n=?", "foo", int64(1))
+       if err != nil {
+               t.Fatal("exec insert", err)
+       }
+       var (
+               str string
+               i64 int64
+               i   int
+               f64 float64
+               ui  uint
+       )
+       err = db.QueryRowContext(ctx, "SELECT|t|str,n,n,n,n|").Scan(&str, &i64, &i, &f64, &ui)
+       if err != nil {
+               t.Fatal("select", err)
+       }
+
+       list := []struct{ got, want any }{
+               {str, "foo"},
+               {i64, int64(42)},
+               {i, int(1)},
+               {f64, float64(1)},
+               {ui, uint(1)},
+       }
+
+       for index, item := range list {
+               if !reflect.DeepEqual(item.got, item.want) {
+                       t.Errorf("got %#v wanted %#v for index %d", item.got, item.want, index)
+               }
+       }
+}
+
 func TestOpenConnector(t *testing.T) {
        Register("testctx", &fakeDriverCtx{})
        db, err := Open("testctx", "people")