]> Cypherpunks repositories - gostls13.git/commitdiff
database/sql: simplify retry logic when the connection is bad
authorJinzhu <wosmvp@gmail.com>
Mon, 5 Sep 2022 09:02:11 +0000 (09:02 +0000)
committerDaniel Theophanes <kardianos@gmail.com>
Wed, 7 Sep 2022 14:33:56 +0000 (14:33 +0000)
Simplify retry logic when got bad connection

Change-Id: I92494c6c020576ec01bc4868334ee920ded7aa57
GitHub-Last-Rev: 7499b0c9419a31c9adce6d5096a1924aa3612f1d
GitHub-Pull-Request: golang/go#54043
Reviewed-on: https://go-review.googlesource.com/c/go/+/419182
Reviewed-by: Daniel Theophanes <kardianos@gmail.com>
Reviewed-by: Cherry Mui <cherryyz@google.com>
Reviewed-by: Benny Siegert <bsiegert@gmail.com>
Run-TryBot: hopehook <hopehook@golangcn.org>
TryBot-Result: Gopher Robot <gobot@golang.org>

src/database/sql/sql.go

index 854a895281ac47e3a04dba3c64b2ff15eaf6acea..e74dd875f900a7ba174451c3302854f27e69f268 100644 (file)
@@ -846,17 +846,12 @@ func (db *DB) pingDC(ctx context.Context, dc *driverConn, release func(error)) e
 func (db *DB) PingContext(ctx context.Context) error {
        var dc *driverConn
        var err error
-       var isBadConn bool
-       for i := 0; i < maxBadConnRetries; i++ {
-               dc, err = db.conn(ctx, cachedOrNewConn)
-               isBadConn = errors.Is(err, driver.ErrBadConn)
-               if !isBadConn {
-                       break
-               }
-       }
-       if isBadConn {
-               dc, err = db.conn(ctx, alwaysNewConn)
-       }
+
+       err = db.retry(func(strategy connReuseStrategy) error {
+               dc, err = db.conn(ctx, strategy)
+               return err
+       })
+
        if err != nil {
                return err
        }
@@ -1539,6 +1534,18 @@ func (db *DB) putConnDBLocked(dc *driverConn, err error) bool {
 // connection to be opened.
 const maxBadConnRetries = 2
 
+func (db *DB) retry(fn func(strategy connReuseStrategy) error) error {
+       for i := int64(0); i < maxBadConnRetries; i++ {
+               err := fn(cachedOrNewConn)
+               // retry if err is driver.ErrBadConn
+               if err == nil || !errors.Is(err, driver.ErrBadConn) {
+                       return err
+               }
+       }
+
+       return fn(alwaysNewConn)
+}
+
 // PrepareContext creates a prepared statement for later queries or executions.
 // Multiple queries or executions may be run concurrently from the
 // returned statement.
@@ -1550,17 +1557,12 @@ const maxBadConnRetries = 2
 func (db *DB) PrepareContext(ctx context.Context, query string) (*Stmt, error) {
        var stmt *Stmt
        var err error
-       var isBadConn bool
-       for i := 0; i < maxBadConnRetries; i++ {
-               stmt, err = db.prepare(ctx, query, cachedOrNewConn)
-               isBadConn = errors.Is(err, driver.ErrBadConn)
-               if !isBadConn {
-                       break
-               }
-       }
-       if isBadConn {
-               return db.prepare(ctx, query, alwaysNewConn)
-       }
+
+       err = db.retry(func(strategy connReuseStrategy) error {
+               stmt, err = db.prepare(ctx, query, strategy)
+               return err
+       })
+
        return stmt, err
 }
 
@@ -1628,17 +1630,12 @@ func (db *DB) prepareDC(ctx context.Context, dc *driverConn, release func(error)
 func (db *DB) ExecContext(ctx context.Context, query string, args ...any) (Result, error) {
        var res Result
        var err error
-       var isBadConn bool
-       for i := 0; i < maxBadConnRetries; i++ {
-               res, err = db.exec(ctx, query, args, cachedOrNewConn)
-               isBadConn = errors.Is(err, driver.ErrBadConn)
-               if !isBadConn {
-                       break
-               }
-       }
-       if isBadConn {
-               return db.exec(ctx, query, args, alwaysNewConn)
-       }
+
+       err = db.retry(func(strategy connReuseStrategy) error {
+               res, err = db.exec(ctx, query, args, strategy)
+               return err
+       })
+
        return res, err
 }
 
@@ -1703,17 +1700,12 @@ func (db *DB) execDC(ctx context.Context, dc *driverConn, release func(error), q
 func (db *DB) QueryContext(ctx context.Context, query string, args ...any) (*Rows, error) {
        var rows *Rows
        var err error
-       var isBadConn bool
-       for i := 0; i < maxBadConnRetries; i++ {
-               rows, err = db.query(ctx, query, args, cachedOrNewConn)
-               isBadConn = errors.Is(err, driver.ErrBadConn)
-               if !isBadConn {
-                       break
-               }
-       }
-       if isBadConn {
-               return db.query(ctx, query, args, alwaysNewConn)
-       }
+
+       err = db.retry(func(strategy connReuseStrategy) error {
+               rows, err = db.query(ctx, query, args, strategy)
+               return err
+       })
+
        return rows, err
 }
 
@@ -1840,17 +1832,12 @@ func (db *DB) QueryRow(query string, args ...any) *Row {
 func (db *DB) BeginTx(ctx context.Context, opts *TxOptions) (*Tx, error) {
        var tx *Tx
        var err error
-       var isBadConn bool
-       for i := 0; i < maxBadConnRetries; i++ {
-               tx, err = db.begin(ctx, opts, cachedOrNewConn)
-               isBadConn = errors.Is(err, driver.ErrBadConn)
-               if !isBadConn {
-                       break
-               }
-       }
-       if isBadConn {
-               return db.begin(ctx, opts, alwaysNewConn)
-       }
+
+       err = db.retry(func(strategy connReuseStrategy) error {
+               tx, err = db.begin(ctx, opts, strategy)
+               return err
+       })
+
        return tx, err
 }
 
@@ -1921,17 +1908,12 @@ var ErrConnDone = errors.New("sql: connection is already closed")
 func (db *DB) Conn(ctx context.Context) (*Conn, error) {
        var dc *driverConn
        var err error
-       var isBadConn bool
-       for i := 0; i < maxBadConnRetries; i++ {
-               dc, err = db.conn(ctx, cachedOrNewConn)
-               isBadConn = errors.Is(err, driver.ErrBadConn)
-               if !isBadConn {
-                       break
-               }
-       }
-       if isBadConn {
-               dc, err = db.conn(ctx, alwaysNewConn)
-       }
+
+       err = db.retry(func(strategy connReuseStrategy) error {
+               dc, err = db.conn(ctx, strategy)
+               return err
+       })
+
        if err != nil {
                return nil, err
        }
@@ -2620,26 +2602,18 @@ func (s *Stmt) ExecContext(ctx context.Context, args ...any) (Result, error) {
        defer s.closemu.RUnlock()
 
        var res Result
-       strategy := cachedOrNewConn
-       for i := 0; i < maxBadConnRetries+1; i++ {
-               if i == maxBadConnRetries {
-                       strategy = alwaysNewConn
-               }
+       err := s.db.retry(func(strategy connReuseStrategy) error {
                dc, releaseConn, ds, err := s.connStmt(ctx, strategy)
                if err != nil {
-                       if errors.Is(err, driver.ErrBadConn) {
-                               continue
-                       }
-                       return nil, err
+                       return err
                }
 
                res, err = resultFromStatement(ctx, dc.ci, ds, args...)
                releaseConn(err)
-               if !errors.Is(err, driver.ErrBadConn) {
-                       return res, err
-               }
-       }
-       return nil, driver.ErrBadConn
+               return err
+       })
+
+       return res, err
 }
 
 // Exec executes a prepared statement with the given arguments and
@@ -2768,24 +2742,19 @@ func (s *Stmt) QueryContext(ctx context.Context, args ...any) (*Rows, error) {
        defer s.closemu.RUnlock()
 
        var rowsi driver.Rows
-       strategy := cachedOrNewConn
-       for i := 0; i < maxBadConnRetries+1; i++ {
-               if i == maxBadConnRetries {
-                       strategy = alwaysNewConn
-               }
+       var rows *Rows
+
+       err := s.db.retry(func(strategy connReuseStrategy) error {
                dc, releaseConn, ds, err := s.connStmt(ctx, strategy)
                if err != nil {
-                       if errors.Is(err, driver.ErrBadConn) {
-                               continue
-                       }
-                       return nil, err
+                       return err
                }
 
                rowsi, err = rowsiFromStatement(ctx, dc.ci, ds, args...)
                if err == nil {
                        // Note: ownership of ci passes to the *Rows, to be freed
                        // with releaseConn.
-                       rows := &Rows{
+                       rows = &Rows{
                                dc:    dc,
                                rowsi: rowsi,
                                // releaseConn set below
@@ -2805,15 +2774,14 @@ func (s *Stmt) QueryContext(ctx context.Context, args ...any) (*Rows, error) {
                                txctx = s.cg.txCtx()
                        }
                        rows.initContextClose(ctx, txctx)
-                       return rows, nil
+                       return nil
                }
 
                releaseConn(err)
-               if !errors.Is(err, driver.ErrBadConn) {
-                       return nil, err
-               }
-       }
-       return nil, driver.ErrBadConn
+               return err
+       })
+
+       return rows, err
 }
 
 // Query executes a prepared query statement with the given arguments