"errors"
"fmt"
"io"
+ "math/rand/v2"
"reflect"
"runtime"
"sort"
mu sync.Mutex // protects following fields
freeConn []*driverConn // free connections ordered by returnedAt oldest to newest
- connRequests map[uint64]chan connRequest
- nextRequest uint64 // Next key to use in connRequests.
- numOpen int // number of opened and pending open connections
+ connRequests connRequestSet
+ numOpen int // number of opened and pending open connections
// Used to signal the need for new connections
// a goroutine running connectionOpener() reads on this chan and
// maybeOpenNewConnections sends on the chan (one send per needed connection)
func OpenDB(c driver.Connector) *DB {
ctx, cancel := context.WithCancel(context.Background())
db := &DB{
- connector: c,
- openerCh: make(chan struct{}, connectionRequestQueueSize),
- lastPut: make(map[*driverConn]string),
- connRequests: make(map[uint64]chan connRequest),
- stop: cancel,
+ connector: c,
+ openerCh: make(chan struct{}, connectionRequestQueueSize),
+ lastPut: make(map[*driverConn]string),
+ stop: cancel,
}
go db.connectionOpener(ctx)
}
db.freeConn = nil
db.closed = true
- for _, req := range db.connRequests {
- close(req)
- }
+ db.connRequests.CloseAndRemoveAll()
db.mu.Unlock()
for _, fn := range fns {
err1 := fn()
// If there are connRequests and the connection limit hasn't been reached,
// then tell the connectionOpener to open new connections.
func (db *DB) maybeOpenNewConnections() {
- numRequests := len(db.connRequests)
+ numRequests := db.connRequests.Len()
if db.maxOpen > 0 {
numCanOpen := db.maxOpen - db.numOpen
if numRequests > numCanOpen {
var errDBClosed = errors.New("sql: database is closed")
-// nextRequestKeyLocked returns the next connection request key.
-// It is assumed that nextRequest will not overflow.
-func (db *DB) nextRequestKeyLocked() uint64 {
- next := db.nextRequest
- db.nextRequest++
- return next
-}
-
// conn returns a newly-opened or cached *driverConn.
func (db *DB) conn(ctx context.Context, strategy connReuseStrategy) (*driverConn, error) {
db.mu.Lock()
// Make the connRequest channel. It's buffered so that the
// connectionOpener doesn't block while waiting for the req to be read.
req := make(chan connRequest, 1)
- reqKey := db.nextRequestKeyLocked()
- db.connRequests[reqKey] = req
+ delHandle := db.connRequests.Add(req)
db.waitCount++
db.mu.Unlock()
// Remove the connection request and ensure no value has been sent
// on it after removing.
db.mu.Lock()
- delete(db.connRequests, reqKey)
+ deleted := db.connRequests.Delete(delHandle)
db.mu.Unlock()
db.waitDuration.Add(int64(time.Since(waitStart)))
- select {
- default:
- case ret, ok := <-req:
- if ok && ret.conn != nil {
- db.putConn(ret.conn, ret.err, false)
+ // If we failed to delete it, that means something else
+ // grabbed it and is about to send on it.
+ if !deleted {
+ // TODO(bradfitz): rather than this best effort select, we
+ // should probably start a goroutine to read from req. This best
+ // effort select existed before the change to check 'deleted'.
+ // But if we know for sure it wasn't deleted and a sender is
+ // outstanding, we should probably block on req (in a new
+ // goroutine) to get the connection back.
+ select {
+ default:
+ case ret, ok := <-req:
+ if ok && ret.conn != nil {
+ db.putConn(ret.conn, ret.err, false)
+ }
}
}
return nil, ctx.Err()
if db.maxOpen > 0 && db.numOpen > db.maxOpen {
return false
}
- if c := len(db.connRequests); c > 0 {
- var req chan connRequest
- var reqKey uint64
- for reqKey, req = range db.connRequests {
- break
- }
- delete(db.connRequests, reqKey) // Remove from pending requests.
+ if req, ok := db.connRequests.TakeRandom(); ok {
if err == nil {
dc.inUse = true
}
defer lk.Unlock() // in case fn panics
fn()
}
+
+// connRequestSet is a set of chan connRequest that's
+// optimized for:
+//
+// - adding an element
+// - removing an element (only by the caller who added it)
+// - taking (get + delete) a random element
+//
+// We previously used a map for this but the take of a random element
+// was expensive, making mapiters. This type avoids a map entirely
+// and just uses a slice.
+type connRequestSet struct {
+ // s are the elements in the set.
+ s []connRequestAndIndex
+}
+
+type connRequestAndIndex struct {
+ // req is the element in the set.
+ req chan connRequest
+
+ // curIdx points to the current location of this element in
+ // connRequestSet.s. It gets set to -1 upon removal.
+ curIdx *int
+}
+
+// CloseAndRemoveAll closes all channels in the set
+// and clears the set.
+func (s *connRequestSet) CloseAndRemoveAll() {
+ for _, v := range s.s {
+ close(v.req)
+ }
+ s.s = nil
+}
+
+// Len returns the length of the set.
+func (s *connRequestSet) Len() int { return len(s.s) }
+
+// connRequestDelHandle is an opaque handle to delete an
+// item from calling Add.
+type connRequestDelHandle struct {
+ idx *int // pointer to index; or -1 if not in slice
+}
+
+// Add adds v to the set of waiting requests.
+// The returned connRequestDelHandle can be used to remove the item from
+// the set.
+func (s *connRequestSet) Add(v chan connRequest) connRequestDelHandle {
+ idx := len(s.s)
+ // TODO(bradfitz): for simplicity, this always allocates a new int-sized
+ // allocation to store the index. But generally the set will be small and
+ // under a scannable-threshold. As an optimization, we could permit the *int
+ // to be nil when the set is small and should be scanned. This works even if
+ // the set grows over the threshold with delete handles outstanding because
+ // an element can only move to a lower index. So if it starts with a nil
+ // position, it'll always be in a low index and thus scannable. But that
+ // can be done in a follow-up change.
+ idxPtr := &idx
+ s.s = append(s.s, connRequestAndIndex{v, idxPtr})
+ return connRequestDelHandle{idxPtr}
+}
+
+// Delete removes an element from the set.
+//
+// It reports whether the element was deleted. (It can return false if a caller
+// of TakeRandom took it meanwhile, or upon the second call to Delete)
+func (s *connRequestSet) Delete(h connRequestDelHandle) bool {
+ idx := *h.idx
+ if idx < 0 {
+ return false
+ }
+ s.deleteIndex(idx)
+ return true
+}
+
+func (s *connRequestSet) deleteIndex(idx int) {
+ // Mark item as deleted.
+ *(s.s[idx].curIdx) = -1
+ // Copy last element, updating its position
+ // to its new home.
+ if idx < len(s.s)-1 {
+ last := s.s[len(s.s)-1]
+ *last.curIdx = idx
+ s.s[idx] = last
+ }
+ // Zero out last element (for GC) before shrinking the slice.
+ s.s[len(s.s)-1] = connRequestAndIndex{}
+ s.s = s.s[:len(s.s)-1]
+}
+
+// TakeRandom returns and removes a random element from s
+// and reports whether there was one to take. (It returns ok=false
+// if the set is empty.)
+func (s *connRequestSet) TakeRandom() (v chan connRequest, ok bool) {
+ if len(s.s) == 0 {
+ return nil, false
+ }
+ pick := rand.IntN(len(s.s))
+ e := s.s[pick]
+ s.deleteIndex(pick)
+ return e.req, true
+}
"math/rand"
"reflect"
"runtime"
+ "slices"
"strings"
"sync"
"sync/atomic"
return
}
db.mu.Lock()
- ct := len(db.connRequests)
+ ct := db.connRequests.Len()
db.mu.Unlock()
if ct > 0 {
return
release(nil)
}
}
+
+func TestConnRequestSet(t *testing.T) {
+ var s connRequestSet
+ wantLen := func(want int) {
+ t.Helper()
+ if got := s.Len(); got != want {
+ t.Errorf("Len = %d; want %d", got, want)
+ }
+ if want == 0 && !t.Failed() {
+ if _, ok := s.TakeRandom(); ok {
+ t.Fatalf("TakeRandom returned result when empty")
+ }
+ }
+ }
+ reset := func() { s = connRequestSet{} }
+
+ t.Run("add-delete", func(t *testing.T) {
+ reset()
+ wantLen(0)
+ dh := s.Add(nil)
+ wantLen(1)
+ if !s.Delete(dh) {
+ t.Fatal("failed to delete")
+ }
+ wantLen(0)
+ if s.Delete(dh) {
+ t.Error("delete worked twice")
+ }
+ wantLen(0)
+ })
+ t.Run("take-before-delete", func(t *testing.T) {
+ reset()
+ ch1 := make(chan connRequest)
+ dh := s.Add(ch1)
+ wantLen(1)
+ if got, ok := s.TakeRandom(); !ok || got != ch1 {
+ t.Fatalf("wrong take; ok=%v", ok)
+ }
+ wantLen(0)
+ if s.Delete(dh) {
+ t.Error("unexpected delete after take")
+ }
+ })
+ t.Run("get-take-many", func(t *testing.T) {
+ reset()
+ m := map[chan connRequest]bool{}
+ const N = 100
+ var inOrder, backOut []chan connRequest
+ for range N {
+ c := make(chan connRequest)
+ m[c] = true
+ s.Add(c)
+ inOrder = append(inOrder, c)
+ }
+ if s.Len() != N {
+ t.Fatalf("Len = %v; want %v", s.Len(), N)
+ }
+ for s.Len() > 0 {
+ c, ok := s.TakeRandom()
+ if !ok {
+ t.Fatal("failed to take when non-empty")
+ }
+ if !m[c] {
+ t.Fatal("returned item not in remaining set")
+ }
+ delete(m, c)
+ backOut = append(backOut, c)
+ }
+ if len(m) > 0 {
+ t.Error("items remain in expected map")
+ }
+ if slices.Equal(inOrder, backOut) { // N! chance of flaking; N=100 is fine
+ t.Error("wasn't random")
+ }
+ })
+}
+
+func BenchmarkConnRequestSet(b *testing.B) {
+ var s connRequestSet
+ for range b.N {
+ for range 16 {
+ s.Add(nil)
+ }
+ for range 8 {
+ if _, ok := s.TakeRandom(); !ok {
+ b.Fatal("want ok")
+ }
+ }
+ for range 8 {
+ s.Add(nil)
+ }
+ for range 16 {
+ if _, ok := s.TakeRandom(); !ok {
+ b.Fatal("want ok")
+ }
+ }
+ if _, ok := s.TakeRandom(); ok {
+ b.Fatal("unexpected ok")
+ }
+ }
+}