]> Cypherpunks repositories - gostls13.git/commitdiff
database/sql: optimize connection request pool
authorBrad Fitzpatrick <bradfitz@golang.org>
Sun, 17 Mar 2024 01:29:06 +0000 (18:29 -0700)
committerGopher Robot <gobot@golang.org>
Mon, 18 Mar 2024 23:40:44 +0000 (23:40 +0000)
This replaces a map used as a set with a slice.

We were using a surprising amount of CPU in this code, making mapiters
to pull out a random element of the map. Instead, just rand.IntN to pick
a random element of the slice.

It also adds a benchmark:

                     │    before    │                after                │
                     │    sec/op    │   sec/op     vs base                │
    ConnRequestSet-8   1818.0n ± 0%   452.4n ± 0%  -75.12% (p=0.000 n=10)

(whether random is a good policy is a bigger question, but this
 optimizes the current policy without changing behavior)

Updates #66361

Change-Id: I3d456a819cc720c2d18e1befffd2657e5f50f1e7
Reviewed-on: https://go-review.googlesource.com/c/go/+/572119
Reviewed-by: Cherry Mui <cherryyz@google.com>
Reviewed-by: Ian Lance Taylor <iant@google.com>
Reviewed-by: Emmanuel Odeke <emmanuel@orijtech.com>
Reviewed-by: David Chase <drchase@google.com>
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Auto-Submit: Brad Fitzpatrick <bradfitz@golang.org>

src/database/sql/sql.go
src/database/sql/sql_test.go
src/go/build/deps_test.go

