return stmt, nil
}
+// hook to simulate broken connections
+var hookPrepareBadConn func() bool
+
func (c *fakeConn) Prepare(query string) (driver.Stmt, error) {
c.numPrepare++
if c.db == nil {
panic("nil c.db; conn = " + fmt.Sprintf("%#v", c))
}
+
+ if hookPrepareBadConn != nil && hookPrepareBadConn() {
+ return nil, driver.ErrBadConn
+ }
+
parts := strings.Split(query, "|")
if len(parts) < 1 {
return nil, errf("empty query")
var errClosed = errors.New("fakedb: statement has been closed")
+// hook to simulate broken connections
+var hookExecBadConn func() bool
+
func (s *fakeStmt) Exec(args []driver.Value) (driver.Result, error) {
if s.closed {
return nil, errClosed
}
+
+ if hookExecBadConn != nil && hookExecBadConn() {
+ return nil, driver.ErrBadConn
+ }
+
err := checkSubsetTypes(args)
if err != nil {
return nil, err
return driver.RowsAffected(1), nil
}
+// hook to simulate broken connections
+var hookQueryBadConn func() bool
+
func (s *fakeStmt) Query(args []driver.Value) (driver.Rows, error) {
if s.closed {
return nil, errClosed
}
+
+ if hookQueryBadConn != nil && hookQueryBadConn() {
+ return nil, driver.ErrBadConn
+ }
+
err := checkSubsetTypes(args)
if err != nil {
return nil, err
// stmt closes if the conn is about to close anyway? For now
// do the safe thing, in case stmts need to be closed.
//
- // TODO(bradfitz): after Go 1.1, closing driver.Stmts
+ // TODO(bradfitz): after Go 1.2, closing driver.Stmts
// should be moved to driverStmt, using unique
// *driverStmts everywhere (including from
// *Stmt.connStmt, instead of returning a
return false
}
+// maxBadConnRetries is the number of maximum retries if the driver returns
+// driver.ErrBadConn to signal a broken connection.
+const maxBadConnRetries = 10
+
// Prepare creates a prepared statement for later queries or executions.
// Multiple queries or executions may be run concurrently from the
// returned statement.
func (db *DB) Prepare(query string) (*Stmt, error) {
var stmt *Stmt
var err error
- for i := 0; i < 10; i++ {
+ for i := 0; i < maxBadConnRetries; i++ {
stmt, err = db.prepare(query)
if err != driver.ErrBadConn {
break
func (db *DB) Exec(query string, args ...interface{}) (Result, error) {
var res Result
var err error
- for i := 0; i < 10; i++ {
+ for i := 0; i < maxBadConnRetries; i++ {
res, err = db.exec(query, args)
if err != driver.ErrBadConn {
break
func (db *DB) Query(query string, args ...interface{}) (*Rows, error) {
var rows *Rows
var err error
- for i := 0; i < 10; i++ {
+ for i := 0; i < maxBadConnRetries; i++ {
rows, err = db.query(query, args)
if err != driver.ErrBadConn {
break
func (db *DB) Begin() (*Tx, error) {
var tx *Tx
var err error
- for i := 0; i < 10; i++ {
+ for i := 0; i < maxBadConnRetries; i++ {
tx, err = db.begin()
if err != driver.ErrBadConn {
break
func (s *Stmt) Exec(args ...interface{}) (Result, error) {
s.closemu.RLock()
defer s.closemu.RUnlock()
- dc, releaseConn, si, err := s.connStmt()
- if err != nil {
- return nil, err
- }
- defer releaseConn(nil)
- return resultFromStatement(driverStmt{dc, si}, args...)
+ var res Result
+ for i := 0; i < maxBadConnRetries; i++ {
+ dc, releaseConn, si, err := s.connStmt()
+ if err != nil {
+ if err == driver.ErrBadConn {
+ continue
+ }
+ return nil, err
+ }
+
+ res, err = resultFromStatement(driverStmt{dc, si}, args...)
+ releaseConn(err)
+ if err != driver.ErrBadConn {
+ return res, err
+ }
+ }
+ return nil, driver.ErrBadConn
}
func resultFromStatement(ds driverStmt, args ...interface{}) (Result, error) {
// Make a new conn if all are busy.
// TODO(bradfitz): or wait for one? make configurable later?
if !match {
- for i := 0; ; i++ {
- dc, err := s.db.conn()
- if err != nil {
- return nil, nil, nil, err
- }
- dc.Lock()
- si, err := dc.prepareLocked(s.query)
- dc.Unlock()
- if err == driver.ErrBadConn && i < 10 {
- continue
- }
- if err != nil {
- return nil, nil, nil, err
- }
- s.mu.Lock()
- cs = connStmt{dc, si}
- s.css = append(s.css, cs)
- s.mu.Unlock()
- break
+ dc, err := s.db.conn()
+ if err != nil {
+ return nil, nil, nil, err
}
+ dc.Lock()
+ si, err := dc.prepareLocked(s.query)
+ dc.Unlock()
+ if err != nil {
+ s.db.putConn(dc, err)
+ return nil, nil, nil, err
+ }
+ s.mu.Lock()
+ cs = connStmt{dc, si}
+ s.css = append(s.css, cs)
+ s.mu.Unlock()
}
conn := cs.dc
s.closemu.RLock()
defer s.closemu.RUnlock()
- dc, releaseConn, si, err := s.connStmt()
- if err != nil {
- return nil, err
- }
+ var rowsi driver.Rows
+ for i := 0; i < maxBadConnRetries; i++ {
+ dc, releaseConn, si, err := s.connStmt()
+ if err != nil {
+ if err == driver.ErrBadConn {
+ continue
+ }
+ return nil, err
+ }
- ds := driverStmt{dc, si}
- rowsi, err := rowsiFromStatement(ds, args...)
- if err != nil {
- releaseConn(err)
- return nil, err
- }
+ rowsi, err = rowsiFromStatement(driverStmt{dc, si}, args...)
+ if err == nil {
+ // Note: ownership of ci passes to the *Rows, to be freed
+ // with releaseConn.
+ rows := &Rows{
+ dc: dc,
+ rowsi: rowsi,
+ // releaseConn set below
+ }
+ s.db.addDep(s, rows)
+ rows.releaseConn = func(err error) {
+ releaseConn(err)
+ s.db.removeDep(s, rows)
+ }
+ return rows, nil
+ }
- // Note: ownership of ci passes to the *Rows, to be freed
- // with releaseConn.
- rows := &Rows{
- dc: dc,
- rowsi: rowsi,
- // releaseConn set below
- }
- s.db.addDep(s, rows)
- rows.releaseConn = func(err error) {
releaseConn(err)
- s.db.removeDep(s, rows)
+ if err != driver.ErrBadConn {
+ return nil, err
+ }
}
- return rows, nil
+ return nil, driver.ErrBadConn
}
func rowsiFromStatement(ds driverStmt, args ...interface{}) (driver.Rows, error) {
t.Errorf("%d: age=%d, want %d", n, age, tt.want)
}
}
-
}
// golang.org/issue/3734
}
}
+// golang.org/issue/5781
+func TestErrBadConnReconnect(t *testing.T) {
+ db := newTestDB(t, "foo")
+ defer closeDB(t, db)
+ exec(t, db, "CREATE|t1|name=string,age=int32,dead=bool")
+
+ simulateBadConn := func(name string, hook *func() bool, op func() error) {
+ broken, retried := false, false
+ numOpen := db.numOpen
+
+ // simulate a broken connection on the first try
+ *hook = func() bool {
+ if !broken {
+ broken = true
+ return true
+ }
+ retried = true
+ return false
+ }
+
+ if err := op(); err != nil {
+ t.Errorf(name+": %v", err)
+ return
+ }
+
+ if !broken || !retried {
+ t.Error(name + ": Failed to simulate broken connection")
+ }
+ *hook = nil
+
+ if numOpen != db.numOpen {
+ t.Errorf(name+": leaked %d connection(s)!", db.numOpen-numOpen)
+ numOpen = db.numOpen
+ }
+ }
+
+ // db.Exec
+ dbExec := func() error {
+ _, err := db.Exec("INSERT|t1|name=?,age=?,dead=?", "Gordon", 3, true)
+ return err
+ }
+ simulateBadConn("db.Exec prepare", &hookPrepareBadConn, dbExec)
+ simulateBadConn("db.Exec exec", &hookExecBadConn, dbExec)
+
+ // db.Query
+ dbQuery := func() error {
+ rows, err := db.Query("SELECT|t1|age,name|")
+ if err == nil {
+ err = rows.Close()
+ }
+ return err
+ }
+ simulateBadConn("db.Query prepare", &hookPrepareBadConn, dbQuery)
+ simulateBadConn("db.Query query", &hookQueryBadConn, dbQuery)
+
+ // db.Prepare
+ simulateBadConn("db.Prepare", &hookPrepareBadConn, func() error {
+ stmt, err := db.Prepare("INSERT|t1|name=?,age=?,dead=?")
+ if err != nil {
+ return err
+ }
+ stmt.Close()
+ return nil
+ })
+
+ // stmt.Exec
+ stmt1, err := db.Prepare("INSERT|t1|name=?,age=?,dead=?")
+ if err != nil {
+ t.Fatalf("prepare: %v", err)
+ }
+ defer stmt1.Close()
+ // make sure we must prepare the stmt first
+ for _, cs := range stmt1.css {
+ cs.dc.inUse = true
+ }
+
+ stmtExec := func() error {
+ _, err := stmt1.Exec("Gopher", 3, false)
+ return err
+ }
+ simulateBadConn("stmt.Exec prepare", &hookPrepareBadConn, stmtExec)
+ simulateBadConn("stmt.Exec exec", &hookExecBadConn, stmtExec)
+
+ // stmt.Query
+ stmt2, err := db.Prepare("SELECT|t1|age,name|")
+ if err != nil {
+ t.Fatalf("prepare: %v", err)
+ }
+ defer stmt2.Close()
+ // make sure we must prepare the stmt first
+ for _, cs := range stmt2.css {
+ cs.dc.inUse = true
+ }
+
+ stmtQuery := func() error {
+ rows, err := stmt2.Query()
+ if err == nil {
+ err = rows.Close()
+ }
+ return err
+ }
+ simulateBadConn("stmt.Query prepare", &hookPrepareBadConn, stmtQuery)
+ simulateBadConn("stmt.Query exec", &hookQueryBadConn, stmtQuery)
+}
+
type concurrentTest interface {
init(t testing.TB, db *DB)
finish(t testing.TB)