]> Cypherpunks repositories - gostls13.git/commitdiff
database/sql: add support for returning cursors to client
authorDaniel Theophanes <kardianos@gmail.com>
Mon, 29 Oct 2018 23:22:37 +0000 (16:22 -0700)
committerDaniel Theophanes <kardianos@gmail.com>
Thu, 8 Nov 2018 21:19:17 +0000 (21:19 +0000)
This CL add support for converting a returned cursor (presented
to this package as a driver.Rows) and scanning it into a *Rows.

Fixes #28515

Change-Id: Id8191c568dc135af9e5e8555efcd01987708edcb
Reviewed-on: https://go-review.googlesource.com/c/145738
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
src/database/sql/convert.go
src/database/sql/driver/driver.go
src/database/sql/fakedb_test.go
src/database/sql/sql.go
src/database/sql/sql_test.go

index 92a2ebe0e991345fc9665a823c872e46bc6f416b..c450d987a465125db4e858dff1ae48ca05f9ab9c 100644 (file)
@@ -203,10 +203,18 @@ func driverArgsConnLocked(ci driver.Conn, ds *driverStmt, args []interface{}) ([
 
 }
 
-// convertAssign copies to dest the value in src, converting it if possible.
-// An error is returned if the copy would result in loss of information.
-// dest should be a pointer type.
+// convertAssign is the same as convertAssignRows, but without the optional
+// rows argument.
 func convertAssign(dest, src interface{}) error {
+       return convertAssignRows(dest, src, nil)
+}
+
+// convertAssignRows copies to dest the value in src, converting it if possible.
+// An error is returned if the copy would result in loss of information.
+// dest should be a pointer type. If rows is passed in, the rows will
+// be used as the parent for any cursor values converted from a
+// driver.Rows to a *Rows.
+func convertAssignRows(dest, src interface{}, rows *Rows) error {
        // Common cases, without reflect.
        switch s := src.(type) {
        case string:
@@ -299,6 +307,35 @@ func convertAssign(dest, src interface{}) error {
                        *d = nil
                        return nil
                }
+       // The driver is returning a cursor the client may iterate over.
+       case driver.Rows:
+               switch d := dest.(type) {
+               case *Rows:
+                       if d == nil {
+                               return errNilPtr
+                       }
+                       if rows == nil {
+                               return errors.New("invalid context to convert cursor rows, missing parent *Rows")
+                       }
+                       rows.closemu.Lock()
+                       *d = Rows{
+                               dc:          rows.dc,
+                               releaseConn: func(error) {},
+                               rowsi:       s,
+                       }
+                       // Chain the cancel function.
+                       parentCancel := rows.cancel
+                       rows.cancel = func() {
+                               // When Rows.cancel is called, the closemu will be locked as well.
+                               // So we can access rs.lasterr.
+                               d.close(rows.lasterr)
+                               if parentCancel != nil {
+                                       parentCancel()
+                               }
+                       }
+                       rows.closemu.Unlock()
+                       return nil
+               }
        }
 
        var sv reflect.Value
@@ -381,7 +418,7 @@ func convertAssign(dest, src interface{}) error {
                        return nil
                }
                dv.Set(reflect.New(dv.Type().Elem()))
-               return convertAssign(dv.Interface(), src)
+               return convertAssignRows(dv.Interface(), src, rows)
        case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
                s := asString(src)
                i64, err := strconv.ParseInt(s, 10, dv.Type().Bits())
index 70b3ddc47030ae7876107e65b8dae74cb6a56a8d..5ff2bc97350e9c8225c48158e87917020e844f5e 100644 (file)
@@ -24,6 +24,11 @@ import (
 //   []byte
 //   string
 //   time.Time
+//
+// If the driver supports cursors, a returned Value may also implement the Rows interface
+// in this package. This is used when, for example, when a user selects a cursor
+// such as "select cursor(select * from my_table) from dual". If the Rows
+// from the select is closed, the cursor Rows will also be closed.
 type Value interface{}
 
 // NamedValue holds both the value name and value.
index a21bae61baccaa284e7d36891377b849b7c4bc27..dcdd264baa8422e51ed3a20e1585fc0e5f7c9864 100644 (file)
@@ -539,7 +539,7 @@ func (c *fakeConn) prepareCreate(stmt *fakeStmt, parts []string) (*fakeStmt, err
 }
 
 // parts are table|col=?,col2=val
-func (c *fakeConn) prepareInsert(stmt *fakeStmt, parts []string) (*fakeStmt, error) {
+func (c *fakeConn) prepareInsert(ctx context.Context, stmt *fakeStmt, parts []string) (*fakeStmt, error) {
        if len(parts) != 2 {
                stmt.Close()
                return nil, errf("invalid INSERT syntax with %d parts; want 2", len(parts))
@@ -574,6 +574,20 @@ func (c *fakeConn) prepareInsert(stmt *fakeStmt, parts []string) (*fakeStmt, err
                                        return nil, errf("invalid conversion to int32 from %q", value)
                                }
                                subsetVal = int64(i) // int64 is a subset type, but not int32
+                       case "table": // For testing cursor reads.
+                               c.skipDirtySession = true
+                               vparts := strings.Split(value, "!")
+
+                               substmt, err := c.PrepareContext(ctx, fmt.Sprintf("SELECT|%s|%s|", vparts[0], strings.Join(vparts[1:], ",")))
+                               if err != nil {
+                                       return nil, err
+                               }
+                               cursor, err := (substmt.(driver.StmtQueryContext)).QueryContext(ctx, []driver.NamedValue{})
+                               substmt.Close()
+                               if err != nil {
+                                       return nil, err
+                               }
+                               subsetVal = cursor
                        default:
                                stmt.Close()
                                return nil, errf("unsupported conversion for pre-bound parameter %q to type %q", value, ctype)
@@ -658,11 +672,11 @@ func (c *fakeConn) PrepareContext(ctx context.Context, query string) (driver.Stm
                case "CREATE":
                        stmt, err = c.prepareCreate(stmt, parts)
                case "INSERT":
-                       stmt, err = c.prepareInsert(stmt, parts)
+                       stmt, err = c.prepareInsert(ctx, stmt, parts)
                case "NOSERT":
                        // Do all the prep-work like for an INSERT but don't actually insert the row.
                        // Used for some of the concurrent tests.
-                       stmt, err = c.prepareInsert(stmt, parts)
+                       stmt, err = c.prepareInsert(ctx, stmt, parts)
                default:
                        stmt.Close()
                        return nil, errf("unsupported command type %q", cmd)
index 099701ce7c5a4852293445e30490bfe399955e69..71800aae83edb456d21eb8364867e83fb4c184f7 100644 (file)
@@ -2882,6 +2882,7 @@ func rowsColumnInfoSetupConnLocked(rowsi driver.Rows) []*ColumnType {
 //    *float32, *float64
 //    *interface{}
 //    *RawBytes
+//    *Rows (cursor value)
 //    any type implementing Scanner (see Scanner docs)
 //
 // In the most simple case, if the type of the value from the source
@@ -2918,6 +2919,11 @@ func rowsColumnInfoSetupConnLocked(rowsi driver.Rows) []*ColumnType {
 //
 // For scanning into *bool, the source may be true, false, 1, 0, or
 // string inputs parseable by strconv.ParseBool.
+//
+// Scan can also convert a cursor returned from a query, such as
+// "select cursor(select * from my_table) from dual", into a
+// *Rows value that can itself be scanned from. The parent
+// select query will close any cursor *Rows if the parent *Rows is closed.
 func (rs *Rows) Scan(dest ...interface{}) error {
        rs.closemu.RLock()
 
@@ -2939,7 +2945,7 @@ func (rs *Rows) Scan(dest ...interface{}) error {
                return fmt.Errorf("sql: expected %d destination arguments in Scan, not %d", len(rs.lastcols), len(dest))
        }
        for i, sv := range rs.lastcols {
-               err := convertAssign(dest[i], sv)
+               err := convertAssignRows(dest[i], sv, rs)
                if err != nil {
                        return fmt.Errorf(`sql: Scan error on column index %d, name %q: %v`, i, rs.rowsi.Columns()[i], err)
                }
index 82f3f316c6a5841e69ef4aa0a66e16d178a65b72..64b9dfea5c21303c25b8d185159a9e3909be1126 100644 (file)
@@ -1338,6 +1338,52 @@ func TestConnQuery(t *testing.T) {
        }
 }
 
+func TestCursorFake(t *testing.T) {
+       db := newTestDB(t, "people")
+       defer closeDB(t, db)
+
+       ctx, cancel := context.WithTimeout(context.Background(), time.Second*30)
+       defer cancel()
+
+       exec(t, db, "CREATE|peoplecursor|list=table")
+       exec(t, db, "INSERT|peoplecursor|list=people!name!age")
+
+       rows, err := db.QueryContext(ctx, `SELECT|peoplecursor|list|`)
+       if err != nil {
+               t.Fatal(err)
+       }
+       defer rows.Close()
+
+       if !rows.Next() {
+               t.Fatal("no rows")
+       }
+       var cursor = &Rows{}
+       err = rows.Scan(cursor)
+       if err != nil {
+               t.Fatal(err)
+       }
+       defer cursor.Close()
+
+       const expectedRows = 3
+       var currentRow int64
+
+       var n int64
+       var s string
+       for cursor.Next() {
+               currentRow++
+               err = cursor.Scan(&s, &n)
+               if err != nil {
+                       t.Fatal(err)
+               }
+               if n != currentRow {
+                       t.Errorf("expected number(Age)=%d, got %d", currentRow, n)
+               }
+       }
+       if currentRow != expectedRows {
+               t.Errorf("expected %d rows, got %d rows", expectedRows, currentRow)
+       }
+}
+
 func TestInvalidNilValues(t *testing.T) {
        var date1 time.Time
        var date2 int