From: Daniel Theophanes Date: Thu, 6 Oct 2016 18:06:21 +0000 (-0700) Subject: database/sql: add support for multiple result sets X-Git-Tag: go1.8beta1~872 X-Git-Url: http://www.git.cypherpunks.su/?a=commitdiff_plain;h=86b2f29676c52774d91dda96e0ba5d4d7bcd3b47;p=gostls13.git database/sql: add support for multiple result sets Many database systems allow returning multiple result sets in a single query. This can be useful when dealing with many intermediate results on the server and there is a need to return more then one arity of data to the client. Fixes #12382 Change-Id: I480a9ac6dadfc8743e0ba8b6d868ccf8442a9ca1 Reviewed-on: https://go-review.googlesource.com/30592 Reviewed-by: Brad Fitzpatrick Run-TryBot: Brad Fitzpatrick TryBot-Result: Gobot Gobot --- diff --git a/src/database/sql/driver/driver.go b/src/database/sql/driver/driver.go index ccc283d373..b3d83f3ff4 100644 --- a/src/database/sql/driver/driver.go +++ b/src/database/sql/driver/driver.go @@ -213,6 +213,22 @@ type Rows interface { Next(dest []Value) error } +// RowsNextResultSet extends the Rows interface by providing a way to signal +// the driver to advance to the next result set. +type RowsNextResultSet interface { + Rows + + // HasNextResultSet is called at the end of the current result set and + // reports whether there is another result set after the current one. + HasNextResultSet() bool + + // NextResultSet advances the driver to the next result set even + // if there are remaining rows in the current result set. + // + // NextResultSet should return io.EOF when there are no more result sets. + NextResultSet() error +} + // Tx is a transaction. type Tx interface { Commit() error diff --git a/src/database/sql/example_test.go b/src/database/sql/example_test.go index dcb74e0699..9032eac2d2 100644 --- a/src/database/sql/example_test.go +++ b/src/database/sql/example_test.go @@ -44,3 +44,65 @@ func ExampleDB_QueryRow() { fmt.Printf("Username is %s\n", username) } } + +func ExampleDB_QueryMultipleResultSets() { + age := 27 + q := ` +create temp table uid (id bigint); -- Create temp table for queries. +insert into uid +select id from users where age < ?; -- Populate temp table. + +-- First result set. +select + users.id, name +from + users + join uid on users.id = uid.id +; + +-- Second result set. +select + ur.user, ur.role +from + user_roles as ur + join uid on uid.id = ur.user +; + ` + rows, err := db.Query(q, age) + if err != nil { + log.Fatal(err) + } + defer rows.Close() + + for rows.Next() { + var ( + id int64 + name string + ) + if err := rows.Scan(&id, &name); err != nil { + log.Fatal(err) + } + fmt.Printf("id %d name is %s\n", id, name) + } + if !rows.NextResultSet() { + log.Fatal("expected more result sets", rows.Err()) + } + var roleMap = map[int64]string{ + 1: "user", + 2: "admin", + 3: "gopher", + } + for rows.Next() { + var ( + id int64 + role int64 + ) + if err := rows.Scan(&id, &role); err != nil { + log.Fatal(err) + } + fmt.Printf("id %d has role %s\n", id, roleMap[role]) + } + if err := rows.Err(); err != nil { + log.Fatal(err) + } +} diff --git a/src/database/sql/fakedb_test.go b/src/database/sql/fakedb_test.go index 5b238bfc5c..aaa13a6799 100644 --- a/src/database/sql/fakedb_test.go +++ b/src/database/sql/fakedb_test.go @@ -36,6 +36,8 @@ var _ = log.Printf // Any of these can be preceded by PANIC||, to cause the // named method on fakeStmt to panic. // +// Multiple of these can be combined when separated with a semicolon. +// // When opening a fakeDriver's database, it starts empty with no // tables. All tables and data are stored in memory only. type fakeDriver struct { @@ -109,6 +111,8 @@ type fakeStmt struct { table string panic string + next *fakeStmt // used for returning multiple results. + closed bool colName []string // used by CREATE, INSERT, SELECT (selected columns) @@ -377,7 +381,7 @@ func errf(msg string, args ...interface{}) error { // parts are table|selectCol1,selectCol2|whereCol=?,whereCol2=? // (note that where columns must always contain ? marks, // just a limitation for fakedb) -func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (driver.Stmt, error) { +func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (*fakeStmt, error) { if len(parts) != 3 { stmt.Close() return nil, errf("invalid SELECT syntax with %d parts; want 3", len(parts)) @@ -411,7 +415,7 @@ func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (driver.Stmt, e } // parts are table|col=type,col2=type2 -func (c *fakeConn) prepareCreate(stmt *fakeStmt, parts []string) (driver.Stmt, error) { +func (c *fakeConn) prepareCreate(stmt *fakeStmt, parts []string) (*fakeStmt, error) { if len(parts) != 2 { stmt.Close() return nil, errf("invalid CREATE syntax with %d parts; want 2", len(parts)) @@ -430,7 +434,7 @@ func (c *fakeConn) prepareCreate(stmt *fakeStmt, parts []string) (driver.Stmt, e } // parts are table|col=?,col2=val -func (c *fakeConn) prepareInsert(stmt *fakeStmt, parts []string) (driver.Stmt, error) { +func (c *fakeConn) prepareInsert(stmt *fakeStmt, parts []string) (*fakeStmt, error) { if len(parts) != 2 { stmt.Close() return nil, errf("invalid INSERT syntax with %d parts; want 2", len(parts)) @@ -492,38 +496,52 @@ func (c *fakeConn) Prepare(query string) (driver.Stmt, error) { return nil, driver.ErrBadConn } - parts := strings.Split(query, "|") - if len(parts) < 1 { - return nil, errf("empty query") - } - stmt := &fakeStmt{q: query, c: c} - if len(parts) >= 3 && parts[0] == "PANIC" { - stmt.panic = parts[1] - parts = parts[2:] - } - cmd := parts[0] - stmt.cmd = cmd - parts = parts[1:] - - c.incrStat(&c.stmtsMade) - switch cmd { - case "WIPE": - // Nothing - case "SELECT": - return c.prepareSelect(stmt, parts) - case "CREATE": - return c.prepareCreate(stmt, parts) - case "INSERT": - return c.prepareInsert(stmt, parts) - case "NOSERT": - // Do all the prep-work like for an INSERT but don't actually insert the row. - // Used for some of the concurrent tests. - return c.prepareInsert(stmt, parts) - default: - stmt.Close() - return nil, errf("unsupported command type %q", cmd) + var firstStmt, prev *fakeStmt + for _, query := range strings.Split(query, ";") { + parts := strings.Split(query, "|") + if len(parts) < 1 { + return nil, errf("empty query") + } + stmt := &fakeStmt{q: query, c: c} + if firstStmt == nil { + firstStmt = stmt + } + if len(parts) >= 3 && parts[0] == "PANIC" { + stmt.panic = parts[1] + parts = parts[2:] + } + cmd := parts[0] + stmt.cmd = cmd + parts = parts[1:] + + c.incrStat(&c.stmtsMade) + var err error + switch cmd { + case "WIPE": + // Nothing + case "SELECT": + stmt, err = c.prepareSelect(stmt, parts) + case "CREATE": + stmt, err = c.prepareCreate(stmt, parts) + case "INSERT": + stmt, err = c.prepareInsert(stmt, parts) + case "NOSERT": + // Do all the prep-work like for an INSERT but don't actually insert the row. + // Used for some of the concurrent tests. + stmt, err = c.prepareInsert(stmt, parts) + default: + stmt.Close() + return nil, errf("unsupported command type %q", cmd) + } + if err != nil { + return nil, err + } + if prev != nil { + prev.next = stmt + } + prev = stmt } - return stmt, nil + return firstStmt, nil } func (s *fakeStmt) ColumnConverter(idx int) driver.ValueConverter { @@ -550,6 +568,9 @@ func (s *fakeStmt) Close() error { s.c.incrStat(&s.c.stmtsClosed) s.closed = true } + if s.next != nil { + s.next.Close() + } return nil } @@ -667,64 +688,80 @@ func (s *fakeStmt) Query(args []driver.Value) (driver.Rows, error) { panic("error in pkg db; should only get here if size is correct") } - db.mu.Lock() - t, ok := db.table(s.table) - db.mu.Unlock() - if !ok { - return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table) - } + setMRows := make([][]*row, 0, 1) + setColumns := make([][]string, 0, 1) - if s.table == "magicquery" { - if len(s.whereCol) == 2 && s.whereCol[0] == "op" && s.whereCol[1] == "millis" { - if args[0] == "sleep" { - time.Sleep(time.Duration(args[1].(int64)) * time.Millisecond) - } + for { + db.mu.Lock() + t, ok := db.table(s.table) + db.mu.Unlock() + if !ok { + return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table) } - } - t.mu.Lock() - defer t.mu.Unlock() - - colIdx := make(map[string]int) // select column name -> column index in table - for _, name := range s.colName { - idx := t.columnIndex(name) - if idx == -1 { - return nil, fmt.Errorf("fakedb: unknown column name %q", name) + if s.table == "magicquery" { + if len(s.whereCol) == 2 && s.whereCol[0] == "op" && s.whereCol[1] == "millis" { + if args[0] == "sleep" { + time.Sleep(time.Duration(args[1].(int64)) * time.Millisecond) + } + } } - colIdx[name] = idx - } - mrows := []*row{} -rows: - for _, trow := range t.rows { - // Process the where clause, skipping non-match rows. This is lazy - // and just uses fmt.Sprintf("%v") to test equality. Good enough - // for test code. - for widx, wcol := range s.whereCol { - idx := t.columnIndex(wcol) + t.mu.Lock() + + colIdx := make(map[string]int) // select column name -> column index in table + for _, name := range s.colName { + idx := t.columnIndex(name) if idx == -1 { - return nil, fmt.Errorf("db: invalid where clause column %q", wcol) + t.mu.Unlock() + return nil, fmt.Errorf("fakedb: unknown column name %q", name) } - tcol := trow.cols[idx] - if bs, ok := tcol.([]byte); ok { - // lazy hack to avoid sprintf %v on a []byte - tcol = string(bs) + colIdx[name] = idx + } + + mrows := []*row{} + rows: + for _, trow := range t.rows { + // Process the where clause, skipping non-match rows. This is lazy + // and just uses fmt.Sprintf("%v") to test equality. Good enough + // for test code. + for widx, wcol := range s.whereCol { + idx := t.columnIndex(wcol) + if idx == -1 { + t.mu.Unlock() + return nil, fmt.Errorf("db: invalid where clause column %q", wcol) + } + tcol := trow.cols[idx] + if bs, ok := tcol.([]byte); ok { + // lazy hack to avoid sprintf %v on a []byte + tcol = string(bs) + } + if fmt.Sprintf("%v", tcol) != fmt.Sprintf("%v", args[widx]) { + continue rows + } } - if fmt.Sprintf("%v", tcol) != fmt.Sprintf("%v", args[widx]) { - continue rows + mrow := &row{cols: make([]interface{}, len(s.colName))} + for seli, name := range s.colName { + mrow.cols[seli] = trow.cols[colIdx[name]] } + mrows = append(mrows, mrow) } - mrow := &row{cols: make([]interface{}, len(s.colName))} - for seli, name := range s.colName { - mrow.cols[seli] = trow.cols[colIdx[name]] + + t.mu.Unlock() + + setMRows = append(setMRows, mrows) + setColumns = append(setColumns, s.colName) + + if s.next == nil { + break } - mrows = append(mrows, mrow) + s = s.next } cursor := &rowsCursor{ - pos: -1, - rows: mrows, - cols: s.colName, + posRow: -1, + rows: setMRows, + cols: setColumns, errPos: -1, } return cursor, nil @@ -760,9 +797,10 @@ func (tx *fakeTx) Rollback() error { } type rowsCursor struct { - cols []string - pos int - rows []*row + cols [][]string + posSet int + posRow int + rows [][]*row closed bool // errPos and err are for making Next return early with error. @@ -786,7 +824,7 @@ func (rc *rowsCursor) Close() error { } func (rc *rowsCursor) Columns() []string { - return rc.cols + return rc.cols[rc.posSet] } var rowsCursorNextHook func(dest []driver.Value) error @@ -799,14 +837,14 @@ func (rc *rowsCursor) Next(dest []driver.Value) error { if rc.closed { return errors.New("fakedb: cursor is closed") } - rc.pos++ - if rc.pos == rc.errPos { + rc.posRow++ + if rc.posRow == rc.errPos { return rc.err } - if rc.pos >= len(rc.rows) { + if rc.posRow >= len(rc.rows[rc.posSet]) { return io.EOF // per interface spec } - for i, v := range rc.rows[rc.pos].cols { + for i, v := range rc.rows[rc.posSet][rc.posRow].cols { // TODO(bradfitz): convert to subset types? naah, I // think the subset types should only be input to // driver, but the sql package should be able to handle @@ -831,6 +869,19 @@ func (rc *rowsCursor) Next(dest []driver.Value) error { return nil } +func (rc *rowsCursor) HasNextResultSet() bool { + return rc.posSet < len(rc.rows)-1 +} + +func (rc *rowsCursor) NextResultSet() error { + if rc.HasNextResultSet() { + rc.posSet++ + rc.posRow = -1 + return nil + } + return io.EOF // Per interface spec. +} + // fakeDriverString is like driver.String, but indirects pointers like // DefaultValueConverter. // diff --git a/src/database/sql/sql.go b/src/database/sql/sql.go index c26d7d3063..970334269d 100644 --- a/src/database/sql/sql.go +++ b/src/database/sql/sql.go @@ -1942,6 +1942,47 @@ func (rs *Rows) Next() bool { rs.lastcols = make([]driver.Value, len(rs.rowsi.Columns())) } rs.lasterr = rs.rowsi.Next(rs.lastcols) + if rs.lasterr != nil { + // Close the connection if there is a driver error. + if rs.lasterr != io.EOF { + rs.Close() + return false + } + nextResultSet, ok := rs.rowsi.(driver.RowsNextResultSet) + if !ok { + rs.Close() + return false + } + // The driver is at the end of the current result set. + // Test to see if there is another result set after the current one. + // Only close Rows if there is no futher result sets to read. + if !nextResultSet.HasNextResultSet() { + rs.Close() + } + return false + } + return true +} + +// NextResultSet prepares the next result set for reading. It returns true if +// there is further result sets, or false if there is no further result set +// or if there is an error advancing to it. The Err method should be consulted +// to distinguish between the two cases. +// +// After calling NextResultSet, the Next method should always be called before +// scanning. If there are further result sets they may not have rows in the result +// set. +func (rs *Rows) NextResultSet() bool { + if rs.isClosed() { + return false + } + rs.lastcols = nil + nextResultSet, ok := rs.rowsi.(driver.RowsNextResultSet) + if !ok { + rs.Close() + return false + } + rs.lasterr = nextResultSet.NextResultSet() if rs.lasterr != nil { rs.Close() return false @@ -2047,7 +2088,8 @@ func (rs *Rows) isClosed() bool { return atomic.LoadInt32(&rs.closed) != 0 } -// Close closes the Rows, preventing further enumeration. If Next returns +// Close closes the Rows, preventing further enumeration. If Next and +// NextResultSet both return // false, the Rows are closed automatically and it will suffice to check the // result of Err. Close is idempotent and does not affect the result of Err. func (rs *Rows) Close() error { diff --git a/src/database/sql/sql_test.go b/src/database/sql/sql_test.go index ca14af79e7..bce210da97 100644 --- a/src/database/sql/sql_test.go +++ b/src/database/sql/sql_test.go @@ -319,6 +319,82 @@ func TestQueryContext(t *testing.T) { } } +func TestMultiResultSetQuery(t *testing.T) { + db := newTestDB(t, "people") + defer closeDB(t, db) + prepares0 := numPrepares(t, db) + rows, err := db.Query("SELECT|people|age,name|;SELECT|people|name|") + if err != nil { + t.Fatalf("Query: %v", err) + } + type row1 struct { + age int + name string + } + type row2 struct { + name string + } + got1 := []row1{} + for rows.Next() { + var r row1 + err = rows.Scan(&r.age, &r.name) + if err != nil { + t.Fatalf("Scan: %v", err) + } + got1 = append(got1, r) + } + err = rows.Err() + if err != nil { + t.Fatalf("Err: %v", err) + } + want1 := []row1{ + {age: 1, name: "Alice"}, + {age: 2, name: "Bob"}, + {age: 3, name: "Chris"}, + } + if !reflect.DeepEqual(got1, want1) { + t.Errorf("mismatch.\n got1: %#v\nwant: %#v", got1, want1) + } + + if !rows.NextResultSet() { + t.Errorf("expected another result set") + } + + got2 := []row2{} + for rows.Next() { + var r row2 + err = rows.Scan(&r.name) + if err != nil { + t.Fatalf("Scan: %v", err) + } + got2 = append(got2, r) + } + err = rows.Err() + if err != nil { + t.Fatalf("Err: %v", err) + } + want2 := []row2{ + {name: "Alice"}, + {name: "Bob"}, + {name: "Chris"}, + } + if !reflect.DeepEqual(got2, want2) { + t.Errorf("mismatch.\n got: %#v\nwant: %#v", got2, want2) + } + if rows.NextResultSet() { + t.Errorf("expected no more result sets") + } + + // And verify that the final rows.Next() call, which hit EOF, + // also closed the rows connection. + if n := db.numFreeConns(); n != 1 { + t.Fatalf("free conns after query hitting EOF = %d; want 1", n) + } + if prepares := numPrepares(t, db) - prepares0; prepares != 1 { + t.Errorf("executed %d Prepare statements; want 1", prepares) + } +} + func TestByteOwnership(t *testing.T) { db := newTestDB(t, "people") defer closeDB(t, db)