From 968742a824de0a6459d2820d11b9e2e58803f472 Mon Sep 17 00:00:00 2001 From: Daniel Theophanes Date: Mon, 29 Oct 2018 16:22:37 -0700 Subject: [PATCH] database/sql: add support for returning cursors to client This CL add support for converting a returned cursor (presented to this package as a driver.Rows) and scanning it into a *Rows. Fixes #28515 Change-Id: Id8191c568dc135af9e5e8555efcd01987708edcb Reviewed-on: https://go-review.googlesource.com/c/145738 Reviewed-by: Brad Fitzpatrick --- src/database/sql/convert.go | 45 +++++++++++++++++++++++++++--- src/database/sql/driver/driver.go | 5 ++++ src/database/sql/fakedb_test.go | 20 ++++++++++++-- src/database/sql/sql.go | 8 +++++- src/database/sql/sql_test.go | 46 +++++++++++++++++++++++++++++++ 5 files changed, 116 insertions(+), 8 deletions(-) diff --git a/src/database/sql/convert.go b/src/database/sql/convert.go index 92a2ebe0e9..c450d987a4 100644 --- a/src/database/sql/convert.go +++ b/src/database/sql/convert.go @@ -203,10 +203,18 @@ func driverArgsConnLocked(ci driver.Conn, ds *driverStmt, args []interface{}) ([ } -// convertAssign copies to dest the value in src, converting it if possible. -// An error is returned if the copy would result in loss of information. -// dest should be a pointer type. +// convertAssign is the same as convertAssignRows, but without the optional +// rows argument. func convertAssign(dest, src interface{}) error { + return convertAssignRows(dest, src, nil) +} + +// convertAssignRows copies to dest the value in src, converting it if possible. +// An error is returned if the copy would result in loss of information. +// dest should be a pointer type. If rows is passed in, the rows will +// be used as the parent for any cursor values converted from a +// driver.Rows to a *Rows. +func convertAssignRows(dest, src interface{}, rows *Rows) error { // Common cases, without reflect. switch s := src.(type) { case string: @@ -299,6 +307,35 @@ func convertAssign(dest, src interface{}) error { *d = nil return nil } + // The driver is returning a cursor the client may iterate over. + case driver.Rows: + switch d := dest.(type) { + case *Rows: + if d == nil { + return errNilPtr + } + if rows == nil { + return errors.New("invalid context to convert cursor rows, missing parent *Rows") + } + rows.closemu.Lock() + *d = Rows{ + dc: rows.dc, + releaseConn: func(error) {}, + rowsi: s, + } + // Chain the cancel function. + parentCancel := rows.cancel + rows.cancel = func() { + // When Rows.cancel is called, the closemu will be locked as well. + // So we can access rs.lasterr. + d.close(rows.lasterr) + if parentCancel != nil { + parentCancel() + } + } + rows.closemu.Unlock() + return nil + } } var sv reflect.Value @@ -381,7 +418,7 @@ func convertAssign(dest, src interface{}) error { return nil } dv.Set(reflect.New(dv.Type().Elem())) - return convertAssign(dv.Interface(), src) + return convertAssignRows(dv.Interface(), src, rows) case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: s := asString(src) i64, err := strconv.ParseInt(s, 10, dv.Type().Bits()) diff --git a/src/database/sql/driver/driver.go b/src/database/sql/driver/driver.go index 70b3ddc470..5ff2bc9735 100644 --- a/src/database/sql/driver/driver.go +++ b/src/database/sql/driver/driver.go @@ -24,6 +24,11 @@ import ( // []byte // string // time.Time +// +// If the driver supports cursors, a returned Value may also implement the Rows interface +// in this package. This is used when, for example, when a user selects a cursor +// such as "select cursor(select * from my_table) from dual". If the Rows +// from the select is closed, the cursor Rows will also be closed. type Value interface{} // NamedValue holds both the value name and value. diff --git a/src/database/sql/fakedb_test.go b/src/database/sql/fakedb_test.go index a21bae61ba..dcdd264baa 100644 --- a/src/database/sql/fakedb_test.go +++ b/src/database/sql/fakedb_test.go @@ -539,7 +539,7 @@ func (c *fakeConn) prepareCreate(stmt *fakeStmt, parts []string) (*fakeStmt, err } // parts are table|col=?,col2=val -func (c *fakeConn) prepareInsert(stmt *fakeStmt, parts []string) (*fakeStmt, error) { +func (c *fakeConn) prepareInsert(ctx context.Context, stmt *fakeStmt, parts []string) (*fakeStmt, error) { if len(parts) != 2 { stmt.Close() return nil, errf("invalid INSERT syntax with %d parts; want 2", len(parts)) @@ -574,6 +574,20 @@ func (c *fakeConn) prepareInsert(stmt *fakeStmt, parts []string) (*fakeStmt, err return nil, errf("invalid conversion to int32 from %q", value) } subsetVal = int64(i) // int64 is a subset type, but not int32 + case "table": // For testing cursor reads. + c.skipDirtySession = true + vparts := strings.Split(value, "!") + + substmt, err := c.PrepareContext(ctx, fmt.Sprintf("SELECT|%s|%s|", vparts[0], strings.Join(vparts[1:], ","))) + if err != nil { + return nil, err + } + cursor, err := (substmt.(driver.StmtQueryContext)).QueryContext(ctx, []driver.NamedValue{}) + substmt.Close() + if err != nil { + return nil, err + } + subsetVal = cursor default: stmt.Close() return nil, errf("unsupported conversion for pre-bound parameter %q to type %q", value, ctype) @@ -658,11 +672,11 @@ func (c *fakeConn) PrepareContext(ctx context.Context, query string) (driver.Stm case "CREATE": stmt, err = c.prepareCreate(stmt, parts) case "INSERT": - stmt, err = c.prepareInsert(stmt, parts) + stmt, err = c.prepareInsert(ctx, stmt, parts) case "NOSERT": // Do all the prep-work like for an INSERT but don't actually insert the row. // Used for some of the concurrent tests. - stmt, err = c.prepareInsert(stmt, parts) + stmt, err = c.prepareInsert(ctx, stmt, parts) default: stmt.Close() return nil, errf("unsupported command type %q", cmd) diff --git a/src/database/sql/sql.go b/src/database/sql/sql.go index 099701ce7c..71800aae83 100644 --- a/src/database/sql/sql.go +++ b/src/database/sql/sql.go @@ -2882,6 +2882,7 @@ func rowsColumnInfoSetupConnLocked(rowsi driver.Rows) []*ColumnType { // *float32, *float64 // *interface{} // *RawBytes +// *Rows (cursor value) // any type implementing Scanner (see Scanner docs) // // In the most simple case, if the type of the value from the source @@ -2918,6 +2919,11 @@ func rowsColumnInfoSetupConnLocked(rowsi driver.Rows) []*ColumnType { // // For scanning into *bool, the source may be true, false, 1, 0, or // string inputs parseable by strconv.ParseBool. +// +// Scan can also convert a cursor returned from a query, such as +// "select cursor(select * from my_table) from dual", into a +// *Rows value that can itself be scanned from. The parent +// select query will close any cursor *Rows if the parent *Rows is closed. func (rs *Rows) Scan(dest ...interface{}) error { rs.closemu.RLock() @@ -2939,7 +2945,7 @@ func (rs *Rows) Scan(dest ...interface{}) error { return fmt.Errorf("sql: expected %d destination arguments in Scan, not %d", len(rs.lastcols), len(dest)) } for i, sv := range rs.lastcols { - err := convertAssign(dest[i], sv) + err := convertAssignRows(dest[i], sv, rs) if err != nil { return fmt.Errorf(`sql: Scan error on column index %d, name %q: %v`, i, rs.rowsi.Columns()[i], err) } diff --git a/src/database/sql/sql_test.go b/src/database/sql/sql_test.go index 82f3f316c6..64b9dfea5c 100644 --- a/src/database/sql/sql_test.go +++ b/src/database/sql/sql_test.go @@ -1338,6 +1338,52 @@ func TestConnQuery(t *testing.T) { } } +func TestCursorFake(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*30) + defer cancel() + + exec(t, db, "CREATE|peoplecursor|list=table") + exec(t, db, "INSERT|peoplecursor|list=people!name!age") + + rows, err := db.QueryContext(ctx, `SELECT|peoplecursor|list|`) + if err != nil { + t.Fatal(err) + } + defer rows.Close() + + if !rows.Next() { + t.Fatal("no rows") + } + var cursor = &Rows{} + err = rows.Scan(cursor) + if err != nil { + t.Fatal(err) + } + defer cursor.Close() + + const expectedRows = 3 + var currentRow int64 + + var n int64 + var s string + for cursor.Next() { + currentRow++ + err = cursor.Scan(&s, &n) + if err != nil { + t.Fatal(err) + } + if n != currentRow { + t.Errorf("expected number(Age)=%d, got %d", currentRow, n) + } + } + if currentRow != expectedRows { + t.Errorf("expected %d rows, got %d rows", expectedRows, currentRow) + } +} + func TestInvalidNilValues(t *testing.T) { var date1 time.Time var date2 int -- 2.48.1