]> Cypherpunks repositories - gostls13.git/commitdiff
database/sql: fix auto-reconnect in prepared statements
authorJulien Schmidt <google@julienschmidt.com>
Tue, 17 Dec 2013 19:57:30 +0000 (11:57 -0800)
committerBrad Fitzpatrick <bradfitz@golang.org>
Tue, 17 Dec 2013 19:57:30 +0000 (11:57 -0800)
This also fixes several connection leaks.
Fixes #5718

R=bradfitz, adg
CC=alberto.garcia.hierro, golang-dev
https://golang.org/cl/14920046

src/pkg/database/sql/fakedb_test.go
src/pkg/database/sql/sql.go
src/pkg/database/sql/sql_test.go

index 775f67d19ed5fbc6bd67e36be45214e3a9e2289c..00ab799981d6af42fa087e4d21635d95b1391682 100644 (file)
@@ -433,11 +433,19 @@ func (c *fakeConn) prepareInsert(stmt *fakeStmt, parts []string) (driver.Stmt, e
        return stmt, nil
 }
 
+// hook to simulate broken connections
+var hookPrepareBadConn func() bool
+
 func (c *fakeConn) Prepare(query string) (driver.Stmt, error) {
        c.numPrepare++
        if c.db == nil {
                panic("nil c.db; conn = " + fmt.Sprintf("%#v", c))
        }
+
+       if hookPrepareBadConn != nil && hookPrepareBadConn() {
+               return nil, driver.ErrBadConn
+       }
+
        parts := strings.Split(query, "|")
        if len(parts) < 1 {
                return nil, errf("empty query")
@@ -489,10 +497,18 @@ func (s *fakeStmt) Close() error {
 
 var errClosed = errors.New("fakedb: statement has been closed")
 
+// hook to simulate broken connections
+var hookExecBadConn func() bool
+
 func (s *fakeStmt) Exec(args []driver.Value) (driver.Result, error) {
        if s.closed {
                return nil, errClosed
        }
+
+       if hookExecBadConn != nil && hookExecBadConn() {
+               return nil, driver.ErrBadConn
+       }
+
        err := checkSubsetTypes(args)
        if err != nil {
                return nil, err
@@ -565,10 +581,18 @@ func (s *fakeStmt) execInsert(args []driver.Value, doInsert bool) (driver.Result
        return driver.RowsAffected(1), nil
 }
 
+// hook to simulate broken connections
+var hookQueryBadConn func() bool
+
 func (s *fakeStmt) Query(args []driver.Value) (driver.Rows, error) {
        if s.closed {
                return nil, errClosed
        }
+
+       if hookQueryBadConn != nil && hookQueryBadConn() {
+               return nil, driver.ErrBadConn
+       }
+
        err := checkSubsetTypes(args)
        if err != nil {
                return nil, err
index fae109f2523832ca55d7790d1966fdccb5b17d8f..df989cd66c7da7a7a7c09c5600c0033181be8c73 100644 (file)
@@ -256,7 +256,7 @@ func (dc *driverConn) prepareLocked(query string) (driver.Stmt, error) {
                // stmt closes if the conn is about to close anyway? For now
                // do the safe thing, in case stmts need to be closed.
                //
-               // TODO(bradfitz): after Go 1.1, closing driver.Stmts
+               // TODO(bradfitz): after Go 1.2, closing driver.Stmts
                // should be moved to driverStmt, using unique
                // *driverStmts everywhere (including from
                // *Stmt.connStmt, instead of returning a
@@ -798,13 +798,17 @@ func (db *DB) putConnDBLocked(dc *driverConn, err error) bool {
        return false
 }
 
+// maxBadConnRetries is the number of maximum retries if the driver returns
+// driver.ErrBadConn to signal a broken connection.
+const maxBadConnRetries = 10
+
 // Prepare creates a prepared statement for later queries or executions.
 // Multiple queries or executions may be run concurrently from the
 // returned statement.
 func (db *DB) Prepare(query string) (*Stmt, error) {
        var stmt *Stmt
        var err error
-       for i := 0; i < 10; i++ {
+       for i := 0; i < maxBadConnRetries; i++ {
                stmt, err = db.prepare(query)
                if err != driver.ErrBadConn {
                        break
@@ -846,7 +850,7 @@ func (db *DB) prepare(query string) (*Stmt, error) {
 func (db *DB) Exec(query string, args ...interface{}) (Result, error) {
        var res Result
        var err error
-       for i := 0; i < 10; i++ {
+       for i := 0; i < maxBadConnRetries; i++ {
                res, err = db.exec(query, args)
                if err != driver.ErrBadConn {
                        break
@@ -895,7 +899,7 @@ func (db *DB) exec(query string, args []interface{}) (res Result, err error) {
 func (db *DB) Query(query string, args ...interface{}) (*Rows, error) {
        var rows *Rows
        var err error
-       for i := 0; i < 10; i++ {
+       for i := 0; i < maxBadConnRetries; i++ {
                rows, err = db.query(query, args)
                if err != driver.ErrBadConn {
                        break
@@ -983,7 +987,7 @@ func (db *DB) QueryRow(query string, args ...interface{}) *Row {
 func (db *DB) Begin() (*Tx, error) {
        var tx *Tx
        var err error
-       for i := 0; i < 10; i++ {
+       for i := 0; i < maxBadConnRetries; i++ {
                tx, err = db.begin()
                if err != driver.ErrBadConn {
                        break
@@ -1245,13 +1249,24 @@ type Stmt struct {
 func (s *Stmt) Exec(args ...interface{}) (Result, error) {
        s.closemu.RLock()
        defer s.closemu.RUnlock()
-       dc, releaseConn, si, err := s.connStmt()
-       if err != nil {
-               return nil, err
-       }
-       defer releaseConn(nil)
 
-       return resultFromStatement(driverStmt{dc, si}, args...)
+       var res Result
+       for i := 0; i < maxBadConnRetries; i++ {
+               dc, releaseConn, si, err := s.connStmt()
+               if err != nil {
+                       if err == driver.ErrBadConn {
+                               continue
+                       }
+                       return nil, err
+               }
+
+               res, err = resultFromStatement(driverStmt{dc, si}, args...)
+               releaseConn(err)
+               if err != driver.ErrBadConn {
+                       return res, err
+               }
+       }
+       return nil, driver.ErrBadConn
 }
 
 func resultFromStatement(ds driverStmt, args ...interface{}) (Result, error) {
@@ -1329,26 +1344,21 @@ func (s *Stmt) connStmt() (ci *driverConn, releaseConn func(error), si driver.St
        // Make a new conn if all are busy.
        // TODO(bradfitz): or wait for one? make configurable later?
        if !match {
-               for i := 0; ; i++ {
-                       dc, err := s.db.conn()
-                       if err != nil {
-                               return nil, nil, nil, err
-                       }
-                       dc.Lock()
-                       si, err := dc.prepareLocked(s.query)
-                       dc.Unlock()
-                       if err == driver.ErrBadConn && i < 10 {
-                               continue
-                       }
-                       if err != nil {
-                               return nil, nil, nil, err
-                       }
-                       s.mu.Lock()
-                       cs = connStmt{dc, si}
-                       s.css = append(s.css, cs)
-                       s.mu.Unlock()
-                       break
+               dc, err := s.db.conn()
+               if err != nil {
+                       return nil, nil, nil, err
                }
+               dc.Lock()
+               si, err := dc.prepareLocked(s.query)
+               dc.Unlock()
+               if err != nil {
+                       s.db.putConn(dc, err)
+                       return nil, nil, nil, err
+               }
+               s.mu.Lock()
+               cs = connStmt{dc, si}
+               s.css = append(s.css, cs)
+               s.mu.Unlock()
        }
 
        conn := cs.dc
@@ -1361,31 +1371,39 @@ func (s *Stmt) Query(args ...interface{}) (*Rows, error) {
        s.closemu.RLock()
        defer s.closemu.RUnlock()
 
-       dc, releaseConn, si, err := s.connStmt()
-       if err != nil {
-               return nil, err
-       }
+       var rowsi driver.Rows
+       for i := 0; i < maxBadConnRetries; i++ {
+               dc, releaseConn, si, err := s.connStmt()
+               if err != nil {
+                       if err == driver.ErrBadConn {
+                               continue
+                       }
+                       return nil, err
+               }
 
-       ds := driverStmt{dc, si}
-       rowsi, err := rowsiFromStatement(ds, args...)
-       if err != nil {
-               releaseConn(err)
-               return nil, err
-       }
+               rowsi, err = rowsiFromStatement(driverStmt{dc, si}, args...)
+               if err == nil {
+                       // Note: ownership of ci passes to the *Rows, to be freed
+                       // with releaseConn.
+                       rows := &Rows{
+                               dc:    dc,
+                               rowsi: rowsi,
+                               // releaseConn set below
+                       }
+                       s.db.addDep(s, rows)
+                       rows.releaseConn = func(err error) {
+                               releaseConn(err)
+                               s.db.removeDep(s, rows)
+                       }
+                       return rows, nil
+               }
 
-       // Note: ownership of ci passes to the *Rows, to be freed
-       // with releaseConn.
-       rows := &Rows{
-               dc:    dc,
-               rowsi: rowsi,
-               // releaseConn set below
-       }
-       s.db.addDep(s, rows)
-       rows.releaseConn = func(err error) {
                releaseConn(err)
-               s.db.removeDep(s, rows)
+               if err != driver.ErrBadConn {
+                       return nil, err
+               }
        }
-       return rows, nil
+       return nil, driver.ErrBadConn
 }
 
 func rowsiFromStatement(ds driverStmt, args ...interface{}) (driver.Rows, error) {
index a3720c4e76056ced71af1f9923b9f1b3722371d9..aac36f87d3bc945dc8bff381928dc506ed41e5f2 100644 (file)
@@ -348,7 +348,6 @@ func TestStatementQueryRow(t *testing.T) {
                        t.Errorf("%d: age=%d, want %d", n, age, tt.want)
                }
        }
-
 }
 
 // golang.org/issue/3734
@@ -1255,6 +1254,111 @@ func TestStmtCloseOrder(t *testing.T) {
        }
 }
 
+// golang.org/issue/5781
+func TestErrBadConnReconnect(t *testing.T) {
+       db := newTestDB(t, "foo")
+       defer closeDB(t, db)
+       exec(t, db, "CREATE|t1|name=string,age=int32,dead=bool")
+
+       simulateBadConn := func(name string, hook *func() bool, op func() error) {
+               broken, retried := false, false
+               numOpen := db.numOpen
+
+               // simulate a broken connection on the first try
+               *hook = func() bool {
+                       if !broken {
+                               broken = true
+                               return true
+                       }
+                       retried = true
+                       return false
+               }
+
+               if err := op(); err != nil {
+                       t.Errorf(name+": %v", err)
+                       return
+               }
+
+               if !broken || !retried {
+                       t.Error(name + ": Failed to simulate broken connection")
+               }
+               *hook = nil
+
+               if numOpen != db.numOpen {
+                       t.Errorf(name+": leaked %d connection(s)!", db.numOpen-numOpen)
+                       numOpen = db.numOpen
+               }
+       }
+
+       // db.Exec
+       dbExec := func() error {
+               _, err := db.Exec("INSERT|t1|name=?,age=?,dead=?", "Gordon", 3, true)
+               return err
+       }
+       simulateBadConn("db.Exec prepare", &hookPrepareBadConn, dbExec)
+       simulateBadConn("db.Exec exec", &hookExecBadConn, dbExec)
+
+       // db.Query
+       dbQuery := func() error {
+               rows, err := db.Query("SELECT|t1|age,name|")
+               if err == nil {
+                       err = rows.Close()
+               }
+               return err
+       }
+       simulateBadConn("db.Query prepare", &hookPrepareBadConn, dbQuery)
+       simulateBadConn("db.Query query", &hookQueryBadConn, dbQuery)
+
+       // db.Prepare
+       simulateBadConn("db.Prepare", &hookPrepareBadConn, func() error {
+               stmt, err := db.Prepare("INSERT|t1|name=?,age=?,dead=?")
+               if err != nil {
+                       return err
+               }
+               stmt.Close()
+               return nil
+       })
+
+       // stmt.Exec
+       stmt1, err := db.Prepare("INSERT|t1|name=?,age=?,dead=?")
+       if err != nil {
+               t.Fatalf("prepare: %v", err)
+       }
+       defer stmt1.Close()
+       // make sure we must prepare the stmt first
+       for _, cs := range stmt1.css {
+               cs.dc.inUse = true
+       }
+
+       stmtExec := func() error {
+               _, err := stmt1.Exec("Gopher", 3, false)
+               return err
+       }
+       simulateBadConn("stmt.Exec prepare", &hookPrepareBadConn, stmtExec)
+       simulateBadConn("stmt.Exec exec", &hookExecBadConn, stmtExec)
+
+       // stmt.Query
+       stmt2, err := db.Prepare("SELECT|t1|age,name|")
+       if err != nil {
+               t.Fatalf("prepare: %v", err)
+       }
+       defer stmt2.Close()
+       // make sure we must prepare the stmt first
+       for _, cs := range stmt2.css {
+               cs.dc.inUse = true
+       }
+
+       stmtQuery := func() error {
+               rows, err := stmt2.Query()
+               if err == nil {
+                       err = rows.Close()
+               }
+               return err
+       }
+       simulateBadConn("stmt.Query prepare", &hookPrepareBadConn, stmtQuery)
+       simulateBadConn("stmt.Query exec", &hookQueryBadConn, stmtQuery)
+}
+
 type concurrentTest interface {
        init(t testing.TB, db *DB)
        finish(t testing.TB)