mu sync.Mutex // protects following fields
freeConn []*driverConn
- connRequests []chan connRequest
- numOpen int // number of opened and pending open connections
+ connRequests map[uint64]chan connRequest
+ nextRequest uint64 // Next key to use in connRequests.
+ numOpen int // number of opened and pending open connections
// Used to signal the need for new connections
// a goroutine running connectionOpener() reads on this chan and
// maybeOpenNewConnections sends on the chan (one send per needed connection)
return nil, fmt.Errorf("sql: unknown driver %q (forgotten import?)", driverName)
}
db := &DB{
- driver: driveri,
- dsn: dataSourceName,
- openerCh: make(chan struct{}, connectionRequestQueueSize),
- lastPut: make(map[*driverConn]string),
+ driver: driveri,
+ dsn: dataSourceName,
+ openerCh: make(chan struct{}, connectionRequestQueueSize),
+ lastPut: make(map[*driverConn]string),
+ connRequests: make(map[uint64]chan connRequest),
}
go db.connectionOpener()
return db, nil
var errDBClosed = errors.New("sql: database is closed")
+// nextRequestKeyLocked returns the next connection request key.
+// It is assumed that nextRequest will not overflow.
+func (db *DB) nextRequestKeyLocked() uint64 {
+ next := db.nextRequest
+ db.nextRequest++
+ return next
+}
+
// conn returns a newly-opened or cached *driverConn.
func (db *DB) conn(ctx context.Context, strategy connReuseStrategy) (*driverConn, error) {
db.mu.Lock()
// Make the connRequest channel. It's buffered so that the
// connectionOpener doesn't block while waiting for the req to be read.
req := make(chan connRequest, 1)
- db.connRequests = append(db.connRequests, req)
+ reqKey := db.nextRequestKeyLocked()
+ db.connRequests[reqKey] = req
db.mu.Unlock()
// Timeout the connection request with the context.
select {
case <-ctx.Done():
+ // Remove the connection request and ensure no value has been sent
+ // on it after removing.
+ db.mu.Lock()
+ delete(db.connRequests, reqKey)
+ select {
+ default:
+ case ret, ok := <-req:
+ if ok {
+ db.putConnDBLocked(ret.conn, ret.err)
+ }
+ }
+ db.mu.Unlock()
return nil, ctx.Err()
case ret, ok := <-req:
if !ok {
return false
}
if c := len(db.connRequests); c > 0 {
- req := db.connRequests[0]
- // This copy is O(n) but in practice faster than a linked list.
- // TODO: consider compacting it down less often and
- // moving the base instead?
- copy(db.connRequests, db.connRequests[1:])
- db.connRequests = db.connRequests[:c-1]
+ var req chan connRequest
+ var reqKey uint64
+ for reqKey, req = range db.connRequests {
+ break
+ }
+ delete(db.connRequests, reqKey) // Remove from pending requests.
if err == nil {
dc.inUse = true
}
}
}
+func TestPoolExhaustOnCancel(t *testing.T) {
+ if testing.Short() {
+ t.Skip("long test")
+ }
+ db := newTestDB(t, "people")
+ defer closeDB(t, db)
+
+ max := 3
+
+ db.SetMaxOpenConns(max)
+
+ // First saturate the connection pool.
+ // Then start new requests for a connection that is cancelled after it is requested.
+
+ var saturate, saturateDone sync.WaitGroup
+ saturate.Add(max)
+ saturateDone.Add(max)
+
+ for i := 0; i < max; i++ {
+ go func() {
+ saturate.Done()
+ rows, err := db.Query("WAIT|500ms|SELECT|people|name,photo|")
+ if err != nil {
+ t.Fatalf("Query: %v", err)
+ }
+ rows.Close()
+ saturateDone.Done()
+ }()
+ }
+
+ saturate.Wait()
+
+ // Now cancel the request while it is waiting.
+ ctx, cancel := context.WithTimeout(context.Background(), time.Second*2)
+ defer cancel()
+
+ for i := 0; i < max; i++ {
+ ctxReq, cancelReq := context.WithCancel(ctx)
+ go func() {
+ time.Sleep(time.Millisecond * 100)
+ cancelReq()
+ }()
+ err := db.PingContext(ctxReq)
+ if err != context.Canceled {
+ t.Fatalf("PingContext (Exhaust): %v", err)
+ }
+ }
+
+ saturateDone.Wait()
+
+ // Now try to open a normal connection.
+ err := db.PingContext(ctx)
+ if err != nil {
+ t.Fatalf("PingContext (Normal): %v", err)
+ }
+}
+
func TestByteOwnership(t *testing.T) {
db := newTestDB(t, "people")
defer closeDB(t, db)