// you shouldn't return ErrBadConn.
var ErrBadConn = errors.New("driver: bad connection")
+// Pinger is an optional interface that may be implemented by a Conn.
+//
+// If a Conn does not implement Pinger, the sql package's DB.Ping and
+// DB.PingContext will check if there is at least one Conn available.
+//
+// If Conn.Ping returns ErrBadConn, DB.Ping and DB.PingContext will remove
+// the Conn from pool.
+type Pinger interface {
+ Ping(ctx context.Context) error
+}
+
// Execer is an optional interface that may be implemented by a Conn.
//
// If a Conn does not implement Execer, the sql package's DB.Exec will
// PingContext verifies a connection to the database is still alive,
// establishing a connection if necessary.
func (db *DB) PingContext(ctx context.Context) error {
- // TODO(bradfitz): give drivers an optional hook to implement
- // this in a more efficient or more reliable way, if they
- // have one.
- dc, err := db.conn(ctx, cachedOrNewConn)
+ var dc *driverConn
+ var err error
+
+ for i := 0; i < maxBadConnRetries; i++ {
+ dc, err = db.conn(ctx, cachedOrNewConn)
+ if err != driver.ErrBadConn {
+ break
+ }
+ }
+ if err == driver.ErrBadConn {
+ dc, err = db.conn(ctx, alwaysNewConn)
+ }
if err != nil {
return err
}
- db.putConn(dc, nil)
- return nil
+
+ if pinger, ok := dc.ci.(driver.Pinger); ok {
+ err = pinger.Ping(ctx)
+ }
+ db.putConn(dc, err)
+ return err
}
// Ping verifies a connection to the database is still alive,
db.Exec("ignored")
}
+type pingDriver struct {
+ fails bool
+}
+
+type pingConn struct {
+ badConn
+ driver *pingDriver
+}
+
+var pingError = errors.New("Ping failed")
+
+func (pc pingConn) Ping(ctx context.Context) error {
+ if pc.driver.fails {
+ return pingError
+ }
+ return nil
+}
+
+var _ driver.Pinger = pingConn{}
+
+func (pd *pingDriver) Open(name string) (driver.Conn, error) {
+ return pingConn{driver: pd}, nil
+}
+
+func TestPing(t *testing.T) {
+ driver := &pingDriver{}
+ Register("ping", driver)
+
+ db, err := Open("ping", "ignored")
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if err := db.Ping(); err != nil {
+ t.Errorf("err was %#v, expected nil", err)
+ return
+ }
+
+ driver.fails = true
+ if err := db.Ping(); err != pingError {
+ t.Errorf("err was %#v, expected pingError", err)
+ }
+}
+
func BenchmarkConcurrentDBExec(b *testing.B) {
b.ReportAllocs()
ct := new(concurrentDBExecTest)