package sql
import (
+ "container/list"
"database/sql/driver"
"errors"
"fmt"
driver driver.Driver
dsn string
- mu sync.Mutex // protects following fields
- freeConn []*driverConn
+ mu sync.Mutex // protects following fields
+ freeConn *list.List // of *driverConn
+ connRequests *list.List // of connRequest
+ numOpen int
+ pendingOpens int
+ // Used to sygnal the need for new connections
+ // a goroutine running connectionOpener() reads on this chan and
+ // maybeOpenNewConnections sends on the chan (one send per needed connection)
+ // It is closed during db.Close(). The close tells the connectionOpener
+ // goroutine to exit.
+ openerCh chan struct{}
closed bool
dep map[finalCloser]depSet
lastPut map[*driverConn]string // stacktrace of last conn's put; debug only
maxIdle int // zero means defaultMaxIdleConns; negative means 0
+ maxOpen int // <= 0 means unlimited
}
// driverConn wraps a driver.Conn with a mutex, to
inUse bool
onPut []func() // code (with db.mu held) run when conn is next returned
dbmuClosed bool // same as closed, but guarded by db.mu, for connIfFree
+ // This is the Element returned by db.freeConn.PushFront(conn).
+ // It's used by connIfFree to remove the conn from the freeConn list.
+ listElem *list.Element
}
func (dc *driverConn) releaseConn(err error) {
}
// the dc.db's Mutex is held.
-func (dc *driverConn) closeDBLocked() error {
+func (dc *driverConn) closeDBLocked() func() error {
dc.Lock()
+ defer dc.Unlock()
if dc.closed {
- dc.Unlock()
- return errors.New("sql: duplicate driverConn close")
+ return func() error { return errors.New("sql: duplicate driverConn close") }
}
dc.closed = true
- dc.Unlock() // not defer; removeDep finalClose calls may need to lock
- return dc.db.removeDepLocked(dc, dc)()
+ return dc.db.removeDepLocked(dc, dc)
}
func (dc *driverConn) Close() error {
err := dc.ci.Close()
dc.ci = nil
dc.finalClosed = true
-
dc.Unlock()
+
+ dc.db.mu.Lock()
+ dc.db.numOpen--
+ dc.db.maybeOpenNewConnections()
+ dc.db.mu.Unlock()
+
return err
}
}
}
+// This is the size of the connectionOpener request chan (dn.openerCh).
+// This value should be larger than the maximum typical value
+// used for db.maxOpen. If maxOpen is significantly larger than
+// connectionRequestQueueSize then it is possible for ALL calls into the *DB
+// to block until the connectionOpener can satify the backlog of requests.
+var connectionRequestQueueSize = 1000000
+
// Open opens a database specified by its database driver name and a
// driver-specific data source name, usually consisting of at least a
// database name and connection information.
return nil, fmt.Errorf("sql: unknown driver %q (forgotten import?)", driverName)
}
db := &DB{
- driver: driveri,
- dsn: dataSourceName,
- lastPut: make(map[*driverConn]string),
- }
+ driver: driveri,
+ dsn: dataSourceName,
+ openerCh: make(chan struct{}, connectionRequestQueueSize),
+ lastPut: make(map[*driverConn]string),
+ }
+ db.freeConn = list.New()
+ db.connRequests = list.New()
+ go db.connectionOpener()
return db, nil
}
// Close closes the database, releasing any open resources.
func (db *DB) Close() error {
db.mu.Lock()
- defer db.mu.Unlock()
+ if db.closed { // Make DB.Close idempotent
+ db.mu.Unlock()
+ return nil
+ }
+ close(db.openerCh)
var err error
- for _, dc := range db.freeConn {
- err1 := dc.closeDBLocked()
+ fns := make([]func() error, 0, db.freeConn.Len())
+ for db.freeConn.Front() != nil {
+ dc := db.freeConn.Front().Value.(*driverConn)
+ dc.listElem = nil
+ fns = append(fns, dc.closeDBLocked())
+ db.freeConn.Remove(db.freeConn.Front())
+ }
+ db.closed = true
+ for db.connRequests.Front() != nil {
+ req := db.connRequests.Front().Value.(connRequest)
+ db.connRequests.Remove(db.connRequests.Front())
+ close(req)
+ }
+ db.mu.Unlock()
+ for _, fn := range fns {
+ err1 := fn()
if err1 != nil {
err = err1
}
}
- db.freeConn = nil
- db.closed = true
return err
}
// SetMaxIdleConns sets the maximum number of connections in the idle
// connection pool.
//
+// If MaxOpenConns is greater than 0 but less than the new MaxIdleConns
+// then the new MaxIdleConns will be reduced to match the MaxOpenConns limit
+//
// If n <= 0, no idle connections are retained.
func (db *DB) SetMaxIdleConns(n int) {
db.mu.Lock()
// No idle connections.
db.maxIdle = -1
}
- for len(db.freeConn) > 0 && len(db.freeConn) > n {
- nfree := len(db.freeConn)
- dc := db.freeConn[nfree-1]
- db.freeConn[nfree-1] = nil
- db.freeConn = db.freeConn[:nfree-1]
+ // Make sure maxIdle doesn't exceed maxOpen
+ if db.maxOpen > 0 && db.maxIdleConnsLocked() > db.maxOpen {
+ db.maxIdle = db.maxOpen
+ }
+ for db.freeConn.Len() > db.maxIdleConnsLocked() {
+ dc := db.freeConn.Back().Value.(*driverConn)
+ dc.listElem = nil
+ db.freeConn.Remove(db.freeConn.Back())
go dc.Close()
}
}
+// SetMaxOpenConns sets the maximum number of open connections to the database.
+//
+// If MaxIdleConns is greater than 0 and the new MaxOpenConns is less than
+// MaxIdleConns, then MaxIdleConns will be reduced to match the new
+// MaxOpenConns limit
+//
+// If n <= 0, then there is no limit on the number of open connections.
+// The default is 0 (unlimited).
+func (db *DB) SetMaxOpenConns(n int) {
+ db.mu.Lock()
+ db.maxOpen = n
+ if n < 0 {
+ db.maxOpen = 0
+ }
+ syncMaxIdle := db.maxOpen > 0 && db.maxIdleConnsLocked() > db.maxOpen
+ db.mu.Unlock()
+ if syncMaxIdle {
+ db.SetMaxIdleConns(n)
+ }
+}
+
+// Assumes db.mu is locked.
+// If there are connRequests and the connection limit hasn't been reached,
+// then tell the connectionOpener to open new connections.
+func (db *DB) maybeOpenNewConnections() {
+ numRequests := db.connRequests.Len() - db.pendingOpens
+ if db.maxOpen > 0 {
+ numCanOpen := db.maxOpen - (db.numOpen + db.pendingOpens)
+ if numRequests > numCanOpen {
+ numRequests = numCanOpen
+ }
+ }
+ for numRequests > 0 {
+ db.pendingOpens++
+ numRequests--
+ db.openerCh <- struct{}{}
+ }
+}
+
+// Runs in a seperate goroutine, opens new connections when requested.
+func (db *DB) connectionOpener() {
+ for _ = range db.openerCh {
+ db.openNewConnection()
+ }
+}
+
+// Open one new connection
+func (db *DB) openNewConnection() {
+ ci, err := db.driver.Open(db.dsn)
+ db.mu.Lock()
+ defer db.mu.Unlock()
+ if db.closed {
+ if err == nil {
+ ci.Close()
+ }
+ return
+ }
+ db.pendingOpens--
+ if err != nil {
+ db.putConnDBLocked(nil, err)
+ return
+ }
+ dc := &driverConn{
+ db: db,
+ ci: ci,
+ }
+ db.addDepLocked(dc, dc)
+ db.numOpen++
+ db.putConnDBLocked(dc, err)
+}
+
+// connRequest represents one request for a new connection
+// When there are no idle connections available, DB.conn will create
+// a new connRequest and put it on the db.connRequests list.
+type connRequest chan<- interface{} // takes either a *driverConn or an error
+
+var errDBClosed = errors.New("sql: database is closed")
+
// conn returns a newly-opened or cached *driverConn
func (db *DB) conn() (*driverConn, error) {
db.mu.Lock()
if db.closed {
db.mu.Unlock()
- return nil, errors.New("sql: database is closed")
+ return nil, errDBClosed
+ }
+
+ // If db.maxOpen > 0 and the number of open connections is over the limit
+ // or there are no free connection, then make a request and wait.
+ if db.maxOpen > 0 && (db.numOpen >= db.maxOpen || db.freeConn.Len() == 0) {
+ // Make the connRequest channel. It's buffered so that the
+ // connectionOpener doesn't block while waiting for the req to be read.
+ ch := make(chan interface{}, 1)
+ req := connRequest(ch)
+ db.connRequests.PushBack(req)
+ db.maybeOpenNewConnections()
+ db.mu.Unlock()
+ ret, ok := <-ch
+ if !ok {
+ return nil, errDBClosed
+ }
+ switch ret.(type) {
+ case *driverConn:
+ return ret.(*driverConn), nil
+ case error:
+ return nil, ret.(error)
+ default:
+ panic("sql: Unexpected type passed through connRequest.ch")
+ }
}
- if n := len(db.freeConn); n > 0 {
- conn := db.freeConn[n-1]
- db.freeConn = db.freeConn[:n-1]
+
+ if f := db.freeConn.Front(); f != nil {
+ conn := f.Value.(*driverConn)
+ conn.listElem = nil
+ db.freeConn.Remove(f)
conn.inUse = true
db.mu.Unlock()
return conn, nil
}
- db.mu.Unlock()
+ db.mu.Unlock()
ci, err := db.driver.Open(db.dsn)
if err != nil {
return nil, err
}
+ db.mu.Lock()
+ db.numOpen++
dc := &driverConn{
db: db,
ci: ci,
}
- db.mu.Lock()
db.addDepLocked(dc, dc)
dc.inUse = true
db.mu.Unlock()
if wanted.inUse {
return nil, errConnBusy
}
- for i, conn := range db.freeConn {
- if conn != wanted {
- continue
- }
- db.freeConn[i] = db.freeConn[len(db.freeConn)-1]
- db.freeConn = db.freeConn[:len(db.freeConn)-1]
+ if wanted.listElem != nil {
+ db.freeConn.Remove(wanted.listElem)
+ wanted.listElem = nil
wanted.inUse = true
return wanted, nil
}
if err == driver.ErrBadConn {
// Don't reuse bad connections.
+ // Since the conn is considered bad and is being discarded, treat it
+ // as closed. Decrement the open count.
+ db.numOpen--
+ db.maybeOpenNewConnections()
db.mu.Unlock()
dc.Close()
return
if putConnHook != nil {
putConnHook(db, dc)
}
- if n := len(db.freeConn); !db.closed && n < db.maxIdleConnsLocked() {
- db.freeConn = append(db.freeConn, dc)
- db.mu.Unlock()
- return
- }
+ added := db.putConnDBLocked(dc, nil)
db.mu.Unlock()
- dc.Close()
+ if !added {
+ dc.Close()
+ }
+}
+
+// Satisfy a connRequest or put the driverConn in the idle pool and return true
+// or return false.
+// putConnDBLocked will satisfy a connRequest if there is one, or it will
+// return the *driverConn to the freeConn list if err != nil and the idle
+// connection limit would not be reached.
+// If err != nil, the value of dc is ignored.
+// If err == nil, then dc must not equal nil.
+// If a connRequest was fullfilled or the *driverConn was placed in the
+// freeConn list, then true is returned, otherwise false is returned.
+func (db *DB) putConnDBLocked(dc *driverConn, err error) bool {
+ if db.connRequests.Len() > 0 {
+ req := db.connRequests.Front().Value.(connRequest)
+ db.connRequests.Remove(db.connRequests.Front())
+ if err != nil {
+ req <- err
+ } else {
+ dc.inUse = true
+ req <- dc
+ }
+ return true
+ } else if err == nil && !db.closed && db.maxIdleConnsLocked() > 0 && db.maxIdleConnsLocked() > db.freeConn.Len() {
+ dc.listElem = db.freeConn.PushFront(dc)
+ return true
+ }
+ return false
}
// Prepare creates a prepared statement for later queries or executions.
return s.stickyErr
}
s.mu.Lock()
- defer s.mu.Unlock()
if s.closed {
+ s.mu.Unlock()
return nil
}
s.closed = true
if s.tx != nil {
s.txsi.Close()
+ s.mu.Unlock()
return nil
}
+ s.mu.Unlock()
return s.db.removeDep(s, s)
}
func (s *Stmt) finalClose() error {
- for _, v := range s.css {
- s.db.noteUnusedDriverStatement(v.dc, v.si)
- v.dc.removeOpenStmt(v.si)
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ if s.css != nil {
+ for _, v := range s.css {
+ s.db.noteUnusedDriverStatement(v.dc, v.si)
+ v.dc.removeOpenStmt(v.si)
+ }
+ s.css = nil
}
- s.css = nil
return nil
}
"database/sql/driver"
"errors"
"fmt"
+ "math/rand"
"reflect"
"runtime"
"strings"
}
freedFrom := make(map[dbConn]string)
putConnHook = func(db *DB, c *driverConn) {
- for _, oc := range db.freeConn {
- if oc == c {
- // print before panic, as panic may get lost due to conflicting panic
- // (all goroutines asleep) elsewhere, since we might not unlock
- // the mutex in freeConn here.
- println("double free of conn. conflicts are:\nA) " + freedFrom[dbConn{db, c}] + "\n\nand\nB) " + stack())
- panic("double free of conn.")
- }
+ if c.listElem != nil {
+ // print before panic, as panic may get lost due to conflicting panic
+ // (all goroutines asleep) elsewhere, since we might not unlock
+ // the mutex in freeConn here.
+ println("double free of conn. conflicts are:\nA) " + freedFrom[dbConn{db, c}] + "\n\nand\nB) " + stack())
+ panic("double free of conn.")
}
freedFrom[dbConn{db, c}] = stack()
}
t.Errorf("Error closing fakeConn: %v", err)
}
})
- for i, dc := range db.freeConn {
+ for node, i := db.freeConn.Front(), 0; node != nil; node, i = node.Next(), i+1 {
+ dc := node.Value.(*driverConn)
if n := len(dc.openStmt); n > 0 {
// Just a sanity check. This is legal in
// general, but if we make the tests clean up
// their statements first, then we can safely
// verify this is always zero here, and any
// other value is a leak.
- t.Errorf("while closing db, freeConn %d/%d had %d open stmts; want 0", i, len(db.freeConn), n)
+ t.Errorf("while closing db, freeConn %d/%d had %d open stmts; want 0", i, db.freeConn.Len(), n)
}
}
err := db.Close()
// numPrepares assumes that db has exactly 1 idle conn and returns
// its count of calls to Prepare
func numPrepares(t *testing.T, db *DB) int {
- if n := len(db.freeConn); n != 1 {
+ if n := db.freeConn.Len(); n != 1 {
t.Fatalf("free conns = %d; want 1", n)
}
- return db.freeConn[0].ci.(*fakeConn).numPrepare
+ return (db.freeConn.Front().Value.(*driverConn)).ci.(*fakeConn).numPrepare
}
func (db *DB) numDeps() int {
func (db *DB) numFreeConns() int {
db.mu.Lock()
defer db.mu.Unlock()
- return len(db.freeConn)
+ return db.freeConn.Len()
}
func (db *DB) dumpDeps(t *testing.T) {
if err != nil {
t.Fatal(err)
}
- if len(db.freeConn) != 1 {
+ if db.freeConn.Len() != 1 {
t.Fatalf("expected 1 free conn")
}
- fakeConn := db.freeConn[0].ci.(*fakeConn)
+ fakeConn := (db.freeConn.Front().Value.(*driverConn)).ci.(*fakeConn)
if made, closed := fakeConn.stmtsMade, fakeConn.stmtsClosed; made != closed {
t.Errorf("statement close mismatch: made %d, closed %d", made, closed)
}
t.Fatal(err)
}
tx.Commit()
- if got := len(db.freeConn); got != 1 {
+ if got := db.freeConn.Len(); got != 1 {
t.Errorf("freeConns = %d; want 1", got)
}
db.SetMaxIdleConns(0)
- if got := len(db.freeConn); got != 0 {
+ if got := db.freeConn.Len(); got != 0 {
t.Errorf("freeConns after set to zero = %d; want 0", got)
}
t.Fatal(err)
}
tx.Commit()
- if got := len(db.freeConn); got != 0 {
+ if got := db.freeConn.Len(); got != 0 {
t.Errorf("freeConns = %d; want 0", got)
}
}
+func TestMaxOpenConns(t *testing.T) {
+ if testing.Short() {
+ t.Skip("skipping in short mode")
+ }
+ defer setHookpostCloseConn(nil)
+ setHookpostCloseConn(func(_ *fakeConn, err error) {
+ if err != nil {
+ t.Errorf("Error closing fakeConn: %v", err)
+ }
+ })
+
+ db := newTestDB(t, "magicquery")
+ defer closeDB(t, db)
+
+ driver := db.driver.(*fakeDriver)
+
+ // Force the number of open connections to 0 so we can get an accurate
+ // count for the test
+ db.SetMaxIdleConns(0)
+
+ if g, w := db.numFreeConns(), 0; g != w {
+ t.Errorf("free conns = %d; want %d", g, w)
+ }
+
+ if n := db.numDepsPollUntil(0, time.Second); n > 0 {
+ t.Errorf("number of dependencies = %d; expected 0", n)
+ db.dumpDeps(t)
+ }
+
+ driver.mu.Lock()
+ opens0 := driver.openCount
+ closes0 := driver.closeCount
+ driver.mu.Unlock()
+
+ db.SetMaxIdleConns(10)
+ db.SetMaxOpenConns(10)
+
+ stmt, err := db.Prepare("SELECT|magicquery|op|op=?,millis=?")
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ // Start 50 parallel slow queries.
+ const (
+ nquery = 50
+ sleepMillis = 25
+ nbatch = 2
+ )
+ var wg sync.WaitGroup
+ for batch := 0; batch < nbatch; batch++ {
+ for i := 0; i < nquery; i++ {
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ var op string
+ if err := stmt.QueryRow("sleep", sleepMillis).Scan(&op); err != nil && err != ErrNoRows {
+ t.Error(err)
+ }
+ }()
+ }
+ // Sleep for twice the expected length of time for the
+ // batch of 50 queries above to finish before starting
+ // the next round.
+ time.Sleep(2 * sleepMillis * time.Millisecond)
+ }
+ wg.Wait()
+
+ if g, w := db.numFreeConns(), 10; g != w {
+ t.Errorf("free conns = %d; want %d", g, w)
+ }
+
+ if n := db.numDepsPollUntil(20, time.Second); n > 20 {
+ t.Errorf("number of dependencies = %d; expected <= 20", n)
+ db.dumpDeps(t)
+ }
+
+ driver.mu.Lock()
+ opens := driver.openCount - opens0
+ closes := driver.closeCount - closes0
+ driver.mu.Unlock()
+
+ if opens > 10 {
+ t.Logf("open calls = %d", opens)
+ t.Logf("close calls = %d", closes)
+ t.Errorf("db connections opened = %d; want <= 10", opens)
+ db.dumpDeps(t)
+ }
+
+ if err := stmt.Close(); err != nil {
+ t.Fatal(err)
+ }
+
+ if g, w := db.numFreeConns(), 10; g != w {
+ t.Errorf("free conns = %d; want %d", g, w)
+ }
+
+ if n := db.numDepsPollUntil(10, time.Second); n > 10 {
+ t.Errorf("number of dependencies = %d; expected <= 10", n)
+ db.dumpDeps(t)
+ }
+
+ db.SetMaxOpenConns(5)
+
+ if g, w := db.numFreeConns(), 5; g != w {
+ t.Errorf("free conns = %d; want %d", g, w)
+ }
+
+ if n := db.numDepsPollUntil(5, time.Second); n > 5 {
+ t.Errorf("number of dependencies = %d; expected 0", n)
+ db.dumpDeps(t)
+ }
+
+ db.SetMaxOpenConns(0)
+
+ if g, w := db.numFreeConns(), 5; g != w {
+ t.Errorf("free conns = %d; want %d", g, w)
+ }
+
+ if n := db.numDepsPollUntil(5, time.Second); n > 5 {
+ t.Errorf("number of dependencies = %d; expected 0", n)
+ db.dumpDeps(t)
+ }
+
+ db.SetMaxIdleConns(0)
+
+ if g, w := db.numFreeConns(), 0; g != w {
+ t.Errorf("free conns = %d; want %d", g, w)
+ }
+
+ if n := db.numDepsPollUntil(0, time.Second); n > 0 {
+ t.Errorf("number of dependencies = %d; expected 0", n)
+ db.dumpDeps(t)
+ }
+}
+
// golang.org/issue/5323
func TestStmtCloseDeps(t *testing.T) {
if testing.Short() {
driver.mu.Lock()
opens := driver.openCount - opens0
closes := driver.closeCount - closes0
- driver.mu.Unlock()
openDelta := (driver.openCount - driver.closeCount) - openDelta0
+ driver.mu.Unlock()
if openDelta > 2 {
t.Logf("open calls = %d", opens)
t.Fatal(err)
}
- if len(db.freeConn) != 1 {
- t.Fatalf("expected 1 freeConn; got %d", len(db.freeConn))
+ if db.freeConn.Len() != 1 {
+ t.Fatalf("expected 1 freeConn; got %d", db.freeConn.Len())
}
- dc := db.freeConn[0]
+ dc := db.freeConn.Front().Value.(*driverConn)
if dc.closed {
t.Errorf("conn shouldn't be closed")
}
}
}
+type concurrentTest interface {
+ init(t testing.TB, db *DB)
+ finish(t testing.TB)
+ test(t testing.TB) error
+}
+
+type concurrentDBQueryTest struct {
+ db *DB
+}
+
+func (c *concurrentDBQueryTest) init(t testing.TB, db *DB) {
+ c.db = db
+}
+
+func (c *concurrentDBQueryTest) finish(t testing.TB) {
+ c.db = nil
+}
+
+func (c *concurrentDBQueryTest) test(t testing.TB) error {
+ rows, err := c.db.Query("SELECT|people|name|")
+ if err != nil {
+ t.Error(err)
+ return err
+ }
+ var name string
+ for rows.Next() {
+ rows.Scan(&name)
+ }
+ rows.Close()
+ return nil
+}
+
+type concurrentDBExecTest struct {
+ db *DB
+}
+
+func (c *concurrentDBExecTest) init(t testing.TB, db *DB) {
+ c.db = db
+}
+
+func (c *concurrentDBExecTest) finish(t testing.TB) {
+ c.db = nil
+}
+
+func (c *concurrentDBExecTest) test(t testing.TB) error {
+ _, err := c.db.Exec("NOSERT|people|name=Chris,age=?,photo=CPHOTO,bdate=?", 3, chrisBirthday)
+ if err != nil {
+ t.Error(err)
+ return err
+ }
+ return nil
+}
+
+type concurrentStmtQueryTest struct {
+ db *DB
+ stmt *Stmt
+}
+
+func (c *concurrentStmtQueryTest) init(t testing.TB, db *DB) {
+ c.db = db
+ var err error
+ c.stmt, err = db.Prepare("SELECT|people|name|")
+ if err != nil {
+ t.Fatal(err)
+ }
+}
+
+func (c *concurrentStmtQueryTest) finish(t testing.TB) {
+ if c.stmt != nil {
+ c.stmt.Close()
+ c.stmt = nil
+ }
+ c.db = nil
+}
+
+func (c *concurrentStmtQueryTest) test(t testing.TB) error {
+ rows, err := c.stmt.Query()
+ if err != nil {
+ t.Errorf("error on query: %v", err)
+ return err
+ }
+
+ var name string
+ for rows.Next() {
+ rows.Scan(&name)
+ }
+ rows.Close()
+ return nil
+}
+
+type concurrentStmtExecTest struct {
+ db *DB
+ stmt *Stmt
+}
+
+func (c *concurrentStmtExecTest) init(t testing.TB, db *DB) {
+ c.db = db
+ var err error
+ c.stmt, err = db.Prepare("NOSERT|people|name=Chris,age=?,photo=CPHOTO,bdate=?")
+ if err != nil {
+ t.Fatal(err)
+ }
+}
+
+func (c *concurrentStmtExecTest) finish(t testing.TB) {
+ if c.stmt != nil {
+ c.stmt.Close()
+ c.stmt = nil
+ }
+ c.db = nil
+}
+
+func (c *concurrentStmtExecTest) test(t testing.TB) error {
+ _, err := c.stmt.Exec(3, chrisBirthday)
+ if err != nil {
+ t.Errorf("error on exec: %v", err)
+ return err
+ }
+ return nil
+}
+
+type concurrentTxQueryTest struct {
+ db *DB
+ tx *Tx
+}
+
+func (c *concurrentTxQueryTest) init(t testing.TB, db *DB) {
+ c.db = db
+ var err error
+ c.tx, err = c.db.Begin()
+ if err != nil {
+ t.Fatal(err)
+ }
+}
+
+func (c *concurrentTxQueryTest) finish(t testing.TB) {
+ if c.tx != nil {
+ c.tx.Rollback()
+ c.tx = nil
+ }
+ c.db = nil
+}
+
+func (c *concurrentTxQueryTest) test(t testing.TB) error {
+ rows, err := c.db.Query("SELECT|people|name|")
+ if err != nil {
+ t.Error(err)
+ return err
+ }
+ var name string
+ for rows.Next() {
+ rows.Scan(&name)
+ }
+ rows.Close()
+ return nil
+}
+
+type concurrentTxExecTest struct {
+ db *DB
+ tx *Tx
+}
+
+func (c *concurrentTxExecTest) init(t testing.TB, db *DB) {
+ c.db = db
+ var err error
+ c.tx, err = c.db.Begin()
+ if err != nil {
+ t.Fatal(err)
+ }
+}
+
+func (c *concurrentTxExecTest) finish(t testing.TB) {
+ if c.tx != nil {
+ c.tx.Rollback()
+ c.tx = nil
+ }
+ c.db = nil
+}
+
+func (c *concurrentTxExecTest) test(t testing.TB) error {
+ _, err := c.tx.Exec("NOSERT|people|name=Chris,age=?,photo=CPHOTO,bdate=?", 3, chrisBirthday)
+ if err != nil {
+ t.Error(err)
+ return err
+ }
+ return nil
+}
+
+type concurrentTxStmtQueryTest struct {
+ db *DB
+ tx *Tx
+ stmt *Stmt
+}
+
+func (c *concurrentTxStmtQueryTest) init(t testing.TB, db *DB) {
+ c.db = db
+ var err error
+ c.tx, err = c.db.Begin()
+ if err != nil {
+ t.Fatal(err)
+ }
+ c.stmt, err = c.tx.Prepare("SELECT|people|name|")
+ if err != nil {
+ t.Fatal(err)
+ }
+}
+
+func (c *concurrentTxStmtQueryTest) finish(t testing.TB) {
+ if c.stmt != nil {
+ c.stmt.Close()
+ c.stmt = nil
+ }
+ if c.tx != nil {
+ c.tx.Rollback()
+ c.tx = nil
+ }
+ c.db = nil
+}
+
+func (c *concurrentTxStmtQueryTest) test(t testing.TB) error {
+ rows, err := c.stmt.Query()
+ if err != nil {
+ t.Errorf("error on query: %v", err)
+ return err
+ }
+
+ var name string
+ for rows.Next() {
+ rows.Scan(&name)
+ }
+ rows.Close()
+ return nil
+}
+
+type concurrentTxStmtExecTest struct {
+ db *DB
+ tx *Tx
+ stmt *Stmt
+}
+
+func (c *concurrentTxStmtExecTest) init(t testing.TB, db *DB) {
+ c.db = db
+ var err error
+ c.tx, err = c.db.Begin()
+ if err != nil {
+ t.Fatal(err)
+ }
+ c.stmt, err = c.tx.Prepare("NOSERT|people|name=Chris,age=?,photo=CPHOTO,bdate=?")
+ if err != nil {
+ t.Fatal(err)
+ }
+}
+
+func (c *concurrentTxStmtExecTest) finish(t testing.TB) {
+ if c.stmt != nil {
+ c.stmt.Close()
+ c.stmt = nil
+ }
+ if c.tx != nil {
+ c.tx.Rollback()
+ c.tx = nil
+ }
+ c.db = nil
+}
+
+func (c *concurrentTxStmtExecTest) test(t testing.TB) error {
+ _, err := c.stmt.Exec(3, chrisBirthday)
+ if err != nil {
+ t.Errorf("error on exec: %v", err)
+ return err
+ }
+ return nil
+}
+
+type concurrentRandomTest struct {
+ tests []concurrentTest
+}
+
+func (c *concurrentRandomTest) init(t testing.TB, db *DB) {
+ c.tests = []concurrentTest{
+ new(concurrentDBQueryTest),
+ new(concurrentDBExecTest),
+ new(concurrentStmtQueryTest),
+ new(concurrentStmtExecTest),
+ new(concurrentTxQueryTest),
+ new(concurrentTxExecTest),
+ new(concurrentTxStmtQueryTest),
+ new(concurrentTxStmtExecTest),
+ }
+ for _, ct := range c.tests {
+ ct.init(t, db)
+ }
+}
+
+func (c *concurrentRandomTest) finish(t testing.TB) {
+ for _, ct := range c.tests {
+ ct.finish(t)
+ }
+}
+
+func (c *concurrentRandomTest) test(t testing.TB) error {
+ ct := c.tests[rand.Intn(len(c.tests))]
+ return ct.test(t)
+}
+
+func doConcurrentTest(t testing.TB, ct concurrentTest) {
+ maxProcs, numReqs := 1, 500
+ if testing.Short() {
+ maxProcs, numReqs = 4, 50
+ }
+ defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(maxProcs))
+
+ db := newTestDB(t, "people")
+ defer closeDB(t, db)
+
+ ct.init(t, db)
+ defer ct.finish(t)
+
+ var wg sync.WaitGroup
+ wg.Add(numReqs)
+
+ reqs := make(chan bool)
+ defer close(reqs)
+
+ for i := 0; i < maxProcs*2; i++ {
+ go func() {
+ for _ = range reqs {
+ err := ct.test(t)
+ if err != nil {
+ wg.Done()
+ continue
+ }
+ wg.Done()
+ }
+ }()
+ }
+
+ for i := 0; i < numReqs; i++ {
+ reqs <- true
+ }
+
+ wg.Wait()
+}
+
func manyConcurrentQueries(t testing.TB) {
maxProcs, numReqs := 16, 500
if testing.Short() {
}
func TestConcurrency(t *testing.T) {
- manyConcurrentQueries(t)
+ doConcurrentTest(t, new(concurrentDBQueryTest))
+ doConcurrentTest(t, new(concurrentDBExecTest))
+ doConcurrentTest(t, new(concurrentStmtQueryTest))
+ doConcurrentTest(t, new(concurrentStmtExecTest))
+ doConcurrentTest(t, new(concurrentTxQueryTest))
+ doConcurrentTest(t, new(concurrentTxExecTest))
+ doConcurrentTest(t, new(concurrentTxStmtQueryTest))
+ doConcurrentTest(t, new(concurrentTxStmtExecTest))
+ doConcurrentTest(t, new(concurrentRandomTest))
+}
+
+func BenchmarkConcurrentDBExec(b *testing.B) {
+ b.ReportAllocs()
+ ct := new(concurrentDBExecTest)
+ for i := 0; i < b.N; i++ {
+ doConcurrentTest(b, ct)
+ }
+}
+
+func BenchmarkConcurrentStmtQuery(b *testing.B) {
+ b.ReportAllocs()
+ ct := new(concurrentStmtQueryTest)
+ for i := 0; i < b.N; i++ {
+ doConcurrentTest(b, ct)
+ }
+}
+
+func BenchmarkConcurrentStmtExec(b *testing.B) {
+ b.ReportAllocs()
+ ct := new(concurrentStmtExecTest)
+ for i := 0; i < b.N; i++ {
+ doConcurrentTest(b, ct)
+ }
+}
+
+func BenchmarkConcurrentTxQuery(b *testing.B) {
+ b.ReportAllocs()
+ ct := new(concurrentTxQueryTest)
+ for i := 0; i < b.N; i++ {
+ doConcurrentTest(b, ct)
+ }
+}
+
+func BenchmarkConcurrentTxExec(b *testing.B) {
+ b.ReportAllocs()
+ ct := new(concurrentTxExecTest)
+ for i := 0; i < b.N; i++ {
+ doConcurrentTest(b, ct)
+ }
+}
+
+func BenchmarkConcurrentTxStmtQuery(b *testing.B) {
+ b.ReportAllocs()
+ ct := new(concurrentTxStmtQueryTest)
+ for i := 0; i < b.N; i++ {
+ doConcurrentTest(b, ct)
+ }
+}
+
+func BenchmarkConcurrentTxStmtExec(b *testing.B) {
+ b.ReportAllocs()
+ ct := new(concurrentTxStmtExecTest)
+ for i := 0; i < b.N; i++ {
+ doConcurrentTest(b, ct)
+ }
}
-func BenchmarkConcurrency(b *testing.B) {
+func BenchmarkConcurrentRandom(b *testing.B) {
b.ReportAllocs()
+ ct := new(concurrentRandomTest)
for i := 0; i < b.N; i++ {
- manyConcurrentQueries(b)
+ doConcurrentTest(b, ct)
}
}