index 4f1197dc6e08f26c336d22449af70dfe83315478..b5facdbf2a16ae279962afc0ac4438ce503c01ad 100644 (file)
@@ -21,6 +21,7 @@ import (
        "errors"
        "fmt"
        "io"
+       "math/rand/v2"
        "reflect"
        "runtime"
        "sort"
@@ -497,9 +498,8 @@ type DB struct {
 
        mu           sync.Mutex    // protects following fields
        freeConn     []*driverConn // free connections ordered by returnedAt oldest to newest
-       connRequests map[uint64]chan connRequest
-       nextRequest  uint64 // Next key to use in connRequests.
-       numOpen      int    // number of opened and pending open connections
+       connRequests connRequestSet
+       numOpen      int // number of opened and pending open connections
        // Used to signal the need for new connections
        // a goroutine running connectionOpener() reads on this chan and
        // maybeOpenNewConnections sends on the chan (one send per needed connection)
@@ -814,11 +814,10 @@ func (t dsnConnector) Driver() driver.Driver {
 func OpenDB(c driver.Connector) *DB {
        ctx, cancel := context.WithCancel(context.Background())
        db := &DB{
-               connector:    c,
-               openerCh:     make(chan struct{}, connectionRequestQueueSize),
-               lastPut:      make(map[*driverConn]string),
-               connRequests: make(map[uint64]chan connRequest),
-               stop:         cancel,
+               connector: c,
+               openerCh:  make(chan struct{}, connectionRequestQueueSize),
+               lastPut:   make(map[*driverConn]string),
+               stop:      cancel,
        }
 
        go db.connectionOpener(ctx)
@@ -922,9 +921,7 @@ func (db *DB) Close() error {
        }
        db.freeConn = nil
        db.closed = true
-       for _, req := range db.connRequests {
-               close(req)
-       }
+       db.connRequests.CloseAndRemoveAll()
        db.mu.Unlock()
        for _, fn := range fns {
                err1 := fn()
@@ -1223,7 +1220,7 @@ func (db *DB) Stats() DBStats {
 // If there are connRequests and the connection limit hasn't been reached,
 // then tell the connectionOpener to open new connections.
 func (db *DB) maybeOpenNewConnections() {
-       numRequests := len(db.connRequests)
+       numRequests := db.connRequests.Len()
        if db.maxOpen > 0 {
                numCanOpen := db.maxOpen - db.numOpen
                if numRequests > numCanOpen {
@@ -1297,14 +1294,6 @@ type connRequest struct {
 
 var errDBClosed = errors.New("sql: database is closed")
 
-// nextRequestKeyLocked returns the next connection request key.
-// It is assumed that nextRequest will not overflow.
-func (db *DB) nextRequestKeyLocked() uint64 {
-       next := db.nextRequest
-       db.nextRequest++
-       return next
-}
-
 // conn returns a newly-opened or cached *driverConn.
 func (db *DB) conn(ctx context.Context, strategy connReuseStrategy) (*driverConn, error) {
        db.mu.Lock()
@@ -1352,8 +1341,7 @@ func (db *DB) conn(ctx context.Context, strategy connReuseStrategy) (*driverConn
                // Make the connRequest channel. It's buffered so that the
                // connectionOpener doesn't block while waiting for the req to be read.
                req := make(chan connRequest, 1)
-               reqKey := db.nextRequestKeyLocked()
-               db.connRequests[reqKey] = req
+               delHandle := db.connRequests.Add(req)
                db.waitCount++
                db.mu.Unlock()
 
@@ -1365,16 +1353,26 @@ func (db *DB) conn(ctx context.Context, strategy connReuseStrategy) (*driverConn
                        // Remove the connection request and ensure no value has been sent
                        // on it after removing.
                        db.mu.Lock()
-                       delete(db.connRequests, reqKey)
+                       deleted := db.connRequests.Delete(delHandle)
                        db.mu.Unlock()
 
                        db.waitDuration.Add(int64(time.Since(waitStart)))
 
-                       select {
-                       default:
-                       case ret, ok := <-req:
-                               if ok && ret.conn != nil {
-                                       db.putConn(ret.conn, ret.err, false)
+                       // If we failed to delete it, that means something else
+                       // grabbed it and is about to send on it.
+                       if !deleted {
+                               // TODO(bradfitz): rather than this best effort select, we
+                               // should probably start a goroutine to read from req. This best
+                               // effort select existed before the change to check 'deleted'.
+                               // But if we know for sure it wasn't deleted and a sender is
+                               // outstanding, we should probably block on req (in a new
+                               // goroutine) to get the connection back.
+                               select {
+                               default:
+                               case ret, ok := <-req:
+                                       if ok && ret.conn != nil {
+                                               db.putConn(ret.conn, ret.err, false)
+                                       }
                                }
                        }
                        return nil, ctx.Err()
@@ -1530,13 +1528,7 @@ func (db *DB) putConnDBLocked(dc *driverConn, err error) bool {
        if db.maxOpen > 0 && db.numOpen > db.maxOpen {
                return false
        }
-       if c := len(db.connRequests); c > 0 {
-               var req chan connRequest
-               var reqKey uint64
-               for reqKey, req = range db.connRequests {
-                       break
-               }
-               delete(db.connRequests, reqKey) // Remove from pending requests.
+       if req, ok := db.connRequests.TakeRandom(); ok {
                if err == nil {
                        dc.inUse = true
                }
@@ -3529,3 +3521,104 @@ func withLock(lk sync.Locker, fn func()) {
        defer lk.Unlock() // in case fn panics
        fn()
 }
+
+// connRequestSet is a set of chan connRequest that's
+// optimized for:
+//
+//   - adding an element
+//   - removing an element (only by the caller who added it)
+//   - taking (get + delete) a random element
+//
+// We previously used a map for this but the take of a random element
+// was expensive, making mapiters. This type avoids a map entirely
+// and just uses a slice.
+type connRequestSet struct {
+       // s are the elements in the set.
+       s []connRequestAndIndex
+}
+
+type connRequestAndIndex struct {
+       // req is the element in the set.
+       req chan connRequest
+
+       // curIdx points to the current location of this element in
+       // connRequestSet.s. It gets set to -1 upon removal.
+       curIdx *int
+}
+
+// CloseAndRemoveAll closes all channels in the set
+// and clears the set.
+func (s *connRequestSet) CloseAndRemoveAll() {
+       for _, v := range s.s {
+               close(v.req)
+       }
+       s.s = nil
+}
+
+// Len returns the length of the set.
+func (s *connRequestSet) Len() int { return len(s.s) }
+
+// connRequestDelHandle is an opaque handle to delete an
+// item from calling Add.
+type connRequestDelHandle struct {
+       idx *int // pointer to index; or -1 if not in slice
+}
+
+// Add adds v to the set of waiting requests.
+// The returned connRequestDelHandle can be used to remove the item from
+// the set.
+func (s *connRequestSet) Add(v chan connRequest) connRequestDelHandle {
+       idx := len(s.s)
+       // TODO(bradfitz): for simplicity, this always allocates a new int-sized
+       // allocation to store the index. But generally the set will be small and
+       // under a scannable-threshold. As an optimization, we could permit the *int
+       // to be nil when the set is small and should be scanned. This works even if
+       // the set grows over the threshold with delete handles outstanding because
+       // an element can only move to a lower index. So if it starts with a nil
+       // position, it'll always be in a low index and thus scannable. But that
+       // can be done in a follow-up change.
+       idxPtr := &idx
+       s.s = append(s.s, connRequestAndIndex{v, idxPtr})
+       return connRequestDelHandle{idxPtr}
+}
+
+// Delete removes an element from the set.
+//
+// It reports whether the element was deleted. (It can return false if a caller
+// of TakeRandom took it meanwhile, or upon the second call to Delete)
+func (s *connRequestSet) Delete(h connRequestDelHandle) bool {
+       idx := *h.idx
+       if idx < 0 {
+               return false
+       }
+       s.deleteIndex(idx)
+       return true
+}
+
+func (s *connRequestSet) deleteIndex(idx int) {
+       // Mark item as deleted.
+       *(s.s[idx].curIdx) = -1
+       // Copy last element, updating its position
+       // to its new home.
+       if idx < len(s.s)-1 {
+               last := s.s[len(s.s)-1]
+               *last.curIdx = idx
+               s.s[idx] = last
+       }
+       // Zero out last element (for GC) before shrinking the slice.
+       s.s[len(s.s)-1] = connRequestAndIndex{}
+       s.s = s.s[:len(s.s)-1]
+}
+
+// TakeRandom returns and removes a random element from s
+// and reports whether there was one to take. (It returns ok=false
+// if the set is empty.)
+func (s *connRequestSet) TakeRandom() (v chan connRequest, ok bool) {
+       if len(s.s) == 0 {
+               return nil, false
+       }
+       pick := rand.IntN(len(s.s))
+       e := s.s[pick]
+       s.deleteIndex(pick)
+       return e.req, true
+}
index bf0ecc243f7bf3b928fe2601cd5746c04a091742..e786ecbfab9fb1b8afd08a4c8893f15fdfb36339 100644 (file)
@@ -14,6 +14,7 @@ import (
        "math/rand"
        "reflect"
        "runtime"
+       "slices"
        "strings"
        "sync"
        "sync/atomic"
@@ -2983,7 +2984,7 @@ func TestConnExpiresFreshOutOfPool(t *testing.T) {
                                                return
                                        }
                                        db.mu.Lock()
-                                       ct := len(db.connRequests)
+                                       ct := db.connRequests.Len()
                                        db.mu.Unlock()
                                        if ct > 0 {
                                                return
@@ -4803,3 +4804,104 @@ func BenchmarkGrabConn(b *testing.B) {
                release(nil)
        }
 }
+
+func TestConnRequestSet(t *testing.T) {
+       var s connRequestSet
+       wantLen := func(want int) {
+               t.Helper()
+               if got := s.Len(); got != want {
+                       t.Errorf("Len = %d; want %d", got, want)
+               }
+               if want == 0 && !t.Failed() {
+                       if _, ok := s.TakeRandom(); ok {
+                               t.Fatalf("TakeRandom returned result when empty")
+                       }
+               }
+       }
+       reset := func() { s = connRequestSet{} }
+
+       t.Run("add-delete", func(t *testing.T) {
+               reset()
+               wantLen(0)
+               dh := s.Add(nil)
+               wantLen(1)
+               if !s.Delete(dh) {
+                       t.Fatal("failed to delete")
+               }
+               wantLen(0)
+               if s.Delete(dh) {
+                       t.Error("delete worked twice")
+               }
+               wantLen(0)
+       })
+       t.Run("take-before-delete", func(t *testing.T) {
+               reset()
+               ch1 := make(chan connRequest)
+               dh := s.Add(ch1)
+               wantLen(1)
+               if got, ok := s.TakeRandom(); !ok || got != ch1 {
+                       t.Fatalf("wrong take; ok=%v", ok)
+               }
+               wantLen(0)
+               if s.Delete(dh) {
+                       t.Error("unexpected delete after take")
+               }
+       })
+       t.Run("get-take-many", func(t *testing.T) {
+               reset()
+               m := map[chan connRequest]bool{}
+               const N = 100
+               var inOrder, backOut []chan connRequest
+               for range N {
+                       c := make(chan connRequest)
+                       m[c] = true
+                       s.Add(c)
+                       inOrder = append(inOrder, c)
+               }
+               if s.Len() != N {
+                       t.Fatalf("Len = %v; want %v", s.Len(), N)
+               }
+               for s.Len() > 0 {
+                       c, ok := s.TakeRandom()
+                       if !ok {
+                               t.Fatal("failed to take when non-empty")
+                       }
+                       if !m[c] {
+                               t.Fatal("returned item not in remaining set")
+                       }
+                       delete(m, c)
+                       backOut = append(backOut, c)
+               }
+               if len(m) > 0 {
+                       t.Error("items remain in expected map")
+               }
+               if slices.Equal(inOrder, backOut) { // N! chance of flaking; N=100 is fine
+                       t.Error("wasn't random")
+               }
+       })
+}
+
+func BenchmarkConnRequestSet(b *testing.B) {
+       var s connRequestSet
+       for range b.N {
+               for range 16 {
+                       s.Add(nil)
+               }
+               for range 8 {
+                       if _, ok := s.TakeRandom(); !ok {
+                               b.Fatal("want ok")
+                       }
+               }
+               for range 8 {
+                       s.Add(nil)
+               }
+               for range 16 {
+                       if _, ok := s.TakeRandom(); !ok {
+                               b.Fatal("want ok")
+                       }
+               }
+               if _, ok := s.TakeRandom(); ok {
+                       b.Fatal("unexpected ok")
+               }
+       }
+}
index 26e6e8a77de6a36a6dc532545043e79bc708c73a..427f5a96b5af3a367f42109855b29cb1193e0d2a 100644 (file)
@@ -321,8 +321,9 @@ var depsRules = `
        # databases
        FMT
        < database/sql/internal
-       < database/sql/driver
-       < database/sql;
+       < database/sql/driver;
+
+       database/sql/driver, math/rand/v2 < database/sql;
 
        # images
        FMT, compress/lzw, compress/zlib