// Any of these can be preceded by PANIC|<method>|, to cause the
// named method on fakeStmt to panic.
//
+// Multiple of these can be combined when separated with a semicolon.
+//
// When opening a fakeDriver's database, it starts empty with no
// tables. All tables and data are stored in memory only.
type fakeDriver struct {
table string
panic string
+ next *fakeStmt // used for returning multiple results.
+
closed bool
colName []string // used by CREATE, INSERT, SELECT (selected columns)
// parts are table|selectCol1,selectCol2|whereCol=?,whereCol2=?
// (note that where columns must always contain ? marks,
// just a limitation for fakedb)
-func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (driver.Stmt, error) {
+func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (*fakeStmt, error) {
if len(parts) != 3 {
stmt.Close()
return nil, errf("invalid SELECT syntax with %d parts; want 3", len(parts))
}
// parts are table|col=type,col2=type2
-func (c *fakeConn) prepareCreate(stmt *fakeStmt, parts []string) (driver.Stmt, error) {
+func (c *fakeConn) prepareCreate(stmt *fakeStmt, parts []string) (*fakeStmt, error) {
if len(parts) != 2 {
stmt.Close()
return nil, errf("invalid CREATE syntax with %d parts; want 2", len(parts))
}
// parts are table|col=?,col2=val
-func (c *fakeConn) prepareInsert(stmt *fakeStmt, parts []string) (driver.Stmt, error) {
+func (c *fakeConn) prepareInsert(stmt *fakeStmt, parts []string) (*fakeStmt, error) {
if len(parts) != 2 {
stmt.Close()
return nil, errf("invalid INSERT syntax with %d parts; want 2", len(parts))
return nil, driver.ErrBadConn
}
- parts := strings.Split(query, "|")
- if len(parts) < 1 {
- return nil, errf("empty query")
- }
- stmt := &fakeStmt{q: query, c: c}
- if len(parts) >= 3 && parts[0] == "PANIC" {
- stmt.panic = parts[1]
- parts = parts[2:]
- }
- cmd := parts[0]
- stmt.cmd = cmd
- parts = parts[1:]
-
- c.incrStat(&c.stmtsMade)
- switch cmd {
- case "WIPE":
- // Nothing
- case "SELECT":
- return c.prepareSelect(stmt, parts)
- case "CREATE":
- return c.prepareCreate(stmt, parts)
- case "INSERT":
- return c.prepareInsert(stmt, parts)
- case "NOSERT":
- // Do all the prep-work like for an INSERT but don't actually insert the row.
- // Used for some of the concurrent tests.
- return c.prepareInsert(stmt, parts)
- default:
- stmt.Close()
- return nil, errf("unsupported command type %q", cmd)
+ var firstStmt, prev *fakeStmt
+ for _, query := range strings.Split(query, ";") {
+ parts := strings.Split(query, "|")
+ if len(parts) < 1 {
+ return nil, errf("empty query")
+ }
+ stmt := &fakeStmt{q: query, c: c}
+ if firstStmt == nil {
+ firstStmt = stmt
+ }
+ if len(parts) >= 3 && parts[0] == "PANIC" {
+ stmt.panic = parts[1]
+ parts = parts[2:]
+ }
+ cmd := parts[0]
+ stmt.cmd = cmd
+ parts = parts[1:]
+
+ c.incrStat(&c.stmtsMade)
+ var err error
+ switch cmd {
+ case "WIPE":
+ // Nothing
+ case "SELECT":
+ stmt, err = c.prepareSelect(stmt, parts)
+ case "CREATE":
+ stmt, err = c.prepareCreate(stmt, parts)
+ case "INSERT":
+ stmt, err = c.prepareInsert(stmt, parts)
+ case "NOSERT":
+ // Do all the prep-work like for an INSERT but don't actually insert the row.
+ // Used for some of the concurrent tests.
+ stmt, err = c.prepareInsert(stmt, parts)
+ default:
+ stmt.Close()
+ return nil, errf("unsupported command type %q", cmd)
+ }
+ if err != nil {
+ return nil, err
+ }
+ if prev != nil {
+ prev.next = stmt
+ }
+ prev = stmt
}
- return stmt, nil
+ return firstStmt, nil
}
func (s *fakeStmt) ColumnConverter(idx int) driver.ValueConverter {
s.c.incrStat(&s.c.stmtsClosed)
s.closed = true
}
+ if s.next != nil {
+ s.next.Close()
+ }
return nil
}
panic("error in pkg db; should only get here if size is correct")
}
- db.mu.Lock()
- t, ok := db.table(s.table)
- db.mu.Unlock()
- if !ok {
- return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table)
- }
+ setMRows := make([][]*row, 0, 1)
+ setColumns := make([][]string, 0, 1)
- if s.table == "magicquery" {
- if len(s.whereCol) == 2 && s.whereCol[0] == "op" && s.whereCol[1] == "millis" {
- if args[0] == "sleep" {
- time.Sleep(time.Duration(args[1].(int64)) * time.Millisecond)
- }
+ for {
+ db.mu.Lock()
+ t, ok := db.table(s.table)
+ db.mu.Unlock()
+ if !ok {
+ return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table)
}
- }
- t.mu.Lock()
- defer t.mu.Unlock()
-
- colIdx := make(map[string]int) // select column name -> column index in table
- for _, name := range s.colName {
- idx := t.columnIndex(name)
- if idx == -1 {
- return nil, fmt.Errorf("fakedb: unknown column name %q", name)
+ if s.table == "magicquery" {
+ if len(s.whereCol) == 2 && s.whereCol[0] == "op" && s.whereCol[1] == "millis" {
+ if args[0] == "sleep" {
+ time.Sleep(time.Duration(args[1].(int64)) * time.Millisecond)
+ }
+ }
}
- colIdx[name] = idx
- }
- mrows := []*row{}
-rows:
- for _, trow := range t.rows {
- // Process the where clause, skipping non-match rows. This is lazy
- // and just uses fmt.Sprintf("%v") to test equality. Good enough
- // for test code.
- for widx, wcol := range s.whereCol {
- idx := t.columnIndex(wcol)
+ t.mu.Lock()
+
+ colIdx := make(map[string]int) // select column name -> column index in table
+ for _, name := range s.colName {
+ idx := t.columnIndex(name)
if idx == -1 {
- return nil, fmt.Errorf("db: invalid where clause column %q", wcol)
+ t.mu.Unlock()
+ return nil, fmt.Errorf("fakedb: unknown column name %q", name)
}
- tcol := trow.cols[idx]
- if bs, ok := tcol.([]byte); ok {
- // lazy hack to avoid sprintf %v on a []byte
- tcol = string(bs)
+ colIdx[name] = idx
+ }
+
+ mrows := []*row{}
+ rows:
+ for _, trow := range t.rows {
+ // Process the where clause, skipping non-match rows. This is lazy
+ // and just uses fmt.Sprintf("%v") to test equality. Good enough
+ // for test code.
+ for widx, wcol := range s.whereCol {
+ idx := t.columnIndex(wcol)
+ if idx == -1 {
+ t.mu.Unlock()
+ return nil, fmt.Errorf("db: invalid where clause column %q", wcol)
+ }
+ tcol := trow.cols[idx]
+ if bs, ok := tcol.([]byte); ok {
+ // lazy hack to avoid sprintf %v on a []byte
+ tcol = string(bs)
+ }
+ if fmt.Sprintf("%v", tcol) != fmt.Sprintf("%v", args[widx]) {
+ continue rows
+ }
}
- if fmt.Sprintf("%v", tcol) != fmt.Sprintf("%v", args[widx]) {
- continue rows
+ mrow := &row{cols: make([]interface{}, len(s.colName))}
+ for seli, name := range s.colName {
+ mrow.cols[seli] = trow.cols[colIdx[name]]
}
+ mrows = append(mrows, mrow)
}
- mrow := &row{cols: make([]interface{}, len(s.colName))}
- for seli, name := range s.colName {
- mrow.cols[seli] = trow.cols[colIdx[name]]
+
+ t.mu.Unlock()
+
+ setMRows = append(setMRows, mrows)
+ setColumns = append(setColumns, s.colName)
+
+ if s.next == nil {
+ break
}
- mrows = append(mrows, mrow)
+ s = s.next
}
cursor := &rowsCursor{
- pos: -1,
- rows: mrows,
- cols: s.colName,
+ posRow: -1,
+ rows: setMRows,
+ cols: setColumns,
errPos: -1,
}
return cursor, nil
}
type rowsCursor struct {
- cols []string
- pos int
- rows []*row
+ cols [][]string
+ posSet int
+ posRow int
+ rows [][]*row
closed bool
// errPos and err are for making Next return early with error.
}
func (rc *rowsCursor) Columns() []string {
- return rc.cols
+ return rc.cols[rc.posSet]
}
var rowsCursorNextHook func(dest []driver.Value) error
if rc.closed {
return errors.New("fakedb: cursor is closed")
}
- rc.pos++
- if rc.pos == rc.errPos {
+ rc.posRow++
+ if rc.posRow == rc.errPos {
return rc.err
}
- if rc.pos >= len(rc.rows) {
+ if rc.posRow >= len(rc.rows[rc.posSet]) {
return io.EOF // per interface spec
}
- for i, v := range rc.rows[rc.pos].cols {
+ for i, v := range rc.rows[rc.posSet][rc.posRow].cols {
// TODO(bradfitz): convert to subset types? naah, I
// think the subset types should only be input to
// driver, but the sql package should be able to handle
return nil
}
+func (rc *rowsCursor) HasNextResultSet() bool {
+ return rc.posSet < len(rc.rows)-1
+}
+
+func (rc *rowsCursor) NextResultSet() error {
+ if rc.HasNextResultSet() {
+ rc.posSet++
+ rc.posRow = -1
+ return nil
+ }
+ return io.EOF // Per interface spec.
+}
+
// fakeDriverString is like driver.String, but indirects pointers like
// DefaultValueConverter.
//
}
}
+func TestMultiResultSetQuery(t *testing.T) {
+ db := newTestDB(t, "people")
+ defer closeDB(t, db)
+ prepares0 := numPrepares(t, db)
+ rows, err := db.Query("SELECT|people|age,name|;SELECT|people|name|")
+ if err != nil {
+ t.Fatalf("Query: %v", err)
+ }
+ type row1 struct {
+ age int
+ name string
+ }
+ type row2 struct {
+ name string
+ }
+ got1 := []row1{}
+ for rows.Next() {
+ var r row1
+ err = rows.Scan(&r.age, &r.name)
+ if err != nil {
+ t.Fatalf("Scan: %v", err)
+ }
+ got1 = append(got1, r)
+ }
+ err = rows.Err()
+ if err != nil {
+ t.Fatalf("Err: %v", err)
+ }
+ want1 := []row1{
+ {age: 1, name: "Alice"},
+ {age: 2, name: "Bob"},
+ {age: 3, name: "Chris"},
+ }
+ if !reflect.DeepEqual(got1, want1) {
+ t.Errorf("mismatch.\n got1: %#v\nwant: %#v", got1, want1)
+ }
+
+ if !rows.NextResultSet() {
+ t.Errorf("expected another result set")
+ }
+
+ got2 := []row2{}
+ for rows.Next() {
+ var r row2
+ err = rows.Scan(&r.name)
+ if err != nil {
+ t.Fatalf("Scan: %v", err)
+ }
+ got2 = append(got2, r)
+ }
+ err = rows.Err()
+ if err != nil {
+ t.Fatalf("Err: %v", err)
+ }
+ want2 := []row2{
+ {name: "Alice"},
+ {name: "Bob"},
+ {name: "Chris"},
+ }
+ if !reflect.DeepEqual(got2, want2) {
+ t.Errorf("mismatch.\n got: %#v\nwant: %#v", got2, want2)
+ }
+ if rows.NextResultSet() {
+ t.Errorf("expected no more result sets")
+ }
+
+ // And verify that the final rows.Next() call, which hit EOF,
+ // also closed the rows connection.
+ if n := db.numFreeConns(); n != 1 {
+ t.Fatalf("free conns after query hitting EOF = %d; want 1", n)
+ }
+ if prepares := numPrepares(t, db) - prepares0; prepares != 1 {
+ t.Errorf("executed %d Prepare statements; want 1", prepares)
+ }
+}
+
func TestByteOwnership(t *testing.T) {
db := newTestDB(t, "people")
defer closeDB(t, db)