]> Cypherpunks repositories - gostls13.git/commitdiff
database/sql: add support for multiple result sets
authorDaniel Theophanes <kardianos@gmail.com>
Thu, 6 Oct 2016 18:06:21 +0000 (11:06 -0700)
committerBrad Fitzpatrick <bradfitz@golang.org>
Sat, 15 Oct 2016 07:13:17 +0000 (07:13 +0000)
Many database systems allow returning multiple result sets
in a single query. This can be useful when dealing with many
intermediate results on the server and there is a need
to return more then one arity of data to the client.

Fixes #12382

Change-Id: I480a9ac6dadfc8743e0ba8b6d868ccf8442a9ca1
Reviewed-on: https://go-review.googlesource.com/30592
Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>

src/database/sql/driver/driver.go
src/database/sql/example_test.go
src/database/sql/fakedb_test.go
src/database/sql/sql.go
src/database/sql/sql_test.go

index ccc283d3738959ef41ced17db7937d68e1191963..b3d83f3ff449a0b3f32a323d7ef5e678a8dccdbc 100644 (file)
@@ -213,6 +213,22 @@ type Rows interface {
        Next(dest []Value) error
 }
 
+// RowsNextResultSet extends the Rows interface by providing a way to signal
+// the driver to advance to the next result set.
+type RowsNextResultSet interface {
+       Rows
+
+       // HasNextResultSet is called at the end of the current result set and
+       // reports whether there is another result set after the current one.
+       HasNextResultSet() bool
+
+       // NextResultSet advances the driver to the next result set even
+       // if there are remaining rows in the current result set.
+       //
+       // NextResultSet should return io.EOF when there are no more result sets.
+       NextResultSet() error
+}
+
 // Tx is a transaction.
 type Tx interface {
        Commit() error
index dcb74e06998d9f8184c39ca1a6aa6f78b6432a31..9032eac2d28573f6eae698370308bdd36d15fdfa 100644 (file)
@@ -44,3 +44,65 @@ func ExampleDB_QueryRow() {
                fmt.Printf("Username is %s\n", username)
        }
 }
+
+func ExampleDB_QueryMultipleResultSets() {
+       age := 27
+       q := `
+create temp table uid (id bigint); -- Create temp table for queries.
+insert into uid
+select id from users where age < ?; -- Populate temp table.
+
+-- First result set.
+select
+       users.id, name
+from
+       users
+       join uid on users.id = uid.id
+;
+
+-- Second result set.
+select 
+       ur.user, ur.role
+from
+       user_roles as ur
+       join uid on uid.id = ur.user
+;
+       `
+       rows, err := db.Query(q, age)
+       if err != nil {
+               log.Fatal(err)
+       }
+       defer rows.Close()
+
+       for rows.Next() {
+               var (
+                       id   int64
+                       name string
+               )
+               if err := rows.Scan(&id, &name); err != nil {
+                       log.Fatal(err)
+               }
+               fmt.Printf("id %d name is %s\n", id, name)
+       }
+       if !rows.NextResultSet() {
+               log.Fatal("expected more result sets", rows.Err())
+       }
+       var roleMap = map[int64]string{
+               1: "user",
+               2: "admin",
+               3: "gopher",
+       }
+       for rows.Next() {
+               var (
+                       id   int64
+                       role int64
+               )
+               if err := rows.Scan(&id, &role); err != nil {
+                       log.Fatal(err)
+               }
+               fmt.Printf("id %d has role %s\n", id, roleMap[role])
+       }
+       if err := rows.Err(); err != nil {
+               log.Fatal(err)
+       }
+}
index 5b238bfc5cbcff7510a1fb6bb9fa39310e35c6e9..aaa13a679964d4242c7fc27b89c14ee92552d647 100644 (file)
@@ -36,6 +36,8 @@ var _ = log.Printf
 // Any of these can be preceded by PANIC|<method>|, to cause the
 // named method on fakeStmt to panic.
 //
