]> Cypherpunks repositories - gostls13.git/commitdiff
database/sql: don't close a driver.Conn until its Stmts are closed
authorBrad Fitzpatrick <bradfitz@golang.org>
Mon, 25 Mar 2013 23:50:27 +0000 (16:50 -0700)
committerBrad Fitzpatrick <bradfitz@golang.org>
Mon, 25 Mar 2013 23:50:27 +0000 (16:50 -0700)
Fixes #5046

R=golang-dev, r
CC=golang-dev
https://golang.org/cl/8016044

src/pkg/database/sql/fakedb_test.go
src/pkg/database/sql/sql.go
src/pkg/database/sql/sql_test.go

index 55597f7de3ba38628ba4dc50c4963529b4f1ec66..24c255f6e08edac49d6915ca80df07b8a093a130 100644 (file)
@@ -229,7 +229,26 @@ func (c *fakeConn) Begin() (driver.Tx, error) {
        return c.currTx, nil
 }
 
-func (c *fakeConn) Close() error {
+var hookPostCloseConn struct {
+       sync.Mutex
+       fn func(*fakeConn, error)
+}
+
+func setHookpostCloseConn(fn func(*fakeConn, error)) {
+       hookPostCloseConn.Lock()
+       defer hookPostCloseConn.Unlock()
+       hookPostCloseConn.fn = fn
+}
+
+func (c *fakeConn) Close() (err error) {
+       defer func() {
+               hookPostCloseConn.Lock()
+               fn := hookPostCloseConn.fn
+               hookPostCloseConn.Unlock()
+               if fn != nil {
+                       fn(c, err)
+               }
+       }()
        if c.currTx != nil {
                return errors.New("can't close fakeConn; in a Transaction")
        }
index bc92ecd8e67a1ccc8043596bd2887c07afe872db..d1f929e7cba9e0f7e352fd7c0e335b6fd6458cf5 100644 (file)
@@ -204,8 +204,42 @@ type DB struct {
 // interfaces returned via that Conn, such as calls on Tx, Stmt,
 // Result, Rows)
 type driverConn struct {
-       sync.Mutex
-       ci driver.Conn
+       db *DB
+
+       sync.Mutex // guards following
+       ci         driver.Conn
+       closed     bool
+}
+
+// the dc.db's Mutex is held.
+func (dc *driverConn) closeDBLocked() error {
+       dc.Lock()
+       if dc.closed {
+               dc.Unlock()
+               return errors.New("sql: duplicate driverConn close")
+       }
+       dc.closed = true
+       dc.Unlock() // not defer; removeDep finalClose calls may need to lock
+       return dc.db.removeDepLocked(dc, dc)()
+}
+
+func (dc *driverConn) Close() error {
+       dc.Lock()
+       if dc.closed {
+               dc.Unlock()
+               return errors.New("sql: duplicate driverConn close")
+       }
+       dc.closed = true
+       dc.Unlock() // not defer; removeDep finalClose calls may need to lock
+       return dc.db.removeDep(dc, dc)
+}
+
+func (dc *driverConn) finalClose() error {
+       dc.Lock()
+       err := dc.ci.Close()
+       dc.ci = nil
+       dc.Unlock()
+       return err
 }
 
 // driverStmt associates a driver.Stmt with the
@@ -238,6 +272,10 @@ func (db *DB) addDep(x finalCloser, dep interface{}) {
        //println(fmt.Sprintf("addDep(%T %p, %T %p)", x, x, dep, dep))
        db.mu.Lock()
        defer db.mu.Unlock()
+       db.addDepLocked(x, dep)
+}
+
+func (db *DB) addDepLocked(x finalCloser, dep interface{}) {
        if db.dep == nil {
                db.dep = make(map[finalCloser]depSet)
        }
@@ -254,10 +292,16 @@ func (db *DB) addDep(x finalCloser, dep interface{}) {
 // If x no longer has any dependencies, its finalClose method will be
 // called and its error value will be returned.
 func (db *DB) removeDep(x finalCloser, dep interface{}) error {
+       db.mu.Lock()
+       fn := db.removeDepLocked(x, dep)
+       db.mu.Unlock()
+       return fn()
+}
+
+func (db *DB) removeDepLocked(x finalCloser, dep interface{}) func() error {
        //println(fmt.Sprintf("removeDep(%T %p, %T %p)", x, x, dep, dep))
        done := false
 
-       db.mu.Lock()
        xdep := db.dep[x]
        if xdep != nil {
                delete(xdep, dep)
@@ -266,13 +310,14 @@ func (db *DB) removeDep(x finalCloser, dep interface{}) error {
                        done = true
                }
        }
-       db.mu.Unlock()
 
        if !done {
-               return nil
+               return func() error { return nil }
+       }
+       return func() error {
+               //println(fmt.Sprintf("calling final close on %T %v (%#v)", x, x, x))
+               return x.finalClose()
        }
-       //println(fmt.Sprintf("calling final close on %T %v (%#v)", x, x, x))
-       return x.finalClose()
 }
 
 // Open opens a database specified by its database driver name and a
@@ -320,9 +365,7 @@ func (db *DB) Close() error {
        defer db.mu.Unlock()
        var err error
        for _, dc := range db.freeConn {
-               dc.Lock()
-               err1 := dc.ci.Close()
-               dc.Unlock()
+               err1 := dc.closeDBLocked()
                if err1 != nil {
                        err = err1
                }
@@ -365,11 +408,7 @@ func (db *DB) SetMaxIdleConns(n int) {
                dc := db.freeConn[nfree-1]
                db.freeConn[nfree-1] = nil
                db.freeConn = db.freeConn[:nfree-1]
-               go func() {
-                       dc.Lock()
-                       dc.ci.Close()
-                       dc.Unlock()
-               }()
+               go dc.Close()
        }
 }
 
@@ -393,8 +432,12 @@ func (db *DB) conn() (*driverConn, error) {
        if err != nil {
                return nil, err
        }
-       dc := &driverConn{ci: ci}
+       dc := &driverConn{
+               db: db,
+               ci: ci,
+       }
        db.mu.Lock()
+       db.addDepLocked(dc, dc)
        db.outConn[dc] = true
        db.mu.Unlock()
        return dc, nil
@@ -484,9 +527,7 @@ func (db *DB) putConn(dc *driverConn, err error) {
        // statements which are still active?
        db.mu.Unlock()
 
-       dc.Lock()
-       dc.ci.Close()
-       dc.Unlock()
+       dc.Close()
 }
 
 // Prepare creates a prepared statement for later queries or executions.
@@ -528,6 +569,7 @@ func (db *DB) prepare(query string) (*Stmt, error) {
                css:   []connStmt{{dc, si}},
        }
        db.addDep(stmt, stmt)
+       db.addDep(dc, stmt)
        db.putConn(dc, nil)
        return stmt, nil
 }
@@ -1031,6 +1073,7 @@ func (s *Stmt) connStmt() (ci *driverConn, releaseConn func(error), si driver.St
                        if err != nil {
                                return nil, nil, nil, err
                        }
+                       s.db.addDep(dc, s)
                        s.mu.Lock()
                        cs = connStmt{dc, si}
                        s.css = append(s.css, cs)
@@ -1149,6 +1192,7 @@ func (s *Stmt) Close() error {
 func (s *Stmt) finalClose() error {
        for _, v := range s.css {
                s.db.noteUnusedDriverStatement(v.dc, v.si)
+               s.db.removeDep(v.dc, s)
        }
        s.css = nil
        return nil
index 2a9592e104c30ae70902240e7d2c91fde79568ce..54aad3a5d01b06d476fc2cd2b6327edae833b920 100644 (file)
@@ -65,6 +65,12 @@ func closeDB(t *testing.T, db *DB) {
                fmt.Printf("Panic: %v\n", e)
                panic(e)
        }
+       defer setHookpostCloseConn(nil)
+       setHookpostCloseConn(func(_ *fakeConn, err error) {
+               if err != nil {
+                       t.Errorf("Error closing fakeConn: %v", err)
+               }
+       })
        err := db.Close()
        if err != nil {
                t.Fatalf("error closing DB: %v", err)
@@ -790,3 +796,51 @@ func TestMaxIdleConns(t *testing.T) {
                t.Errorf("freeConns = %d; want 0", got)
        }
 }
+
+// golang.org/issue/5046
+func TestCloseConnBeforeStmts(t *testing.T) {
+       defer setHookpostCloseConn(nil)
+       setHookpostCloseConn(func(_ *fakeConn, err error) {
+               if err != nil {
+                       t.Errorf("Error closing fakeConn: %v", err)
+               }
+       })
+
+       db := newTestDB(t, "people")
+
+       stmt, err := db.Prepare("SELECT|people|name|")
+       if err != nil {
+               t.Fatal(err)
+       }
+
+       if len(db.freeConn) != 1 {
+               t.Fatalf("expected 1 freeConn; got %d", len(db.freeConn))
+       }
+       dc := db.freeConn[0]
+       if dc.closed {
+               t.Errorf("conn shouldn't be closed")
+       }
+
+       err = db.Close()
+       if err != nil {
+               t.Errorf("db Close = %v", err)
+       }
+       if !dc.closed {
+               t.Errorf("after db.Close, driverConn should be closed")
+       }
+       if dc.ci == nil {
+               t.Errorf("after db.Close, driverConn should still have its Conn interface")
+       }
+
+       err = stmt.Close()
+       if err != nil {
+               t.Errorf("Stmt close = %v", err)
+       }
+
+       if !dc.closed {
+               t.Errorf("conn should be closed")
+       }
+       if dc.ci != nil {
+               t.Errorf("after Stmt Close, driverConn's Conn interface should be nil")
+       }
+}