]> Cypherpunks repositories - gostls13.git/commitdiff
database/sql: add driver.ResetSessioner and add pool support
authorDaniel Theophanes <kardianos@gmail.com>
Mon, 2 Oct 2017 18:16:53 +0000 (11:16 -0700)
committerDaniel Theophanes <kardianos@gmail.com>
Tue, 24 Oct 2017 21:37:46 +0000 (21:37 +0000)
A single database connection ususally maps to a single session.
A connection pool is logically also a session pool. Most
sessions have a way to reset the session state which is desirable
to prevent one bad query from poisoning another later query with
temp table name conflicts or other persistent session resources.

It also lets drivers provide users with better error messages from
queryies when the underlying transport or query method fails.
Internally the driver connection should now be marked as bad, but
return the actual connection. When ResetSession is called on the
connection it should return driver.ErrBadConn to remove it from
the connection pool. Previously drivers had to choose between
meaningful error messages or poisoning the connection pool.

Lastly update TestPoolExhaustOnCancel from relying on a
WAIT query fixing a flaky timeout issue exposed by this
change.

Fixes #22049
Fixes #20807

Change-Id: I2b5df6d954a38d0ad93bf1922ec16e74c827274c
Reviewed-on: https://go-review.googlesource.com/73033
Run-TryBot: Daniel Theophanes <kardianos@gmail.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: Ian Lance Taylor <iant@golang.org>
src/database/sql/driver/driver.go
src/database/sql/fakedb_test.go
src/database/sql/sql.go
src/database/sql/sql_test.go

index f5a2e7c16c9370be8b9804be2dd9716fa2bdc536..6113af79c573b4de28842126b8831b5cd2d5367a 100644 (file)
@@ -222,6 +222,18 @@ type ConnBeginTx interface {
        BeginTx(ctx context.Context, opts TxOptions) (Tx, error)
 }
 
