]> Cypherpunks repositories - gostls13.git/commitdiff
database/sql: use errors.Is when checking ErrBadConn
authorDaniel Theophanes <kardianos@gmail.com>
Mon, 12 Jul 2021 14:25:04 +0000 (09:25 -0500)
committerDaniel Theophanes <kardianos@gmail.com>
Wed, 3 Nov 2021 22:51:09 +0000 (22:51 +0000)
When drivers return driver.ErrBadConn, no meaningful
information about what the cause of the problem is
returned. Ideally the driver.ErrBadConn would be
always caught with the retry loop, but this is not
always the case. Drivers today must choose between
returning a useful error and use the rety logic.
This allows supporting both.

Fixes #47142

Change-Id: I454573028f041dfdf874eed6c254fb194ccf6d96
Reviewed-on: https://go-review.googlesource.com/c/go/+/333949
Run-TryBot: Ian Lance Taylor <iant@golang.org>
Trust: Ian Lance Taylor <iant@golang.org>
Trust: Daniel Theophanes <kardianos@gmail.com>
TryBot-Result: Go Bot <gobot@golang.org>
Reviewed-by: Ian Lance Taylor <iant@golang.org>
src/database/sql/driver/driver.go
src/database/sql/fakedb_test.go
src/database/sql/sql.go
src/database/sql/sql_test.go

index f09396175acf86099e5d902fb944d1f8131210e2..ea1de5a8fb6a8b4f056b2b83f99c5ee985a5fce5 100644 (file)
@@ -156,6 +156,9 @@ var ErrSkip = errors.New("driver: skip fast-path; continue as if unimplemented")
 // if there's a possibility that the database server might have
 // performed the operation. Even if the server sends back an error,
 // you shouldn't return ErrBadConn.
+//
+// Errors will be checked using errors.Is. An error may
+// wrap ErrBadConn or implement the Is(error) bool method.
 var ErrBadConn = errors.New("driver: bad connection")
 
 // Pinger is an optional interface that may be implemented by a Conn.
index 4b68f1cba9fe1aa7f9b2806494a8eb4c6c9e066f..34e97e012b1a363d703aafae3a24e2642fe2a8e5 100644 (file)
@@ -96,6 +96,19 @@ type fakeDB struct {
        allowAny bool
 }
 