+// Multiple of these can be combined when separated with a semicolon.
+//
 // When opening a fakeDriver's database, it starts empty with no
 // tables. All tables and data are stored in memory only.
 type fakeDriver struct {
@@ -109,6 +111,8 @@ type fakeStmt struct {
        table string
        panic string
 
+       next *fakeStmt // used for returning multiple results.
+
        closed bool
 
        colName      []string      // used by CREATE, INSERT, SELECT (selected columns)
@@ -377,7 +381,7 @@ func errf(msg string, args ...interface{}) error {
 // parts are table|selectCol1,selectCol2|whereCol=?,whereCol2=?
 // (note that where columns must always contain ? marks,
 //  just a limitation for fakedb)
-func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (driver.Stmt, error) {
+func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (*fakeStmt, error) {
        if len(parts) != 3 {
                stmt.Close()
                return nil, errf("invalid SELECT syntax with %d parts; want 3", len(parts))
@@ -411,7 +415,7 @@ func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (driver.Stmt, e
 }
 
 // parts are table|col=type,col2=type2
-func (c *fakeConn) prepareCreate(stmt *fakeStmt, parts []string) (driver.Stmt, error) {
+func (c *fakeConn) prepareCreate(stmt *fakeStmt, parts []string) (*fakeStmt, error) {
        if len(parts) != 2 {
                stmt.Close()
                return nil, errf("invalid CREATE syntax with %d parts; want 2", len(parts))
@@ -430,7 +434,7 @@ func (c *fakeConn) prepareCreate(stmt *fakeStmt, parts []string) (driver.Stmt, e
 }
 
 // parts are table|col=?,col2=val
-func (c *fakeConn) prepareInsert(stmt *fakeStmt, parts []string) (driver.Stmt, error) {
+func (c *fakeConn) prepareInsert(stmt *fakeStmt, parts []string) (*fakeStmt, error) {
        if len(parts) != 2 {
                stmt.Close()
                return nil, errf("invalid INSERT syntax with %d parts; want 2", len(parts))
@@ -492,38 +496,52 @@ func (c *fakeConn) Prepare(query string) (driver.Stmt, error) {
                return nil, driver.ErrBadConn
        }
 
-       parts := strings.Split(query, "|")
-       if len(parts) < 1 {
-               return nil, errf("empty query")
-       }
-       stmt := &fakeStmt{q: query, c: c}
-       if len(parts) >= 3 && parts[0] == "PANIC" {
-               stmt.panic = parts[1]
-               parts = parts[2:]
-       }
-       cmd := parts[0]
-       stmt.cmd = cmd
-       parts = parts[1:]
-
-       c.incrStat(&c.stmtsMade)
-       switch cmd {
-       case "WIPE":
-               // Nothing
-       case "SELECT":
-               return c.prepareSelect(stmt, parts)
-       case "CREATE":
-               return c.prepareCreate(stmt, parts)
-       case "INSERT":
-               return c.prepareInsert(stmt, parts)
-       case "NOSERT":
-               // Do all the prep-work like for an INSERT but don't actually insert the row.
-               // Used for some of the concurrent tests.
-               return c.prepareInsert(stmt, parts)
-       default:
-               stmt.Close()
-               return nil, errf("unsupported command type %q", cmd)
+       var firstStmt, prev *fakeStmt
+       for _, query := range strings.Split(query, ";") {
+               parts := strings.Split(query, "|")
+               if len(parts) < 1 {
+                       return nil, errf("empty query")
+               }
+               stmt := &fakeStmt{q: query, c: c}
+               if firstStmt == nil {
+                       firstStmt = stmt
+               }
+               if len(parts) >= 3 && parts[0] == "PANIC" {
+                       stmt.panic = parts[1]
+                       parts = parts[2:]
+               }
+               cmd := parts[0]
+               stmt.cmd = cmd
+               parts = parts[1:]
+
+               c.incrStat(&c.stmtsMade)
+               var err error
+               switch cmd {
+               case "WIPE":
+                       // Nothing
+               case "SELECT":
+                       stmt, err = c.prepareSelect(stmt, parts)
+               case "CREATE":
+                       stmt, err = c.prepareCreate(stmt, parts)
+               case "INSERT":
+                       stmt, err = c.prepareInsert(stmt, parts)
+               case "NOSERT":
+                       // Do all the prep-work like for an INSERT but don't actually insert the row.
+                       // Used for some of the concurrent tests.
+                       stmt, err = c.prepareInsert(stmt, parts)
+               default:
+                       stmt.Close()
+                       return nil, errf("unsupported command type %q", cmd)
+               }
+               if err != nil {
+                       return nil, err
+               }
+               if prev != nil {
+                       prev.next = stmt
+               }
+               prev = stmt
        }
-       return stmt, nil
+       return firstStmt, nil
 }
 
 func (s *fakeStmt) ColumnConverter(idx int) driver.ValueConverter {
@@ -550,6 +568,9 @@ func (s *fakeStmt) Close() error {
                s.c.incrStat(&s.c.stmtsClosed)
                s.closed = true
        }
+       if s.next != nil {
+               s.next.Close()
+       }
        return nil
 }
 
@@ -667,64 +688,80 @@ func (s *fakeStmt) Query(args []driver.Value) (driver.Rows, error) {
                panic("error in pkg db; should only get here if size is correct")
        }
 
-       db.mu.Lock()
-       t, ok := db.table(s.table)
-       db.mu.Unlock()
-       if !ok {
-               return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table)
-       }
+       setMRows := make([][]*row, 0, 1)
+       setColumns := make([][]string, 0, 1)
 
-       if s.table == "magicquery" {
-               if len(s.whereCol) == 2 && s.whereCol[0] == "op" && s.whereCol[1] == "millis" {
-                       if args[0] == "sleep" {
-                               time.Sleep(time.Duration(args[1].(int64)) * time.Millisecond)
-                       }
+       for {
+               db.mu.Lock()
+               t, ok := db.table(s.table)
+               db.mu.Unlock()
+               if !ok {
+                       return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table)
                }
-       }
 
-       t.mu.Lock()
-       defer t.mu.Unlock()
-
-       colIdx := make(map[string]int) // select column name -> column index in table
-       for _, name := range s.colName {
-               idx := t.columnIndex(name)
-               if idx == -1 {
-                       return nil, fmt.Errorf("fakedb: unknown column name %q", name)
+               if s.table == "magicquery" {
+                       if len(s.whereCol) == 2 && s.whereCol[0] == "op" && s.whereCol[1] == "millis" {
+                               if args[0] == "sleep" {
+                                       time.Sleep(time.Duration(args[1].(int64)) * time.Millisecond)
+                               }
+                       }
                }
-               colIdx[name] = idx
-       }
 
-       mrows := []*row{}
-rows:
-       for _, trow := range t.rows {
-               // Process the where clause, skipping non-match rows. This is lazy
-               // and just uses fmt.Sprintf("%v") to test equality. Good enough
-               // for test code.
-               for widx, wcol := range s.whereCol {
-                       idx := t.columnIndex(wcol)
+               t.mu.Lock()
+
+               colIdx := make(map[string]int) // select column name -> column index in table
+               for _, name := range s.colName {
+                       idx := t.columnIndex(name)
                        if idx == -1 {
-                               return nil, fmt.Errorf("db: invalid where clause column %q", wcol)
+                               t.mu.Unlock()
+                               return nil, fmt.Errorf("fakedb: unknown column name %q", name)
                        }
-                       tcol := trow.cols[idx]
-                       if bs, ok := tcol.([]byte); ok {
-                               // lazy hack to avoid sprintf %v on a []byte
-                               tcol = string(bs)
+                       colIdx[name] = idx
+               }
+
+               mrows := []*row{}
+       rows:
+               for _, trow := range t.rows {
+                       // Process the where clause, skipping non-match rows. This is lazy
+                       // and just uses fmt.Sprintf("%v") to test equality. Good enough
+                       // for test code.
+                       for widx, wcol := range s.whereCol {
+                               idx := t.columnIndex(wcol)
+                               if idx == -1 {
+                                       t.mu.Unlock()
+                                       return nil, fmt.Errorf("db: invalid where clause column %q", wcol)
+                               }
+                               tcol := trow.cols[idx]
+                               if bs, ok := tcol.([]byte); ok {
+                                       // lazy hack to avoid sprintf %v on a []byte
+                                       tcol = string(bs)
+                               }
+                               if fmt.Sprintf("%v", tcol) != fmt.Sprintf("%v", args[widx]) {
+                                       continue rows
+                               }
                        }
-                       if fmt.Sprintf("%v", tcol) != fmt.Sprintf("%v", args[widx]) {
-                               continue rows
+                       mrow := &row{cols: make([]interface{}, len(s.colName))}
+                       for seli, name := range s.colName {
+                               mrow.cols[seli] = trow.cols[colIdx[name]]
                        }
+                       mrows = append(mrows, mrow)
                }
-               mrow := &row{cols: make([]interface{}, len(s.colName))}
-               for seli, name := range s.colName {
-                       mrow.cols[seli] = trow.cols[colIdx[name]]
+
+               t.mu.Unlock()
+
+               setMRows = append(setMRows, mrows)
+               setColumns = append(setColumns, s.colName)
+
+               if s.next == nil {
+                       break
                }
-               mrows = append(mrows, mrow)
+               s = s.next
        }
 
        cursor := &rowsCursor{
-               pos:    -1,
-               rows:   mrows,
-               cols:   s.colName,
+               posRow: -1,
+               rows:   setMRows,
+               cols:   setColumns,
                errPos: -1,
        }
        return cursor, nil
@@ -760,9 +797,10 @@ func (tx *fakeTx) Rollback() error {
 }
 
 type rowsCursor struct {
-       cols   []string
-       pos    int
-       rows   []*row
+       cols   [][]string
+       posSet int
+       posRow int
+       rows   [][]*row
        closed bool
 
        // errPos and err are for making Next return early with error.
@@ -786,7 +824,7 @@ func (rc *rowsCursor) Close() error {
 }
 
 func (rc *rowsCursor) Columns() []string {
-       return rc.cols
+       return rc.cols[rc.posSet]
 }
 
 var rowsCursorNextHook func(dest []driver.Value) error
@@ -799,14 +837,14 @@ func (rc *rowsCursor) Next(dest []driver.Value) error {
        if rc.closed {
                return errors.New("fakedb: cursor is closed")
        }
-       rc.pos++
-       if rc.pos == rc.errPos {
+       rc.posRow++
+       if rc.posRow == rc.errPos {
                return rc.err
        }
-       if rc.pos >= len(rc.rows) {
+       if rc.posRow >= len(rc.rows[rc.posSet]) {
                return io.EOF // per interface spec
        }
-       for i, v := range rc.rows[rc.pos].cols {
+       for i, v := range rc.rows[rc.posSet][rc.posRow].cols {
                // TODO(bradfitz): convert to subset types? naah, I
                // think the subset types should only be input to
                // driver, but the sql package should be able to handle
@@ -831,6 +869,19 @@ func (rc *rowsCursor) Next(dest []driver.Value) error {
        return nil
 }
 
+func (rc *rowsCursor) HasNextResultSet() bool {
+       return rc.posSet < len(rc.rows)-1
+}
+
+func (rc *rowsCursor) NextResultSet() error {
+       if rc.HasNextResultSet() {
+               rc.posSet++
+               rc.posRow = -1
+               return nil
+       }
+       return io.EOF // Per interface spec.
+}
+
 // fakeDriverString is like driver.String, but indirects pointers like
 // DefaultValueConverter.
 //
index c26d7d30634ea5f702ddb44184c2362e4741f3f4..970334269ddf8b1540ffc357cac43e338c4a3e10 100644 (file)
@@ -1942,6 +1942,47 @@ func (rs *Rows) Next() bool {
                rs.lastcols = make([]driver.Value, len(rs.rowsi.Columns()))
        }
        rs.lasterr = rs.rowsi.Next(rs.lastcols)
+       if rs.lasterr != nil {
+               // Close the connection if there is a driver error.
+               if rs.lasterr != io.EOF {
+                       rs.Close()
+                       return false
+               }
+               nextResultSet, ok := rs.rowsi.(driver.RowsNextResultSet)
+               if !ok {
+                       rs.Close()
+                       return false
+               }
+               // The driver is at the end of the current result set.
+               // Test to see if there is another result set after the current one.
+               // Only close Rows if there is no futher result sets to read.
+               if !nextResultSet.HasNextResultSet() {
+                       rs.Close()
+               }
+               return false
+       }
+       return true
+}
+
+// NextResultSet prepares the next result set for reading. It returns true if
+// there is further result sets, or false if there is no further result set
+// or if there is an error advancing to it. The Err method should be consulted
+// to distinguish between the two cases.
+//
+// After calling NextResultSet, the Next method should always be called before
+// scanning. If there are further result sets they may not have rows in the result
+// set.
+func (rs *Rows) NextResultSet() bool {
+       if rs.isClosed() {
+               return false
+       }
+       rs.lastcols = nil
+       nextResultSet, ok := rs.rowsi.(driver.RowsNextResultSet)
+       if !ok {
+               rs.Close()
+               return false
+       }
+       rs.lasterr = nextResultSet.NextResultSet()
        if rs.lasterr != nil {
                rs.Close()
                return false
@@ -2047,7 +2088,8 @@ func (rs *Rows) isClosed() bool {
        return atomic.LoadInt32(&rs.closed) != 0
 }
 
-// Close closes the Rows, preventing further enumeration. If Next returns
+// Close closes the Rows, preventing further enumeration. If Next and
+// NextResultSet both return
 // false, the Rows are closed automatically and it will suffice to check the
 // result of Err. Close is idempotent and does not affect the result of Err.
 func (rs *Rows) Close() error {
index ca14af79e7cb85b90fdcbadbdd74c6a668556d1c..bce210da97000d1acbd1f26d989de1cfcebec176 100644 (file)
@@ -319,6 +319,82 @@ func TestQueryContext(t *testing.T) {
        }
 }
 
+func TestMultiResultSetQuery(t *testing.T) {
+       db := newTestDB(t, "people")
+       defer closeDB(t, db)
+       prepares0 := numPrepares(t, db)
+       rows, err := db.Query("SELECT|people|age,name|;SELECT|people|name|")
+       if err != nil {
+               t.Fatalf("Query: %v", err)
+       }
+       type row1 struct {
+               age  int
+               name string
+       }
+       type row2 struct {
+               name string
+       }
+       got1 := []row1{}
+       for rows.Next() {
+               var r row1
+               err = rows.Scan(&r.age, &r.name)
+               if err != nil {
+                       t.Fatalf("Scan: %v", err)
+               }
+               got1 = append(got1, r)
+       }
+       err = rows.Err()
+       if err != nil {
+               t.Fatalf("Err: %v", err)
+       }
+       want1 := []row1{
+               {age: 1, name: "Alice"},
+               {age: 2, name: "Bob"},
+               {age: 3, name: "Chris"},
+       }
+       if !reflect.DeepEqual(got1, want1) {
+               t.Errorf("mismatch.\n got1: %#v\nwant: %#v", got1, want1)
+       }
+
+       if !rows.NextResultSet() {
+               t.Errorf("expected another result set")
+       }
+
+       got2 := []row2{}
+       for rows.Next() {
+               var r row2
+               err = rows.Scan(&r.name)
+               if err != nil {
+                       t.Fatalf("Scan: %v", err)
+               }
+               got2 = append(got2, r)
+       }
+       err = rows.Err()
+       if err != nil {
+               t.Fatalf("Err: %v", err)
+       }
+       want2 := []row2{
+               {name: "Alice"},
+               {name: "Bob"},
+               {name: "Chris"},
+       }
+       if !reflect.DeepEqual(got2, want2) {
+               t.Errorf("mismatch.\n got: %#v\nwant: %#v", got2, want2)
+       }
+       if rows.NextResultSet() {
+               t.Errorf("expected no more result sets")
+       }
+
+       // And verify that the final rows.Next() call, which hit EOF,
+       // also closed the rows connection.
+       if n := db.numFreeConns(); n != 1 {
+               t.Fatalf("free conns after query hitting EOF = %d; want 1", n)
+       }
+       if prepares := numPrepares(t, db) - prepares0; prepares != 1 {
+               t.Errorf("executed %d Prepare statements; want 1", prepares)
+       }
+}
+
 func TestByteOwnership(t *testing.T) {
        db := newTestDB(t, "people")
        defer closeDB(t, db)