From 209f6b1d2ca84541505c29c6158cde9d5a0bbd57 Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Mon, 25 Mar 2013 16:50:27 -0700 Subject: [PATCH] database/sql: don't close a driver.Conn until its Stmts are closed Fixes #5046 R=golang-dev, r CC=golang-dev https://golang.org/cl/8016044 --- src/pkg/database/sql/fakedb_test.go | 21 +++++++- src/pkg/database/sql/sql.go | 82 ++++++++++++++++++++++------- src/pkg/database/sql/sql_test.go | 54 +++++++++++++++++++ 3 files changed, 137 insertions(+), 20 deletions(-) diff --git a/src/pkg/database/sql/fakedb_test.go b/src/pkg/database/sql/fakedb_test.go index 55597f7de3..24c255f6e0 100644 --- a/src/pkg/database/sql/fakedb_test.go +++ b/src/pkg/database/sql/fakedb_test.go @@ -229,7 +229,26 @@ func (c *fakeConn) Begin() (driver.Tx, error) { return c.currTx, nil } -func (c *fakeConn) Close() error { +var hookPostCloseConn struct { + sync.Mutex + fn func(*fakeConn, error) +} + +func setHookpostCloseConn(fn func(*fakeConn, error)) { + hookPostCloseConn.Lock() + defer hookPostCloseConn.Unlock() + hookPostCloseConn.fn = fn +} + +func (c *fakeConn) Close() (err error) { + defer func() { + hookPostCloseConn.Lock() + fn := hookPostCloseConn.fn + hookPostCloseConn.Unlock() + if fn != nil { + fn(c, err) + } + }() if c.currTx != nil { return errors.New("can't close fakeConn; in a Transaction") } diff --git a/src/pkg/database/sql/sql.go b/src/pkg/database/sql/sql.go index bc92ecd8e6..d1f929e7cb 100644 --- a/src/pkg/database/sql/sql.go +++ b/src/pkg/database/sql/sql.go @@ -204,8 +204,42 @@ type DB struct { // interfaces returned via that Conn, such as calls on Tx, Stmt, // Result, Rows) type driverConn struct { - sync.Mutex - ci driver.Conn + db *DB + + sync.Mutex // guards following + ci driver.Conn + closed bool +} + +// the dc.db's Mutex is held. +func (dc *driverConn) closeDBLocked() error { + dc.Lock() + if dc.closed { + dc.Unlock() + return errors.New("sql: duplicate driverConn close") + } + dc.closed = true + dc.Unlock() // not defer; removeDep finalClose calls may need to lock + return dc.db.removeDepLocked(dc, dc)() +} + +func (dc *driverConn) Close() error { + dc.Lock() + if dc.closed { + dc.Unlock() + return errors.New("sql: duplicate driverConn close") + } + dc.closed = true + dc.Unlock() // not defer; removeDep finalClose calls may need to lock + return dc.db.removeDep(dc, dc) +} + +func (dc *driverConn) finalClose() error { + dc.Lock() + err := dc.ci.Close() + dc.ci = nil + dc.Unlock() + return err } // driverStmt associates a driver.Stmt with the @@ -238,6 +272,10 @@ func (db *DB) addDep(x finalCloser, dep interface{}) { //println(fmt.Sprintf("addDep(%T %p, %T %p)", x, x, dep, dep)) db.mu.Lock() defer db.mu.Unlock() + db.addDepLocked(x, dep) +} + +func (db *DB) addDepLocked(x finalCloser, dep interface{}) { if db.dep == nil { db.dep = make(map[finalCloser]depSet) } @@ -254,10 +292,16 @@ func (db *DB) addDep(x finalCloser, dep interface{}) { // If x no longer has any dependencies, its finalClose method will be // called and its error value will be returned. func (db *DB) removeDep(x finalCloser, dep interface{}) error { + db.mu.Lock() + fn := db.removeDepLocked(x, dep) + db.mu.Unlock() + return fn() +} + +func (db *DB) removeDepLocked(x finalCloser, dep interface{}) func() error { //println(fmt.Sprintf("removeDep(%T %p, %T %p)", x, x, dep, dep)) done := false - db.mu.Lock() xdep := db.dep[x] if xdep != nil { delete(xdep, dep) @@ -266,13 +310,14 @@ func (db *DB) removeDep(x finalCloser, dep interface{}) error { done = true } } - db.mu.Unlock() if !done { - return nil + return func() error { return nil } + } + return func() error { + //println(fmt.Sprintf("calling final close on %T %v (%#v)", x, x, x)) + return x.finalClose() } - //println(fmt.Sprintf("calling final close on %T %v (%#v)", x, x, x)) - return x.finalClose() } // Open opens a database specified by its database driver name and a @@ -320,9 +365,7 @@ func (db *DB) Close() error { defer db.mu.Unlock() var err error for _, dc := range db.freeConn { - dc.Lock() - err1 := dc.ci.Close() - dc.Unlock() + err1 := dc.closeDBLocked() if err1 != nil { err = err1 } @@ -365,11 +408,7 @@ func (db *DB) SetMaxIdleConns(n int) { dc := db.freeConn[nfree-1] db.freeConn[nfree-1] = nil db.freeConn = db.freeConn[:nfree-1] - go func() { - dc.Lock() - dc.ci.Close() - dc.Unlock() - }() + go dc.Close() } } @@ -393,8 +432,12 @@ func (db *DB) conn() (*driverConn, error) { if err != nil { return nil, err } - dc := &driverConn{ci: ci} + dc := &driverConn{ + db: db, + ci: ci, + } db.mu.Lock() + db.addDepLocked(dc, dc) db.outConn[dc] = true db.mu.Unlock() return dc, nil @@ -484,9 +527,7 @@ func (db *DB) putConn(dc *driverConn, err error) { // statements which are still active? db.mu.Unlock() - dc.Lock() - dc.ci.Close() - dc.Unlock() + dc.Close() } // Prepare creates a prepared statement for later queries or executions. @@ -528,6 +569,7 @@ func (db *DB) prepare(query string) (*Stmt, error) { css: []connStmt{{dc, si}}, } db.addDep(stmt, stmt) + db.addDep(dc, stmt) db.putConn(dc, nil) return stmt, nil } @@ -1031,6 +1073,7 @@ func (s *Stmt) connStmt() (ci *driverConn, releaseConn func(error), si driver.St if err != nil { return nil, nil, nil, err } + s.db.addDep(dc, s) s.mu.Lock() cs = connStmt{dc, si} s.css = append(s.css, cs) @@ -1149,6 +1192,7 @@ func (s *Stmt) Close() error { func (s *Stmt) finalClose() error { for _, v := range s.css { s.db.noteUnusedDriverStatement(v.dc, v.si) + s.db.removeDep(v.dc, s) } s.css = nil return nil diff --git a/src/pkg/database/sql/sql_test.go b/src/pkg/database/sql/sql_test.go index 2a9592e104..54aad3a5d0 100644 --- a/src/pkg/database/sql/sql_test.go +++ b/src/pkg/database/sql/sql_test.go @@ -65,6 +65,12 @@ func closeDB(t *testing.T, db *DB) { fmt.Printf("Panic: %v\n", e) panic(e) } + defer setHookpostCloseConn(nil) + setHookpostCloseConn(func(_ *fakeConn, err error) { + if err != nil { + t.Errorf("Error closing fakeConn: %v", err) + } + }) err := db.Close() if err != nil { t.Fatalf("error closing DB: %v", err) @@ -790,3 +796,51 @@ func TestMaxIdleConns(t *testing.T) { t.Errorf("freeConns = %d; want 0", got) } } + +// golang.org/issue/5046 +func TestCloseConnBeforeStmts(t *testing.T) { + defer setHookpostCloseConn(nil) + setHookpostCloseConn(func(_ *fakeConn, err error) { + if err != nil { + t.Errorf("Error closing fakeConn: %v", err) + } + }) + + db := newTestDB(t, "people") + + stmt, err := db.Prepare("SELECT|people|name|") + if err != nil { + t.Fatal(err) + } + + if len(db.freeConn) != 1 { + t.Fatalf("expected 1 freeConn; got %d", len(db.freeConn)) + } + dc := db.freeConn[0] + if dc.closed { + t.Errorf("conn shouldn't be closed") + } + + err = db.Close() + if err != nil { + t.Errorf("db Close = %v", err) + } + if !dc.closed { + t.Errorf("after db.Close, driverConn should be closed") + } + if dc.ci == nil { + t.Errorf("after db.Close, driverConn should still have its Conn interface") + } + + err = stmt.Close() + if err != nil { + t.Errorf("Stmt close = %v", err) + } + + if !dc.closed { + t.Errorf("conn should be closed") + } + if dc.ci != nil { + t.Errorf("after Stmt Close, driverConn's Conn interface should be nil") + } +} -- 2.48.1