]> Cypherpunks repositories - gostls13.git/commitdiff
database/sql: refcounting and lifetime fixes
authorBrad Fitzpatrick <bradfitz@golang.org>
Wed, 20 Feb 2013 23:35:27 +0000 (15:35 -0800)
committerBrad Fitzpatrick <bradfitz@golang.org>
Wed, 20 Feb 2013 23:35:27 +0000 (15:35 -0800)
Simplifies the contract for Driver.Stmt.Close in
the process of fixing issue 3865.

Fixes #3865
Update #4459 (maybe fixes it; uninvestigated)

R=golang-dev, rsc
CC=golang-dev
https://golang.org/cl/7363043

src/pkg/database/sql/driver/driver.go
src/pkg/database/sql/sql.go
src/pkg/database/sql/sql_test.go

index 88c87eeea080cd3544942b519ad923d515b0285c..2434e419baa106e56a2bc589a5f816cde463ab51 100644 (file)
@@ -115,23 +115,8 @@ type Result interface {
 type Stmt interface {
        // Close closes the statement.
        //
-       // Closing a statement should not interrupt any outstanding
-       // query created from that statement. That is, the following
-       // order of operations is valid:
-       //
-       //  * create a driver statement
-       //  * call Query on statement, returning Rows
-       //  * close the statement
-       //  * read from Rows
-       //
-       // If closing a statement invalidates currently-running
-       // queries, the final step above will incorrectly fail.
-       //
-       // TODO(bradfitz): possibly remove the restriction above, if
-       // enough driver authors object and find it complicates their
-       // code too much. The sql package could be smarter about
-       // refcounting the statement and closing it at the appropriate
-       // time.
+       // As of Go 1.1, a Stmt will not be closed if it's in use
+       // by any queries.
        Close() error
 
        // NumInput returns the number of placeholder parameters.
index 376390aa7197d2010a27c0df4fa1924ae5fed942..28339921091aa2420e1a27f06d44e64e2e8cd60e 100644 (file)
@@ -11,6 +11,7 @@ import (
        "errors"
        "fmt"
        "io"
+       "runtime"
        "sync"
 )
 
@@ -189,9 +190,66 @@ type DB struct {
        driver driver.Driver
        dsn    string
 
-       mu       sync.Mutex // protects freeConn and closed
-       freeConn []driver.Conn
-       closed   bool
+       mu        sync.Mutex           // protects following fields
+       outConn   map[driver.Conn]bool // whether the conn is in use
+       freeConn  []driver.Conn
+       closed    bool
+       dep       map[finalCloser]depSet
+       onConnPut map[driver.Conn][]func() // code (with mu held) run when conn is next returned
+       lastPut   map[driver.Conn]string   // stacktrace of last conn's put; debug only
+}
+
+// depSet is a finalCloser's outstanding dependencies
+type depSet map[interface{}]bool // set of true bools
+
+// The finalCloser interface is used by (*DB).addDep and (*DB).get
+type finalCloser interface {
+       // finalClose is called when the reference count of an object
+       // goes to zero. (*DB).mu is not held while calling it.
+       finalClose() error
+}
+
+// addDep notes that x now depends on dep, and x's finalClose won't be
+// called until all of x's dependencies are removed with removeDep.
+func (db *DB) addDep(x finalCloser, dep interface{}) {
+       //println(fmt.Sprintf("addDep(%T %p, %T %p)", x, x, dep, dep))
+       db.mu.Lock()
+       defer db.mu.Unlock()
+       if db.dep == nil {
+               db.dep = make(map[finalCloser]depSet)
+       }
+       xdep := db.dep[x]
+       if xdep == nil {
+               xdep = make(depSet)
+               db.dep[x] = xdep
+       }
+       xdep[dep] = true
+}
+
+// removeDep notes that x no longer depends on dep.
+// If x still has dependencies, nil is returned.
+// If x no longer has any dependencies, its finalClose method will be
+// called and its error value will be returned.
+func (db *DB) removeDep(x finalCloser, dep interface{}) error {
+       //println(fmt.Sprintf("removeDep(%T %p, %T %p)", x, x, dep, dep))
+       done := false
+
+       db.mu.Lock()
+       xdep := db.dep[x]
+       if xdep != nil {
+               delete(xdep, dep)
+               if len(xdep) == 0 {
+                       delete(db.dep, x)
+                       done = true
+               }
+       }
+       db.mu.Unlock()
+
+       if !done {
+               return nil
+       }
+       //println(fmt.Sprintf("calling final close on %T %v (%#v)", x, x, x))
+       return x.finalClose()
 }
 
 // Open opens a database specified by its database driver name and a
@@ -201,11 +259,20 @@ type DB struct {
 // Most users will open a database via a driver-specific connection
 // helper function that returns a *DB.
 func Open(driverName, dataSourceName string) (*DB, error) {
-       driver, ok := drivers[driverName]
+       driveri, ok := drivers[driverName]
        if !ok {
                return nil, fmt.Errorf("sql: unknown driver %q (forgotten import?)", driverName)
        }
-       return &DB{driver: driver, dsn: dataSourceName}, nil
+       // TODO: optionally proactively connect to a Conn to check
+       // the dataSourceName: golang.org/issue/4804
+       db := &DB{
+               driver:    driveri,
+               dsn:       dataSourceName,
+               outConn:   make(map[driver.Conn]bool),
+               lastPut:   make(map[driver.Conn]string),
+               onConnPut: make(map[driver.Conn][]func()),
+       }
+       return db, nil
 }
 
 // Close closes the database, releasing any open resources.
@@ -241,22 +308,38 @@ func (db *DB) conn() (driver.Conn, error) {
        if n := len(db.freeConn); n > 0 {
                conn := db.freeConn[n-1]
                db.freeConn = db.freeConn[:n-1]
+               db.outConn[conn] = true
                db.mu.Unlock()
                return conn, nil
        }
        db.mu.Unlock()
-       return db.driver.Open(db.dsn)
+       conn, err := db.driver.Open(db.dsn)
+       if err == nil {
+               db.mu.Lock()
+               db.outConn[conn] = true
+               db.mu.Unlock()
+       }
+       return conn, err
 }
 
+// connIfFree returns (wanted, true) if wanted is still a valid conn and
+// isn't in use.
+//
+// If wanted is valid but in use, connIfFree returns (wanted, false).
+// If wanted is invalid, connIfFre returns (nil, false).
 func (db *DB) connIfFree(wanted driver.Conn) (conn driver.Conn, ok bool) {
        db.mu.Lock()
        defer db.mu.Unlock()
+       if db.outConn[wanted] {
+               return conn, false
+       }
        for i, conn := range db.freeConn {
                if conn != wanted {
                        continue
                }
                db.freeConn[i] = db.freeConn[len(db.freeConn)-1]
                db.freeConn = db.freeConn[:len(db.freeConn)-1]
+               db.outConn[wanted] = true
                return wanted, true
        }
        return nil, false
@@ -265,14 +348,52 @@ func (db *DB) connIfFree(wanted driver.Conn) (conn driver.Conn, ok bool) {
 // putConnHook is a hook for testing.
 var putConnHook func(*DB, driver.Conn)
 
+// noteUnusedDriverStatement notes that si is no longer used and should
+// be closed whenever possible (when c is next not in use), unless c is
+// already closed.
+func (db *DB) noteUnusedDriverStatement(c driver.Conn, si driver.Stmt) {
+       db.mu.Lock()
+       defer db.mu.Unlock()
+       if db.outConn[c] {
+               db.onConnPut[c] = append(db.onConnPut[c], func() {
+                       si.Close()
+               })
+       } else {
+               si.Close()
+       }
+}
+
+// debugGetPut determines whether getConn & putConn calls' stack traces
+// are returned for more verbose crashes.
+const debugGetPut = false
+
 // putConn adds a connection to the db's free pool.
 // err is optionally the last error that occurred on this connection.
 func (db *DB) putConn(c driver.Conn, err error) {
+       db.mu.Lock()
+       if !db.outConn[c] {
+               if debugGetPut {
+                       fmt.Printf("putConn(%v) DUPLICATE was: %s\n\nPREVIOUS was: %s", c, stack(), db.lastPut[c])
+               }
+               panic("sql: connection returned that was never out")
+       }
+       if debugGetPut {
+               db.lastPut[c] = stack()
+       }
+       delete(db.outConn, c)
+
+       if fns, ok := db.onConnPut[c]; ok {
+               for _, fn := range fns {
+                       fn()
+               }
+               delete(db.onConnPut, c)
+       }
+
        if err == driver.ErrBadConn {
                // Don't reuse bad connections.
+               db.mu.Unlock()
                return
        }
-       db.mu.Lock()
        if putConnHook != nil {
                putConnHook(db, c)
        }
@@ -300,7 +421,7 @@ func (db *DB) Prepare(query string) (*Stmt, error) {
        return stmt, err
 }
 
-func (db *DB) prepare(query string) (stmt *Stmt, err error) {
+func (db *DB) prepare(query string) (*Stmt, error) {
        // TODO: check if db.driver supports an optional
        // driver.Preparer interface and call that instead, if so,
        // otherwise we make a prepared statement that's bound
@@ -311,19 +432,17 @@ func (db *DB) prepare(query string) (stmt *Stmt, err error) {
        if err != nil {
                return nil, err
        }
-       defer func() {
-               db.putConn(ci, err)
-       }()
-
        si, err := ci.Prepare(query)
        if err != nil {
+               db.putConn(ci, err)
                return nil, err
        }
-       stmt = &Stmt{
+       stmt := &Stmt{
                db:    db,
                query: query,
                css:   []connStmt{{ci, si}},
        }
+       db.addDep(stmt, stmt)
        return stmt, nil
 }
 
@@ -698,6 +817,8 @@ type Stmt struct {
        query     string // that created the Stmt
        stickyErr error  // if non-nil, this error is returned for all operations
 
+       closemu sync.RWMutex // held exclusively during close, for read otherwise.
+
        // If in a transaction, else both nil:
        tx   *Tx
        txsi driver.Stmt
@@ -715,6 +836,8 @@ type Stmt struct {
 // Exec executes a prepared statement with the given arguments and
 // returns a Result summarizing the effect of the statement.
 func (s *Stmt) Exec(args ...interface{}) (Result, error) {
+       s.closemu.RLock()
+       defer s.closemu.RUnlock()
        _, releaseConn, si, err := s.connStmt()
        if err != nil {
                return nil, err
@@ -813,6 +936,9 @@ func (s *Stmt) connStmt() (ci driver.Conn, releaseConn func(error), si driver.St
 // Query executes a prepared query statement with the given arguments
 // and returns the query results as a *Rows.
 func (s *Stmt) Query(args ...interface{}) (*Rows, error) {
+       s.closemu.RLock()
+       defer s.closemu.RUnlock()
+
        ci, releaseConn, si, err := s.connStmt()
        if err != nil {
                return nil, err
@@ -827,10 +953,15 @@ func (s *Stmt) Query(args ...interface{}) (*Rows, error) {
        // Note: ownership of ci passes to the *Rows, to be freed
        // with releaseConn.
        rows := &Rows{
-               db:          s.db,
-               ci:          ci,
-               releaseConn: releaseConn,
-               rowsi:       rowsi,
+               db:    s.db,
+               ci:    ci,
+               rowsi: rowsi,
+               // releaseConn set below
+       }
+       s.db.addDep(s, rows)
+       rows.releaseConn = func(err error) {
+               releaseConn(err)
+               s.db.removeDep(s, rows)
        }
        return rows, nil
 }
@@ -876,6 +1007,9 @@ func (s *Stmt) QueryRow(args ...interface{}) *Row {
 
 // Close closes the statement.
 func (s *Stmt) Close() error {
+       s.closemu.Lock()
+       defer s.closemu.Unlock()
+
        if s.stickyErr != nil {
                return s.stickyErr
        }
@@ -888,18 +1022,17 @@ func (s *Stmt) Close() error {
 
        if s.tx != nil {
                s.txsi.Close()
-       } else {
-               for _, v := range s.css {
-                       if ci, match := s.db.connIfFree(v.ci); match {
-                               v.si.Close()
-                               s.db.putConn(ci, nil)
-                       } else {
-                               // TODO(bradfitz): care that we can't close
-                               // this statement because the statement's
-                               // connection is in use?
-                       }
-               }
+               return nil
+       }
+
+       return s.db.removeDep(s, s)
+}
+
+func (s *Stmt) finalClose() error {
+       for _, v := range s.css {
+               s.db.noteUnusedDriverStatement(v.ci, v.si)
        }
+       s.css = nil
        return nil
 }
 
@@ -918,7 +1051,7 @@ func (s *Stmt) Close() error {
 //     ...
 type Rows struct {
        db          *DB
-       ci          driver.Conn // owned; must call putconn when closed to release
+       ci          driver.Conn // owned; must call releaseConn when closed to release
        releaseConn func(error)
        rowsi       driver.Rows
 
@@ -1094,3 +1227,8 @@ type Result interface {
 type result struct {
        driver.Result
 }
+
+func stack() string {
+       var buf [1024]byte
+       return string(buf[:runtime.Stack(buf[:], false)])
+}
index 6571cfd8466f87dc3cf78ca95e8426e1b101f8f1..e6fc6a1d57c53720150124f0748058f1dfc6ba13 100644 (file)
@@ -8,7 +8,6 @@ import (
        "database/sql/driver"
        "fmt"
        "reflect"
-       "runtime"
        "strings"
        "testing"
        "time"
@@ -63,6 +62,10 @@ func exec(t *testing.T, db *DB, query string, args ...interface{}) {
 }
 
 func closeDB(t *testing.T, db *DB) {
+       if e := recover(); e != nil {
+               fmt.Printf("Panic: %v\n", e)
+               panic(e)
+       }
        err := db.Close()
        if err != nil {
                t.Fatalf("error closing DB: %v", err)
@@ -448,10 +451,8 @@ func TestIssue2542Deadlock(t *testing.T) {
        }
 }
 
+// From golang.org/issue/3865
 func TestCloseStmtBeforeRows(t *testing.T) {
-       t.Skip("known broken test; golang.org/issue/3865")
-       return
-
        db := newTestDB(t, "people")
        defer closeDB(t, db)
 
@@ -666,8 +667,3 @@ func nullTestRun(t *testing.T, spec nullTestSpec) {
                }
        }
 }
-
-func stack() string {
-       buf := make([]byte, 1024)
-       return string(buf[:runtime.Stack(buf, false)])
-}