+// ResetSessioner may be implemented by Conn to allow drivers to reset the
+// session state associated with the connection and to signal a bad connection.
+type ResetSessioner interface {
+       // ResetSession is called while a connection is in the connection
+       // pool. No queries will run on this connection until this method returns.
+       //
+       // If the connection is bad this should return driver.ErrBadConn to prevent
+       // the connection from being returned to the connection pool. Any other
+       // error will be discarded.
+       ResetSession(ctx context.Context) error
+}
+
 // Result is the result of a query execution.
 type Result interface {
        // LastInsertId returns the database's auto-generated ID
index 4dcd096ca4d20db0e4fa4d8dc0f3a37ff32f3a50..070b7834539d52f02f7b070695871c9177acf763 100644 (file)
@@ -55,6 +55,22 @@ type fakeDriver struct {
        dbs        map[string]*fakeDB
 }
 
+type fakeConnector struct {
+       name string
+
+       waiter func(context.Context)
+}
+
+func (c *fakeConnector) Connect(context.Context) (driver.Conn, error) {
+       conn, err := fdriver.Open(c.name)
+       conn.(*fakeConn).waiter = c.waiter
+       return conn, err
+}
+
+func (c *fakeConnector) Driver() driver.Driver {
+       return fdriver
+}
+
 type fakeDB struct {
        name string
 
@@ -107,6 +123,16 @@ type fakeConn struct {
        // bad connection tests; see isBad()
        bad       bool
        stickyBad bool
+
+       skipDirtySession bool // tests that use Conn should set this to true.
+
+       // dirtySession tests ResetSession, true if a query has executed
+       // until ResetSession is called.
+       dirtySession bool
+
+       // The waiter is called before each query. May be used in place of the "WAIT"
+       // directive.
+       waiter func(context.Context)
 }
 
 func (c *fakeConn) touchMem() {
@@ -298,6 +324,9 @@ func (c *fakeConn) isBad() bool {
        if c.stickyBad {
                return true
        } else if c.bad {
+               if c.db == nil {
+                       return false
+               }
                // alternate between bad conn and not bad conn
                c.db.badConn = !c.db.badConn
                return c.db.badConn
@@ -306,6 +335,21 @@ func (c *fakeConn) isBad() bool {
        }
 }
 
+func (c *fakeConn) isDirtyAndMark() bool {
+       if c.skipDirtySession {
+               return false
+       }
+       if c.currTx != nil {
+               c.dirtySession = true
+               return false
+       }
+       if c.dirtySession {
+               return true
+       }
+       c.dirtySession = true
+       return false
+}
+
 func (c *fakeConn) Begin() (driver.Tx, error) {
        if c.isBad() {
                return nil, driver.ErrBadConn
@@ -337,6 +381,14 @@ func setStrictFakeConnClose(t *testing.T) {
        testStrictClose = t
 }
 
+func (c *fakeConn) ResetSession(ctx context.Context) error {
+       c.dirtySession = false
+       if c.isBad() {
+               return driver.ErrBadConn
+       }
+       return nil
+}
+
 func (c *fakeConn) Close() (err error) {
        drv := fdriver.(*fakeDriver)
        defer func() {
@@ -572,6 +624,10 @@ func (c *fakeConn) PrepareContext(ctx context.Context, query string) (driver.Stm
                stmt.cmd = cmd
                parts = parts[1:]
 
+               if c.waiter != nil {
+                       c.waiter(ctx)
+               }
+
                if stmt.wait > 0 {
                        wait := time.NewTimer(stmt.wait)
                        select {
@@ -662,6 +718,9 @@ func (s *fakeStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (d
        if s.c.stickyBad || (hookExecBadConn != nil && hookExecBadConn()) {
                return nil, driver.ErrBadConn
        }
+       if s.c.isDirtyAndMark() {
+               return nil, errors.New("session is dirty")
+       }
 
        err := checkSubsetTypes(s.c.db.allowAny, args)
        if err != nil {
@@ -774,6 +833,9 @@ func (s *fakeStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (
        if s.c.stickyBad || (hookQueryBadConn != nil && hookQueryBadConn()) {
                return nil, driver.ErrBadConn
        }
+       if s.c.isDirtyAndMark() {
+               return nil, errors.New("session is dirty")
+       }
 
        err := checkSubsetTypes(s.c.db.allowAny, args)
        if err != nil {
index 7c357106882d28dc973143ef59b4d1861656b4f2..c17b2b543b57202e89fa7c7b3474651ad7f5ee1b 100644 (file)
@@ -334,6 +334,7 @@ type DB struct {
        // It is closed during db.Close(). The close tells the connectionOpener
        // goroutine to exit.
        openerCh    chan struct{}
+       resetterCh  chan *driverConn
        closed      bool
        dep         map[finalCloser]depSet
        lastPut     map[*driverConn]string // stacktrace of last conn's put; debug only
@@ -341,6 +342,8 @@ type DB struct {
        maxOpen     int                    // <= 0 means unlimited
        maxLifetime time.Duration          // maximum amount of time a connection may be reused
        cleanerCh   chan struct{}
+
+       stop func() // stop cancels the connection opener and the session resetter.
 }
 
 // connReuseStrategy determines how (*DB).conn returns database connections.
@@ -368,6 +371,7 @@ type driverConn struct {
        closed      bool
        finalClosed bool // ci.Close has been called
        openStmt    map[*driverStmt]bool
+       lastErr     error // lastError captures the result of the session resetter.
 
        // guarded by db.mu
        inUse      bool
@@ -376,7 +380,7 @@ type driverConn struct {
 }
 
 func (dc *driverConn) releaseConn(err error) {
-       dc.db.putConn(dc, err)
+       dc.db.putConn(dc, err, true)
 }
 
 func (dc *driverConn) removeOpenStmt(ds *driverStmt) {
@@ -417,6 +421,19 @@ func (dc *driverConn) prepareLocked(ctx context.Context, cg stmtConnGrabber, que
        return ds, nil
 }
 
+// resetSession resets the connection session and sets the lastErr
+// that is checked before returning the connection to another query.
+//
+// resetSession assumes that the embedded mutex is locked when the connection
+// was returned to the pool. This unlocks the mutex.
+func (dc *driverConn) resetSession(ctx context.Context) {
+       defer dc.Unlock() // In case of panic.
+       if dc.closed {    // Check if the database has been closed.
+               return
+       }
+       dc.lastErr = dc.ci.(driver.ResetSessioner).ResetSession(ctx)
+}
+
 // the dc.db's Mutex is held.
 func (dc *driverConn) closeDBLocked() func() error {
        dc.Lock()
@@ -604,14 +621,18 @@ func (t dsnConnector) Driver() driver.Driver {
 // function should be called just once. It is rarely necessary to
 // close a DB.
 func OpenDB(c driver.Connector) *DB {
+       ctx, cancel := context.WithCancel(context.Background())
        db := &DB{
                connector:    c,
                openerCh:     make(chan struct{}, connectionRequestQueueSize),
+               resetterCh:   make(chan *driverConn, 50),
                lastPut:      make(map[*driverConn]string),
                connRequests: make(map[uint64]chan connRequest),
+               stop:         cancel,
        }
 
-       go db.connectionOpener()
+       go db.connectionOpener(ctx)
+       go db.connectionResetter(ctx)
 
        return db
 }
@@ -693,7 +714,6 @@ func (db *DB) Close() error {
                db.mu.Unlock()
                return nil
        }
-       close(db.openerCh)
        if db.cleanerCh != nil {
                close(db.cleanerCh)
        }
@@ -714,6 +734,7 @@ func (db *DB) Close() error {
                        err = err1
                }
        }
+       db.stop()
        return err
 }
 
@@ -901,18 +922,39 @@ func (db *DB) maybeOpenNewConnections() {
 }
 
 // Runs in a separate goroutine, opens new connections when requested.
-func (db *DB) connectionOpener() {
-       for range db.openerCh {
-               db.openNewConnection()
+func (db *DB) connectionOpener(ctx context.Context) {
+       for {
+               select {
+               case <-ctx.Done():
+                       return
+               case <-db.openerCh:
+                       db.openNewConnection(ctx)
+               }
+       }
+}
+
+// connectionResetter runs in a separate goroutine to reset connections async
+// to exported API.
+func (db *DB) connectionResetter(ctx context.Context) {
+       for {
+               select {
+               case <-ctx.Done():
+                       for dc := range db.resetterCh {
+                               dc.Unlock()
+                       }
+                       return
+               case dc := <-db.resetterCh:
+                       dc.resetSession(ctx)
+               }
        }
 }
 
 // Open one new connection
-func (db *DB) openNewConnection() {
+func (db *DB) openNewConnection(ctx context.Context) {
        // maybeOpenNewConnctions has already executed db.numOpen++ before it sent
        // on db.openerCh. This function must execute db.numOpen-- if the
        // connection fails or is closed before returning.
-       ci, err := db.connector.Connect(context.Background())
+       ci, err := db.connector.Connect(ctx)
        db.mu.Lock()
        defer db.mu.Unlock()
        if db.closed {
@@ -987,6 +1029,14 @@ func (db *DB) conn(ctx context.Context, strategy connReuseStrategy) (*driverConn
                        conn.Close()
                        return nil, driver.ErrBadConn
                }
+               // Lock around reading lastErr to ensure the session resetter finished.
+               conn.Lock()
+               err := conn.lastErr
+               conn.Unlock()
+               if err == driver.ErrBadConn {
+                       conn.Close()
+                       return nil, driver.ErrBadConn
+               }
                return conn, nil
        }
 
@@ -1012,7 +1062,7 @@ func (db *DB) conn(ctx context.Context, strategy connReuseStrategy) (*driverConn
                        default:
                        case ret, ok := <-req:
                                if ok {
-                                       db.putConn(ret.conn, ret.err)
+                                       db.putConn(ret.conn, ret.err, false)
                                }
                        }
                        return nil, ctx.Err()
@@ -1024,6 +1074,17 @@ func (db *DB) conn(ctx context.Context, strategy connReuseStrategy) (*driverConn
                                ret.conn.Close()
                                return nil, driver.ErrBadConn
                        }
+                       if ret.conn == nil {
+                               return nil, ret.err
+                       }
+                       // Lock around reading lastErr to ensure the session resetter finished.
+                       ret.conn.Lock()
+                       err := ret.conn.lastErr
+                       ret.conn.Unlock()
+                       if err == driver.ErrBadConn {
+                               ret.conn.Close()
+                               return nil, driver.ErrBadConn
+                       }
                        return ret.conn, ret.err
                }
        }
@@ -1079,7 +1140,7 @@ const debugGetPut = false
 
 // putConn adds a connection to the db's free pool.
 // err is optionally the last error that occurred on this connection.
-func (db *DB) putConn(dc *driverConn, err error) {
+func (db *DB) putConn(dc *driverConn, err error, resetSession bool) {
        db.mu.Lock()
        if !dc.inUse {
                if debugGetPut {
@@ -1110,11 +1171,35 @@ func (db *DB) putConn(dc *driverConn, err error) {
        if putConnHook != nil {
                putConnHook(db, dc)
        }
+       if resetSession {
+               if _, resetSession = dc.ci.(driver.ResetSessioner); resetSession {
+                       // Lock the driverConn here so it isn't released until
+                       // the connection is reset.
+                       // The lock must be taken before the connection is put into
+                       // the pool to prevent it from being taken out before it is reset.
+                       dc.Lock()
+               }
+       }
        added := db.putConnDBLocked(dc, nil)
        db.mu.Unlock()
 
        if !added {
+               if resetSession {
+                       dc.Unlock()
+               }
                dc.Close()
+               return
+       }
+       if !resetSession {
+               return
+       }
+       select {
+       default:
+               // If the resetterCh is blocking then mark the connection
+               // as bad and continue on.
+               dc.lastErr = driver.ErrBadConn
+               dc.Unlock()
+       case db.resetterCh <- dc:
        }
 }
 
index 3551366369c60c40f6d6bf372f69f565862cfe91..dead273503f5ad64ce292579f403e77ea95f03de 100644 (file)
@@ -60,10 +60,12 @@ const fakeDBName = "foo"
 var chrisBirthday = time.Unix(123456789, 0)
 
 func newTestDB(t testing.TB, name string) *DB {
-       db, err := Open("test", fakeDBName)
-       if err != nil {
-               t.Fatalf("Open: %v", err)
-       }
+       return newTestDBConnector(t, &fakeConnector{name: fakeDBName}, name)
+}
+
+func newTestDBConnector(t testing.TB, fc *fakeConnector, name string) *DB {
+       fc.name = fakeDBName
+       db := OpenDB(fc)
        if _, err := db.Exec("WIPE"); err != nil {
                t.Fatalf("exec wipe: %v", err)
        }
@@ -585,24 +587,46 @@ func TestPoolExhaustOnCancel(t *testing.T) {
        if testing.Short() {
                t.Skip("long test")
        }
-       db := newTestDB(t, "people")
-       defer closeDB(t, db)
 
        max := 3
+       var saturate, saturateDone sync.WaitGroup
+       saturate.Add(max)
+       saturateDone.Add(max)
+
+       donePing := make(chan bool)
+       state := 0
+
+       // waiter will be called for all queries, including
+       // initial setup queries. The state is only assigned when no
+       // no queries are made.
+       //
+       // Only allow the first batch of queries to finish once the
+       // second batch of Ping queries have finished.
+       waiter := func(ctx context.Context) {
+               switch state {
+               case 0:
+                       // Nothing. Initial database setup.
+               case 1:
+                       saturate.Done()
+                       select {
+                       case <-ctx.Done():
+                       case <-donePing:
+                       }
+               case 2:
+               }
+       }
+       db := newTestDBConnector(t, &fakeConnector{waiter: waiter}, "people")
+       defer closeDB(t, db)
 
        db.SetMaxOpenConns(max)
 
        // First saturate the connection pool.
        // Then start new requests for a connection that is cancelled after it is requested.
 
-       var saturate, saturateDone sync.WaitGroup
-       saturate.Add(max)
-       saturateDone.Add(max)
-
+       state = 1
        for i := 0; i < max; i++ {
                go func() {
-                       saturate.Done()
-                       rows, err := db.Query("WAIT|500ms|SELECT|people|name,photo|")
+                       rows, err := db.Query("SELECT|people|name,photo|")
                        if err != nil {
                                t.Fatalf("Query: %v", err)
                        }
@@ -612,6 +636,7 @@ func TestPoolExhaustOnCancel(t *testing.T) {
        }
 
        saturate.Wait()
+       state = 2
 
        // Now cancel the request while it is waiting.
        ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
@@ -628,7 +653,7 @@ func TestPoolExhaustOnCancel(t *testing.T) {
                        t.Fatalf("PingContext (Exhaust): %v", err)
                }
        }
-
+       close(donePing)
        saturateDone.Wait()
 
        // Now try to open a normal connection.
@@ -1332,6 +1357,7 @@ func TestConnQuery(t *testing.T) {
        if err != nil {
                t.Fatal(err)
        }
+       conn.dc.ci.(*fakeConn).skipDirtySession = true
        defer conn.Close()
 
        var name string
@@ -1359,6 +1385,7 @@ func TestConnTx(t *testing.T) {
        if err != nil {
                t.Fatal(err)
        }
+       conn.dc.ci.(*fakeConn).skipDirtySession = true
        defer conn.Close()
 
        tx, err := conn.BeginTx(ctx, nil)
@@ -2384,7 +2411,9 @@ func TestManyErrBadConn(t *testing.T) {
                        t.Fatalf("unexpected len(db.freeConn) %d (was expecting %d)", len(db.freeConn), nconn)
                }
                for _, conn := range db.freeConn {
+                       conn.Lock()
                        conn.ci.(*fakeConn).stickyBad = true
+                       conn.Unlock()
                }
                return db
        }
@@ -2474,6 +2503,7 @@ func TestManyErrBadConn(t *testing.T) {
        if err != nil {
                t.Fatal(err)
        }
+       conn.dc.ci.(*fakeConn).skipDirtySession = true
        err = conn.Close()
        if err != nil {
                t.Fatal(err)
@@ -3238,9 +3268,8 @@ func TestIssue18719(t *testing.T) {
 
        // This call will grab the connection and cancel the context
        // after it has done so. Code after must deal with the canceled state.
-       rows, err := tx.QueryContext(ctx, "SELECT|people|name|")
+       _, err = tx.QueryContext(ctx, "SELECT|people|name|")
        if err != nil {
-               rows.Close()
                t.Fatalf("expected error %v but got %v", nil, err)
        }
 
@@ -3263,6 +3292,7 @@ func TestIssue20647(t *testing.T) {
        if err != nil {
                t.Fatal(err)
        }
+       conn.dc.ci.(*fakeConn).skipDirtySession = true
        defer conn.Close()
 
        stmt, err := conn.PrepareContext(ctx, "SELECT|people|name|")
@@ -3567,6 +3597,8 @@ func TestQueryExecContextOnly(t *testing.T) {
                t.Fatal("db.Conn", err)
        }
        defer conn.Close()
+       coc := conn.dc.ci.(*ctxOnlyConn)
+       coc.fc.skipDirtySession = true
 
        _, err = conn.ExecContext(ctx, "WIPE")
        if err != nil {
@@ -3599,7 +3631,6 @@ func TestQueryExecContextOnly(t *testing.T) {
                t.Fatalf("expected %q, got %q", expectedValue, v1)
        }
 
-       coc := conn.dc.ci.(*ctxOnlyConn)
        if !coc.execCtxCalled {
                t.Error("ExecContext not called")
        }