+type fakeError struct {
+       Message string
+       Wrapped error
+}
+
+func (err fakeError) Error() string {
+       return err.Message
+}
+
+func (err fakeError) Unwrap() error {
+       return err.Wrapped
+}
+
 type table struct {
        mu      sync.Mutex
        colname []string
@@ -368,7 +381,7 @@ func (c *fakeConn) isDirtyAndMark() bool {
 
 func (c *fakeConn) Begin() (driver.Tx, error) {
        if c.isBad() {
-               return nil, driver.ErrBadConn
+               return nil, fakeError{Wrapped: driver.ErrBadConn}
        }
        if c.currTx != nil {
                return nil, errors.New("fakedb: already in a transaction")
@@ -401,7 +414,7 @@ func (c *fakeConn) ResetSession(ctx context.Context) error {
        c.dirtySession = false
        c.currTx = nil
        if c.isBad() {
-               return driver.ErrBadConn
+               return fakeError{Message: "Reset Session: bad conn", Wrapped: driver.ErrBadConn}
        }
        return nil
 }
@@ -629,7 +642,7 @@ func (c *fakeConn) PrepareContext(ctx context.Context, query string) (driver.Stm
        }
 
        if c.stickyBad || (hookPrepareBadConn != nil && hookPrepareBadConn()) {
-               return nil, driver.ErrBadConn
+               return nil, fakeError{Message: "Preapre: Sticky Bad", Wrapped: driver.ErrBadConn}
        }
 
        c.touchMem()
@@ -756,7 +769,7 @@ func (s *fakeStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (d
        }
 
        if s.c.stickyBad || (hookExecBadConn != nil && hookExecBadConn()) {
-               return nil, driver.ErrBadConn
+               return nil, fakeError{Message: "Exec: Sticky Bad", Wrapped: driver.ErrBadConn}
        }
        if s.c.isDirtyAndMark() {
                return nil, errFakeConnSessionDirty
@@ -870,7 +883,7 @@ func (s *fakeStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (
        }
 
        if s.c.stickyBad || (hookQueryBadConn != nil && hookQueryBadConn()) {
-               return nil, driver.ErrBadConn
+               return nil, fakeError{Message: "Query: Sticky Bad", Wrapped: driver.ErrBadConn}
        }
        if s.c.isDirtyAndMark() {
                return nil, errFakeConnSessionDirty
@@ -1031,7 +1044,7 @@ var hookCommitBadConn func() bool
 func (tx *fakeTx) Commit() error {
        tx.c.currTx = nil
        if hookCommitBadConn != nil && hookCommitBadConn() {
-               return driver.ErrBadConn
+               return fakeError{Message: "Commit: Hook Bad Conn", Wrapped: driver.ErrBadConn}
        }
        tx.c.touchMem()
        return nil
@@ -1043,7 +1056,7 @@ var hookRollbackBadConn func() bool
 func (tx *fakeTx) Rollback() error {
        tx.c.currTx = nil
        if hookRollbackBadConn != nil && hookRollbackBadConn() {
-               return driver.ErrBadConn
+               return fakeError{Message: "Rollback: Hook Bad Conn", Wrapped: driver.ErrBadConn}
        }
        tx.c.touchMem()
        return nil
index e4a5a225b0ab793ee6106256ace4fbdf83ec4223..897bca059b85403692569c3d2239b5a5a2f1c6b0 100644 (file)
@@ -848,14 +848,15 @@ func (db *DB) pingDC(ctx context.Context, dc *driverConn, release func(error)) e
 func (db *DB) PingContext(ctx context.Context) error {
        var dc *driverConn
        var err error
-
+       var isBadConn bool
        for i := 0; i < maxBadConnRetries; i++ {
                dc, err = db.conn(ctx, cachedOrNewConn)
-               if err != driver.ErrBadConn {
+               isBadConn = errors.Is(err, driver.ErrBadConn)
+               if !isBadConn {
                        break
                }
        }
-       if err == driver.ErrBadConn {
+       if isBadConn {
                dc, err = db.conn(ctx, alwaysNewConn)
        }
        if err != nil {
@@ -1317,9 +1318,9 @@ func (db *DB) conn(ctx context.Context, strategy connReuseStrategy) (*driverConn
                db.mu.Unlock()
 
                // Reset the session if required.
-               if err := conn.resetSession(ctx); err == driver.ErrBadConn {
+               if err := conn.resetSession(ctx); errors.Is(err, driver.ErrBadConn) {
                        conn.Close()
-                       return nil, driver.ErrBadConn
+                       return nil, err
                }
 
                return conn, nil
@@ -1381,9 +1382,9 @@ func (db *DB) conn(ctx context.Context, strategy connReuseStrategy) (*driverConn
                        }
 
                        // Reset the session if required.
-                       if err := ret.conn.resetSession(ctx); err == driver.ErrBadConn {
+                       if err := ret.conn.resetSession(ctx); errors.Is(err, driver.ErrBadConn) {
                                ret.conn.Close()
-                               return nil, driver.ErrBadConn
+                               return nil, err
                        }
                        return ret.conn, ret.err
                }
@@ -1442,7 +1443,7 @@ const debugGetPut = false
 // putConn adds a connection to the db's free pool.
 // err is optionally the last error that occurred on this connection.
 func (db *DB) putConn(dc *driverConn, err error, resetSession bool) {
-       if err != driver.ErrBadConn {
+       if !errors.Is(err, driver.ErrBadConn) {
                if !dc.validateConnection(resetSession) {
                        err = driver.ErrBadConn
                }
@@ -1456,7 +1457,7 @@ func (db *DB) putConn(dc *driverConn, err error, resetSession bool) {
                panic("sql: connection returned that was never out")
        }
 
-       if err != driver.ErrBadConn && dc.expired(db.maxLifetime) {
+       if !errors.Is(err, driver.ErrBadConn) && dc.expired(db.maxLifetime) {
                db.maxLifetimeClosed++
                err = driver.ErrBadConn
        }
@@ -1471,7 +1472,7 @@ func (db *DB) putConn(dc *driverConn, err error, resetSession bool) {
        }
        dc.onPut = nil
 
-       if err == driver.ErrBadConn {
+       if errors.Is(err, driver.ErrBadConn) {
                // Don't reuse bad connections.
                // Since the conn is considered bad and is being discarded, treat it
                // as closed. Don't decrement the open count here, finalClose will
@@ -1551,13 +1552,15 @@ const maxBadConnRetries = 2
 func (db *DB) PrepareContext(ctx context.Context, query string) (*Stmt, error) {
        var stmt *Stmt
        var err error
+       var isBadConn bool
        for i := 0; i < maxBadConnRetries; i++ {
                stmt, err = db.prepare(ctx, query, cachedOrNewConn)
-               if err != driver.ErrBadConn {
+               isBadConn = errors.Is(err, driver.ErrBadConn)
+               if !isBadConn {
                        break
                }
        }
-       if err == driver.ErrBadConn {
+       if isBadConn {
                return db.prepare(ctx, query, alwaysNewConn)
        }
        return stmt, err
@@ -1627,13 +1630,15 @@ func (db *DB) prepareDC(ctx context.Context, dc *driverConn, release func(error)
 func (db *DB) ExecContext(ctx context.Context, query string, args ...interface{}) (Result, error) {
        var res Result
        var err error
+       var isBadConn bool
        for i := 0; i < maxBadConnRetries; i++ {
                res, err = db.exec(ctx, query, args, cachedOrNewConn)
-               if err != driver.ErrBadConn {
+               isBadConn = errors.Is(err, driver.ErrBadConn)
+               if !isBadConn {
                        break
                }
        }
-       if err == driver.ErrBadConn {
+       if isBadConn {
                return db.exec(ctx, query, args, alwaysNewConn)
        }
        return res, err
@@ -1700,13 +1705,15 @@ func (db *DB) execDC(ctx context.Context, dc *driverConn, release func(error), q
 func (db *DB) QueryContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) {
        var rows *Rows
        var err error
+       var isBadConn bool
        for i := 0; i < maxBadConnRetries; i++ {
                rows, err = db.query(ctx, query, args, cachedOrNewConn)
-               if err != driver.ErrBadConn {
+               isBadConn = errors.Is(err, driver.ErrBadConn)
+               if !isBadConn {
                        break
                }
        }
-       if err == driver.ErrBadConn {
+       if isBadConn {
                return db.query(ctx, query, args, alwaysNewConn)
        }
        return rows, err
@@ -1835,13 +1842,15 @@ func (db *DB) QueryRow(query string, args ...interface{}) *Row {
 func (db *DB) BeginTx(ctx context.Context, opts *TxOptions) (*Tx, error) {
        var tx *Tx
        var err error
+       var isBadConn bool
        for i := 0; i < maxBadConnRetries; i++ {
                tx, err = db.begin(ctx, opts, cachedOrNewConn)
-               if err != driver.ErrBadConn {
+               isBadConn = errors.Is(err, driver.ErrBadConn)
+               if !isBadConn {
                        break
                }
        }
-       if err == driver.ErrBadConn {
+       if isBadConn {
                return db.begin(ctx, opts, alwaysNewConn)
        }
        return tx, err
@@ -1914,13 +1923,15 @@ var ErrConnDone = errors.New("sql: connection is already closed")
 func (db *DB) Conn(ctx context.Context) (*Conn, error) {
        var dc *driverConn
        var err error
+       var isBadConn bool
        for i := 0; i < maxBadConnRetries; i++ {
                dc, err = db.conn(ctx, cachedOrNewConn)
-               if err != driver.ErrBadConn {
+               isBadConn = errors.Is(err, driver.ErrBadConn)
+               if !isBadConn {
                        break
                }
        }
-       if err == driver.ErrBadConn {
+       if isBadConn {
                dc, err = db.conn(ctx, alwaysNewConn)
        }
        if err != nil {
@@ -2032,8 +2043,8 @@ func (c *Conn) PrepareContext(ctx context.Context, query string) (*Stmt, error)
 // Raw executes f exposing the underlying driver connection for the
 // duration of f. The driverConn must not be used outside of f.
 //
-// Once f returns and err is not equal to driver.ErrBadConn, the Conn will
-// continue to be usable until Conn.Close is called.
+// Once f returns and err is not driver.ErrBadConn, the Conn will continue to be usable
+// until Conn.Close is called.
 func (c *Conn) Raw(f func(driverConn interface{}) error) (err error) {
        var dc *driverConn
        var release releaseConn
@@ -2084,7 +2095,7 @@ func (c *Conn) BeginTx(ctx context.Context, opts *TxOptions) (*Tx, error) {
 // as the sql operation is done with the dc.
 func (c *Conn) closemuRUnlockCondReleaseConn(err error) {
        c.closemu.RUnlock()
-       if err == driver.ErrBadConn {
+       if errors.Is(err, driver.ErrBadConn) {
                c.close(err)
        }
 }
@@ -2278,7 +2289,7 @@ func (tx *Tx) Commit() error {
        withLock(tx.dc, func() {
                err = tx.txi.Commit()
        })
-       if err != driver.ErrBadConn {
+       if !errors.Is(err, driver.ErrBadConn) {
                tx.closePrepared()
        }
        tx.close(err)
@@ -2310,7 +2321,7 @@ func (tx *Tx) rollback(discardConn bool) error {
        withLock(tx.dc, func() {
                err = tx.txi.Rollback()
        })
-       if err != driver.ErrBadConn {
+       if !errors.Is(err, driver.ErrBadConn) {
                tx.closePrepared()
        }
        if discardConn {
@@ -2616,7 +2627,7 @@ func (s *Stmt) ExecContext(ctx context.Context, args ...interface{}) (Result, er
                }
                dc, releaseConn, ds, err := s.connStmt(ctx, strategy)
                if err != nil {
-                       if err == driver.ErrBadConn {
+                       if errors.Is(err, driver.ErrBadConn) {
                                continue
                        }
                        return nil, err
@@ -2624,7 +2635,7 @@ func (s *Stmt) ExecContext(ctx context.Context, args ...interface{}) (Result, er
 
                res, err = resultFromStatement(ctx, dc.ci, ds, args...)
                releaseConn(err)
-               if err != driver.ErrBadConn {
+               if !errors.Is(err, driver.ErrBadConn) {
                        return res, err
                }
        }
@@ -2764,7 +2775,7 @@ func (s *Stmt) QueryContext(ctx context.Context, args ...interface{}) (*Rows, er
                }
                dc, releaseConn, ds, err := s.connStmt(ctx, strategy)
                if err != nil {
-                       if err == driver.ErrBadConn {
+                       if errors.Is(err, driver.ErrBadConn) {
                                continue
                        }
                        return nil, err
@@ -2798,7 +2809,7 @@ func (s *Stmt) QueryContext(ctx context.Context, args ...interface{}) (*Rows, er
                }
 
                releaseConn(err)
-               if err != driver.ErrBadConn {
+               if !errors.Is(err, driver.ErrBadConn) {
                        return nil, err
                }
        }
index 15c30e0d00951e1718539fad4fcc47cb5ed50e42..889adc3164ed4d6294fb70912af5df7622ee918c 100644 (file)
@@ -3159,7 +3159,7 @@ func TestTxEndBadConn(t *testing.T) {
                        return broken
                }
 
-               if err := op(); err != driver.ErrBadConn {
+               if err := op(); !errors.Is(err, driver.ErrBadConn) {
                        t.Errorf(name+": %v", err)
                        return
                }